Spaces:
Runtime error
Runtime error
Pierre Fernandez
commited on
Commit
·
9e6cbab
1
Parent(s):
c25dff6
added encoding and decoding
Browse files- app.py +77 -63
- requirements.txt +1 -0
- utils.py +84 -0
- utils_img.py +85 -0
app.py
CHANGED
@@ -3,80 +3,94 @@ import gradio.inputs as grinputs
|
|
3 |
import gradio.outputs as groutputs
|
4 |
|
5 |
import numpy as np
|
|
|
6 |
|
7 |
import torch
|
8 |
-
|
9 |
-
|
|
|
|
|
10 |
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
torch.manual_seed(0)
|
14 |
np.random.seed(0)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
FPR = 1e-6
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
backbone = build_backbone(path='dino_r50.pth')
|
72 |
-
normlayer = load_normalization_layer(path='out2048.pth')
|
73 |
-
model = NormLayerWrapper(backbone, normlayer)
|
74 |
-
|
75 |
-
def encode(image):
|
76 |
-
return image
|
77 |
|
78 |
def decode(image):
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
def on_submit(image, mode):
|
82 |
print('{} mode'.format(mode))
|
|
|
3 |
import gradio.outputs as groutputs
|
4 |
|
5 |
import numpy as np
|
6 |
+
import json
|
7 |
|
8 |
import torch
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
import utils
|
12 |
+
import utils_img
|
13 |
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
16 |
torch.manual_seed(0)
|
17 |
np.random.seed(0)
|
18 |
|
19 |
+
print('Building backbone and normalization layer...')
|
20 |
+
backbone = utils.build_backbone(path='dino_r50.pth')
|
21 |
+
normlayer = utils.load_normalization_layer(path='out2048.pth')
|
22 |
+
model = utils.NormLayerWrapper(backbone, normlayer)
|
23 |
+
|
24 |
+
print('Building the hypercone...')
|
25 |
FPR = 1e-6
|
26 |
+
angle = 1.462771101178447 # value for FPR=1e-6 and D=2048
|
27 |
+
rho = 1 + np.tan(angle)**2
|
28 |
+
# angle = utils.pvalue_angle(2048, 1, proba=FPR)
|
29 |
+
carrier = torch.randn(1, 2048)
|
30 |
+
carrier /= torch.norm(carrier, dim=1, keepdim=True)
|
31 |
+
|
32 |
+
default_transform = transforms.Compose([
|
33 |
+
transforms.ToTensor(),
|
34 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
35 |
+
])
|
36 |
+
|
37 |
+
def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1):
|
38 |
+
img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
|
39 |
+
img = img_orig.clone().to(device, non_blocking=True)
|
40 |
+
img.requires_grad = True
|
41 |
+
optimizer = torch.optim.Adam([img], lr=1e-2)
|
42 |
+
|
43 |
+
for iteration in range(epochs):
|
44 |
+
x = utils_img.ssim_attenuation(img, img_orig)
|
45 |
+
x = utils_img.psnr_clip(x, img_orig, psnr)
|
46 |
+
|
47 |
+
ft = model(x) # BxCxWxH -> BxD
|
48 |
+
|
49 |
+
dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
|
50 |
+
norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
|
51 |
+
cosines = torch.abs(dot_product/norm)
|
52 |
+
log10_pvalue = np.log10(utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
|
53 |
+
loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
|
54 |
+
|
55 |
+
loss_l2_img = torch.norm(x - img_orig)**2 # CxWxH -> 1
|
56 |
+
loss = lambda_w*loss_R + lambda_i*loss_l2_img
|
57 |
+
|
58 |
+
optimizer.zero_grad()
|
59 |
+
loss.backward()
|
60 |
+
optimizer.step()
|
61 |
+
|
62 |
+
logs = {
|
63 |
+
"keyword": "img_optim",
|
64 |
+
"iteration": iteration,
|
65 |
+
"loss": loss.item(),
|
66 |
+
"loss_R": loss_R.item(),
|
67 |
+
"loss_l2_img": loss_l2_img.item(),
|
68 |
+
"log10_pvalue": log10_pvalue.item(),
|
69 |
+
}
|
70 |
+
print("__log__:%s" % json.dumps(logs))
|
71 |
+
|
72 |
+
img = utils_img.ssim_attenuation(img, img_orig)
|
73 |
+
img = utils_img.psnr_clip(img, img_orig, psnr)
|
74 |
+
img = utils_img.round_pixel(img)
|
75 |
+
img = img.squeeze(0).detach().cpu()
|
76 |
+
img = transforms.ToPILImage()(utils_img.unnormalize_img(img).squeeze(0))
|
77 |
+
|
78 |
+
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
def decode(image):
|
81 |
+
img = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
|
82 |
+
ft = model(img) # BxCxWxH -> BxD
|
83 |
+
|
84 |
+
dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
|
85 |
+
norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
|
86 |
+
cosines = torch.abs(dot_product/norm)
|
87 |
+
log10_pvalue = np.log10(utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
|
88 |
+
loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
|
89 |
+
|
90 |
+
text_marked = "marked" if loss_R < 0 else "unmarked"
|
91 |
+
return 'Image is {s}, with p-value={p}'.format(s=text_marked, p=10**log10_pvalue)
|
92 |
+
|
93 |
+
|
94 |
|
95 |
def on_submit(image, mode):
|
96 |
print('{} mode'.format(mode))
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
torch==1.10.1
|
2 |
torchvision==0.11.2
|
3 |
pillow==9.0.0
|
|
|
|
1 |
torch==1.10.1
|
2 |
torchvision==0.11.2
|
3 |
pillow==9.0.0
|
4 |
+
scipy
|
utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torchvision import models
|
6 |
+
|
7 |
+
from scipy.optimize import root_scalar
|
8 |
+
from scipy.special import betainc
|
9 |
+
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
|
12 |
+
def build_backbone(path, name='resnet50'):
|
13 |
+
""" Builds a pretrained ResNet-50 backbone. """
|
14 |
+
model = getattr(models, name)(pretrained=False)
|
15 |
+
model.head = nn.Identity()
|
16 |
+
model.fc = nn.Identity()
|
17 |
+
checkpoint = torch.load(path, map_location=device)
|
18 |
+
state_dict = checkpoint
|
19 |
+
for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
|
20 |
+
if ckpt_key in checkpoint:
|
21 |
+
state_dict = checkpoint[ckpt_key]
|
22 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
23 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
24 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
25 |
+
return model
|
26 |
+
|
27 |
+
def get_linear_layer(weight, bias):
|
28 |
+
""" Creates a layer that performs feature whitening or centering """
|
29 |
+
dim_out, dim_in = weight.shape
|
30 |
+
layer = nn.Linear(dim_in, dim_out)
|
31 |
+
layer.weight = nn.Parameter(weight)
|
32 |
+
layer.bias = nn.Parameter(bias)
|
33 |
+
return layer
|
34 |
+
|
35 |
+
def load_normalization_layer(path):
|
36 |
+
"""
|
37 |
+
Loads the normalization layer from a checkpoint and returns the layer.
|
38 |
+
"""
|
39 |
+
checkpoint = torch.load(path, map_location=device)
|
40 |
+
if 'whitening' in path or 'out' in path:
|
41 |
+
D = checkpoint['weight'].shape[1]
|
42 |
+
weight = torch.nn.Parameter(D*checkpoint['weight'])
|
43 |
+
bias = torch.nn.Parameter(D*checkpoint['bias'])
|
44 |
+
else:
|
45 |
+
weight = checkpoint['weight']
|
46 |
+
bias = checkpoint['bias']
|
47 |
+
return get_linear_layer(weight, bias).to(device, non_blocking=True)
|
48 |
+
|
49 |
+
class NormLayerWrapper(nn.Module):
|
50 |
+
"""
|
51 |
+
Wraps backbone model and normalization layer
|
52 |
+
"""
|
53 |
+
def __init__(self, backbone, head):
|
54 |
+
super(NormLayerWrapper, self).__init__()
|
55 |
+
backbone.eval(), head.eval()
|
56 |
+
self.backbone = backbone
|
57 |
+
self.head = head
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
output = self.backbone(x)
|
61 |
+
return self.head(output)
|
62 |
+
|
63 |
+
def cosine_pvalue(c, d, k=1):
|
64 |
+
"""
|
65 |
+
Returns the probability that the absolute value of the projection
|
66 |
+
between random unit vectors is higher than c
|
67 |
+
Args:
|
68 |
+
c: cosine value
|
69 |
+
d: dimension of the features
|
70 |
+
k: number of dimensions of the projection
|
71 |
+
"""
|
72 |
+
assert k>0
|
73 |
+
a = (d - k) / 2.0
|
74 |
+
b = k / 2.0
|
75 |
+
if c < 0:
|
76 |
+
return 1.0
|
77 |
+
return betainc(a, b, 1 - c ** 2)
|
78 |
+
|
79 |
+
def pvalue_angle(dim, k=1, angle=None, proba=None):
|
80 |
+
def f(a):
|
81 |
+
return cosine_pvalue(np.cos(a), dim, k) - proba
|
82 |
+
a = root_scalar(f, x0=0.49*np.pi, bracket=[0, np.pi/2])
|
83 |
+
# a = fsolve(f, x0=0.49*np.pi)[0]
|
84 |
+
return a.root
|
utils_img.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from torch.autograd.variable import Variable
|
9 |
+
|
10 |
+
NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
11 |
+
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
image_mean = torch.Tensor(NORMALIZE_IMAGENET.mean).view(-1, 1, 1).to(device)
|
14 |
+
image_std = torch.Tensor(NORMALIZE_IMAGENET.std).view(-1, 1, 1).to(device)
|
15 |
+
|
16 |
+
def normalize_img(x):
|
17 |
+
return (x.to(device) - image_mean) / image_std
|
18 |
+
|
19 |
+
def unnormalize_img(x):
|
20 |
+
return (x.to(device) * image_std) + image_mean
|
21 |
+
|
22 |
+
def round_pixel(x):
|
23 |
+
x_pixel = 255 * unnormalize_img(x)
|
24 |
+
y = torch.round(x_pixel).clamp(0, 255)
|
25 |
+
y = normalize_img(y/255.0)
|
26 |
+
return y
|
27 |
+
|
28 |
+
def project_linf(x, y, radius):
|
29 |
+
""" Clamp x-y so that Linf(x,y)<=radius """
|
30 |
+
delta = x - y
|
31 |
+
delta = 255 * (delta * image_std)
|
32 |
+
delta = torch.clamp(delta, -radius, radius)
|
33 |
+
delta = (delta / 255.0) / image_std
|
34 |
+
return y + delta
|
35 |
+
|
36 |
+
def psnr_clip(x, y, target_psnr):
|
37 |
+
""" Clip x-y so that PSNR(x,y)=target_psnr """
|
38 |
+
delta = x - y
|
39 |
+
delta = 255 * (delta * image_std)
|
40 |
+
psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
|
41 |
+
if psnr<target_psnr:
|
42 |
+
delta = (torch.sqrt(10**((psnr-target_psnr)/10))) * delta
|
43 |
+
psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
|
44 |
+
delta = (delta / 255.0) / image_std
|
45 |
+
return y + delta
|
46 |
+
|
47 |
+
def ssim_heatmap(img1, img2, window_size):
|
48 |
+
""" Compute the SSIM heatmap between 2 images """
|
49 |
+
_1D_window = torch.Tensor(
|
50 |
+
[np.exp(-(x - window_size//2)**2/float(2*1.5**2)) for x in range(window_size)]
|
51 |
+
).to(device, non_blocking=True)
|
52 |
+
_1D_window = (_1D_window/_1D_window.sum()).unsqueeze(1)
|
53 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
54 |
+
window = Variable(_2D_window.expand(3, 1, window_size, window_size).contiguous())
|
55 |
+
|
56 |
+
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = 3)
|
57 |
+
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = 3)
|
58 |
+
|
59 |
+
mu1_sq = mu1.pow(2)
|
60 |
+
mu2_sq = mu2.pow(2)
|
61 |
+
mu1_mu2 = mu1*mu2
|
62 |
+
|
63 |
+
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = 3) - mu1_sq
|
64 |
+
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = 3) - mu2_sq
|
65 |
+
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = 3) - mu1_mu2
|
66 |
+
|
67 |
+
C1 = 0.01**2
|
68 |
+
C2 = 0.03**2
|
69 |
+
|
70 |
+
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
71 |
+
return ssim_map
|
72 |
+
|
73 |
+
def ssim_attenuation(x, y):
|
74 |
+
""" attenuate x-y using SSIM heatmap """
|
75 |
+
delta = x - y
|
76 |
+
ssim_map = ssim_heatmap(x, y, window_size=17) # 1xCxHxW
|
77 |
+
ssim_map = torch.sum(ssim_map, dim=1, keepdim=True)
|
78 |
+
ssim_map = torch.clamp_min(ssim_map,0)
|
79 |
+
# min_v = torch.min(ssim_map)
|
80 |
+
# range_v = torch.max(ssim_map) - min_v
|
81 |
+
# if range_v < 1e-10:
|
82 |
+
# return y + delta
|
83 |
+
# ssim_map = (ssim_map - min_v) / range_v
|
84 |
+
delta = delta*ssim_map
|
85 |
+
return y + delta
|