Spaces:
Runtime error
Runtime error
Upload 19 files
Browse files- .gitattributes +10 -0
- Matcher.py +452 -0
- app_running.py +173 -0
- images/dinosaur1.png +3 -0
- images/dinosaur2.png +3 -0
- images/dinosaur3.png +3 -0
- images/earth1.png +0 -0
- images/earth2.png +0 -0
- images/earth3.png +0 -0
- images/elephant1.png +3 -0
- images/elephant2.png +3 -0
- images/elephant3.png +3 -0
- images/hmbb1.png +0 -0
- images/hmbb2.png +0 -0
- images/hmbb3.png +3 -0
- images/horse1.png +3 -0
- images/horse2.png +3 -0
- images/horse3.png +3 -0
- oss_ops_inference.py +288 -0
- runner.py +156 -0
.gitattributes
CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
images/dinosaur1.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
images/dinosaur2.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
images/dinosaur3.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
images/elephant1.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
images/elephant2.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
images/elephant3.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
images/hmbb3.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
images/horse1.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
images/horse2.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
images/horse3.png filter=lfs diff=lfs merge=lfs -text
|
Matcher.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from os import path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import cv2
|
11 |
+
import ot
|
12 |
+
import math
|
13 |
+
from scipy.optimize import linear_sum_assignment
|
14 |
+
|
15 |
+
from segment_anything import sam_model_registry
|
16 |
+
from segment_anything import SamAutomaticMaskGenerator
|
17 |
+
from dinov2.models import vision_transformer as vits
|
18 |
+
import dinov2.utils.utils as dinov2_utils
|
19 |
+
from dinov2.data.transforms import MaybeToTensor, make_normalize_transform
|
20 |
+
|
21 |
+
from matcher.k_means import kmeans_pp
|
22 |
+
|
23 |
+
import random
|
24 |
+
|
25 |
+
class Matcher:
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
encoder,
|
29 |
+
generator=None,
|
30 |
+
input_size=518,
|
31 |
+
num_centers=8,
|
32 |
+
use_box=False,
|
33 |
+
use_points_or_centers=True,
|
34 |
+
sample_range=(4, 6),
|
35 |
+
max_sample_iterations=30,
|
36 |
+
alpha=1.,
|
37 |
+
beta=0.,
|
38 |
+
exp=0.,
|
39 |
+
score_filter_cfg=None,
|
40 |
+
num_merging_mask=10,
|
41 |
+
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
|
42 |
+
):
|
43 |
+
# models
|
44 |
+
self.encoder = encoder
|
45 |
+
self.generator = generator
|
46 |
+
self.rps = None
|
47 |
+
|
48 |
+
if not isinstance(input_size, tuple):
|
49 |
+
input_size = (input_size, input_size)
|
50 |
+
self.input_size = input_size
|
51 |
+
|
52 |
+
# transforms for image encoder
|
53 |
+
self.encoder_transform = transforms.Compose([
|
54 |
+
MaybeToTensor(),
|
55 |
+
make_normalize_transform(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
56 |
+
])
|
57 |
+
|
58 |
+
self.tar_img = None
|
59 |
+
self.tar_img_np = None
|
60 |
+
self.tar_img_ori_size = None
|
61 |
+
|
62 |
+
self.ref_imgs = None
|
63 |
+
self.ref_masks_pool = None
|
64 |
+
self.nshot = None
|
65 |
+
|
66 |
+
self.encoder_img_size = None
|
67 |
+
self.encoder_feat_size = None
|
68 |
+
|
69 |
+
self.num_centers = num_centers
|
70 |
+
self.use_box = use_box
|
71 |
+
self.use_points_or_centers = use_points_or_centers
|
72 |
+
self.sample_range = sample_range
|
73 |
+
self.max_sample_iterations =max_sample_iterations
|
74 |
+
|
75 |
+
self.alpha, self.beta, self.exp = alpha, beta, exp
|
76 |
+
assert score_filter_cfg is not None
|
77 |
+
self.score_filter_cfg = score_filter_cfg
|
78 |
+
self.num_merging_mask = num_merging_mask
|
79 |
+
|
80 |
+
self.device = device
|
81 |
+
|
82 |
+
def set_reference(self, imgs, masks):
|
83 |
+
|
84 |
+
def reference_masks_verification(masks):
|
85 |
+
if masks.sum() == 0:
|
86 |
+
_, _, sh, sw = masks.shape
|
87 |
+
masks[..., (sh // 2 - 7):(sh // 2 + 7), (sw // 2 - 7):(sw // 2 + 7)] = 1
|
88 |
+
return masks
|
89 |
+
|
90 |
+
imgs = imgs.flatten(0, 1) # bs, 3, h, w
|
91 |
+
img_size = imgs.shape[-1]
|
92 |
+
assert img_size == self.input_size[-1]
|
93 |
+
feat_size = img_size // self.encoder.patch_size
|
94 |
+
|
95 |
+
self.encoder_img_size = img_size
|
96 |
+
self.encoder_feat_size = feat_size
|
97 |
+
|
98 |
+
# process reference masks
|
99 |
+
masks = reference_masks_verification(masks)
|
100 |
+
masks = masks.permute(1, 0, 2, 3) # ns, 1, h, w
|
101 |
+
ref_masks_pool = F.avg_pool2d(masks.float(), (self.encoder.patch_size, self.encoder.patch_size))
|
102 |
+
nshot = ref_masks_pool.shape[0]
|
103 |
+
ref_masks_pool = (ref_masks_pool > self.generator.predictor.model.mask_threshold).float()
|
104 |
+
ref_masks_pool = ref_masks_pool.reshape(-1) # nshot, N
|
105 |
+
|
106 |
+
self.ref_imgs = imgs
|
107 |
+
self.ref_masks_pool = ref_masks_pool
|
108 |
+
self.nshot = nshot
|
109 |
+
|
110 |
+
def set_target(self, img, tar_img_ori_size):
|
111 |
+
|
112 |
+
img_h, img_w = img.shape[-2:]
|
113 |
+
assert img_h == self.input_size[0] and img_w == self.input_size[1]
|
114 |
+
|
115 |
+
# transform query to numpy as input of sam
|
116 |
+
img_np = img.mul(255).byte()
|
117 |
+
img_np = img_np.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
118 |
+
|
119 |
+
self.tar_img = img
|
120 |
+
self.tar_img_np = img_np
|
121 |
+
self.tar_img_ori_size = tar_img_ori_size
|
122 |
+
|
123 |
+
def set_rps(self):
|
124 |
+
if self.rps is None:
|
125 |
+
assert self.encoder_feat_size is not None
|
126 |
+
self.rps = RobustPromptSampler(
|
127 |
+
encoder_feat_size=self.encoder_feat_size,
|
128 |
+
sample_range=self.sample_range,
|
129 |
+
max_iterations=self.max_sample_iterations
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
def predict(self):
|
134 |
+
|
135 |
+
ref_feats, tar_feat = self.extract_img_feats()
|
136 |
+
all_points, box, S, C, reduced_points_num = self.patch_level_matching(ref_feats=ref_feats, tar_feat=tar_feat)
|
137 |
+
points = self.clustering(all_points) if not self.use_points_or_centers else all_points
|
138 |
+
self.set_rps()
|
139 |
+
mask, mask_list = self.mask_generation(self.tar_img_np, points, box, all_points, self.ref_masks_pool, C)
|
140 |
+
return mask, mask_list
|
141 |
+
|
142 |
+
|
143 |
+
def extract_img_feats(self):
|
144 |
+
|
145 |
+
ref_imgs = torch.cat([self.encoder_transform(rimg)[None, ...] for rimg in self.ref_imgs], dim=0)
|
146 |
+
tar_img = torch.cat([self.encoder_transform(timg)[None, ...] for timg in self.tar_img], dim=0)
|
147 |
+
|
148 |
+
ref_feats = self.encoder.forward_features(ref_imgs.to(self.device))["x_prenorm"][:, 1:]
|
149 |
+
tar_feat = self.encoder.forward_features(tar_img.to(self.device))["x_prenorm"][:, 1:]
|
150 |
+
# ns, N, c = ref_feats.shape
|
151 |
+
ref_feats = ref_feats.reshape(-1, self.encoder.embed_dim) # ns*N, c
|
152 |
+
tar_feat = tar_feat.reshape(-1, self.encoder.embed_dim) # N, c
|
153 |
+
|
154 |
+
ref_feats = F.normalize(ref_feats, dim=1, p=2) # normalize for cosine similarity
|
155 |
+
tar_feat = F.normalize(tar_feat, dim=1, p=2)
|
156 |
+
|
157 |
+
return ref_feats, tar_feat
|
158 |
+
|
159 |
+
def patch_level_matching(self, ref_feats, tar_feat):
|
160 |
+
|
161 |
+
# forward matching
|
162 |
+
S = ref_feats @ tar_feat.t() # ns*N, N
|
163 |
+
C = (1 - S) / 2 # distance
|
164 |
+
|
165 |
+
S_forward = S[self.ref_masks_pool.flatten().bool()]
|
166 |
+
|
167 |
+
indices_forward = linear_sum_assignment(S_forward.cpu(), maximize=True)
|
168 |
+
indices_forward = [torch.as_tensor(index, dtype=torch.int64, device=self.device) for index in indices_forward]
|
169 |
+
sim_scores_f = S_forward[indices_forward[0], indices_forward[1]]
|
170 |
+
indices_mask = self.ref_masks_pool.flatten().nonzero()[:, 0]
|
171 |
+
|
172 |
+
# reverse matching
|
173 |
+
S_reverse = S.t()[indices_forward[1]]
|
174 |
+
indices_reverse = linear_sum_assignment(S_reverse.cpu(), maximize=True)
|
175 |
+
indices_reverse = [torch.as_tensor(index, dtype=torch.int64, device=self.device) for index in indices_reverse]
|
176 |
+
retain_ind = torch.isin(indices_reverse[1], indices_mask)
|
177 |
+
if not (retain_ind == False).all().item():
|
178 |
+
indices_forward = [indices_forward[0][retain_ind], indices_forward[1][retain_ind]]
|
179 |
+
sim_scores_f = sim_scores_f[retain_ind]
|
180 |
+
inds_matched, sim_matched = indices_forward, sim_scores_f
|
181 |
+
|
182 |
+
reduced_points_num = len(sim_matched) // 2 if len(sim_matched) > 40 else len(sim_matched)
|
183 |
+
sim_sorted, sim_idx_sorted = torch.sort(sim_matched, descending=True)
|
184 |
+
sim_filter = sim_idx_sorted[:reduced_points_num]
|
185 |
+
points_matched_inds = indices_forward[1][sim_filter]
|
186 |
+
|
187 |
+
points_matched_inds_set = torch.tensor(list(set(points_matched_inds.cpu().tolist())))
|
188 |
+
points_matched_inds_set_w = points_matched_inds_set % (self.encoder_feat_size)
|
189 |
+
points_matched_inds_set_h = points_matched_inds_set // (self.encoder_feat_size)
|
190 |
+
idxs_mask_set_x = (points_matched_inds_set_w * self.encoder.patch_size + self.encoder.patch_size // 2).tolist()
|
191 |
+
idxs_mask_set_y = (points_matched_inds_set_h * self.encoder.patch_size + self.encoder.patch_size // 2).tolist()
|
192 |
+
|
193 |
+
ponits_matched = []
|
194 |
+
for x, y in zip(idxs_mask_set_x, idxs_mask_set_y):
|
195 |
+
if int(x) < self.input_size[1] and int(y) < self.input_size[0]:
|
196 |
+
ponits_matched.append([int(x), int(y)])
|
197 |
+
ponits = np.array(ponits_matched)
|
198 |
+
|
199 |
+
if self.use_box:
|
200 |
+
box = np.array([
|
201 |
+
max(ponits[:, 0].min(), 0),
|
202 |
+
max(ponits[:, 1].min(), 0),
|
203 |
+
min(ponits[:, 0].max(), self.input_size[1] - 1),
|
204 |
+
min(ponits[:, 1].max(), self.input_size[0] - 1),
|
205 |
+
])
|
206 |
+
else:
|
207 |
+
box = None
|
208 |
+
|
209 |
+
return ponits, box, S, C, reduced_points_num
|
210 |
+
|
211 |
+
def clustering(self, points):
|
212 |
+
|
213 |
+
num_centers = min(self.num_centers, len(points))
|
214 |
+
flag = True
|
215 |
+
while (flag):
|
216 |
+
centers, cluster_assignment = kmeans_pp(points, num_centers)
|
217 |
+
id, fre = torch.unique(cluster_assignment, return_counts=True)
|
218 |
+
if id.shape[0] == num_centers:
|
219 |
+
flag = False
|
220 |
+
else:
|
221 |
+
print('Kmeans++ failed, re-run')
|
222 |
+
centers = np.array(centers).astype(np.int64)
|
223 |
+
return centers
|
224 |
+
|
225 |
+
|
226 |
+
def mask_generation(self, tar_img_np, points, box, all_ponits, ref_masks_pool, C):
|
227 |
+
samples_list, label_list = self.rps.sample_points(points)
|
228 |
+
tar_masks_ori = self.generator.generate(
|
229 |
+
tar_img_np,
|
230 |
+
select_point_coords=samples_list,
|
231 |
+
select_point_labels=label_list,
|
232 |
+
select_box=[box] if self.use_box else None,
|
233 |
+
)
|
234 |
+
tar_masks = torch.cat(
|
235 |
+
[torch.from_numpy(qmask['segmentation']).float()[None, None, ...].to(self.device) for
|
236 |
+
qmask in tar_masks_ori], dim=0).cpu().numpy() > 0
|
237 |
+
|
238 |
+
# append to original results
|
239 |
+
purity = torch.zeros(tar_masks.shape[0])
|
240 |
+
coverage = torch.zeros(tar_masks.shape[0])
|
241 |
+
emd = torch.zeros(tar_masks.shape[0])
|
242 |
+
|
243 |
+
samples = samples_list[-1]
|
244 |
+
labels = torch.ones(tar_masks.shape[0], samples.shape[1])
|
245 |
+
samples = torch.ones(tar_masks.shape[0], samples.shape[1], 2)
|
246 |
+
|
247 |
+
# compute scores for each mask
|
248 |
+
for i in range(len(tar_masks)):
|
249 |
+
purity_, coverage_, emd_, sample_, label_, mask_ = \
|
250 |
+
self.rps.get_mask_scores(
|
251 |
+
points=points,
|
252 |
+
masks=tar_masks[i],
|
253 |
+
all_points=all_ponits,
|
254 |
+
emd_cost=C,
|
255 |
+
ref_masks_pool=ref_masks_pool
|
256 |
+
)
|
257 |
+
assert np.all(mask_ == tar_masks[i])
|
258 |
+
purity[i] = purity_
|
259 |
+
coverage[i] = coverage_
|
260 |
+
emd[i] = emd_
|
261 |
+
|
262 |
+
pred_masks = tar_masks.squeeze(1)
|
263 |
+
metric_preds = {
|
264 |
+
"purity": purity,
|
265 |
+
"coverage": coverage,
|
266 |
+
"emd": emd
|
267 |
+
}
|
268 |
+
|
269 |
+
scores = self.alpha * emd + self.beta * purity * coverage ** self.exp
|
270 |
+
|
271 |
+
def check_pred_mask(pred_masks):
|
272 |
+
if len(pred_masks.shape) < 3: # avoid only one mask
|
273 |
+
pred_masks = pred_masks[None, ...]
|
274 |
+
return pred_masks
|
275 |
+
|
276 |
+
pred_masks = check_pred_mask(pred_masks)
|
277 |
+
|
278 |
+
# filter the false-positive mask fragments by using the proposed metrics
|
279 |
+
for metric in ["coverage", "emd", "purity"]:
|
280 |
+
if self.score_filter_cfg[metric] > 0:
|
281 |
+
thres = min(self.score_filter_cfg[metric], metric_preds[metric].max())
|
282 |
+
idx = torch.where(metric_preds[metric] >= thres)[0]
|
283 |
+
scores = scores[idx]
|
284 |
+
samples = samples[idx]
|
285 |
+
labels = labels[idx]
|
286 |
+
pred_masks = check_pred_mask(pred_masks[idx])
|
287 |
+
|
288 |
+
for key in metric_preds.keys():
|
289 |
+
metric_preds[key] = metric_preds[key][idx]
|
290 |
+
|
291 |
+
# score-based masks selection, masks merging
|
292 |
+
if self.score_filter_cfg["score_filter"]:
|
293 |
+
|
294 |
+
distances = 1 - scores
|
295 |
+
distances, rank = torch.sort(distances, descending=False)
|
296 |
+
distances_norm = distances - distances.min()
|
297 |
+
distances_norm = distances_norm / (distances.max() + 1e-6)
|
298 |
+
filer_dis = distances < self.score_filter_cfg["score"]
|
299 |
+
filer_dis[..., 0] = True
|
300 |
+
filer_dis_norm = distances_norm < self.score_filter_cfg["score_norm"]
|
301 |
+
filer_dis = filer_dis * filer_dis_norm
|
302 |
+
|
303 |
+
pred_masks = check_pred_mask(pred_masks)
|
304 |
+
masks = pred_masks[rank[filer_dis][:self.num_merging_mask]]
|
305 |
+
masks = check_pred_mask(masks)
|
306 |
+
mask_list = masks
|
307 |
+
masks = masks.sum(0) > 0
|
308 |
+
masks = masks[None, ...]
|
309 |
+
|
310 |
+
else:
|
311 |
+
|
312 |
+
topk = min(self.num_merging_mask, scores.size(0))
|
313 |
+
topk_idx = scores.topk(topk)[1]
|
314 |
+
topk_samples = samples[topk_idx].cpu().numpy()
|
315 |
+
topk_scores = scores[topk_idx].cpu().numpy()
|
316 |
+
topk_pred_masks = pred_masks[topk_idx]
|
317 |
+
topk_pred_masks = check_pred_mask(topk_pred_masks)
|
318 |
+
|
319 |
+
if self.score_filter_cfg["topk_scores_threshold"] > 0:
|
320 |
+
# map scores to 0-1
|
321 |
+
topk_scores = topk_scores / (topk_scores.max())
|
322 |
+
|
323 |
+
idx = topk_scores > self.score_filter_cfg["topk_scores_threshold"]
|
324 |
+
topk_samples = topk_samples[idx]
|
325 |
+
|
326 |
+
topk_pred_masks = check_pred_mask(topk_pred_masks)
|
327 |
+
topk_pred_masks = topk_pred_masks[idx]
|
328 |
+
mask_list = []
|
329 |
+
for i in range(len(topk_samples)):
|
330 |
+
mask = topk_pred_masks[i][None, ...]
|
331 |
+
mask_list.append(mask)
|
332 |
+
mask_list = np.concatenate(mask_list, axis=0)
|
333 |
+
masks = np.sum(mask_list, axis=0) > 0
|
334 |
+
masks = check_pred_mask(masks)
|
335 |
+
|
336 |
+
tar_img_ori_size = self.tar_img_ori_size
|
337 |
+
mask = torch.tensor(masks, device=self.device)[None, ...]
|
338 |
+
mask = F.interpolate(mask.float(), tar_img_ori_size, mode="bilinear", align_corners=False) > 0
|
339 |
+
mask = mask.squeeze(0).cpu().numpy()
|
340 |
+
if mask_list is not None:
|
341 |
+
mask_list = torch.tensor(mask_list, device=self.device)[:, None, ...]
|
342 |
+
mask_list = F.interpolate(mask_list.float(), tar_img_ori_size, mode="bilinear", align_corners=False)
|
343 |
+
mask_list = mask_list.squeeze(0).cpu().numpy()
|
344 |
+
|
345 |
+
return mask, mask_list
|
346 |
+
|
347 |
+
|
348 |
+
def clear(self):
|
349 |
+
|
350 |
+
self.tar_img = None
|
351 |
+
self.tar_img_np = None
|
352 |
+
self.tar_img_ori_size = None
|
353 |
+
|
354 |
+
self.ref_imgs = None
|
355 |
+
self.ref_masks_pool = None
|
356 |
+
self.nshot = None
|
357 |
+
|
358 |
+
self.encoder_img_size = None
|
359 |
+
self.encoder_feat_size = None
|
360 |
+
|
361 |
+
|
362 |
+
|
363 |
+
class RobustPromptSampler:
|
364 |
+
|
365 |
+
def __init__(
|
366 |
+
self,
|
367 |
+
encoder_feat_size,
|
368 |
+
sample_range,
|
369 |
+
max_iterations
|
370 |
+
):
|
371 |
+
self.encoder_feat_size = encoder_feat_size
|
372 |
+
self.sample_range = sample_range
|
373 |
+
self.max_iterations = max_iterations
|
374 |
+
|
375 |
+
|
376 |
+
def get_mask_scores(self, points, masks, all_points, emd_cost, ref_masks_pool):
|
377 |
+
|
378 |
+
def is_in_mask(point, mask):
|
379 |
+
# input: point: n*2, mask: h*w
|
380 |
+
# output: n*1
|
381 |
+
h, w = mask.shape
|
382 |
+
point = point.astype(np.int)
|
383 |
+
point = point[:, ::-1] # y,x
|
384 |
+
point = np.clip(point, 0, [h - 1, w - 1])
|
385 |
+
return mask[point[:, 0], point[:, 1]]
|
386 |
+
|
387 |
+
ori_masks = masks
|
388 |
+
masks = cv2.resize(
|
389 |
+
masks[0].astype(np.float32),
|
390 |
+
(self.encoder_feat_size, self.encoder_feat_size),
|
391 |
+
interpolation=cv2.INTER_AREA)
|
392 |
+
if masks.max() <= 0:
|
393 |
+
thres = masks.max() - 1e-6
|
394 |
+
else:
|
395 |
+
thres = 0
|
396 |
+
masks = masks > thres
|
397 |
+
|
398 |
+
# 1. emd
|
399 |
+
emd_cost_pool = emd_cost[ref_masks_pool.flatten().bool(), :][:, masks.flatten()]
|
400 |
+
emd = ot.emd2(a=[1. / emd_cost_pool.shape[0] for i in range(emd_cost_pool.shape[0])],
|
401 |
+
b=[1. / emd_cost_pool.shape[1] for i in range(emd_cost_pool.shape[1])],
|
402 |
+
M=emd_cost_pool.cpu().numpy())
|
403 |
+
emd_score = 1 - emd
|
404 |
+
|
405 |
+
labels = np.ones((points.shape[0],))
|
406 |
+
|
407 |
+
# 2. purity and coverage
|
408 |
+
assert all_points is not None
|
409 |
+
points_in_mask = is_in_mask(all_points, ori_masks[0])
|
410 |
+
points_in_mask = all_points[points_in_mask]
|
411 |
+
# here we define two metrics for local matching , purity and coverage
|
412 |
+
# purity: points_in/mask_area, the higher means the denser points in mask
|
413 |
+
# coverage: points_in / all_points, the higher means the mask is more complete
|
414 |
+
mask_area = max(float(masks.sum()), 1.0)
|
415 |
+
purity = points_in_mask.shape[0] / mask_area
|
416 |
+
coverage = points_in_mask.shape[0] / all_points.shape[0]
|
417 |
+
purity = torch.tensor([purity]) + 1e-6
|
418 |
+
coverage = torch.tensor([coverage]) + 1e-6
|
419 |
+
return purity, coverage, emd_score, points, labels, ori_masks
|
420 |
+
|
421 |
+
def combinations(self, n, k):
|
422 |
+
if k > n:
|
423 |
+
return []
|
424 |
+
if k == 0:
|
425 |
+
return [[]]
|
426 |
+
if k == n:
|
427 |
+
return [[i for i in range(n)]]
|
428 |
+
res = []
|
429 |
+
for i in range(n):
|
430 |
+
for j in self.combinations(i, k - 1):
|
431 |
+
res.append(j + [i])
|
432 |
+
return res
|
433 |
+
|
434 |
+
def sample_points(self, points):
|
435 |
+
# return list of arrary
|
436 |
+
|
437 |
+
sample_list = []
|
438 |
+
label_list = []
|
439 |
+
for i in range(min(self.sample_range[0], len(points)), min(self.sample_range[1], len(points)) + 1):
|
440 |
+
if len(points) > 8:
|
441 |
+
index = [random.sample(range(len(points)), i) for j in range(self.max_iterations)]
|
442 |
+
sample = np.take(points, index, axis=0) # (max_iterations * i) * 2
|
443 |
+
else:
|
444 |
+
index = self.combinations(len(points), i)
|
445 |
+
sample = np.take(points, index, axis=0) # i * n * 2
|
446 |
+
|
447 |
+
# generate label max_iterations * i
|
448 |
+
label = np.ones((sample.shape[0], i))
|
449 |
+
sample_list.append(sample)
|
450 |
+
label_list.append(label)
|
451 |
+
return sample_list, label_list
|
452 |
+
|
app_running.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
import gradio as gr
|
6 |
+
from gradio_demo.runner import Runner
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
def show_mask(mask, ax, color='blue'):
|
10 |
+
if color == 'blue':
|
11 |
+
# reference, blue
|
12 |
+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
13 |
+
else:
|
14 |
+
# target, green
|
15 |
+
color = np.array([78 / 255, 238 / 255, 148 / 255, 0.6])
|
16 |
+
|
17 |
+
# if random_color:
|
18 |
+
# color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
19 |
+
# else:
|
20 |
+
# color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
21 |
+
h, w = mask.shape[-2:]
|
22 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
23 |
+
ax.imshow(mask_image)
|
24 |
+
|
25 |
+
|
26 |
+
def show_points(coords, labels, ax, marker_size=375):
|
27 |
+
pos_points = coords[labels == 1]
|
28 |
+
neg_points = coords[labels == 0]
|
29 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
|
30 |
+
linewidth=1.25)
|
31 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
|
32 |
+
linewidth=1.25)
|
33 |
+
|
34 |
+
def show_box(box, ax):
|
35 |
+
x0, y0 = box[0], box[1]
|
36 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
37 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
38 |
+
|
39 |
+
|
40 |
+
def show_img_point_box_mask(img, input_point=None, input_label=None, box=None, masks=None, save_path=None, mode='mask', color='blue'):
|
41 |
+
|
42 |
+
if mode == 'point':
|
43 |
+
# point
|
44 |
+
plt.figure(figsize=(10, 10))
|
45 |
+
plt.imshow(img)
|
46 |
+
show_points(input_point, input_label, plt.gca())
|
47 |
+
plt.axis('on')
|
48 |
+
plt.savefig(save_path, bbox_inches='tight')
|
49 |
+
elif mode == 'box':
|
50 |
+
# box
|
51 |
+
plt.figure(figsize=(10, 10))
|
52 |
+
plt.imshow(img)
|
53 |
+
show_box(box, plt.gca())
|
54 |
+
plt.axis('on')
|
55 |
+
plt.savefig(save_path, bbox_inches='tight')
|
56 |
+
else:
|
57 |
+
# mask
|
58 |
+
plt.figure(figsize=(10, 10))
|
59 |
+
plt.imshow(img)
|
60 |
+
show_mask(masks, plt.gca(), color=color)
|
61 |
+
plt.axis('off')
|
62 |
+
plt.savefig(save_path, bbox_inches='tight')
|
63 |
+
plt.close()
|
64 |
+
|
65 |
+
|
66 |
+
def create_oss_demo(
|
67 |
+
runner: Runner,
|
68 |
+
pipe: None = None
|
69 |
+
) -> gr.Blocks:
|
70 |
+
|
71 |
+
examples = [
|
72 |
+
['./gradio_demo/images/horse1.png', './gradio_demo/images/horse2.png', './gradio_demo/images/horse3.png'],
|
73 |
+
['./gradio_demo/images/hmbb1.png', './gradio_demo/images/hmbb2.png', './gradio_demo/images/hmbb3.png'],
|
74 |
+
['./gradio_demo/images/earth1.png', './gradio_demo/images/earth2.png', './gradio_demo/images/earth3.png'],
|
75 |
+
['./gradio_demo/images/elephant1.png', './gradio_demo/images/elephant2.png', './gradio_demo/images/elephant3.png'],
|
76 |
+
['./gradio_demo/images/dinosaur1.png', './gradio_demo/images/dinosaur2.png', './gradio_demo/images/dinosaur3.png'],
|
77 |
+
]
|
78 |
+
|
79 |
+
with gr.Blocks() as oss_demo:
|
80 |
+
with gr.Column():
|
81 |
+
|
82 |
+
# inputs
|
83 |
+
with gr.Row():
|
84 |
+
img_input_prompt = gr.ImageMask(label='Prompt (ζη€ΊεΎ)')
|
85 |
+
img_input_target1 = gr.Image(label='Target 1 (ζ΅θ―εΎ1)')
|
86 |
+
img_input_target2 = gr.Image(label='Target 2 (ζ΅θ―εΎ2)')
|
87 |
+
|
88 |
+
version = gr.inputs.Radio(['version 1 (πΊ multiple instances π» whole, π» part)',
|
89 |
+
'version 2 (π» multiple instances πΊ whole, π» part)',
|
90 |
+
'version 3 (π» multiple instances π» whole, πΊ part)'],
|
91 |
+
type="value", default='version 1 (πΊ whole, π» part)',
|
92 |
+
label='Multiple Instances (version 1), Single Instance (version 2), Part of a object (version 3)')
|
93 |
+
|
94 |
+
with gr.Row():
|
95 |
+
submit1 = gr.Button("ζδΊ€ (Submit)")
|
96 |
+
clear = gr.Button("ζΈ
ι€ (Clear)")
|
97 |
+
info = gr.Text(label="Processing result: ", interactive=False)
|
98 |
+
|
99 |
+
# decision
|
100 |
+
K = gr.Slider(0, 10, 10, step=1, label="Controllable mask output", interactive=True)
|
101 |
+
submit2 = gr.Button("ζδΊ€ (Submit)")
|
102 |
+
|
103 |
+
# outputs
|
104 |
+
with gr.Row():
|
105 |
+
img_output_pmt = gr.Image(label='Prompt (ζη€ΊεΎ)')
|
106 |
+
img_output_tar1 = gr.Image(label='Output 1 (θΎεΊεΎ1)')
|
107 |
+
img_output_tar2 = gr.Image(label='Output 2 (θΎεΊεΎ2)')
|
108 |
+
|
109 |
+
# images
|
110 |
+
gr.Examples(
|
111 |
+
examples=examples,
|
112 |
+
fn=runner.inference_oss_ops,
|
113 |
+
inputs=[img_input_prompt, img_input_target1, img_input_target2],
|
114 |
+
outputs=info
|
115 |
+
)
|
116 |
+
|
117 |
+
submit1.click(
|
118 |
+
fn=runner.inference_oss_ops,
|
119 |
+
inputs=[img_input_prompt, img_input_target1, img_input_target2, version],
|
120 |
+
outputs=info
|
121 |
+
)
|
122 |
+
submit2.click(
|
123 |
+
fn=runner.controllable_mask_output,
|
124 |
+
inputs=K,
|
125 |
+
outputs=[img_output_pmt, img_output_tar1, img_output_tar2]
|
126 |
+
)
|
127 |
+
|
128 |
+
clear.click(
|
129 |
+
fn=runner.clear_fn,
|
130 |
+
inputs=None,
|
131 |
+
outputs=[img_input_prompt, img_input_target1, img_input_target2, info, img_output_pmt, img_output_tar1, img_output_tar2],
|
132 |
+
queue=False
|
133 |
+
)
|
134 |
+
|
135 |
+
return oss_demo
|
136 |
+
|
137 |
+
|
138 |
+
def create_vos_demo(
|
139 |
+
runner: Runner,
|
140 |
+
pipe: None = None
|
141 |
+
) -> gr.Interface:
|
142 |
+
|
143 |
+
raise NotImplementedError
|
144 |
+
|
145 |
+
def create_demo(
|
146 |
+
runner: Runner,
|
147 |
+
pipe: None = None
|
148 |
+
) -> gr.TabbedInterface:
|
149 |
+
|
150 |
+
title = "Matcherπ―: Segment Anything with One Shot Using All-Purpose Feature Matching<br> \
|
151 |
+
<div align='center'> \
|
152 |
+
<h2><a href='https://arxiv.org/abs/2305.13310' target='_blank' rel='noopener'>[paper]</a> \
|
153 |
+
<a href='https://github.com/aim-uofa/Matcher' target='_blank' rel='noopener'>[code]</a></h2> \
|
154 |
+
<h2>Matcher can segment anything with one shot by integrating an all-purpose feature extraction model and a class-agnostic segmentation model.</h2> \
|
155 |
+
<br> \
|
156 |
+
</div> \
|
157 |
+
"
|
158 |
+
|
159 |
+
oss_demo = create_oss_demo(runner=runner, pipe=pipe)
|
160 |
+
# vos_demo = create_vos_demo(runner=runner, pipe=pipe)
|
161 |
+
demo = gr.TabbedInterface(
|
162 |
+
[oss_demo,],
|
163 |
+
['OSS+OPS',], title=title)
|
164 |
+
return demo
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == '__main__':
|
168 |
+
pipe = None
|
169 |
+
HF_TOKEN = os.getenv('HF_TOKEN')
|
170 |
+
runner = Runner(HF_TOKEN)
|
171 |
+
# runner = None
|
172 |
+
demo = create_demo(runner, pipe)
|
173 |
+
demo.launch(enable_queue=False)
|
images/dinosaur1.png
ADDED
![]() |
Git LFS Details
|
images/dinosaur2.png
ADDED
![]() |
Git LFS Details
|
images/dinosaur3.png
ADDED
![]() |
Git LFS Details
|
images/earth1.png
ADDED
![]() |
images/earth2.png
ADDED
![]() |
images/earth3.png
ADDED
![]() |
images/elephant1.png
ADDED
![]() |
Git LFS Details
|
images/elephant2.png
ADDED
![]() |
Git LFS Details
|
images/elephant3.png
ADDED
![]() |
Git LFS Details
|
images/hmbb1.png
ADDED
![]() |
images/hmbb2.png
ADDED
![]() |
images/hmbb3.png
ADDED
![]() |
Git LFS Details
|
images/horse1.png
ADDED
![]() |
Git LFS Details
|
images/horse2.png
ADDED
![]() |
Git LFS Details
|
images/horse3.png
ADDED
![]() |
Git LFS Details
|
oss_ops_inference.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" HyperAverageMetercorrelation Squeeze testing code """
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
from os.path import join
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torchvision import transforms
|
12 |
+
import numpy as np
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from segment_anything import SamPredictor, SamAutomaticMaskGenerator
|
16 |
+
|
17 |
+
from gradio_demo.Matcher import Matcher
|
18 |
+
from matcher.common import utils
|
19 |
+
|
20 |
+
import random
|
21 |
+
random.seed(0)
|
22 |
+
|
23 |
+
|
24 |
+
def default_argument_parser():
|
25 |
+
|
26 |
+
# Arguments parsing
|
27 |
+
parser = argparse.ArgumentParser(description='Matcher Pytorch Implementation for One-shot Segmentation')
|
28 |
+
|
29 |
+
# Dataset parameters
|
30 |
+
parser.add_argument('--datapath', type=str, default='datasets')
|
31 |
+
parser.add_argument('--benchmark', type=str, default='coco',
|
32 |
+
choices=['fss', 'coco', 'lvis', 'paco_part', 'pascal_part'])
|
33 |
+
parser.add_argument('--bsz', type=int, default=1)
|
34 |
+
parser.add_argument('--nworker', type=int, default=0)
|
35 |
+
parser.add_argument('--fold', type=int, default=0)
|
36 |
+
parser.add_argument('--nshot', type=int, default=1)
|
37 |
+
parser.add_argument('--img-size', type=int, default=518)
|
38 |
+
parser.add_argument('--use_original_imgsize', action='store_true')
|
39 |
+
parser.add_argument('--log-root', type=str, default='output/coco/fold0')
|
40 |
+
parser.add_argument('--visualize', type=int, default=0)
|
41 |
+
|
42 |
+
# DINOv2 and SAM parameters
|
43 |
+
parser.add_argument('--dinov2-weights', type=str, default="models/dinov2_vitl14_pretrain.pth")
|
44 |
+
parser.add_argument('--sam-weights', type=str, default="models/sam_vit_h_4b8939.pth")
|
45 |
+
parser.add_argument('--points_per_side', type=int, default=64)
|
46 |
+
parser.add_argument('--pred_iou_thresh', type=float, default=0.88)
|
47 |
+
parser.add_argument('--sel_stability_score_thresh', type=float, default=0.0)
|
48 |
+
parser.add_argument('--stability_score_thresh', type=float, default=0.95)
|
49 |
+
parser.add_argument('--iou_filter', type=float, default=0.0)
|
50 |
+
parser.add_argument('--box_nms_thresh', type=float, default=1.0)
|
51 |
+
parser.add_argument('--output_layer', type=int, default=3)
|
52 |
+
parser.add_argument('--dense_multimask_output', type=int, default=0)
|
53 |
+
parser.add_argument('--use_dense_mask', type=int, default=0)
|
54 |
+
parser.add_argument('--multimask_output', type=int, default=0)
|
55 |
+
|
56 |
+
# Matcher parameters
|
57 |
+
parser.add_argument('--num_centers', type=int, default=8, help='K centers for kmeans')
|
58 |
+
parser.add_argument('--use_box', action='store_true', help='use box as an extra prompt for sam')
|
59 |
+
parser.add_argument('--use_points_or_centers', action='store_true', help='points:T, center: F')
|
60 |
+
parser.add_argument('--sample-range', type=tuple, default=(4,6), help='sample points number range')
|
61 |
+
parser.add_argument('--max_sample_iterations', type=int, default=30)
|
62 |
+
parser.add_argument('--alpha', type=float, default=1.)
|
63 |
+
parser.add_argument('--beta', type=float, default=0.)
|
64 |
+
parser.add_argument('--exp', type=float, default=0.)
|
65 |
+
parser.add_argument('--emd_filter', type=float, default=0.0, help='use emd_filter')
|
66 |
+
parser.add_argument('--purity_filter', type=float, default=0.0, help='use purity_filter')
|
67 |
+
parser.add_argument('--coverage_filter', type=float, default=0.0, help='use coverage_filter')
|
68 |
+
parser.add_argument('--use_score_filter', action='store_true')
|
69 |
+
parser.add_argument('--deep_score_norm_filter', type=float, default=0.1)
|
70 |
+
parser.add_argument('--deep_score_filter', type=float, default=0.33)
|
71 |
+
parser.add_argument('--topk_scores_threshold', type=float, default=0.7)
|
72 |
+
parser.add_argument('--num_merging_mask', type=int, default=10, help='topk masks for merging')
|
73 |
+
|
74 |
+
args = parser.parse_args()
|
75 |
+
return args
|
76 |
+
|
77 |
+
def definite_argument_parser(args, version=1):
|
78 |
+
|
79 |
+
if version==1:
|
80 |
+
|
81 |
+
args.max_sample_iterations = 64
|
82 |
+
args.box_nms_thresh = 0.65
|
83 |
+
args.sample_range = (1, 6)
|
84 |
+
args.topk_scores_threshold = 0.0
|
85 |
+
args.use_dense_mask = 1
|
86 |
+
args.use_points_or_centers = True
|
87 |
+
args.purity_filter = 0.02
|
88 |
+
args.iou_filter = 0.85
|
89 |
+
args.multimask_output = 1
|
90 |
+
args.sel_stability_score_thresh = 0.90
|
91 |
+
args.use_score_filter = True
|
92 |
+
args.alpha = 1.0
|
93 |
+
args.beta = 0.
|
94 |
+
args.exp = 0.
|
95 |
+
args.num_merging_mask = 9
|
96 |
+
elif version == 2:
|
97 |
+
args.max_sample_iterations = 30
|
98 |
+
args.sample_range = (4, 6)
|
99 |
+
args.multimask_output = 0
|
100 |
+
args.alpha = 0.8
|
101 |
+
args.beta = 0.2
|
102 |
+
args.exp = 1.
|
103 |
+
args.num_merging_mask = 10
|
104 |
+
elif version == 3:
|
105 |
+
args.max_sample_iterations = 128
|
106 |
+
args.sample_range = (3, 6)
|
107 |
+
args.use_box = True
|
108 |
+
args.use_points_or_centers = True
|
109 |
+
args.coverage_filter = 0.3
|
110 |
+
args.alpha = 0.5
|
111 |
+
args.beta = 0.5
|
112 |
+
args.exp = 0.
|
113 |
+
args.num_merging_mask = 5
|
114 |
+
|
115 |
+
return args
|
116 |
+
|
117 |
+
def preprocess_data(kwargs, args=None):
|
118 |
+
|
119 |
+
img_size = args.img_size
|
120 |
+
transform = transforms.Compose([
|
121 |
+
transforms.Resize(size=(img_size, img_size)),
|
122 |
+
transforms.ToTensor()
|
123 |
+
])
|
124 |
+
|
125 |
+
support_img = Image.fromarray(kwargs.get("support_img"))
|
126 |
+
query_img_1 = Image.fromarray(kwargs.get("query_img_1"))
|
127 |
+
query_img_2 = Image.fromarray(kwargs.get("query_img_2"))
|
128 |
+
|
129 |
+
support_img_ori_size = (support_img.size[1], support_img.size[0]) # H, W
|
130 |
+
query_img_1_ori_size = (query_img_1.size[1], query_img_1.size[0])
|
131 |
+
query_img_2_ori_size = (query_img_2.size[1], query_img_2.size[0])
|
132 |
+
|
133 |
+
|
134 |
+
support_img = transform(support_img)
|
135 |
+
query_img_1 = transform(query_img_1)
|
136 |
+
query_img_2 = transform(query_img_2)
|
137 |
+
|
138 |
+
support_mask = torch.tensor(kwargs.get("support_mask"))
|
139 |
+
support_mask = F.interpolate(support_mask.unsqueeze(0).float(), support_img.size()[-2:],
|
140 |
+
mode='nearest') > 0
|
141 |
+
query_imgs = torch.stack([query_img_1, query_img_2], dim=0)
|
142 |
+
|
143 |
+
data = {
|
144 |
+
"support_img": support_img[None, ...],
|
145 |
+
"support_mask": support_mask,
|
146 |
+
"query_imgs": query_imgs,
|
147 |
+
"support_img_ori_size": support_img_ori_size,
|
148 |
+
"query_imgs_ori_size": (query_img_1_ori_size, query_img_2_ori_size),
|
149 |
+
}
|
150 |
+
|
151 |
+
return data
|
152 |
+
|
153 |
+
def preprocess_support_mask(data, predictor, version=1):
|
154 |
+
|
155 |
+
if version == 3:
|
156 |
+
return data
|
157 |
+
|
158 |
+
sup_mask = data['support_mask'].squeeze()
|
159 |
+
H, W = sup_mask.shape[-2:]
|
160 |
+
input_points = sup_mask.nonzero().numpy()[:1,::-1]#[:,::-1]
|
161 |
+
input_label = np.array([1]*len(input_points))
|
162 |
+
|
163 |
+
support_img_np = data['support_img'].mul(255).byte()
|
164 |
+
support_img_np = support_img_np.squeeze().permute(1,2,0).cpu().numpy()
|
165 |
+
|
166 |
+
# forward encoder to obtain image feature
|
167 |
+
predictor.reset_image()
|
168 |
+
predictor.set_image(support_img_np)
|
169 |
+
|
170 |
+
# mask, _, _ = predictor.predict(
|
171 |
+
# point_coords=input_points,
|
172 |
+
# point_labels=input_label,
|
173 |
+
# multimask_output=False #True
|
174 |
+
# )
|
175 |
+
mask, _, _ = predictor.predict(
|
176 |
+
point_coords=input_points,
|
177 |
+
point_labels=input_label,
|
178 |
+
multimask_output=True # True
|
179 |
+
)
|
180 |
+
predictor.reset_image()
|
181 |
+
|
182 |
+
# show_img_point_box_mask(
|
183 |
+
# support_img_np,
|
184 |
+
# masks=mask,
|
185 |
+
# save_path='test1.png',
|
186 |
+
# mode='mask'
|
187 |
+
# )
|
188 |
+
|
189 |
+
# data['support_mask'] = torch.tensor(mask[:1])[None, ...]
|
190 |
+
data['support_mask'] = torch.tensor(mask[-1:])[None, ...]
|
191 |
+
|
192 |
+
return data
|
193 |
+
|
194 |
+
def main_oss_ops(**kwargs):
|
195 |
+
|
196 |
+
args = default_argument_parser()
|
197 |
+
args = definite_argument_parser(args, kwargs.get("version"))
|
198 |
+
|
199 |
+
# Model initialization
|
200 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
201 |
+
args.device = device
|
202 |
+
|
203 |
+
# create sam
|
204 |
+
sam = kwargs.get("sam")
|
205 |
+
predictor = SamPredictor(sam)
|
206 |
+
generator = SamAutomaticMaskGenerator(
|
207 |
+
sam,
|
208 |
+
points_per_side=args.points_per_side,
|
209 |
+
points_per_batch=64,
|
210 |
+
pred_iou_thresh=args.pred_iou_thresh,
|
211 |
+
stability_score_thresh=args.stability_score_thresh,
|
212 |
+
stability_score_offset=1.0,
|
213 |
+
sel_stability_score_thresh=args.sel_stability_score_thresh,
|
214 |
+
sel_pred_iou_thresh=args.iou_filter,
|
215 |
+
box_nms_thresh=args.box_nms_thresh,
|
216 |
+
sel_output_layer=args.output_layer,
|
217 |
+
output_layer=args.dense_multimask_output,
|
218 |
+
dense_pred=args.use_dense_mask,
|
219 |
+
multimask_output=args.dense_multimask_output > 0,
|
220 |
+
sel_multimask_output=args.multimask_output > 0,
|
221 |
+
)
|
222 |
+
|
223 |
+
# create dinov2, large
|
224 |
+
dinov2 = kwargs.get("dinov2")
|
225 |
+
|
226 |
+
# create matcher
|
227 |
+
score_filter_cfg = {
|
228 |
+
"emd": args.emd_filter,
|
229 |
+
"purity": args.purity_filter,
|
230 |
+
"coverage": args.coverage_filter,
|
231 |
+
"score_filter": args.use_score_filter,
|
232 |
+
"score": args.deep_score_filter,
|
233 |
+
"score_norm": args.deep_score_norm_filter,
|
234 |
+
"topk_scores_threshold": args.topk_scores_threshold
|
235 |
+
}
|
236 |
+
|
237 |
+
matcher = Matcher(
|
238 |
+
encoder=dinov2,
|
239 |
+
generator=generator,
|
240 |
+
num_centers=args.num_centers,
|
241 |
+
use_box=args.use_box,
|
242 |
+
use_points_or_centers=args.use_points_or_centers,
|
243 |
+
sample_range=args.sample_range,
|
244 |
+
max_sample_iterations=args.max_sample_iterations,
|
245 |
+
alpha=args.alpha,
|
246 |
+
beta=args.beta,
|
247 |
+
exp=args.exp,
|
248 |
+
score_filter_cfg=score_filter_cfg,
|
249 |
+
num_merging_mask=args.num_merging_mask,
|
250 |
+
device=args.device
|
251 |
+
)
|
252 |
+
|
253 |
+
# process data
|
254 |
+
data = preprocess_data(kwargs, args=args)
|
255 |
+
data = preprocess_support_mask(data, predictor, version=kwargs.get("version"))
|
256 |
+
|
257 |
+
# inference
|
258 |
+
with torch.no_grad():
|
259 |
+
utils.fix_randseed(0)
|
260 |
+
pred_masks, pred_mask_lists = [], []
|
261 |
+
|
262 |
+
# support mask
|
263 |
+
support_img_ori_size = data['support_img_ori_size']
|
264 |
+
mask = data['support_mask'].to(predictor.model.device).float()
|
265 |
+
mask = F.interpolate(mask, support_img_ori_size, mode="bilinear", align_corners=False) > 0
|
266 |
+
mask = mask.squeeze(0).cpu().numpy()
|
267 |
+
pred_masks.append(mask)
|
268 |
+
pred_mask_lists.append(None)
|
269 |
+
|
270 |
+
for query_img, query_img_ori_size in zip(data['query_imgs'], data['query_imgs_ori_size']):
|
271 |
+
data['query_img'], data['query_img_ori_size'] = query_img[None, ...], query_img_ori_size
|
272 |
+
|
273 |
+
support_imgs, support_masks = data["support_img"].to(matcher.device)[None, ...], data["support_mask"].to(matcher.device) # (1, 1, 3, H, W), (1, 1, H, W)
|
274 |
+
query_img, query_img_ori_size = data['query_img'].to(matcher.device), data['query_img_ori_size'] # (1, 3, H, W), img_size
|
275 |
+
|
276 |
+
# 1. Matcher prepare references and target
|
277 |
+
matcher.set_reference(support_imgs, support_masks)
|
278 |
+
matcher.set_target(query_img, query_img_ori_size)
|
279 |
+
|
280 |
+
# 2. Predict mask of target
|
281 |
+
pred_mask, pred_mask_list = matcher.predict()
|
282 |
+
matcher.clear()
|
283 |
+
|
284 |
+
pred_masks.append(pred_mask)
|
285 |
+
pred_mask_lists.append(pred_mask_list)
|
286 |
+
|
287 |
+
|
288 |
+
return pred_masks, pred_mask_lists
|
runner.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import os
|
5 |
+
import pathlib
|
6 |
+
import shlex
|
7 |
+
import shutil
|
8 |
+
import subprocess
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import slugify
|
13 |
+
import torch
|
14 |
+
import numpy as np
|
15 |
+
import huggingface_hub
|
16 |
+
from huggingface_hub import HfApi
|
17 |
+
from omegaconf import OmegaConf
|
18 |
+
|
19 |
+
from segment_anything import sam_model_registry
|
20 |
+
from dinov2.models import vision_transformer as vits
|
21 |
+
import dinov2.utils.utils as dinov2_utils
|
22 |
+
|
23 |
+
from gradio_demo.oss_ops_inference import main_oss_ops
|
24 |
+
|
25 |
+
|
26 |
+
ORIGINAL_SPACE_ID = ''
|
27 |
+
SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
|
28 |
+
|
29 |
+
|
30 |
+
class Runner:
|
31 |
+
def __init__(self, hf_token: str | None = None):
|
32 |
+
self.hf_token = hf_token
|
33 |
+
|
34 |
+
# self.checkpoint_dir = pathlib.Path('checkpoints')
|
35 |
+
# self.checkpoint_dir.mkdir(exist_ok=True)
|
36 |
+
|
37 |
+
# oss, ops
|
38 |
+
self.prompt_res_g = None
|
39 |
+
self.prompt_mask_g = None
|
40 |
+
self.tar1_res_g = None
|
41 |
+
self.tar2_res_g = None
|
42 |
+
self.version = 1
|
43 |
+
|
44 |
+
self.pred_masks = None
|
45 |
+
self.pred_mask_lists = None
|
46 |
+
|
47 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
48 |
+
|
49 |
+
sam_checkpoint = "models/sam_vit_h_4b8939.pth"
|
50 |
+
model_type = "default"
|
51 |
+
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
52 |
+
self.sam.to(device=device)
|
53 |
+
|
54 |
+
dinov2_kwargs = dict(
|
55 |
+
img_size=518,
|
56 |
+
patch_size=14,
|
57 |
+
init_values=1e-5,
|
58 |
+
ffn_layer='mlp',
|
59 |
+
block_chunks=0,
|
60 |
+
qkv_bias=True,
|
61 |
+
proj_bias=True,
|
62 |
+
ffn_bias=True,
|
63 |
+
)
|
64 |
+
dinov2 = vits.__dict__["vit_large"](**dinov2_kwargs)
|
65 |
+
|
66 |
+
dinov2_utils.load_pretrained_weights(dinov2, "models/dinov2_vitl14_pretrain.pth", "teacher")
|
67 |
+
dinov2.eval()
|
68 |
+
dinov2.to(device=device)
|
69 |
+
self.dinov2 = dinov2
|
70 |
+
|
71 |
+
def inference_oss_ops(self, prompt, target1, target2, version):
|
72 |
+
|
73 |
+
if version == 'version 1 (πΊ multiple instances π» whole, π» part)':
|
74 |
+
self.prompt_res_g, self.tar1_res_g, self.tar2_res_g = prompt['image'], target1, target2
|
75 |
+
self.prompt_mask_g = (prompt['mask'][..., 0] != 0)[None, ...] # 1, H, w
|
76 |
+
self.version = 1
|
77 |
+
elif version == 'version 2 (π» multiple instances πΊ whole, π» part)':
|
78 |
+
self.prompt_res_g, self.tar1_res_g, self.tar2_res_g = prompt['image'], target1, target2
|
79 |
+
self.prompt_mask_g = (prompt['mask'][..., 0] != 0)[None, ...] # 1, H, w
|
80 |
+
self.version = 2
|
81 |
+
else:
|
82 |
+
self.prompt_res_g, self.tar1_res_g, self.tar2_res_g = prompt['image'], target1, target2
|
83 |
+
self.prompt_mask_g = (prompt['mask'][..., 0] != 0)[None, ...] # 1, H, w
|
84 |
+
self.version = 3
|
85 |
+
|
86 |
+
self.pred_masks, self.pred_mask_lists = main_oss_ops(
|
87 |
+
sam=self.sam,
|
88 |
+
dinov2=self.dinov2,
|
89 |
+
support_img=self.prompt_res_g,
|
90 |
+
support_mask=self.prompt_mask_g,
|
91 |
+
query_img_1=self.tar1_res_g,
|
92 |
+
query_img_2=self.tar2_res_g,
|
93 |
+
version=self.version
|
94 |
+
)
|
95 |
+
|
96 |
+
text = "Process Successful!"
|
97 |
+
|
98 |
+
return text
|
99 |
+
|
100 |
+
|
101 |
+
def clear_fn(self):
|
102 |
+
|
103 |
+
self.prompt_res_g, self.tar1_res_g, self.tar2_res_g, self.prompt_mask_g = None, None, None, None
|
104 |
+
self.version = 1
|
105 |
+
self.pred_masks = None
|
106 |
+
self.pred_mask_lists = None
|
107 |
+
|
108 |
+
return [None] * 7
|
109 |
+
|
110 |
+
|
111 |
+
def controllable_mask_output(self, k):
|
112 |
+
|
113 |
+
color = np.array([30, 144, 255])
|
114 |
+
|
115 |
+
if self.version != 1:
|
116 |
+
|
117 |
+
prompt_mask_res, tar1_mask_res, tar2_mask_res = self.pred_masks
|
118 |
+
|
119 |
+
h, w = prompt_mask_res.shape[-2:]
|
120 |
+
prompt_mask_img = prompt_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
121 |
+
prompt_mask_res = self.prompt_res_g * 0.5 + prompt_mask_img * 0.5
|
122 |
+
|
123 |
+
h, w = tar1_mask_res.shape[-2:]
|
124 |
+
tar1_mask_img = tar1_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
125 |
+
tar1_mask_res = self.tar1_res_g * 0.5 + tar1_mask_img * 0.5
|
126 |
+
|
127 |
+
h, w = tar2_mask_res.shape[-2:]
|
128 |
+
tar2_mask_img = tar2_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
129 |
+
tar2_mask_res = self.tar2_res_g * 0.5 + tar2_mask_img * 0.5
|
130 |
+
|
131 |
+
else:
|
132 |
+
prompt_mask_res = self.pred_masks[0]
|
133 |
+
tar1_mask_res, tar2_mask_res = self.pred_mask_lists[1:]
|
134 |
+
|
135 |
+
tar1_mask_res = tar1_mask_res[:min(k, len(tar1_mask_res))].sum(0)>0
|
136 |
+
tar2_mask_res = tar2_mask_res[:min(k, len(tar2_mask_res))].sum(0) > 0
|
137 |
+
|
138 |
+
h, w = prompt_mask_res.shape[-2:]
|
139 |
+
prompt_mask_img = prompt_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
140 |
+
prompt_mask_res = self.prompt_res_g * 0.5 + prompt_mask_img * 0.5
|
141 |
+
|
142 |
+
h, w = tar1_mask_res.shape[-2:]
|
143 |
+
tar1_mask_img = tar1_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
144 |
+
tar1_mask_res = self.tar1_res_g * 0.5 + tar1_mask_img * 0.5
|
145 |
+
|
146 |
+
h, w = tar2_mask_res.shape[-2:]
|
147 |
+
tar2_mask_img = tar2_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
148 |
+
tar2_mask_res = self.tar2_res_g * 0.5 + tar2_mask_img * 0.5
|
149 |
+
|
150 |
+
return prompt_mask_res/255, tar1_mask_res/255, tar2_mask_res/255
|
151 |
+
|
152 |
+
|
153 |
+
def inference_vos(self, prompt_vid, vid):
|
154 |
+
|
155 |
+
raise NotImplementedError
|
156 |
+
|