aryanxxvii commited on
Commit
330b9b4
·
1 Parent(s): 87e60c4
Files changed (4) hide show
  1. data_transforms.py +266 -0
  2. push_to_hf.py +9 -0
  3. requirements.txt +7 -0
  4. u2net.py +610 -0
data_transforms.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data loader
2
+ from __future__ import print_function, division
3
+ import glob
4
+ import torch
5
+ from skimage import io, transform, color
6
+ import numpy as np
7
+ import random
8
+ import math
9
+ import matplotlib.pyplot as plt
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision import transforms, utils
12
+ from PIL import Image
13
+
14
+ #==========================dataset load==========================
15
+ class RescaleT(object):
16
+
17
+ def __init__(self,output_size):
18
+ assert isinstance(output_size,(int,tuple))
19
+ self.output_size = output_size
20
+
21
+ def __call__(self,sample):
22
+ imidx, image, label = sample['imidx'], sample['image'],sample['label']
23
+
24
+ h, w = image.shape[:2]
25
+
26
+ if isinstance(self.output_size,int):
27
+ if h > w:
28
+ new_h, new_w = self.output_size*h/w,self.output_size
29
+ else:
30
+ new_h, new_w = self.output_size,self.output_size*w/h
31
+ else:
32
+ new_h, new_w = self.output_size
33
+
34
+ new_h, new_w = int(new_h), int(new_w)
35
+
36
+ # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
37
+ # img = transform.resize(image,(new_h,new_w),mode='constant')
38
+ # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
39
+
40
+ img = transform.resize(image,(self.output_size,self.output_size),mode='constant')
41
+ lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)
42
+
43
+ return {'imidx':imidx, 'image':img,'label':lbl}
44
+
45
+ class Rescale(object):
46
+
47
+ def __init__(self,output_size):
48
+ assert isinstance(output_size,(int,tuple))
49
+ self.output_size = output_size
50
+
51
+ def __call__(self,sample):
52
+ imidx, image, label = sample['imidx'], sample['image'],sample['label']
53
+
54
+ if random.random() >= 0.5:
55
+ image = image[::-1]
56
+ label = label[::-1]
57
+
58
+ h, w = image.shape[:2]
59
+
60
+ if isinstance(self.output_size,int):
61
+ if h > w:
62
+ new_h, new_w = self.output_size*h/w,self.output_size
63
+ else:
64
+ new_h, new_w = self.output_size,self.output_size*w/h
65
+ else:
66
+ new_h, new_w = self.output_size
67
+
68
+ new_h, new_w = int(new_h), int(new_w)
69
+
70
+ # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
71
+ img = transform.resize(image,(new_h,new_w),mode='constant')
72
+ lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
73
+
74
+ return {'imidx':imidx, 'image':img,'label':lbl}
75
+
76
+ class RandomCrop(object):
77
+
78
+ def __init__(self,output_size):
79
+ assert isinstance(output_size, (int, tuple))
80
+ if isinstance(output_size, int):
81
+ self.output_size = (output_size, output_size)
82
+ else:
83
+ assert len(output_size) == 2
84
+ self.output_size = output_size
85
+ def __call__(self,sample):
86
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
87
+
88
+ if random.random() >= 0.5:
89
+ image = image[::-1]
90
+ label = label[::-1]
91
+
92
+ h, w = image.shape[:2]
93
+ new_h, new_w = self.output_size
94
+
95
+ top = np.random.randint(0, h - new_h)
96
+ left = np.random.randint(0, w - new_w)
97
+
98
+ image = image[top: top + new_h, left: left + new_w]
99
+ label = label[top: top + new_h, left: left + new_w]
100
+
101
+ return {'imidx':imidx,'image':image, 'label':label}
102
+
103
+ class ToTensor(object):
104
+ """Convert ndarrays in sample to Tensors."""
105
+
106
+ def __call__(self, sample):
107
+
108
+ imidx, image, label = sample['imidx'], sample['image'], sample['label']
109
+
110
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
111
+ tmpLbl = np.zeros(label.shape)
112
+
113
+ image = image/np.max(image)
114
+ if(np.max(label)<1e-6):
115
+ label = label
116
+ else:
117
+ label = label/np.max(label)
118
+
119
+ if image.shape[2]==1:
120
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
121
+ tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
122
+ tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
123
+ else:
124
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
125
+ tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
126
+ tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
127
+
128
+ tmpLbl[:,:,0] = label[:,:,0]
129
+
130
+
131
+ tmpImg = tmpImg.transpose((2, 0, 1))
132
+ tmpLbl = label.transpose((2, 0, 1))
133
+
134
+ return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
135
+
136
+ class ToTensorLab(object):
137
+ """Convert ndarrays in sample to Tensors."""
138
+ def __init__(self,flag=0):
139
+ self.flag = flag
140
+
141
+ def __call__(self, sample):
142
+
143
+ imidx, image, label =sample['imidx'], sample['image'], sample['label']
144
+
145
+ tmpLbl = np.zeros(label.shape)
146
+
147
+ if(np.max(label)<1e-6):
148
+ label = label
149
+ else:
150
+ label = label/np.max(label)
151
+
152
+ # change the color space
153
+ if self.flag == 2: # with rgb and Lab colors
154
+ tmpImg = np.zeros((image.shape[0],image.shape[1],6))
155
+ tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
156
+ if image.shape[2]==1:
157
+ tmpImgt[:,:,0] = image[:,:,0]
158
+ tmpImgt[:,:,1] = image[:,:,0]
159
+ tmpImgt[:,:,2] = image[:,:,0]
160
+ else:
161
+ tmpImgt = image
162
+ tmpImgtl = color.rgb2lab(tmpImgt)
163
+
164
+ # nomalize image to range [0,1]
165
+ tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
166
+ tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
167
+ tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
168
+ tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
169
+ tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
170
+ tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))
171
+
172
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
173
+
174
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
175
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
176
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
177
+ tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
178
+ tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
179
+ tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])
180
+
181
+ elif self.flag == 1: #with Lab color
182
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
183
+
184
+ if image.shape[2]==1:
185
+ tmpImg[:,:,0] = image[:,:,0]
186
+ tmpImg[:,:,1] = image[:,:,0]
187
+ tmpImg[:,:,2] = image[:,:,0]
188
+ else:
189
+ tmpImg = image
190
+
191
+ tmpImg = color.rgb2lab(tmpImg)
192
+
193
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
194
+
195
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
196
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
197
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))
198
+
199
+ tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
200
+ tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
201
+ tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
202
+
203
+ else: # with rgb color
204
+ tmpImg = np.zeros((image.shape[0],image.shape[1],3))
205
+ image = image/np.max(image)
206
+ if image.shape[2]==1:
207
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
208
+ tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
209
+ tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
210
+ else:
211
+ tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
212
+ tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
213
+ tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
214
+
215
+ tmpLbl[:,:,0] = label[:,:,0]
216
+
217
+
218
+ tmpImg = tmpImg.transpose((2, 0, 1))
219
+ tmpLbl = label.transpose((2, 0, 1))
220
+
221
+ return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
222
+
223
+ class SalObjDataset(Dataset):
224
+ def __init__(self,img_name_list,lbl_name_list,transform=None):
225
+ # self.root_dir = root_dir
226
+ # self.image_name_list = glob.glob(image_dir+'*.png')
227
+ # self.label_name_list = glob.glob(label_dir+'*.png')
228
+ self.image_name_list = img_name_list
229
+ self.label_name_list = lbl_name_list
230
+ self.transform = transform
231
+
232
+ def __len__(self):
233
+ return len(self.image_name_list)
234
+
235
+ def __getitem__(self,idx):
236
+
237
+ # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
238
+ # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
239
+
240
+ image = io.imread(self.image_name_list[idx])
241
+ imname = self.image_name_list[idx]
242
+ imidx = np.array([idx])
243
+
244
+ if(0==len(self.label_name_list)):
245
+ label_3 = np.zeros(image.shape)
246
+ else:
247
+ label_3 = io.imread(self.label_name_list[idx])
248
+
249
+ label = np.zeros(label_3.shape[0:2])
250
+ if(3==len(label_3.shape)):
251
+ label = label_3[:,:,0]
252
+ elif(2==len(label_3.shape)):
253
+ label = label_3
254
+
255
+ if(3==len(image.shape) and 2==len(label.shape)):
256
+ label = label[:,:,np.newaxis]
257
+ elif(2==len(image.shape) and 2==len(label.shape)):
258
+ image = image[:,:,np.newaxis]
259
+ label = label[:,:,np.newaxis]
260
+
261
+ sample = {'imidx':imidx, 'image':image, 'label':label}
262
+
263
+ if self.transform:
264
+ sample = self.transform(sample)
265
+
266
+ return sample
push_to_hf.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import Repository
2
+
3
+ # Replace with your Hugging Face username and the repository name you created
4
+ repo = Repository(local_dir="clearbg", clone_from="totoshi/clearbg")
5
+
6
+ # Add all files and push to Hugging Face
7
+ # repo.git_add()
8
+ # repo.git_commit("Upload U2NET model for background removal")
9
+ # repo.git_push()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ pillow
5
+ matplotlib
6
+ scikit-image
7
+ huggingface-hub
u2net.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import torch.nn.functional as F
5
+
6
+
7
+ bce_loss = nn.BCELoss(size_average=True)
8
+ def muti_loss_fusion(preds, target):
9
+ loss0 = 0.0
10
+ loss = 0.0
11
+
12
+ for i in range(0,len(preds)):
13
+ # print("i: ", i, preds[i].shape)
14
+ if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
15
+ # tmp_target = _upsample_like(target,preds[i])
16
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
17
+ loss = loss + bce_loss(preds[i],tmp_target)
18
+ else:
19
+ loss = loss + bce_loss(preds[i],target)
20
+ if(i==0):
21
+ loss0 = loss
22
+ return loss0, loss
23
+
24
+ fea_loss = nn.MSELoss(size_average=True)
25
+ kl_loss = nn.KLDivLoss(size_average=True)
26
+ l1_loss = nn.L1Loss(size_average=True)
27
+ smooth_l1_loss = nn.SmoothL1Loss(size_average=True)
28
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
29
+ loss0 = 0.0
30
+ loss = 0.0
31
+
32
+ for i in range(0,len(preds)):
33
+ # print("i: ", i, preds[i].shape)
34
+ if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
35
+ # tmp_target = _upsample_like(target,preds[i])
36
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
37
+ loss = loss + bce_loss(preds[i],tmp_target)
38
+ else:
39
+ loss = loss + bce_loss(preds[i],target)
40
+ if(i==0):
41
+ loss0 = loss
42
+
43
+ for i in range(0,len(dfs)):
44
+ if(mode=='MSE'):
45
+ loss = loss + fea_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints
46
+ # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
47
+ elif(mode=='KL'):
48
+ loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1))
49
+ # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
50
+ elif(mode=='MAE'):
51
+ loss = loss + l1_loss(dfs[i],fs[i])
52
+ # print("ls_loss: ", l1_loss(dfs[i],fs[i]))
53
+ elif(mode=='SmoothL1'):
54
+ loss = loss + smooth_l1_loss(dfs[i],fs[i])
55
+ # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
56
+
57
+ return loss0, loss
58
+
59
+ class REBNCONV(nn.Module):
60
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
61
+ super(REBNCONV,self).__init__()
62
+
63
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
64
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
65
+ self.relu_s1 = nn.ReLU(inplace=True)
66
+
67
+ def forward(self,x):
68
+
69
+ hx = x
70
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
71
+
72
+ return xout
73
+
74
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
75
+ def _upsample_like(src,tar):
76
+
77
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
78
+
79
+ return src
80
+
81
+
82
+ ### RSU-7 ###
83
+ class RSU7(nn.Module):
84
+
85
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
86
+ super(RSU7,self).__init__()
87
+
88
+ self.in_ch = in_ch
89
+ self.mid_ch = mid_ch
90
+ self.out_ch = out_ch
91
+
92
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
93
+
94
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
95
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
96
+
97
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
98
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
99
+
100
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
101
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
102
+
103
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
104
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
105
+
106
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
107
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
108
+
109
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
110
+
111
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
112
+
113
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
114
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
115
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
116
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
117
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
118
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
119
+
120
+ def forward(self,x):
121
+ b, c, h, w = x.shape
122
+
123
+ hx = x
124
+ hxin = self.rebnconvin(hx)
125
+
126
+ hx1 = self.rebnconv1(hxin)
127
+ hx = self.pool1(hx1)
128
+
129
+ hx2 = self.rebnconv2(hx)
130
+ hx = self.pool2(hx2)
131
+
132
+ hx3 = self.rebnconv3(hx)
133
+ hx = self.pool3(hx3)
134
+
135
+ hx4 = self.rebnconv4(hx)
136
+ hx = self.pool4(hx4)
137
+
138
+ hx5 = self.rebnconv5(hx)
139
+ hx = self.pool5(hx5)
140
+
141
+ hx6 = self.rebnconv6(hx)
142
+
143
+ hx7 = self.rebnconv7(hx6)
144
+
145
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
146
+ hx6dup = _upsample_like(hx6d,hx5)
147
+
148
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
149
+ hx5dup = _upsample_like(hx5d,hx4)
150
+
151
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
152
+ hx4dup = _upsample_like(hx4d,hx3)
153
+
154
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
155
+ hx3dup = _upsample_like(hx3d,hx2)
156
+
157
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
158
+ hx2dup = _upsample_like(hx2d,hx1)
159
+
160
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
161
+
162
+ return hx1d + hxin
163
+
164
+
165
+ ### RSU-6 ###
166
+ class RSU6(nn.Module):
167
+
168
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
169
+ super(RSU6,self).__init__()
170
+
171
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
172
+
173
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
174
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
175
+
176
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
177
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
178
+
179
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
180
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
181
+
182
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
183
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
+
185
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
+
187
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
188
+
189
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
190
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
191
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
192
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
193
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
194
+
195
+ def forward(self,x):
196
+
197
+ hx = x
198
+
199
+ hxin = self.rebnconvin(hx)
200
+
201
+ hx1 = self.rebnconv1(hxin)
202
+ hx = self.pool1(hx1)
203
+
204
+ hx2 = self.rebnconv2(hx)
205
+ hx = self.pool2(hx2)
206
+
207
+ hx3 = self.rebnconv3(hx)
208
+ hx = self.pool3(hx3)
209
+
210
+ hx4 = self.rebnconv4(hx)
211
+ hx = self.pool4(hx4)
212
+
213
+ hx5 = self.rebnconv5(hx)
214
+
215
+ hx6 = self.rebnconv6(hx5)
216
+
217
+
218
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
219
+ hx5dup = _upsample_like(hx5d,hx4)
220
+
221
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
222
+ hx4dup = _upsample_like(hx4d,hx3)
223
+
224
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
225
+ hx3dup = _upsample_like(hx3d,hx2)
226
+
227
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
228
+ hx2dup = _upsample_like(hx2d,hx1)
229
+
230
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
231
+
232
+ return hx1d + hxin
233
+
234
+ ### RSU-5 ###
235
+ class RSU5(nn.Module):
236
+
237
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
238
+ super(RSU5,self).__init__()
239
+
240
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
241
+
242
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
243
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
244
+
245
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
246
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
247
+
248
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
249
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
250
+
251
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
252
+
253
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
254
+
255
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
256
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
257
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
258
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
259
+
260
+ def forward(self,x):
261
+
262
+ hx = x
263
+
264
+ hxin = self.rebnconvin(hx)
265
+
266
+ hx1 = self.rebnconv1(hxin)
267
+ hx = self.pool1(hx1)
268
+
269
+ hx2 = self.rebnconv2(hx)
270
+ hx = self.pool2(hx2)
271
+
272
+ hx3 = self.rebnconv3(hx)
273
+ hx = self.pool3(hx3)
274
+
275
+ hx4 = self.rebnconv4(hx)
276
+
277
+ hx5 = self.rebnconv5(hx4)
278
+
279
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
280
+ hx4dup = _upsample_like(hx4d,hx3)
281
+
282
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
283
+ hx3dup = _upsample_like(hx3d,hx2)
284
+
285
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
286
+ hx2dup = _upsample_like(hx2d,hx1)
287
+
288
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
289
+
290
+ return hx1d + hxin
291
+
292
+ ### RSU-4 ###
293
+ class RSU4(nn.Module):
294
+
295
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
296
+ super(RSU4,self).__init__()
297
+
298
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
299
+
300
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
301
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
302
+
303
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
304
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
305
+
306
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
307
+
308
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
309
+
310
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
311
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
312
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
313
+
314
+ def forward(self,x):
315
+
316
+ hx = x
317
+
318
+ hxin = self.rebnconvin(hx)
319
+
320
+ hx1 = self.rebnconv1(hxin)
321
+ hx = self.pool1(hx1)
322
+
323
+ hx2 = self.rebnconv2(hx)
324
+ hx = self.pool2(hx2)
325
+
326
+ hx3 = self.rebnconv3(hx)
327
+
328
+ hx4 = self.rebnconv4(hx3)
329
+
330
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
331
+ hx3dup = _upsample_like(hx3d,hx2)
332
+
333
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
334
+ hx2dup = _upsample_like(hx2d,hx1)
335
+
336
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
337
+
338
+ return hx1d + hxin
339
+
340
+ ### RSU-4F ###
341
+ class RSU4F(nn.Module):
342
+
343
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
344
+ super(RSU4F,self).__init__()
345
+
346
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
347
+
348
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
349
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
350
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
351
+
352
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
353
+
354
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
355
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
356
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
357
+
358
+ def forward(self,x):
359
+
360
+ hx = x
361
+
362
+ hxin = self.rebnconvin(hx)
363
+
364
+ hx1 = self.rebnconv1(hxin)
365
+ hx2 = self.rebnconv2(hx1)
366
+ hx3 = self.rebnconv3(hx2)
367
+
368
+ hx4 = self.rebnconv4(hx3)
369
+
370
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
371
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
372
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
373
+
374
+ return hx1d + hxin
375
+
376
+
377
+ class myrebnconv(nn.Module):
378
+ def __init__(self, in_ch=3,
379
+ out_ch=1,
380
+ kernel_size=3,
381
+ stride=1,
382
+ padding=1,
383
+ dilation=1,
384
+ groups=1):
385
+ super(myrebnconv,self).__init__()
386
+
387
+ self.conv = nn.Conv2d(in_ch,
388
+ out_ch,
389
+ kernel_size=kernel_size,
390
+ stride=stride,
391
+ padding=padding,
392
+ dilation=dilation,
393
+ groups=groups)
394
+ self.bn = nn.BatchNorm2d(out_ch)
395
+ self.rl = nn.ReLU(inplace=True)
396
+
397
+ def forward(self,x):
398
+ return self.rl(self.bn(self.conv(x)))
399
+
400
+
401
+ class U2NetGTEncoder(nn.Module):
402
+
403
+ def __init__(self,in_ch=1,out_ch=1):
404
+ super(U2NetGTEncoder,self).__init__()
405
+
406
+ self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
407
+
408
+ self.stage1 = RSU7(16,16,64)
409
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
410
+
411
+ self.stage2 = RSU6(64,16,64)
412
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
413
+
414
+ self.stage3 = RSU5(64,32,128)
415
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
416
+
417
+ self.stage4 = RSU4(128,32,256)
418
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
419
+
420
+ self.stage5 = RSU4F(256,64,512)
421
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
422
+
423
+ self.stage6 = RSU4F(512,64,512)
424
+
425
+
426
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
427
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
428
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
429
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
430
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
431
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
432
+
433
+ def compute_loss(self, preds, targets):
434
+
435
+ return muti_loss_fusion(preds,targets)
436
+
437
+ def forward(self,x):
438
+
439
+ hx = x
440
+
441
+ hxin = self.conv_in(hx)
442
+ # hx = self.pool_in(hxin)
443
+
444
+ #stage 1
445
+ hx1 = self.stage1(hxin)
446
+ hx = self.pool12(hx1)
447
+
448
+ #stage 2
449
+ hx2 = self.stage2(hx)
450
+ hx = self.pool23(hx2)
451
+
452
+ #stage 3
453
+ hx3 = self.stage3(hx)
454
+ hx = self.pool34(hx3)
455
+
456
+ #stage 4
457
+ hx4 = self.stage4(hx)
458
+ hx = self.pool45(hx4)
459
+
460
+ #stage 5
461
+ hx5 = self.stage5(hx)
462
+ hx = self.pool56(hx5)
463
+
464
+ #stage 6
465
+ hx6 = self.stage6(hx)
466
+
467
+
468
+ #side output
469
+ d1 = self.side1(hx1)
470
+ d1 = _upsample_like(d1,x)
471
+
472
+ d2 = self.side2(hx2)
473
+ d2 = _upsample_like(d2,x)
474
+
475
+ d3 = self.side3(hx3)
476
+ d3 = _upsample_like(d3,x)
477
+
478
+ d4 = self.side4(hx4)
479
+ d4 = _upsample_like(d4,x)
480
+
481
+ d5 = self.side5(hx5)
482
+ d5 = _upsample_like(d5,x)
483
+
484
+ d6 = self.side6(hx6)
485
+ d6 = _upsample_like(d6,x)
486
+
487
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
488
+
489
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6]
490
+
491
+ class U2NET(nn.Module):
492
+
493
+ def __init__(self,in_ch=3,out_ch=1):
494
+ super(U2NET,self).__init__()
495
+
496
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
497
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
498
+
499
+ self.stage1 = RSU7(64,32,64)
500
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
501
+
502
+ self.stage2 = RSU6(64,32,128)
503
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
504
+
505
+ self.stage3 = RSU5(128,64,256)
506
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
507
+
508
+ self.stage4 = RSU4(256,128,512)
509
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
510
+
511
+ self.stage5 = RSU4F(512,256,512)
512
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
513
+
514
+ self.stage6 = RSU4F(512,256,512)
515
+
516
+ # decoder
517
+ self.stage5d = RSU4F(1024,256,512)
518
+ self.stage4d = RSU4(1024,128,256)
519
+ self.stage3d = RSU5(512,64,128)
520
+ self.stage2d = RSU6(256,32,64)
521
+ self.stage1d = RSU7(128,16,64)
522
+
523
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
524
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
525
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
526
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
527
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
528
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
529
+
530
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
531
+
532
+ def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
533
+
534
+ # return muti_loss_fusion(preds,targets)
535
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
536
+
537
+ def compute_loss(self, preds, targets):
538
+
539
+ # return muti_loss_fusion(preds,targets)
540
+ return muti_loss_fusion(preds, targets)
541
+
542
+ def forward(self,x):
543
+
544
+ hx = x
545
+
546
+ hxin = self.conv_in(hx)
547
+ #hx = self.pool_in(hxin)
548
+
549
+ #stage 1
550
+ hx1 = self.stage1(hxin)
551
+ hx = self.pool12(hx1)
552
+
553
+ #stage 2
554
+ hx2 = self.stage2(hx)
555
+ hx = self.pool23(hx2)
556
+
557
+ #stage 3
558
+ hx3 = self.stage3(hx)
559
+ hx = self.pool34(hx3)
560
+
561
+ #stage 4
562
+ hx4 = self.stage4(hx)
563
+ hx = self.pool45(hx4)
564
+
565
+ #stage 5
566
+ hx5 = self.stage5(hx)
567
+ hx = self.pool56(hx5)
568
+
569
+ #stage 6
570
+ hx6 = self.stage6(hx)
571
+ hx6up = _upsample_like(hx6,hx5)
572
+
573
+ #-------------------- decoder --------------------
574
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
575
+ hx5dup = _upsample_like(hx5d,hx4)
576
+
577
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
578
+ hx4dup = _upsample_like(hx4d,hx3)
579
+
580
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
581
+ hx3dup = _upsample_like(hx3d,hx2)
582
+
583
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
584
+ hx2dup = _upsample_like(hx2d,hx1)
585
+
586
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
587
+
588
+
589
+ #side output
590
+ d1 = self.side1(hx1d)
591
+ d1 = _upsample_like(d1,x)
592
+
593
+ d2 = self.side2(hx2d)
594
+ d2 = _upsample_like(d2,x)
595
+
596
+ d3 = self.side3(hx3d)
597
+ d3 = _upsample_like(d3,x)
598
+
599
+ d4 = self.side4(hx4d)
600
+ d4 = _upsample_like(d4,x)
601
+
602
+ d5 = self.side5(hx5d)
603
+ d5 = _upsample_like(d5,x)
604
+
605
+ d6 = self.side6(hx6)
606
+ d6 = _upsample_like(d6,x)
607
+
608
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
609
+
610
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]