import random import cv2 import numpy as np from albumentations import DualTransform, ImageOnlyTransform from albumentations.augmentations import functional as F from albumentations.core.serialization import SERIALIZABLE_REGISTRY from albumentations.core.transforms_interface import to_tuple from isegm.utils.misc import (clamp_bbox, expand_bbox, get_bbox_from_mask, get_labels_with_sizes) class UniformRandomResize(DualTransform): def __init__( self, scale_range=(0.9, 1.1), interpolation=cv2.INTER_LINEAR, always_apply=False, p=1, ): super().__init__(always_apply, p) self.scale_range = scale_range self.interpolation = interpolation def get_params_dependent_on_targets(self, params): scale = random.uniform(*self.scale_range) height = int(round(params["image"].shape[0] * scale)) width = int(round(params["image"].shape[1] * scale)) return {"new_height": height, "new_width": width} def apply( self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params ): return F.resize( img, height=new_height, width=new_width, interpolation=interpolation ) def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params): scale_x = new_width / params["cols"] scale_y = new_height / params["rows"] return F.keypoint_scale(keypoint, scale_x, scale_y) def get_transform_init_args_names(self): return "scale_range", "interpolation" @property def targets_as_params(self): return ["image"] class ZoomIn(DualTransform): def __init__( self, height, width, bbox_jitter=0.1, expansion_ratio=1.4, min_crop_size=200, min_area=100, always_resize=False, always_apply=False, p=0.5, ): super(ZoomIn, self).__init__(always_apply, p) self.height = height self.width = width self.bbox_jitter = to_tuple(bbox_jitter) self.expansion_ratio = expansion_ratio self.min_crop_size = min_crop_size self.min_area = min_area self.always_resize = always_resize def apply(self, img, selected_object, bbox, **params): if selected_object is None: if self.always_resize: img = F.resize(img, height=self.height, width=self.width) return img rmin, rmax, cmin, cmax = bbox img = img[rmin : rmax + 1, cmin : cmax + 1] img = F.resize(img, height=self.height, width=self.width) return img def apply_to_mask(self, mask, selected_object, bbox, **params): if selected_object is None: if self.always_resize: mask = F.resize( mask, height=self.height, width=self.width, interpolation=cv2.INTER_NEAREST, ) return mask rmin, rmax, cmin, cmax = bbox mask = mask[rmin : rmax + 1, cmin : cmax + 1] if isinstance(selected_object, tuple): layer_indx, mask_id = selected_object obj_mask = mask[:, :, layer_indx] == mask_id new_mask = np.zeros_like(mask) new_mask[:, :, layer_indx][obj_mask] = mask_id else: obj_mask = mask == selected_object new_mask = mask.copy() new_mask[np.logical_not(obj_mask)] = 0 new_mask = F.resize( new_mask, height=self.height, width=self.width, interpolation=cv2.INTER_NEAREST, ) return new_mask def get_params_dependent_on_targets(self, params): instances = params["mask"] is_mask_layer = len(instances.shape) > 2 candidates = [] if is_mask_layer: for layer_indx in range(instances.shape[2]): labels, areas = get_labels_with_sizes(instances[:, :, layer_indx]) candidates.extend( [ (layer_indx, obj_id) for obj_id, area in zip(labels, areas) if area > self.min_area ] ) else: labels, areas = get_labels_with_sizes(instances) candidates = [ obj_id for obj_id, area in zip(labels, areas) if area > self.min_area ] selected_object = None bbox = None if candidates: selected_object = random.choice(candidates) if is_mask_layer: layer_indx, mask_id = selected_object obj_mask = instances[:, :, layer_indx] == mask_id else: obj_mask = instances == selected_object bbox = get_bbox_from_mask(obj_mask) if isinstance(self.expansion_ratio, tuple): expansion_ratio = random.uniform(*self.expansion_ratio) else: expansion_ratio = self.expansion_ratio bbox = expand_bbox(bbox, expansion_ratio, self.min_crop_size) bbox = self._jitter_bbox(bbox) bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1) return {"selected_object": selected_object, "bbox": bbox} def _jitter_bbox(self, bbox): rmin, rmax, cmin, cmax = bbox height = rmax - rmin + 1 width = cmax - cmin + 1 rmin = int(rmin + random.uniform(*self.bbox_jitter) * height) rmax = int(rmax + random.uniform(*self.bbox_jitter) * height) cmin = int(cmin + random.uniform(*self.bbox_jitter) * width) cmax = int(cmax + random.uniform(*self.bbox_jitter) * width) return rmin, rmax, cmin, cmax def apply_to_bbox(self, bbox, **params): raise NotImplementedError def apply_to_keypoint(self, keypoint, **params): raise NotImplementedError @property def targets_as_params(self): return ["mask"] def get_transform_init_args_names(self): return ( "height", "width", "bbox_jitter", "expansion_ratio", "min_crop_size", "min_area", "always_resize", ) def remove_image_only_transforms(sdict): if not "transforms" in sdict: return sdict keep_transforms = [] for tdict in sdict["transforms"]: cls = SERIALIZABLE_REGISTRY[tdict["__class_fullname__"]] if "transforms" in tdict: keep_transforms.append(remove_image_only_transforms(tdict)) elif not issubclass(cls, ImageOnlyTransform): keep_transforms.append(tdict) sdict["transforms"] = keep_transforms return sdict