Pierre Fernandez commited on
Commit
9e6cbab
·
1 Parent(s): c25dff6

added encoding and decoding

Browse files
Files changed (4) hide show
  1. app.py +77 -63
  2. requirements.txt +1 -0
  3. utils.py +84 -0
  4. utils_img.py +85 -0
app.py CHANGED
@@ -3,80 +3,94 @@ import gradio.inputs as grinputs
3
  import gradio.outputs as groutputs
4
 
5
  import numpy as np
 
6
 
7
  import torch
8
- import torch.nn as nn
9
- from torchvision import models
 
 
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  torch.manual_seed(0)
14
  np.random.seed(0)
15
 
 
 
 
 
 
 
16
  FPR = 1e-6
17
- carrier = np.random.randn(1, 2048)
18
-
19
-
20
- def build_backbone(path, name='resnet50'):
21
- """ Builds a pretrained ResNet-50 backbone. """
22
- model = getattr(models, name)(pretrained=False)
23
- model.head = nn.Identity()
24
- model.fc = nn.Identity()
25
- checkpoint = torch.load(path, map_location=device)
26
- state_dict = checkpoint
27
- for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
28
- if ckpt_key in checkpoint:
29
- state_dict = checkpoint[ckpt_key]
30
- state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
31
- state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
32
- msg = model.load_state_dict(state_dict, strict=False)
33
- return model
34
-
35
- def get_linear_layer(weight, bias):
36
- """ Creates a layer that performs feature whitening or centering """
37
- dim_out, dim_in = weight.shape
38
- layer = nn.Linear(dim_in, dim_out)
39
- layer.weight = nn.Parameter(weight)
40
- layer.bias = nn.Parameter(bias)
41
- return layer
42
-
43
- def load_normalization_layer(path):
44
- """
45
- Loads the normalization layer from a checkpoint and returns the layer.
46
- """
47
- checkpoint = torch.load(path, map_location=device)
48
- if 'whitening' in path or 'out' in path:
49
- D = checkpoint['weight'].shape[1]
50
- weight = torch.nn.Parameter(D*checkpoint['weight'])
51
- bias = torch.nn.Parameter(D*checkpoint['bias'])
52
- else:
53
- weight = checkpoint['weight']
54
- bias = checkpoint['bias']
55
- return get_linear_layer(weight, bias).to(device, non_blocking=True)
56
-
57
- class NormLayerWrapper(nn.Module):
58
- """
59
- Wraps backbone model and normalization layer
60
- """
61
- def __init__(self, backbone, head):
62
- super(NormLayerWrapper, self).__init__()
63
- backbone.eval(), head.eval()
64
- self.backbone = backbone
65
- self.head = head
66
-
67
- def forward(self, x):
68
- output = self.backbone(x)
69
- return self.head(output)
70
-
71
- backbone = build_backbone(path='dino_r50.pth')
72
- normlayer = load_normalization_layer(path='out2048.pth')
73
- model = NormLayerWrapper(backbone, normlayer)
74
-
75
- def encode(image):
76
- return image
77
 
78
  def decode(image):
79
- return 'decoded'
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def on_submit(image, mode):
82
  print('{} mode'.format(mode))
 
3
  import gradio.outputs as groutputs
4
 
5
  import numpy as np
6
+ import json
7
 
8
  import torch
9
+ from torchvision import transforms
10
+
11
+ import utils
12
+ import utils_img
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  torch.manual_seed(0)
17
  np.random.seed(0)
18
 
19
+ print('Building backbone and normalization layer...')
20
+ backbone = utils.build_backbone(path='dino_r50.pth')
21
+ normlayer = utils.load_normalization_layer(path='out2048.pth')
22
+ model = utils.NormLayerWrapper(backbone, normlayer)
23
+
24
+ print('Building the hypercone...')
25
  FPR = 1e-6
26
+ angle = 1.462771101178447 # value for FPR=1e-6 and D=2048
27
+ rho = 1 + np.tan(angle)**2
28
+ # angle = utils.pvalue_angle(2048, 1, proba=FPR)
29
+ carrier = torch.randn(1, 2048)
30
+ carrier /= torch.norm(carrier, dim=1, keepdim=True)
31
+
32
+ default_transform = transforms.Compose([
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
35
+ ])
36
+
37
+ def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1):
38
+ img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
39
+ img = img_orig.clone().to(device, non_blocking=True)
40
+ img.requires_grad = True
41
+ optimizer = torch.optim.Adam([img], lr=1e-2)
42
+
43
+ for iteration in range(epochs):
44
+ x = utils_img.ssim_attenuation(img, img_orig)
45
+ x = utils_img.psnr_clip(x, img_orig, psnr)
46
+
47
+ ft = model(x) # BxCxWxH -> BxD
48
+
49
+ dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
50
+ norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
51
+ cosines = torch.abs(dot_product/norm)
52
+ log10_pvalue = np.log10(utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
53
+ loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
54
+
55
+ loss_l2_img = torch.norm(x - img_orig)**2 # CxWxH -> 1
56
+ loss = lambda_w*loss_R + lambda_i*loss_l2_img
57
+
58
+ optimizer.zero_grad()
59
+ loss.backward()
60
+ optimizer.step()
61
+
62
+ logs = {
63
+ "keyword": "img_optim",
64
+ "iteration": iteration,
65
+ "loss": loss.item(),
66
+ "loss_R": loss_R.item(),
67
+ "loss_l2_img": loss_l2_img.item(),
68
+ "log10_pvalue": log10_pvalue.item(),
69
+ }
70
+ print("__log__:%s" % json.dumps(logs))
71
+
72
+ img = utils_img.ssim_attenuation(img, img_orig)
73
+ img = utils_img.psnr_clip(img, img_orig, psnr)
74
+ img = utils_img.round_pixel(img)
75
+ img = img.squeeze(0).detach().cpu()
76
+ img = transforms.ToPILImage()(utils_img.unnormalize_img(img).squeeze(0))
77
+
78
+ return img
 
 
 
 
 
 
 
79
 
80
  def decode(image):
81
+ img = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
82
+ ft = model(img) # BxCxWxH -> BxD
83
+
84
+ dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
85
+ norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
86
+ cosines = torch.abs(dot_product/norm)
87
+ log10_pvalue = np.log10(utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
88
+ loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B
89
+
90
+ text_marked = "marked" if loss_R < 0 else "unmarked"
91
+ return 'Image is {s}, with p-value={p}'.format(s=text_marked, p=10**log10_pvalue)
92
+
93
+
94
 
95
  def on_submit(image, mode):
96
  print('{} mode'.format(mode))
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  torch==1.10.1
2
  torchvision==0.11.2
3
  pillow==9.0.0
 
 
1
  torch==1.10.1
2
  torchvision==0.11.2
3
  pillow==9.0.0
4
+ scipy
utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import models
6
+
7
+ from scipy.optimize import root_scalar
8
+ from scipy.special import betainc
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ def build_backbone(path, name='resnet50'):
13
+ """ Builds a pretrained ResNet-50 backbone. """
14
+ model = getattr(models, name)(pretrained=False)
15
+ model.head = nn.Identity()
16
+ model.fc = nn.Identity()
17
+ checkpoint = torch.load(path, map_location=device)
18
+ state_dict = checkpoint
19
+ for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
20
+ if ckpt_key in checkpoint:
21
+ state_dict = checkpoint[ckpt_key]
22
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
23
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
24
+ msg = model.load_state_dict(state_dict, strict=False)
25
+ return model
26
+
27
+ def get_linear_layer(weight, bias):
28
+ """ Creates a layer that performs feature whitening or centering """
29
+ dim_out, dim_in = weight.shape
30
+ layer = nn.Linear(dim_in, dim_out)
31
+ layer.weight = nn.Parameter(weight)
32
+ layer.bias = nn.Parameter(bias)
33
+ return layer
34
+
35
+ def load_normalization_layer(path):
36
+ """
37
+ Loads the normalization layer from a checkpoint and returns the layer.
38
+ """
39
+ checkpoint = torch.load(path, map_location=device)
40
+ if 'whitening' in path or 'out' in path:
41
+ D = checkpoint['weight'].shape[1]
42
+ weight = torch.nn.Parameter(D*checkpoint['weight'])
43
+ bias = torch.nn.Parameter(D*checkpoint['bias'])
44
+ else:
45
+ weight = checkpoint['weight']
46
+ bias = checkpoint['bias']
47
+ return get_linear_layer(weight, bias).to(device, non_blocking=True)
48
+
49
+ class NormLayerWrapper(nn.Module):
50
+ """
51
+ Wraps backbone model and normalization layer
52
+ """
53
+ def __init__(self, backbone, head):
54
+ super(NormLayerWrapper, self).__init__()
55
+ backbone.eval(), head.eval()
56
+ self.backbone = backbone
57
+ self.head = head
58
+
59
+ def forward(self, x):
60
+ output = self.backbone(x)
61
+ return self.head(output)
62
+
63
+ def cosine_pvalue(c, d, k=1):
64
+ """
65
+ Returns the probability that the absolute value of the projection
66
+ between random unit vectors is higher than c
67
+ Args:
68
+ c: cosine value
69
+ d: dimension of the features
70
+ k: number of dimensions of the projection
71
+ """
72
+ assert k>0
73
+ a = (d - k) / 2.0
74
+ b = k / 2.0
75
+ if c < 0:
76
+ return 1.0
77
+ return betainc(a, b, 1 - c ** 2)
78
+
79
+ def pvalue_angle(dim, k=1, angle=None, proba=None):
80
+ def f(a):
81
+ return cosine_pvalue(np.cos(a), dim, k) - proba
82
+ a = root_scalar(f, x0=0.49*np.pi, bracket=[0, np.pi/2])
83
+ # a = fsolve(f, x0=0.49*np.pi)[0]
84
+ return a.root
utils_img.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ from torchvision import transforms
5
+
6
+ import torch.nn.functional as F
7
+
8
+ from torch.autograd.variable import Variable
9
+
10
+ NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ image_mean = torch.Tensor(NORMALIZE_IMAGENET.mean).view(-1, 1, 1).to(device)
14
+ image_std = torch.Tensor(NORMALIZE_IMAGENET.std).view(-1, 1, 1).to(device)
15
+
16
+ def normalize_img(x):
17
+ return (x.to(device) - image_mean) / image_std
18
+
19
+ def unnormalize_img(x):
20
+ return (x.to(device) * image_std) + image_mean
21
+
22
+ def round_pixel(x):
23
+ x_pixel = 255 * unnormalize_img(x)
24
+ y = torch.round(x_pixel).clamp(0, 255)
25
+ y = normalize_img(y/255.0)
26
+ return y
27
+
28
+ def project_linf(x, y, radius):
29
+ """ Clamp x-y so that Linf(x,y)<=radius """
30
+ delta = x - y
31
+ delta = 255 * (delta * image_std)
32
+ delta = torch.clamp(delta, -radius, radius)
33
+ delta = (delta / 255.0) / image_std
34
+ return y + delta
35
+
36
+ def psnr_clip(x, y, target_psnr):
37
+ """ Clip x-y so that PSNR(x,y)=target_psnr """
38
+ delta = x - y
39
+ delta = 255 * (delta * image_std)
40
+ psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
41
+ if psnr<target_psnr:
42
+ delta = (torch.sqrt(10**((psnr-target_psnr)/10))) * delta
43
+ psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2))
44
+ delta = (delta / 255.0) / image_std
45
+ return y + delta
46
+
47
+ def ssim_heatmap(img1, img2, window_size):
48
+ """ Compute the SSIM heatmap between 2 images """
49
+ _1D_window = torch.Tensor(
50
+ [np.exp(-(x - window_size//2)**2/float(2*1.5**2)) for x in range(window_size)]
51
+ ).to(device, non_blocking=True)
52
+ _1D_window = (_1D_window/_1D_window.sum()).unsqueeze(1)
53
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
54
+ window = Variable(_2D_window.expand(3, 1, window_size, window_size).contiguous())
55
+
56
+ mu1 = F.conv2d(img1, window, padding = window_size//2, groups = 3)
57
+ mu2 = F.conv2d(img2, window, padding = window_size//2, groups = 3)
58
+
59
+ mu1_sq = mu1.pow(2)
60
+ mu2_sq = mu2.pow(2)
61
+ mu1_mu2 = mu1*mu2
62
+
63
+ sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = 3) - mu1_sq
64
+ sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = 3) - mu2_sq
65
+ sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = 3) - mu1_mu2
66
+
67
+ C1 = 0.01**2
68
+ C2 = 0.03**2
69
+
70
+ ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
71
+ return ssim_map
72
+
73
+ def ssim_attenuation(x, y):
74
+ """ attenuate x-y using SSIM heatmap """
75
+ delta = x - y
76
+ ssim_map = ssim_heatmap(x, y, window_size=17) # 1xCxHxW
77
+ ssim_map = torch.sum(ssim_map, dim=1, keepdim=True)
78
+ ssim_map = torch.clamp_min(ssim_map,0)
79
+ # min_v = torch.min(ssim_map)
80
+ # range_v = torch.max(ssim_map) - min_v
81
+ # if range_v < 1e-10:
82
+ # return y + delta
83
+ # ssim_map = (ssim_map - min_v) / range_v
84
+ delta = delta*ssim_map
85
+ return y + delta