import pickle import random import numpy as np import torch from torchvision import transforms from .points_sampler import MultiPointSampler from .sample import DSample class ISDataset(torch.utils.data.dataset.Dataset): def __init__( self, augmentator=None, points_sampler=MultiPointSampler(max_num_points=12), min_object_area=0, keep_background_prob=0.0, with_image_info=False, samples_scores_path=None, samples_scores_gamma=1.0, epoch_len=-1, ): super(ISDataset, self).__init__() self.epoch_len = epoch_len self.augmentator = augmentator self.min_object_area = min_object_area self.keep_background_prob = keep_background_prob self.points_sampler = points_sampler self.with_image_info = with_image_info self.samples_precomputed_scores = self._load_samples_scores( samples_scores_path, samples_scores_gamma ) self.to_tensor = transforms.ToTensor() self.dataset_samples = None def __getitem__(self, index): if self.samples_precomputed_scores is not None: index = np.random.choice( self.samples_precomputed_scores["indices"], p=self.samples_precomputed_scores["probs"], ) else: if self.epoch_len > 0: index = random.randrange(0, len(self.dataset_samples)) sample = self.get_sample(index) sample = self.augment_sample(sample) sample.remove_small_objects(self.min_object_area) self.points_sampler.sample_object(sample) points = np.array(self.points_sampler.sample_points()) mask = self.points_sampler.selected_mask output = { "images": self.to_tensor(sample.image), "points": points.astype(np.float32), "instances": mask, } if self.with_image_info: output["image_info"] = sample.sample_id return output def augment_sample(self, sample) -> DSample: if self.augmentator is None: return sample valid_augmentation = False while not valid_augmentation: sample.augment(self.augmentator) keep_sample = ( self.keep_background_prob < 0.0 or random.random() < self.keep_background_prob ) valid_augmentation = len(sample) > 0 or keep_sample return sample def get_sample(self, index) -> DSample: raise NotImplementedError def __len__(self): if self.epoch_len > 0: return self.epoch_len else: return self.get_samples_number() def get_samples_number(self): return len(self.dataset_samples) @staticmethod def _load_samples_scores(samples_scores_path, samples_scores_gamma): if samples_scores_path is None: return None with open(samples_scores_path, "rb") as f: images_scores = pickle.load(f) probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores]) probs /= probs.sum() samples_scores = {"indices": [x[0] for x in images_scores], "probs": probs} print(f"Loaded {len(probs)} weights with gamma={samples_scores_gamma}") return samples_scores