Xeraphinite commited on
Commit
ebb9c75
Β·
verified Β·
1 Parent(s): 4edfb60

Upload 19 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/dinosaur1.png filter=lfs diff=lfs merge=lfs -text
37
+ images/dinosaur2.png filter=lfs diff=lfs merge=lfs -text
38
+ images/dinosaur3.png filter=lfs diff=lfs merge=lfs -text
39
+ images/elephant1.png filter=lfs diff=lfs merge=lfs -text
40
+ images/elephant2.png filter=lfs diff=lfs merge=lfs -text
41
+ images/elephant3.png filter=lfs diff=lfs merge=lfs -text
42
+ images/hmbb3.png filter=lfs diff=lfs merge=lfs -text
43
+ images/horse1.png filter=lfs diff=lfs merge=lfs -text
44
+ images/horse2.png filter=lfs diff=lfs merge=lfs -text
45
+ images/horse3.png filter=lfs diff=lfs merge=lfs -text
Matcher.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os import path
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchvision import transforms
8
+
9
+ import numpy as np
10
+ import cv2
11
+ import ot
12
+ import math
13
+ from scipy.optimize import linear_sum_assignment
14
+
15
+ from segment_anything import sam_model_registry
16
+ from segment_anything import SamAutomaticMaskGenerator
17
+ from dinov2.models import vision_transformer as vits
18
+ import dinov2.utils.utils as dinov2_utils
19
+ from dinov2.data.transforms import MaybeToTensor, make_normalize_transform
20
+
21
+ from matcher.k_means import kmeans_pp
22
+
23
+ import random
24
+
25
+ class Matcher:
26
+ def __init__(
27
+ self,
28
+ encoder,
29
+ generator=None,
30
+ input_size=518,
31
+ num_centers=8,
32
+ use_box=False,
33
+ use_points_or_centers=True,
34
+ sample_range=(4, 6),
35
+ max_sample_iterations=30,
36
+ alpha=1.,
37
+ beta=0.,
38
+ exp=0.,
39
+ score_filter_cfg=None,
40
+ num_merging_mask=10,
41
+ device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
42
+ ):
43
+ # models
44
+ self.encoder = encoder
45
+ self.generator = generator
46
+ self.rps = None
47
+
48
+ if not isinstance(input_size, tuple):
49
+ input_size = (input_size, input_size)
50
+ self.input_size = input_size
51
+
52
+ # transforms for image encoder
53
+ self.encoder_transform = transforms.Compose([
54
+ MaybeToTensor(),
55
+ make_normalize_transform(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
56
+ ])
57
+
58
+ self.tar_img = None
59
+ self.tar_img_np = None
60
+ self.tar_img_ori_size = None
61
+
62
+ self.ref_imgs = None
63
+ self.ref_masks_pool = None
64
+ self.nshot = None
65
+
66
+ self.encoder_img_size = None
67
+ self.encoder_feat_size = None
68
+
69
+ self.num_centers = num_centers
70
+ self.use_box = use_box
71
+ self.use_points_or_centers = use_points_or_centers
72
+ self.sample_range = sample_range
73
+ self.max_sample_iterations =max_sample_iterations
74
+
75
+ self.alpha, self.beta, self.exp = alpha, beta, exp
76
+ assert score_filter_cfg is not None
77
+ self.score_filter_cfg = score_filter_cfg
78
+ self.num_merging_mask = num_merging_mask
79
+
80
+ self.device = device
81
+
82
+ def set_reference(self, imgs, masks):
83
+
84
+ def reference_masks_verification(masks):
85
+ if masks.sum() == 0:
86
+ _, _, sh, sw = masks.shape
87
+ masks[..., (sh // 2 - 7):(sh // 2 + 7), (sw // 2 - 7):(sw // 2 + 7)] = 1
88
+ return masks
89
+
90
+ imgs = imgs.flatten(0, 1) # bs, 3, h, w
91
+ img_size = imgs.shape[-1]
92
+ assert img_size == self.input_size[-1]
93
+ feat_size = img_size // self.encoder.patch_size
94
+
95
+ self.encoder_img_size = img_size
96
+ self.encoder_feat_size = feat_size
97
+
98
+ # process reference masks
99
+ masks = reference_masks_verification(masks)
100
+ masks = masks.permute(1, 0, 2, 3) # ns, 1, h, w
101
+ ref_masks_pool = F.avg_pool2d(masks.float(), (self.encoder.patch_size, self.encoder.patch_size))
102
+ nshot = ref_masks_pool.shape[0]
103
+ ref_masks_pool = (ref_masks_pool > self.generator.predictor.model.mask_threshold).float()
104
+ ref_masks_pool = ref_masks_pool.reshape(-1) # nshot, N
105
+
106
+ self.ref_imgs = imgs
107
+ self.ref_masks_pool = ref_masks_pool
108
+ self.nshot = nshot
109
+
110
+ def set_target(self, img, tar_img_ori_size):
111
+
112
+ img_h, img_w = img.shape[-2:]
113
+ assert img_h == self.input_size[0] and img_w == self.input_size[1]
114
+
115
+ # transform query to numpy as input of sam
116
+ img_np = img.mul(255).byte()
117
+ img_np = img_np.squeeze(0).permute(1, 2, 0).cpu().numpy()
118
+
119
+ self.tar_img = img
120
+ self.tar_img_np = img_np
121
+ self.tar_img_ori_size = tar_img_ori_size
122
+
123
+ def set_rps(self):
124
+ if self.rps is None:
125
+ assert self.encoder_feat_size is not None
126
+ self.rps = RobustPromptSampler(
127
+ encoder_feat_size=self.encoder_feat_size,
128
+ sample_range=self.sample_range,
129
+ max_iterations=self.max_sample_iterations
130
+ )
131
+
132
+
133
+ def predict(self):
134
+
135
+ ref_feats, tar_feat = self.extract_img_feats()
136
+ all_points, box, S, C, reduced_points_num = self.patch_level_matching(ref_feats=ref_feats, tar_feat=tar_feat)
137
+ points = self.clustering(all_points) if not self.use_points_or_centers else all_points
138
+ self.set_rps()
139
+ mask, mask_list = self.mask_generation(self.tar_img_np, points, box, all_points, self.ref_masks_pool, C)
140
+ return mask, mask_list
141
+
142
+
143
+ def extract_img_feats(self):
144
+
145
+ ref_imgs = torch.cat([self.encoder_transform(rimg)[None, ...] for rimg in self.ref_imgs], dim=0)
146
+ tar_img = torch.cat([self.encoder_transform(timg)[None, ...] for timg in self.tar_img], dim=0)
147
+
148
+ ref_feats = self.encoder.forward_features(ref_imgs.to(self.device))["x_prenorm"][:, 1:]
149
+ tar_feat = self.encoder.forward_features(tar_img.to(self.device))["x_prenorm"][:, 1:]
150
+ # ns, N, c = ref_feats.shape
151
+ ref_feats = ref_feats.reshape(-1, self.encoder.embed_dim) # ns*N, c
152
+ tar_feat = tar_feat.reshape(-1, self.encoder.embed_dim) # N, c
153
+
154
+ ref_feats = F.normalize(ref_feats, dim=1, p=2) # normalize for cosine similarity
155
+ tar_feat = F.normalize(tar_feat, dim=1, p=2)
156
+
157
+ return ref_feats, tar_feat
158
+
159
+ def patch_level_matching(self, ref_feats, tar_feat):
160
+
161
+ # forward matching
162
+ S = ref_feats @ tar_feat.t() # ns*N, N
163
+ C = (1 - S) / 2 # distance
164
+
165
+ S_forward = S[self.ref_masks_pool.flatten().bool()]
166
+
167
+ indices_forward = linear_sum_assignment(S_forward.cpu(), maximize=True)
168
+ indices_forward = [torch.as_tensor(index, dtype=torch.int64, device=self.device) for index in indices_forward]
169
+ sim_scores_f = S_forward[indices_forward[0], indices_forward[1]]
170
+ indices_mask = self.ref_masks_pool.flatten().nonzero()[:, 0]
171
+
172
+ # reverse matching
173
+ S_reverse = S.t()[indices_forward[1]]
174
+ indices_reverse = linear_sum_assignment(S_reverse.cpu(), maximize=True)
175
+ indices_reverse = [torch.as_tensor(index, dtype=torch.int64, device=self.device) for index in indices_reverse]
176
+ retain_ind = torch.isin(indices_reverse[1], indices_mask)
177
+ if not (retain_ind == False).all().item():
178
+ indices_forward = [indices_forward[0][retain_ind], indices_forward[1][retain_ind]]
179
+ sim_scores_f = sim_scores_f[retain_ind]
180
+ inds_matched, sim_matched = indices_forward, sim_scores_f
181
+
182
+ reduced_points_num = len(sim_matched) // 2 if len(sim_matched) > 40 else len(sim_matched)
183
+ sim_sorted, sim_idx_sorted = torch.sort(sim_matched, descending=True)
184
+ sim_filter = sim_idx_sorted[:reduced_points_num]
185
+ points_matched_inds = indices_forward[1][sim_filter]
186
+
187
+ points_matched_inds_set = torch.tensor(list(set(points_matched_inds.cpu().tolist())))
188
+ points_matched_inds_set_w = points_matched_inds_set % (self.encoder_feat_size)
189
+ points_matched_inds_set_h = points_matched_inds_set // (self.encoder_feat_size)
190
+ idxs_mask_set_x = (points_matched_inds_set_w * self.encoder.patch_size + self.encoder.patch_size // 2).tolist()
191
+ idxs_mask_set_y = (points_matched_inds_set_h * self.encoder.patch_size + self.encoder.patch_size // 2).tolist()
192
+
193
+ ponits_matched = []
194
+ for x, y in zip(idxs_mask_set_x, idxs_mask_set_y):
195
+ if int(x) < self.input_size[1] and int(y) < self.input_size[0]:
196
+ ponits_matched.append([int(x), int(y)])
197
+ ponits = np.array(ponits_matched)
198
+
199
+ if self.use_box:
200
+ box = np.array([
201
+ max(ponits[:, 0].min(), 0),
202
+ max(ponits[:, 1].min(), 0),
203
+ min(ponits[:, 0].max(), self.input_size[1] - 1),
204
+ min(ponits[:, 1].max(), self.input_size[0] - 1),
205
+ ])
206
+ else:
207
+ box = None
208
+
209
+ return ponits, box, S, C, reduced_points_num
210
+
211
+ def clustering(self, points):
212
+
213
+ num_centers = min(self.num_centers, len(points))
214
+ flag = True
215
+ while (flag):
216
+ centers, cluster_assignment = kmeans_pp(points, num_centers)
217
+ id, fre = torch.unique(cluster_assignment, return_counts=True)
218
+ if id.shape[0] == num_centers:
219
+ flag = False
220
+ else:
221
+ print('Kmeans++ failed, re-run')
222
+ centers = np.array(centers).astype(np.int64)
223
+ return centers
224
+
225
+
226
+ def mask_generation(self, tar_img_np, points, box, all_ponits, ref_masks_pool, C):
227
+ samples_list, label_list = self.rps.sample_points(points)
228
+ tar_masks_ori = self.generator.generate(
229
+ tar_img_np,
230
+ select_point_coords=samples_list,
231
+ select_point_labels=label_list,
232
+ select_box=[box] if self.use_box else None,
233
+ )
234
+ tar_masks = torch.cat(
235
+ [torch.from_numpy(qmask['segmentation']).float()[None, None, ...].to(self.device) for
236
+ qmask in tar_masks_ori], dim=0).cpu().numpy() > 0
237
+
238
+ # append to original results
239
+ purity = torch.zeros(tar_masks.shape[0])
240
+ coverage = torch.zeros(tar_masks.shape[0])
241
+ emd = torch.zeros(tar_masks.shape[0])
242
+
243
+ samples = samples_list[-1]
244
+ labels = torch.ones(tar_masks.shape[0], samples.shape[1])
245
+ samples = torch.ones(tar_masks.shape[0], samples.shape[1], 2)
246
+
247
+ # compute scores for each mask
248
+ for i in range(len(tar_masks)):
249
+ purity_, coverage_, emd_, sample_, label_, mask_ = \
250
+ self.rps.get_mask_scores(
251
+ points=points,
252
+ masks=tar_masks[i],
253
+ all_points=all_ponits,
254
+ emd_cost=C,
255
+ ref_masks_pool=ref_masks_pool
256
+ )
257
+ assert np.all(mask_ == tar_masks[i])
258
+ purity[i] = purity_
259
+ coverage[i] = coverage_
260
+ emd[i] = emd_
261
+
262
+ pred_masks = tar_masks.squeeze(1)
263
+ metric_preds = {
264
+ "purity": purity,
265
+ "coverage": coverage,
266
+ "emd": emd
267
+ }
268
+
269
+ scores = self.alpha * emd + self.beta * purity * coverage ** self.exp
270
+
271
+ def check_pred_mask(pred_masks):
272
+ if len(pred_masks.shape) < 3: # avoid only one mask
273
+ pred_masks = pred_masks[None, ...]
274
+ return pred_masks
275
+
276
+ pred_masks = check_pred_mask(pred_masks)
277
+
278
+ # filter the false-positive mask fragments by using the proposed metrics
279
+ for metric in ["coverage", "emd", "purity"]:
280
+ if self.score_filter_cfg[metric] > 0:
281
+ thres = min(self.score_filter_cfg[metric], metric_preds[metric].max())
282
+ idx = torch.where(metric_preds[metric] >= thres)[0]
283
+ scores = scores[idx]
284
+ samples = samples[idx]
285
+ labels = labels[idx]
286
+ pred_masks = check_pred_mask(pred_masks[idx])
287
+
288
+ for key in metric_preds.keys():
289
+ metric_preds[key] = metric_preds[key][idx]
290
+
291
+ # score-based masks selection, masks merging
292
+ if self.score_filter_cfg["score_filter"]:
293
+
294
+ distances = 1 - scores
295
+ distances, rank = torch.sort(distances, descending=False)
296
+ distances_norm = distances - distances.min()
297
+ distances_norm = distances_norm / (distances.max() + 1e-6)
298
+ filer_dis = distances < self.score_filter_cfg["score"]
299
+ filer_dis[..., 0] = True
300
+ filer_dis_norm = distances_norm < self.score_filter_cfg["score_norm"]
301
+ filer_dis = filer_dis * filer_dis_norm
302
+
303
+ pred_masks = check_pred_mask(pred_masks)
304
+ masks = pred_masks[rank[filer_dis][:self.num_merging_mask]]
305
+ masks = check_pred_mask(masks)
306
+ mask_list = masks
307
+ masks = masks.sum(0) > 0
308
+ masks = masks[None, ...]
309
+
310
+ else:
311
+
312
+ topk = min(self.num_merging_mask, scores.size(0))
313
+ topk_idx = scores.topk(topk)[1]
314
+ topk_samples = samples[topk_idx].cpu().numpy()
315
+ topk_scores = scores[topk_idx].cpu().numpy()
316
+ topk_pred_masks = pred_masks[topk_idx]
317
+ topk_pred_masks = check_pred_mask(topk_pred_masks)
318
+
319
+ if self.score_filter_cfg["topk_scores_threshold"] > 0:
320
+ # map scores to 0-1
321
+ topk_scores = topk_scores / (topk_scores.max())
322
+
323
+ idx = topk_scores > self.score_filter_cfg["topk_scores_threshold"]
324
+ topk_samples = topk_samples[idx]
325
+
326
+ topk_pred_masks = check_pred_mask(topk_pred_masks)
327
+ topk_pred_masks = topk_pred_masks[idx]
328
+ mask_list = []
329
+ for i in range(len(topk_samples)):
330
+ mask = topk_pred_masks[i][None, ...]
331
+ mask_list.append(mask)
332
+ mask_list = np.concatenate(mask_list, axis=0)
333
+ masks = np.sum(mask_list, axis=0) > 0
334
+ masks = check_pred_mask(masks)
335
+
336
+ tar_img_ori_size = self.tar_img_ori_size
337
+ mask = torch.tensor(masks, device=self.device)[None, ...]
338
+ mask = F.interpolate(mask.float(), tar_img_ori_size, mode="bilinear", align_corners=False) > 0
339
+ mask = mask.squeeze(0).cpu().numpy()
340
+ if mask_list is not None:
341
+ mask_list = torch.tensor(mask_list, device=self.device)[:, None, ...]
342
+ mask_list = F.interpolate(mask_list.float(), tar_img_ori_size, mode="bilinear", align_corners=False)
343
+ mask_list = mask_list.squeeze(0).cpu().numpy()
344
+
345
+ return mask, mask_list
346
+
347
+
348
+ def clear(self):
349
+
350
+ self.tar_img = None
351
+ self.tar_img_np = None
352
+ self.tar_img_ori_size = None
353
+
354
+ self.ref_imgs = None
355
+ self.ref_masks_pool = None
356
+ self.nshot = None
357
+
358
+ self.encoder_img_size = None
359
+ self.encoder_feat_size = None
360
+
361
+
362
+
363
+ class RobustPromptSampler:
364
+
365
+ def __init__(
366
+ self,
367
+ encoder_feat_size,
368
+ sample_range,
369
+ max_iterations
370
+ ):
371
+ self.encoder_feat_size = encoder_feat_size
372
+ self.sample_range = sample_range
373
+ self.max_iterations = max_iterations
374
+
375
+
376
+ def get_mask_scores(self, points, masks, all_points, emd_cost, ref_masks_pool):
377
+
378
+ def is_in_mask(point, mask):
379
+ # input: point: n*2, mask: h*w
380
+ # output: n*1
381
+ h, w = mask.shape
382
+ point = point.astype(np.int)
383
+ point = point[:, ::-1] # y,x
384
+ point = np.clip(point, 0, [h - 1, w - 1])
385
+ return mask[point[:, 0], point[:, 1]]
386
+
387
+ ori_masks = masks
388
+ masks = cv2.resize(
389
+ masks[0].astype(np.float32),
390
+ (self.encoder_feat_size, self.encoder_feat_size),
391
+ interpolation=cv2.INTER_AREA)
392
+ if masks.max() <= 0:
393
+ thres = masks.max() - 1e-6
394
+ else:
395
+ thres = 0
396
+ masks = masks > thres
397
+
398
+ # 1. emd
399
+ emd_cost_pool = emd_cost[ref_masks_pool.flatten().bool(), :][:, masks.flatten()]
400
+ emd = ot.emd2(a=[1. / emd_cost_pool.shape[0] for i in range(emd_cost_pool.shape[0])],
401
+ b=[1. / emd_cost_pool.shape[1] for i in range(emd_cost_pool.shape[1])],
402
+ M=emd_cost_pool.cpu().numpy())
403
+ emd_score = 1 - emd
404
+
405
+ labels = np.ones((points.shape[0],))
406
+
407
+ # 2. purity and coverage
408
+ assert all_points is not None
409
+ points_in_mask = is_in_mask(all_points, ori_masks[0])
410
+ points_in_mask = all_points[points_in_mask]
411
+ # here we define two metrics for local matching , purity and coverage
412
+ # purity: points_in/mask_area, the higher means the denser points in mask
413
+ # coverage: points_in / all_points, the higher means the mask is more complete
414
+ mask_area = max(float(masks.sum()), 1.0)
415
+ purity = points_in_mask.shape[0] / mask_area
416
+ coverage = points_in_mask.shape[0] / all_points.shape[0]
417
+ purity = torch.tensor([purity]) + 1e-6
418
+ coverage = torch.tensor([coverage]) + 1e-6
419
+ return purity, coverage, emd_score, points, labels, ori_masks
420
+
421
+ def combinations(self, n, k):
422
+ if k > n:
423
+ return []
424
+ if k == 0:
425
+ return [[]]
426
+ if k == n:
427
+ return [[i for i in range(n)]]
428
+ res = []
429
+ for i in range(n):
430
+ for j in self.combinations(i, k - 1):
431
+ res.append(j + [i])
432
+ return res
433
+
434
+ def sample_points(self, points):
435
+ # return list of arrary
436
+
437
+ sample_list = []
438
+ label_list = []
439
+ for i in range(min(self.sample_range[0], len(points)), min(self.sample_range[1], len(points)) + 1):
440
+ if len(points) > 8:
441
+ index = [random.sample(range(len(points)), i) for j in range(self.max_iterations)]
442
+ sample = np.take(points, index, axis=0) # (max_iterations * i) * 2
443
+ else:
444
+ index = self.combinations(len(points), i)
445
+ sample = np.take(points, index, axis=0) # i * n * 2
446
+
447
+ # generate label max_iterations * i
448
+ label = np.ones((sample.shape[0], i))
449
+ sample_list.append(sample)
450
+ label_list.append(label)
451
+ return sample_list, label_list
452
+
app_running.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import random
4
+ import time
5
+ import gradio as gr
6
+ from gradio_demo.runner import Runner
7
+ import matplotlib.pyplot as plt
8
+
9
+ def show_mask(mask, ax, color='blue'):
10
+ if color == 'blue':
11
+ # reference, blue
12
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
13
+ else:
14
+ # target, green
15
+ color = np.array([78 / 255, 238 / 255, 148 / 255, 0.6])
16
+
17
+ # if random_color:
18
+ # color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
19
+ # else:
20
+ # color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
21
+ h, w = mask.shape[-2:]
22
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
23
+ ax.imshow(mask_image)
24
+
25
+
26
+ def show_points(coords, labels, ax, marker_size=375):
27
+ pos_points = coords[labels == 1]
28
+ neg_points = coords[labels == 0]
29
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
30
+ linewidth=1.25)
31
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
32
+ linewidth=1.25)
33
+
34
+ def show_box(box, ax):
35
+ x0, y0 = box[0], box[1]
36
+ w, h = box[2] - box[0], box[3] - box[1]
37
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
38
+
39
+
40
+ def show_img_point_box_mask(img, input_point=None, input_label=None, box=None, masks=None, save_path=None, mode='mask', color='blue'):
41
+
42
+ if mode == 'point':
43
+ # point
44
+ plt.figure(figsize=(10, 10))
45
+ plt.imshow(img)
46
+ show_points(input_point, input_label, plt.gca())
47
+ plt.axis('on')
48
+ plt.savefig(save_path, bbox_inches='tight')
49
+ elif mode == 'box':
50
+ # box
51
+ plt.figure(figsize=(10, 10))
52
+ plt.imshow(img)
53
+ show_box(box, plt.gca())
54
+ plt.axis('on')
55
+ plt.savefig(save_path, bbox_inches='tight')
56
+ else:
57
+ # mask
58
+ plt.figure(figsize=(10, 10))
59
+ plt.imshow(img)
60
+ show_mask(masks, plt.gca(), color=color)
61
+ plt.axis('off')
62
+ plt.savefig(save_path, bbox_inches='tight')
63
+ plt.close()
64
+
65
+
66
+ def create_oss_demo(
67
+ runner: Runner,
68
+ pipe: None = None
69
+ ) -> gr.Blocks:
70
+
71
+ examples = [
72
+ ['./gradio_demo/images/horse1.png', './gradio_demo/images/horse2.png', './gradio_demo/images/horse3.png'],
73
+ ['./gradio_demo/images/hmbb1.png', './gradio_demo/images/hmbb2.png', './gradio_demo/images/hmbb3.png'],
74
+ ['./gradio_demo/images/earth1.png', './gradio_demo/images/earth2.png', './gradio_demo/images/earth3.png'],
75
+ ['./gradio_demo/images/elephant1.png', './gradio_demo/images/elephant2.png', './gradio_demo/images/elephant3.png'],
76
+ ['./gradio_demo/images/dinosaur1.png', './gradio_demo/images/dinosaur2.png', './gradio_demo/images/dinosaur3.png'],
77
+ ]
78
+
79
+ with gr.Blocks() as oss_demo:
80
+ with gr.Column():
81
+
82
+ # inputs
83
+ with gr.Row():
84
+ img_input_prompt = gr.ImageMask(label='Prompt (提瀺图)')
85
+ img_input_target1 = gr.Image(label='Target 1 (ζ΅‹θ―•ε›Ύ1)')
86
+ img_input_target2 = gr.Image(label='Target 2 (ζ΅‹θ―•ε›Ύ2)')
87
+
88
+ version = gr.inputs.Radio(['version 1 (πŸ”Ί multiple instances πŸ”» whole, πŸ”» part)',
89
+ 'version 2 (πŸ”» multiple instances πŸ”Ί whole, πŸ”» part)',
90
+ 'version 3 (πŸ”» multiple instances πŸ”» whole, πŸ”Ί part)'],
91
+ type="value", default='version 1 (πŸ”Ί whole, πŸ”» part)',
92
+ label='Multiple Instances (version 1), Single Instance (version 2), Part of a object (version 3)')
93
+
94
+ with gr.Row():
95
+ submit1 = gr.Button("提亀 (Submit)")
96
+ clear = gr.Button("清陀 (Clear)")
97
+ info = gr.Text(label="Processing result: ", interactive=False)
98
+
99
+ # decision
100
+ K = gr.Slider(0, 10, 10, step=1, label="Controllable mask output", interactive=True)
101
+ submit2 = gr.Button("提亀 (Submit)")
102
+
103
+ # outputs
104
+ with gr.Row():
105
+ img_output_pmt = gr.Image(label='Prompt (提瀺图)')
106
+ img_output_tar1 = gr.Image(label='Output 1 (θΎ“ε‡Ίε›Ύ1)')
107
+ img_output_tar2 = gr.Image(label='Output 2 (θΎ“ε‡Ίε›Ύ2)')
108
+
109
+ # images
110
+ gr.Examples(
111
+ examples=examples,
112
+ fn=runner.inference_oss_ops,
113
+ inputs=[img_input_prompt, img_input_target1, img_input_target2],
114
+ outputs=info
115
+ )
116
+
117
+ submit1.click(
118
+ fn=runner.inference_oss_ops,
119
+ inputs=[img_input_prompt, img_input_target1, img_input_target2, version],
120
+ outputs=info
121
+ )
122
+ submit2.click(
123
+ fn=runner.controllable_mask_output,
124
+ inputs=K,
125
+ outputs=[img_output_pmt, img_output_tar1, img_output_tar2]
126
+ )
127
+
128
+ clear.click(
129
+ fn=runner.clear_fn,
130
+ inputs=None,
131
+ outputs=[img_input_prompt, img_input_target1, img_input_target2, info, img_output_pmt, img_output_tar1, img_output_tar2],
132
+ queue=False
133
+ )
134
+
135
+ return oss_demo
136
+
137
+
138
+ def create_vos_demo(
139
+ runner: Runner,
140
+ pipe: None = None
141
+ ) -> gr.Interface:
142
+
143
+ raise NotImplementedError
144
+
145
+ def create_demo(
146
+ runner: Runner,
147
+ pipe: None = None
148
+ ) -> gr.TabbedInterface:
149
+
150
+ title = "Matcher🎯: Segment Anything with One Shot Using All-Purpose Feature Matching<br> \
151
+ <div align='center'> \
152
+ <h2><a href='https://arxiv.org/abs/2305.13310' target='_blank' rel='noopener'>[paper]</a> \
153
+ <a href='https://github.com/aim-uofa/Matcher' target='_blank' rel='noopener'>[code]</a></h2> \
154
+ <h2>Matcher can segment anything with one shot by integrating an all-purpose feature extraction model and a class-agnostic segmentation model.</h2> \
155
+ <br> \
156
+ </div> \
157
+ "
158
+
159
+ oss_demo = create_oss_demo(runner=runner, pipe=pipe)
160
+ # vos_demo = create_vos_demo(runner=runner, pipe=pipe)
161
+ demo = gr.TabbedInterface(
162
+ [oss_demo,],
163
+ ['OSS+OPS',], title=title)
164
+ return demo
165
+
166
+
167
+ if __name__ == '__main__':
168
+ pipe = None
169
+ HF_TOKEN = os.getenv('HF_TOKEN')
170
+ runner = Runner(HF_TOKEN)
171
+ # runner = None
172
+ demo = create_demo(runner, pipe)
173
+ demo.launch(enable_queue=False)
images/dinosaur1.png ADDED

Git LFS Details

  • SHA256: fae170e17e9064e7c91b2a20693a40a96564ae49eacf1c5170a2e4fed17a75fd
  • Pointer size: 131 Bytes
  • Size of remote file: 379 kB
images/dinosaur2.png ADDED

Git LFS Details

  • SHA256: 4884a4962da74ba8fe012f9a607ad0bd67b5dc395e369e1775152b60190b0ff0
  • Pointer size: 131 Bytes
  • Size of remote file: 224 kB
images/dinosaur3.png ADDED

Git LFS Details

  • SHA256: 5538c2af5fa121e233d76bce57bf8f04c0e0caa31442c92d6c341d005c4cd926
  • Pointer size: 131 Bytes
  • Size of remote file: 181 kB
images/earth1.png ADDED
images/earth2.png ADDED
images/earth3.png ADDED
images/elephant1.png ADDED

Git LFS Details

  • SHA256: 3571da1c7c7d2ed747ec8ff5e321d0ce0a8541da385e04292cadf180867238ed
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB
images/elephant2.png ADDED

Git LFS Details

  • SHA256: 490f36a2331358b07acd5fc527eb39a4c9511503e6dc0507da766cdfa2d40167
  • Pointer size: 131 Bytes
  • Size of remote file: 241 kB
images/elephant3.png ADDED

Git LFS Details

  • SHA256: 6884671ef41321f20341a11f4f129ea01e74ed70c7a49639cc283a1528c14ca1
  • Pointer size: 131 Bytes
  • Size of remote file: 174 kB
images/hmbb1.png ADDED
images/hmbb2.png ADDED
images/hmbb3.png ADDED

Git LFS Details

  • SHA256: c99894c8f76116940e60b16751fd3954754e6f2e166e46c1db5e00e9d7f0c190
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
images/horse1.png ADDED

Git LFS Details

  • SHA256: 3452a98123fc18f1d0cf4336d13ed89c264ad093bdc126e242583d3a0c8b581d
  • Pointer size: 131 Bytes
  • Size of remote file: 319 kB
images/horse2.png ADDED

Git LFS Details

  • SHA256: 36e8fb7b8f92f1215706c58691dcd53621965062b9a1f71cb6db5b0454e2660a
  • Pointer size: 131 Bytes
  • Size of remote file: 274 kB
images/horse3.png ADDED

Git LFS Details

  • SHA256: b942fe3cf55d4ae103173c852700851597f175abfd43a763a8c5873d61a24f91
  • Pointer size: 131 Bytes
  • Size of remote file: 419 kB
oss_ops_inference.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" HyperAverageMetercorrelation Squeeze testing code """
2
+
3
+ import argparse
4
+ import sys
5
+ import os
6
+ from os.path import join
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torchvision import transforms
12
+ import numpy as np
13
+ from PIL import Image
14
+
15
+ from segment_anything import SamPredictor, SamAutomaticMaskGenerator
16
+
17
+ from gradio_demo.Matcher import Matcher
18
+ from matcher.common import utils
19
+
20
+ import random
21
+ random.seed(0)
22
+
23
+
24
+ def default_argument_parser():
25
+
26
+ # Arguments parsing
27
+ parser = argparse.ArgumentParser(description='Matcher Pytorch Implementation for One-shot Segmentation')
28
+
29
+ # Dataset parameters
30
+ parser.add_argument('--datapath', type=str, default='datasets')
31
+ parser.add_argument('--benchmark', type=str, default='coco',
32
+ choices=['fss', 'coco', 'lvis', 'paco_part', 'pascal_part'])
33
+ parser.add_argument('--bsz', type=int, default=1)
34
+ parser.add_argument('--nworker', type=int, default=0)
35
+ parser.add_argument('--fold', type=int, default=0)
36
+ parser.add_argument('--nshot', type=int, default=1)
37
+ parser.add_argument('--img-size', type=int, default=518)
38
+ parser.add_argument('--use_original_imgsize', action='store_true')
39
+ parser.add_argument('--log-root', type=str, default='output/coco/fold0')
40
+ parser.add_argument('--visualize', type=int, default=0)
41
+
42
+ # DINOv2 and SAM parameters
43
+ parser.add_argument('--dinov2-weights', type=str, default="models/dinov2_vitl14_pretrain.pth")
44
+ parser.add_argument('--sam-weights', type=str, default="models/sam_vit_h_4b8939.pth")
45
+ parser.add_argument('--points_per_side', type=int, default=64)
46
+ parser.add_argument('--pred_iou_thresh', type=float, default=0.88)
47
+ parser.add_argument('--sel_stability_score_thresh', type=float, default=0.0)
48
+ parser.add_argument('--stability_score_thresh', type=float, default=0.95)
49
+ parser.add_argument('--iou_filter', type=float, default=0.0)
50
+ parser.add_argument('--box_nms_thresh', type=float, default=1.0)
51
+ parser.add_argument('--output_layer', type=int, default=3)
52
+ parser.add_argument('--dense_multimask_output', type=int, default=0)
53
+ parser.add_argument('--use_dense_mask', type=int, default=0)
54
+ parser.add_argument('--multimask_output', type=int, default=0)
55
+
56
+ # Matcher parameters
57
+ parser.add_argument('--num_centers', type=int, default=8, help='K centers for kmeans')
58
+ parser.add_argument('--use_box', action='store_true', help='use box as an extra prompt for sam')
59
+ parser.add_argument('--use_points_or_centers', action='store_true', help='points:T, center: F')
60
+ parser.add_argument('--sample-range', type=tuple, default=(4,6), help='sample points number range')
61
+ parser.add_argument('--max_sample_iterations', type=int, default=30)
62
+ parser.add_argument('--alpha', type=float, default=1.)
63
+ parser.add_argument('--beta', type=float, default=0.)
64
+ parser.add_argument('--exp', type=float, default=0.)
65
+ parser.add_argument('--emd_filter', type=float, default=0.0, help='use emd_filter')
66
+ parser.add_argument('--purity_filter', type=float, default=0.0, help='use purity_filter')
67
+ parser.add_argument('--coverage_filter', type=float, default=0.0, help='use coverage_filter')
68
+ parser.add_argument('--use_score_filter', action='store_true')
69
+ parser.add_argument('--deep_score_norm_filter', type=float, default=0.1)
70
+ parser.add_argument('--deep_score_filter', type=float, default=0.33)
71
+ parser.add_argument('--topk_scores_threshold', type=float, default=0.7)
72
+ parser.add_argument('--num_merging_mask', type=int, default=10, help='topk masks for merging')
73
+
74
+ args = parser.parse_args()
75
+ return args
76
+
77
+ def definite_argument_parser(args, version=1):
78
+
79
+ if version==1:
80
+
81
+ args.max_sample_iterations = 64
82
+ args.box_nms_thresh = 0.65
83
+ args.sample_range = (1, 6)
84
+ args.topk_scores_threshold = 0.0
85
+ args.use_dense_mask = 1
86
+ args.use_points_or_centers = True
87
+ args.purity_filter = 0.02
88
+ args.iou_filter = 0.85
89
+ args.multimask_output = 1
90
+ args.sel_stability_score_thresh = 0.90
91
+ args.use_score_filter = True
92
+ args.alpha = 1.0
93
+ args.beta = 0.
94
+ args.exp = 0.
95
+ args.num_merging_mask = 9
96
+ elif version == 2:
97
+ args.max_sample_iterations = 30
98
+ args.sample_range = (4, 6)
99
+ args.multimask_output = 0
100
+ args.alpha = 0.8
101
+ args.beta = 0.2
102
+ args.exp = 1.
103
+ args.num_merging_mask = 10
104
+ elif version == 3:
105
+ args.max_sample_iterations = 128
106
+ args.sample_range = (3, 6)
107
+ args.use_box = True
108
+ args.use_points_or_centers = True
109
+ args.coverage_filter = 0.3
110
+ args.alpha = 0.5
111
+ args.beta = 0.5
112
+ args.exp = 0.
113
+ args.num_merging_mask = 5
114
+
115
+ return args
116
+
117
+ def preprocess_data(kwargs, args=None):
118
+
119
+ img_size = args.img_size
120
+ transform = transforms.Compose([
121
+ transforms.Resize(size=(img_size, img_size)),
122
+ transforms.ToTensor()
123
+ ])
124
+
125
+ support_img = Image.fromarray(kwargs.get("support_img"))
126
+ query_img_1 = Image.fromarray(kwargs.get("query_img_1"))
127
+ query_img_2 = Image.fromarray(kwargs.get("query_img_2"))
128
+
129
+ support_img_ori_size = (support_img.size[1], support_img.size[0]) # H, W
130
+ query_img_1_ori_size = (query_img_1.size[1], query_img_1.size[0])
131
+ query_img_2_ori_size = (query_img_2.size[1], query_img_2.size[0])
132
+
133
+
134
+ support_img = transform(support_img)
135
+ query_img_1 = transform(query_img_1)
136
+ query_img_2 = transform(query_img_2)
137
+
138
+ support_mask = torch.tensor(kwargs.get("support_mask"))
139
+ support_mask = F.interpolate(support_mask.unsqueeze(0).float(), support_img.size()[-2:],
140
+ mode='nearest') > 0
141
+ query_imgs = torch.stack([query_img_1, query_img_2], dim=0)
142
+
143
+ data = {
144
+ "support_img": support_img[None, ...],
145
+ "support_mask": support_mask,
146
+ "query_imgs": query_imgs,
147
+ "support_img_ori_size": support_img_ori_size,
148
+ "query_imgs_ori_size": (query_img_1_ori_size, query_img_2_ori_size),
149
+ }
150
+
151
+ return data
152
+
153
+ def preprocess_support_mask(data, predictor, version=1):
154
+
155
+ if version == 3:
156
+ return data
157
+
158
+ sup_mask = data['support_mask'].squeeze()
159
+ H, W = sup_mask.shape[-2:]
160
+ input_points = sup_mask.nonzero().numpy()[:1,::-1]#[:,::-1]
161
+ input_label = np.array([1]*len(input_points))
162
+
163
+ support_img_np = data['support_img'].mul(255).byte()
164
+ support_img_np = support_img_np.squeeze().permute(1,2,0).cpu().numpy()
165
+
166
+ # forward encoder to obtain image feature
167
+ predictor.reset_image()
168
+ predictor.set_image(support_img_np)
169
+
170
+ # mask, _, _ = predictor.predict(
171
+ # point_coords=input_points,
172
+ # point_labels=input_label,
173
+ # multimask_output=False #True
174
+ # )
175
+ mask, _, _ = predictor.predict(
176
+ point_coords=input_points,
177
+ point_labels=input_label,
178
+ multimask_output=True # True
179
+ )
180
+ predictor.reset_image()
181
+
182
+ # show_img_point_box_mask(
183
+ # support_img_np,
184
+ # masks=mask,
185
+ # save_path='test1.png',
186
+ # mode='mask'
187
+ # )
188
+
189
+ # data['support_mask'] = torch.tensor(mask[:1])[None, ...]
190
+ data['support_mask'] = torch.tensor(mask[-1:])[None, ...]
191
+
192
+ return data
193
+
194
+ def main_oss_ops(**kwargs):
195
+
196
+ args = default_argument_parser()
197
+ args = definite_argument_parser(args, kwargs.get("version"))
198
+
199
+ # Model initialization
200
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
201
+ args.device = device
202
+
203
+ # create sam
204
+ sam = kwargs.get("sam")
205
+ predictor = SamPredictor(sam)
206
+ generator = SamAutomaticMaskGenerator(
207
+ sam,
208
+ points_per_side=args.points_per_side,
209
+ points_per_batch=64,
210
+ pred_iou_thresh=args.pred_iou_thresh,
211
+ stability_score_thresh=args.stability_score_thresh,
212
+ stability_score_offset=1.0,
213
+ sel_stability_score_thresh=args.sel_stability_score_thresh,
214
+ sel_pred_iou_thresh=args.iou_filter,
215
+ box_nms_thresh=args.box_nms_thresh,
216
+ sel_output_layer=args.output_layer,
217
+ output_layer=args.dense_multimask_output,
218
+ dense_pred=args.use_dense_mask,
219
+ multimask_output=args.dense_multimask_output > 0,
220
+ sel_multimask_output=args.multimask_output > 0,
221
+ )
222
+
223
+ # create dinov2, large
224
+ dinov2 = kwargs.get("dinov2")
225
+
226
+ # create matcher
227
+ score_filter_cfg = {
228
+ "emd": args.emd_filter,
229
+ "purity": args.purity_filter,
230
+ "coverage": args.coverage_filter,
231
+ "score_filter": args.use_score_filter,
232
+ "score": args.deep_score_filter,
233
+ "score_norm": args.deep_score_norm_filter,
234
+ "topk_scores_threshold": args.topk_scores_threshold
235
+ }
236
+
237
+ matcher = Matcher(
238
+ encoder=dinov2,
239
+ generator=generator,
240
+ num_centers=args.num_centers,
241
+ use_box=args.use_box,
242
+ use_points_or_centers=args.use_points_or_centers,
243
+ sample_range=args.sample_range,
244
+ max_sample_iterations=args.max_sample_iterations,
245
+ alpha=args.alpha,
246
+ beta=args.beta,
247
+ exp=args.exp,
248
+ score_filter_cfg=score_filter_cfg,
249
+ num_merging_mask=args.num_merging_mask,
250
+ device=args.device
251
+ )
252
+
253
+ # process data
254
+ data = preprocess_data(kwargs, args=args)
255
+ data = preprocess_support_mask(data, predictor, version=kwargs.get("version"))
256
+
257
+ # inference
258
+ with torch.no_grad():
259
+ utils.fix_randseed(0)
260
+ pred_masks, pred_mask_lists = [], []
261
+
262
+ # support mask
263
+ support_img_ori_size = data['support_img_ori_size']
264
+ mask = data['support_mask'].to(predictor.model.device).float()
265
+ mask = F.interpolate(mask, support_img_ori_size, mode="bilinear", align_corners=False) > 0
266
+ mask = mask.squeeze(0).cpu().numpy()
267
+ pred_masks.append(mask)
268
+ pred_mask_lists.append(None)
269
+
270
+ for query_img, query_img_ori_size in zip(data['query_imgs'], data['query_imgs_ori_size']):
271
+ data['query_img'], data['query_img_ori_size'] = query_img[None, ...], query_img_ori_size
272
+
273
+ support_imgs, support_masks = data["support_img"].to(matcher.device)[None, ...], data["support_mask"].to(matcher.device) # (1, 1, 3, H, W), (1, 1, H, W)
274
+ query_img, query_img_ori_size = data['query_img'].to(matcher.device), data['query_img_ori_size'] # (1, 3, H, W), img_size
275
+
276
+ # 1. Matcher prepare references and target
277
+ matcher.set_reference(support_imgs, support_masks)
278
+ matcher.set_target(query_img, query_img_ori_size)
279
+
280
+ # 2. Predict mask of target
281
+ pred_mask, pred_mask_list = matcher.predict()
282
+ matcher.clear()
283
+
284
+ pred_masks.append(pred_mask)
285
+ pred_mask_lists.append(pred_mask_list)
286
+
287
+
288
+ return pred_masks, pred_mask_lists
runner.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import os
5
+ import pathlib
6
+ import shlex
7
+ import shutil
8
+ import subprocess
9
+ import sys
10
+
11
+ import gradio as gr
12
+ import slugify
13
+ import torch
14
+ import numpy as np
15
+ import huggingface_hub
16
+ from huggingface_hub import HfApi
17
+ from omegaconf import OmegaConf
18
+
19
+ from segment_anything import sam_model_registry
20
+ from dinov2.models import vision_transformer as vits
21
+ import dinov2.utils.utils as dinov2_utils
22
+
23
+ from gradio_demo.oss_ops_inference import main_oss_ops
24
+
25
+
26
+ ORIGINAL_SPACE_ID = ''
27
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
28
+
29
+
30
+ class Runner:
31
+ def __init__(self, hf_token: str | None = None):
32
+ self.hf_token = hf_token
33
+
34
+ # self.checkpoint_dir = pathlib.Path('checkpoints')
35
+ # self.checkpoint_dir.mkdir(exist_ok=True)
36
+
37
+ # oss, ops
38
+ self.prompt_res_g = None
39
+ self.prompt_mask_g = None
40
+ self.tar1_res_g = None
41
+ self.tar2_res_g = None
42
+ self.version = 1
43
+
44
+ self.pred_masks = None
45
+ self.pred_mask_lists = None
46
+
47
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
48
+
49
+ sam_checkpoint = "models/sam_vit_h_4b8939.pth"
50
+ model_type = "default"
51
+ self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
52
+ self.sam.to(device=device)
53
+
54
+ dinov2_kwargs = dict(
55
+ img_size=518,
56
+ patch_size=14,
57
+ init_values=1e-5,
58
+ ffn_layer='mlp',
59
+ block_chunks=0,
60
+ qkv_bias=True,
61
+ proj_bias=True,
62
+ ffn_bias=True,
63
+ )
64
+ dinov2 = vits.__dict__["vit_large"](**dinov2_kwargs)
65
+
66
+ dinov2_utils.load_pretrained_weights(dinov2, "models/dinov2_vitl14_pretrain.pth", "teacher")
67
+ dinov2.eval()
68
+ dinov2.to(device=device)
69
+ self.dinov2 = dinov2
70
+
71
+ def inference_oss_ops(self, prompt, target1, target2, version):
72
+
73
+ if version == 'version 1 (πŸ”Ί multiple instances πŸ”» whole, πŸ”» part)':
74
+ self.prompt_res_g, self.tar1_res_g, self.tar2_res_g = prompt['image'], target1, target2
75
+ self.prompt_mask_g = (prompt['mask'][..., 0] != 0)[None, ...] # 1, H, w
76
+ self.version = 1
77
+ elif version == 'version 2 (πŸ”» multiple instances πŸ”Ί whole, πŸ”» part)':
78
+ self.prompt_res_g, self.tar1_res_g, self.tar2_res_g = prompt['image'], target1, target2
79
+ self.prompt_mask_g = (prompt['mask'][..., 0] != 0)[None, ...] # 1, H, w
80
+ self.version = 2
81
+ else:
82
+ self.prompt_res_g, self.tar1_res_g, self.tar2_res_g = prompt['image'], target1, target2
83
+ self.prompt_mask_g = (prompt['mask'][..., 0] != 0)[None, ...] # 1, H, w
84
+ self.version = 3
85
+
86
+ self.pred_masks, self.pred_mask_lists = main_oss_ops(
87
+ sam=self.sam,
88
+ dinov2=self.dinov2,
89
+ support_img=self.prompt_res_g,
90
+ support_mask=self.prompt_mask_g,
91
+ query_img_1=self.tar1_res_g,
92
+ query_img_2=self.tar2_res_g,
93
+ version=self.version
94
+ )
95
+
96
+ text = "Process Successful!"
97
+
98
+ return text
99
+
100
+
101
+ def clear_fn(self):
102
+
103
+ self.prompt_res_g, self.tar1_res_g, self.tar2_res_g, self.prompt_mask_g = None, None, None, None
104
+ self.version = 1
105
+ self.pred_masks = None
106
+ self.pred_mask_lists = None
107
+
108
+ return [None] * 7
109
+
110
+
111
+ def controllable_mask_output(self, k):
112
+
113
+ color = np.array([30, 144, 255])
114
+
115
+ if self.version != 1:
116
+
117
+ prompt_mask_res, tar1_mask_res, tar2_mask_res = self.pred_masks
118
+
119
+ h, w = prompt_mask_res.shape[-2:]
120
+ prompt_mask_img = prompt_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
121
+ prompt_mask_res = self.prompt_res_g * 0.5 + prompt_mask_img * 0.5
122
+
123
+ h, w = tar1_mask_res.shape[-2:]
124
+ tar1_mask_img = tar1_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
125
+ tar1_mask_res = self.tar1_res_g * 0.5 + tar1_mask_img * 0.5
126
+
127
+ h, w = tar2_mask_res.shape[-2:]
128
+ tar2_mask_img = tar2_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
129
+ tar2_mask_res = self.tar2_res_g * 0.5 + tar2_mask_img * 0.5
130
+
131
+ else:
132
+ prompt_mask_res = self.pred_masks[0]
133
+ tar1_mask_res, tar2_mask_res = self.pred_mask_lists[1:]
134
+
135
+ tar1_mask_res = tar1_mask_res[:min(k, len(tar1_mask_res))].sum(0)>0
136
+ tar2_mask_res = tar2_mask_res[:min(k, len(tar2_mask_res))].sum(0) > 0
137
+
138
+ h, w = prompt_mask_res.shape[-2:]
139
+ prompt_mask_img = prompt_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
140
+ prompt_mask_res = self.prompt_res_g * 0.5 + prompt_mask_img * 0.5
141
+
142
+ h, w = tar1_mask_res.shape[-2:]
143
+ tar1_mask_img = tar1_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
144
+ tar1_mask_res = self.tar1_res_g * 0.5 + tar1_mask_img * 0.5
145
+
146
+ h, w = tar2_mask_res.shape[-2:]
147
+ tar2_mask_img = tar2_mask_res.reshape(h, w, 1) * color.reshape(1, 1, -1)
148
+ tar2_mask_res = self.tar2_res_g * 0.5 + tar2_mask_img * 0.5
149
+
150
+ return prompt_mask_res/255, tar1_mask_res/255, tar2_mask_res/255
151
+
152
+
153
+ def inference_vos(self, prompt_vid, vid):
154
+
155
+ raise NotImplementedError
156
+