test / Matcher.py
Xeraphinite's picture
Upload 19 files
ebb9c75 verified
import os
from os import path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
import cv2
import ot
import math
from scipy.optimize import linear_sum_assignment
from segment_anything import sam_model_registry
from segment_anything import SamAutomaticMaskGenerator
from dinov2.models import vision_transformer as vits
import dinov2.utils.utils as dinov2_utils
from dinov2.data.transforms import MaybeToTensor, make_normalize_transform
from matcher.k_means import kmeans_pp
import random
class Matcher:
def __init__(
self,
encoder,
generator=None,
input_size=518,
num_centers=8,
use_box=False,
use_points_or_centers=True,
sample_range=(4, 6),
max_sample_iterations=30,
alpha=1.,
beta=0.,
exp=0.,
score_filter_cfg=None,
num_merging_mask=10,
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
):
# models
self.encoder = encoder
self.generator = generator
self.rps = None
if not isinstance(input_size, tuple):
input_size = (input_size, input_size)
self.input_size = input_size
# transforms for image encoder
self.encoder_transform = transforms.Compose([
MaybeToTensor(),
make_normalize_transform(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
self.tar_img = None
self.tar_img_np = None
self.tar_img_ori_size = None
self.ref_imgs = None
self.ref_masks_pool = None
self.nshot = None
self.encoder_img_size = None
self.encoder_feat_size = None
self.num_centers = num_centers
self.use_box = use_box
self.use_points_or_centers = use_points_or_centers
self.sample_range = sample_range
self.max_sample_iterations =max_sample_iterations
self.alpha, self.beta, self.exp = alpha, beta, exp
assert score_filter_cfg is not None
self.score_filter_cfg = score_filter_cfg
self.num_merging_mask = num_merging_mask
self.device = device
def set_reference(self, imgs, masks):
def reference_masks_verification(masks):
if masks.sum() == 0:
_, _, sh, sw = masks.shape
masks[..., (sh // 2 - 7):(sh // 2 + 7), (sw // 2 - 7):(sw // 2 + 7)] = 1
return masks
imgs = imgs.flatten(0, 1) # bs, 3, h, w
img_size = imgs.shape[-1]
assert img_size == self.input_size[-1]
feat_size = img_size // self.encoder.patch_size
self.encoder_img_size = img_size
self.encoder_feat_size = feat_size
# process reference masks
masks = reference_masks_verification(masks)
masks = masks.permute(1, 0, 2, 3) # ns, 1, h, w
ref_masks_pool = F.avg_pool2d(masks.float(), (self.encoder.patch_size, self.encoder.patch_size))
nshot = ref_masks_pool.shape[0]
ref_masks_pool = (ref_masks_pool > self.generator.predictor.model.mask_threshold).float()
ref_masks_pool = ref_masks_pool.reshape(-1) # nshot, N
self.ref_imgs = imgs
self.ref_masks_pool = ref_masks_pool
self.nshot = nshot
def set_target(self, img, tar_img_ori_size):
img_h, img_w = img.shape[-2:]
assert img_h == self.input_size[0] and img_w == self.input_size[1]
# transform query to numpy as input of sam
img_np = img.mul(255).byte()
img_np = img_np.squeeze(0).permute(1, 2, 0).cpu().numpy()
self.tar_img = img
self.tar_img_np = img_np
self.tar_img_ori_size = tar_img_ori_size
def set_rps(self):
if self.rps is None:
assert self.encoder_feat_size is not None
self.rps = RobustPromptSampler(
encoder_feat_size=self.encoder_feat_size,
sample_range=self.sample_range,
max_iterations=self.max_sample_iterations
)
def predict(self):
ref_feats, tar_feat = self.extract_img_feats()
all_points, box, S, C, reduced_points_num = self.patch_level_matching(ref_feats=ref_feats, tar_feat=tar_feat)
points = self.clustering(all_points) if not self.use_points_or_centers else all_points
self.set_rps()
mask, mask_list = self.mask_generation(self.tar_img_np, points, box, all_points, self.ref_masks_pool, C)
return mask, mask_list
def extract_img_feats(self):
ref_imgs = torch.cat([self.encoder_transform(rimg)[None, ...] for rimg in self.ref_imgs], dim=0)
tar_img = torch.cat([self.encoder_transform(timg)[None, ...] for timg in self.tar_img], dim=0)
ref_feats = self.encoder.forward_features(ref_imgs.to(self.device))["x_prenorm"][:, 1:]
tar_feat = self.encoder.forward_features(tar_img.to(self.device))["x_prenorm"][:, 1:]
# ns, N, c = ref_feats.shape
ref_feats = ref_feats.reshape(-1, self.encoder.embed_dim) # ns*N, c
tar_feat = tar_feat.reshape(-1, self.encoder.embed_dim) # N, c
ref_feats = F.normalize(ref_feats, dim=1, p=2) # normalize for cosine similarity
tar_feat = F.normalize(tar_feat, dim=1, p=2)
return ref_feats, tar_feat
def patch_level_matching(self, ref_feats, tar_feat):
# forward matching
S = ref_feats @ tar_feat.t() # ns*N, N
C = (1 - S) / 2 # distance
S_forward = S[self.ref_masks_pool.flatten().bool()]
indices_forward = linear_sum_assignment(S_forward.cpu(), maximize=True)
indices_forward = [torch.as_tensor(index, dtype=torch.int64, device=self.device) for index in indices_forward]
sim_scores_f = S_forward[indices_forward[0], indices_forward[1]]
indices_mask = self.ref_masks_pool.flatten().nonzero()[:, 0]
# reverse matching
S_reverse = S.t()[indices_forward[1]]
indices_reverse = linear_sum_assignment(S_reverse.cpu(), maximize=True)
indices_reverse = [torch.as_tensor(index, dtype=torch.int64, device=self.device) for index in indices_reverse]
retain_ind = torch.isin(indices_reverse[1], indices_mask)
if not (retain_ind == False).all().item():
indices_forward = [indices_forward[0][retain_ind], indices_forward[1][retain_ind]]
sim_scores_f = sim_scores_f[retain_ind]
inds_matched, sim_matched = indices_forward, sim_scores_f
reduced_points_num = len(sim_matched) // 2 if len(sim_matched) > 40 else len(sim_matched)
sim_sorted, sim_idx_sorted = torch.sort(sim_matched, descending=True)
sim_filter = sim_idx_sorted[:reduced_points_num]
points_matched_inds = indices_forward[1][sim_filter]
points_matched_inds_set = torch.tensor(list(set(points_matched_inds.cpu().tolist())))
points_matched_inds_set_w = points_matched_inds_set % (self.encoder_feat_size)
points_matched_inds_set_h = points_matched_inds_set // (self.encoder_feat_size)
idxs_mask_set_x = (points_matched_inds_set_w * self.encoder.patch_size + self.encoder.patch_size // 2).tolist()
idxs_mask_set_y = (points_matched_inds_set_h * self.encoder.patch_size + self.encoder.patch_size // 2).tolist()
ponits_matched = []
for x, y in zip(idxs_mask_set_x, idxs_mask_set_y):
if int(x) < self.input_size[1] and int(y) < self.input_size[0]:
ponits_matched.append([int(x), int(y)])
ponits = np.array(ponits_matched)
if self.use_box:
box = np.array([
max(ponits[:, 0].min(), 0),
max(ponits[:, 1].min(), 0),
min(ponits[:, 0].max(), self.input_size[1] - 1),
min(ponits[:, 1].max(), self.input_size[0] - 1),
])
else:
box = None
return ponits, box, S, C, reduced_points_num
def clustering(self, points):
num_centers = min(self.num_centers, len(points))
flag = True
while (flag):
centers, cluster_assignment = kmeans_pp(points, num_centers)
id, fre = torch.unique(cluster_assignment, return_counts=True)
if id.shape[0] == num_centers:
flag = False
else:
print('Kmeans++ failed, re-run')
centers = np.array(centers).astype(np.int64)
return centers
def mask_generation(self, tar_img_np, points, box, all_ponits, ref_masks_pool, C):
samples_list, label_list = self.rps.sample_points(points)
tar_masks_ori = self.generator.generate(
tar_img_np,
select_point_coords=samples_list,
select_point_labels=label_list,
select_box=[box] if self.use_box else None,
)
tar_masks = torch.cat(
[torch.from_numpy(qmask['segmentation']).float()[None, None, ...].to(self.device) for
qmask in tar_masks_ori], dim=0).cpu().numpy() > 0
# append to original results
purity = torch.zeros(tar_masks.shape[0])
coverage = torch.zeros(tar_masks.shape[0])
emd = torch.zeros(tar_masks.shape[0])
samples = samples_list[-1]
labels = torch.ones(tar_masks.shape[0], samples.shape[1])
samples = torch.ones(tar_masks.shape[0], samples.shape[1], 2)
# compute scores for each mask
for i in range(len(tar_masks)):
purity_, coverage_, emd_, sample_, label_, mask_ = \
self.rps.get_mask_scores(
points=points,
masks=tar_masks[i],
all_points=all_ponits,
emd_cost=C,
ref_masks_pool=ref_masks_pool
)
assert np.all(mask_ == tar_masks[i])
purity[i] = purity_
coverage[i] = coverage_
emd[i] = emd_
pred_masks = tar_masks.squeeze(1)
metric_preds = {
"purity": purity,
"coverage": coverage,
"emd": emd
}
scores = self.alpha * emd + self.beta * purity * coverage ** self.exp
def check_pred_mask(pred_masks):
if len(pred_masks.shape) < 3: # avoid only one mask
pred_masks = pred_masks[None, ...]
return pred_masks
pred_masks = check_pred_mask(pred_masks)
# filter the false-positive mask fragments by using the proposed metrics
for metric in ["coverage", "emd", "purity"]:
if self.score_filter_cfg[metric] > 0:
thres = min(self.score_filter_cfg[metric], metric_preds[metric].max())
idx = torch.where(metric_preds[metric] >= thres)[0]
scores = scores[idx]
samples = samples[idx]
labels = labels[idx]
pred_masks = check_pred_mask(pred_masks[idx])
for key in metric_preds.keys():
metric_preds[key] = metric_preds[key][idx]
# score-based masks selection, masks merging
if self.score_filter_cfg["score_filter"]:
distances = 1 - scores
distances, rank = torch.sort(distances, descending=False)
distances_norm = distances - distances.min()
distances_norm = distances_norm / (distances.max() + 1e-6)
filer_dis = distances < self.score_filter_cfg["score"]
filer_dis[..., 0] = True
filer_dis_norm = distances_norm < self.score_filter_cfg["score_norm"]
filer_dis = filer_dis * filer_dis_norm
pred_masks = check_pred_mask(pred_masks)
masks = pred_masks[rank[filer_dis][:self.num_merging_mask]]
masks = check_pred_mask(masks)
mask_list = masks
masks = masks.sum(0) > 0
masks = masks[None, ...]
else:
topk = min(self.num_merging_mask, scores.size(0))
topk_idx = scores.topk(topk)[1]
topk_samples = samples[topk_idx].cpu().numpy()
topk_scores = scores[topk_idx].cpu().numpy()
topk_pred_masks = pred_masks[topk_idx]
topk_pred_masks = check_pred_mask(topk_pred_masks)
if self.score_filter_cfg["topk_scores_threshold"] > 0:
# map scores to 0-1
topk_scores = topk_scores / (topk_scores.max())
idx = topk_scores > self.score_filter_cfg["topk_scores_threshold"]
topk_samples = topk_samples[idx]
topk_pred_masks = check_pred_mask(topk_pred_masks)
topk_pred_masks = topk_pred_masks[idx]
mask_list = []
for i in range(len(topk_samples)):
mask = topk_pred_masks[i][None, ...]
mask_list.append(mask)
mask_list = np.concatenate(mask_list, axis=0)
masks = np.sum(mask_list, axis=0) > 0
masks = check_pred_mask(masks)
tar_img_ori_size = self.tar_img_ori_size
mask = torch.tensor(masks, device=self.device)[None, ...]
mask = F.interpolate(mask.float(), tar_img_ori_size, mode="bilinear", align_corners=False) > 0
mask = mask.squeeze(0).cpu().numpy()
if mask_list is not None:
mask_list = torch.tensor(mask_list, device=self.device)[:, None, ...]
mask_list = F.interpolate(mask_list.float(), tar_img_ori_size, mode="bilinear", align_corners=False)
mask_list = mask_list.squeeze(0).cpu().numpy()
return mask, mask_list
def clear(self):
self.tar_img = None
self.tar_img_np = None
self.tar_img_ori_size = None
self.ref_imgs = None
self.ref_masks_pool = None
self.nshot = None
self.encoder_img_size = None
self.encoder_feat_size = None
class RobustPromptSampler:
def __init__(
self,
encoder_feat_size,
sample_range,
max_iterations
):
self.encoder_feat_size = encoder_feat_size
self.sample_range = sample_range
self.max_iterations = max_iterations
def get_mask_scores(self, points, masks, all_points, emd_cost, ref_masks_pool):
def is_in_mask(point, mask):
# input: point: n*2, mask: h*w
# output: n*1
h, w = mask.shape
point = point.astype(np.int)
point = point[:, ::-1] # y,x
point = np.clip(point, 0, [h - 1, w - 1])
return mask[point[:, 0], point[:, 1]]
ori_masks = masks
masks = cv2.resize(
masks[0].astype(np.float32),
(self.encoder_feat_size, self.encoder_feat_size),
interpolation=cv2.INTER_AREA)
if masks.max() <= 0:
thres = masks.max() - 1e-6
else:
thres = 0
masks = masks > thres
# 1. emd
emd_cost_pool = emd_cost[ref_masks_pool.flatten().bool(), :][:, masks.flatten()]
emd = ot.emd2(a=[1. / emd_cost_pool.shape[0] for i in range(emd_cost_pool.shape[0])],
b=[1. / emd_cost_pool.shape[1] for i in range(emd_cost_pool.shape[1])],
M=emd_cost_pool.cpu().numpy())
emd_score = 1 - emd
labels = np.ones((points.shape[0],))
# 2. purity and coverage
assert all_points is not None
points_in_mask = is_in_mask(all_points, ori_masks[0])
points_in_mask = all_points[points_in_mask]
# here we define two metrics for local matching , purity and coverage
# purity: points_in/mask_area, the higher means the denser points in mask
# coverage: points_in / all_points, the higher means the mask is more complete
mask_area = max(float(masks.sum()), 1.0)
purity = points_in_mask.shape[0] / mask_area
coverage = points_in_mask.shape[0] / all_points.shape[0]
purity = torch.tensor([purity]) + 1e-6
coverage = torch.tensor([coverage]) + 1e-6
return purity, coverage, emd_score, points, labels, ori_masks
def combinations(self, n, k):
if k > n:
return []
if k == 0:
return [[]]
if k == n:
return [[i for i in range(n)]]
res = []
for i in range(n):
for j in self.combinations(i, k - 1):
res.append(j + [i])
return res
def sample_points(self, points):
# return list of arrary
sample_list = []
label_list = []
for i in range(min(self.sample_range[0], len(points)), min(self.sample_range[1], len(points)) + 1):
if len(points) > 8:
index = [random.sample(range(len(points)), i) for j in range(self.max_iterations)]
sample = np.take(points, index, axis=0) # (max_iterations * i) * 2
else:
index = self.combinations(len(points), i)
sample = np.take(points, index, axis=0) # i * n * 2
# generate label max_iterations * i
label = np.ones((sample.shape[0], i))
sample_list.append(sample)
label_list.append(label)
return sample_list, label_list