Spaces:
Runtime error
Runtime error
import math | |
import random | |
from functools import lru_cache | |
import cv2 | |
import numpy as np | |
from .sample import DSample | |
class BasePointSampler: | |
def __init__(self): | |
self._selected_mask = None | |
self._selected_masks = None | |
def sample_object(self, sample: DSample): | |
raise NotImplementedError | |
def sample_points(self): | |
raise NotImplementedError | |
def selected_mask(self): | |
assert self._selected_mask is not None | |
return self._selected_mask | |
def selected_mask(self, mask): | |
self._selected_mask = mask[np.newaxis, :].astype(np.float32) | |
class MultiPointSampler(BasePointSampler): | |
def __init__( | |
self, | |
max_num_points, | |
prob_gamma=0.7, | |
expand_ratio=0.1, | |
positive_erode_prob=0.9, | |
positive_erode_iters=3, | |
negative_bg_prob=0.1, | |
negative_other_prob=0.4, | |
negative_border_prob=0.5, | |
merge_objects_prob=0.0, | |
max_num_merged_objects=2, | |
use_hierarchy=False, | |
soft_targets=False, | |
first_click_center=False, | |
only_one_first_click=False, | |
sfc_inner_k=1.7, | |
sfc_full_inner_prob=0.0, | |
): | |
super().__init__() | |
self.max_num_points = max_num_points | |
self.expand_ratio = expand_ratio | |
self.positive_erode_prob = positive_erode_prob | |
self.positive_erode_iters = positive_erode_iters | |
self.merge_objects_prob = merge_objects_prob | |
self.use_hierarchy = use_hierarchy | |
self.soft_targets = soft_targets | |
self.first_click_center = first_click_center | |
self.only_one_first_click = only_one_first_click | |
self.sfc_inner_k = sfc_inner_k | |
self.sfc_full_inner_prob = sfc_full_inner_prob | |
if max_num_merged_objects == -1: | |
max_num_merged_objects = max_num_points | |
self.max_num_merged_objects = max_num_merged_objects | |
self.neg_strategies = ["bg", "other", "border"] | |
self.neg_strategies_prob = [ | |
negative_bg_prob, | |
negative_other_prob, | |
negative_border_prob, | |
] | |
assert math.isclose(sum(self.neg_strategies_prob), 1.0) | |
self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma) | |
self._neg_probs = generate_probs(max_num_points + 1, gamma=prob_gamma) | |
self._neg_masks = None | |
def sample_object(self, sample: DSample): | |
if len(sample) == 0: | |
bg_mask = sample.get_background_mask() | |
self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32) | |
self._selected_masks = [[]] | |
self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies} | |
self._neg_masks["required"] = [] | |
return | |
gt_mask, pos_masks, neg_masks = self._sample_mask(sample) | |
binary_gt_mask = gt_mask > 0.5 if self.soft_targets else gt_mask > 0 | |
self.selected_mask = gt_mask | |
self._selected_masks = pos_masks | |
neg_mask_bg = np.logical_not(binary_gt_mask) | |
neg_mask_border = self._get_border_mask(binary_gt_mask) | |
if len(sample) <= len(self._selected_masks): | |
neg_mask_other = neg_mask_bg | |
else: | |
neg_mask_other = np.logical_and( | |
np.logical_not(sample.get_background_mask()), | |
np.logical_not(binary_gt_mask), | |
) | |
self._neg_masks = { | |
"bg": neg_mask_bg, | |
"other": neg_mask_other, | |
"border": neg_mask_border, | |
"required": neg_masks, | |
} | |
def _sample_mask(self, sample: DSample): | |
root_obj_ids = sample.root_objects | |
if len(root_obj_ids) > 1 and random.random() < self.merge_objects_prob: | |
max_selected_objects = min(len(root_obj_ids), self.max_num_merged_objects) | |
num_selected_objects = np.random.randint(2, max_selected_objects + 1) | |
random_ids = random.sample(root_obj_ids, num_selected_objects) | |
else: | |
random_ids = [random.choice(root_obj_ids)] | |
gt_mask = None | |
pos_segments = [] | |
neg_segments = [] | |
for obj_id in random_ids: | |
( | |
obj_gt_mask, | |
obj_pos_segments, | |
obj_neg_segments, | |
) = self._sample_from_masks_layer(obj_id, sample) | |
if gt_mask is None: | |
gt_mask = obj_gt_mask | |
else: | |
gt_mask = np.maximum(gt_mask, obj_gt_mask) | |
pos_segments.extend(obj_pos_segments) | |
neg_segments.extend(obj_neg_segments) | |
pos_masks = [self._positive_erode(x) for x in pos_segments] | |
neg_masks = [self._positive_erode(x) for x in neg_segments] | |
return gt_mask, pos_masks, neg_masks | |
def _sample_from_masks_layer(self, obj_id, sample: DSample): | |
objs_tree = sample._objects | |
if not self.use_hierarchy: | |
node_mask = sample.get_object_mask(obj_id) | |
gt_mask = ( | |
sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask | |
) | |
return gt_mask, [node_mask], [] | |
def _select_node(node_id): | |
node_info = objs_tree[node_id] | |
if not node_info["children"] or random.random() < 0.5: | |
return node_id | |
return _select_node(random.choice(node_info["children"])) | |
selected_node = _select_node(obj_id) | |
node_info = objs_tree[selected_node] | |
node_mask = sample.get_object_mask(selected_node) | |
gt_mask = ( | |
sample.get_soft_object_mask(selected_node) | |
if self.soft_targets | |
else node_mask | |
) | |
pos_mask = node_mask.copy() | |
negative_segments = [] | |
if node_info["parent"] is not None and node_info["parent"] in objs_tree: | |
parent_mask = sample.get_object_mask(node_info["parent"]) | |
negative_segments.append( | |
np.logical_and(parent_mask, np.logical_not(node_mask)) | |
) | |
for child_id in node_info["children"]: | |
if objs_tree[child_id]["area"] / node_info["area"] < 0.10: | |
child_mask = sample.get_object_mask(child_id) | |
pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask)) | |
if node_info["children"]: | |
max_disabled_children = min(len(node_info["children"]), 3) | |
num_disabled_children = np.random.randint(0, max_disabled_children + 1) | |
disabled_children = random.sample( | |
node_info["children"], num_disabled_children | |
) | |
for child_id in disabled_children: | |
child_mask = sample.get_object_mask(child_id) | |
pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask)) | |
if self.soft_targets: | |
soft_child_mask = sample.get_soft_object_mask(child_id) | |
gt_mask = np.minimum(gt_mask, 1.0 - soft_child_mask) | |
else: | |
gt_mask = np.logical_and(gt_mask, np.logical_not(child_mask)) | |
negative_segments.append(child_mask) | |
return gt_mask, [pos_mask], negative_segments | |
def sample_points(self): | |
assert self._selected_mask is not None | |
pos_points = self._multi_mask_sample_points( | |
self._selected_masks, | |
is_negative=[False] * len(self._selected_masks), | |
with_first_click=self.first_click_center, | |
) | |
neg_strategy = [ | |
(self._neg_masks[k], prob) | |
for k, prob in zip(self.neg_strategies, self.neg_strategies_prob) | |
] | |
neg_masks = self._neg_masks["required"] + [neg_strategy] | |
neg_points = self._multi_mask_sample_points( | |
neg_masks, is_negative=[False] * len(self._neg_masks["required"]) + [True] | |
) | |
return pos_points + neg_points | |
def _multi_mask_sample_points( | |
self, selected_masks, is_negative, with_first_click=False | |
): | |
selected_masks = selected_masks[: self.max_num_points] | |
each_obj_points = [ | |
self._sample_points( | |
mask, is_negative=is_negative[i], with_first_click=with_first_click | |
) | |
for i, mask in enumerate(selected_masks) | |
] | |
each_obj_points = [x for x in each_obj_points if len(x) > 0] | |
points = [] | |
if len(each_obj_points) == 1: | |
points = each_obj_points[0] | |
elif len(each_obj_points) > 1: | |
if self.only_one_first_click: | |
each_obj_points = each_obj_points[:1] | |
points = [obj_points[0] for obj_points in each_obj_points] | |
aggregated_masks_with_prob = [] | |
for indx, x in enumerate(selected_masks): | |
if ( | |
isinstance(x, (list, tuple)) | |
and x | |
and isinstance(x[0], (list, tuple)) | |
): | |
for t, prob in x: | |
aggregated_masks_with_prob.append( | |
(t, prob / len(selected_masks)) | |
) | |
else: | |
aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks))) | |
other_points_union = self._sample_points( | |
aggregated_masks_with_prob, is_negative=True | |
) | |
if len(other_points_union) + len(points) <= self.max_num_points: | |
points.extend(other_points_union) | |
else: | |
points.extend( | |
random.sample(other_points_union, self.max_num_points - len(points)) | |
) | |
if len(points) < self.max_num_points: | |
points.extend([(-1, -1, -1)] * (self.max_num_points - len(points))) | |
return points | |
def _sample_points(self, mask, is_negative=False, with_first_click=False): | |
if is_negative: | |
num_points = np.random.choice( | |
np.arange(self.max_num_points + 1), p=self._neg_probs | |
) | |
else: | |
num_points = 1 + np.random.choice( | |
np.arange(self.max_num_points), p=self._pos_probs | |
) | |
indices_probs = None | |
if isinstance(mask, (list, tuple)): | |
indices_probs = [x[1] for x in mask] | |
indices = [(np.argwhere(x), prob) for x, prob in mask] | |
if indices_probs: | |
assert math.isclose(sum(indices_probs), 1.0) | |
else: | |
indices = np.argwhere(mask) | |
points = [] | |
for j in range(num_points): | |
first_click = with_first_click and j == 0 and indices_probs is None | |
if first_click: | |
point_indices = get_point_candidates( | |
mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob | |
) | |
elif indices_probs: | |
point_indices_indx = np.random.choice( | |
np.arange(len(indices)), p=indices_probs | |
) | |
point_indices = indices[point_indices_indx][0] | |
else: | |
point_indices = indices | |
num_indices = len(point_indices) | |
if num_indices > 0: | |
point_indx = 0 if first_click else 100 | |
click = point_indices[np.random.randint(0, num_indices)].tolist() + [ | |
point_indx | |
] | |
points.append(click) | |
return points | |
def _positive_erode(self, mask): | |
if random.random() > self.positive_erode_prob: | |
return mask | |
kernel = np.ones((3, 3), np.uint8) | |
eroded_mask = cv2.erode( | |
mask.astype(np.uint8), kernel, iterations=self.positive_erode_iters | |
).astype(np.bool) | |
if eroded_mask.sum() > 10: | |
return eroded_mask | |
else: | |
return mask | |
def _get_border_mask(self, mask): | |
expand_r = int(np.ceil(self.expand_ratio * np.sqrt(mask.sum()))) | |
kernel = np.ones((3, 3), np.uint8) | |
expanded_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=expand_r) | |
expanded_mask[mask.astype(np.bool)] = 0 | |
return expanded_mask | |
def generate_probs(max_num_points, gamma): | |
probs = [] | |
last_value = 1 | |
for i in range(max_num_points): | |
probs.append(last_value) | |
last_value *= gamma | |
probs = np.array(probs) | |
probs /= probs.sum() | |
return probs | |
def get_point_candidates(obj_mask, k=1.7, full_prob=0.0): | |
if full_prob > 0 and random.random() < full_prob: | |
return obj_mask | |
padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), "constant") | |
dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1] | |
if k > 0: | |
inner_mask = dt > dt.max() / k | |
return np.argwhere(inner_mask) | |
else: | |
prob_map = dt.flatten() | |
prob_map /= max(prob_map.sum(), 1e-6) | |
click_indx = np.random.choice(len(prob_map), p=prob_map) | |
click_coords = np.unravel_index(click_indx, dt.shape) | |
return np.array([click_coords]) | |