Spaces:
Runtime error
Runtime error
r""" HyperAverageMetercorrelation Squeeze testing code """ | |
import argparse | |
import sys | |
import os | |
from os.path import join | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import numpy as np | |
from PIL import Image | |
from segment_anything import SamPredictor, SamAutomaticMaskGenerator | |
from gradio_demo.Matcher import Matcher | |
from matcher.common import utils | |
import random | |
random.seed(0) | |
def default_argument_parser(): | |
# Arguments parsing | |
parser = argparse.ArgumentParser(description='Matcher Pytorch Implementation for One-shot Segmentation') | |
# Dataset parameters | |
parser.add_argument('--datapath', type=str, default='datasets') | |
parser.add_argument('--benchmark', type=str, default='coco', | |
choices=['fss', 'coco', 'lvis', 'paco_part', 'pascal_part']) | |
parser.add_argument('--bsz', type=int, default=1) | |
parser.add_argument('--nworker', type=int, default=0) | |
parser.add_argument('--fold', type=int, default=0) | |
parser.add_argument('--nshot', type=int, default=1) | |
parser.add_argument('--img-size', type=int, default=518) | |
parser.add_argument('--use_original_imgsize', action='store_true') | |
parser.add_argument('--log-root', type=str, default='output/coco/fold0') | |
parser.add_argument('--visualize', type=int, default=0) | |
# DINOv2 and SAM parameters | |
parser.add_argument('--dinov2-weights', type=str, default="models/dinov2_vitl14_pretrain.pth") | |
parser.add_argument('--sam-weights', type=str, default="models/sam_vit_h_4b8939.pth") | |
parser.add_argument('--points_per_side', type=int, default=64) | |
parser.add_argument('--pred_iou_thresh', type=float, default=0.88) | |
parser.add_argument('--sel_stability_score_thresh', type=float, default=0.0) | |
parser.add_argument('--stability_score_thresh', type=float, default=0.95) | |
parser.add_argument('--iou_filter', type=float, default=0.0) | |
parser.add_argument('--box_nms_thresh', type=float, default=1.0) | |
parser.add_argument('--output_layer', type=int, default=3) | |
parser.add_argument('--dense_multimask_output', type=int, default=0) | |
parser.add_argument('--use_dense_mask', type=int, default=0) | |
parser.add_argument('--multimask_output', type=int, default=0) | |
# Matcher parameters | |
parser.add_argument('--num_centers', type=int, default=8, help='K centers for kmeans') | |
parser.add_argument('--use_box', action='store_true', help='use box as an extra prompt for sam') | |
parser.add_argument('--use_points_or_centers', action='store_true', help='points:T, center: F') | |
parser.add_argument('--sample-range', type=tuple, default=(4,6), help='sample points number range') | |
parser.add_argument('--max_sample_iterations', type=int, default=30) | |
parser.add_argument('--alpha', type=float, default=1.) | |
parser.add_argument('--beta', type=float, default=0.) | |
parser.add_argument('--exp', type=float, default=0.) | |
parser.add_argument('--emd_filter', type=float, default=0.0, help='use emd_filter') | |
parser.add_argument('--purity_filter', type=float, default=0.0, help='use purity_filter') | |
parser.add_argument('--coverage_filter', type=float, default=0.0, help='use coverage_filter') | |
parser.add_argument('--use_score_filter', action='store_true') | |
parser.add_argument('--deep_score_norm_filter', type=float, default=0.1) | |
parser.add_argument('--deep_score_filter', type=float, default=0.33) | |
parser.add_argument('--topk_scores_threshold', type=float, default=0.7) | |
parser.add_argument('--num_merging_mask', type=int, default=10, help='topk masks for merging') | |
args = parser.parse_args() | |
return args | |
def definite_argument_parser(args, version=1): | |
if version==1: | |
args.max_sample_iterations = 64 | |
args.box_nms_thresh = 0.65 | |
args.sample_range = (1, 6) | |
args.topk_scores_threshold = 0.0 | |
args.use_dense_mask = 1 | |
args.use_points_or_centers = True | |
args.purity_filter = 0.02 | |
args.iou_filter = 0.85 | |
args.multimask_output = 1 | |
args.sel_stability_score_thresh = 0.90 | |
args.use_score_filter = True | |
args.alpha = 1.0 | |
args.beta = 0. | |
args.exp = 0. | |
args.num_merging_mask = 9 | |
elif version == 2: | |
args.max_sample_iterations = 30 | |
args.sample_range = (4, 6) | |
args.multimask_output = 0 | |
args.alpha = 0.8 | |
args.beta = 0.2 | |
args.exp = 1. | |
args.num_merging_mask = 10 | |
elif version == 3: | |
args.max_sample_iterations = 128 | |
args.sample_range = (3, 6) | |
args.use_box = True | |
args.use_points_or_centers = True | |
args.coverage_filter = 0.3 | |
args.alpha = 0.5 | |
args.beta = 0.5 | |
args.exp = 0. | |
args.num_merging_mask = 5 | |
return args | |
def preprocess_data(kwargs, args=None): | |
img_size = args.img_size | |
transform = transforms.Compose([ | |
transforms.Resize(size=(img_size, img_size)), | |
transforms.ToTensor() | |
]) | |
support_img = Image.fromarray(kwargs.get("support_img")) | |
query_img_1 = Image.fromarray(kwargs.get("query_img_1")) | |
query_img_2 = Image.fromarray(kwargs.get("query_img_2")) | |
support_img_ori_size = (support_img.size[1], support_img.size[0]) # H, W | |
query_img_1_ori_size = (query_img_1.size[1], query_img_1.size[0]) | |
query_img_2_ori_size = (query_img_2.size[1], query_img_2.size[0]) | |
support_img = transform(support_img) | |
query_img_1 = transform(query_img_1) | |
query_img_2 = transform(query_img_2) | |
support_mask = torch.tensor(kwargs.get("support_mask")) | |
support_mask = F.interpolate(support_mask.unsqueeze(0).float(), support_img.size()[-2:], | |
mode='nearest') > 0 | |
query_imgs = torch.stack([query_img_1, query_img_2], dim=0) | |
data = { | |
"support_img": support_img[None, ...], | |
"support_mask": support_mask, | |
"query_imgs": query_imgs, | |
"support_img_ori_size": support_img_ori_size, | |
"query_imgs_ori_size": (query_img_1_ori_size, query_img_2_ori_size), | |
} | |
return data | |
def preprocess_support_mask(data, predictor, version=1): | |
if version == 3: | |
return data | |
sup_mask = data['support_mask'].squeeze() | |
H, W = sup_mask.shape[-2:] | |
input_points = sup_mask.nonzero().numpy()[:1,::-1]#[:,::-1] | |
input_label = np.array([1]*len(input_points)) | |
support_img_np = data['support_img'].mul(255).byte() | |
support_img_np = support_img_np.squeeze().permute(1,2,0).cpu().numpy() | |
# forward encoder to obtain image feature | |
predictor.reset_image() | |
predictor.set_image(support_img_np) | |
# mask, _, _ = predictor.predict( | |
# point_coords=input_points, | |
# point_labels=input_label, | |
# multimask_output=False #True | |
# ) | |
mask, _, _ = predictor.predict( | |
point_coords=input_points, | |
point_labels=input_label, | |
multimask_output=True # True | |
) | |
predictor.reset_image() | |
# show_img_point_box_mask( | |
# support_img_np, | |
# masks=mask, | |
# save_path='test1.png', | |
# mode='mask' | |
# ) | |
# data['support_mask'] = torch.tensor(mask[:1])[None, ...] | |
data['support_mask'] = torch.tensor(mask[-1:])[None, ...] | |
return data | |
def main_oss_ops(**kwargs): | |
args = default_argument_parser() | |
args = definite_argument_parser(args, kwargs.get("version")) | |
# Model initialization | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
args.device = device | |
# create sam | |
sam = kwargs.get("sam") | |
predictor = SamPredictor(sam) | |
generator = SamAutomaticMaskGenerator( | |
sam, | |
points_per_side=args.points_per_side, | |
points_per_batch=64, | |
pred_iou_thresh=args.pred_iou_thresh, | |
stability_score_thresh=args.stability_score_thresh, | |
stability_score_offset=1.0, | |
sel_stability_score_thresh=args.sel_stability_score_thresh, | |
sel_pred_iou_thresh=args.iou_filter, | |
box_nms_thresh=args.box_nms_thresh, | |
sel_output_layer=args.output_layer, | |
output_layer=args.dense_multimask_output, | |
dense_pred=args.use_dense_mask, | |
multimask_output=args.dense_multimask_output > 0, | |
sel_multimask_output=args.multimask_output > 0, | |
) | |
# create dinov2, large | |
dinov2 = kwargs.get("dinov2") | |
# create matcher | |
score_filter_cfg = { | |
"emd": args.emd_filter, | |
"purity": args.purity_filter, | |
"coverage": args.coverage_filter, | |
"score_filter": args.use_score_filter, | |
"score": args.deep_score_filter, | |
"score_norm": args.deep_score_norm_filter, | |
"topk_scores_threshold": args.topk_scores_threshold | |
} | |
matcher = Matcher( | |
encoder=dinov2, | |
generator=generator, | |
num_centers=args.num_centers, | |
use_box=args.use_box, | |
use_points_or_centers=args.use_points_or_centers, | |
sample_range=args.sample_range, | |
max_sample_iterations=args.max_sample_iterations, | |
alpha=args.alpha, | |
beta=args.beta, | |
exp=args.exp, | |
score_filter_cfg=score_filter_cfg, | |
num_merging_mask=args.num_merging_mask, | |
device=args.device | |
) | |
# process data | |
data = preprocess_data(kwargs, args=args) | |
data = preprocess_support_mask(data, predictor, version=kwargs.get("version")) | |
# inference | |
with torch.no_grad(): | |
utils.fix_randseed(0) | |
pred_masks, pred_mask_lists = [], [] | |
# support mask | |
support_img_ori_size = data['support_img_ori_size'] | |
mask = data['support_mask'].to(predictor.model.device).float() | |
mask = F.interpolate(mask, support_img_ori_size, mode="bilinear", align_corners=False) > 0 | |
mask = mask.squeeze(0).cpu().numpy() | |
pred_masks.append(mask) | |
pred_mask_lists.append(None) | |
for query_img, query_img_ori_size in zip(data['query_imgs'], data['query_imgs_ori_size']): | |
data['query_img'], data['query_img_ori_size'] = query_img[None, ...], query_img_ori_size | |
support_imgs, support_masks = data["support_img"].to(matcher.device)[None, ...], data["support_mask"].to(matcher.device) # (1, 1, 3, H, W), (1, 1, H, W) | |
query_img, query_img_ori_size = data['query_img'].to(matcher.device), data['query_img_ori_size'] # (1, 3, H, W), img_size | |
# 1. Matcher prepare references and target | |
matcher.set_reference(support_imgs, support_masks) | |
matcher.set_target(query_img, query_img_ori_size) | |
# 2. Predict mask of target | |
pred_mask, pred_mask_list = matcher.predict() | |
matcher.clear() | |
pred_masks.append(pred_mask) | |
pred_mask_lists.append(pred_mask_list) | |
return pred_masks, pred_mask_lists | |