MaxwellMeyer
commited on
Upload 12 files
Browse files- .gitattributes +5 -0
- BEN2.py +1377 -0
- BEN2_Base.pth +3 -0
- BEN2_demo_pictures/grid_example1.png +3 -0
- BEN2_demo_pictures/grid_example2.png +3 -0
- BEN2_demo_pictures/grid_example3.png +3 -0
- BEN2_demo_pictures/grid_example6.png +3 -0
- BEN2_demo_pictures/grid_example7.png +3 -0
- BEN2_demo_pictures/model_comparison.png +0 -0
- README.md +151 -3
- config.json +6 -0
- inference.py +17 -0
- requirements.txt +6 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
BEN2_demo_pictures/grid_example1.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
BEN2_demo_pictures/grid_example2.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
BEN2_demo_pictures/grid_example3.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
BEN2_demo_pictures/grid_example6.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
BEN2_demo_pictures/grid_example7.png filter=lfs diff=lfs merge=lfs -text
|
BEN2.py
ADDED
@@ -0,0 +1,1377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
import torch.utils.checkpoint as checkpoint
|
8 |
+
import numpy as np
|
9 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
10 |
+
from PIL import Image, ImageOps
|
11 |
+
from torchvision import transforms
|
12 |
+
import numpy as np
|
13 |
+
import random
|
14 |
+
import cv2
|
15 |
+
import os
|
16 |
+
import subprocess
|
17 |
+
import time
|
18 |
+
import tempfile
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
def set_random_seed(seed):
|
24 |
+
random.seed(seed)
|
25 |
+
np.random.seed(seed)
|
26 |
+
torch.manual_seed(seed)
|
27 |
+
torch.cuda.manual_seed(seed)
|
28 |
+
torch.cuda.manual_seed_all(seed)
|
29 |
+
torch.backends.cudnn.deterministic = True
|
30 |
+
torch.backends.cudnn.benchmark = False
|
31 |
+
set_random_seed(9)
|
32 |
+
|
33 |
+
|
34 |
+
torch.set_float32_matmul_precision('highest')
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
class Mlp(nn.Module):
|
39 |
+
""" Multilayer perceptron."""
|
40 |
+
|
41 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
42 |
+
super().__init__()
|
43 |
+
out_features = out_features or in_features
|
44 |
+
hidden_features = hidden_features or in_features
|
45 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
46 |
+
self.act = act_layer()
|
47 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
48 |
+
self.drop = nn.Dropout(drop)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x = self.fc1(x)
|
52 |
+
x = self.act(x)
|
53 |
+
x = self.drop(x)
|
54 |
+
x = self.fc2(x)
|
55 |
+
x = self.drop(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
def window_partition(x, window_size):
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
x: (B, H, W, C)
|
63 |
+
window_size (int): window size
|
64 |
+
Returns:
|
65 |
+
windows: (num_windows*B, window_size, window_size, C)
|
66 |
+
"""
|
67 |
+
B, H, W, C = x.shape
|
68 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
69 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
70 |
+
return windows
|
71 |
+
|
72 |
+
|
73 |
+
def window_reverse(windows, window_size, H, W):
|
74 |
+
"""
|
75 |
+
Args:
|
76 |
+
windows: (num_windows*B, window_size, window_size, C)
|
77 |
+
window_size (int): Window size
|
78 |
+
H (int): Height of image
|
79 |
+
W (int): Width of image
|
80 |
+
Returns:
|
81 |
+
x: (B, H, W, C)
|
82 |
+
"""
|
83 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
84 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
85 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class WindowAttention(nn.Module):
|
90 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
91 |
+
It supports both of shifted and non-shifted window.
|
92 |
+
Args:
|
93 |
+
dim (int): Number of input channels.
|
94 |
+
window_size (tuple[int]): The height and width of the window.
|
95 |
+
num_heads (int): Number of attention heads.
|
96 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
97 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
98 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
99 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
103 |
+
|
104 |
+
super().__init__()
|
105 |
+
self.dim = dim
|
106 |
+
self.window_size = window_size # Wh, Ww
|
107 |
+
self.num_heads = num_heads
|
108 |
+
head_dim = dim // num_heads
|
109 |
+
self.scale = qk_scale or head_dim ** -0.5
|
110 |
+
|
111 |
+
# define a parameter table of relative position bias
|
112 |
+
self.relative_position_bias_table = nn.Parameter(
|
113 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
114 |
+
|
115 |
+
# get pair-wise relative position index for each token inside the window
|
116 |
+
coords_h = torch.arange(self.window_size[0])
|
117 |
+
coords_w = torch.arange(self.window_size[1])
|
118 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
119 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
120 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
121 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
122 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
123 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
124 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
125 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
126 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
127 |
+
|
128 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
129 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
130 |
+
self.proj = nn.Linear(dim, dim)
|
131 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
132 |
+
|
133 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
134 |
+
self.softmax = nn.Softmax(dim=-1)
|
135 |
+
|
136 |
+
def forward(self, x, mask=None):
|
137 |
+
""" Forward function.
|
138 |
+
Args:
|
139 |
+
x: input features with shape of (num_windows*B, N, C)
|
140 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
141 |
+
"""
|
142 |
+
B_, N, C = x.shape
|
143 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
144 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
145 |
+
|
146 |
+
q = q * self.scale
|
147 |
+
attn = (q @ k.transpose(-2, -1))
|
148 |
+
|
149 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
150 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
151 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
152 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
153 |
+
|
154 |
+
if mask is not None:
|
155 |
+
nW = mask.shape[0]
|
156 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
157 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
158 |
+
attn = self.softmax(attn)
|
159 |
+
else:
|
160 |
+
attn = self.softmax(attn)
|
161 |
+
|
162 |
+
attn = self.attn_drop(attn)
|
163 |
+
|
164 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
165 |
+
x = self.proj(x)
|
166 |
+
x = self.proj_drop(x)
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class SwinTransformerBlock(nn.Module):
|
171 |
+
""" Swin Transformer Block.
|
172 |
+
Args:
|
173 |
+
dim (int): Number of input channels.
|
174 |
+
num_heads (int): Number of attention heads.
|
175 |
+
window_size (int): Window size.
|
176 |
+
shift_size (int): Shift size for SW-MSA.
|
177 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
178 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
179 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
180 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
181 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
182 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
183 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
184 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
188 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
189 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
190 |
+
super().__init__()
|
191 |
+
self.dim = dim
|
192 |
+
self.num_heads = num_heads
|
193 |
+
self.window_size = window_size
|
194 |
+
self.shift_size = shift_size
|
195 |
+
self.mlp_ratio = mlp_ratio
|
196 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
197 |
+
|
198 |
+
self.norm1 = norm_layer(dim)
|
199 |
+
self.attn = WindowAttention(
|
200 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
201 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
202 |
+
|
203 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
204 |
+
self.norm2 = norm_layer(dim)
|
205 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
206 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
207 |
+
|
208 |
+
self.H = None
|
209 |
+
self.W = None
|
210 |
+
|
211 |
+
def forward(self, x, mask_matrix):
|
212 |
+
""" Forward function.
|
213 |
+
Args:
|
214 |
+
x: Input feature, tensor size (B, H*W, C).
|
215 |
+
H, W: Spatial resolution of the input feature.
|
216 |
+
mask_matrix: Attention mask for cyclic shift.
|
217 |
+
"""
|
218 |
+
B, L, C = x.shape
|
219 |
+
H, W = self.H, self.W
|
220 |
+
assert L == H * W, "input feature has wrong size"
|
221 |
+
|
222 |
+
shortcut = x
|
223 |
+
x = self.norm1(x)
|
224 |
+
x = x.view(B, H, W, C)
|
225 |
+
|
226 |
+
# pad feature maps to multiples of window size
|
227 |
+
pad_l = pad_t = 0
|
228 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
229 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
230 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
231 |
+
_, Hp, Wp, _ = x.shape
|
232 |
+
|
233 |
+
# cyclic shift
|
234 |
+
if self.shift_size > 0:
|
235 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
236 |
+
attn_mask = mask_matrix
|
237 |
+
else:
|
238 |
+
shifted_x = x
|
239 |
+
attn_mask = None
|
240 |
+
|
241 |
+
# partition windows
|
242 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
243 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
244 |
+
|
245 |
+
# W-MSA/SW-MSA
|
246 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
247 |
+
|
248 |
+
# merge windows
|
249 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
250 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
251 |
+
|
252 |
+
# reverse cyclic shift
|
253 |
+
if self.shift_size > 0:
|
254 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
255 |
+
else:
|
256 |
+
x = shifted_x
|
257 |
+
|
258 |
+
if pad_r > 0 or pad_b > 0:
|
259 |
+
x = x[:, :H, :W, :].contiguous()
|
260 |
+
|
261 |
+
x = x.view(B, H * W, C)
|
262 |
+
|
263 |
+
# FFN
|
264 |
+
x = shortcut + self.drop_path(x)
|
265 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
266 |
+
|
267 |
+
return x
|
268 |
+
|
269 |
+
|
270 |
+
class PatchMerging(nn.Module):
|
271 |
+
""" Patch Merging Layer
|
272 |
+
Args:
|
273 |
+
dim (int): Number of input channels.
|
274 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
275 |
+
"""
|
276 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
277 |
+
super().__init__()
|
278 |
+
self.dim = dim
|
279 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
280 |
+
self.norm = norm_layer(4 * dim)
|
281 |
+
|
282 |
+
def forward(self, x, H, W):
|
283 |
+
""" Forward function.
|
284 |
+
Args:
|
285 |
+
x: Input feature, tensor size (B, H*W, C).
|
286 |
+
H, W: Spatial resolution of the input feature.
|
287 |
+
"""
|
288 |
+
B, L, C = x.shape
|
289 |
+
assert L == H * W, "input feature has wrong size"
|
290 |
+
|
291 |
+
x = x.view(B, H, W, C)
|
292 |
+
|
293 |
+
# padding
|
294 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
295 |
+
if pad_input:
|
296 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
297 |
+
|
298 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
299 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
300 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
301 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
302 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
303 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
304 |
+
|
305 |
+
x = self.norm(x)
|
306 |
+
x = self.reduction(x)
|
307 |
+
|
308 |
+
return x
|
309 |
+
|
310 |
+
|
311 |
+
class BasicLayer(nn.Module):
|
312 |
+
""" A basic Swin Transformer layer for one stage.
|
313 |
+
Args:
|
314 |
+
dim (int): Number of feature channels
|
315 |
+
depth (int): Depths of this stage.
|
316 |
+
num_heads (int): Number of attention head.
|
317 |
+
window_size (int): Local window size. Default: 7.
|
318 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
319 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
320 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
321 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
322 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
323 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
324 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
325 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
326 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
327 |
+
"""
|
328 |
+
|
329 |
+
def __init__(self,
|
330 |
+
dim,
|
331 |
+
depth,
|
332 |
+
num_heads,
|
333 |
+
window_size=7,
|
334 |
+
mlp_ratio=4.,
|
335 |
+
qkv_bias=True,
|
336 |
+
qk_scale=None,
|
337 |
+
drop=0.,
|
338 |
+
attn_drop=0.,
|
339 |
+
drop_path=0.,
|
340 |
+
norm_layer=nn.LayerNorm,
|
341 |
+
downsample=None,
|
342 |
+
use_checkpoint=False):
|
343 |
+
super().__init__()
|
344 |
+
self.window_size = window_size
|
345 |
+
self.shift_size = window_size // 2
|
346 |
+
self.depth = depth
|
347 |
+
self.use_checkpoint = use_checkpoint
|
348 |
+
|
349 |
+
# build blocks
|
350 |
+
self.blocks = nn.ModuleList([
|
351 |
+
SwinTransformerBlock(
|
352 |
+
dim=dim,
|
353 |
+
num_heads=num_heads,
|
354 |
+
window_size=window_size,
|
355 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
356 |
+
mlp_ratio=mlp_ratio,
|
357 |
+
qkv_bias=qkv_bias,
|
358 |
+
qk_scale=qk_scale,
|
359 |
+
drop=drop,
|
360 |
+
attn_drop=attn_drop,
|
361 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
362 |
+
norm_layer=norm_layer)
|
363 |
+
for i in range(depth)])
|
364 |
+
|
365 |
+
# patch merging layer
|
366 |
+
if downsample is not None:
|
367 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
368 |
+
else:
|
369 |
+
self.downsample = None
|
370 |
+
|
371 |
+
def forward(self, x, H, W):
|
372 |
+
""" Forward function.
|
373 |
+
Args:
|
374 |
+
x: Input feature, tensor size (B, H*W, C).
|
375 |
+
H, W: Spatial resolution of the input feature.
|
376 |
+
"""
|
377 |
+
|
378 |
+
# calculate attention mask for SW-MSA
|
379 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
380 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
381 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
382 |
+
h_slices = (slice(0, -self.window_size),
|
383 |
+
slice(-self.window_size, -self.shift_size),
|
384 |
+
slice(-self.shift_size, None))
|
385 |
+
w_slices = (slice(0, -self.window_size),
|
386 |
+
slice(-self.window_size, -self.shift_size),
|
387 |
+
slice(-self.shift_size, None))
|
388 |
+
cnt = 0
|
389 |
+
for h in h_slices:
|
390 |
+
for w in w_slices:
|
391 |
+
img_mask[:, h, w, :] = cnt
|
392 |
+
cnt += 1
|
393 |
+
|
394 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
395 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
396 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
397 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
398 |
+
|
399 |
+
for blk in self.blocks:
|
400 |
+
blk.H, blk.W = H, W
|
401 |
+
if self.use_checkpoint:
|
402 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
403 |
+
else:
|
404 |
+
x = blk(x, attn_mask)
|
405 |
+
if self.downsample is not None:
|
406 |
+
x_down = self.downsample(x, H, W)
|
407 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
408 |
+
return x, H, W, x_down, Wh, Ww
|
409 |
+
else:
|
410 |
+
return x, H, W, x, H, W
|
411 |
+
|
412 |
+
|
413 |
+
class PatchEmbed(nn.Module):
|
414 |
+
""" Image to Patch Embedding
|
415 |
+
Args:
|
416 |
+
patch_size (int): Patch token size. Default: 4.
|
417 |
+
in_chans (int): Number of input image channels. Default: 3.
|
418 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
419 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
420 |
+
"""
|
421 |
+
|
422 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
423 |
+
super().__init__()
|
424 |
+
patch_size = to_2tuple(patch_size)
|
425 |
+
self.patch_size = patch_size
|
426 |
+
|
427 |
+
self.in_chans = in_chans
|
428 |
+
self.embed_dim = embed_dim
|
429 |
+
|
430 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
431 |
+
if norm_layer is not None:
|
432 |
+
self.norm = norm_layer(embed_dim)
|
433 |
+
else:
|
434 |
+
self.norm = None
|
435 |
+
|
436 |
+
def forward(self, x):
|
437 |
+
"""Forward function."""
|
438 |
+
# padding
|
439 |
+
_, _, H, W = x.size()
|
440 |
+
if W % self.patch_size[1] != 0:
|
441 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
442 |
+
if H % self.patch_size[0] != 0:
|
443 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
444 |
+
|
445 |
+
x = self.proj(x) # B C Wh Ww
|
446 |
+
if self.norm is not None:
|
447 |
+
Wh, Ww = x.size(2), x.size(3)
|
448 |
+
x = x.flatten(2).transpose(1, 2)
|
449 |
+
x = self.norm(x)
|
450 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
451 |
+
|
452 |
+
return x
|
453 |
+
|
454 |
+
|
455 |
+
class SwinTransformer(nn.Module):
|
456 |
+
""" Swin Transformer backbone.
|
457 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
458 |
+
https://arxiv.org/pdf/2103.14030
|
459 |
+
Args:
|
460 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
461 |
+
used in absolute postion embedding. Default 224.
|
462 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
463 |
+
in_chans (int): Number of input image channels. Default: 3.
|
464 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
465 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
466 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
467 |
+
window_size (int): Window size. Default: 7.
|
468 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
469 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
470 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
471 |
+
drop_rate (float): Dropout rate.
|
472 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
473 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
474 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
475 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
476 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
477 |
+
out_indices (Sequence[int]): Output from which stages.
|
478 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
479 |
+
-1 means not freezing any parameters.
|
480 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
481 |
+
"""
|
482 |
+
|
483 |
+
def __init__(self,
|
484 |
+
pretrain_img_size=224,
|
485 |
+
patch_size=4,
|
486 |
+
in_chans=3,
|
487 |
+
embed_dim=96,
|
488 |
+
depths=[2, 2, 6, 2],
|
489 |
+
num_heads=[3, 6, 12, 24],
|
490 |
+
window_size=7,
|
491 |
+
mlp_ratio=4.,
|
492 |
+
qkv_bias=True,
|
493 |
+
qk_scale=None,
|
494 |
+
drop_rate=0.,
|
495 |
+
attn_drop_rate=0.,
|
496 |
+
drop_path_rate=0.2,
|
497 |
+
norm_layer=nn.LayerNorm,
|
498 |
+
ape=False,
|
499 |
+
patch_norm=True,
|
500 |
+
out_indices=(0, 1, 2, 3),
|
501 |
+
frozen_stages=-1,
|
502 |
+
use_checkpoint=False):
|
503 |
+
super().__init__()
|
504 |
+
|
505 |
+
self.pretrain_img_size = pretrain_img_size
|
506 |
+
self.num_layers = len(depths)
|
507 |
+
self.embed_dim = embed_dim
|
508 |
+
self.ape = ape
|
509 |
+
self.patch_norm = patch_norm
|
510 |
+
self.out_indices = out_indices
|
511 |
+
self.frozen_stages = frozen_stages
|
512 |
+
|
513 |
+
# split image into non-overlapping patches
|
514 |
+
self.patch_embed = PatchEmbed(
|
515 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
516 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
517 |
+
|
518 |
+
# absolute position embedding
|
519 |
+
if self.ape:
|
520 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
521 |
+
patch_size = to_2tuple(patch_size)
|
522 |
+
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
|
523 |
+
|
524 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
|
525 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
526 |
+
|
527 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
528 |
+
|
529 |
+
# stochastic depth
|
530 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
531 |
+
|
532 |
+
# build layers
|
533 |
+
self.layers = nn.ModuleList()
|
534 |
+
for i_layer in range(self.num_layers):
|
535 |
+
layer = BasicLayer(
|
536 |
+
dim=int(embed_dim * 2 ** i_layer),
|
537 |
+
depth=depths[i_layer],
|
538 |
+
num_heads=num_heads[i_layer],
|
539 |
+
window_size=window_size,
|
540 |
+
mlp_ratio=mlp_ratio,
|
541 |
+
qkv_bias=qkv_bias,
|
542 |
+
qk_scale=qk_scale,
|
543 |
+
drop=drop_rate,
|
544 |
+
attn_drop=attn_drop_rate,
|
545 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
546 |
+
norm_layer=norm_layer,
|
547 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
548 |
+
use_checkpoint=use_checkpoint)
|
549 |
+
self.layers.append(layer)
|
550 |
+
|
551 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
552 |
+
self.num_features = num_features
|
553 |
+
|
554 |
+
# add a norm layer for each output
|
555 |
+
for i_layer in out_indices:
|
556 |
+
layer = norm_layer(num_features[i_layer])
|
557 |
+
layer_name = f'norm{i_layer}'
|
558 |
+
self.add_module(layer_name, layer)
|
559 |
+
|
560 |
+
self._freeze_stages()
|
561 |
+
|
562 |
+
def _freeze_stages(self):
|
563 |
+
if self.frozen_stages >= 0:
|
564 |
+
self.patch_embed.eval()
|
565 |
+
for param in self.patch_embed.parameters():
|
566 |
+
param.requires_grad = False
|
567 |
+
|
568 |
+
if self.frozen_stages >= 1 and self.ape:
|
569 |
+
self.absolute_pos_embed.requires_grad = False
|
570 |
+
|
571 |
+
if self.frozen_stages >= 2:
|
572 |
+
self.pos_drop.eval()
|
573 |
+
for i in range(0, self.frozen_stages - 1):
|
574 |
+
m = self.layers[i]
|
575 |
+
m.eval()
|
576 |
+
for param in m.parameters():
|
577 |
+
param.requires_grad = False
|
578 |
+
|
579 |
+
|
580 |
+
def forward(self, x):
|
581 |
+
|
582 |
+
x = self.patch_embed(x)
|
583 |
+
|
584 |
+
Wh, Ww = x.size(2), x.size(3)
|
585 |
+
if self.ape:
|
586 |
+
# interpolate the position embedding to the corresponding size
|
587 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
588 |
+
x = (x + absolute_pos_embed) # B Wh*Ww C
|
589 |
+
|
590 |
+
outs = [x.contiguous()]
|
591 |
+
x = x.flatten(2).transpose(1, 2)
|
592 |
+
x = self.pos_drop(x)
|
593 |
+
|
594 |
+
|
595 |
+
for i in range(self.num_layers):
|
596 |
+
layer = self.layers[i]
|
597 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
598 |
+
|
599 |
+
|
600 |
+
if i in self.out_indices:
|
601 |
+
norm_layer = getattr(self, f'norm{i}')
|
602 |
+
x_out = norm_layer(x_out)
|
603 |
+
|
604 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
605 |
+
outs.append(out)
|
606 |
+
|
607 |
+
|
608 |
+
|
609 |
+
return tuple(outs)
|
610 |
+
|
611 |
+
|
612 |
+
|
613 |
+
|
614 |
+
|
615 |
+
|
616 |
+
|
617 |
+
|
618 |
+
def get_activation_fn(activation):
|
619 |
+
"""Return an activation function given a string"""
|
620 |
+
if activation == "gelu":
|
621 |
+
return F.gelu
|
622 |
+
|
623 |
+
raise RuntimeError(F"activation should be gelu, not {activation}.")
|
624 |
+
|
625 |
+
|
626 |
+
def make_cbr(in_dim, out_dim):
|
627 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
|
628 |
+
|
629 |
+
|
630 |
+
def make_cbg(in_dim, out_dim):
|
631 |
+
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
|
632 |
+
|
633 |
+
|
634 |
+
def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
|
635 |
+
return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
|
636 |
+
|
637 |
+
|
638 |
+
def resize_as(x, y, interpolation='bilinear'):
|
639 |
+
return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
|
640 |
+
|
641 |
+
|
642 |
+
def image2patches(x):
|
643 |
+
"""b c (hg h) (wg w) -> (hg wg b) c h w"""
|
644 |
+
x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2 )
|
645 |
+
return x
|
646 |
+
|
647 |
+
|
648 |
+
def patches2image(x):
|
649 |
+
"""(hg wg b) c h w -> b c (hg h) (wg w)"""
|
650 |
+
x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
651 |
+
return x
|
652 |
+
|
653 |
+
|
654 |
+
|
655 |
+
class PositionEmbeddingSine:
|
656 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
657 |
+
super().__init__()
|
658 |
+
self.num_pos_feats = num_pos_feats
|
659 |
+
self.temperature = temperature
|
660 |
+
self.normalize = normalize
|
661 |
+
if scale is not None and normalize is False:
|
662 |
+
raise ValueError("normalize should be True if scale is passed")
|
663 |
+
if scale is None:
|
664 |
+
scale = 2 * math.pi
|
665 |
+
self.scale = scale
|
666 |
+
self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
|
667 |
+
|
668 |
+
def __call__(self, b, h, w):
|
669 |
+
device = self.dim_t.device
|
670 |
+
mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
|
671 |
+
assert mask is not None
|
672 |
+
not_mask = ~mask
|
673 |
+
y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
|
674 |
+
x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
|
675 |
+
if self.normalize:
|
676 |
+
eps = 1e-6
|
677 |
+
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
678 |
+
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
679 |
+
|
680 |
+
dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
|
681 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
682 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
683 |
+
|
684 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
685 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
686 |
+
|
687 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
688 |
+
|
689 |
+
|
690 |
+
|
691 |
+
class PositionEmbeddingSine:
|
692 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
693 |
+
super().__init__()
|
694 |
+
self.num_pos_feats = num_pos_feats
|
695 |
+
self.temperature = temperature
|
696 |
+
self.normalize = normalize
|
697 |
+
if scale is not None and normalize is False:
|
698 |
+
raise ValueError("normalize should be True if scale is passed")
|
699 |
+
if scale is None:
|
700 |
+
scale = 2 * math.pi
|
701 |
+
self.scale = scale
|
702 |
+
self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
|
703 |
+
|
704 |
+
def __call__(self, b, h, w):
|
705 |
+
device = self.dim_t.device
|
706 |
+
mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
|
707 |
+
assert mask is not None
|
708 |
+
not_mask = ~mask
|
709 |
+
y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
|
710 |
+
x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
|
711 |
+
if self.normalize:
|
712 |
+
eps = 1e-6
|
713 |
+
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
714 |
+
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
715 |
+
|
716 |
+
dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
|
717 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
718 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
719 |
+
|
720 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
721 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
722 |
+
|
723 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
724 |
+
|
725 |
+
|
726 |
+
class MCLM(nn.Module):
|
727 |
+
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
728 |
+
super(MCLM, self).__init__()
|
729 |
+
self.attention = nn.ModuleList([
|
730 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
731 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
732 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
733 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
734 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
735 |
+
])
|
736 |
+
|
737 |
+
self.linear1 = nn.Linear(d_model, d_model * 2)
|
738 |
+
self.linear2 = nn.Linear(d_model * 2, d_model)
|
739 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
740 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
741 |
+
self.norm1 = nn.LayerNorm(d_model)
|
742 |
+
self.norm2 = nn.LayerNorm(d_model)
|
743 |
+
self.dropout = nn.Dropout(0.1)
|
744 |
+
self.dropout1 = nn.Dropout(0.1)
|
745 |
+
self.dropout2 = nn.Dropout(0.1)
|
746 |
+
self.activation = get_activation_fn('gelu')
|
747 |
+
self.pool_ratios = pool_ratios
|
748 |
+
self.p_poses = []
|
749 |
+
self.g_pos = None
|
750 |
+
self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
|
751 |
+
|
752 |
+
def forward(self, l, g):
|
753 |
+
"""
|
754 |
+
l: 4,c,h,w
|
755 |
+
g: 1,c,h,w
|
756 |
+
"""
|
757 |
+
self.p_poses = []
|
758 |
+
self.g_pos = None
|
759 |
+
b, c, h, w = l.size()
|
760 |
+
# 4,c,h,w -> 1,c,2h,2w
|
761 |
+
concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
762 |
+
|
763 |
+
pools = []
|
764 |
+
for pool_ratio in self.pool_ratios:
|
765 |
+
# b,c,h,w
|
766 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
767 |
+
pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
|
768 |
+
pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
|
769 |
+
if self.g_pos is None:
|
770 |
+
pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
|
771 |
+
pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
772 |
+
self.p_poses.append(pos_emb)
|
773 |
+
pools = torch.cat(pools, 0)
|
774 |
+
if self.g_pos is None:
|
775 |
+
self.p_poses = torch.cat(self.p_poses, dim=0)
|
776 |
+
pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
|
777 |
+
self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
|
778 |
+
|
779 |
+
device = pools.device
|
780 |
+
self.p_poses = self.p_poses.to(device)
|
781 |
+
self.g_pos = self.g_pos.to(device)
|
782 |
+
|
783 |
+
|
784 |
+
# attention between glb (q) & multisensory concated-locs (k,v)
|
785 |
+
g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
|
786 |
+
|
787 |
+
|
788 |
+
g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
|
789 |
+
g_hw_b_c = self.norm1(g_hw_b_c)
|
790 |
+
g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
|
791 |
+
g_hw_b_c = self.norm2(g_hw_b_c)
|
792 |
+
|
793 |
+
# attention between origin locs (q) & freashed glb (k,v)
|
794 |
+
l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
|
795 |
+
_g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
|
796 |
+
_g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
|
797 |
+
outputs_re = []
|
798 |
+
for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
|
799 |
+
outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
|
800 |
+
outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
|
801 |
+
|
802 |
+
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
803 |
+
l_hw_b_c = self.norm1(l_hw_b_c)
|
804 |
+
l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
805 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
806 |
+
|
807 |
+
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
808 |
+
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
809 |
+
|
810 |
+
|
811 |
+
|
812 |
+
|
813 |
+
|
814 |
+
|
815 |
+
|
816 |
+
|
817 |
+
|
818 |
+
class MCRM(nn.Module):
|
819 |
+
def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
|
820 |
+
super(MCRM, self).__init__()
|
821 |
+
self.attention = nn.ModuleList([
|
822 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
823 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
824 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
|
825 |
+
nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
|
826 |
+
])
|
827 |
+
self.linear3 = nn.Linear(d_model, d_model * 2)
|
828 |
+
self.linear4 = nn.Linear(d_model * 2, d_model)
|
829 |
+
self.norm1 = nn.LayerNorm(d_model)
|
830 |
+
self.norm2 = nn.LayerNorm(d_model)
|
831 |
+
self.dropout = nn.Dropout(0.1)
|
832 |
+
self.dropout1 = nn.Dropout(0.1)
|
833 |
+
self.dropout2 = nn.Dropout(0.1)
|
834 |
+
self.sigmoid = nn.Sigmoid()
|
835 |
+
self.activation = get_activation_fn('gelu')
|
836 |
+
self.sal_conv = nn.Conv2d(d_model, 1, 1)
|
837 |
+
self.pool_ratios = pool_ratios
|
838 |
+
|
839 |
+
def forward(self, x):
|
840 |
+
device = x.device
|
841 |
+
b, c, h, w = x.size()
|
842 |
+
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
843 |
+
|
844 |
+
patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
845 |
+
|
846 |
+
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
847 |
+
token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
|
848 |
+
loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
849 |
+
|
850 |
+
pools = []
|
851 |
+
for pool_ratio in self.pool_ratios:
|
852 |
+
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
853 |
+
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
854 |
+
pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
855 |
+
|
856 |
+
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
857 |
+
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
858 |
+
|
859 |
+
outputs = []
|
860 |
+
for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
|
861 |
+
v = pools[i]
|
862 |
+
k = v
|
863 |
+
outputs.append(self.attention[i](q, k, v)[0])
|
864 |
+
|
865 |
+
outputs = torch.cat(outputs, 1)
|
866 |
+
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
867 |
+
src = self.norm1(src)
|
868 |
+
src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
|
869 |
+
src = self.norm2(src)
|
870 |
+
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
871 |
+
glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
|
872 |
+
|
873 |
+
return torch.cat((src, glb), 0), token_attention_map
|
874 |
+
|
875 |
+
|
876 |
+
|
877 |
+
class BEN_Base(nn.Module):
|
878 |
+
def __init__(self):
|
879 |
+
super().__init__()
|
880 |
+
|
881 |
+
self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
|
882 |
+
emb_dim = 128
|
883 |
+
self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
884 |
+
self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
885 |
+
self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
886 |
+
self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
887 |
+
self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
888 |
+
|
889 |
+
self.output5 = make_cbr(1024, emb_dim)
|
890 |
+
self.output4 = make_cbr(512, emb_dim)
|
891 |
+
self.output3 = make_cbr(256, emb_dim)
|
892 |
+
self.output2 = make_cbr(128, emb_dim)
|
893 |
+
self.output1 = make_cbr(128, emb_dim)
|
894 |
+
|
895 |
+
self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
|
896 |
+
self.conv1 = make_cbr(emb_dim, emb_dim)
|
897 |
+
self.conv2 = make_cbr(emb_dim, emb_dim)
|
898 |
+
self.conv3 = make_cbr(emb_dim, emb_dim)
|
899 |
+
self.conv4 = make_cbr(emb_dim, emb_dim)
|
900 |
+
self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
|
901 |
+
self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
|
902 |
+
self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
|
903 |
+
self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
|
904 |
+
|
905 |
+
self.insmask_head = nn.Sequential(
|
906 |
+
nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
|
907 |
+
nn.InstanceNorm2d(384),
|
908 |
+
nn.GELU(),
|
909 |
+
nn.Conv2d(384, 384, kernel_size=3, padding=1),
|
910 |
+
nn.InstanceNorm2d(384),
|
911 |
+
nn.GELU(),
|
912 |
+
nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
|
913 |
+
)
|
914 |
+
|
915 |
+
self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
|
916 |
+
self.upsample1 = make_cbg(emb_dim, emb_dim)
|
917 |
+
self.upsample2 = make_cbg(emb_dim, emb_dim)
|
918 |
+
self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
|
919 |
+
|
920 |
+
for m in self.modules():
|
921 |
+
if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
|
922 |
+
m.inplace = True
|
923 |
+
|
924 |
+
@torch.inference_mode()
|
925 |
+
@torch.autocast(device_type="cuda",dtype=torch.float16)
|
926 |
+
def forward(self, x):
|
927 |
+
real_batch = x.size(0)
|
928 |
+
|
929 |
+
shallow_batch = self.shallow(x)
|
930 |
+
glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
|
931 |
+
|
932 |
+
|
933 |
+
|
934 |
+
final_input = None
|
935 |
+
for i in range(real_batch):
|
936 |
+
start = i * 4
|
937 |
+
end = (i + 1) * 4
|
938 |
+
loc_batch = image2patches(x[i,:,:,:].unsqueeze(dim=0))
|
939 |
+
input_ = torch.cat((loc_batch, glb_batch[i,:,:,:].unsqueeze(dim=0)), dim=0)
|
940 |
+
|
941 |
+
|
942 |
+
if final_input == None:
|
943 |
+
final_input= input_
|
944 |
+
else: final_input = torch.cat((final_input, input_), dim=0)
|
945 |
+
|
946 |
+
features = self.backbone(final_input)
|
947 |
+
outputs = []
|
948 |
+
|
949 |
+
for i in range(real_batch):
|
950 |
+
|
951 |
+
start = i * 5
|
952 |
+
end = (i + 1) * 5
|
953 |
+
|
954 |
+
f4 = features[4][start:end, :, :, :] # shape: [5, C, H, W]
|
955 |
+
f3 = features[3][start:end, :, :, :]
|
956 |
+
f2 = features[2][start:end, :, :, :]
|
957 |
+
f1 = features[1][start:end, :, :, :]
|
958 |
+
f0 = features[0][start:end, :, :, :]
|
959 |
+
e5 = self.output5(f4)
|
960 |
+
e4 = self.output4(f3)
|
961 |
+
e3 = self.output3(f2)
|
962 |
+
e2 = self.output2(f1)
|
963 |
+
e1 = self.output1(f0)
|
964 |
+
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
|
965 |
+
e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
|
966 |
+
|
967 |
+
|
968 |
+
e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
|
969 |
+
e4 = self.conv4(e4)
|
970 |
+
e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
|
971 |
+
e3 = self.conv3(e3)
|
972 |
+
e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
|
973 |
+
e2 = self.conv2(e2)
|
974 |
+
e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
|
975 |
+
e1 = self.conv1(e1)
|
976 |
+
|
977 |
+
loc_e1, glb_e1 = e1.split([4, 1], dim=0)
|
978 |
+
|
979 |
+
output1_cat = patches2image(loc_e1) # (1,128,256,256)
|
980 |
+
|
981 |
+
# add glb feat in
|
982 |
+
output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
|
983 |
+
# merge
|
984 |
+
final_output = self.insmask_head(output1_cat) # (1,128,256,256)
|
985 |
+
# shallow feature merge
|
986 |
+
shallow = shallow_batch[i,:,:,:].unsqueeze(dim=0)
|
987 |
+
final_output = final_output + resize_as(shallow, final_output)
|
988 |
+
final_output = self.upsample1(rescale_to(final_output))
|
989 |
+
final_output = rescale_to(final_output + resize_as(shallow, final_output))
|
990 |
+
final_output = self.upsample2(final_output)
|
991 |
+
final_output = self.output(final_output)
|
992 |
+
mask = final_output.sigmoid()
|
993 |
+
outputs.append(mask)
|
994 |
+
|
995 |
+
return torch.cat(outputs, dim=0)
|
996 |
+
|
997 |
+
|
998 |
+
|
999 |
+
|
1000 |
+
def loadcheckpoints(self,model_path):
|
1001 |
+
model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
1002 |
+
self.load_state_dict(model_dict['model_state_dict'], strict=True)
|
1003 |
+
del model_path
|
1004 |
+
|
1005 |
+
def inference(self,image,refine_foreground=False):
|
1006 |
+
|
1007 |
+
set_random_seed(9)
|
1008 |
+
# image = ImageOps.exif_transpose(image)
|
1009 |
+
if isinstance(image, Image.Image):
|
1010 |
+
image, h, w,original_image = rgb_loader_refiner(image)
|
1011 |
+
img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
|
1012 |
+
with torch.no_grad():
|
1013 |
+
res = self.forward(img_tensor)
|
1014 |
+
|
1015 |
+
# Show Results
|
1016 |
+
if refine_foreground == True:
|
1017 |
+
|
1018 |
+
pred_pil = transforms.ToPILImage()(res.squeeze())
|
1019 |
+
image_masked = refine_foreground_process(original_image, pred_pil)
|
1020 |
+
|
1021 |
+
image_masked.putalpha(pred_pil.resize(original_image.size))
|
1022 |
+
return image_masked
|
1023 |
+
|
1024 |
+
else:
|
1025 |
+
alpha = postprocess_image(res, im_size=[w,h])
|
1026 |
+
pred_pil = transforms.ToPILImage()(alpha)
|
1027 |
+
mask = pred_pil.resize(original_image.size)
|
1028 |
+
foreground = original_image.putalpha(mask)
|
1029 |
+
# mask = Image.fromarray(alpha)
|
1030 |
+
|
1031 |
+
return foreground
|
1032 |
+
|
1033 |
+
|
1034 |
+
else:
|
1035 |
+
foregrounds = []
|
1036 |
+
for batch in image:
|
1037 |
+
image, h, w,original_image = rgb_loader_refiner(batch)
|
1038 |
+
img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
|
1039 |
+
|
1040 |
+
with torch.no_grad():
|
1041 |
+
res = self.forward(img_tensor)
|
1042 |
+
|
1043 |
+
if refine_foreground == True:
|
1044 |
+
|
1045 |
+
pred_pil = transforms.ToPILImage()(res.squeeze())
|
1046 |
+
image_masked = refine_foreground_process(original_image, pred_pil)
|
1047 |
+
|
1048 |
+
image_masked.putalpha(pred_pil.resize(original_image.size))
|
1049 |
+
|
1050 |
+
foregrounds.append(image_masked)
|
1051 |
+
else:
|
1052 |
+
alpha = postprocess_image(res, im_size=[w,h])
|
1053 |
+
pred_pil = transforms.ToPILImage()(alpha)
|
1054 |
+
mask = pred_pil.resize(original_image.size)
|
1055 |
+
original_image.putalpha(mask)
|
1056 |
+
# mask = Image.fromarray(alpha)
|
1057 |
+
foregrounds.append(original_image)
|
1058 |
+
|
1059 |
+
return foregrounds
|
1060 |
+
|
1061 |
+
def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, print_frames_processed=True, webm = False, rgb_value= (0, 255, 0)):
|
1062 |
+
|
1063 |
+
"""
|
1064 |
+
Segments the given video to extract the foreground (with alpha) from each frame
|
1065 |
+
and saves the result as either a WebM video (with alpha channel) or MP4 (with a
|
1066 |
+
color background).
|
1067 |
+
|
1068 |
+
Args:
|
1069 |
+
video_path (str):
|
1070 |
+
Path to the input video file.
|
1071 |
+
|
1072 |
+
output_path (str, optional):
|
1073 |
+
Directory (or full path) where the output video and/or files will be saved.
|
1074 |
+
Defaults to "./".
|
1075 |
+
|
1076 |
+
fps (int, optional):
|
1077 |
+
The frames per second (FPS) to use for the output video. If 0 (default), the
|
1078 |
+
original FPS of the input video is used. Otherwise, overrides it.
|
1079 |
+
|
1080 |
+
refine_foreground (bool, optional):
|
1081 |
+
Whether to run an additional “refine foreground” process on each frame.
|
1082 |
+
Defaults to False.
|
1083 |
+
|
1084 |
+
batch (int, optional):
|
1085 |
+
Number of frames to process at once (inference batch size). Large batch sizes
|
1086 |
+
may require more GPU memory. Defaults to 1.
|
1087 |
+
|
1088 |
+
print_frames_processed (bool, optional):
|
1089 |
+
If True (default), prints progress (how many frames have been processed) to
|
1090 |
+
the console.
|
1091 |
+
|
1092 |
+
webm (bool, optional):
|
1093 |
+
If True (default), exports a WebM video with alpha channel (VP9 / yuva420p).
|
1094 |
+
If False, exports an MP4 video composited over a solid color background.
|
1095 |
+
|
1096 |
+
rgb_value (tuple, optional):
|
1097 |
+
The RGB background color (e.g., green screen) used to composite frames when
|
1098 |
+
saving to MP4. Defaults to (0, 255, 0).
|
1099 |
+
|
1100 |
+
Returns:
|
1101 |
+
None. Writes the output video(s) to disk in the specified format.
|
1102 |
+
"""
|
1103 |
+
|
1104 |
+
|
1105 |
+
cap = cv2.VideoCapture(video_path)
|
1106 |
+
if not cap.isOpened():
|
1107 |
+
raise IOError(f"Cannot open video: {video_path}")
|
1108 |
+
|
1109 |
+
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
1110 |
+
original_fps = 30 if original_fps == 0 else original_fps
|
1111 |
+
fps = original_fps if fps == 0 else fps
|
1112 |
+
|
1113 |
+
ret, first_frame = cap.read()
|
1114 |
+
if not ret:
|
1115 |
+
raise ValueError("No frames found in the video.")
|
1116 |
+
height, width = first_frame.shape[:2]
|
1117 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
1118 |
+
|
1119 |
+
foregrounds = []
|
1120 |
+
frame_idx = 0
|
1121 |
+
processed_count = 0
|
1122 |
+
batch_frames = []
|
1123 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
1124 |
+
|
1125 |
+
while True:
|
1126 |
+
ret, frame = cap.read()
|
1127 |
+
if not ret:
|
1128 |
+
if batch_frames:
|
1129 |
+
batch_results = self.inference(batch_frames, refine_foreground)
|
1130 |
+
if isinstance(batch_results, Image.Image):
|
1131 |
+
foregrounds.append(batch_results)
|
1132 |
+
else:
|
1133 |
+
foregrounds.extend(batch_results)
|
1134 |
+
if print_frames_processed:
|
1135 |
+
print(f"Processed frames {frame_idx-len(batch_frames)+1} to {frame_idx} of {total_frames}")
|
1136 |
+
break
|
1137 |
+
|
1138 |
+
# Process every frame instead of using intervals
|
1139 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
1140 |
+
pil_frame = Image.fromarray(frame_rgb)
|
1141 |
+
batch_frames.append(pil_frame)
|
1142 |
+
|
1143 |
+
if len(batch_frames) == batch:
|
1144 |
+
batch_results = self.inference(batch_frames, refine_foreground)
|
1145 |
+
if isinstance(batch_results, Image.Image):
|
1146 |
+
foregrounds.append(batch_results)
|
1147 |
+
else:
|
1148 |
+
foregrounds.extend(batch_results)
|
1149 |
+
if print_frames_processed:
|
1150 |
+
print(f"Processed frames {frame_idx-batch+1} to {frame_idx} of {total_frames}")
|
1151 |
+
batch_frames = []
|
1152 |
+
processed_count += batch
|
1153 |
+
|
1154 |
+
frame_idx += 1
|
1155 |
+
|
1156 |
+
|
1157 |
+
if webm:
|
1158 |
+
alpha_webm_path = os.path.join(output_path, "foreground.webm")
|
1159 |
+
pil_images_to_webm_alpha(foregrounds, alpha_webm_path, fps=original_fps)
|
1160 |
+
|
1161 |
+
else:
|
1162 |
+
cap.release()
|
1163 |
+
fg_output = os.path.join(output_path, 'foreground.mp4')
|
1164 |
+
|
1165 |
+
pil_images_to_mp4(foregrounds, fg_output, fps=original_fps,rgb_value=rgb_value)
|
1166 |
+
cv2.destroyAllWindows()
|
1167 |
+
|
1168 |
+
try:
|
1169 |
+
fg_audio_output = os.path.join(output_path, 'foreground_output_with_audio.mp4')
|
1170 |
+
add_audio_to_video(fg_output, video_path, fg_audio_output)
|
1171 |
+
except Exception as e:
|
1172 |
+
print("No audio found in the original video")
|
1173 |
+
print(e)
|
1174 |
+
|
1175 |
+
|
1176 |
+
|
1177 |
+
|
1178 |
+
|
1179 |
+
def rgb_loader_refiner( original_image):
|
1180 |
+
h, w = original_image.size
|
1181 |
+
|
1182 |
+
image = original_image
|
1183 |
+
# Convert to RGB if necessary
|
1184 |
+
if image.mode != 'RGB':
|
1185 |
+
image = image.convert('RGB')
|
1186 |
+
|
1187 |
+
# Resize the image
|
1188 |
+
image = image.resize((1024, 1024), resample=Image.LANCZOS)
|
1189 |
+
|
1190 |
+
return image.convert('RGB'), h, w,original_image
|
1191 |
+
|
1192 |
+
# Define the image transformation
|
1193 |
+
img_transform = transforms.Compose([
|
1194 |
+
transforms.ToTensor(),
|
1195 |
+
transforms.ConvertImageDtype(torch.float16),
|
1196 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
1197 |
+
])
|
1198 |
+
|
1199 |
+
|
1200 |
+
|
1201 |
+
|
1202 |
+
def pil_images_to_mp4(images, output_path, fps=24, rgb_value=(0, 255, 0)):
|
1203 |
+
"""
|
1204 |
+
Converts an array of PIL images to an MP4 video.
|
1205 |
+
|
1206 |
+
Args:
|
1207 |
+
images: List of PIL images
|
1208 |
+
output_path: Path to save the MP4 file
|
1209 |
+
fps: Frames per second (default: 24)
|
1210 |
+
rgb_value: Background RGB color tuple (default: green (0, 255, 0))
|
1211 |
+
"""
|
1212 |
+
if not images:
|
1213 |
+
raise ValueError("No images provided to convert to MP4.")
|
1214 |
+
|
1215 |
+
width, height = images[0].size
|
1216 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
1217 |
+
video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
1218 |
+
|
1219 |
+
for image in images:
|
1220 |
+
# If image has alpha channel, composite onto the specified background color
|
1221 |
+
if image.mode == 'RGBA':
|
1222 |
+
# Create background image with specified RGB color
|
1223 |
+
background = Image.new('RGB', image.size, rgb_value)
|
1224 |
+
background = background.convert('RGBA')
|
1225 |
+
# Composite the image onto the background
|
1226 |
+
image = Image.alpha_composite(background, image)
|
1227 |
+
image = image.convert('RGB')
|
1228 |
+
else:
|
1229 |
+
# Ensure RGB format for non-alpha images
|
1230 |
+
image = image.convert('RGB')
|
1231 |
+
|
1232 |
+
# Convert to OpenCV format and write
|
1233 |
+
open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
1234 |
+
video_writer.write(open_cv_image)
|
1235 |
+
|
1236 |
+
video_writer.release()
|
1237 |
+
|
1238 |
+
def pil_images_to_webm_alpha(images, output_path, fps=30):
|
1239 |
+
"""
|
1240 |
+
Converts a list of PIL RGBA images to a VP9 .webm video with alpha channel.
|
1241 |
+
|
1242 |
+
NOTE: Not all players will display alpha in WebM.
|
1243 |
+
Browsers like Chrome/Firefox typically do support VP9 alpha.
|
1244 |
+
"""
|
1245 |
+
if not images:
|
1246 |
+
raise ValueError("No images provided for WebM with alpha.")
|
1247 |
+
|
1248 |
+
# Ensure output directory exists
|
1249 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
1250 |
+
|
1251 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
1252 |
+
# Save frames as PNG (with alpha)
|
1253 |
+
for idx, img in enumerate(images):
|
1254 |
+
if img.mode != "RGBA":
|
1255 |
+
img = img.convert("RGBA")
|
1256 |
+
out_path = os.path.join(tmpdir, f"{idx:06d}.png")
|
1257 |
+
img.save(out_path, "PNG")
|
1258 |
+
|
1259 |
+
# Construct ffmpeg command
|
1260 |
+
# -c:v libvpx-vp9 => VP9 encoder
|
1261 |
+
# -pix_fmt yuva420p => alpha-enabled pixel format
|
1262 |
+
# -auto-alt-ref 0 => helps preserve alpha frames (libvpx quirk)
|
1263 |
+
ffmpeg_cmd = [
|
1264 |
+
"ffmpeg", "-y",
|
1265 |
+
"-framerate", str(fps),
|
1266 |
+
"-i", os.path.join(tmpdir, "%06d.png"),
|
1267 |
+
"-c:v", "libvpx-vp9",
|
1268 |
+
"-pix_fmt", "yuva420p",
|
1269 |
+
"-auto-alt-ref", "0",
|
1270 |
+
output_path
|
1271 |
+
]
|
1272 |
+
|
1273 |
+
subprocess.run(ffmpeg_cmd, check=True)
|
1274 |
+
|
1275 |
+
print(f"WebM with alpha saved to {output_path}")
|
1276 |
+
|
1277 |
+
def add_audio_to_video(video_without_audio_path, original_video_path, output_path):
|
1278 |
+
"""
|
1279 |
+
Check if the original video has an audio stream. If yes, add it. If not, skip.
|
1280 |
+
"""
|
1281 |
+
# 1) Probe original video for audio streams
|
1282 |
+
probe_command = [
|
1283 |
+
'ffprobe', '-v', 'error',
|
1284 |
+
'-select_streams', 'a:0',
|
1285 |
+
'-show_entries', 'stream=index',
|
1286 |
+
'-of', 'csv=p=0',
|
1287 |
+
original_video_path
|
1288 |
+
]
|
1289 |
+
result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
1290 |
+
|
1291 |
+
# result.stdout is empty if no audio stream found
|
1292 |
+
if not result.stdout.strip():
|
1293 |
+
print("No audio track found in original video, skipping audio addition.")
|
1294 |
+
return
|
1295 |
+
|
1296 |
+
print("Audio track detected; proceeding to mux audio.")
|
1297 |
+
# 2) If audio found, run ffmpeg to add it
|
1298 |
+
command = [
|
1299 |
+
'ffmpeg', '-y',
|
1300 |
+
'-i', video_without_audio_path,
|
1301 |
+
'-i', original_video_path,
|
1302 |
+
'-c', 'copy',
|
1303 |
+
'-map', '0:v:0',
|
1304 |
+
'-map', '1:a:0', # we know there's an audio track now
|
1305 |
+
output_path
|
1306 |
+
]
|
1307 |
+
subprocess.run(command, check=True)
|
1308 |
+
print(f"Audio added successfully => {output_path}")
|
1309 |
+
|
1310 |
+
|
1311 |
+
|
1312 |
+
|
1313 |
+
|
1314 |
+
### Thanks to the source: https://huggingface.co/ZhengPeng7/BiRefNet/blob/main/handler.py
|
1315 |
+
def refine_foreground_process(image, mask, r=90):
|
1316 |
+
if mask.size != image.size:
|
1317 |
+
mask = mask.resize(image.size)
|
1318 |
+
image = np.array(image) / 255.0
|
1319 |
+
mask = np.array(mask) / 255.0
|
1320 |
+
estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
|
1321 |
+
image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
|
1322 |
+
return image_masked
|
1323 |
+
|
1324 |
+
|
1325 |
+
def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
|
1326 |
+
# Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
|
1327 |
+
alpha = alpha[:, :, None]
|
1328 |
+
F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
|
1329 |
+
return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
|
1330 |
+
|
1331 |
+
|
1332 |
+
def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
1333 |
+
if isinstance(image, Image.Image):
|
1334 |
+
image = np.array(image) / 255.0
|
1335 |
+
blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
|
1336 |
+
|
1337 |
+
blurred_FA = cv2.blur(F * alpha, (r, r))
|
1338 |
+
blurred_F = blurred_FA / (blurred_alpha + 1e-5)
|
1339 |
+
|
1340 |
+
blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
|
1341 |
+
blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
|
1342 |
+
F = blurred_F + alpha * \
|
1343 |
+
(image - alpha * blurred_F - (1 - alpha) * blurred_B)
|
1344 |
+
F = np.clip(F, 0, 1)
|
1345 |
+
return F, blurred_B
|
1346 |
+
|
1347 |
+
|
1348 |
+
|
1349 |
+
def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
|
1350 |
+
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
|
1351 |
+
ma = torch.max(result)
|
1352 |
+
mi = torch.min(result)
|
1353 |
+
result = (result - mi) / (ma - mi)
|
1354 |
+
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
|
1355 |
+
im_array = np.squeeze(im_array)
|
1356 |
+
return im_array
|
1357 |
+
|
1358 |
+
|
1359 |
+
|
1360 |
+
|
1361 |
+
def rgb_loader_refiner( original_image):
|
1362 |
+
h, w = original_image.size
|
1363 |
+
# # Apply EXIF orientation
|
1364 |
+
|
1365 |
+
if original_image.mode != 'RGB':
|
1366 |
+
original_image = original_image.convert('RGB')
|
1367 |
+
|
1368 |
+
image = original_image
|
1369 |
+
# Convert to RGB if necessary
|
1370 |
+
|
1371 |
+
# Resize the image
|
1372 |
+
image = image.resize((1024, 1024), resample=Image.LANCZOS)
|
1373 |
+
|
1374 |
+
return image, h, w,original_image
|
1375 |
+
|
1376 |
+
|
1377 |
+
|
BEN2_Base.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:926144a876bda06f125555b4f5a239ece89dc6eb838a863700ca9bf192161a1c
|
3 |
+
size 1134584206
|
BEN2_demo_pictures/grid_example1.png
ADDED
Git LFS Details
|
BEN2_demo_pictures/grid_example2.png
ADDED
Git LFS Details
|
BEN2_demo_pictures/grid_example3.png
ADDED
Git LFS Details
|
BEN2_demo_pictures/grid_example6.png
ADDED
Git LFS Details
|
BEN2_demo_pictures/grid_example7.png
ADDED
Git LFS Details
|
BEN2_demo_pictures/model_comparison.png
ADDED
README.md
CHANGED
@@ -1,3 +1,151 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
pipeline_tag: image-segmentation
|
4 |
+
tags:
|
5 |
+
- BEN2
|
6 |
+
- background-remove
|
7 |
+
- mask-generation
|
8 |
+
- Dichotomous image segmentation
|
9 |
+
- background remove
|
10 |
+
- foreground
|
11 |
+
- background
|
12 |
+
- remove background
|
13 |
+
- pytorch
|
14 |
+
---
|
15 |
+
|
16 |
+
# BEN2: Background Erase Network
|
17 |
+
|
18 |
+
[![arXiv](https://img.shields.io/badge/arXiv-2501.06230-b31b1b.svg)](https://arxiv.org/abs/2501.06230)
|
19 |
+
[![GitHub](https://img.shields.io/badge/GitHub-BEN2-black.svg)](https://github.com/PramaLLC/BEN2/)
|
20 |
+
[![Website](https://img.shields.io/badge/Website-backgrounderase.net-104233)](https://backgrounderase.net)
|
21 |
+
|
22 |
+
## Overview
|
23 |
+
BEN2 (Background Erase Network) introduces a novel approach to foreground segmentation through its innovative Confidence Guided Matting (CGM) pipeline. The architecture employs a refiner network that targets and processes pixels where the base model exhibits lower confidence levels, resulting in more precise and reliable matting results. This model is built on BEN:
|
24 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/ben-using-confidence-guided-matting-for/dichotomous-image-segmentation-on-dis-vd)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-vd?p=ben-using-confidence-guided-matting-for)
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
## BEN2 access
|
30 |
+
BEN2 was trained on the DIS5k and our 22K proprietary segmentation dataset. Our enhanced model delivers superior performance in hair matting, 4K processing, object segmentation, and edge refinement. Our Base model is open source. To try the full model through our free web demo or integrate BEN2 into your project with our API:
|
31 |
+
- 🌐 [backgrounderase.net](https://backgrounderase.net)
|
32 |
+
|
33 |
+
|
34 |
+
## Contact us
|
35 |
+
- For access to our commercial model email us at [email protected]
|
36 |
+
- Our website: https://prama.llc/
|
37 |
+
- Follow us on X: https://x.com/PramaResearch/
|
38 |
+
|
39 |
+
|
40 |
+
## Quick start code (inside cloned repo)
|
41 |
+
|
42 |
+
```python
|
43 |
+
import model
|
44 |
+
from PIL import Image
|
45 |
+
import torch
|
46 |
+
|
47 |
+
|
48 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
49 |
+
|
50 |
+
file = "./image.png" # input image
|
51 |
+
|
52 |
+
model = model.BEN_Base().to(device).eval() #init pipeline
|
53 |
+
|
54 |
+
model.loadcheckpoints("./BEN_Base2.pth")
|
55 |
+
image = Image.open(file)
|
56 |
+
foreground = model.inference(image, refine_foreground=False,) #Refine foreground is an extract postprocessing step that increases inference time but can improve matting edges. The default value is False.
|
57 |
+
|
58 |
+
foreground.save("./foreground.png")
|
59 |
+
|
60 |
+
```
|
61 |
+
|
62 |
+
|
63 |
+
## Batch image processing
|
64 |
+
|
65 |
+
```python
|
66 |
+
import model
|
67 |
+
from PIL import Image
|
68 |
+
import torch
|
69 |
+
|
70 |
+
|
71 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
model = model.BEN_Base().to(device).eval() #init pipeline
|
76 |
+
|
77 |
+
model.loadcheckpoints("./BEN_Base2.pth")
|
78 |
+
|
79 |
+
file1 = "./image1.png" # input image1
|
80 |
+
file2 = "./image2.png" # input image2
|
81 |
+
image1 = Image.open(file1)
|
82 |
+
image2 = Image.open(file2)
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
foregrounds = model.inference([image1, image2]) # We recommended that batch size not exceed 3 for consumer GPUs as there are minimal inference gains. Due to our custom batch processing for the MVANet decoding steps.
|
87 |
+
foregrounds[0].save("./foreground1.png")
|
88 |
+
foregrounds[1].save("./foreground2.png")
|
89 |
+
|
90 |
+
```
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
# BEN2 video segmentation
|
95 |
+
[![BEN2 Demo](https://img.youtube.com/vi/skEXiIHQcys/0.jpg)](https://www.youtube.com/watch?v=skEXiIHQcys)
|
96 |
+
|
97 |
+
## Video Segmentation
|
98 |
+
|
99 |
+
```bash
|
100 |
+
sudo apt update
|
101 |
+
sudo apt install ffmpeg
|
102 |
+
```
|
103 |
+
|
104 |
+
```python
|
105 |
+
import model
|
106 |
+
from PIL import Image
|
107 |
+
import torch
|
108 |
+
|
109 |
+
|
110 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
111 |
+
|
112 |
+
file = "./image.png" # input image
|
113 |
+
|
114 |
+
model = model.BEN_Base().to(device).eval() #init pipeline
|
115 |
+
|
116 |
+
model.loadcheckpoints("./BEN_Base2.pth")
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
model.segment_video(
|
122 |
+
video_path="/path_to_your_video.mp4",
|
123 |
+
output_path="./", # Outputs will be saved as foreground.webm or foreground.mp4. The default value is "./"
|
124 |
+
fps=0, # If this is set to 0 CV2 will detect the fps in the original video. The default value is 0.
|
125 |
+
refine_foreground=False, #refine foreground is an extract postprocessing step that increases inference time but can improve matting edges. The default value is False.
|
126 |
+
batch=1, # We recommended that batch size not exceed 3 for consumer GPUs as there are minimal inference gains. The default value is 1.
|
127 |
+
print_frames_processed=True, #Informs you what frame is being processed. The default value is True.
|
128 |
+
webm = False, # This will output an alpha layer video but this defaults to mp4 when webm is false. The default value is False.
|
129 |
+
rgb_value= (0, 255, 0) # If you do not use webm this will be the RGB value of the resulting background only when webm is False. The default value is a green background (0,255,0).
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
```
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
**# BEN2 evaluation**
|
138 |
+
![Model Comparison](BEN2_demo_pictures/model_comparison.png)
|
139 |
+
|
140 |
+
RMBG 2.0 did not preserve the DIS 5k validation dataset
|
141 |
+
|
142 |
+
![Example 1](BEN2_demo_pictures/grid_example1.png)
|
143 |
+
![Example 2](BEN2_demo_pictures/grid_example2.png)
|
144 |
+
![Example 3](BEN2_demo_pictures/grid_example3.png)
|
145 |
+
![Example 6](BEN2_demo_pictures/grid_example6.png)
|
146 |
+
![Example 7](BEN2_demo_pictures/grid_example7.png)
|
147 |
+
|
148 |
+
|
149 |
+
## Installation
|
150 |
+
1. Clone Repo
|
151 |
+
2. Install requirements.txt
|
config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "PramaLLC/BEN2",
|
3 |
+
"architectures": ["PramaBEN_Base"],
|
4 |
+
"version": "1.0",
|
5 |
+
"torch_dtype": "float32"
|
6 |
+
}
|
inference.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import BEN2
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
+
|
8 |
+
file = "./image.png" # input image
|
9 |
+
|
10 |
+
model = BEN2.BEN_Base().to(device).eval() #init pipeline
|
11 |
+
|
12 |
+
model.loadcheckpoints("./BEN_Base2.pth")
|
13 |
+
image = Image.open(file)
|
14 |
+
mask, foreground = model.inference(image)
|
15 |
+
|
16 |
+
mask.save("./mask.png")
|
17 |
+
foreground.save("./foreground.png")
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy>=1.21.0
|
2 |
+
torch>=1.9.0
|
3 |
+
einops>=0.6.0
|
4 |
+
Pillow>=9.0.0
|
5 |
+
timm>=0.6.0
|
6 |
+
torchvision>=0.10.0
|