test / oss_ops_inference.py
Xeraphinite's picture
Upload 19 files
ebb9c75 verified
r""" HyperAverageMetercorrelation Squeeze testing code """
import argparse
import sys
import os
from os.path import join
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from segment_anything import SamPredictor, SamAutomaticMaskGenerator
from gradio_demo.Matcher import Matcher
from matcher.common import utils
import random
random.seed(0)
def default_argument_parser():
# Arguments parsing
parser = argparse.ArgumentParser(description='Matcher Pytorch Implementation for One-shot Segmentation')
# Dataset parameters
parser.add_argument('--datapath', type=str, default='datasets')
parser.add_argument('--benchmark', type=str, default='coco',
choices=['fss', 'coco', 'lvis', 'paco_part', 'pascal_part'])
parser.add_argument('--bsz', type=int, default=1)
parser.add_argument('--nworker', type=int, default=0)
parser.add_argument('--fold', type=int, default=0)
parser.add_argument('--nshot', type=int, default=1)
parser.add_argument('--img-size', type=int, default=518)
parser.add_argument('--use_original_imgsize', action='store_true')
parser.add_argument('--log-root', type=str, default='output/coco/fold0')
parser.add_argument('--visualize', type=int, default=0)
# DINOv2 and SAM parameters
parser.add_argument('--dinov2-weights', type=str, default="models/dinov2_vitl14_pretrain.pth")
parser.add_argument('--sam-weights', type=str, default="models/sam_vit_h_4b8939.pth")
parser.add_argument('--points_per_side', type=int, default=64)
parser.add_argument('--pred_iou_thresh', type=float, default=0.88)
parser.add_argument('--sel_stability_score_thresh', type=float, default=0.0)
parser.add_argument('--stability_score_thresh', type=float, default=0.95)
parser.add_argument('--iou_filter', type=float, default=0.0)
parser.add_argument('--box_nms_thresh', type=float, default=1.0)
parser.add_argument('--output_layer', type=int, default=3)
parser.add_argument('--dense_multimask_output', type=int, default=0)
parser.add_argument('--use_dense_mask', type=int, default=0)
parser.add_argument('--multimask_output', type=int, default=0)
# Matcher parameters
parser.add_argument('--num_centers', type=int, default=8, help='K centers for kmeans')
parser.add_argument('--use_box', action='store_true', help='use box as an extra prompt for sam')
parser.add_argument('--use_points_or_centers', action='store_true', help='points:T, center: F')
parser.add_argument('--sample-range', type=tuple, default=(4,6), help='sample points number range')
parser.add_argument('--max_sample_iterations', type=int, default=30)
parser.add_argument('--alpha', type=float, default=1.)
parser.add_argument('--beta', type=float, default=0.)
parser.add_argument('--exp', type=float, default=0.)
parser.add_argument('--emd_filter', type=float, default=0.0, help='use emd_filter')
parser.add_argument('--purity_filter', type=float, default=0.0, help='use purity_filter')
parser.add_argument('--coverage_filter', type=float, default=0.0, help='use coverage_filter')
parser.add_argument('--use_score_filter', action='store_true')
parser.add_argument('--deep_score_norm_filter', type=float, default=0.1)
parser.add_argument('--deep_score_filter', type=float, default=0.33)
parser.add_argument('--topk_scores_threshold', type=float, default=0.7)
parser.add_argument('--num_merging_mask', type=int, default=10, help='topk masks for merging')
args = parser.parse_args()
return args
def definite_argument_parser(args, version=1):
if version==1:
args.max_sample_iterations = 64
args.box_nms_thresh = 0.65
args.sample_range = (1, 6)
args.topk_scores_threshold = 0.0
args.use_dense_mask = 1
args.use_points_or_centers = True
args.purity_filter = 0.02
args.iou_filter = 0.85
args.multimask_output = 1
args.sel_stability_score_thresh = 0.90
args.use_score_filter = True
args.alpha = 1.0
args.beta = 0.
args.exp = 0.
args.num_merging_mask = 9
elif version == 2:
args.max_sample_iterations = 30
args.sample_range = (4, 6)
args.multimask_output = 0
args.alpha = 0.8
args.beta = 0.2
args.exp = 1.
args.num_merging_mask = 10
elif version == 3:
args.max_sample_iterations = 128
args.sample_range = (3, 6)
args.use_box = True
args.use_points_or_centers = True
args.coverage_filter = 0.3
args.alpha = 0.5
args.beta = 0.5
args.exp = 0.
args.num_merging_mask = 5
return args
def preprocess_data(kwargs, args=None):
img_size = args.img_size
transform = transforms.Compose([
transforms.Resize(size=(img_size, img_size)),
transforms.ToTensor()
])
support_img = Image.fromarray(kwargs.get("support_img"))
query_img_1 = Image.fromarray(kwargs.get("query_img_1"))
query_img_2 = Image.fromarray(kwargs.get("query_img_2"))
support_img_ori_size = (support_img.size[1], support_img.size[0]) # H, W
query_img_1_ori_size = (query_img_1.size[1], query_img_1.size[0])
query_img_2_ori_size = (query_img_2.size[1], query_img_2.size[0])
support_img = transform(support_img)
query_img_1 = transform(query_img_1)
query_img_2 = transform(query_img_2)
support_mask = torch.tensor(kwargs.get("support_mask"))
support_mask = F.interpolate(support_mask.unsqueeze(0).float(), support_img.size()[-2:],
mode='nearest') > 0
query_imgs = torch.stack([query_img_1, query_img_2], dim=0)
data = {
"support_img": support_img[None, ...],
"support_mask": support_mask,
"query_imgs": query_imgs,
"support_img_ori_size": support_img_ori_size,
"query_imgs_ori_size": (query_img_1_ori_size, query_img_2_ori_size),
}
return data
def preprocess_support_mask(data, predictor, version=1):
if version == 3:
return data
sup_mask = data['support_mask'].squeeze()
H, W = sup_mask.shape[-2:]
input_points = sup_mask.nonzero().numpy()[:1,::-1]#[:,::-1]
input_label = np.array([1]*len(input_points))
support_img_np = data['support_img'].mul(255).byte()
support_img_np = support_img_np.squeeze().permute(1,2,0).cpu().numpy()
# forward encoder to obtain image feature
predictor.reset_image()
predictor.set_image(support_img_np)
# mask, _, _ = predictor.predict(
# point_coords=input_points,
# point_labels=input_label,
# multimask_output=False #True
# )
mask, _, _ = predictor.predict(
point_coords=input_points,
point_labels=input_label,
multimask_output=True # True
)
predictor.reset_image()
# show_img_point_box_mask(
# support_img_np,
# masks=mask,
# save_path='test1.png',
# mode='mask'
# )
# data['support_mask'] = torch.tensor(mask[:1])[None, ...]
data['support_mask'] = torch.tensor(mask[-1:])[None, ...]
return data
def main_oss_ops(**kwargs):
args = default_argument_parser()
args = definite_argument_parser(args, kwargs.get("version"))
# Model initialization
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.device = device
# create sam
sam = kwargs.get("sam")
predictor = SamPredictor(sam)
generator = SamAutomaticMaskGenerator(
sam,
points_per_side=args.points_per_side,
points_per_batch=64,
pred_iou_thresh=args.pred_iou_thresh,
stability_score_thresh=args.stability_score_thresh,
stability_score_offset=1.0,
sel_stability_score_thresh=args.sel_stability_score_thresh,
sel_pred_iou_thresh=args.iou_filter,
box_nms_thresh=args.box_nms_thresh,
sel_output_layer=args.output_layer,
output_layer=args.dense_multimask_output,
dense_pred=args.use_dense_mask,
multimask_output=args.dense_multimask_output > 0,
sel_multimask_output=args.multimask_output > 0,
)
# create dinov2, large
dinov2 = kwargs.get("dinov2")
# create matcher
score_filter_cfg = {
"emd": args.emd_filter,
"purity": args.purity_filter,
"coverage": args.coverage_filter,
"score_filter": args.use_score_filter,
"score": args.deep_score_filter,
"score_norm": args.deep_score_norm_filter,
"topk_scores_threshold": args.topk_scores_threshold
}
matcher = Matcher(
encoder=dinov2,
generator=generator,
num_centers=args.num_centers,
use_box=args.use_box,
use_points_or_centers=args.use_points_or_centers,
sample_range=args.sample_range,
max_sample_iterations=args.max_sample_iterations,
alpha=args.alpha,
beta=args.beta,
exp=args.exp,
score_filter_cfg=score_filter_cfg,
num_merging_mask=args.num_merging_mask,
device=args.device
)
# process data
data = preprocess_data(kwargs, args=args)
data = preprocess_support_mask(data, predictor, version=kwargs.get("version"))
# inference
with torch.no_grad():
utils.fix_randseed(0)
pred_masks, pred_mask_lists = [], []
# support mask
support_img_ori_size = data['support_img_ori_size']
mask = data['support_mask'].to(predictor.model.device).float()
mask = F.interpolate(mask, support_img_ori_size, mode="bilinear", align_corners=False) > 0
mask = mask.squeeze(0).cpu().numpy()
pred_masks.append(mask)
pred_mask_lists.append(None)
for query_img, query_img_ori_size in zip(data['query_imgs'], data['query_imgs_ori_size']):
data['query_img'], data['query_img_ori_size'] = query_img[None, ...], query_img_ori_size
support_imgs, support_masks = data["support_img"].to(matcher.device)[None, ...], data["support_mask"].to(matcher.device) # (1, 1, 3, H, W), (1, 1, H, W)
query_img, query_img_ori_size = data['query_img'].to(matcher.device), data['query_img_ori_size'] # (1, 3, H, W), img_size
# 1. Matcher prepare references and target
matcher.set_reference(support_imgs, support_masks)
matcher.set_target(query_img, query_img_ori_size)
# 2. Predict mask of target
pred_mask, pred_mask_list = matcher.predict()
matcher.clear()
pred_masks.append(pred_mask)
pred_mask_lists.append(pred_mask_list)
return pred_masks, pred_mask_lists