Spaces:
Runtime error
Runtime error
from copy import deepcopy | |
import numpy as np | |
from albumentations import ReplayCompose | |
from isegm.data.transforms import remove_image_only_transforms | |
from isegm.utils.misc import get_labels_with_sizes | |
class DSample: | |
def __init__( | |
self, | |
image, | |
encoded_masks, | |
objects=None, | |
objects_ids=None, | |
ignore_ids=None, | |
sample_id=None, | |
): | |
self.image = image | |
self.sample_id = sample_id | |
if len(encoded_masks.shape) == 2: | |
encoded_masks = encoded_masks[:, :, np.newaxis] | |
self._encoded_masks = encoded_masks | |
self._ignored_regions = [] | |
if objects_ids is not None: | |
if not objects_ids or not isinstance(objects_ids[0], tuple): | |
assert encoded_masks.shape[2] == 1 | |
objects_ids = [(0, obj_id) for obj_id in objects_ids] | |
self._objects = dict() | |
for indx, obj_mapping in enumerate(objects_ids): | |
self._objects[indx] = { | |
"parent": None, | |
"mapping": obj_mapping, | |
"children": [], | |
} | |
if ignore_ids: | |
if isinstance(ignore_ids[0], tuple): | |
self._ignored_regions = ignore_ids | |
else: | |
self._ignored_regions = [(0, region_id) for region_id in ignore_ids] | |
else: | |
self._objects = deepcopy(objects) | |
self._augmented = False | |
self._soft_mask_aug = None | |
self._original_data = self.image, self._encoded_masks, deepcopy(self._objects) | |
def augment(self, augmentator): | |
self.reset_augmentation() | |
aug_output = augmentator(image=self.image, mask=self._encoded_masks) | |
self.image = aug_output["image"] | |
self._encoded_masks = aug_output["mask"] | |
aug_replay = aug_output.get("replay", None) | |
if aug_replay: | |
assert len(self._ignored_regions) == 0 | |
mask_replay = remove_image_only_transforms(aug_replay) | |
self._soft_mask_aug = ReplayCompose._restore_for_replay(mask_replay) | |
self._compute_objects_areas() | |
self.remove_small_objects(min_area=1) | |
self._augmented = True | |
def reset_augmentation(self): | |
if not self._augmented: | |
return | |
orig_image, orig_masks, orig_objects = self._original_data | |
self.image = orig_image | |
self._encoded_masks = orig_masks | |
self._objects = deepcopy(orig_objects) | |
self._augmented = False | |
self._soft_mask_aug = None | |
def remove_small_objects(self, min_area): | |
if self._objects and not "area" in list(self._objects.values())[0]: | |
self._compute_objects_areas() | |
for obj_id, obj_info in list(self._objects.items()): | |
if obj_info["area"] < min_area: | |
self._remove_object(obj_id) | |
def get_object_mask(self, obj_id): | |
layer_indx, mask_id = self._objects[obj_id]["mapping"] | |
obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32) | |
if self._ignored_regions: | |
for layer_indx, mask_id in self._ignored_regions: | |
ignore_mask = self._encoded_masks[:, :, layer_indx] == mask_id | |
obj_mask[ignore_mask] = -1 | |
return obj_mask | |
def get_soft_object_mask(self, obj_id): | |
assert self._soft_mask_aug is not None | |
original_encoded_masks = self._original_data[1] | |
layer_indx, mask_id = self._objects[obj_id]["mapping"] | |
obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype( | |
np.float32 | |
) | |
obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)[ | |
"image" | |
] | |
return np.clip(obj_mask, 0, 1) | |
def get_background_mask(self): | |
return np.max(self._encoded_masks, axis=2) == 0 | |
def objects_ids(self): | |
return list(self._objects.keys()) | |
def gt_mask(self): | |
assert len(self._objects) == 1 | |
return self.get_object_mask(self.objects_ids[0]) | |
def root_objects(self): | |
return [ | |
obj_id | |
for obj_id, obj_info in self._objects.items() | |
if obj_info["parent"] is None | |
] | |
def _compute_objects_areas(self): | |
inverse_index = { | |
node["mapping"]: node_id for node_id, node in self._objects.items() | |
} | |
ignored_regions_keys = set(self._ignored_regions) | |
for layer_indx in range(self._encoded_masks.shape[2]): | |
objects_ids, objects_areas = get_labels_with_sizes( | |
self._encoded_masks[:, :, layer_indx] | |
) | |
for obj_id, obj_area in zip(objects_ids, objects_areas): | |
inv_key = (layer_indx, obj_id) | |
if inv_key in ignored_regions_keys: | |
continue | |
try: | |
self._objects[inverse_index[inv_key]]["area"] = obj_area | |
del inverse_index[inv_key] | |
except KeyError: | |
layer = self._encoded_masks[:, :, layer_indx] | |
layer[layer == obj_id] = 0 | |
self._encoded_masks[:, :, layer_indx] = layer | |
for obj_id in inverse_index.values(): | |
self._objects[obj_id]["area"] = 0 | |
def _remove_object(self, obj_id): | |
obj_info = self._objects[obj_id] | |
obj_parent = obj_info["parent"] | |
for child_id in obj_info["children"]: | |
self._objects[child_id]["parent"] = obj_parent | |
if obj_parent is not None: | |
parent_children = self._objects[obj_parent]["children"] | |
parent_children = [x for x in parent_children if x != obj_id] | |
self._objects[obj_parent]["children"] = ( | |
parent_children + obj_info["children"] | |
) | |
del self._objects[obj_id] | |
def __len__(self): | |
return len(self._objects) | |