Spaces:
Runtime error
Runtime error
import pickle | |
import random | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
from .points_sampler import MultiPointSampler | |
from .sample import DSample | |
class ISDataset(torch.utils.data.dataset.Dataset): | |
def __init__( | |
self, | |
augmentator=None, | |
points_sampler=MultiPointSampler(max_num_points=12), | |
min_object_area=0, | |
keep_background_prob=0.0, | |
with_image_info=False, | |
samples_scores_path=None, | |
samples_scores_gamma=1.0, | |
epoch_len=-1, | |
): | |
super(ISDataset, self).__init__() | |
self.epoch_len = epoch_len | |
self.augmentator = augmentator | |
self.min_object_area = min_object_area | |
self.keep_background_prob = keep_background_prob | |
self.points_sampler = points_sampler | |
self.with_image_info = with_image_info | |
self.samples_precomputed_scores = self._load_samples_scores( | |
samples_scores_path, samples_scores_gamma | |
) | |
self.to_tensor = transforms.ToTensor() | |
self.dataset_samples = None | |
def __getitem__(self, index): | |
if self.samples_precomputed_scores is not None: | |
index = np.random.choice( | |
self.samples_precomputed_scores["indices"], | |
p=self.samples_precomputed_scores["probs"], | |
) | |
else: | |
if self.epoch_len > 0: | |
index = random.randrange(0, len(self.dataset_samples)) | |
sample = self.get_sample(index) | |
sample = self.augment_sample(sample) | |
sample.remove_small_objects(self.min_object_area) | |
self.points_sampler.sample_object(sample) | |
points = np.array(self.points_sampler.sample_points()) | |
mask = self.points_sampler.selected_mask | |
output = { | |
"images": self.to_tensor(sample.image), | |
"points": points.astype(np.float32), | |
"instances": mask, | |
} | |
if self.with_image_info: | |
output["image_info"] = sample.sample_id | |
return output | |
def augment_sample(self, sample) -> DSample: | |
if self.augmentator is None: | |
return sample | |
valid_augmentation = False | |
while not valid_augmentation: | |
sample.augment(self.augmentator) | |
keep_sample = ( | |
self.keep_background_prob < 0.0 | |
or random.random() < self.keep_background_prob | |
) | |
valid_augmentation = len(sample) > 0 or keep_sample | |
return sample | |
def get_sample(self, index) -> DSample: | |
raise NotImplementedError | |
def __len__(self): | |
if self.epoch_len > 0: | |
return self.epoch_len | |
else: | |
return self.get_samples_number() | |
def get_samples_number(self): | |
return len(self.dataset_samples) | |
def _load_samples_scores(samples_scores_path, samples_scores_gamma): | |
if samples_scores_path is None: | |
return None | |
with open(samples_scores_path, "rb") as f: | |
images_scores = pickle.load(f) | |
probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores]) | |
probs /= probs.sum() | |
samples_scores = {"indices": [x[0] for x in images_scores], "probs": probs} | |
print(f"Loaded {len(probs)} weights with gamma={samples_scores_gamma}") | |
return samples_scores | |