Spaces:
Runtime error
Runtime error
Refactor code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +90 -57
- isegm/data/base.py +32 -25
- isegm/data/compose.py +13 -7
- isegm/data/datasets/__init__.py +5 -4
- isegm/data/datasets/ade20k.py +13 -11
- isegm/data/datasets/berkeley.py +3 -1
- isegm/data/datasets/coco.py +21 -17
- isegm/data/datasets/coco_lvis.py +35 -23
- isegm/data/datasets/davis.py +6 -6
- isegm/data/datasets/grabcut.py +13 -7
- isegm/data/datasets/images_dir.py +22 -16
- isegm/data/datasets/lvis.py +27 -24
- isegm/data/datasets/openimages.py +22 -13
- isegm/data/datasets/pascalvoc.py +22 -10
- isegm/data/datasets/sbd.py +31 -23
- isegm/data/points_sampler.py +120 -57
- isegm/data/sample.py +50 -27
- isegm/data/transforms.py +73 -44
- isegm/engine/optimizer.py +10 -8
- isegm/engine/trainer.py +259 -122
- isegm/inference/clicker.py +15 -6
- isegm/inference/evaluation.py +22 -6
- isegm/inference/predictors/__init__.py +78 -56
- isegm/inference/predictors/base.py +49 -24
- isegm/inference/predictors/brs.py +157 -66
- isegm/inference/predictors/brs_functors.py +22 -13
- isegm/inference/predictors/brs_losses.py +7 -5
- isegm/inference/transforms/__init__.py +2 -2
- isegm/inference/transforms/crops.py +20 -9
- isegm/inference/transforms/flip.py +7 -3
- isegm/inference/transforms/zoom_in.py +69 -25
- isegm/inference/utils.py +59 -41
- isegm/model/initializer.py +32 -17
- isegm/model/is_deeplab_model.py +28 -9
- isegm/model/is_hrnet_model.py +19 -6
- isegm/model/is_model.py +86 -31
- isegm/model/losses.py +101 -31
- isegm/model/metrics.py +35 -9
- isegm/model/modeling/basic_blocks.py +68 -22
- isegm/model/modeling/deeplab_v3.py +106 -45
- isegm/model/modeling/hrnet_ocr.py +292 -132
- isegm/model/modeling/ocr.py +91 -45
- isegm/model/modeling/resnet.py +27 -13
- isegm/model/modeling/resnetv1b.py +227 -61
- isegm/model/modifiers.py +3 -5
- isegm/model/ops.py +38 -15
- isegm/utils/cython/__init__.py +1 -1
- isegm/utils/cython/_get_dist_maps.pyx +2 -1
- isegm/utils/cython/dist_maps.py +4 -2
- isegm/utils/distributed.py +11 -2
app.py
CHANGED
@@ -1,94 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
-
import numpy as np
|
4 |
-
import cv2
|
5 |
import wget
|
6 |
-
import os
|
7 |
-
|
8 |
from PIL import Image
|
9 |
from streamlit_drawable_canvas import st_canvas
|
10 |
|
11 |
from isegm.inference import clicker as ck
|
12 |
from isegm.inference import utils
|
13 |
-
from isegm.inference.predictors import get_predictor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
17 |
model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
|
18 |
predictor_params = {"brs_mode": "NoBRS"}
|
19 |
predictor = get_predictor(model, device=device, **predictor_params)
|
20 |
return predictor
|
21 |
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# Items in the sidebar.
|
35 |
-
model = st.sidebar.selectbox("Select a Model:", tuple(
|
36 |
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
|
37 |
-
marking_type = st.sidebar.radio("
|
38 |
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
|
|
|
|
|
39 |
|
40 |
-
|
|
|
|
|
|
|
41 |
with st.spinner("Wait for downloading a model..."):
|
42 |
-
if not os.path.exists(
|
43 |
-
_ = wget.download(f"{
|
44 |
-
|
45 |
with st.spinner("Wait for loading a model..."):
|
46 |
-
predictor = load_model(
|
47 |
|
|
|
|
|
|
|
48 |
# Create a canvas component.
|
49 |
-
if image_path:
|
50 |
-
image = Image.open(image_path).convert("RGB")
|
51 |
-
|
52 |
st.title("Canvas:")
|
53 |
canvas_result = st_canvas(
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
)
|
66 |
|
|
|
|
|
|
|
67 |
# Check the user inputs ans execute predictions.
|
68 |
st.title("Prediction:")
|
69 |
if canvas_result.json_data and canvas_result.json_data["objects"] and image:
|
70 |
-
objects = canvas_result.json_data["objects"]
|
71 |
image_width, image_height = image.size
|
72 |
-
|
73 |
-
|
74 |
-
pos_clicks, neg_clicks = [], []
|
75 |
-
for click in objects:
|
76 |
-
x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
|
77 |
-
x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
|
78 |
-
|
79 |
-
is_positive = click["stroke"] == pos_color
|
80 |
-
click = ck.Click(is_positive=is_positive, coords=(y, x))
|
81 |
-
clicker.add_click(click)
|
82 |
|
83 |
# Run prediction.
|
84 |
-
|
85 |
-
|
86 |
-
init_mask = torch.zeros((1, 1, image_height, image_width), device=device)
|
87 |
-
|
88 |
-
with st.spinner("Wait for prediction..."):
|
89 |
-
pred = predictor.get_prediction(clicker, prev_mask=init_mask)
|
90 |
-
pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
|
91 |
-
pred = np.where(pred > threshold, 1.0, 0)
|
92 |
|
93 |
# Show the prediction result.
|
94 |
st.image(pred, caption="")
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, List
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
import streamlit as st
|
7 |
import torch
|
|
|
|
|
8 |
import wget
|
|
|
|
|
9 |
from PIL import Image
|
10 |
from streamlit_drawable_canvas import st_canvas
|
11 |
|
12 |
from isegm.inference import clicker as ck
|
13 |
from isegm.inference import utils
|
14 |
+
from isegm.inference.predictors import BasePredictor, get_predictor
|
15 |
+
|
16 |
+
###################################
|
17 |
+
# Global scope objects.
|
18 |
+
###################################
|
19 |
+
URL_PREFIX = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
|
20 |
+
CANVAS_HEIGHT, CANVAS_WIDTH = 600, 600
|
21 |
+
POS_COLOR, NEG_COLOR = "#3498DB", "#C70039"
|
22 |
+
ERR_X, ERR_Y = 5.5, 1.0
|
23 |
+
MODELS = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
|
24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
clicker = ck.Clicker()
|
26 |
+
predictor = None
|
27 |
+
image = None
|
28 |
+
|
29 |
|
30 |
+
###################################
|
31 |
+
# Functions.
|
32 |
+
###################################
|
33 |
+
# @st.cache_resource
|
34 |
+
def load_model(model_path: str, device: torch.device) -> BasePredictor:
|
35 |
model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
|
36 |
predictor_params = {"brs_mode": "NoBRS"}
|
37 |
predictor = get_predictor(model, device=device, **predictor_params)
|
38 |
return predictor
|
39 |
|
40 |
|
41 |
+
def feed_clicks(
|
42 |
+
clicker: ck.Clicker,
|
43 |
+
clicks: List[Dict[str, float]],
|
44 |
+
image_width: int,
|
45 |
+
image_height: int,
|
46 |
+
) -> None:
|
47 |
+
ratio_h, ratio_w = image_height / CANVAS_HEIGHT, image_width / CANVAS_WIDTH
|
48 |
+
for click in clicks:
|
49 |
+
x, y = (click["left"] + ERR_X) * ratio_w, (click["top"] + ERR_Y) * ratio_h
|
50 |
+
x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
|
51 |
+
|
52 |
+
is_positive = click["stroke"] == POS_COLOR
|
53 |
+
click = ck.Click(is_positive=is_positive, coords=(y, x))
|
54 |
+
clicker.add_click(click)
|
55 |
|
56 |
+
|
57 |
+
def predict(
|
58 |
+
image: Image, mask: torch.Tensor, threshold: float = 0.5
|
59 |
+
) -> torch.Tensor:
|
60 |
+
predictor.set_input_image(np.array(image))
|
61 |
+
with st.spinner("Wait for prediction..."):
|
62 |
+
pred = predictor.get_prediction(clicker, prev_mask=mask)
|
63 |
+
pred = cv2.resize(
|
64 |
+
pred,
|
65 |
+
dsize=(CANVAS_HEIGHT, CANVAS_WIDTH),
|
66 |
+
interpolation=cv2.INTER_CUBIC,
|
67 |
+
)
|
68 |
+
pred = np.where(pred > threshold, 1.0, 0)
|
69 |
+
return pred
|
70 |
+
|
71 |
+
|
72 |
+
###################################
|
73 |
+
# Sidebar GUI
|
74 |
+
###################################
|
75 |
# Items in the sidebar.
|
76 |
+
model = st.sidebar.selectbox("Select a Model:", tuple(MODELS.keys()))
|
77 |
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
|
78 |
+
marking_type = st.sidebar.radio("Click Type:", ("Positive", "Negative"))
|
79 |
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
|
80 |
+
if image_path:
|
81 |
+
image = Image.open(image_path).convert("RGB")
|
82 |
|
83 |
+
###################################
|
84 |
+
# Preparation
|
85 |
+
###################################
|
86 |
+
# Model.
|
87 |
with st.spinner("Wait for downloading a model..."):
|
88 |
+
if not os.path.exists(MODELS[model]):
|
89 |
+
_ = wget.download(f"{URL_PREFIX}/{MODELS[model]}")
|
90 |
+
# Predictor.
|
91 |
with st.spinner("Wait for loading a model..."):
|
92 |
+
predictor = load_model(MODELS[model], device)
|
93 |
|
94 |
+
###################################
|
95 |
+
# GUI
|
96 |
+
###################################
|
97 |
# Create a canvas component.
|
|
|
|
|
|
|
98 |
st.title("Canvas:")
|
99 |
canvas_result = st_canvas(
|
100 |
+
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
101 |
+
stroke_width=3,
|
102 |
+
stroke_color=POS_COLOR if marking_type == "Positive" else NEG_COLOR,
|
103 |
+
background_color="#eee",
|
104 |
+
background_image=image,
|
105 |
+
update_streamlit=True,
|
106 |
+
drawing_mode="point",
|
107 |
+
point_display_radius=3,
|
108 |
+
key="canvas",
|
109 |
+
width=CANVAS_WIDTH,
|
110 |
+
height=CANVAS_HEIGHT,
|
111 |
)
|
112 |
|
113 |
+
###################################
|
114 |
+
# Prediction
|
115 |
+
###################################
|
116 |
# Check the user inputs ans execute predictions.
|
117 |
st.title("Prediction:")
|
118 |
if canvas_result.json_data and canvas_result.json_data["objects"] and image:
|
|
|
119 |
image_width, image_height = image.size
|
120 |
+
feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
# Run prediction.
|
123 |
+
mask = torch.zeros((1, 1, image_width, image_height), device=device)
|
124 |
+
pred = predict(image, mask, threshold)
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
# Show the prediction result.
|
127 |
st.image(pred, caption="")
|
isegm/data/base.py
CHANGED
@@ -1,22 +1,26 @@
|
|
1 |
-
import random
|
2 |
import pickle
|
|
|
|
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
from torchvision import transforms
|
|
|
6 |
from .points_sampler import MultiPointSampler
|
7 |
from .sample import DSample
|
8 |
|
9 |
|
10 |
class ISDataset(torch.utils.data.dataset.Dataset):
|
11 |
-
def __init__(
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
20 |
super(ISDataset, self).__init__()
|
21 |
self.epoch_len = epoch_len
|
22 |
self.augmentator = augmentator
|
@@ -24,15 +28,19 @@ class ISDataset(torch.utils.data.dataset.Dataset):
|
|
24 |
self.keep_background_prob = keep_background_prob
|
25 |
self.points_sampler = points_sampler
|
26 |
self.with_image_info = with_image_info
|
27 |
-
self.samples_precomputed_scores = self._load_samples_scores(
|
|
|
|
|
28 |
self.to_tensor = transforms.ToTensor()
|
29 |
|
30 |
self.dataset_samples = None
|
31 |
|
32 |
def __getitem__(self, index):
|
33 |
if self.samples_precomputed_scores is not None:
|
34 |
-
index = np.random.choice(
|
35 |
-
|
|
|
|
|
36 |
else:
|
37 |
if self.epoch_len > 0:
|
38 |
index = random.randrange(0, len(self.dataset_samples))
|
@@ -46,13 +54,13 @@ class ISDataset(torch.utils.data.dataset.Dataset):
|
|
46 |
mask = self.points_sampler.selected_mask
|
47 |
|
48 |
output = {
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
}
|
53 |
|
54 |
if self.with_image_info:
|
55 |
-
output[
|
56 |
|
57 |
return output
|
58 |
|
@@ -63,8 +71,10 @@ class ISDataset(torch.utils.data.dataset.Dataset):
|
|
63 |
valid_augmentation = False
|
64 |
while not valid_augmentation:
|
65 |
sample.augment(self.augmentator)
|
66 |
-
keep_sample = (
|
67 |
-
|
|
|
|
|
68 |
valid_augmentation = len(sample) > 0 or keep_sample
|
69 |
|
70 |
return sample
|
@@ -86,14 +96,11 @@ class ISDataset(torch.utils.data.dataset.Dataset):
|
|
86 |
if samples_scores_path is None:
|
87 |
return None
|
88 |
|
89 |
-
with open(samples_scores_path,
|
90 |
images_scores = pickle.load(f)
|
91 |
|
92 |
probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores])
|
93 |
probs /= probs.sum()
|
94 |
-
samples_scores = {
|
95 |
-
|
96 |
-
'probs': probs
|
97 |
-
}
|
98 |
-
print(f'Loaded {len(probs)} weights with gamma={samples_scores_gamma}')
|
99 |
return samples_scores
|
|
|
|
|
1 |
import pickle
|
2 |
+
import random
|
3 |
+
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
from torchvision import transforms
|
7 |
+
|
8 |
from .points_sampler import MultiPointSampler
|
9 |
from .sample import DSample
|
10 |
|
11 |
|
12 |
class ISDataset(torch.utils.data.dataset.Dataset):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
augmentator=None,
|
16 |
+
points_sampler=MultiPointSampler(max_num_points=12),
|
17 |
+
min_object_area=0,
|
18 |
+
keep_background_prob=0.0,
|
19 |
+
with_image_info=False,
|
20 |
+
samples_scores_path=None,
|
21 |
+
samples_scores_gamma=1.0,
|
22 |
+
epoch_len=-1,
|
23 |
+
):
|
24 |
super(ISDataset, self).__init__()
|
25 |
self.epoch_len = epoch_len
|
26 |
self.augmentator = augmentator
|
|
|
28 |
self.keep_background_prob = keep_background_prob
|
29 |
self.points_sampler = points_sampler
|
30 |
self.with_image_info = with_image_info
|
31 |
+
self.samples_precomputed_scores = self._load_samples_scores(
|
32 |
+
samples_scores_path, samples_scores_gamma
|
33 |
+
)
|
34 |
self.to_tensor = transforms.ToTensor()
|
35 |
|
36 |
self.dataset_samples = None
|
37 |
|
38 |
def __getitem__(self, index):
|
39 |
if self.samples_precomputed_scores is not None:
|
40 |
+
index = np.random.choice(
|
41 |
+
self.samples_precomputed_scores["indices"],
|
42 |
+
p=self.samples_precomputed_scores["probs"],
|
43 |
+
)
|
44 |
else:
|
45 |
if self.epoch_len > 0:
|
46 |
index = random.randrange(0, len(self.dataset_samples))
|
|
|
54 |
mask = self.points_sampler.selected_mask
|
55 |
|
56 |
output = {
|
57 |
+
"images": self.to_tensor(sample.image),
|
58 |
+
"points": points.astype(np.float32),
|
59 |
+
"instances": mask,
|
60 |
}
|
61 |
|
62 |
if self.with_image_info:
|
63 |
+
output["image_info"] = sample.sample_id
|
64 |
|
65 |
return output
|
66 |
|
|
|
71 |
valid_augmentation = False
|
72 |
while not valid_augmentation:
|
73 |
sample.augment(self.augmentator)
|
74 |
+
keep_sample = (
|
75 |
+
self.keep_background_prob < 0.0
|
76 |
+
or random.random() < self.keep_background_prob
|
77 |
+
)
|
78 |
valid_augmentation = len(sample) > 0 or keep_sample
|
79 |
|
80 |
return sample
|
|
|
96 |
if samples_scores_path is None:
|
97 |
return None
|
98 |
|
99 |
+
with open(samples_scores_path, "rb") as f:
|
100 |
images_scores = pickle.load(f)
|
101 |
|
102 |
probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores])
|
103 |
probs /= probs.sum()
|
104 |
+
samples_scores = {"indices": [x[0] for x in images_scores], "probs": probs}
|
105 |
+
print(f"Loaded {len(probs)} weights with gamma={samples_scores_gamma}")
|
|
|
|
|
|
|
106 |
return samples_scores
|
isegm/data/compose.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
-
import numpy as np
|
2 |
from math import isclose
|
|
|
|
|
|
|
3 |
from .base import ISDataset
|
4 |
|
5 |
|
@@ -10,7 +12,9 @@ class ComposeDataset(ISDataset):
|
|
10 |
self._datasets = datasets
|
11 |
self.dataset_samples = []
|
12 |
for dataset_indx, dataset in enumerate(self._datasets):
|
13 |
-
self.dataset_samples.extend(
|
|
|
|
|
14 |
|
15 |
def get_sample(self, index):
|
16 |
dataset_indx, sample_indx = self.dataset_samples[index]
|
@@ -21,16 +25,18 @@ class ProportionalComposeDataset(ISDataset):
|
|
21 |
def __init__(self, datasets, ratios, **kwargs):
|
22 |
super().__init__(**kwargs)
|
23 |
|
24 |
-
assert len(ratios) == len(
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
|
29 |
self._ratios = ratios
|
30 |
self._datasets = datasets
|
31 |
self.dataset_samples = []
|
32 |
for dataset_indx, dataset in enumerate(self._datasets):
|
33 |
-
self.dataset_samples.extend(
|
|
|
|
|
34 |
|
35 |
def get_sample(self, index):
|
36 |
dataset_indx = np.random.choice(len(self._datasets), p=self._ratios)
|
|
|
|
|
1 |
from math import isclose
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
from .base import ISDataset
|
6 |
|
7 |
|
|
|
12 |
self._datasets = datasets
|
13 |
self.dataset_samples = []
|
14 |
for dataset_indx, dataset in enumerate(self._datasets):
|
15 |
+
self.dataset_samples.extend(
|
16 |
+
[(dataset_indx, i) for i in range(len(dataset))]
|
17 |
+
)
|
18 |
|
19 |
def get_sample(self, index):
|
20 |
dataset_indx, sample_indx = self.dataset_samples[index]
|
|
|
25 |
def __init__(self, datasets, ratios, **kwargs):
|
26 |
super().__init__(**kwargs)
|
27 |
|
28 |
+
assert len(ratios) == len(
|
29 |
+
datasets
|
30 |
+
), "The number of datasets must match the number of ratios"
|
31 |
+
assert isclose(sum(ratios), 1.0), "The sum of ratios must be equal to 1"
|
32 |
|
33 |
self._ratios = ratios
|
34 |
self._datasets = datasets
|
35 |
self.dataset_samples = []
|
36 |
for dataset_indx, dataset in enumerate(self._datasets):
|
37 |
+
self.dataset_samples.extend(
|
38 |
+
[(dataset_indx, i) for i in range(len(dataset))]
|
39 |
+
)
|
40 |
|
41 |
def get_sample(self, index):
|
42 |
dataset_indx = np.random.choice(len(self._datasets), p=self._ratios)
|
isegm/data/datasets/__init__.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
from isegm.data.compose import ComposeDataset, ProportionalComposeDataset
|
|
|
|
|
2 |
from .berkeley import BerkeleyDataset
|
3 |
from .coco import CocoDataset
|
|
|
4 |
from .davis import DavisDataset
|
5 |
from .grabcut import GrabCutDataset
|
6 |
-
from .
|
7 |
from .lvis import LvisDataset
|
8 |
from .openimages import OpenImagesDataset
|
9 |
-
from .sbd import SBDDataset, SBDEvaluationDataset
|
10 |
-
from .images_dir import ImagesDirDataset
|
11 |
-
from .ade20k import ADE20kDataset
|
12 |
from .pascalvoc import PascalVocDataset
|
|
|
|
1 |
from isegm.data.compose import ComposeDataset, ProportionalComposeDataset
|
2 |
+
|
3 |
+
from .ade20k import ADE20kDataset
|
4 |
from .berkeley import BerkeleyDataset
|
5 |
from .coco import CocoDataset
|
6 |
+
from .coco_lvis import CocoLvisDataset
|
7 |
from .davis import DavisDataset
|
8 |
from .grabcut import GrabCutDataset
|
9 |
+
from .images_dir import ImagesDirDataset
|
10 |
from .lvis import LvisDataset
|
11 |
from .openimages import OpenImagesDataset
|
|
|
|
|
|
|
12 |
from .pascalvoc import PascalVocDataset
|
13 |
+
from .sbd import SBDDataset, SBDEvaluationDataset
|
isegm/data/datasets/ade20k.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import os
|
2 |
-
import random
|
3 |
import pickle as pkl
|
|
|
4 |
from pathlib import Path
|
5 |
|
6 |
import cv2
|
@@ -12,18 +12,18 @@ from isegm.utils.misc import get_labels_with_sizes
|
|
12 |
|
13 |
|
14 |
class ADE20kDataset(ISDataset):
|
15 |
-
def __init__(self, dataset_path, split=
|
16 |
super().__init__(**kwargs)
|
17 |
-
assert split in {
|
18 |
|
19 |
self.dataset_path = Path(dataset_path)
|
20 |
self.dataset_split = split
|
21 |
-
self.dataset_split_folder =
|
22 |
self.stuff_prob = stuff_prob
|
23 |
|
24 |
-
anno_path = self.dataset_path / f
|
25 |
if os.path.exists(anno_path):
|
26 |
-
with anno_path.open(
|
27 |
annotations = pkl.load(f)
|
28 |
else:
|
29 |
raise RuntimeError(f"Can't find annotations at {anno_path}")
|
@@ -34,21 +34,23 @@ class ADE20kDataset(ISDataset):
|
|
34 |
image_id = self.dataset_samples[index]
|
35 |
sample_annos = self.annotations[image_id]
|
36 |
|
37 |
-
image_path = str(self.dataset_path / sample_annos[
|
38 |
image = cv2.imread(image_path)
|
39 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
40 |
|
41 |
# select random mask for an image
|
42 |
-
layer = random.choice(sample_annos[
|
43 |
-
mask_path = str(self.dataset_path / sample_annos[
|
44 |
-
instances_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)[
|
|
|
|
|
45 |
instances_mask = instances_mask.astype(np.int32)
|
46 |
object_ids, _ = get_labels_with_sizes(instances_mask)
|
47 |
|
48 |
if (self.stuff_prob <= 0) or (random.random() > self.stuff_prob):
|
49 |
# remove stuff objects
|
50 |
for i, object_id in enumerate(object_ids):
|
51 |
-
if i in layer[
|
52 |
instances_mask[instances_mask == object_id] = 0
|
53 |
object_ids, _ = get_labels_with_sizes(instances_mask)
|
54 |
|
|
|
1 |
import os
|
|
|
2 |
import pickle as pkl
|
3 |
+
import random
|
4 |
from pathlib import Path
|
5 |
|
6 |
import cv2
|
|
|
12 |
|
13 |
|
14 |
class ADE20kDataset(ISDataset):
|
15 |
+
def __init__(self, dataset_path, split="train", stuff_prob=0.0, **kwargs):
|
16 |
super().__init__(**kwargs)
|
17 |
+
assert split in {"train", "val"}
|
18 |
|
19 |
self.dataset_path = Path(dataset_path)
|
20 |
self.dataset_split = split
|
21 |
+
self.dataset_split_folder = "training" if split == "train" else "validation"
|
22 |
self.stuff_prob = stuff_prob
|
23 |
|
24 |
+
anno_path = self.dataset_path / f"{split}-annotations-object-segmentation.pkl"
|
25 |
if os.path.exists(anno_path):
|
26 |
+
with anno_path.open("rb") as f:
|
27 |
annotations = pkl.load(f)
|
28 |
else:
|
29 |
raise RuntimeError(f"Can't find annotations at {anno_path}")
|
|
|
34 |
image_id = self.dataset_samples[index]
|
35 |
sample_annos = self.annotations[image_id]
|
36 |
|
37 |
+
image_path = str(self.dataset_path / sample_annos["folder"] / f"{image_id}.jpg")
|
38 |
image = cv2.imread(image_path)
|
39 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
40 |
|
41 |
# select random mask for an image
|
42 |
+
layer = random.choice(sample_annos["layers"])
|
43 |
+
mask_path = str(self.dataset_path / sample_annos["folder"] / layer["mask_name"])
|
44 |
+
instances_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)[
|
45 |
+
:, :, 0
|
46 |
+
] # the B channel holds instances
|
47 |
instances_mask = instances_mask.astype(np.int32)
|
48 |
object_ids, _ = get_labels_with_sizes(instances_mask)
|
49 |
|
50 |
if (self.stuff_prob <= 0) or (random.random() > self.stuff_prob):
|
51 |
# remove stuff objects
|
52 |
for i, object_id in enumerate(object_ids):
|
53 |
+
if i in layer["stuff_instances"]:
|
54 |
instances_mask[instances_mask == object_id] = 0
|
55 |
object_ids, _ = get_labels_with_sizes(instances_mask)
|
56 |
|
isegm/data/datasets/berkeley.py
CHANGED
@@ -3,4 +3,6 @@ from .grabcut import GrabCutDataset
|
|
3 |
|
4 |
class BerkeleyDataset(GrabCutDataset):
|
5 |
def __init__(self, dataset_path, **kwargs):
|
6 |
-
super().__init__(
|
|
|
|
|
|
3 |
|
4 |
class BerkeleyDataset(GrabCutDataset):
|
5 |
def __init__(self, dataset_path, **kwargs):
|
6 |
+
super().__init__(
|
7 |
+
dataset_path, images_dir_name="images", masks_dir_name="masks", **kwargs
|
8 |
+
)
|
isegm/data/datasets/coco.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1 |
-
import cv2
|
2 |
import json
|
3 |
import random
|
4 |
-
import numpy as np
|
5 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
6 |
from isegm.data.base import ISDataset
|
7 |
from isegm.data.sample import DSample
|
8 |
|
9 |
|
10 |
class CocoDataset(ISDataset):
|
11 |
-
def __init__(self, dataset_path, split=
|
12 |
super(CocoDataset, self).__init__(**kwargs)
|
13 |
self.split = split
|
14 |
self.dataset_path = Path(dataset_path)
|
@@ -17,26 +19,28 @@ class CocoDataset(ISDataset):
|
|
17 |
self.load_samples()
|
18 |
|
19 |
def load_samples(self):
|
20 |
-
annotation_path =
|
21 |
-
|
|
|
|
|
22 |
self.images_path = self.dataset_path / self.split
|
23 |
|
24 |
-
with open(annotation_path,
|
25 |
annotation = json.load(f)
|
26 |
|
27 |
-
self.dataset_samples = annotation[
|
28 |
|
29 |
-
self._categories = annotation[
|
30 |
-
self._stuff_labels = [x[
|
31 |
-
self._things_labels = [x[
|
32 |
self._things_labels_set = set(self._things_labels)
|
33 |
self._stuff_labels_set = set(self._stuff_labels)
|
34 |
|
35 |
def get_sample(self, index) -> DSample:
|
36 |
dataset_sample = self.dataset_samples[index]
|
37 |
|
38 |
-
image_path = self.images_path / self.get_image_name(dataset_sample[
|
39 |
-
label_path = self.labels_path / dataset_sample[
|
40 |
|
41 |
image = cv2.imread(str(image_path))
|
42 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
@@ -47,11 +51,11 @@ class CocoDataset(ISDataset):
|
|
47 |
things_ids = []
|
48 |
stuff_ids = []
|
49 |
|
50 |
-
for segment in dataset_sample[
|
51 |
-
class_id = segment[
|
52 |
-
obj_id = segment[
|
53 |
if class_id in self._things_labels_set:
|
54 |
-
if segment[
|
55 |
continue
|
56 |
things_ids.append(obj_id)
|
57 |
else:
|
@@ -71,4 +75,4 @@ class CocoDataset(ISDataset):
|
|
71 |
|
72 |
@classmethod
|
73 |
def get_image_name(cls, panoptic_name):
|
74 |
-
return panoptic_name.replace(
|
|
|
|
|
1 |
import json
|
2 |
import random
|
|
|
3 |
from pathlib import Path
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
from isegm.data.base import ISDataset
|
9 |
from isegm.data.sample import DSample
|
10 |
|
11 |
|
12 |
class CocoDataset(ISDataset):
|
13 |
+
def __init__(self, dataset_path, split="train", stuff_prob=0.0, **kwargs):
|
14 |
super(CocoDataset, self).__init__(**kwargs)
|
15 |
self.split = split
|
16 |
self.dataset_path = Path(dataset_path)
|
|
|
19 |
self.load_samples()
|
20 |
|
21 |
def load_samples(self):
|
22 |
+
annotation_path = (
|
23 |
+
self.dataset_path / "annotations" / f"panoptic_{self.split}.json"
|
24 |
+
)
|
25 |
+
self.labels_path = self.dataset_path / "annotations" / f"panoptic_{self.split}"
|
26 |
self.images_path = self.dataset_path / self.split
|
27 |
|
28 |
+
with open(annotation_path, "r") as f:
|
29 |
annotation = json.load(f)
|
30 |
|
31 |
+
self.dataset_samples = annotation["annotations"]
|
32 |
|
33 |
+
self._categories = annotation["categories"]
|
34 |
+
self._stuff_labels = [x["id"] for x in self._categories if x["isthing"] == 0]
|
35 |
+
self._things_labels = [x["id"] for x in self._categories if x["isthing"] == 1]
|
36 |
self._things_labels_set = set(self._things_labels)
|
37 |
self._stuff_labels_set = set(self._stuff_labels)
|
38 |
|
39 |
def get_sample(self, index) -> DSample:
|
40 |
dataset_sample = self.dataset_samples[index]
|
41 |
|
42 |
+
image_path = self.images_path / self.get_image_name(dataset_sample["file_name"])
|
43 |
+
label_path = self.labels_path / dataset_sample["file_name"]
|
44 |
|
45 |
image = cv2.imread(str(image_path))
|
46 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
51 |
things_ids = []
|
52 |
stuff_ids = []
|
53 |
|
54 |
+
for segment in dataset_sample["segments_info"]:
|
55 |
+
class_id = segment["category_id"]
|
56 |
+
obj_id = segment["id"]
|
57 |
if class_id in self._things_labels_set:
|
58 |
+
if segment["iscrowd"] == 1:
|
59 |
continue
|
60 |
things_ids.append(obj_id)
|
61 |
else:
|
|
|
75 |
|
76 |
@classmethod
|
77 |
def get_image_name(cls, panoptic_name):
|
78 |
+
return panoptic_name.replace(".png", ".jpg")
|
isegm/data/datasets/coco_lvis.py
CHANGED
@@ -1,66 +1,78 @@
|
|
1 |
-
|
2 |
import pickle
|
3 |
import random
|
4 |
-
import numpy as np
|
5 |
-
import json
|
6 |
-
import cv2
|
7 |
from copy import deepcopy
|
|
|
|
|
|
|
|
|
|
|
8 |
from isegm.data.base import ISDataset
|
9 |
from isegm.data.sample import DSample
|
10 |
|
11 |
|
12 |
class CocoLvisDataset(ISDataset):
|
13 |
-
def __init__(
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
super(CocoLvisDataset, self).__init__(**kwargs)
|
16 |
dataset_path = Path(dataset_path)
|
17 |
self._split_path = dataset_path / split
|
18 |
self.split = split
|
19 |
-
self._images_path = self._split_path /
|
20 |
-
self._masks_path = self._split_path /
|
21 |
self.stuff_prob = stuff_prob
|
22 |
|
23 |
-
with open(self._split_path / anno_file,
|
24 |
self.dataset_samples = sorted(pickle.load(f).items())
|
25 |
|
26 |
if allow_list_name is not None:
|
27 |
allow_list_path = self._split_path / allow_list_name
|
28 |
-
with open(allow_list_path,
|
29 |
allow_images_ids = json.load(f)
|
30 |
allow_images_ids = set(allow_images_ids)
|
31 |
|
32 |
-
self.dataset_samples = [
|
33 |
-
|
|
|
|
|
|
|
34 |
|
35 |
def get_sample(self, index) -> DSample:
|
36 |
image_id, sample = self.dataset_samples[index]
|
37 |
-
image_path = self._images_path / f
|
38 |
|
39 |
image = cv2.imread(str(image_path))
|
40 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
41 |
|
42 |
-
packed_masks_path = self._masks_path / f
|
43 |
-
with open(packed_masks_path,
|
44 |
encoded_layers, objs_mapping = pickle.load(f)
|
45 |
layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in encoded_layers]
|
46 |
layers = np.stack(layers, axis=2)
|
47 |
|
48 |
-
instances_info = deepcopy(sample[
|
49 |
for inst_id, inst_info in list(instances_info.items()):
|
50 |
if inst_info is None:
|
51 |
-
inst_info = {
|
52 |
instances_info[inst_id] = inst_info
|
53 |
-
inst_info[
|
54 |
|
55 |
if self.stuff_prob > 0 and random.random() < self.stuff_prob:
|
56 |
-
for inst_id in range(sample[
|
57 |
instances_info[inst_id] = {
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
}
|
62 |
else:
|
63 |
-
for inst_id in range(sample[
|
64 |
layer_indx, mask_id = objs_mapping[inst_id]
|
65 |
layers[:, :, layer_indx][layers[:, :, layer_indx] == mask_id] = 0
|
66 |
|
|
|
1 |
+
import json
|
2 |
import pickle
|
3 |
import random
|
|
|
|
|
|
|
4 |
from copy import deepcopy
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
from isegm.data.base import ISDataset
|
11 |
from isegm.data.sample import DSample
|
12 |
|
13 |
|
14 |
class CocoLvisDataset(ISDataset):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
dataset_path,
|
18 |
+
split="train",
|
19 |
+
stuff_prob=0.0,
|
20 |
+
allow_list_name=None,
|
21 |
+
anno_file="hannotation.pickle",
|
22 |
+
**kwargs,
|
23 |
+
):
|
24 |
super(CocoLvisDataset, self).__init__(**kwargs)
|
25 |
dataset_path = Path(dataset_path)
|
26 |
self._split_path = dataset_path / split
|
27 |
self.split = split
|
28 |
+
self._images_path = self._split_path / "images"
|
29 |
+
self._masks_path = self._split_path / "masks"
|
30 |
self.stuff_prob = stuff_prob
|
31 |
|
32 |
+
with open(self._split_path / anno_file, "rb") as f:
|
33 |
self.dataset_samples = sorted(pickle.load(f).items())
|
34 |
|
35 |
if allow_list_name is not None:
|
36 |
allow_list_path = self._split_path / allow_list_name
|
37 |
+
with open(allow_list_path, "r") as f:
|
38 |
allow_images_ids = json.load(f)
|
39 |
allow_images_ids = set(allow_images_ids)
|
40 |
|
41 |
+
self.dataset_samples = [
|
42 |
+
sample
|
43 |
+
for sample in self.dataset_samples
|
44 |
+
if sample[0] in allow_images_ids
|
45 |
+
]
|
46 |
|
47 |
def get_sample(self, index) -> DSample:
|
48 |
image_id, sample = self.dataset_samples[index]
|
49 |
+
image_path = self._images_path / f"{image_id}.jpg"
|
50 |
|
51 |
image = cv2.imread(str(image_path))
|
52 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
53 |
|
54 |
+
packed_masks_path = self._masks_path / f"{image_id}.pickle"
|
55 |
+
with open(packed_masks_path, "rb") as f:
|
56 |
encoded_layers, objs_mapping = pickle.load(f)
|
57 |
layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in encoded_layers]
|
58 |
layers = np.stack(layers, axis=2)
|
59 |
|
60 |
+
instances_info = deepcopy(sample["hierarchy"])
|
61 |
for inst_id, inst_info in list(instances_info.items()):
|
62 |
if inst_info is None:
|
63 |
+
inst_info = {"children": [], "parent": None, "node_level": 0}
|
64 |
instances_info[inst_id] = inst_info
|
65 |
+
inst_info["mapping"] = objs_mapping[inst_id]
|
66 |
|
67 |
if self.stuff_prob > 0 and random.random() < self.stuff_prob:
|
68 |
+
for inst_id in range(sample["num_instance_masks"], len(objs_mapping)):
|
69 |
instances_info[inst_id] = {
|
70 |
+
"mapping": objs_mapping[inst_id],
|
71 |
+
"parent": None,
|
72 |
+
"children": [],
|
73 |
}
|
74 |
else:
|
75 |
+
for inst_id in range(sample["num_instance_masks"], len(objs_mapping)):
|
76 |
layer_indx, mask_id = objs_mapping[inst_id]
|
77 |
layers[:, :, layer_indx][layers[:, :, layer_indx] == mask_id] = 0
|
78 |
|
isegm/data/datasets/davis.py
CHANGED
@@ -8,22 +8,22 @@ from isegm.data.sample import DSample
|
|
8 |
|
9 |
|
10 |
class DavisDataset(ISDataset):
|
11 |
-
def __init__(
|
12 |
-
|
13 |
-
|
14 |
super(DavisDataset, self).__init__(**kwargs)
|
15 |
|
16 |
self.dataset_path = Path(dataset_path)
|
17 |
self._images_path = self.dataset_path / images_dir_name
|
18 |
self._insts_path = self.dataset_path / masks_dir_name
|
19 |
|
20 |
-
self.dataset_samples = [x.name for x in sorted(self._images_path.glob(
|
21 |
-
self._masks_paths = {x.stem: x for x in self._insts_path.glob(
|
22 |
|
23 |
def get_sample(self, index) -> DSample:
|
24 |
image_name = self.dataset_samples[index]
|
25 |
image_path = str(self._images_path / image_name)
|
26 |
-
mask_path = str(self._masks_paths[image_name.split(
|
27 |
|
28 |
image = cv2.imread(image_path)
|
29 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
8 |
|
9 |
|
10 |
class DavisDataset(ISDataset):
|
11 |
+
def __init__(
|
12 |
+
self, dataset_path, images_dir_name="img", masks_dir_name="gt", **kwargs
|
13 |
+
):
|
14 |
super(DavisDataset, self).__init__(**kwargs)
|
15 |
|
16 |
self.dataset_path = Path(dataset_path)
|
17 |
self._images_path = self.dataset_path / images_dir_name
|
18 |
self._insts_path = self.dataset_path / masks_dir_name
|
19 |
|
20 |
+
self.dataset_samples = [x.name for x in sorted(self._images_path.glob("*.*"))]
|
21 |
+
self._masks_paths = {x.stem: x for x in self._insts_path.glob("*.*")}
|
22 |
|
23 |
def get_sample(self, index) -> DSample:
|
24 |
image_name = self.dataset_samples[index]
|
25 |
image_path = str(self._images_path / image_name)
|
26 |
+
mask_path = str(self._masks_paths[image_name.split(".")[0]])
|
27 |
|
28 |
image = cv2.imread(image_path)
|
29 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
isegm/data/datasets/grabcut.py
CHANGED
@@ -8,22 +8,26 @@ from isegm.data.sample import DSample
|
|
8 |
|
9 |
|
10 |
class GrabCutDataset(ISDataset):
|
11 |
-
def __init__(
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
super(GrabCutDataset, self).__init__(**kwargs)
|
15 |
|
16 |
self.dataset_path = Path(dataset_path)
|
17 |
self._images_path = self.dataset_path / images_dir_name
|
18 |
self._insts_path = self.dataset_path / masks_dir_name
|
19 |
|
20 |
-
self.dataset_samples = [x.name for x in sorted(self._images_path.glob(
|
21 |
-
self._masks_paths = {x.stem: x for x in self._insts_path.glob(
|
22 |
|
23 |
def get_sample(self, index) -> DSample:
|
24 |
image_name = self.dataset_samples[index]
|
25 |
image_path = str(self._images_path / image_name)
|
26 |
-
mask_path = str(self._masks_paths[image_name.split(
|
27 |
|
28 |
image = cv2.imread(image_path)
|
29 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
@@ -31,4 +35,6 @@ class GrabCutDataset(ISDataset):
|
|
31 |
instances_mask[instances_mask == 128] = -1
|
32 |
instances_mask[instances_mask > 128] = 1
|
33 |
|
34 |
-
return DSample(
|
|
|
|
|
|
8 |
|
9 |
|
10 |
class GrabCutDataset(ISDataset):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
dataset_path,
|
14 |
+
images_dir_name="data_GT",
|
15 |
+
masks_dir_name="boundary_GT",
|
16 |
+
**kwargs
|
17 |
+
):
|
18 |
super(GrabCutDataset, self).__init__(**kwargs)
|
19 |
|
20 |
self.dataset_path = Path(dataset_path)
|
21 |
self._images_path = self.dataset_path / images_dir_name
|
22 |
self._insts_path = self.dataset_path / masks_dir_name
|
23 |
|
24 |
+
self.dataset_samples = [x.name for x in sorted(self._images_path.glob("*.*"))]
|
25 |
+
self._masks_paths = {x.stem: x for x in self._insts_path.glob("*.*")}
|
26 |
|
27 |
def get_sample(self, index) -> DSample:
|
28 |
image_name = self.dataset_samples[index]
|
29 |
image_path = str(self._images_path / image_name)
|
30 |
+
mask_path = str(self._masks_paths[image_name.split(".")[0]])
|
31 |
|
32 |
image = cv2.imread(image_path)
|
33 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
35 |
instances_mask[instances_mask == 128] = -1
|
36 |
instances_mask[instances_mask > 128] = 1
|
37 |
|
38 |
+
return DSample(
|
39 |
+
image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index
|
40 |
+
)
|
isegm/data/datasets/images_dir.py
CHANGED
@@ -1,49 +1,50 @@
|
|
|
|
|
|
1 |
import cv2
|
2 |
import numpy as np
|
3 |
-
from pathlib import Path
|
4 |
|
5 |
from isegm.data.base import ISDataset
|
6 |
from isegm.data.sample import DSample
|
7 |
|
8 |
|
9 |
class ImagesDirDataset(ISDataset):
|
10 |
-
def __init__(
|
11 |
-
|
12 |
-
|
13 |
super(ImagesDirDataset, self).__init__(**kwargs)
|
14 |
|
15 |
self.dataset_path = Path(dataset_path)
|
16 |
self._images_path = self.dataset_path / images_dir_name
|
17 |
self._insts_path = self.dataset_path / masks_dir_name
|
18 |
|
19 |
-
images_list = [x for x in sorted(self._images_path.glob(
|
20 |
|
21 |
-
samples = {x.stem: {
|
22 |
-
for mask_path in self._insts_path.glob(
|
23 |
mask_name = mask_path.stem
|
24 |
if mask_name in samples:
|
25 |
-
samples[mask_name][
|
26 |
continue
|
27 |
|
28 |
-
mask_name_split = mask_name.split(
|
29 |
if mask_name_split[-1].isdigit():
|
30 |
-
mask_name =
|
31 |
assert mask_name in samples
|
32 |
-
samples[mask_name][
|
33 |
|
34 |
for x in samples.values():
|
35 |
-
assert len(x[
|
36 |
|
37 |
self.dataset_samples = [v for k, v in sorted(samples.items())]
|
38 |
|
39 |
def get_sample(self, index) -> DSample:
|
40 |
sample = self.dataset_samples[index]
|
41 |
-
image_path = str(sample[
|
42 |
|
43 |
objects = []
|
44 |
ignored_regions = []
|
45 |
masks = []
|
46 |
-
for indx, mask_path in enumerate(sample[
|
47 |
gt_mask = cv2.imread(str(mask_path))[:, :, 0].astype(np.int32)
|
48 |
instances_mask = np.zeros_like(gt_mask)
|
49 |
instances_mask[gt_mask == 128] = 2
|
@@ -55,5 +56,10 @@ class ImagesDirDataset(ISDataset):
|
|
55 |
image = cv2.imread(image_path)
|
56 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
57 |
|
58 |
-
return DSample(
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
import cv2
|
4 |
import numpy as np
|
|
|
5 |
|
6 |
from isegm.data.base import ISDataset
|
7 |
from isegm.data.sample import DSample
|
8 |
|
9 |
|
10 |
class ImagesDirDataset(ISDataset):
|
11 |
+
def __init__(
|
12 |
+
self, dataset_path, images_dir_name="images", masks_dir_name="masks", **kwargs
|
13 |
+
):
|
14 |
super(ImagesDirDataset, self).__init__(**kwargs)
|
15 |
|
16 |
self.dataset_path = Path(dataset_path)
|
17 |
self._images_path = self.dataset_path / images_dir_name
|
18 |
self._insts_path = self.dataset_path / masks_dir_name
|
19 |
|
20 |
+
images_list = [x for x in sorted(self._images_path.glob("*.*"))]
|
21 |
|
22 |
+
samples = {x.stem: {"image": x, "masks": []} for x in images_list}
|
23 |
+
for mask_path in self._insts_path.glob("*.*"):
|
24 |
mask_name = mask_path.stem
|
25 |
if mask_name in samples:
|
26 |
+
samples[mask_name]["masks"].append(mask_path)
|
27 |
continue
|
28 |
|
29 |
+
mask_name_split = mask_name.split("_")
|
30 |
if mask_name_split[-1].isdigit():
|
31 |
+
mask_name = "_".join(mask_name_split[:-1])
|
32 |
assert mask_name in samples
|
33 |
+
samples[mask_name]["masks"].append(mask_path)
|
34 |
|
35 |
for x in samples.values():
|
36 |
+
assert len(x["masks"]) > 0, x["image"]
|
37 |
|
38 |
self.dataset_samples = [v for k, v in sorted(samples.items())]
|
39 |
|
40 |
def get_sample(self, index) -> DSample:
|
41 |
sample = self.dataset_samples[index]
|
42 |
+
image_path = str(sample["image"])
|
43 |
|
44 |
objects = []
|
45 |
ignored_regions = []
|
46 |
masks = []
|
47 |
+
for indx, mask_path in enumerate(sample["masks"]):
|
48 |
gt_mask = cv2.imread(str(mask_path))[:, :, 0].astype(np.int32)
|
49 |
instances_mask = np.zeros_like(gt_mask)
|
50 |
instances_mask[gt_mask == 128] = 2
|
|
|
56 |
image = cv2.imread(image_path)
|
57 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
58 |
|
59 |
+
return DSample(
|
60 |
+
image,
|
61 |
+
np.stack(masks, axis=2),
|
62 |
+
objects_ids=objects,
|
63 |
+
ignore_ids=ignored_regions,
|
64 |
+
sample_id=index,
|
65 |
+
)
|
isegm/data/datasets/lvis.py
CHANGED
@@ -11,42 +11,41 @@ from isegm.data.sample import DSample
|
|
11 |
|
12 |
|
13 |
class LvisDataset(ISDataset):
|
14 |
-
def __init__(self, dataset_path, split=
|
15 |
-
max_overlap_ratio=0.5,
|
16 |
-
**kwargs):
|
17 |
super(LvisDataset, self).__init__(**kwargs)
|
18 |
dataset_path = Path(dataset_path)
|
19 |
-
train_categories_path = dataset_path /
|
20 |
-
self._train_path = dataset_path /
|
21 |
-
self._val_path = dataset_path /
|
22 |
|
23 |
self.split = split
|
24 |
self.max_overlap_ratio = max_overlap_ratio
|
25 |
|
26 |
-
with open(
|
27 |
json_annotation = json.loads(f.read())
|
28 |
|
29 |
self.annotations = defaultdict(list)
|
30 |
-
for x in json_annotation[
|
31 |
-
self.annotations[x[
|
32 |
|
33 |
if not train_categories_path.exists():
|
34 |
self.generate_train_categories(dataset_path, train_categories_path)
|
35 |
-
self.dataset_samples = [
|
36 |
-
|
|
|
37 |
|
38 |
def get_sample(self, index) -> DSample:
|
39 |
image_info = self.dataset_samples[index]
|
40 |
-
image_id, image_url = image_info[
|
41 |
-
image_filename = image_url.split(
|
42 |
image_annotations = self.annotations[image_id]
|
43 |
random.shuffle(image_annotations)
|
44 |
|
45 |
# LVISv1 splits do not match older LVIS splits (some images in val may come from COCO train2017)
|
46 |
-
if
|
47 |
-
image_path = self._train_path /
|
48 |
else:
|
49 |
-
image_path = self._val_path /
|
50 |
image = cv2.imread(str(image_path))
|
51 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
52 |
|
@@ -62,9 +61,14 @@ class LvisDataset(ISDataset):
|
|
62 |
instances_mask = np.zeros_like(object_mask, dtype=np.int32)
|
63 |
|
64 |
overlap_ids = np.bincount(instances_mask[object_mask].flatten())
|
65 |
-
overlap_areas = [
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
68 |
if overlap_areas:
|
69 |
overlap_ratio = max(overlap_ratio, max(overlap_areas))
|
70 |
if overlap_ratio > self.max_overlap_ratio:
|
@@ -77,11 +81,10 @@ class LvisDataset(ISDataset):
|
|
77 |
|
78 |
return DSample(image, instances_mask, objects_ids=objects_ids)
|
79 |
|
80 |
-
|
81 |
@staticmethod
|
82 |
def get_mask_from_polygon(annotation, image):
|
83 |
mask = np.zeros(image.shape[:2], dtype=np.int32)
|
84 |
-
for contour_points in annotation[
|
85 |
contour_points = np.array(contour_points).reshape((-1, 2))
|
86 |
contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :]
|
87 |
cv2.fillPoly(mask, contour_points, 1)
|
@@ -90,8 +93,8 @@ class LvisDataset(ISDataset):
|
|
90 |
|
91 |
@staticmethod
|
92 |
def generate_train_categories(dataset_path, train_categories_path):
|
93 |
-
with open(dataset_path /
|
94 |
annotation = json.load(f)
|
95 |
|
96 |
-
with open(train_categories_path,
|
97 |
-
json.dump(annotation[
|
|
|
11 |
|
12 |
|
13 |
class LvisDataset(ISDataset):
|
14 |
+
def __init__(self, dataset_path, split="train", max_overlap_ratio=0.5, **kwargs):
|
|
|
|
|
15 |
super(LvisDataset, self).__init__(**kwargs)
|
16 |
dataset_path = Path(dataset_path)
|
17 |
+
train_categories_path = dataset_path / "train_categories.json"
|
18 |
+
self._train_path = dataset_path / "train"
|
19 |
+
self._val_path = dataset_path / "val"
|
20 |
|
21 |
self.split = split
|
22 |
self.max_overlap_ratio = max_overlap_ratio
|
23 |
|
24 |
+
with open(dataset_path / split / f"lvis_{self.split}.json", "r") as f:
|
25 |
json_annotation = json.loads(f.read())
|
26 |
|
27 |
self.annotations = defaultdict(list)
|
28 |
+
for x in json_annotation["annotations"]:
|
29 |
+
self.annotations[x["image_id"]].append(x)
|
30 |
|
31 |
if not train_categories_path.exists():
|
32 |
self.generate_train_categories(dataset_path, train_categories_path)
|
33 |
+
self.dataset_samples = [
|
34 |
+
x for x in json_annotation["images"] if len(self.annotations[x["id"]]) > 0
|
35 |
+
]
|
36 |
|
37 |
def get_sample(self, index) -> DSample:
|
38 |
image_info = self.dataset_samples[index]
|
39 |
+
image_id, image_url = image_info["id"], image_info["coco_url"]
|
40 |
+
image_filename = image_url.split("/")[-1]
|
41 |
image_annotations = self.annotations[image_id]
|
42 |
random.shuffle(image_annotations)
|
43 |
|
44 |
# LVISv1 splits do not match older LVIS splits (some images in val may come from COCO train2017)
|
45 |
+
if "train2017" in image_url:
|
46 |
+
image_path = self._train_path / "images" / image_filename
|
47 |
else:
|
48 |
+
image_path = self._val_path / "images" / image_filename
|
49 |
image = cv2.imread(str(image_path))
|
50 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
51 |
|
|
|
61 |
instances_mask = np.zeros_like(object_mask, dtype=np.int32)
|
62 |
|
63 |
overlap_ids = np.bincount(instances_mask[object_mask].flatten())
|
64 |
+
overlap_areas = [
|
65 |
+
overlap_area / instances_area[inst_id]
|
66 |
+
for inst_id, overlap_area in enumerate(overlap_ids)
|
67 |
+
if overlap_area > 0 and inst_id > 0
|
68 |
+
]
|
69 |
+
overlap_ratio = (
|
70 |
+
np.logical_and(object_mask, instances_mask > 0).sum() / object_area
|
71 |
+
)
|
72 |
if overlap_areas:
|
73 |
overlap_ratio = max(overlap_ratio, max(overlap_areas))
|
74 |
if overlap_ratio > self.max_overlap_ratio:
|
|
|
81 |
|
82 |
return DSample(image, instances_mask, objects_ids=objects_ids)
|
83 |
|
|
|
84 |
@staticmethod
|
85 |
def get_mask_from_polygon(annotation, image):
|
86 |
mask = np.zeros(image.shape[:2], dtype=np.int32)
|
87 |
+
for contour_points in annotation["segmentation"]:
|
88 |
contour_points = np.array(contour_points).reshape((-1, 2))
|
89 |
contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :]
|
90 |
cv2.fillPoly(mask, contour_points, 1)
|
|
|
93 |
|
94 |
@staticmethod
|
95 |
def generate_train_categories(dataset_path, train_categories_path):
|
96 |
+
with open(dataset_path / "train/lvis_train.json", "r") as f:
|
97 |
annotation = json.load(f)
|
98 |
|
99 |
+
with open(train_categories_path, "w") as f:
|
100 |
+
json.dump(annotation["categories"], f, indent=1)
|
isegm/data/datasets/openimages.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import os
|
2 |
-
import random
|
3 |
import pickle as pkl
|
|
|
4 |
from pathlib import Path
|
5 |
|
6 |
import cv2
|
@@ -11,29 +11,31 @@ from isegm.data.sample import DSample
|
|
11 |
|
12 |
|
13 |
class OpenImagesDataset(ISDataset):
|
14 |
-
def __init__(self, dataset_path, split=
|
15 |
super().__init__(**kwargs)
|
16 |
-
assert split in {
|
17 |
|
18 |
self.dataset_path = Path(dataset_path)
|
19 |
self._split_path = self.dataset_path / split
|
20 |
-
self._images_path = self._split_path /
|
21 |
-
self._masks_path = self._split_path /
|
22 |
self.dataset_split = split
|
23 |
|
24 |
-
clean_anno_path =
|
|
|
|
|
25 |
if os.path.exists(clean_anno_path):
|
26 |
-
with clean_anno_path.open(
|
27 |
annotations = pkl.load(f)
|
28 |
else:
|
29 |
raise RuntimeError(f"Can't find annotations at {clean_anno_path}")
|
30 |
-
self.image_id_to_masks = annotations[
|
31 |
-
self.dataset_samples = annotations[
|
32 |
|
33 |
def get_sample(self, index) -> DSample:
|
34 |
image_id = self.dataset_samples[index]
|
35 |
|
36 |
-
image_path = str(self._images_path / f
|
37 |
image = cv2.imread(image_path)
|
38 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
39 |
|
@@ -49,9 +51,16 @@ class OpenImagesDataset(ISDataset):
|
|
49 |
min_height = min(image.shape[0], instances_mask.shape[0])
|
50 |
|
51 |
if image.shape[0] != min_height or image.shape[1] != min_width:
|
52 |
-
image = cv2.resize(
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
object_ids = [1] if instances_mask.sum() > 0 else []
|
57 |
|
|
|
1 |
import os
|
|
|
2 |
import pickle as pkl
|
3 |
+
import random
|
4 |
from pathlib import Path
|
5 |
|
6 |
import cv2
|
|
|
11 |
|
12 |
|
13 |
class OpenImagesDataset(ISDataset):
|
14 |
+
def __init__(self, dataset_path, split="train", **kwargs):
|
15 |
super().__init__(**kwargs)
|
16 |
+
assert split in {"train", "val", "test"}
|
17 |
|
18 |
self.dataset_path = Path(dataset_path)
|
19 |
self._split_path = self.dataset_path / split
|
20 |
+
self._images_path = self._split_path / "images"
|
21 |
+
self._masks_path = self._split_path / "masks"
|
22 |
self.dataset_split = split
|
23 |
|
24 |
+
clean_anno_path = (
|
25 |
+
self._split_path / f"{split}-annotations-object-segmentation_clean.pkl"
|
26 |
+
)
|
27 |
if os.path.exists(clean_anno_path):
|
28 |
+
with clean_anno_path.open("rb") as f:
|
29 |
annotations = pkl.load(f)
|
30 |
else:
|
31 |
raise RuntimeError(f"Can't find annotations at {clean_anno_path}")
|
32 |
+
self.image_id_to_masks = annotations["image_id_to_masks"]
|
33 |
+
self.dataset_samples = annotations["dataset_samples"]
|
34 |
|
35 |
def get_sample(self, index) -> DSample:
|
36 |
image_id = self.dataset_samples[index]
|
37 |
|
38 |
+
image_path = str(self._images_path / f"{image_id}.jpg")
|
39 |
image = cv2.imread(image_path)
|
40 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
41 |
|
|
|
51 |
min_height = min(image.shape[0], instances_mask.shape[0])
|
52 |
|
53 |
if image.shape[0] != min_height or image.shape[1] != min_width:
|
54 |
+
image = cv2.resize(
|
55 |
+
image, (min_width, min_height), interpolation=cv2.INTER_LINEAR
|
56 |
+
)
|
57 |
+
if (
|
58 |
+
instances_mask.shape[0] != min_height
|
59 |
+
or instances_mask.shape[1] != min_width
|
60 |
+
):
|
61 |
+
instances_mask = cv2.resize(
|
62 |
+
instances_mask, (min_width, min_height), interpolation=cv2.INTER_NEAREST
|
63 |
+
)
|
64 |
|
65 |
object_ids = [1] if instances_mask.sum() > 0 else []
|
66 |
|
isegm/data/datasets/pascalvoc.py
CHANGED
@@ -9,32 +9,38 @@ from isegm.data.sample import DSample
|
|
9 |
|
10 |
|
11 |
class PascalVocDataset(ISDataset):
|
12 |
-
def __init__(self, dataset_path, split=
|
13 |
super().__init__(**kwargs)
|
14 |
-
assert split in {
|
15 |
|
16 |
self.dataset_path = Path(dataset_path)
|
17 |
self._images_path = self.dataset_path / "JPEGImages"
|
18 |
self._insts_path = self.dataset_path / "SegmentationObject"
|
19 |
self.dataset_split = split
|
20 |
|
21 |
-
if split ==
|
22 |
-
with open(
|
|
|
|
|
23 |
self.dataset_samples, self.instance_ids = pkl.load(f)
|
24 |
else:
|
25 |
-
with open(
|
|
|
|
|
26 |
self.dataset_samples = [name.strip() for name in f.readlines()]
|
27 |
|
28 |
def get_sample(self, index) -> DSample:
|
29 |
sample_id = self.dataset_samples[index]
|
30 |
-
image_path = str(self._images_path / f
|
31 |
-
mask_path = str(self._insts_path / f
|
32 |
|
33 |
image = cv2.imread(image_path)
|
34 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
35 |
instances_mask = cv2.imread(mask_path)
|
36 |
-
instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(
|
37 |
-
|
|
|
|
|
38 |
instance_id = self.instance_ids[index]
|
39 |
mask = np.zeros_like(instances_mask)
|
40 |
mask[instances_mask == 220] = 220 # ignored area
|
@@ -45,4 +51,10 @@ class PascalVocDataset(ISDataset):
|
|
45 |
objects_ids = np.unique(instances_mask)
|
46 |
objects_ids = [x for x in objects_ids if x != 0 and x != 220]
|
47 |
|
48 |
-
return DSample(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
class PascalVocDataset(ISDataset):
|
12 |
+
def __init__(self, dataset_path, split="train", **kwargs):
|
13 |
super().__init__(**kwargs)
|
14 |
+
assert split in {"train", "val", "trainval", "test"}
|
15 |
|
16 |
self.dataset_path = Path(dataset_path)
|
17 |
self._images_path = self.dataset_path / "JPEGImages"
|
18 |
self._insts_path = self.dataset_path / "SegmentationObject"
|
19 |
self.dataset_split = split
|
20 |
|
21 |
+
if split == "test":
|
22 |
+
with open(
|
23 |
+
self.dataset_path / f"ImageSets/Segmentation/test.pickle", "rb"
|
24 |
+
) as f:
|
25 |
self.dataset_samples, self.instance_ids = pkl.load(f)
|
26 |
else:
|
27 |
+
with open(
|
28 |
+
self.dataset_path / f"ImageSets/Segmentation/{split}.txt", "r"
|
29 |
+
) as f:
|
30 |
self.dataset_samples = [name.strip() for name in f.readlines()]
|
31 |
|
32 |
def get_sample(self, index) -> DSample:
|
33 |
sample_id = self.dataset_samples[index]
|
34 |
+
image_path = str(self._images_path / f"{sample_id}.jpg")
|
35 |
+
mask_path = str(self._insts_path / f"{sample_id}.png")
|
36 |
|
37 |
image = cv2.imread(image_path)
|
38 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
39 |
instances_mask = cv2.imread(mask_path)
|
40 |
+
instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(
|
41 |
+
np.int32
|
42 |
+
)
|
43 |
+
if self.dataset_split == "test":
|
44 |
instance_id = self.instance_ids[index]
|
45 |
mask = np.zeros_like(instances_mask)
|
46 |
mask[instances_mask == 220] = 220 # ignored area
|
|
|
51 |
objects_ids = np.unique(instances_mask)
|
52 |
objects_ids = [x for x in objects_ids if x != 0 and x != 220]
|
53 |
|
54 |
+
return DSample(
|
55 |
+
image,
|
56 |
+
instances_mask,
|
57 |
+
objects_ids=objects_ids,
|
58 |
+
ignore_ids=[220],
|
59 |
+
sample_id=index,
|
60 |
+
)
|
isegm/data/datasets/sbd.py
CHANGED
@@ -5,38 +5,42 @@ import cv2
|
|
5 |
import numpy as np
|
6 |
from scipy.io import loadmat
|
7 |
|
8 |
-
from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes
|
9 |
from isegm.data.base import ISDataset
|
10 |
from isegm.data.sample import DSample
|
|
|
11 |
|
12 |
|
13 |
class SBDDataset(ISDataset):
|
14 |
-
def __init__(self, dataset_path, split=
|
15 |
super(SBDDataset, self).__init__(**kwargs)
|
16 |
-
assert split in {
|
17 |
|
18 |
self.dataset_path = Path(dataset_path)
|
19 |
self.dataset_split = split
|
20 |
-
self._images_path = self.dataset_path /
|
21 |
-
self._insts_path = self.dataset_path /
|
22 |
self._buggy_objects = dict()
|
23 |
self._buggy_mask_thresh = buggy_mask_thresh
|
24 |
|
25 |
-
with open(self.dataset_path / f
|
26 |
self.dataset_samples = [x.strip() for x in f.readlines()]
|
27 |
|
28 |
def get_sample(self, index):
|
29 |
image_name = self.dataset_samples[index]
|
30 |
-
image_path = str(self._images_path / f
|
31 |
-
inst_info_path = str(self._insts_path / f
|
32 |
|
33 |
image = cv2.imread(image_path)
|
34 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
35 |
-
instances_mask = loadmat(str(inst_info_path))[
|
|
|
|
|
36 |
instances_mask = self.remove_buggy_masks(index, instances_mask)
|
37 |
instances_ids, _ = get_labels_with_sizes(instances_mask)
|
38 |
|
39 |
-
return DSample(
|
|
|
|
|
40 |
|
41 |
def remove_buggy_masks(self, index, instances_mask):
|
42 |
if self._buggy_mask_thresh > 0.0:
|
@@ -61,51 +65,55 @@ class SBDDataset(ISDataset):
|
|
61 |
|
62 |
|
63 |
class SBDEvaluationDataset(ISDataset):
|
64 |
-
def __init__(self, dataset_path, split=
|
65 |
super(SBDEvaluationDataset, self).__init__(**kwargs)
|
66 |
-
assert split in {
|
67 |
|
68 |
self.dataset_path = Path(dataset_path)
|
69 |
self.dataset_split = split
|
70 |
-
self._images_path = self.dataset_path /
|
71 |
-
self._insts_path = self.dataset_path /
|
72 |
|
73 |
-
with open(self.dataset_path / f
|
74 |
self.dataset_samples = [x.strip() for x in f.readlines()]
|
75 |
|
76 |
self.dataset_samples = self.get_sbd_images_and_ids_list()
|
77 |
|
78 |
def get_sample(self, index) -> DSample:
|
79 |
image_name, instance_id = self.dataset_samples[index]
|
80 |
-
image_path = str(self._images_path / f
|
81 |
-
inst_info_path = str(self._insts_path / f
|
82 |
|
83 |
image = cv2.imread(image_path)
|
84 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
85 |
-
instances_mask = loadmat(str(inst_info_path))[
|
|
|
|
|
86 |
instances_mask[instances_mask != instance_id] = 0
|
87 |
instances_mask[instances_mask > 0] = 1
|
88 |
|
89 |
return DSample(image, instances_mask, objects_ids=[1], sample_id=index)
|
90 |
|
91 |
def get_sbd_images_and_ids_list(self):
|
92 |
-
pkl_path = self.dataset_path / f
|
93 |
|
94 |
if pkl_path.exists():
|
95 |
-
with open(str(pkl_path),
|
96 |
images_and_ids_list = pkl.load(fp)
|
97 |
else:
|
98 |
images_and_ids_list = []
|
99 |
|
100 |
for sample in self.dataset_samples:
|
101 |
-
inst_info_path = str(self._insts_path / f
|
102 |
-
instances_mask = loadmat(str(inst_info_path))[
|
|
|
|
|
103 |
instances_ids, _ = get_labels_with_sizes(instances_mask)
|
104 |
|
105 |
for instances_id in instances_ids:
|
106 |
images_and_ids_list.append((sample, instances_id))
|
107 |
|
108 |
-
with open(str(pkl_path),
|
109 |
pkl.dump(images_and_ids_list, fp)
|
110 |
|
111 |
return images_and_ids_list
|
|
|
5 |
import numpy as np
|
6 |
from scipy.io import loadmat
|
7 |
|
|
|
8 |
from isegm.data.base import ISDataset
|
9 |
from isegm.data.sample import DSample
|
10 |
+
from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes
|
11 |
|
12 |
|
13 |
class SBDDataset(ISDataset):
|
14 |
+
def __init__(self, dataset_path, split="train", buggy_mask_thresh=0.08, **kwargs):
|
15 |
super(SBDDataset, self).__init__(**kwargs)
|
16 |
+
assert split in {"train", "val"}
|
17 |
|
18 |
self.dataset_path = Path(dataset_path)
|
19 |
self.dataset_split = split
|
20 |
+
self._images_path = self.dataset_path / "img"
|
21 |
+
self._insts_path = self.dataset_path / "inst"
|
22 |
self._buggy_objects = dict()
|
23 |
self._buggy_mask_thresh = buggy_mask_thresh
|
24 |
|
25 |
+
with open(self.dataset_path / f"{split}.txt", "r") as f:
|
26 |
self.dataset_samples = [x.strip() for x in f.readlines()]
|
27 |
|
28 |
def get_sample(self, index):
|
29 |
image_name = self.dataset_samples[index]
|
30 |
+
image_path = str(self._images_path / f"{image_name}.jpg")
|
31 |
+
inst_info_path = str(self._insts_path / f"{image_name}.mat")
|
32 |
|
33 |
image = cv2.imread(image_path)
|
34 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
35 |
+
instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype(
|
36 |
+
np.int32
|
37 |
+
)
|
38 |
instances_mask = self.remove_buggy_masks(index, instances_mask)
|
39 |
instances_ids, _ = get_labels_with_sizes(instances_mask)
|
40 |
|
41 |
+
return DSample(
|
42 |
+
image, instances_mask, objects_ids=instances_ids, sample_id=index
|
43 |
+
)
|
44 |
|
45 |
def remove_buggy_masks(self, index, instances_mask):
|
46 |
if self._buggy_mask_thresh > 0.0:
|
|
|
65 |
|
66 |
|
67 |
class SBDEvaluationDataset(ISDataset):
|
68 |
+
def __init__(self, dataset_path, split="val", **kwargs):
|
69 |
super(SBDEvaluationDataset, self).__init__(**kwargs)
|
70 |
+
assert split in {"train", "val"}
|
71 |
|
72 |
self.dataset_path = Path(dataset_path)
|
73 |
self.dataset_split = split
|
74 |
+
self._images_path = self.dataset_path / "img"
|
75 |
+
self._insts_path = self.dataset_path / "inst"
|
76 |
|
77 |
+
with open(self.dataset_path / f"{split}.txt", "r") as f:
|
78 |
self.dataset_samples = [x.strip() for x in f.readlines()]
|
79 |
|
80 |
self.dataset_samples = self.get_sbd_images_and_ids_list()
|
81 |
|
82 |
def get_sample(self, index) -> DSample:
|
83 |
image_name, instance_id = self.dataset_samples[index]
|
84 |
+
image_path = str(self._images_path / f"{image_name}.jpg")
|
85 |
+
inst_info_path = str(self._insts_path / f"{image_name}.mat")
|
86 |
|
87 |
image = cv2.imread(image_path)
|
88 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
89 |
+
instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype(
|
90 |
+
np.int32
|
91 |
+
)
|
92 |
instances_mask[instances_mask != instance_id] = 0
|
93 |
instances_mask[instances_mask > 0] = 1
|
94 |
|
95 |
return DSample(image, instances_mask, objects_ids=[1], sample_id=index)
|
96 |
|
97 |
def get_sbd_images_and_ids_list(self):
|
98 |
+
pkl_path = self.dataset_path / f"{self.dataset_split}_images_and_ids_list.pkl"
|
99 |
|
100 |
if pkl_path.exists():
|
101 |
+
with open(str(pkl_path), "rb") as fp:
|
102 |
images_and_ids_list = pkl.load(fp)
|
103 |
else:
|
104 |
images_and_ids_list = []
|
105 |
|
106 |
for sample in self.dataset_samples:
|
107 |
+
inst_info_path = str(self._insts_path / f"{sample}.mat")
|
108 |
+
instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype(
|
109 |
+
np.int32
|
110 |
+
)
|
111 |
instances_ids, _ = get_labels_with_sizes(instances_mask)
|
112 |
|
113 |
for instances_id in instances_ids:
|
114 |
images_and_ids_list.append((sample, instances_id))
|
115 |
|
116 |
+
with open(str(pkl_path), "wb") as fp:
|
117 |
pkl.dump(images_and_ids_list, fp)
|
118 |
|
119 |
return images_and_ids_list
|
isegm/data/points_sampler.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
-
import cv2
|
2 |
import math
|
3 |
import random
|
4 |
-
import numpy as np
|
5 |
from functools import lru_cache
|
|
|
|
|
|
|
|
|
6 |
from .sample import DSample
|
7 |
|
8 |
|
@@ -28,13 +30,25 @@ class BasePointSampler:
|
|
28 |
|
29 |
|
30 |
class MultiPointSampler(BasePointSampler):
|
31 |
-
def __init__(
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
super().__init__()
|
39 |
self.max_num_points = max_num_points
|
40 |
self.expand_ratio = expand_ratio
|
@@ -52,8 +66,12 @@ class MultiPointSampler(BasePointSampler):
|
|
52 |
max_num_merged_objects = max_num_points
|
53 |
self.max_num_merged_objects = max_num_merged_objects
|
54 |
|
55 |
-
self.neg_strategies = [
|
56 |
-
self.neg_strategies_prob = [
|
|
|
|
|
|
|
|
|
57 |
assert math.isclose(sum(self.neg_strategies_prob), 1.0)
|
58 |
|
59 |
self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma)
|
@@ -66,7 +84,7 @@ class MultiPointSampler(BasePointSampler):
|
|
66 |
self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32)
|
67 |
self._selected_masks = [[]]
|
68 |
self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies}
|
69 |
-
self._neg_masks[
|
70 |
return
|
71 |
|
72 |
gt_mask, pos_masks, neg_masks = self._sample_mask(sample)
|
@@ -80,14 +98,16 @@ class MultiPointSampler(BasePointSampler):
|
|
80 |
if len(sample) <= len(self._selected_masks):
|
81 |
neg_mask_other = neg_mask_bg
|
82 |
else:
|
83 |
-
neg_mask_other = np.logical_and(
|
84 |
-
|
|
|
|
|
85 |
|
86 |
self._neg_masks = {
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
}
|
92 |
|
93 |
def _sample_mask(self, sample: DSample):
|
@@ -104,7 +124,11 @@ class MultiPointSampler(BasePointSampler):
|
|
104 |
pos_segments = []
|
105 |
neg_segments = []
|
106 |
for obj_id in random_ids:
|
107 |
-
|
|
|
|
|
|
|
|
|
108 |
if gt_mask is None:
|
109 |
gt_mask = obj_gt_mask
|
110 |
else:
|
@@ -123,35 +147,45 @@ class MultiPointSampler(BasePointSampler):
|
|
123 |
|
124 |
if not self.use_hierarchy:
|
125 |
node_mask = sample.get_object_mask(obj_id)
|
126 |
-
gt_mask =
|
|
|
|
|
127 |
return gt_mask, [node_mask], []
|
128 |
|
129 |
def _select_node(node_id):
|
130 |
node_info = objs_tree[node_id]
|
131 |
-
if not node_info[
|
132 |
return node_id
|
133 |
-
return _select_node(random.choice(node_info[
|
134 |
|
135 |
selected_node = _select_node(obj_id)
|
136 |
node_info = objs_tree[selected_node]
|
137 |
node_mask = sample.get_object_mask(selected_node)
|
138 |
-
gt_mask =
|
|
|
|
|
|
|
|
|
139 |
pos_mask = node_mask.copy()
|
140 |
|
141 |
negative_segments = []
|
142 |
-
if node_info[
|
143 |
-
parent_mask = sample.get_object_mask(node_info[
|
144 |
-
negative_segments.append(
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
148 |
child_mask = sample.get_object_mask(child_id)
|
149 |
pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask))
|
150 |
|
151 |
-
if node_info[
|
152 |
-
max_disabled_children = min(len(node_info[
|
153 |
num_disabled_children = np.random.randint(0, max_disabled_children + 1)
|
154 |
-
disabled_children = random.sample(
|
|
|
|
|
155 |
|
156 |
for child_id in disabled_children:
|
157 |
child_mask = sample.get_object_mask(child_id)
|
@@ -167,24 +201,32 @@ class MultiPointSampler(BasePointSampler):
|
|
167 |
|
168 |
def sample_points(self):
|
169 |
assert self._selected_mask is not None
|
170 |
-
pos_points = self._multi_mask_sample_points(
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
return pos_points + neg_points
|
181 |
|
182 |
-
def _multi_mask_sample_points(
|
183 |
-
selected_masks =
|
|
|
|
|
184 |
|
185 |
each_obj_points = [
|
186 |
-
self._sample_points(
|
187 |
-
|
|
|
188 |
for i, mask in enumerate(selected_masks)
|
189 |
]
|
190 |
each_obj_points = [x for x in each_obj_points if len(x) > 0]
|
@@ -200,17 +242,27 @@ class MultiPointSampler(BasePointSampler):
|
|
200 |
|
201 |
aggregated_masks_with_prob = []
|
202 |
for indx, x in enumerate(selected_masks):
|
203 |
-
if
|
|
|
|
|
|
|
|
|
204 |
for t, prob in x:
|
205 |
-
aggregated_masks_with_prob.append(
|
|
|
|
|
206 |
else:
|
207 |
aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks)))
|
208 |
|
209 |
-
other_points_union = self._sample_points(
|
|
|
|
|
210 |
if len(other_points_union) + len(points) <= self.max_num_points:
|
211 |
points.extend(other_points_union)
|
212 |
else:
|
213 |
-
points.extend(
|
|
|
|
|
214 |
|
215 |
if len(points) < self.max_num_points:
|
216 |
points.extend([(-1, -1, -1)] * (self.max_num_points - len(points)))
|
@@ -219,9 +271,13 @@ class MultiPointSampler(BasePointSampler):
|
|
219 |
|
220 |
def _sample_points(self, mask, is_negative=False, with_first_click=False):
|
221 |
if is_negative:
|
222 |
-
num_points = np.random.choice(
|
|
|
|
|
223 |
else:
|
224 |
-
num_points = 1 + np.random.choice(
|
|
|
|
|
225 |
|
226 |
indices_probs = None
|
227 |
if isinstance(mask, (list, tuple)):
|
@@ -237,9 +293,13 @@ class MultiPointSampler(BasePointSampler):
|
|
237 |
first_click = with_first_click and j == 0 and indices_probs is None
|
238 |
|
239 |
if first_click:
|
240 |
-
point_indices = get_point_candidates(
|
|
|
|
|
241 |
elif indices_probs:
|
242 |
-
point_indices_indx = np.random.choice(
|
|
|
|
|
243 |
point_indices = indices[point_indices_indx][0]
|
244 |
else:
|
245 |
point_indices = indices
|
@@ -247,7 +307,9 @@ class MultiPointSampler(BasePointSampler):
|
|
247 |
num_indices = len(point_indices)
|
248 |
if num_indices > 0:
|
249 |
point_indx = 0 if first_click else 100
|
250 |
-
click = point_indices[np.random.randint(0, num_indices)].tolist() + [
|
|
|
|
|
251 |
points.append(click)
|
252 |
|
253 |
return points
|
@@ -257,8 +319,9 @@ class MultiPointSampler(BasePointSampler):
|
|
257 |
return mask
|
258 |
|
259 |
kernel = np.ones((3, 3), np.uint8)
|
260 |
-
eroded_mask = cv2.erode(
|
261 |
-
|
|
|
262 |
|
263 |
if eroded_mask.sum() > 10:
|
264 |
return eroded_mask
|
@@ -291,7 +354,7 @@ def get_point_candidates(obj_mask, k=1.7, full_prob=0.0):
|
|
291 |
if full_prob > 0 and random.random() < full_prob:
|
292 |
return obj_mask
|
293 |
|
294 |
-
padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)),
|
295 |
|
296 |
dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1]
|
297 |
if k > 0:
|
|
|
|
|
1 |
import math
|
2 |
import random
|
|
|
3 |
from functools import lru_cache
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
from .sample import DSample
|
9 |
|
10 |
|
|
|
30 |
|
31 |
|
32 |
class MultiPointSampler(BasePointSampler):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
max_num_points,
|
36 |
+
prob_gamma=0.7,
|
37 |
+
expand_ratio=0.1,
|
38 |
+
positive_erode_prob=0.9,
|
39 |
+
positive_erode_iters=3,
|
40 |
+
negative_bg_prob=0.1,
|
41 |
+
negative_other_prob=0.4,
|
42 |
+
negative_border_prob=0.5,
|
43 |
+
merge_objects_prob=0.0,
|
44 |
+
max_num_merged_objects=2,
|
45 |
+
use_hierarchy=False,
|
46 |
+
soft_targets=False,
|
47 |
+
first_click_center=False,
|
48 |
+
only_one_first_click=False,
|
49 |
+
sfc_inner_k=1.7,
|
50 |
+
sfc_full_inner_prob=0.0,
|
51 |
+
):
|
52 |
super().__init__()
|
53 |
self.max_num_points = max_num_points
|
54 |
self.expand_ratio = expand_ratio
|
|
|
66 |
max_num_merged_objects = max_num_points
|
67 |
self.max_num_merged_objects = max_num_merged_objects
|
68 |
|
69 |
+
self.neg_strategies = ["bg", "other", "border"]
|
70 |
+
self.neg_strategies_prob = [
|
71 |
+
negative_bg_prob,
|
72 |
+
negative_other_prob,
|
73 |
+
negative_border_prob,
|
74 |
+
]
|
75 |
assert math.isclose(sum(self.neg_strategies_prob), 1.0)
|
76 |
|
77 |
self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma)
|
|
|
84 |
self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32)
|
85 |
self._selected_masks = [[]]
|
86 |
self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies}
|
87 |
+
self._neg_masks["required"] = []
|
88 |
return
|
89 |
|
90 |
gt_mask, pos_masks, neg_masks = self._sample_mask(sample)
|
|
|
98 |
if len(sample) <= len(self._selected_masks):
|
99 |
neg_mask_other = neg_mask_bg
|
100 |
else:
|
101 |
+
neg_mask_other = np.logical_and(
|
102 |
+
np.logical_not(sample.get_background_mask()),
|
103 |
+
np.logical_not(binary_gt_mask),
|
104 |
+
)
|
105 |
|
106 |
self._neg_masks = {
|
107 |
+
"bg": neg_mask_bg,
|
108 |
+
"other": neg_mask_other,
|
109 |
+
"border": neg_mask_border,
|
110 |
+
"required": neg_masks,
|
111 |
}
|
112 |
|
113 |
def _sample_mask(self, sample: DSample):
|
|
|
124 |
pos_segments = []
|
125 |
neg_segments = []
|
126 |
for obj_id in random_ids:
|
127 |
+
(
|
128 |
+
obj_gt_mask,
|
129 |
+
obj_pos_segments,
|
130 |
+
obj_neg_segments,
|
131 |
+
) = self._sample_from_masks_layer(obj_id, sample)
|
132 |
if gt_mask is None:
|
133 |
gt_mask = obj_gt_mask
|
134 |
else:
|
|
|
147 |
|
148 |
if not self.use_hierarchy:
|
149 |
node_mask = sample.get_object_mask(obj_id)
|
150 |
+
gt_mask = (
|
151 |
+
sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask
|
152 |
+
)
|
153 |
return gt_mask, [node_mask], []
|
154 |
|
155 |
def _select_node(node_id):
|
156 |
node_info = objs_tree[node_id]
|
157 |
+
if not node_info["children"] or random.random() < 0.5:
|
158 |
return node_id
|
159 |
+
return _select_node(random.choice(node_info["children"]))
|
160 |
|
161 |
selected_node = _select_node(obj_id)
|
162 |
node_info = objs_tree[selected_node]
|
163 |
node_mask = sample.get_object_mask(selected_node)
|
164 |
+
gt_mask = (
|
165 |
+
sample.get_soft_object_mask(selected_node)
|
166 |
+
if self.soft_targets
|
167 |
+
else node_mask
|
168 |
+
)
|
169 |
pos_mask = node_mask.copy()
|
170 |
|
171 |
negative_segments = []
|
172 |
+
if node_info["parent"] is not None and node_info["parent"] in objs_tree:
|
173 |
+
parent_mask = sample.get_object_mask(node_info["parent"])
|
174 |
+
negative_segments.append(
|
175 |
+
np.logical_and(parent_mask, np.logical_not(node_mask))
|
176 |
+
)
|
177 |
+
|
178 |
+
for child_id in node_info["children"]:
|
179 |
+
if objs_tree[child_id]["area"] / node_info["area"] < 0.10:
|
180 |
child_mask = sample.get_object_mask(child_id)
|
181 |
pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask))
|
182 |
|
183 |
+
if node_info["children"]:
|
184 |
+
max_disabled_children = min(len(node_info["children"]), 3)
|
185 |
num_disabled_children = np.random.randint(0, max_disabled_children + 1)
|
186 |
+
disabled_children = random.sample(
|
187 |
+
node_info["children"], num_disabled_children
|
188 |
+
)
|
189 |
|
190 |
for child_id in disabled_children:
|
191 |
child_mask = sample.get_object_mask(child_id)
|
|
|
201 |
|
202 |
def sample_points(self):
|
203 |
assert self._selected_mask is not None
|
204 |
+
pos_points = self._multi_mask_sample_points(
|
205 |
+
self._selected_masks,
|
206 |
+
is_negative=[False] * len(self._selected_masks),
|
207 |
+
with_first_click=self.first_click_center,
|
208 |
+
)
|
209 |
+
|
210 |
+
neg_strategy = [
|
211 |
+
(self._neg_masks[k], prob)
|
212 |
+
for k, prob in zip(self.neg_strategies, self.neg_strategies_prob)
|
213 |
+
]
|
214 |
+
neg_masks = self._neg_masks["required"] + [neg_strategy]
|
215 |
+
neg_points = self._multi_mask_sample_points(
|
216 |
+
neg_masks, is_negative=[False] * len(self._neg_masks["required"]) + [True]
|
217 |
+
)
|
218 |
|
219 |
return pos_points + neg_points
|
220 |
|
221 |
+
def _multi_mask_sample_points(
|
222 |
+
self, selected_masks, is_negative, with_first_click=False
|
223 |
+
):
|
224 |
+
selected_masks = selected_masks[: self.max_num_points]
|
225 |
|
226 |
each_obj_points = [
|
227 |
+
self._sample_points(
|
228 |
+
mask, is_negative=is_negative[i], with_first_click=with_first_click
|
229 |
+
)
|
230 |
for i, mask in enumerate(selected_masks)
|
231 |
]
|
232 |
each_obj_points = [x for x in each_obj_points if len(x) > 0]
|
|
|
242 |
|
243 |
aggregated_masks_with_prob = []
|
244 |
for indx, x in enumerate(selected_masks):
|
245 |
+
if (
|
246 |
+
isinstance(x, (list, tuple))
|
247 |
+
and x
|
248 |
+
and isinstance(x[0], (list, tuple))
|
249 |
+
):
|
250 |
for t, prob in x:
|
251 |
+
aggregated_masks_with_prob.append(
|
252 |
+
(t, prob / len(selected_masks))
|
253 |
+
)
|
254 |
else:
|
255 |
aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks)))
|
256 |
|
257 |
+
other_points_union = self._sample_points(
|
258 |
+
aggregated_masks_with_prob, is_negative=True
|
259 |
+
)
|
260 |
if len(other_points_union) + len(points) <= self.max_num_points:
|
261 |
points.extend(other_points_union)
|
262 |
else:
|
263 |
+
points.extend(
|
264 |
+
random.sample(other_points_union, self.max_num_points - len(points))
|
265 |
+
)
|
266 |
|
267 |
if len(points) < self.max_num_points:
|
268 |
points.extend([(-1, -1, -1)] * (self.max_num_points - len(points)))
|
|
|
271 |
|
272 |
def _sample_points(self, mask, is_negative=False, with_first_click=False):
|
273 |
if is_negative:
|
274 |
+
num_points = np.random.choice(
|
275 |
+
np.arange(self.max_num_points + 1), p=self._neg_probs
|
276 |
+
)
|
277 |
else:
|
278 |
+
num_points = 1 + np.random.choice(
|
279 |
+
np.arange(self.max_num_points), p=self._pos_probs
|
280 |
+
)
|
281 |
|
282 |
indices_probs = None
|
283 |
if isinstance(mask, (list, tuple)):
|
|
|
293 |
first_click = with_first_click and j == 0 and indices_probs is None
|
294 |
|
295 |
if first_click:
|
296 |
+
point_indices = get_point_candidates(
|
297 |
+
mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob
|
298 |
+
)
|
299 |
elif indices_probs:
|
300 |
+
point_indices_indx = np.random.choice(
|
301 |
+
np.arange(len(indices)), p=indices_probs
|
302 |
+
)
|
303 |
point_indices = indices[point_indices_indx][0]
|
304 |
else:
|
305 |
point_indices = indices
|
|
|
307 |
num_indices = len(point_indices)
|
308 |
if num_indices > 0:
|
309 |
point_indx = 0 if first_click else 100
|
310 |
+
click = point_indices[np.random.randint(0, num_indices)].tolist() + [
|
311 |
+
point_indx
|
312 |
+
]
|
313 |
points.append(click)
|
314 |
|
315 |
return points
|
|
|
319 |
return mask
|
320 |
|
321 |
kernel = np.ones((3, 3), np.uint8)
|
322 |
+
eroded_mask = cv2.erode(
|
323 |
+
mask.astype(np.uint8), kernel, iterations=self.positive_erode_iters
|
324 |
+
).astype(np.bool)
|
325 |
|
326 |
if eroded_mask.sum() > 10:
|
327 |
return eroded_mask
|
|
|
354 |
if full_prob > 0 and random.random() < full_prob:
|
355 |
return obj_mask
|
356 |
|
357 |
+
padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), "constant")
|
358 |
|
359 |
dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1]
|
360 |
if k > 0:
|
isegm/data/sample.py
CHANGED
@@ -1,13 +1,22 @@
|
|
1 |
-
import numpy as np
|
2 |
from copy import deepcopy
|
3 |
-
|
4 |
-
|
5 |
from albumentations import ReplayCompose
|
6 |
|
|
|
|
|
|
|
7 |
|
8 |
class DSample:
|
9 |
-
def __init__(
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
self.image = image
|
12 |
self.sample_id = sample_id
|
13 |
|
@@ -24,9 +33,9 @@ class DSample:
|
|
24 |
self._objects = dict()
|
25 |
for indx, obj_mapping in enumerate(objects_ids):
|
26 |
self._objects[indx] = {
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
}
|
31 |
|
32 |
if ignore_ids:
|
@@ -44,10 +53,10 @@ class DSample:
|
|
44 |
def augment(self, augmentator):
|
45 |
self.reset_augmentation()
|
46 |
aug_output = augmentator(image=self.image, mask=self._encoded_masks)
|
47 |
-
self.image = aug_output[
|
48 |
-
self._encoded_masks = aug_output[
|
49 |
|
50 |
-
aug_replay = aug_output.get(
|
51 |
if aug_replay:
|
52 |
assert len(self._ignored_regions) == 0
|
53 |
mask_replay = remove_image_only_transforms(aug_replay)
|
@@ -69,15 +78,15 @@ class DSample:
|
|
69 |
self._soft_mask_aug = None
|
70 |
|
71 |
def remove_small_objects(self, min_area):
|
72 |
-
if self._objects and not
|
73 |
self._compute_objects_areas()
|
74 |
|
75 |
for obj_id, obj_info in list(self._objects.items()):
|
76 |
-
if obj_info[
|
77 |
self._remove_object(obj_id)
|
78 |
|
79 |
def get_object_mask(self, obj_id):
|
80 |
-
layer_indx, mask_id = self._objects[obj_id][
|
81 |
obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32)
|
82 |
if self._ignored_regions:
|
83 |
for layer_indx, mask_id in self._ignored_regions:
|
@@ -89,9 +98,13 @@ class DSample:
|
|
89 |
def get_soft_object_mask(self, obj_id):
|
90 |
assert self._soft_mask_aug is not None
|
91 |
original_encoded_masks = self._original_data[1]
|
92 |
-
layer_indx, mask_id = self._objects[obj_id][
|
93 |
-
obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(
|
94 |
-
|
|
|
|
|
|
|
|
|
95 |
return np.clip(obj_mask, 0, 1)
|
96 |
|
97 |
def get_background_mask(self):
|
@@ -108,20 +121,28 @@ class DSample:
|
|
108 |
|
109 |
@property
|
110 |
def root_objects(self):
|
111 |
-
return [
|
|
|
|
|
|
|
|
|
112 |
|
113 |
def _compute_objects_areas(self):
|
114 |
-
inverse_index = {
|
|
|
|
|
115 |
ignored_regions_keys = set(self._ignored_regions)
|
116 |
|
117 |
for layer_indx in range(self._encoded_masks.shape[2]):
|
118 |
-
objects_ids, objects_areas = get_labels_with_sizes(
|
|
|
|
|
119 |
for obj_id, obj_area in zip(objects_ids, objects_areas):
|
120 |
inv_key = (layer_indx, obj_id)
|
121 |
if inv_key in ignored_regions_keys:
|
122 |
continue
|
123 |
try:
|
124 |
-
self._objects[inverse_index[inv_key]][
|
125 |
del inverse_index[inv_key]
|
126 |
except KeyError:
|
127 |
layer = self._encoded_masks[:, :, layer_indx]
|
@@ -129,18 +150,20 @@ class DSample:
|
|
129 |
self._encoded_masks[:, :, layer_indx] = layer
|
130 |
|
131 |
for obj_id in inverse_index.values():
|
132 |
-
self._objects[obj_id][
|
133 |
|
134 |
def _remove_object(self, obj_id):
|
135 |
obj_info = self._objects[obj_id]
|
136 |
-
obj_parent = obj_info[
|
137 |
-
for child_id in obj_info[
|
138 |
-
self._objects[child_id][
|
139 |
|
140 |
if obj_parent is not None:
|
141 |
-
parent_children = self._objects[obj_parent][
|
142 |
parent_children = [x for x in parent_children if x != obj_id]
|
143 |
-
self._objects[obj_parent][
|
|
|
|
|
144 |
|
145 |
del self._objects[obj_id]
|
146 |
|
|
|
|
|
1 |
from copy import deepcopy
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
from albumentations import ReplayCompose
|
5 |
|
6 |
+
from isegm.data.transforms import remove_image_only_transforms
|
7 |
+
from isegm.utils.misc import get_labels_with_sizes
|
8 |
+
|
9 |
|
10 |
class DSample:
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
image,
|
14 |
+
encoded_masks,
|
15 |
+
objects=None,
|
16 |
+
objects_ids=None,
|
17 |
+
ignore_ids=None,
|
18 |
+
sample_id=None,
|
19 |
+
):
|
20 |
self.image = image
|
21 |
self.sample_id = sample_id
|
22 |
|
|
|
33 |
self._objects = dict()
|
34 |
for indx, obj_mapping in enumerate(objects_ids):
|
35 |
self._objects[indx] = {
|
36 |
+
"parent": None,
|
37 |
+
"mapping": obj_mapping,
|
38 |
+
"children": [],
|
39 |
}
|
40 |
|
41 |
if ignore_ids:
|
|
|
53 |
def augment(self, augmentator):
|
54 |
self.reset_augmentation()
|
55 |
aug_output = augmentator(image=self.image, mask=self._encoded_masks)
|
56 |
+
self.image = aug_output["image"]
|
57 |
+
self._encoded_masks = aug_output["mask"]
|
58 |
|
59 |
+
aug_replay = aug_output.get("replay", None)
|
60 |
if aug_replay:
|
61 |
assert len(self._ignored_regions) == 0
|
62 |
mask_replay = remove_image_only_transforms(aug_replay)
|
|
|
78 |
self._soft_mask_aug = None
|
79 |
|
80 |
def remove_small_objects(self, min_area):
|
81 |
+
if self._objects and not "area" in list(self._objects.values())[0]:
|
82 |
self._compute_objects_areas()
|
83 |
|
84 |
for obj_id, obj_info in list(self._objects.items()):
|
85 |
+
if obj_info["area"] < min_area:
|
86 |
self._remove_object(obj_id)
|
87 |
|
88 |
def get_object_mask(self, obj_id):
|
89 |
+
layer_indx, mask_id = self._objects[obj_id]["mapping"]
|
90 |
obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32)
|
91 |
if self._ignored_regions:
|
92 |
for layer_indx, mask_id in self._ignored_regions:
|
|
|
98 |
def get_soft_object_mask(self, obj_id):
|
99 |
assert self._soft_mask_aug is not None
|
100 |
original_encoded_masks = self._original_data[1]
|
101 |
+
layer_indx, mask_id = self._objects[obj_id]["mapping"]
|
102 |
+
obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(
|
103 |
+
np.float32
|
104 |
+
)
|
105 |
+
obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)[
|
106 |
+
"image"
|
107 |
+
]
|
108 |
return np.clip(obj_mask, 0, 1)
|
109 |
|
110 |
def get_background_mask(self):
|
|
|
121 |
|
122 |
@property
|
123 |
def root_objects(self):
|
124 |
+
return [
|
125 |
+
obj_id
|
126 |
+
for obj_id, obj_info in self._objects.items()
|
127 |
+
if obj_info["parent"] is None
|
128 |
+
]
|
129 |
|
130 |
def _compute_objects_areas(self):
|
131 |
+
inverse_index = {
|
132 |
+
node["mapping"]: node_id for node_id, node in self._objects.items()
|
133 |
+
}
|
134 |
ignored_regions_keys = set(self._ignored_regions)
|
135 |
|
136 |
for layer_indx in range(self._encoded_masks.shape[2]):
|
137 |
+
objects_ids, objects_areas = get_labels_with_sizes(
|
138 |
+
self._encoded_masks[:, :, layer_indx]
|
139 |
+
)
|
140 |
for obj_id, obj_area in zip(objects_ids, objects_areas):
|
141 |
inv_key = (layer_indx, obj_id)
|
142 |
if inv_key in ignored_regions_keys:
|
143 |
continue
|
144 |
try:
|
145 |
+
self._objects[inverse_index[inv_key]]["area"] = obj_area
|
146 |
del inverse_index[inv_key]
|
147 |
except KeyError:
|
148 |
layer = self._encoded_masks[:, :, layer_indx]
|
|
|
150 |
self._encoded_masks[:, :, layer_indx] = layer
|
151 |
|
152 |
for obj_id in inverse_index.values():
|
153 |
+
self._objects[obj_id]["area"] = 0
|
154 |
|
155 |
def _remove_object(self, obj_id):
|
156 |
obj_info = self._objects[obj_id]
|
157 |
+
obj_parent = obj_info["parent"]
|
158 |
+
for child_id in obj_info["children"]:
|
159 |
+
self._objects[child_id]["parent"] = obj_parent
|
160 |
|
161 |
if obj_parent is not None:
|
162 |
+
parent_children = self._objects[obj_parent]["children"]
|
163 |
parent_children = [x for x in parent_children if x != obj_id]
|
164 |
+
self._objects[obj_parent]["children"] = (
|
165 |
+
parent_children + obj_info["children"]
|
166 |
+
)
|
167 |
|
168 |
del self._objects[obj_id]
|
169 |
|
isegm/data/transforms.py
CHANGED
@@ -1,28 +1,40 @@
|
|
1 |
-
import cv2
|
2 |
import random
|
3 |
-
import numpy as np
|
4 |
|
|
|
|
|
|
|
|
|
5 |
from albumentations.core.serialization import SERIALIZABLE_REGISTRY
|
6 |
-
from albumentations import ImageOnlyTransform, DualTransform
|
7 |
from albumentations.core.transforms_interface import to_tuple
|
8 |
-
|
9 |
-
from isegm.utils.misc import
|
|
|
10 |
|
11 |
|
12 |
class UniformRandomResize(DualTransform):
|
13 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
super().__init__(always_apply, p)
|
15 |
self.scale_range = scale_range
|
16 |
self.interpolation = interpolation
|
17 |
|
18 |
def get_params_dependent_on_targets(self, params):
|
19 |
scale = random.uniform(*self.scale_range)
|
20 |
-
height = int(round(params[
|
21 |
-
width = int(round(params[
|
22 |
-
return {
|
23 |
|
24 |
-
def apply(
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params):
|
28 |
scale_x = new_width / params["cols"]
|
@@ -39,16 +51,16 @@ class UniformRandomResize(DualTransform):
|
|
39 |
|
40 |
class ZoomIn(DualTransform):
|
41 |
def __init__(
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
):
|
53 |
super(ZoomIn, self).__init__(always_apply, p)
|
54 |
self.height = height
|
@@ -66,7 +78,7 @@ class ZoomIn(DualTransform):
|
|
66 |
return img
|
67 |
|
68 |
rmin, rmax, cmin, cmax = bbox
|
69 |
-
img = img[rmin:rmax + 1, cmin:cmax + 1]
|
70 |
img = F.resize(img, height=self.height, width=self.width)
|
71 |
|
72 |
return img
|
@@ -74,12 +86,16 @@ class ZoomIn(DualTransform):
|
|
74 |
def apply_to_mask(self, mask, selected_object, bbox, **params):
|
75 |
if selected_object is None:
|
76 |
if self.always_resize:
|
77 |
-
mask = F.resize(
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
return mask
|
80 |
|
81 |
rmin, rmax, cmin, cmax = bbox
|
82 |
-
mask = mask[rmin:rmax + 1, cmin:cmax + 1]
|
83 |
if isinstance(selected_object, tuple):
|
84 |
layer_indx, mask_id = selected_object
|
85 |
obj_mask = mask[:, :, layer_indx] == mask_id
|
@@ -90,25 +106,34 @@ class ZoomIn(DualTransform):
|
|
90 |
new_mask = mask.copy()
|
91 |
new_mask[np.logical_not(obj_mask)] = 0
|
92 |
|
93 |
-
new_mask = F.resize(
|
94 |
-
|
|
|
|
|
|
|
|
|
95 |
return new_mask
|
96 |
|
97 |
def get_params_dependent_on_targets(self, params):
|
98 |
-
instances = params[
|
99 |
|
100 |
is_mask_layer = len(instances.shape) > 2
|
101 |
candidates = []
|
102 |
if is_mask_layer:
|
103 |
for layer_indx in range(instances.shape[2]):
|
104 |
labels, areas = get_labels_with_sizes(instances[:, :, layer_indx])
|
105 |
-
candidates.extend(
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
108 |
else:
|
109 |
labels, areas = get_labels_with_sizes(instances)
|
110 |
-
candidates = [
|
111 |
-
|
|
|
112 |
|
113 |
selected_object = None
|
114 |
bbox = None
|
@@ -131,10 +156,7 @@ class ZoomIn(DualTransform):
|
|
131 |
bbox = self._jitter_bbox(bbox)
|
132 |
bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1)
|
133 |
|
134 |
-
return {
|
135 |
-
'selected_object': selected_object,
|
136 |
-
'bbox': bbox
|
137 |
-
}
|
138 |
|
139 |
def _jitter_bbox(self, bbox):
|
140 |
rmin, rmax, cmin, cmax = bbox
|
@@ -158,21 +180,28 @@ class ZoomIn(DualTransform):
|
|
158 |
return ["mask"]
|
159 |
|
160 |
def get_transform_init_args_names(self):
|
161 |
-
return (
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
|
165 |
def remove_image_only_transforms(sdict):
|
166 |
-
if not
|
167 |
return sdict
|
168 |
|
169 |
keep_transforms = []
|
170 |
-
for tdict in sdict[
|
171 |
-
cls = SERIALIZABLE_REGISTRY[tdict[
|
172 |
-
if
|
173 |
keep_transforms.append(remove_image_only_transforms(tdict))
|
174 |
elif not issubclass(cls, ImageOnlyTransform):
|
175 |
keep_transforms.append(tdict)
|
176 |
-
sdict[
|
177 |
|
178 |
return sdict
|
|
|
|
|
1 |
import random
|
|
|
2 |
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from albumentations import DualTransform, ImageOnlyTransform
|
6 |
+
from albumentations.augmentations import functional as F
|
7 |
from albumentations.core.serialization import SERIALIZABLE_REGISTRY
|
|
|
8 |
from albumentations.core.transforms_interface import to_tuple
|
9 |
+
|
10 |
+
from isegm.utils.misc import (clamp_bbox, expand_bbox, get_bbox_from_mask,
|
11 |
+
get_labels_with_sizes)
|
12 |
|
13 |
|
14 |
class UniformRandomResize(DualTransform):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
scale_range=(0.9, 1.1),
|
18 |
+
interpolation=cv2.INTER_LINEAR,
|
19 |
+
always_apply=False,
|
20 |
+
p=1,
|
21 |
+
):
|
22 |
super().__init__(always_apply, p)
|
23 |
self.scale_range = scale_range
|
24 |
self.interpolation = interpolation
|
25 |
|
26 |
def get_params_dependent_on_targets(self, params):
|
27 |
scale = random.uniform(*self.scale_range)
|
28 |
+
height = int(round(params["image"].shape[0] * scale))
|
29 |
+
width = int(round(params["image"].shape[1] * scale))
|
30 |
+
return {"new_height": height, "new_width": width}
|
31 |
|
32 |
+
def apply(
|
33 |
+
self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params
|
34 |
+
):
|
35 |
+
return F.resize(
|
36 |
+
img, height=new_height, width=new_width, interpolation=interpolation
|
37 |
+
)
|
38 |
|
39 |
def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params):
|
40 |
scale_x = new_width / params["cols"]
|
|
|
51 |
|
52 |
class ZoomIn(DualTransform):
|
53 |
def __init__(
|
54 |
+
self,
|
55 |
+
height,
|
56 |
+
width,
|
57 |
+
bbox_jitter=0.1,
|
58 |
+
expansion_ratio=1.4,
|
59 |
+
min_crop_size=200,
|
60 |
+
min_area=100,
|
61 |
+
always_resize=False,
|
62 |
+
always_apply=False,
|
63 |
+
p=0.5,
|
64 |
):
|
65 |
super(ZoomIn, self).__init__(always_apply, p)
|
66 |
self.height = height
|
|
|
78 |
return img
|
79 |
|
80 |
rmin, rmax, cmin, cmax = bbox
|
81 |
+
img = img[rmin : rmax + 1, cmin : cmax + 1]
|
82 |
img = F.resize(img, height=self.height, width=self.width)
|
83 |
|
84 |
return img
|
|
|
86 |
def apply_to_mask(self, mask, selected_object, bbox, **params):
|
87 |
if selected_object is None:
|
88 |
if self.always_resize:
|
89 |
+
mask = F.resize(
|
90 |
+
mask,
|
91 |
+
height=self.height,
|
92 |
+
width=self.width,
|
93 |
+
interpolation=cv2.INTER_NEAREST,
|
94 |
+
)
|
95 |
return mask
|
96 |
|
97 |
rmin, rmax, cmin, cmax = bbox
|
98 |
+
mask = mask[rmin : rmax + 1, cmin : cmax + 1]
|
99 |
if isinstance(selected_object, tuple):
|
100 |
layer_indx, mask_id = selected_object
|
101 |
obj_mask = mask[:, :, layer_indx] == mask_id
|
|
|
106 |
new_mask = mask.copy()
|
107 |
new_mask[np.logical_not(obj_mask)] = 0
|
108 |
|
109 |
+
new_mask = F.resize(
|
110 |
+
new_mask,
|
111 |
+
height=self.height,
|
112 |
+
width=self.width,
|
113 |
+
interpolation=cv2.INTER_NEAREST,
|
114 |
+
)
|
115 |
return new_mask
|
116 |
|
117 |
def get_params_dependent_on_targets(self, params):
|
118 |
+
instances = params["mask"]
|
119 |
|
120 |
is_mask_layer = len(instances.shape) > 2
|
121 |
candidates = []
|
122 |
if is_mask_layer:
|
123 |
for layer_indx in range(instances.shape[2]):
|
124 |
labels, areas = get_labels_with_sizes(instances[:, :, layer_indx])
|
125 |
+
candidates.extend(
|
126 |
+
[
|
127 |
+
(layer_indx, obj_id)
|
128 |
+
for obj_id, area in zip(labels, areas)
|
129 |
+
if area > self.min_area
|
130 |
+
]
|
131 |
+
)
|
132 |
else:
|
133 |
labels, areas = get_labels_with_sizes(instances)
|
134 |
+
candidates = [
|
135 |
+
obj_id for obj_id, area in zip(labels, areas) if area > self.min_area
|
136 |
+
]
|
137 |
|
138 |
selected_object = None
|
139 |
bbox = None
|
|
|
156 |
bbox = self._jitter_bbox(bbox)
|
157 |
bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1)
|
158 |
|
159 |
+
return {"selected_object": selected_object, "bbox": bbox}
|
|
|
|
|
|
|
160 |
|
161 |
def _jitter_bbox(self, bbox):
|
162 |
rmin, rmax, cmin, cmax = bbox
|
|
|
180 |
return ["mask"]
|
181 |
|
182 |
def get_transform_init_args_names(self):
|
183 |
+
return (
|
184 |
+
"height",
|
185 |
+
"width",
|
186 |
+
"bbox_jitter",
|
187 |
+
"expansion_ratio",
|
188 |
+
"min_crop_size",
|
189 |
+
"min_area",
|
190 |
+
"always_resize",
|
191 |
+
)
|
192 |
|
193 |
|
194 |
def remove_image_only_transforms(sdict):
|
195 |
+
if not "transforms" in sdict:
|
196 |
return sdict
|
197 |
|
198 |
keep_transforms = []
|
199 |
+
for tdict in sdict["transforms"]:
|
200 |
+
cls = SERIALIZABLE_REGISTRY[tdict["__class_fullname__"]]
|
201 |
+
if "transforms" in tdict:
|
202 |
keep_transforms.append(remove_image_only_transforms(tdict))
|
203 |
elif not issubclass(cls, ImageOnlyTransform):
|
204 |
keep_transforms.append(tdict)
|
205 |
+
sdict["transforms"] = keep_transforms
|
206 |
|
207 |
return sdict
|
isegm/engine/optimizer.py
CHANGED
@@ -1,27 +1,29 @@
|
|
1 |
-
import torch
|
2 |
import math
|
|
|
|
|
|
|
3 |
from isegm.utils.log import logger
|
4 |
|
5 |
|
6 |
def get_optimizer(model, opt_name, opt_kwargs):
|
7 |
params = []
|
8 |
-
base_lr = opt_kwargs[
|
9 |
for name, param in model.named_parameters():
|
10 |
-
param_group = {
|
11 |
if not param.requires_grad:
|
12 |
params.append(param_group)
|
13 |
continue
|
14 |
|
15 |
-
if not math.isclose(getattr(param,
|
16 |
logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.')
|
17 |
-
param_group[
|
18 |
|
19 |
params.append(param_group)
|
20 |
|
21 |
optimizer = {
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
}[opt_name.lower()](params, **opt_kwargs)
|
26 |
|
27 |
return optimizer
|
|
|
|
|
1 |
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
from isegm.utils.log import logger
|
6 |
|
7 |
|
8 |
def get_optimizer(model, opt_name, opt_kwargs):
|
9 |
params = []
|
10 |
+
base_lr = opt_kwargs["lr"]
|
11 |
for name, param in model.named_parameters():
|
12 |
+
param_group = {"params": [param]}
|
13 |
if not param.requires_grad:
|
14 |
params.append(param_group)
|
15 |
continue
|
16 |
|
17 |
+
if not math.isclose(getattr(param, "lr_mult", 1.0), 1.0):
|
18 |
logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.')
|
19 |
+
param_group["lr"] = param_group.get("lr", base_lr) * param.lr_mult
|
20 |
|
21 |
params.append(param_group)
|
22 |
|
23 |
optimizer = {
|
24 |
+
"sgd": torch.optim.SGD,
|
25 |
+
"adam": torch.optim.Adam,
|
26 |
+
"adamw": torch.optim.AdamW,
|
27 |
}[opt_name.lower()](params, **opt_kwargs)
|
28 |
|
29 |
return optimizer
|
isegm/engine/trainer.py
CHANGED
@@ -1,40 +1,48 @@
|
|
|
|
1 |
import os
|
2 |
import random
|
3 |
-
import logging
|
4 |
-
from copy import deepcopy
|
5 |
from collections import defaultdict
|
|
|
6 |
|
7 |
import cv2
|
8 |
-
import torch
|
9 |
import numpy as np
|
10 |
-
|
11 |
from torch.utils.data import DataLoader
|
|
|
12 |
|
13 |
-
from isegm.utils.
|
14 |
-
|
|
|
15 |
from isegm.utils.misc import save_checkpoint
|
16 |
from isegm.utils.serialization import get_config_repr
|
17 |
-
from isegm.utils.
|
|
|
18 |
from .optimizer import get_optimizer
|
19 |
|
20 |
|
21 |
class ISTrainer(object):
|
22 |
-
def __init__(
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
self.cfg = cfg
|
39 |
self.model_cfg = model_cfg
|
40 |
self.max_interactive_points = max_interactive_points
|
@@ -60,35 +68,44 @@ class ISTrainer(object):
|
|
60 |
|
61 |
self.checkpoint_interval = checkpoint_interval
|
62 |
self.image_dump_interval = image_dump_interval
|
63 |
-
self.task_prefix =
|
64 |
self.sw = None
|
65 |
|
66 |
self.trainset = trainset
|
67 |
self.valset = valset
|
68 |
|
69 |
-
logger.info(
|
70 |
-
|
|
|
|
|
|
|
|
|
71 |
|
72 |
self.train_data = DataLoader(
|
73 |
-
trainset,
|
|
|
74 |
sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed),
|
75 |
-
drop_last=True,
|
76 |
-
|
|
|
77 |
)
|
78 |
|
79 |
self.val_data = DataLoader(
|
80 |
-
valset,
|
|
|
81 |
sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed),
|
82 |
-
drop_last=True,
|
83 |
-
|
|
|
84 |
)
|
85 |
|
86 |
self.optim = get_optimizer(model, optimizer, optimizer_params)
|
87 |
model = self._load_weights(model)
|
88 |
|
89 |
if cfg.multi_gpu:
|
90 |
-
model = get_dp_wrapper(cfg.distributed)(
|
91 |
-
|
|
|
92 |
|
93 |
if self.is_master:
|
94 |
logger.info(model)
|
@@ -96,7 +113,7 @@ class ISTrainer(object):
|
|
96 |
|
97 |
self.device = cfg.device
|
98 |
self.net = model.to(self.device)
|
99 |
-
self.lr = optimizer_params[
|
100 |
|
101 |
if lr_scheduler is not None:
|
102 |
self.lr_scheduler = lr_scheduler(optimizer=self.optim)
|
@@ -117,8 +134,8 @@ class ISTrainer(object):
|
|
117 |
if start_epoch is None:
|
118 |
start_epoch = self.cfg.start_epoch
|
119 |
|
120 |
-
logger.info(f
|
121 |
-
logger.info(f
|
122 |
for epoch in range(start_epoch, num_epochs):
|
123 |
self.training(epoch)
|
124 |
if validation:
|
@@ -126,15 +143,21 @@ class ISTrainer(object):
|
|
126 |
|
127 |
def training(self, epoch):
|
128 |
if self.sw is None and self.is_master:
|
129 |
-
self.sw = SummaryWriterAvg(
|
130 |
-
|
|
|
|
|
|
|
131 |
|
132 |
if self.cfg.distributed:
|
133 |
self.train_data.sampler.set_epoch(epoch)
|
134 |
|
135 |
-
log_prefix =
|
136 |
-
tbar =
|
137 |
-
|
|
|
|
|
|
|
138 |
|
139 |
for metric in self.train_metrics:
|
140 |
metric.reset_epoch_stats()
|
@@ -144,67 +167,109 @@ class ISTrainer(object):
|
|
144 |
for i, batch_data in enumerate(tbar):
|
145 |
global_step = epoch * len(self.train_data) + i
|
146 |
|
147 |
-
loss, losses_logging, splitted_batch_data, outputs =
|
148 |
-
|
|
|
149 |
|
150 |
self.optim.zero_grad()
|
151 |
loss.backward()
|
152 |
self.optim.step()
|
153 |
|
154 |
-
losses_logging[
|
155 |
reduce_loss_dict(losses_logging)
|
156 |
|
157 |
-
train_loss += losses_logging[
|
158 |
|
159 |
if self.is_master:
|
160 |
for loss_name, loss_value in losses_logging.items():
|
161 |
-
self.sw.add_scalar(
|
162 |
-
|
163 |
-
|
|
|
|
|
164 |
|
165 |
for k, v in self.loss_cfg.items():
|
166 |
-
if
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
for metric in self.train_metrics:
|
178 |
-
metric.log_states(
|
|
|
|
|
179 |
|
180 |
if self.is_master:
|
181 |
for metric in self.train_metrics:
|
182 |
-
self.sw.add_scalar(
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
if isinstance(self.checkpoint_interval, (list, tuple)):
|
190 |
-
checkpoint_interval = [
|
|
|
|
|
191 |
else:
|
192 |
checkpoint_interval = self.checkpoint_interval
|
193 |
|
194 |
if epoch % checkpoint_interval == 0:
|
195 |
-
save_checkpoint(
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
199 |
self.lr_scheduler.step()
|
200 |
|
201 |
def validation(self, epoch):
|
202 |
if self.sw is None and self.is_master:
|
203 |
-
self.sw = SummaryWriterAvg(
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
for metric in self.val_metrics:
|
210 |
metric.reset_epoch_stats()
|
@@ -215,29 +280,45 @@ class ISTrainer(object):
|
|
215 |
self.net.eval()
|
216 |
for i, batch_data in enumerate(tbar):
|
217 |
global_step = epoch * len(self.val_data) + i
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
222 |
reduce_loss_dict(batch_losses_logging)
|
223 |
for loss_name, loss_value in batch_losses_logging.items():
|
224 |
losses_logging[loss_name].append(loss_value.item())
|
225 |
|
226 |
-
val_loss += batch_losses_logging[
|
227 |
|
228 |
if self.is_master:
|
229 |
-
tbar.set_description(
|
|
|
|
|
230 |
for metric in self.val_metrics:
|
231 |
-
metric.log_states(
|
|
|
|
|
232 |
|
233 |
if self.is_master:
|
234 |
for loss_name, loss_values in losses_logging.items():
|
235 |
-
self.sw.add_scalar(
|
236 |
-
|
|
|
|
|
|
|
|
|
237 |
|
238 |
for metric in self.val_metrics:
|
239 |
-
self.sw.add_scalar(
|
240 |
-
|
|
|
|
|
|
|
|
|
241 |
|
242 |
def batch_forward(self, batch_data, validation=False):
|
243 |
metrics = self.val_metrics if validation else self.train_metrics
|
@@ -245,8 +326,16 @@ class ISTrainer(object):
|
|
245 |
|
246 |
with torch.set_grad_enabled(not validation):
|
247 |
batch_data = {k: v.to(self.device) for k, v in batch_data.items()}
|
248 |
-
image, gt_mask, points =
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :]
|
252 |
|
@@ -261,44 +350,79 @@ class ISTrainer(object):
|
|
261 |
if not validation:
|
262 |
self.net.eval()
|
263 |
|
264 |
-
if self.click_models is None or click_indx >= len(
|
|
|
|
|
265 |
eval_model = self.net
|
266 |
else:
|
267 |
eval_model = self.click_models[click_indx]
|
268 |
|
269 |
-
net_input =
|
270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
-
points = get_next_points(
|
|
|
|
|
273 |
|
274 |
if not validation:
|
275 |
self.net.train()
|
276 |
|
277 |
-
if
|
278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask])
|
280 |
|
281 |
-
batch_data[
|
282 |
|
283 |
-
net_input =
|
|
|
|
|
|
|
|
|
284 |
output = self.net(net_input, points)
|
285 |
|
286 |
loss = 0.0
|
287 |
-
loss = self.add_loss(
|
288 |
-
|
289 |
-
|
290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
|
292 |
if self.is_master:
|
293 |
with torch.no_grad():
|
294 |
for m in metrics:
|
295 |
-
m.update(
|
296 |
-
|
|
|
|
|
297 |
return loss, losses_logging, batch_data, output
|
298 |
|
299 |
-
def add_loss(
|
|
|
|
|
300 |
loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg
|
301 |
-
loss_weight = loss_cfg.get(loss_name +
|
302 |
if loss_weight > 0.0:
|
303 |
loss_criterion = loss_cfg.get(loss_name)
|
304 |
loss = loss_criterion(*lambda_loss_inputs())
|
@@ -316,18 +440,23 @@ class ISTrainer(object):
|
|
316 |
|
317 |
if not output_images_path.exists():
|
318 |
output_images_path.mkdir(parents=True)
|
319 |
-
image_name_prefix = f
|
320 |
|
321 |
def _save_image(suffix, image):
|
322 |
-
cv2.imwrite(
|
323 |
-
|
|
|
|
|
|
|
324 |
|
325 |
-
images = splitted_batch_data[
|
326 |
-
points = splitted_batch_data[
|
327 |
-
instance_masks = splitted_batch_data[
|
328 |
|
329 |
gt_instance_masks = instance_masks.cpu().numpy()
|
330 |
-
predicted_instance_masks =
|
|
|
|
|
331 |
points = points.detach().cpu().numpy()
|
332 |
|
333 |
image_blob, points = images[0], points[0]
|
@@ -337,15 +466,21 @@ class ISTrainer(object):
|
|
337 |
image = image_blob.cpu().numpy() * 255
|
338 |
image = image.transpose((1, 2, 0))
|
339 |
|
340 |
-
image_with_points = draw_points(
|
341 |
-
|
|
|
|
|
|
|
|
|
342 |
|
343 |
gt_mask[gt_mask < 0] = 0.25
|
344 |
gt_mask = draw_probmap(gt_mask)
|
345 |
predicted_mask = draw_probmap(predicted_mask)
|
346 |
-
viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(
|
|
|
|
|
347 |
|
348 |
-
_save_image(
|
349 |
|
350 |
def _load_weights(self, net):
|
351 |
if self.cfg.weights is not None:
|
@@ -355,11 +490,13 @@ class ISTrainer(object):
|
|
355 |
else:
|
356 |
raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'")
|
357 |
elif self.cfg.resume_exp is not None:
|
358 |
-
checkpoints = list(
|
|
|
|
|
359 |
assert len(checkpoints) == 1
|
360 |
|
361 |
checkpoint_path = checkpoints[0]
|
362 |
-
logger.info(f
|
363 |
load_weights(net, str(checkpoint_path))
|
364 |
return net
|
365 |
|
@@ -376,8 +513,8 @@ def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49):
|
|
376 |
fn_mask = np.logical_and(gt, pred < pred_thresh)
|
377 |
fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh)
|
378 |
|
379 |
-
fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)),
|
380 |
-
fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)),
|
381 |
num_points = points.size(1) // 2
|
382 |
points = points.clone()
|
383 |
|
@@ -408,6 +545,6 @@ def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49):
|
|
408 |
|
409 |
def load_weights(model, path_to_weights):
|
410 |
current_state_dict = model.state_dict()
|
411 |
-
new_state_dict = torch.load(path_to_weights, map_location=
|
412 |
current_state_dict.update(new_state_dict)
|
413 |
model.load_state_dict(current_state_dict)
|
|
|
1 |
+
import logging
|
2 |
import os
|
3 |
import random
|
|
|
|
|
4 |
from collections import defaultdict
|
5 |
+
from copy import deepcopy
|
6 |
|
7 |
import cv2
|
|
|
8 |
import numpy as np
|
9 |
+
import torch
|
10 |
from torch.utils.data import DataLoader
|
11 |
+
from tqdm import tqdm
|
12 |
|
13 |
+
from isegm.utils.distributed import (get_dp_wrapper, get_sampler,
|
14 |
+
reduce_loss_dict)
|
15 |
+
from isegm.utils.log import SummaryWriterAvg, TqdmToLogger, logger
|
16 |
from isegm.utils.misc import save_checkpoint
|
17 |
from isegm.utils.serialization import get_config_repr
|
18 |
+
from isegm.utils.vis import draw_points, draw_probmap
|
19 |
+
|
20 |
from .optimizer import get_optimizer
|
21 |
|
22 |
|
23 |
class ISTrainer(object):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
model,
|
27 |
+
cfg,
|
28 |
+
model_cfg,
|
29 |
+
loss_cfg,
|
30 |
+
trainset,
|
31 |
+
valset,
|
32 |
+
optimizer="adam",
|
33 |
+
optimizer_params=None,
|
34 |
+
image_dump_interval=200,
|
35 |
+
checkpoint_interval=10,
|
36 |
+
tb_dump_period=25,
|
37 |
+
max_interactive_points=0,
|
38 |
+
lr_scheduler=None,
|
39 |
+
metrics=None,
|
40 |
+
additional_val_metrics=None,
|
41 |
+
net_inputs=("images", "points"),
|
42 |
+
max_num_next_clicks=0,
|
43 |
+
click_models=None,
|
44 |
+
prev_mask_drop_prob=0.0,
|
45 |
+
):
|
46 |
self.cfg = cfg
|
47 |
self.model_cfg = model_cfg
|
48 |
self.max_interactive_points = max_interactive_points
|
|
|
68 |
|
69 |
self.checkpoint_interval = checkpoint_interval
|
70 |
self.image_dump_interval = image_dump_interval
|
71 |
+
self.task_prefix = ""
|
72 |
self.sw = None
|
73 |
|
74 |
self.trainset = trainset
|
75 |
self.valset = valset
|
76 |
|
77 |
+
logger.info(
|
78 |
+
f"Dataset of {trainset.get_samples_number()} samples was loaded for training."
|
79 |
+
)
|
80 |
+
logger.info(
|
81 |
+
f"Dataset of {valset.get_samples_number()} samples was loaded for validation."
|
82 |
+
)
|
83 |
|
84 |
self.train_data = DataLoader(
|
85 |
+
trainset,
|
86 |
+
cfg.batch_size,
|
87 |
sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed),
|
88 |
+
drop_last=True,
|
89 |
+
pin_memory=True,
|
90 |
+
num_workers=cfg.workers,
|
91 |
)
|
92 |
|
93 |
self.val_data = DataLoader(
|
94 |
+
valset,
|
95 |
+
cfg.val_batch_size,
|
96 |
sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed),
|
97 |
+
drop_last=True,
|
98 |
+
pin_memory=True,
|
99 |
+
num_workers=cfg.workers,
|
100 |
)
|
101 |
|
102 |
self.optim = get_optimizer(model, optimizer, optimizer_params)
|
103 |
model = self._load_weights(model)
|
104 |
|
105 |
if cfg.multi_gpu:
|
106 |
+
model = get_dp_wrapper(cfg.distributed)(
|
107 |
+
model, device_ids=cfg.gpu_ids, output_device=cfg.gpu_ids[0]
|
108 |
+
)
|
109 |
|
110 |
if self.is_master:
|
111 |
logger.info(model)
|
|
|
113 |
|
114 |
self.device = cfg.device
|
115 |
self.net = model.to(self.device)
|
116 |
+
self.lr = optimizer_params["lr"]
|
117 |
|
118 |
if lr_scheduler is not None:
|
119 |
self.lr_scheduler = lr_scheduler(optimizer=self.optim)
|
|
|
134 |
if start_epoch is None:
|
135 |
start_epoch = self.cfg.start_epoch
|
136 |
|
137 |
+
logger.info(f"Starting Epoch: {start_epoch}")
|
138 |
+
logger.info(f"Total Epochs: {num_epochs}")
|
139 |
for epoch in range(start_epoch, num_epochs):
|
140 |
self.training(epoch)
|
141 |
if validation:
|
|
|
143 |
|
144 |
def training(self, epoch):
|
145 |
if self.sw is None and self.is_master:
|
146 |
+
self.sw = SummaryWriterAvg(
|
147 |
+
log_dir=str(self.cfg.LOGS_PATH),
|
148 |
+
flush_secs=10,
|
149 |
+
dump_period=self.tb_dump_period,
|
150 |
+
)
|
151 |
|
152 |
if self.cfg.distributed:
|
153 |
self.train_data.sampler.set_epoch(epoch)
|
154 |
|
155 |
+
log_prefix = "Train" + self.task_prefix.capitalize()
|
156 |
+
tbar = (
|
157 |
+
tqdm(self.train_data, file=self.tqdm_out, ncols=100)
|
158 |
+
if self.is_master
|
159 |
+
else self.train_data
|
160 |
+
)
|
161 |
|
162 |
for metric in self.train_metrics:
|
163 |
metric.reset_epoch_stats()
|
|
|
167 |
for i, batch_data in enumerate(tbar):
|
168 |
global_step = epoch * len(self.train_data) + i
|
169 |
|
170 |
+
loss, losses_logging, splitted_batch_data, outputs = self.batch_forward(
|
171 |
+
batch_data
|
172 |
+
)
|
173 |
|
174 |
self.optim.zero_grad()
|
175 |
loss.backward()
|
176 |
self.optim.step()
|
177 |
|
178 |
+
losses_logging["overall"] = loss
|
179 |
reduce_loss_dict(losses_logging)
|
180 |
|
181 |
+
train_loss += losses_logging["overall"].item()
|
182 |
|
183 |
if self.is_master:
|
184 |
for loss_name, loss_value in losses_logging.items():
|
185 |
+
self.sw.add_scalar(
|
186 |
+
tag=f"{log_prefix}Losses/{loss_name}",
|
187 |
+
value=loss_value.item(),
|
188 |
+
global_step=global_step,
|
189 |
+
)
|
190 |
|
191 |
for k, v in self.loss_cfg.items():
|
192 |
+
if (
|
193 |
+
"_loss" in k
|
194 |
+
and hasattr(v, "log_states")
|
195 |
+
and self.loss_cfg.get(k + "_weight", 0.0) > 0
|
196 |
+
):
|
197 |
+
v.log_states(self.sw, f"{log_prefix}Losses/{k}", global_step)
|
198 |
+
|
199 |
+
if (
|
200 |
+
self.image_dump_interval > 0
|
201 |
+
and global_step % self.image_dump_interval == 0
|
202 |
+
):
|
203 |
+
self.save_visualization(
|
204 |
+
splitted_batch_data, outputs, global_step, prefix="train"
|
205 |
+
)
|
206 |
+
|
207 |
+
self.sw.add_scalar(
|
208 |
+
tag=f"{log_prefix}States/learning_rate",
|
209 |
+
value=self.lr
|
210 |
+
if not hasattr(self, "lr_scheduler")
|
211 |
+
else self.lr_scheduler.get_lr()[-1],
|
212 |
+
global_step=global_step,
|
213 |
+
)
|
214 |
+
|
215 |
+
tbar.set_description(
|
216 |
+
f"Epoch {epoch}, training loss {train_loss/(i+1):.4f}"
|
217 |
+
)
|
218 |
for metric in self.train_metrics:
|
219 |
+
metric.log_states(
|
220 |
+
self.sw, f"{log_prefix}Metrics/{metric.name}", global_step
|
221 |
+
)
|
222 |
|
223 |
if self.is_master:
|
224 |
for metric in self.train_metrics:
|
225 |
+
self.sw.add_scalar(
|
226 |
+
tag=f"{log_prefix}Metrics/{metric.name}",
|
227 |
+
value=metric.get_epoch_value(),
|
228 |
+
global_step=epoch,
|
229 |
+
disable_avg=True,
|
230 |
+
)
|
231 |
+
|
232 |
+
save_checkpoint(
|
233 |
+
self.net,
|
234 |
+
self.cfg.CHECKPOINTS_PATH,
|
235 |
+
prefix=self.task_prefix,
|
236 |
+
epoch=None,
|
237 |
+
multi_gpu=self.cfg.multi_gpu,
|
238 |
+
)
|
239 |
|
240 |
if isinstance(self.checkpoint_interval, (list, tuple)):
|
241 |
+
checkpoint_interval = [
|
242 |
+
x for x in self.checkpoint_interval if x[0] <= epoch
|
243 |
+
][-1][1]
|
244 |
else:
|
245 |
checkpoint_interval = self.checkpoint_interval
|
246 |
|
247 |
if epoch % checkpoint_interval == 0:
|
248 |
+
save_checkpoint(
|
249 |
+
self.net,
|
250 |
+
self.cfg.CHECKPOINTS_PATH,
|
251 |
+
prefix=self.task_prefix,
|
252 |
+
epoch=epoch,
|
253 |
+
multi_gpu=self.cfg.multi_gpu,
|
254 |
+
)
|
255 |
+
|
256 |
+
if hasattr(self, "lr_scheduler"):
|
257 |
self.lr_scheduler.step()
|
258 |
|
259 |
def validation(self, epoch):
|
260 |
if self.sw is None and self.is_master:
|
261 |
+
self.sw = SummaryWriterAvg(
|
262 |
+
log_dir=str(self.cfg.LOGS_PATH),
|
263 |
+
flush_secs=10,
|
264 |
+
dump_period=self.tb_dump_period,
|
265 |
+
)
|
266 |
+
|
267 |
+
log_prefix = "Val" + self.task_prefix.capitalize()
|
268 |
+
tbar = (
|
269 |
+
tqdm(self.val_data, file=self.tqdm_out, ncols=100)
|
270 |
+
if self.is_master
|
271 |
+
else self.val_data
|
272 |
+
)
|
273 |
|
274 |
for metric in self.val_metrics:
|
275 |
metric.reset_epoch_stats()
|
|
|
280 |
self.net.eval()
|
281 |
for i, batch_data in enumerate(tbar):
|
282 |
global_step = epoch * len(self.val_data) + i
|
283 |
+
(
|
284 |
+
loss,
|
285 |
+
batch_losses_logging,
|
286 |
+
splitted_batch_data,
|
287 |
+
outputs,
|
288 |
+
) = self.batch_forward(batch_data, validation=True)
|
289 |
+
|
290 |
+
batch_losses_logging["overall"] = loss
|
291 |
reduce_loss_dict(batch_losses_logging)
|
292 |
for loss_name, loss_value in batch_losses_logging.items():
|
293 |
losses_logging[loss_name].append(loss_value.item())
|
294 |
|
295 |
+
val_loss += batch_losses_logging["overall"].item()
|
296 |
|
297 |
if self.is_master:
|
298 |
+
tbar.set_description(
|
299 |
+
f"Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}"
|
300 |
+
)
|
301 |
for metric in self.val_metrics:
|
302 |
+
metric.log_states(
|
303 |
+
self.sw, f"{log_prefix}Metrics/{metric.name}", global_step
|
304 |
+
)
|
305 |
|
306 |
if self.is_master:
|
307 |
for loss_name, loss_values in losses_logging.items():
|
308 |
+
self.sw.add_scalar(
|
309 |
+
tag=f"{log_prefix}Losses/{loss_name}",
|
310 |
+
value=np.array(loss_values).mean(),
|
311 |
+
global_step=epoch,
|
312 |
+
disable_avg=True,
|
313 |
+
)
|
314 |
|
315 |
for metric in self.val_metrics:
|
316 |
+
self.sw.add_scalar(
|
317 |
+
tag=f"{log_prefix}Metrics/{metric.name}",
|
318 |
+
value=metric.get_epoch_value(),
|
319 |
+
global_step=epoch,
|
320 |
+
disable_avg=True,
|
321 |
+
)
|
322 |
|
323 |
def batch_forward(self, batch_data, validation=False):
|
324 |
metrics = self.val_metrics if validation else self.train_metrics
|
|
|
326 |
|
327 |
with torch.set_grad_enabled(not validation):
|
328 |
batch_data = {k: v.to(self.device) for k, v in batch_data.items()}
|
329 |
+
image, gt_mask, points = (
|
330 |
+
batch_data["images"],
|
331 |
+
batch_data["instances"],
|
332 |
+
batch_data["points"],
|
333 |
+
)
|
334 |
+
orig_image, orig_gt_mask, orig_points = (
|
335 |
+
image.clone(),
|
336 |
+
gt_mask.clone(),
|
337 |
+
points.clone(),
|
338 |
+
)
|
339 |
|
340 |
prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :]
|
341 |
|
|
|
350 |
if not validation:
|
351 |
self.net.eval()
|
352 |
|
353 |
+
if self.click_models is None or click_indx >= len(
|
354 |
+
self.click_models
|
355 |
+
):
|
356 |
eval_model = self.net
|
357 |
else:
|
358 |
eval_model = self.click_models[click_indx]
|
359 |
|
360 |
+
net_input = (
|
361 |
+
torch.cat((image, prev_output), dim=1)
|
362 |
+
if self.net.with_prev_mask
|
363 |
+
else image
|
364 |
+
)
|
365 |
+
prev_output = torch.sigmoid(
|
366 |
+
eval_model(net_input, points)["instances"]
|
367 |
+
)
|
368 |
|
369 |
+
points = get_next_points(
|
370 |
+
prev_output, orig_gt_mask, points, click_indx + 1
|
371 |
+
)
|
372 |
|
373 |
if not validation:
|
374 |
self.net.train()
|
375 |
|
376 |
+
if (
|
377 |
+
self.net.with_prev_mask
|
378 |
+
and self.prev_mask_drop_prob > 0
|
379 |
+
and last_click_indx is not None
|
380 |
+
):
|
381 |
+
zero_mask = (
|
382 |
+
np.random.random(size=prev_output.size(0))
|
383 |
+
< self.prev_mask_drop_prob
|
384 |
+
)
|
385 |
prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask])
|
386 |
|
387 |
+
batch_data["points"] = points
|
388 |
|
389 |
+
net_input = (
|
390 |
+
torch.cat((image, prev_output), dim=1)
|
391 |
+
if self.net.with_prev_mask
|
392 |
+
else image
|
393 |
+
)
|
394 |
output = self.net(net_input, points)
|
395 |
|
396 |
loss = 0.0
|
397 |
+
loss = self.add_loss(
|
398 |
+
"instance_loss",
|
399 |
+
loss,
|
400 |
+
losses_logging,
|
401 |
+
validation,
|
402 |
+
lambda: (output["instances"], batch_data["instances"]),
|
403 |
+
)
|
404 |
+
loss = self.add_loss(
|
405 |
+
"instance_aux_loss",
|
406 |
+
loss,
|
407 |
+
losses_logging,
|
408 |
+
validation,
|
409 |
+
lambda: (output["instances_aux"], batch_data["instances"]),
|
410 |
+
)
|
411 |
|
412 |
if self.is_master:
|
413 |
with torch.no_grad():
|
414 |
for m in metrics:
|
415 |
+
m.update(
|
416 |
+
*(output.get(x) for x in m.pred_outputs),
|
417 |
+
*(batch_data[x] for x in m.gt_outputs),
|
418 |
+
)
|
419 |
return loss, losses_logging, batch_data, output
|
420 |
|
421 |
+
def add_loss(
|
422 |
+
self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs
|
423 |
+
):
|
424 |
loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg
|
425 |
+
loss_weight = loss_cfg.get(loss_name + "_weight", 0.0)
|
426 |
if loss_weight > 0.0:
|
427 |
loss_criterion = loss_cfg.get(loss_name)
|
428 |
loss = loss_criterion(*lambda_loss_inputs())
|
|
|
440 |
|
441 |
if not output_images_path.exists():
|
442 |
output_images_path.mkdir(parents=True)
|
443 |
+
image_name_prefix = f"{global_step:06d}"
|
444 |
|
445 |
def _save_image(suffix, image):
|
446 |
+
cv2.imwrite(
|
447 |
+
str(output_images_path / f"{image_name_prefix}_{suffix}.jpg"),
|
448 |
+
image,
|
449 |
+
[cv2.IMWRITE_JPEG_QUALITY, 85],
|
450 |
+
)
|
451 |
|
452 |
+
images = splitted_batch_data["images"]
|
453 |
+
points = splitted_batch_data["points"]
|
454 |
+
instance_masks = splitted_batch_data["instances"]
|
455 |
|
456 |
gt_instance_masks = instance_masks.cpu().numpy()
|
457 |
+
predicted_instance_masks = (
|
458 |
+
torch.sigmoid(outputs["instances"]).detach().cpu().numpy()
|
459 |
+
)
|
460 |
points = points.detach().cpu().numpy()
|
461 |
|
462 |
image_blob, points = images[0], points[0]
|
|
|
466 |
image = image_blob.cpu().numpy() * 255
|
467 |
image = image.transpose((1, 2, 0))
|
468 |
|
469 |
+
image_with_points = draw_points(
|
470 |
+
image, points[: self.max_interactive_points], (0, 255, 0)
|
471 |
+
)
|
472 |
+
image_with_points = draw_points(
|
473 |
+
image_with_points, points[self.max_interactive_points :], (0, 0, 255)
|
474 |
+
)
|
475 |
|
476 |
gt_mask[gt_mask < 0] = 0.25
|
477 |
gt_mask = draw_probmap(gt_mask)
|
478 |
predicted_mask = draw_probmap(predicted_mask)
|
479 |
+
viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(
|
480 |
+
np.uint8
|
481 |
+
)
|
482 |
|
483 |
+
_save_image("instance_segmentation", viz_image[:, :, ::-1])
|
484 |
|
485 |
def _load_weights(self, net):
|
486 |
if self.cfg.weights is not None:
|
|
|
490 |
else:
|
491 |
raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'")
|
492 |
elif self.cfg.resume_exp is not None:
|
493 |
+
checkpoints = list(
|
494 |
+
self.cfg.CHECKPOINTS_PATH.glob(f"{self.cfg.resume_prefix}*.pth")
|
495 |
+
)
|
496 |
assert len(checkpoints) == 1
|
497 |
|
498 |
checkpoint_path = checkpoints[0]
|
499 |
+
logger.info(f"Load checkpoint from path: {checkpoint_path}")
|
500 |
load_weights(net, str(checkpoint_path))
|
501 |
return net
|
502 |
|
|
|
513 |
fn_mask = np.logical_and(gt, pred < pred_thresh)
|
514 |
fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh)
|
515 |
|
516 |
+
fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8)
|
517 |
+
fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8)
|
518 |
num_points = points.size(1) // 2
|
519 |
points = points.clone()
|
520 |
|
|
|
545 |
|
546 |
def load_weights(model, path_to_weights):
|
547 |
current_state_dict = model.state_dict()
|
548 |
+
new_state_dict = torch.load(path_to_weights, map_location="cpu")["state_dict"]
|
549 |
current_state_dict.update(new_state_dict)
|
550 |
model.load_state_dict(current_state_dict)
|
isegm/inference/clicker.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1 |
-
import numpy as np
|
2 |
from copy import deepcopy
|
|
|
3 |
import cv2
|
|
|
4 |
|
5 |
|
6 |
class Clicker(object):
|
7 |
-
def __init__(
|
|
|
|
|
8 |
self.click_indx_offset = click_indx_offset
|
9 |
if gt_mask is not None:
|
10 |
self.gt_mask = gt_mask == 1
|
@@ -27,12 +30,18 @@ class Clicker(object):
|
|
27 |
return self.clicks_list[:clicks_limit]
|
28 |
|
29 |
def _get_next_click(self, pred_mask, padding=True):
|
30 |
-
fn_mask = np.logical_and(
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
if padding:
|
34 |
-
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)),
|
35 |
-
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)),
|
36 |
|
37 |
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
|
38 |
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
|
|
|
|
|
1 |
from copy import deepcopy
|
2 |
+
|
3 |
import cv2
|
4 |
+
import numpy as np
|
5 |
|
6 |
|
7 |
class Clicker(object):
|
8 |
+
def __init__(
|
9 |
+
self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0
|
10 |
+
):
|
11 |
self.click_indx_offset = click_indx_offset
|
12 |
if gt_mask is not None:
|
13 |
self.gt_mask = gt_mask == 1
|
|
|
30 |
return self.clicks_list[:clicks_limit]
|
31 |
|
32 |
def _get_next_click(self, pred_mask, padding=True):
|
33 |
+
fn_mask = np.logical_and(
|
34 |
+
np.logical_and(self.gt_mask, np.logical_not(pred_mask)),
|
35 |
+
self.not_ignore_mask,
|
36 |
+
)
|
37 |
+
fp_mask = np.logical_and(
|
38 |
+
np.logical_and(np.logical_not(self.gt_mask), pred_mask),
|
39 |
+
self.not_ignore_mask,
|
40 |
+
)
|
41 |
|
42 |
if padding:
|
43 |
+
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
|
44 |
+
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
|
45 |
|
46 |
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
|
47 |
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
|
isegm/inference/evaluation.py
CHANGED
@@ -20,8 +20,9 @@ def evaluate_dataset(dataset, predictor, **kwargs):
|
|
20 |
for index in tqdm(range(len(dataset)), leave=False):
|
21 |
sample = dataset.get_sample(index)
|
22 |
|
23 |
-
_, sample_ious, _ = evaluate_sample(
|
24 |
-
|
|
|
25 |
all_ious.append(sample_ious)
|
26 |
end_time = time()
|
27 |
elapsed_time = end_time - start_time
|
@@ -29,9 +30,17 @@ def evaluate_dataset(dataset, predictor, **kwargs):
|
|
29 |
return all_ious, elapsed_time
|
30 |
|
31 |
|
32 |
-
def evaluate_sample(
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
clicker = Clicker(gt_mask=gt_mask)
|
36 |
pred_mask = np.zeros_like(gt_mask)
|
37 |
ious_list = []
|
@@ -45,7 +54,14 @@ def evaluate_sample(image, gt_mask, predictor, max_iou_thr,
|
|
45 |
pred_mask = pred_probs > pred_thr
|
46 |
|
47 |
if callback is not None:
|
48 |
-
callback(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
iou = utils.get_iou(gt_mask, pred_mask)
|
51 |
ious_list.append(iou)
|
|
|
20 |
for index in tqdm(range(len(dataset)), leave=False):
|
21 |
sample = dataset.get_sample(index)
|
22 |
|
23 |
+
_, sample_ious, _ = evaluate_sample(
|
24 |
+
sample.image, sample.gt_mask, predictor, sample_id=index, **kwargs
|
25 |
+
)
|
26 |
all_ious.append(sample_ious)
|
27 |
end_time = time()
|
28 |
elapsed_time = end_time - start_time
|
|
|
30 |
return all_ious, elapsed_time
|
31 |
|
32 |
|
33 |
+
def evaluate_sample(
|
34 |
+
image,
|
35 |
+
gt_mask,
|
36 |
+
predictor,
|
37 |
+
max_iou_thr,
|
38 |
+
pred_thr=0.49,
|
39 |
+
min_clicks=1,
|
40 |
+
max_clicks=20,
|
41 |
+
sample_id=None,
|
42 |
+
callback=None,
|
43 |
+
):
|
44 |
clicker = Clicker(gt_mask=gt_mask)
|
45 |
pred_mask = np.zeros_like(gt_mask)
|
46 |
ious_list = []
|
|
|
54 |
pred_mask = pred_probs > pred_thr
|
55 |
|
56 |
if callback is not None:
|
57 |
+
callback(
|
58 |
+
image,
|
59 |
+
gt_mask,
|
60 |
+
pred_probs,
|
61 |
+
sample_id,
|
62 |
+
click_indx,
|
63 |
+
clicker.clicks_list,
|
64 |
+
)
|
65 |
|
66 |
iou = utils.get_iou(gt_mask, pred_mask)
|
67 |
ious_list.append(iou)
|
isegm/inference/predictors/__init__.py
CHANGED
@@ -1,27 +1,31 @@
|
|
1 |
-
from .base import BasePredictor
|
2 |
-
from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor
|
3 |
-
from .brs_functors import InputOptimizer, ScaleBiasOptimizer
|
4 |
from isegm.inference.transforms import ZoomIn
|
5 |
from isegm.model.is_hrnet_model import HRNetModel
|
6 |
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
def get_predictor(
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
lbfgs_params_ = {
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
}
|
21 |
|
22 |
-
predictor_params_ = {
|
23 |
-
'optimize_after_n_clicks': 1
|
24 |
-
}
|
25 |
|
26 |
if zoom_in_params is not None:
|
27 |
zoom_in = ZoomIn(**zoom_in_params)
|
@@ -30,68 +34,86 @@ def get_predictor(net, brs_mode, device,
|
|
30 |
|
31 |
if lbfgs_params is not None:
|
32 |
lbfgs_params_.update(lbfgs_params)
|
33 |
-
lbfgs_params_[
|
34 |
|
35 |
if brs_opt_func_params is None:
|
36 |
brs_opt_func_params = dict()
|
37 |
|
38 |
if isinstance(net, (list, tuple)):
|
39 |
-
assert brs_mode ==
|
40 |
|
41 |
-
if brs_mode ==
|
42 |
if predictor_params is not None:
|
43 |
predictor_params_.update(predictor_params)
|
44 |
-
predictor = BasePredictor(
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
if predictor_params is not None:
|
50 |
predictor_params_.update(predictor_params)
|
51 |
|
52 |
insertion_mode = {
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
}[brs_mode]
|
57 |
|
58 |
-
opt_functor = ScaleBiasOptimizer(
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
62 |
|
63 |
if isinstance(net, HRNetModel):
|
64 |
FeaturePredictor = HRNetFeatureBRSPredictor
|
65 |
-
insertion_mode = {
|
|
|
|
|
66 |
else:
|
67 |
FeaturePredictor = FeatureBRSPredictor
|
68 |
|
69 |
-
predictor = FeaturePredictor(
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
81 |
if predictor_params is not None:
|
82 |
predictor_params_.update(predictor_params)
|
83 |
|
84 |
-
opt_functor = InputOptimizer(
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
95 |
else:
|
96 |
raise NotImplementedError
|
97 |
|
|
|
|
|
|
|
|
|
1 |
from isegm.inference.transforms import ZoomIn
|
2 |
from isegm.model.is_hrnet_model import HRNetModel
|
3 |
|
4 |
+
from .base import BasePredictor
|
5 |
+
from .brs import (FeatureBRSPredictor, HRNetFeatureBRSPredictor,
|
6 |
+
InputBRSPredictor)
|
7 |
+
from .brs_functors import InputOptimizer, ScaleBiasOptimizer
|
8 |
+
|
9 |
|
10 |
+
def get_predictor(
|
11 |
+
net,
|
12 |
+
brs_mode,
|
13 |
+
device,
|
14 |
+
prob_thresh=0.49,
|
15 |
+
with_flip=True,
|
16 |
+
zoom_in_params=dict(),
|
17 |
+
predictor_params=None,
|
18 |
+
brs_opt_func_params=None,
|
19 |
+
lbfgs_params=None,
|
20 |
+
):
|
21 |
lbfgs_params_ = {
|
22 |
+
"m": 20,
|
23 |
+
"factr": 0,
|
24 |
+
"pgtol": 1e-8,
|
25 |
+
"maxfun": 20,
|
26 |
}
|
27 |
|
28 |
+
predictor_params_ = {"optimize_after_n_clicks": 1}
|
|
|
|
|
29 |
|
30 |
if zoom_in_params is not None:
|
31 |
zoom_in = ZoomIn(**zoom_in_params)
|
|
|
34 |
|
35 |
if lbfgs_params is not None:
|
36 |
lbfgs_params_.update(lbfgs_params)
|
37 |
+
lbfgs_params_["maxiter"] = 2 * lbfgs_params_["maxfun"]
|
38 |
|
39 |
if brs_opt_func_params is None:
|
40 |
brs_opt_func_params = dict()
|
41 |
|
42 |
if isinstance(net, (list, tuple)):
|
43 |
+
assert brs_mode == "NoBRS", "Multi-stage models support only NoBRS mode."
|
44 |
|
45 |
+
if brs_mode == "NoBRS":
|
46 |
if predictor_params is not None:
|
47 |
predictor_params_.update(predictor_params)
|
48 |
+
predictor = BasePredictor(
|
49 |
+
net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_
|
50 |
+
)
|
51 |
+
elif brs_mode.startswith("f-BRS"):
|
52 |
+
predictor_params_.update(
|
53 |
+
{
|
54 |
+
"net_clicks_limit": 8,
|
55 |
+
}
|
56 |
+
)
|
57 |
if predictor_params is not None:
|
58 |
predictor_params_.update(predictor_params)
|
59 |
|
60 |
insertion_mode = {
|
61 |
+
"f-BRS-A": "after_c4",
|
62 |
+
"f-BRS-B": "after_aspp",
|
63 |
+
"f-BRS-C": "after_deeplab",
|
64 |
}[brs_mode]
|
65 |
|
66 |
+
opt_functor = ScaleBiasOptimizer(
|
67 |
+
prob_thresh=prob_thresh,
|
68 |
+
with_flip=with_flip,
|
69 |
+
optimizer_params=lbfgs_params_,
|
70 |
+
**brs_opt_func_params
|
71 |
+
)
|
72 |
|
73 |
if isinstance(net, HRNetModel):
|
74 |
FeaturePredictor = HRNetFeatureBRSPredictor
|
75 |
+
insertion_mode = {"after_c4": "A", "after_aspp": "A", "after_deeplab": "C"}[
|
76 |
+
insertion_mode
|
77 |
+
]
|
78 |
else:
|
79 |
FeaturePredictor = FeatureBRSPredictor
|
80 |
|
81 |
+
predictor = FeaturePredictor(
|
82 |
+
net,
|
83 |
+
device,
|
84 |
+
opt_functor=opt_functor,
|
85 |
+
with_flip=with_flip,
|
86 |
+
insertion_mode=insertion_mode,
|
87 |
+
zoom_in=zoom_in,
|
88 |
+
**predictor_params_
|
89 |
+
)
|
90 |
+
elif brs_mode == "RGB-BRS" or brs_mode == "DistMap-BRS":
|
91 |
+
use_dmaps = brs_mode == "DistMap-BRS"
|
92 |
+
|
93 |
+
predictor_params_.update(
|
94 |
+
{
|
95 |
+
"net_clicks_limit": 5,
|
96 |
+
}
|
97 |
+
)
|
98 |
if predictor_params is not None:
|
99 |
predictor_params_.update(predictor_params)
|
100 |
|
101 |
+
opt_functor = InputOptimizer(
|
102 |
+
prob_thresh=prob_thresh,
|
103 |
+
with_flip=with_flip,
|
104 |
+
optimizer_params=lbfgs_params_,
|
105 |
+
**brs_opt_func_params
|
106 |
+
)
|
107 |
+
|
108 |
+
predictor = InputBRSPredictor(
|
109 |
+
net,
|
110 |
+
device,
|
111 |
+
optimize_target="dmaps" if use_dmaps else "rgb",
|
112 |
+
opt_functor=opt_functor,
|
113 |
+
with_flip=with_flip,
|
114 |
+
zoom_in=zoom_in,
|
115 |
+
**predictor_params_
|
116 |
+
)
|
117 |
else:
|
118 |
raise NotImplementedError
|
119 |
|
isegm/inference/predictors/base.py
CHANGED
@@ -1,16 +1,22 @@
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
from torchvision import transforms
|
4 |
-
|
|
|
|
|
5 |
|
6 |
|
7 |
class BasePredictor(object):
|
8 |
-
def __init__(
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
self.with_flip = with_flip
|
15 |
self.net_clicks_limit = net_clicks_limit
|
16 |
self.original_image = None
|
@@ -48,7 +54,12 @@ class BasePredictor(object):
|
|
48 |
clicks_list = clicker.get_clicks()
|
49 |
|
50 |
if self.click_models is not None:
|
51 |
-
model_indx =
|
|
|
|
|
|
|
|
|
|
|
52 |
if model_indx != self.model_indx:
|
53 |
self.model_indx = model_indx
|
54 |
self.net = self.click_models[model_indx]
|
@@ -56,15 +67,16 @@ class BasePredictor(object):
|
|
56 |
input_image = self.original_image
|
57 |
if prev_mask is None:
|
58 |
prev_mask = self.prev_prediction
|
59 |
-
if hasattr(self.net,
|
60 |
input_image = torch.cat((input_image, prev_mask), dim=1)
|
61 |
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
|
62 |
input_image, [clicks_list]
|
63 |
)
|
64 |
|
65 |
pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
|
66 |
-
prediction = F.interpolate(
|
67 |
-
|
|
|
68 |
|
69 |
for t in reversed(self.transforms):
|
70 |
prediction = t.inv_transform(prediction)
|
@@ -77,7 +89,7 @@ class BasePredictor(object):
|
|
77 |
|
78 |
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
|
79 |
points_nd = self.get_points_nd(clicks_lists)
|
80 |
-
return self.net(image_nd, points_nd)[
|
81 |
|
82 |
def _get_transform_states(self):
|
83 |
return [x.get_state() for x in self.transforms]
|
@@ -97,30 +109,43 @@ class BasePredictor(object):
|
|
97 |
|
98 |
def get_points_nd(self, clicks_lists):
|
99 |
total_clicks = []
|
100 |
-
num_pos_clicks = [
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
102 |
num_max_points = max(num_pos_clicks + num_neg_clicks)
|
103 |
if self.net_clicks_limit is not None:
|
104 |
num_max_points = min(self.net_clicks_limit, num_max_points)
|
105 |
num_max_points = max(1, num_max_points)
|
106 |
|
107 |
for clicks_list in clicks_lists:
|
108 |
-
clicks_list = clicks_list[:self.net_clicks_limit]
|
109 |
-
pos_clicks = [
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
total_clicks.append(pos_clicks + neg_clicks)
|
115 |
|
116 |
return torch.tensor(total_clicks, device=self.device)
|
117 |
|
118 |
def get_states(self):
|
119 |
return {
|
120 |
-
|
121 |
-
|
122 |
}
|
123 |
|
124 |
def set_states(self, states):
|
125 |
-
self._set_transform_states(states[
|
126 |
-
self.prev_prediction = states[
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
from torchvision import transforms
|
4 |
+
|
5 |
+
from isegm.inference.transforms import (AddHorizontalFlip, LimitLongestSide,
|
6 |
+
SigmoidForPred)
|
7 |
|
8 |
|
9 |
class BasePredictor(object):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
model,
|
13 |
+
device,
|
14 |
+
net_clicks_limit=None,
|
15 |
+
with_flip=False,
|
16 |
+
zoom_in=None,
|
17 |
+
max_size=None,
|
18 |
+
**kwargs
|
19 |
+
):
|
20 |
self.with_flip = with_flip
|
21 |
self.net_clicks_limit = net_clicks_limit
|
22 |
self.original_image = None
|
|
|
54 |
clicks_list = clicker.get_clicks()
|
55 |
|
56 |
if self.click_models is not None:
|
57 |
+
model_indx = (
|
58 |
+
min(
|
59 |
+
clicker.click_indx_offset + len(clicks_list), len(self.click_models)
|
60 |
+
)
|
61 |
+
- 1
|
62 |
+
)
|
63 |
if model_indx != self.model_indx:
|
64 |
self.model_indx = model_indx
|
65 |
self.net = self.click_models[model_indx]
|
|
|
67 |
input_image = self.original_image
|
68 |
if prev_mask is None:
|
69 |
prev_mask = self.prev_prediction
|
70 |
+
if hasattr(self.net, "with_prev_mask") and self.net.with_prev_mask:
|
71 |
input_image = torch.cat((input_image, prev_mask), dim=1)
|
72 |
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
|
73 |
input_image, [clicks_list]
|
74 |
)
|
75 |
|
76 |
pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
|
77 |
+
prediction = F.interpolate(
|
78 |
+
pred_logits, mode="bilinear", align_corners=True, size=image_nd.size()[2:]
|
79 |
+
)
|
80 |
|
81 |
for t in reversed(self.transforms):
|
82 |
prediction = t.inv_transform(prediction)
|
|
|
89 |
|
90 |
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
|
91 |
points_nd = self.get_points_nd(clicks_lists)
|
92 |
+
return self.net(image_nd, points_nd)["instances"]
|
93 |
|
94 |
def _get_transform_states(self):
|
95 |
return [x.get_state() for x in self.transforms]
|
|
|
109 |
|
110 |
def get_points_nd(self, clicks_lists):
|
111 |
total_clicks = []
|
112 |
+
num_pos_clicks = [
|
113 |
+
sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists
|
114 |
+
]
|
115 |
+
num_neg_clicks = [
|
116 |
+
len(clicks_list) - num_pos
|
117 |
+
for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
|
118 |
+
]
|
119 |
num_max_points = max(num_pos_clicks + num_neg_clicks)
|
120 |
if self.net_clicks_limit is not None:
|
121 |
num_max_points = min(self.net_clicks_limit, num_max_points)
|
122 |
num_max_points = max(1, num_max_points)
|
123 |
|
124 |
for clicks_list in clicks_lists:
|
125 |
+
clicks_list = clicks_list[: self.net_clicks_limit]
|
126 |
+
pos_clicks = [
|
127 |
+
click.coords_and_indx for click in clicks_list if click.is_positive
|
128 |
+
]
|
129 |
+
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [
|
130 |
+
(-1, -1, -1)
|
131 |
+
]
|
132 |
+
|
133 |
+
neg_clicks = [
|
134 |
+
click.coords_and_indx for click in clicks_list if not click.is_positive
|
135 |
+
]
|
136 |
+
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [
|
137 |
+
(-1, -1, -1)
|
138 |
+
]
|
139 |
total_clicks.append(pos_clicks + neg_clicks)
|
140 |
|
141 |
return torch.tensor(total_clicks, device=self.device)
|
142 |
|
143 |
def get_states(self):
|
144 |
return {
|
145 |
+
"transform_states": self._get_transform_states(),
|
146 |
+
"prev_prediction": self.prev_prediction.clone(),
|
147 |
}
|
148 |
|
149 |
def set_states(self, states):
|
150 |
+
self._set_transform_states(states["transform_states"])
|
151 |
+
self.prev_prediction = states["prev_prediction"]
|
isegm/inference/predictors/brs.py
CHANGED
@@ -1,6 +1,6 @@
|
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
-
import numpy as np
|
4 |
from scipy.optimize import fmin_l_bfgs_b
|
5 |
|
6 |
from .base import BasePredictor
|
@@ -21,8 +21,12 @@ class BRSBasePredictor(BasePredictor):
|
|
21 |
self.input_data = None
|
22 |
|
23 |
def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
|
24 |
-
pos_clicks_map = np.zeros(
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
|
27 |
for list_indx, clicks_list in enumerate(clicks_lists):
|
28 |
for click in clicks_list:
|
@@ -43,24 +47,29 @@ class BRSBasePredictor(BasePredictor):
|
|
43 |
return pos_clicks_map, neg_clicks_map
|
44 |
|
45 |
def get_states(self):
|
46 |
-
return {
|
|
|
|
|
|
|
47 |
|
48 |
def set_states(self, states):
|
49 |
-
self._set_transform_states(states[
|
50 |
-
self.opt_data = states[
|
51 |
|
52 |
|
53 |
class FeatureBRSPredictor(BRSBasePredictor):
|
54 |
-
def __init__(
|
|
|
|
|
55 |
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
|
56 |
self.insertion_mode = insertion_mode
|
57 |
self._c1_features = None
|
58 |
|
59 |
-
if self.insertion_mode ==
|
60 |
self.num_channels = model.feature_extractor.ch
|
61 |
-
elif self.insertion_mode ==
|
62 |
self.num_channels = model.feature_extractor.aspp_in_channels
|
63 |
-
elif self.insertion_mode ==
|
64 |
self.num_channels = model.feature_extractor.ch + 32
|
65 |
else:
|
66 |
raise NotImplementedError
|
@@ -72,10 +81,17 @@ class FeatureBRSPredictor(BRSBasePredictor):
|
|
72 |
num_clicks = len(clicks_lists[0])
|
73 |
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
|
74 |
|
75 |
-
if
|
|
|
|
|
|
|
76 |
self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
|
77 |
|
78 |
-
if
|
|
|
|
|
|
|
|
|
79 |
self.input_data = self._get_head_input(image_nd, points_nd)
|
80 |
|
81 |
def get_prediction_logits(scale, bias):
|
@@ -87,24 +103,39 @@ class FeatureBRSPredictor(BRSBasePredictor):
|
|
87 |
|
88 |
scaled_backbone_features = self.input_data * scale
|
89 |
scaled_backbone_features = scaled_backbone_features + bias
|
90 |
-
if self.insertion_mode ==
|
91 |
x = self.net.feature_extractor.aspp(scaled_backbone_features)
|
92 |
-
x = F.interpolate(
|
93 |
-
|
|
|
|
|
|
|
|
|
94 |
x = torch.cat((x, self._c1_features), dim=1)
|
95 |
scaled_backbone_features = self.net.feature_extractor.head(x)
|
96 |
-
elif self.insertion_mode ==
|
97 |
-
scaled_backbone_features = self.net.feature_extractor.head(
|
|
|
|
|
98 |
|
99 |
pred_logits = self.net.head(scaled_backbone_features)
|
100 |
-
pred_logits = F.interpolate(
|
101 |
-
|
|
|
|
|
|
|
|
|
102 |
return pred_logits
|
103 |
|
104 |
-
self.opt_functor.init_click(
|
|
|
|
|
105 |
if num_clicks > self.optimize_after_n_clicks:
|
106 |
-
opt_result = fmin_l_bfgs_b(
|
107 |
-
|
|
|
|
|
|
|
108 |
self.opt_data = opt_result[0]
|
109 |
|
110 |
with torch.no_grad():
|
@@ -125,37 +156,45 @@ class FeatureBRSPredictor(BRSBasePredictor):
|
|
125 |
if self.net.rgb_conv is not None:
|
126 |
x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
|
127 |
additional_features = None
|
128 |
-
elif hasattr(self.net,
|
129 |
x = image_nd
|
130 |
additional_features = self.net.maps_transform(coord_features)
|
131 |
|
132 |
-
if self.insertion_mode ==
|
133 |
-
c1, _, c3, c4 = self.net.feature_extractor.backbone(
|
|
|
|
|
134 |
c1 = self.net.feature_extractor.skip_project(c1)
|
135 |
|
136 |
-
if self.insertion_mode ==
|
137 |
x = self.net.feature_extractor.aspp(c4)
|
138 |
-
x = F.interpolate(
|
|
|
|
|
139 |
x = torch.cat((x, c1), dim=1)
|
140 |
backbone_features = x
|
141 |
else:
|
142 |
backbone_features = c4
|
143 |
self._c1_features = c1
|
144 |
else:
|
145 |
-
backbone_features = self.net.feature_extractor(x, additional_features)[
|
|
|
|
|
146 |
|
147 |
return backbone_features
|
148 |
|
149 |
|
150 |
class HRNetFeatureBRSPredictor(BRSBasePredictor):
|
151 |
-
def __init__(self, model, device, opt_functor, insertion_mode=
|
152 |
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
|
153 |
self.insertion_mode = insertion_mode
|
154 |
self._c1_features = None
|
155 |
|
156 |
-
if self.insertion_mode ==
|
157 |
-
self.num_channels = sum(
|
158 |
-
|
|
|
|
|
159 |
self.num_channels = 2 * model.feature_extractor.ocr_width
|
160 |
else:
|
161 |
raise NotImplementedError
|
@@ -166,10 +205,17 @@ class HRNetFeatureBRSPredictor(BRSBasePredictor):
|
|
166 |
num_clicks = len(clicks_lists[0])
|
167 |
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
|
168 |
|
169 |
-
if
|
|
|
|
|
|
|
170 |
self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
|
171 |
|
172 |
-
if
|
|
|
|
|
|
|
|
|
173 |
self.input_data = self._get_head_input(image_nd, points_nd)
|
174 |
|
175 |
def get_prediction_logits(scale, bias):
|
@@ -181,29 +227,44 @@ class HRNetFeatureBRSPredictor(BRSBasePredictor):
|
|
181 |
|
182 |
scaled_backbone_features = self.input_data * scale
|
183 |
scaled_backbone_features = scaled_backbone_features + bias
|
184 |
-
if self.insertion_mode ==
|
185 |
if self.net.feature_extractor.ocr_width > 0:
|
186 |
-
out_aux = self.net.feature_extractor.aux_head(
|
187 |
-
|
|
|
|
|
|
|
|
|
188 |
|
189 |
context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
|
190 |
feats = self.net.feature_extractor.ocr_distri_head(feats, context)
|
191 |
else:
|
192 |
feats = scaled_backbone_features
|
193 |
pred_logits = self.net.feature_extractor.cls_head(feats)
|
194 |
-
elif self.insertion_mode ==
|
195 |
-
pred_logits = self.net.feature_extractor.cls_head(
|
|
|
|
|
196 |
else:
|
197 |
raise NotImplementedError
|
198 |
|
199 |
-
pred_logits = F.interpolate(
|
200 |
-
|
|
|
|
|
|
|
|
|
201 |
return pred_logits
|
202 |
|
203 |
-
self.opt_functor.init_click(
|
|
|
|
|
204 |
if num_clicks > self.optimize_after_n_clicks:
|
205 |
-
opt_result = fmin_l_bfgs_b(
|
206 |
-
|
|
|
|
|
|
|
207 |
self.opt_data = opt_result[0]
|
208 |
|
209 |
with torch.no_grad():
|
@@ -224,20 +285,24 @@ class HRNetFeatureBRSPredictor(BRSBasePredictor):
|
|
224 |
if self.net.rgb_conv is not None:
|
225 |
x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
|
226 |
additional_features = None
|
227 |
-
elif hasattr(self.net,
|
228 |
x = image_nd
|
229 |
additional_features = self.net.maps_transform(coord_features)
|
230 |
|
231 |
-
feats = self.net.feature_extractor.compute_hrnet_feats(
|
|
|
|
|
232 |
|
233 |
-
if self.insertion_mode ==
|
234 |
backbone_features = feats
|
235 |
-
elif self.insertion_mode ==
|
236 |
out_aux = self.net.feature_extractor.aux_head(feats)
|
237 |
feats = self.net.feature_extractor.conv3x3_ocr(feats)
|
238 |
|
239 |
context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
|
240 |
-
backbone_features = self.net.feature_extractor.ocr_distri_head(
|
|
|
|
|
241 |
else:
|
242 |
raise NotImplementedError
|
243 |
|
@@ -245,7 +310,7 @@ class HRNetFeatureBRSPredictor(BRSBasePredictor):
|
|
245 |
|
246 |
|
247 |
class InputBRSPredictor(BRSBasePredictor):
|
248 |
-
def __init__(self, model, device, opt_functor, optimize_target=
|
249 |
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
|
250 |
self.optimize_target = optimize_target
|
251 |
|
@@ -255,21 +320,28 @@ class InputBRSPredictor(BRSBasePredictor):
|
|
255 |
num_clicks = len(clicks_lists[0])
|
256 |
|
257 |
if self.opt_data is None or is_image_changed:
|
258 |
-
if self.optimize_target ==
|
259 |
-
opt_channels =
|
|
|
|
|
|
|
|
|
260 |
else:
|
261 |
opt_channels = 3
|
262 |
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
|
263 |
-
self.opt_data = torch.zeros(
|
264 |
-
|
|
|
|
|
|
|
265 |
|
266 |
def get_prediction_logits(opt_bias):
|
267 |
input_image, prev_mask = self.net.prepare_input(image_nd)
|
268 |
dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd)
|
269 |
|
270 |
-
if self.optimize_target ==
|
271 |
input_image = input_image + opt_bias
|
272 |
-
elif self.optimize_target ==
|
273 |
if self.net.with_prev_mask:
|
274 |
dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias
|
275 |
else:
|
@@ -277,25 +349,44 @@ class InputBRSPredictor(BRSBasePredictor):
|
|
277 |
|
278 |
if self.net.rgb_conv is not None:
|
279 |
x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
|
280 |
-
if self.optimize_target ==
|
281 |
x = x + opt_bias
|
282 |
coord_features = None
|
283 |
-
elif hasattr(self.net,
|
284 |
x = input_image
|
285 |
coord_features = self.net.maps_transform(dmaps)
|
286 |
|
287 |
-
pred_logits = self.net.backbone_forward(x, coord_features=coord_features)[
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
return pred_logits
|
291 |
|
292 |
-
self.opt_functor.init_click(
|
293 |
-
|
|
|
|
|
|
|
|
|
|
|
294 |
if num_clicks > self.optimize_after_n_clicks:
|
295 |
-
opt_result = fmin_l_bfgs_b(
|
296 |
-
|
297 |
-
|
298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
300 |
with torch.no_grad():
|
301 |
if self.opt_functor.best_prediction is not None:
|
|
|
1 |
+
import numpy as np
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
|
|
4 |
from scipy.optimize import fmin_l_bfgs_b
|
5 |
|
6 |
from .base import BasePredictor
|
|
|
21 |
self.input_data = None
|
22 |
|
23 |
def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1):
|
24 |
+
pos_clicks_map = np.zeros(
|
25 |
+
(len(clicks_lists), 1) + image_shape, dtype=np.float32
|
26 |
+
)
|
27 |
+
neg_clicks_map = np.zeros(
|
28 |
+
(len(clicks_lists), 1) + image_shape, dtype=np.float32
|
29 |
+
)
|
30 |
|
31 |
for list_indx, clicks_list in enumerate(clicks_lists):
|
32 |
for click in clicks_list:
|
|
|
47 |
return pos_clicks_map, neg_clicks_map
|
48 |
|
49 |
def get_states(self):
|
50 |
+
return {
|
51 |
+
"transform_states": self._get_transform_states(),
|
52 |
+
"opt_data": self.opt_data,
|
53 |
+
}
|
54 |
|
55 |
def set_states(self, states):
|
56 |
+
self._set_transform_states(states["transform_states"])
|
57 |
+
self.opt_data = states["opt_data"]
|
58 |
|
59 |
|
60 |
class FeatureBRSPredictor(BRSBasePredictor):
|
61 |
+
def __init__(
|
62 |
+
self, model, device, opt_functor, insertion_mode="after_deeplab", **kwargs
|
63 |
+
):
|
64 |
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
|
65 |
self.insertion_mode = insertion_mode
|
66 |
self._c1_features = None
|
67 |
|
68 |
+
if self.insertion_mode == "after_deeplab":
|
69 |
self.num_channels = model.feature_extractor.ch
|
70 |
+
elif self.insertion_mode == "after_c4":
|
71 |
self.num_channels = model.feature_extractor.aspp_in_channels
|
72 |
+
elif self.insertion_mode == "after_aspp":
|
73 |
self.num_channels = model.feature_extractor.ch + 32
|
74 |
else:
|
75 |
raise NotImplementedError
|
|
|
81 |
num_clicks = len(clicks_lists[0])
|
82 |
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
|
83 |
|
84 |
+
if (
|
85 |
+
self.opt_data is None
|
86 |
+
or self.opt_data.shape[0] // (2 * self.num_channels) != bs
|
87 |
+
):
|
88 |
self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
|
89 |
|
90 |
+
if (
|
91 |
+
num_clicks <= self.net_clicks_limit
|
92 |
+
or is_image_changed
|
93 |
+
or self.input_data is None
|
94 |
+
):
|
95 |
self.input_data = self._get_head_input(image_nd, points_nd)
|
96 |
|
97 |
def get_prediction_logits(scale, bias):
|
|
|
103 |
|
104 |
scaled_backbone_features = self.input_data * scale
|
105 |
scaled_backbone_features = scaled_backbone_features + bias
|
106 |
+
if self.insertion_mode == "after_c4":
|
107 |
x = self.net.feature_extractor.aspp(scaled_backbone_features)
|
108 |
+
x = F.interpolate(
|
109 |
+
x,
|
110 |
+
mode="bilinear",
|
111 |
+
size=self._c1_features.size()[2:],
|
112 |
+
align_corners=True,
|
113 |
+
)
|
114 |
x = torch.cat((x, self._c1_features), dim=1)
|
115 |
scaled_backbone_features = self.net.feature_extractor.head(x)
|
116 |
+
elif self.insertion_mode == "after_aspp":
|
117 |
+
scaled_backbone_features = self.net.feature_extractor.head(
|
118 |
+
scaled_backbone_features
|
119 |
+
)
|
120 |
|
121 |
pred_logits = self.net.head(scaled_backbone_features)
|
122 |
+
pred_logits = F.interpolate(
|
123 |
+
pred_logits,
|
124 |
+
size=image_nd.size()[2:],
|
125 |
+
mode="bilinear",
|
126 |
+
align_corners=True,
|
127 |
+
)
|
128 |
return pred_logits
|
129 |
|
130 |
+
self.opt_functor.init_click(
|
131 |
+
get_prediction_logits, pos_mask, neg_mask, self.device
|
132 |
+
)
|
133 |
if num_clicks > self.optimize_after_n_clicks:
|
134 |
+
opt_result = fmin_l_bfgs_b(
|
135 |
+
func=self.opt_functor,
|
136 |
+
x0=self.opt_data,
|
137 |
+
**self.opt_functor.optimizer_params
|
138 |
+
)
|
139 |
self.opt_data = opt_result[0]
|
140 |
|
141 |
with torch.no_grad():
|
|
|
156 |
if self.net.rgb_conv is not None:
|
157 |
x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
|
158 |
additional_features = None
|
159 |
+
elif hasattr(self.net, "maps_transform"):
|
160 |
x = image_nd
|
161 |
additional_features = self.net.maps_transform(coord_features)
|
162 |
|
163 |
+
if self.insertion_mode == "after_c4" or self.insertion_mode == "after_aspp":
|
164 |
+
c1, _, c3, c4 = self.net.feature_extractor.backbone(
|
165 |
+
x, additional_features
|
166 |
+
)
|
167 |
c1 = self.net.feature_extractor.skip_project(c1)
|
168 |
|
169 |
+
if self.insertion_mode == "after_aspp":
|
170 |
x = self.net.feature_extractor.aspp(c4)
|
171 |
+
x = F.interpolate(
|
172 |
+
x, size=c1.size()[2:], mode="bilinear", align_corners=True
|
173 |
+
)
|
174 |
x = torch.cat((x, c1), dim=1)
|
175 |
backbone_features = x
|
176 |
else:
|
177 |
backbone_features = c4
|
178 |
self._c1_features = c1
|
179 |
else:
|
180 |
+
backbone_features = self.net.feature_extractor(x, additional_features)[
|
181 |
+
0
|
182 |
+
]
|
183 |
|
184 |
return backbone_features
|
185 |
|
186 |
|
187 |
class HRNetFeatureBRSPredictor(BRSBasePredictor):
|
188 |
+
def __init__(self, model, device, opt_functor, insertion_mode="A", **kwargs):
|
189 |
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
|
190 |
self.insertion_mode = insertion_mode
|
191 |
self._c1_features = None
|
192 |
|
193 |
+
if self.insertion_mode == "A":
|
194 |
+
self.num_channels = sum(
|
195 |
+
k * model.feature_extractor.width for k in [1, 2, 4, 8]
|
196 |
+
)
|
197 |
+
elif self.insertion_mode == "C":
|
198 |
self.num_channels = 2 * model.feature_extractor.ocr_width
|
199 |
else:
|
200 |
raise NotImplementedError
|
|
|
205 |
num_clicks = len(clicks_lists[0])
|
206 |
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
|
207 |
|
208 |
+
if (
|
209 |
+
self.opt_data is None
|
210 |
+
or self.opt_data.shape[0] // (2 * self.num_channels) != bs
|
211 |
+
):
|
212 |
self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32)
|
213 |
|
214 |
+
if (
|
215 |
+
num_clicks <= self.net_clicks_limit
|
216 |
+
or is_image_changed
|
217 |
+
or self.input_data is None
|
218 |
+
):
|
219 |
self.input_data = self._get_head_input(image_nd, points_nd)
|
220 |
|
221 |
def get_prediction_logits(scale, bias):
|
|
|
227 |
|
228 |
scaled_backbone_features = self.input_data * scale
|
229 |
scaled_backbone_features = scaled_backbone_features + bias
|
230 |
+
if self.insertion_mode == "A":
|
231 |
if self.net.feature_extractor.ocr_width > 0:
|
232 |
+
out_aux = self.net.feature_extractor.aux_head(
|
233 |
+
scaled_backbone_features
|
234 |
+
)
|
235 |
+
feats = self.net.feature_extractor.conv3x3_ocr(
|
236 |
+
scaled_backbone_features
|
237 |
+
)
|
238 |
|
239 |
context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
|
240 |
feats = self.net.feature_extractor.ocr_distri_head(feats, context)
|
241 |
else:
|
242 |
feats = scaled_backbone_features
|
243 |
pred_logits = self.net.feature_extractor.cls_head(feats)
|
244 |
+
elif self.insertion_mode == "C":
|
245 |
+
pred_logits = self.net.feature_extractor.cls_head(
|
246 |
+
scaled_backbone_features
|
247 |
+
)
|
248 |
else:
|
249 |
raise NotImplementedError
|
250 |
|
251 |
+
pred_logits = F.interpolate(
|
252 |
+
pred_logits,
|
253 |
+
size=image_nd.size()[2:],
|
254 |
+
mode="bilinear",
|
255 |
+
align_corners=True,
|
256 |
+
)
|
257 |
return pred_logits
|
258 |
|
259 |
+
self.opt_functor.init_click(
|
260 |
+
get_prediction_logits, pos_mask, neg_mask, self.device
|
261 |
+
)
|
262 |
if num_clicks > self.optimize_after_n_clicks:
|
263 |
+
opt_result = fmin_l_bfgs_b(
|
264 |
+
func=self.opt_functor,
|
265 |
+
x0=self.opt_data,
|
266 |
+
**self.opt_functor.optimizer_params
|
267 |
+
)
|
268 |
self.opt_data = opt_result[0]
|
269 |
|
270 |
with torch.no_grad():
|
|
|
285 |
if self.net.rgb_conv is not None:
|
286 |
x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1))
|
287 |
additional_features = None
|
288 |
+
elif hasattr(self.net, "maps_transform"):
|
289 |
x = image_nd
|
290 |
additional_features = self.net.maps_transform(coord_features)
|
291 |
|
292 |
+
feats = self.net.feature_extractor.compute_hrnet_feats(
|
293 |
+
x, additional_features
|
294 |
+
)
|
295 |
|
296 |
+
if self.insertion_mode == "A":
|
297 |
backbone_features = feats
|
298 |
+
elif self.insertion_mode == "C":
|
299 |
out_aux = self.net.feature_extractor.aux_head(feats)
|
300 |
feats = self.net.feature_extractor.conv3x3_ocr(feats)
|
301 |
|
302 |
context = self.net.feature_extractor.ocr_gather_head(feats, out_aux)
|
303 |
+
backbone_features = self.net.feature_extractor.ocr_distri_head(
|
304 |
+
feats, context
|
305 |
+
)
|
306 |
else:
|
307 |
raise NotImplementedError
|
308 |
|
|
|
310 |
|
311 |
|
312 |
class InputBRSPredictor(BRSBasePredictor):
|
313 |
+
def __init__(self, model, device, opt_functor, optimize_target="rgb", **kwargs):
|
314 |
super().__init__(model, device, opt_functor=opt_functor, **kwargs)
|
315 |
self.optimize_target = optimize_target
|
316 |
|
|
|
320 |
num_clicks = len(clicks_lists[0])
|
321 |
|
322 |
if self.opt_data is None or is_image_changed:
|
323 |
+
if self.optimize_target == "dmaps":
|
324 |
+
opt_channels = (
|
325 |
+
self.net.coord_feature_ch - 1
|
326 |
+
if self.net.with_prev_mask
|
327 |
+
else self.net.coord_feature_ch
|
328 |
+
)
|
329 |
else:
|
330 |
opt_channels = 3
|
331 |
bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0]
|
332 |
+
self.opt_data = torch.zeros(
|
333 |
+
(bs, opt_channels, image_nd.shape[2], image_nd.shape[3]),
|
334 |
+
device=self.device,
|
335 |
+
dtype=torch.float32,
|
336 |
+
)
|
337 |
|
338 |
def get_prediction_logits(opt_bias):
|
339 |
input_image, prev_mask = self.net.prepare_input(image_nd)
|
340 |
dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd)
|
341 |
|
342 |
+
if self.optimize_target == "rgb":
|
343 |
input_image = input_image + opt_bias
|
344 |
+
elif self.optimize_target == "dmaps":
|
345 |
if self.net.with_prev_mask:
|
346 |
dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias
|
347 |
else:
|
|
|
349 |
|
350 |
if self.net.rgb_conv is not None:
|
351 |
x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1))
|
352 |
+
if self.optimize_target == "all":
|
353 |
x = x + opt_bias
|
354 |
coord_features = None
|
355 |
+
elif hasattr(self.net, "maps_transform"):
|
356 |
x = input_image
|
357 |
coord_features = self.net.maps_transform(dmaps)
|
358 |
|
359 |
+
pred_logits = self.net.backbone_forward(x, coord_features=coord_features)[
|
360 |
+
"instances"
|
361 |
+
]
|
362 |
+
pred_logits = F.interpolate(
|
363 |
+
pred_logits,
|
364 |
+
size=image_nd.size()[2:],
|
365 |
+
mode="bilinear",
|
366 |
+
align_corners=True,
|
367 |
+
)
|
368 |
|
369 |
return pred_logits
|
370 |
|
371 |
+
self.opt_functor.init_click(
|
372 |
+
get_prediction_logits,
|
373 |
+
pos_mask,
|
374 |
+
neg_mask,
|
375 |
+
self.device,
|
376 |
+
shape=self.opt_data.shape,
|
377 |
+
)
|
378 |
if num_clicks > self.optimize_after_n_clicks:
|
379 |
+
opt_result = fmin_l_bfgs_b(
|
380 |
+
func=self.opt_functor,
|
381 |
+
x0=self.opt_data.cpu().numpy().ravel(),
|
382 |
+
**self.opt_functor.optimizer_params
|
383 |
+
)
|
384 |
+
|
385 |
+
self.opt_data = (
|
386 |
+
torch.from_numpy(opt_result[0])
|
387 |
+
.view(self.opt_data.shape)
|
388 |
+
.to(self.device)
|
389 |
+
)
|
390 |
|
391 |
with torch.no_grad():
|
392 |
if self.opt_functor.best_prediction is not None:
|
isegm/inference/predictors/brs_functors.py
CHANGED
@@ -1,19 +1,23 @@
|
|
1 |
-
import torch
|
2 |
import numpy as np
|
|
|
3 |
|
4 |
from isegm.model.metrics import _compute_iou
|
|
|
5 |
from .brs_losses import BRSMaskLoss
|
6 |
|
7 |
|
8 |
class BaseOptimizer:
|
9 |
-
def __init__(
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
17 |
self.brs_loss = brs_loss
|
18 |
self.optimizer_params = optimizer_params
|
19 |
self.prob_thresh = prob_thresh
|
@@ -51,7 +55,10 @@ class BaseOptimizer:
|
|
51 |
if self.with_flip and self.flip_average:
|
52 |
result, result_flipped = torch.chunk(result, 2, dim=0)
|
53 |
result = 0.5 * (result + torch.flip(result_flipped, dims=[3]))
|
54 |
-
pos_mask, neg_mask =
|
|
|
|
|
|
|
55 |
|
56 |
loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask)
|
57 |
loss = loss + reg_loss
|
@@ -99,11 +106,13 @@ class ScaleBiasOptimizer(BaseOptimizer):
|
|
99 |
|
100 |
def unpack_opt_params(self, opt_params):
|
101 |
scale, bias = torch.chunk(opt_params, 2, dim=0)
|
102 |
-
reg_loss = self.reg_weight * (
|
|
|
|
|
103 |
|
104 |
-
if self.scale_act ==
|
105 |
scale = torch.tanh(scale)
|
106 |
-
elif self.scale_act ==
|
107 |
scale = torch.sin(scale)
|
108 |
|
109 |
return (1 + scale, bias), reg_loss
|
|
|
|
|
1 |
import numpy as np
|
2 |
+
import torch
|
3 |
|
4 |
from isegm.model.metrics import _compute_iou
|
5 |
+
|
6 |
from .brs_losses import BRSMaskLoss
|
7 |
|
8 |
|
9 |
class BaseOptimizer:
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
optimizer_params,
|
13 |
+
prob_thresh=0.49,
|
14 |
+
reg_weight=1e-3,
|
15 |
+
min_iou_diff=0.01,
|
16 |
+
brs_loss=BRSMaskLoss(),
|
17 |
+
with_flip=False,
|
18 |
+
flip_average=False,
|
19 |
+
**kwargs
|
20 |
+
):
|
21 |
self.brs_loss = brs_loss
|
22 |
self.optimizer_params = optimizer_params
|
23 |
self.prob_thresh = prob_thresh
|
|
|
55 |
if self.with_flip and self.flip_average:
|
56 |
result, result_flipped = torch.chunk(result, 2, dim=0)
|
57 |
result = 0.5 * (result + torch.flip(result_flipped, dims=[3]))
|
58 |
+
pos_mask, neg_mask = (
|
59 |
+
pos_mask[: result.shape[0]],
|
60 |
+
neg_mask[: result.shape[0]],
|
61 |
+
)
|
62 |
|
63 |
loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask)
|
64 |
loss = loss + reg_loss
|
|
|
106 |
|
107 |
def unpack_opt_params(self, opt_params):
|
108 |
scale, bias = torch.chunk(opt_params, 2, dim=0)
|
109 |
+
reg_loss = self.reg_weight * (
|
110 |
+
torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2)
|
111 |
+
)
|
112 |
|
113 |
+
if self.scale_act == "tanh":
|
114 |
scale = torch.tanh(scale)
|
115 |
+
elif self.scale_act == "sin":
|
116 |
scale = torch.sin(scale)
|
117 |
|
118 |
return (1 + scale, bias), reg_loss
|
isegm/inference/predictors/brs_losses.py
CHANGED
@@ -10,13 +10,13 @@ class BRSMaskLoss(torch.nn.Module):
|
|
10 |
|
11 |
def forward(self, result, pos_mask, neg_mask):
|
12 |
pos_diff = (1 - result) * pos_mask
|
13 |
-
pos_target = torch.sum(pos_diff
|
14 |
pos_target = pos_target / (torch.sum(pos_mask) + self._eps)
|
15 |
|
16 |
neg_diff = result * neg_mask
|
17 |
-
neg_target = torch.sum(neg_diff
|
18 |
neg_target = neg_target / (torch.sum(neg_mask) + self._eps)
|
19 |
-
|
20 |
loss = pos_target + neg_target
|
21 |
|
22 |
with torch.no_grad():
|
@@ -42,8 +42,10 @@ class OracleMaskLoss(torch.nn.Module):
|
|
42 |
gt_mask = self.gt_mask.to(result.device)
|
43 |
if self.predictor.object_roi is not None:
|
44 |
r1, r2, c1, c2 = self.predictor.object_roi[:4]
|
45 |
-
gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1]
|
46 |
-
gt_mask = torch.nn.functional.interpolate(
|
|
|
|
|
47 |
|
48 |
if result.shape[0] == 2:
|
49 |
gt_mask_flipped = torch.flip(gt_mask, dims=[3])
|
|
|
10 |
|
11 |
def forward(self, result, pos_mask, neg_mask):
|
12 |
pos_diff = (1 - result) * pos_mask
|
13 |
+
pos_target = torch.sum(pos_diff**2)
|
14 |
pos_target = pos_target / (torch.sum(pos_mask) + self._eps)
|
15 |
|
16 |
neg_diff = result * neg_mask
|
17 |
+
neg_target = torch.sum(neg_diff**2)
|
18 |
neg_target = neg_target / (torch.sum(neg_mask) + self._eps)
|
19 |
+
|
20 |
loss = pos_target + neg_target
|
21 |
|
22 |
with torch.no_grad():
|
|
|
42 |
gt_mask = self.gt_mask.to(result.device)
|
43 |
if self.predictor.object_roi is not None:
|
44 |
r1, r2, c1, c2 = self.predictor.object_roi[:4]
|
45 |
+
gt_mask = gt_mask[:, :, r1 : r2 + 1, c1 : c2 + 1]
|
46 |
+
gt_mask = torch.nn.functional.interpolate(
|
47 |
+
gt_mask, result.size()[2:], mode="bilinear", align_corners=True
|
48 |
+
)
|
49 |
|
50 |
if result.shape[0] == 2:
|
51 |
gt_mask_flipped = torch.flip(gt_mask, dims=[3])
|
isegm/inference/transforms/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from .base import SigmoidForPred
|
|
|
2 |
from .flip import AddHorizontalFlip
|
3 |
-
from .zoom_in import ZoomIn
|
4 |
from .limit_longest_side import LimitLongestSide
|
5 |
-
from .
|
|
|
1 |
from .base import SigmoidForPred
|
2 |
+
from .crops import Crops
|
3 |
from .flip import AddHorizontalFlip
|
|
|
4 |
from .limit_longest_side import LimitLongestSide
|
5 |
+
from .zoom_in import ZoomIn
|
isegm/inference/transforms/crops.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import math
|
|
|
2 |
|
3 |
-
import torch
|
4 |
import numpy as np
|
5 |
-
|
6 |
|
7 |
from isegm.inference.clicker import Click
|
|
|
8 |
from .base import BaseTransform
|
9 |
|
10 |
|
@@ -33,17 +34,24 @@ class Crops(BaseTransform):
|
|
33 |
image_crops = []
|
34 |
for dy in self.y_offsets:
|
35 |
for dx in self.x_offsets:
|
36 |
-
self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1
|
37 |
-
image_crop = image_nd[
|
|
|
|
|
38 |
image_crops.append(image_crop)
|
39 |
image_crops = torch.cat(image_crops, dim=0)
|
40 |
-
self._counts = torch.tensor(
|
|
|
|
|
41 |
|
42 |
clicks_list = clicks_lists[0]
|
43 |
clicks_lists = []
|
44 |
for dy in self.y_offsets:
|
45 |
for dx in self.x_offsets:
|
46 |
-
crop_clicks = [
|
|
|
|
|
|
|
47 |
clicks_lists.append(crop_clicks)
|
48 |
|
49 |
return image_crops, clicks_lists
|
@@ -52,13 +60,16 @@ class Crops(BaseTransform):
|
|
52 |
if self._counts is None:
|
53 |
return prob_map
|
54 |
|
55 |
-
new_prob_map = torch.zeros(
|
56 |
-
|
|
|
57 |
|
58 |
crop_indx = 0
|
59 |
for dy in self.y_offsets:
|
60 |
for dx in self.x_offsets:
|
61 |
-
new_prob_map[
|
|
|
|
|
62 |
crop_indx += 1
|
63 |
new_prob_map = torch.div(new_prob_map, self._counts)
|
64 |
|
|
|
1 |
import math
|
2 |
+
from typing import List
|
3 |
|
|
|
4 |
import numpy as np
|
5 |
+
import torch
|
6 |
|
7 |
from isegm.inference.clicker import Click
|
8 |
+
|
9 |
from .base import BaseTransform
|
10 |
|
11 |
|
|
|
34 |
image_crops = []
|
35 |
for dy in self.y_offsets:
|
36 |
for dx in self.x_offsets:
|
37 |
+
self._counts[dy : dy + self.crop_height, dx : dx + self.crop_width] += 1
|
38 |
+
image_crop = image_nd[
|
39 |
+
:, :, dy : dy + self.crop_height, dx : dx + self.crop_width
|
40 |
+
]
|
41 |
image_crops.append(image_crop)
|
42 |
image_crops = torch.cat(image_crops, dim=0)
|
43 |
+
self._counts = torch.tensor(
|
44 |
+
self._counts, device=image_nd.device, dtype=torch.float32
|
45 |
+
)
|
46 |
|
47 |
clicks_list = clicks_lists[0]
|
48 |
clicks_lists = []
|
49 |
for dy in self.y_offsets:
|
50 |
for dx in self.x_offsets:
|
51 |
+
crop_clicks = [
|
52 |
+
x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx))
|
53 |
+
for x in clicks_list
|
54 |
+
]
|
55 |
clicks_lists.append(crop_clicks)
|
56 |
|
57 |
return image_crops, clicks_lists
|
|
|
60 |
if self._counts is None:
|
61 |
return prob_map
|
62 |
|
63 |
+
new_prob_map = torch.zeros(
|
64 |
+
(1, 1, *self._counts.shape), dtype=prob_map.dtype, device=prob_map.device
|
65 |
+
)
|
66 |
|
67 |
crop_indx = 0
|
68 |
for dy in self.y_offsets:
|
69 |
for dx in self.x_offsets:
|
70 |
+
new_prob_map[
|
71 |
+
0, 0, dy : dy + self.crop_height, dx : dx + self.crop_width
|
72 |
+
] += prob_map[crop_indx, 0]
|
73 |
crop_indx += 1
|
74 |
new_prob_map = torch.div(new_prob_map, self._counts)
|
75 |
|
isegm/inference/transforms/flip.py
CHANGED
@@ -1,7 +1,9 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
|
3 |
-
from typing import List
|
4 |
from isegm.inference.clicker import Click
|
|
|
5 |
from .base import BaseTransform
|
6 |
|
7 |
|
@@ -13,8 +15,10 @@ class AddHorizontalFlip(BaseTransform):
|
|
13 |
image_width = image_nd.shape[3]
|
14 |
clicks_lists_flipped = []
|
15 |
for clicks_list in clicks_lists:
|
16 |
-
clicks_list_flipped = [
|
17 |
-
|
|
|
|
|
18 |
clicks_lists_flipped.append(clicks_list_flipped)
|
19 |
clicks_lists = clicks_lists + clicks_lists_flipped
|
20 |
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
import torch
|
4 |
|
|
|
5 |
from isegm.inference.clicker import Click
|
6 |
+
|
7 |
from .base import BaseTransform
|
8 |
|
9 |
|
|
|
15 |
image_width = image_nd.shape[3]
|
16 |
clicks_lists_flipped = []
|
17 |
for clicks_list in clicks_lists:
|
18 |
+
clicks_list_flipped = [
|
19 |
+
click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1))
|
20 |
+
for click in clicks_list
|
21 |
+
]
|
22 |
clicks_lists_flipped.append(clicks_list_flipped)
|
23 |
clicks_lists = clicks_lists + clicks_lists_flipped
|
24 |
|
isegm/inference/transforms/zoom_in.py
CHANGED
@@ -1,19 +1,24 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
|
3 |
-
from typing import List
|
4 |
from isegm.inference.clicker import Click
|
5 |
-
from isegm.utils.misc import
|
|
|
|
|
6 |
from .base import BaseTransform
|
7 |
|
8 |
|
9 |
class ZoomIn(BaseTransform):
|
10 |
-
def __init__(
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
17 |
super().__init__()
|
18 |
self.target_size = target_size
|
19 |
self.min_crop_size = min_crop_size
|
@@ -41,8 +46,12 @@ class ZoomIn(BaseTransform):
|
|
41 |
if self._prev_probs is not None:
|
42 |
current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
|
43 |
if current_pred_mask.sum() > 0:
|
44 |
-
current_object_roi = get_object_roi(
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
|
47 |
if current_object_roi is None:
|
48 |
if self.skip_clicks >= 0:
|
@@ -55,7 +64,10 @@ class ZoomIn(BaseTransform):
|
|
55 |
update_object_roi = True
|
56 |
elif not check_object_roi(self._object_roi, clicks_list):
|
57 |
update_object_roi = True
|
58 |
-
elif
|
|
|
|
|
|
|
59 |
update_object_roi = True
|
60 |
|
61 |
if update_object_roi:
|
@@ -73,12 +85,18 @@ class ZoomIn(BaseTransform):
|
|
73 |
|
74 |
assert prob_map.shape[0] == 1
|
75 |
rmin, rmax, cmin, cmax = self._object_roi
|
76 |
-
prob_map = torch.nn.functional.interpolate(
|
77 |
-
|
|
|
|
|
|
|
|
|
78 |
|
79 |
if self._prev_probs is not None:
|
80 |
-
new_prob_map = torch.zeros(
|
81 |
-
|
|
|
|
|
82 |
else:
|
83 |
new_prob_map = prob_map
|
84 |
|
@@ -87,24 +105,46 @@ class ZoomIn(BaseTransform):
|
|
87 |
return new_prob_map
|
88 |
|
89 |
def check_possible_recalculation(self):
|
90 |
-
if
|
|
|
|
|
|
|
|
|
91 |
return False
|
92 |
|
93 |
pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
|
94 |
if pred_mask.sum() > 0:
|
95 |
-
possible_object_roi = get_object_roi(
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
|
99 |
return True
|
100 |
return False
|
101 |
|
102 |
def get_state(self):
|
103 |
roi_image = self._roi_image.cpu() if self._roi_image is not None else None
|
104 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
def set_state(self, state):
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
def reset(self):
|
110 |
self._input_image_shape = None
|
@@ -157,9 +197,13 @@ def get_roi_image_nd(image_nd, object_roi, target_size):
|
|
157 |
new_width = int(round(width * scale))
|
158 |
|
159 |
with torch.no_grad():
|
160 |
-
roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1]
|
161 |
-
roi_image_nd = torch.nn.functional.interpolate(
|
162 |
-
|
|
|
|
|
|
|
|
|
163 |
|
164 |
return roi_image_nd
|
165 |
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
import torch
|
4 |
|
|
|
5 |
from isegm.inference.clicker import Click
|
6 |
+
from isegm.utils.misc import (clamp_bbox, expand_bbox, get_bbox_from_mask,
|
7 |
+
get_bbox_iou)
|
8 |
+
|
9 |
from .base import BaseTransform
|
10 |
|
11 |
|
12 |
class ZoomIn(BaseTransform):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
target_size=400,
|
16 |
+
skip_clicks=1,
|
17 |
+
expansion_ratio=1.4,
|
18 |
+
min_crop_size=200,
|
19 |
+
recompute_thresh_iou=0.5,
|
20 |
+
prob_thresh=0.50,
|
21 |
+
):
|
22 |
super().__init__()
|
23 |
self.target_size = target_size
|
24 |
self.min_crop_size = min_crop_size
|
|
|
46 |
if self._prev_probs is not None:
|
47 |
current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
|
48 |
if current_pred_mask.sum() > 0:
|
49 |
+
current_object_roi = get_object_roi(
|
50 |
+
current_pred_mask,
|
51 |
+
clicks_list,
|
52 |
+
self.expansion_ratio,
|
53 |
+
self.min_crop_size,
|
54 |
+
)
|
55 |
|
56 |
if current_object_roi is None:
|
57 |
if self.skip_clicks >= 0:
|
|
|
64 |
update_object_roi = True
|
65 |
elif not check_object_roi(self._object_roi, clicks_list):
|
66 |
update_object_roi = True
|
67 |
+
elif (
|
68 |
+
get_bbox_iou(current_object_roi, self._object_roi)
|
69 |
+
< self.recompute_thresh_iou
|
70 |
+
):
|
71 |
update_object_roi = True
|
72 |
|
73 |
if update_object_roi:
|
|
|
85 |
|
86 |
assert prob_map.shape[0] == 1
|
87 |
rmin, rmax, cmin, cmax = self._object_roi
|
88 |
+
prob_map = torch.nn.functional.interpolate(
|
89 |
+
prob_map,
|
90 |
+
size=(rmax - rmin + 1, cmax - cmin + 1),
|
91 |
+
mode="bilinear",
|
92 |
+
align_corners=True,
|
93 |
+
)
|
94 |
|
95 |
if self._prev_probs is not None:
|
96 |
+
new_prob_map = torch.zeros(
|
97 |
+
*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype
|
98 |
+
)
|
99 |
+
new_prob_map[:, :, rmin : rmax + 1, cmin : cmax + 1] = prob_map
|
100 |
else:
|
101 |
new_prob_map = prob_map
|
102 |
|
|
|
105 |
return new_prob_map
|
106 |
|
107 |
def check_possible_recalculation(self):
|
108 |
+
if (
|
109 |
+
self._prev_probs is None
|
110 |
+
or self._object_roi is not None
|
111 |
+
or self.skip_clicks > 0
|
112 |
+
):
|
113 |
return False
|
114 |
|
115 |
pred_mask = (self._prev_probs > self.prob_thresh)[0, 0]
|
116 |
if pred_mask.sum() > 0:
|
117 |
+
possible_object_roi = get_object_roi(
|
118 |
+
pred_mask, [], self.expansion_ratio, self.min_crop_size
|
119 |
+
)
|
120 |
+
image_roi = (
|
121 |
+
0,
|
122 |
+
self._input_image_shape[2] - 1,
|
123 |
+
0,
|
124 |
+
self._input_image_shape[3] - 1,
|
125 |
+
)
|
126 |
if get_bbox_iou(possible_object_roi, image_roi) < 0.50:
|
127 |
return True
|
128 |
return False
|
129 |
|
130 |
def get_state(self):
|
131 |
roi_image = self._roi_image.cpu() if self._roi_image is not None else None
|
132 |
+
return (
|
133 |
+
self._input_image_shape,
|
134 |
+
self._object_roi,
|
135 |
+
self._prev_probs,
|
136 |
+
roi_image,
|
137 |
+
self.image_changed,
|
138 |
+
)
|
139 |
|
140 |
def set_state(self, state):
|
141 |
+
(
|
142 |
+
self._input_image_shape,
|
143 |
+
self._object_roi,
|
144 |
+
self._prev_probs,
|
145 |
+
self._roi_image,
|
146 |
+
self.image_changed,
|
147 |
+
) = state
|
148 |
|
149 |
def reset(self):
|
150 |
self._input_image_shape = None
|
|
|
197 |
new_width = int(round(width * scale))
|
198 |
|
199 |
with torch.no_grad():
|
200 |
+
roi_image_nd = image_nd[:, :, rmin : rmax + 1, cmin : cmax + 1]
|
201 |
+
roi_image_nd = torch.nn.functional.interpolate(
|
202 |
+
roi_image_nd,
|
203 |
+
size=(new_height, new_width),
|
204 |
+
mode="bilinear",
|
205 |
+
align_corners=True,
|
206 |
+
)
|
207 |
|
208 |
return roi_image_nd
|
209 |
|
isegm/inference/utils.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
from datetime import timedelta
|
2 |
from pathlib import Path
|
3 |
|
4 |
-
import torch
|
5 |
import numpy as np
|
|
|
6 |
|
7 |
-
from isegm.data.datasets import
|
|
|
8 |
from isegm.utils.serialization import load_model
|
9 |
|
10 |
|
@@ -20,7 +21,7 @@ def get_time_metrics(all_ious, elapsed_time):
|
|
20 |
|
21 |
def load_is_model(checkpoint, device, **kwargs):
|
22 |
if isinstance(checkpoint, (str, Path)):
|
23 |
-
state_dict = torch.load(checkpoint, map_location=
|
24 |
else:
|
25 |
state_dict = checkpoint
|
26 |
|
@@ -34,8 +35,8 @@ def load_is_model(checkpoint, device, **kwargs):
|
|
34 |
|
35 |
|
36 |
def load_single_is_model(state_dict, device, **kwargs):
|
37 |
-
model = load_model(state_dict[
|
38 |
-
model.load_state_dict(state_dict[
|
39 |
|
40 |
for param in model.parameters():
|
41 |
param.requires_grad = False
|
@@ -46,19 +47,19 @@ def load_single_is_model(state_dict, device, **kwargs):
|
|
46 |
|
47 |
|
48 |
def get_dataset(dataset_name, cfg):
|
49 |
-
if dataset_name ==
|
50 |
dataset = GrabCutDataset(cfg.GRABCUT_PATH)
|
51 |
-
elif dataset_name ==
|
52 |
dataset = BerkeleyDataset(cfg.BERKELEY_PATH)
|
53 |
-
elif dataset_name ==
|
54 |
dataset = DavisDataset(cfg.DAVIS_PATH)
|
55 |
-
elif dataset_name ==
|
56 |
dataset = SBDEvaluationDataset(cfg.SBD_PATH)
|
57 |
-
elif dataset_name ==
|
58 |
-
dataset = SBDEvaluationDataset(cfg.SBD_PATH, split=
|
59 |
-
elif dataset_name ==
|
60 |
-
dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split=
|
61 |
-
elif dataset_name ==
|
62 |
dataset = DavisDataset(cfg.COCO_MVAL_PATH)
|
63 |
else:
|
64 |
dataset = None
|
@@ -70,8 +71,12 @@ def get_iou(gt_mask, pred_mask, ignore_label=-1):
|
|
70 |
ignore_gt_mask_inv = gt_mask != ignore_label
|
71 |
obj_gt_mask = gt_mask == 1
|
72 |
|
73 |
-
intersection = np.logical_and(
|
74 |
-
|
|
|
|
|
|
|
|
|
75 |
|
76 |
return intersection / union
|
77 |
|
@@ -84,8 +89,9 @@ def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
|
|
84 |
noc_list = []
|
85 |
over_max_list = []
|
86 |
for iou_thr in iou_thrs:
|
87 |
-
scores_arr = np.array(
|
88 |
-
|
|
|
89 |
|
90 |
score = scores_arr.mean()
|
91 |
over_max = (scores_arr == max_clicks).sum()
|
@@ -98,46 +104,58 @@ def compute_noc_metric(all_ious, iou_thrs, max_clicks=20):
|
|
98 |
|
99 |
def find_checkpoint(weights_folder, checkpoint_name):
|
100 |
weights_folder = Path(weights_folder)
|
101 |
-
if
|
102 |
-
model_name, checkpoint_name = checkpoint_name.split(
|
103 |
-
models_candidates = [
|
|
|
|
|
104 |
assert len(models_candidates) == 1
|
105 |
model_folder = models_candidates[0]
|
106 |
else:
|
107 |
model_folder = weights_folder
|
108 |
|
109 |
-
if checkpoint_name.endswith(
|
110 |
if Path(checkpoint_name).exists():
|
111 |
checkpoint_path = checkpoint_name
|
112 |
else:
|
113 |
checkpoint_path = weights_folder / checkpoint_name
|
114 |
else:
|
115 |
-
model_checkpoints = list(model_folder.rglob(f
|
116 |
assert len(model_checkpoints) == 1
|
117 |
checkpoint_path = model_checkpoints[0]
|
118 |
|
119 |
return str(checkpoint_path)
|
120 |
|
121 |
|
122 |
-
def get_results_table(
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
row_width = len(table_header)
|
129 |
|
130 |
-
header = f
|
131 |
-
header +=
|
132 |
-
header += table_header +
|
133 |
|
134 |
eval_time = str(timedelta(seconds=int(elapsed_time)))
|
135 |
-
table_row = f
|
136 |
-
table_row += f
|
137 |
-
table_row += f
|
138 |
-
table_row += f
|
139 |
-
table_row += f
|
140 |
-
table_row += f
|
141 |
-
table_row += f
|
142 |
-
|
143 |
-
return header, table_row
|
|
|
1 |
from datetime import timedelta
|
2 |
from pathlib import Path
|
3 |
|
|
|
4 |
import numpy as np
|
5 |
+
import torch
|
6 |
|
7 |
+
from isegm.data.datasets import (BerkeleyDataset, DavisDataset, GrabCutDataset,
|
8 |
+
PascalVocDataset, SBDEvaluationDataset)
|
9 |
from isegm.utils.serialization import load_model
|
10 |
|
11 |
|
|
|
21 |
|
22 |
def load_is_model(checkpoint, device, **kwargs):
|
23 |
if isinstance(checkpoint, (str, Path)):
|
24 |
+
state_dict = torch.load(checkpoint, map_location="cpu")
|
25 |
else:
|
26 |
state_dict = checkpoint
|
27 |
|
|
|
35 |
|
36 |
|
37 |
def load_single_is_model(state_dict, device, **kwargs):
|
38 |
+
model = load_model(state_dict["config"], **kwargs)
|
39 |
+
model.load_state_dict(state_dict["state_dict"], strict=False)
|
40 |
|
41 |
for param in model.parameters():
|
42 |
param.requires_grad = False
|
|
|
47 |
|
48 |
|
49 |
def get_dataset(dataset_name, cfg):
|
50 |
+
if dataset_name == "GrabCut":
|
51 |
dataset = GrabCutDataset(cfg.GRABCUT_PATH)
|
52 |
+
elif dataset_name == "Berkeley":
|
53 |
dataset = BerkeleyDataset(cfg.BERKELEY_PATH)
|
54 |
+
elif dataset_name == "DAVIS":
|
55 |
dataset = DavisDataset(cfg.DAVIS_PATH)
|
56 |
+
elif dataset_name == "SBD":
|
57 |
dataset = SBDEvaluationDataset(cfg.SBD_PATH)
|
58 |
+
elif dataset_name == "SBD_Train":
|
59 |
+
dataset = SBDEvaluationDataset(cfg.SBD_PATH, split="train")
|
60 |
+
elif dataset_name == "PascalVOC":
|
61 |
+
dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split="test")
|
62 |
+
elif dataset_name == "COCO_MVal":
|
63 |
dataset = DavisDataset(cfg.COCO_MVAL_PATH)
|
64 |
else:
|
65 |
dataset = None
|
|
|
71 |
ignore_gt_mask_inv = gt_mask != ignore_label
|
72 |
obj_gt_mask = gt_mask == 1
|
73 |
|
74 |
+
intersection = np.logical_and(
|
75 |
+
np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv
|
76 |
+
).sum()
|
77 |
+
union = np.logical_and(
|
78 |
+
np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv
|
79 |
+
).sum()
|
80 |
|
81 |
return intersection / union
|
82 |
|
|
|
89 |
noc_list = []
|
90 |
over_max_list = []
|
91 |
for iou_thr in iou_thrs:
|
92 |
+
scores_arr = np.array(
|
93 |
+
[_get_noc(iou_arr, iou_thr) for iou_arr in all_ious], dtype=np.int
|
94 |
+
)
|
95 |
|
96 |
score = scores_arr.mean()
|
97 |
over_max = (scores_arr == max_clicks).sum()
|
|
|
104 |
|
105 |
def find_checkpoint(weights_folder, checkpoint_name):
|
106 |
weights_folder = Path(weights_folder)
|
107 |
+
if ":" in checkpoint_name:
|
108 |
+
model_name, checkpoint_name = checkpoint_name.split(":")
|
109 |
+
models_candidates = [
|
110 |
+
x for x in weights_folder.glob(f"{model_name}*") if x.is_dir()
|
111 |
+
]
|
112 |
assert len(models_candidates) == 1
|
113 |
model_folder = models_candidates[0]
|
114 |
else:
|
115 |
model_folder = weights_folder
|
116 |
|
117 |
+
if checkpoint_name.endswith(".pth"):
|
118 |
if Path(checkpoint_name).exists():
|
119 |
checkpoint_path = checkpoint_name
|
120 |
else:
|
121 |
checkpoint_path = weights_folder / checkpoint_name
|
122 |
else:
|
123 |
+
model_checkpoints = list(model_folder.rglob(f"{checkpoint_name}*.pth"))
|
124 |
assert len(model_checkpoints) == 1
|
125 |
checkpoint_path = model_checkpoints[0]
|
126 |
|
127 |
return str(checkpoint_path)
|
128 |
|
129 |
|
130 |
+
def get_results_table(
|
131 |
+
noc_list,
|
132 |
+
over_max_list,
|
133 |
+
brs_type,
|
134 |
+
dataset_name,
|
135 |
+
mean_spc,
|
136 |
+
elapsed_time,
|
137 |
+
n_clicks=20,
|
138 |
+
model_name=None,
|
139 |
+
):
|
140 |
+
table_header = (
|
141 |
+
f'|{"BRS Type":^13}|{"Dataset":^11}|'
|
142 |
+
f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|'
|
143 |
+
f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|'
|
144 |
+
f'{"SPC,s":^7}|{"Time":^9}|'
|
145 |
+
)
|
146 |
row_width = len(table_header)
|
147 |
|
148 |
+
header = f"Eval results for model: {model_name}\n" if model_name is not None else ""
|
149 |
+
header += "-" * row_width + "\n"
|
150 |
+
header += table_header + "\n" + "-" * row_width
|
151 |
|
152 |
eval_time = str(timedelta(seconds=int(elapsed_time)))
|
153 |
+
table_row = f"|{brs_type:^13}|{dataset_name:^11}|"
|
154 |
+
table_row += f"{noc_list[0]:^9.2f}|"
|
155 |
+
table_row += f"{noc_list[1]:^9.2f}|" if len(noc_list) > 1 else f'{"?":^9}|'
|
156 |
+
table_row += f"{noc_list[2]:^9.2f}|" if len(noc_list) > 2 else f'{"?":^9}|'
|
157 |
+
table_row += f"{over_max_list[1]:^9}|" if len(noc_list) > 1 else f'{"?":^9}|'
|
158 |
+
table_row += f"{over_max_list[2]:^9}|" if len(noc_list) > 2 else f'{"?":^9}|'
|
159 |
+
table_row += f"{mean_spc:^7.3f}|{eval_time:^9}|"
|
160 |
+
|
161 |
+
return header, table_row
|
isegm/model/initializer.py
CHANGED
@@ -1,6 +1,6 @@
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
import numpy as np
|
4 |
|
5 |
|
6 |
class Initializer(object):
|
@@ -9,24 +9,37 @@ class Initializer(object):
|
|
9 |
self.gamma = gamma
|
10 |
|
11 |
def __call__(self, m):
|
12 |
-
if getattr(m,
|
13 |
return
|
14 |
|
15 |
-
if
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
if m.weight is not None:
|
19 |
self._init_gamma(m.weight.data)
|
20 |
if m.bias is not None:
|
21 |
self._init_beta(m.bias.data)
|
22 |
else:
|
23 |
-
if getattr(m,
|
24 |
self._init_weight(m.weight.data)
|
25 |
-
if getattr(m,
|
26 |
self._init_bias(m.bias.data)
|
27 |
|
28 |
if self.local_init:
|
29 |
-
object.__setattr__(m,
|
30 |
|
31 |
def _init_weight(self, data):
|
32 |
nn.init.uniform_(data, -0.07, 0.07)
|
@@ -71,13 +84,15 @@ class Bilinear(Initializer):
|
|
71 |
center = scale - 0.5 * (1 + kernel_size % 2)
|
72 |
|
73 |
og = np.ogrid[:kernel_size, :kernel_size]
|
74 |
-
kernel = (1 - np.abs(og[0] - center) / scale) * (
|
|
|
|
|
75 |
|
76 |
return torch.tensor(kernel, dtype=torch.float32)
|
77 |
|
78 |
|
79 |
class XavierGluon(Initializer):
|
80 |
-
def __init__(self, rnd_type=
|
81 |
super().__init__(**kwargs)
|
82 |
|
83 |
self.rnd_type = rnd_type
|
@@ -87,19 +102,19 @@ class XavierGluon(Initializer):
|
|
87 |
def _init_weight(self, arr):
|
88 |
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
|
89 |
|
90 |
-
if self.factor_type ==
|
91 |
factor = (fan_in + fan_out) / 2.0
|
92 |
-
elif self.factor_type ==
|
93 |
factor = fan_in
|
94 |
-
elif self.factor_type ==
|
95 |
factor = fan_out
|
96 |
else:
|
97 |
-
raise ValueError(
|
98 |
scale = np.sqrt(self.magnitude / factor)
|
99 |
|
100 |
-
if self.rnd_type ==
|
101 |
nn.init.uniform_(arr, -scale, scale)
|
102 |
-
elif self.rnd_type ==
|
103 |
nn.init.normal_(arr, 0, scale)
|
104 |
else:
|
105 |
-
raise ValueError(
|
|
|
1 |
+
import numpy as np
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
4 |
|
5 |
|
6 |
class Initializer(object):
|
|
|
9 |
self.gamma = gamma
|
10 |
|
11 |
def __call__(self, m):
|
12 |
+
if getattr(m, "__initialized", False):
|
13 |
return
|
14 |
|
15 |
+
if (
|
16 |
+
isinstance(
|
17 |
+
m,
|
18 |
+
(
|
19 |
+
nn.BatchNorm1d,
|
20 |
+
nn.BatchNorm2d,
|
21 |
+
nn.BatchNorm3d,
|
22 |
+
nn.InstanceNorm1d,
|
23 |
+
nn.InstanceNorm2d,
|
24 |
+
nn.InstanceNorm3d,
|
25 |
+
nn.GroupNorm,
|
26 |
+
nn.SyncBatchNorm,
|
27 |
+
),
|
28 |
+
)
|
29 |
+
or "BatchNorm" in m.__class__.__name__
|
30 |
+
):
|
31 |
if m.weight is not None:
|
32 |
self._init_gamma(m.weight.data)
|
33 |
if m.bias is not None:
|
34 |
self._init_beta(m.bias.data)
|
35 |
else:
|
36 |
+
if getattr(m, "weight", None) is not None:
|
37 |
self._init_weight(m.weight.data)
|
38 |
+
if getattr(m, "bias", None) is not None:
|
39 |
self._init_bias(m.bias.data)
|
40 |
|
41 |
if self.local_init:
|
42 |
+
object.__setattr__(m, "__initialized", True)
|
43 |
|
44 |
def _init_weight(self, data):
|
45 |
nn.init.uniform_(data, -0.07, 0.07)
|
|
|
84 |
center = scale - 0.5 * (1 + kernel_size % 2)
|
85 |
|
86 |
og = np.ogrid[:kernel_size, :kernel_size]
|
87 |
+
kernel = (1 - np.abs(og[0] - center) / scale) * (
|
88 |
+
1 - np.abs(og[1] - center) / scale
|
89 |
+
)
|
90 |
|
91 |
return torch.tensor(kernel, dtype=torch.float32)
|
92 |
|
93 |
|
94 |
class XavierGluon(Initializer):
|
95 |
+
def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3, **kwargs):
|
96 |
super().__init__(**kwargs)
|
97 |
|
98 |
self.rnd_type = rnd_type
|
|
|
102 |
def _init_weight(self, arr):
|
103 |
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
|
104 |
|
105 |
+
if self.factor_type == "avg":
|
106 |
factor = (fan_in + fan_out) / 2.0
|
107 |
+
elif self.factor_type == "in":
|
108 |
factor = fan_in
|
109 |
+
elif self.factor_type == "out":
|
110 |
factor = fan_out
|
111 |
else:
|
112 |
+
raise ValueError("Incorrect factor type")
|
113 |
scale = np.sqrt(self.magnitude / factor)
|
114 |
|
115 |
+
if self.rnd_type == "uniform":
|
116 |
nn.init.uniform_(arr, -scale, scale)
|
117 |
+
elif self.rnd_type == "gaussian":
|
118 |
nn.init.normal_(arr, 0, scale)
|
119 |
else:
|
120 |
+
raise ValueError("Unknown random type")
|
isegm/model/is_deeplab_model.py
CHANGED
@@ -1,25 +1,44 @@
|
|
1 |
import torch.nn as nn
|
2 |
|
|
|
3 |
from isegm.utils.serialization import serialize
|
|
|
4 |
from .is_model import ISModel
|
5 |
-
from .modeling.deeplab_v3 import DeepLabV3Plus
|
6 |
from .modeling.basic_blocks import SepConvHead
|
7 |
-
from
|
8 |
|
9 |
|
10 |
class DeeplabModel(ISModel):
|
11 |
@serialize
|
12 |
-
def __init__(
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
super().__init__(norm_layer=norm_layer, **kwargs)
|
15 |
|
16 |
-
self.feature_extractor = DeepLabV3Plus(
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
18 |
self.feature_extractor.backbone.apply(LRMult(backbone_lr_mult))
|
19 |
-
self.head = SepConvHead(
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def backbone_forward(self, image, coord_features=None):
|
23 |
backbone_features = self.feature_extractor(image, coord_features)
|
24 |
|
25 |
-
return {
|
|
|
1 |
import torch.nn as nn
|
2 |
|
3 |
+
from isegm.model.modifiers import LRMult
|
4 |
from isegm.utils.serialization import serialize
|
5 |
+
|
6 |
from .is_model import ISModel
|
|
|
7 |
from .modeling.basic_blocks import SepConvHead
|
8 |
+
from .modeling.deeplab_v3 import DeepLabV3Plus
|
9 |
|
10 |
|
11 |
class DeeplabModel(ISModel):
|
12 |
@serialize
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
backbone="resnet50",
|
16 |
+
deeplab_ch=256,
|
17 |
+
aspp_dropout=0.5,
|
18 |
+
backbone_norm_layer=None,
|
19 |
+
backbone_lr_mult=0.1,
|
20 |
+
norm_layer=nn.BatchNorm2d,
|
21 |
+
**kwargs
|
22 |
+
):
|
23 |
super().__init__(norm_layer=norm_layer, **kwargs)
|
24 |
|
25 |
+
self.feature_extractor = DeepLabV3Plus(
|
26 |
+
backbone=backbone,
|
27 |
+
ch=deeplab_ch,
|
28 |
+
project_dropout=aspp_dropout,
|
29 |
+
norm_layer=norm_layer,
|
30 |
+
backbone_norm_layer=backbone_norm_layer,
|
31 |
+
)
|
32 |
self.feature_extractor.backbone.apply(LRMult(backbone_lr_mult))
|
33 |
+
self.head = SepConvHead(
|
34 |
+
1,
|
35 |
+
in_channels=deeplab_ch,
|
36 |
+
mid_channels=deeplab_ch // 2,
|
37 |
+
num_layers=2,
|
38 |
+
norm_layer=norm_layer,
|
39 |
+
)
|
40 |
|
41 |
def backbone_forward(self, image, coord_features=None):
|
42 |
backbone_features = self.feature_extractor(image, coord_features)
|
43 |
|
44 |
+
return {"instances": self.head(backbone_features[0])}
|
isegm/model/is_hrnet_model.py
CHANGED
@@ -1,19 +1,32 @@
|
|
1 |
import torch.nn as nn
|
2 |
|
|
|
3 |
from isegm.utils.serialization import serialize
|
|
|
4 |
from .is_model import ISModel
|
5 |
from .modeling.hrnet_ocr import HighResolutionNet
|
6 |
-
from isegm.model.modifiers import LRMult
|
7 |
|
8 |
|
9 |
class HRNetModel(ISModel):
|
10 |
@serialize
|
11 |
-
def __init__(
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
super().__init__(norm_layer=norm_layer, **kwargs)
|
14 |
|
15 |
-
self.feature_extractor = HighResolutionNet(
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
17 |
self.feature_extractor.apply(LRMult(backbone_lr_mult))
|
18 |
if ocr_width > 0:
|
19 |
self.feature_extractor.ocr_distri_head.apply(LRMult(1.0))
|
@@ -23,4 +36,4 @@ class HRNetModel(ISModel):
|
|
23 |
def backbone_forward(self, image, coord_features=None):
|
24 |
net_outputs = self.feature_extractor(image, coord_features)
|
25 |
|
26 |
-
return {
|
|
|
1 |
import torch.nn as nn
|
2 |
|
3 |
+
from isegm.model.modifiers import LRMult
|
4 |
from isegm.utils.serialization import serialize
|
5 |
+
|
6 |
from .is_model import ISModel
|
7 |
from .modeling.hrnet_ocr import HighResolutionNet
|
|
|
8 |
|
9 |
|
10 |
class HRNetModel(ISModel):
|
11 |
@serialize
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
width=48,
|
15 |
+
ocr_width=256,
|
16 |
+
small=False,
|
17 |
+
backbone_lr_mult=0.1,
|
18 |
+
norm_layer=nn.BatchNorm2d,
|
19 |
+
**kwargs
|
20 |
+
):
|
21 |
super().__init__(norm_layer=norm_layer, **kwargs)
|
22 |
|
23 |
+
self.feature_extractor = HighResolutionNet(
|
24 |
+
width=width,
|
25 |
+
ocr_width=ocr_width,
|
26 |
+
small=small,
|
27 |
+
num_classes=1,
|
28 |
+
norm_layer=norm_layer,
|
29 |
+
)
|
30 |
self.feature_extractor.apply(LRMult(backbone_lr_mult))
|
31 |
if ocr_width > 0:
|
32 |
self.feature_extractor.ocr_distri_head.apply(LRMult(1.0))
|
|
|
36 |
def backbone_forward(self, image, coord_features=None):
|
37 |
net_outputs = self.feature_extractor(image, coord_features)
|
38 |
|
39 |
+
return {"instances": net_outputs[0], "instances_aux": net_outputs[1]}
|
isegm/model/is_model.py
CHANGED
@@ -1,17 +1,27 @@
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
import numpy as np
|
4 |
|
5 |
-
from isegm.model.ops import DistMaps, ScaleLayer, BatchImageNormalize
|
6 |
from isegm.model.modifiers import LRMult
|
|
|
7 |
|
8 |
|
9 |
class ISModel(nn.Module):
|
10 |
-
def __init__(
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
super().__init__()
|
16 |
self.with_aux_output = with_aux_output
|
17 |
self.clicks_groups = clicks_groups
|
@@ -28,35 +38,64 @@ class ISModel(nn.Module):
|
|
28 |
|
29 |
if use_rgb_conv:
|
30 |
rgb_conv_layers = [
|
31 |
-
nn.Conv2d(
|
|
|
|
|
|
|
|
|
32 |
norm_layer(6 + self.coord_feature_ch),
|
33 |
-
nn.LeakyReLU(negative_slope=0.2)
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
]
|
36 |
self.rgb_conv = nn.Sequential(*rgb_conv_layers)
|
37 |
elif conv_extend:
|
38 |
self.rgb_conv = None
|
39 |
-
self.maps_transform = nn.Conv2d(
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
41 |
self.maps_transform.apply(LRMult(0.1))
|
42 |
else:
|
43 |
self.rgb_conv = None
|
44 |
mt_layers = [
|
45 |
-
nn.Conv2d(
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
]
|
50 |
self.maps_transform = nn.Sequential(*mt_layers)
|
51 |
|
52 |
if self.clicks_groups is not None:
|
53 |
self.dist_maps = nn.ModuleList()
|
54 |
for click_radius in self.clicks_groups:
|
55 |
-
self.dist_maps.append(
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
else:
|
58 |
-
self.dist_maps = DistMaps(
|
59 |
-
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def forward(self, image, points):
|
62 |
image, prev_mask = self.prepare_input(image)
|
@@ -69,11 +108,19 @@ class ISModel(nn.Module):
|
|
69 |
coord_features = self.maps_transform(coord_features)
|
70 |
outputs = self.backbone_forward(image, coord_features)
|
71 |
|
72 |
-
outputs[
|
73 |
-
|
|
|
|
|
|
|
|
|
74 |
if self.with_aux_output:
|
75 |
-
outputs[
|
76 |
-
|
|
|
|
|
|
|
|
|
77 |
|
78 |
return outputs
|
79 |
|
@@ -93,8 +140,13 @@ class ISModel(nn.Module):
|
|
93 |
|
94 |
def get_coord_features(self, image, prev_mask, points):
|
95 |
if self.clicks_groups is not None:
|
96 |
-
points_groups = split_points_by_order(
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
98 |
coord_features = torch.cat(coord_features, dim=1)
|
99 |
else:
|
100 |
coord_features = self.dist_maps(image, points)
|
@@ -112,8 +164,7 @@ def split_points_by_order(tpoints: torch.Tensor, groups):
|
|
112 |
num_points = points.shape[1] // 2
|
113 |
|
114 |
groups = [x if x > 0 else num_points for x in groups]
|
115 |
-
group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32)
|
116 |
-
for x in groups]
|
117 |
|
118 |
last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int)
|
119 |
for group_indx, group_size in enumerate(groups):
|
@@ -127,7 +178,9 @@ def split_points_by_order(tpoints: torch.Tensor, groups):
|
|
127 |
continue
|
128 |
|
129 |
is_negative = int(pindx >= num_points)
|
130 |
-
if group_id >= num_groups or (
|
|
|
|
|
131 |
group_id = num_groups - 1
|
132 |
|
133 |
new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
|
@@ -135,7 +188,9 @@ def split_points_by_order(tpoints: torch.Tensor, groups):
|
|
135 |
|
136 |
group_points[group_id][bindx, new_point_indx, :] = point
|
137 |
|
138 |
-
group_points = [
|
139 |
-
|
|
|
|
|
140 |
|
141 |
return group_points
|
|
|
1 |
+
import numpy as np
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
4 |
|
|
|
5 |
from isegm.model.modifiers import LRMult
|
6 |
+
from isegm.model.ops import BatchImageNormalize, DistMaps, ScaleLayer
|
7 |
|
8 |
|
9 |
class ISModel(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
use_rgb_conv=True,
|
13 |
+
with_aux_output=False,
|
14 |
+
norm_radius=260,
|
15 |
+
use_disks=False,
|
16 |
+
cpu_dist_maps=False,
|
17 |
+
clicks_groups=None,
|
18 |
+
with_prev_mask=False,
|
19 |
+
use_leaky_relu=False,
|
20 |
+
binary_prev_mask=False,
|
21 |
+
conv_extend=False,
|
22 |
+
norm_layer=nn.BatchNorm2d,
|
23 |
+
norm_mean_std=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
24 |
+
):
|
25 |
super().__init__()
|
26 |
self.with_aux_output = with_aux_output
|
27 |
self.clicks_groups = clicks_groups
|
|
|
38 |
|
39 |
if use_rgb_conv:
|
40 |
rgb_conv_layers = [
|
41 |
+
nn.Conv2d(
|
42 |
+
in_channels=3 + self.coord_feature_ch,
|
43 |
+
out_channels=6 + self.coord_feature_ch,
|
44 |
+
kernel_size=1,
|
45 |
+
),
|
46 |
norm_layer(6 + self.coord_feature_ch),
|
47 |
+
nn.LeakyReLU(negative_slope=0.2)
|
48 |
+
if use_leaky_relu
|
49 |
+
else nn.ReLU(inplace=True),
|
50 |
+
nn.Conv2d(
|
51 |
+
in_channels=6 + self.coord_feature_ch, out_channels=3, kernel_size=1
|
52 |
+
),
|
53 |
]
|
54 |
self.rgb_conv = nn.Sequential(*rgb_conv_layers)
|
55 |
elif conv_extend:
|
56 |
self.rgb_conv = None
|
57 |
+
self.maps_transform = nn.Conv2d(
|
58 |
+
in_channels=self.coord_feature_ch,
|
59 |
+
out_channels=64,
|
60 |
+
kernel_size=3,
|
61 |
+
stride=2,
|
62 |
+
padding=1,
|
63 |
+
)
|
64 |
self.maps_transform.apply(LRMult(0.1))
|
65 |
else:
|
66 |
self.rgb_conv = None
|
67 |
mt_layers = [
|
68 |
+
nn.Conv2d(
|
69 |
+
in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1
|
70 |
+
),
|
71 |
+
nn.LeakyReLU(negative_slope=0.2)
|
72 |
+
if use_leaky_relu
|
73 |
+
else nn.ReLU(inplace=True),
|
74 |
+
nn.Conv2d(
|
75 |
+
in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1
|
76 |
+
),
|
77 |
+
ScaleLayer(init_value=0.05, lr_mult=1),
|
78 |
]
|
79 |
self.maps_transform = nn.Sequential(*mt_layers)
|
80 |
|
81 |
if self.clicks_groups is not None:
|
82 |
self.dist_maps = nn.ModuleList()
|
83 |
for click_radius in self.clicks_groups:
|
84 |
+
self.dist_maps.append(
|
85 |
+
DistMaps(
|
86 |
+
norm_radius=click_radius,
|
87 |
+
spatial_scale=1.0,
|
88 |
+
cpu_mode=cpu_dist_maps,
|
89 |
+
use_disks=use_disks,
|
90 |
+
)
|
91 |
+
)
|
92 |
else:
|
93 |
+
self.dist_maps = DistMaps(
|
94 |
+
norm_radius=norm_radius,
|
95 |
+
spatial_scale=1.0,
|
96 |
+
cpu_mode=cpu_dist_maps,
|
97 |
+
use_disks=use_disks,
|
98 |
+
)
|
99 |
|
100 |
def forward(self, image, points):
|
101 |
image, prev_mask = self.prepare_input(image)
|
|
|
108 |
coord_features = self.maps_transform(coord_features)
|
109 |
outputs = self.backbone_forward(image, coord_features)
|
110 |
|
111 |
+
outputs["instances"] = nn.functional.interpolate(
|
112 |
+
outputs["instances"],
|
113 |
+
size=image.size()[2:],
|
114 |
+
mode="bilinear",
|
115 |
+
align_corners=True,
|
116 |
+
)
|
117 |
if self.with_aux_output:
|
118 |
+
outputs["instances_aux"] = nn.functional.interpolate(
|
119 |
+
outputs["instances_aux"],
|
120 |
+
size=image.size()[2:],
|
121 |
+
mode="bilinear",
|
122 |
+
align_corners=True,
|
123 |
+
)
|
124 |
|
125 |
return outputs
|
126 |
|
|
|
140 |
|
141 |
def get_coord_features(self, image, prev_mask, points):
|
142 |
if self.clicks_groups is not None:
|
143 |
+
points_groups = split_points_by_order(
|
144 |
+
points, groups=(2,) + (1,) * (len(self.clicks_groups) - 2) + (-1,)
|
145 |
+
)
|
146 |
+
coord_features = [
|
147 |
+
dist_map(image, pg)
|
148 |
+
for dist_map, pg in zip(self.dist_maps, points_groups)
|
149 |
+
]
|
150 |
coord_features = torch.cat(coord_features, dim=1)
|
151 |
else:
|
152 |
coord_features = self.dist_maps(image, points)
|
|
|
164 |
num_points = points.shape[1] // 2
|
165 |
|
166 |
groups = [x if x > 0 else num_points for x in groups]
|
167 |
+
group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32) for x in groups]
|
|
|
168 |
|
169 |
last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int)
|
170 |
for group_indx, group_size in enumerate(groups):
|
|
|
178 |
continue
|
179 |
|
180 |
is_negative = int(pindx >= num_points)
|
181 |
+
if group_id >= num_groups or (
|
182 |
+
group_id == 0 and is_negative
|
183 |
+
): # disable negative first click
|
184 |
group_id = num_groups - 1
|
185 |
|
186 |
new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
|
|
|
188 |
|
189 |
group_points[group_id][bindx, new_point_indx, :] = point
|
190 |
|
191 |
+
group_points = [
|
192 |
+
torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device)
|
193 |
+
for x in group_points
|
194 |
+
]
|
195 |
|
196 |
return group_points
|
isegm/model/losses.py
CHANGED
@@ -7,10 +7,20 @@ from isegm.utils import misc
|
|
7 |
|
8 |
|
9 |
class NormalizedFocalLossSigmoid(nn.Module):
|
10 |
-
def __init__(
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
super(NormalizedFocalLossSigmoid, self).__init__()
|
15 |
self._axis = axis
|
16 |
self._alpha = alpha
|
@@ -34,8 +44,12 @@ class NormalizedFocalLossSigmoid(nn.Module):
|
|
34 |
if not self._from_logits:
|
35 |
pred = torch.sigmoid(pred)
|
36 |
|
37 |
-
alpha = torch.where(
|
38 |
-
|
|
|
|
|
|
|
|
|
39 |
|
40 |
beta = (1 - pt) ** self._gamma
|
41 |
|
@@ -49,37 +63,69 @@ class NormalizedFocalLossSigmoid(nn.Module):
|
|
49 |
beta = torch.clamp_max(beta, self._max_mult)
|
50 |
|
51 |
with torch.no_grad():
|
52 |
-
ignore_area =
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
if np.any(ignore_area == 0):
|
55 |
-
self._k_sum =
|
|
|
|
|
56 |
|
57 |
beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1)
|
58 |
beta_pmax = beta_pmax.mean().item()
|
59 |
self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax
|
60 |
|
61 |
-
loss =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
loss = self._weight * (loss * sample_weight)
|
63 |
|
64 |
if self._size_average:
|
65 |
-
bsum = torch.sum(
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
67 |
else:
|
68 |
-
loss = torch.sum(
|
|
|
|
|
69 |
|
70 |
return loss
|
71 |
|
72 |
def log_states(self, sw, name, global_step):
|
73 |
-
sw.add_scalar(tag=name +
|
74 |
-
sw.add_scalar(tag=name +
|
75 |
|
76 |
|
77 |
class FocalLoss(nn.Module):
|
78 |
-
def __init__(
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
super(FocalLoss, self).__init__()
|
84 |
self._axis = axis
|
85 |
self._alpha = alpha
|
@@ -101,19 +147,38 @@ class FocalLoss(nn.Module):
|
|
101 |
if not self._from_logits:
|
102 |
pred = torch.sigmoid(pred)
|
103 |
|
104 |
-
alpha = torch.where(
|
105 |
-
|
|
|
|
|
|
|
|
|
106 |
|
107 |
beta = (1 - pt) ** self._gamma
|
108 |
|
109 |
-
loss =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
loss = self._weight * (loss * sample_weight)
|
111 |
|
112 |
if self._size_average:
|
113 |
-
tsum = torch.sum(
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
115 |
else:
|
116 |
-
loss = torch.sum(
|
|
|
|
|
117 |
|
118 |
return self._scale * loss
|
119 |
|
@@ -131,8 +196,9 @@ class SoftIoU(nn.Module):
|
|
131 |
if not self._from_sigmoid:
|
132 |
pred = torch.sigmoid(pred)
|
133 |
|
134 |
-
loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3))
|
135 |
-
|
|
|
136 |
|
137 |
return loss
|
138 |
|
@@ -154,8 +220,12 @@ class SigmoidBinaryCrossEntropyLoss(nn.Module):
|
|
154 |
loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
|
155 |
else:
|
156 |
eps = 1e-12
|
157 |
-
loss = -(
|
158 |
-
|
|
|
|
|
159 |
|
160 |
loss = self._weight * (loss * sample_weight)
|
161 |
-
return torch.mean(
|
|
|
|
|
|
7 |
|
8 |
|
9 |
class NormalizedFocalLossSigmoid(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
axis=-1,
|
13 |
+
alpha=0.25,
|
14 |
+
gamma=2,
|
15 |
+
max_mult=-1,
|
16 |
+
eps=1e-12,
|
17 |
+
from_sigmoid=False,
|
18 |
+
detach_delimeter=True,
|
19 |
+
batch_axis=0,
|
20 |
+
weight=None,
|
21 |
+
size_average=True,
|
22 |
+
ignore_label=-1,
|
23 |
+
):
|
24 |
super(NormalizedFocalLossSigmoid, self).__init__()
|
25 |
self._axis = axis
|
26 |
self._alpha = alpha
|
|
|
44 |
if not self._from_logits:
|
45 |
pred = torch.sigmoid(pred)
|
46 |
|
47 |
+
alpha = torch.where(
|
48 |
+
one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight
|
49 |
+
)
|
50 |
+
pt = torch.where(
|
51 |
+
sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)
|
52 |
+
)
|
53 |
|
54 |
beta = (1 - pt) ** self._gamma
|
55 |
|
|
|
63 |
beta = torch.clamp_max(beta, self._max_mult)
|
64 |
|
65 |
with torch.no_grad():
|
66 |
+
ignore_area = (
|
67 |
+
torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim())))
|
68 |
+
.cpu()
|
69 |
+
.numpy()
|
70 |
+
)
|
71 |
+
sample_mult = (
|
72 |
+
torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy()
|
73 |
+
)
|
74 |
if np.any(ignore_area == 0):
|
75 |
+
self._k_sum = (
|
76 |
+
0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean()
|
77 |
+
)
|
78 |
|
79 |
beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1)
|
80 |
beta_pmax = beta_pmax.mean().item()
|
81 |
self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax
|
82 |
|
83 |
+
loss = (
|
84 |
+
-alpha
|
85 |
+
* beta
|
86 |
+
* torch.log(
|
87 |
+
torch.min(
|
88 |
+
pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)
|
89 |
+
)
|
90 |
+
)
|
91 |
+
)
|
92 |
loss = self._weight * (loss * sample_weight)
|
93 |
|
94 |
if self._size_average:
|
95 |
+
bsum = torch.sum(
|
96 |
+
sample_weight,
|
97 |
+
dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis),
|
98 |
+
)
|
99 |
+
loss = torch.sum(
|
100 |
+
loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
|
101 |
+
) / (bsum + self._eps)
|
102 |
else:
|
103 |
+
loss = torch.sum(
|
104 |
+
loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
|
105 |
+
)
|
106 |
|
107 |
return loss
|
108 |
|
109 |
def log_states(self, sw, name, global_step):
|
110 |
+
sw.add_scalar(tag=name + "_k", value=self._k_sum, global_step=global_step)
|
111 |
+
sw.add_scalar(tag=name + "_m", value=self._m_max, global_step=global_step)
|
112 |
|
113 |
|
114 |
class FocalLoss(nn.Module):
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
axis=-1,
|
118 |
+
alpha=0.25,
|
119 |
+
gamma=2,
|
120 |
+
from_logits=False,
|
121 |
+
batch_axis=0,
|
122 |
+
weight=None,
|
123 |
+
num_class=None,
|
124 |
+
eps=1e-9,
|
125 |
+
size_average=True,
|
126 |
+
scale=1.0,
|
127 |
+
ignore_label=-1,
|
128 |
+
):
|
129 |
super(FocalLoss, self).__init__()
|
130 |
self._axis = axis
|
131 |
self._alpha = alpha
|
|
|
147 |
if not self._from_logits:
|
148 |
pred = torch.sigmoid(pred)
|
149 |
|
150 |
+
alpha = torch.where(
|
151 |
+
one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight
|
152 |
+
)
|
153 |
+
pt = torch.where(
|
154 |
+
sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)
|
155 |
+
)
|
156 |
|
157 |
beta = (1 - pt) ** self._gamma
|
158 |
|
159 |
+
loss = (
|
160 |
+
-alpha
|
161 |
+
* beta
|
162 |
+
* torch.log(
|
163 |
+
torch.min(
|
164 |
+
pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device)
|
165 |
+
)
|
166 |
+
)
|
167 |
+
)
|
168 |
loss = self._weight * (loss * sample_weight)
|
169 |
|
170 |
if self._size_average:
|
171 |
+
tsum = torch.sum(
|
172 |
+
sample_weight,
|
173 |
+
dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis),
|
174 |
+
)
|
175 |
+
loss = torch.sum(
|
176 |
+
loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
|
177 |
+
) / (tsum + self._eps)
|
178 |
else:
|
179 |
+
loss = torch.sum(
|
180 |
+
loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
|
181 |
+
)
|
182 |
|
183 |
return self._scale * loss
|
184 |
|
|
|
196 |
if not self._from_sigmoid:
|
197 |
pred = torch.sigmoid(pred)
|
198 |
|
199 |
+
loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) / (
|
200 |
+
torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8
|
201 |
+
)
|
202 |
|
203 |
return loss
|
204 |
|
|
|
220 |
loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred))
|
221 |
else:
|
222 |
eps = 1e-12
|
223 |
+
loss = -(
|
224 |
+
torch.log(pred + eps) * label
|
225 |
+
+ torch.log(1.0 - pred + eps) * (1.0 - label)
|
226 |
+
)
|
227 |
|
228 |
loss = self._weight * (loss * sample_weight)
|
229 |
+
return torch.mean(
|
230 |
+
loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)
|
231 |
+
)
|
isegm/model/metrics.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
import torch
|
2 |
import numpy as np
|
|
|
3 |
|
4 |
from isegm.utils import misc
|
5 |
|
@@ -27,9 +27,17 @@ class TrainMetric(object):
|
|
27 |
|
28 |
|
29 |
class AdaptiveIoU(TrainMetric):
|
30 |
-
def __init__(
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,))
|
34 |
self._ignore_label = ignore_label
|
35 |
self._from_logits = from_logits
|
@@ -59,7 +67,9 @@ class AdaptiveIoU(TrainMetric):
|
|
59 |
max_iou = temp_iou
|
60 |
best_thresh = t
|
61 |
|
62 |
-
self._iou_thresh =
|
|
|
|
|
63 |
self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou
|
64 |
self._epoch_iou_sum += max_iou
|
65 |
self._epoch_batch_count += 1
|
@@ -75,8 +85,14 @@ class AdaptiveIoU(TrainMetric):
|
|
75 |
self._epoch_batch_count = 0
|
76 |
|
77 |
def log_states(self, sw, tag_prefix, global_step):
|
78 |
-
sw.add_scalar(
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
@property
|
82 |
def iou_thresh(self):
|
@@ -88,8 +104,18 @@ def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False):
|
|
88 |
pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)
|
89 |
|
90 |
reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
|
91 |
-
union =
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
nonzero = union > 0
|
94 |
|
95 |
iou = intersection[nonzero] / union[nonzero]
|
|
|
|
|
1 |
import numpy as np
|
2 |
+
import torch
|
3 |
|
4 |
from isegm.utils import misc
|
5 |
|
|
|
27 |
|
28 |
|
29 |
class AdaptiveIoU(TrainMetric):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
init_thresh=0.4,
|
33 |
+
thresh_step=0.025,
|
34 |
+
thresh_beta=0.99,
|
35 |
+
iou_beta=0.9,
|
36 |
+
ignore_label=-1,
|
37 |
+
from_logits=True,
|
38 |
+
pred_output="instances",
|
39 |
+
gt_output="instances",
|
40 |
+
):
|
41 |
super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,))
|
42 |
self._ignore_label = ignore_label
|
43 |
self._from_logits = from_logits
|
|
|
67 |
max_iou = temp_iou
|
68 |
best_thresh = t
|
69 |
|
70 |
+
self._iou_thresh = (
|
71 |
+
self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh
|
72 |
+
)
|
73 |
self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou
|
74 |
self._epoch_iou_sum += max_iou
|
75 |
self._epoch_batch_count += 1
|
|
|
85 |
self._epoch_batch_count = 0
|
86 |
|
87 |
def log_states(self, sw, tag_prefix, global_step):
|
88 |
+
sw.add_scalar(
|
89 |
+
tag=tag_prefix + "_ema_iou", value=self._ema_iou, global_step=global_step
|
90 |
+
)
|
91 |
+
sw.add_scalar(
|
92 |
+
tag=tag_prefix + "_iou_thresh",
|
93 |
+
value=self._iou_thresh,
|
94 |
+
global_step=global_step,
|
95 |
+
)
|
96 |
|
97 |
@property
|
98 |
def iou_thresh(self):
|
|
|
104 |
pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)
|
105 |
|
106 |
reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
|
107 |
+
union = (
|
108 |
+
torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims)
|
109 |
+
.detach()
|
110 |
+
.cpu()
|
111 |
+
.numpy()
|
112 |
+
)
|
113 |
+
intersection = (
|
114 |
+
torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims)
|
115 |
+
.detach()
|
116 |
+
.cpu()
|
117 |
+
.numpy()
|
118 |
+
)
|
119 |
nonzero = union > 0
|
120 |
|
121 |
iou = intersection[nonzero] / union[nonzero]
|
isegm/model/modeling/basic_blocks.py
CHANGED
@@ -4,18 +4,28 @@ from isegm.model import ops
|
|
4 |
|
5 |
|
6 |
class ConvHead(nn.Module):
|
7 |
-
def __init__(
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
super(ConvHead, self).__init__()
|
11 |
convhead = []
|
12 |
|
13 |
for i in range(num_layers):
|
14 |
-
convhead.extend(
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
|
20 |
|
21 |
self.convhead = nn.Sequential(*convhead)
|
@@ -25,25 +35,43 @@ class ConvHead(nn.Module):
|
|
25 |
|
26 |
|
27 |
class SepConvHead(nn.Module):
|
28 |
-
def __init__(
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
super(SepConvHead, self).__init__()
|
32 |
|
33 |
sepconvhead = []
|
34 |
|
35 |
for i in range(num_layers):
|
36 |
sepconvhead.append(
|
37 |
-
SeparableConv2d(
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
41 |
)
|
42 |
if dropout_ratio > 0 and dropout_indx == i:
|
43 |
sepconvhead.append(nn.Dropout(dropout_ratio))
|
44 |
|
45 |
sepconvhead.append(
|
46 |
-
nn.Conv2d(
|
|
|
|
|
|
|
|
|
|
|
47 |
)
|
48 |
|
49 |
self.layers = nn.Sequential(*sepconvhead)
|
@@ -55,16 +83,34 @@ class SepConvHead(nn.Module):
|
|
55 |
|
56 |
|
57 |
class SeparableConv2d(nn.Module):
|
58 |
-
def __init__(
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
super(SeparableConv2d, self).__init__()
|
61 |
_activation = ops.select_activation_function(activation)
|
62 |
self.body = nn.Sequential(
|
63 |
-
nn.Conv2d(
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
|
67 |
-
_activation()
|
68 |
)
|
69 |
|
70 |
def forward(self, x):
|
|
|
4 |
|
5 |
|
6 |
class ConvHead(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
out_channels,
|
10 |
+
in_channels=32,
|
11 |
+
num_layers=1,
|
12 |
+
kernel_size=3,
|
13 |
+
padding=1,
|
14 |
+
norm_layer=nn.BatchNorm2d,
|
15 |
+
):
|
16 |
super(ConvHead, self).__init__()
|
17 |
convhead = []
|
18 |
|
19 |
for i in range(num_layers):
|
20 |
+
convhead.extend(
|
21 |
+
[
|
22 |
+
nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
|
23 |
+
nn.ReLU(),
|
24 |
+
norm_layer(in_channels)
|
25 |
+
if norm_layer is not None
|
26 |
+
else nn.Identity(),
|
27 |
+
]
|
28 |
+
)
|
29 |
convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
|
30 |
|
31 |
self.convhead = nn.Sequential(*convhead)
|
|
|
35 |
|
36 |
|
37 |
class SepConvHead(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
num_outputs,
|
41 |
+
in_channels,
|
42 |
+
mid_channels,
|
43 |
+
num_layers=1,
|
44 |
+
kernel_size=3,
|
45 |
+
padding=1,
|
46 |
+
dropout_ratio=0.0,
|
47 |
+
dropout_indx=0,
|
48 |
+
norm_layer=nn.BatchNorm2d,
|
49 |
+
):
|
50 |
super(SepConvHead, self).__init__()
|
51 |
|
52 |
sepconvhead = []
|
53 |
|
54 |
for i in range(num_layers):
|
55 |
sepconvhead.append(
|
56 |
+
SeparableConv2d(
|
57 |
+
in_channels=in_channels if i == 0 else mid_channels,
|
58 |
+
out_channels=mid_channels,
|
59 |
+
dw_kernel=kernel_size,
|
60 |
+
dw_padding=padding,
|
61 |
+
norm_layer=norm_layer,
|
62 |
+
activation="relu",
|
63 |
+
)
|
64 |
)
|
65 |
if dropout_ratio > 0 and dropout_indx == i:
|
66 |
sepconvhead.append(nn.Dropout(dropout_ratio))
|
67 |
|
68 |
sepconvhead.append(
|
69 |
+
nn.Conv2d(
|
70 |
+
in_channels=mid_channels,
|
71 |
+
out_channels=num_outputs,
|
72 |
+
kernel_size=1,
|
73 |
+
padding=0,
|
74 |
+
)
|
75 |
)
|
76 |
|
77 |
self.layers = nn.Sequential(*sepconvhead)
|
|
|
83 |
|
84 |
|
85 |
class SeparableConv2d(nn.Module):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
in_channels,
|
89 |
+
out_channels,
|
90 |
+
dw_kernel,
|
91 |
+
dw_padding,
|
92 |
+
dw_stride=1,
|
93 |
+
activation=None,
|
94 |
+
use_bias=False,
|
95 |
+
norm_layer=None,
|
96 |
+
):
|
97 |
super(SeparableConv2d, self).__init__()
|
98 |
_activation = ops.select_activation_function(activation)
|
99 |
self.body = nn.Sequential(
|
100 |
+
nn.Conv2d(
|
101 |
+
in_channels,
|
102 |
+
in_channels,
|
103 |
+
kernel_size=dw_kernel,
|
104 |
+
stride=dw_stride,
|
105 |
+
padding=dw_padding,
|
106 |
+
bias=use_bias,
|
107 |
+
groups=in_channels,
|
108 |
+
),
|
109 |
+
nn.Conv2d(
|
110 |
+
in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias
|
111 |
+
),
|
112 |
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
|
113 |
+
_activation(),
|
114 |
)
|
115 |
|
116 |
def forward(self, x):
|
isegm/model/modeling/deeplab_v3.py
CHANGED
@@ -1,21 +1,26 @@
|
|
1 |
from contextlib import ExitStack
|
2 |
|
3 |
import torch
|
4 |
-
from torch import nn
|
5 |
import torch.nn.functional as F
|
|
|
|
|
|
|
6 |
|
7 |
from .basic_blocks import SeparableConv2d
|
8 |
from .resnet import ResNetBackbone
|
9 |
-
from isegm.model import ops
|
10 |
|
11 |
|
12 |
class DeepLabV3Plus(nn.Module):
|
13 |
-
def __init__(
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
super(DeepLabV3Plus, self).__init__()
|
20 |
if backbone_norm_layer is None:
|
21 |
backbone_norm_layer = norm_layer
|
@@ -29,28 +34,44 @@ class DeepLabV3Plus(nn.Module):
|
|
29 |
self.skip_project_in_channels = 256 # layer 1 out_channels
|
30 |
|
31 |
self._kwargs = kwargs
|
32 |
-
if backbone ==
|
33 |
self.aspp_in_channels = 512
|
34 |
self.skip_project_in_channels = 64
|
35 |
|
36 |
-
self.backbone = ResNetBackbone(
|
37 |
-
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
self.head = _DeepLabHead(
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
if inference_mode:
|
49 |
self.set_prediction_mode()
|
50 |
|
51 |
def load_pretrained_weights(self):
|
52 |
-
pretrained = ResNetBackbone(
|
53 |
-
|
|
|
|
|
|
|
|
|
54 |
backbone_state_dict = self.backbone.state_dict()
|
55 |
pretrained_state_dict = pretrained.state_dict()
|
56 |
|
@@ -74,11 +95,11 @@ class DeepLabV3Plus(nn.Module):
|
|
74 |
c1 = self.skip_project(c1)
|
75 |
|
76 |
x = self.aspp(c4)
|
77 |
-
x = F.interpolate(x, c1.size()[2:], mode=
|
78 |
x = torch.cat((x, c1), dim=1)
|
79 |
x = self.head(x)
|
80 |
|
81 |
-
return x,
|
82 |
|
83 |
|
84 |
class _SkipProject(nn.Module):
|
@@ -89,7 +110,7 @@ class _SkipProject(nn.Module):
|
|
89 |
self.skip_project = nn.Sequential(
|
90 |
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
91 |
norm_layer(out_channels),
|
92 |
-
_activation()
|
93 |
)
|
94 |
|
95 |
def forward(self, x):
|
@@ -97,15 +118,31 @@ class _SkipProject(nn.Module):
|
|
97 |
|
98 |
|
99 |
class _DeepLabHead(nn.Module):
|
100 |
-
def __init__(
|
|
|
|
|
101 |
super(_DeepLabHead, self).__init__()
|
102 |
|
103 |
self.block = nn.Sequential(
|
104 |
-
SeparableConv2d(
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
)
|
110 |
|
111 |
def forward(self, x):
|
@@ -113,14 +150,25 @@ class _DeepLabHead(nn.Module):
|
|
113 |
|
114 |
|
115 |
class _ASPP(nn.Module):
|
116 |
-
def __init__(
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
super(_ASPP, self).__init__()
|
119 |
|
120 |
b0 = nn.Sequential(
|
121 |
-
nn.Conv2d(
|
|
|
|
|
|
|
|
|
|
|
122 |
norm_layer(out_channels),
|
123 |
-
nn.ReLU()
|
124 |
)
|
125 |
|
126 |
rate1, rate2, rate3 = tuple(atrous_rates)
|
@@ -132,10 +180,14 @@ class _ASPP(nn.Module):
|
|
132 |
self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
|
133 |
|
134 |
project = [
|
135 |
-
nn.Conv2d(
|
136 |
-
|
|
|
|
|
|
|
|
|
137 |
norm_layer(out_channels),
|
138 |
-
nn.ReLU()
|
139 |
]
|
140 |
if project_dropout > 0:
|
141 |
project.append(nn.Dropout(project_dropout))
|
@@ -153,24 +205,33 @@ class _AsppPooling(nn.Module):
|
|
153 |
|
154 |
self.gap = nn.Sequential(
|
155 |
nn.AdaptiveAvgPool2d((1, 1)),
|
156 |
-
nn.Conv2d(
|
157 |
-
|
|
|
|
|
|
|
|
|
158 |
norm_layer(out_channels),
|
159 |
-
nn.ReLU()
|
160 |
)
|
161 |
|
162 |
def forward(self, x):
|
163 |
pool = self.gap(x)
|
164 |
-
return F.interpolate(pool, x.size()[2:], mode=
|
165 |
|
166 |
|
167 |
def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
|
168 |
block = nn.Sequential(
|
169 |
-
nn.Conv2d(
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
172 |
norm_layer(out_channels),
|
173 |
-
nn.ReLU()
|
174 |
)
|
175 |
|
176 |
return block
|
|
|
1 |
from contextlib import ExitStack
|
2 |
|
3 |
import torch
|
|
|
4 |
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from isegm.model import ops
|
8 |
|
9 |
from .basic_blocks import SeparableConv2d
|
10 |
from .resnet import ResNetBackbone
|
|
|
11 |
|
12 |
|
13 |
class DeepLabV3Plus(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
backbone="resnet50",
|
17 |
+
norm_layer=nn.BatchNorm2d,
|
18 |
+
backbone_norm_layer=None,
|
19 |
+
ch=256,
|
20 |
+
project_dropout=0.5,
|
21 |
+
inference_mode=False,
|
22 |
+
**kwargs
|
23 |
+
):
|
24 |
super(DeepLabV3Plus, self).__init__()
|
25 |
if backbone_norm_layer is None:
|
26 |
backbone_norm_layer = norm_layer
|
|
|
34 |
self.skip_project_in_channels = 256 # layer 1 out_channels
|
35 |
|
36 |
self._kwargs = kwargs
|
37 |
+
if backbone == "resnet34":
|
38 |
self.aspp_in_channels = 512
|
39 |
self.skip_project_in_channels = 64
|
40 |
|
41 |
+
self.backbone = ResNetBackbone(
|
42 |
+
backbone=self.backbone_name,
|
43 |
+
pretrained_base=False,
|
44 |
+
norm_layer=self.backbone_norm_layer,
|
45 |
+
**kwargs
|
46 |
+
)
|
47 |
|
48 |
+
self.head = _DeepLabHead(
|
49 |
+
in_channels=ch + 32,
|
50 |
+
mid_channels=ch,
|
51 |
+
out_channels=ch,
|
52 |
+
norm_layer=self.norm_layer,
|
53 |
+
)
|
54 |
+
self.skip_project = _SkipProject(
|
55 |
+
self.skip_project_in_channels, 32, norm_layer=self.norm_layer
|
56 |
+
)
|
57 |
+
self.aspp = _ASPP(
|
58 |
+
in_channels=self.aspp_in_channels,
|
59 |
+
atrous_rates=[12, 24, 36],
|
60 |
+
out_channels=ch,
|
61 |
+
project_dropout=project_dropout,
|
62 |
+
norm_layer=self.norm_layer,
|
63 |
+
)
|
64 |
|
65 |
if inference_mode:
|
66 |
self.set_prediction_mode()
|
67 |
|
68 |
def load_pretrained_weights(self):
|
69 |
+
pretrained = ResNetBackbone(
|
70 |
+
backbone=self.backbone_name,
|
71 |
+
pretrained_base=True,
|
72 |
+
norm_layer=self.backbone_norm_layer,
|
73 |
+
**self._kwargs
|
74 |
+
)
|
75 |
backbone_state_dict = self.backbone.state_dict()
|
76 |
pretrained_state_dict = pretrained.state_dict()
|
77 |
|
|
|
95 |
c1 = self.skip_project(c1)
|
96 |
|
97 |
x = self.aspp(c4)
|
98 |
+
x = F.interpolate(x, c1.size()[2:], mode="bilinear", align_corners=True)
|
99 |
x = torch.cat((x, c1), dim=1)
|
100 |
x = self.head(x)
|
101 |
|
102 |
+
return (x,)
|
103 |
|
104 |
|
105 |
class _SkipProject(nn.Module):
|
|
|
110 |
self.skip_project = nn.Sequential(
|
111 |
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
112 |
norm_layer(out_channels),
|
113 |
+
_activation(),
|
114 |
)
|
115 |
|
116 |
def forward(self, x):
|
|
|
118 |
|
119 |
|
120 |
class _DeepLabHead(nn.Module):
|
121 |
+
def __init__(
|
122 |
+
self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d
|
123 |
+
):
|
124 |
super(_DeepLabHead, self).__init__()
|
125 |
|
126 |
self.block = nn.Sequential(
|
127 |
+
SeparableConv2d(
|
128 |
+
in_channels=in_channels,
|
129 |
+
out_channels=mid_channels,
|
130 |
+
dw_kernel=3,
|
131 |
+
dw_padding=1,
|
132 |
+
activation="relu",
|
133 |
+
norm_layer=norm_layer,
|
134 |
+
),
|
135 |
+
SeparableConv2d(
|
136 |
+
in_channels=mid_channels,
|
137 |
+
out_channels=mid_channels,
|
138 |
+
dw_kernel=3,
|
139 |
+
dw_padding=1,
|
140 |
+
activation="relu",
|
141 |
+
norm_layer=norm_layer,
|
142 |
+
),
|
143 |
+
nn.Conv2d(
|
144 |
+
in_channels=mid_channels, out_channels=out_channels, kernel_size=1
|
145 |
+
),
|
146 |
)
|
147 |
|
148 |
def forward(self, x):
|
|
|
150 |
|
151 |
|
152 |
class _ASPP(nn.Module):
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
in_channels,
|
156 |
+
atrous_rates,
|
157 |
+
out_channels=256,
|
158 |
+
project_dropout=0.5,
|
159 |
+
norm_layer=nn.BatchNorm2d,
|
160 |
+
):
|
161 |
super(_ASPP, self).__init__()
|
162 |
|
163 |
b0 = nn.Sequential(
|
164 |
+
nn.Conv2d(
|
165 |
+
in_channels=in_channels,
|
166 |
+
out_channels=out_channels,
|
167 |
+
kernel_size=1,
|
168 |
+
bias=False,
|
169 |
+
),
|
170 |
norm_layer(out_channels),
|
171 |
+
nn.ReLU(),
|
172 |
)
|
173 |
|
174 |
rate1, rate2, rate3 = tuple(atrous_rates)
|
|
|
180 |
self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
|
181 |
|
182 |
project = [
|
183 |
+
nn.Conv2d(
|
184 |
+
in_channels=5 * out_channels,
|
185 |
+
out_channels=out_channels,
|
186 |
+
kernel_size=1,
|
187 |
+
bias=False,
|
188 |
+
),
|
189 |
norm_layer(out_channels),
|
190 |
+
nn.ReLU(),
|
191 |
]
|
192 |
if project_dropout > 0:
|
193 |
project.append(nn.Dropout(project_dropout))
|
|
|
205 |
|
206 |
self.gap = nn.Sequential(
|
207 |
nn.AdaptiveAvgPool2d((1, 1)),
|
208 |
+
nn.Conv2d(
|
209 |
+
in_channels=in_channels,
|
210 |
+
out_channels=out_channels,
|
211 |
+
kernel_size=1,
|
212 |
+
bias=False,
|
213 |
+
),
|
214 |
norm_layer(out_channels),
|
215 |
+
nn.ReLU(),
|
216 |
)
|
217 |
|
218 |
def forward(self, x):
|
219 |
pool = self.gap(x)
|
220 |
+
return F.interpolate(pool, x.size()[2:], mode="bilinear", align_corners=True)
|
221 |
|
222 |
|
223 |
def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
|
224 |
block = nn.Sequential(
|
225 |
+
nn.Conv2d(
|
226 |
+
in_channels=in_channels,
|
227 |
+
out_channels=out_channels,
|
228 |
+
kernel_size=3,
|
229 |
+
padding=atrous_rate,
|
230 |
+
dilation=atrous_rate,
|
231 |
+
bias=False,
|
232 |
+
),
|
233 |
norm_layer(out_channels),
|
234 |
+
nn.ReLU(),
|
235 |
)
|
236 |
|
237 |
return block
|
isegm/model/modeling/hrnet_ocr.py
CHANGED
@@ -1,19 +1,30 @@
|
|
1 |
import os
|
|
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
-
import torch.nn as nn
|
5 |
import torch._utils
|
|
|
6 |
import torch.nn.functional as F
|
7 |
-
|
|
|
8 |
from .resnetv1b import BasicBlockV1b, BottleneckV1b
|
9 |
|
10 |
relu_inplace = True
|
11 |
|
12 |
|
13 |
class HighResolutionModule(nn.Module):
|
14 |
-
def __init__(
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
super(HighResolutionModule, self).__init__()
|
18 |
self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
|
19 |
|
@@ -26,48 +37,67 @@ class HighResolutionModule(nn.Module):
|
|
26 |
self.multi_scale_output = multi_scale_output
|
27 |
|
28 |
self.branches = self._make_branches(
|
29 |
-
num_branches, blocks, num_blocks, num_channels
|
|
|
30 |
self.fuse_layers = self._make_fuse_layers()
|
31 |
self.relu = nn.ReLU(inplace=relu_inplace)
|
32 |
|
33 |
def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
|
34 |
if num_branches != len(num_blocks):
|
35 |
-
error_msg =
|
36 |
-
num_branches, len(num_blocks)
|
|
|
37 |
raise ValueError(error_msg)
|
38 |
|
39 |
if num_branches != len(num_channels):
|
40 |
-
error_msg =
|
41 |
-
num_branches, len(num_channels)
|
|
|
42 |
raise ValueError(error_msg)
|
43 |
|
44 |
if num_branches != len(num_inchannels):
|
45 |
-
error_msg =
|
46 |
-
num_branches, len(num_inchannels)
|
|
|
47 |
raise ValueError(error_msg)
|
48 |
|
49 |
-
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
50 |
-
stride=1):
|
51 |
downsample = None
|
52 |
-
if
|
53 |
-
|
|
|
|
|
|
|
54 |
downsample = nn.Sequential(
|
55 |
-
nn.Conv2d(
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
58 |
self.norm_layer(num_channels[branch_index] * block.expansion),
|
59 |
)
|
60 |
|
61 |
layers = []
|
62 |
-
layers.append(
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
67 |
for i in range(1, num_blocks[branch_index]):
|
68 |
-
layers.append(
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
71 |
|
72 |
return nn.Sequential(*layers)
|
73 |
|
@@ -75,8 +105,7 @@ class HighResolutionModule(nn.Module):
|
|
75 |
branches = []
|
76 |
|
77 |
for i in range(num_branches):
|
78 |
-
branches.append(
|
79 |
-
self._make_one_branch(i, block, num_blocks, num_channels))
|
80 |
|
81 |
return nn.ModuleList(branches)
|
82 |
|
@@ -91,12 +120,17 @@ class HighResolutionModule(nn.Module):
|
|
91 |
fuse_layer = []
|
92 |
for j in range(num_branches):
|
93 |
if j > i:
|
94 |
-
fuse_layer.append(
|
95 |
-
nn.
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
100 |
elif j == i:
|
101 |
fuse_layer.append(None)
|
102 |
else:
|
@@ -104,19 +138,35 @@ class HighResolutionModule(nn.Module):
|
|
104 |
for k in range(i - j):
|
105 |
if k == i - j - 1:
|
106 |
num_outchannels_conv3x3 = num_inchannels[i]
|
107 |
-
conv3x3s.append(
|
108 |
-
nn.
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
else:
|
113 |
num_outchannels_conv3x3 = num_inchannels[j]
|
114 |
-
conv3x3s.append(
|
115 |
-
nn.
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
fuse_layer.append(nn.Sequential(*conv3x3s))
|
121 |
fuse_layers.append(nn.ModuleList(fuse_layer))
|
122 |
|
@@ -144,7 +194,9 @@ class HighResolutionModule(nn.Module):
|
|
144 |
y = y + F.interpolate(
|
145 |
self.fuse_layers[i][j](x[j]),
|
146 |
size=[height_output, width_output],
|
147 |
-
mode=
|
|
|
|
|
148 |
else:
|
149 |
y = y + self.fuse_layers[i][j](x[j])
|
150 |
x_fuse.append(self.relu(y))
|
@@ -153,8 +205,15 @@ class HighResolutionModule(nn.Module):
|
|
153 |
|
154 |
|
155 |
class HighResolutionNet(nn.Module):
|
156 |
-
def __init__(
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
super(HighResolutionNet, self).__init__()
|
159 |
self.norm_layer = norm_layer
|
160 |
self.width = width
|
@@ -170,40 +229,61 @@ class HighResolutionNet(nn.Module):
|
|
170 |
num_blocks = 2 if small else 4
|
171 |
|
172 |
stage1_num_channels = 64
|
173 |
-
self.layer1 = self._make_layer(
|
|
|
|
|
174 |
stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
|
175 |
|
176 |
self.stage2_num_branches = 2
|
177 |
num_channels = [width, 2 * width]
|
178 |
num_inchannels = [
|
179 |
-
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))
|
|
|
180 |
self.transition1 = self._make_transition_layer(
|
181 |
-
[stage1_out_channel], num_inchannels
|
|
|
182 |
self.stage2, pre_stage_channels = self._make_stage(
|
183 |
-
BasicBlockV1b,
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
self.stage3_num_branches = 3
|
187 |
num_channels = [width, 2 * width, 4 * width]
|
188 |
num_inchannels = [
|
189 |
-
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))
|
|
|
190 |
self.transition2 = self._make_transition_layer(
|
191 |
-
pre_stage_channels, num_inchannels
|
|
|
192 |
self.stage3, pre_stage_channels = self._make_stage(
|
193 |
-
BasicBlockV1b,
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
196 |
|
197 |
self.stage4_num_branches = 4
|
198 |
num_channels = [width, 2 * width, 4 * width, 8 * width]
|
199 |
num_inchannels = [
|
200 |
-
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))
|
|
|
201 |
self.transition3 = self._make_transition_layer(
|
202 |
-
pre_stage_channels, num_inchannels
|
|
|
203 |
self.stage4, pre_stage_channels = self._make_stage(
|
204 |
-
BasicBlockV1b,
|
|
|
|
|
205 |
num_branches=self.stage4_num_branches,
|
206 |
-
num_blocks=4 * [num_blocks],
|
|
|
|
|
207 |
|
208 |
last_inp_channels = np.int(np.sum(pre_stage_channels))
|
209 |
if self.ocr_width > 0:
|
@@ -211,43 +291,77 @@ class HighResolutionNet(nn.Module):
|
|
211 |
ocr_key_channels = self.ocr_width
|
212 |
|
213 |
self.conv3x3_ocr = nn.Sequential(
|
214 |
-
nn.Conv2d(
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
216 |
norm_layer(ocr_mid_channels),
|
217 |
nn.ReLU(inplace=relu_inplace),
|
218 |
)
|
219 |
self.ocr_gather_head = SpatialGather_Module(num_classes)
|
220 |
|
221 |
-
self.ocr_distri_head = SpatialOCR_Module(
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
|
|
|
|
228 |
self.cls_head = nn.Conv2d(
|
229 |
-
ocr_mid_channels,
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
self.aux_head = nn.Sequential(
|
232 |
-
nn.Conv2d(
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
234 |
norm_layer(last_inp_channels),
|
235 |
nn.ReLU(inplace=relu_inplace),
|
236 |
-
nn.Conv2d(
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
)
|
239 |
else:
|
240 |
self.cls_head = nn.Sequential(
|
241 |
-
nn.Conv2d(
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
243 |
norm_layer(last_inp_channels),
|
244 |
nn.ReLU(inplace=relu_inplace),
|
245 |
-
nn.Conv2d(
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
)
|
248 |
|
249 |
-
def _make_transition_layer(
|
250 |
-
self, num_channels_pre_layer, num_channels_cur_layer):
|
251 |
num_branches_cur = len(num_channels_cur_layer)
|
252 |
num_branches_pre = len(num_channels_pre_layer)
|
253 |
|
@@ -255,28 +369,45 @@ class HighResolutionNet(nn.Module):
|
|
255 |
for i in range(num_branches_cur):
|
256 |
if i < num_branches_pre:
|
257 |
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
258 |
-
transition_layers.append(
|
259 |
-
nn.
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
267 |
else:
|
268 |
transition_layers.append(None)
|
269 |
else:
|
270 |
conv3x3s = []
|
271 |
for j in range(i + 1 - num_branches_pre):
|
272 |
inchannels = num_channels_pre_layer[-1]
|
273 |
-
outchannels =
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
nn.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
transition_layers.append(nn.Sequential(*conv3x3s))
|
281 |
|
282 |
return nn.ModuleList(transition_layers)
|
@@ -285,24 +416,43 @@ class HighResolutionNet(nn.Module):
|
|
285 |
downsample = None
|
286 |
if stride != 1 or inplanes != planes * block.expansion:
|
287 |
downsample = nn.Sequential(
|
288 |
-
nn.Conv2d(
|
289 |
-
|
|
|
|
|
|
|
|
|
|
|
290 |
self.norm_layer(planes * block.expansion),
|
291 |
)
|
292 |
|
293 |
layers = []
|
294 |
-
layers.append(
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
inplanes = planes * block.expansion
|
297 |
for i in range(1, blocks):
|
298 |
layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
|
299 |
|
300 |
return nn.Sequential(*layers)
|
301 |
|
302 |
-
def _make_stage(
|
303 |
-
|
304 |
-
|
305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
modules = []
|
307 |
for i in range(num_modules):
|
308 |
# multi_scale_output is only used last module
|
@@ -311,15 +461,17 @@ class HighResolutionNet(nn.Module):
|
|
311 |
else:
|
312 |
reset_multi_scale_output = True
|
313 |
modules.append(
|
314 |
-
HighResolutionModule(
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
|
|
|
|
323 |
)
|
324 |
num_inchannels = modules[-1].get_num_inchannels()
|
325 |
|
@@ -387,30 +539,38 @@ class HighResolutionNet(nn.Module):
|
|
387 |
def aggregate_hrnet_features(self, x):
|
388 |
# Upsampling
|
389 |
x0_h, x0_w = x[0].size(2), x[0].size(3)
|
390 |
-
x1 = F.interpolate(
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
|
|
|
|
|
|
396 |
|
397 |
return torch.cat([x[0], x1, x2, x3], 1)
|
398 |
|
399 |
-
def load_pretrained_weights(self, pretrained_path=
|
400 |
model_dict = self.state_dict()
|
401 |
|
402 |
if not os.path.exists(pretrained_path):
|
403 |
print(f'\nFile "{pretrained_path}" does not exist.')
|
404 |
-
print(
|
405 |
-
|
406 |
-
|
|
|
|
|
407 |
exit(1)
|
408 |
-
pretrained_dict = torch.load(pretrained_path, map_location={
|
409 |
-
pretrained_dict = {
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
|
|
|
|
|
|
414 |
|
415 |
model_dict.update(pretrained_dict)
|
416 |
self.load_state_dict(model_dict)
|
|
|
1 |
import os
|
2 |
+
|
3 |
import numpy as np
|
4 |
import torch
|
|
|
5 |
import torch._utils
|
6 |
+
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from .ocr import SpatialGather_Module, SpatialOCR_Module
|
10 |
from .resnetv1b import BasicBlockV1b, BottleneckV1b
|
11 |
|
12 |
relu_inplace = True
|
13 |
|
14 |
|
15 |
class HighResolutionModule(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
num_branches,
|
19 |
+
blocks,
|
20 |
+
num_blocks,
|
21 |
+
num_inchannels,
|
22 |
+
num_channels,
|
23 |
+
fuse_method,
|
24 |
+
multi_scale_output=True,
|
25 |
+
norm_layer=nn.BatchNorm2d,
|
26 |
+
align_corners=True,
|
27 |
+
):
|
28 |
super(HighResolutionModule, self).__init__()
|
29 |
self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
|
30 |
|
|
|
37 |
self.multi_scale_output = multi_scale_output
|
38 |
|
39 |
self.branches = self._make_branches(
|
40 |
+
num_branches, blocks, num_blocks, num_channels
|
41 |
+
)
|
42 |
self.fuse_layers = self._make_fuse_layers()
|
43 |
self.relu = nn.ReLU(inplace=relu_inplace)
|
44 |
|
45 |
def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
|
46 |
if num_branches != len(num_blocks):
|
47 |
+
error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
|
48 |
+
num_branches, len(num_blocks)
|
49 |
+
)
|
50 |
raise ValueError(error_msg)
|
51 |
|
52 |
if num_branches != len(num_channels):
|
53 |
+
error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
|
54 |
+
num_branches, len(num_channels)
|
55 |
+
)
|
56 |
raise ValueError(error_msg)
|
57 |
|
58 |
if num_branches != len(num_inchannels):
|
59 |
+
error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
|
60 |
+
num_branches, len(num_inchannels)
|
61 |
+
)
|
62 |
raise ValueError(error_msg)
|
63 |
|
64 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
|
|
|
65 |
downsample = None
|
66 |
+
if (
|
67 |
+
stride != 1
|
68 |
+
or self.num_inchannels[branch_index]
|
69 |
+
!= num_channels[branch_index] * block.expansion
|
70 |
+
):
|
71 |
downsample = nn.Sequential(
|
72 |
+
nn.Conv2d(
|
73 |
+
self.num_inchannels[branch_index],
|
74 |
+
num_channels[branch_index] * block.expansion,
|
75 |
+
kernel_size=1,
|
76 |
+
stride=stride,
|
77 |
+
bias=False,
|
78 |
+
),
|
79 |
self.norm_layer(num_channels[branch_index] * block.expansion),
|
80 |
)
|
81 |
|
82 |
layers = []
|
83 |
+
layers.append(
|
84 |
+
block(
|
85 |
+
self.num_inchannels[branch_index],
|
86 |
+
num_channels[branch_index],
|
87 |
+
stride,
|
88 |
+
downsample=downsample,
|
89 |
+
norm_layer=self.norm_layer,
|
90 |
+
)
|
91 |
+
)
|
92 |
+
self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
|
93 |
for i in range(1, num_blocks[branch_index]):
|
94 |
+
layers.append(
|
95 |
+
block(
|
96 |
+
self.num_inchannels[branch_index],
|
97 |
+
num_channels[branch_index],
|
98 |
+
norm_layer=self.norm_layer,
|
99 |
+
)
|
100 |
+
)
|
101 |
|
102 |
return nn.Sequential(*layers)
|
103 |
|
|
|
105 |
branches = []
|
106 |
|
107 |
for i in range(num_branches):
|
108 |
+
branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
|
|
|
109 |
|
110 |
return nn.ModuleList(branches)
|
111 |
|
|
|
120 |
fuse_layer = []
|
121 |
for j in range(num_branches):
|
122 |
if j > i:
|
123 |
+
fuse_layer.append(
|
124 |
+
nn.Sequential(
|
125 |
+
nn.Conv2d(
|
126 |
+
in_channels=num_inchannels[j],
|
127 |
+
out_channels=num_inchannels[i],
|
128 |
+
kernel_size=1,
|
129 |
+
bias=False,
|
130 |
+
),
|
131 |
+
self.norm_layer(num_inchannels[i]),
|
132 |
+
)
|
133 |
+
)
|
134 |
elif j == i:
|
135 |
fuse_layer.append(None)
|
136 |
else:
|
|
|
138 |
for k in range(i - j):
|
139 |
if k == i - j - 1:
|
140 |
num_outchannels_conv3x3 = num_inchannels[i]
|
141 |
+
conv3x3s.append(
|
142 |
+
nn.Sequential(
|
143 |
+
nn.Conv2d(
|
144 |
+
num_inchannels[j],
|
145 |
+
num_outchannels_conv3x3,
|
146 |
+
kernel_size=3,
|
147 |
+
stride=2,
|
148 |
+
padding=1,
|
149 |
+
bias=False,
|
150 |
+
),
|
151 |
+
self.norm_layer(num_outchannels_conv3x3),
|
152 |
+
)
|
153 |
+
)
|
154 |
else:
|
155 |
num_outchannels_conv3x3 = num_inchannels[j]
|
156 |
+
conv3x3s.append(
|
157 |
+
nn.Sequential(
|
158 |
+
nn.Conv2d(
|
159 |
+
num_inchannels[j],
|
160 |
+
num_outchannels_conv3x3,
|
161 |
+
kernel_size=3,
|
162 |
+
stride=2,
|
163 |
+
padding=1,
|
164 |
+
bias=False,
|
165 |
+
),
|
166 |
+
self.norm_layer(num_outchannels_conv3x3),
|
167 |
+
nn.ReLU(inplace=relu_inplace),
|
168 |
+
)
|
169 |
+
)
|
170 |
fuse_layer.append(nn.Sequential(*conv3x3s))
|
171 |
fuse_layers.append(nn.ModuleList(fuse_layer))
|
172 |
|
|
|
194 |
y = y + F.interpolate(
|
195 |
self.fuse_layers[i][j](x[j]),
|
196 |
size=[height_output, width_output],
|
197 |
+
mode="bilinear",
|
198 |
+
align_corners=self.align_corners,
|
199 |
+
)
|
200 |
else:
|
201 |
y = y + self.fuse_layers[i][j](x[j])
|
202 |
x_fuse.append(self.relu(y))
|
|
|
205 |
|
206 |
|
207 |
class HighResolutionNet(nn.Module):
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
width,
|
211 |
+
num_classes,
|
212 |
+
ocr_width=256,
|
213 |
+
small=False,
|
214 |
+
norm_layer=nn.BatchNorm2d,
|
215 |
+
align_corners=True,
|
216 |
+
):
|
217 |
super(HighResolutionNet, self).__init__()
|
218 |
self.norm_layer = norm_layer
|
219 |
self.width = width
|
|
|
229 |
num_blocks = 2 if small else 4
|
230 |
|
231 |
stage1_num_channels = 64
|
232 |
+
self.layer1 = self._make_layer(
|
233 |
+
BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks
|
234 |
+
)
|
235 |
stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
|
236 |
|
237 |
self.stage2_num_branches = 2
|
238 |
num_channels = [width, 2 * width]
|
239 |
num_inchannels = [
|
240 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))
|
241 |
+
]
|
242 |
self.transition1 = self._make_transition_layer(
|
243 |
+
[stage1_out_channel], num_inchannels
|
244 |
+
)
|
245 |
self.stage2, pre_stage_channels = self._make_stage(
|
246 |
+
BasicBlockV1b,
|
247 |
+
num_inchannels=num_inchannels,
|
248 |
+
num_modules=1,
|
249 |
+
num_branches=self.stage2_num_branches,
|
250 |
+
num_blocks=2 * [num_blocks],
|
251 |
+
num_channels=num_channels,
|
252 |
+
)
|
253 |
|
254 |
self.stage3_num_branches = 3
|
255 |
num_channels = [width, 2 * width, 4 * width]
|
256 |
num_inchannels = [
|
257 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))
|
258 |
+
]
|
259 |
self.transition2 = self._make_transition_layer(
|
260 |
+
pre_stage_channels, num_inchannels
|
261 |
+
)
|
262 |
self.stage3, pre_stage_channels = self._make_stage(
|
263 |
+
BasicBlockV1b,
|
264 |
+
num_inchannels=num_inchannels,
|
265 |
+
num_modules=3 if small else 4,
|
266 |
+
num_branches=self.stage3_num_branches,
|
267 |
+
num_blocks=3 * [num_blocks],
|
268 |
+
num_channels=num_channels,
|
269 |
+
)
|
270 |
|
271 |
self.stage4_num_branches = 4
|
272 |
num_channels = [width, 2 * width, 4 * width, 8 * width]
|
273 |
num_inchannels = [
|
274 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))
|
275 |
+
]
|
276 |
self.transition3 = self._make_transition_layer(
|
277 |
+
pre_stage_channels, num_inchannels
|
278 |
+
)
|
279 |
self.stage4, pre_stage_channels = self._make_stage(
|
280 |
+
BasicBlockV1b,
|
281 |
+
num_inchannels=num_inchannels,
|
282 |
+
num_modules=2 if small else 3,
|
283 |
num_branches=self.stage4_num_branches,
|
284 |
+
num_blocks=4 * [num_blocks],
|
285 |
+
num_channels=num_channels,
|
286 |
+
)
|
287 |
|
288 |
last_inp_channels = np.int(np.sum(pre_stage_channels))
|
289 |
if self.ocr_width > 0:
|
|
|
291 |
ocr_key_channels = self.ocr_width
|
292 |
|
293 |
self.conv3x3_ocr = nn.Sequential(
|
294 |
+
nn.Conv2d(
|
295 |
+
last_inp_channels,
|
296 |
+
ocr_mid_channels,
|
297 |
+
kernel_size=3,
|
298 |
+
stride=1,
|
299 |
+
padding=1,
|
300 |
+
),
|
301 |
norm_layer(ocr_mid_channels),
|
302 |
nn.ReLU(inplace=relu_inplace),
|
303 |
)
|
304 |
self.ocr_gather_head = SpatialGather_Module(num_classes)
|
305 |
|
306 |
+
self.ocr_distri_head = SpatialOCR_Module(
|
307 |
+
in_channels=ocr_mid_channels,
|
308 |
+
key_channels=ocr_key_channels,
|
309 |
+
out_channels=ocr_mid_channels,
|
310 |
+
scale=1,
|
311 |
+
dropout=0.05,
|
312 |
+
norm_layer=norm_layer,
|
313 |
+
align_corners=align_corners,
|
314 |
+
)
|
315 |
self.cls_head = nn.Conv2d(
|
316 |
+
ocr_mid_channels,
|
317 |
+
num_classes,
|
318 |
+
kernel_size=1,
|
319 |
+
stride=1,
|
320 |
+
padding=0,
|
321 |
+
bias=True,
|
322 |
+
)
|
323 |
|
324 |
self.aux_head = nn.Sequential(
|
325 |
+
nn.Conv2d(
|
326 |
+
last_inp_channels,
|
327 |
+
last_inp_channels,
|
328 |
+
kernel_size=1,
|
329 |
+
stride=1,
|
330 |
+
padding=0,
|
331 |
+
),
|
332 |
norm_layer(last_inp_channels),
|
333 |
nn.ReLU(inplace=relu_inplace),
|
334 |
+
nn.Conv2d(
|
335 |
+
last_inp_channels,
|
336 |
+
num_classes,
|
337 |
+
kernel_size=1,
|
338 |
+
stride=1,
|
339 |
+
padding=0,
|
340 |
+
bias=True,
|
341 |
+
),
|
342 |
)
|
343 |
else:
|
344 |
self.cls_head = nn.Sequential(
|
345 |
+
nn.Conv2d(
|
346 |
+
last_inp_channels,
|
347 |
+
last_inp_channels,
|
348 |
+
kernel_size=3,
|
349 |
+
stride=1,
|
350 |
+
padding=1,
|
351 |
+
),
|
352 |
norm_layer(last_inp_channels),
|
353 |
nn.ReLU(inplace=relu_inplace),
|
354 |
+
nn.Conv2d(
|
355 |
+
last_inp_channels,
|
356 |
+
num_classes,
|
357 |
+
kernel_size=1,
|
358 |
+
stride=1,
|
359 |
+
padding=0,
|
360 |
+
bias=True,
|
361 |
+
),
|
362 |
)
|
363 |
|
364 |
+
def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
|
|
|
365 |
num_branches_cur = len(num_channels_cur_layer)
|
366 |
num_branches_pre = len(num_channels_pre_layer)
|
367 |
|
|
|
369 |
for i in range(num_branches_cur):
|
370 |
if i < num_branches_pre:
|
371 |
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
372 |
+
transition_layers.append(
|
373 |
+
nn.Sequential(
|
374 |
+
nn.Conv2d(
|
375 |
+
num_channels_pre_layer[i],
|
376 |
+
num_channels_cur_layer[i],
|
377 |
+
kernel_size=3,
|
378 |
+
stride=1,
|
379 |
+
padding=1,
|
380 |
+
bias=False,
|
381 |
+
),
|
382 |
+
self.norm_layer(num_channels_cur_layer[i]),
|
383 |
+
nn.ReLU(inplace=relu_inplace),
|
384 |
+
)
|
385 |
+
)
|
386 |
else:
|
387 |
transition_layers.append(None)
|
388 |
else:
|
389 |
conv3x3s = []
|
390 |
for j in range(i + 1 - num_branches_pre):
|
391 |
inchannels = num_channels_pre_layer[-1]
|
392 |
+
outchannels = (
|
393 |
+
num_channels_cur_layer[i]
|
394 |
+
if j == i - num_branches_pre
|
395 |
+
else inchannels
|
396 |
+
)
|
397 |
+
conv3x3s.append(
|
398 |
+
nn.Sequential(
|
399 |
+
nn.Conv2d(
|
400 |
+
inchannels,
|
401 |
+
outchannels,
|
402 |
+
kernel_size=3,
|
403 |
+
stride=2,
|
404 |
+
padding=1,
|
405 |
+
bias=False,
|
406 |
+
),
|
407 |
+
self.norm_layer(outchannels),
|
408 |
+
nn.ReLU(inplace=relu_inplace),
|
409 |
+
)
|
410 |
+
)
|
411 |
transition_layers.append(nn.Sequential(*conv3x3s))
|
412 |
|
413 |
return nn.ModuleList(transition_layers)
|
|
|
416 |
downsample = None
|
417 |
if stride != 1 or inplanes != planes * block.expansion:
|
418 |
downsample = nn.Sequential(
|
419 |
+
nn.Conv2d(
|
420 |
+
inplanes,
|
421 |
+
planes * block.expansion,
|
422 |
+
kernel_size=1,
|
423 |
+
stride=stride,
|
424 |
+
bias=False,
|
425 |
+
),
|
426 |
self.norm_layer(planes * block.expansion),
|
427 |
)
|
428 |
|
429 |
layers = []
|
430 |
+
layers.append(
|
431 |
+
block(
|
432 |
+
inplanes,
|
433 |
+
planes,
|
434 |
+
stride,
|
435 |
+
downsample=downsample,
|
436 |
+
norm_layer=self.norm_layer,
|
437 |
+
)
|
438 |
+
)
|
439 |
inplanes = planes * block.expansion
|
440 |
for i in range(1, blocks):
|
441 |
layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
|
442 |
|
443 |
return nn.Sequential(*layers)
|
444 |
|
445 |
+
def _make_stage(
|
446 |
+
self,
|
447 |
+
block,
|
448 |
+
num_inchannels,
|
449 |
+
num_modules,
|
450 |
+
num_branches,
|
451 |
+
num_blocks,
|
452 |
+
num_channels,
|
453 |
+
fuse_method="SUM",
|
454 |
+
multi_scale_output=True,
|
455 |
+
):
|
456 |
modules = []
|
457 |
for i in range(num_modules):
|
458 |
# multi_scale_output is only used last module
|
|
|
461 |
else:
|
462 |
reset_multi_scale_output = True
|
463 |
modules.append(
|
464 |
+
HighResolutionModule(
|
465 |
+
num_branches,
|
466 |
+
block,
|
467 |
+
num_blocks,
|
468 |
+
num_inchannels,
|
469 |
+
num_channels,
|
470 |
+
fuse_method,
|
471 |
+
reset_multi_scale_output,
|
472 |
+
norm_layer=self.norm_layer,
|
473 |
+
align_corners=self.align_corners,
|
474 |
+
)
|
475 |
)
|
476 |
num_inchannels = modules[-1].get_num_inchannels()
|
477 |
|
|
|
539 |
def aggregate_hrnet_features(self, x):
|
540 |
# Upsampling
|
541 |
x0_h, x0_w = x[0].size(2), x[0].size(3)
|
542 |
+
x1 = F.interpolate(
|
543 |
+
x[1], size=(x0_h, x0_w), mode="bilinear", align_corners=self.align_corners
|
544 |
+
)
|
545 |
+
x2 = F.interpolate(
|
546 |
+
x[2], size=(x0_h, x0_w), mode="bilinear", align_corners=self.align_corners
|
547 |
+
)
|
548 |
+
x3 = F.interpolate(
|
549 |
+
x[3], size=(x0_h, x0_w), mode="bilinear", align_corners=self.align_corners
|
550 |
+
)
|
551 |
|
552 |
return torch.cat([x[0], x1, x2, x3], 1)
|
553 |
|
554 |
+
def load_pretrained_weights(self, pretrained_path=""):
|
555 |
model_dict = self.state_dict()
|
556 |
|
557 |
if not os.path.exists(pretrained_path):
|
558 |
print(f'\nFile "{pretrained_path}" does not exist.')
|
559 |
+
print(
|
560 |
+
"You need to specify the correct path to the pre-trained weights.\n"
|
561 |
+
"You can download the weights for HRNet from the repository:\n"
|
562 |
+
"https://github.com/HRNet/HRNet-Image-Classification"
|
563 |
+
)
|
564 |
exit(1)
|
565 |
+
pretrained_dict = torch.load(pretrained_path, map_location={"cuda:0": "cpu"})
|
566 |
+
pretrained_dict = {
|
567 |
+
k.replace("last_layer", "aux_head").replace("model.", ""): v
|
568 |
+
for k, v in pretrained_dict.items()
|
569 |
+
}
|
570 |
+
|
571 |
+
pretrained_dict = {
|
572 |
+
k: v for k, v in pretrained_dict.items() if k in model_dict.keys()
|
573 |
+
}
|
574 |
|
575 |
model_dict.update(pretrained_dict)
|
576 |
self.load_state_dict(model_dict)
|
isegm/model/modeling/ocr.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
import torch
|
2 |
-
import torch.nn as nn
|
3 |
import torch._utils
|
|
|
4 |
import torch.nn.functional as F
|
5 |
|
6 |
|
7 |
class SpatialGather_Module(nn.Module):
|
8 |
"""
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
"""
|
13 |
|
14 |
def __init__(self, cls_num=0, scale=1):
|
@@ -22,8 +22,9 @@ class SpatialGather_Module(nn.Module):
|
|
22 |
feats = feats.view(batch_size, feats.size(1), -1)
|
23 |
feats = feats.permute(0, 2, 1) # batch x hw x c
|
24 |
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
|
25 |
-
ocr_context =
|
26 |
-
.permute(0, 2, 1).unsqueeze(3)
|
|
|
27 |
return ocr_context
|
28 |
|
29 |
|
@@ -33,23 +34,26 @@ class SpatialOCR_Module(nn.Module):
|
|
33 |
We aggregate the global object representation to update the representation for each pixel.
|
34 |
"""
|
35 |
|
36 |
-
def __init__(
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
44 |
super(SpatialOCR_Module, self).__init__()
|
45 |
-
self.object_context_block = ObjectAttentionBlock2D(
|
46 |
-
|
|
|
47 |
_in_channels = 2 * in_channels
|
48 |
|
49 |
self.conv_bn_dropout = nn.Sequential(
|
50 |
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
|
51 |
nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
|
52 |
-
nn.Dropout2d(dropout)
|
53 |
)
|
54 |
|
55 |
def forward(self, feats, proxy_feats):
|
@@ -61,7 +65,7 @@ class SpatialOCR_Module(nn.Module):
|
|
61 |
|
62 |
|
63 |
class ObjectAttentionBlock2D(nn.Module):
|
64 |
-
|
65 |
The basic implementation for object context block
|
66 |
Input:
|
67 |
N X C X H X W
|
@@ -72,14 +76,16 @@ class ObjectAttentionBlock2D(nn.Module):
|
|
72 |
bn_type : specify the bn type
|
73 |
Return:
|
74 |
N X C X H X W
|
75 |
-
|
76 |
-
|
77 |
-
def __init__(
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
83 |
super(ObjectAttentionBlock2D, self).__init__()
|
84 |
self.scale = scale
|
85 |
self.in_channels = in_channels
|
@@ -88,30 +94,66 @@ class ObjectAttentionBlock2D(nn.Module):
|
|
88 |
|
89 |
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
|
90 |
self.f_pixel = nn.Sequential(
|
91 |
-
nn.Conv2d(
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
94 |
-
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
95 |
-
kernel_size=1, stride=1, padding=0, bias=False),
|
96 |
-
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
97 |
)
|
98 |
self.f_object = nn.Sequential(
|
99 |
-
nn.Conv2d(
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
102 |
-
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
103 |
-
kernel_size=1, stride=1, padding=0, bias=False),
|
104 |
-
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
105 |
)
|
106 |
self.f_down = nn.Sequential(
|
107 |
-
nn.Conv2d(
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
)
|
111 |
self.f_up = nn.Sequential(
|
112 |
-
nn.Conv2d(
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
)
|
116 |
|
117 |
def forward(self, x, proxy):
|
@@ -126,7 +168,7 @@ class ObjectAttentionBlock2D(nn.Module):
|
|
126 |
value = value.permute(0, 2, 1)
|
127 |
|
128 |
sim_map = torch.matmul(query, key)
|
129 |
-
sim_map = (self.key_channels
|
130 |
sim_map = F.softmax(sim_map, dim=-1)
|
131 |
|
132 |
# add bg context ...
|
@@ -135,7 +177,11 @@ class ObjectAttentionBlock2D(nn.Module):
|
|
135 |
context = context.view(batch_size, self.key_channels, *x.size()[2:])
|
136 |
context = self.f_up(context)
|
137 |
if self.scale > 1:
|
138 |
-
context = F.interpolate(
|
139 |
-
|
|
|
|
|
|
|
|
|
140 |
|
141 |
return context
|
|
|
1 |
import torch
|
|
|
2 |
import torch._utils
|
3 |
+
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
5 |
|
6 |
|
7 |
class SpatialGather_Module(nn.Module):
|
8 |
"""
|
9 |
+
Aggregate the context features according to the initial
|
10 |
+
predicted probability distribution.
|
11 |
+
Employ the soft-weighted method to aggregate the context.
|
12 |
"""
|
13 |
|
14 |
def __init__(self, cls_num=0, scale=1):
|
|
|
22 |
feats = feats.view(batch_size, feats.size(1), -1)
|
23 |
feats = feats.permute(0, 2, 1) # batch x hw x c
|
24 |
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
|
25 |
+
ocr_context = (
|
26 |
+
torch.matmul(probs, feats).permute(0, 2, 1).unsqueeze(3)
|
27 |
+
) # batch x k x c
|
28 |
return ocr_context
|
29 |
|
30 |
|
|
|
34 |
We aggregate the global object representation to update the representation for each pixel.
|
35 |
"""
|
36 |
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
in_channels,
|
40 |
+
key_channels,
|
41 |
+
out_channels,
|
42 |
+
scale=1,
|
43 |
+
dropout=0.1,
|
44 |
+
norm_layer=nn.BatchNorm2d,
|
45 |
+
align_corners=True,
|
46 |
+
):
|
47 |
super(SpatialOCR_Module, self).__init__()
|
48 |
+
self.object_context_block = ObjectAttentionBlock2D(
|
49 |
+
in_channels, key_channels, scale, norm_layer, align_corners
|
50 |
+
)
|
51 |
_in_channels = 2 * in_channels
|
52 |
|
53 |
self.conv_bn_dropout = nn.Sequential(
|
54 |
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
|
55 |
nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
|
56 |
+
nn.Dropout2d(dropout),
|
57 |
)
|
58 |
|
59 |
def forward(self, feats, proxy_feats):
|
|
|
65 |
|
66 |
|
67 |
class ObjectAttentionBlock2D(nn.Module):
|
68 |
+
"""
|
69 |
The basic implementation for object context block
|
70 |
Input:
|
71 |
N X C X H X W
|
|
|
76 |
bn_type : specify the bn type
|
77 |
Return:
|
78 |
N X C X H X W
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
in_channels,
|
84 |
+
key_channels,
|
85 |
+
scale=1,
|
86 |
+
norm_layer=nn.BatchNorm2d,
|
87 |
+
align_corners=True,
|
88 |
+
):
|
89 |
super(ObjectAttentionBlock2D, self).__init__()
|
90 |
self.scale = scale
|
91 |
self.in_channels = in_channels
|
|
|
94 |
|
95 |
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
|
96 |
self.f_pixel = nn.Sequential(
|
97 |
+
nn.Conv2d(
|
98 |
+
in_channels=self.in_channels,
|
99 |
+
out_channels=self.key_channels,
|
100 |
+
kernel_size=1,
|
101 |
+
stride=1,
|
102 |
+
padding=0,
|
103 |
+
bias=False,
|
104 |
+
),
|
105 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
106 |
+
nn.Conv2d(
|
107 |
+
in_channels=self.key_channels,
|
108 |
+
out_channels=self.key_channels,
|
109 |
+
kernel_size=1,
|
110 |
+
stride=1,
|
111 |
+
padding=0,
|
112 |
+
bias=False,
|
113 |
+
),
|
114 |
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
|
|
|
|
|
|
115 |
)
|
116 |
self.f_object = nn.Sequential(
|
117 |
+
nn.Conv2d(
|
118 |
+
in_channels=self.in_channels,
|
119 |
+
out_channels=self.key_channels,
|
120 |
+
kernel_size=1,
|
121 |
+
stride=1,
|
122 |
+
padding=0,
|
123 |
+
bias=False,
|
124 |
+
),
|
125 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
126 |
+
nn.Conv2d(
|
127 |
+
in_channels=self.key_channels,
|
128 |
+
out_channels=self.key_channels,
|
129 |
+
kernel_size=1,
|
130 |
+
stride=1,
|
131 |
+
padding=0,
|
132 |
+
bias=False,
|
133 |
+
),
|
134 |
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
|
|
|
|
|
|
135 |
)
|
136 |
self.f_down = nn.Sequential(
|
137 |
+
nn.Conv2d(
|
138 |
+
in_channels=self.in_channels,
|
139 |
+
out_channels=self.key_channels,
|
140 |
+
kernel_size=1,
|
141 |
+
stride=1,
|
142 |
+
padding=0,
|
143 |
+
bias=False,
|
144 |
+
),
|
145 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
146 |
)
|
147 |
self.f_up = nn.Sequential(
|
148 |
+
nn.Conv2d(
|
149 |
+
in_channels=self.key_channels,
|
150 |
+
out_channels=self.in_channels,
|
151 |
+
kernel_size=1,
|
152 |
+
stride=1,
|
153 |
+
padding=0,
|
154 |
+
bias=False,
|
155 |
+
),
|
156 |
+
nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)),
|
157 |
)
|
158 |
|
159 |
def forward(self, x, proxy):
|
|
|
168 |
value = value.permute(0, 2, 1)
|
169 |
|
170 |
sim_map = torch.matmul(query, key)
|
171 |
+
sim_map = (self.key_channels**-0.5) * sim_map
|
172 |
sim_map = F.softmax(sim_map, dim=-1)
|
173 |
|
174 |
# add bg context ...
|
|
|
177 |
context = context.view(batch_size, self.key_channels, *x.size()[2:])
|
178 |
context = self.f_up(context)
|
179 |
if self.scale > 1:
|
180 |
+
context = F.interpolate(
|
181 |
+
input=context,
|
182 |
+
size=(h, w),
|
183 |
+
mode="bilinear",
|
184 |
+
align_corners=self.align_corners,
|
185 |
+
)
|
186 |
|
187 |
return context
|
isegm/model/modeling/resnet.py
CHANGED
@@ -1,21 +1,32 @@
|
|
1 |
import torch
|
|
|
2 |
from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
|
3 |
|
4 |
|
5 |
class ResNetBackbone(torch.nn.Module):
|
6 |
-
def __init__(
|
|
|
|
|
7 |
super(ResNetBackbone, self).__init__()
|
8 |
|
9 |
-
if backbone ==
|
10 |
-
pretrained = resnet34_v1b(
|
11 |
-
|
12 |
-
|
13 |
-
elif backbone ==
|
14 |
-
pretrained =
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
else:
|
18 |
-
raise RuntimeError(f
|
19 |
|
20 |
self.conv1 = pretrained.conv1
|
21 |
self.bn1 = pretrained.bn1
|
@@ -31,9 +42,12 @@ class ResNetBackbone(torch.nn.Module):
|
|
31 |
x = self.bn1(x)
|
32 |
x = self.relu(x)
|
33 |
if additional_features is not None:
|
34 |
-
x = x + torch.nn.functional.pad(
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
37 |
x = self.maxpool(x)
|
38 |
c1 = self.layer1(x)
|
39 |
c2 = self.layer2(c1)
|
|
|
1 |
import torch
|
2 |
+
|
3 |
from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
|
4 |
|
5 |
|
6 |
class ResNetBackbone(torch.nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self, backbone="resnet50", pretrained_base=True, dilated=True, **kwargs
|
9 |
+
):
|
10 |
super(ResNetBackbone, self).__init__()
|
11 |
|
12 |
+
if backbone == "resnet34":
|
13 |
+
pretrained = resnet34_v1b(
|
14 |
+
pretrained=pretrained_base, dilated=dilated, **kwargs
|
15 |
+
)
|
16 |
+
elif backbone == "resnet50":
|
17 |
+
pretrained = resnet50_v1s(
|
18 |
+
pretrained=pretrained_base, dilated=dilated, **kwargs
|
19 |
+
)
|
20 |
+
elif backbone == "resnet101":
|
21 |
+
pretrained = resnet101_v1s(
|
22 |
+
pretrained=pretrained_base, dilated=dilated, **kwargs
|
23 |
+
)
|
24 |
+
elif backbone == "resnet152":
|
25 |
+
pretrained = resnet152_v1s(
|
26 |
+
pretrained=pretrained_base, dilated=dilated, **kwargs
|
27 |
+
)
|
28 |
else:
|
29 |
+
raise RuntimeError(f"unknown backbone: {backbone}")
|
30 |
|
31 |
self.conv1 = pretrained.conv1
|
32 |
self.bn1 = pretrained.bn1
|
|
|
42 |
x = self.bn1(x)
|
43 |
x = self.relu(x)
|
44 |
if additional_features is not None:
|
45 |
+
x = x + torch.nn.functional.pad(
|
46 |
+
additional_features,
|
47 |
+
[0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)],
|
48 |
+
mode="constant",
|
49 |
+
value=0,
|
50 |
+
)
|
51 |
x = self.maxpool(x)
|
52 |
c1 = self.layer1(x)
|
53 |
c2 = self.layer2(c1)
|
isegm/model/modeling/resnetv1b.py
CHANGED
@@ -1,19 +1,42 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
|
|
|
4 |
|
5 |
|
6 |
class BasicBlockV1b(nn.Module):
|
7 |
expansion = 1
|
8 |
|
9 |
-
def __init__(
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
super(BasicBlockV1b, self).__init__()
|
12 |
-
self.conv1 = nn.Conv2d(
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
self.bn1 = norm_layer(planes)
|
15 |
-
self.conv2 = nn.Conv2d(
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
self.bn2 = norm_layer(planes)
|
18 |
|
19 |
self.relu = nn.ReLU(inplace=True)
|
@@ -42,17 +65,34 @@ class BasicBlockV1b(nn.Module):
|
|
42 |
class BottleneckV1b(nn.Module):
|
43 |
expansion = 4
|
44 |
|
45 |
-
def __init__(
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
super(BottleneckV1b, self).__init__()
|
48 |
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
49 |
self.bn1 = norm_layer(planes)
|
50 |
|
51 |
-
self.conv2 = nn.Conv2d(
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
self.bn2 = norm_layer(planes)
|
54 |
|
55 |
-
self.conv3 = nn.Conv2d(
|
|
|
|
|
56 |
self.bn3 = norm_layer(planes * self.expansion)
|
57 |
|
58 |
self.relu = nn.ReLU(inplace=True)
|
@@ -83,7 +123,7 @@ class BottleneckV1b(nn.Module):
|
|
83 |
|
84 |
|
85 |
class ResNetV1b(nn.Module):
|
86 |
-
"""
|
87 |
|
88 |
Parameters
|
89 |
----------
|
@@ -111,86 +151,198 @@ class ResNetV1b(nn.Module):
|
|
111 |
|
112 |
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
|
113 |
"""
|
114 |
-
|
115 |
-
|
116 |
-
self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
super(ResNetV1b, self).__init__()
|
118 |
if not deep_stem:
|
119 |
-
self.conv1 = nn.Conv2d(
|
|
|
|
|
120 |
else:
|
121 |
self.conv1 = nn.Sequential(
|
122 |
-
nn.Conv2d(
|
|
|
|
|
123 |
norm_layer(stem_width),
|
124 |
nn.ReLU(True),
|
125 |
-
nn.Conv2d(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
norm_layer(stem_width),
|
127 |
nn.ReLU(True),
|
128 |
-
nn.Conv2d(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
)
|
130 |
self.bn1 = norm_layer(self.inplanes)
|
131 |
self.relu = nn.ReLU(True)
|
132 |
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
133 |
-
self.layer1 = self._make_layer(
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
137 |
if dilated:
|
138 |
-
self.layer3 = self._make_layer(
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
else:
|
143 |
-
self.layer3 = self._make_layer(
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
148 |
self.drop = None
|
149 |
if final_drop > 0.0:
|
150 |
self.drop = nn.Dropout(final_drop)
|
151 |
self.fc = nn.Linear(512 * block.expansion, classes)
|
152 |
|
153 |
-
def _make_layer(
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
downsample = None
|
156 |
if stride != 1 or self.inplanes != planes * block.expansion:
|
157 |
downsample = []
|
158 |
if avg_down:
|
159 |
if dilation == 1:
|
160 |
downsample.append(
|
161 |
-
nn.AvgPool2d(
|
|
|
|
|
|
|
|
|
|
|
162 |
)
|
163 |
else:
|
164 |
downsample.append(
|
165 |
-
nn.AvgPool2d(
|
|
|
|
|
|
|
|
|
|
|
166 |
)
|
167 |
-
downsample.extend(
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
downsample = nn.Sequential(*downsample)
|
173 |
else:
|
174 |
downsample = nn.Sequential(
|
175 |
-
nn.Conv2d(
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
178 |
)
|
179 |
|
180 |
layers = []
|
181 |
if dilation in (1, 2):
|
182 |
-
layers.append(
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
elif dilation == 4:
|
185 |
-
layers.append(
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
else:
|
188 |
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
|
189 |
|
190 |
self.inplanes = planes * block.expansion
|
191 |
for _ in range(1, blocks):
|
192 |
-
layers.append(
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
return nn.Sequential(*layers)
|
196 |
|
@@ -229,8 +381,10 @@ def resnet34_v1b(pretrained=False, **kwargs):
|
|
229 |
if pretrained:
|
230 |
model_dict = model.state_dict()
|
231 |
filtered_orig_dict = _safe_state_dict_filtering(
|
232 |
-
torch.hub.load(
|
233 |
-
|
|
|
|
|
234 |
)
|
235 |
model_dict.update(filtered_orig_dict)
|
236 |
model.load_state_dict(model_dict)
|
@@ -238,12 +392,16 @@ def resnet34_v1b(pretrained=False, **kwargs):
|
|
238 |
|
239 |
|
240 |
def resnet50_v1s(pretrained=False, **kwargs):
|
241 |
-
model = ResNetV1b(
|
|
|
|
|
242 |
if pretrained:
|
243 |
model_dict = model.state_dict()
|
244 |
filtered_orig_dict = _safe_state_dict_filtering(
|
245 |
-
torch.hub.load(
|
246 |
-
|
|
|
|
|
247 |
)
|
248 |
model_dict.update(filtered_orig_dict)
|
249 |
model.load_state_dict(model_dict)
|
@@ -251,12 +409,16 @@ def resnet50_v1s(pretrained=False, **kwargs):
|
|
251 |
|
252 |
|
253 |
def resnet101_v1s(pretrained=False, **kwargs):
|
254 |
-
model = ResNetV1b(
|
|
|
|
|
255 |
if pretrained:
|
256 |
model_dict = model.state_dict()
|
257 |
filtered_orig_dict = _safe_state_dict_filtering(
|
258 |
-
torch.hub.load(
|
259 |
-
|
|
|
|
|
260 |
)
|
261 |
model_dict.update(filtered_orig_dict)
|
262 |
model.load_state_dict(model_dict)
|
@@ -264,12 +426,16 @@ def resnet101_v1s(pretrained=False, **kwargs):
|
|
264 |
|
265 |
|
266 |
def resnet152_v1s(pretrained=False, **kwargs):
|
267 |
-
model = ResNetV1b(
|
|
|
|
|
268 |
if pretrained:
|
269 |
model_dict = model.state_dict()
|
270 |
filtered_orig_dict = _safe_state_dict_filtering(
|
271 |
-
torch.hub.load(
|
272 |
-
|
|
|
|
|
273 |
)
|
274 |
model_dict.update(filtered_orig_dict)
|
275 |
model.load_state_dict(model_dict)
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
|
4 |
+
GLUON_RESNET_TORCH_HUB = "rwightman/pytorch-pretrained-gluonresnet"
|
5 |
|
6 |
|
7 |
class BasicBlockV1b(nn.Module):
|
8 |
expansion = 1
|
9 |
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
inplanes,
|
13 |
+
planes,
|
14 |
+
stride=1,
|
15 |
+
dilation=1,
|
16 |
+
downsample=None,
|
17 |
+
previous_dilation=1,
|
18 |
+
norm_layer=nn.BatchNorm2d,
|
19 |
+
):
|
20 |
super(BasicBlockV1b, self).__init__()
|
21 |
+
self.conv1 = nn.Conv2d(
|
22 |
+
inplanes,
|
23 |
+
planes,
|
24 |
+
kernel_size=3,
|
25 |
+
stride=stride,
|
26 |
+
padding=dilation,
|
27 |
+
dilation=dilation,
|
28 |
+
bias=False,
|
29 |
+
)
|
30 |
self.bn1 = norm_layer(planes)
|
31 |
+
self.conv2 = nn.Conv2d(
|
32 |
+
planes,
|
33 |
+
planes,
|
34 |
+
kernel_size=3,
|
35 |
+
stride=1,
|
36 |
+
padding=previous_dilation,
|
37 |
+
dilation=previous_dilation,
|
38 |
+
bias=False,
|
39 |
+
)
|
40 |
self.bn2 = norm_layer(planes)
|
41 |
|
42 |
self.relu = nn.ReLU(inplace=True)
|
|
|
65 |
class BottleneckV1b(nn.Module):
|
66 |
expansion = 4
|
67 |
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
inplanes,
|
71 |
+
planes,
|
72 |
+
stride=1,
|
73 |
+
dilation=1,
|
74 |
+
downsample=None,
|
75 |
+
previous_dilation=1,
|
76 |
+
norm_layer=nn.BatchNorm2d,
|
77 |
+
):
|
78 |
super(BottleneckV1b, self).__init__()
|
79 |
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
80 |
self.bn1 = norm_layer(planes)
|
81 |
|
82 |
+
self.conv2 = nn.Conv2d(
|
83 |
+
planes,
|
84 |
+
planes,
|
85 |
+
kernel_size=3,
|
86 |
+
stride=stride,
|
87 |
+
padding=dilation,
|
88 |
+
dilation=dilation,
|
89 |
+
bias=False,
|
90 |
+
)
|
91 |
self.bn2 = norm_layer(planes)
|
92 |
|
93 |
+
self.conv3 = nn.Conv2d(
|
94 |
+
planes, planes * self.expansion, kernel_size=1, bias=False
|
95 |
+
)
|
96 |
self.bn3 = norm_layer(planes * self.expansion)
|
97 |
|
98 |
self.relu = nn.ReLU(inplace=True)
|
|
|
123 |
|
124 |
|
125 |
class ResNetV1b(nn.Module):
|
126 |
+
"""Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
|
127 |
|
128 |
Parameters
|
129 |
----------
|
|
|
151 |
|
152 |
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
|
153 |
"""
|
154 |
+
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
block,
|
158 |
+
layers,
|
159 |
+
classes=1000,
|
160 |
+
dilated=True,
|
161 |
+
deep_stem=False,
|
162 |
+
stem_width=32,
|
163 |
+
avg_down=False,
|
164 |
+
final_drop=0.0,
|
165 |
+
norm_layer=nn.BatchNorm2d,
|
166 |
+
):
|
167 |
+
self.inplanes = stem_width * 2 if deep_stem else 64
|
168 |
super(ResNetV1b, self).__init__()
|
169 |
if not deep_stem:
|
170 |
+
self.conv1 = nn.Conv2d(
|
171 |
+
3, 64, kernel_size=7, stride=2, padding=3, bias=False
|
172 |
+
)
|
173 |
else:
|
174 |
self.conv1 = nn.Sequential(
|
175 |
+
nn.Conv2d(
|
176 |
+
3, stem_width, kernel_size=3, stride=2, padding=1, bias=False
|
177 |
+
),
|
178 |
norm_layer(stem_width),
|
179 |
nn.ReLU(True),
|
180 |
+
nn.Conv2d(
|
181 |
+
stem_width,
|
182 |
+
stem_width,
|
183 |
+
kernel_size=3,
|
184 |
+
stride=1,
|
185 |
+
padding=1,
|
186 |
+
bias=False,
|
187 |
+
),
|
188 |
norm_layer(stem_width),
|
189 |
nn.ReLU(True),
|
190 |
+
nn.Conv2d(
|
191 |
+
stem_width,
|
192 |
+
2 * stem_width,
|
193 |
+
kernel_size=3,
|
194 |
+
stride=1,
|
195 |
+
padding=1,
|
196 |
+
bias=False,
|
197 |
+
),
|
198 |
)
|
199 |
self.bn1 = norm_layer(self.inplanes)
|
200 |
self.relu = nn.ReLU(True)
|
201 |
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
202 |
+
self.layer1 = self._make_layer(
|
203 |
+
block, 64, layers[0], avg_down=avg_down, norm_layer=norm_layer
|
204 |
+
)
|
205 |
+
self.layer2 = self._make_layer(
|
206 |
+
block, 128, layers[1], stride=2, avg_down=avg_down, norm_layer=norm_layer
|
207 |
+
)
|
208 |
if dilated:
|
209 |
+
self.layer3 = self._make_layer(
|
210 |
+
block,
|
211 |
+
256,
|
212 |
+
layers[2],
|
213 |
+
stride=1,
|
214 |
+
dilation=2,
|
215 |
+
avg_down=avg_down,
|
216 |
+
norm_layer=norm_layer,
|
217 |
+
)
|
218 |
+
self.layer4 = self._make_layer(
|
219 |
+
block,
|
220 |
+
512,
|
221 |
+
layers[3],
|
222 |
+
stride=1,
|
223 |
+
dilation=4,
|
224 |
+
avg_down=avg_down,
|
225 |
+
norm_layer=norm_layer,
|
226 |
+
)
|
227 |
else:
|
228 |
+
self.layer3 = self._make_layer(
|
229 |
+
block,
|
230 |
+
256,
|
231 |
+
layers[2],
|
232 |
+
stride=2,
|
233 |
+
avg_down=avg_down,
|
234 |
+
norm_layer=norm_layer,
|
235 |
+
)
|
236 |
+
self.layer4 = self._make_layer(
|
237 |
+
block,
|
238 |
+
512,
|
239 |
+
layers[3],
|
240 |
+
stride=2,
|
241 |
+
avg_down=avg_down,
|
242 |
+
norm_layer=norm_layer,
|
243 |
+
)
|
244 |
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
245 |
self.drop = None
|
246 |
if final_drop > 0.0:
|
247 |
self.drop = nn.Dropout(final_drop)
|
248 |
self.fc = nn.Linear(512 * block.expansion, classes)
|
249 |
|
250 |
+
def _make_layer(
|
251 |
+
self,
|
252 |
+
block,
|
253 |
+
planes,
|
254 |
+
blocks,
|
255 |
+
stride=1,
|
256 |
+
dilation=1,
|
257 |
+
avg_down=False,
|
258 |
+
norm_layer=nn.BatchNorm2d,
|
259 |
+
):
|
260 |
downsample = None
|
261 |
if stride != 1 or self.inplanes != planes * block.expansion:
|
262 |
downsample = []
|
263 |
if avg_down:
|
264 |
if dilation == 1:
|
265 |
downsample.append(
|
266 |
+
nn.AvgPool2d(
|
267 |
+
kernel_size=stride,
|
268 |
+
stride=stride,
|
269 |
+
ceil_mode=True,
|
270 |
+
count_include_pad=False,
|
271 |
+
)
|
272 |
)
|
273 |
else:
|
274 |
downsample.append(
|
275 |
+
nn.AvgPool2d(
|
276 |
+
kernel_size=1,
|
277 |
+
stride=1,
|
278 |
+
ceil_mode=True,
|
279 |
+
count_include_pad=False,
|
280 |
+
)
|
281 |
)
|
282 |
+
downsample.extend(
|
283 |
+
[
|
284 |
+
nn.Conv2d(
|
285 |
+
self.inplanes,
|
286 |
+
out_channels=planes * block.expansion,
|
287 |
+
kernel_size=1,
|
288 |
+
stride=1,
|
289 |
+
bias=False,
|
290 |
+
),
|
291 |
+
norm_layer(planes * block.expansion),
|
292 |
+
]
|
293 |
+
)
|
294 |
downsample = nn.Sequential(*downsample)
|
295 |
else:
|
296 |
downsample = nn.Sequential(
|
297 |
+
nn.Conv2d(
|
298 |
+
self.inplanes,
|
299 |
+
out_channels=planes * block.expansion,
|
300 |
+
kernel_size=1,
|
301 |
+
stride=stride,
|
302 |
+
bias=False,
|
303 |
+
),
|
304 |
+
norm_layer(planes * block.expansion),
|
305 |
)
|
306 |
|
307 |
layers = []
|
308 |
if dilation in (1, 2):
|
309 |
+
layers.append(
|
310 |
+
block(
|
311 |
+
self.inplanes,
|
312 |
+
planes,
|
313 |
+
stride,
|
314 |
+
dilation=1,
|
315 |
+
downsample=downsample,
|
316 |
+
previous_dilation=dilation,
|
317 |
+
norm_layer=norm_layer,
|
318 |
+
)
|
319 |
+
)
|
320 |
elif dilation == 4:
|
321 |
+
layers.append(
|
322 |
+
block(
|
323 |
+
self.inplanes,
|
324 |
+
planes,
|
325 |
+
stride,
|
326 |
+
dilation=2,
|
327 |
+
downsample=downsample,
|
328 |
+
previous_dilation=dilation,
|
329 |
+
norm_layer=norm_layer,
|
330 |
+
)
|
331 |
+
)
|
332 |
else:
|
333 |
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
|
334 |
|
335 |
self.inplanes = planes * block.expansion
|
336 |
for _ in range(1, blocks):
|
337 |
+
layers.append(
|
338 |
+
block(
|
339 |
+
self.inplanes,
|
340 |
+
planes,
|
341 |
+
dilation=dilation,
|
342 |
+
previous_dilation=dilation,
|
343 |
+
norm_layer=norm_layer,
|
344 |
+
)
|
345 |
+
)
|
346 |
|
347 |
return nn.Sequential(*layers)
|
348 |
|
|
|
381 |
if pretrained:
|
382 |
model_dict = model.state_dict()
|
383 |
filtered_orig_dict = _safe_state_dict_filtering(
|
384 |
+
torch.hub.load(
|
385 |
+
GLUON_RESNET_TORCH_HUB, "gluon_resnet34_v1b", pretrained=True
|
386 |
+
).state_dict(),
|
387 |
+
model_dict.keys(),
|
388 |
)
|
389 |
model_dict.update(filtered_orig_dict)
|
390 |
model.load_state_dict(model_dict)
|
|
|
392 |
|
393 |
|
394 |
def resnet50_v1s(pretrained=False, **kwargs):
|
395 |
+
model = ResNetV1b(
|
396 |
+
BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs
|
397 |
+
)
|
398 |
if pretrained:
|
399 |
model_dict = model.state_dict()
|
400 |
filtered_orig_dict = _safe_state_dict_filtering(
|
401 |
+
torch.hub.load(
|
402 |
+
GLUON_RESNET_TORCH_HUB, "gluon_resnet50_v1s", pretrained=True
|
403 |
+
).state_dict(),
|
404 |
+
model_dict.keys(),
|
405 |
)
|
406 |
model_dict.update(filtered_orig_dict)
|
407 |
model.load_state_dict(model_dict)
|
|
|
409 |
|
410 |
|
411 |
def resnet101_v1s(pretrained=False, **kwargs):
|
412 |
+
model = ResNetV1b(
|
413 |
+
BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs
|
414 |
+
)
|
415 |
if pretrained:
|
416 |
model_dict = model.state_dict()
|
417 |
filtered_orig_dict = _safe_state_dict_filtering(
|
418 |
+
torch.hub.load(
|
419 |
+
GLUON_RESNET_TORCH_HUB, "gluon_resnet101_v1s", pretrained=True
|
420 |
+
).state_dict(),
|
421 |
+
model_dict.keys(),
|
422 |
)
|
423 |
model_dict.update(filtered_orig_dict)
|
424 |
model.load_state_dict(model_dict)
|
|
|
426 |
|
427 |
|
428 |
def resnet152_v1s(pretrained=False, **kwargs):
|
429 |
+
model = ResNetV1b(
|
430 |
+
BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs
|
431 |
+
)
|
432 |
if pretrained:
|
433 |
model_dict = model.state_dict()
|
434 |
filtered_orig_dict = _safe_state_dict_filtering(
|
435 |
+
torch.hub.load(
|
436 |
+
GLUON_RESNET_TORCH_HUB, "gluon_resnet152_v1s", pretrained=True
|
437 |
+
).state_dict(),
|
438 |
+
model_dict.keys(),
|
439 |
)
|
440 |
model_dict.update(filtered_orig_dict)
|
441 |
model.load_state_dict(model_dict)
|
isegm/model/modifiers.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
-
|
2 |
-
|
3 |
class LRMult(object):
|
4 |
-
def __init__(self, lr_mult=1.):
|
5 |
self.lr_mult = lr_mult
|
6 |
|
7 |
def __call__(self, m):
|
8 |
-
if getattr(m,
|
9 |
m.weight.lr_mult = self.lr_mult
|
10 |
-
if getattr(m,
|
11 |
m.bias.lr_mult = self.lr_mult
|
|
|
|
|
|
|
1 |
class LRMult(object):
|
2 |
+
def __init__(self, lr_mult=1.0):
|
3 |
self.lr_mult = lr_mult
|
4 |
|
5 |
def __call__(self, m):
|
6 |
+
if getattr(m, "weight", None) is not None:
|
7 |
m.weight.lr_mult = self.lr_mult
|
8 |
+
if getattr(m, "bias", None) is not None:
|
9 |
m.bias.lr_mult = self.lr_mult
|
isegm/model/ops.py
CHANGED
@@ -1,14 +1,15 @@
|
|
|
|
1 |
import torch
|
2 |
from torch import nn as nn
|
3 |
-
|
4 |
import isegm.model.initializer as initializer
|
5 |
|
6 |
|
7 |
def select_activation_function(activation):
|
8 |
if isinstance(activation, str):
|
9 |
-
if activation.lower() ==
|
10 |
return nn.ReLU
|
11 |
-
elif activation.lower() ==
|
12 |
return nn.Softplus
|
13 |
else:
|
14 |
raise ValueError(f"Unknown activation type {activation}")
|
@@ -24,14 +25,18 @@ class BilinearConvTranspose2d(nn.ConvTranspose2d):
|
|
24 |
self.scale = scale
|
25 |
|
26 |
super().__init__(
|
27 |
-
in_channels,
|
|
|
28 |
kernel_size=kernel_size,
|
29 |
stride=scale,
|
30 |
padding=1,
|
31 |
groups=groups,
|
32 |
-
bias=False
|
|
|
33 |
|
34 |
-
self.apply(
|
|
|
|
|
35 |
|
36 |
|
37 |
class DistMaps(nn.Module):
|
@@ -43,29 +48,47 @@ class DistMaps(nn.Module):
|
|
43 |
self.use_disks = use_disks
|
44 |
if self.cpu_mode:
|
45 |
from isegm.utils.cython import get_dist_maps
|
|
|
46 |
self._get_dist_maps = get_dist_maps
|
47 |
|
48 |
def get_coord_features(self, points, batchsize, rows, cols):
|
49 |
if self.cpu_mode:
|
50 |
coords = []
|
51 |
for i in range(batchsize):
|
52 |
-
norm_delimeter =
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
else:
|
57 |
num_points = points.shape[1] // 2
|
58 |
points = points.view(-1, points.size(2))
|
59 |
points, points_order = torch.split(points, [2, 1], dim=1)
|
60 |
|
61 |
invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0
|
62 |
-
row_array = torch.arange(
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
|
65 |
coord_rows, coord_cols = torch.meshgrid(row_array, col_array)
|
66 |
-
coords =
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
coords.add_(-add_xy)
|
70 |
if not self.use_disks:
|
71 |
coords.div_(self.norm_radius * self.spatial_scale)
|
|
|
1 |
+
import numpy as np
|
2 |
import torch
|
3 |
from torch import nn as nn
|
4 |
+
|
5 |
import isegm.model.initializer as initializer
|
6 |
|
7 |
|
8 |
def select_activation_function(activation):
|
9 |
if isinstance(activation, str):
|
10 |
+
if activation.lower() == "relu":
|
11 |
return nn.ReLU
|
12 |
+
elif activation.lower() == "softplus":
|
13 |
return nn.Softplus
|
14 |
else:
|
15 |
raise ValueError(f"Unknown activation type {activation}")
|
|
|
25 |
self.scale = scale
|
26 |
|
27 |
super().__init__(
|
28 |
+
in_channels,
|
29 |
+
out_channels,
|
30 |
kernel_size=kernel_size,
|
31 |
stride=scale,
|
32 |
padding=1,
|
33 |
groups=groups,
|
34 |
+
bias=False,
|
35 |
+
)
|
36 |
|
37 |
+
self.apply(
|
38 |
+
initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups)
|
39 |
+
)
|
40 |
|
41 |
|
42 |
class DistMaps(nn.Module):
|
|
|
48 |
self.use_disks = use_disks
|
49 |
if self.cpu_mode:
|
50 |
from isegm.utils.cython import get_dist_maps
|
51 |
+
|
52 |
self._get_dist_maps = get_dist_maps
|
53 |
|
54 |
def get_coord_features(self, points, batchsize, rows, cols):
|
55 |
if self.cpu_mode:
|
56 |
coords = []
|
57 |
for i in range(batchsize):
|
58 |
+
norm_delimeter = (
|
59 |
+
1.0 if self.use_disks else self.spatial_scale * self.norm_radius
|
60 |
+
)
|
61 |
+
coords.append(
|
62 |
+
self._get_dist_maps(
|
63 |
+
points[i].cpu().float().numpy(), rows, cols, norm_delimeter
|
64 |
+
)
|
65 |
+
)
|
66 |
+
coords = (
|
67 |
+
torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float()
|
68 |
+
)
|
69 |
else:
|
70 |
num_points = points.shape[1] // 2
|
71 |
points = points.view(-1, points.size(2))
|
72 |
points, points_order = torch.split(points, [2, 1], dim=1)
|
73 |
|
74 |
invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0
|
75 |
+
row_array = torch.arange(
|
76 |
+
start=0, end=rows, step=1, dtype=torch.float32, device=points.device
|
77 |
+
)
|
78 |
+
col_array = torch.arange(
|
79 |
+
start=0, end=cols, step=1, dtype=torch.float32, device=points.device
|
80 |
+
)
|
81 |
|
82 |
coord_rows, coord_cols = torch.meshgrid(row_array, col_array)
|
83 |
+
coords = (
|
84 |
+
torch.stack((coord_rows, coord_cols), dim=0)
|
85 |
+
.unsqueeze(0)
|
86 |
+
.repeat(points.size(0), 1, 1, 1)
|
87 |
+
)
|
88 |
+
|
89 |
+
add_xy = (points * self.spatial_scale).view(
|
90 |
+
points.size(0), points.size(1), 1, 1
|
91 |
+
)
|
92 |
coords.add_(-add_xy)
|
93 |
if not self.use_disks:
|
94 |
coords.div_(self.norm_radius * self.spatial_scale)
|
isegm/utils/cython/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
# noinspection PyUnresolvedReferences
|
2 |
-
from .dist_maps import get_dist_maps
|
|
|
1 |
# noinspection PyUnresolvedReferences
|
2 |
+
from .dist_maps import get_dist_maps
|
isegm/utils/cython/_get_dist_maps.pyx
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import numpy as np
|
|
|
2 |
cimport cython
|
3 |
cimport numpy as np
|
4 |
-
from libc.stdlib cimport
|
5 |
|
6 |
ctypedef struct qnode:
|
7 |
int row
|
|
|
1 |
import numpy as np
|
2 |
+
|
3 |
cimport cython
|
4 |
cimport numpy as np
|
5 |
+
from libc.stdlib cimport free, malloc
|
6 |
|
7 |
ctypedef struct qnode:
|
8 |
int row
|
isegm/utils/cython/dist_maps.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
-
import pyximport
|
|
|
|
|
2 |
# noinspection PyUnresolvedReferences
|
3 |
-
from ._get_dist_maps import get_dist_maps
|
|
|
1 |
+
import pyximport
|
2 |
+
|
3 |
+
pyximport.install(pyximport=True, language_level=3)
|
4 |
# noinspection PyUnresolvedReferences
|
5 |
+
from ._get_dist_maps import get_dist_maps
|
isegm/utils/distributed.py
CHANGED
@@ -10,7 +10,11 @@ def get_rank():
|
|
10 |
|
11 |
|
12 |
def synchronize():
|
13 |
-
if
|
|
|
|
|
|
|
|
|
14 |
return
|
15 |
dist.barrier()
|
16 |
|
@@ -58,10 +62,15 @@ def get_sampler(dataset, shuffle, distributed):
|
|
58 |
|
59 |
|
60 |
def get_dp_wrapper(distributed):
|
61 |
-
class DPWrapper(
|
|
|
|
|
|
|
|
|
62 |
def __getattr__(self, name):
|
63 |
try:
|
64 |
return super().__getattr__(name)
|
65 |
except AttributeError:
|
66 |
return getattr(self.module, name)
|
|
|
67 |
return DPWrapper
|
|
|
10 |
|
11 |
|
12 |
def synchronize():
|
13 |
+
if (
|
14 |
+
not dist.is_available()
|
15 |
+
or not dist.is_initialized()
|
16 |
+
or dist.get_world_size() == 1
|
17 |
+
):
|
18 |
return
|
19 |
dist.barrier()
|
20 |
|
|
|
62 |
|
63 |
|
64 |
def get_dp_wrapper(distributed):
|
65 |
+
class DPWrapper(
|
66 |
+
torch.nn.parallel.DistributedDataParallel
|
67 |
+
if distributed
|
68 |
+
else torch.nn.DataParallel
|
69 |
+
):
|
70 |
def __getattr__(self, name):
|
71 |
try:
|
72 |
return super().__getattr__(name)
|
73 |
except AttributeError:
|
74 |
return getattr(self.module, name)
|
75 |
+
|
76 |
return DPWrapper
|