Spaces:
Runtime error
Runtime error
Init the space
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +2 -0
- Makefile +8 -0
- app.py +83 -2
- isegm/data/base.py +99 -0
- isegm/data/compose.py +39 -0
- isegm/data/datasets/__init__.py +12 -0
- isegm/data/datasets/ade20k.py +55 -0
- isegm/data/datasets/berkeley.py +6 -0
- isegm/data/datasets/coco.py +74 -0
- isegm/data/datasets/coco_lvis.py +67 -0
- isegm/data/datasets/davis.py +33 -0
- isegm/data/datasets/grabcut.py +34 -0
- isegm/data/datasets/images_dir.py +59 -0
- isegm/data/datasets/lvis.py +97 -0
- isegm/data/datasets/openimages.py +58 -0
- isegm/data/datasets/pascalvoc.py +48 -0
- isegm/data/datasets/sbd.py +111 -0
- isegm/data/points_sampler.py +305 -0
- isegm/data/sample.py +148 -0
- isegm/data/transforms.py +178 -0
- isegm/engine/optimizer.py +27 -0
- isegm/engine/trainer.py +413 -0
- isegm/inference/__init__.py +0 -0
- isegm/inference/clicker.py +118 -0
- isegm/inference/evaluation.py +56 -0
- isegm/inference/predictors/__init__.py +98 -0
- isegm/inference/predictors/base.py +126 -0
- isegm/inference/predictors/brs.py +307 -0
- isegm/inference/predictors/brs_functors.py +109 -0
- isegm/inference/predictors/brs_losses.py +58 -0
- isegm/inference/transforms/__init__.py +5 -0
- isegm/inference/transforms/base.py +38 -0
- isegm/inference/transforms/crops.py +97 -0
- isegm/inference/transforms/flip.py +37 -0
- isegm/inference/transforms/limit_longest_side.py +22 -0
- isegm/inference/transforms/zoom_in.py +175 -0
- isegm/inference/utils.py +143 -0
- isegm/model/initializer.py +105 -0
- isegm/model/is_deeplab_model.py +25 -0
- isegm/model/is_hrnet_model.py +26 -0
- isegm/model/is_model.py +141 -0
- isegm/model/losses.py +161 -0
- isegm/model/metrics.py +101 -0
- isegm/model/modeling/basic_blocks.py +71 -0
- isegm/model/modeling/deeplab_v3.py +176 -0
- isegm/model/modeling/hrnet_ocr.py +416 -0
- isegm/model/modeling/ocr.py +141 -0
- isegm/model/modeling/resnet.py +43 -0
- isegm/model/modeling/resnetv1b.py +276 -0
- isegm/model/modifiers.py +11 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.pth
|
Makefile
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PYTHON=3.9
|
2 |
+
BASENAME=$(shell basename $(CURDIR))
|
3 |
+
|
4 |
+
env:
|
5 |
+
conda create -n $(BASENAME) python=$(PYTHON)
|
6 |
+
|
7 |
+
setup:
|
8 |
+
pip install -r requirements.txt
|
app.py
CHANGED
@@ -1,4 +1,85 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import wget
|
6 |
+
import os
|
7 |
|
8 |
+
from PIL import Image
|
9 |
+
from streamlit_drawable_canvas import st_canvas
|
10 |
+
|
11 |
+
from isegm.inference import clicker as ck
|
12 |
+
from isegm.inference import utils
|
13 |
+
from isegm.inference.predictors import get_predictor
|
14 |
+
|
15 |
+
# Model Path
|
16 |
+
prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
|
17 |
+
models = {
|
18 |
+
"RITM": "ritm_coco_lvis_h18_itermask.pth",
|
19 |
+
}
|
20 |
+
|
21 |
+
# Items in the sidebar.
|
22 |
+
model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
|
23 |
+
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
|
24 |
+
marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative"))
|
25 |
+
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
|
26 |
+
|
27 |
+
# Objects for prediction.
|
28 |
+
clicker = ck.Clicker()
|
29 |
+
device = torch.device("cpu")
|
30 |
+
predictor = None
|
31 |
+
with st.spinner("Wait for downloading a model..."):
|
32 |
+
if not os.path.exists(models[model]):
|
33 |
+
_ = wget.download(f"{prefix}/{models[model]}")
|
34 |
+
|
35 |
+
with st.spinner("Wait for loading a model..."):
|
36 |
+
model = utils.load_is_model(models[model], device, cpu_dist_maps=True)
|
37 |
+
predictor_params = {"brs_mode": "NoBRS"}
|
38 |
+
predictor = get_predictor(model, device=device, **predictor_params)
|
39 |
+
|
40 |
+
# Create a canvas component.
|
41 |
+
image = None
|
42 |
+
if image_path:
|
43 |
+
image = Image.open(image_path)
|
44 |
+
canvas_height, canvas_width = 600, 600
|
45 |
+
pos_color, neg_color = "#3498DB", "#C70039"
|
46 |
+
st.title("Canvas:")
|
47 |
+
canvas_result = st_canvas(
|
48 |
+
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
49 |
+
stroke_width=3,
|
50 |
+
stroke_color=pos_color if marking_type == "positive" else neg_color,
|
51 |
+
background_color="#eee",
|
52 |
+
background_image=image,
|
53 |
+
update_streamlit=True,
|
54 |
+
drawing_mode="point",
|
55 |
+
point_display_radius=3,
|
56 |
+
key="canvas",
|
57 |
+
width=canvas_width,
|
58 |
+
height=canvas_height,
|
59 |
+
)
|
60 |
+
|
61 |
+
# Check the user inputs ans execute predictions.
|
62 |
+
st.title("Prediction:")
|
63 |
+
if canvas_result.json_data and canvas_result.json_data["objects"] and image:
|
64 |
+
objects = canvas_result.json_data["objects"]
|
65 |
+
image_width, image_height = image.size
|
66 |
+
ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
|
67 |
+
|
68 |
+
err_x, err_y = 5.5, 1.0
|
69 |
+
pos_clicks, neg_clicks = [], []
|
70 |
+
for click in objects:
|
71 |
+
x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
|
72 |
+
x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
|
73 |
+
|
74 |
+
is_positive = click["stroke"] == pos_color
|
75 |
+
click = ck.Click(is_positive=is_positive, coords=(y, x))
|
76 |
+
clicker.add_click(click)
|
77 |
+
|
78 |
+
# prediction.
|
79 |
+
pred = None
|
80 |
+
predictor.set_input_image(np.array(image))
|
81 |
+
with st.spinner("Wait for prediction..."):
|
82 |
+
pred = predictor.get_prediction(clicker, prev_mask=None)
|
83 |
+
pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
|
84 |
+
pred = np.where(pred > threshold, 1.0, 0)
|
85 |
+
st.image(pred, caption="")
|
isegm/data/base.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
from .points_sampler import MultiPointSampler
|
7 |
+
from .sample import DSample
|
8 |
+
|
9 |
+
|
10 |
+
class ISDataset(torch.utils.data.dataset.Dataset):
|
11 |
+
def __init__(self,
|
12 |
+
augmentator=None,
|
13 |
+
points_sampler=MultiPointSampler(max_num_points=12),
|
14 |
+
min_object_area=0,
|
15 |
+
keep_background_prob=0.0,
|
16 |
+
with_image_info=False,
|
17 |
+
samples_scores_path=None,
|
18 |
+
samples_scores_gamma=1.0,
|
19 |
+
epoch_len=-1):
|
20 |
+
super(ISDataset, self).__init__()
|
21 |
+
self.epoch_len = epoch_len
|
22 |
+
self.augmentator = augmentator
|
23 |
+
self.min_object_area = min_object_area
|
24 |
+
self.keep_background_prob = keep_background_prob
|
25 |
+
self.points_sampler = points_sampler
|
26 |
+
self.with_image_info = with_image_info
|
27 |
+
self.samples_precomputed_scores = self._load_samples_scores(samples_scores_path, samples_scores_gamma)
|
28 |
+
self.to_tensor = transforms.ToTensor()
|
29 |
+
|
30 |
+
self.dataset_samples = None
|
31 |
+
|
32 |
+
def __getitem__(self, index):
|
33 |
+
if self.samples_precomputed_scores is not None:
|
34 |
+
index = np.random.choice(self.samples_precomputed_scores['indices'],
|
35 |
+
p=self.samples_precomputed_scores['probs'])
|
36 |
+
else:
|
37 |
+
if self.epoch_len > 0:
|
38 |
+
index = random.randrange(0, len(self.dataset_samples))
|
39 |
+
|
40 |
+
sample = self.get_sample(index)
|
41 |
+
sample = self.augment_sample(sample)
|
42 |
+
sample.remove_small_objects(self.min_object_area)
|
43 |
+
|
44 |
+
self.points_sampler.sample_object(sample)
|
45 |
+
points = np.array(self.points_sampler.sample_points())
|
46 |
+
mask = self.points_sampler.selected_mask
|
47 |
+
|
48 |
+
output = {
|
49 |
+
'images': self.to_tensor(sample.image),
|
50 |
+
'points': points.astype(np.float32),
|
51 |
+
'instances': mask
|
52 |
+
}
|
53 |
+
|
54 |
+
if self.with_image_info:
|
55 |
+
output['image_info'] = sample.sample_id
|
56 |
+
|
57 |
+
return output
|
58 |
+
|
59 |
+
def augment_sample(self, sample) -> DSample:
|
60 |
+
if self.augmentator is None:
|
61 |
+
return sample
|
62 |
+
|
63 |
+
valid_augmentation = False
|
64 |
+
while not valid_augmentation:
|
65 |
+
sample.augment(self.augmentator)
|
66 |
+
keep_sample = (self.keep_background_prob < 0.0 or
|
67 |
+
random.random() < self.keep_background_prob)
|
68 |
+
valid_augmentation = len(sample) > 0 or keep_sample
|
69 |
+
|
70 |
+
return sample
|
71 |
+
|
72 |
+
def get_sample(self, index) -> DSample:
|
73 |
+
raise NotImplementedError
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
if self.epoch_len > 0:
|
77 |
+
return self.epoch_len
|
78 |
+
else:
|
79 |
+
return self.get_samples_number()
|
80 |
+
|
81 |
+
def get_samples_number(self):
|
82 |
+
return len(self.dataset_samples)
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def _load_samples_scores(samples_scores_path, samples_scores_gamma):
|
86 |
+
if samples_scores_path is None:
|
87 |
+
return None
|
88 |
+
|
89 |
+
with open(samples_scores_path, 'rb') as f:
|
90 |
+
images_scores = pickle.load(f)
|
91 |
+
|
92 |
+
probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores])
|
93 |
+
probs /= probs.sum()
|
94 |
+
samples_scores = {
|
95 |
+
'indices': [x[0] for x in images_scores],
|
96 |
+
'probs': probs
|
97 |
+
}
|
98 |
+
print(f'Loaded {len(probs)} weights with gamma={samples_scores_gamma}')
|
99 |
+
return samples_scores
|
isegm/data/compose.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from math import isclose
|
3 |
+
from .base import ISDataset
|
4 |
+
|
5 |
+
|
6 |
+
class ComposeDataset(ISDataset):
|
7 |
+
def __init__(self, datasets, **kwargs):
|
8 |
+
super(ComposeDataset, self).__init__(**kwargs)
|
9 |
+
|
10 |
+
self._datasets = datasets
|
11 |
+
self.dataset_samples = []
|
12 |
+
for dataset_indx, dataset in enumerate(self._datasets):
|
13 |
+
self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))])
|
14 |
+
|
15 |
+
def get_sample(self, index):
|
16 |
+
dataset_indx, sample_indx = self.dataset_samples[index]
|
17 |
+
return self._datasets[dataset_indx].get_sample(sample_indx)
|
18 |
+
|
19 |
+
|
20 |
+
class ProportionalComposeDataset(ISDataset):
|
21 |
+
def __init__(self, datasets, ratios, **kwargs):
|
22 |
+
super().__init__(**kwargs)
|
23 |
+
|
24 |
+
assert len(ratios) == len(datasets),\
|
25 |
+
"The number of datasets must match the number of ratios"
|
26 |
+
assert isclose(sum(ratios), 1.0),\
|
27 |
+
"The sum of ratios must be equal to 1"
|
28 |
+
|
29 |
+
self._ratios = ratios
|
30 |
+
self._datasets = datasets
|
31 |
+
self.dataset_samples = []
|
32 |
+
for dataset_indx, dataset in enumerate(self._datasets):
|
33 |
+
self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))])
|
34 |
+
|
35 |
+
def get_sample(self, index):
|
36 |
+
dataset_indx = np.random.choice(len(self._datasets), p=self._ratios)
|
37 |
+
sample_indx = np.random.choice(len(self._datasets[dataset_indx]))
|
38 |
+
|
39 |
+
return self._datasets[dataset_indx].get_sample(sample_indx)
|
isegm/data/datasets/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from isegm.data.compose import ComposeDataset, ProportionalComposeDataset
|
2 |
+
from .berkeley import BerkeleyDataset
|
3 |
+
from .coco import CocoDataset
|
4 |
+
from .davis import DavisDataset
|
5 |
+
from .grabcut import GrabCutDataset
|
6 |
+
from .coco_lvis import CocoLvisDataset
|
7 |
+
from .lvis import LvisDataset
|
8 |
+
from .openimages import OpenImagesDataset
|
9 |
+
from .sbd import SBDDataset, SBDEvaluationDataset
|
10 |
+
from .images_dir import ImagesDirDataset
|
11 |
+
from .ade20k import ADE20kDataset
|
12 |
+
from .pascalvoc import PascalVocDataset
|
isegm/data/datasets/ade20k.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import pickle as pkl
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from isegm.data.base import ISDataset
|
10 |
+
from isegm.data.sample import DSample
|
11 |
+
from isegm.utils.misc import get_labels_with_sizes
|
12 |
+
|
13 |
+
|
14 |
+
class ADE20kDataset(ISDataset):
|
15 |
+
def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs):
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
assert split in {'train', 'val'}
|
18 |
+
|
19 |
+
self.dataset_path = Path(dataset_path)
|
20 |
+
self.dataset_split = split
|
21 |
+
self.dataset_split_folder = 'training' if split == 'train' else 'validation'
|
22 |
+
self.stuff_prob = stuff_prob
|
23 |
+
|
24 |
+
anno_path = self.dataset_path / f'{split}-annotations-object-segmentation.pkl'
|
25 |
+
if os.path.exists(anno_path):
|
26 |
+
with anno_path.open('rb') as f:
|
27 |
+
annotations = pkl.load(f)
|
28 |
+
else:
|
29 |
+
raise RuntimeError(f"Can't find annotations at {anno_path}")
|
30 |
+
self.annotations = annotations
|
31 |
+
self.dataset_samples = list(annotations.keys())
|
32 |
+
|
33 |
+
def get_sample(self, index) -> DSample:
|
34 |
+
image_id = self.dataset_samples[index]
|
35 |
+
sample_annos = self.annotations[image_id]
|
36 |
+
|
37 |
+
image_path = str(self.dataset_path / sample_annos['folder'] / f'{image_id}.jpg')
|
38 |
+
image = cv2.imread(image_path)
|
39 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
40 |
+
|
41 |
+
# select random mask for an image
|
42 |
+
layer = random.choice(sample_annos['layers'])
|
43 |
+
mask_path = str(self.dataset_path / sample_annos['folder'] / layer['mask_name'])
|
44 |
+
instances_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)[:, :, 0] # the B channel holds instances
|
45 |
+
instances_mask = instances_mask.astype(np.int32)
|
46 |
+
object_ids, _ = get_labels_with_sizes(instances_mask)
|
47 |
+
|
48 |
+
if (self.stuff_prob <= 0) or (random.random() > self.stuff_prob):
|
49 |
+
# remove stuff objects
|
50 |
+
for i, object_id in enumerate(object_ids):
|
51 |
+
if i in layer['stuff_instances']:
|
52 |
+
instances_mask[instances_mask == object_id] = 0
|
53 |
+
object_ids, _ = get_labels_with_sizes(instances_mask)
|
54 |
+
|
55 |
+
return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index)
|
isegm/data/datasets/berkeley.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .grabcut import GrabCutDataset
|
2 |
+
|
3 |
+
|
4 |
+
class BerkeleyDataset(GrabCutDataset):
|
5 |
+
def __init__(self, dataset_path, **kwargs):
|
6 |
+
super().__init__(dataset_path, images_dir_name='images', masks_dir_name='masks', **kwargs)
|
isegm/data/datasets/coco.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from pathlib import Path
|
6 |
+
from isegm.data.base import ISDataset
|
7 |
+
from isegm.data.sample import DSample
|
8 |
+
|
9 |
+
|
10 |
+
class CocoDataset(ISDataset):
|
11 |
+
def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs):
|
12 |
+
super(CocoDataset, self).__init__(**kwargs)
|
13 |
+
self.split = split
|
14 |
+
self.dataset_path = Path(dataset_path)
|
15 |
+
self.stuff_prob = stuff_prob
|
16 |
+
|
17 |
+
self.load_samples()
|
18 |
+
|
19 |
+
def load_samples(self):
|
20 |
+
annotation_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}.json'
|
21 |
+
self.labels_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}'
|
22 |
+
self.images_path = self.dataset_path / self.split
|
23 |
+
|
24 |
+
with open(annotation_path, 'r') as f:
|
25 |
+
annotation = json.load(f)
|
26 |
+
|
27 |
+
self.dataset_samples = annotation['annotations']
|
28 |
+
|
29 |
+
self._categories = annotation['categories']
|
30 |
+
self._stuff_labels = [x['id'] for x in self._categories if x['isthing'] == 0]
|
31 |
+
self._things_labels = [x['id'] for x in self._categories if x['isthing'] == 1]
|
32 |
+
self._things_labels_set = set(self._things_labels)
|
33 |
+
self._stuff_labels_set = set(self._stuff_labels)
|
34 |
+
|
35 |
+
def get_sample(self, index) -> DSample:
|
36 |
+
dataset_sample = self.dataset_samples[index]
|
37 |
+
|
38 |
+
image_path = self.images_path / self.get_image_name(dataset_sample['file_name'])
|
39 |
+
label_path = self.labels_path / dataset_sample['file_name']
|
40 |
+
|
41 |
+
image = cv2.imread(str(image_path))
|
42 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
43 |
+
label = cv2.imread(str(label_path), cv2.IMREAD_UNCHANGED).astype(np.int32)
|
44 |
+
label = 256 * 256 * label[:, :, 0] + 256 * label[:, :, 1] + label[:, :, 2]
|
45 |
+
|
46 |
+
instance_map = np.full_like(label, 0)
|
47 |
+
things_ids = []
|
48 |
+
stuff_ids = []
|
49 |
+
|
50 |
+
for segment in dataset_sample['segments_info']:
|
51 |
+
class_id = segment['category_id']
|
52 |
+
obj_id = segment['id']
|
53 |
+
if class_id in self._things_labels_set:
|
54 |
+
if segment['iscrowd'] == 1:
|
55 |
+
continue
|
56 |
+
things_ids.append(obj_id)
|
57 |
+
else:
|
58 |
+
stuff_ids.append(obj_id)
|
59 |
+
|
60 |
+
instance_map[label == obj_id] = obj_id
|
61 |
+
|
62 |
+
if self.stuff_prob > 0 and random.random() < self.stuff_prob:
|
63 |
+
instances_ids = things_ids + stuff_ids
|
64 |
+
else:
|
65 |
+
instances_ids = things_ids
|
66 |
+
|
67 |
+
for stuff_id in stuff_ids:
|
68 |
+
instance_map[instance_map == stuff_id] = 0
|
69 |
+
|
70 |
+
return DSample(image, instance_map, objects_ids=instances_ids)
|
71 |
+
|
72 |
+
@classmethod
|
73 |
+
def get_image_name(cls, panoptic_name):
|
74 |
+
return panoptic_name.replace('.png', '.jpg')
|
isegm/data/datasets/coco_lvis.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import pickle
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
import cv2
|
7 |
+
from copy import deepcopy
|
8 |
+
from isegm.data.base import ISDataset
|
9 |
+
from isegm.data.sample import DSample
|
10 |
+
|
11 |
+
|
12 |
+
class CocoLvisDataset(ISDataset):
|
13 |
+
def __init__(self, dataset_path, split='train', stuff_prob=0.0,
|
14 |
+
allow_list_name=None, anno_file='hannotation.pickle', **kwargs):
|
15 |
+
super(CocoLvisDataset, self).__init__(**kwargs)
|
16 |
+
dataset_path = Path(dataset_path)
|
17 |
+
self._split_path = dataset_path / split
|
18 |
+
self.split = split
|
19 |
+
self._images_path = self._split_path / 'images'
|
20 |
+
self._masks_path = self._split_path / 'masks'
|
21 |
+
self.stuff_prob = stuff_prob
|
22 |
+
|
23 |
+
with open(self._split_path / anno_file, 'rb') as f:
|
24 |
+
self.dataset_samples = sorted(pickle.load(f).items())
|
25 |
+
|
26 |
+
if allow_list_name is not None:
|
27 |
+
allow_list_path = self._split_path / allow_list_name
|
28 |
+
with open(allow_list_path, 'r') as f:
|
29 |
+
allow_images_ids = json.load(f)
|
30 |
+
allow_images_ids = set(allow_images_ids)
|
31 |
+
|
32 |
+
self.dataset_samples = [sample for sample in self.dataset_samples
|
33 |
+
if sample[0] in allow_images_ids]
|
34 |
+
|
35 |
+
def get_sample(self, index) -> DSample:
|
36 |
+
image_id, sample = self.dataset_samples[index]
|
37 |
+
image_path = self._images_path / f'{image_id}.jpg'
|
38 |
+
|
39 |
+
image = cv2.imread(str(image_path))
|
40 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
41 |
+
|
42 |
+
packed_masks_path = self._masks_path / f'{image_id}.pickle'
|
43 |
+
with open(packed_masks_path, 'rb') as f:
|
44 |
+
encoded_layers, objs_mapping = pickle.load(f)
|
45 |
+
layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in encoded_layers]
|
46 |
+
layers = np.stack(layers, axis=2)
|
47 |
+
|
48 |
+
instances_info = deepcopy(sample['hierarchy'])
|
49 |
+
for inst_id, inst_info in list(instances_info.items()):
|
50 |
+
if inst_info is None:
|
51 |
+
inst_info = {'children': [], 'parent': None, 'node_level': 0}
|
52 |
+
instances_info[inst_id] = inst_info
|
53 |
+
inst_info['mapping'] = objs_mapping[inst_id]
|
54 |
+
|
55 |
+
if self.stuff_prob > 0 and random.random() < self.stuff_prob:
|
56 |
+
for inst_id in range(sample['num_instance_masks'], len(objs_mapping)):
|
57 |
+
instances_info[inst_id] = {
|
58 |
+
'mapping': objs_mapping[inst_id],
|
59 |
+
'parent': None,
|
60 |
+
'children': []
|
61 |
+
}
|
62 |
+
else:
|
63 |
+
for inst_id in range(sample['num_instance_masks'], len(objs_mapping)):
|
64 |
+
layer_indx, mask_id = objs_mapping[inst_id]
|
65 |
+
layers[:, :, layer_indx][layers[:, :, layer_indx] == mask_id] = 0
|
66 |
+
|
67 |
+
return DSample(image, layers, objects=instances_info)
|
isegm/data/datasets/davis.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from isegm.data.base import ISDataset
|
7 |
+
from isegm.data.sample import DSample
|
8 |
+
|
9 |
+
|
10 |
+
class DavisDataset(ISDataset):
|
11 |
+
def __init__(self, dataset_path,
|
12 |
+
images_dir_name='img', masks_dir_name='gt',
|
13 |
+
**kwargs):
|
14 |
+
super(DavisDataset, self).__init__(**kwargs)
|
15 |
+
|
16 |
+
self.dataset_path = Path(dataset_path)
|
17 |
+
self._images_path = self.dataset_path / images_dir_name
|
18 |
+
self._insts_path = self.dataset_path / masks_dir_name
|
19 |
+
|
20 |
+
self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))]
|
21 |
+
self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')}
|
22 |
+
|
23 |
+
def get_sample(self, index) -> DSample:
|
24 |
+
image_name = self.dataset_samples[index]
|
25 |
+
image_path = str(self._images_path / image_name)
|
26 |
+
mask_path = str(self._masks_paths[image_name.split('.')[0]])
|
27 |
+
|
28 |
+
image = cv2.imread(image_path)
|
29 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
30 |
+
instances_mask = np.max(cv2.imread(mask_path).astype(np.int32), axis=2)
|
31 |
+
instances_mask[instances_mask > 0] = 1
|
32 |
+
|
33 |
+
return DSample(image, instances_mask, objects_ids=[1], sample_id=index)
|
isegm/data/datasets/grabcut.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from isegm.data.base import ISDataset
|
7 |
+
from isegm.data.sample import DSample
|
8 |
+
|
9 |
+
|
10 |
+
class GrabCutDataset(ISDataset):
|
11 |
+
def __init__(self, dataset_path,
|
12 |
+
images_dir_name='data_GT', masks_dir_name='boundary_GT',
|
13 |
+
**kwargs):
|
14 |
+
super(GrabCutDataset, self).__init__(**kwargs)
|
15 |
+
|
16 |
+
self.dataset_path = Path(dataset_path)
|
17 |
+
self._images_path = self.dataset_path / images_dir_name
|
18 |
+
self._insts_path = self.dataset_path / masks_dir_name
|
19 |
+
|
20 |
+
self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))]
|
21 |
+
self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')}
|
22 |
+
|
23 |
+
def get_sample(self, index) -> DSample:
|
24 |
+
image_name = self.dataset_samples[index]
|
25 |
+
image_path = str(self._images_path / image_name)
|
26 |
+
mask_path = str(self._masks_paths[image_name.split('.')[0]])
|
27 |
+
|
28 |
+
image = cv2.imread(image_path)
|
29 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
30 |
+
instances_mask = cv2.imread(mask_path)[:, :, 0].astype(np.int32)
|
31 |
+
instances_mask[instances_mask == 128] = -1
|
32 |
+
instances_mask[instances_mask > 128] = 1
|
33 |
+
|
34 |
+
return DSample(image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index)
|
isegm/data/datasets/images_dir.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from isegm.data.base import ISDataset
|
6 |
+
from isegm.data.sample import DSample
|
7 |
+
|
8 |
+
|
9 |
+
class ImagesDirDataset(ISDataset):
|
10 |
+
def __init__(self, dataset_path,
|
11 |
+
images_dir_name='images', masks_dir_name='masks',
|
12 |
+
**kwargs):
|
13 |
+
super(ImagesDirDataset, self).__init__(**kwargs)
|
14 |
+
|
15 |
+
self.dataset_path = Path(dataset_path)
|
16 |
+
self._images_path = self.dataset_path / images_dir_name
|
17 |
+
self._insts_path = self.dataset_path / masks_dir_name
|
18 |
+
|
19 |
+
images_list = [x for x in sorted(self._images_path.glob('*.*'))]
|
20 |
+
|
21 |
+
samples = {x.stem: {'image': x, 'masks': []} for x in images_list}
|
22 |
+
for mask_path in self._insts_path.glob('*.*'):
|
23 |
+
mask_name = mask_path.stem
|
24 |
+
if mask_name in samples:
|
25 |
+
samples[mask_name]['masks'].append(mask_path)
|
26 |
+
continue
|
27 |
+
|
28 |
+
mask_name_split = mask_name.split('_')
|
29 |
+
if mask_name_split[-1].isdigit():
|
30 |
+
mask_name = '_'.join(mask_name_split[:-1])
|
31 |
+
assert mask_name in samples
|
32 |
+
samples[mask_name]['masks'].append(mask_path)
|
33 |
+
|
34 |
+
for x in samples.values():
|
35 |
+
assert len(x['masks']) > 0, x['image']
|
36 |
+
|
37 |
+
self.dataset_samples = [v for k, v in sorted(samples.items())]
|
38 |
+
|
39 |
+
def get_sample(self, index) -> DSample:
|
40 |
+
sample = self.dataset_samples[index]
|
41 |
+
image_path = str(sample['image'])
|
42 |
+
|
43 |
+
objects = []
|
44 |
+
ignored_regions = []
|
45 |
+
masks = []
|
46 |
+
for indx, mask_path in enumerate(sample['masks']):
|
47 |
+
gt_mask = cv2.imread(str(mask_path))[:, :, 0].astype(np.int32)
|
48 |
+
instances_mask = np.zeros_like(gt_mask)
|
49 |
+
instances_mask[gt_mask == 128] = 2
|
50 |
+
instances_mask[gt_mask > 128] = 1
|
51 |
+
masks.append(instances_mask)
|
52 |
+
objects.append((indx, 1))
|
53 |
+
ignored_regions.append((indx, 2))
|
54 |
+
|
55 |
+
image = cv2.imread(image_path)
|
56 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
57 |
+
|
58 |
+
return DSample(image, np.stack(masks, axis=2),
|
59 |
+
objects_ids=objects, ignore_ids=ignored_regions, sample_id=index)
|
isegm/data/datasets/lvis.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
from collections import defaultdict
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from isegm.data.base import ISDataset
|
10 |
+
from isegm.data.sample import DSample
|
11 |
+
|
12 |
+
|
13 |
+
class LvisDataset(ISDataset):
|
14 |
+
def __init__(self, dataset_path, split='train',
|
15 |
+
max_overlap_ratio=0.5,
|
16 |
+
**kwargs):
|
17 |
+
super(LvisDataset, self).__init__(**kwargs)
|
18 |
+
dataset_path = Path(dataset_path)
|
19 |
+
train_categories_path = dataset_path / 'train_categories.json'
|
20 |
+
self._train_path = dataset_path / 'train'
|
21 |
+
self._val_path = dataset_path / 'val'
|
22 |
+
|
23 |
+
self.split = split
|
24 |
+
self.max_overlap_ratio = max_overlap_ratio
|
25 |
+
|
26 |
+
with open( dataset_path / split / f'lvis_{self.split}.json', 'r') as f:
|
27 |
+
json_annotation = json.loads(f.read())
|
28 |
+
|
29 |
+
self.annotations = defaultdict(list)
|
30 |
+
for x in json_annotation['annotations']:
|
31 |
+
self.annotations[x['image_id']].append(x)
|
32 |
+
|
33 |
+
if not train_categories_path.exists():
|
34 |
+
self.generate_train_categories(dataset_path, train_categories_path)
|
35 |
+
self.dataset_samples = [x for x in json_annotation['images']
|
36 |
+
if len(self.annotations[x['id']]) > 0]
|
37 |
+
|
38 |
+
def get_sample(self, index) -> DSample:
|
39 |
+
image_info = self.dataset_samples[index]
|
40 |
+
image_id, image_url = image_info['id'], image_info['coco_url']
|
41 |
+
image_filename = image_url.split('/')[-1]
|
42 |
+
image_annotations = self.annotations[image_id]
|
43 |
+
random.shuffle(image_annotations)
|
44 |
+
|
45 |
+
# LVISv1 splits do not match older LVIS splits (some images in val may come from COCO train2017)
|
46 |
+
if 'train2017' in image_url:
|
47 |
+
image_path = self._train_path / 'images' / image_filename
|
48 |
+
else:
|
49 |
+
image_path = self._val_path / 'images' / image_filename
|
50 |
+
image = cv2.imread(str(image_path))
|
51 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
52 |
+
|
53 |
+
instances_mask = None
|
54 |
+
instances_area = defaultdict(int)
|
55 |
+
objects_ids = []
|
56 |
+
for indx, obj_annotation in enumerate(image_annotations):
|
57 |
+
mask = self.get_mask_from_polygon(obj_annotation, image)
|
58 |
+
object_mask = mask > 0
|
59 |
+
object_area = object_mask.sum()
|
60 |
+
|
61 |
+
if instances_mask is None:
|
62 |
+
instances_mask = np.zeros_like(object_mask, dtype=np.int32)
|
63 |
+
|
64 |
+
overlap_ids = np.bincount(instances_mask[object_mask].flatten())
|
65 |
+
overlap_areas = [overlap_area / instances_area[inst_id] for inst_id, overlap_area in enumerate(overlap_ids)
|
66 |
+
if overlap_area > 0 and inst_id > 0]
|
67 |
+
overlap_ratio = np.logical_and(object_mask, instances_mask > 0).sum() / object_area
|
68 |
+
if overlap_areas:
|
69 |
+
overlap_ratio = max(overlap_ratio, max(overlap_areas))
|
70 |
+
if overlap_ratio > self.max_overlap_ratio:
|
71 |
+
continue
|
72 |
+
|
73 |
+
instance_id = indx + 1
|
74 |
+
instances_mask[object_mask] = instance_id
|
75 |
+
instances_area[instance_id] = object_area
|
76 |
+
objects_ids.append(instance_id)
|
77 |
+
|
78 |
+
return DSample(image, instances_mask, objects_ids=objects_ids)
|
79 |
+
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def get_mask_from_polygon(annotation, image):
|
83 |
+
mask = np.zeros(image.shape[:2], dtype=np.int32)
|
84 |
+
for contour_points in annotation['segmentation']:
|
85 |
+
contour_points = np.array(contour_points).reshape((-1, 2))
|
86 |
+
contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :]
|
87 |
+
cv2.fillPoly(mask, contour_points, 1)
|
88 |
+
|
89 |
+
return mask
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def generate_train_categories(dataset_path, train_categories_path):
|
93 |
+
with open(dataset_path / 'train/lvis_train.json', 'r') as f:
|
94 |
+
annotation = json.load(f)
|
95 |
+
|
96 |
+
with open(train_categories_path, 'w') as f:
|
97 |
+
json.dump(annotation['categories'], f, indent=1)
|
isegm/data/datasets/openimages.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import pickle as pkl
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from isegm.data.base import ISDataset
|
10 |
+
from isegm.data.sample import DSample
|
11 |
+
|
12 |
+
|
13 |
+
class OpenImagesDataset(ISDataset):
|
14 |
+
def __init__(self, dataset_path, split='train', **kwargs):
|
15 |
+
super().__init__(**kwargs)
|
16 |
+
assert split in {'train', 'val', 'test'}
|
17 |
+
|
18 |
+
self.dataset_path = Path(dataset_path)
|
19 |
+
self._split_path = self.dataset_path / split
|
20 |
+
self._images_path = self._split_path / 'images'
|
21 |
+
self._masks_path = self._split_path / 'masks'
|
22 |
+
self.dataset_split = split
|
23 |
+
|
24 |
+
clean_anno_path = self._split_path / f'{split}-annotations-object-segmentation_clean.pkl'
|
25 |
+
if os.path.exists(clean_anno_path):
|
26 |
+
with clean_anno_path.open('rb') as f:
|
27 |
+
annotations = pkl.load(f)
|
28 |
+
else:
|
29 |
+
raise RuntimeError(f"Can't find annotations at {clean_anno_path}")
|
30 |
+
self.image_id_to_masks = annotations['image_id_to_masks']
|
31 |
+
self.dataset_samples = annotations['dataset_samples']
|
32 |
+
|
33 |
+
def get_sample(self, index) -> DSample:
|
34 |
+
image_id = self.dataset_samples[index]
|
35 |
+
|
36 |
+
image_path = str(self._images_path / f'{image_id}.jpg')
|
37 |
+
image = cv2.imread(image_path)
|
38 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
39 |
+
|
40 |
+
mask_paths = self.image_id_to_masks[image_id]
|
41 |
+
# select random mask for an image
|
42 |
+
mask_path = str(self._masks_path / random.choice(mask_paths))
|
43 |
+
instances_mask = cv2.imread(mask_path)
|
44 |
+
instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY)
|
45 |
+
instances_mask[instances_mask > 0] = 1
|
46 |
+
instances_mask = instances_mask.astype(np.int32)
|
47 |
+
|
48 |
+
min_width = min(image.shape[1], instances_mask.shape[1])
|
49 |
+
min_height = min(image.shape[0], instances_mask.shape[0])
|
50 |
+
|
51 |
+
if image.shape[0] != min_height or image.shape[1] != min_width:
|
52 |
+
image = cv2.resize(image, (min_width, min_height), interpolation=cv2.INTER_LINEAR)
|
53 |
+
if instances_mask.shape[0] != min_height or instances_mask.shape[1] != min_width:
|
54 |
+
instances_mask = cv2.resize(instances_mask, (min_width, min_height), interpolation=cv2.INTER_NEAREST)
|
55 |
+
|
56 |
+
object_ids = [1] if instances_mask.sum() > 0 else []
|
57 |
+
|
58 |
+
return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index)
|
isegm/data/datasets/pascalvoc.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle as pkl
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from isegm.data.base import ISDataset
|
8 |
+
from isegm.data.sample import DSample
|
9 |
+
|
10 |
+
|
11 |
+
class PascalVocDataset(ISDataset):
|
12 |
+
def __init__(self, dataset_path, split='train', **kwargs):
|
13 |
+
super().__init__(**kwargs)
|
14 |
+
assert split in {'train', 'val', 'trainval', 'test'}
|
15 |
+
|
16 |
+
self.dataset_path = Path(dataset_path)
|
17 |
+
self._images_path = self.dataset_path / "JPEGImages"
|
18 |
+
self._insts_path = self.dataset_path / "SegmentationObject"
|
19 |
+
self.dataset_split = split
|
20 |
+
|
21 |
+
if split == 'test':
|
22 |
+
with open(self.dataset_path / f'ImageSets/Segmentation/test.pickle', 'rb') as f:
|
23 |
+
self.dataset_samples, self.instance_ids = pkl.load(f)
|
24 |
+
else:
|
25 |
+
with open(self.dataset_path / f'ImageSets/Segmentation/{split}.txt', 'r') as f:
|
26 |
+
self.dataset_samples = [name.strip() for name in f.readlines()]
|
27 |
+
|
28 |
+
def get_sample(self, index) -> DSample:
|
29 |
+
sample_id = self.dataset_samples[index]
|
30 |
+
image_path = str(self._images_path / f'{sample_id}.jpg')
|
31 |
+
mask_path = str(self._insts_path / f'{sample_id}.png')
|
32 |
+
|
33 |
+
image = cv2.imread(image_path)
|
34 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
35 |
+
instances_mask = cv2.imread(mask_path)
|
36 |
+
instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32)
|
37 |
+
if self.dataset_split == 'test':
|
38 |
+
instance_id = self.instance_ids[index]
|
39 |
+
mask = np.zeros_like(instances_mask)
|
40 |
+
mask[instances_mask == 220] = 220 # ignored area
|
41 |
+
mask[instances_mask == instance_id] = 1
|
42 |
+
objects_ids = [1]
|
43 |
+
instances_mask = mask
|
44 |
+
else:
|
45 |
+
objects_ids = np.unique(instances_mask)
|
46 |
+
objects_ids = [x for x in objects_ids if x != 0 and x != 220]
|
47 |
+
|
48 |
+
return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index)
|
isegm/data/datasets/sbd.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle as pkl
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from scipy.io import loadmat
|
7 |
+
|
8 |
+
from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes
|
9 |
+
from isegm.data.base import ISDataset
|
10 |
+
from isegm.data.sample import DSample
|
11 |
+
|
12 |
+
|
13 |
+
class SBDDataset(ISDataset):
|
14 |
+
def __init__(self, dataset_path, split='train', buggy_mask_thresh=0.08, **kwargs):
|
15 |
+
super(SBDDataset, self).__init__(**kwargs)
|
16 |
+
assert split in {'train', 'val'}
|
17 |
+
|
18 |
+
self.dataset_path = Path(dataset_path)
|
19 |
+
self.dataset_split = split
|
20 |
+
self._images_path = self.dataset_path / 'img'
|
21 |
+
self._insts_path = self.dataset_path / 'inst'
|
22 |
+
self._buggy_objects = dict()
|
23 |
+
self._buggy_mask_thresh = buggy_mask_thresh
|
24 |
+
|
25 |
+
with open(self.dataset_path / f'{split}.txt', 'r') as f:
|
26 |
+
self.dataset_samples = [x.strip() for x in f.readlines()]
|
27 |
+
|
28 |
+
def get_sample(self, index):
|
29 |
+
image_name = self.dataset_samples[index]
|
30 |
+
image_path = str(self._images_path / f'{image_name}.jpg')
|
31 |
+
inst_info_path = str(self._insts_path / f'{image_name}.mat')
|
32 |
+
|
33 |
+
image = cv2.imread(image_path)
|
34 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
35 |
+
instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32)
|
36 |
+
instances_mask = self.remove_buggy_masks(index, instances_mask)
|
37 |
+
instances_ids, _ = get_labels_with_sizes(instances_mask)
|
38 |
+
|
39 |
+
return DSample(image, instances_mask, objects_ids=instances_ids, sample_id=index)
|
40 |
+
|
41 |
+
def remove_buggy_masks(self, index, instances_mask):
|
42 |
+
if self._buggy_mask_thresh > 0.0:
|
43 |
+
buggy_image_objects = self._buggy_objects.get(index, None)
|
44 |
+
if buggy_image_objects is None:
|
45 |
+
buggy_image_objects = []
|
46 |
+
instances_ids, _ = get_labels_with_sizes(instances_mask)
|
47 |
+
for obj_id in instances_ids:
|
48 |
+
obj_mask = instances_mask == obj_id
|
49 |
+
mask_area = obj_mask.sum()
|
50 |
+
bbox = get_bbox_from_mask(obj_mask)
|
51 |
+
bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1)
|
52 |
+
obj_area_ratio = mask_area / bbox_area
|
53 |
+
if obj_area_ratio < self._buggy_mask_thresh:
|
54 |
+
buggy_image_objects.append(obj_id)
|
55 |
+
|
56 |
+
self._buggy_objects[index] = buggy_image_objects
|
57 |
+
for obj_id in buggy_image_objects:
|
58 |
+
instances_mask[instances_mask == obj_id] = 0
|
59 |
+
|
60 |
+
return instances_mask
|
61 |
+
|
62 |
+
|
63 |
+
class SBDEvaluationDataset(ISDataset):
|
64 |
+
def __init__(self, dataset_path, split='val', **kwargs):
|
65 |
+
super(SBDEvaluationDataset, self).__init__(**kwargs)
|
66 |
+
assert split in {'train', 'val'}
|
67 |
+
|
68 |
+
self.dataset_path = Path(dataset_path)
|
69 |
+
self.dataset_split = split
|
70 |
+
self._images_path = self.dataset_path / 'img'
|
71 |
+
self._insts_path = self.dataset_path / 'inst'
|
72 |
+
|
73 |
+
with open(self.dataset_path / f'{split}.txt', 'r') as f:
|
74 |
+
self.dataset_samples = [x.strip() for x in f.readlines()]
|
75 |
+
|
76 |
+
self.dataset_samples = self.get_sbd_images_and_ids_list()
|
77 |
+
|
78 |
+
def get_sample(self, index) -> DSample:
|
79 |
+
image_name, instance_id = self.dataset_samples[index]
|
80 |
+
image_path = str(self._images_path / f'{image_name}.jpg')
|
81 |
+
inst_info_path = str(self._insts_path / f'{image_name}.mat')
|
82 |
+
|
83 |
+
image = cv2.imread(image_path)
|
84 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
85 |
+
instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32)
|
86 |
+
instances_mask[instances_mask != instance_id] = 0
|
87 |
+
instances_mask[instances_mask > 0] = 1
|
88 |
+
|
89 |
+
return DSample(image, instances_mask, objects_ids=[1], sample_id=index)
|
90 |
+
|
91 |
+
def get_sbd_images_and_ids_list(self):
|
92 |
+
pkl_path = self.dataset_path / f'{self.dataset_split}_images_and_ids_list.pkl'
|
93 |
+
|
94 |
+
if pkl_path.exists():
|
95 |
+
with open(str(pkl_path), 'rb') as fp:
|
96 |
+
images_and_ids_list = pkl.load(fp)
|
97 |
+
else:
|
98 |
+
images_and_ids_list = []
|
99 |
+
|
100 |
+
for sample in self.dataset_samples:
|
101 |
+
inst_info_path = str(self._insts_path / f'{sample}.mat')
|
102 |
+
instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32)
|
103 |
+
instances_ids, _ = get_labels_with_sizes(instances_mask)
|
104 |
+
|
105 |
+
for instances_id in instances_ids:
|
106 |
+
images_and_ids_list.append((sample, instances_id))
|
107 |
+
|
108 |
+
with open(str(pkl_path), 'wb') as fp:
|
109 |
+
pkl.dump(images_and_ids_list, fp)
|
110 |
+
|
111 |
+
return images_and_ids_list
|
isegm/data/points_sampler.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from functools import lru_cache
|
6 |
+
from .sample import DSample
|
7 |
+
|
8 |
+
|
9 |
+
class BasePointSampler:
|
10 |
+
def __init__(self):
|
11 |
+
self._selected_mask = None
|
12 |
+
self._selected_masks = None
|
13 |
+
|
14 |
+
def sample_object(self, sample: DSample):
|
15 |
+
raise NotImplementedError
|
16 |
+
|
17 |
+
def sample_points(self):
|
18 |
+
raise NotImplementedError
|
19 |
+
|
20 |
+
@property
|
21 |
+
def selected_mask(self):
|
22 |
+
assert self._selected_mask is not None
|
23 |
+
return self._selected_mask
|
24 |
+
|
25 |
+
@selected_mask.setter
|
26 |
+
def selected_mask(self, mask):
|
27 |
+
self._selected_mask = mask[np.newaxis, :].astype(np.float32)
|
28 |
+
|
29 |
+
|
30 |
+
class MultiPointSampler(BasePointSampler):
|
31 |
+
def __init__(self, max_num_points, prob_gamma=0.7, expand_ratio=0.1,
|
32 |
+
positive_erode_prob=0.9, positive_erode_iters=3,
|
33 |
+
negative_bg_prob=0.1, negative_other_prob=0.4, negative_border_prob=0.5,
|
34 |
+
merge_objects_prob=0.0, max_num_merged_objects=2,
|
35 |
+
use_hierarchy=False, soft_targets=False,
|
36 |
+
first_click_center=False, only_one_first_click=False,
|
37 |
+
sfc_inner_k=1.7, sfc_full_inner_prob=0.0):
|
38 |
+
super().__init__()
|
39 |
+
self.max_num_points = max_num_points
|
40 |
+
self.expand_ratio = expand_ratio
|
41 |
+
self.positive_erode_prob = positive_erode_prob
|
42 |
+
self.positive_erode_iters = positive_erode_iters
|
43 |
+
self.merge_objects_prob = merge_objects_prob
|
44 |
+
self.use_hierarchy = use_hierarchy
|
45 |
+
self.soft_targets = soft_targets
|
46 |
+
self.first_click_center = first_click_center
|
47 |
+
self.only_one_first_click = only_one_first_click
|
48 |
+
self.sfc_inner_k = sfc_inner_k
|
49 |
+
self.sfc_full_inner_prob = sfc_full_inner_prob
|
50 |
+
|
51 |
+
if max_num_merged_objects == -1:
|
52 |
+
max_num_merged_objects = max_num_points
|
53 |
+
self.max_num_merged_objects = max_num_merged_objects
|
54 |
+
|
55 |
+
self.neg_strategies = ['bg', 'other', 'border']
|
56 |
+
self.neg_strategies_prob = [negative_bg_prob, negative_other_prob, negative_border_prob]
|
57 |
+
assert math.isclose(sum(self.neg_strategies_prob), 1.0)
|
58 |
+
|
59 |
+
self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma)
|
60 |
+
self._neg_probs = generate_probs(max_num_points + 1, gamma=prob_gamma)
|
61 |
+
self._neg_masks = None
|
62 |
+
|
63 |
+
def sample_object(self, sample: DSample):
|
64 |
+
if len(sample) == 0:
|
65 |
+
bg_mask = sample.get_background_mask()
|
66 |
+
self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32)
|
67 |
+
self._selected_masks = [[]]
|
68 |
+
self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies}
|
69 |
+
self._neg_masks['required'] = []
|
70 |
+
return
|
71 |
+
|
72 |
+
gt_mask, pos_masks, neg_masks = self._sample_mask(sample)
|
73 |
+
binary_gt_mask = gt_mask > 0.5 if self.soft_targets else gt_mask > 0
|
74 |
+
|
75 |
+
self.selected_mask = gt_mask
|
76 |
+
self._selected_masks = pos_masks
|
77 |
+
|
78 |
+
neg_mask_bg = np.logical_not(binary_gt_mask)
|
79 |
+
neg_mask_border = self._get_border_mask(binary_gt_mask)
|
80 |
+
if len(sample) <= len(self._selected_masks):
|
81 |
+
neg_mask_other = neg_mask_bg
|
82 |
+
else:
|
83 |
+
neg_mask_other = np.logical_and(np.logical_not(sample.get_background_mask()),
|
84 |
+
np.logical_not(binary_gt_mask))
|
85 |
+
|
86 |
+
self._neg_masks = {
|
87 |
+
'bg': neg_mask_bg,
|
88 |
+
'other': neg_mask_other,
|
89 |
+
'border': neg_mask_border,
|
90 |
+
'required': neg_masks
|
91 |
+
}
|
92 |
+
|
93 |
+
def _sample_mask(self, sample: DSample):
|
94 |
+
root_obj_ids = sample.root_objects
|
95 |
+
|
96 |
+
if len(root_obj_ids) > 1 and random.random() < self.merge_objects_prob:
|
97 |
+
max_selected_objects = min(len(root_obj_ids), self.max_num_merged_objects)
|
98 |
+
num_selected_objects = np.random.randint(2, max_selected_objects + 1)
|
99 |
+
random_ids = random.sample(root_obj_ids, num_selected_objects)
|
100 |
+
else:
|
101 |
+
random_ids = [random.choice(root_obj_ids)]
|
102 |
+
|
103 |
+
gt_mask = None
|
104 |
+
pos_segments = []
|
105 |
+
neg_segments = []
|
106 |
+
for obj_id in random_ids:
|
107 |
+
obj_gt_mask, obj_pos_segments, obj_neg_segments = self._sample_from_masks_layer(obj_id, sample)
|
108 |
+
if gt_mask is None:
|
109 |
+
gt_mask = obj_gt_mask
|
110 |
+
else:
|
111 |
+
gt_mask = np.maximum(gt_mask, obj_gt_mask)
|
112 |
+
|
113 |
+
pos_segments.extend(obj_pos_segments)
|
114 |
+
neg_segments.extend(obj_neg_segments)
|
115 |
+
|
116 |
+
pos_masks = [self._positive_erode(x) for x in pos_segments]
|
117 |
+
neg_masks = [self._positive_erode(x) for x in neg_segments]
|
118 |
+
|
119 |
+
return gt_mask, pos_masks, neg_masks
|
120 |
+
|
121 |
+
def _sample_from_masks_layer(self, obj_id, sample: DSample):
|
122 |
+
objs_tree = sample._objects
|
123 |
+
|
124 |
+
if not self.use_hierarchy:
|
125 |
+
node_mask = sample.get_object_mask(obj_id)
|
126 |
+
gt_mask = sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask
|
127 |
+
return gt_mask, [node_mask], []
|
128 |
+
|
129 |
+
def _select_node(node_id):
|
130 |
+
node_info = objs_tree[node_id]
|
131 |
+
if not node_info['children'] or random.random() < 0.5:
|
132 |
+
return node_id
|
133 |
+
return _select_node(random.choice(node_info['children']))
|
134 |
+
|
135 |
+
selected_node = _select_node(obj_id)
|
136 |
+
node_info = objs_tree[selected_node]
|
137 |
+
node_mask = sample.get_object_mask(selected_node)
|
138 |
+
gt_mask = sample.get_soft_object_mask(selected_node) if self.soft_targets else node_mask
|
139 |
+
pos_mask = node_mask.copy()
|
140 |
+
|
141 |
+
negative_segments = []
|
142 |
+
if node_info['parent'] is not None and node_info['parent'] in objs_tree:
|
143 |
+
parent_mask = sample.get_object_mask(node_info['parent'])
|
144 |
+
negative_segments.append(np.logical_and(parent_mask, np.logical_not(node_mask)))
|
145 |
+
|
146 |
+
for child_id in node_info['children']:
|
147 |
+
if objs_tree[child_id]['area'] / node_info['area'] < 0.10:
|
148 |
+
child_mask = sample.get_object_mask(child_id)
|
149 |
+
pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask))
|
150 |
+
|
151 |
+
if node_info['children']:
|
152 |
+
max_disabled_children = min(len(node_info['children']), 3)
|
153 |
+
num_disabled_children = np.random.randint(0, max_disabled_children + 1)
|
154 |
+
disabled_children = random.sample(node_info['children'], num_disabled_children)
|
155 |
+
|
156 |
+
for child_id in disabled_children:
|
157 |
+
child_mask = sample.get_object_mask(child_id)
|
158 |
+
pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask))
|
159 |
+
if self.soft_targets:
|
160 |
+
soft_child_mask = sample.get_soft_object_mask(child_id)
|
161 |
+
gt_mask = np.minimum(gt_mask, 1.0 - soft_child_mask)
|
162 |
+
else:
|
163 |
+
gt_mask = np.logical_and(gt_mask, np.logical_not(child_mask))
|
164 |
+
negative_segments.append(child_mask)
|
165 |
+
|
166 |
+
return gt_mask, [pos_mask], negative_segments
|
167 |
+
|
168 |
+
def sample_points(self):
|
169 |
+
assert self._selected_mask is not None
|
170 |
+
pos_points = self._multi_mask_sample_points(self._selected_masks,
|
171 |
+
is_negative=[False] * len(self._selected_masks),
|
172 |
+
with_first_click=self.first_click_center)
|
173 |
+
|
174 |
+
neg_strategy = [(self._neg_masks[k], prob)
|
175 |
+
for k, prob in zip(self.neg_strategies, self.neg_strategies_prob)]
|
176 |
+
neg_masks = self._neg_masks['required'] + [neg_strategy]
|
177 |
+
neg_points = self._multi_mask_sample_points(neg_masks,
|
178 |
+
is_negative=[False] * len(self._neg_masks['required']) + [True])
|
179 |
+
|
180 |
+
return pos_points + neg_points
|
181 |
+
|
182 |
+
def _multi_mask_sample_points(self, selected_masks, is_negative, with_first_click=False):
|
183 |
+
selected_masks = selected_masks[:self.max_num_points]
|
184 |
+
|
185 |
+
each_obj_points = [
|
186 |
+
self._sample_points(mask, is_negative=is_negative[i],
|
187 |
+
with_first_click=with_first_click)
|
188 |
+
for i, mask in enumerate(selected_masks)
|
189 |
+
]
|
190 |
+
each_obj_points = [x for x in each_obj_points if len(x) > 0]
|
191 |
+
|
192 |
+
points = []
|
193 |
+
if len(each_obj_points) == 1:
|
194 |
+
points = each_obj_points[0]
|
195 |
+
elif len(each_obj_points) > 1:
|
196 |
+
if self.only_one_first_click:
|
197 |
+
each_obj_points = each_obj_points[:1]
|
198 |
+
|
199 |
+
points = [obj_points[0] for obj_points in each_obj_points]
|
200 |
+
|
201 |
+
aggregated_masks_with_prob = []
|
202 |
+
for indx, x in enumerate(selected_masks):
|
203 |
+
if isinstance(x, (list, tuple)) and x and isinstance(x[0], (list, tuple)):
|
204 |
+
for t, prob in x:
|
205 |
+
aggregated_masks_with_prob.append((t, prob / len(selected_masks)))
|
206 |
+
else:
|
207 |
+
aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks)))
|
208 |
+
|
209 |
+
other_points_union = self._sample_points(aggregated_masks_with_prob, is_negative=True)
|
210 |
+
if len(other_points_union) + len(points) <= self.max_num_points:
|
211 |
+
points.extend(other_points_union)
|
212 |
+
else:
|
213 |
+
points.extend(random.sample(other_points_union, self.max_num_points - len(points)))
|
214 |
+
|
215 |
+
if len(points) < self.max_num_points:
|
216 |
+
points.extend([(-1, -1, -1)] * (self.max_num_points - len(points)))
|
217 |
+
|
218 |
+
return points
|
219 |
+
|
220 |
+
def _sample_points(self, mask, is_negative=False, with_first_click=False):
|
221 |
+
if is_negative:
|
222 |
+
num_points = np.random.choice(np.arange(self.max_num_points + 1), p=self._neg_probs)
|
223 |
+
else:
|
224 |
+
num_points = 1 + np.random.choice(np.arange(self.max_num_points), p=self._pos_probs)
|
225 |
+
|
226 |
+
indices_probs = None
|
227 |
+
if isinstance(mask, (list, tuple)):
|
228 |
+
indices_probs = [x[1] for x in mask]
|
229 |
+
indices = [(np.argwhere(x), prob) for x, prob in mask]
|
230 |
+
if indices_probs:
|
231 |
+
assert math.isclose(sum(indices_probs), 1.0)
|
232 |
+
else:
|
233 |
+
indices = np.argwhere(mask)
|
234 |
+
|
235 |
+
points = []
|
236 |
+
for j in range(num_points):
|
237 |
+
first_click = with_first_click and j == 0 and indices_probs is None
|
238 |
+
|
239 |
+
if first_click:
|
240 |
+
point_indices = get_point_candidates(mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob)
|
241 |
+
elif indices_probs:
|
242 |
+
point_indices_indx = np.random.choice(np.arange(len(indices)), p=indices_probs)
|
243 |
+
point_indices = indices[point_indices_indx][0]
|
244 |
+
else:
|
245 |
+
point_indices = indices
|
246 |
+
|
247 |
+
num_indices = len(point_indices)
|
248 |
+
if num_indices > 0:
|
249 |
+
point_indx = 0 if first_click else 100
|
250 |
+
click = point_indices[np.random.randint(0, num_indices)].tolist() + [point_indx]
|
251 |
+
points.append(click)
|
252 |
+
|
253 |
+
return points
|
254 |
+
|
255 |
+
def _positive_erode(self, mask):
|
256 |
+
if random.random() > self.positive_erode_prob:
|
257 |
+
return mask
|
258 |
+
|
259 |
+
kernel = np.ones((3, 3), np.uint8)
|
260 |
+
eroded_mask = cv2.erode(mask.astype(np.uint8),
|
261 |
+
kernel, iterations=self.positive_erode_iters).astype(np.bool)
|
262 |
+
|
263 |
+
if eroded_mask.sum() > 10:
|
264 |
+
return eroded_mask
|
265 |
+
else:
|
266 |
+
return mask
|
267 |
+
|
268 |
+
def _get_border_mask(self, mask):
|
269 |
+
expand_r = int(np.ceil(self.expand_ratio * np.sqrt(mask.sum())))
|
270 |
+
kernel = np.ones((3, 3), np.uint8)
|
271 |
+
expanded_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=expand_r)
|
272 |
+
expanded_mask[mask.astype(np.bool)] = 0
|
273 |
+
return expanded_mask
|
274 |
+
|
275 |
+
|
276 |
+
@lru_cache(maxsize=None)
|
277 |
+
def generate_probs(max_num_points, gamma):
|
278 |
+
probs = []
|
279 |
+
last_value = 1
|
280 |
+
for i in range(max_num_points):
|
281 |
+
probs.append(last_value)
|
282 |
+
last_value *= gamma
|
283 |
+
|
284 |
+
probs = np.array(probs)
|
285 |
+
probs /= probs.sum()
|
286 |
+
|
287 |
+
return probs
|
288 |
+
|
289 |
+
|
290 |
+
def get_point_candidates(obj_mask, k=1.7, full_prob=0.0):
|
291 |
+
if full_prob > 0 and random.random() < full_prob:
|
292 |
+
return obj_mask
|
293 |
+
|
294 |
+
padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), 'constant')
|
295 |
+
|
296 |
+
dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1]
|
297 |
+
if k > 0:
|
298 |
+
inner_mask = dt > dt.max() / k
|
299 |
+
return np.argwhere(inner_mask)
|
300 |
+
else:
|
301 |
+
prob_map = dt.flatten()
|
302 |
+
prob_map /= max(prob_map.sum(), 1e-6)
|
303 |
+
click_indx = np.random.choice(len(prob_map), p=prob_map)
|
304 |
+
click_coords = np.unravel_index(click_indx, dt.shape)
|
305 |
+
return np.array([click_coords])
|
isegm/data/sample.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from copy import deepcopy
|
3 |
+
from isegm.utils.misc import get_labels_with_sizes
|
4 |
+
from isegm.data.transforms import remove_image_only_transforms
|
5 |
+
from albumentations import ReplayCompose
|
6 |
+
|
7 |
+
|
8 |
+
class DSample:
|
9 |
+
def __init__(self, image, encoded_masks, objects=None,
|
10 |
+
objects_ids=None, ignore_ids=None, sample_id=None):
|
11 |
+
self.image = image
|
12 |
+
self.sample_id = sample_id
|
13 |
+
|
14 |
+
if len(encoded_masks.shape) == 2:
|
15 |
+
encoded_masks = encoded_masks[:, :, np.newaxis]
|
16 |
+
self._encoded_masks = encoded_masks
|
17 |
+
self._ignored_regions = []
|
18 |
+
|
19 |
+
if objects_ids is not None:
|
20 |
+
if not objects_ids or not isinstance(objects_ids[0], tuple):
|
21 |
+
assert encoded_masks.shape[2] == 1
|
22 |
+
objects_ids = [(0, obj_id) for obj_id in objects_ids]
|
23 |
+
|
24 |
+
self._objects = dict()
|
25 |
+
for indx, obj_mapping in enumerate(objects_ids):
|
26 |
+
self._objects[indx] = {
|
27 |
+
'parent': None,
|
28 |
+
'mapping': obj_mapping,
|
29 |
+
'children': []
|
30 |
+
}
|
31 |
+
|
32 |
+
if ignore_ids:
|
33 |
+
if isinstance(ignore_ids[0], tuple):
|
34 |
+
self._ignored_regions = ignore_ids
|
35 |
+
else:
|
36 |
+
self._ignored_regions = [(0, region_id) for region_id in ignore_ids]
|
37 |
+
else:
|
38 |
+
self._objects = deepcopy(objects)
|
39 |
+
|
40 |
+
self._augmented = False
|
41 |
+
self._soft_mask_aug = None
|
42 |
+
self._original_data = self.image, self._encoded_masks, deepcopy(self._objects)
|
43 |
+
|
44 |
+
def augment(self, augmentator):
|
45 |
+
self.reset_augmentation()
|
46 |
+
aug_output = augmentator(image=self.image, mask=self._encoded_masks)
|
47 |
+
self.image = aug_output['image']
|
48 |
+
self._encoded_masks = aug_output['mask']
|
49 |
+
|
50 |
+
aug_replay = aug_output.get('replay', None)
|
51 |
+
if aug_replay:
|
52 |
+
assert len(self._ignored_regions) == 0
|
53 |
+
mask_replay = remove_image_only_transforms(aug_replay)
|
54 |
+
self._soft_mask_aug = ReplayCompose._restore_for_replay(mask_replay)
|
55 |
+
|
56 |
+
self._compute_objects_areas()
|
57 |
+
self.remove_small_objects(min_area=1)
|
58 |
+
|
59 |
+
self._augmented = True
|
60 |
+
|
61 |
+
def reset_augmentation(self):
|
62 |
+
if not self._augmented:
|
63 |
+
return
|
64 |
+
orig_image, orig_masks, orig_objects = self._original_data
|
65 |
+
self.image = orig_image
|
66 |
+
self._encoded_masks = orig_masks
|
67 |
+
self._objects = deepcopy(orig_objects)
|
68 |
+
self._augmented = False
|
69 |
+
self._soft_mask_aug = None
|
70 |
+
|
71 |
+
def remove_small_objects(self, min_area):
|
72 |
+
if self._objects and not 'area' in list(self._objects.values())[0]:
|
73 |
+
self._compute_objects_areas()
|
74 |
+
|
75 |
+
for obj_id, obj_info in list(self._objects.items()):
|
76 |
+
if obj_info['area'] < min_area:
|
77 |
+
self._remove_object(obj_id)
|
78 |
+
|
79 |
+
def get_object_mask(self, obj_id):
|
80 |
+
layer_indx, mask_id = self._objects[obj_id]['mapping']
|
81 |
+
obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32)
|
82 |
+
if self._ignored_regions:
|
83 |
+
for layer_indx, mask_id in self._ignored_regions:
|
84 |
+
ignore_mask = self._encoded_masks[:, :, layer_indx] == mask_id
|
85 |
+
obj_mask[ignore_mask] = -1
|
86 |
+
|
87 |
+
return obj_mask
|
88 |
+
|
89 |
+
def get_soft_object_mask(self, obj_id):
|
90 |
+
assert self._soft_mask_aug is not None
|
91 |
+
original_encoded_masks = self._original_data[1]
|
92 |
+
layer_indx, mask_id = self._objects[obj_id]['mapping']
|
93 |
+
obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(np.float32)
|
94 |
+
obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)['image']
|
95 |
+
return np.clip(obj_mask, 0, 1)
|
96 |
+
|
97 |
+
def get_background_mask(self):
|
98 |
+
return np.max(self._encoded_masks, axis=2) == 0
|
99 |
+
|
100 |
+
@property
|
101 |
+
def objects_ids(self):
|
102 |
+
return list(self._objects.keys())
|
103 |
+
|
104 |
+
@property
|
105 |
+
def gt_mask(self):
|
106 |
+
assert len(self._objects) == 1
|
107 |
+
return self.get_object_mask(self.objects_ids[0])
|
108 |
+
|
109 |
+
@property
|
110 |
+
def root_objects(self):
|
111 |
+
return [obj_id for obj_id, obj_info in self._objects.items() if obj_info['parent'] is None]
|
112 |
+
|
113 |
+
def _compute_objects_areas(self):
|
114 |
+
inverse_index = {node['mapping']: node_id for node_id, node in self._objects.items()}
|
115 |
+
ignored_regions_keys = set(self._ignored_regions)
|
116 |
+
|
117 |
+
for layer_indx in range(self._encoded_masks.shape[2]):
|
118 |
+
objects_ids, objects_areas = get_labels_with_sizes(self._encoded_masks[:, :, layer_indx])
|
119 |
+
for obj_id, obj_area in zip(objects_ids, objects_areas):
|
120 |
+
inv_key = (layer_indx, obj_id)
|
121 |
+
if inv_key in ignored_regions_keys:
|
122 |
+
continue
|
123 |
+
try:
|
124 |
+
self._objects[inverse_index[inv_key]]['area'] = obj_area
|
125 |
+
del inverse_index[inv_key]
|
126 |
+
except KeyError:
|
127 |
+
layer = self._encoded_masks[:, :, layer_indx]
|
128 |
+
layer[layer == obj_id] = 0
|
129 |
+
self._encoded_masks[:, :, layer_indx] = layer
|
130 |
+
|
131 |
+
for obj_id in inverse_index.values():
|
132 |
+
self._objects[obj_id]['area'] = 0
|
133 |
+
|
134 |
+
def _remove_object(self, obj_id):
|
135 |
+
obj_info = self._objects[obj_id]
|
136 |
+
obj_parent = obj_info['parent']
|
137 |
+
for child_id in obj_info['children']:
|
138 |
+
self._objects[child_id]['parent'] = obj_parent
|
139 |
+
|
140 |
+
if obj_parent is not None:
|
141 |
+
parent_children = self._objects[obj_parent]['children']
|
142 |
+
parent_children = [x for x in parent_children if x != obj_id]
|
143 |
+
self._objects[obj_parent]['children'] = parent_children + obj_info['children']
|
144 |
+
|
145 |
+
del self._objects[obj_id]
|
146 |
+
|
147 |
+
def __len__(self):
|
148 |
+
return len(self._objects)
|
isegm/data/transforms.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from albumentations.core.serialization import SERIALIZABLE_REGISTRY
|
6 |
+
from albumentations import ImageOnlyTransform, DualTransform
|
7 |
+
from albumentations.core.transforms_interface import to_tuple
|
8 |
+
from albumentations.augmentations import functional as F
|
9 |
+
from isegm.utils.misc import get_bbox_from_mask, expand_bbox, clamp_bbox, get_labels_with_sizes
|
10 |
+
|
11 |
+
|
12 |
+
class UniformRandomResize(DualTransform):
|
13 |
+
def __init__(self, scale_range=(0.9, 1.1), interpolation=cv2.INTER_LINEAR, always_apply=False, p=1):
|
14 |
+
super().__init__(always_apply, p)
|
15 |
+
self.scale_range = scale_range
|
16 |
+
self.interpolation = interpolation
|
17 |
+
|
18 |
+
def get_params_dependent_on_targets(self, params):
|
19 |
+
scale = random.uniform(*self.scale_range)
|
20 |
+
height = int(round(params['image'].shape[0] * scale))
|
21 |
+
width = int(round(params['image'].shape[1] * scale))
|
22 |
+
return {'new_height': height, 'new_width': width}
|
23 |
+
|
24 |
+
def apply(self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params):
|
25 |
+
return F.resize(img, height=new_height, width=new_width, interpolation=interpolation)
|
26 |
+
|
27 |
+
def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params):
|
28 |
+
scale_x = new_width / params["cols"]
|
29 |
+
scale_y = new_height / params["rows"]
|
30 |
+
return F.keypoint_scale(keypoint, scale_x, scale_y)
|
31 |
+
|
32 |
+
def get_transform_init_args_names(self):
|
33 |
+
return "scale_range", "interpolation"
|
34 |
+
|
35 |
+
@property
|
36 |
+
def targets_as_params(self):
|
37 |
+
return ["image"]
|
38 |
+
|
39 |
+
|
40 |
+
class ZoomIn(DualTransform):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
height,
|
44 |
+
width,
|
45 |
+
bbox_jitter=0.1,
|
46 |
+
expansion_ratio=1.4,
|
47 |
+
min_crop_size=200,
|
48 |
+
min_area=100,
|
49 |
+
always_resize=False,
|
50 |
+
always_apply=False,
|
51 |
+
p=0.5,
|
52 |
+
):
|
53 |
+
super(ZoomIn, self).__init__(always_apply, p)
|
54 |
+
self.height = height
|
55 |
+
self.width = width
|
56 |
+
self.bbox_jitter = to_tuple(bbox_jitter)
|
57 |
+
self.expansion_ratio = expansion_ratio
|
58 |
+
self.min_crop_size = min_crop_size
|
59 |
+
self.min_area = min_area
|
60 |
+
self.always_resize = always_resize
|
61 |
+
|
62 |
+
def apply(self, img, selected_object, bbox, **params):
|
63 |
+
if selected_object is None:
|
64 |
+
if self.always_resize:
|
65 |
+
img = F.resize(img, height=self.height, width=self.width)
|
66 |
+
return img
|
67 |
+
|
68 |
+
rmin, rmax, cmin, cmax = bbox
|
69 |
+
img = img[rmin:rmax + 1, cmin:cmax + 1]
|
70 |
+
img = F.resize(img, height=self.height, width=self.width)
|
71 |
+
|
72 |
+
return img
|
73 |
+
|
74 |
+
def apply_to_mask(self, mask, selected_object, bbox, **params):
|
75 |
+
if selected_object is None:
|
76 |
+
if self.always_resize:
|
77 |
+
mask = F.resize(mask, height=self.height, width=self.width,
|
78 |
+
interpolation=cv2.INTER_NEAREST)
|
79 |
+
return mask
|
80 |
+
|
81 |
+
rmin, rmax, cmin, cmax = bbox
|
82 |
+
mask = mask[rmin:rmax + 1, cmin:cmax + 1]
|
83 |
+
if isinstance(selected_object, tuple):
|
84 |
+
layer_indx, mask_id = selected_object
|
85 |
+
obj_mask = mask[:, :, layer_indx] == mask_id
|
86 |
+
new_mask = np.zeros_like(mask)
|
87 |
+
new_mask[:, :, layer_indx][obj_mask] = mask_id
|
88 |
+
else:
|
89 |
+
obj_mask = mask == selected_object
|
90 |
+
new_mask = mask.copy()
|
91 |
+
new_mask[np.logical_not(obj_mask)] = 0
|
92 |
+
|
93 |
+
new_mask = F.resize(new_mask, height=self.height, width=self.width,
|
94 |
+
interpolation=cv2.INTER_NEAREST)
|
95 |
+
return new_mask
|
96 |
+
|
97 |
+
def get_params_dependent_on_targets(self, params):
|
98 |
+
instances = params['mask']
|
99 |
+
|
100 |
+
is_mask_layer = len(instances.shape) > 2
|
101 |
+
candidates = []
|
102 |
+
if is_mask_layer:
|
103 |
+
for layer_indx in range(instances.shape[2]):
|
104 |
+
labels, areas = get_labels_with_sizes(instances[:, :, layer_indx])
|
105 |
+
candidates.extend([(layer_indx, obj_id)
|
106 |
+
for obj_id, area in zip(labels, areas)
|
107 |
+
if area > self.min_area])
|
108 |
+
else:
|
109 |
+
labels, areas = get_labels_with_sizes(instances)
|
110 |
+
candidates = [obj_id for obj_id, area in zip(labels, areas)
|
111 |
+
if area > self.min_area]
|
112 |
+
|
113 |
+
selected_object = None
|
114 |
+
bbox = None
|
115 |
+
if candidates:
|
116 |
+
selected_object = random.choice(candidates)
|
117 |
+
if is_mask_layer:
|
118 |
+
layer_indx, mask_id = selected_object
|
119 |
+
obj_mask = instances[:, :, layer_indx] == mask_id
|
120 |
+
else:
|
121 |
+
obj_mask = instances == selected_object
|
122 |
+
|
123 |
+
bbox = get_bbox_from_mask(obj_mask)
|
124 |
+
|
125 |
+
if isinstance(self.expansion_ratio, tuple):
|
126 |
+
expansion_ratio = random.uniform(*self.expansion_ratio)
|
127 |
+
else:
|
128 |
+
expansion_ratio = self.expansion_ratio
|
129 |
+
|
130 |
+
bbox = expand_bbox(bbox, expansion_ratio, self.min_crop_size)
|
131 |
+
bbox = self._jitter_bbox(bbox)
|
132 |
+
bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1)
|
133 |
+
|
134 |
+
return {
|
135 |
+
'selected_object': selected_object,
|
136 |
+
'bbox': bbox
|
137 |
+
}
|
138 |
+
|
139 |
+
def _jitter_bbox(self, bbox):
|
140 |
+
rmin, rmax, cmin, cmax = bbox
|
141 |
+
height = rmax - rmin + 1
|
142 |
+
width = cmax - cmin + 1
|
143 |
+
rmin = int(rmin + random.uniform(*self.bbox_jitter) * height)
|
144 |
+
rmax = int(rmax + random.uniform(*self.bbox_jitter) * height)
|
145 |
+
cmin = int(cmin + random.uniform(*self.bbox_jitter) * width)
|
146 |
+
cmax = int(cmax + random.uniform(*self.bbox_jitter) * width)
|
147 |
+
|
148 |
+
return rmin, rmax, cmin, cmax
|
149 |
+
|
150 |
+
def apply_to_bbox(self, bbox, **params):
|
151 |
+
raise NotImplementedError
|
152 |
+
|
153 |
+
def apply_to_keypoint(self, keypoint, **params):
|
154 |
+
raise NotImplementedError
|
155 |
+
|
156 |
+
@property
|
157 |
+
def targets_as_params(self):
|
158 |
+
return ["mask"]
|
159 |
+
|
160 |
+
def get_transform_init_args_names(self):
|
161 |
+
return ("height", "width", "bbox_jitter",
|
162 |
+
"expansion_ratio", "min_crop_size", "min_area", "always_resize")
|
163 |
+
|
164 |
+
|
165 |
+
def remove_image_only_transforms(sdict):
|
166 |
+
if not 'transforms' in sdict:
|
167 |
+
return sdict
|
168 |
+
|
169 |
+
keep_transforms = []
|
170 |
+
for tdict in sdict['transforms']:
|
171 |
+
cls = SERIALIZABLE_REGISTRY[tdict['__class_fullname__']]
|
172 |
+
if 'transforms' in tdict:
|
173 |
+
keep_transforms.append(remove_image_only_transforms(tdict))
|
174 |
+
elif not issubclass(cls, ImageOnlyTransform):
|
175 |
+
keep_transforms.append(tdict)
|
176 |
+
sdict['transforms'] = keep_transforms
|
177 |
+
|
178 |
+
return sdict
|
isegm/engine/optimizer.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
from isegm.utils.log import logger
|
4 |
+
|
5 |
+
|
6 |
+
def get_optimizer(model, opt_name, opt_kwargs):
|
7 |
+
params = []
|
8 |
+
base_lr = opt_kwargs['lr']
|
9 |
+
for name, param in model.named_parameters():
|
10 |
+
param_group = {'params': [param]}
|
11 |
+
if not param.requires_grad:
|
12 |
+
params.append(param_group)
|
13 |
+
continue
|
14 |
+
|
15 |
+
if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0):
|
16 |
+
logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.')
|
17 |
+
param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult
|
18 |
+
|
19 |
+
params.append(param_group)
|
20 |
+
|
21 |
+
optimizer = {
|
22 |
+
'sgd': torch.optim.SGD,
|
23 |
+
'adam': torch.optim.Adam,
|
24 |
+
'adamw': torch.optim.AdamW
|
25 |
+
}[opt_name.lower()](params, **opt_kwargs)
|
26 |
+
|
27 |
+
return optimizer
|
isegm/engine/trainer.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import logging
|
4 |
+
from copy import deepcopy
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
|
13 |
+
from isegm.utils.log import logger, TqdmToLogger, SummaryWriterAvg
|
14 |
+
from isegm.utils.vis import draw_probmap, draw_points
|
15 |
+
from isegm.utils.misc import save_checkpoint
|
16 |
+
from isegm.utils.serialization import get_config_repr
|
17 |
+
from isegm.utils.distributed import get_dp_wrapper, get_sampler, reduce_loss_dict
|
18 |
+
from .optimizer import get_optimizer
|
19 |
+
|
20 |
+
|
21 |
+
class ISTrainer(object):
|
22 |
+
def __init__(self, model, cfg, model_cfg, loss_cfg,
|
23 |
+
trainset, valset,
|
24 |
+
optimizer='adam',
|
25 |
+
optimizer_params=None,
|
26 |
+
image_dump_interval=200,
|
27 |
+
checkpoint_interval=10,
|
28 |
+
tb_dump_period=25,
|
29 |
+
max_interactive_points=0,
|
30 |
+
lr_scheduler=None,
|
31 |
+
metrics=None,
|
32 |
+
additional_val_metrics=None,
|
33 |
+
net_inputs=('images', 'points'),
|
34 |
+
max_num_next_clicks=0,
|
35 |
+
click_models=None,
|
36 |
+
prev_mask_drop_prob=0.0,
|
37 |
+
):
|
38 |
+
self.cfg = cfg
|
39 |
+
self.model_cfg = model_cfg
|
40 |
+
self.max_interactive_points = max_interactive_points
|
41 |
+
self.loss_cfg = loss_cfg
|
42 |
+
self.val_loss_cfg = deepcopy(loss_cfg)
|
43 |
+
self.tb_dump_period = tb_dump_period
|
44 |
+
self.net_inputs = net_inputs
|
45 |
+
self.max_num_next_clicks = max_num_next_clicks
|
46 |
+
|
47 |
+
self.click_models = click_models
|
48 |
+
self.prev_mask_drop_prob = prev_mask_drop_prob
|
49 |
+
|
50 |
+
if cfg.distributed:
|
51 |
+
cfg.batch_size //= cfg.ngpus
|
52 |
+
cfg.val_batch_size //= cfg.ngpus
|
53 |
+
|
54 |
+
if metrics is None:
|
55 |
+
metrics = []
|
56 |
+
self.train_metrics = metrics
|
57 |
+
self.val_metrics = deepcopy(metrics)
|
58 |
+
if additional_val_metrics is not None:
|
59 |
+
self.val_metrics.extend(additional_val_metrics)
|
60 |
+
|
61 |
+
self.checkpoint_interval = checkpoint_interval
|
62 |
+
self.image_dump_interval = image_dump_interval
|
63 |
+
self.task_prefix = ''
|
64 |
+
self.sw = None
|
65 |
+
|
66 |
+
self.trainset = trainset
|
67 |
+
self.valset = valset
|
68 |
+
|
69 |
+
logger.info(f'Dataset of {trainset.get_samples_number()} samples was loaded for training.')
|
70 |
+
logger.info(f'Dataset of {valset.get_samples_number()} samples was loaded for validation.')
|
71 |
+
|
72 |
+
self.train_data = DataLoader(
|
73 |
+
trainset, cfg.batch_size,
|
74 |
+
sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed),
|
75 |
+
drop_last=True, pin_memory=True,
|
76 |
+
num_workers=cfg.workers
|
77 |
+
)
|
78 |
+
|
79 |
+
self.val_data = DataLoader(
|
80 |
+
valset, cfg.val_batch_size,
|
81 |
+
sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed),
|
82 |
+
drop_last=True, pin_memory=True,
|
83 |
+
num_workers=cfg.workers
|
84 |
+
)
|
85 |
+
|
86 |
+
self.optim = get_optimizer(model, optimizer, optimizer_params)
|
87 |
+
model = self._load_weights(model)
|
88 |
+
|
89 |
+
if cfg.multi_gpu:
|
90 |
+
model = get_dp_wrapper(cfg.distributed)(model, device_ids=cfg.gpu_ids,
|
91 |
+
output_device=cfg.gpu_ids[0])
|
92 |
+
|
93 |
+
if self.is_master:
|
94 |
+
logger.info(model)
|
95 |
+
logger.info(get_config_repr(model._config))
|
96 |
+
|
97 |
+
self.device = cfg.device
|
98 |
+
self.net = model.to(self.device)
|
99 |
+
self.lr = optimizer_params['lr']
|
100 |
+
|
101 |
+
if lr_scheduler is not None:
|
102 |
+
self.lr_scheduler = lr_scheduler(optimizer=self.optim)
|
103 |
+
if cfg.start_epoch > 0:
|
104 |
+
for _ in range(cfg.start_epoch):
|
105 |
+
self.lr_scheduler.step()
|
106 |
+
|
107 |
+
self.tqdm_out = TqdmToLogger(logger, level=logging.INFO)
|
108 |
+
|
109 |
+
if self.click_models is not None:
|
110 |
+
for click_model in self.click_models:
|
111 |
+
for param in click_model.parameters():
|
112 |
+
param.requires_grad = False
|
113 |
+
click_model.to(self.device)
|
114 |
+
click_model.eval()
|
115 |
+
|
116 |
+
def run(self, num_epochs, start_epoch=None, validation=True):
|
117 |
+
if start_epoch is None:
|
118 |
+
start_epoch = self.cfg.start_epoch
|
119 |
+
|
120 |
+
logger.info(f'Starting Epoch: {start_epoch}')
|
121 |
+
logger.info(f'Total Epochs: {num_epochs}')
|
122 |
+
for epoch in range(start_epoch, num_epochs):
|
123 |
+
self.training(epoch)
|
124 |
+
if validation:
|
125 |
+
self.validation(epoch)
|
126 |
+
|
127 |
+
def training(self, epoch):
|
128 |
+
if self.sw is None and self.is_master:
|
129 |
+
self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH),
|
130 |
+
flush_secs=10, dump_period=self.tb_dump_period)
|
131 |
+
|
132 |
+
if self.cfg.distributed:
|
133 |
+
self.train_data.sampler.set_epoch(epoch)
|
134 |
+
|
135 |
+
log_prefix = 'Train' + self.task_prefix.capitalize()
|
136 |
+
tbar = tqdm(self.train_data, file=self.tqdm_out, ncols=100)\
|
137 |
+
if self.is_master else self.train_data
|
138 |
+
|
139 |
+
for metric in self.train_metrics:
|
140 |
+
metric.reset_epoch_stats()
|
141 |
+
|
142 |
+
self.net.train()
|
143 |
+
train_loss = 0.0
|
144 |
+
for i, batch_data in enumerate(tbar):
|
145 |
+
global_step = epoch * len(self.train_data) + i
|
146 |
+
|
147 |
+
loss, losses_logging, splitted_batch_data, outputs = \
|
148 |
+
self.batch_forward(batch_data)
|
149 |
+
|
150 |
+
self.optim.zero_grad()
|
151 |
+
loss.backward()
|
152 |
+
self.optim.step()
|
153 |
+
|
154 |
+
losses_logging['overall'] = loss
|
155 |
+
reduce_loss_dict(losses_logging)
|
156 |
+
|
157 |
+
train_loss += losses_logging['overall'].item()
|
158 |
+
|
159 |
+
if self.is_master:
|
160 |
+
for loss_name, loss_value in losses_logging.items():
|
161 |
+
self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}',
|
162 |
+
value=loss_value.item(),
|
163 |
+
global_step=global_step)
|
164 |
+
|
165 |
+
for k, v in self.loss_cfg.items():
|
166 |
+
if '_loss' in k and hasattr(v, 'log_states') and self.loss_cfg.get(k + '_weight', 0.0) > 0:
|
167 |
+
v.log_states(self.sw, f'{log_prefix}Losses/{k}', global_step)
|
168 |
+
|
169 |
+
if self.image_dump_interval > 0 and global_step % self.image_dump_interval == 0:
|
170 |
+
self.save_visualization(splitted_batch_data, outputs, global_step, prefix='train')
|
171 |
+
|
172 |
+
self.sw.add_scalar(tag=f'{log_prefix}States/learning_rate',
|
173 |
+
value=self.lr if not hasattr(self, 'lr_scheduler') else self.lr_scheduler.get_lr()[-1],
|
174 |
+
global_step=global_step)
|
175 |
+
|
176 |
+
tbar.set_description(f'Epoch {epoch}, training loss {train_loss/(i+1):.4f}')
|
177 |
+
for metric in self.train_metrics:
|
178 |
+
metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step)
|
179 |
+
|
180 |
+
if self.is_master:
|
181 |
+
for metric in self.train_metrics:
|
182 |
+
self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}',
|
183 |
+
value=metric.get_epoch_value(),
|
184 |
+
global_step=epoch, disable_avg=True)
|
185 |
+
|
186 |
+
save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix,
|
187 |
+
epoch=None, multi_gpu=self.cfg.multi_gpu)
|
188 |
+
|
189 |
+
if isinstance(self.checkpoint_interval, (list, tuple)):
|
190 |
+
checkpoint_interval = [x for x in self.checkpoint_interval if x[0] <= epoch][-1][1]
|
191 |
+
else:
|
192 |
+
checkpoint_interval = self.checkpoint_interval
|
193 |
+
|
194 |
+
if epoch % checkpoint_interval == 0:
|
195 |
+
save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix,
|
196 |
+
epoch=epoch, multi_gpu=self.cfg.multi_gpu)
|
197 |
+
|
198 |
+
if hasattr(self, 'lr_scheduler'):
|
199 |
+
self.lr_scheduler.step()
|
200 |
+
|
201 |
+
def validation(self, epoch):
|
202 |
+
if self.sw is None and self.is_master:
|
203 |
+
self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH),
|
204 |
+
flush_secs=10, dump_period=self.tb_dump_period)
|
205 |
+
|
206 |
+
log_prefix = 'Val' + self.task_prefix.capitalize()
|
207 |
+
tbar = tqdm(self.val_data, file=self.tqdm_out, ncols=100) if self.is_master else self.val_data
|
208 |
+
|
209 |
+
for metric in self.val_metrics:
|
210 |
+
metric.reset_epoch_stats()
|
211 |
+
|
212 |
+
val_loss = 0
|
213 |
+
losses_logging = defaultdict(list)
|
214 |
+
|
215 |
+
self.net.eval()
|
216 |
+
for i, batch_data in enumerate(tbar):
|
217 |
+
global_step = epoch * len(self.val_data) + i
|
218 |
+
loss, batch_losses_logging, splitted_batch_data, outputs = \
|
219 |
+
self.batch_forward(batch_data, validation=True)
|
220 |
+
|
221 |
+
batch_losses_logging['overall'] = loss
|
222 |
+
reduce_loss_dict(batch_losses_logging)
|
223 |
+
for loss_name, loss_value in batch_losses_logging.items():
|
224 |
+
losses_logging[loss_name].append(loss_value.item())
|
225 |
+
|
226 |
+
val_loss += batch_losses_logging['overall'].item()
|
227 |
+
|
228 |
+
if self.is_master:
|
229 |
+
tbar.set_description(f'Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}')
|
230 |
+
for metric in self.val_metrics:
|
231 |
+
metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step)
|
232 |
+
|
233 |
+
if self.is_master:
|
234 |
+
for loss_name, loss_values in losses_logging.items():
|
235 |
+
self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', value=np.array(loss_values).mean(),
|
236 |
+
global_step=epoch, disable_avg=True)
|
237 |
+
|
238 |
+
for metric in self.val_metrics:
|
239 |
+
self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', value=metric.get_epoch_value(),
|
240 |
+
global_step=epoch, disable_avg=True)
|
241 |
+
|
242 |
+
def batch_forward(self, batch_data, validation=False):
|
243 |
+
metrics = self.val_metrics if validation else self.train_metrics
|
244 |
+
losses_logging = dict()
|
245 |
+
|
246 |
+
with torch.set_grad_enabled(not validation):
|
247 |
+
batch_data = {k: v.to(self.device) for k, v in batch_data.items()}
|
248 |
+
image, gt_mask, points = batch_data['images'], batch_data['instances'], batch_data['points']
|
249 |
+
orig_image, orig_gt_mask, orig_points = image.clone(), gt_mask.clone(), points.clone()
|
250 |
+
|
251 |
+
prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :]
|
252 |
+
|
253 |
+
last_click_indx = None
|
254 |
+
|
255 |
+
with torch.no_grad():
|
256 |
+
num_iters = random.randint(0, self.max_num_next_clicks)
|
257 |
+
|
258 |
+
for click_indx in range(num_iters):
|
259 |
+
last_click_indx = click_indx
|
260 |
+
|
261 |
+
if not validation:
|
262 |
+
self.net.eval()
|
263 |
+
|
264 |
+
if self.click_models is None or click_indx >= len(self.click_models):
|
265 |
+
eval_model = self.net
|
266 |
+
else:
|
267 |
+
eval_model = self.click_models[click_indx]
|
268 |
+
|
269 |
+
net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image
|
270 |
+
prev_output = torch.sigmoid(eval_model(net_input, points)['instances'])
|
271 |
+
|
272 |
+
points = get_next_points(prev_output, orig_gt_mask, points, click_indx + 1)
|
273 |
+
|
274 |
+
if not validation:
|
275 |
+
self.net.train()
|
276 |
+
|
277 |
+
if self.net.with_prev_mask and self.prev_mask_drop_prob > 0 and last_click_indx is not None:
|
278 |
+
zero_mask = np.random.random(size=prev_output.size(0)) < self.prev_mask_drop_prob
|
279 |
+
prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask])
|
280 |
+
|
281 |
+
batch_data['points'] = points
|
282 |
+
|
283 |
+
net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image
|
284 |
+
output = self.net(net_input, points)
|
285 |
+
|
286 |
+
loss = 0.0
|
287 |
+
loss = self.add_loss('instance_loss', loss, losses_logging, validation,
|
288 |
+
lambda: (output['instances'], batch_data['instances']))
|
289 |
+
loss = self.add_loss('instance_aux_loss', loss, losses_logging, validation,
|
290 |
+
lambda: (output['instances_aux'], batch_data['instances']))
|
291 |
+
|
292 |
+
if self.is_master:
|
293 |
+
with torch.no_grad():
|
294 |
+
for m in metrics:
|
295 |
+
m.update(*(output.get(x) for x in m.pred_outputs),
|
296 |
+
*(batch_data[x] for x in m.gt_outputs))
|
297 |
+
return loss, losses_logging, batch_data, output
|
298 |
+
|
299 |
+
def add_loss(self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs):
|
300 |
+
loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg
|
301 |
+
loss_weight = loss_cfg.get(loss_name + '_weight', 0.0)
|
302 |
+
if loss_weight > 0.0:
|
303 |
+
loss_criterion = loss_cfg.get(loss_name)
|
304 |
+
loss = loss_criterion(*lambda_loss_inputs())
|
305 |
+
loss = torch.mean(loss)
|
306 |
+
losses_logging[loss_name] = loss
|
307 |
+
loss = loss_weight * loss
|
308 |
+
total_loss = total_loss + loss
|
309 |
+
|
310 |
+
return total_loss
|
311 |
+
|
312 |
+
def save_visualization(self, splitted_batch_data, outputs, global_step, prefix):
|
313 |
+
output_images_path = self.cfg.VIS_PATH / prefix
|
314 |
+
if self.task_prefix:
|
315 |
+
output_images_path /= self.task_prefix
|
316 |
+
|
317 |
+
if not output_images_path.exists():
|
318 |
+
output_images_path.mkdir(parents=True)
|
319 |
+
image_name_prefix = f'{global_step:06d}'
|
320 |
+
|
321 |
+
def _save_image(suffix, image):
|
322 |
+
cv2.imwrite(str(output_images_path / f'{image_name_prefix}_{suffix}.jpg'),
|
323 |
+
image, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
324 |
+
|
325 |
+
images = splitted_batch_data['images']
|
326 |
+
points = splitted_batch_data['points']
|
327 |
+
instance_masks = splitted_batch_data['instances']
|
328 |
+
|
329 |
+
gt_instance_masks = instance_masks.cpu().numpy()
|
330 |
+
predicted_instance_masks = torch.sigmoid(outputs['instances']).detach().cpu().numpy()
|
331 |
+
points = points.detach().cpu().numpy()
|
332 |
+
|
333 |
+
image_blob, points = images[0], points[0]
|
334 |
+
gt_mask = np.squeeze(gt_instance_masks[0], axis=0)
|
335 |
+
predicted_mask = np.squeeze(predicted_instance_masks[0], axis=0)
|
336 |
+
|
337 |
+
image = image_blob.cpu().numpy() * 255
|
338 |
+
image = image.transpose((1, 2, 0))
|
339 |
+
|
340 |
+
image_with_points = draw_points(image, points[:self.max_interactive_points], (0, 255, 0))
|
341 |
+
image_with_points = draw_points(image_with_points, points[self.max_interactive_points:], (0, 0, 255))
|
342 |
+
|
343 |
+
gt_mask[gt_mask < 0] = 0.25
|
344 |
+
gt_mask = draw_probmap(gt_mask)
|
345 |
+
predicted_mask = draw_probmap(predicted_mask)
|
346 |
+
viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(np.uint8)
|
347 |
+
|
348 |
+
_save_image('instance_segmentation', viz_image[:, :, ::-1])
|
349 |
+
|
350 |
+
def _load_weights(self, net):
|
351 |
+
if self.cfg.weights is not None:
|
352 |
+
if os.path.isfile(self.cfg.weights):
|
353 |
+
load_weights(net, self.cfg.weights)
|
354 |
+
self.cfg.weights = None
|
355 |
+
else:
|
356 |
+
raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'")
|
357 |
+
elif self.cfg.resume_exp is not None:
|
358 |
+
checkpoints = list(self.cfg.CHECKPOINTS_PATH.glob(f'{self.cfg.resume_prefix}*.pth'))
|
359 |
+
assert len(checkpoints) == 1
|
360 |
+
|
361 |
+
checkpoint_path = checkpoints[0]
|
362 |
+
logger.info(f'Load checkpoint from path: {checkpoint_path}')
|
363 |
+
load_weights(net, str(checkpoint_path))
|
364 |
+
return net
|
365 |
+
|
366 |
+
@property
|
367 |
+
def is_master(self):
|
368 |
+
return self.cfg.local_rank == 0
|
369 |
+
|
370 |
+
|
371 |
+
def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49):
|
372 |
+
assert click_indx > 0
|
373 |
+
pred = pred.cpu().numpy()[:, 0, :, :]
|
374 |
+
gt = gt.cpu().numpy()[:, 0, :, :] > 0.5
|
375 |
+
|
376 |
+
fn_mask = np.logical_and(gt, pred < pred_thresh)
|
377 |
+
fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh)
|
378 |
+
|
379 |
+
fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8)
|
380 |
+
fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8)
|
381 |
+
num_points = points.size(1) // 2
|
382 |
+
points = points.clone()
|
383 |
+
|
384 |
+
for bindx in range(fn_mask.shape[0]):
|
385 |
+
fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1]
|
386 |
+
fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1]
|
387 |
+
|
388 |
+
fn_max_dist = np.max(fn_mask_dt)
|
389 |
+
fp_max_dist = np.max(fp_mask_dt)
|
390 |
+
|
391 |
+
is_positive = fn_max_dist > fp_max_dist
|
392 |
+
dt = fn_mask_dt if is_positive else fp_mask_dt
|
393 |
+
inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0
|
394 |
+
indices = np.argwhere(inner_mask)
|
395 |
+
if len(indices) > 0:
|
396 |
+
coords = indices[np.random.randint(0, len(indices))]
|
397 |
+
if is_positive:
|
398 |
+
points[bindx, num_points - click_indx, 0] = float(coords[0])
|
399 |
+
points[bindx, num_points - click_indx, 1] = float(coords[1])
|
400 |
+
points[bindx, num_points - click_indx, 2] = float(click_indx)
|
401 |
+
else:
|
402 |
+
points[bindx, 2 * num_points - click_indx, 0] = float(coords[0])
|
403 |
+
points[bindx, 2 * num_points - click_indx, 1] = float(coords[1])
|
404 |
+
points[bindx, 2 * num_points - click_indx, 2] = float(click_indx)
|
405 |
+
|
406 |
+
return points
|
407 |
+
|
408 |
+
|
409 |
+
def load_weights(model, path_to_weights):
|
410 |
+
current_state_dict = model.state_dict()
|
411 |
+
new_state_dict = torch.load(path_to_weights, map_location='cpu')['state_dict']
|
412 |
+
current_state_dict.update(new_state_dict)
|
413 |
+
model.load_state_dict(current_state_dict)
|
isegm/inference/__init__.py
ADDED
File without changes
|
isegm/inference/clicker.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from copy import deepcopy
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
class Clicker(object):
|
7 |
+
def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0):
|
8 |
+
self.click_indx_offset = click_indx_offset
|
9 |
+
if gt_mask is not None:
|
10 |
+
self.gt_mask = gt_mask == 1
|
11 |
+
self.not_ignore_mask = gt_mask != ignore_label
|
12 |
+
else:
|
13 |
+
self.gt_mask = None
|
14 |
+
|
15 |
+
self.reset_clicks()
|
16 |
+
|
17 |
+
if init_clicks is not None:
|
18 |
+
for click in init_clicks:
|
19 |
+
self.add_click(click)
|
20 |
+
|
21 |
+
def make_next_click(self, pred_mask):
|
22 |
+
assert self.gt_mask is not None
|
23 |
+
click = self._get_next_click(pred_mask)
|
24 |
+
self.add_click(click)
|
25 |
+
|
26 |
+
def get_clicks(self, clicks_limit=None):
|
27 |
+
return self.clicks_list[:clicks_limit]
|
28 |
+
|
29 |
+
def _get_next_click(self, pred_mask, padding=True):
|
30 |
+
fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask)
|
31 |
+
fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask)
|
32 |
+
|
33 |
+
if padding:
|
34 |
+
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant')
|
35 |
+
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant')
|
36 |
+
|
37 |
+
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
|
38 |
+
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
|
39 |
+
|
40 |
+
if padding:
|
41 |
+
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
|
42 |
+
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
|
43 |
+
|
44 |
+
fn_mask_dt = fn_mask_dt * self.not_clicked_map
|
45 |
+
fp_mask_dt = fp_mask_dt * self.not_clicked_map
|
46 |
+
|
47 |
+
fn_max_dist = np.max(fn_mask_dt)
|
48 |
+
fp_max_dist = np.max(fp_mask_dt)
|
49 |
+
|
50 |
+
is_positive = fn_max_dist > fp_max_dist
|
51 |
+
if is_positive:
|
52 |
+
coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x]
|
53 |
+
else:
|
54 |
+
coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x]
|
55 |
+
|
56 |
+
return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0]))
|
57 |
+
|
58 |
+
def add_click(self, click):
|
59 |
+
coords = click.coords
|
60 |
+
|
61 |
+
click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks
|
62 |
+
if click.is_positive:
|
63 |
+
self.num_pos_clicks += 1
|
64 |
+
else:
|
65 |
+
self.num_neg_clicks += 1
|
66 |
+
|
67 |
+
self.clicks_list.append(click)
|
68 |
+
if self.gt_mask is not None:
|
69 |
+
self.not_clicked_map[coords[0], coords[1]] = False
|
70 |
+
|
71 |
+
def _remove_last_click(self):
|
72 |
+
click = self.clicks_list.pop()
|
73 |
+
coords = click.coords
|
74 |
+
|
75 |
+
if click.is_positive:
|
76 |
+
self.num_pos_clicks -= 1
|
77 |
+
else:
|
78 |
+
self.num_neg_clicks -= 1
|
79 |
+
|
80 |
+
if self.gt_mask is not None:
|
81 |
+
self.not_clicked_map[coords[0], coords[1]] = True
|
82 |
+
|
83 |
+
def reset_clicks(self):
|
84 |
+
if self.gt_mask is not None:
|
85 |
+
self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool)
|
86 |
+
|
87 |
+
self.num_pos_clicks = 0
|
88 |
+
self.num_neg_clicks = 0
|
89 |
+
|
90 |
+
self.clicks_list = []
|
91 |
+
|
92 |
+
def get_state(self):
|
93 |
+
return deepcopy(self.clicks_list)
|
94 |
+
|
95 |
+
def set_state(self, state):
|
96 |
+
self.reset_clicks()
|
97 |
+
for click in state:
|
98 |
+
self.add_click(click)
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.clicks_list)
|
102 |
+
|
103 |
+
|
104 |
+
class Click:
|
105 |
+
def __init__(self, is_positive, coords, indx=None):
|
106 |
+
self.is_positive = is_positive
|
107 |
+
self.coords = coords
|
108 |
+
self.indx = indx
|
109 |
+
|
110 |
+
@property
|
111 |
+
def coords_and_indx(self):
|
112 |
+
return (*self.coords, self.indx)
|
113 |
+
|
114 |
+
def copy(self, **kwargs):
|
115 |
+
self_copy = deepcopy(self)
|
116 |
+
for k, v in kwargs.items():
|
117 |
+
setattr(self_copy, k, v)
|
118 |
+
return self_copy
|
isegm/inference/evaluation.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import time
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from isegm.inference import utils
|
7 |
+
from isegm.inference.clicker import Clicker
|
8 |
+
|
9 |
+
try:
|
10 |
+
get_ipython()
|
11 |
+
from tqdm import tqdm_notebook as tqdm
|
12 |
+
except NameError:
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
def evaluate_dataset(dataset, predictor, **kwargs):
|
17 |
+
all_ious = []
|
18 |
+
|
19 |
+
start_time = time()
|
20 |
+
for index in tqdm(range(len(dataset)), leave=False):
|
21 |
+
sample = dataset.get_sample(index)
|
22 |
+
|
23 |
+
_, sample_ious, _ = evaluate_sample(sample.image, sample.gt_mask, predictor,
|
24 |
+
sample_id=index, **kwargs)
|
25 |
+
all_ious.append(sample_ious)
|
26 |
+
end_time = time()
|
27 |
+
elapsed_time = end_time - start_time
|
28 |
+
|
29 |
+
return all_ious, elapsed_time
|
30 |
+
|
31 |
+
|
32 |
+
def evaluate_sample(image, gt_mask, predictor, max_iou_thr,
|
33 |
+
pred_thr=0.49, min_clicks=1, max_clicks=20,
|
34 |
+
sample_id=None, callback=None):
|
35 |
+
clicker = Clicker(gt_mask=gt_mask)
|
36 |
+
pred_mask = np.zeros_like(gt_mask)
|
37 |
+
ious_list = []
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
predictor.set_input_image(image)
|
41 |
+
|
42 |
+
for click_indx in range(max_clicks):
|
43 |
+
clicker.make_next_click(pred_mask)
|
44 |
+
pred_probs = predictor.get_prediction(clicker)
|
45 |
+
pred_mask = pred_probs > pred_thr
|
46 |
+
|
47 |
+
if callback is not None:
|
48 |
+
callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list)
|
49 |
+
|
50 |
+
iou = utils.get_iou(gt_mask, pred_mask)
|
51 |
+
ious_list.append(iou)
|
52 |
+
|
53 |
+
if iou >= max_iou_thr and click_indx + 1 >= min_clicks:
|
54 |
+
break
|
55 |
+
|
56 |
+
return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs
|
isegm/inference/predictors/__init__.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import BasePredictor
|
2 |
+
from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor
|
3 |
+
from .brs_functors import InputOptimizer, ScaleBiasOptimizer
|
4 |
+
from isegm.inference.transforms import ZoomIn
|
5 |
+
from isegm.model.is_hrnet_model import HRNetModel
|
6 |
+
|
7 |
+
|
8 |
+
def get_predictor(net, brs_mode, device,
|
9 |
+
prob_thresh=0.49,
|
10 |
+
with_flip=True,
|
11 |
+
zoom_in_params=dict(),
|
12 |
+
predictor_params=None,
|
13 |
+
brs_opt_func_params=None,
|
14 |
+
lbfgs_params=None):
|
15 |
+
lbfgs_params_ = {
|
16 |
+
'm': 20,
|
17 |
+
'factr': 0,
|
18 |
+
'pgtol': 1e-8,
|
19 |
+
'maxfun': 20,
|
20 |
+
}
|
21 |
+
|
22 |
+
predictor_params_ = {
|
23 |
+
'optimize_after_n_clicks': 1
|
24 |
+
}
|
25 |
+
|
26 |
+
if zoom_in_params is not None:
|
27 |
+
zoom_in = ZoomIn(**zoom_in_params)
|
28 |
+
else:
|
29 |
+
zoom_in = None
|
30 |
+
|
31 |
+
if lbfgs_params is not None:
|
32 |
+
lbfgs_params_.update(lbfgs_params)
|
33 |
+
lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun']
|
34 |
+
|
35 |
+
if brs_opt_func_params is None:
|
36 |
+
brs_opt_func_params = dict()
|
37 |
+
|
38 |
+
if isinstance(net, (list, tuple)):
|
39 |
+
assert brs_mode == 'NoBRS', "Multi-stage models support only NoBRS mode."
|
40 |
+
|
41 |
+
if brs_mode == 'NoBRS':
|
42 |
+
if predictor_params is not None:
|
43 |
+
predictor_params_.update(predictor_params)
|
44 |
+
predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_)
|
45 |
+
elif brs_mode.startswith('f-BRS'):
|
46 |
+
predictor_params_.update({
|
47 |
+
'net_clicks_limit': 8,
|
48 |
+
})
|
49 |
+
if predictor_params is not None:
|
50 |
+
predictor_params_.update(predictor_params)
|
51 |
+
|
52 |
+
insertion_mode = {
|
53 |
+
'f-BRS-A': 'after_c4',
|
54 |
+
'f-BRS-B': 'after_aspp',
|
55 |
+
'f-BRS-C': 'after_deeplab'
|
56 |
+
}[brs_mode]
|
57 |
+
|
58 |
+
opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh,
|
59 |
+
with_flip=with_flip,
|
60 |
+
optimizer_params=lbfgs_params_,
|
61 |
+
**brs_opt_func_params)
|
62 |
+
|
63 |
+
if isinstance(net, HRNetModel):
|
64 |
+
FeaturePredictor = HRNetFeatureBRSPredictor
|
65 |
+
insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode]
|
66 |
+
else:
|
67 |
+
FeaturePredictor = FeatureBRSPredictor
|
68 |
+
|
69 |
+
predictor = FeaturePredictor(net, device,
|
70 |
+
opt_functor=opt_functor,
|
71 |
+
with_flip=with_flip,
|
72 |
+
insertion_mode=insertion_mode,
|
73 |
+
zoom_in=zoom_in,
|
74 |
+
**predictor_params_)
|
75 |
+
elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS':
|
76 |
+
use_dmaps = brs_mode == 'DistMap-BRS'
|
77 |
+
|
78 |
+
predictor_params_.update({
|
79 |
+
'net_clicks_limit': 5,
|
80 |
+
})
|
81 |
+
if predictor_params is not None:
|
82 |
+
predictor_params_.update(predictor_params)
|
83 |
+
|
84 |
+
opt_functor = InputOptimizer(prob_thresh=prob_thresh,
|
85 |
+
with_flip=with_flip,
|
86 |
+
optimizer_params=lbfgs_params_,
|
87 |
+
**brs_opt_func_params)
|
88 |
+
|
89 |
+
predictor = InputBRSPredictor(net, device,
|
90 |
+
optimize_target='dmaps' if use_dmaps else 'rgb',
|
91 |
+
opt_functor=opt_functor,
|
92 |
+
with_flip=with_flip,
|
93 |
+
zoom_in=zoom_in,
|
94 |
+
**predictor_params_)
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
return predictor
|
isegm/inference/predictors/base.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torchvision import transforms
|
4 |
+
from isegm.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide
|
5 |
+
|
6 |
+
|
7 |
+
class BasePredictor(object):
|
8 |
+
def __init__(self, model, device,
|
9 |
+
net_clicks_limit=None,
|
10 |
+
with_flip=False,
|
11 |
+
zoom_in=None,
|
12 |
+
max_size=None,
|
13 |
+
**kwargs):
|
14 |
+
self.with_flip = with_flip
|
15 |
+
self.net_clicks_limit = net_clicks_limit
|
16 |
+
self.original_image = None
|
17 |
+
self.device = device
|
18 |
+
self.zoom_in = zoom_in
|
19 |
+
self.prev_prediction = None
|
20 |
+
self.model_indx = 0
|
21 |
+
self.click_models = None
|
22 |
+
self.net_state_dict = None
|
23 |
+
|
24 |
+
if isinstance(model, tuple):
|
25 |
+
self.net, self.click_models = model
|
26 |
+
else:
|
27 |
+
self.net = model
|
28 |
+
|
29 |
+
self.to_tensor = transforms.ToTensor()
|
30 |
+
|
31 |
+
self.transforms = [zoom_in] if zoom_in is not None else []
|
32 |
+
if max_size is not None:
|
33 |
+
self.transforms.append(LimitLongestSide(max_size=max_size))
|
34 |
+
self.transforms.append(SigmoidForPred())
|
35 |
+
if with_flip:
|
36 |
+
self.transforms.append(AddHorizontalFlip())
|
37 |
+
|
38 |
+
def set_input_image(self, image):
|
39 |
+
image_nd = self.to_tensor(image)
|
40 |
+
for transform in self.transforms:
|
41 |
+
transform.reset()
|
42 |
+
self.original_image = image_nd.to(self.device)
|
43 |
+
if len(self.original_image.shape) == 3:
|
44 |
+
self.original_image = self.original_image.unsqueeze(0)
|
45 |
+
self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :])
|
46 |
+
|
47 |
+
def get_prediction(self, clicker, prev_mask=None):
|
48 |
+
clicks_list = clicker.get_clicks()
|
49 |
+
|
50 |
+
if self.click_models is not None:
|
51 |
+
model_indx = min(clicker.click_indx_offset + len(clicks_list), len(self.click_models)) - 1
|
52 |
+
if model_indx != self.model_indx:
|
53 |
+
self.model_indx = model_indx
|
54 |
+
self.net = self.click_models[model_indx]
|
55 |
+
|
56 |
+
input_image = self.original_image
|
57 |
+
if prev_mask is None:
|
58 |
+
prev_mask = self.prev_prediction
|
59 |
+
if hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask:
|
60 |
+
input_image = torch.cat((input_image, prev_mask), dim=1)
|
61 |
+
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
|
62 |
+
input_image, [clicks_list]
|
63 |
+
)
|
64 |
+
|
65 |
+
pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
|
66 |
+
prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True,
|
67 |
+
size=image_nd.size()[2:])
|
68 |
+
|
69 |
+
for t in reversed(self.transforms):
|
70 |
+
prediction = t.inv_transform(prediction)
|
71 |
+
|
72 |
+
if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
|
73 |
+
return self.get_prediction(clicker)
|
74 |
+
|
75 |
+
self.prev_prediction = prediction
|
76 |
+
return prediction.cpu().numpy()[0, 0]
|
77 |
+
|
78 |
+
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
|
79 |
+
points_nd = self.get_points_nd(clicks_lists)
|
80 |
+
return self.net(image_nd, points_nd)['instances']
|
81 |
+
|
82 |
+
def _get_transform_states(self):
|
83 |
+
return [x.get_state() for x in self.transforms]
|
84 |
+
|
85 |
+
def _set_transform_states(self, states):
|
86 |
+
assert len(states) == len(self.transforms)
|
87 |
+
for state, transform in zip(states, self.transforms):
|
88 |
+
transform.set_state(state)
|
89 |
+
|
90 |
+
def apply_transforms(self, image_nd, clicks_lists):
|
91 |
+
is_image_changed = False
|
92 |
+
for t in self.transforms:
|
93 |
+
image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
|
94 |
+
is_image_changed |= t.image_changed
|
95 |
+
|
96 |
+
return image_nd, clicks_lists, is_image_changed
|
97 |
+
|
98 |
+
def get_points_nd(self, clicks_lists):
|
99 |
+
total_clicks = []
|
100 |
+
num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
|
101 |
+
num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
|
102 |
+
num_max_points = max(num_pos_clicks + num_neg_clicks)
|
103 |
+
if self.net_clicks_limit is not None:
|
104 |
+
num_max_points = min(self.net_clicks_limit, num_max_points)
|
105 |
+
num_max_points = max(1, num_max_points)
|
106 |
+
|
107 |
+
for clicks_list in clicks_lists:
|
108 |
+
clicks_list = clicks_list[:self.net_clicks_limit]
|
109 |
+
pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
|
110 |
+
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
|
111 |
+
|
112 |
+
neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
|
113 |
+
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
|
114 |
+
total_clicks.append(pos_clicks + neg_clicks)
|
115 |
+
|
116 |
+
return torch.tensor(total_clicks, device=self.device)
|
117 |
+
|
118 |
+
def get_states(self):
|
119 |
+
return {
|
120 |
+
'transform_states': self._get_transform_states(),
|
121 |
+
'prev_prediction': self.prev_prediction.clone()
|
122 |
+
}
|
123 |
+
|
124 |
+
def set_states(self, states):
|
125 |
+
self._set_transform_states(states['transform_states'])
|
126 |
+
self.prev_prediction = states['prev_prediction']
|
isegm/inference/predictors/brs.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from scipy.optimize import fmin_l_bfgs_b
|
5 |
+
|
6 |
+
from .base import BasePredictor
|
7 |
+
|
8 |
+
|
9 |
+
class BRSBasePredictor(BasePredictor):
|
10 |
+
def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs):
|
11 |
+
super().__init__(model, device, **kwargs)
|
12 |
+
self.optimize_after_n_clicks = optimize_after_n_clicks
|
13 |
+
self.opt_functor = opt_functor
|
14 |
+
|
15 |
+
self.opt_data = None
|
16 |
+
self.input_data = None
|
17 |
+
|
18 |
+
def set_input_image(self, image):
|
19 |
+
super().set_input_image(image)
|
20 |
+
self.opt_data = None
|
21 |
+
self.input_data = None
|
22 |
+
|
23 |
+
def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
|
24 |
+
pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
|
25 |
+
neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32)
|
26 |
+
|
27 |
+
for list_indx, clicks_list in enumerate(clicks_lists):
|
28 |
+
for click in clicks_list:
|
29 |
+
y, x = click.coords
|
30 |
+
y, x = int(round(y)), int(round(x))
|
31 |
+
y1, x1 = y - radius, x - radius
|
32 |
+
y2, x2 = y + radius + 1, x + radius + 1
|
33 |
+
|
34 |
+
if click.is_positive:
|
35 |
+
pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
|
36 |
+
else:
|
37 |
+
neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device)
|
41 |
+
neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device)
|
42 |
+
|
43 |
+
return pos_clicks_map, neg_clicks_map
|
44 |
+
|
45 |
+
def get_states(self):
|
46 |
+
return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data}
|
47 |
+
|
48 |
+
def set_states(self, states):
|
49 |
+
self._set_transform_states(states['transform_states'])
|
50 |
+
self.opt_data = states['opt_data']
|
51 |
+
|
52 |
+
|
53 |
+
class FeatureBRSPredictor(BRSBasePredictor):
|
54 |
+
def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs):
|
55 |
+
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
|
56 |
+
self.insertion_mode = insertion_mode
|
57 |
+
self._c1_features = None
|
58 |
+
|
59 |
+
if self.insertion_mode == 'after_deeplab':
|
60 |
+
self.num_channels = model.feature_extractor.ch
|
61 |
+
elif self.insertion_mode == 'after_c4':
|
62 |
+
self.num_channels = model.feature_extractor.aspp_in_channels
|
63 |
+
elif self.insertion_mode == 'after_aspp':
|
64 |
+
self.num_channels = model.feature_extractor.ch + 32
|
65 |
+
else:
|
66 |
+
raise NotImplementedError
|
67 |
+
|
68 |
+
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
|
69 |
+
points_nd = self.get_points_nd(clicks_lists)
|
70 |
+
pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
|
71 |
+
|
72 |
+
num_clicks = len(clicks_lists[0])
|
73 |
+
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
|
74 |
+
|
75 |
+
if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
|
76 |
+
self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
|
77 |
+
|
78 |
+
if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
|
79 |
+
self.input_data = self._get_head_input(image_nd, points_nd)
|
80 |
+
|
81 |
+
def get_prediction_logits(scale, bias):
|
82 |
+
scale = scale.view(bs, -1, 1, 1)
|
83 |
+
bias = bias.view(bs, -1, 1, 1)
|
84 |
+
if self.with_flip:
|
85 |
+
scale = scale.repeat(2, 1, 1, 1)
|
86 |
+
bias = bias.repeat(2, 1, 1, 1)
|
87 |
+
|
88 |
+
scaled_backbone_features = self.input_data * scale
|
89 |
+
scaled_backbone_features = scaled_backbone_features + bias
|
90 |
+
if self.insertion_mode == 'after_c4':
|
91 |
+
x = self.net.feature_extractor.aspp(scaled_backbone_features)
|
92 |
+
x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:],
|
93 |
+
align_corners=True)
|
94 |
+
x = torch.cat((x, self._c1_features), dim=1)
|
95 |
+
scaled_backbone_features = self.net.feature_extractor.head(x)
|
96 |
+
elif self.insertion_mode == 'after_aspp':
|
97 |
+
scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features)
|
98 |
+
|
99 |
+
pred_logits = self.net.head(scaled_backbone_features)
|
100 |
+
pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
|
101 |
+
align_corners=True)
|
102 |
+
return pred_logits
|
103 |
+
|
104 |
+
self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
|
105 |
+
if num_clicks > self.optimize_after_n_clicks:
|
106 |
+
opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
|
107 |
+
**self.opt_functor.optimizer_params)
|
108 |
+
self.opt_data = opt_result[0]
|
109 |
+
|
110 |
+
with torch.no_grad():
|
111 |
+
if self.opt_functor.best_prediction is not None:
|
112 |
+
opt_pred_logits = self.opt_functor.best_prediction
|
113 |
+
else:
|
114 |
+
opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
|
115 |
+
opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
|
116 |
+
opt_pred_logits = get_prediction_logits(*opt_vars)
|
117 |
+
|
118 |
+
return opt_pred_logits
|
119 |
+
|
120 |
+
def _get_head_input(self, image_nd, points):
|
121 |
+
with torch.no_grad():
|
122 |
+
image_nd, prev_mask = self.net.prepare_input(image_nd)
|
123 |
+
coord_features = self.net.get_coord_features(image_nd, prev_mask, points)
|
124 |
+
|
125 |
+
if self.net.rgb_conv is not None:
|
126 |
+
x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
|
127 |
+
additional_features = None
|
128 |
+
elif hasattr(self.net, 'maps_transform'):
|
129 |
+
x = image_nd
|
130 |
+
additional_features = self.net.maps_transform(coord_features)
|
131 |
+
|
132 |
+
if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp':
|
133 |
+
c1, _, c3, c4 = self.net.feature_extractor.backbone(x, additional_features)
|
134 |
+
c1 = self.net.feature_extractor.skip_project(c1)
|
135 |
+
|
136 |
+
if self.insertion_mode == 'after_aspp':
|
137 |
+
x = self.net.feature_extractor.aspp(c4)
|
138 |
+
x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True)
|
139 |
+
x = torch.cat((x, c1), dim=1)
|
140 |
+
backbone_features = x
|
141 |
+
else:
|
142 |
+
backbone_features = c4
|
143 |
+
self._c1_features = c1
|
144 |
+
else:
|
145 |
+
backbone_features = self.net.feature_extractor(x, additional_features)[0]
|
146 |
+
|
147 |
+
return backbone_features
|
148 |
+
|
149 |
+
|
150 |
+
class HRNetFeatureBRSPredictor(BRSBasePredictor):
|
151 |
+
def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs):
|
152 |
+
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
|
153 |
+
self.insertion_mode = insertion_mode
|
154 |
+
self._c1_features = None
|
155 |
+
|
156 |
+
if self.insertion_mode == 'A':
|
157 |
+
self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8])
|
158 |
+
elif self.insertion_mode == 'C':
|
159 |
+
self.num_channels = 2 * model.feature_extractor.ocr_width
|
160 |
+
else:
|
161 |
+
raise NotImplementedError
|
162 |
+
|
163 |
+
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
|
164 |
+
points_nd = self.get_points_nd(clicks_lists)
|
165 |
+
pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
|
166 |
+
num_clicks = len(clicks_lists[0])
|
167 |
+
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
|
168 |
+
|
169 |
+
if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs:
|
170 |
+
self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
|
171 |
+
|
172 |
+
if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None:
|
173 |
+
self.input_data = self._get_head_input(image_nd, points_nd)
|
174 |
+
|
175 |
+
def get_prediction_logits(scale, bias):
|
176 |
+
scale = scale.view(bs, -1, 1, 1)
|
177 |
+
bias = bias.view(bs, -1, 1, 1)
|
178 |
+
if self.with_flip:
|
179 |
+
scale = scale.repeat(2, 1, 1, 1)
|
180 |
+
bias = bias.repeat(2, 1, 1, 1)
|
181 |
+
|
182 |
+
scaled_backbone_features = self.input_data * scale
|
183 |
+
scaled_backbone_features = scaled_backbone_features + bias
|
184 |
+
if self.insertion_mode == 'A':
|
185 |
+
if self.net.feature_extractor.ocr_width > 0:
|
186 |
+
out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features)
|
187 |
+
feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features)
|
188 |
+
|
189 |
+
context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
|
190 |
+
feats = self.net.feature_extractor.ocr_distri_head(feats, context)
|
191 |
+
else:
|
192 |
+
feats = scaled_backbone_features
|
193 |
+
pred_logits = self.net.feature_extractor.cls_head(feats)
|
194 |
+
elif self.insertion_mode == 'C':
|
195 |
+
pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features)
|
196 |
+
else:
|
197 |
+
raise NotImplementedError
|
198 |
+
|
199 |
+
pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear',
|
200 |
+
align_corners=True)
|
201 |
+
return pred_logits
|
202 |
+
|
203 |
+
self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device)
|
204 |
+
if num_clicks > self.optimize_after_n_clicks:
|
205 |
+
opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data,
|
206 |
+
**self.opt_functor.optimizer_params)
|
207 |
+
self.opt_data = opt_result[0]
|
208 |
+
|
209 |
+
with torch.no_grad():
|
210 |
+
if self.opt_functor.best_prediction is not None:
|
211 |
+
opt_pred_logits = self.opt_functor.best_prediction
|
212 |
+
else:
|
213 |
+
opt_data_nd = torch.from_numpy(self.opt_data).to(self.device)
|
214 |
+
opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd)
|
215 |
+
opt_pred_logits = get_prediction_logits(*opt_vars)
|
216 |
+
|
217 |
+
return opt_pred_logits
|
218 |
+
|
219 |
+
def _get_head_input(self, image_nd, points):
|
220 |
+
with torch.no_grad():
|
221 |
+
image_nd, prev_mask = self.net.prepare_input(image_nd)
|
222 |
+
coord_features = self.net.get_coord_features(image_nd, prev_mask, points)
|
223 |
+
|
224 |
+
if self.net.rgb_conv is not None:
|
225 |
+
x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
|
226 |
+
additional_features = None
|
227 |
+
elif hasattr(self.net, 'maps_transform'):
|
228 |
+
x = image_nd
|
229 |
+
additional_features = self.net.maps_transform(coord_features)
|
230 |
+
|
231 |
+
feats = self.net.feature_extractor.compute_hrnet_feats(x, additional_features)
|
232 |
+
|
233 |
+
if self.insertion_mode == 'A':
|
234 |
+
backbone_features = feats
|
235 |
+
elif self.insertion_mode == 'C':
|
236 |
+
out_aux = self.net.feature_extractor.aux_head(feats)
|
237 |
+
feats = self.net.feature_extractor.conv3x3_ocr(feats)
|
238 |
+
|
239 |
+
context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
|
240 |
+
backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context)
|
241 |
+
else:
|
242 |
+
raise NotImplementedError
|
243 |
+
|
244 |
+
return backbone_features
|
245 |
+
|
246 |
+
|
247 |
+
class InputBRSPredictor(BRSBasePredictor):
|
248 |
+
def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs):
|
249 |
+
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
|
250 |
+
self.optimize_target = optimize_target
|
251 |
+
|
252 |
+
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
|
253 |
+
points_nd = self.get_points_nd(clicks_lists)
|
254 |
+
pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:])
|
255 |
+
num_clicks = len(clicks_lists[0])
|
256 |
+
|
257 |
+
if self.opt_data is None or is_image_changed:
|
258 |
+
if self.optimize_target == 'dmaps':
|
259 |
+
opt_channels = self.net.coord_feature_ch - 1 if self.net.with_prev_mask else self.net.coord_feature_ch
|
260 |
+
else:
|
261 |
+
opt_channels = 3
|
262 |
+
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
|
263 |
+
self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
|
264 |
+
device=self.device, dtype=torch.float32)
|
265 |
+
|
266 |
+
def get_prediction_logits(opt_bias):
|
267 |
+
input_image, prev_mask = self.net.prepare_input(image_nd)
|
268 |
+
dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd)
|
269 |
+
|
270 |
+
if self.optimize_target == 'rgb':
|
271 |
+
input_image = input_image + opt_bias
|
272 |
+
elif self.optimize_target == 'dmaps':
|
273 |
+
if self.net.with_prev_mask:
|
274 |
+
dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias
|
275 |
+
else:
|
276 |
+
dmaps = dmaps + opt_bias
|
277 |
+
|
278 |
+
if self.net.rgb_conv is not None:
|
279 |
+
x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
|
280 |
+
if self.optimize_target == 'all':
|
281 |
+
x = x + opt_bias
|
282 |
+
coord_features = None
|
283 |
+
elif hasattr(self.net, 'maps_transform'):
|
284 |
+
x = input_image
|
285 |
+
coord_features = self.net.maps_transform(dmaps)
|
286 |
+
|
287 |
+
pred_logits = self.net.backbone_forward(x, coord_features=coord_features)['instances']
|
288 |
+
pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True)
|
289 |
+
|
290 |
+
return pred_logits
|
291 |
+
|
292 |
+
self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device,
|
293 |
+
shape=self.opt_data.shape)
|
294 |
+
if num_clicks > self.optimize_after_n_clicks:
|
295 |
+
opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(),
|
296 |
+
**self.opt_functor.optimizer_params)
|
297 |
+
|
298 |
+
self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device)
|
299 |
+
|
300 |
+
with torch.no_grad():
|
301 |
+
if self.opt_functor.best_prediction is not None:
|
302 |
+
opt_pred_logits = self.opt_functor.best_prediction
|
303 |
+
else:
|
304 |
+
opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data)
|
305 |
+
opt_pred_logits = get_prediction_logits(*opt_vars)
|
306 |
+
|
307 |
+
return opt_pred_logits
|
isegm/inference/predictors/brs_functors.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from isegm.model.metrics import _compute_iou
|
5 |
+
from .brs_losses import BRSMaskLoss
|
6 |
+
|
7 |
+
|
8 |
+
class BaseOptimizer:
|
9 |
+
def __init__(self, optimizer_params,
|
10 |
+
prob_thresh=0.49,
|
11 |
+
reg_weight=1e-3,
|
12 |
+
min_iou_diff=0.01,
|
13 |
+
brs_loss=BRSMaskLoss(),
|
14 |
+
with_flip=False,
|
15 |
+
flip_average=False,
|
16 |
+
**kwargs):
|
17 |
+
self.brs_loss = brs_loss
|
18 |
+
self.optimizer_params = optimizer_params
|
19 |
+
self.prob_thresh = prob_thresh
|
20 |
+
self.reg_weight = reg_weight
|
21 |
+
self.min_iou_diff = min_iou_diff
|
22 |
+
self.with_flip = with_flip
|
23 |
+
self.flip_average = flip_average
|
24 |
+
|
25 |
+
self.best_prediction = None
|
26 |
+
self._get_prediction_logits = None
|
27 |
+
self._opt_shape = None
|
28 |
+
self._best_loss = None
|
29 |
+
self._click_masks = None
|
30 |
+
self._last_mask = None
|
31 |
+
self.device = None
|
32 |
+
|
33 |
+
def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None):
|
34 |
+
self.best_prediction = None
|
35 |
+
self._get_prediction_logits = get_prediction_logits
|
36 |
+
self._click_masks = (pos_mask, neg_mask)
|
37 |
+
self._opt_shape = shape
|
38 |
+
self._last_mask = None
|
39 |
+
self.device = device
|
40 |
+
|
41 |
+
def __call__(self, x):
|
42 |
+
opt_params = torch.from_numpy(x).float().to(self.device)
|
43 |
+
opt_params.requires_grad_(True)
|
44 |
+
|
45 |
+
with torch.enable_grad():
|
46 |
+
opt_vars, reg_loss = self.unpack_opt_params(opt_params)
|
47 |
+
result_before_sigmoid = self._get_prediction_logits(*opt_vars)
|
48 |
+
result = torch.sigmoid(result_before_sigmoid)
|
49 |
+
|
50 |
+
pos_mask, neg_mask = self._click_masks
|
51 |
+
if self.with_flip and self.flip_average:
|
52 |
+
result, result_flipped = torch.chunk(result, 2, dim=0)
|
53 |
+
result = 0.5 * (result + torch.flip(result_flipped, dims=[3]))
|
54 |
+
pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]]
|
55 |
+
|
56 |
+
loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask)
|
57 |
+
loss = loss + reg_loss
|
58 |
+
|
59 |
+
f_val = loss.detach().cpu().numpy()
|
60 |
+
if self.best_prediction is None or f_val < self._best_loss:
|
61 |
+
self.best_prediction = result_before_sigmoid.detach()
|
62 |
+
self._best_loss = f_val
|
63 |
+
|
64 |
+
if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh:
|
65 |
+
return [f_val, np.zeros_like(x)]
|
66 |
+
|
67 |
+
current_mask = result > self.prob_thresh
|
68 |
+
if self._last_mask is not None and self.min_iou_diff > 0:
|
69 |
+
diff_iou = _compute_iou(current_mask, self._last_mask)
|
70 |
+
if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff:
|
71 |
+
return [f_val, np.zeros_like(x)]
|
72 |
+
self._last_mask = current_mask
|
73 |
+
|
74 |
+
loss.backward()
|
75 |
+
f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float)
|
76 |
+
|
77 |
+
return [f_val, f_grad]
|
78 |
+
|
79 |
+
def unpack_opt_params(self, opt_params):
|
80 |
+
raise NotImplementedError
|
81 |
+
|
82 |
+
|
83 |
+
class InputOptimizer(BaseOptimizer):
|
84 |
+
def unpack_opt_params(self, opt_params):
|
85 |
+
opt_params = opt_params.view(self._opt_shape)
|
86 |
+
if self.with_flip:
|
87 |
+
opt_params_flipped = torch.flip(opt_params, dims=[3])
|
88 |
+
opt_params = torch.cat([opt_params, opt_params_flipped], dim=0)
|
89 |
+
reg_loss = self.reg_weight * torch.sum(opt_params**2)
|
90 |
+
|
91 |
+
return (opt_params,), reg_loss
|
92 |
+
|
93 |
+
|
94 |
+
class ScaleBiasOptimizer(BaseOptimizer):
|
95 |
+
def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs):
|
96 |
+
super().__init__(*args, **kwargs)
|
97 |
+
self.scale_act = scale_act
|
98 |
+
self.reg_bias_weight = reg_bias_weight
|
99 |
+
|
100 |
+
def unpack_opt_params(self, opt_params):
|
101 |
+
scale, bias = torch.chunk(opt_params, 2, dim=0)
|
102 |
+
reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2))
|
103 |
+
|
104 |
+
if self.scale_act == 'tanh':
|
105 |
+
scale = torch.tanh(scale)
|
106 |
+
elif self.scale_act == 'sin':
|
107 |
+
scale = torch.sin(scale)
|
108 |
+
|
109 |
+
return (1 + scale, bias), reg_loss
|
isegm/inference/predictors/brs_losses.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from isegm.model.losses import SigmoidBinaryCrossEntropyLoss
|
4 |
+
|
5 |
+
|
6 |
+
class BRSMaskLoss(torch.nn.Module):
|
7 |
+
def __init__(self, eps=1e-5):
|
8 |
+
super().__init__()
|
9 |
+
self._eps = eps
|
10 |
+
|
11 |
+
def forward(self, result, pos_mask, neg_mask):
|
12 |
+
pos_diff = (1 - result) * pos_mask
|
13 |
+
pos_target = torch.sum(pos_diff ** 2)
|
14 |
+
pos_target = pos_target / (torch.sum(pos_mask) + self._eps)
|
15 |
+
|
16 |
+
neg_diff = result * neg_mask
|
17 |
+
neg_target = torch.sum(neg_diff ** 2)
|
18 |
+
neg_target = neg_target / (torch.sum(neg_mask) + self._eps)
|
19 |
+
|
20 |
+
loss = pos_target + neg_target
|
21 |
+
|
22 |
+
with torch.no_grad():
|
23 |
+
f_max_pos = torch.max(torch.abs(pos_diff)).item()
|
24 |
+
f_max_neg = torch.max(torch.abs(neg_diff)).item()
|
25 |
+
|
26 |
+
return loss, f_max_pos, f_max_neg
|
27 |
+
|
28 |
+
|
29 |
+
class OracleMaskLoss(torch.nn.Module):
|
30 |
+
def __init__(self):
|
31 |
+
super().__init__()
|
32 |
+
self.gt_mask = None
|
33 |
+
self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True)
|
34 |
+
self.predictor = None
|
35 |
+
self.history = []
|
36 |
+
|
37 |
+
def set_gt_mask(self, gt_mask):
|
38 |
+
self.gt_mask = gt_mask
|
39 |
+
self.history = []
|
40 |
+
|
41 |
+
def forward(self, result, pos_mask, neg_mask):
|
42 |
+
gt_mask = self.gt_mask.to(result.device)
|
43 |
+
if self.predictor.object_roi is not None:
|
44 |
+
r1, r2, c1, c2 = self.predictor.object_roi[:4]
|
45 |
+
gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1]
|
46 |
+
gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True)
|
47 |
+
|
48 |
+
if result.shape[0] == 2:
|
49 |
+
gt_mask_flipped = torch.flip(gt_mask, dims=[3])
|
50 |
+
gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0)
|
51 |
+
|
52 |
+
loss = self.loss(result, gt_mask)
|
53 |
+
self.history.append(loss.detach().cpu().numpy()[0])
|
54 |
+
|
55 |
+
if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5:
|
56 |
+
return 0, 0, 0
|
57 |
+
|
58 |
+
return loss, 1.0, 1.0
|
isegm/inference/transforms/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import SigmoidForPred
|
2 |
+
from .flip import AddHorizontalFlip
|
3 |
+
from .zoom_in import ZoomIn
|
4 |
+
from .limit_longest_side import LimitLongestSide
|
5 |
+
from .crops import Crops
|
isegm/inference/transforms/base.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseTransform(object):
|
5 |
+
def __init__(self):
|
6 |
+
self.image_changed = False
|
7 |
+
|
8 |
+
def transform(self, image_nd, clicks_lists):
|
9 |
+
raise NotImplementedError
|
10 |
+
|
11 |
+
def inv_transform(self, prob_map):
|
12 |
+
raise NotImplementedError
|
13 |
+
|
14 |
+
def reset(self):
|
15 |
+
raise NotImplementedError
|
16 |
+
|
17 |
+
def get_state(self):
|
18 |
+
raise NotImplementedError
|
19 |
+
|
20 |
+
def set_state(self, state):
|
21 |
+
raise NotImplementedError
|
22 |
+
|
23 |
+
|
24 |
+
class SigmoidForPred(BaseTransform):
|
25 |
+
def transform(self, image_nd, clicks_lists):
|
26 |
+
return image_nd, clicks_lists
|
27 |
+
|
28 |
+
def inv_transform(self, prob_map):
|
29 |
+
return torch.sigmoid(prob_map)
|
30 |
+
|
31 |
+
def reset(self):
|
32 |
+
pass
|
33 |
+
|
34 |
+
def get_state(self):
|
35 |
+
return None
|
36 |
+
|
37 |
+
def set_state(self, state):
|
38 |
+
pass
|
isegm/inference/transforms/crops.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
from isegm.inference.clicker import Click
|
8 |
+
from .base import BaseTransform
|
9 |
+
|
10 |
+
|
11 |
+
class Crops(BaseTransform):
|
12 |
+
def __init__(self, crop_size=(320, 480), min_overlap=0.2):
|
13 |
+
super().__init__()
|
14 |
+
self.crop_height, self.crop_width = crop_size
|
15 |
+
self.min_overlap = min_overlap
|
16 |
+
|
17 |
+
self.x_offsets = None
|
18 |
+
self.y_offsets = None
|
19 |
+
self._counts = None
|
20 |
+
|
21 |
+
def transform(self, image_nd, clicks_lists: List[List[Click]]):
|
22 |
+
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
|
23 |
+
image_height, image_width = image_nd.shape[2:4]
|
24 |
+
self._counts = None
|
25 |
+
|
26 |
+
if image_height < self.crop_height or image_width < self.crop_width:
|
27 |
+
return image_nd, clicks_lists
|
28 |
+
|
29 |
+
self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap)
|
30 |
+
self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap)
|
31 |
+
self._counts = np.zeros((image_height, image_width))
|
32 |
+
|
33 |
+
image_crops = []
|
34 |
+
for dy in self.y_offsets:
|
35 |
+
for dx in self.x_offsets:
|
36 |
+
self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1
|
37 |
+
image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width]
|
38 |
+
image_crops.append(image_crop)
|
39 |
+
image_crops = torch.cat(image_crops, dim=0)
|
40 |
+
self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32)
|
41 |
+
|
42 |
+
clicks_list = clicks_lists[0]
|
43 |
+
clicks_lists = []
|
44 |
+
for dy in self.y_offsets:
|
45 |
+
for dx in self.x_offsets:
|
46 |
+
crop_clicks = [x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list]
|
47 |
+
clicks_lists.append(crop_clicks)
|
48 |
+
|
49 |
+
return image_crops, clicks_lists
|
50 |
+
|
51 |
+
def inv_transform(self, prob_map):
|
52 |
+
if self._counts is None:
|
53 |
+
return prob_map
|
54 |
+
|
55 |
+
new_prob_map = torch.zeros((1, 1, *self._counts.shape),
|
56 |
+
dtype=prob_map.dtype, device=prob_map.device)
|
57 |
+
|
58 |
+
crop_indx = 0
|
59 |
+
for dy in self.y_offsets:
|
60 |
+
for dx in self.x_offsets:
|
61 |
+
new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0]
|
62 |
+
crop_indx += 1
|
63 |
+
new_prob_map = torch.div(new_prob_map, self._counts)
|
64 |
+
|
65 |
+
return new_prob_map
|
66 |
+
|
67 |
+
def get_state(self):
|
68 |
+
return self.x_offsets, self.y_offsets, self._counts
|
69 |
+
|
70 |
+
def set_state(self, state):
|
71 |
+
self.x_offsets, self.y_offsets, self._counts = state
|
72 |
+
|
73 |
+
def reset(self):
|
74 |
+
self.x_offsets = None
|
75 |
+
self.y_offsets = None
|
76 |
+
self._counts = None
|
77 |
+
|
78 |
+
|
79 |
+
def get_offsets(length, crop_size, min_overlap_ratio=0.2):
|
80 |
+
if length == crop_size:
|
81 |
+
return [0]
|
82 |
+
|
83 |
+
N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
|
84 |
+
N = math.ceil(N)
|
85 |
+
|
86 |
+
overlap_ratio = (N - length / crop_size) / (N - 1)
|
87 |
+
overlap_width = int(crop_size * overlap_ratio)
|
88 |
+
|
89 |
+
offsets = [0]
|
90 |
+
for i in range(1, N):
|
91 |
+
new_offset = offsets[-1] + crop_size - overlap_width
|
92 |
+
if new_offset + crop_size > length:
|
93 |
+
new_offset = length - crop_size
|
94 |
+
|
95 |
+
offsets.append(new_offset)
|
96 |
+
|
97 |
+
return offsets
|
isegm/inference/transforms/flip.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from typing import List
|
4 |
+
from isegm.inference.clicker import Click
|
5 |
+
from .base import BaseTransform
|
6 |
+
|
7 |
+
|
8 |
+
class AddHorizontalFlip(BaseTransform):
|
9 |
+
def transform(self, image_nd, clicks_lists: List[List[Click]]):
|
10 |
+
assert len(image_nd.shape) == 4
|
11 |
+
image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0)
|
12 |
+
|
13 |
+
image_width = image_nd.shape[3]
|
14 |
+
clicks_lists_flipped = []
|
15 |
+
for clicks_list in clicks_lists:
|
16 |
+
clicks_list_flipped = [click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1))
|
17 |
+
for click in clicks_list]
|
18 |
+
clicks_lists_flipped.append(clicks_list_flipped)
|
19 |
+
clicks_lists = clicks_lists + clicks_lists_flipped
|
20 |
+
|
21 |
+
return image_nd, clicks_lists
|
22 |
+
|
23 |
+
def inv_transform(self, prob_map):
|
24 |
+
assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0
|
25 |
+
num_maps = prob_map.shape[0] // 2
|
26 |
+
prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:]
|
27 |
+
|
28 |
+
return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3]))
|
29 |
+
|
30 |
+
def get_state(self):
|
31 |
+
return None
|
32 |
+
|
33 |
+
def set_state(self, state):
|
34 |
+
pass
|
35 |
+
|
36 |
+
def reset(self):
|
37 |
+
pass
|
isegm/inference/transforms/limit_longest_side.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .zoom_in import ZoomIn, get_roi_image_nd
|
2 |
+
|
3 |
+
|
4 |
+
class LimitLongestSide(ZoomIn):
|
5 |
+
def __init__(self, max_size=800):
|
6 |
+
super().__init__(target_size=max_size, skip_clicks=0)
|
7 |
+
|
8 |
+
def transform(self, image_nd, clicks_lists):
|
9 |
+
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
|
10 |
+
image_max_size = max(image_nd.shape[2:4])
|
11 |
+
self.image_changed = False
|
12 |
+
|
13 |
+
if image_max_size <= self.target_size:
|
14 |
+
return image_nd, clicks_lists
|
15 |
+
self._input_image = image_nd
|
16 |
+
|
17 |
+
self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1)
|
18 |
+
self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
|
19 |
+
self.image_changed = True
|
20 |
+
|
21 |
+
tclicks_lists = [self._transform_clicks(clicks_lists[0])]
|
22 |
+
return self._roi_image, tclicks_lists
|
isegm/inference/transforms/zoom_in.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from typing import List
|
4 |
+
from isegm.inference.clicker import Click
|
5 |
+
from isegm.utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox
|
6 |
+
from .base import BaseTransform
|
7 |
+
|
8 |
+
|
9 |
+
class ZoomIn(BaseTransform):
|
10 |
+
def __init__(self,
|
11 |
+
target_size=400,
|
12 |
+
skip_clicks=1,
|
13 |
+
expansion_ratio=1.4,
|
14 |
+
min_crop_size=200,
|
15 |
+
recompute_thresh_iou=0.5,
|
16 |
+
prob_thresh=0.50):
|
17 |
+
super().__init__()
|
18 |
+
self.target_size = target_size
|
19 |
+
self.min_crop_size = min_crop_size
|
20 |
+
self.skip_clicks = skip_clicks
|
21 |
+
self.expansion_ratio = expansion_ratio
|
22 |
+
self.recompute_thresh_iou = recompute_thresh_iou
|
23 |
+
self.prob_thresh = prob_thresh
|
24 |
+
|
25 |
+
self._input_image_shape = None
|
26 |
+
self._prev_probs = None
|
27 |
+
self._object_roi = None
|
28 |
+
self._roi_image = None
|
29 |
+
|
30 |
+
def transform(self, image_nd, clicks_lists: List[List[Click]]):
|
31 |
+
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
|
32 |
+
self.image_changed = False
|
33 |
+
|
34 |
+
clicks_list = clicks_lists[0]
|
35 |
+
if len(clicks_list) <= self.skip_clicks:
|
36 |
+
return image_nd, clicks_lists
|
37 |
+
|
38 |
+
self._input_image_shape = image_nd.shape
|
39 |
+
|
40 |
+
current_object_roi = None
|
41 |
+
if self._prev_probs is not None:
|
42 |
+
current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
|
43 |
+
if current_pred_mask.sum() > 0:
|
44 |
+
current_object_roi = get_object_roi(current_pred_mask, clicks_list,
|
45 |
+
self.expansion_ratio, self.min_crop_size)
|
46 |
+
|
47 |
+
if current_object_roi is None:
|
48 |
+
if self.skip_clicks >= 0:
|
49 |
+
return image_nd, clicks_lists
|
50 |
+
else:
|
51 |
+
current_object_roi = 0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1
|
52 |
+
|
53 |
+
update_object_roi = False
|
54 |
+
if self._object_roi is None:
|
55 |
+
update_object_roi = True
|
56 |
+
elif not check_object_roi(self._object_roi, clicks_list):
|
57 |
+
update_object_roi = True
|
58 |
+
elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou:
|
59 |
+
update_object_roi = True
|
60 |
+
|
61 |
+
if update_object_roi:
|
62 |
+
self._object_roi = current_object_roi
|
63 |
+
self.image_changed = True
|
64 |
+
self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size)
|
65 |
+
|
66 |
+
tclicks_lists = [self._transform_clicks(clicks_list)]
|
67 |
+
return self._roi_image.to(image_nd.device), tclicks_lists
|
68 |
+
|
69 |
+
def inv_transform(self, prob_map):
|
70 |
+
if self._object_roi is None:
|
71 |
+
self._prev_probs = prob_map.cpu().numpy()
|
72 |
+
return prob_map
|
73 |
+
|
74 |
+
assert prob_map.shape[0] == 1
|
75 |
+
rmin, rmax, cmin, cmax = self._object_roi
|
76 |
+
prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1),
|
77 |
+
mode='bilinear', align_corners=True)
|
78 |
+
|
79 |
+
if self._prev_probs is not None:
|
80 |
+
new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype)
|
81 |
+
new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map
|
82 |
+
else:
|
83 |
+
new_prob_map = prob_map
|
84 |
+
|
85 |
+
self._prev_probs = new_prob_map.cpu().numpy()
|
86 |
+
|
87 |
+
return new_prob_map
|
88 |
+
|
89 |
+
def check_possible_recalculation(self):
|
90 |
+
if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0:
|
91 |
+
return False
|
92 |
+
|
93 |
+
pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
|
94 |
+
if pred_mask.sum() > 0:
|
95 |
+
possible_object_roi = get_object_roi(pred_mask, [],
|
96 |
+
self.expansion_ratio, self.min_crop_size)
|
97 |
+
image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1)
|
98 |
+
if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
|
99 |
+
return True
|
100 |
+
return False
|
101 |
+
|
102 |
+
def get_state(self):
|
103 |
+
roi_image = self._roi_image.cpu() if self._roi_image is not None else None
|
104 |
+
return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed
|
105 |
+
|
106 |
+
def set_state(self, state):
|
107 |
+
self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state
|
108 |
+
|
109 |
+
def reset(self):
|
110 |
+
self._input_image_shape = None
|
111 |
+
self._object_roi = None
|
112 |
+
self._prev_probs = None
|
113 |
+
self._roi_image = None
|
114 |
+
self.image_changed = False
|
115 |
+
|
116 |
+
def _transform_clicks(self, clicks_list):
|
117 |
+
if self._object_roi is None:
|
118 |
+
return clicks_list
|
119 |
+
|
120 |
+
rmin, rmax, cmin, cmax = self._object_roi
|
121 |
+
crop_height, crop_width = self._roi_image.shape[2:]
|
122 |
+
|
123 |
+
transformed_clicks = []
|
124 |
+
for click in clicks_list:
|
125 |
+
new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1)
|
126 |
+
new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1)
|
127 |
+
transformed_clicks.append(click.copy(coords=(new_r, new_c)))
|
128 |
+
return transformed_clicks
|
129 |
+
|
130 |
+
|
131 |
+
def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size):
|
132 |
+
pred_mask = pred_mask.copy()
|
133 |
+
|
134 |
+
for click in clicks_list:
|
135 |
+
if click.is_positive:
|
136 |
+
pred_mask[int(click.coords[0]), int(click.coords[1])] = 1
|
137 |
+
|
138 |
+
bbox = get_bbox_from_mask(pred_mask)
|
139 |
+
bbox = expand_bbox(bbox, expansion_ratio, min_crop_size)
|
140 |
+
h, w = pred_mask.shape[0], pred_mask.shape[1]
|
141 |
+
bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1)
|
142 |
+
|
143 |
+
return bbox
|
144 |
+
|
145 |
+
|
146 |
+
def get_roi_image_nd(image_nd, object_roi, target_size):
|
147 |
+
rmin, rmax, cmin, cmax = object_roi
|
148 |
+
|
149 |
+
height = rmax - rmin + 1
|
150 |
+
width = cmax - cmin + 1
|
151 |
+
|
152 |
+
if isinstance(target_size, tuple):
|
153 |
+
new_height, new_width = target_size
|
154 |
+
else:
|
155 |
+
scale = target_size / max(height, width)
|
156 |
+
new_height = int(round(height * scale))
|
157 |
+
new_width = int(round(width * scale))
|
158 |
+
|
159 |
+
with torch.no_grad():
|
160 |
+
roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
|
161 |
+
roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width),
|
162 |
+
mode='bilinear', align_corners=True)
|
163 |
+
|
164 |
+
return roi_image_nd
|
165 |
+
|
166 |
+
|
167 |
+
def check_object_roi(object_roi, clicks_list):
|
168 |
+
for click in clicks_list:
|
169 |
+
if click.is_positive:
|
170 |
+
if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]:
|
171 |
+
return False
|
172 |
+
if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]:
|
173 |
+
return False
|
174 |
+
|
175 |
+
return True
|
isegm/inference/utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import timedelta
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from isegm.data.datasets import GrabCutDataset, BerkeleyDataset, DavisDataset, SBDEvaluationDataset, PascalVocDataset
|
8 |
+
from isegm.utils.serialization import load_model
|
9 |
+
|
10 |
+
|
11 |
+
def get_time_metrics(all_ious, elapsed_time):
|
12 |
+
n_images = len(all_ious)
|
13 |
+
n_clicks = sum(map(len, all_ious))
|
14 |
+
|
15 |
+
mean_spc = elapsed_time / n_clicks
|
16 |
+
mean_spi = elapsed_time / n_images
|
17 |
+
|
18 |
+
return mean_spc, mean_spi
|
19 |
+
|
20 |
+
|
21 |
+
def load_is_model(checkpoint, device, **kwargs):
|
22 |
+
if isinstance(checkpoint, (str, Path)):
|
23 |
+
state_dict = torch.load(checkpoint, map_location='cpu')
|
24 |
+
else:
|
25 |
+
state_dict = checkpoint
|
26 |
+
|
27 |
+
if isinstance(state_dict, list):
|
28 |
+
model = load_single_is_model(state_dict[0], device, **kwargs)
|
29 |
+
models = [load_single_is_model(x, device, **kwargs) for x in state_dict]
|
30 |
+
|
31 |
+
return model, models
|
32 |
+
else:
|
33 |
+
return load_single_is_model(state_dict, device, **kwargs)
|
34 |
+
|
35 |
+
|
36 |
+
def load_single_is_model(state_dict, device, **kwargs):
|
37 |
+
model = load_model(state_dict['config'], **kwargs)
|
38 |
+
model.load_state_dict(state_dict['state_dict'], strict=False)
|
39 |
+
|
40 |
+
for param in model.parameters():
|
41 |
+
param.requires_grad = False
|
42 |
+
model.to(device)
|
43 |
+
model.eval()
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
|
48 |
+
def get_dataset(dataset_name, cfg):
|
49 |
+
if dataset_name == 'GrabCut':
|
50 |
+
dataset = GrabCutDataset(cfg.GRABCUT_PATH)
|
51 |
+
elif dataset_name == 'Berkeley':
|
52 |
+
dataset = BerkeleyDataset(cfg.BERKELEY_PATH)
|
53 |
+
elif dataset_name == 'DAVIS':
|
54 |
+
dataset = DavisDataset(cfg.DAVIS_PATH)
|
55 |
+
elif dataset_name == 'SBD':
|
56 |
+
dataset = SBDEvaluationDataset(cfg.SBD_PATH)
|
57 |
+
elif dataset_name == 'SBD_Train':
|
58 |
+
dataset = SBDEvaluationDataset(cfg.SBD_PATH, split='train')
|
59 |
+
elif dataset_name == 'PascalVOC':
|
60 |
+
dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split='test')
|
61 |
+
elif dataset_name == 'COCO_MVal':
|
62 |
+
dataset = DavisDataset(cfg.COCO_MVAL_PATH)
|
63 |
+
else:
|
64 |
+
dataset = None
|
65 |
+
|
66 |
+
return dataset
|
67 |
+
|
68 |
+
|
69 |
+
def get_iou(gt_mask, pred_mask, ignore_label=-1):
|
70 |
+
ignore_gt_mask_inv = gt_mask != ignore_label
|
71 |
+
obj_gt_mask = gt_mask == 1
|
72 |
+
|
73 |
+
intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
|
74 |
+
union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum()
|
75 |
+
|
76 |
+
return intersection / union
|
77 |
+
|
78 |
+
|
79 |
+
def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
|
80 |
+
def _get_noc(iou_arr, iou_thr):
|
81 |
+
vals = iou_arr >= iou_thr
|
82 |
+
return np.argmax(vals) + 1 if np.any(vals) else max_clicks
|
83 |
+
|
84 |
+
noc_list = []
|
85 |
+
over_max_list = []
|
86 |
+
for iou_thr in iou_thrs:
|
87 |
+
scores_arr = np.array([_get_noc(iou_arr, iou_thr)
|
88 |
+
for iou_arr in all_ious], dtype=np.int)
|
89 |
+
|
90 |
+
score = scores_arr.mean()
|
91 |
+
over_max = (scores_arr == max_clicks).sum()
|
92 |
+
|
93 |
+
noc_list.append(score)
|
94 |
+
over_max_list.append(over_max)
|
95 |
+
|
96 |
+
return noc_list, over_max_list
|
97 |
+
|
98 |
+
|
99 |
+
def find_checkpoint(weights_folder, checkpoint_name):
|
100 |
+
weights_folder = Path(weights_folder)
|
101 |
+
if ':' in checkpoint_name:
|
102 |
+
model_name, checkpoint_name = checkpoint_name.split(':')
|
103 |
+
models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()]
|
104 |
+
assert len(models_candidates) == 1
|
105 |
+
model_folder = models_candidates[0]
|
106 |
+
else:
|
107 |
+
model_folder = weights_folder
|
108 |
+
|
109 |
+
if checkpoint_name.endswith('.pth'):
|
110 |
+
if Path(checkpoint_name).exists():
|
111 |
+
checkpoint_path = checkpoint_name
|
112 |
+
else:
|
113 |
+
checkpoint_path = weights_folder / checkpoint_name
|
114 |
+
else:
|
115 |
+
model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth'))
|
116 |
+
assert len(model_checkpoints) == 1
|
117 |
+
checkpoint_path = model_checkpoints[0]
|
118 |
+
|
119 |
+
return str(checkpoint_path)
|
120 |
+
|
121 |
+
|
122 |
+
def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time,
|
123 |
+
n_clicks=20, model_name=None):
|
124 |
+
table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|'
|
125 |
+
f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
|
126 |
+
f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
|
127 |
+
f'{"SPC,s":^7}|{"Time":^9}|')
|
128 |
+
row_width = len(table_header)
|
129 |
+
|
130 |
+
header = f'Eval results for model: {model_name}\n' if model_name is not None else ''
|
131 |
+
header += '-' * row_width + '\n'
|
132 |
+
header += table_header + '\n' + '-' * row_width
|
133 |
+
|
134 |
+
eval_time = str(timedelta(seconds=int(elapsed_time)))
|
135 |
+
table_row = f'|{brs_type:^13}|{dataset_name:^11}|'
|
136 |
+
table_row += f'{noc_list[0]:^9.2f}|'
|
137 |
+
table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|'
|
138 |
+
table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|'
|
139 |
+
table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|'
|
140 |
+
table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|'
|
141 |
+
table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|'
|
142 |
+
|
143 |
+
return header, table_row
|
isegm/model/initializer.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class Initializer(object):
|
7 |
+
def __init__(self, local_init=True, gamma=None):
|
8 |
+
self.local_init = local_init
|
9 |
+
self.gamma = gamma
|
10 |
+
|
11 |
+
def __call__(self, m):
|
12 |
+
if getattr(m, '__initialized', False):
|
13 |
+
return
|
14 |
+
|
15 |
+
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
16 |
+
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
|
17 |
+
nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__:
|
18 |
+
if m.weight is not None:
|
19 |
+
self._init_gamma(m.weight.data)
|
20 |
+
if m.bias is not None:
|
21 |
+
self._init_beta(m.bias.data)
|
22 |
+
else:
|
23 |
+
if getattr(m, 'weight', None) is not None:
|
24 |
+
self._init_weight(m.weight.data)
|
25 |
+
if getattr(m, 'bias', None) is not None:
|
26 |
+
self._init_bias(m.bias.data)
|
27 |
+
|
28 |
+
if self.local_init:
|
29 |
+
object.__setattr__(m, '__initialized', True)
|
30 |
+
|
31 |
+
def _init_weight(self, data):
|
32 |
+
nn.init.uniform_(data, -0.07, 0.07)
|
33 |
+
|
34 |
+
def _init_bias(self, data):
|
35 |
+
nn.init.constant_(data, 0)
|
36 |
+
|
37 |
+
def _init_gamma(self, data):
|
38 |
+
if self.gamma is None:
|
39 |
+
nn.init.constant_(data, 1.0)
|
40 |
+
else:
|
41 |
+
nn.init.normal_(data, 1.0, self.gamma)
|
42 |
+
|
43 |
+
def _init_beta(self, data):
|
44 |
+
nn.init.constant_(data, 0)
|
45 |
+
|
46 |
+
|
47 |
+
class Bilinear(Initializer):
|
48 |
+
def __init__(self, scale, groups, in_channels, **kwargs):
|
49 |
+
super().__init__(**kwargs)
|
50 |
+
self.scale = scale
|
51 |
+
self.groups = groups
|
52 |
+
self.in_channels = in_channels
|
53 |
+
|
54 |
+
def _init_weight(self, data):
|
55 |
+
"""Reset the weight and bias."""
|
56 |
+
bilinear_kernel = self.get_bilinear_kernel(self.scale)
|
57 |
+
weight = torch.zeros_like(data)
|
58 |
+
for i in range(self.in_channels):
|
59 |
+
if self.groups == 1:
|
60 |
+
j = i
|
61 |
+
else:
|
62 |
+
j = 0
|
63 |
+
weight[i, j] = bilinear_kernel
|
64 |
+
data[:] = weight
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def get_bilinear_kernel(scale):
|
68 |
+
"""Generate a bilinear upsampling kernel."""
|
69 |
+
kernel_size = 2 * scale - scale % 2
|
70 |
+
scale = (kernel_size + 1) // 2
|
71 |
+
center = scale - 0.5 * (1 + kernel_size % 2)
|
72 |
+
|
73 |
+
og = np.ogrid[:kernel_size, :kernel_size]
|
74 |
+
kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale)
|
75 |
+
|
76 |
+
return torch.tensor(kernel, dtype=torch.float32)
|
77 |
+
|
78 |
+
|
79 |
+
class XavierGluon(Initializer):
|
80 |
+
def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs):
|
81 |
+
super().__init__(**kwargs)
|
82 |
+
|
83 |
+
self.rnd_type = rnd_type
|
84 |
+
self.factor_type = factor_type
|
85 |
+
self.magnitude = float(magnitude)
|
86 |
+
|
87 |
+
def _init_weight(self, arr):
|
88 |
+
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
|
89 |
+
|
90 |
+
if self.factor_type == 'avg':
|
91 |
+
factor = (fan_in + fan_out) / 2.0
|
92 |
+
elif self.factor_type == 'in':
|
93 |
+
factor = fan_in
|
94 |
+
elif self.factor_type == 'out':
|
95 |
+
factor = fan_out
|
96 |
+
else:
|
97 |
+
raise ValueError('Incorrect factor type')
|
98 |
+
scale = np.sqrt(self.magnitude / factor)
|
99 |
+
|
100 |
+
if self.rnd_type == 'uniform':
|
101 |
+
nn.init.uniform_(arr, -scale, scale)
|
102 |
+
elif self.rnd_type == 'gaussian':
|
103 |
+
nn.init.normal_(arr, 0, scale)
|
104 |
+
else:
|
105 |
+
raise ValueError('Unknown random type')
|
isegm/model/is_deeplab_model.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from isegm.utils.serialization import serialize
|
4 |
+
from .is_model import ISModel
|
5 |
+
from .modeling.deeplab_v3 import DeepLabV3Plus
|
6 |
+
from .modeling.basic_blocks import SepConvHead
|
7 |
+
from isegm.model.modifiers import LRMult
|
8 |
+
|
9 |
+
|
10 |
+
class DeeplabModel(ISModel):
|
11 |
+
@serialize
|
12 |
+
def __init__(self, backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5,
|
13 |
+
backbone_norm_layer=None, backbone_lr_mult=0.1, norm_layer=nn.BatchNorm2d, **kwargs):
|
14 |
+
super().__init__(norm_layer=norm_layer, **kwargs)
|
15 |
+
|
16 |
+
self.feature_extractor = DeepLabV3Plus(backbone=backbone, ch=deeplab_ch, project_dropout=aspp_dropout,
|
17 |
+
norm_layer=norm_layer, backbone_norm_layer=backbone_norm_layer)
|
18 |
+
self.feature_extractor.backbone.apply(LRMult(backbone_lr_mult))
|
19 |
+
self.head = SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2,
|
20 |
+
num_layers=2, norm_layer=norm_layer)
|
21 |
+
|
22 |
+
def backbone_forward(self, image, coord_features=None):
|
23 |
+
backbone_features = self.feature_extractor(image, coord_features)
|
24 |
+
|
25 |
+
return {'instances': self.head(backbone_features[0])}
|
isegm/model/is_hrnet_model.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from isegm.utils.serialization import serialize
|
4 |
+
from .is_model import ISModel
|
5 |
+
from .modeling.hrnet_ocr import HighResolutionNet
|
6 |
+
from isegm.model.modifiers import LRMult
|
7 |
+
|
8 |
+
|
9 |
+
class HRNetModel(ISModel):
|
10 |
+
@serialize
|
11 |
+
def __init__(self, width=48, ocr_width=256, small=False, backbone_lr_mult=0.1,
|
12 |
+
norm_layer=nn.BatchNorm2d, **kwargs):
|
13 |
+
super().__init__(norm_layer=norm_layer, **kwargs)
|
14 |
+
|
15 |
+
self.feature_extractor = HighResolutionNet(width=width, ocr_width=ocr_width, small=small,
|
16 |
+
num_classes=1, norm_layer=norm_layer)
|
17 |
+
self.feature_extractor.apply(LRMult(backbone_lr_mult))
|
18 |
+
if ocr_width > 0:
|
19 |
+
self.feature_extractor.ocr_distri_head.apply(LRMult(1.0))
|
20 |
+
self.feature_extractor.ocr_gather_head.apply(LRMult(1.0))
|
21 |
+
self.feature_extractor.conv3x3_ocr.apply(LRMult(1.0))
|
22 |
+
|
23 |
+
def backbone_forward(self, image, coord_features=None):
|
24 |
+
net_outputs = self.feature_extractor(image, coord_features)
|
25 |
+
|
26 |
+
return {'instances': net_outputs[0], 'instances_aux': net_outputs[1]}
|
isegm/model/is_model.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from isegm.model.ops import DistMaps, ScaleLayer, BatchImageNormalize
|
6 |
+
from isegm.model.modifiers import LRMult
|
7 |
+
|
8 |
+
|
9 |
+
class ISModel(nn.Module):
|
10 |
+
def __init__(self, use_rgb_conv=True, with_aux_output=False,
|
11 |
+
norm_radius=260, use_disks=False, cpu_dist_maps=False,
|
12 |
+
clicks_groups=None, with_prev_mask=False, use_leaky_relu=False,
|
13 |
+
binary_prev_mask=False, conv_extend=False, norm_layer=nn.BatchNorm2d,
|
14 |
+
norm_mean_std=([.485, .456, .406], [.229, .224, .225])):
|
15 |
+
super().__init__()
|
16 |
+
self.with_aux_output = with_aux_output
|
17 |
+
self.clicks_groups = clicks_groups
|
18 |
+
self.with_prev_mask = with_prev_mask
|
19 |
+
self.binary_prev_mask = binary_prev_mask
|
20 |
+
self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1])
|
21 |
+
|
22 |
+
self.coord_feature_ch = 2
|
23 |
+
if clicks_groups is not None:
|
24 |
+
self.coord_feature_ch *= len(clicks_groups)
|
25 |
+
|
26 |
+
if self.with_prev_mask:
|
27 |
+
self.coord_feature_ch += 1
|
28 |
+
|
29 |
+
if use_rgb_conv:
|
30 |
+
rgb_conv_layers = [
|
31 |
+
nn.Conv2d(in_channels=3 + self.coord_feature_ch, out_channels=6 + self.coord_feature_ch, kernel_size=1),
|
32 |
+
norm_layer(6 + self.coord_feature_ch),
|
33 |
+
nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True),
|
34 |
+
nn.Conv2d(in_channels=6 + self.coord_feature_ch, out_channels=3, kernel_size=1)
|
35 |
+
]
|
36 |
+
self.rgb_conv = nn.Sequential(*rgb_conv_layers)
|
37 |
+
elif conv_extend:
|
38 |
+
self.rgb_conv = None
|
39 |
+
self.maps_transform = nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=64,
|
40 |
+
kernel_size=3, stride=2, padding=1)
|
41 |
+
self.maps_transform.apply(LRMult(0.1))
|
42 |
+
else:
|
43 |
+
self.rgb_conv = None
|
44 |
+
mt_layers = [
|
45 |
+
nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1),
|
46 |
+
nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True),
|
47 |
+
nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1),
|
48 |
+
ScaleLayer(init_value=0.05, lr_mult=1)
|
49 |
+
]
|
50 |
+
self.maps_transform = nn.Sequential(*mt_layers)
|
51 |
+
|
52 |
+
if self.clicks_groups is not None:
|
53 |
+
self.dist_maps = nn.ModuleList()
|
54 |
+
for click_radius in self.clicks_groups:
|
55 |
+
self.dist_maps.append(DistMaps(norm_radius=click_radius, spatial_scale=1.0,
|
56 |
+
cpu_mode=cpu_dist_maps, use_disks=use_disks))
|
57 |
+
else:
|
58 |
+
self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0,
|
59 |
+
cpu_mode=cpu_dist_maps, use_disks=use_disks)
|
60 |
+
|
61 |
+
def forward(self, image, points):
|
62 |
+
image, prev_mask = self.prepare_input(image)
|
63 |
+
coord_features = self.get_coord_features(image, prev_mask, points)
|
64 |
+
|
65 |
+
if self.rgb_conv is not None:
|
66 |
+
x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
|
67 |
+
outputs = self.backbone_forward(x)
|
68 |
+
else:
|
69 |
+
coord_features = self.maps_transform(coord_features)
|
70 |
+
outputs = self.backbone_forward(image, coord_features)
|
71 |
+
|
72 |
+
outputs['instances'] = nn.functional.interpolate(outputs['instances'], size=image.size()[2:],
|
73 |
+
mode='bilinear', align_corners=True)
|
74 |
+
if self.with_aux_output:
|
75 |
+
outputs['instances_aux'] = nn.functional.interpolate(outputs['instances_aux'], size=image.size()[2:],
|
76 |
+
mode='bilinear', align_corners=True)
|
77 |
+
|
78 |
+
return outputs
|
79 |
+
|
80 |
+
def prepare_input(self, image):
|
81 |
+
prev_mask = None
|
82 |
+
if self.with_prev_mask:
|
83 |
+
prev_mask = image[:, 3:, :, :]
|
84 |
+
image = image[:, :3, :, :]
|
85 |
+
if self.binary_prev_mask:
|
86 |
+
prev_mask = (prev_mask > 0.5).float()
|
87 |
+
|
88 |
+
image = self.normalization(image)
|
89 |
+
return image, prev_mask
|
90 |
+
|
91 |
+
def backbone_forward(self, image, coord_features=None):
|
92 |
+
raise NotImplementedError
|
93 |
+
|
94 |
+
def get_coord_features(self, image, prev_mask, points):
|
95 |
+
if self.clicks_groups is not None:
|
96 |
+
points_groups = split_points_by_order(points, groups=(2,) + (1, ) * (len(self.clicks_groups) - 2) + (-1,))
|
97 |
+
coord_features = [dist_map(image, pg) for dist_map, pg in zip(self.dist_maps, points_groups)]
|
98 |
+
coord_features = torch.cat(coord_features, dim=1)
|
99 |
+
else:
|
100 |
+
coord_features = self.dist_maps(image, points)
|
101 |
+
|
102 |
+
if prev_mask is not None:
|
103 |
+
coord_features = torch.cat((prev_mask, coord_features), dim=1)
|
104 |
+
|
105 |
+
return coord_features
|
106 |
+
|
107 |
+
|
108 |
+
def split_points_by_order(tpoints: torch.Tensor, groups):
|
109 |
+
points = tpoints.cpu().numpy()
|
110 |
+
num_groups = len(groups)
|
111 |
+
bs = points.shape[0]
|
112 |
+
num_points = points.shape[1] // 2
|
113 |
+
|
114 |
+
groups = [x if x > 0 else num_points for x in groups]
|
115 |
+
group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32)
|
116 |
+
for x in groups]
|
117 |
+
|
118 |
+
last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int)
|
119 |
+
for group_indx, group_size in enumerate(groups):
|
120 |
+
last_point_indx_group[:, group_indx, 1] = group_size
|
121 |
+
|
122 |
+
for bindx in range(bs):
|
123 |
+
for pindx in range(2 * num_points):
|
124 |
+
point = points[bindx, pindx, :]
|
125 |
+
group_id = int(point[2])
|
126 |
+
if group_id < 0:
|
127 |
+
continue
|
128 |
+
|
129 |
+
is_negative = int(pindx >= num_points)
|
130 |
+
if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click
|
131 |
+
group_id = num_groups - 1
|
132 |
+
|
133 |
+
new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
|
134 |
+
last_point_indx_group[bindx, group_id, is_negative] += 1
|
135 |
+
|
136 |
+
group_points[group_id][bindx, new_point_indx, :] = point
|
137 |
+
|
138 |
+
group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device)
|
139 |
+
for x in group_points]
|
140 |
+
|
141 |
+
return group_points
|
isegm/model/losses.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from isegm.utils import misc
|
7 |
+
|
8 |
+
|
9 |
+
class NormalizedFocalLossSigmoid(nn.Module):
|
10 |
+
def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12,
|
11 |
+
from_sigmoid=False, detach_delimeter=True,
|
12 |
+
batch_axis=0, weight=None, size_average=True,
|
13 |
+
ignore_label=-1):
|
14 |
+
super(NormalizedFocalLossSigmoid, self).__init__()
|
15 |
+
self._axis = axis
|
16 |
+
self._alpha = alpha
|
17 |
+
self._gamma = gamma
|
18 |
+
self._ignore_label = ignore_label
|
19 |
+
self._weight = weight if weight is not None else 1.0
|
20 |
+
self._batch_axis = batch_axis
|
21 |
+
|
22 |
+
self._from_logits = from_sigmoid
|
23 |
+
self._eps = eps
|
24 |
+
self._size_average = size_average
|
25 |
+
self._detach_delimeter = detach_delimeter
|
26 |
+
self._max_mult = max_mult
|
27 |
+
self._k_sum = 0
|
28 |
+
self._m_max = 0
|
29 |
+
|
30 |
+
def forward(self, pred, label):
|
31 |
+
one_hot = label > 0.5
|
32 |
+
sample_weight = label != self._ignore_label
|
33 |
+
|
34 |
+
if not self._from_logits:
|
35 |
+
pred = torch.sigmoid(pred)
|
36 |
+
|
37 |
+
alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
|
38 |
+
pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred))
|
39 |
+
|
40 |
+
beta = (1 - pt) ** self._gamma
|
41 |
+
|
42 |
+
sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
|
43 |
+
beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
|
44 |
+
mult = sw_sum / (beta_sum + self._eps)
|
45 |
+
if self._detach_delimeter:
|
46 |
+
mult = mult.detach()
|
47 |
+
beta = beta * mult
|
48 |
+
if self._max_mult > 0:
|
49 |
+
beta = torch.clamp_max(beta, self._max_mult)
|
50 |
+
|
51 |
+
with torch.no_grad():
|
52 |
+
ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy()
|
53 |
+
sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
|
54 |
+
if np.any(ignore_area == 0):
|
55 |
+
self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
|
56 |
+
|
57 |
+
beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1)
|
58 |
+
beta_pmax = beta_pmax.mean().item()
|
59 |
+
self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax
|
60 |
+
|
61 |
+
loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
|
62 |
+
loss = self._weight * (loss * sample_weight)
|
63 |
+
|
64 |
+
if self._size_average:
|
65 |
+
bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis))
|
66 |
+
loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps)
|
67 |
+
else:
|
68 |
+
loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
|
69 |
+
|
70 |
+
return loss
|
71 |
+
|
72 |
+
def log_states(self, sw, name, global_step):
|
73 |
+
sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step)
|
74 |
+
sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step)
|
75 |
+
|
76 |
+
|
77 |
+
class FocalLoss(nn.Module):
|
78 |
+
def __init__(self, axis=-1, alpha=0.25, gamma=2,
|
79 |
+
from_logits=False, batch_axis=0,
|
80 |
+
weight=None, num_class=None,
|
81 |
+
eps=1e-9, size_average=True, scale=1.0,
|
82 |
+
ignore_label=-1):
|
83 |
+
super(FocalLoss, self).__init__()
|
84 |
+
self._axis = axis
|
85 |
+
self._alpha = alpha
|
86 |
+
self._gamma = gamma
|
87 |
+
self._ignore_label = ignore_label
|
88 |
+
self._weight = weight if weight is not None else 1.0
|
89 |
+
self._batch_axis = batch_axis
|
90 |
+
|
91 |
+
self._scale = scale
|
92 |
+
self._num_class = num_class
|
93 |
+
self._from_logits = from_logits
|
94 |
+
self._eps = eps
|
95 |
+
self._size_average = size_average
|
96 |
+
|
97 |
+
def forward(self, pred, label, sample_weight=None):
|
98 |
+
one_hot = label > 0.5
|
99 |
+
sample_weight = label != self._ignore_label
|
100 |
+
|
101 |
+
if not self._from_logits:
|
102 |
+
pred = torch.sigmoid(pred)
|
103 |
+
|
104 |
+
alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight)
|
105 |
+
pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred))
|
106 |
+
|
107 |
+
beta = (1 - pt) ** self._gamma
|
108 |
+
|
109 |
+
loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)))
|
110 |
+
loss = self._weight * (loss * sample_weight)
|
111 |
+
|
112 |
+
if self._size_average:
|
113 |
+
tsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis))
|
114 |
+
loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps)
|
115 |
+
else:
|
116 |
+
loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
|
117 |
+
|
118 |
+
return self._scale * loss
|
119 |
+
|
120 |
+
|
121 |
+
class SoftIoU(nn.Module):
|
122 |
+
def __init__(self, from_sigmoid=False, ignore_label=-1):
|
123 |
+
super().__init__()
|
124 |
+
self._from_sigmoid = from_sigmoid
|
125 |
+
self._ignore_label = ignore_label
|
126 |
+
|
127 |
+
def forward(self, pred, label):
|
128 |
+
label = label.view(pred.size())
|
129 |
+
sample_weight = label != self._ignore_label
|
130 |
+
|
131 |
+
if not self._from_sigmoid:
|
132 |
+
pred = torch.sigmoid(pred)
|
133 |
+
|
134 |
+
loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) \
|
135 |
+
/ (torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8)
|
136 |
+
|
137 |
+
return loss
|
138 |
+
|
139 |
+
|
140 |
+
class SigmoidBinaryCrossEntropyLoss(nn.Module):
|
141 |
+
def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1):
|
142 |
+
super(SigmoidBinaryCrossEntropyLoss, self).__init__()
|
143 |
+
self._from_sigmoid = from_sigmoid
|
144 |
+
self._ignore_label = ignore_label
|
145 |
+
self._weight = weight if weight is not None else 1.0
|
146 |
+
self._batch_axis = batch_axis
|
147 |
+
|
148 |
+
def forward(self, pred, label):
|
149 |
+
label = label.view(pred.size())
|
150 |
+
sample_weight = label != self._ignore_label
|
151 |
+
label = torch.where(sample_weight, label, torch.zeros_like(label))
|
152 |
+
|
153 |
+
if not self._from_sigmoid:
|
154 |
+
loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
|
155 |
+
else:
|
156 |
+
eps = 1e-12
|
157 |
+
loss = -(torch.log(pred + eps) * label
|
158 |
+
+ torch.log(1. - pred + eps) * (1. - label))
|
159 |
+
|
160 |
+
loss = self._weight * (loss * sample_weight)
|
161 |
+
return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis))
|
isegm/model/metrics.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from isegm.utils import misc
|
5 |
+
|
6 |
+
|
7 |
+
class TrainMetric(object):
|
8 |
+
def __init__(self, pred_outputs, gt_outputs):
|
9 |
+
self.pred_outputs = pred_outputs
|
10 |
+
self.gt_outputs = gt_outputs
|
11 |
+
|
12 |
+
def update(self, *args, **kwargs):
|
13 |
+
raise NotImplementedError
|
14 |
+
|
15 |
+
def get_epoch_value(self):
|
16 |
+
raise NotImplementedError
|
17 |
+
|
18 |
+
def reset_epoch_stats(self):
|
19 |
+
raise NotImplementedError
|
20 |
+
|
21 |
+
def log_states(self, sw, tag_prefix, global_step):
|
22 |
+
pass
|
23 |
+
|
24 |
+
@property
|
25 |
+
def name(self):
|
26 |
+
return type(self).__name__
|
27 |
+
|
28 |
+
|
29 |
+
class AdaptiveIoU(TrainMetric):
|
30 |
+
def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9,
|
31 |
+
ignore_label=-1, from_logits=True,
|
32 |
+
pred_output='instances', gt_output='instances'):
|
33 |
+
super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,))
|
34 |
+
self._ignore_label = ignore_label
|
35 |
+
self._from_logits = from_logits
|
36 |
+
self._iou_thresh = init_thresh
|
37 |
+
self._thresh_step = thresh_step
|
38 |
+
self._thresh_beta = thresh_beta
|
39 |
+
self._iou_beta = iou_beta
|
40 |
+
self._ema_iou = 0.0
|
41 |
+
self._epoch_iou_sum = 0.0
|
42 |
+
self._epoch_batch_count = 0
|
43 |
+
|
44 |
+
def update(self, pred, gt):
|
45 |
+
gt_mask = gt > 0.5
|
46 |
+
if self._from_logits:
|
47 |
+
pred = torch.sigmoid(pred)
|
48 |
+
|
49 |
+
gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy()
|
50 |
+
if np.all(gt_mask_area == 0):
|
51 |
+
return
|
52 |
+
|
53 |
+
ignore_mask = gt == self._ignore_label
|
54 |
+
max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean()
|
55 |
+
best_thresh = self._iou_thresh
|
56 |
+
for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]:
|
57 |
+
temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean()
|
58 |
+
if temp_iou > max_iou:
|
59 |
+
max_iou = temp_iou
|
60 |
+
best_thresh = t
|
61 |
+
|
62 |
+
self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh
|
63 |
+
self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou
|
64 |
+
self._epoch_iou_sum += max_iou
|
65 |
+
self._epoch_batch_count += 1
|
66 |
+
|
67 |
+
def get_epoch_value(self):
|
68 |
+
if self._epoch_batch_count > 0:
|
69 |
+
return self._epoch_iou_sum / self._epoch_batch_count
|
70 |
+
else:
|
71 |
+
return 0.0
|
72 |
+
|
73 |
+
def reset_epoch_stats(self):
|
74 |
+
self._epoch_iou_sum = 0.0
|
75 |
+
self._epoch_batch_count = 0
|
76 |
+
|
77 |
+
def log_states(self, sw, tag_prefix, global_step):
|
78 |
+
sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step)
|
79 |
+
sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step)
|
80 |
+
|
81 |
+
@property
|
82 |
+
def iou_thresh(self):
|
83 |
+
return self._iou_thresh
|
84 |
+
|
85 |
+
|
86 |
+
def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False):
|
87 |
+
if ignore_mask is not None:
|
88 |
+
pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)
|
89 |
+
|
90 |
+
reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
|
91 |
+
union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
|
92 |
+
intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
|
93 |
+
nonzero = union > 0
|
94 |
+
|
95 |
+
iou = intersection[nonzero] / union[nonzero]
|
96 |
+
if not keep_ignore:
|
97 |
+
return iou
|
98 |
+
else:
|
99 |
+
result = np.full_like(intersection, -1)
|
100 |
+
result[nonzero] = iou
|
101 |
+
return result
|
isegm/model/modeling/basic_blocks.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from isegm.model import ops
|
4 |
+
|
5 |
+
|
6 |
+
class ConvHead(nn.Module):
|
7 |
+
def __init__(self, out_channels, in_channels=32, num_layers=1,
|
8 |
+
kernel_size=3, padding=1,
|
9 |
+
norm_layer=nn.BatchNorm2d):
|
10 |
+
super(ConvHead, self).__init__()
|
11 |
+
convhead = []
|
12 |
+
|
13 |
+
for i in range(num_layers):
|
14 |
+
convhead.extend([
|
15 |
+
nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
|
16 |
+
nn.ReLU(),
|
17 |
+
norm_layer(in_channels) if norm_layer is not None else nn.Identity()
|
18 |
+
])
|
19 |
+
convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
|
20 |
+
|
21 |
+
self.convhead = nn.Sequential(*convhead)
|
22 |
+
|
23 |
+
def forward(self, *inputs):
|
24 |
+
return self.convhead(inputs[0])
|
25 |
+
|
26 |
+
|
27 |
+
class SepConvHead(nn.Module):
|
28 |
+
def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1,
|
29 |
+
kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0,
|
30 |
+
norm_layer=nn.BatchNorm2d):
|
31 |
+
super(SepConvHead, self).__init__()
|
32 |
+
|
33 |
+
sepconvhead = []
|
34 |
+
|
35 |
+
for i in range(num_layers):
|
36 |
+
sepconvhead.append(
|
37 |
+
SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels,
|
38 |
+
out_channels=mid_channels,
|
39 |
+
dw_kernel=kernel_size, dw_padding=padding,
|
40 |
+
norm_layer=norm_layer, activation='relu')
|
41 |
+
)
|
42 |
+
if dropout_ratio > 0 and dropout_indx == i:
|
43 |
+
sepconvhead.append(nn.Dropout(dropout_ratio))
|
44 |
+
|
45 |
+
sepconvhead.append(
|
46 |
+
nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0)
|
47 |
+
)
|
48 |
+
|
49 |
+
self.layers = nn.Sequential(*sepconvhead)
|
50 |
+
|
51 |
+
def forward(self, *inputs):
|
52 |
+
x = inputs[0]
|
53 |
+
|
54 |
+
return self.layers(x)
|
55 |
+
|
56 |
+
|
57 |
+
class SeparableConv2d(nn.Module):
|
58 |
+
def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1,
|
59 |
+
activation=None, use_bias=False, norm_layer=None):
|
60 |
+
super(SeparableConv2d, self).__init__()
|
61 |
+
_activation = ops.select_activation_function(activation)
|
62 |
+
self.body = nn.Sequential(
|
63 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride,
|
64 |
+
padding=dw_padding, bias=use_bias, groups=in_channels),
|
65 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias),
|
66 |
+
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
|
67 |
+
_activation()
|
68 |
+
)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
return self.body(x)
|
isegm/model/modeling/deeplab_v3.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import ExitStack
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from .basic_blocks import SeparableConv2d
|
8 |
+
from .resnet import ResNetBackbone
|
9 |
+
from isegm.model import ops
|
10 |
+
|
11 |
+
|
12 |
+
class DeepLabV3Plus(nn.Module):
|
13 |
+
def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
|
14 |
+
backbone_norm_layer=None,
|
15 |
+
ch=256,
|
16 |
+
project_dropout=0.5,
|
17 |
+
inference_mode=False,
|
18 |
+
**kwargs):
|
19 |
+
super(DeepLabV3Plus, self).__init__()
|
20 |
+
if backbone_norm_layer is None:
|
21 |
+
backbone_norm_layer = norm_layer
|
22 |
+
|
23 |
+
self.backbone_name = backbone
|
24 |
+
self.norm_layer = norm_layer
|
25 |
+
self.backbone_norm_layer = backbone_norm_layer
|
26 |
+
self.inference_mode = False
|
27 |
+
self.ch = ch
|
28 |
+
self.aspp_in_channels = 2048
|
29 |
+
self.skip_project_in_channels = 256 # layer 1 out_channels
|
30 |
+
|
31 |
+
self._kwargs = kwargs
|
32 |
+
if backbone == 'resnet34':
|
33 |
+
self.aspp_in_channels = 512
|
34 |
+
self.skip_project_in_channels = 64
|
35 |
+
|
36 |
+
self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False,
|
37 |
+
norm_layer=self.backbone_norm_layer, **kwargs)
|
38 |
+
|
39 |
+
self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch,
|
40 |
+
norm_layer=self.norm_layer)
|
41 |
+
self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer)
|
42 |
+
self.aspp = _ASPP(in_channels=self.aspp_in_channels,
|
43 |
+
atrous_rates=[12, 24, 36],
|
44 |
+
out_channels=ch,
|
45 |
+
project_dropout=project_dropout,
|
46 |
+
norm_layer=self.norm_layer)
|
47 |
+
|
48 |
+
if inference_mode:
|
49 |
+
self.set_prediction_mode()
|
50 |
+
|
51 |
+
def load_pretrained_weights(self):
|
52 |
+
pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True,
|
53 |
+
norm_layer=self.backbone_norm_layer, **self._kwargs)
|
54 |
+
backbone_state_dict = self.backbone.state_dict()
|
55 |
+
pretrained_state_dict = pretrained.state_dict()
|
56 |
+
|
57 |
+
backbone_state_dict.update(pretrained_state_dict)
|
58 |
+
self.backbone.load_state_dict(backbone_state_dict)
|
59 |
+
|
60 |
+
if self.inference_mode:
|
61 |
+
for param in self.backbone.parameters():
|
62 |
+
param.requires_grad = False
|
63 |
+
|
64 |
+
def set_prediction_mode(self):
|
65 |
+
self.inference_mode = True
|
66 |
+
self.eval()
|
67 |
+
|
68 |
+
def forward(self, x, additional_features=None):
|
69 |
+
with ExitStack() as stack:
|
70 |
+
if self.inference_mode:
|
71 |
+
stack.enter_context(torch.no_grad())
|
72 |
+
|
73 |
+
c1, _, c3, c4 = self.backbone(x, additional_features)
|
74 |
+
c1 = self.skip_project(c1)
|
75 |
+
|
76 |
+
x = self.aspp(c4)
|
77 |
+
x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True)
|
78 |
+
x = torch.cat((x, c1), dim=1)
|
79 |
+
x = self.head(x)
|
80 |
+
|
81 |
+
return x,
|
82 |
+
|
83 |
+
|
84 |
+
class _SkipProject(nn.Module):
|
85 |
+
def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
|
86 |
+
super(_SkipProject, self).__init__()
|
87 |
+
_activation = ops.select_activation_function("relu")
|
88 |
+
|
89 |
+
self.skip_project = nn.Sequential(
|
90 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
91 |
+
norm_layer(out_channels),
|
92 |
+
_activation()
|
93 |
+
)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
return self.skip_project(x)
|
97 |
+
|
98 |
+
|
99 |
+
class _DeepLabHead(nn.Module):
|
100 |
+
def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d):
|
101 |
+
super(_DeepLabHead, self).__init__()
|
102 |
+
|
103 |
+
self.block = nn.Sequential(
|
104 |
+
SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3,
|
105 |
+
dw_padding=1, activation='relu', norm_layer=norm_layer),
|
106 |
+
SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3,
|
107 |
+
dw_padding=1, activation='relu', norm_layer=norm_layer),
|
108 |
+
nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
|
109 |
+
)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
return self.block(x)
|
113 |
+
|
114 |
+
|
115 |
+
class _ASPP(nn.Module):
|
116 |
+
def __init__(self, in_channels, atrous_rates, out_channels=256,
|
117 |
+
project_dropout=0.5, norm_layer=nn.BatchNorm2d):
|
118 |
+
super(_ASPP, self).__init__()
|
119 |
+
|
120 |
+
b0 = nn.Sequential(
|
121 |
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
|
122 |
+
norm_layer(out_channels),
|
123 |
+
nn.ReLU()
|
124 |
+
)
|
125 |
+
|
126 |
+
rate1, rate2, rate3 = tuple(atrous_rates)
|
127 |
+
b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
|
128 |
+
b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
|
129 |
+
b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
|
130 |
+
b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
|
131 |
+
|
132 |
+
self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
|
133 |
+
|
134 |
+
project = [
|
135 |
+
nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
|
136 |
+
kernel_size=1, bias=False),
|
137 |
+
norm_layer(out_channels),
|
138 |
+
nn.ReLU()
|
139 |
+
]
|
140 |
+
if project_dropout > 0:
|
141 |
+
project.append(nn.Dropout(project_dropout))
|
142 |
+
self.project = nn.Sequential(*project)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
x = torch.cat([block(x) for block in self.concurent], dim=1)
|
146 |
+
|
147 |
+
return self.project(x)
|
148 |
+
|
149 |
+
|
150 |
+
class _AsppPooling(nn.Module):
|
151 |
+
def __init__(self, in_channels, out_channels, norm_layer):
|
152 |
+
super(_AsppPooling, self).__init__()
|
153 |
+
|
154 |
+
self.gap = nn.Sequential(
|
155 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
156 |
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
157 |
+
kernel_size=1, bias=False),
|
158 |
+
norm_layer(out_channels),
|
159 |
+
nn.ReLU()
|
160 |
+
)
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
pool = self.gap(x)
|
164 |
+
return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True)
|
165 |
+
|
166 |
+
|
167 |
+
def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
|
168 |
+
block = nn.Sequential(
|
169 |
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
170 |
+
kernel_size=3, padding=atrous_rate,
|
171 |
+
dilation=atrous_rate, bias=False),
|
172 |
+
norm_layer(out_channels),
|
173 |
+
nn.ReLU()
|
174 |
+
)
|
175 |
+
|
176 |
+
return block
|
isegm/model/modeling/hrnet_ocr.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch._utils
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from .ocr import SpatialOCR_Module, SpatialGather_Module
|
8 |
+
from .resnetv1b import BasicBlockV1b, BottleneckV1b
|
9 |
+
|
10 |
+
relu_inplace = True
|
11 |
+
|
12 |
+
|
13 |
+
class HighResolutionModule(nn.Module):
|
14 |
+
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
15 |
+
num_channels, fuse_method,multi_scale_output=True,
|
16 |
+
norm_layer=nn.BatchNorm2d, align_corners=True):
|
17 |
+
super(HighResolutionModule, self).__init__()
|
18 |
+
self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
|
19 |
+
|
20 |
+
self.num_inchannels = num_inchannels
|
21 |
+
self.fuse_method = fuse_method
|
22 |
+
self.num_branches = num_branches
|
23 |
+
self.norm_layer = norm_layer
|
24 |
+
self.align_corners = align_corners
|
25 |
+
|
26 |
+
self.multi_scale_output = multi_scale_output
|
27 |
+
|
28 |
+
self.branches = self._make_branches(
|
29 |
+
num_branches, blocks, num_blocks, num_channels)
|
30 |
+
self.fuse_layers = self._make_fuse_layers()
|
31 |
+
self.relu = nn.ReLU(inplace=relu_inplace)
|
32 |
+
|
33 |
+
def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
|
34 |
+
if num_branches != len(num_blocks):
|
35 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
36 |
+
num_branches, len(num_blocks))
|
37 |
+
raise ValueError(error_msg)
|
38 |
+
|
39 |
+
if num_branches != len(num_channels):
|
40 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
41 |
+
num_branches, len(num_channels))
|
42 |
+
raise ValueError(error_msg)
|
43 |
+
|
44 |
+
if num_branches != len(num_inchannels):
|
45 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
46 |
+
num_branches, len(num_inchannels))
|
47 |
+
raise ValueError(error_msg)
|
48 |
+
|
49 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
50 |
+
stride=1):
|
51 |
+
downsample = None
|
52 |
+
if stride != 1 or \
|
53 |
+
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
54 |
+
downsample = nn.Sequential(
|
55 |
+
nn.Conv2d(self.num_inchannels[branch_index],
|
56 |
+
num_channels[branch_index] * block.expansion,
|
57 |
+
kernel_size=1, stride=stride, bias=False),
|
58 |
+
self.norm_layer(num_channels[branch_index] * block.expansion),
|
59 |
+
)
|
60 |
+
|
61 |
+
layers = []
|
62 |
+
layers.append(block(self.num_inchannels[branch_index],
|
63 |
+
num_channels[branch_index], stride,
|
64 |
+
downsample=downsample, norm_layer=self.norm_layer))
|
65 |
+
self.num_inchannels[branch_index] = \
|
66 |
+
num_channels[branch_index] * block.expansion
|
67 |
+
for i in range(1, num_blocks[branch_index]):
|
68 |
+
layers.append(block(self.num_inchannels[branch_index],
|
69 |
+
num_channels[branch_index],
|
70 |
+
norm_layer=self.norm_layer))
|
71 |
+
|
72 |
+
return nn.Sequential(*layers)
|
73 |
+
|
74 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
75 |
+
branches = []
|
76 |
+
|
77 |
+
for i in range(num_branches):
|
78 |
+
branches.append(
|
79 |
+
self._make_one_branch(i, block, num_blocks, num_channels))
|
80 |
+
|
81 |
+
return nn.ModuleList(branches)
|
82 |
+
|
83 |
+
def _make_fuse_layers(self):
|
84 |
+
if self.num_branches == 1:
|
85 |
+
return None
|
86 |
+
|
87 |
+
num_branches = self.num_branches
|
88 |
+
num_inchannels = self.num_inchannels
|
89 |
+
fuse_layers = []
|
90 |
+
for i in range(num_branches if self.multi_scale_output else 1):
|
91 |
+
fuse_layer = []
|
92 |
+
for j in range(num_branches):
|
93 |
+
if j > i:
|
94 |
+
fuse_layer.append(nn.Sequential(
|
95 |
+
nn.Conv2d(in_channels=num_inchannels[j],
|
96 |
+
out_channels=num_inchannels[i],
|
97 |
+
kernel_size=1,
|
98 |
+
bias=False),
|
99 |
+
self.norm_layer(num_inchannels[i])))
|
100 |
+
elif j == i:
|
101 |
+
fuse_layer.append(None)
|
102 |
+
else:
|
103 |
+
conv3x3s = []
|
104 |
+
for k in range(i - j):
|
105 |
+
if k == i - j - 1:
|
106 |
+
num_outchannels_conv3x3 = num_inchannels[i]
|
107 |
+
conv3x3s.append(nn.Sequential(
|
108 |
+
nn.Conv2d(num_inchannels[j],
|
109 |
+
num_outchannels_conv3x3,
|
110 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
111 |
+
self.norm_layer(num_outchannels_conv3x3)))
|
112 |
+
else:
|
113 |
+
num_outchannels_conv3x3 = num_inchannels[j]
|
114 |
+
conv3x3s.append(nn.Sequential(
|
115 |
+
nn.Conv2d(num_inchannels[j],
|
116 |
+
num_outchannels_conv3x3,
|
117 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
118 |
+
self.norm_layer(num_outchannels_conv3x3),
|
119 |
+
nn.ReLU(inplace=relu_inplace)))
|
120 |
+
fuse_layer.append(nn.Sequential(*conv3x3s))
|
121 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
122 |
+
|
123 |
+
return nn.ModuleList(fuse_layers)
|
124 |
+
|
125 |
+
def get_num_inchannels(self):
|
126 |
+
return self.num_inchannels
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
if self.num_branches == 1:
|
130 |
+
return [self.branches[0](x[0])]
|
131 |
+
|
132 |
+
for i in range(self.num_branches):
|
133 |
+
x[i] = self.branches[i](x[i])
|
134 |
+
|
135 |
+
x_fuse = []
|
136 |
+
for i in range(len(self.fuse_layers)):
|
137 |
+
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
138 |
+
for j in range(1, self.num_branches):
|
139 |
+
if i == j:
|
140 |
+
y = y + x[j]
|
141 |
+
elif j > i:
|
142 |
+
width_output = x[i].shape[-1]
|
143 |
+
height_output = x[i].shape[-2]
|
144 |
+
y = y + F.interpolate(
|
145 |
+
self.fuse_layers[i][j](x[j]),
|
146 |
+
size=[height_output, width_output],
|
147 |
+
mode='bilinear', align_corners=self.align_corners)
|
148 |
+
else:
|
149 |
+
y = y + self.fuse_layers[i][j](x[j])
|
150 |
+
x_fuse.append(self.relu(y))
|
151 |
+
|
152 |
+
return x_fuse
|
153 |
+
|
154 |
+
|
155 |
+
class HighResolutionNet(nn.Module):
|
156 |
+
def __init__(self, width, num_classes, ocr_width=256, small=False,
|
157 |
+
norm_layer=nn.BatchNorm2d, align_corners=True):
|
158 |
+
super(HighResolutionNet, self).__init__()
|
159 |
+
self.norm_layer = norm_layer
|
160 |
+
self.width = width
|
161 |
+
self.ocr_width = ocr_width
|
162 |
+
self.align_corners = align_corners
|
163 |
+
|
164 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
165 |
+
self.bn1 = norm_layer(64)
|
166 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
167 |
+
self.bn2 = norm_layer(64)
|
168 |
+
self.relu = nn.ReLU(inplace=relu_inplace)
|
169 |
+
|
170 |
+
num_blocks = 2 if small else 4
|
171 |
+
|
172 |
+
stage1_num_channels = 64
|
173 |
+
self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
|
174 |
+
stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
|
175 |
+
|
176 |
+
self.stage2_num_branches = 2
|
177 |
+
num_channels = [width, 2 * width]
|
178 |
+
num_inchannels = [
|
179 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
180 |
+
self.transition1 = self._make_transition_layer(
|
181 |
+
[stage1_out_channel], num_inchannels)
|
182 |
+
self.stage2, pre_stage_channels = self._make_stage(
|
183 |
+
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
|
184 |
+
num_blocks=2 * [num_blocks], num_channels=num_channels)
|
185 |
+
|
186 |
+
self.stage3_num_branches = 3
|
187 |
+
num_channels = [width, 2 * width, 4 * width]
|
188 |
+
num_inchannels = [
|
189 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
190 |
+
self.transition2 = self._make_transition_layer(
|
191 |
+
pre_stage_channels, num_inchannels)
|
192 |
+
self.stage3, pre_stage_channels = self._make_stage(
|
193 |
+
BasicBlockV1b, num_inchannels=num_inchannels,
|
194 |
+
num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
|
195 |
+
num_blocks=3 * [num_blocks], num_channels=num_channels)
|
196 |
+
|
197 |
+
self.stage4_num_branches = 4
|
198 |
+
num_channels = [width, 2 * width, 4 * width, 8 * width]
|
199 |
+
num_inchannels = [
|
200 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
201 |
+
self.transition3 = self._make_transition_layer(
|
202 |
+
pre_stage_channels, num_inchannels)
|
203 |
+
self.stage4, pre_stage_channels = self._make_stage(
|
204 |
+
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
|
205 |
+
num_branches=self.stage4_num_branches,
|
206 |
+
num_blocks=4 * [num_blocks], num_channels=num_channels)
|
207 |
+
|
208 |
+
last_inp_channels = np.int(np.sum(pre_stage_channels))
|
209 |
+
if self.ocr_width > 0:
|
210 |
+
ocr_mid_channels = 2 * self.ocr_width
|
211 |
+
ocr_key_channels = self.ocr_width
|
212 |
+
|
213 |
+
self.conv3x3_ocr = nn.Sequential(
|
214 |
+
nn.Conv2d(last_inp_channels, ocr_mid_channels,
|
215 |
+
kernel_size=3, stride=1, padding=1),
|
216 |
+
norm_layer(ocr_mid_channels),
|
217 |
+
nn.ReLU(inplace=relu_inplace),
|
218 |
+
)
|
219 |
+
self.ocr_gather_head = SpatialGather_Module(num_classes)
|
220 |
+
|
221 |
+
self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
|
222 |
+
key_channels=ocr_key_channels,
|
223 |
+
out_channels=ocr_mid_channels,
|
224 |
+
scale=1,
|
225 |
+
dropout=0.05,
|
226 |
+
norm_layer=norm_layer,
|
227 |
+
align_corners=align_corners)
|
228 |
+
self.cls_head = nn.Conv2d(
|
229 |
+
ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
|
230 |
+
|
231 |
+
self.aux_head = nn.Sequential(
|
232 |
+
nn.Conv2d(last_inp_channels, last_inp_channels,
|
233 |
+
kernel_size=1, stride=1, padding=0),
|
234 |
+
norm_layer(last_inp_channels),
|
235 |
+
nn.ReLU(inplace=relu_inplace),
|
236 |
+
nn.Conv2d(last_inp_channels, num_classes,
|
237 |
+
kernel_size=1, stride=1, padding=0, bias=True)
|
238 |
+
)
|
239 |
+
else:
|
240 |
+
self.cls_head = nn.Sequential(
|
241 |
+
nn.Conv2d(last_inp_channels, last_inp_channels,
|
242 |
+
kernel_size=3, stride=1, padding=1),
|
243 |
+
norm_layer(last_inp_channels),
|
244 |
+
nn.ReLU(inplace=relu_inplace),
|
245 |
+
nn.Conv2d(last_inp_channels, num_classes,
|
246 |
+
kernel_size=1, stride=1, padding=0, bias=True)
|
247 |
+
)
|
248 |
+
|
249 |
+
def _make_transition_layer(
|
250 |
+
self, num_channels_pre_layer, num_channels_cur_layer):
|
251 |
+
num_branches_cur = len(num_channels_cur_layer)
|
252 |
+
num_branches_pre = len(num_channels_pre_layer)
|
253 |
+
|
254 |
+
transition_layers = []
|
255 |
+
for i in range(num_branches_cur):
|
256 |
+
if i < num_branches_pre:
|
257 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
258 |
+
transition_layers.append(nn.Sequential(
|
259 |
+
nn.Conv2d(num_channels_pre_layer[i],
|
260 |
+
num_channels_cur_layer[i],
|
261 |
+
kernel_size=3,
|
262 |
+
stride=1,
|
263 |
+
padding=1,
|
264 |
+
bias=False),
|
265 |
+
self.norm_layer(num_channels_cur_layer[i]),
|
266 |
+
nn.ReLU(inplace=relu_inplace)))
|
267 |
+
else:
|
268 |
+
transition_layers.append(None)
|
269 |
+
else:
|
270 |
+
conv3x3s = []
|
271 |
+
for j in range(i + 1 - num_branches_pre):
|
272 |
+
inchannels = num_channels_pre_layer[-1]
|
273 |
+
outchannels = num_channels_cur_layer[i] \
|
274 |
+
if j == i - num_branches_pre else inchannels
|
275 |
+
conv3x3s.append(nn.Sequential(
|
276 |
+
nn.Conv2d(inchannels, outchannels,
|
277 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
278 |
+
self.norm_layer(outchannels),
|
279 |
+
nn.ReLU(inplace=relu_inplace)))
|
280 |
+
transition_layers.append(nn.Sequential(*conv3x3s))
|
281 |
+
|
282 |
+
return nn.ModuleList(transition_layers)
|
283 |
+
|
284 |
+
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
285 |
+
downsample = None
|
286 |
+
if stride != 1 or inplanes != planes * block.expansion:
|
287 |
+
downsample = nn.Sequential(
|
288 |
+
nn.Conv2d(inplanes, planes * block.expansion,
|
289 |
+
kernel_size=1, stride=stride, bias=False),
|
290 |
+
self.norm_layer(planes * block.expansion),
|
291 |
+
)
|
292 |
+
|
293 |
+
layers = []
|
294 |
+
layers.append(block(inplanes, planes, stride,
|
295 |
+
downsample=downsample, norm_layer=self.norm_layer))
|
296 |
+
inplanes = planes * block.expansion
|
297 |
+
for i in range(1, blocks):
|
298 |
+
layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
|
299 |
+
|
300 |
+
return nn.Sequential(*layers)
|
301 |
+
|
302 |
+
def _make_stage(self, block, num_inchannels,
|
303 |
+
num_modules, num_branches, num_blocks, num_channels,
|
304 |
+
fuse_method='SUM',
|
305 |
+
multi_scale_output=True):
|
306 |
+
modules = []
|
307 |
+
for i in range(num_modules):
|
308 |
+
# multi_scale_output is only used last module
|
309 |
+
if not multi_scale_output and i == num_modules - 1:
|
310 |
+
reset_multi_scale_output = False
|
311 |
+
else:
|
312 |
+
reset_multi_scale_output = True
|
313 |
+
modules.append(
|
314 |
+
HighResolutionModule(num_branches,
|
315 |
+
block,
|
316 |
+
num_blocks,
|
317 |
+
num_inchannels,
|
318 |
+
num_channels,
|
319 |
+
fuse_method,
|
320 |
+
reset_multi_scale_output,
|
321 |
+
norm_layer=self.norm_layer,
|
322 |
+
align_corners=self.align_corners)
|
323 |
+
)
|
324 |
+
num_inchannels = modules[-1].get_num_inchannels()
|
325 |
+
|
326 |
+
return nn.Sequential(*modules), num_inchannels
|
327 |
+
|
328 |
+
def forward(self, x, additional_features=None):
|
329 |
+
feats = self.compute_hrnet_feats(x, additional_features)
|
330 |
+
if self.ocr_width > 0:
|
331 |
+
out_aux = self.aux_head(feats)
|
332 |
+
feats = self.conv3x3_ocr(feats)
|
333 |
+
|
334 |
+
context = self.ocr_gather_head(feats, out_aux)
|
335 |
+
feats = self.ocr_distri_head(feats, context)
|
336 |
+
out = self.cls_head(feats)
|
337 |
+
return [out, out_aux]
|
338 |
+
else:
|
339 |
+
return [self.cls_head(feats), None]
|
340 |
+
|
341 |
+
def compute_hrnet_feats(self, x, additional_features):
|
342 |
+
x = self.compute_pre_stage_features(x, additional_features)
|
343 |
+
x = self.layer1(x)
|
344 |
+
|
345 |
+
x_list = []
|
346 |
+
for i in range(self.stage2_num_branches):
|
347 |
+
if self.transition1[i] is not None:
|
348 |
+
x_list.append(self.transition1[i](x))
|
349 |
+
else:
|
350 |
+
x_list.append(x)
|
351 |
+
y_list = self.stage2(x_list)
|
352 |
+
|
353 |
+
x_list = []
|
354 |
+
for i in range(self.stage3_num_branches):
|
355 |
+
if self.transition2[i] is not None:
|
356 |
+
if i < self.stage2_num_branches:
|
357 |
+
x_list.append(self.transition2[i](y_list[i]))
|
358 |
+
else:
|
359 |
+
x_list.append(self.transition2[i](y_list[-1]))
|
360 |
+
else:
|
361 |
+
x_list.append(y_list[i])
|
362 |
+
y_list = self.stage3(x_list)
|
363 |
+
|
364 |
+
x_list = []
|
365 |
+
for i in range(self.stage4_num_branches):
|
366 |
+
if self.transition3[i] is not None:
|
367 |
+
if i < self.stage3_num_branches:
|
368 |
+
x_list.append(self.transition3[i](y_list[i]))
|
369 |
+
else:
|
370 |
+
x_list.append(self.transition3[i](y_list[-1]))
|
371 |
+
else:
|
372 |
+
x_list.append(y_list[i])
|
373 |
+
x = self.stage4(x_list)
|
374 |
+
|
375 |
+
return self.aggregate_hrnet_features(x)
|
376 |
+
|
377 |
+
def compute_pre_stage_features(self, x, additional_features):
|
378 |
+
x = self.conv1(x)
|
379 |
+
x = self.bn1(x)
|
380 |
+
x = self.relu(x)
|
381 |
+
if additional_features is not None:
|
382 |
+
x = x + additional_features
|
383 |
+
x = self.conv2(x)
|
384 |
+
x = self.bn2(x)
|
385 |
+
return self.relu(x)
|
386 |
+
|
387 |
+
def aggregate_hrnet_features(self, x):
|
388 |
+
# Upsampling
|
389 |
+
x0_h, x0_w = x[0].size(2), x[0].size(3)
|
390 |
+
x1 = F.interpolate(x[1], size=(x0_h, x0_w),
|
391 |
+
mode='bilinear', align_corners=self.align_corners)
|
392 |
+
x2 = F.interpolate(x[2], size=(x0_h, x0_w),
|
393 |
+
mode='bilinear', align_corners=self.align_corners)
|
394 |
+
x3 = F.interpolate(x[3], size=(x0_h, x0_w),
|
395 |
+
mode='bilinear', align_corners=self.align_corners)
|
396 |
+
|
397 |
+
return torch.cat([x[0], x1, x2, x3], 1)
|
398 |
+
|
399 |
+
def load_pretrained_weights(self, pretrained_path=''):
|
400 |
+
model_dict = self.state_dict()
|
401 |
+
|
402 |
+
if not os.path.exists(pretrained_path):
|
403 |
+
print(f'\nFile "{pretrained_path}" does not exist.')
|
404 |
+
print('You need to specify the correct path to the pre-trained weights.\n'
|
405 |
+
'You can download the weights for HRNet from the repository:\n'
|
406 |
+
'https://github.com/HRNet/HRNet-Image-Classification')
|
407 |
+
exit(1)
|
408 |
+
pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'})
|
409 |
+
pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
|
410 |
+
pretrained_dict.items()}
|
411 |
+
|
412 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items()
|
413 |
+
if k in model_dict.keys()}
|
414 |
+
|
415 |
+
model_dict.update(pretrained_dict)
|
416 |
+
self.load_state_dict(model_dict)
|
isegm/model/modeling/ocr.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch._utils
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class SpatialGather_Module(nn.Module):
|
8 |
+
"""
|
9 |
+
Aggregate the context features according to the initial
|
10 |
+
predicted probability distribution.
|
11 |
+
Employ the soft-weighted method to aggregate the context.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, cls_num=0, scale=1):
|
15 |
+
super(SpatialGather_Module, self).__init__()
|
16 |
+
self.cls_num = cls_num
|
17 |
+
self.scale = scale
|
18 |
+
|
19 |
+
def forward(self, feats, probs):
|
20 |
+
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
|
21 |
+
probs = probs.view(batch_size, c, -1)
|
22 |
+
feats = feats.view(batch_size, feats.size(1), -1)
|
23 |
+
feats = feats.permute(0, 2, 1) # batch x hw x c
|
24 |
+
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
|
25 |
+
ocr_context = torch.matmul(probs, feats) \
|
26 |
+
.permute(0, 2, 1).unsqueeze(3) # batch x k x c
|
27 |
+
return ocr_context
|
28 |
+
|
29 |
+
|
30 |
+
class SpatialOCR_Module(nn.Module):
|
31 |
+
"""
|
32 |
+
Implementation of the OCR module:
|
33 |
+
We aggregate the global object representation to update the representation for each pixel.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
in_channels,
|
38 |
+
key_channels,
|
39 |
+
out_channels,
|
40 |
+
scale=1,
|
41 |
+
dropout=0.1,
|
42 |
+
norm_layer=nn.BatchNorm2d,
|
43 |
+
align_corners=True):
|
44 |
+
super(SpatialOCR_Module, self).__init__()
|
45 |
+
self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
|
46 |
+
norm_layer, align_corners)
|
47 |
+
_in_channels = 2 * in_channels
|
48 |
+
|
49 |
+
self.conv_bn_dropout = nn.Sequential(
|
50 |
+
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
|
51 |
+
nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
|
52 |
+
nn.Dropout2d(dropout)
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, feats, proxy_feats):
|
56 |
+
context = self.object_context_block(feats, proxy_feats)
|
57 |
+
|
58 |
+
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
|
59 |
+
|
60 |
+
return output
|
61 |
+
|
62 |
+
|
63 |
+
class ObjectAttentionBlock2D(nn.Module):
|
64 |
+
'''
|
65 |
+
The basic implementation for object context block
|
66 |
+
Input:
|
67 |
+
N X C X H X W
|
68 |
+
Parameters:
|
69 |
+
in_channels : the dimension of the input feature map
|
70 |
+
key_channels : the dimension after the key/query transform
|
71 |
+
scale : choose the scale to downsample the input feature maps (save memory cost)
|
72 |
+
bn_type : specify the bn type
|
73 |
+
Return:
|
74 |
+
N X C X H X W
|
75 |
+
'''
|
76 |
+
|
77 |
+
def __init__(self,
|
78 |
+
in_channels,
|
79 |
+
key_channels,
|
80 |
+
scale=1,
|
81 |
+
norm_layer=nn.BatchNorm2d,
|
82 |
+
align_corners=True):
|
83 |
+
super(ObjectAttentionBlock2D, self).__init__()
|
84 |
+
self.scale = scale
|
85 |
+
self.in_channels = in_channels
|
86 |
+
self.key_channels = key_channels
|
87 |
+
self.align_corners = align_corners
|
88 |
+
|
89 |
+
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
|
90 |
+
self.f_pixel = nn.Sequential(
|
91 |
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
92 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
93 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
94 |
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
95 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
96 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
97 |
+
)
|
98 |
+
self.f_object = nn.Sequential(
|
99 |
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
100 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
101 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
102 |
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
103 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
104 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
105 |
+
)
|
106 |
+
self.f_down = nn.Sequential(
|
107 |
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
108 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
109 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
110 |
+
)
|
111 |
+
self.f_up = nn.Sequential(
|
112 |
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
|
113 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
114 |
+
nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, x, proxy):
|
118 |
+
batch_size, h, w = x.size(0), x.size(2), x.size(3)
|
119 |
+
if self.scale > 1:
|
120 |
+
x = self.pool(x)
|
121 |
+
|
122 |
+
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
|
123 |
+
query = query.permute(0, 2, 1)
|
124 |
+
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
|
125 |
+
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
|
126 |
+
value = value.permute(0, 2, 1)
|
127 |
+
|
128 |
+
sim_map = torch.matmul(query, key)
|
129 |
+
sim_map = (self.key_channels ** -.5) * sim_map
|
130 |
+
sim_map = F.softmax(sim_map, dim=-1)
|
131 |
+
|
132 |
+
# add bg context ...
|
133 |
+
context = torch.matmul(sim_map, value)
|
134 |
+
context = context.permute(0, 2, 1).contiguous()
|
135 |
+
context = context.view(batch_size, self.key_channels, *x.size()[2:])
|
136 |
+
context = self.f_up(context)
|
137 |
+
if self.scale > 1:
|
138 |
+
context = F.interpolate(input=context, size=(h, w),
|
139 |
+
mode='bilinear', align_corners=self.align_corners)
|
140 |
+
|
141 |
+
return context
|
isegm/model/modeling/resnet.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
|
3 |
+
|
4 |
+
|
5 |
+
class ResNetBackbone(torch.nn.Module):
|
6 |
+
def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs):
|
7 |
+
super(ResNetBackbone, self).__init__()
|
8 |
+
|
9 |
+
if backbone == 'resnet34':
|
10 |
+
pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs)
|
11 |
+
elif backbone == 'resnet50':
|
12 |
+
pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
|
13 |
+
elif backbone == 'resnet101':
|
14 |
+
pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
|
15 |
+
elif backbone == 'resnet152':
|
16 |
+
pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
|
17 |
+
else:
|
18 |
+
raise RuntimeError(f'unknown backbone: {backbone}')
|
19 |
+
|
20 |
+
self.conv1 = pretrained.conv1
|
21 |
+
self.bn1 = pretrained.bn1
|
22 |
+
self.relu = pretrained.relu
|
23 |
+
self.maxpool = pretrained.maxpool
|
24 |
+
self.layer1 = pretrained.layer1
|
25 |
+
self.layer2 = pretrained.layer2
|
26 |
+
self.layer3 = pretrained.layer3
|
27 |
+
self.layer4 = pretrained.layer4
|
28 |
+
|
29 |
+
def forward(self, x, additional_features=None):
|
30 |
+
x = self.conv1(x)
|
31 |
+
x = self.bn1(x)
|
32 |
+
x = self.relu(x)
|
33 |
+
if additional_features is not None:
|
34 |
+
x = x + torch.nn.functional.pad(additional_features,
|
35 |
+
[0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)],
|
36 |
+
mode='constant', value=0)
|
37 |
+
x = self.maxpool(x)
|
38 |
+
c1 = self.layer1(x)
|
39 |
+
c2 = self.layer2(c1)
|
40 |
+
c3 = self.layer3(c2)
|
41 |
+
c4 = self.layer4(c3)
|
42 |
+
|
43 |
+
return c1, c2, c3, c4
|
isegm/model/modeling/resnetv1b.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
|
4 |
+
|
5 |
+
|
6 |
+
class BasicBlockV1b(nn.Module):
|
7 |
+
expansion = 1
|
8 |
+
|
9 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
|
10 |
+
previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
11 |
+
super(BasicBlockV1b, self).__init__()
|
12 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
|
13 |
+
padding=dilation, dilation=dilation, bias=False)
|
14 |
+
self.bn1 = norm_layer(planes)
|
15 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
|
16 |
+
padding=previous_dilation, dilation=previous_dilation, bias=False)
|
17 |
+
self.bn2 = norm_layer(planes)
|
18 |
+
|
19 |
+
self.relu = nn.ReLU(inplace=True)
|
20 |
+
self.downsample = downsample
|
21 |
+
self.stride = stride
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
residual = x
|
25 |
+
|
26 |
+
out = self.conv1(x)
|
27 |
+
out = self.bn1(out)
|
28 |
+
out = self.relu(out)
|
29 |
+
|
30 |
+
out = self.conv2(out)
|
31 |
+
out = self.bn2(out)
|
32 |
+
|
33 |
+
if self.downsample is not None:
|
34 |
+
residual = self.downsample(x)
|
35 |
+
|
36 |
+
out = out + residual
|
37 |
+
out = self.relu(out)
|
38 |
+
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
class BottleneckV1b(nn.Module):
|
43 |
+
expansion = 4
|
44 |
+
|
45 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
|
46 |
+
previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
47 |
+
super(BottleneckV1b, self).__init__()
|
48 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
49 |
+
self.bn1 = norm_layer(planes)
|
50 |
+
|
51 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
52 |
+
padding=dilation, dilation=dilation, bias=False)
|
53 |
+
self.bn2 = norm_layer(planes)
|
54 |
+
|
55 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
56 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
57 |
+
|
58 |
+
self.relu = nn.ReLU(inplace=True)
|
59 |
+
self.downsample = downsample
|
60 |
+
self.stride = stride
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
residual = x
|
64 |
+
|
65 |
+
out = self.conv1(x)
|
66 |
+
out = self.bn1(out)
|
67 |
+
out = self.relu(out)
|
68 |
+
|
69 |
+
out = self.conv2(out)
|
70 |
+
out = self.bn2(out)
|
71 |
+
out = self.relu(out)
|
72 |
+
|
73 |
+
out = self.conv3(out)
|
74 |
+
out = self.bn3(out)
|
75 |
+
|
76 |
+
if self.downsample is not None:
|
77 |
+
residual = self.downsample(x)
|
78 |
+
|
79 |
+
out = out + residual
|
80 |
+
out = self.relu(out)
|
81 |
+
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
class ResNetV1b(nn.Module):
|
86 |
+
""" Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
|
87 |
+
|
88 |
+
Parameters
|
89 |
+
----------
|
90 |
+
block : Block
|
91 |
+
Class for the residual block. Options are BasicBlockV1, BottleneckV1.
|
92 |
+
layers : list of int
|
93 |
+
Numbers of layers in each block
|
94 |
+
classes : int, default 1000
|
95 |
+
Number of classification classes.
|
96 |
+
dilated : bool, default False
|
97 |
+
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
|
98 |
+
typically used in Semantic Segmentation.
|
99 |
+
norm_layer : object
|
100 |
+
Normalization layer used (default: :class:`nn.BatchNorm2d`)
|
101 |
+
deep_stem : bool, default False
|
102 |
+
Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
|
103 |
+
avg_down : bool, default False
|
104 |
+
Whether to use average pooling for projection skip connection between stages/downsample.
|
105 |
+
final_drop : float, default 0.0
|
106 |
+
Dropout ratio before the final classification layer.
|
107 |
+
|
108 |
+
Reference:
|
109 |
+
- He, Kaiming, et al. "Deep residual learning for image recognition."
|
110 |
+
Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
|
111 |
+
|
112 |
+
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
|
113 |
+
"""
|
114 |
+
def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
|
115 |
+
avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
|
116 |
+
self.inplanes = stem_width*2 if deep_stem else 64
|
117 |
+
super(ResNetV1b, self).__init__()
|
118 |
+
if not deep_stem:
|
119 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
120 |
+
else:
|
121 |
+
self.conv1 = nn.Sequential(
|
122 |
+
nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
|
123 |
+
norm_layer(stem_width),
|
124 |
+
nn.ReLU(True),
|
125 |
+
nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
|
126 |
+
norm_layer(stem_width),
|
127 |
+
nn.ReLU(True),
|
128 |
+
nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
|
129 |
+
)
|
130 |
+
self.bn1 = norm_layer(self.inplanes)
|
131 |
+
self.relu = nn.ReLU(True)
|
132 |
+
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
133 |
+
self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
|
134 |
+
norm_layer=norm_layer)
|
135 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
|
136 |
+
norm_layer=norm_layer)
|
137 |
+
if dilated:
|
138 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
|
139 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
140 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
|
141 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
142 |
+
else:
|
143 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
144 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
145 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
146 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
147 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
148 |
+
self.drop = None
|
149 |
+
if final_drop > 0.0:
|
150 |
+
self.drop = nn.Dropout(final_drop)
|
151 |
+
self.fc = nn.Linear(512 * block.expansion, classes)
|
152 |
+
|
153 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
|
154 |
+
avg_down=False, norm_layer=nn.BatchNorm2d):
|
155 |
+
downsample = None
|
156 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
157 |
+
downsample = []
|
158 |
+
if avg_down:
|
159 |
+
if dilation == 1:
|
160 |
+
downsample.append(
|
161 |
+
nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
downsample.append(
|
165 |
+
nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
|
166 |
+
)
|
167 |
+
downsample.extend([
|
168 |
+
nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
|
169 |
+
kernel_size=1, stride=1, bias=False),
|
170 |
+
norm_layer(planes * block.expansion)
|
171 |
+
])
|
172 |
+
downsample = nn.Sequential(*downsample)
|
173 |
+
else:
|
174 |
+
downsample = nn.Sequential(
|
175 |
+
nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
|
176 |
+
kernel_size=1, stride=stride, bias=False),
|
177 |
+
norm_layer(planes * block.expansion)
|
178 |
+
)
|
179 |
+
|
180 |
+
layers = []
|
181 |
+
if dilation in (1, 2):
|
182 |
+
layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
|
183 |
+
previous_dilation=dilation, norm_layer=norm_layer))
|
184 |
+
elif dilation == 4:
|
185 |
+
layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
|
186 |
+
previous_dilation=dilation, norm_layer=norm_layer))
|
187 |
+
else:
|
188 |
+
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
|
189 |
+
|
190 |
+
self.inplanes = planes * block.expansion
|
191 |
+
for _ in range(1, blocks):
|
192 |
+
layers.append(block(self.inplanes, planes, dilation=dilation,
|
193 |
+
previous_dilation=dilation, norm_layer=norm_layer))
|
194 |
+
|
195 |
+
return nn.Sequential(*layers)
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
x = self.conv1(x)
|
199 |
+
x = self.bn1(x)
|
200 |
+
x = self.relu(x)
|
201 |
+
x = self.maxpool(x)
|
202 |
+
|
203 |
+
x = self.layer1(x)
|
204 |
+
x = self.layer2(x)
|
205 |
+
x = self.layer3(x)
|
206 |
+
x = self.layer4(x)
|
207 |
+
|
208 |
+
x = self.avgpool(x)
|
209 |
+
x = x.view(x.size(0), -1)
|
210 |
+
if self.drop is not None:
|
211 |
+
x = self.drop(x)
|
212 |
+
x = self.fc(x)
|
213 |
+
|
214 |
+
return x
|
215 |
+
|
216 |
+
|
217 |
+
def _safe_state_dict_filtering(orig_dict, model_dict_keys):
|
218 |
+
filtered_orig_dict = {}
|
219 |
+
for k, v in orig_dict.items():
|
220 |
+
if k in model_dict_keys:
|
221 |
+
filtered_orig_dict[k] = v
|
222 |
+
else:
|
223 |
+
print(f"[ERROR] Failed to load <{k}> in backbone")
|
224 |
+
return filtered_orig_dict
|
225 |
+
|
226 |
+
|
227 |
+
def resnet34_v1b(pretrained=False, **kwargs):
|
228 |
+
model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
|
229 |
+
if pretrained:
|
230 |
+
model_dict = model.state_dict()
|
231 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
232 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
|
233 |
+
model_dict.keys()
|
234 |
+
)
|
235 |
+
model_dict.update(filtered_orig_dict)
|
236 |
+
model.load_state_dict(model_dict)
|
237 |
+
return model
|
238 |
+
|
239 |
+
|
240 |
+
def resnet50_v1s(pretrained=False, **kwargs):
|
241 |
+
model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
|
242 |
+
if pretrained:
|
243 |
+
model_dict = model.state_dict()
|
244 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
245 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
|
246 |
+
model_dict.keys()
|
247 |
+
)
|
248 |
+
model_dict.update(filtered_orig_dict)
|
249 |
+
model.load_state_dict(model_dict)
|
250 |
+
return model
|
251 |
+
|
252 |
+
|
253 |
+
def resnet101_v1s(pretrained=False, **kwargs):
|
254 |
+
model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
|
255 |
+
if pretrained:
|
256 |
+
model_dict = model.state_dict()
|
257 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
258 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
|
259 |
+
model_dict.keys()
|
260 |
+
)
|
261 |
+
model_dict.update(filtered_orig_dict)
|
262 |
+
model.load_state_dict(model_dict)
|
263 |
+
return model
|
264 |
+
|
265 |
+
|
266 |
+
def resnet152_v1s(pretrained=False, **kwargs):
|
267 |
+
model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
|
268 |
+
if pretrained:
|
269 |
+
model_dict = model.state_dict()
|
270 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
271 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
|
272 |
+
model_dict.keys()
|
273 |
+
)
|
274 |
+
model_dict.update(filtered_orig_dict)
|
275 |
+
model.load_state_dict(model_dict)
|
276 |
+
return model
|
isegm/model/modifiers.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
class LRMult(object):
|
4 |
+
def __init__(self, lr_mult=1.):
|
5 |
+
self.lr_mult = lr_mult
|
6 |
+
|
7 |
+
def __call__(self, m):
|
8 |
+
if getattr(m, 'weight', None) is not None:
|
9 |
+
m.weight.lr_mult = self.lr_mult
|
10 |
+
if getattr(m, 'bias', None) is not None:
|
11 |
+
m.bias.lr_mult = self.lr_mult
|