Spaces:
Runtime error
Runtime error
import os | |
from os import path | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import numpy as np | |
import cv2 | |
import ot | |
import math | |
from scipy.optimize import linear_sum_assignment | |
from segment_anything import sam_model_registry | |
from segment_anything import SamAutomaticMaskGenerator | |
from dinov2.models import vision_transformer as vits | |
import dinov2.utils.utils as dinov2_utils | |
from dinov2.data.transforms import MaybeToTensor, make_normalize_transform | |
from matcher.k_means import kmeans_pp | |
import random | |
class Matcher: | |
def __init__( | |
self, | |
encoder, | |
generator=None, | |
input_size=518, | |
num_centers=8, | |
use_box=False, | |
use_points_or_centers=True, | |
sample_range=(4, 6), | |
max_sample_iterations=30, | |
alpha=1., | |
beta=0., | |
exp=0., | |
score_filter_cfg=None, | |
num_merging_mask=10, | |
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), | |
): | |
# models | |
self.encoder = encoder | |
self.generator = generator | |
self.rps = None | |
if not isinstance(input_size, tuple): | |
input_size = (input_size, input_size) | |
self.input_size = input_size | |
# transforms for image encoder | |
self.encoder_transform = transforms.Compose([ | |
MaybeToTensor(), | |
make_normalize_transform(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
]) | |
self.tar_img = None | |
self.tar_img_np = None | |
self.tar_img_ori_size = None | |
self.ref_imgs = None | |
self.ref_masks_pool = None | |
self.nshot = None | |
self.encoder_img_size = None | |
self.encoder_feat_size = None | |
self.num_centers = num_centers | |
self.use_box = use_box | |
self.use_points_or_centers = use_points_or_centers | |
self.sample_range = sample_range | |
self.max_sample_iterations =max_sample_iterations | |
self.alpha, self.beta, self.exp = alpha, beta, exp | |
assert score_filter_cfg is not None | |
self.score_filter_cfg = score_filter_cfg | |
self.num_merging_mask = num_merging_mask | |
self.device = device | |
def set_reference(self, imgs, masks): | |
def reference_masks_verification(masks): | |
if masks.sum() == 0: | |
_, _, sh, sw = masks.shape | |
masks[..., (sh // 2 - 7):(sh // 2 + 7), (sw // 2 - 7):(sw // 2 + 7)] = 1 | |
return masks | |
imgs = imgs.flatten(0, 1) # bs, 3, h, w | |
img_size = imgs.shape[-1] | |
assert img_size == self.input_size[-1] | |
feat_size = img_size // self.encoder.patch_size | |
self.encoder_img_size = img_size | |
self.encoder_feat_size = feat_size | |
# process reference masks | |
masks = reference_masks_verification(masks) | |
masks = masks.permute(1, 0, 2, 3) # ns, 1, h, w | |
ref_masks_pool = F.avg_pool2d(masks.float(), (self.encoder.patch_size, self.encoder.patch_size)) | |
nshot = ref_masks_pool.shape[0] | |
ref_masks_pool = (ref_masks_pool > self.generator.predictor.model.mask_threshold).float() | |
ref_masks_pool = ref_masks_pool.reshape(-1) # nshot, N | |
self.ref_imgs = imgs | |
self.ref_masks_pool = ref_masks_pool | |
self.nshot = nshot | |
def set_target(self, img, tar_img_ori_size): | |
img_h, img_w = img.shape[-2:] | |
assert img_h == self.input_size[0] and img_w == self.input_size[1] | |
# transform query to numpy as input of sam | |
img_np = img.mul(255).byte() | |
img_np = img_np.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
self.tar_img = img | |
self.tar_img_np = img_np | |
self.tar_img_ori_size = tar_img_ori_size | |
def set_rps(self): | |
if self.rps is None: | |
assert self.encoder_feat_size is not None | |
self.rps = RobustPromptSampler( | |
encoder_feat_size=self.encoder_feat_size, | |
sample_range=self.sample_range, | |
max_iterations=self.max_sample_iterations | |
) | |
def predict(self): | |
ref_feats, tar_feat = self.extract_img_feats() | |
all_points, box, S, C, reduced_points_num = self.patch_level_matching(ref_feats=ref_feats, tar_feat=tar_feat) | |
points = self.clustering(all_points) if not self.use_points_or_centers else all_points | |
self.set_rps() | |
mask, mask_list = self.mask_generation(self.tar_img_np, points, box, all_points, self.ref_masks_pool, C) | |
return mask, mask_list | |
def extract_img_feats(self): | |
ref_imgs = torch.cat([self.encoder_transform(rimg)[None, ...] for rimg in self.ref_imgs], dim=0) | |
tar_img = torch.cat([self.encoder_transform(timg)[None, ...] for timg in self.tar_img], dim=0) | |
ref_feats = self.encoder.forward_features(ref_imgs.to(self.device))["x_prenorm"][:, 1:] | |
tar_feat = self.encoder.forward_features(tar_img.to(self.device))["x_prenorm"][:, 1:] | |
# ns, N, c = ref_feats.shape | |
ref_feats = ref_feats.reshape(-1, self.encoder.embed_dim) # ns*N, c | |
tar_feat = tar_feat.reshape(-1, self.encoder.embed_dim) # N, c | |
ref_feats = F.normalize(ref_feats, dim=1, p=2) # normalize for cosine similarity | |
tar_feat = F.normalize(tar_feat, dim=1, p=2) | |
return ref_feats, tar_feat | |
def patch_level_matching(self, ref_feats, tar_feat): | |
# forward matching | |
S = ref_feats @ tar_feat.t() # ns*N, N | |
C = (1 - S) / 2 # distance | |
S_forward = S[self.ref_masks_pool.flatten().bool()] | |
indices_forward = linear_sum_assignment(S_forward.cpu(), maximize=True) | |
indices_forward = [torch.as_tensor(index, dtype=torch.int64, device=self.device) for index in indices_forward] | |
sim_scores_f = S_forward[indices_forward[0], indices_forward[1]] | |
indices_mask = self.ref_masks_pool.flatten().nonzero()[:, 0] | |
# reverse matching | |
S_reverse = S.t()[indices_forward[1]] | |
indices_reverse = linear_sum_assignment(S_reverse.cpu(), maximize=True) | |
indices_reverse = [torch.as_tensor(index, dtype=torch.int64, device=self.device) for index in indices_reverse] | |
retain_ind = torch.isin(indices_reverse[1], indices_mask) | |
if not (retain_ind == False).all().item(): | |
indices_forward = [indices_forward[0][retain_ind], indices_forward[1][retain_ind]] | |
sim_scores_f = sim_scores_f[retain_ind] | |
inds_matched, sim_matched = indices_forward, sim_scores_f | |
reduced_points_num = len(sim_matched) // 2 if len(sim_matched) > 40 else len(sim_matched) | |
sim_sorted, sim_idx_sorted = torch.sort(sim_matched, descending=True) | |
sim_filter = sim_idx_sorted[:reduced_points_num] | |
points_matched_inds = indices_forward[1][sim_filter] | |
points_matched_inds_set = torch.tensor(list(set(points_matched_inds.cpu().tolist()))) | |
points_matched_inds_set_w = points_matched_inds_set % (self.encoder_feat_size) | |
points_matched_inds_set_h = points_matched_inds_set // (self.encoder_feat_size) | |
idxs_mask_set_x = (points_matched_inds_set_w * self.encoder.patch_size + self.encoder.patch_size // 2).tolist() | |
idxs_mask_set_y = (points_matched_inds_set_h * self.encoder.patch_size + self.encoder.patch_size // 2).tolist() | |
ponits_matched = [] | |
for x, y in zip(idxs_mask_set_x, idxs_mask_set_y): | |
if int(x) < self.input_size[1] and int(y) < self.input_size[0]: | |
ponits_matched.append([int(x), int(y)]) | |
ponits = np.array(ponits_matched) | |
if self.use_box: | |
box = np.array([ | |
max(ponits[:, 0].min(), 0), | |
max(ponits[:, 1].min(), 0), | |
min(ponits[:, 0].max(), self.input_size[1] - 1), | |
min(ponits[:, 1].max(), self.input_size[0] - 1), | |
]) | |
else: | |
box = None | |
return ponits, box, S, C, reduced_points_num | |
def clustering(self, points): | |
num_centers = min(self.num_centers, len(points)) | |
flag = True | |
while (flag): | |
centers, cluster_assignment = kmeans_pp(points, num_centers) | |
id, fre = torch.unique(cluster_assignment, return_counts=True) | |
if id.shape[0] == num_centers: | |
flag = False | |
else: | |
print('Kmeans++ failed, re-run') | |
centers = np.array(centers).astype(np.int64) | |
return centers | |
def mask_generation(self, tar_img_np, points, box, all_ponits, ref_masks_pool, C): | |
samples_list, label_list = self.rps.sample_points(points) | |
tar_masks_ori = self.generator.generate( | |
tar_img_np, | |
select_point_coords=samples_list, | |
select_point_labels=label_list, | |
select_box=[box] if self.use_box else None, | |
) | |
tar_masks = torch.cat( | |
[torch.from_numpy(qmask['segmentation']).float()[None, None, ...].to(self.device) for | |
qmask in tar_masks_ori], dim=0).cpu().numpy() > 0 | |
# append to original results | |
purity = torch.zeros(tar_masks.shape[0]) | |
coverage = torch.zeros(tar_masks.shape[0]) | |
emd = torch.zeros(tar_masks.shape[0]) | |
samples = samples_list[-1] | |
labels = torch.ones(tar_masks.shape[0], samples.shape[1]) | |
samples = torch.ones(tar_masks.shape[0], samples.shape[1], 2) | |
# compute scores for each mask | |
for i in range(len(tar_masks)): | |
purity_, coverage_, emd_, sample_, label_, mask_ = \ | |
self.rps.get_mask_scores( | |
points=points, | |
masks=tar_masks[i], | |
all_points=all_ponits, | |
emd_cost=C, | |
ref_masks_pool=ref_masks_pool | |
) | |
assert np.all(mask_ == tar_masks[i]) | |
purity[i] = purity_ | |
coverage[i] = coverage_ | |
emd[i] = emd_ | |
pred_masks = tar_masks.squeeze(1) | |
metric_preds = { | |
"purity": purity, | |
"coverage": coverage, | |
"emd": emd | |
} | |
scores = self.alpha * emd + self.beta * purity * coverage ** self.exp | |
def check_pred_mask(pred_masks): | |
if len(pred_masks.shape) < 3: # avoid only one mask | |
pred_masks = pred_masks[None, ...] | |
return pred_masks | |
pred_masks = check_pred_mask(pred_masks) | |
# filter the false-positive mask fragments by using the proposed metrics | |
for metric in ["coverage", "emd", "purity"]: | |
if self.score_filter_cfg[metric] > 0: | |
thres = min(self.score_filter_cfg[metric], metric_preds[metric].max()) | |
idx = torch.where(metric_preds[metric] >= thres)[0] | |
scores = scores[idx] | |
samples = samples[idx] | |
labels = labels[idx] | |
pred_masks = check_pred_mask(pred_masks[idx]) | |
for key in metric_preds.keys(): | |
metric_preds[key] = metric_preds[key][idx] | |
# score-based masks selection, masks merging | |
if self.score_filter_cfg["score_filter"]: | |
distances = 1 - scores | |
distances, rank = torch.sort(distances, descending=False) | |
distances_norm = distances - distances.min() | |
distances_norm = distances_norm / (distances.max() + 1e-6) | |
filer_dis = distances < self.score_filter_cfg["score"] | |
filer_dis[..., 0] = True | |
filer_dis_norm = distances_norm < self.score_filter_cfg["score_norm"] | |
filer_dis = filer_dis * filer_dis_norm | |
pred_masks = check_pred_mask(pred_masks) | |
masks = pred_masks[rank[filer_dis][:self.num_merging_mask]] | |
masks = check_pred_mask(masks) | |
mask_list = masks | |
masks = masks.sum(0) > 0 | |
masks = masks[None, ...] | |
else: | |
topk = min(self.num_merging_mask, scores.size(0)) | |
topk_idx = scores.topk(topk)[1] | |
topk_samples = samples[topk_idx].cpu().numpy() | |
topk_scores = scores[topk_idx].cpu().numpy() | |
topk_pred_masks = pred_masks[topk_idx] | |
topk_pred_masks = check_pred_mask(topk_pred_masks) | |
if self.score_filter_cfg["topk_scores_threshold"] > 0: | |
# map scores to 0-1 | |
topk_scores = topk_scores / (topk_scores.max()) | |
idx = topk_scores > self.score_filter_cfg["topk_scores_threshold"] | |
topk_samples = topk_samples[idx] | |
topk_pred_masks = check_pred_mask(topk_pred_masks) | |
topk_pred_masks = topk_pred_masks[idx] | |
mask_list = [] | |
for i in range(len(topk_samples)): | |
mask = topk_pred_masks[i][None, ...] | |
mask_list.append(mask) | |
mask_list = np.concatenate(mask_list, axis=0) | |
masks = np.sum(mask_list, axis=0) > 0 | |
masks = check_pred_mask(masks) | |
tar_img_ori_size = self.tar_img_ori_size | |
mask = torch.tensor(masks, device=self.device)[None, ...] | |
mask = F.interpolate(mask.float(), tar_img_ori_size, mode="bilinear", align_corners=False) > 0 | |
mask = mask.squeeze(0).cpu().numpy() | |
if mask_list is not None: | |
mask_list = torch.tensor(mask_list, device=self.device)[:, None, ...] | |
mask_list = F.interpolate(mask_list.float(), tar_img_ori_size, mode="bilinear", align_corners=False) | |
mask_list = mask_list.squeeze(0).cpu().numpy() | |
return mask, mask_list | |
def clear(self): | |
self.tar_img = None | |
self.tar_img_np = None | |
self.tar_img_ori_size = None | |
self.ref_imgs = None | |
self.ref_masks_pool = None | |
self.nshot = None | |
self.encoder_img_size = None | |
self.encoder_feat_size = None | |
class RobustPromptSampler: | |
def __init__( | |
self, | |
encoder_feat_size, | |
sample_range, | |
max_iterations | |
): | |
self.encoder_feat_size = encoder_feat_size | |
self.sample_range = sample_range | |
self.max_iterations = max_iterations | |
def get_mask_scores(self, points, masks, all_points, emd_cost, ref_masks_pool): | |
def is_in_mask(point, mask): | |
# input: point: n*2, mask: h*w | |
# output: n*1 | |
h, w = mask.shape | |
point = point.astype(np.int) | |
point = point[:, ::-1] # y,x | |
point = np.clip(point, 0, [h - 1, w - 1]) | |
return mask[point[:, 0], point[:, 1]] | |
ori_masks = masks | |
masks = cv2.resize( | |
masks[0].astype(np.float32), | |
(self.encoder_feat_size, self.encoder_feat_size), | |
interpolation=cv2.INTER_AREA) | |
if masks.max() <= 0: | |
thres = masks.max() - 1e-6 | |
else: | |
thres = 0 | |
masks = masks > thres | |
# 1. emd | |
emd_cost_pool = emd_cost[ref_masks_pool.flatten().bool(), :][:, masks.flatten()] | |
emd = ot.emd2(a=[1. / emd_cost_pool.shape[0] for i in range(emd_cost_pool.shape[0])], | |
b=[1. / emd_cost_pool.shape[1] for i in range(emd_cost_pool.shape[1])], | |
M=emd_cost_pool.cpu().numpy()) | |
emd_score = 1 - emd | |
labels = np.ones((points.shape[0],)) | |
# 2. purity and coverage | |
assert all_points is not None | |
points_in_mask = is_in_mask(all_points, ori_masks[0]) | |
points_in_mask = all_points[points_in_mask] | |
# here we define two metrics for local matching , purity and coverage | |
# purity: points_in/mask_area, the higher means the denser points in mask | |
# coverage: points_in / all_points, the higher means the mask is more complete | |
mask_area = max(float(masks.sum()), 1.0) | |
purity = points_in_mask.shape[0] / mask_area | |
coverage = points_in_mask.shape[0] / all_points.shape[0] | |
purity = torch.tensor([purity]) + 1e-6 | |
coverage = torch.tensor([coverage]) + 1e-6 | |
return purity, coverage, emd_score, points, labels, ori_masks | |
def combinations(self, n, k): | |
if k > n: | |
return [] | |
if k == 0: | |
return [[]] | |
if k == n: | |
return [[i for i in range(n)]] | |
res = [] | |
for i in range(n): | |
for j in self.combinations(i, k - 1): | |
res.append(j + [i]) | |
return res | |
def sample_points(self, points): | |
# return list of arrary | |
sample_list = [] | |
label_list = [] | |
for i in range(min(self.sample_range[0], len(points)), min(self.sample_range[1], len(points)) + 1): | |
if len(points) > 8: | |
index = [random.sample(range(len(points)), i) for j in range(self.max_iterations)] | |
sample = np.take(points, index, axis=0) # (max_iterations * i) * 2 | |
else: | |
index = self.combinations(len(points), i) | |
sample = np.take(points, index, axis=0) # i * n * 2 | |
# generate label max_iterations * i | |
label = np.ones((sample.shape[0], i)) | |
sample_list.append(sample) | |
label_list.append(label) | |
return sample_list, label_list | |