curt-park commited on
Commit
2cdd41c
·
1 Parent(s): e82cf8b

Init the space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. Makefile +8 -0
  3. app.py +83 -2
  4. isegm/data/base.py +99 -0
  5. isegm/data/compose.py +39 -0
  6. isegm/data/datasets/__init__.py +12 -0
  7. isegm/data/datasets/ade20k.py +55 -0
  8. isegm/data/datasets/berkeley.py +6 -0
  9. isegm/data/datasets/coco.py +74 -0
  10. isegm/data/datasets/coco_lvis.py +67 -0
  11. isegm/data/datasets/davis.py +33 -0
  12. isegm/data/datasets/grabcut.py +34 -0
  13. isegm/data/datasets/images_dir.py +59 -0
  14. isegm/data/datasets/lvis.py +97 -0
  15. isegm/data/datasets/openimages.py +58 -0
  16. isegm/data/datasets/pascalvoc.py +48 -0
  17. isegm/data/datasets/sbd.py +111 -0
  18. isegm/data/points_sampler.py +305 -0
  19. isegm/data/sample.py +148 -0
  20. isegm/data/transforms.py +178 -0
  21. isegm/engine/optimizer.py +27 -0
  22. isegm/engine/trainer.py +413 -0
  23. isegm/inference/__init__.py +0 -0
  24. isegm/inference/clicker.py +118 -0
  25. isegm/inference/evaluation.py +56 -0
  26. isegm/inference/predictors/__init__.py +98 -0
  27. isegm/inference/predictors/base.py +126 -0
  28. isegm/inference/predictors/brs.py +307 -0
  29. isegm/inference/predictors/brs_functors.py +109 -0
  30. isegm/inference/predictors/brs_losses.py +58 -0
  31. isegm/inference/transforms/__init__.py +5 -0
  32. isegm/inference/transforms/base.py +38 -0
  33. isegm/inference/transforms/crops.py +97 -0
  34. isegm/inference/transforms/flip.py +37 -0
  35. isegm/inference/transforms/limit_longest_side.py +22 -0
  36. isegm/inference/transforms/zoom_in.py +175 -0
  37. isegm/inference/utils.py +143 -0
  38. isegm/model/initializer.py +105 -0
  39. isegm/model/is_deeplab_model.py +25 -0
  40. isegm/model/is_hrnet_model.py +26 -0
  41. isegm/model/is_model.py +141 -0
  42. isegm/model/losses.py +161 -0
  43. isegm/model/metrics.py +101 -0
  44. isegm/model/modeling/basic_blocks.py +71 -0
  45. isegm/model/modeling/deeplab_v3.py +176 -0
  46. isegm/model/modeling/hrnet_ocr.py +416 -0
  47. isegm/model/modeling/ocr.py +141 -0
  48. isegm/model/modeling/resnet.py +43 -0
  49. isegm/model/modeling/resnetv1b.py +276 -0
  50. isegm/model/modifiers.py +11 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ *.pth
Makefile ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ PYTHON=3.9
2
+ BASENAME=$(shell basename $(CURDIR))
3
+
4
+ env:
5
+ conda create -n $(BASENAME) python=$(PYTHON)
6
+
7
+ setup:
8
+ pip install -r requirements.txt
app.py CHANGED
@@ -1,4 +1,85 @@
1
  import streamlit as st
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ import wget
6
+ import os
7
 
8
+ from PIL import Image
9
+ from streamlit_drawable_canvas import st_canvas
10
+
11
+ from isegm.inference import clicker as ck
12
+ from isegm.inference import utils
13
+ from isegm.inference.predictors import get_predictor
14
+
15
+ # Model Path
16
+ prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
17
+ models = {
18
+ "RITM": "ritm_coco_lvis_h18_itermask.pth",
19
+ }
20
+
21
+ # Items in the sidebar.
22
+ model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
23
+ threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
24
+ marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative"))
25
+ image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
26
+
27
+ # Objects for prediction.
28
+ clicker = ck.Clicker()
29
+ device = torch.device("cpu")
30
+ predictor = None
31
+ with st.spinner("Wait for downloading a model..."):
32
+ if not os.path.exists(models[model]):
33
+ _ = wget.download(f"{prefix}/{models[model]}")
34
+
35
+ with st.spinner("Wait for loading a model..."):
36
+ model = utils.load_is_model(models[model], device, cpu_dist_maps=True)
37
+ predictor_params = {"brs_mode": "NoBRS"}
38
+ predictor = get_predictor(model, device=device, **predictor_params)
39
+
40
+ # Create a canvas component.
41
+ image = None
42
+ if image_path:
43
+ image = Image.open(image_path)
44
+ canvas_height, canvas_width = 600, 600
45
+ pos_color, neg_color = "#3498DB", "#C70039"
46
+ st.title("Canvas:")
47
+ canvas_result = st_canvas(
48
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
49
+ stroke_width=3,
50
+ stroke_color=pos_color if marking_type == "positive" else neg_color,
51
+ background_color="#eee",
52
+ background_image=image,
53
+ update_streamlit=True,
54
+ drawing_mode="point",
55
+ point_display_radius=3,
56
+ key="canvas",
57
+ width=canvas_width,
58
+ height=canvas_height,
59
+ )
60
+
61
+ # Check the user inputs ans execute predictions.
62
+ st.title("Prediction:")
63
+ if canvas_result.json_data and canvas_result.json_data["objects"] and image:
64
+ objects = canvas_result.json_data["objects"]
65
+ image_width, image_height = image.size
66
+ ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
67
+
68
+ err_x, err_y = 5.5, 1.0
69
+ pos_clicks, neg_clicks = [], []
70
+ for click in objects:
71
+ x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
72
+ x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
73
+
74
+ is_positive = click["stroke"] == pos_color
75
+ click = ck.Click(is_positive=is_positive, coords=(y, x))
76
+ clicker.add_click(click)
77
+
78
+ # prediction.
79
+ pred = None
80
+ predictor.set_input_image(np.array(image))
81
+ with st.spinner("Wait for prediction..."):
82
+ pred = predictor.get_prediction(clicker, prev_mask=None)
83
+ pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
84
+ pred = np.where(pred > threshold, 1.0, 0)
85
+ st.image(pred, caption="")
isegm/data/base.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import pickle
3
+ import numpy as np
4
+ import torch
5
+ from torchvision import transforms
6
+ from .points_sampler import MultiPointSampler
7
+ from .sample import DSample
8
+
9
+
10
+ class ISDataset(torch.utils.data.dataset.Dataset):
11
+ def __init__(self,
12
+ augmentator=None,
13
+ points_sampler=MultiPointSampler(max_num_points=12),
14
+ min_object_area=0,
15
+ keep_background_prob=0.0,
16
+ with_image_info=False,
17
+ samples_scores_path=None,
18
+ samples_scores_gamma=1.0,
19
+ epoch_len=-1):
20
+ super(ISDataset, self).__init__()
21
+ self.epoch_len = epoch_len
22
+ self.augmentator = augmentator
23
+ self.min_object_area = min_object_area
24
+ self.keep_background_prob = keep_background_prob
25
+ self.points_sampler = points_sampler
26
+ self.with_image_info = with_image_info
27
+ self.samples_precomputed_scores = self._load_samples_scores(samples_scores_path, samples_scores_gamma)
28
+ self.to_tensor = transforms.ToTensor()
29
+
30
+ self.dataset_samples = None
31
+
32
+ def __getitem__(self, index):
33
+ if self.samples_precomputed_scores is not None:
34
+ index = np.random.choice(self.samples_precomputed_scores['indices'],
35
+ p=self.samples_precomputed_scores['probs'])
36
+ else:
37
+ if self.epoch_len > 0:
38
+ index = random.randrange(0, len(self.dataset_samples))
39
+
40
+ sample = self.get_sample(index)
41
+ sample = self.augment_sample(sample)
42
+ sample.remove_small_objects(self.min_object_area)
43
+
44
+ self.points_sampler.sample_object(sample)
45
+ points = np.array(self.points_sampler.sample_points())
46
+ mask = self.points_sampler.selected_mask
47
+
48
+ output = {
49
+ 'images': self.to_tensor(sample.image),
50
+ 'points': points.astype(np.float32),
51
+ 'instances': mask
52
+ }
53
+
54
+ if self.with_image_info:
55
+ output['image_info'] = sample.sample_id
56
+
57
+ return output
58
+
59
+ def augment_sample(self, sample) -> DSample:
60
+ if self.augmentator is None:
61
+ return sample
62
+
63
+ valid_augmentation = False
64
+ while not valid_augmentation:
65
+ sample.augment(self.augmentator)
66
+ keep_sample = (self.keep_background_prob < 0.0 or
67
+ random.random() < self.keep_background_prob)
68
+ valid_augmentation = len(sample) > 0 or keep_sample
69
+
70
+ return sample
71
+
72
+ def get_sample(self, index) -> DSample:
73
+ raise NotImplementedError
74
+
75
+ def __len__(self):
76
+ if self.epoch_len > 0:
77
+ return self.epoch_len
78
+ else:
79
+ return self.get_samples_number()
80
+
81
+ def get_samples_number(self):
82
+ return len(self.dataset_samples)
83
+
84
+ @staticmethod
85
+ def _load_samples_scores(samples_scores_path, samples_scores_gamma):
86
+ if samples_scores_path is None:
87
+ return None
88
+
89
+ with open(samples_scores_path, 'rb') as f:
90
+ images_scores = pickle.load(f)
91
+
92
+ probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores])
93
+ probs /= probs.sum()
94
+ samples_scores = {
95
+ 'indices': [x[0] for x in images_scores],
96
+ 'probs': probs
97
+ }
98
+ print(f'Loaded {len(probs)} weights with gamma={samples_scores_gamma}')
99
+ return samples_scores
isegm/data/compose.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from math import isclose
3
+ from .base import ISDataset
4
+
5
+
6
+ class ComposeDataset(ISDataset):
7
+ def __init__(self, datasets, **kwargs):
8
+ super(ComposeDataset, self).__init__(**kwargs)
9
+
10
+ self._datasets = datasets
11
+ self.dataset_samples = []
12
+ for dataset_indx, dataset in enumerate(self._datasets):
13
+ self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))])
14
+
15
+ def get_sample(self, index):
16
+ dataset_indx, sample_indx = self.dataset_samples[index]
17
+ return self._datasets[dataset_indx].get_sample(sample_indx)
18
+
19
+
20
+ class ProportionalComposeDataset(ISDataset):
21
+ def __init__(self, datasets, ratios, **kwargs):
22
+ super().__init__(**kwargs)
23
+
24
+ assert len(ratios) == len(datasets),\
25
+ "The number of datasets must match the number of ratios"
26
+ assert isclose(sum(ratios), 1.0),\
27
+ "The sum of ratios must be equal to 1"
28
+
29
+ self._ratios = ratios
30
+ self._datasets = datasets
31
+ self.dataset_samples = []
32
+ for dataset_indx, dataset in enumerate(self._datasets):
33
+ self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))])
34
+
35
+ def get_sample(self, index):
36
+ dataset_indx = np.random.choice(len(self._datasets), p=self._ratios)
37
+ sample_indx = np.random.choice(len(self._datasets[dataset_indx]))
38
+
39
+ return self._datasets[dataset_indx].get_sample(sample_indx)
isegm/data/datasets/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from isegm.data.compose import ComposeDataset, ProportionalComposeDataset
2
+ from .berkeley import BerkeleyDataset
3
+ from .coco import CocoDataset
4
+ from .davis import DavisDataset
5
+ from .grabcut import GrabCutDataset
6
+ from .coco_lvis import CocoLvisDataset
7
+ from .lvis import LvisDataset
8
+ from .openimages import OpenImagesDataset
9
+ from .sbd import SBDDataset, SBDEvaluationDataset
10
+ from .images_dir import ImagesDirDataset
11
+ from .ade20k import ADE20kDataset
12
+ from .pascalvoc import PascalVocDataset
isegm/data/datasets/ade20k.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import pickle as pkl
4
+ from pathlib import Path
5
+
6
+ import cv2
7
+ import numpy as np
8
+
9
+ from isegm.data.base import ISDataset
10
+ from isegm.data.sample import DSample
11
+ from isegm.utils.misc import get_labels_with_sizes
12
+
13
+
14
+ class ADE20kDataset(ISDataset):
15
+ def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs):
16
+ super().__init__(**kwargs)
17
+ assert split in {'train', 'val'}
18
+
19
+ self.dataset_path = Path(dataset_path)
20
+ self.dataset_split = split
21
+ self.dataset_split_folder = 'training' if split == 'train' else 'validation'
22
+ self.stuff_prob = stuff_prob
23
+
24
+ anno_path = self.dataset_path / f'{split}-annotations-object-segmentation.pkl'
25
+ if os.path.exists(anno_path):
26
+ with anno_path.open('rb') as f:
27
+ annotations = pkl.load(f)
28
+ else:
29
+ raise RuntimeError(f"Can't find annotations at {anno_path}")
30
+ self.annotations = annotations
31
+ self.dataset_samples = list(annotations.keys())
32
+
33
+ def get_sample(self, index) -> DSample:
34
+ image_id = self.dataset_samples[index]
35
+ sample_annos = self.annotations[image_id]
36
+
37
+ image_path = str(self.dataset_path / sample_annos['folder'] / f'{image_id}.jpg')
38
+ image = cv2.imread(image_path)
39
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
40
+
41
+ # select random mask for an image
42
+ layer = random.choice(sample_annos['layers'])
43
+ mask_path = str(self.dataset_path / sample_annos['folder'] / layer['mask_name'])
44
+ instances_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)[:, :, 0] # the B channel holds instances
45
+ instances_mask = instances_mask.astype(np.int32)
46
+ object_ids, _ = get_labels_with_sizes(instances_mask)
47
+
48
+ if (self.stuff_prob <= 0) or (random.random() > self.stuff_prob):
49
+ # remove stuff objects
50
+ for i, object_id in enumerate(object_ids):
51
+ if i in layer['stuff_instances']:
52
+ instances_mask[instances_mask == object_id] = 0
53
+ object_ids, _ = get_labels_with_sizes(instances_mask)
54
+
55
+ return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index)
isegm/data/datasets/berkeley.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .grabcut import GrabCutDataset
2
+
3
+
4
+ class BerkeleyDataset(GrabCutDataset):
5
+ def __init__(self, dataset_path, **kwargs):
6
+ super().__init__(dataset_path, images_dir_name='images', masks_dir_name='masks', **kwargs)
isegm/data/datasets/coco.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import json
3
+ import random
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from isegm.data.base import ISDataset
7
+ from isegm.data.sample import DSample
8
+
9
+
10
+ class CocoDataset(ISDataset):
11
+ def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs):
12
+ super(CocoDataset, self).__init__(**kwargs)
13
+ self.split = split
14
+ self.dataset_path = Path(dataset_path)
15
+ self.stuff_prob = stuff_prob
16
+
17
+ self.load_samples()
18
+
19
+ def load_samples(self):
20
+ annotation_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}.json'
21
+ self.labels_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}'
22
+ self.images_path = self.dataset_path / self.split
23
+
24
+ with open(annotation_path, 'r') as f:
25
+ annotation = json.load(f)
26
+
27
+ self.dataset_samples = annotation['annotations']
28
+
29
+ self._categories = annotation['categories']
30
+ self._stuff_labels = [x['id'] for x in self._categories if x['isthing'] == 0]
31
+ self._things_labels = [x['id'] for x in self._categories if x['isthing'] == 1]
32
+ self._things_labels_set = set(self._things_labels)
33
+ self._stuff_labels_set = set(self._stuff_labels)
34
+
35
+ def get_sample(self, index) -> DSample:
36
+ dataset_sample = self.dataset_samples[index]
37
+
38
+ image_path = self.images_path / self.get_image_name(dataset_sample['file_name'])
39
+ label_path = self.labels_path / dataset_sample['file_name']
40
+
41
+ image = cv2.imread(str(image_path))
42
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
43
+ label = cv2.imread(str(label_path), cv2.IMREAD_UNCHANGED).astype(np.int32)
44
+ label = 256 * 256 * label[:, :, 0] + 256 * label[:, :, 1] + label[:, :, 2]
45
+
46
+ instance_map = np.full_like(label, 0)
47
+ things_ids = []
48
+ stuff_ids = []
49
+
50
+ for segment in dataset_sample['segments_info']:
51
+ class_id = segment['category_id']
52
+ obj_id = segment['id']
53
+ if class_id in self._things_labels_set:
54
+ if segment['iscrowd'] == 1:
55
+ continue
56
+ things_ids.append(obj_id)
57
+ else:
58
+ stuff_ids.append(obj_id)
59
+
60
+ instance_map[label == obj_id] = obj_id
61
+
62
+ if self.stuff_prob > 0 and random.random() < self.stuff_prob:
63
+ instances_ids = things_ids + stuff_ids
64
+ else:
65
+ instances_ids = things_ids
66
+
67
+ for stuff_id in stuff_ids:
68
+ instance_map[instance_map == stuff_id] = 0
69
+
70
+ return DSample(image, instance_map, objects_ids=instances_ids)
71
+
72
+ @classmethod
73
+ def get_image_name(cls, panoptic_name):
74
+ return panoptic_name.replace('.png', '.jpg')
isegm/data/datasets/coco_lvis.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import pickle
3
+ import random
4
+ import numpy as np
5
+ import json
6
+ import cv2
7
+ from copy import deepcopy
8
+ from isegm.data.base import ISDataset
9
+ from isegm.data.sample import DSample
10
+
11
+
12
+ class CocoLvisDataset(ISDataset):
13
+ def __init__(self, dataset_path, split='train', stuff_prob=0.0,
14
+ allow_list_name=None, anno_file='hannotation.pickle', **kwargs):
15
+ super(CocoLvisDataset, self).__init__(**kwargs)
16
+ dataset_path = Path(dataset_path)
17
+ self._split_path = dataset_path / split
18
+ self.split = split
19
+ self._images_path = self._split_path / 'images'
20
+ self._masks_path = self._split_path / 'masks'
21
+ self.stuff_prob = stuff_prob
22
+
23
+ with open(self._split_path / anno_file, 'rb') as f:
24
+ self.dataset_samples = sorted(pickle.load(f).items())
25
+
26
+ if allow_list_name is not None:
27
+ allow_list_path = self._split_path / allow_list_name
28
+ with open(allow_list_path, 'r') as f:
29
+ allow_images_ids = json.load(f)
30
+ allow_images_ids = set(allow_images_ids)
31
+
32
+ self.dataset_samples = [sample for sample in self.dataset_samples
33
+ if sample[0] in allow_images_ids]
34
+
35
+ def get_sample(self, index) -> DSample:
36
+ image_id, sample = self.dataset_samples[index]
37
+ image_path = self._images_path / f'{image_id}.jpg'
38
+
39
+ image = cv2.imread(str(image_path))
40
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
41
+
42
+ packed_masks_path = self._masks_path / f'{image_id}.pickle'
43
+ with open(packed_masks_path, 'rb') as f:
44
+ encoded_layers, objs_mapping = pickle.load(f)
45
+ layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in encoded_layers]
46
+ layers = np.stack(layers, axis=2)
47
+
48
+ instances_info = deepcopy(sample['hierarchy'])
49
+ for inst_id, inst_info in list(instances_info.items()):
50
+ if inst_info is None:
51
+ inst_info = {'children': [], 'parent': None, 'node_level': 0}
52
+ instances_info[inst_id] = inst_info
53
+ inst_info['mapping'] = objs_mapping[inst_id]
54
+
55
+ if self.stuff_prob > 0 and random.random() < self.stuff_prob:
56
+ for inst_id in range(sample['num_instance_masks'], len(objs_mapping)):
57
+ instances_info[inst_id] = {
58
+ 'mapping': objs_mapping[inst_id],
59
+ 'parent': None,
60
+ 'children': []
61
+ }
62
+ else:
63
+ for inst_id in range(sample['num_instance_masks'], len(objs_mapping)):
64
+ layer_indx, mask_id = objs_mapping[inst_id]
65
+ layers[:, :, layer_indx][layers[:, :, layer_indx] == mask_id] = 0
66
+
67
+ return DSample(image, layers, objects=instances_info)
isegm/data/datasets/davis.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from isegm.data.base import ISDataset
7
+ from isegm.data.sample import DSample
8
+
9
+
10
+ class DavisDataset(ISDataset):
11
+ def __init__(self, dataset_path,
12
+ images_dir_name='img', masks_dir_name='gt',
13
+ **kwargs):
14
+ super(DavisDataset, self).__init__(**kwargs)
15
+
16
+ self.dataset_path = Path(dataset_path)
17
+ self._images_path = self.dataset_path / images_dir_name
18
+ self._insts_path = self.dataset_path / masks_dir_name
19
+
20
+ self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))]
21
+ self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')}
22
+
23
+ def get_sample(self, index) -> DSample:
24
+ image_name = self.dataset_samples[index]
25
+ image_path = str(self._images_path / image_name)
26
+ mask_path = str(self._masks_paths[image_name.split('.')[0]])
27
+
28
+ image = cv2.imread(image_path)
29
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
30
+ instances_mask = np.max(cv2.imread(mask_path).astype(np.int32), axis=2)
31
+ instances_mask[instances_mask > 0] = 1
32
+
33
+ return DSample(image, instances_mask, objects_ids=[1], sample_id=index)
isegm/data/datasets/grabcut.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from isegm.data.base import ISDataset
7
+ from isegm.data.sample import DSample
8
+
9
+
10
+ class GrabCutDataset(ISDataset):
11
+ def __init__(self, dataset_path,
12
+ images_dir_name='data_GT', masks_dir_name='boundary_GT',
13
+ **kwargs):
14
+ super(GrabCutDataset, self).__init__(**kwargs)
15
+
16
+ self.dataset_path = Path(dataset_path)
17
+ self._images_path = self.dataset_path / images_dir_name
18
+ self._insts_path = self.dataset_path / masks_dir_name
19
+
20
+ self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))]
21
+ self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')}
22
+
23
+ def get_sample(self, index) -> DSample:
24
+ image_name = self.dataset_samples[index]
25
+ image_path = str(self._images_path / image_name)
26
+ mask_path = str(self._masks_paths[image_name.split('.')[0]])
27
+
28
+ image = cv2.imread(image_path)
29
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
30
+ instances_mask = cv2.imread(mask_path)[:, :, 0].astype(np.int32)
31
+ instances_mask[instances_mask == 128] = -1
32
+ instances_mask[instances_mask > 128] = 1
33
+
34
+ return DSample(image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index)
isegm/data/datasets/images_dir.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from pathlib import Path
4
+
5
+ from isegm.data.base import ISDataset
6
+ from isegm.data.sample import DSample
7
+
8
+
9
+ class ImagesDirDataset(ISDataset):
10
+ def __init__(self, dataset_path,
11
+ images_dir_name='images', masks_dir_name='masks',
12
+ **kwargs):
13
+ super(ImagesDirDataset, self).__init__(**kwargs)
14
+
15
+ self.dataset_path = Path(dataset_path)
16
+ self._images_path = self.dataset_path / images_dir_name
17
+ self._insts_path = self.dataset_path / masks_dir_name
18
+
19
+ images_list = [x for x in sorted(self._images_path.glob('*.*'))]
20
+
21
+ samples = {x.stem: {'image': x, 'masks': []} for x in images_list}
22
+ for mask_path in self._insts_path.glob('*.*'):
23
+ mask_name = mask_path.stem
24
+ if mask_name in samples:
25
+ samples[mask_name]['masks'].append(mask_path)
26
+ continue
27
+
28
+ mask_name_split = mask_name.split('_')
29
+ if mask_name_split[-1].isdigit():
30
+ mask_name = '_'.join(mask_name_split[:-1])
31
+ assert mask_name in samples
32
+ samples[mask_name]['masks'].append(mask_path)
33
+
34
+ for x in samples.values():
35
+ assert len(x['masks']) > 0, x['image']
36
+
37
+ self.dataset_samples = [v for k, v in sorted(samples.items())]
38
+
39
+ def get_sample(self, index) -> DSample:
40
+ sample = self.dataset_samples[index]
41
+ image_path = str(sample['image'])
42
+
43
+ objects = []
44
+ ignored_regions = []
45
+ masks = []
46
+ for indx, mask_path in enumerate(sample['masks']):
47
+ gt_mask = cv2.imread(str(mask_path))[:, :, 0].astype(np.int32)
48
+ instances_mask = np.zeros_like(gt_mask)
49
+ instances_mask[gt_mask == 128] = 2
50
+ instances_mask[gt_mask > 128] = 1
51
+ masks.append(instances_mask)
52
+ objects.append((indx, 1))
53
+ ignored_regions.append((indx, 2))
54
+
55
+ image = cv2.imread(image_path)
56
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
57
+
58
+ return DSample(image, np.stack(masks, axis=2),
59
+ objects_ids=objects, ignore_ids=ignored_regions, sample_id=index)
isegm/data/datasets/lvis.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from collections import defaultdict
4
+ from pathlib import Path
5
+
6
+ import cv2
7
+ import numpy as np
8
+
9
+ from isegm.data.base import ISDataset
10
+ from isegm.data.sample import DSample
11
+
12
+
13
+ class LvisDataset(ISDataset):
14
+ def __init__(self, dataset_path, split='train',
15
+ max_overlap_ratio=0.5,
16
+ **kwargs):
17
+ super(LvisDataset, self).__init__(**kwargs)
18
+ dataset_path = Path(dataset_path)
19
+ train_categories_path = dataset_path / 'train_categories.json'
20
+ self._train_path = dataset_path / 'train'
21
+ self._val_path = dataset_path / 'val'
22
+
23
+ self.split = split
24
+ self.max_overlap_ratio = max_overlap_ratio
25
+
26
+ with open( dataset_path / split / f'lvis_{self.split}.json', 'r') as f:
27
+ json_annotation = json.loads(f.read())
28
+
29
+ self.annotations = defaultdict(list)
30
+ for x in json_annotation['annotations']:
31
+ self.annotations[x['image_id']].append(x)
32
+
33
+ if not train_categories_path.exists():
34
+ self.generate_train_categories(dataset_path, train_categories_path)
35
+ self.dataset_samples = [x for x in json_annotation['images']
36
+ if len(self.annotations[x['id']]) > 0]
37
+
38
+ def get_sample(self, index) -> DSample:
39
+ image_info = self.dataset_samples[index]
40
+ image_id, image_url = image_info['id'], image_info['coco_url']
41
+ image_filename = image_url.split('/')[-1]
42
+ image_annotations = self.annotations[image_id]
43
+ random.shuffle(image_annotations)
44
+
45
+ # LVISv1 splits do not match older LVIS splits (some images in val may come from COCO train2017)
46
+ if 'train2017' in image_url:
47
+ image_path = self._train_path / 'images' / image_filename
48
+ else:
49
+ image_path = self._val_path / 'images' / image_filename
50
+ image = cv2.imread(str(image_path))
51
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
52
+
53
+ instances_mask = None
54
+ instances_area = defaultdict(int)
55
+ objects_ids = []
56
+ for indx, obj_annotation in enumerate(image_annotations):
57
+ mask = self.get_mask_from_polygon(obj_annotation, image)
58
+ object_mask = mask > 0
59
+ object_area = object_mask.sum()
60
+
61
+ if instances_mask is None:
62
+ instances_mask = np.zeros_like(object_mask, dtype=np.int32)
63
+
64
+ overlap_ids = np.bincount(instances_mask[object_mask].flatten())
65
+ overlap_areas = [overlap_area / instances_area[inst_id] for inst_id, overlap_area in enumerate(overlap_ids)
66
+ if overlap_area > 0 and inst_id > 0]
67
+ overlap_ratio = np.logical_and(object_mask, instances_mask > 0).sum() / object_area
68
+ if overlap_areas:
69
+ overlap_ratio = max(overlap_ratio, max(overlap_areas))
70
+ if overlap_ratio > self.max_overlap_ratio:
71
+ continue
72
+
73
+ instance_id = indx + 1
74
+ instances_mask[object_mask] = instance_id
75
+ instances_area[instance_id] = object_area
76
+ objects_ids.append(instance_id)
77
+
78
+ return DSample(image, instances_mask, objects_ids=objects_ids)
79
+
80
+
81
+ @staticmethod
82
+ def get_mask_from_polygon(annotation, image):
83
+ mask = np.zeros(image.shape[:2], dtype=np.int32)
84
+ for contour_points in annotation['segmentation']:
85
+ contour_points = np.array(contour_points).reshape((-1, 2))
86
+ contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :]
87
+ cv2.fillPoly(mask, contour_points, 1)
88
+
89
+ return mask
90
+
91
+ @staticmethod
92
+ def generate_train_categories(dataset_path, train_categories_path):
93
+ with open(dataset_path / 'train/lvis_train.json', 'r') as f:
94
+ annotation = json.load(f)
95
+
96
+ with open(train_categories_path, 'w') as f:
97
+ json.dump(annotation['categories'], f, indent=1)
isegm/data/datasets/openimages.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import pickle as pkl
4
+ from pathlib import Path
5
+
6
+ import cv2
7
+ import numpy as np
8
+
9
+ from isegm.data.base import ISDataset
10
+ from isegm.data.sample import DSample
11
+
12
+
13
+ class OpenImagesDataset(ISDataset):
14
+ def __init__(self, dataset_path, split='train', **kwargs):
15
+ super().__init__(**kwargs)
16
+ assert split in {'train', 'val', 'test'}
17
+
18
+ self.dataset_path = Path(dataset_path)
19
+ self._split_path = self.dataset_path / split
20
+ self._images_path = self._split_path / 'images'
21
+ self._masks_path = self._split_path / 'masks'
22
+ self.dataset_split = split
23
+
24
+ clean_anno_path = self._split_path / f'{split}-annotations-object-segmentation_clean.pkl'
25
+ if os.path.exists(clean_anno_path):
26
+ with clean_anno_path.open('rb') as f:
27
+ annotations = pkl.load(f)
28
+ else:
29
+ raise RuntimeError(f"Can't find annotations at {clean_anno_path}")
30
+ self.image_id_to_masks = annotations['image_id_to_masks']
31
+ self.dataset_samples = annotations['dataset_samples']
32
+
33
+ def get_sample(self, index) -> DSample:
34
+ image_id = self.dataset_samples[index]
35
+
36
+ image_path = str(self._images_path / f'{image_id}.jpg')
37
+ image = cv2.imread(image_path)
38
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
39
+
40
+ mask_paths = self.image_id_to_masks[image_id]
41
+ # select random mask for an image
42
+ mask_path = str(self._masks_path / random.choice(mask_paths))
43
+ instances_mask = cv2.imread(mask_path)
44
+ instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY)
45
+ instances_mask[instances_mask > 0] = 1
46
+ instances_mask = instances_mask.astype(np.int32)
47
+
48
+ min_width = min(image.shape[1], instances_mask.shape[1])
49
+ min_height = min(image.shape[0], instances_mask.shape[0])
50
+
51
+ if image.shape[0] != min_height or image.shape[1] != min_width:
52
+ image = cv2.resize(image, (min_width, min_height), interpolation=cv2.INTER_LINEAR)
53
+ if instances_mask.shape[0] != min_height or instances_mask.shape[1] != min_width:
54
+ instances_mask = cv2.resize(instances_mask, (min_width, min_height), interpolation=cv2.INTER_NEAREST)
55
+
56
+ object_ids = [1] if instances_mask.sum() > 0 else []
57
+
58
+ return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index)
isegm/data/datasets/pascalvoc.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle as pkl
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from isegm.data.base import ISDataset
8
+ from isegm.data.sample import DSample
9
+
10
+
11
+ class PascalVocDataset(ISDataset):
12
+ def __init__(self, dataset_path, split='train', **kwargs):
13
+ super().__init__(**kwargs)
14
+ assert split in {'train', 'val', 'trainval', 'test'}
15
+
16
+ self.dataset_path = Path(dataset_path)
17
+ self._images_path = self.dataset_path / "JPEGImages"
18
+ self._insts_path = self.dataset_path / "SegmentationObject"
19
+ self.dataset_split = split
20
+
21
+ if split == 'test':
22
+ with open(self.dataset_path / f'ImageSets/Segmentation/test.pickle', 'rb') as f:
23
+ self.dataset_samples, self.instance_ids = pkl.load(f)
24
+ else:
25
+ with open(self.dataset_path / f'ImageSets/Segmentation/{split}.txt', 'r') as f:
26
+ self.dataset_samples = [name.strip() for name in f.readlines()]
27
+
28
+ def get_sample(self, index) -> DSample:
29
+ sample_id = self.dataset_samples[index]
30
+ image_path = str(self._images_path / f'{sample_id}.jpg')
31
+ mask_path = str(self._insts_path / f'{sample_id}.png')
32
+
33
+ image = cv2.imread(image_path)
34
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
35
+ instances_mask = cv2.imread(mask_path)
36
+ instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32)
37
+ if self.dataset_split == 'test':
38
+ instance_id = self.instance_ids[index]
39
+ mask = np.zeros_like(instances_mask)
40
+ mask[instances_mask == 220] = 220 # ignored area
41
+ mask[instances_mask == instance_id] = 1
42
+ objects_ids = [1]
43
+ instances_mask = mask
44
+ else:
45
+ objects_ids = np.unique(instances_mask)
46
+ objects_ids = [x for x in objects_ids if x != 0 and x != 220]
47
+
48
+ return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index)
isegm/data/datasets/sbd.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle as pkl
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from scipy.io import loadmat
7
+
8
+ from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes
9
+ from isegm.data.base import ISDataset
10
+ from isegm.data.sample import DSample
11
+
12
+
13
+ class SBDDataset(ISDataset):
14
+ def __init__(self, dataset_path, split='train', buggy_mask_thresh=0.08, **kwargs):
15
+ super(SBDDataset, self).__init__(**kwargs)
16
+ assert split in {'train', 'val'}
17
+
18
+ self.dataset_path = Path(dataset_path)
19
+ self.dataset_split = split
20
+ self._images_path = self.dataset_path / 'img'
21
+ self._insts_path = self.dataset_path / 'inst'
22
+ self._buggy_objects = dict()
23
+ self._buggy_mask_thresh = buggy_mask_thresh
24
+
25
+ with open(self.dataset_path / f'{split}.txt', 'r') as f:
26
+ self.dataset_samples = [x.strip() for x in f.readlines()]
27
+
28
+ def get_sample(self, index):
29
+ image_name = self.dataset_samples[index]
30
+ image_path = str(self._images_path / f'{image_name}.jpg')
31
+ inst_info_path = str(self._insts_path / f'{image_name}.mat')
32
+
33
+ image = cv2.imread(image_path)
34
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
35
+ instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32)
36
+ instances_mask = self.remove_buggy_masks(index, instances_mask)
37
+ instances_ids, _ = get_labels_with_sizes(instances_mask)
38
+
39
+ return DSample(image, instances_mask, objects_ids=instances_ids, sample_id=index)
40
+
41
+ def remove_buggy_masks(self, index, instances_mask):
42
+ if self._buggy_mask_thresh > 0.0:
43
+ buggy_image_objects = self._buggy_objects.get(index, None)
44
+ if buggy_image_objects is None:
45
+ buggy_image_objects = []
46
+ instances_ids, _ = get_labels_with_sizes(instances_mask)
47
+ for obj_id in instances_ids:
48
+ obj_mask = instances_mask == obj_id
49
+ mask_area = obj_mask.sum()
50
+ bbox = get_bbox_from_mask(obj_mask)
51
+ bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1)
52
+ obj_area_ratio = mask_area / bbox_area
53
+ if obj_area_ratio < self._buggy_mask_thresh:
54
+ buggy_image_objects.append(obj_id)
55
+
56
+ self._buggy_objects[index] = buggy_image_objects
57
+ for obj_id in buggy_image_objects:
58
+ instances_mask[instances_mask == obj_id] = 0
59
+
60
+ return instances_mask
61
+
62
+
63
+ class SBDEvaluationDataset(ISDataset):
64
+ def __init__(self, dataset_path, split='val', **kwargs):
65
+ super(SBDEvaluationDataset, self).__init__(**kwargs)
66
+ assert split in {'train', 'val'}
67
+
68
+ self.dataset_path = Path(dataset_path)
69
+ self.dataset_split = split
70
+ self._images_path = self.dataset_path / 'img'
71
+ self._insts_path = self.dataset_path / 'inst'
72
+
73
+ with open(self.dataset_path / f'{split}.txt', 'r') as f:
74
+ self.dataset_samples = [x.strip() for x in f.readlines()]
75
+
76
+ self.dataset_samples = self.get_sbd_images_and_ids_list()
77
+
78
+ def get_sample(self, index) -> DSample:
79
+ image_name, instance_id = self.dataset_samples[index]
80
+ image_path = str(self._images_path / f'{image_name}.jpg')
81
+ inst_info_path = str(self._insts_path / f'{image_name}.mat')
82
+
83
+ image = cv2.imread(image_path)
84
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
85
+ instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32)
86
+ instances_mask[instances_mask != instance_id] = 0
87
+ instances_mask[instances_mask > 0] = 1
88
+
89
+ return DSample(image, instances_mask, objects_ids=[1], sample_id=index)
90
+
91
+ def get_sbd_images_and_ids_list(self):
92
+ pkl_path = self.dataset_path / f'{self.dataset_split}_images_and_ids_list.pkl'
93
+
94
+ if pkl_path.exists():
95
+ with open(str(pkl_path), 'rb') as fp:
96
+ images_and_ids_list = pkl.load(fp)
97
+ else:
98
+ images_and_ids_list = []
99
+
100
+ for sample in self.dataset_samples:
101
+ inst_info_path = str(self._insts_path / f'{sample}.mat')
102
+ instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32)
103
+ instances_ids, _ = get_labels_with_sizes(instances_mask)
104
+
105
+ for instances_id in instances_ids:
106
+ images_and_ids_list.append((sample, instances_id))
107
+
108
+ with open(str(pkl_path), 'wb') as fp:
109
+ pkl.dump(images_and_ids_list, fp)
110
+
111
+ return images_and_ids_list
isegm/data/points_sampler.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ from functools import lru_cache
6
+ from .sample import DSample
7
+
8
+
9
+ class BasePointSampler:
10
+ def __init__(self):
11
+ self._selected_mask = None
12
+ self._selected_masks = None
13
+
14
+ def sample_object(self, sample: DSample):
15
+ raise NotImplementedError
16
+
17
+ def sample_points(self):
18
+ raise NotImplementedError
19
+
20
+ @property
21
+ def selected_mask(self):
22
+ assert self._selected_mask is not None
23
+ return self._selected_mask
24
+
25
+ @selected_mask.setter
26
+ def selected_mask(self, mask):
27
+ self._selected_mask = mask[np.newaxis, :].astype(np.float32)
28
+
29
+
30
+ class MultiPointSampler(BasePointSampler):
31
+ def __init__(self, max_num_points, prob_gamma=0.7, expand_ratio=0.1,
32
+ positive_erode_prob=0.9, positive_erode_iters=3,
33
+ negative_bg_prob=0.1, negative_other_prob=0.4, negative_border_prob=0.5,
34
+ merge_objects_prob=0.0, max_num_merged_objects=2,
35
+ use_hierarchy=False, soft_targets=False,
36
+ first_click_center=False, only_one_first_click=False,
37
+ sfc_inner_k=1.7, sfc_full_inner_prob=0.0):
38
+ super().__init__()
39
+ self.max_num_points = max_num_points
40
+ self.expand_ratio = expand_ratio
41
+ self.positive_erode_prob = positive_erode_prob
42
+ self.positive_erode_iters = positive_erode_iters
43
+ self.merge_objects_prob = merge_objects_prob
44
+ self.use_hierarchy = use_hierarchy
45
+ self.soft_targets = soft_targets
46
+ self.first_click_center = first_click_center
47
+ self.only_one_first_click = only_one_first_click
48
+ self.sfc_inner_k = sfc_inner_k
49
+ self.sfc_full_inner_prob = sfc_full_inner_prob
50
+
51
+ if max_num_merged_objects == -1:
52
+ max_num_merged_objects = max_num_points
53
+ self.max_num_merged_objects = max_num_merged_objects
54
+
55
+ self.neg_strategies = ['bg', 'other', 'border']
56
+ self.neg_strategies_prob = [negative_bg_prob, negative_other_prob, negative_border_prob]
57
+ assert math.isclose(sum(self.neg_strategies_prob), 1.0)
58
+
59
+ self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma)
60
+ self._neg_probs = generate_probs(max_num_points + 1, gamma=prob_gamma)
61
+ self._neg_masks = None
62
+
63
+ def sample_object(self, sample: DSample):
64
+ if len(sample) == 0:
65
+ bg_mask = sample.get_background_mask()
66
+ self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32)
67
+ self._selected_masks = [[]]
68
+ self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies}
69
+ self._neg_masks['required'] = []
70
+ return
71
+
72
+ gt_mask, pos_masks, neg_masks = self._sample_mask(sample)
73
+ binary_gt_mask = gt_mask > 0.5 if self.soft_targets else gt_mask > 0
74
+
75
+ self.selected_mask = gt_mask
76
+ self._selected_masks = pos_masks
77
+
78
+ neg_mask_bg = np.logical_not(binary_gt_mask)
79
+ neg_mask_border = self._get_border_mask(binary_gt_mask)
80
+ if len(sample) <= len(self._selected_masks):
81
+ neg_mask_other = neg_mask_bg
82
+ else:
83
+ neg_mask_other = np.logical_and(np.logical_not(sample.get_background_mask()),
84
+ np.logical_not(binary_gt_mask))
85
+
86
+ self._neg_masks = {
87
+ 'bg': neg_mask_bg,
88
+ 'other': neg_mask_other,
89
+ 'border': neg_mask_border,
90
+ 'required': neg_masks
91
+ }
92
+
93
+ def _sample_mask(self, sample: DSample):
94
+ root_obj_ids = sample.root_objects
95
+
96
+ if len(root_obj_ids) > 1 and random.random() < self.merge_objects_prob:
97
+ max_selected_objects = min(len(root_obj_ids), self.max_num_merged_objects)
98
+ num_selected_objects = np.random.randint(2, max_selected_objects + 1)
99
+ random_ids = random.sample(root_obj_ids, num_selected_objects)
100
+ else:
101
+ random_ids = [random.choice(root_obj_ids)]
102
+
103
+ gt_mask = None
104
+ pos_segments = []
105
+ neg_segments = []
106
+ for obj_id in random_ids:
107
+ obj_gt_mask, obj_pos_segments, obj_neg_segments = self._sample_from_masks_layer(obj_id, sample)
108
+ if gt_mask is None:
109
+ gt_mask = obj_gt_mask
110
+ else:
111
+ gt_mask = np.maximum(gt_mask, obj_gt_mask)
112
+
113
+ pos_segments.extend(obj_pos_segments)
114
+ neg_segments.extend(obj_neg_segments)
115
+
116
+ pos_masks = [self._positive_erode(x) for x in pos_segments]
117
+ neg_masks = [self._positive_erode(x) for x in neg_segments]
118
+
119
+ return gt_mask, pos_masks, neg_masks
120
+
121
+ def _sample_from_masks_layer(self, obj_id, sample: DSample):
122
+ objs_tree = sample._objects
123
+
124
+ if not self.use_hierarchy:
125
+ node_mask = sample.get_object_mask(obj_id)
126
+ gt_mask = sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask
127
+ return gt_mask, [node_mask], []
128
+
129
+ def _select_node(node_id):
130
+ node_info = objs_tree[node_id]
131
+ if not node_info['children'] or random.random() < 0.5:
132
+ return node_id
133
+ return _select_node(random.choice(node_info['children']))
134
+
135
+ selected_node = _select_node(obj_id)
136
+ node_info = objs_tree[selected_node]
137
+ node_mask = sample.get_object_mask(selected_node)
138
+ gt_mask = sample.get_soft_object_mask(selected_node) if self.soft_targets else node_mask
139
+ pos_mask = node_mask.copy()
140
+
141
+ negative_segments = []
142
+ if node_info['parent'] is not None and node_info['parent'] in objs_tree:
143
+ parent_mask = sample.get_object_mask(node_info['parent'])
144
+ negative_segments.append(np.logical_and(parent_mask, np.logical_not(node_mask)))
145
+
146
+ for child_id in node_info['children']:
147
+ if objs_tree[child_id]['area'] / node_info['area'] < 0.10:
148
+ child_mask = sample.get_object_mask(child_id)
149
+ pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask))
150
+
151
+ if node_info['children']:
152
+ max_disabled_children = min(len(node_info['children']), 3)
153
+ num_disabled_children = np.random.randint(0, max_disabled_children + 1)
154
+ disabled_children = random.sample(node_info['children'], num_disabled_children)
155
+
156
+ for child_id in disabled_children:
157
+ child_mask = sample.get_object_mask(child_id)
158
+ pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask))
159
+ if self.soft_targets:
160
+ soft_child_mask = sample.get_soft_object_mask(child_id)
161
+ gt_mask = np.minimum(gt_mask, 1.0 - soft_child_mask)
162
+ else:
163
+ gt_mask = np.logical_and(gt_mask, np.logical_not(child_mask))
164
+ negative_segments.append(child_mask)
165
+
166
+ return gt_mask, [pos_mask], negative_segments
167
+
168
+ def sample_points(self):
169
+ assert self._selected_mask is not None
170
+ pos_points = self._multi_mask_sample_points(self._selected_masks,
171
+ is_negative=[False] * len(self._selected_masks),
172
+ with_first_click=self.first_click_center)
173
+
174
+ neg_strategy = [(self._neg_masks[k], prob)
175
+ for k, prob in zip(self.neg_strategies, self.neg_strategies_prob)]
176
+ neg_masks = self._neg_masks['required'] + [neg_strategy]
177
+ neg_points = self._multi_mask_sample_points(neg_masks,
178
+ is_negative=[False] * len(self._neg_masks['required']) + [True])
179
+
180
+ return pos_points + neg_points
181
+
182
+ def _multi_mask_sample_points(self, selected_masks, is_negative, with_first_click=False):
183
+ selected_masks = selected_masks[:self.max_num_points]
184
+
185
+ each_obj_points = [
186
+ self._sample_points(mask, is_negative=is_negative[i],
187
+ with_first_click=with_first_click)
188
+ for i, mask in enumerate(selected_masks)
189
+ ]
190
+ each_obj_points = [x for x in each_obj_points if len(x) > 0]
191
+
192
+ points = []
193
+ if len(each_obj_points) == 1:
194
+ points = each_obj_points[0]
195
+ elif len(each_obj_points) > 1:
196
+ if self.only_one_first_click:
197
+ each_obj_points = each_obj_points[:1]
198
+
199
+ points = [obj_points[0] for obj_points in each_obj_points]
200
+
201
+ aggregated_masks_with_prob = []
202
+ for indx, x in enumerate(selected_masks):
203
+ if isinstance(x, (list, tuple)) and x and isinstance(x[0], (list, tuple)):
204
+ for t, prob in x:
205
+ aggregated_masks_with_prob.append((t, prob / len(selected_masks)))
206
+ else:
207
+ aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks)))
208
+
209
+ other_points_union = self._sample_points(aggregated_masks_with_prob, is_negative=True)
210
+ if len(other_points_union) + len(points) <= self.max_num_points:
211
+ points.extend(other_points_union)
212
+ else:
213
+ points.extend(random.sample(other_points_union, self.max_num_points - len(points)))
214
+
215
+ if len(points) < self.max_num_points:
216
+ points.extend([(-1, -1, -1)] * (self.max_num_points - len(points)))
217
+
218
+ return points
219
+
220
+ def _sample_points(self, mask, is_negative=False, with_first_click=False):
221
+ if is_negative:
222
+ num_points = np.random.choice(np.arange(self.max_num_points + 1), p=self._neg_probs)
223
+ else:
224
+ num_points = 1 + np.random.choice(np.arange(self.max_num_points), p=self._pos_probs)
225
+
226
+ indices_probs = None
227
+ if isinstance(mask, (list, tuple)):
228
+ indices_probs = [x[1] for x in mask]
229
+ indices = [(np.argwhere(x), prob) for x, prob in mask]
230
+ if indices_probs:
231
+ assert math.isclose(sum(indices_probs), 1.0)
232
+ else:
233
+ indices = np.argwhere(mask)
234
+
235
+ points = []
236
+ for j in range(num_points):
237
+ first_click = with_first_click and j == 0 and indices_probs is None
238
+
239
+ if first_click:
240
+ point_indices = get_point_candidates(mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob)
241
+ elif indices_probs:
242
+ point_indices_indx = np.random.choice(np.arange(len(indices)), p=indices_probs)
243
+ point_indices = indices[point_indices_indx][0]
244
+ else:
245
+ point_indices = indices
246
+
247
+ num_indices = len(point_indices)
248
+ if num_indices > 0:
249
+ point_indx = 0 if first_click else 100
250
+ click = point_indices[np.random.randint(0, num_indices)].tolist() + [point_indx]
251
+ points.append(click)
252
+
253
+ return points
254
+
255
+ def _positive_erode(self, mask):
256
+ if random.random() > self.positive_erode_prob:
257
+ return mask
258
+
259
+ kernel = np.ones((3, 3), np.uint8)
260
+ eroded_mask = cv2.erode(mask.astype(np.uint8),
261
+ kernel, iterations=self.positive_erode_iters).astype(np.bool)
262
+
263
+ if eroded_mask.sum() > 10:
264
+ return eroded_mask
265
+ else:
266
+ return mask
267
+
268
+ def _get_border_mask(self, mask):
269
+ expand_r = int(np.ceil(self.expand_ratio * np.sqrt(mask.sum())))
270
+ kernel = np.ones((3, 3), np.uint8)
271
+ expanded_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=expand_r)
272
+ expanded_mask[mask.astype(np.bool)] = 0
273
+ return expanded_mask
274
+
275
+
276
+ @lru_cache(maxsize=None)
277
+ def generate_probs(max_num_points, gamma):
278
+ probs = []
279
+ last_value = 1
280
+ for i in range(max_num_points):
281
+ probs.append(last_value)
282
+ last_value *= gamma
283
+
284
+ probs = np.array(probs)
285
+ probs /= probs.sum()
286
+
287
+ return probs
288
+
289
+
290
+ def get_point_candidates(obj_mask, k=1.7, full_prob=0.0):
291
+ if full_prob > 0 and random.random() < full_prob:
292
+ return obj_mask
293
+
294
+ padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), 'constant')
295
+
296
+ dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1]
297
+ if k > 0:
298
+ inner_mask = dt > dt.max() / k
299
+ return np.argwhere(inner_mask)
300
+ else:
301
+ prob_map = dt.flatten()
302
+ prob_map /= max(prob_map.sum(), 1e-6)
303
+ click_indx = np.random.choice(len(prob_map), p=prob_map)
304
+ click_coords = np.unravel_index(click_indx, dt.shape)
305
+ return np.array([click_coords])
isegm/data/sample.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from copy import deepcopy
3
+ from isegm.utils.misc import get_labels_with_sizes
4
+ from isegm.data.transforms import remove_image_only_transforms
5
+ from albumentations import ReplayCompose
6
+
7
+
8
+ class DSample:
9
+ def __init__(self, image, encoded_masks, objects=None,
10
+ objects_ids=None, ignore_ids=None, sample_id=None):
11
+ self.image = image
12
+ self.sample_id = sample_id
13
+
14
+ if len(encoded_masks.shape) == 2:
15
+ encoded_masks = encoded_masks[:, :, np.newaxis]
16
+ self._encoded_masks = encoded_masks
17
+ self._ignored_regions = []
18
+
19
+ if objects_ids is not None:
20
+ if not objects_ids or not isinstance(objects_ids[0], tuple):
21
+ assert encoded_masks.shape[2] == 1
22
+ objects_ids = [(0, obj_id) for obj_id in objects_ids]
23
+
24
+ self._objects = dict()
25
+ for indx, obj_mapping in enumerate(objects_ids):
26
+ self._objects[indx] = {
27
+ 'parent': None,
28
+ 'mapping': obj_mapping,
29
+ 'children': []
30
+ }
31
+
32
+ if ignore_ids:
33
+ if isinstance(ignore_ids[0], tuple):
34
+ self._ignored_regions = ignore_ids
35
+ else:
36
+ self._ignored_regions = [(0, region_id) for region_id in ignore_ids]
37
+ else:
38
+ self._objects = deepcopy(objects)
39
+
40
+ self._augmented = False
41
+ self._soft_mask_aug = None
42
+ self._original_data = self.image, self._encoded_masks, deepcopy(self._objects)
43
+
44
+ def augment(self, augmentator):
45
+ self.reset_augmentation()
46
+ aug_output = augmentator(image=self.image, mask=self._encoded_masks)
47
+ self.image = aug_output['image']
48
+ self._encoded_masks = aug_output['mask']
49
+
50
+ aug_replay = aug_output.get('replay', None)
51
+ if aug_replay:
52
+ assert len(self._ignored_regions) == 0
53
+ mask_replay = remove_image_only_transforms(aug_replay)
54
+ self._soft_mask_aug = ReplayCompose._restore_for_replay(mask_replay)
55
+
56
+ self._compute_objects_areas()
57
+ self.remove_small_objects(min_area=1)
58
+
59
+ self._augmented = True
60
+
61
+ def reset_augmentation(self):
62
+ if not self._augmented:
63
+ return
64
+ orig_image, orig_masks, orig_objects = self._original_data
65
+ self.image = orig_image
66
+ self._encoded_masks = orig_masks
67
+ self._objects = deepcopy(orig_objects)
68
+ self._augmented = False
69
+ self._soft_mask_aug = None
70
+
71
+ def remove_small_objects(self, min_area):
72
+ if self._objects and not 'area' in list(self._objects.values())[0]:
73
+ self._compute_objects_areas()
74
+
75
+ for obj_id, obj_info in list(self._objects.items()):
76
+ if obj_info['area'] < min_area:
77
+ self._remove_object(obj_id)
78
+
79
+ def get_object_mask(self, obj_id):
80
+ layer_indx, mask_id = self._objects[obj_id]['mapping']
81
+ obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32)
82
+ if self._ignored_regions:
83
+ for layer_indx, mask_id in self._ignored_regions:
84
+ ignore_mask = self._encoded_masks[:, :, layer_indx] == mask_id
85
+ obj_mask[ignore_mask] = -1
86
+
87
+ return obj_mask
88
+
89
+ def get_soft_object_mask(self, obj_id):
90
+ assert self._soft_mask_aug is not None
91
+ original_encoded_masks = self._original_data[1]
92
+ layer_indx, mask_id = self._objects[obj_id]['mapping']
93
+ obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(np.float32)
94
+ obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)['image']
95
+ return np.clip(obj_mask, 0, 1)
96
+
97
+ def get_background_mask(self):
98
+ return np.max(self._encoded_masks, axis=2) == 0
99
+
100
+ @property
101
+ def objects_ids(self):
102
+ return list(self._objects.keys())
103
+
104
+ @property
105
+ def gt_mask(self):
106
+ assert len(self._objects) == 1
107
+ return self.get_object_mask(self.objects_ids[0])
108
+
109
+ @property
110
+ def root_objects(self):
111
+ return [obj_id for obj_id, obj_info in self._objects.items() if obj_info['parent'] is None]
112
+
113
+ def _compute_objects_areas(self):
114
+ inverse_index = {node['mapping']: node_id for node_id, node in self._objects.items()}
115
+ ignored_regions_keys = set(self._ignored_regions)
116
+
117
+ for layer_indx in range(self._encoded_masks.shape[2]):
118
+ objects_ids, objects_areas = get_labels_with_sizes(self._encoded_masks[:, :, layer_indx])
119
+ for obj_id, obj_area in zip(objects_ids, objects_areas):
120
+ inv_key = (layer_indx, obj_id)
121
+ if inv_key in ignored_regions_keys:
122
+ continue
123
+ try:
124
+ self._objects[inverse_index[inv_key]]['area'] = obj_area
125
+ del inverse_index[inv_key]
126
+ except KeyError:
127
+ layer = self._encoded_masks[:, :, layer_indx]
128
+ layer[layer == obj_id] = 0
129
+ self._encoded_masks[:, :, layer_indx] = layer
130
+
131
+ for obj_id in inverse_index.values():
132
+ self._objects[obj_id]['area'] = 0
133
+
134
+ def _remove_object(self, obj_id):
135
+ obj_info = self._objects[obj_id]
136
+ obj_parent = obj_info['parent']
137
+ for child_id in obj_info['children']:
138
+ self._objects[child_id]['parent'] = obj_parent
139
+
140
+ if obj_parent is not None:
141
+ parent_children = self._objects[obj_parent]['children']
142
+ parent_children = [x for x in parent_children if x != obj_id]
143
+ self._objects[obj_parent]['children'] = parent_children + obj_info['children']
144
+
145
+ del self._objects[obj_id]
146
+
147
+ def __len__(self):
148
+ return len(self._objects)
isegm/data/transforms.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+ import numpy as np
4
+
5
+ from albumentations.core.serialization import SERIALIZABLE_REGISTRY
6
+ from albumentations import ImageOnlyTransform, DualTransform
7
+ from albumentations.core.transforms_interface import to_tuple
8
+ from albumentations.augmentations import functional as F
9
+ from isegm.utils.misc import get_bbox_from_mask, expand_bbox, clamp_bbox, get_labels_with_sizes
10
+
11
+
12
+ class UniformRandomResize(DualTransform):
13
+ def __init__(self, scale_range=(0.9, 1.1), interpolation=cv2.INTER_LINEAR, always_apply=False, p=1):
14
+ super().__init__(always_apply, p)
15
+ self.scale_range = scale_range
16
+ self.interpolation = interpolation
17
+
18
+ def get_params_dependent_on_targets(self, params):
19
+ scale = random.uniform(*self.scale_range)
20
+ height = int(round(params['image'].shape[0] * scale))
21
+ width = int(round(params['image'].shape[1] * scale))
22
+ return {'new_height': height, 'new_width': width}
23
+
24
+ def apply(self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params):
25
+ return F.resize(img, height=new_height, width=new_width, interpolation=interpolation)
26
+
27
+ def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params):
28
+ scale_x = new_width / params["cols"]
29
+ scale_y = new_height / params["rows"]
30
+ return F.keypoint_scale(keypoint, scale_x, scale_y)
31
+
32
+ def get_transform_init_args_names(self):
33
+ return "scale_range", "interpolation"
34
+
35
+ @property
36
+ def targets_as_params(self):
37
+ return ["image"]
38
+
39
+
40
+ class ZoomIn(DualTransform):
41
+ def __init__(
42
+ self,
43
+ height,
44
+ width,
45
+ bbox_jitter=0.1,
46
+ expansion_ratio=1.4,
47
+ min_crop_size=200,
48
+ min_area=100,
49
+ always_resize=False,
50
+ always_apply=False,
51
+ p=0.5,
52
+ ):
53
+ super(ZoomIn, self).__init__(always_apply, p)
54
+ self.height = height
55
+ self.width = width
56
+ self.bbox_jitter = to_tuple(bbox_jitter)
57
+ self.expansion_ratio = expansion_ratio
58
+ self.min_crop_size = min_crop_size
59
+ self.min_area = min_area
60
+ self.always_resize = always_resize
61
+
62
+ def apply(self, img, selected_object, bbox, **params):
63
+ if selected_object is None:
64
+ if self.always_resize:
65
+ img = F.resize(img, height=self.height, width=self.width)
66
+ return img
67
+
68
+ rmin, rmax, cmin, cmax = bbox
69
+ img = img[rmin:rmax + 1, cmin:cmax + 1]
70
+ img = F.resize(img, height=self.height, width=self.width)
71
+
72
+ return img
73
+
74
+ def apply_to_mask(self, mask, selected_object, bbox, **params):
75
+ if selected_object is None:
76
+ if self.always_resize:
77
+ mask = F.resize(mask, height=self.height, width=self.width,
78
+ interpolation=cv2.INTER_NEAREST)
79
+ return mask
80
+
81
+ rmin, rmax, cmin, cmax = bbox
82
+ mask = mask[rmin:rmax + 1, cmin:cmax + 1]
83
+ if isinstance(selected_object, tuple):
84
+ layer_indx, mask_id = selected_object
85
+ obj_mask = mask[:, :, layer_indx] == mask_id
86
+ new_mask = np.zeros_like(mask)
87
+ new_mask[:, :, layer_indx][obj_mask] = mask_id
88
+ else:
89
+ obj_mask = mask == selected_object
90
+ new_mask = mask.copy()
91
+ new_mask[np.logical_not(obj_mask)] = 0
92
+
93
+ new_mask = F.resize(new_mask, height=self.height, width=self.width,
94
+ interpolation=cv2.INTER_NEAREST)
95
+ return new_mask
96
+
97
+ def get_params_dependent_on_targets(self, params):
98
+ instances = params['mask']
99
+
100
+ is_mask_layer = len(instances.shape) > 2
101
+ candidates = []
102
+ if is_mask_layer:
103
+ for layer_indx in range(instances.shape[2]):
104
+ labels, areas = get_labels_with_sizes(instances[:, :, layer_indx])
105
+ candidates.extend([(layer_indx, obj_id)
106
+ for obj_id, area in zip(labels, areas)
107
+ if area > self.min_area])
108
+ else:
109
+ labels, areas = get_labels_with_sizes(instances)
110
+ candidates = [obj_id for obj_id, area in zip(labels, areas)
111
+ if area > self.min_area]
112
+
113
+ selected_object = None
114
+ bbox = None
115
+ if candidates:
116
+ selected_object = random.choice(candidates)
117
+ if is_mask_layer:
118
+ layer_indx, mask_id = selected_object
119
+ obj_mask = instances[:, :, layer_indx] == mask_id
120
+ else:
121
+ obj_mask = instances == selected_object
122
+
123
+ bbox = get_bbox_from_mask(obj_mask)
124
+
125
+ if isinstance(self.expansion_ratio, tuple):
126
+ expansion_ratio = random.uniform(*self.expansion_ratio)
127
+ else:
128
+ expansion_ratio = self.expansion_ratio
129
+
130
+ bbox = expand_bbox(bbox, expansion_ratio, self.min_crop_size)
131
+ bbox = self._jitter_bbox(bbox)
132
+ bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1)
133
+
134
+ return {
135
+ 'selected_object': selected_object,
136
+ 'bbox': bbox
137
+ }
138
+
139
+ def _jitter_bbox(self, bbox):
140
+ rmin, rmax, cmin, cmax = bbox
141
+ height = rmax - rmin + 1
142
+ width = cmax - cmin + 1
143
+ rmin = int(rmin + random.uniform(*self.bbox_jitter) * height)
144
+ rmax = int(rmax + random.uniform(*self.bbox_jitter) * height)
145
+ cmin = int(cmin + random.uniform(*self.bbox_jitter) * width)
146
+ cmax = int(cmax + random.uniform(*self.bbox_jitter) * width)
147
+
148
+ return rmin, rmax, cmin, cmax
149
+
150
+ def apply_to_bbox(self, bbox, **params):
151
+ raise NotImplementedError
152
+
153
+ def apply_to_keypoint(self, keypoint, **params):
154
+ raise NotImplementedError
155
+
156
+ @property
157
+ def targets_as_params(self):
158
+ return ["mask"]
159
+
160
+ def get_transform_init_args_names(self):
161
+ return ("height", "width", "bbox_jitter",
162
+ "expansion_ratio", "min_crop_size", "min_area", "always_resize")
163
+
164
+
165
+ def remove_image_only_transforms(sdict):
166
+ if not 'transforms' in sdict:
167
+ return sdict
168
+
169
+ keep_transforms = []
170
+ for tdict in sdict['transforms']:
171
+ cls = SERIALIZABLE_REGISTRY[tdict['__class_fullname__']]
172
+ if 'transforms' in tdict:
173
+ keep_transforms.append(remove_image_only_transforms(tdict))
174
+ elif not issubclass(cls, ImageOnlyTransform):
175
+ keep_transforms.append(tdict)
176
+ sdict['transforms'] = keep_transforms
177
+
178
+ return sdict
isegm/engine/optimizer.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from isegm.utils.log import logger
4
+
5
+
6
+ def get_optimizer(model, opt_name, opt_kwargs):
7
+ params = []
8
+ base_lr = opt_kwargs['lr']
9
+ for name, param in model.named_parameters():
10
+ param_group = {'params': [param]}
11
+ if not param.requires_grad:
12
+ params.append(param_group)
13
+ continue
14
+
15
+ if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0):
16
+ logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.')
17
+ param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult
18
+
19
+ params.append(param_group)
20
+
21
+ optimizer = {
22
+ 'sgd': torch.optim.SGD,
23
+ 'adam': torch.optim.Adam,
24
+ 'adamw': torch.optim.AdamW
25
+ }[opt_name.lower()](params, **opt_kwargs)
26
+
27
+ return optimizer
isegm/engine/trainer.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import logging
4
+ from copy import deepcopy
5
+ from collections import defaultdict
6
+
7
+ import cv2
8
+ import torch
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ from torch.utils.data import DataLoader
12
+
13
+ from isegm.utils.log import logger, TqdmToLogger, SummaryWriterAvg
14
+ from isegm.utils.vis import draw_probmap, draw_points
15
+ from isegm.utils.misc import save_checkpoint
16
+ from isegm.utils.serialization import get_config_repr
17
+ from isegm.utils.distributed import get_dp_wrapper, get_sampler, reduce_loss_dict
18
+ from .optimizer import get_optimizer
19
+
20
+
21
+ class ISTrainer(object):
22
+ def __init__(self, model, cfg, model_cfg, loss_cfg,
23
+ trainset, valset,
24
+ optimizer='adam',
25
+ optimizer_params=None,
26
+ image_dump_interval=200,
27
+ checkpoint_interval=10,
28
+ tb_dump_period=25,
29
+ max_interactive_points=0,
30
+ lr_scheduler=None,
31
+ metrics=None,
32
+ additional_val_metrics=None,
33
+ net_inputs=('images', 'points'),
34
+ max_num_next_clicks=0,
35
+ click_models=None,
36
+ prev_mask_drop_prob=0.0,
37
+ ):
38
+ self.cfg = cfg
39
+ self.model_cfg = model_cfg
40
+ self.max_interactive_points = max_interactive_points
41
+ self.loss_cfg = loss_cfg
42
+ self.val_loss_cfg = deepcopy(loss_cfg)
43
+ self.tb_dump_period = tb_dump_period
44
+ self.net_inputs = net_inputs
45
+ self.max_num_next_clicks = max_num_next_clicks
46
+
47
+ self.click_models = click_models
48
+ self.prev_mask_drop_prob = prev_mask_drop_prob
49
+
50
+ if cfg.distributed:
51
+ cfg.batch_size //= cfg.ngpus
52
+ cfg.val_batch_size //= cfg.ngpus
53
+
54
+ if metrics is None:
55
+ metrics = []
56
+ self.train_metrics = metrics
57
+ self.val_metrics = deepcopy(metrics)
58
+ if additional_val_metrics is not None:
59
+ self.val_metrics.extend(additional_val_metrics)
60
+
61
+ self.checkpoint_interval = checkpoint_interval
62
+ self.image_dump_interval = image_dump_interval
63
+ self.task_prefix = ''
64
+ self.sw = None
65
+
66
+ self.trainset = trainset
67
+ self.valset = valset
68
+
69
+ logger.info(f'Dataset of {trainset.get_samples_number()} samples was loaded for training.')
70
+ logger.info(f'Dataset of {valset.get_samples_number()} samples was loaded for validation.')
71
+
72
+ self.train_data = DataLoader(
73
+ trainset, cfg.batch_size,
74
+ sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed),
75
+ drop_last=True, pin_memory=True,
76
+ num_workers=cfg.workers
77
+ )
78
+
79
+ self.val_data = DataLoader(
80
+ valset, cfg.val_batch_size,
81
+ sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed),
82
+ drop_last=True, pin_memory=True,
83
+ num_workers=cfg.workers
84
+ )
85
+
86
+ self.optim = get_optimizer(model, optimizer, optimizer_params)
87
+ model = self._load_weights(model)
88
+
89
+ if cfg.multi_gpu:
90
+ model = get_dp_wrapper(cfg.distributed)(model, device_ids=cfg.gpu_ids,
91
+ output_device=cfg.gpu_ids[0])
92
+
93
+ if self.is_master:
94
+ logger.info(model)
95
+ logger.info(get_config_repr(model._config))
96
+
97
+ self.device = cfg.device
98
+ self.net = model.to(self.device)
99
+ self.lr = optimizer_params['lr']
100
+
101
+ if lr_scheduler is not None:
102
+ self.lr_scheduler = lr_scheduler(optimizer=self.optim)
103
+ if cfg.start_epoch > 0:
104
+ for _ in range(cfg.start_epoch):
105
+ self.lr_scheduler.step()
106
+
107
+ self.tqdm_out = TqdmToLogger(logger, level=logging.INFO)
108
+
109
+ if self.click_models is not None:
110
+ for click_model in self.click_models:
111
+ for param in click_model.parameters():
112
+ param.requires_grad = False
113
+ click_model.to(self.device)
114
+ click_model.eval()
115
+
116
+ def run(self, num_epochs, start_epoch=None, validation=True):
117
+ if start_epoch is None:
118
+ start_epoch = self.cfg.start_epoch
119
+
120
+ logger.info(f'Starting Epoch: {start_epoch}')
121
+ logger.info(f'Total Epochs: {num_epochs}')
122
+ for epoch in range(start_epoch, num_epochs):
123
+ self.training(epoch)
124
+ if validation:
125
+ self.validation(epoch)
126
+
127
+ def training(self, epoch):
128
+ if self.sw is None and self.is_master:
129
+ self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH),
130
+ flush_secs=10, dump_period=self.tb_dump_period)
131
+
132
+ if self.cfg.distributed:
133
+ self.train_data.sampler.set_epoch(epoch)
134
+
135
+ log_prefix = 'Train' + self.task_prefix.capitalize()
136
+ tbar = tqdm(self.train_data, file=self.tqdm_out, ncols=100)\
137
+ if self.is_master else self.train_data
138
+
139
+ for metric in self.train_metrics:
140
+ metric.reset_epoch_stats()
141
+
142
+ self.net.train()
143
+ train_loss = 0.0
144
+ for i, batch_data in enumerate(tbar):
145
+ global_step = epoch * len(self.train_data) + i
146
+
147
+ loss, losses_logging, splitted_batch_data, outputs = \
148
+ self.batch_forward(batch_data)
149
+
150
+ self.optim.zero_grad()
151
+ loss.backward()
152
+ self.optim.step()
153
+
154
+ losses_logging['overall'] = loss
155
+ reduce_loss_dict(losses_logging)
156
+
157
+ train_loss += losses_logging['overall'].item()
158
+
159
+ if self.is_master:
160
+ for loss_name, loss_value in losses_logging.items():
161
+ self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}',
162
+ value=loss_value.item(),
163
+ global_step=global_step)
164
+
165
+ for k, v in self.loss_cfg.items():
166
+ if '_loss' in k and hasattr(v, 'log_states') and self.loss_cfg.get(k + '_weight', 0.0) > 0:
167
+ v.log_states(self.sw, f'{log_prefix}Losses/{k}', global_step)
168
+
169
+ if self.image_dump_interval > 0 and global_step % self.image_dump_interval == 0:
170
+ self.save_visualization(splitted_batch_data, outputs, global_step, prefix='train')
171
+
172
+ self.sw.add_scalar(tag=f'{log_prefix}States/learning_rate',
173
+ value=self.lr if not hasattr(self, 'lr_scheduler') else self.lr_scheduler.get_lr()[-1],
174
+ global_step=global_step)
175
+
176
+ tbar.set_description(f'Epoch {epoch}, training loss {train_loss/(i+1):.4f}')
177
+ for metric in self.train_metrics:
178
+ metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step)
179
+
180
+ if self.is_master:
181
+ for metric in self.train_metrics:
182
+ self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}',
183
+ value=metric.get_epoch_value(),
184
+ global_step=epoch, disable_avg=True)
185
+
186
+ save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix,
187
+ epoch=None, multi_gpu=self.cfg.multi_gpu)
188
+
189
+ if isinstance(self.checkpoint_interval, (list, tuple)):
190
+ checkpoint_interval = [x for x in self.checkpoint_interval if x[0] <= epoch][-1][1]
191
+ else:
192
+ checkpoint_interval = self.checkpoint_interval
193
+
194
+ if epoch % checkpoint_interval == 0:
195
+ save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix,
196
+ epoch=epoch, multi_gpu=self.cfg.multi_gpu)
197
+
198
+ if hasattr(self, 'lr_scheduler'):
199
+ self.lr_scheduler.step()
200
+
201
+ def validation(self, epoch):
202
+ if self.sw is None and self.is_master:
203
+ self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH),
204
+ flush_secs=10, dump_period=self.tb_dump_period)
205
+
206
+ log_prefix = 'Val' + self.task_prefix.capitalize()
207
+ tbar = tqdm(self.val_data, file=self.tqdm_out, ncols=100) if self.is_master else self.val_data
208
+
209
+ for metric in self.val_metrics:
210
+ metric.reset_epoch_stats()
211
+
212
+ val_loss = 0
213
+ losses_logging = defaultdict(list)
214
+
215
+ self.net.eval()
216
+ for i, batch_data in enumerate(tbar):
217
+ global_step = epoch * len(self.val_data) + i
218
+ loss, batch_losses_logging, splitted_batch_data, outputs = \
219
+ self.batch_forward(batch_data, validation=True)
220
+
221
+ batch_losses_logging['overall'] = loss
222
+ reduce_loss_dict(batch_losses_logging)
223
+ for loss_name, loss_value in batch_losses_logging.items():
224
+ losses_logging[loss_name].append(loss_value.item())
225
+
226
+ val_loss += batch_losses_logging['overall'].item()
227
+
228
+ if self.is_master:
229
+ tbar.set_description(f'Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}')
230
+ for metric in self.val_metrics:
231
+ metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step)
232
+
233
+ if self.is_master:
234
+ for loss_name, loss_values in losses_logging.items():
235
+ self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', value=np.array(loss_values).mean(),
236
+ global_step=epoch, disable_avg=True)
237
+
238
+ for metric in self.val_metrics:
239
+ self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', value=metric.get_epoch_value(),
240
+ global_step=epoch, disable_avg=True)
241
+
242
+ def batch_forward(self, batch_data, validation=False):
243
+ metrics = self.val_metrics if validation else self.train_metrics
244
+ losses_logging = dict()
245
+
246
+ with torch.set_grad_enabled(not validation):
247
+ batch_data = {k: v.to(self.device) for k, v in batch_data.items()}
248
+ image, gt_mask, points = batch_data['images'], batch_data['instances'], batch_data['points']
249
+ orig_image, orig_gt_mask, orig_points = image.clone(), gt_mask.clone(), points.clone()
250
+
251
+ prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :]
252
+
253
+ last_click_indx = None
254
+
255
+ with torch.no_grad():
256
+ num_iters = random.randint(0, self.max_num_next_clicks)
257
+
258
+ for click_indx in range(num_iters):
259
+ last_click_indx = click_indx
260
+
261
+ if not validation:
262
+ self.net.eval()
263
+
264
+ if self.click_models is None or click_indx >= len(self.click_models):
265
+ eval_model = self.net
266
+ else:
267
+ eval_model = self.click_models[click_indx]
268
+
269
+ net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image
270
+ prev_output = torch.sigmoid(eval_model(net_input, points)['instances'])
271
+
272
+ points = get_next_points(prev_output, orig_gt_mask, points, click_indx + 1)
273
+
274
+ if not validation:
275
+ self.net.train()
276
+
277
+ if self.net.with_prev_mask and self.prev_mask_drop_prob > 0 and last_click_indx is not None:
278
+ zero_mask = np.random.random(size=prev_output.size(0)) < self.prev_mask_drop_prob
279
+ prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask])
280
+
281
+ batch_data['points'] = points
282
+
283
+ net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image
284
+ output = self.net(net_input, points)
285
+
286
+ loss = 0.0
287
+ loss = self.add_loss('instance_loss', loss, losses_logging, validation,
288
+ lambda: (output['instances'], batch_data['instances']))
289
+ loss = self.add_loss('instance_aux_loss', loss, losses_logging, validation,
290
+ lambda: (output['instances_aux'], batch_data['instances']))
291
+
292
+ if self.is_master:
293
+ with torch.no_grad():
294
+ for m in metrics:
295
+ m.update(*(output.get(x) for x in m.pred_outputs),
296
+ *(batch_data[x] for x in m.gt_outputs))
297
+ return loss, losses_logging, batch_data, output
298
+
299
+ def add_loss(self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs):
300
+ loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg
301
+ loss_weight = loss_cfg.get(loss_name + '_weight', 0.0)
302
+ if loss_weight > 0.0:
303
+ loss_criterion = loss_cfg.get(loss_name)
304
+ loss = loss_criterion(*lambda_loss_inputs())
305
+ loss = torch.mean(loss)
306
+ losses_logging[loss_name] = loss
307
+ loss = loss_weight * loss
308
+ total_loss = total_loss + loss
309
+
310
+ return total_loss
311
+
312
+ def save_visualization(self, splitted_batch_data, outputs, global_step, prefix):
313
+ output_images_path = self.cfg.VIS_PATH / prefix
314
+ if self.task_prefix:
315
+ output_images_path /= self.task_prefix
316
+
317
+ if not output_images_path.exists():
318
+ output_images_path.mkdir(parents=True)
319
+ image_name_prefix = f'{global_step:06d}'
320
+
321
+ def _save_image(suffix, image):
322
+ cv2.imwrite(str(output_images_path / f'{image_name_prefix}_{suffix}.jpg'),
323
+ image, [cv2.IMWRITE_JPEG_QUALITY, 85])
324
+
325
+ images = splitted_batch_data['images']
326
+ points = splitted_batch_data['points']
327
+ instance_masks = splitted_batch_data['instances']
328
+
329
+ gt_instance_masks = instance_masks.cpu().numpy()
330
+ predicted_instance_masks = torch.sigmoid(outputs['instances']).detach().cpu().numpy()
331
+ points = points.detach().cpu().numpy()
332
+
333
+ image_blob, points = images[0], points[0]
334
+ gt_mask = np.squeeze(gt_instance_masks[0], axis=0)
335
+ predicted_mask = np.squeeze(predicted_instance_masks[0], axis=0)
336
+
337
+ image = image_blob.cpu().numpy() * 255
338
+ image = image.transpose((1, 2, 0))
339
+
340
+ image_with_points = draw_points(image, points[:self.max_interactive_points], (0, 255, 0))
341
+ image_with_points = draw_points(image_with_points, points[self.max_interactive_points:], (0, 0, 255))
342
+
343
+ gt_mask[gt_mask < 0] = 0.25
344
+ gt_mask = draw_probmap(gt_mask)
345
+ predicted_mask = draw_probmap(predicted_mask)
346
+ viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(np.uint8)
347
+
348
+ _save_image('instance_segmentation', viz_image[:, :, ::-1])
349
+
350
+ def _load_weights(self, net):
351
+ if self.cfg.weights is not None:
352
+ if os.path.isfile(self.cfg.weights):
353
+ load_weights(net, self.cfg.weights)
354
+ self.cfg.weights = None
355
+ else:
356
+ raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'")
357
+ elif self.cfg.resume_exp is not None:
358
+ checkpoints = list(self.cfg.CHECKPOINTS_PATH.glob(f'{self.cfg.resume_prefix}*.pth'))
359
+ assert len(checkpoints) == 1
360
+
361
+ checkpoint_path = checkpoints[0]
362
+ logger.info(f'Load checkpoint from path: {checkpoint_path}')
363
+ load_weights(net, str(checkpoint_path))
364
+ return net
365
+
366
+ @property
367
+ def is_master(self):
368
+ return self.cfg.local_rank == 0
369
+
370
+
371
+ def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49):
372
+ assert click_indx > 0
373
+ pred = pred.cpu().numpy()[:, 0, :, :]
374
+ gt = gt.cpu().numpy()[:, 0, :, :] > 0.5
375
+
376
+ fn_mask = np.logical_and(gt, pred < pred_thresh)
377
+ fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh)
378
+
379
+ fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8)
380
+ fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8)
381
+ num_points = points.size(1) // 2
382
+ points = points.clone()
383
+
384
+ for bindx in range(fn_mask.shape[0]):
385
+ fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1]
386
+ fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1]
387
+
388
+ fn_max_dist = np.max(fn_mask_dt)
389
+ fp_max_dist = np.max(fp_mask_dt)
390
+
391
+ is_positive = fn_max_dist > fp_max_dist
392
+ dt = fn_mask_dt if is_positive else fp_mask_dt
393
+ inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0
394
+ indices = np.argwhere(inner_mask)
395
+ if len(indices) > 0:
396
+ coords = indices[np.random.randint(0, len(indices))]
397
+ if is_positive:
398
+ points[bindx, num_points - click_indx, 0] = float(coords[0])
399
+ points[bindx, num_points - click_indx, 1] = float(coords[1])
400
+ points[bindx, num_points - click_indx, 2] = float(click_indx)
401
+ else:
402
+ points[bindx, 2 * num_points - click_indx, 0] = float(coords[0])
403
+ points[bindx, 2 * num_points - click_indx, 1] = float(coords[1])
404
+ points[bindx, 2 * num_points - click_indx, 2] = float(click_indx)
405
+
406
+ return points
407
+
408
+
409
+ def load_weights(model, path_to_weights):
410
+ current_state_dict = model.state_dict()
411
+ new_state_dict = torch.load(path_to_weights, map_location='cpu')['state_dict']
412
+ current_state_dict.update(new_state_dict)
413
+ model.load_state_dict(current_state_dict)
isegm/inference/__init__.py ADDED
File without changes
isegm/inference/clicker.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from copy import deepcopy
3
+ import cv2
4
+
5
+
6
+ class Clicker(object):
7
+ def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0):
8
+ self.click_indx_offset = click_indx_offset
9
+ if gt_mask is not None:
10
+ self.gt_mask = gt_mask == 1
11
+ self.not_ignore_mask = gt_mask != ignore_label
12
+ else:
13
+ self.gt_mask = None
14
+
15
+ self.reset_clicks()
16
+
17
+ if init_clicks is not None:
18
+ for click in init_clicks:
19
+ self.add_click(click)
20
+
21
+ def make_next_click(self, pred_mask):
22
+ assert self.gt_mask is not None
23
+ click = self._get_next_click(pred_mask)
24
+ self.add_click(click)
25
+
26
+ def get_clicks(self, clicks_limit=None):
27
+ return self.clicks_list[:clicks_limit]
28
+
29
+ def _get_next_click(self, pred_mask, padding=True):
30
+ fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask)
31
+ fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask)
32
+
33
+ if padding:
34
+ fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant')
35
+ fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant')
36
+
37
+ fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
38
+ fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
39
+
40
+ if padding:
41
+ fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
42
+ fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
43
+
44
+ fn_mask_dt = fn_mask_dt * self.not_clicked_map
45
+ fp_mask_dt = fp_mask_dt * self.not_clicked_map
46
+
47
+ fn_max_dist = np.max(fn_mask_dt)
48
+ fp_max_dist = np.max(fp_mask_dt)
49
+
50
+ is_positive = fn_max_dist > fp_max_dist
51
+ if is_positive:
52
+ coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x]
53
+ else:
54
+ coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x]
55
+
56
+ return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0]))
57
+
58
+ def add_click(self, click):
59
+ coords = click.coords
60
+
61
+ click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks
62
+ if click.is_positive:
63
+ self.num_pos_clicks += 1
64
+ else:
65
+ self.num_neg_clicks += 1
66
+
67
+ self.clicks_list.append(click)
68
+ if self.gt_mask is not None:
69
+ self.not_clicked_map[coords[0], coords[1]] = False
70
+
71
+ def _remove_last_click(self):
72
+ click = self.clicks_list.pop()
73
+ coords = click.coords
74
+
75
+ if click.is_positive:
76
+ self.num_pos_clicks -= 1
77
+ else:
78
+ self.num_neg_clicks -= 1
79
+
80
+ if self.gt_mask is not None:
81
+ self.not_clicked_map[coords[0], coords[1]] = True
82
+
83
+ def reset_clicks(self):
84
+ if self.gt_mask is not None:
85
+ self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool)
86
+
87
+ self.num_pos_clicks = 0
88
+ self.num_neg_clicks = 0
89
+
90
+ self.clicks_list = []
91
+
92
+ def get_state(self):
93
+ return deepcopy(self.clicks_list)
94
+
95
+ def set_state(self, state):
96
+ self.reset_clicks()
97
+ for click in state:
98
+ self.add_click(click)
99
+
100
+ def __len__(self):
101
+ return len(self.clicks_list)
102
+
103
+
104
+ class Click:
105
+ def __init__(self, is_positive, coords, indx=None):
106
+ self.is_positive = is_positive
107
+ self.coords = coords
108
+ self.indx = indx
109
+
110
+ @property
111
+ def coords_and_indx(self):
112
+ return (*self.coords, self.indx)
113
+
114
+ def copy(self, **kwargs):
115
+ self_copy = deepcopy(self)
116
+ for k, v in kwargs.items():
117
+ setattr(self_copy, k, v)
118
+ return self_copy
isegm/inference/evaluation.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import time
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from isegm.inference import utils
7
+ from isegm.inference.clicker import Clicker
8
+
9
+ try:
10
+ get_ipython()
11
+ from tqdm import tqdm_notebook as tqdm
12
+ except NameError:
13
+ from tqdm import tqdm
14
+
15
+
16
+ def evaluate_dataset(dataset, predictor, **kwargs):
17
+ all_ious = []
18
+
19
+ start_time = time()
20
+ for index in tqdm(range(len(dataset)), leave=False):
21
+ sample = dataset.get_sample(index)
22
+
23
+ _, sample_ious, _ = evaluate_sample(sample.image, sample.gt_mask, predictor,
24
+ sample_id=index, **kwargs)
25
+ all_ious.append(sample_ious)
26
+ end_time = time()
27
+ elapsed_time = end_time - start_time
28
+
29
+ return all_ious, elapsed_time
30
+
31
+
32
+ def evaluate_sample(image, gt_mask, predictor, max_iou_thr,
33
+ pred_thr=0.49, min_clicks=1, max_clicks=20,
34
+ sample_id=None, callback=None):
35
+ clicker = Clicker(gt_mask=gt_mask)
36
+ pred_mask = np.zeros_like(gt_mask)
37
+ ious_list = []
38
+
39
+ with torch.no_grad():
40
+ predictor.set_input_image(image)
41
+
42
+ for click_indx in range(max_clicks):
43
+ clicker.make_next_click(pred_mask)
44
+ pred_probs = predictor.get_prediction(clicker)
45
+ pred_mask = pred_probs > pred_thr
46
+
47
+ if callback is not None:
48
+ callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list)
49
+
50
+ iou = utils.get_iou(gt_mask, pred_mask)
51
+ ious_list.append(iou)
52
+
53
+ if iou >= max_iou_thr and click_indx + 1 >= min_clicks:
54
+ break
55
+
56
+ return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs
isegm/inference/predictors/__init__.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BasePredictor
2
+ from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor
3
+ from .brs_functors import InputOptimizer, ScaleBiasOptimizer
4
+ from isegm.inference.transforms import ZoomIn
5
+ from isegm.model.is_hrnet_model import HRNetModel
6
+
7
+
8
+ def get_predictor(net, brs_mode, device,
9
+ prob_thresh=0.49,
10
+ with_flip=True,
11
+ zoom_in_params=dict(),
12
+ predictor_params=None,
13
+ brs_opt_func_params=None,
14
+ lbfgs_params=None):
15
+ lbfgs_params_ = {
16
+ 'm': 20,
17
+ 'factr': 0,
18
+ 'pgtol': 1e-8,
19
+ 'maxfun': 20,
20
+ }
21
+
22
+ predictor_params_ = {
23
+ 'optimize_after_n_clicks': 1
24
+ }
25
+
26
+ if zoom_in_params is not None:
27
+ zoom_in = ZoomIn(**zoom_in_params)
28
+ else:
29
+ zoom_in = None
30
+
31
+ if lbfgs_params is not None:
32
+ lbfgs_params_.update(lbfgs_params)
33
+ lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun']
34
+
35
+ if brs_opt_func_params is None:
36
+ brs_opt_func_params = dict()
37
+
38
+ if isinstance(net, (list, tuple)):
39
+ assert brs_mode == 'NoBRS', "Multi-stage models support only NoBRS mode."
40
+
41
+ if brs_mode == 'NoBRS':
42
+ if predictor_params is not None:
43
+ predictor_params_.update(predictor_params)
44
+ predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_)
45
+ elif brs_mode.startswith('f-BRS'):
46
+ predictor_params_.update({
47
+ 'net_clicks_limit': 8,
48
+ })
49
+ if predictor_params is not None:
50
+ predictor_params_.update(predictor_params)
51
+
52
+ insertion_mode = {
53
+ 'f-BRS-A': 'after_c4',
54
+ 'f-BRS-B': 'after_aspp',
55
+ 'f-BRS-C': 'after_deeplab'
56
+ }[brs_mode]
57
+
58
+ opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh,
59
+ with_flip=with_flip,
60
+ optimizer_params=lbfgs_params_,
61
+ **brs_opt_func_params)
62
+
63
+ if isinstance(net, HRNetModel):
64
+ FeaturePredictor = HRNetFeatureBRSPredictor
65
+ insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode]
66
+ else:
67
+ FeaturePredictor = FeatureBRSPredictor
68
+
69
+ predictor = FeaturePredictor(net, device,
70
+ opt_functor=opt_functor,
71
+ with_flip=with_flip,
72
+ insertion_mode=insertion_mode,
73
+ zoom_in=zoom_in,
74
+ **predictor_params_)
75
+ elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS':
76
+ use_dmaps = brs_mode == 'DistMap-BRS'
77
+
78
+ predictor_params_.update({
79
+ 'net_clicks_limit': 5,
80
+ })
81
+ if predictor_params is not None:
82
+ predictor_params_.update(predictor_params)
83
+
84
+ opt_functor = InputOptimizer(prob_thresh=prob_thresh,
85
+ with_flip=with_flip,
86
+ optimizer_params=lbfgs_params_,
87
+ **brs_opt_func_params)
88
+
89
+ predictor = InputBRSPredictor(net, device,
90
+ optimize_target='dmaps' if use_dmaps else 'rgb',
91
+ opt_functor=opt_functor,
92
+ with_flip=with_flip,
93
+ zoom_in=zoom_in,
94
+ **predictor_params_)
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ return predictor
isegm/inference/predictors/base.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchvision import transforms
4
+ from isegm.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide
5
+
6
+
7
+ class BasePredictor(object):
8
+ def __init__(self, model, device,
9
+ net_clicks_limit=None,
10
+ with_flip=False,
11
+ zoom_in=None,
12
+ max_size=None,
13
+ **kwargs):
14
+ self.with_flip = with_flip
15
+ self.net_clicks_limit = net_clicks_limit
16
+ self.original_image = None
17
+ self.device = device
18
+ self.zoom_in = zoom_in
19
+ self.prev_prediction = None
20
+ self.model_indx = 0
21
+ self.click_models = None
22
+ self.net_state_dict = None
23
+
24
+ if isinstance(model, tuple):
25
+ self.net, self.click_models = model
26
+ else:
27
+ self.net = model
28
+
29
+ self.to_tensor = transforms.ToTensor()
30
+
31
+ self.transforms = [zoom_in] if zoom_in is not None else []
32
+ if max_size is not None:
33
+ self.transforms.append(LimitLongestSide(max_size=max_size))
34
+ self.transforms.append(SigmoidForPred())
35
+ if with_flip:
36
+ self.transforms.append(AddHorizontalFlip())
37
+
38
+ def set_input_image(self, image):
39
+ image_nd = self.to_tensor(image)
40
+ for transform in self.transforms:
41
+ transform.reset()
42
+ self.original_image = image_nd.to(self.device)
43
+ if len(self.original_image.shape) == 3:
44
+ self.original_image = self.original_image.unsqueeze(0)
45
+ self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :])
46
+
47
+ def get_prediction(self, clicker, prev_mask=None):
48
+ clicks_list = clicker.get_clicks()
49
+
50
+ if self.click_models is not None:
51
+ model_indx = min(clicker.click_indx_offset + len(clicks_list), len(self.click_models)) - 1
52
+ if model_indx != self.model_indx:
53
+ self.model_indx = model_indx
54
+ self.net = self.click_models[model_indx]
55
+
56
+ input_image = self.original_image
57
+ if prev_mask is None:
58
+ prev_mask = self.prev_prediction
59
+ if hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask:
60
+ input_image = torch.cat((input_image, prev_mask), dim=1)
61
+ image_nd, clicks_lists, is_image_changed = self.apply_transforms(
62
+ input_image, [clicks_list]
63
+ )
64
+
65
+ pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
66
+ prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True,
67
+ size=image_nd.size()[2:])
68
+
69
+ for t in reversed(self.transforms):
70
+ prediction = t.inv_transform(prediction)
71
+
72
+ if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
73
+ return self.get_prediction(clicker)
74
+
75
+ self.prev_prediction = prediction
76
+ return prediction.cpu().numpy()[0, 0]
77
+
78
+ def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
79
+ points_nd = self.get_points_nd(clicks_lists)
80
+ return self.net(image_nd, points_nd)['instances']
81
+
82
+ def _get_transform_states(self):
83
+ return [x.get_state() for x in self.transforms]
84
+
85
+ def _set_transform_states(self, states):
86
+ assert len(states) == len(self.transforms)
87
+ for state, transform in zip(states, self.transforms):
88
+ transform.set_state(state)
89
+
90
+ def apply_transforms(self, image_nd, clicks_lists):
91
+ is_image_changed = False
92
+ for t in self.transforms:
93
+ image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
94
+ is_image_changed |= t.image_changed
95
+
96
+ return image_nd, clicks_lists, is_image_changed
97
+
98
+ def get_points_nd(self, clicks_lists):
99
+ total_clicks = []
100
+ num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
101
+ num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
102
+ num_max_points = max(num_pos_clicks + num_neg_clicks)
103
+ if self.net_clicks_limit is not None:
104
+ num_max_points = min(self.net_clicks_limit, num_max_points)
105
+ num_max_points = max(1, num_max_points)
106
+
107
+ for clicks_list in clicks_lists:
108
+ clicks_list = clicks_list[:self.net_clicks_limit]
109
+ pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
110
+ pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
111
+
112
+ neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
113
+ neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
114
+ total_clicks.append(pos_clicks + neg_clicks)
115
+
116
+ return torch.tensor(total_clicks, device=self.device)
117
+
118
+ def get_states(self):
119
+ return {
120
+ 'transform_states': self._get_transform_states(),
121
+ 'prev_prediction': self.prev_prediction.clone()
122
+ }
123
+
124
+ def set_states(self, states):
125
+ self._set_transform_states(states['transform_states'])
126
+ self.prev_prediction = states['prev_prediction']
isegm/inference/predictors/brs.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from scipy.optimize import fmin_l_bfgs_b
5
+
6
+ from .base import BasePredictor
7
+
8
+
9
+ class BRSBasePredictor(BasePredictor):
10
+ def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs):
11
+ super().__init__(model, device, **kwargs)
12
+ self.optimize_after_n_clicks = optimize_after_n_clicks
13
+ self.opt_functor = opt_functor
14
+
15
+ self.opt_data = None
16
+ self.input_data = None
17
+
18
+ def set_input_image(self, image):
19
+ super().set_input_image(image)
20
+ self.opt_data = None
21
+ self.input_data = None
22
+
23
+ def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
24
+ pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
25
+ neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
26
+
27
+ for list_indx, clicks_list in enumerate(clicks_lists):
28
+ for click in clicks_list:
29
+ y, x = click.coords
30
+ y, x = int(round(y)), int(round(x))
31
+ y1, x1 = y - radius, x - radius
32
+ y2, x2 = y + radius + 1, x + radius + 1
33
+
34
+ if click.is_positive:
35
+ pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
36
+ else:
37
+ neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
38
+
39
+ with torch.no_grad():
40
+ pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device)
41
+ neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device)
42
+
43
+ return pos_clicks_map, neg_clicks_map
44
+
45
+ def get_states(self):
46
+ return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data}
47
+
48
+ def set_states(self, states):
49
+ self._set_transform_states(states['transform_states'])
50
+ self.opt_data = states['opt_data']
51
+
52
+
53
+ class FeatureBRSPredictor(BRSBasePredictor):
54
+ def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs):
55
+ super().__init__(model, device, opt_functor=opt_functor, **kwargs)
56
+ self.insertion_mode = insertion_mode
57
+ self._c1_features = None
58
+
59
+ if self.insertion_mode == 'after_deeplab':
60
+ self.num_channels = model.feature_extractor.ch
61
+ elif self.insertion_mode == 'after_c4':
62
+ self.num_channels = model.feature_extractor.aspp_in_channels
63
+ elif self.insertion_mode == 'after_aspp':
64
+ self.num_channels = model.feature_extractor.ch + 32
65
+ else:
66
+ raise NotImplementedError
67
+
68
+ def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
69
+ points_nd = self.get_points_nd(clicks_lists)
70
+ pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
71
+
72
+ num_clicks = len(clicks_lists[0])
73
+ bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
74
+
75
+ if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
76
+ self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
77
+
78
+ if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
79
+ self.input_data = self._get_head_input(image_nd, points_nd)
80
+
81
+ def get_prediction_logits(scale, bias):
82
+ scale = scale.view(bs, -1, 1, 1)
83
+ bias = bias.view(bs, -1, 1, 1)
84
+ if self.with_flip:
85
+ scale = scale.repeat(2, 1, 1, 1)
86
+ bias = bias.repeat(2, 1, 1, 1)
87
+
88
+ scaled_backbone_features = self.input_data * scale
89
+ scaled_backbone_features = scaled_backbone_features + bias
90
+ if self.insertion_mode == 'after_c4':
91
+ x = self.net.feature_extractor.aspp(scaled_backbone_features)
92
+ x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:],
93
+ align_corners=True)
94
+ x = torch.cat((x, self._c1_features), dim=1)
95
+ scaled_backbone_features = self.net.feature_extractor.head(x)
96
+ elif self.insertion_mode == 'after_aspp':
97
+ scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features)
98
+
99
+ pred_logits = self.net.head(scaled_backbone_features)
100
+ pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
101
+ align_corners=True)
102
+ return pred_logits
103
+
104
+ self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
105
+ if num_clicks > self.optimize_after_n_clicks:
106
+ opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
107
+ **self.opt_functor.optimizer_params)
108
+ self.opt_data = opt_result[0]
109
+
110
+ with torch.no_grad():
111
+ if self.opt_functor.best_prediction is not None:
112
+ opt_pred_logits = self.opt_functor.best_prediction
113
+ else:
114
+ opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
115
+ opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
116
+ opt_pred_logits = get_prediction_logits(*opt_vars)
117
+
118
+ return opt_pred_logits
119
+
120
+ def _get_head_input(self, image_nd, points):
121
+ with torch.no_grad():
122
+ image_nd, prev_mask = self.net.prepare_input(image_nd)
123
+ coord_features = self.net.get_coord_features(image_nd, prev_mask, points)
124
+
125
+ if self.net.rgb_conv is not None:
126
+ x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
127
+ additional_features = None
128
+ elif hasattr(self.net, 'maps_transform'):
129
+ x = image_nd
130
+ additional_features = self.net.maps_transform(coord_features)
131
+
132
+ if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp':
133
+ c1, _, c3, c4 = self.net.feature_extractor.backbone(x, additional_features)
134
+ c1 = self.net.feature_extractor.skip_project(c1)
135
+
136
+ if self.insertion_mode == 'after_aspp':
137
+ x = self.net.feature_extractor.aspp(c4)
138
+ x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True)
139
+ x = torch.cat((x, c1), dim=1)
140
+ backbone_features = x
141
+ else:
142
+ backbone_features = c4
143
+ self._c1_features = c1
144
+ else:
145
+ backbone_features = self.net.feature_extractor(x, additional_features)[0]
146
+
147
+ return backbone_features
148
+
149
+
150
+ class HRNetFeatureBRSPredictor(BRSBasePredictor):
151
+ def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs):
152
+ super().__init__(model, device, opt_functor=opt_functor, **kwargs)
153
+ self.insertion_mode = insertion_mode
154
+ self._c1_features = None
155
+
156
+ if self.insertion_mode == 'A':
157
+ self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8])
158
+ elif self.insertion_mode == 'C':
159
+ self.num_channels = 2 * model.feature_extractor.ocr_width
160
+ else:
161
+ raise NotImplementedError
162
+
163
+ def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
164
+ points_nd = self.get_points_nd(clicks_lists)
165
+ pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
166
+ num_clicks = len(clicks_lists[0])
167
+ bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
168
+
169
+ if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
170
+ self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
171
+
172
+ if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
173
+ self.input_data = self._get_head_input(image_nd, points_nd)
174
+
175
+ def get_prediction_logits(scale, bias):
176
+ scale = scale.view(bs, -1, 1, 1)
177
+ bias = bias.view(bs, -1, 1, 1)
178
+ if self.with_flip:
179
+ scale = scale.repeat(2, 1, 1, 1)
180
+ bias = bias.repeat(2, 1, 1, 1)
181
+
182
+ scaled_backbone_features = self.input_data * scale
183
+ scaled_backbone_features = scaled_backbone_features + bias
184
+ if self.insertion_mode == 'A':
185
+ if self.net.feature_extractor.ocr_width > 0:
186
+ out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features)
187
+ feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features)
188
+
189
+ context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
190
+ feats = self.net.feature_extractor.ocr_distri_head(feats, context)
191
+ else:
192
+ feats = scaled_backbone_features
193
+ pred_logits = self.net.feature_extractor.cls_head(feats)
194
+ elif self.insertion_mode == 'C':
195
+ pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features)
196
+ else:
197
+ raise NotImplementedError
198
+
199
+ pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
200
+ align_corners=True)
201
+ return pred_logits
202
+
203
+ self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
204
+ if num_clicks > self.optimize_after_n_clicks:
205
+ opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
206
+ **self.opt_functor.optimizer_params)
207
+ self.opt_data = opt_result[0]
208
+
209
+ with torch.no_grad():
210
+ if self.opt_functor.best_prediction is not None:
211
+ opt_pred_logits = self.opt_functor.best_prediction
212
+ else:
213
+ opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
214
+ opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
215
+ opt_pred_logits = get_prediction_logits(*opt_vars)
216
+
217
+ return opt_pred_logits
218
+
219
+ def _get_head_input(self, image_nd, points):
220
+ with torch.no_grad():
221
+ image_nd, prev_mask = self.net.prepare_input(image_nd)
222
+ coord_features = self.net.get_coord_features(image_nd, prev_mask, points)
223
+
224
+ if self.net.rgb_conv is not None:
225
+ x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
226
+ additional_features = None
227
+ elif hasattr(self.net, 'maps_transform'):
228
+ x = image_nd
229
+ additional_features = self.net.maps_transform(coord_features)
230
+
231
+ feats = self.net.feature_extractor.compute_hrnet_feats(x, additional_features)
232
+
233
+ if self.insertion_mode == 'A':
234
+ backbone_features = feats
235
+ elif self.insertion_mode == 'C':
236
+ out_aux = self.net.feature_extractor.aux_head(feats)
237
+ feats = self.net.feature_extractor.conv3x3_ocr(feats)
238
+
239
+ context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
240
+ backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context)
241
+ else:
242
+ raise NotImplementedError
243
+
244
+ return backbone_features
245
+
246
+
247
+ class InputBRSPredictor(BRSBasePredictor):
248
+ def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs):
249
+ super().__init__(model, device, opt_functor=opt_functor, **kwargs)
250
+ self.optimize_target = optimize_target
251
+
252
+ def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
253
+ points_nd = self.get_points_nd(clicks_lists)
254
+ pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
255
+ num_clicks = len(clicks_lists[0])
256
+
257
+ if self.opt_data is None or is_image_changed:
258
+ if self.optimize_target == 'dmaps':
259
+ opt_channels = self.net.coord_feature_ch - 1 if self.net.with_prev_mask else self.net.coord_feature_ch
260
+ else:
261
+ opt_channels = 3
262
+ bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
263
+ self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
264
+ device=self.device, dtype=torch.float32)
265
+
266
+ def get_prediction_logits(opt_bias):
267
+ input_image, prev_mask = self.net.prepare_input(image_nd)
268
+ dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd)
269
+
270
+ if self.optimize_target == 'rgb':
271
+ input_image = input_image + opt_bias
272
+ elif self.optimize_target == 'dmaps':
273
+ if self.net.with_prev_mask:
274
+ dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias
275
+ else:
276
+ dmaps = dmaps + opt_bias
277
+
278
+ if self.net.rgb_conv is not None:
279
+ x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
280
+ if self.optimize_target == 'all':
281
+ x = x + opt_bias
282
+ coord_features = None
283
+ elif hasattr(self.net, 'maps_transform'):
284
+ x = input_image
285
+ coord_features = self.net.maps_transform(dmaps)
286
+
287
+ pred_logits = self.net.backbone_forward(x, coord_features=coord_features)['instances']
288
+ pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True)
289
+
290
+ return pred_logits
291
+
292
+ self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device,
293
+ shape=self.opt_data.shape)
294
+ if num_clicks > self.optimize_after_n_clicks:
295
+ opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(),
296
+ **self.opt_functor.optimizer_params)
297
+
298
+ self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device)
299
+
300
+ with torch.no_grad():
301
+ if self.opt_functor.best_prediction is not None:
302
+ opt_pred_logits = self.opt_functor.best_prediction
303
+ else:
304
+ opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data)
305
+ opt_pred_logits = get_prediction_logits(*opt_vars)
306
+
307
+ return opt_pred_logits
isegm/inference/predictors/brs_functors.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from isegm.model.metrics import _compute_iou
5
+ from .brs_losses import BRSMaskLoss
6
+
7
+
8
+ class BaseOptimizer:
9
+ def __init__(self, optimizer_params,
10
+ prob_thresh=0.49,
11
+ reg_weight=1e-3,
12
+ min_iou_diff=0.01,
13
+ brs_loss=BRSMaskLoss(),
14
+ with_flip=False,
15
+ flip_average=False,
16
+ **kwargs):
17
+ self.brs_loss = brs_loss
18
+ self.optimizer_params = optimizer_params
19
+ self.prob_thresh = prob_thresh
20
+ self.reg_weight = reg_weight
21
+ self.min_iou_diff = min_iou_diff
22
+ self.with_flip = with_flip
23
+ self.flip_average = flip_average
24
+
25
+ self.best_prediction = None
26
+ self._get_prediction_logits = None
27
+ self._opt_shape = None
28
+ self._best_loss = None
29
+ self._click_masks = None
30
+ self._last_mask = None
31
+ self.device = None
32
+
33
+ def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None):
34
+ self.best_prediction = None
35
+ self._get_prediction_logits = get_prediction_logits
36
+ self._click_masks = (pos_mask, neg_mask)
37
+ self._opt_shape = shape
38
+ self._last_mask = None
39
+ self.device = device
40
+
41
+ def __call__(self, x):
42
+ opt_params = torch.from_numpy(x).float().to(self.device)
43
+ opt_params.requires_grad_(True)
44
+
45
+ with torch.enable_grad():
46
+ opt_vars, reg_loss = self.unpack_opt_params(opt_params)
47
+ result_before_sigmoid = self._get_prediction_logits(*opt_vars)
48
+ result = torch.sigmoid(result_before_sigmoid)
49
+
50
+ pos_mask, neg_mask = self._click_masks
51
+ if self.with_flip and self.flip_average:
52
+ result, result_flipped = torch.chunk(result, 2, dim=0)
53
+ result = 0.5 * (result + torch.flip(result_flipped, dims=[3]))
54
+ pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]]
55
+
56
+ loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask)
57
+ loss = loss + reg_loss
58
+
59
+ f_val = loss.detach().cpu().numpy()
60
+ if self.best_prediction is None or f_val < self._best_loss:
61
+ self.best_prediction = result_before_sigmoid.detach()
62
+ self._best_loss = f_val
63
+
64
+ if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh:
65
+ return [f_val, np.zeros_like(x)]
66
+
67
+ current_mask = result > self.prob_thresh
68
+ if self._last_mask is not None and self.min_iou_diff > 0:
69
+ diff_iou = _compute_iou(current_mask, self._last_mask)
70
+ if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff:
71
+ return [f_val, np.zeros_like(x)]
72
+ self._last_mask = current_mask
73
+
74
+ loss.backward()
75
+ f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float)
76
+
77
+ return [f_val, f_grad]
78
+
79
+ def unpack_opt_params(self, opt_params):
80
+ raise NotImplementedError
81
+
82
+
83
+ class InputOptimizer(BaseOptimizer):
84
+ def unpack_opt_params(self, opt_params):
85
+ opt_params = opt_params.view(self._opt_shape)
86
+ if self.with_flip:
87
+ opt_params_flipped = torch.flip(opt_params, dims=[3])
88
+ opt_params = torch.cat([opt_params, opt_params_flipped], dim=0)
89
+ reg_loss = self.reg_weight * torch.sum(opt_params**2)
90
+
91
+ return (opt_params,), reg_loss
92
+
93
+
94
+ class ScaleBiasOptimizer(BaseOptimizer):
95
+ def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs):
96
+ super().__init__(*args, **kwargs)
97
+ self.scale_act = scale_act
98
+ self.reg_bias_weight = reg_bias_weight
99
+
100
+ def unpack_opt_params(self, opt_params):
101
+ scale, bias = torch.chunk(opt_params, 2, dim=0)
102
+ reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2))
103
+
104
+ if self.scale_act == 'tanh':
105
+ scale = torch.tanh(scale)
106
+ elif self.scale_act == 'sin':
107
+ scale = torch.sin(scale)
108
+
109
+ return (1 + scale, bias), reg_loss
isegm/inference/predictors/brs_losses.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from isegm.model.losses import SigmoidBinaryCrossEntropyLoss
4
+
5
+
6
+ class BRSMaskLoss(torch.nn.Module):
7
+ def __init__(self, eps=1e-5):
8
+ super().__init__()
9
+ self._eps = eps
10
+
11
+ def forward(self, result, pos_mask, neg_mask):
12
+ pos_diff = (1 - result) * pos_mask
13
+ pos_target = torch.sum(pos_diff ** 2)
14
+ pos_target = pos_target / (torch.sum(pos_mask) + self._eps)
15
+
16
+ neg_diff = result * neg_mask
17
+ neg_target = torch.sum(neg_diff ** 2)
18
+ neg_target = neg_target / (torch.sum(neg_mask) + self._eps)
19
+
20
+ loss = pos_target + neg_target
21
+
22
+ with torch.no_grad():
23
+ f_max_pos = torch.max(torch.abs(pos_diff)).item()
24
+ f_max_neg = torch.max(torch.abs(neg_diff)).item()
25
+
26
+ return loss, f_max_pos, f_max_neg
27
+
28
+
29
+ class OracleMaskLoss(torch.nn.Module):
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.gt_mask = None
33
+ self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True)
34
+ self.predictor = None
35
+ self.history = []
36
+
37
+ def set_gt_mask(self, gt_mask):
38
+ self.gt_mask = gt_mask
39
+ self.history = []
40
+
41
+ def forward(self, result, pos_mask, neg_mask):
42
+ gt_mask = self.gt_mask.to(result.device)
43
+ if self.predictor.object_roi is not None:
44
+ r1, r2, c1, c2 = self.predictor.object_roi[:4]
45
+ gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1]
46
+ gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True)
47
+
48
+ if result.shape[0] == 2:
49
+ gt_mask_flipped = torch.flip(gt_mask, dims=[3])
50
+ gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0)
51
+
52
+ loss = self.loss(result, gt_mask)
53
+ self.history.append(loss.detach().cpu().numpy()[0])
54
+
55
+ if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5:
56
+ return 0, 0, 0
57
+
58
+ return loss, 1.0, 1.0
isegm/inference/transforms/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .base import SigmoidForPred
2
+ from .flip import AddHorizontalFlip
3
+ from .zoom_in import ZoomIn
4
+ from .limit_longest_side import LimitLongestSide
5
+ from .crops import Crops
isegm/inference/transforms/base.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseTransform(object):
5
+ def __init__(self):
6
+ self.image_changed = False
7
+
8
+ def transform(self, image_nd, clicks_lists):
9
+ raise NotImplementedError
10
+
11
+ def inv_transform(self, prob_map):
12
+ raise NotImplementedError
13
+
14
+ def reset(self):
15
+ raise NotImplementedError
16
+
17
+ def get_state(self):
18
+ raise NotImplementedError
19
+
20
+ def set_state(self, state):
21
+ raise NotImplementedError
22
+
23
+
24
+ class SigmoidForPred(BaseTransform):
25
+ def transform(self, image_nd, clicks_lists):
26
+ return image_nd, clicks_lists
27
+
28
+ def inv_transform(self, prob_map):
29
+ return torch.sigmoid(prob_map)
30
+
31
+ def reset(self):
32
+ pass
33
+
34
+ def get_state(self):
35
+ return None
36
+
37
+ def set_state(self, state):
38
+ pass
isegm/inference/transforms/crops.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import numpy as np
5
+ from typing import List
6
+
7
+ from isegm.inference.clicker import Click
8
+ from .base import BaseTransform
9
+
10
+
11
+ class Crops(BaseTransform):
12
+ def __init__(self, crop_size=(320, 480), min_overlap=0.2):
13
+ super().__init__()
14
+ self.crop_height, self.crop_width = crop_size
15
+ self.min_overlap = min_overlap
16
+
17
+ self.x_offsets = None
18
+ self.y_offsets = None
19
+ self._counts = None
20
+
21
+ def transform(self, image_nd, clicks_lists: List[List[Click]]):
22
+ assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
23
+ image_height, image_width = image_nd.shape[2:4]
24
+ self._counts = None
25
+
26
+ if image_height < self.crop_height or image_width < self.crop_width:
27
+ return image_nd, clicks_lists
28
+
29
+ self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap)
30
+ self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap)
31
+ self._counts = np.zeros((image_height, image_width))
32
+
33
+ image_crops = []
34
+ for dy in self.y_offsets:
35
+ for dx in self.x_offsets:
36
+ self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1
37
+ image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width]
38
+ image_crops.append(image_crop)
39
+ image_crops = torch.cat(image_crops, dim=0)
40
+ self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32)
41
+
42
+ clicks_list = clicks_lists[0]
43
+ clicks_lists = []
44
+ for dy in self.y_offsets:
45
+ for dx in self.x_offsets:
46
+ crop_clicks = [x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list]
47
+ clicks_lists.append(crop_clicks)
48
+
49
+ return image_crops, clicks_lists
50
+
51
+ def inv_transform(self, prob_map):
52
+ if self._counts is None:
53
+ return prob_map
54
+
55
+ new_prob_map = torch.zeros((1, 1, *self._counts.shape),
56
+ dtype=prob_map.dtype, device=prob_map.device)
57
+
58
+ crop_indx = 0
59
+ for dy in self.y_offsets:
60
+ for dx in self.x_offsets:
61
+ new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0]
62
+ crop_indx += 1
63
+ new_prob_map = torch.div(new_prob_map, self._counts)
64
+
65
+ return new_prob_map
66
+
67
+ def get_state(self):
68
+ return self.x_offsets, self.y_offsets, self._counts
69
+
70
+ def set_state(self, state):
71
+ self.x_offsets, self.y_offsets, self._counts = state
72
+
73
+ def reset(self):
74
+ self.x_offsets = None
75
+ self.y_offsets = None
76
+ self._counts = None
77
+
78
+
79
+ def get_offsets(length, crop_size, min_overlap_ratio=0.2):
80
+ if length == crop_size:
81
+ return [0]
82
+
83
+ N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
84
+ N = math.ceil(N)
85
+
86
+ overlap_ratio = (N - length / crop_size) / (N - 1)
87
+ overlap_width = int(crop_size * overlap_ratio)
88
+
89
+ offsets = [0]
90
+ for i in range(1, N):
91
+ new_offset = offsets[-1] + crop_size - overlap_width
92
+ if new_offset + crop_size > length:
93
+ new_offset = length - crop_size
94
+
95
+ offsets.append(new_offset)
96
+
97
+ return offsets
isegm/inference/transforms/flip.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import List
4
+ from isegm.inference.clicker import Click
5
+ from .base import BaseTransform
6
+
7
+
8
+ class AddHorizontalFlip(BaseTransform):
9
+ def transform(self, image_nd, clicks_lists: List[List[Click]]):
10
+ assert len(image_nd.shape) == 4
11
+ image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0)
12
+
13
+ image_width = image_nd.shape[3]
14
+ clicks_lists_flipped = []
15
+ for clicks_list in clicks_lists:
16
+ clicks_list_flipped = [click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1))
17
+ for click in clicks_list]
18
+ clicks_lists_flipped.append(clicks_list_flipped)
19
+ clicks_lists = clicks_lists + clicks_lists_flipped
20
+
21
+ return image_nd, clicks_lists
22
+
23
+ def inv_transform(self, prob_map):
24
+ assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0
25
+ num_maps = prob_map.shape[0] // 2
26
+ prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:]
27
+
28
+ return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3]))
29
+
30
+ def get_state(self):
31
+ return None
32
+
33
+ def set_state(self, state):
34
+ pass
35
+
36
+ def reset(self):
37
+ pass
isegm/inference/transforms/limit_longest_side.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .zoom_in import ZoomIn, get_roi_image_nd
2
+
3
+
4
+ class LimitLongestSide(ZoomIn):
5
+ def __init__(self, max_size=800):
6
+ super().__init__(target_size=max_size, skip_clicks=0)
7
+
8
+ def transform(self, image_nd, clicks_lists):
9
+ assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
10
+ image_max_size = max(image_nd.shape[2:4])
11
+ self.image_changed = False
12
+
13
+ if image_max_size <= self.target_size:
14
+ return image_nd, clicks_lists
15
+ self._input_image = image_nd
16
+
17
+ self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1)
18
+ self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
19
+ self.image_changed = True
20
+
21
+ tclicks_lists = [self._transform_clicks(clicks_lists[0])]
22
+ return self._roi_image, tclicks_lists
isegm/inference/transforms/zoom_in.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import List
4
+ from isegm.inference.clicker import Click
5
+ from isegm.utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox
6
+ from .base import BaseTransform
7
+
8
+
9
+ class ZoomIn(BaseTransform):
10
+ def __init__(self,
11
+ target_size=400,
12
+ skip_clicks=1,
13
+ expansion_ratio=1.4,
14
+ min_crop_size=200,
15
+ recompute_thresh_iou=0.5,
16
+ prob_thresh=0.50):
17
+ super().__init__()
18
+ self.target_size = target_size
19
+ self.min_crop_size = min_crop_size
20
+ self.skip_clicks = skip_clicks
21
+ self.expansion_ratio = expansion_ratio
22
+ self.recompute_thresh_iou = recompute_thresh_iou
23
+ self.prob_thresh = prob_thresh
24
+
25
+ self._input_image_shape = None
26
+ self._prev_probs = None
27
+ self._object_roi = None
28
+ self._roi_image = None
29
+
30
+ def transform(self, image_nd, clicks_lists: List[List[Click]]):
31
+ assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
32
+ self.image_changed = False
33
+
34
+ clicks_list = clicks_lists[0]
35
+ if len(clicks_list) <= self.skip_clicks:
36
+ return image_nd, clicks_lists
37
+
38
+ self._input_image_shape = image_nd.shape
39
+
40
+ current_object_roi = None
41
+ if self._prev_probs is not None:
42
+ current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
43
+ if current_pred_mask.sum() > 0:
44
+ current_object_roi = get_object_roi(current_pred_mask, clicks_list,
45
+ self.expansion_ratio, self.min_crop_size)
46
+
47
+ if current_object_roi is None:
48
+ if self.skip_clicks >= 0:
49
+ return image_nd, clicks_lists
50
+ else:
51
+ current_object_roi = 0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1
52
+
53
+ update_object_roi = False
54
+ if self._object_roi is None:
55
+ update_object_roi = True
56
+ elif not check_object_roi(self._object_roi, clicks_list):
57
+ update_object_roi = True
58
+ elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou:
59
+ update_object_roi = True
60
+
61
+ if update_object_roi:
62
+ self._object_roi = current_object_roi
63
+ self.image_changed = True
64
+ self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
65
+
66
+ tclicks_lists = [self._transform_clicks(clicks_list)]
67
+ return self._roi_image.to(image_nd.device), tclicks_lists
68
+
69
+ def inv_transform(self, prob_map):
70
+ if self._object_roi is None:
71
+ self._prev_probs = prob_map.cpu().numpy()
72
+ return prob_map
73
+
74
+ assert prob_map.shape[0] == 1
75
+ rmin, rmax, cmin, cmax = self._object_roi
76
+ prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1),
77
+ mode='bilinear', align_corners=True)
78
+
79
+ if self._prev_probs is not None:
80
+ new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype)
81
+ new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map
82
+ else:
83
+ new_prob_map = prob_map
84
+
85
+ self._prev_probs = new_prob_map.cpu().numpy()
86
+
87
+ return new_prob_map
88
+
89
+ def check_possible_recalculation(self):
90
+ if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0:
91
+ return False
92
+
93
+ pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
94
+ if pred_mask.sum() > 0:
95
+ possible_object_roi = get_object_roi(pred_mask, [],
96
+ self.expansion_ratio, self.min_crop_size)
97
+ image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1)
98
+ if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
99
+ return True
100
+ return False
101
+
102
+ def get_state(self):
103
+ roi_image = self._roi_image.cpu() if self._roi_image is not None else None
104
+ return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed
105
+
106
+ def set_state(self, state):
107
+ self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state
108
+
109
+ def reset(self):
110
+ self._input_image_shape = None
111
+ self._object_roi = None
112
+ self._prev_probs = None
113
+ self._roi_image = None
114
+ self.image_changed = False
115
+
116
+ def _transform_clicks(self, clicks_list):
117
+ if self._object_roi is None:
118
+ return clicks_list
119
+
120
+ rmin, rmax, cmin, cmax = self._object_roi
121
+ crop_height, crop_width = self._roi_image.shape[2:]
122
+
123
+ transformed_clicks = []
124
+ for click in clicks_list:
125
+ new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1)
126
+ new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1)
127
+ transformed_clicks.append(click.copy(coords=(new_r, new_c)))
128
+ return transformed_clicks
129
+
130
+
131
+ def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size):
132
+ pred_mask = pred_mask.copy()
133
+
134
+ for click in clicks_list:
135
+ if click.is_positive:
136
+ pred_mask[int(click.coords[0]), int(click.coords[1])] = 1
137
+
138
+ bbox = get_bbox_from_mask(pred_mask)
139
+ bbox = expand_bbox(bbox, expansion_ratio, min_crop_size)
140
+ h, w = pred_mask.shape[0], pred_mask.shape[1]
141
+ bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1)
142
+
143
+ return bbox
144
+
145
+
146
+ def get_roi_image_nd(image_nd, object_roi, target_size):
147
+ rmin, rmax, cmin, cmax = object_roi
148
+
149
+ height = rmax - rmin + 1
150
+ width = cmax - cmin + 1
151
+
152
+ if isinstance(target_size, tuple):
153
+ new_height, new_width = target_size
154
+ else:
155
+ scale = target_size / max(height, width)
156
+ new_height = int(round(height * scale))
157
+ new_width = int(round(width * scale))
158
+
159
+ with torch.no_grad():
160
+ roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
161
+ roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width),
162
+ mode='bilinear', align_corners=True)
163
+
164
+ return roi_image_nd
165
+
166
+
167
+ def check_object_roi(object_roi, clicks_list):
168
+ for click in clicks_list:
169
+ if click.is_positive:
170
+ if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]:
171
+ return False
172
+ if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]:
173
+ return False
174
+
175
+ return True
isegm/inference/utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import timedelta
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import numpy as np
6
+
7
+ from isegm.data.datasets import GrabCutDataset, BerkeleyDataset, DavisDataset, SBDEvaluationDataset, PascalVocDataset
8
+ from isegm.utils.serialization import load_model
9
+
10
+
11
+ def get_time_metrics(all_ious, elapsed_time):
12
+ n_images = len(all_ious)
13
+ n_clicks = sum(map(len, all_ious))
14
+
15
+ mean_spc = elapsed_time / n_clicks
16
+ mean_spi = elapsed_time / n_images
17
+
18
+ return mean_spc, mean_spi
19
+
20
+
21
+ def load_is_model(checkpoint, device, **kwargs):
22
+ if isinstance(checkpoint, (str, Path)):
23
+ state_dict = torch.load(checkpoint, map_location='cpu')
24
+ else:
25
+ state_dict = checkpoint
26
+
27
+ if isinstance(state_dict, list):
28
+ model = load_single_is_model(state_dict[0], device, **kwargs)
29
+ models = [load_single_is_model(x, device, **kwargs) for x in state_dict]
30
+
31
+ return model, models
32
+ else:
33
+ return load_single_is_model(state_dict, device, **kwargs)
34
+
35
+
36
+ def load_single_is_model(state_dict, device, **kwargs):
37
+ model = load_model(state_dict['config'], **kwargs)
38
+ model.load_state_dict(state_dict['state_dict'], strict=False)
39
+
40
+ for param in model.parameters():
41
+ param.requires_grad = False
42
+ model.to(device)
43
+ model.eval()
44
+
45
+ return model
46
+
47
+
48
+ def get_dataset(dataset_name, cfg):
49
+ if dataset_name == 'GrabCut':
50
+ dataset = GrabCutDataset(cfg.GRABCUT_PATH)
51
+ elif dataset_name == 'Berkeley':
52
+ dataset = BerkeleyDataset(cfg.BERKELEY_PATH)
53
+ elif dataset_name == 'DAVIS':
54
+ dataset = DavisDataset(cfg.DAVIS_PATH)
55
+ elif dataset_name == 'SBD':
56
+ dataset = SBDEvaluationDataset(cfg.SBD_PATH)
57
+ elif dataset_name == 'SBD_Train':
58
+ dataset = SBDEvaluationDataset(cfg.SBD_PATH, split='train')
59
+ elif dataset_name == 'PascalVOC':
60
+ dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split='test')
61
+ elif dataset_name == 'COCO_MVal':
62
+ dataset = DavisDataset(cfg.COCO_MVAL_PATH)
63
+ else:
64
+ dataset = None
65
+
66
+ return dataset
67
+
68
+
69
+ def get_iou(gt_mask, pred_mask, ignore_label=-1):
70
+ ignore_gt_mask_inv = gt_mask != ignore_label
71
+ obj_gt_mask = gt_mask == 1
72
+
73
+ intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
74
+ union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
75
+
76
+ return intersection / union
77
+
78
+
79
+ def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
80
+ def _get_noc(iou_arr, iou_thr):
81
+ vals = iou_arr >= iou_thr
82
+ return np.argmax(vals) + 1 if np.any(vals) else max_clicks
83
+
84
+ noc_list = []
85
+ over_max_list = []
86
+ for iou_thr in iou_thrs:
87
+ scores_arr = np.array([_get_noc(iou_arr, iou_thr)
88
+ for iou_arr in all_ious], dtype=np.int)
89
+
90
+ score = scores_arr.mean()
91
+ over_max = (scores_arr == max_clicks).sum()
92
+
93
+ noc_list.append(score)
94
+ over_max_list.append(over_max)
95
+
96
+ return noc_list, over_max_list
97
+
98
+
99
+ def find_checkpoint(weights_folder, checkpoint_name):
100
+ weights_folder = Path(weights_folder)
101
+ if ':' in checkpoint_name:
102
+ model_name, checkpoint_name = checkpoint_name.split(':')
103
+ models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()]
104
+ assert len(models_candidates) == 1
105
+ model_folder = models_candidates[0]
106
+ else:
107
+ model_folder = weights_folder
108
+
109
+ if checkpoint_name.endswith('.pth'):
110
+ if Path(checkpoint_name).exists():
111
+ checkpoint_path = checkpoint_name
112
+ else:
113
+ checkpoint_path = weights_folder / checkpoint_name
114
+ else:
115
+ model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth'))
116
+ assert len(model_checkpoints) == 1
117
+ checkpoint_path = model_checkpoints[0]
118
+
119
+ return str(checkpoint_path)
120
+
121
+
122
+ def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time,
123
+ n_clicks=20, model_name=None):
124
+ table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|'
125
+ f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
126
+ f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
127
+ f'{"SPC,s":^7}|{"Time":^9}|')
128
+ row_width = len(table_header)
129
+
130
+ header = f'Eval results for model: {model_name}\n' if model_name is not None else ''
131
+ header += '-' * row_width + '\n'
132
+ header += table_header + '\n' + '-' * row_width
133
+
134
+ eval_time = str(timedelta(seconds=int(elapsed_time)))
135
+ table_row = f'|{brs_type:^13}|{dataset_name:^11}|'
136
+ table_row += f'{noc_list[0]:^9.2f}|'
137
+ table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|'
138
+ table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|'
139
+ table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|'
140
+ table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|'
141
+ table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|'
142
+
143
+ return header, table_row
isegm/model/initializer.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class Initializer(object):
7
+ def __init__(self, local_init=True, gamma=None):
8
+ self.local_init = local_init
9
+ self.gamma = gamma
10
+
11
+ def __call__(self, m):
12
+ if getattr(m, '__initialized', False):
13
+ return
14
+
15
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
16
+ nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
17
+ nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__:
18
+ if m.weight is not None:
19
+ self._init_gamma(m.weight.data)
20
+ if m.bias is not None:
21
+ self._init_beta(m.bias.data)
22
+ else:
23
+ if getattr(m, 'weight', None) is not None:
24
+ self._init_weight(m.weight.data)
25
+ if getattr(m, 'bias', None) is not None:
26
+ self._init_bias(m.bias.data)
27
+
28
+ if self.local_init:
29
+ object.__setattr__(m, '__initialized', True)
30
+
31
+ def _init_weight(self, data):
32
+ nn.init.uniform_(data, -0.07, 0.07)
33
+
34
+ def _init_bias(self, data):
35
+ nn.init.constant_(data, 0)
36
+
37
+ def _init_gamma(self, data):
38
+ if self.gamma is None:
39
+ nn.init.constant_(data, 1.0)
40
+ else:
41
+ nn.init.normal_(data, 1.0, self.gamma)
42
+
43
+ def _init_beta(self, data):
44
+ nn.init.constant_(data, 0)
45
+
46
+
47
+ class Bilinear(Initializer):
48
+ def __init__(self, scale, groups, in_channels, **kwargs):
49
+ super().__init__(**kwargs)
50
+ self.scale = scale
51
+ self.groups = groups
52
+ self.in_channels = in_channels
53
+
54
+ def _init_weight(self, data):
55
+ """Reset the weight and bias."""
56
+ bilinear_kernel = self.get_bilinear_kernel(self.scale)
57
+ weight = torch.zeros_like(data)
58
+ for i in range(self.in_channels):
59
+ if self.groups == 1:
60
+ j = i
61
+ else:
62
+ j = 0
63
+ weight[i, j] = bilinear_kernel
64
+ data[:] = weight
65
+
66
+ @staticmethod
67
+ def get_bilinear_kernel(scale):
68
+ """Generate a bilinear upsampling kernel."""
69
+ kernel_size = 2 * scale - scale % 2
70
+ scale = (kernel_size + 1) // 2
71
+ center = scale - 0.5 * (1 + kernel_size % 2)
72
+
73
+ og = np.ogrid[:kernel_size, :kernel_size]
74
+ kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale)
75
+
76
+ return torch.tensor(kernel, dtype=torch.float32)
77
+
78
+
79
+ class XavierGluon(Initializer):
80
+ def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs):
81
+ super().__init__(**kwargs)
82
+
83
+ self.rnd_type = rnd_type
84
+ self.factor_type = factor_type
85
+ self.magnitude = float(magnitude)
86
+
87
+ def _init_weight(self, arr):
88
+ fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
89
+
90
+ if self.factor_type == 'avg':
91
+ factor = (fan_in + fan_out) / 2.0
92
+ elif self.factor_type == 'in':
93
+ factor = fan_in
94
+ elif self.factor_type == 'out':
95
+ factor = fan_out
96
+ else:
97
+ raise ValueError('Incorrect factor type')
98
+ scale = np.sqrt(self.magnitude / factor)
99
+
100
+ if self.rnd_type == 'uniform':
101
+ nn.init.uniform_(arr, -scale, scale)
102
+ elif self.rnd_type == 'gaussian':
103
+ nn.init.normal_(arr, 0, scale)
104
+ else:
105
+ raise ValueError('Unknown random type')
isegm/model/is_deeplab_model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from isegm.utils.serialization import serialize
4
+ from .is_model import ISModel
5
+ from .modeling.deeplab_v3 import DeepLabV3Plus
6
+ from .modeling.basic_blocks import SepConvHead
7
+ from isegm.model.modifiers import LRMult
8
+
9
+
10
+ class DeeplabModel(ISModel):
11
+ @serialize
12
+ def __init__(self, backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5,
13
+ backbone_norm_layer=None, backbone_lr_mult=0.1, norm_layer=nn.BatchNorm2d, **kwargs):
14
+ super().__init__(norm_layer=norm_layer, **kwargs)
15
+
16
+ self.feature_extractor = DeepLabV3Plus(backbone=backbone, ch=deeplab_ch, project_dropout=aspp_dropout,
17
+ norm_layer=norm_layer, backbone_norm_layer=backbone_norm_layer)
18
+ self.feature_extractor.backbone.apply(LRMult(backbone_lr_mult))
19
+ self.head = SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2,
20
+ num_layers=2, norm_layer=norm_layer)
21
+
22
+ def backbone_forward(self, image, coord_features=None):
23
+ backbone_features = self.feature_extractor(image, coord_features)
24
+
25
+ return {'instances': self.head(backbone_features[0])}
isegm/model/is_hrnet_model.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from isegm.utils.serialization import serialize
4
+ from .is_model import ISModel
5
+ from .modeling.hrnet_ocr import HighResolutionNet
6
+ from isegm.model.modifiers import LRMult
7
+
8
+
9
+ class HRNetModel(ISModel):
10
+ @serialize
11
+ def __init__(self, width=48, ocr_width=256, small=False, backbone_lr_mult=0.1,
12
+ norm_layer=nn.BatchNorm2d, **kwargs):
13
+ super().__init__(norm_layer=norm_layer, **kwargs)
14
+
15
+ self.feature_extractor = HighResolutionNet(width=width, ocr_width=ocr_width, small=small,
16
+ num_classes=1, norm_layer=norm_layer)
17
+ self.feature_extractor.apply(LRMult(backbone_lr_mult))
18
+ if ocr_width > 0:
19
+ self.feature_extractor.ocr_distri_head.apply(LRMult(1.0))
20
+ self.feature_extractor.ocr_gather_head.apply(LRMult(1.0))
21
+ self.feature_extractor.conv3x3_ocr.apply(LRMult(1.0))
22
+
23
+ def backbone_forward(self, image, coord_features=None):
24
+ net_outputs = self.feature_extractor(image, coord_features)
25
+
26
+ return {'instances': net_outputs[0], 'instances_aux': net_outputs[1]}
isegm/model/is_model.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ from isegm.model.ops import DistMaps, ScaleLayer, BatchImageNormalize
6
+ from isegm.model.modifiers import LRMult
7
+
8
+
9
+ class ISModel(nn.Module):
10
+ def __init__(self, use_rgb_conv=True, with_aux_output=False,
11
+ norm_radius=260, use_disks=False, cpu_dist_maps=False,
12
+ clicks_groups=None, with_prev_mask=False, use_leaky_relu=False,
13
+ binary_prev_mask=False, conv_extend=False, norm_layer=nn.BatchNorm2d,
14
+ norm_mean_std=([.485, .456, .406], [.229, .224, .225])):
15
+ super().__init__()
16
+ self.with_aux_output = with_aux_output
17
+ self.clicks_groups = clicks_groups
18
+ self.with_prev_mask = with_prev_mask
19
+ self.binary_prev_mask = binary_prev_mask
20
+ self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1])
21
+
22
+ self.coord_feature_ch = 2
23
+ if clicks_groups is not None:
24
+ self.coord_feature_ch *= len(clicks_groups)
25
+
26
+ if self.with_prev_mask:
27
+ self.coord_feature_ch += 1
28
+
29
+ if use_rgb_conv:
30
+ rgb_conv_layers = [
31
+ nn.Conv2d(in_channels=3 + self.coord_feature_ch, out_channels=6 + self.coord_feature_ch, kernel_size=1),
32
+ norm_layer(6 + self.coord_feature_ch),
33
+ nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True),
34
+ nn.Conv2d(in_channels=6 + self.coord_feature_ch, out_channels=3, kernel_size=1)
35
+ ]
36
+ self.rgb_conv = nn.Sequential(*rgb_conv_layers)
37
+ elif conv_extend:
38
+ self.rgb_conv = None
39
+ self.maps_transform = nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=64,
40
+ kernel_size=3, stride=2, padding=1)
41
+ self.maps_transform.apply(LRMult(0.1))
42
+ else:
43
+ self.rgb_conv = None
44
+ mt_layers = [
45
+ nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1),
46
+ nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True),
47
+ nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1),
48
+ ScaleLayer(init_value=0.05, lr_mult=1)
49
+ ]
50
+ self.maps_transform = nn.Sequential(*mt_layers)
51
+
52
+ if self.clicks_groups is not None:
53
+ self.dist_maps = nn.ModuleList()
54
+ for click_radius in self.clicks_groups:
55
+ self.dist_maps.append(DistMaps(norm_radius=click_radius, spatial_scale=1.0,
56
+ cpu_mode=cpu_dist_maps, use_disks=use_disks))
57
+ else:
58
+ self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0,
59
+ cpu_mode=cpu_dist_maps, use_disks=use_disks)
60
+
61
+ def forward(self, image, points):
62
+ image, prev_mask = self.prepare_input(image)
63
+ coord_features = self.get_coord_features(image, prev_mask, points)
64
+
65
+ if self.rgb_conv is not None:
66
+ x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
67
+ outputs = self.backbone_forward(x)
68
+ else:
69
+ coord_features = self.maps_transform(coord_features)
70
+ outputs = self.backbone_forward(image, coord_features)
71
+
72
+ outputs['instances'] = nn.functional.interpolate(outputs['instances'], size=image.size()[2:],
73
+ mode='bilinear', align_corners=True)
74
+ if self.with_aux_output:
75
+ outputs['instances_aux'] = nn.functional.interpolate(outputs['instances_aux'], size=image.size()[2:],
76
+ mode='bilinear', align_corners=True)
77
+
78
+ return outputs
79
+
80
+ def prepare_input(self, image):
81
+ prev_mask = None
82
+ if self.with_prev_mask:
83
+ prev_mask = image[:, 3:, :, :]
84
+ image = image[:, :3, :, :]
85
+ if self.binary_prev_mask:
86
+ prev_mask = (prev_mask > 0.5).float()
87
+
88
+ image = self.normalization(image)
89
+ return image, prev_mask
90
+
91
+ def backbone_forward(self, image, coord_features=None):
92
+ raise NotImplementedError
93
+
94
+ def get_coord_features(self, image, prev_mask, points):
95
+ if self.clicks_groups is not None:
96
+ points_groups = split_points_by_order(points, groups=(2,) + (1, ) * (len(self.clicks_groups) - 2) + (-1,))
97
+ coord_features = [dist_map(image, pg) for dist_map, pg in zip(self.dist_maps, points_groups)]
98
+ coord_features = torch.cat(coord_features, dim=1)
99
+ else:
100
+ coord_features = self.dist_maps(image, points)
101
+
102
+ if prev_mask is not None:
103
+ coord_features = torch.cat((prev_mask, coord_features), dim=1)
104
+
105
+ return coord_features
106
+
107
+
108
+ def split_points_by_order(tpoints: torch.Tensor, groups):
109
+ points = tpoints.cpu().numpy()
110
+ num_groups = len(groups)
111
+ bs = points.shape[0]
112
+ num_points = points.shape[1] // 2
113
+
114
+ groups = [x if x > 0 else num_points for x in groups]
115
+ group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32)
116
+ for x in groups]
117
+
118
+ last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int)
119
+ for group_indx, group_size in enumerate(groups):
120
+ last_point_indx_group[:, group_indx, 1] = group_size
121
+
122
+ for bindx in range(bs):
123
+ for pindx in range(2 * num_points):
124
+ point = points[bindx, pindx, :]
125
+ group_id = int(point[2])
126
+ if group_id < 0:
127
+ continue
128
+
129
+ is_negative = int(pindx >= num_points)
130
+ if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click
131
+ group_id = num_groups - 1
132
+
133
+ new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
134
+ last_point_indx_group[bindx, group_id, is_negative] += 1
135
+
136
+ group_points[group_id][bindx, new_point_indx, :] = point
137
+
138
+ group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device)
139
+ for x in group_points]
140
+
141
+ return group_points
isegm/model/losses.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from isegm.utils import misc
7
+
8
+
9
+ class NormalizedFocalLossSigmoid(nn.Module):
10
+ def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12,
11
+ from_sigmoid=False, detach_delimeter=True,
12
+ batch_axis=0, weight=None, size_average=True,
13
+ ignore_label=-1):
14
+ super(NormalizedFocalLossSigmoid, self).__init__()
15
+ self._axis = axis
16
+ self._alpha = alpha
17
+ self._gamma = gamma
18
+ self._ignore_label = ignore_label
19
+ self._weight = weight if weight is not None else 1.0
20
+ self._batch_axis = batch_axis
21
+
22
+ self._from_logits = from_sigmoid
23
+ self._eps = eps
24
+ self._size_average = size_average
25
+ self._detach_delimeter = detach_delimeter
26
+ self._max_mult = max_mult
27
+ self._k_sum = 0
28
+ self._m_max = 0
29
+
30
+ def forward(self, pred, label):
31
+ one_hot = label > 0.5
32
+ sample_weight = label != self._ignore_label
33
+
34
+ if not self._from_logits:
35
+ pred = torch.sigmoid(pred)
36
+
37
+ alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
38
+ pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred))
39
+
40
+ beta = (1 - pt) ** self._gamma
41
+
42
+ sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
43
+ beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
44
+ mult = sw_sum / (beta_sum + self._eps)
45
+ if self._detach_delimeter:
46
+ mult = mult.detach()
47
+ beta = beta * mult
48
+ if self._max_mult > 0:
49
+ beta = torch.clamp_max(beta, self._max_mult)
50
+
51
+ with torch.no_grad():
52
+ ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy()
53
+ sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
54
+ if np.any(ignore_area == 0):
55
+ self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
56
+
57
+ beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1)
58
+ beta_pmax = beta_pmax.mean().item()
59
+ self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax
60
+
61
+ loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
62
+ loss = self._weight * (loss * sample_weight)
63
+
64
+ if self._size_average:
65
+ bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis))
66
+ loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps)
67
+ else:
68
+ loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
69
+
70
+ return loss
71
+
72
+ def log_states(self, sw, name, global_step):
73
+ sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step)
74
+ sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step)
75
+
76
+
77
+ class FocalLoss(nn.Module):
78
+ def __init__(self, axis=-1, alpha=0.25, gamma=2,
79
+ from_logits=False, batch_axis=0,
80
+ weight=None, num_class=None,
81
+ eps=1e-9, size_average=True, scale=1.0,
82
+ ignore_label=-1):
83
+ super(FocalLoss, self).__init__()
84
+ self._axis = axis
85
+ self._alpha = alpha
86
+ self._gamma = gamma
87
+ self._ignore_label = ignore_label
88
+ self._weight = weight if weight is not None else 1.0
89
+ self._batch_axis = batch_axis
90
+
91
+ self._scale = scale
92
+ self._num_class = num_class
93
+ self._from_logits = from_logits
94
+ self._eps = eps
95
+ self._size_average = size_average
96
+
97
+ def forward(self, pred, label, sample_weight=None):
98
+ one_hot = label > 0.5
99
+ sample_weight = label != self._ignore_label
100
+
101
+ if not self._from_logits:
102
+ pred = torch.sigmoid(pred)
103
+
104
+ alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
105
+ pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred))
106
+
107
+ beta = (1 - pt) ** self._gamma
108
+
109
+ loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
110
+ loss = self._weight * (loss * sample_weight)
111
+
112
+ if self._size_average:
113
+ tsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis))
114
+ loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps)
115
+ else:
116
+ loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
117
+
118
+ return self._scale * loss
119
+
120
+
121
+ class SoftIoU(nn.Module):
122
+ def __init__(self, from_sigmoid=False, ignore_label=-1):
123
+ super().__init__()
124
+ self._from_sigmoid = from_sigmoid
125
+ self._ignore_label = ignore_label
126
+
127
+ def forward(self, pred, label):
128
+ label = label.view(pred.size())
129
+ sample_weight = label != self._ignore_label
130
+
131
+ if not self._from_sigmoid:
132
+ pred = torch.sigmoid(pred)
133
+
134
+ loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) \
135
+ / (torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8)
136
+
137
+ return loss
138
+
139
+
140
+ class SigmoidBinaryCrossEntropyLoss(nn.Module):
141
+ def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1):
142
+ super(SigmoidBinaryCrossEntropyLoss, self).__init__()
143
+ self._from_sigmoid = from_sigmoid
144
+ self._ignore_label = ignore_label
145
+ self._weight = weight if weight is not None else 1.0
146
+ self._batch_axis = batch_axis
147
+
148
+ def forward(self, pred, label):
149
+ label = label.view(pred.size())
150
+ sample_weight = label != self._ignore_label
151
+ label = torch.where(sample_weight, label, torch.zeros_like(label))
152
+
153
+ if not self._from_sigmoid:
154
+ loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
155
+ else:
156
+ eps = 1e-12
157
+ loss = -(torch.log(pred + eps) * label
158
+ + torch.log(1. - pred + eps) * (1. - label))
159
+
160
+ loss = self._weight * (loss * sample_weight)
161
+ return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
isegm/model/metrics.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from isegm.utils import misc
5
+
6
+
7
+ class TrainMetric(object):
8
+ def __init__(self, pred_outputs, gt_outputs):
9
+ self.pred_outputs = pred_outputs
10
+ self.gt_outputs = gt_outputs
11
+
12
+ def update(self, *args, **kwargs):
13
+ raise NotImplementedError
14
+
15
+ def get_epoch_value(self):
16
+ raise NotImplementedError
17
+
18
+ def reset_epoch_stats(self):
19
+ raise NotImplementedError
20
+
21
+ def log_states(self, sw, tag_prefix, global_step):
22
+ pass
23
+
24
+ @property
25
+ def name(self):
26
+ return type(self).__name__
27
+
28
+
29
+ class AdaptiveIoU(TrainMetric):
30
+ def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9,
31
+ ignore_label=-1, from_logits=True,
32
+ pred_output='instances', gt_output='instances'):
33
+ super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,))
34
+ self._ignore_label = ignore_label
35
+ self._from_logits = from_logits
36
+ self._iou_thresh = init_thresh
37
+ self._thresh_step = thresh_step
38
+ self._thresh_beta = thresh_beta
39
+ self._iou_beta = iou_beta
40
+ self._ema_iou = 0.0
41
+ self._epoch_iou_sum = 0.0
42
+ self._epoch_batch_count = 0
43
+
44
+ def update(self, pred, gt):
45
+ gt_mask = gt > 0.5
46
+ if self._from_logits:
47
+ pred = torch.sigmoid(pred)
48
+
49
+ gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy()
50
+ if np.all(gt_mask_area == 0):
51
+ return
52
+
53
+ ignore_mask = gt == self._ignore_label
54
+ max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean()
55
+ best_thresh = self._iou_thresh
56
+ for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]:
57
+ temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean()
58
+ if temp_iou > max_iou:
59
+ max_iou = temp_iou
60
+ best_thresh = t
61
+
62
+ self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh
63
+ self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou
64
+ self._epoch_iou_sum += max_iou
65
+ self._epoch_batch_count += 1
66
+
67
+ def get_epoch_value(self):
68
+ if self._epoch_batch_count > 0:
69
+ return self._epoch_iou_sum / self._epoch_batch_count
70
+ else:
71
+ return 0.0
72
+
73
+ def reset_epoch_stats(self):
74
+ self._epoch_iou_sum = 0.0
75
+ self._epoch_batch_count = 0
76
+
77
+ def log_states(self, sw, tag_prefix, global_step):
78
+ sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step)
79
+ sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step)
80
+
81
+ @property
82
+ def iou_thresh(self):
83
+ return self._iou_thresh
84
+
85
+
86
+ def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False):
87
+ if ignore_mask is not None:
88
+ pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)
89
+
90
+ reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
91
+ union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
92
+ intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
93
+ nonzero = union > 0
94
+
95
+ iou = intersection[nonzero] / union[nonzero]
96
+ if not keep_ignore:
97
+ return iou
98
+ else:
99
+ result = np.full_like(intersection, -1)
100
+ result[nonzero] = iou
101
+ return result
isegm/model/modeling/basic_blocks.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from isegm.model import ops
4
+
5
+
6
+ class ConvHead(nn.Module):
7
+ def __init__(self, out_channels, in_channels=32, num_layers=1,
8
+ kernel_size=3, padding=1,
9
+ norm_layer=nn.BatchNorm2d):
10
+ super(ConvHead, self).__init__()
11
+ convhead = []
12
+
13
+ for i in range(num_layers):
14
+ convhead.extend([
15
+ nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
16
+ nn.ReLU(),
17
+ norm_layer(in_channels) if norm_layer is not None else nn.Identity()
18
+ ])
19
+ convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
20
+
21
+ self.convhead = nn.Sequential(*convhead)
22
+
23
+ def forward(self, *inputs):
24
+ return self.convhead(inputs[0])
25
+
26
+
27
+ class SepConvHead(nn.Module):
28
+ def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1,
29
+ kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0,
30
+ norm_layer=nn.BatchNorm2d):
31
+ super(SepConvHead, self).__init__()
32
+
33
+ sepconvhead = []
34
+
35
+ for i in range(num_layers):
36
+ sepconvhead.append(
37
+ SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels,
38
+ out_channels=mid_channels,
39
+ dw_kernel=kernel_size, dw_padding=padding,
40
+ norm_layer=norm_layer, activation='relu')
41
+ )
42
+ if dropout_ratio > 0 and dropout_indx == i:
43
+ sepconvhead.append(nn.Dropout(dropout_ratio))
44
+
45
+ sepconvhead.append(
46
+ nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0)
47
+ )
48
+
49
+ self.layers = nn.Sequential(*sepconvhead)
50
+
51
+ def forward(self, *inputs):
52
+ x = inputs[0]
53
+
54
+ return self.layers(x)
55
+
56
+
57
+ class SeparableConv2d(nn.Module):
58
+ def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1,
59
+ activation=None, use_bias=False, norm_layer=None):
60
+ super(SeparableConv2d, self).__init__()
61
+ _activation = ops.select_activation_function(activation)
62
+ self.body = nn.Sequential(
63
+ nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride,
64
+ padding=dw_padding, bias=use_bias, groups=in_channels),
65
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias),
66
+ norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
67
+ _activation()
68
+ )
69
+
70
+ def forward(self, x):
71
+ return self.body(x)
isegm/model/modeling/deeplab_v3.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import ExitStack
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+ from .basic_blocks import SeparableConv2d
8
+ from .resnet import ResNetBackbone
9
+ from isegm.model import ops
10
+
11
+
12
+ class DeepLabV3Plus(nn.Module):
13
+ def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
14
+ backbone_norm_layer=None,
15
+ ch=256,
16
+ project_dropout=0.5,
17
+ inference_mode=False,
18
+ **kwargs):
19
+ super(DeepLabV3Plus, self).__init__()
20
+ if backbone_norm_layer is None:
21
+ backbone_norm_layer = norm_layer
22
+
23
+ self.backbone_name = backbone
24
+ self.norm_layer = norm_layer
25
+ self.backbone_norm_layer = backbone_norm_layer
26
+ self.inference_mode = False
27
+ self.ch = ch
28
+ self.aspp_in_channels = 2048
29
+ self.skip_project_in_channels = 256 # layer 1 out_channels
30
+
31
+ self._kwargs = kwargs
32
+ if backbone == 'resnet34':
33
+ self.aspp_in_channels = 512
34
+ self.skip_project_in_channels = 64
35
+
36
+ self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False,
37
+ norm_layer=self.backbone_norm_layer, **kwargs)
38
+
39
+ self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch,
40
+ norm_layer=self.norm_layer)
41
+ self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer)
42
+ self.aspp = _ASPP(in_channels=self.aspp_in_channels,
43
+ atrous_rates=[12, 24, 36],
44
+ out_channels=ch,
45
+ project_dropout=project_dropout,
46
+ norm_layer=self.norm_layer)
47
+
48
+ if inference_mode:
49
+ self.set_prediction_mode()
50
+
51
+ def load_pretrained_weights(self):
52
+ pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True,
53
+ norm_layer=self.backbone_norm_layer, **self._kwargs)
54
+ backbone_state_dict = self.backbone.state_dict()
55
+ pretrained_state_dict = pretrained.state_dict()
56
+
57
+ backbone_state_dict.update(pretrained_state_dict)
58
+ self.backbone.load_state_dict(backbone_state_dict)
59
+
60
+ if self.inference_mode:
61
+ for param in self.backbone.parameters():
62
+ param.requires_grad = False
63
+
64
+ def set_prediction_mode(self):
65
+ self.inference_mode = True
66
+ self.eval()
67
+
68
+ def forward(self, x, additional_features=None):
69
+ with ExitStack() as stack:
70
+ if self.inference_mode:
71
+ stack.enter_context(torch.no_grad())
72
+
73
+ c1, _, c3, c4 = self.backbone(x, additional_features)
74
+ c1 = self.skip_project(c1)
75
+
76
+ x = self.aspp(c4)
77
+ x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True)
78
+ x = torch.cat((x, c1), dim=1)
79
+ x = self.head(x)
80
+
81
+ return x,
82
+
83
+
84
+ class _SkipProject(nn.Module):
85
+ def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
86
+ super(_SkipProject, self).__init__()
87
+ _activation = ops.select_activation_function("relu")
88
+
89
+ self.skip_project = nn.Sequential(
90
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
91
+ norm_layer(out_channels),
92
+ _activation()
93
+ )
94
+
95
+ def forward(self, x):
96
+ return self.skip_project(x)
97
+
98
+
99
+ class _DeepLabHead(nn.Module):
100
+ def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d):
101
+ super(_DeepLabHead, self).__init__()
102
+
103
+ self.block = nn.Sequential(
104
+ SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3,
105
+ dw_padding=1, activation='relu', norm_layer=norm_layer),
106
+ SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3,
107
+ dw_padding=1, activation='relu', norm_layer=norm_layer),
108
+ nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
109
+ )
110
+
111
+ def forward(self, x):
112
+ return self.block(x)
113
+
114
+
115
+ class _ASPP(nn.Module):
116
+ def __init__(self, in_channels, atrous_rates, out_channels=256,
117
+ project_dropout=0.5, norm_layer=nn.BatchNorm2d):
118
+ super(_ASPP, self).__init__()
119
+
120
+ b0 = nn.Sequential(
121
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
122
+ norm_layer(out_channels),
123
+ nn.ReLU()
124
+ )
125
+
126
+ rate1, rate2, rate3 = tuple(atrous_rates)
127
+ b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
128
+ b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
129
+ b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
130
+ b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
131
+
132
+ self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
133
+
134
+ project = [
135
+ nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
136
+ kernel_size=1, bias=False),
137
+ norm_layer(out_channels),
138
+ nn.ReLU()
139
+ ]
140
+ if project_dropout > 0:
141
+ project.append(nn.Dropout(project_dropout))
142
+ self.project = nn.Sequential(*project)
143
+
144
+ def forward(self, x):
145
+ x = torch.cat([block(x) for block in self.concurent], dim=1)
146
+
147
+ return self.project(x)
148
+
149
+
150
+ class _AsppPooling(nn.Module):
151
+ def __init__(self, in_channels, out_channels, norm_layer):
152
+ super(_AsppPooling, self).__init__()
153
+
154
+ self.gap = nn.Sequential(
155
+ nn.AdaptiveAvgPool2d((1, 1)),
156
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
157
+ kernel_size=1, bias=False),
158
+ norm_layer(out_channels),
159
+ nn.ReLU()
160
+ )
161
+
162
+ def forward(self, x):
163
+ pool = self.gap(x)
164
+ return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True)
165
+
166
+
167
+ def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
168
+ block = nn.Sequential(
169
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
170
+ kernel_size=3, padding=atrous_rate,
171
+ dilation=atrous_rate, bias=False),
172
+ norm_layer(out_channels),
173
+ nn.ReLU()
174
+ )
175
+
176
+ return block
isegm/model/modeling/hrnet_ocr.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch._utils
6
+ import torch.nn.functional as F
7
+ from .ocr import SpatialOCR_Module, SpatialGather_Module
8
+ from .resnetv1b import BasicBlockV1b, BottleneckV1b
9
+
10
+ relu_inplace = True
11
+
12
+
13
+ class HighResolutionModule(nn.Module):
14
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
15
+ num_channels, fuse_method,multi_scale_output=True,
16
+ norm_layer=nn.BatchNorm2d, align_corners=True):
17
+ super(HighResolutionModule, self).__init__()
18
+ self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
19
+
20
+ self.num_inchannels = num_inchannels
21
+ self.fuse_method = fuse_method
22
+ self.num_branches = num_branches
23
+ self.norm_layer = norm_layer
24
+ self.align_corners = align_corners
25
+
26
+ self.multi_scale_output = multi_scale_output
27
+
28
+ self.branches = self._make_branches(
29
+ num_branches, blocks, num_blocks, num_channels)
30
+ self.fuse_layers = self._make_fuse_layers()
31
+ self.relu = nn.ReLU(inplace=relu_inplace)
32
+
33
+ def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
34
+ if num_branches != len(num_blocks):
35
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
36
+ num_branches, len(num_blocks))
37
+ raise ValueError(error_msg)
38
+
39
+ if num_branches != len(num_channels):
40
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
41
+ num_branches, len(num_channels))
42
+ raise ValueError(error_msg)
43
+
44
+ if num_branches != len(num_inchannels):
45
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
46
+ num_branches, len(num_inchannels))
47
+ raise ValueError(error_msg)
48
+
49
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
50
+ stride=1):
51
+ downsample = None
52
+ if stride != 1 or \
53
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
54
+ downsample = nn.Sequential(
55
+ nn.Conv2d(self.num_inchannels[branch_index],
56
+ num_channels[branch_index] * block.expansion,
57
+ kernel_size=1, stride=stride, bias=False),
58
+ self.norm_layer(num_channels[branch_index] * block.expansion),
59
+ )
60
+
61
+ layers = []
62
+ layers.append(block(self.num_inchannels[branch_index],
63
+ num_channels[branch_index], stride,
64
+ downsample=downsample, norm_layer=self.norm_layer))
65
+ self.num_inchannels[branch_index] = \
66
+ num_channels[branch_index] * block.expansion
67
+ for i in range(1, num_blocks[branch_index]):
68
+ layers.append(block(self.num_inchannels[branch_index],
69
+ num_channels[branch_index],
70
+ norm_layer=self.norm_layer))
71
+
72
+ return nn.Sequential(*layers)
73
+
74
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
75
+ branches = []
76
+
77
+ for i in range(num_branches):
78
+ branches.append(
79
+ self._make_one_branch(i, block, num_blocks, num_channels))
80
+
81
+ return nn.ModuleList(branches)
82
+
83
+ def _make_fuse_layers(self):
84
+ if self.num_branches == 1:
85
+ return None
86
+
87
+ num_branches = self.num_branches
88
+ num_inchannels = self.num_inchannels
89
+ fuse_layers = []
90
+ for i in range(num_branches if self.multi_scale_output else 1):
91
+ fuse_layer = []
92
+ for j in range(num_branches):
93
+ if j > i:
94
+ fuse_layer.append(nn.Sequential(
95
+ nn.Conv2d(in_channels=num_inchannels[j],
96
+ out_channels=num_inchannels[i],
97
+ kernel_size=1,
98
+ bias=False),
99
+ self.norm_layer(num_inchannels[i])))
100
+ elif j == i:
101
+ fuse_layer.append(None)
102
+ else:
103
+ conv3x3s = []
104
+ for k in range(i - j):
105
+ if k == i - j - 1:
106
+ num_outchannels_conv3x3 = num_inchannels[i]
107
+ conv3x3s.append(nn.Sequential(
108
+ nn.Conv2d(num_inchannels[j],
109
+ num_outchannels_conv3x3,
110
+ kernel_size=3, stride=2, padding=1, bias=False),
111
+ self.norm_layer(num_outchannels_conv3x3)))
112
+ else:
113
+ num_outchannels_conv3x3 = num_inchannels[j]
114
+ conv3x3s.append(nn.Sequential(
115
+ nn.Conv2d(num_inchannels[j],
116
+ num_outchannels_conv3x3,
117
+ kernel_size=3, stride=2, padding=1, bias=False),
118
+ self.norm_layer(num_outchannels_conv3x3),
119
+ nn.ReLU(inplace=relu_inplace)))
120
+ fuse_layer.append(nn.Sequential(*conv3x3s))
121
+ fuse_layers.append(nn.ModuleList(fuse_layer))
122
+
123
+ return nn.ModuleList(fuse_layers)
124
+
125
+ def get_num_inchannels(self):
126
+ return self.num_inchannels
127
+
128
+ def forward(self, x):
129
+ if self.num_branches == 1:
130
+ return [self.branches[0](x[0])]
131
+
132
+ for i in range(self.num_branches):
133
+ x[i] = self.branches[i](x[i])
134
+
135
+ x_fuse = []
136
+ for i in range(len(self.fuse_layers)):
137
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
138
+ for j in range(1, self.num_branches):
139
+ if i == j:
140
+ y = y + x[j]
141
+ elif j > i:
142
+ width_output = x[i].shape[-1]
143
+ height_output = x[i].shape[-2]
144
+ y = y + F.interpolate(
145
+ self.fuse_layers[i][j](x[j]),
146
+ size=[height_output, width_output],
147
+ mode='bilinear', align_corners=self.align_corners)
148
+ else:
149
+ y = y + self.fuse_layers[i][j](x[j])
150
+ x_fuse.append(self.relu(y))
151
+
152
+ return x_fuse
153
+
154
+
155
+ class HighResolutionNet(nn.Module):
156
+ def __init__(self, width, num_classes, ocr_width=256, small=False,
157
+ norm_layer=nn.BatchNorm2d, align_corners=True):
158
+ super(HighResolutionNet, self).__init__()
159
+ self.norm_layer = norm_layer
160
+ self.width = width
161
+ self.ocr_width = ocr_width
162
+ self.align_corners = align_corners
163
+
164
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
165
+ self.bn1 = norm_layer(64)
166
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
167
+ self.bn2 = norm_layer(64)
168
+ self.relu = nn.ReLU(inplace=relu_inplace)
169
+
170
+ num_blocks = 2 if small else 4
171
+
172
+ stage1_num_channels = 64
173
+ self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
174
+ stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
175
+
176
+ self.stage2_num_branches = 2
177
+ num_channels = [width, 2 * width]
178
+ num_inchannels = [
179
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
180
+ self.transition1 = self._make_transition_layer(
181
+ [stage1_out_channel], num_inchannels)
182
+ self.stage2, pre_stage_channels = self._make_stage(
183
+ BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
184
+ num_blocks=2 * [num_blocks], num_channels=num_channels)
185
+
186
+ self.stage3_num_branches = 3
187
+ num_channels = [width, 2 * width, 4 * width]
188
+ num_inchannels = [
189
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
190
+ self.transition2 = self._make_transition_layer(
191
+ pre_stage_channels, num_inchannels)
192
+ self.stage3, pre_stage_channels = self._make_stage(
193
+ BasicBlockV1b, num_inchannels=num_inchannels,
194
+ num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
195
+ num_blocks=3 * [num_blocks], num_channels=num_channels)
196
+
197
+ self.stage4_num_branches = 4
198
+ num_channels = [width, 2 * width, 4 * width, 8 * width]
199
+ num_inchannels = [
200
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
201
+ self.transition3 = self._make_transition_layer(
202
+ pre_stage_channels, num_inchannels)
203
+ self.stage4, pre_stage_channels = self._make_stage(
204
+ BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
205
+ num_branches=self.stage4_num_branches,
206
+ num_blocks=4 * [num_blocks], num_channels=num_channels)
207
+
208
+ last_inp_channels = np.int(np.sum(pre_stage_channels))
209
+ if self.ocr_width > 0:
210
+ ocr_mid_channels = 2 * self.ocr_width
211
+ ocr_key_channels = self.ocr_width
212
+
213
+ self.conv3x3_ocr = nn.Sequential(
214
+ nn.Conv2d(last_inp_channels, ocr_mid_channels,
215
+ kernel_size=3, stride=1, padding=1),
216
+ norm_layer(ocr_mid_channels),
217
+ nn.ReLU(inplace=relu_inplace),
218
+ )
219
+ self.ocr_gather_head = SpatialGather_Module(num_classes)
220
+
221
+ self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
222
+ key_channels=ocr_key_channels,
223
+ out_channels=ocr_mid_channels,
224
+ scale=1,
225
+ dropout=0.05,
226
+ norm_layer=norm_layer,
227
+ align_corners=align_corners)
228
+ self.cls_head = nn.Conv2d(
229
+ ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
230
+
231
+ self.aux_head = nn.Sequential(
232
+ nn.Conv2d(last_inp_channels, last_inp_channels,
233
+ kernel_size=1, stride=1, padding=0),
234
+ norm_layer(last_inp_channels),
235
+ nn.ReLU(inplace=relu_inplace),
236
+ nn.Conv2d(last_inp_channels, num_classes,
237
+ kernel_size=1, stride=1, padding=0, bias=True)
238
+ )
239
+ else:
240
+ self.cls_head = nn.Sequential(
241
+ nn.Conv2d(last_inp_channels, last_inp_channels,
242
+ kernel_size=3, stride=1, padding=1),
243
+ norm_layer(last_inp_channels),
244
+ nn.ReLU(inplace=relu_inplace),
245
+ nn.Conv2d(last_inp_channels, num_classes,
246
+ kernel_size=1, stride=1, padding=0, bias=True)
247
+ )
248
+
249
+ def _make_transition_layer(
250
+ self, num_channels_pre_layer, num_channels_cur_layer):
251
+ num_branches_cur = len(num_channels_cur_layer)
252
+ num_branches_pre = len(num_channels_pre_layer)
253
+
254
+ transition_layers = []
255
+ for i in range(num_branches_cur):
256
+ if i < num_branches_pre:
257
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
258
+ transition_layers.append(nn.Sequential(
259
+ nn.Conv2d(num_channels_pre_layer[i],
260
+ num_channels_cur_layer[i],
261
+ kernel_size=3,
262
+ stride=1,
263
+ padding=1,
264
+ bias=False),
265
+ self.norm_layer(num_channels_cur_layer[i]),
266
+ nn.ReLU(inplace=relu_inplace)))
267
+ else:
268
+ transition_layers.append(None)
269
+ else:
270
+ conv3x3s = []
271
+ for j in range(i + 1 - num_branches_pre):
272
+ inchannels = num_channels_pre_layer[-1]
273
+ outchannels = num_channels_cur_layer[i] \
274
+ if j == i - num_branches_pre else inchannels
275
+ conv3x3s.append(nn.Sequential(
276
+ nn.Conv2d(inchannels, outchannels,
277
+ kernel_size=3, stride=2, padding=1, bias=False),
278
+ self.norm_layer(outchannels),
279
+ nn.ReLU(inplace=relu_inplace)))
280
+ transition_layers.append(nn.Sequential(*conv3x3s))
281
+
282
+ return nn.ModuleList(transition_layers)
283
+
284
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
285
+ downsample = None
286
+ if stride != 1 or inplanes != planes * block.expansion:
287
+ downsample = nn.Sequential(
288
+ nn.Conv2d(inplanes, planes * block.expansion,
289
+ kernel_size=1, stride=stride, bias=False),
290
+ self.norm_layer(planes * block.expansion),
291
+ )
292
+
293
+ layers = []
294
+ layers.append(block(inplanes, planes, stride,
295
+ downsample=downsample, norm_layer=self.norm_layer))
296
+ inplanes = planes * block.expansion
297
+ for i in range(1, blocks):
298
+ layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
299
+
300
+ return nn.Sequential(*layers)
301
+
302
+ def _make_stage(self, block, num_inchannels,
303
+ num_modules, num_branches, num_blocks, num_channels,
304
+ fuse_method='SUM',
305
+ multi_scale_output=True):
306
+ modules = []
307
+ for i in range(num_modules):
308
+ # multi_scale_output is only used last module
309
+ if not multi_scale_output and i == num_modules - 1:
310
+ reset_multi_scale_output = False
311
+ else:
312
+ reset_multi_scale_output = True
313
+ modules.append(
314
+ HighResolutionModule(num_branches,
315
+ block,
316
+ num_blocks,
317
+ num_inchannels,
318
+ num_channels,
319
+ fuse_method,
320
+ reset_multi_scale_output,
321
+ norm_layer=self.norm_layer,
322
+ align_corners=self.align_corners)
323
+ )
324
+ num_inchannels = modules[-1].get_num_inchannels()
325
+
326
+ return nn.Sequential(*modules), num_inchannels
327
+
328
+ def forward(self, x, additional_features=None):
329
+ feats = self.compute_hrnet_feats(x, additional_features)
330
+ if self.ocr_width > 0:
331
+ out_aux = self.aux_head(feats)
332
+ feats = self.conv3x3_ocr(feats)
333
+
334
+ context = self.ocr_gather_head(feats, out_aux)
335
+ feats = self.ocr_distri_head(feats, context)
336
+ out = self.cls_head(feats)
337
+ return [out, out_aux]
338
+ else:
339
+ return [self.cls_head(feats), None]
340
+
341
+ def compute_hrnet_feats(self, x, additional_features):
342
+ x = self.compute_pre_stage_features(x, additional_features)
343
+ x = self.layer1(x)
344
+
345
+ x_list = []
346
+ for i in range(self.stage2_num_branches):
347
+ if self.transition1[i] is not None:
348
+ x_list.append(self.transition1[i](x))
349
+ else:
350
+ x_list.append(x)
351
+ y_list = self.stage2(x_list)
352
+
353
+ x_list = []
354
+ for i in range(self.stage3_num_branches):
355
+ if self.transition2[i] is not None:
356
+ if i < self.stage2_num_branches:
357
+ x_list.append(self.transition2[i](y_list[i]))
358
+ else:
359
+ x_list.append(self.transition2[i](y_list[-1]))
360
+ else:
361
+ x_list.append(y_list[i])
362
+ y_list = self.stage3(x_list)
363
+
364
+ x_list = []
365
+ for i in range(self.stage4_num_branches):
366
+ if self.transition3[i] is not None:
367
+ if i < self.stage3_num_branches:
368
+ x_list.append(self.transition3[i](y_list[i]))
369
+ else:
370
+ x_list.append(self.transition3[i](y_list[-1]))
371
+ else:
372
+ x_list.append(y_list[i])
373
+ x = self.stage4(x_list)
374
+
375
+ return self.aggregate_hrnet_features(x)
376
+
377
+ def compute_pre_stage_features(self, x, additional_features):
378
+ x = self.conv1(x)
379
+ x = self.bn1(x)
380
+ x = self.relu(x)
381
+ if additional_features is not None:
382
+ x = x + additional_features
383
+ x = self.conv2(x)
384
+ x = self.bn2(x)
385
+ return self.relu(x)
386
+
387
+ def aggregate_hrnet_features(self, x):
388
+ # Upsampling
389
+ x0_h, x0_w = x[0].size(2), x[0].size(3)
390
+ x1 = F.interpolate(x[1], size=(x0_h, x0_w),
391
+ mode='bilinear', align_corners=self.align_corners)
392
+ x2 = F.interpolate(x[2], size=(x0_h, x0_w),
393
+ mode='bilinear', align_corners=self.align_corners)
394
+ x3 = F.interpolate(x[3], size=(x0_h, x0_w),
395
+ mode='bilinear', align_corners=self.align_corners)
396
+
397
+ return torch.cat([x[0], x1, x2, x3], 1)
398
+
399
+ def load_pretrained_weights(self, pretrained_path=''):
400
+ model_dict = self.state_dict()
401
+
402
+ if not os.path.exists(pretrained_path):
403
+ print(f'\nFile "{pretrained_path}" does not exist.')
404
+ print('You need to specify the correct path to the pre-trained weights.\n'
405
+ 'You can download the weights for HRNet from the repository:\n'
406
+ 'https://github.com/HRNet/HRNet-Image-Classification')
407
+ exit(1)
408
+ pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'})
409
+ pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
410
+ pretrained_dict.items()}
411
+
412
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
413
+ if k in model_dict.keys()}
414
+
415
+ model_dict.update(pretrained_dict)
416
+ self.load_state_dict(model_dict)
isegm/model/modeling/ocr.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch._utils
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class SpatialGather_Module(nn.Module):
8
+ """
9
+ Aggregate the context features according to the initial
10
+ predicted probability distribution.
11
+ Employ the soft-weighted method to aggregate the context.
12
+ """
13
+
14
+ def __init__(self, cls_num=0, scale=1):
15
+ super(SpatialGather_Module, self).__init__()
16
+ self.cls_num = cls_num
17
+ self.scale = scale
18
+
19
+ def forward(self, feats, probs):
20
+ batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
21
+ probs = probs.view(batch_size, c, -1)
22
+ feats = feats.view(batch_size, feats.size(1), -1)
23
+ feats = feats.permute(0, 2, 1) # batch x hw x c
24
+ probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
25
+ ocr_context = torch.matmul(probs, feats) \
26
+ .permute(0, 2, 1).unsqueeze(3) # batch x k x c
27
+ return ocr_context
28
+
29
+
30
+ class SpatialOCR_Module(nn.Module):
31
+ """
32
+ Implementation of the OCR module:
33
+ We aggregate the global object representation to update the representation for each pixel.
34
+ """
35
+
36
+ def __init__(self,
37
+ in_channels,
38
+ key_channels,
39
+ out_channels,
40
+ scale=1,
41
+ dropout=0.1,
42
+ norm_layer=nn.BatchNorm2d,
43
+ align_corners=True):
44
+ super(SpatialOCR_Module, self).__init__()
45
+ self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
46
+ norm_layer, align_corners)
47
+ _in_channels = 2 * in_channels
48
+
49
+ self.conv_bn_dropout = nn.Sequential(
50
+ nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
51
+ nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
52
+ nn.Dropout2d(dropout)
53
+ )
54
+
55
+ def forward(self, feats, proxy_feats):
56
+ context = self.object_context_block(feats, proxy_feats)
57
+
58
+ output = self.conv_bn_dropout(torch.cat([context, feats], 1))
59
+
60
+ return output
61
+
62
+
63
+ class ObjectAttentionBlock2D(nn.Module):
64
+ '''
65
+ The basic implementation for object context block
66
+ Input:
67
+ N X C X H X W
68
+ Parameters:
69
+ in_channels : the dimension of the input feature map
70
+ key_channels : the dimension after the key/query transform
71
+ scale : choose the scale to downsample the input feature maps (save memory cost)
72
+ bn_type : specify the bn type
73
+ Return:
74
+ N X C X H X W
75
+ '''
76
+
77
+ def __init__(self,
78
+ in_channels,
79
+ key_channels,
80
+ scale=1,
81
+ norm_layer=nn.BatchNorm2d,
82
+ align_corners=True):
83
+ super(ObjectAttentionBlock2D, self).__init__()
84
+ self.scale = scale
85
+ self.in_channels = in_channels
86
+ self.key_channels = key_channels
87
+ self.align_corners = align_corners
88
+
89
+ self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
90
+ self.f_pixel = nn.Sequential(
91
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
92
+ kernel_size=1, stride=1, padding=0, bias=False),
93
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
94
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
95
+ kernel_size=1, stride=1, padding=0, bias=False),
96
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
97
+ )
98
+ self.f_object = nn.Sequential(
99
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
100
+ kernel_size=1, stride=1, padding=0, bias=False),
101
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
102
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
103
+ kernel_size=1, stride=1, padding=0, bias=False),
104
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
105
+ )
106
+ self.f_down = nn.Sequential(
107
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
108
+ kernel_size=1, stride=1, padding=0, bias=False),
109
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
110
+ )
111
+ self.f_up = nn.Sequential(
112
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
113
+ kernel_size=1, stride=1, padding=0, bias=False),
114
+ nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
115
+ )
116
+
117
+ def forward(self, x, proxy):
118
+ batch_size, h, w = x.size(0), x.size(2), x.size(3)
119
+ if self.scale > 1:
120
+ x = self.pool(x)
121
+
122
+ query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
123
+ query = query.permute(0, 2, 1)
124
+ key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
125
+ value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
126
+ value = value.permute(0, 2, 1)
127
+
128
+ sim_map = torch.matmul(query, key)
129
+ sim_map = (self.key_channels ** -.5) * sim_map
130
+ sim_map = F.softmax(sim_map, dim=-1)
131
+
132
+ # add bg context ...
133
+ context = torch.matmul(sim_map, value)
134
+ context = context.permute(0, 2, 1).contiguous()
135
+ context = context.view(batch_size, self.key_channels, *x.size()[2:])
136
+ context = self.f_up(context)
137
+ if self.scale > 1:
138
+ context = F.interpolate(input=context, size=(h, w),
139
+ mode='bilinear', align_corners=self.align_corners)
140
+
141
+ return context
isegm/model/modeling/resnet.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
3
+
4
+
5
+ class ResNetBackbone(torch.nn.Module):
6
+ def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs):
7
+ super(ResNetBackbone, self).__init__()
8
+
9
+ if backbone == 'resnet34':
10
+ pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs)
11
+ elif backbone == 'resnet50':
12
+ pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
13
+ elif backbone == 'resnet101':
14
+ pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
15
+ elif backbone == 'resnet152':
16
+ pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
17
+ else:
18
+ raise RuntimeError(f'unknown backbone: {backbone}')
19
+
20
+ self.conv1 = pretrained.conv1
21
+ self.bn1 = pretrained.bn1
22
+ self.relu = pretrained.relu
23
+ self.maxpool = pretrained.maxpool
24
+ self.layer1 = pretrained.layer1
25
+ self.layer2 = pretrained.layer2
26
+ self.layer3 = pretrained.layer3
27
+ self.layer4 = pretrained.layer4
28
+
29
+ def forward(self, x, additional_features=None):
30
+ x = self.conv1(x)
31
+ x = self.bn1(x)
32
+ x = self.relu(x)
33
+ if additional_features is not None:
34
+ x = x + torch.nn.functional.pad(additional_features,
35
+ [0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)],
36
+ mode='constant', value=0)
37
+ x = self.maxpool(x)
38
+ c1 = self.layer1(x)
39
+ c2 = self.layer2(c1)
40
+ c3 = self.layer3(c2)
41
+ c4 = self.layer4(c3)
42
+
43
+ return c1, c2, c3, c4
isegm/model/modeling/resnetv1b.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
4
+
5
+
6
+ class BasicBlockV1b(nn.Module):
7
+ expansion = 1
8
+
9
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
10
+ previous_dilation=1, norm_layer=nn.BatchNorm2d):
11
+ super(BasicBlockV1b, self).__init__()
12
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
13
+ padding=dilation, dilation=dilation, bias=False)
14
+ self.bn1 = norm_layer(planes)
15
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
16
+ padding=previous_dilation, dilation=previous_dilation, bias=False)
17
+ self.bn2 = norm_layer(planes)
18
+
19
+ self.relu = nn.ReLU(inplace=True)
20
+ self.downsample = downsample
21
+ self.stride = stride
22
+
23
+ def forward(self, x):
24
+ residual = x
25
+
26
+ out = self.conv1(x)
27
+ out = self.bn1(out)
28
+ out = self.relu(out)
29
+
30
+ out = self.conv2(out)
31
+ out = self.bn2(out)
32
+
33
+ if self.downsample is not None:
34
+ residual = self.downsample(x)
35
+
36
+ out = out + residual
37
+ out = self.relu(out)
38
+
39
+ return out
40
+
41
+
42
+ class BottleneckV1b(nn.Module):
43
+ expansion = 4
44
+
45
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
46
+ previous_dilation=1, norm_layer=nn.BatchNorm2d):
47
+ super(BottleneckV1b, self).__init__()
48
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
49
+ self.bn1 = norm_layer(planes)
50
+
51
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
52
+ padding=dilation, dilation=dilation, bias=False)
53
+ self.bn2 = norm_layer(planes)
54
+
55
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
56
+ self.bn3 = norm_layer(planes * self.expansion)
57
+
58
+ self.relu = nn.ReLU(inplace=True)
59
+ self.downsample = downsample
60
+ self.stride = stride
61
+
62
+ def forward(self, x):
63
+ residual = x
64
+
65
+ out = self.conv1(x)
66
+ out = self.bn1(out)
67
+ out = self.relu(out)
68
+
69
+ out = self.conv2(out)
70
+ out = self.bn2(out)
71
+ out = self.relu(out)
72
+
73
+ out = self.conv3(out)
74
+ out = self.bn3(out)
75
+
76
+ if self.downsample is not None:
77
+ residual = self.downsample(x)
78
+
79
+ out = out + residual
80
+ out = self.relu(out)
81
+
82
+ return out
83
+
84
+
85
+ class ResNetV1b(nn.Module):
86
+ """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
87
+
88
+ Parameters
89
+ ----------
90
+ block : Block
91
+ Class for the residual block. Options are BasicBlockV1, BottleneckV1.
92
+ layers : list of int
93
+ Numbers of layers in each block
94
+ classes : int, default 1000
95
+ Number of classification classes.
96
+ dilated : bool, default False
97
+ Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
98
+ typically used in Semantic Segmentation.
99
+ norm_layer : object
100
+ Normalization layer used (default: :class:`nn.BatchNorm2d`)
101
+ deep_stem : bool, default False
102
+ Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
103
+ avg_down : bool, default False
104
+ Whether to use average pooling for projection skip connection between stages/downsample.
105
+ final_drop : float, default 0.0
106
+ Dropout ratio before the final classification layer.
107
+
108
+ Reference:
109
+ - He, Kaiming, et al. "Deep residual learning for image recognition."
110
+ Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
111
+
112
+ - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
113
+ """
114
+ def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
115
+ avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
116
+ self.inplanes = stem_width*2 if deep_stem else 64
117
+ super(ResNetV1b, self).__init__()
118
+ if not deep_stem:
119
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
120
+ else:
121
+ self.conv1 = nn.Sequential(
122
+ nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
123
+ norm_layer(stem_width),
124
+ nn.ReLU(True),
125
+ nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
126
+ norm_layer(stem_width),
127
+ nn.ReLU(True),
128
+ nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
129
+ )
130
+ self.bn1 = norm_layer(self.inplanes)
131
+ self.relu = nn.ReLU(True)
132
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
133
+ self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
134
+ norm_layer=norm_layer)
135
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
136
+ norm_layer=norm_layer)
137
+ if dilated:
138
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
139
+ avg_down=avg_down, norm_layer=norm_layer)
140
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
141
+ avg_down=avg_down, norm_layer=norm_layer)
142
+ else:
143
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
144
+ avg_down=avg_down, norm_layer=norm_layer)
145
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
146
+ avg_down=avg_down, norm_layer=norm_layer)
147
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
148
+ self.drop = None
149
+ if final_drop > 0.0:
150
+ self.drop = nn.Dropout(final_drop)
151
+ self.fc = nn.Linear(512 * block.expansion, classes)
152
+
153
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
154
+ avg_down=False, norm_layer=nn.BatchNorm2d):
155
+ downsample = None
156
+ if stride != 1 or self.inplanes != planes * block.expansion:
157
+ downsample = []
158
+ if avg_down:
159
+ if dilation == 1:
160
+ downsample.append(
161
+ nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
162
+ )
163
+ else:
164
+ downsample.append(
165
+ nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
166
+ )
167
+ downsample.extend([
168
+ nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
169
+ kernel_size=1, stride=1, bias=False),
170
+ norm_layer(planes * block.expansion)
171
+ ])
172
+ downsample = nn.Sequential(*downsample)
173
+ else:
174
+ downsample = nn.Sequential(
175
+ nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
176
+ kernel_size=1, stride=stride, bias=False),
177
+ norm_layer(planes * block.expansion)
178
+ )
179
+
180
+ layers = []
181
+ if dilation in (1, 2):
182
+ layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
183
+ previous_dilation=dilation, norm_layer=norm_layer))
184
+ elif dilation == 4:
185
+ layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
186
+ previous_dilation=dilation, norm_layer=norm_layer))
187
+ else:
188
+ raise RuntimeError("=> unknown dilation size: {}".format(dilation))
189
+
190
+ self.inplanes = planes * block.expansion
191
+ for _ in range(1, blocks):
192
+ layers.append(block(self.inplanes, planes, dilation=dilation,
193
+ previous_dilation=dilation, norm_layer=norm_layer))
194
+
195
+ return nn.Sequential(*layers)
196
+
197
+ def forward(self, x):
198
+ x = self.conv1(x)
199
+ x = self.bn1(x)
200
+ x = self.relu(x)
201
+ x = self.maxpool(x)
202
+
203
+ x = self.layer1(x)
204
+ x = self.layer2(x)
205
+ x = self.layer3(x)
206
+ x = self.layer4(x)
207
+
208
+ x = self.avgpool(x)
209
+ x = x.view(x.size(0), -1)
210
+ if self.drop is not None:
211
+ x = self.drop(x)
212
+ x = self.fc(x)
213
+
214
+ return x
215
+
216
+
217
+ def _safe_state_dict_filtering(orig_dict, model_dict_keys):
218
+ filtered_orig_dict = {}
219
+ for k, v in orig_dict.items():
220
+ if k in model_dict_keys:
221
+ filtered_orig_dict[k] = v
222
+ else:
223
+ print(f"[ERROR] Failed to load <{k}> in backbone")
224
+ return filtered_orig_dict
225
+
226
+
227
+ def resnet34_v1b(pretrained=False, **kwargs):
228
+ model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
229
+ if pretrained:
230
+ model_dict = model.state_dict()
231
+ filtered_orig_dict = _safe_state_dict_filtering(
232
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
233
+ model_dict.keys()
234
+ )
235
+ model_dict.update(filtered_orig_dict)
236
+ model.load_state_dict(model_dict)
237
+ return model
238
+
239
+
240
+ def resnet50_v1s(pretrained=False, **kwargs):
241
+ model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
242
+ if pretrained:
243
+ model_dict = model.state_dict()
244
+ filtered_orig_dict = _safe_state_dict_filtering(
245
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
246
+ model_dict.keys()
247
+ )
248
+ model_dict.update(filtered_orig_dict)
249
+ model.load_state_dict(model_dict)
250
+ return model
251
+
252
+
253
+ def resnet101_v1s(pretrained=False, **kwargs):
254
+ model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
255
+ if pretrained:
256
+ model_dict = model.state_dict()
257
+ filtered_orig_dict = _safe_state_dict_filtering(
258
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
259
+ model_dict.keys()
260
+ )
261
+ model_dict.update(filtered_orig_dict)
262
+ model.load_state_dict(model_dict)
263
+ return model
264
+
265
+
266
+ def resnet152_v1s(pretrained=False, **kwargs):
267
+ model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
268
+ if pretrained:
269
+ model_dict = model.state_dict()
270
+ filtered_orig_dict = _safe_state_dict_filtering(
271
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
272
+ model_dict.keys()
273
+ )
274
+ model_dict.update(filtered_orig_dict)
275
+ model.load_state_dict(model_dict)
276
+ return model
isegm/model/modifiers.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ class LRMult(object):
4
+ def __init__(self, lr_mult=1.):
5
+ self.lr_mult = lr_mult
6
+
7
+ def __call__(self, m):
8
+ if getattr(m, 'weight', None) is not None:
9
+ m.weight.lr_mult = self.lr_mult
10
+ if getattr(m, 'bias', None) is not None:
11
+ m.bias.lr_mult = self.lr_mult