curt-park's picture
Refactor code
1615d09
raw
history blame
5.97 kB
from copy import deepcopy
import numpy as np
from albumentations import ReplayCompose
from isegm.data.transforms import remove_image_only_transforms
from isegm.utils.misc import get_labels_with_sizes
class DSample:
def __init__(
self,
image,
encoded_masks,
objects=None,
objects_ids=None,
ignore_ids=None,
sample_id=None,
):
self.image = image
self.sample_id = sample_id
if len(encoded_masks.shape) == 2:
encoded_masks = encoded_masks[:, :, np.newaxis]
self._encoded_masks = encoded_masks
self._ignored_regions = []
if objects_ids is not None:
if not objects_ids or not isinstance(objects_ids[0], tuple):
assert encoded_masks.shape[2] == 1
objects_ids = [(0, obj_id) for obj_id in objects_ids]
self._objects = dict()
for indx, obj_mapping in enumerate(objects_ids):
self._objects[indx] = {
"parent": None,
"mapping": obj_mapping,
"children": [],
}
if ignore_ids:
if isinstance(ignore_ids[0], tuple):
self._ignored_regions = ignore_ids
else:
self._ignored_regions = [(0, region_id) for region_id in ignore_ids]
else:
self._objects = deepcopy(objects)
self._augmented = False
self._soft_mask_aug = None
self._original_data = self.image, self._encoded_masks, deepcopy(self._objects)
def augment(self, augmentator):
self.reset_augmentation()
aug_output = augmentator(image=self.image, mask=self._encoded_masks)
self.image = aug_output["image"]
self._encoded_masks = aug_output["mask"]
aug_replay = aug_output.get("replay", None)
if aug_replay:
assert len(self._ignored_regions) == 0
mask_replay = remove_image_only_transforms(aug_replay)
self._soft_mask_aug = ReplayCompose._restore_for_replay(mask_replay)
self._compute_objects_areas()
self.remove_small_objects(min_area=1)
self._augmented = True
def reset_augmentation(self):
if not self._augmented:
return
orig_image, orig_masks, orig_objects = self._original_data
self.image = orig_image
self._encoded_masks = orig_masks
self._objects = deepcopy(orig_objects)
self._augmented = False
self._soft_mask_aug = None
def remove_small_objects(self, min_area):
if self._objects and not "area" in list(self._objects.values())[0]:
self._compute_objects_areas()
for obj_id, obj_info in list(self._objects.items()):
if obj_info["area"] < min_area:
self._remove_object(obj_id)
def get_object_mask(self, obj_id):
layer_indx, mask_id = self._objects[obj_id]["mapping"]
obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32)
if self._ignored_regions:
for layer_indx, mask_id in self._ignored_regions:
ignore_mask = self._encoded_masks[:, :, layer_indx] == mask_id
obj_mask[ignore_mask] = -1
return obj_mask
def get_soft_object_mask(self, obj_id):
assert self._soft_mask_aug is not None
original_encoded_masks = self._original_data[1]
layer_indx, mask_id = self._objects[obj_id]["mapping"]
obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(
np.float32
)
obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)[
"image"
]
return np.clip(obj_mask, 0, 1)
def get_background_mask(self):
return np.max(self._encoded_masks, axis=2) == 0
@property
def objects_ids(self):
return list(self._objects.keys())
@property
def gt_mask(self):
assert len(self._objects) == 1
return self.get_object_mask(self.objects_ids[0])
@property
def root_objects(self):
return [
obj_id
for obj_id, obj_info in self._objects.items()
if obj_info["parent"] is None
]
def _compute_objects_areas(self):
inverse_index = {
node["mapping"]: node_id for node_id, node in self._objects.items()
}
ignored_regions_keys = set(self._ignored_regions)
for layer_indx in range(self._encoded_masks.shape[2]):
objects_ids, objects_areas = get_labels_with_sizes(
self._encoded_masks[:, :, layer_indx]
)
for obj_id, obj_area in zip(objects_ids, objects_areas):
inv_key = (layer_indx, obj_id)
if inv_key in ignored_regions_keys:
continue
try:
self._objects[inverse_index[inv_key]]["area"] = obj_area
del inverse_index[inv_key]
except KeyError:
layer = self._encoded_masks[:, :, layer_indx]
layer[layer == obj_id] = 0
self._encoded_masks[:, :, layer_indx] = layer
for obj_id in inverse_index.values():
self._objects[obj_id]["area"] = 0
def _remove_object(self, obj_id):
obj_info = self._objects[obj_id]
obj_parent = obj_info["parent"]
for child_id in obj_info["children"]:
self._objects[child_id]["parent"] = obj_parent
if obj_parent is not None:
parent_children = self._objects[obj_parent]["children"]
parent_children = [x for x in parent_children if x != obj_id]
self._objects[obj_parent]["children"] = (
parent_children + obj_info["children"]
)
del self._objects[obj_id]
def __len__(self):
return len(self._objects)