curt-park's picture
Refactor code
1615d09
raw
history blame
3.36 kB
import pickle
import random
import numpy as np
import torch
from torchvision import transforms
from .points_sampler import MultiPointSampler
from .sample import DSample
class ISDataset(torch.utils.data.dataset.Dataset):
def __init__(
self,
augmentator=None,
points_sampler=MultiPointSampler(max_num_points=12),
min_object_area=0,
keep_background_prob=0.0,
with_image_info=False,
samples_scores_path=None,
samples_scores_gamma=1.0,
epoch_len=-1,
):
super(ISDataset, self).__init__()
self.epoch_len = epoch_len
self.augmentator = augmentator
self.min_object_area = min_object_area
self.keep_background_prob = keep_background_prob
self.points_sampler = points_sampler
self.with_image_info = with_image_info
self.samples_precomputed_scores = self._load_samples_scores(
samples_scores_path, samples_scores_gamma
)
self.to_tensor = transforms.ToTensor()
self.dataset_samples = None
def __getitem__(self, index):
if self.samples_precomputed_scores is not None:
index = np.random.choice(
self.samples_precomputed_scores["indices"],
p=self.samples_precomputed_scores["probs"],
)
else:
if self.epoch_len > 0:
index = random.randrange(0, len(self.dataset_samples))
sample = self.get_sample(index)
sample = self.augment_sample(sample)
sample.remove_small_objects(self.min_object_area)
self.points_sampler.sample_object(sample)
points = np.array(self.points_sampler.sample_points())
mask = self.points_sampler.selected_mask
output = {
"images": self.to_tensor(sample.image),
"points": points.astype(np.float32),
"instances": mask,
}
if self.with_image_info:
output["image_info"] = sample.sample_id
return output
def augment_sample(self, sample) -> DSample:
if self.augmentator is None:
return sample
valid_augmentation = False
while not valid_augmentation:
sample.augment(self.augmentator)
keep_sample = (
self.keep_background_prob < 0.0
or random.random() < self.keep_background_prob
)
valid_augmentation = len(sample) > 0 or keep_sample
return sample
def get_sample(self, index) -> DSample:
raise NotImplementedError
def __len__(self):
if self.epoch_len > 0:
return self.epoch_len
else:
return self.get_samples_number()
def get_samples_number(self):
return len(self.dataset_samples)
@staticmethod
def _load_samples_scores(samples_scores_path, samples_scores_gamma):
if samples_scores_path is None:
return None
with open(samples_scores_path, "rb") as f:
images_scores = pickle.load(f)
probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores])
probs /= probs.sum()
samples_scores = {"indices": [x[0] for x in images_scores], "probs": probs}
print(f"Loaded {len(probs)} weights with gamma={samples_scores_gamma}")
return samples_scores