MaxwellMeyer commited on
Commit
cefc11a
·
verified ·
1 Parent(s): c83f50d

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ BEN2_demo_pictures/grid_example1.png filter=lfs diff=lfs merge=lfs -text
37
+ BEN2_demo_pictures/grid_example2.png filter=lfs diff=lfs merge=lfs -text
38
+ BEN2_demo_pictures/grid_example3.png filter=lfs diff=lfs merge=lfs -text
39
+ BEN2_demo_pictures/grid_example6.png filter=lfs diff=lfs merge=lfs -text
40
+ BEN2_demo_pictures/grid_example7.png filter=lfs diff=lfs merge=lfs -text
BEN2.py ADDED
@@ -0,0 +1,1377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ import torch.utils.checkpoint as checkpoint
8
+ import numpy as np
9
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
10
+ from PIL import Image, ImageOps
11
+ from torchvision import transforms
12
+ import numpy as np
13
+ import random
14
+ import cv2
15
+ import os
16
+ import subprocess
17
+ import time
18
+ import tempfile
19
+
20
+
21
+
22
+
23
+ def set_random_seed(seed):
24
+ random.seed(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+ torch.cuda.manual_seed_all(seed)
29
+ torch.backends.cudnn.deterministic = True
30
+ torch.backends.cudnn.benchmark = False
31
+ set_random_seed(9)
32
+
33
+
34
+ torch.set_float32_matmul_precision('highest')
35
+
36
+
37
+
38
+ class Mlp(nn.Module):
39
+ """ Multilayer perceptron."""
40
+
41
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
42
+ super().__init__()
43
+ out_features = out_features or in_features
44
+ hidden_features = hidden_features or in_features
45
+ self.fc1 = nn.Linear(in_features, hidden_features)
46
+ self.act = act_layer()
47
+ self.fc2 = nn.Linear(hidden_features, out_features)
48
+ self.drop = nn.Dropout(drop)
49
+
50
+ def forward(self, x):
51
+ x = self.fc1(x)
52
+ x = self.act(x)
53
+ x = self.drop(x)
54
+ x = self.fc2(x)
55
+ x = self.drop(x)
56
+ return x
57
+
58
+
59
+ def window_partition(x, window_size):
60
+ """
61
+ Args:
62
+ x: (B, H, W, C)
63
+ window_size (int): window size
64
+ Returns:
65
+ windows: (num_windows*B, window_size, window_size, C)
66
+ """
67
+ B, H, W, C = x.shape
68
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
69
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
70
+ return windows
71
+
72
+
73
+ def window_reverse(windows, window_size, H, W):
74
+ """
75
+ Args:
76
+ windows: (num_windows*B, window_size, window_size, C)
77
+ window_size (int): Window size
78
+ H (int): Height of image
79
+ W (int): Width of image
80
+ Returns:
81
+ x: (B, H, W, C)
82
+ """
83
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
84
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
85
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
86
+ return x
87
+
88
+
89
+ class WindowAttention(nn.Module):
90
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
91
+ It supports both of shifted and non-shifted window.
92
+ Args:
93
+ dim (int): Number of input channels.
94
+ window_size (tuple[int]): The height and width of the window.
95
+ num_heads (int): Number of attention heads.
96
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
97
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
98
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
99
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
100
+ """
101
+
102
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
103
+
104
+ super().__init__()
105
+ self.dim = dim
106
+ self.window_size = window_size # Wh, Ww
107
+ self.num_heads = num_heads
108
+ head_dim = dim // num_heads
109
+ self.scale = qk_scale or head_dim ** -0.5
110
+
111
+ # define a parameter table of relative position bias
112
+ self.relative_position_bias_table = nn.Parameter(
113
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
114
+
115
+ # get pair-wise relative position index for each token inside the window
116
+ coords_h = torch.arange(self.window_size[0])
117
+ coords_w = torch.arange(self.window_size[1])
118
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
119
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
120
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
121
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
122
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
123
+ relative_coords[:, :, 1] += self.window_size[1] - 1
124
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
125
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
126
+ self.register_buffer("relative_position_index", relative_position_index)
127
+
128
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
129
+ self.attn_drop = nn.Dropout(attn_drop)
130
+ self.proj = nn.Linear(dim, dim)
131
+ self.proj_drop = nn.Dropout(proj_drop)
132
+
133
+ trunc_normal_(self.relative_position_bias_table, std=.02)
134
+ self.softmax = nn.Softmax(dim=-1)
135
+
136
+ def forward(self, x, mask=None):
137
+ """ Forward function.
138
+ Args:
139
+ x: input features with shape of (num_windows*B, N, C)
140
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
141
+ """
142
+ B_, N, C = x.shape
143
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
144
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
145
+
146
+ q = q * self.scale
147
+ attn = (q @ k.transpose(-2, -1))
148
+
149
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
150
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
151
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
152
+ attn = attn + relative_position_bias.unsqueeze(0)
153
+
154
+ if mask is not None:
155
+ nW = mask.shape[0]
156
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
157
+ attn = attn.view(-1, self.num_heads, N, N)
158
+ attn = self.softmax(attn)
159
+ else:
160
+ attn = self.softmax(attn)
161
+
162
+ attn = self.attn_drop(attn)
163
+
164
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
165
+ x = self.proj(x)
166
+ x = self.proj_drop(x)
167
+ return x
168
+
169
+
170
+ class SwinTransformerBlock(nn.Module):
171
+ """ Swin Transformer Block.
172
+ Args:
173
+ dim (int): Number of input channels.
174
+ num_heads (int): Number of attention heads.
175
+ window_size (int): Window size.
176
+ shift_size (int): Shift size for SW-MSA.
177
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
178
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
179
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
180
+ drop (float, optional): Dropout rate. Default: 0.0
181
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
182
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
183
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
184
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
185
+ """
186
+
187
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
188
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
189
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
190
+ super().__init__()
191
+ self.dim = dim
192
+ self.num_heads = num_heads
193
+ self.window_size = window_size
194
+ self.shift_size = shift_size
195
+ self.mlp_ratio = mlp_ratio
196
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
197
+
198
+ self.norm1 = norm_layer(dim)
199
+ self.attn = WindowAttention(
200
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
201
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
202
+
203
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
204
+ self.norm2 = norm_layer(dim)
205
+ mlp_hidden_dim = int(dim * mlp_ratio)
206
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
207
+
208
+ self.H = None
209
+ self.W = None
210
+
211
+ def forward(self, x, mask_matrix):
212
+ """ Forward function.
213
+ Args:
214
+ x: Input feature, tensor size (B, H*W, C).
215
+ H, W: Spatial resolution of the input feature.
216
+ mask_matrix: Attention mask for cyclic shift.
217
+ """
218
+ B, L, C = x.shape
219
+ H, W = self.H, self.W
220
+ assert L == H * W, "input feature has wrong size"
221
+
222
+ shortcut = x
223
+ x = self.norm1(x)
224
+ x = x.view(B, H, W, C)
225
+
226
+ # pad feature maps to multiples of window size
227
+ pad_l = pad_t = 0
228
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
229
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
230
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
231
+ _, Hp, Wp, _ = x.shape
232
+
233
+ # cyclic shift
234
+ if self.shift_size > 0:
235
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
236
+ attn_mask = mask_matrix
237
+ else:
238
+ shifted_x = x
239
+ attn_mask = None
240
+
241
+ # partition windows
242
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
243
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
244
+
245
+ # W-MSA/SW-MSA
246
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
247
+
248
+ # merge windows
249
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
250
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
251
+
252
+ # reverse cyclic shift
253
+ if self.shift_size > 0:
254
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
255
+ else:
256
+ x = shifted_x
257
+
258
+ if pad_r > 0 or pad_b > 0:
259
+ x = x[:, :H, :W, :].contiguous()
260
+
261
+ x = x.view(B, H * W, C)
262
+
263
+ # FFN
264
+ x = shortcut + self.drop_path(x)
265
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
266
+
267
+ return x
268
+
269
+
270
+ class PatchMerging(nn.Module):
271
+ """ Patch Merging Layer
272
+ Args:
273
+ dim (int): Number of input channels.
274
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
275
+ """
276
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
277
+ super().__init__()
278
+ self.dim = dim
279
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
280
+ self.norm = norm_layer(4 * dim)
281
+
282
+ def forward(self, x, H, W):
283
+ """ Forward function.
284
+ Args:
285
+ x: Input feature, tensor size (B, H*W, C).
286
+ H, W: Spatial resolution of the input feature.
287
+ """
288
+ B, L, C = x.shape
289
+ assert L == H * W, "input feature has wrong size"
290
+
291
+ x = x.view(B, H, W, C)
292
+
293
+ # padding
294
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
295
+ if pad_input:
296
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
297
+
298
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
299
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
300
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
301
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
302
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
303
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
304
+
305
+ x = self.norm(x)
306
+ x = self.reduction(x)
307
+
308
+ return x
309
+
310
+
311
+ class BasicLayer(nn.Module):
312
+ """ A basic Swin Transformer layer for one stage.
313
+ Args:
314
+ dim (int): Number of feature channels
315
+ depth (int): Depths of this stage.
316
+ num_heads (int): Number of attention head.
317
+ window_size (int): Local window size. Default: 7.
318
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
319
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
320
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
321
+ drop (float, optional): Dropout rate. Default: 0.0
322
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
323
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
324
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
325
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
326
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
327
+ """
328
+
329
+ def __init__(self,
330
+ dim,
331
+ depth,
332
+ num_heads,
333
+ window_size=7,
334
+ mlp_ratio=4.,
335
+ qkv_bias=True,
336
+ qk_scale=None,
337
+ drop=0.,
338
+ attn_drop=0.,
339
+ drop_path=0.,
340
+ norm_layer=nn.LayerNorm,
341
+ downsample=None,
342
+ use_checkpoint=False):
343
+ super().__init__()
344
+ self.window_size = window_size
345
+ self.shift_size = window_size // 2
346
+ self.depth = depth
347
+ self.use_checkpoint = use_checkpoint
348
+
349
+ # build blocks
350
+ self.blocks = nn.ModuleList([
351
+ SwinTransformerBlock(
352
+ dim=dim,
353
+ num_heads=num_heads,
354
+ window_size=window_size,
355
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
356
+ mlp_ratio=mlp_ratio,
357
+ qkv_bias=qkv_bias,
358
+ qk_scale=qk_scale,
359
+ drop=drop,
360
+ attn_drop=attn_drop,
361
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
362
+ norm_layer=norm_layer)
363
+ for i in range(depth)])
364
+
365
+ # patch merging layer
366
+ if downsample is not None:
367
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
368
+ else:
369
+ self.downsample = None
370
+
371
+ def forward(self, x, H, W):
372
+ """ Forward function.
373
+ Args:
374
+ x: Input feature, tensor size (B, H*W, C).
375
+ H, W: Spatial resolution of the input feature.
376
+ """
377
+
378
+ # calculate attention mask for SW-MSA
379
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
380
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
381
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
382
+ h_slices = (slice(0, -self.window_size),
383
+ slice(-self.window_size, -self.shift_size),
384
+ slice(-self.shift_size, None))
385
+ w_slices = (slice(0, -self.window_size),
386
+ slice(-self.window_size, -self.shift_size),
387
+ slice(-self.shift_size, None))
388
+ cnt = 0
389
+ for h in h_slices:
390
+ for w in w_slices:
391
+ img_mask[:, h, w, :] = cnt
392
+ cnt += 1
393
+
394
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
395
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
396
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
397
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
398
+
399
+ for blk in self.blocks:
400
+ blk.H, blk.W = H, W
401
+ if self.use_checkpoint:
402
+ x = checkpoint.checkpoint(blk, x, attn_mask)
403
+ else:
404
+ x = blk(x, attn_mask)
405
+ if self.downsample is not None:
406
+ x_down = self.downsample(x, H, W)
407
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
408
+ return x, H, W, x_down, Wh, Ww
409
+ else:
410
+ return x, H, W, x, H, W
411
+
412
+
413
+ class PatchEmbed(nn.Module):
414
+ """ Image to Patch Embedding
415
+ Args:
416
+ patch_size (int): Patch token size. Default: 4.
417
+ in_chans (int): Number of input image channels. Default: 3.
418
+ embed_dim (int): Number of linear projection output channels. Default: 96.
419
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
420
+ """
421
+
422
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
423
+ super().__init__()
424
+ patch_size = to_2tuple(patch_size)
425
+ self.patch_size = patch_size
426
+
427
+ self.in_chans = in_chans
428
+ self.embed_dim = embed_dim
429
+
430
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
431
+ if norm_layer is not None:
432
+ self.norm = norm_layer(embed_dim)
433
+ else:
434
+ self.norm = None
435
+
436
+ def forward(self, x):
437
+ """Forward function."""
438
+ # padding
439
+ _, _, H, W = x.size()
440
+ if W % self.patch_size[1] != 0:
441
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
442
+ if H % self.patch_size[0] != 0:
443
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
444
+
445
+ x = self.proj(x) # B C Wh Ww
446
+ if self.norm is not None:
447
+ Wh, Ww = x.size(2), x.size(3)
448
+ x = x.flatten(2).transpose(1, 2)
449
+ x = self.norm(x)
450
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
451
+
452
+ return x
453
+
454
+
455
+ class SwinTransformer(nn.Module):
456
+ """ Swin Transformer backbone.
457
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
458
+ https://arxiv.org/pdf/2103.14030
459
+ Args:
460
+ pretrain_img_size (int): Input image size for training the pretrained model,
461
+ used in absolute postion embedding. Default 224.
462
+ patch_size (int | tuple(int)): Patch size. Default: 4.
463
+ in_chans (int): Number of input image channels. Default: 3.
464
+ embed_dim (int): Number of linear projection output channels. Default: 96.
465
+ depths (tuple[int]): Depths of each Swin Transformer stage.
466
+ num_heads (tuple[int]): Number of attention head of each stage.
467
+ window_size (int): Window size. Default: 7.
468
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
469
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
470
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
471
+ drop_rate (float): Dropout rate.
472
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
473
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
474
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
475
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
476
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
477
+ out_indices (Sequence[int]): Output from which stages.
478
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
479
+ -1 means not freezing any parameters.
480
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
481
+ """
482
+
483
+ def __init__(self,
484
+ pretrain_img_size=224,
485
+ patch_size=4,
486
+ in_chans=3,
487
+ embed_dim=96,
488
+ depths=[2, 2, 6, 2],
489
+ num_heads=[3, 6, 12, 24],
490
+ window_size=7,
491
+ mlp_ratio=4.,
492
+ qkv_bias=True,
493
+ qk_scale=None,
494
+ drop_rate=0.,
495
+ attn_drop_rate=0.,
496
+ drop_path_rate=0.2,
497
+ norm_layer=nn.LayerNorm,
498
+ ape=False,
499
+ patch_norm=True,
500
+ out_indices=(0, 1, 2, 3),
501
+ frozen_stages=-1,
502
+ use_checkpoint=False):
503
+ super().__init__()
504
+
505
+ self.pretrain_img_size = pretrain_img_size
506
+ self.num_layers = len(depths)
507
+ self.embed_dim = embed_dim
508
+ self.ape = ape
509
+ self.patch_norm = patch_norm
510
+ self.out_indices = out_indices
511
+ self.frozen_stages = frozen_stages
512
+
513
+ # split image into non-overlapping patches
514
+ self.patch_embed = PatchEmbed(
515
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
516
+ norm_layer=norm_layer if self.patch_norm else None)
517
+
518
+ # absolute position embedding
519
+ if self.ape:
520
+ pretrain_img_size = to_2tuple(pretrain_img_size)
521
+ patch_size = to_2tuple(patch_size)
522
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
523
+
524
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
525
+ trunc_normal_(self.absolute_pos_embed, std=.02)
526
+
527
+ self.pos_drop = nn.Dropout(p=drop_rate)
528
+
529
+ # stochastic depth
530
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
531
+
532
+ # build layers
533
+ self.layers = nn.ModuleList()
534
+ for i_layer in range(self.num_layers):
535
+ layer = BasicLayer(
536
+ dim=int(embed_dim * 2 ** i_layer),
537
+ depth=depths[i_layer],
538
+ num_heads=num_heads[i_layer],
539
+ window_size=window_size,
540
+ mlp_ratio=mlp_ratio,
541
+ qkv_bias=qkv_bias,
542
+ qk_scale=qk_scale,
543
+ drop=drop_rate,
544
+ attn_drop=attn_drop_rate,
545
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
546
+ norm_layer=norm_layer,
547
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
548
+ use_checkpoint=use_checkpoint)
549
+ self.layers.append(layer)
550
+
551
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
552
+ self.num_features = num_features
553
+
554
+ # add a norm layer for each output
555
+ for i_layer in out_indices:
556
+ layer = norm_layer(num_features[i_layer])
557
+ layer_name = f'norm{i_layer}'
558
+ self.add_module(layer_name, layer)
559
+
560
+ self._freeze_stages()
561
+
562
+ def _freeze_stages(self):
563
+ if self.frozen_stages >= 0:
564
+ self.patch_embed.eval()
565
+ for param in self.patch_embed.parameters():
566
+ param.requires_grad = False
567
+
568
+ if self.frozen_stages >= 1 and self.ape:
569
+ self.absolute_pos_embed.requires_grad = False
570
+
571
+ if self.frozen_stages >= 2:
572
+ self.pos_drop.eval()
573
+ for i in range(0, self.frozen_stages - 1):
574
+ m = self.layers[i]
575
+ m.eval()
576
+ for param in m.parameters():
577
+ param.requires_grad = False
578
+
579
+
580
+ def forward(self, x):
581
+
582
+ x = self.patch_embed(x)
583
+
584
+ Wh, Ww = x.size(2), x.size(3)
585
+ if self.ape:
586
+ # interpolate the position embedding to the corresponding size
587
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
588
+ x = (x + absolute_pos_embed) # B Wh*Ww C
589
+
590
+ outs = [x.contiguous()]
591
+ x = x.flatten(2).transpose(1, 2)
592
+ x = self.pos_drop(x)
593
+
594
+
595
+ for i in range(self.num_layers):
596
+ layer = self.layers[i]
597
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
598
+
599
+
600
+ if i in self.out_indices:
601
+ norm_layer = getattr(self, f'norm{i}')
602
+ x_out = norm_layer(x_out)
603
+
604
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
605
+ outs.append(out)
606
+
607
+
608
+
609
+ return tuple(outs)
610
+
611
+
612
+
613
+
614
+
615
+
616
+
617
+
618
+ def get_activation_fn(activation):
619
+ """Return an activation function given a string"""
620
+ if activation == "gelu":
621
+ return F.gelu
622
+
623
+ raise RuntimeError(F"activation should be gelu, not {activation}.")
624
+
625
+
626
+ def make_cbr(in_dim, out_dim):
627
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
628
+
629
+
630
+ def make_cbg(in_dim, out_dim):
631
+ return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
632
+
633
+
634
+ def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
635
+ return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
636
+
637
+
638
+ def resize_as(x, y, interpolation='bilinear'):
639
+ return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
640
+
641
+
642
+ def image2patches(x):
643
+ """b c (hg h) (wg w) -> (hg wg b) c h w"""
644
+ x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2 )
645
+ return x
646
+
647
+
648
+ def patches2image(x):
649
+ """(hg wg b) c h w -> b c (hg h) (wg w)"""
650
+ x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
651
+ return x
652
+
653
+
654
+
655
+ class PositionEmbeddingSine:
656
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
657
+ super().__init__()
658
+ self.num_pos_feats = num_pos_feats
659
+ self.temperature = temperature
660
+ self.normalize = normalize
661
+ if scale is not None and normalize is False:
662
+ raise ValueError("normalize should be True if scale is passed")
663
+ if scale is None:
664
+ scale = 2 * math.pi
665
+ self.scale = scale
666
+ self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
667
+
668
+ def __call__(self, b, h, w):
669
+ device = self.dim_t.device
670
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
671
+ assert mask is not None
672
+ not_mask = ~mask
673
+ y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
674
+ x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
675
+ if self.normalize:
676
+ eps = 1e-6
677
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
678
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
679
+
680
+ dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
681
+ pos_x = x_embed[:, :, :, None] / dim_t
682
+ pos_y = y_embed[:, :, :, None] / dim_t
683
+
684
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
685
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
686
+
687
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
688
+
689
+
690
+
691
+ class PositionEmbeddingSine:
692
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
693
+ super().__init__()
694
+ self.num_pos_feats = num_pos_feats
695
+ self.temperature = temperature
696
+ self.normalize = normalize
697
+ if scale is not None and normalize is False:
698
+ raise ValueError("normalize should be True if scale is passed")
699
+ if scale is None:
700
+ scale = 2 * math.pi
701
+ self.scale = scale
702
+ self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
703
+
704
+ def __call__(self, b, h, w):
705
+ device = self.dim_t.device
706
+ mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
707
+ assert mask is not None
708
+ not_mask = ~mask
709
+ y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
710
+ x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
711
+ if self.normalize:
712
+ eps = 1e-6
713
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
714
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
715
+
716
+ dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
717
+ pos_x = x_embed[:, :, :, None] / dim_t
718
+ pos_y = y_embed[:, :, :, None] / dim_t
719
+
720
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
721
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
722
+
723
+ return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
724
+
725
+
726
+ class MCLM(nn.Module):
727
+ def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
728
+ super(MCLM, self).__init__()
729
+ self.attention = nn.ModuleList([
730
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
731
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
732
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
733
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
734
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
735
+ ])
736
+
737
+ self.linear1 = nn.Linear(d_model, d_model * 2)
738
+ self.linear2 = nn.Linear(d_model * 2, d_model)
739
+ self.linear3 = nn.Linear(d_model, d_model * 2)
740
+ self.linear4 = nn.Linear(d_model * 2, d_model)
741
+ self.norm1 = nn.LayerNorm(d_model)
742
+ self.norm2 = nn.LayerNorm(d_model)
743
+ self.dropout = nn.Dropout(0.1)
744
+ self.dropout1 = nn.Dropout(0.1)
745
+ self.dropout2 = nn.Dropout(0.1)
746
+ self.activation = get_activation_fn('gelu')
747
+ self.pool_ratios = pool_ratios
748
+ self.p_poses = []
749
+ self.g_pos = None
750
+ self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
751
+
752
+ def forward(self, l, g):
753
+ """
754
+ l: 4,c,h,w
755
+ g: 1,c,h,w
756
+ """
757
+ self.p_poses = []
758
+ self.g_pos = None
759
+ b, c, h, w = l.size()
760
+ # 4,c,h,w -> 1,c,2h,2w
761
+ concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
762
+
763
+ pools = []
764
+ for pool_ratio in self.pool_ratios:
765
+ # b,c,h,w
766
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
767
+ pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
768
+ pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
769
+ if self.g_pos is None:
770
+ pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
771
+ pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
772
+ self.p_poses.append(pos_emb)
773
+ pools = torch.cat(pools, 0)
774
+ if self.g_pos is None:
775
+ self.p_poses = torch.cat(self.p_poses, dim=0)
776
+ pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
777
+ self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
778
+
779
+ device = pools.device
780
+ self.p_poses = self.p_poses.to(device)
781
+ self.g_pos = self.g_pos.to(device)
782
+
783
+
784
+ # attention between glb (q) & multisensory concated-locs (k,v)
785
+ g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
786
+
787
+
788
+ g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
789
+ g_hw_b_c = self.norm1(g_hw_b_c)
790
+ g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
791
+ g_hw_b_c = self.norm2(g_hw_b_c)
792
+
793
+ # attention between origin locs (q) & freashed glb (k,v)
794
+ l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
795
+ _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
796
+ _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
797
+ outputs_re = []
798
+ for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
799
+ outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
800
+ outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
801
+
802
+ l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
803
+ l_hw_b_c = self.norm1(l_hw_b_c)
804
+ l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
805
+ l_hw_b_c = self.norm2(l_hw_b_c)
806
+
807
+ l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
808
+ return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
809
+
810
+
811
+
812
+
813
+
814
+
815
+
816
+
817
+
818
+ class MCRM(nn.Module):
819
+ def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
820
+ super(MCRM, self).__init__()
821
+ self.attention = nn.ModuleList([
822
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
823
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
824
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
825
+ nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
826
+ ])
827
+ self.linear3 = nn.Linear(d_model, d_model * 2)
828
+ self.linear4 = nn.Linear(d_model * 2, d_model)
829
+ self.norm1 = nn.LayerNorm(d_model)
830
+ self.norm2 = nn.LayerNorm(d_model)
831
+ self.dropout = nn.Dropout(0.1)
832
+ self.dropout1 = nn.Dropout(0.1)
833
+ self.dropout2 = nn.Dropout(0.1)
834
+ self.sigmoid = nn.Sigmoid()
835
+ self.activation = get_activation_fn('gelu')
836
+ self.sal_conv = nn.Conv2d(d_model, 1, 1)
837
+ self.pool_ratios = pool_ratios
838
+
839
+ def forward(self, x):
840
+ device = x.device
841
+ b, c, h, w = x.size()
842
+ loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
843
+
844
+ patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
845
+
846
+ token_attention_map = self.sigmoid(self.sal_conv(glb))
847
+ token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
848
+ loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
849
+
850
+ pools = []
851
+ for pool_ratio in self.pool_ratios:
852
+ tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
853
+ pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
854
+ pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
855
+
856
+ pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
857
+ loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
858
+
859
+ outputs = []
860
+ for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
861
+ v = pools[i]
862
+ k = v
863
+ outputs.append(self.attention[i](q, k, v)[0])
864
+
865
+ outputs = torch.cat(outputs, 1)
866
+ src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
867
+ src = self.norm1(src)
868
+ src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
869
+ src = self.norm2(src)
870
+ src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
871
+ glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
872
+
873
+ return torch.cat((src, glb), 0), token_attention_map
874
+
875
+
876
+
877
+ class BEN_Base(nn.Module):
878
+ def __init__(self):
879
+ super().__init__()
880
+
881
+ self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
882
+ emb_dim = 128
883
+ self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
884
+ self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
885
+ self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
886
+ self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
887
+ self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
888
+
889
+ self.output5 = make_cbr(1024, emb_dim)
890
+ self.output4 = make_cbr(512, emb_dim)
891
+ self.output3 = make_cbr(256, emb_dim)
892
+ self.output2 = make_cbr(128, emb_dim)
893
+ self.output1 = make_cbr(128, emb_dim)
894
+
895
+ self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
896
+ self.conv1 = make_cbr(emb_dim, emb_dim)
897
+ self.conv2 = make_cbr(emb_dim, emb_dim)
898
+ self.conv3 = make_cbr(emb_dim, emb_dim)
899
+ self.conv4 = make_cbr(emb_dim, emb_dim)
900
+ self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
901
+ self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
902
+ self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
903
+ self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
904
+
905
+ self.insmask_head = nn.Sequential(
906
+ nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
907
+ nn.InstanceNorm2d(384),
908
+ nn.GELU(),
909
+ nn.Conv2d(384, 384, kernel_size=3, padding=1),
910
+ nn.InstanceNorm2d(384),
911
+ nn.GELU(),
912
+ nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
913
+ )
914
+
915
+ self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
916
+ self.upsample1 = make_cbg(emb_dim, emb_dim)
917
+ self.upsample2 = make_cbg(emb_dim, emb_dim)
918
+ self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
919
+
920
+ for m in self.modules():
921
+ if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
922
+ m.inplace = True
923
+
924
+ @torch.inference_mode()
925
+ @torch.autocast(device_type="cuda",dtype=torch.float16)
926
+ def forward(self, x):
927
+ real_batch = x.size(0)
928
+
929
+ shallow_batch = self.shallow(x)
930
+ glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
931
+
932
+
933
+
934
+ final_input = None
935
+ for i in range(real_batch):
936
+ start = i * 4
937
+ end = (i + 1) * 4
938
+ loc_batch = image2patches(x[i,:,:,:].unsqueeze(dim=0))
939
+ input_ = torch.cat((loc_batch, glb_batch[i,:,:,:].unsqueeze(dim=0)), dim=0)
940
+
941
+
942
+ if final_input == None:
943
+ final_input= input_
944
+ else: final_input = torch.cat((final_input, input_), dim=0)
945
+
946
+ features = self.backbone(final_input)
947
+ outputs = []
948
+
949
+ for i in range(real_batch):
950
+
951
+ start = i * 5
952
+ end = (i + 1) * 5
953
+
954
+ f4 = features[4][start:end, :, :, :] # shape: [5, C, H, W]
955
+ f3 = features[3][start:end, :, :, :]
956
+ f2 = features[2][start:end, :, :, :]
957
+ f1 = features[1][start:end, :, :, :]
958
+ f0 = features[0][start:end, :, :, :]
959
+ e5 = self.output5(f4)
960
+ e4 = self.output4(f3)
961
+ e3 = self.output3(f2)
962
+ e2 = self.output2(f1)
963
+ e1 = self.output1(f0)
964
+ loc_e5, glb_e5 = e5.split([4, 1], dim=0)
965
+ e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
966
+
967
+
968
+ e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
969
+ e4 = self.conv4(e4)
970
+ e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
971
+ e3 = self.conv3(e3)
972
+ e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
973
+ e2 = self.conv2(e2)
974
+ e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
975
+ e1 = self.conv1(e1)
976
+
977
+ loc_e1, glb_e1 = e1.split([4, 1], dim=0)
978
+
979
+ output1_cat = patches2image(loc_e1) # (1,128,256,256)
980
+
981
+ # add glb feat in
982
+ output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
983
+ # merge
984
+ final_output = self.insmask_head(output1_cat) # (1,128,256,256)
985
+ # shallow feature merge
986
+ shallow = shallow_batch[i,:,:,:].unsqueeze(dim=0)
987
+ final_output = final_output + resize_as(shallow, final_output)
988
+ final_output = self.upsample1(rescale_to(final_output))
989
+ final_output = rescale_to(final_output + resize_as(shallow, final_output))
990
+ final_output = self.upsample2(final_output)
991
+ final_output = self.output(final_output)
992
+ mask = final_output.sigmoid()
993
+ outputs.append(mask)
994
+
995
+ return torch.cat(outputs, dim=0)
996
+
997
+
998
+
999
+
1000
+ def loadcheckpoints(self,model_path):
1001
+ model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
1002
+ self.load_state_dict(model_dict['model_state_dict'], strict=True)
1003
+ del model_path
1004
+
1005
+ def inference(self,image,refine_foreground=False):
1006
+
1007
+ set_random_seed(9)
1008
+ # image = ImageOps.exif_transpose(image)
1009
+ if isinstance(image, Image.Image):
1010
+ image, h, w,original_image = rgb_loader_refiner(image)
1011
+ img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
1012
+ with torch.no_grad():
1013
+ res = self.forward(img_tensor)
1014
+
1015
+ # Show Results
1016
+ if refine_foreground == True:
1017
+
1018
+ pred_pil = transforms.ToPILImage()(res.squeeze())
1019
+ image_masked = refine_foreground_process(original_image, pred_pil)
1020
+
1021
+ image_masked.putalpha(pred_pil.resize(original_image.size))
1022
+ return image_masked
1023
+
1024
+ else:
1025
+ alpha = postprocess_image(res, im_size=[w,h])
1026
+ pred_pil = transforms.ToPILImage()(alpha)
1027
+ mask = pred_pil.resize(original_image.size)
1028
+ foreground = original_image.putalpha(mask)
1029
+ # mask = Image.fromarray(alpha)
1030
+
1031
+ return foreground
1032
+
1033
+
1034
+ else:
1035
+ foregrounds = []
1036
+ for batch in image:
1037
+ image, h, w,original_image = rgb_loader_refiner(batch)
1038
+ img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
1039
+
1040
+ with torch.no_grad():
1041
+ res = self.forward(img_tensor)
1042
+
1043
+ if refine_foreground == True:
1044
+
1045
+ pred_pil = transforms.ToPILImage()(res.squeeze())
1046
+ image_masked = refine_foreground_process(original_image, pred_pil)
1047
+
1048
+ image_masked.putalpha(pred_pil.resize(original_image.size))
1049
+
1050
+ foregrounds.append(image_masked)
1051
+ else:
1052
+ alpha = postprocess_image(res, im_size=[w,h])
1053
+ pred_pil = transforms.ToPILImage()(alpha)
1054
+ mask = pred_pil.resize(original_image.size)
1055
+ original_image.putalpha(mask)
1056
+ # mask = Image.fromarray(alpha)
1057
+ foregrounds.append(original_image)
1058
+
1059
+ return foregrounds
1060
+
1061
+ def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, print_frames_processed=True, webm = False, rgb_value= (0, 255, 0)):
1062
+
1063
+ """
1064
+ Segments the given video to extract the foreground (with alpha) from each frame
1065
+ and saves the result as either a WebM video (with alpha channel) or MP4 (with a
1066
+ color background).
1067
+
1068
+ Args:
1069
+ video_path (str):
1070
+ Path to the input video file.
1071
+
1072
+ output_path (str, optional):
1073
+ Directory (or full path) where the output video and/or files will be saved.
1074
+ Defaults to "./".
1075
+
1076
+ fps (int, optional):
1077
+ The frames per second (FPS) to use for the output video. If 0 (default), the
1078
+ original FPS of the input video is used. Otherwise, overrides it.
1079
+
1080
+ refine_foreground (bool, optional):
1081
+ Whether to run an additional “refine foreground” process on each frame.
1082
+ Defaults to False.
1083
+
1084
+ batch (int, optional):
1085
+ Number of frames to process at once (inference batch size). Large batch sizes
1086
+ may require more GPU memory. Defaults to 1.
1087
+
1088
+ print_frames_processed (bool, optional):
1089
+ If True (default), prints progress (how many frames have been processed) to
1090
+ the console.
1091
+
1092
+ webm (bool, optional):
1093
+ If True (default), exports a WebM video with alpha channel (VP9 / yuva420p).
1094
+ If False, exports an MP4 video composited over a solid color background.
1095
+
1096
+ rgb_value (tuple, optional):
1097
+ The RGB background color (e.g., green screen) used to composite frames when
1098
+ saving to MP4. Defaults to (0, 255, 0).
1099
+
1100
+ Returns:
1101
+ None. Writes the output video(s) to disk in the specified format.
1102
+ """
1103
+
1104
+
1105
+ cap = cv2.VideoCapture(video_path)
1106
+ if not cap.isOpened():
1107
+ raise IOError(f"Cannot open video: {video_path}")
1108
+
1109
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
1110
+ original_fps = 30 if original_fps == 0 else original_fps
1111
+ fps = original_fps if fps == 0 else fps
1112
+
1113
+ ret, first_frame = cap.read()
1114
+ if not ret:
1115
+ raise ValueError("No frames found in the video.")
1116
+ height, width = first_frame.shape[:2]
1117
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
1118
+
1119
+ foregrounds = []
1120
+ frame_idx = 0
1121
+ processed_count = 0
1122
+ batch_frames = []
1123
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
1124
+
1125
+ while True:
1126
+ ret, frame = cap.read()
1127
+ if not ret:
1128
+ if batch_frames:
1129
+ batch_results = self.inference(batch_frames, refine_foreground)
1130
+ if isinstance(batch_results, Image.Image):
1131
+ foregrounds.append(batch_results)
1132
+ else:
1133
+ foregrounds.extend(batch_results)
1134
+ if print_frames_processed:
1135
+ print(f"Processed frames {frame_idx-len(batch_frames)+1} to {frame_idx} of {total_frames}")
1136
+ break
1137
+
1138
+ # Process every frame instead of using intervals
1139
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
1140
+ pil_frame = Image.fromarray(frame_rgb)
1141
+ batch_frames.append(pil_frame)
1142
+
1143
+ if len(batch_frames) == batch:
1144
+ batch_results = self.inference(batch_frames, refine_foreground)
1145
+ if isinstance(batch_results, Image.Image):
1146
+ foregrounds.append(batch_results)
1147
+ else:
1148
+ foregrounds.extend(batch_results)
1149
+ if print_frames_processed:
1150
+ print(f"Processed frames {frame_idx-batch+1} to {frame_idx} of {total_frames}")
1151
+ batch_frames = []
1152
+ processed_count += batch
1153
+
1154
+ frame_idx += 1
1155
+
1156
+
1157
+ if webm:
1158
+ alpha_webm_path = os.path.join(output_path, "foreground.webm")
1159
+ pil_images_to_webm_alpha(foregrounds, alpha_webm_path, fps=original_fps)
1160
+
1161
+ else:
1162
+ cap.release()
1163
+ fg_output = os.path.join(output_path, 'foreground.mp4')
1164
+
1165
+ pil_images_to_mp4(foregrounds, fg_output, fps=original_fps,rgb_value=rgb_value)
1166
+ cv2.destroyAllWindows()
1167
+
1168
+ try:
1169
+ fg_audio_output = os.path.join(output_path, 'foreground_output_with_audio.mp4')
1170
+ add_audio_to_video(fg_output, video_path, fg_audio_output)
1171
+ except Exception as e:
1172
+ print("No audio found in the original video")
1173
+ print(e)
1174
+
1175
+
1176
+
1177
+
1178
+
1179
+ def rgb_loader_refiner( original_image):
1180
+ h, w = original_image.size
1181
+
1182
+ image = original_image
1183
+ # Convert to RGB if necessary
1184
+ if image.mode != 'RGB':
1185
+ image = image.convert('RGB')
1186
+
1187
+ # Resize the image
1188
+ image = image.resize((1024, 1024), resample=Image.LANCZOS)
1189
+
1190
+ return image.convert('RGB'), h, w,original_image
1191
+
1192
+ # Define the image transformation
1193
+ img_transform = transforms.Compose([
1194
+ transforms.ToTensor(),
1195
+ transforms.ConvertImageDtype(torch.float16),
1196
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1197
+ ])
1198
+
1199
+
1200
+
1201
+
1202
+ def pil_images_to_mp4(images, output_path, fps=24, rgb_value=(0, 255, 0)):
1203
+ """
1204
+ Converts an array of PIL images to an MP4 video.
1205
+
1206
+ Args:
1207
+ images: List of PIL images
1208
+ output_path: Path to save the MP4 file
1209
+ fps: Frames per second (default: 24)
1210
+ rgb_value: Background RGB color tuple (default: green (0, 255, 0))
1211
+ """
1212
+ if not images:
1213
+ raise ValueError("No images provided to convert to MP4.")
1214
+
1215
+ width, height = images[0].size
1216
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
1217
+ video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
1218
+
1219
+ for image in images:
1220
+ # If image has alpha channel, composite onto the specified background color
1221
+ if image.mode == 'RGBA':
1222
+ # Create background image with specified RGB color
1223
+ background = Image.new('RGB', image.size, rgb_value)
1224
+ background = background.convert('RGBA')
1225
+ # Composite the image onto the background
1226
+ image = Image.alpha_composite(background, image)
1227
+ image = image.convert('RGB')
1228
+ else:
1229
+ # Ensure RGB format for non-alpha images
1230
+ image = image.convert('RGB')
1231
+
1232
+ # Convert to OpenCV format and write
1233
+ open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
1234
+ video_writer.write(open_cv_image)
1235
+
1236
+ video_writer.release()
1237
+
1238
+ def pil_images_to_webm_alpha(images, output_path, fps=30):
1239
+ """
1240
+ Converts a list of PIL RGBA images to a VP9 .webm video with alpha channel.
1241
+
1242
+ NOTE: Not all players will display alpha in WebM.
1243
+ Browsers like Chrome/Firefox typically do support VP9 alpha.
1244
+ """
1245
+ if not images:
1246
+ raise ValueError("No images provided for WebM with alpha.")
1247
+
1248
+ # Ensure output directory exists
1249
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
1250
+
1251
+ with tempfile.TemporaryDirectory() as tmpdir:
1252
+ # Save frames as PNG (with alpha)
1253
+ for idx, img in enumerate(images):
1254
+ if img.mode != "RGBA":
1255
+ img = img.convert("RGBA")
1256
+ out_path = os.path.join(tmpdir, f"{idx:06d}.png")
1257
+ img.save(out_path, "PNG")
1258
+
1259
+ # Construct ffmpeg command
1260
+ # -c:v libvpx-vp9 => VP9 encoder
1261
+ # -pix_fmt yuva420p => alpha-enabled pixel format
1262
+ # -auto-alt-ref 0 => helps preserve alpha frames (libvpx quirk)
1263
+ ffmpeg_cmd = [
1264
+ "ffmpeg", "-y",
1265
+ "-framerate", str(fps),
1266
+ "-i", os.path.join(tmpdir, "%06d.png"),
1267
+ "-c:v", "libvpx-vp9",
1268
+ "-pix_fmt", "yuva420p",
1269
+ "-auto-alt-ref", "0",
1270
+ output_path
1271
+ ]
1272
+
1273
+ subprocess.run(ffmpeg_cmd, check=True)
1274
+
1275
+ print(f"WebM with alpha saved to {output_path}")
1276
+
1277
+ def add_audio_to_video(video_without_audio_path, original_video_path, output_path):
1278
+ """
1279
+ Check if the original video has an audio stream. If yes, add it. If not, skip.
1280
+ """
1281
+ # 1) Probe original video for audio streams
1282
+ probe_command = [
1283
+ 'ffprobe', '-v', 'error',
1284
+ '-select_streams', 'a:0',
1285
+ '-show_entries', 'stream=index',
1286
+ '-of', 'csv=p=0',
1287
+ original_video_path
1288
+ ]
1289
+ result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
1290
+
1291
+ # result.stdout is empty if no audio stream found
1292
+ if not result.stdout.strip():
1293
+ print("No audio track found in original video, skipping audio addition.")
1294
+ return
1295
+
1296
+ print("Audio track detected; proceeding to mux audio.")
1297
+ # 2) If audio found, run ffmpeg to add it
1298
+ command = [
1299
+ 'ffmpeg', '-y',
1300
+ '-i', video_without_audio_path,
1301
+ '-i', original_video_path,
1302
+ '-c', 'copy',
1303
+ '-map', '0:v:0',
1304
+ '-map', '1:a:0', # we know there's an audio track now
1305
+ output_path
1306
+ ]
1307
+ subprocess.run(command, check=True)
1308
+ print(f"Audio added successfully => {output_path}")
1309
+
1310
+
1311
+
1312
+
1313
+
1314
+ ### Thanks to the source: https://huggingface.co/ZhengPeng7/BiRefNet/blob/main/handler.py
1315
+ def refine_foreground_process(image, mask, r=90):
1316
+ if mask.size != image.size:
1317
+ mask = mask.resize(image.size)
1318
+ image = np.array(image) / 255.0
1319
+ mask = np.array(mask) / 255.0
1320
+ estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
1321
+ image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
1322
+ return image_masked
1323
+
1324
+
1325
+ def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
1326
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
1327
+ alpha = alpha[:, :, None]
1328
+ F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
1329
+ return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
1330
+
1331
+
1332
+ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
1333
+ if isinstance(image, Image.Image):
1334
+ image = np.array(image) / 255.0
1335
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
1336
+
1337
+ blurred_FA = cv2.blur(F * alpha, (r, r))
1338
+ blurred_F = blurred_FA / (blurred_alpha + 1e-5)
1339
+
1340
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
1341
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
1342
+ F = blurred_F + alpha * \
1343
+ (image - alpha * blurred_F - (1 - alpha) * blurred_B)
1344
+ F = np.clip(F, 0, 1)
1345
+ return F, blurred_B
1346
+
1347
+
1348
+
1349
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
1350
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
1351
+ ma = torch.max(result)
1352
+ mi = torch.min(result)
1353
+ result = (result - mi) / (ma - mi)
1354
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
1355
+ im_array = np.squeeze(im_array)
1356
+ return im_array
1357
+
1358
+
1359
+
1360
+
1361
+ def rgb_loader_refiner( original_image):
1362
+ h, w = original_image.size
1363
+ # # Apply EXIF orientation
1364
+
1365
+ if original_image.mode != 'RGB':
1366
+ original_image = original_image.convert('RGB')
1367
+
1368
+ image = original_image
1369
+ # Convert to RGB if necessary
1370
+
1371
+ # Resize the image
1372
+ image = image.resize((1024, 1024), resample=Image.LANCZOS)
1373
+
1374
+ return image, h, w,original_image
1375
+
1376
+
1377
+
BEN2_Base.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:926144a876bda06f125555b4f5a239ece89dc6eb838a863700ca9bf192161a1c
3
+ size 1134584206
BEN2_demo_pictures/grid_example1.png ADDED

Git LFS Details

  • SHA256: 49df5808df57c1db87f1bdf94ff0687ba436f2c377378799dd3ce49be85e0973
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
BEN2_demo_pictures/grid_example2.png ADDED

Git LFS Details

  • SHA256: 0899c57bdb592ccf3b04ed0316d3f8d4b23f337ff00f6807b18b46b74b8e91bf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.76 MB
BEN2_demo_pictures/grid_example3.png ADDED

Git LFS Details

  • SHA256: f0e2cb53afd4ad04daa223525f688cad835826890eb4ababb1e0bf0e629800e5
  • Pointer size: 132 Bytes
  • Size of remote file: 8.59 MB
BEN2_demo_pictures/grid_example6.png ADDED

Git LFS Details

  • SHA256: 327eca743beef0cd452e40015b0695d67806cd6b59bd3a7759cfd2be260c5cae
  • Pointer size: 132 Bytes
  • Size of remote file: 2.37 MB
BEN2_demo_pictures/grid_example7.png ADDED

Git LFS Details

  • SHA256: 0f758d617b3266d2fb540bd5088c58a3a69d1cb5e54dcdf0a4d5ea6b74c6e7e2
  • Pointer size: 132 Bytes
  • Size of remote file: 5.27 MB
BEN2_demo_pictures/model_comparison.png ADDED
README.md CHANGED
@@ -1,3 +1,151 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: image-segmentation
4
+ tags:
5
+ - BEN2
6
+ - background-remove
7
+ - mask-generation
8
+ - Dichotomous image segmentation
9
+ - background remove
10
+ - foreground
11
+ - background
12
+ - remove background
13
+ - pytorch
14
+ ---
15
+
16
+ # BEN2: Background Erase Network
17
+
18
+ [![arXiv](https://img.shields.io/badge/arXiv-2501.06230-b31b1b.svg)](https://arxiv.org/abs/2501.06230)
19
+ [![GitHub](https://img.shields.io/badge/GitHub-BEN2-black.svg)](https://github.com/PramaLLC/BEN2/)
20
+ [![Website](https://img.shields.io/badge/Website-backgrounderase.net-104233)](https://backgrounderase.net)
21
+
22
+ ## Overview
23
+ BEN2 (Background Erase Network) introduces a novel approach to foreground segmentation through its innovative Confidence Guided Matting (CGM) pipeline. The architecture employs a refiner network that targets and processes pixels where the base model exhibits lower confidence levels, resulting in more precise and reliable matting results. This model is built on BEN:
24
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/ben-using-confidence-guided-matting-for/dichotomous-image-segmentation-on-dis-vd)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-vd?p=ben-using-confidence-guided-matting-for)
25
+
26
+
27
+
28
+
29
+ ## BEN2 access
30
+ BEN2 was trained on the DIS5k and our 22K proprietary segmentation dataset. Our enhanced model delivers superior performance in hair matting, 4K processing, object segmentation, and edge refinement. Our Base model is open source. To try the full model through our free web demo or integrate BEN2 into your project with our API:
31
+ - 🌐 [backgrounderase.net](https://backgrounderase.net)
32
+
33
+
34
+ ## Contact us
35
+ - For access to our commercial model email us at [email protected]
36
+ - Our website: https://prama.llc/
37
+ - Follow us on X: https://x.com/PramaResearch/
38
+
39
+
40
+ ## Quick start code (inside cloned repo)
41
+
42
+ ```python
43
+ import model
44
+ from PIL import Image
45
+ import torch
46
+
47
+
48
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49
+
50
+ file = "./image.png" # input image
51
+
52
+ model = model.BEN_Base().to(device).eval() #init pipeline
53
+
54
+ model.loadcheckpoints("./BEN_Base2.pth")
55
+ image = Image.open(file)
56
+ foreground = model.inference(image, refine_foreground=False,) #Refine foreground is an extract postprocessing step that increases inference time but can improve matting edges. The default value is False.
57
+
58
+ foreground.save("./foreground.png")
59
+
60
+ ```
61
+
62
+
63
+ ## Batch image processing
64
+
65
+ ```python
66
+ import model
67
+ from PIL import Image
68
+ import torch
69
+
70
+
71
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
72
+
73
+
74
+
75
+ model = model.BEN_Base().to(device).eval() #init pipeline
76
+
77
+ model.loadcheckpoints("./BEN_Base2.pth")
78
+
79
+ file1 = "./image1.png" # input image1
80
+ file2 = "./image2.png" # input image2
81
+ image1 = Image.open(file1)
82
+ image2 = Image.open(file2)
83
+
84
+
85
+
86
+ foregrounds = model.inference([image1, image2]) # We recommended that batch size not exceed 3 for consumer GPUs as there are minimal inference gains. Due to our custom batch processing for the MVANet decoding steps.
87
+ foregrounds[0].save("./foreground1.png")
88
+ foregrounds[1].save("./foreground2.png")
89
+
90
+ ```
91
+
92
+
93
+
94
+ # BEN2 video segmentation
95
+ [![BEN2 Demo](https://img.youtube.com/vi/skEXiIHQcys/0.jpg)](https://www.youtube.com/watch?v=skEXiIHQcys)
96
+
97
+ ## Video Segmentation
98
+
99
+ ```bash
100
+ sudo apt update
101
+ sudo apt install ffmpeg
102
+ ```
103
+
104
+ ```python
105
+ import model
106
+ from PIL import Image
107
+ import torch
108
+
109
+
110
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
111
+
112
+ file = "./image.png" # input image
113
+
114
+ model = model.BEN_Base().to(device).eval() #init pipeline
115
+
116
+ model.loadcheckpoints("./BEN_Base2.pth")
117
+
118
+
119
+
120
+
121
+ model.segment_video(
122
+ video_path="/path_to_your_video.mp4",
123
+ output_path="./", # Outputs will be saved as foreground.webm or foreground.mp4. The default value is "./"
124
+ fps=0, # If this is set to 0 CV2 will detect the fps in the original video. The default value is 0.
125
+ refine_foreground=False, #refine foreground is an extract postprocessing step that increases inference time but can improve matting edges. The default value is False.
126
+ batch=1, # We recommended that batch size not exceed 3 for consumer GPUs as there are minimal inference gains. The default value is 1.
127
+ print_frames_processed=True, #Informs you what frame is being processed. The default value is True.
128
+ webm = False, # This will output an alpha layer video but this defaults to mp4 when webm is false. The default value is False.
129
+ rgb_value= (0, 255, 0) # If you do not use webm this will be the RGB value of the resulting background only when webm is False. The default value is a green background (0,255,0).
130
+ )
131
+
132
+
133
+ ```
134
+
135
+
136
+
137
+ **# BEN2 evaluation**
138
+ ![Model Comparison](BEN2_demo_pictures/model_comparison.png)
139
+
140
+ RMBG 2.0 did not preserve the DIS 5k validation dataset
141
+
142
+ ![Example 1](BEN2_demo_pictures/grid_example1.png)
143
+ ![Example 2](BEN2_demo_pictures/grid_example2.png)
144
+ ![Example 3](BEN2_demo_pictures/grid_example3.png)
145
+ ![Example 6](BEN2_demo_pictures/grid_example6.png)
146
+ ![Example 7](BEN2_demo_pictures/grid_example7.png)
147
+
148
+
149
+ ## Installation
150
+ 1. Clone Repo
151
+ 2. Install requirements.txt
config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "PramaLLC/BEN2",
3
+ "architectures": ["PramaBEN_Base"],
4
+ "version": "1.0",
5
+ "torch_dtype": "float32"
6
+ }
inference.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import BEN2
2
+ from PIL import Image
3
+ import torch
4
+
5
+
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+
8
+ file = "./image.png" # input image
9
+
10
+ model = BEN2.BEN_Base().to(device).eval() #init pipeline
11
+
12
+ model.loadcheckpoints("./BEN_Base2.pth")
13
+ image = Image.open(file)
14
+ mask, foreground = model.inference(image)
15
+
16
+ mask.save("./mask.png")
17
+ foreground.save("./foreground.png")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy>=1.21.0
2
+ torch>=1.9.0
3
+ einops>=0.6.0
4
+ Pillow>=9.0.0
5
+ timm>=0.6.0
6
+ torchvision>=0.10.0