ZeqiangLai commited on
Commit
5ae5c5c
·
1 Parent(s): 5febeab
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
drag_gan.py CHANGED
@@ -136,20 +136,21 @@ def bilinear_interpolate_torch(im, y, x):
136
  y : 1,numPoints -- pixel location y float
137
  x : 1,numPOints -- pixel location y float
138
  """
139
-
140
- x0 = torch.floor(x).long()
 
141
  x1 = x0 + 1
142
 
143
- y0 = torch.floor(y).long()
144
  y1 = y0 + 1
145
 
146
- wa = (x1.float() - x) * (y1.float() - y)
147
- wb = (x1.float() - x) * (y - y0.float())
148
- wc = (x - x0.float()) * (y1.float() - y)
149
- wd = (x - x0.float()) * (y - y0.float())
150
  # Instead of clamp
151
- x1 = x1 - torch.floor(x1 / im.shape[3]).int()
152
- y1 = y1 - torch.floor(y1 / im.shape[2]).int()
153
  Ia = im[:, :, y0, x0]
154
  Ib = im[:, :, y1, x0]
155
  Ic = im[:, :, y0, x1]
@@ -194,7 +195,8 @@ def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points
194
  f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1])
195
  loss += FF.l1_loss(f2, f1)
196
 
197
- loss += ((F2 - F0) * (1 - mask)).abs().mean() * lam
 
198
 
199
  loss.backward()
200
  optimizer.step()
 
136
  y : 1,numPoints -- pixel location y float
137
  x : 1,numPOints -- pixel location y float
138
  """
139
+ device = im.device
140
+
141
+ x0 = torch.floor(x).long().to(device)
142
  x1 = x0 + 1
143
 
144
+ y0 = torch.floor(y).long().to(device)
145
  y1 = y0 + 1
146
 
147
+ wa = ((x1.float() - x) * (y1.float() - y)).to(device)
148
+ wb = ((x1.float() - x) * (y - y0.float())).to(device)
149
+ wc = ((x - x0.float()) * (y1.float() - y)).to(device)
150
+ wd = ((x - x0.float()) * (y - y0.float())).to(device)
151
  # Instead of clamp
152
+ x1 = x1 - torch.floor(x1 / im.shape[3]).int().to(device)
153
+ y1 = y1 - torch.floor(y1 / im.shape[2]).int().to(device)
154
  Ia = im[:, :, y0, x0]
155
  Ib = im[:, :, y1, x0]
156
  Ic = im[:, :, y0, x1]
 
195
  f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1])
196
  loss += FF.l1_loss(f2, f1)
197
 
198
+ if mask is not None:
199
+ loss += ((F2 - F0) * (1 - mask)).abs().mean() * lam
200
 
201
  loss.backward()
202
  optimizer.step()
gradio_app.py CHANGED
@@ -7,12 +7,14 @@ from PIL import Image
7
  import uuid
8
 
9
  from drag_gan import drag_gan, stylegan2
 
10
 
11
- device = 'cuda'
12
 
13
 
14
  SIZE_TO_CLICK_SIZE = {
15
- 1024: 5,
 
16
  256: 2
17
  }
18
 
@@ -21,8 +23,32 @@ CKPT_SIZE = {
21
  'stylegan2-cat-config-f.pt': 256,
22
  'stylegan2-church-config-f.pt': 256,
23
  'stylegan2-horse-config-f.pt': 256,
 
 
 
 
 
 
24
  }
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  class ImageMask(gr.components.Image):
28
  """
@@ -94,11 +120,14 @@ def on_drag(model, points, max_iters, state, size, mask):
94
  handle_points = [torch.tensor(p).float() for p in points['handle']]
95
  target_points = [torch.tensor(p).float() for p in points['target']]
96
 
97
- mask = Image.fromarray(mask['mask']).convert('L')
98
- mask = np.array(mask) == 255
 
99
 
100
- mask = torch.from_numpy(mask).float().to(device)
101
- mask = mask.unsqueeze(0).unsqueeze(0)
 
 
102
 
103
  step = 0
104
  for sample2, latent, F, handle_points in drag_gan(model.g_ema, latent, noise, F,
@@ -149,7 +178,7 @@ def on_change_model(selected, model):
149
  'sample': sample,
150
  'history': []
151
  }
152
- return model, state, to_image(sample), size
153
 
154
 
155
  def on_new_image(model):
@@ -187,11 +216,29 @@ def on_show_save():
187
  return gr.update(visible=True)
188
 
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  def main():
191
  torch.cuda.manual_seed(25)
192
 
193
  with gr.Blocks() as demo:
194
- wrapped_model = ModelWrapper()
195
  model = gr.State(wrapped_model)
196
  sample_z = torch.randn([1, 512], device=device)
197
  latent, noise = wrapped_model.g_ema.prepare([sample_z])
@@ -199,11 +246,11 @@ def main():
199
 
200
  gr.Markdown(
201
  """
202
- # DragGAN (Unofficial)
203
 
204
  Unofficial implementation of [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
205
 
206
- [Github](https://github.com/Zeqiang-Lai/DragGAN) | [Official Implementation](https://github.com/XingangPan/DragGAN) (Not released yet)
207
 
208
  ## Tutorial
209
 
@@ -211,6 +258,22 @@ def main():
211
  2. Setup a least one pair of handle point and target point.
212
  3. Click "Drag it".
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  """,
215
  )
216
  state = gr.State({
@@ -221,12 +284,12 @@ def main():
221
  'history': []
222
  })
223
  points = gr.State({'target': [], 'handle': []})
224
- size = gr.State(1024)
225
 
226
  with gr.Row():
227
  with gr.Column(scale=0.3):
228
  with gr.Accordion("Model"):
229
- model_dropdown = gr.Dropdown(choices=list(CKPT_SIZE.keys()), value='stylegan2-ffhq-config-f.pt',
230
  label='StyleGAN2 model')
231
  max_iters = gr.Slider(1, 500, 20, step=1, label='Max Iterations')
232
  new_btn = gr.Button('New Image')
@@ -252,24 +315,34 @@ def main():
252
  with gr.Column():
253
  with gr.Tabs():
254
  with gr.Tab('Draw a Mask', id='mask'):
255
- mask = gr.ImageMask(value=to_image(sample), label='Mask').style(height=768, width=768)
256
  with gr.Tab('Setup Handle Points', id='input'):
257
- image = gr.Image(to_image(sample)).style(height=768, width=768)
258
 
259
  image.select(on_click, [image, target_point, points, size], [image, text, target_point])
 
 
260
  btn.click(on_drag, inputs=[model, points, max_iters, state, size, mask], outputs=[image, state, progress]).then(
261
  on_show_save, outputs=save_panel).then(
262
  on_save_files, inputs=[image, state], outputs=[files]
263
  )
264
  reset_btn.click(on_reset, inputs=[points, image, state], outputs=[points, image])
265
  undo_btn.click(on_undo, inputs=[points, image, state, size], outputs=[points, image])
266
- model_dropdown.change(on_change_model, inputs=[model_dropdown, model], outputs=[model, state, image, size])
267
  new_btn.click(on_new_image, inputs=[model], outputs=[image, mask, state, points, target_point])
268
  max_iters.change(on_max_iter_change, inputs=max_iters, outputs=progress)
269
  return demo
270
 
271
 
272
  if __name__ == '__main__':
273
- import fire
 
 
 
 
 
 
 
274
  demo = main()
275
- fire.Fire(demo.queue(concurrency_count=1, max_size=20).launch)
 
 
7
  import uuid
8
 
9
  from drag_gan import drag_gan, stylegan2
10
+ from stylegan2.inversion import inverse_image
11
 
12
+ device = 'cpu'
13
 
14
 
15
  SIZE_TO_CLICK_SIZE = {
16
+ 1024: 8,
17
+ 512: 5,
18
  256: 2
19
  }
20
 
 
23
  'stylegan2-cat-config-f.pt': 256,
24
  'stylegan2-church-config-f.pt': 256,
25
  'stylegan2-horse-config-f.pt': 256,
26
+ 'ada/ffhq.pt': 1024,
27
+ 'ada/afhqcat.pt': 512,
28
+ 'ada/afhqdog.pt': 512,
29
+ 'ada/afhqwild.pt': 512,
30
+ 'ada/brecahad.pt': 512,
31
+ 'ada/metfaces.pt': 512,
32
  }
33
 
34
+ DEFAULT_CKPT = 'stylegan2-ffhq-config-f.pt'
35
+
36
+
37
+ class grImage(gr.components.Image):
38
+ is_template = True
39
+
40
+ def preprocess(self, x):
41
+ if x is None:
42
+ return x
43
+ if self.tool == "sketch" and self.source in ["upload", "webcam"]:
44
+ decode_image = gr.processing_utils.decode_base64_to_image(x)
45
+ width, height = decode_image.size
46
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
47
+ mask[..., -1] = 255
48
+ mask = self.postprocess(mask)
49
+ x = {'image': x, 'mask': mask}
50
+ return super().preprocess(x)
51
+
52
 
53
  class ImageMask(gr.components.Image):
54
  """
 
120
  handle_points = [torch.tensor(p).float() for p in points['handle']]
121
  target_points = [torch.tensor(p).float() for p in points['target']]
122
 
123
+ if mask.get('mask') is not None:
124
+ mask = Image.fromarray(mask['mask']).convert('L')
125
+ mask = np.array(mask) == 255
126
 
127
+ mask = torch.from_numpy(mask).float().to(device)
128
+ mask = mask.unsqueeze(0).unsqueeze(0)
129
+ else:
130
+ mask = None
131
 
132
  step = 0
133
  for sample2, latent, F, handle_points in drag_gan(model.g_ema, latent, noise, F,
 
178
  'sample': sample,
179
  'history': []
180
  }
181
+ return model, state, to_image(sample), to_image(sample), size
182
 
183
 
184
  def on_new_image(model):
 
216
  return gr.update(visible=True)
217
 
218
 
219
+ def on_image_change(model, image_size, image):
220
+ image = Image.fromarray(image)
221
+ result = inverse_image(
222
+ model.g_ema,
223
+ image,
224
+ image_size=image_size
225
+ )
226
+ result['history'] = []
227
+ image = to_image(result['sample'])
228
+ points = {'target': [], 'handle': []}
229
+ target_point = False
230
+ return image, image, result, points, target_point
231
+
232
+
233
+ def on_mask_change(mask):
234
+ return mask['image']
235
+
236
+
237
  def main():
238
  torch.cuda.manual_seed(25)
239
 
240
  with gr.Blocks() as demo:
241
+ wrapped_model = ModelWrapper(ckpt=DEFAULT_CKPT, size=CKPT_SIZE[DEFAULT_CKPT])
242
  model = gr.State(wrapped_model)
243
  sample_z = torch.randn([1, 512], device=device)
244
  latent, noise = wrapped_model.g_ema.prepare([sample_z])
 
246
 
247
  gr.Markdown(
248
  """
249
+ # DragGAN
250
 
251
  Unofficial implementation of [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
252
 
253
+ [Our Implementation](https://github.com/Zeqiang-Lai/DragGAN) | [Official Implementation](https://github.com/XingangPan/DragGAN) (Not released yet)
254
 
255
  ## Tutorial
256
 
 
258
  2. Setup a least one pair of handle point and target point.
259
  3. Click "Drag it".
260
 
261
+ ## Hints
262
+
263
+ - Handle points (Blue): the point you want to drag.
264
+ - Target points (Red): the destination you want to drag towards to.
265
+
266
+ ## Primary Support of Custom Image.
267
+
268
+ - We now support dragging user uploaded image by GAN inversion.
269
+ - **Please upload your image at `Setup Handle Points` pannel.** Upload it from `Draw a Mask` would cause errors for now.
270
+ - Due to the limitation of GAN inversion,
271
+ - You might wait roughly 1 minute to see the GAN version of the uploaded image.
272
+ - The shown image might be slightly difference from the uploaded one.
273
+ - It could also fail to invert the uploaded image and generate very poor results.
274
+ - Idealy, you should choose the closest model of the uploaded image. For example, choose `stylegan2-ffhq-config-f.pt` for human face. `stylegan2-cat-config-f.pt` for cat.
275
+
276
+ > Please fire an issue if you have encounted any problem. Also don't forgot to give a star to the [Official Repo](https://github.com/XingangPan/DragGAN), [our project](https://github.com/Zeqiang-Lai/DragGAN) could not exist without it.
277
  """,
278
  )
279
  state = gr.State({
 
284
  'history': []
285
  })
286
  points = gr.State({'target': [], 'handle': []})
287
+ size = gr.State(CKPT_SIZE[DEFAULT_CKPT])
288
 
289
  with gr.Row():
290
  with gr.Column(scale=0.3):
291
  with gr.Accordion("Model"):
292
+ model_dropdown = gr.Dropdown(choices=list(CKPT_SIZE.keys()), value=DEFAULT_CKPT,
293
  label='StyleGAN2 model')
294
  max_iters = gr.Slider(1, 500, 20, step=1, label='Max Iterations')
295
  new_btn = gr.Button('New Image')
 
315
  with gr.Column():
316
  with gr.Tabs():
317
  with gr.Tab('Draw a Mask', id='mask'):
318
+ mask = ImageMask(value=to_image(sample), label='Mask').style(height=768, width=768)
319
  with gr.Tab('Setup Handle Points', id='input'):
320
+ image = grImage(to_image(sample)).style(height=768, width=768)
321
 
322
  image.select(on_click, [image, target_point, points, size], [image, text, target_point])
323
+ image.upload(on_image_change, [model, size, image], [image, mask, state, points, target_point])
324
+ mask.upload(on_mask_change, [mask], [image])
325
  btn.click(on_drag, inputs=[model, points, max_iters, state, size, mask], outputs=[image, state, progress]).then(
326
  on_show_save, outputs=save_panel).then(
327
  on_save_files, inputs=[image, state], outputs=[files]
328
  )
329
  reset_btn.click(on_reset, inputs=[points, image, state], outputs=[points, image])
330
  undo_btn.click(on_undo, inputs=[points, image, state, size], outputs=[points, image])
331
+ model_dropdown.change(on_change_model, inputs=[model_dropdown, model], outputs=[model, state, image, mask, size])
332
  new_btn.click(on_new_image, inputs=[model], outputs=[image, mask, state, points, target_point])
333
  max_iters.change(on_max_iter_change, inputs=max_iters, outputs=progress)
334
  return demo
335
 
336
 
337
  if __name__ == '__main__':
338
+ import argparse
339
+ parser = argparse.ArgumentParser()
340
+ parser.add_argument('--device', default='cuda')
341
+ parser.add_argument('--share', action='store_true')
342
+ parser.add_argument('-p', '--port', default=None)
343
+ parser.add_argument('--ip', default=None)
344
+ args = parser.parse_args()
345
+ device = args.device
346
  demo = main()
347
+ print('Successfully loaded, starting gradio demo')
348
+ demo.queue(concurrency_count=1, max_size=20).launch(share=args.share, server_name=args.ip, server_port=args.port)
requirements.txt CHANGED
@@ -5,4 +5,5 @@ numpy
5
  ninja
6
  fire
7
  imageio
8
- torchvision
 
 
5
  ninja
6
  fire
7
  imageio
8
+ torchvision
9
+ IPython
stylegan2/{_init__.py → __init__.py} RENAMED
File without changes
stylegan2/inversion.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import torch
5
+ from torch import optim
6
+ from torch.nn import functional as FF
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ import dataclasses
11
+
12
+ from .lpips import util
13
+
14
+
15
+ def noise_regularize(noises):
16
+ loss = 0
17
+
18
+ for noise in noises:
19
+ size = noise.shape[2]
20
+
21
+ while True:
22
+ loss = (
23
+ loss
24
+ + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
25
+ + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
26
+ )
27
+
28
+ if size <= 8:
29
+ break
30
+
31
+ noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
32
+ noise = noise.mean([3, 5])
33
+ size //= 2
34
+
35
+ return loss
36
+
37
+
38
+ def noise_normalize_(noises):
39
+ for noise in noises:
40
+ mean = noise.mean()
41
+ std = noise.std()
42
+
43
+ noise.data.add_(-mean).div_(std)
44
+
45
+
46
+ def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
47
+ lr_ramp = min(1, (1 - t) / rampdown)
48
+ lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
49
+ lr_ramp = lr_ramp * min(1, t / rampup)
50
+
51
+ return initial_lr * lr_ramp
52
+
53
+
54
+ def latent_noise(latent, strength):
55
+ noise = torch.randn_like(latent) * strength
56
+
57
+ return latent + noise
58
+
59
+
60
+ def make_image(tensor):
61
+ return (
62
+ tensor.detach()
63
+ .clamp_(min=-1, max=1)
64
+ .add(1)
65
+ .div_(2)
66
+ .mul(255)
67
+ .type(torch.uint8)
68
+ .permute(0, 2, 3, 1)
69
+ .to("cpu")
70
+ .numpy()
71
+ )
72
+
73
+
74
+ @dataclasses.dataclass
75
+ class InverseConfig:
76
+ lr_warmup = 0.05
77
+ lr_decay = 0.25
78
+ lr = 0.1
79
+ noise = 0.05
80
+ noise_decay = 0.75
81
+ step = 1000
82
+ noise_regularize = 1e5
83
+ mse = 0
84
+ w_plus = False,
85
+
86
+
87
+ def inverse_image(
88
+ g_ema,
89
+ image,
90
+ image_size=256,
91
+ config=InverseConfig()
92
+ ):
93
+ device = "cuda"
94
+ args = config
95
+
96
+ n_mean_latent = 10000
97
+
98
+ resize = min(image_size, 256)
99
+
100
+ transform = transforms.Compose(
101
+ [
102
+ transforms.Resize(resize),
103
+ transforms.CenterCrop(resize),
104
+ transforms.ToTensor(),
105
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
106
+ ]
107
+ )
108
+
109
+ imgs = []
110
+ img = transform(image)
111
+ imgs.append(img)
112
+
113
+ imgs = torch.stack(imgs, 0).to(device)
114
+
115
+ with torch.no_grad():
116
+ noise_sample = torch.randn(n_mean_latent, 512, device=device)
117
+ latent_out = g_ema.style(noise_sample)
118
+
119
+ latent_mean = latent_out.mean(0)
120
+ latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
121
+
122
+ percept = util.PerceptualLoss(
123
+ model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
124
+ )
125
+
126
+ noises_single = g_ema.make_noise()
127
+ noises = []
128
+ for noise in noises_single:
129
+ noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
130
+
131
+ latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
132
+
133
+ if args.w_plus:
134
+ latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
135
+
136
+ latent_in.requires_grad = True
137
+
138
+ for noise in noises:
139
+ noise.requires_grad = True
140
+
141
+ optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
142
+
143
+ pbar = tqdm(range(args.step))
144
+ latent_path = []
145
+
146
+ for i in pbar:
147
+ t = i / args.step
148
+ lr = get_lr(t, args.lr)
149
+ optimizer.param_groups[0]["lr"] = lr
150
+ noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2
151
+ latent_n = latent_noise(latent_in, noise_strength.item())
152
+
153
+ latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises)
154
+ img_gen, F = g_ema.generate(latent, noise)
155
+
156
+ batch, channel, height, width = img_gen.shape
157
+
158
+ if height > 256:
159
+ factor = height // 256
160
+
161
+ img_gen = img_gen.reshape(
162
+ batch, channel, height // factor, factor, width // factor, factor
163
+ )
164
+ img_gen = img_gen.mean([3, 5])
165
+
166
+ p_loss = percept(img_gen, imgs).sum()
167
+ n_loss = noise_regularize(noises)
168
+ mse_loss = FF.mse_loss(img_gen, imgs)
169
+
170
+ loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
171
+
172
+ optimizer.zero_grad()
173
+ loss.backward()
174
+ optimizer.step()
175
+
176
+ noise_normalize_(noises)
177
+
178
+ if (i + 1) % 100 == 0:
179
+ latent_path.append(latent_in.detach().clone())
180
+
181
+ pbar.set_description(
182
+ (
183
+ f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
184
+ f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
185
+ )
186
+ )
187
+
188
+ latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises)
189
+ img_gen, F = g_ema.generate(latent, noise)
190
+
191
+ img_ar = make_image(img_gen)
192
+
193
+ i = 0
194
+
195
+ noise_single = []
196
+ for noise in noises:
197
+ noise_single.append(noise[i: i + 1])
198
+
199
+ result = {
200
+ "latent": latent,
201
+ "noise": noise_single,
202
+ 'F': F,
203
+ "sample": img_gen,
204
+ }
205
+
206
+ pil_img = Image.fromarray(img_ar[i])
207
+ pil_img.save('project.png')
208
+
209
+ return result
stylegan2/lpips/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
stylegan2/lpips/base_model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch.autograd import Variable
5
+ from pdb import set_trace as st
6
+ from IPython import embed
7
+
8
+ class BaseModel():
9
+ def __init__(self):
10
+ pass;
11
+
12
+ def name(self):
13
+ return 'BaseModel'
14
+
15
+ def initialize(self, use_gpu=True, gpu_ids=[0]):
16
+ self.use_gpu = use_gpu
17
+ self.gpu_ids = gpu_ids
18
+
19
+ def forward(self):
20
+ pass
21
+
22
+ def get_image_paths(self):
23
+ pass
24
+
25
+ def optimize_parameters(self):
26
+ pass
27
+
28
+ def get_current_visuals(self):
29
+ return self.input
30
+
31
+ def get_current_errors(self):
32
+ return {}
33
+
34
+ def save(self, label):
35
+ pass
36
+
37
+ # helper saving function that can be used by subclasses
38
+ def save_network(self, network, path, network_label, epoch_label):
39
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
40
+ save_path = os.path.join(path, save_filename)
41
+ torch.save(network.state_dict(), save_path)
42
+
43
+ # helper loading function that can be used by subclasses
44
+ def load_network(self, network, network_label, epoch_label):
45
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
46
+ save_path = os.path.join(self.save_dir, save_filename)
47
+ print('Loading network from %s'%save_path)
48
+ network.load_state_dict(torch.load(save_path))
49
+
50
+ def update_learning_rate():
51
+ pass
52
+
53
+ def get_image_paths(self):
54
+ return self.image_paths
55
+
56
+ def save_done(self, flag=False):
57
+ np.save(os.path.join(self.save_dir, 'done_flag'),flag)
58
+ np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
stylegan2/lpips/dist_model.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ import os
9
+ from collections import OrderedDict
10
+ from torch.autograd import Variable
11
+ import itertools
12
+ from .base_model import BaseModel
13
+ from scipy.ndimage import zoom
14
+ import fractions
15
+ import functools
16
+ import skimage.transform
17
+ from tqdm import tqdm
18
+ import urllib
19
+
20
+ from IPython import embed
21
+
22
+ from . import networks_basic as networks
23
+ from . import util
24
+
25
+
26
+ class DownloadProgressBar(tqdm):
27
+ def update_to(self, b=1, bsize=1, tsize=None):
28
+ if tsize is not None:
29
+ self.total = tsize
30
+ self.update(b * bsize - self.n)
31
+
32
+
33
+ def get_path(base_path):
34
+ BASE_DIR = os.path.join('checkpoints')
35
+
36
+ save_path = os.path.join(BASE_DIR, base_path)
37
+ if not os.path.exists(save_path):
38
+ url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
39
+ print(f'{base_path} not found')
40
+ print('Try to download from huggingface: ', url)
41
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
42
+ download_url(url, save_path)
43
+ print('Downloaded to ', save_path)
44
+ return save_path
45
+
46
+
47
+ def download_url(url, output_path):
48
+ with DownloadProgressBar(unit='B', unit_scale=True,
49
+ miniters=1, desc=url.split('/')[-1]) as t:
50
+ urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
51
+
52
+
53
+ class DistModel(BaseModel):
54
+ def name(self):
55
+ return self.model_name
56
+
57
+ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
58
+ use_gpu=True, printNet=False, spatial=False,
59
+ is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
60
+ '''
61
+ INPUTS
62
+ model - ['net-lin'] for linearly calibrated network
63
+ ['net'] for off-the-shelf network
64
+ ['L2'] for L2 distance in Lab colorspace
65
+ ['SSIM'] for ssim in RGB colorspace
66
+ net - ['squeeze','alex','vgg']
67
+ model_path - if None, will look in weights/[NET_NAME].pth
68
+ colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
69
+ use_gpu - bool - whether or not to use a GPU
70
+ printNet - bool - whether or not to print network architecture out
71
+ spatial - bool - whether to output an array containing varying distances across spatial dimensions
72
+ spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
73
+ spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
74
+ spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
75
+ is_train - bool - [True] for training mode
76
+ lr - float - initial learning rate
77
+ beta1 - float - initial momentum term for adam
78
+ version - 0.1 for latest, 0.0 was original (with a bug)
79
+ gpu_ids - int array - [0] by default, gpus to use
80
+ '''
81
+ BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
82
+
83
+ self.model = model
84
+ self.net = net
85
+ self.is_train = is_train
86
+ self.spatial = spatial
87
+ self.gpu_ids = gpu_ids
88
+ self.model_name = '%s [%s]' % (model, net)
89
+
90
+ if(self.model == 'net-lin'): # pretrained net + linear layer
91
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
92
+ use_dropout=True, spatial=spatial, version=version, lpips=True)
93
+ kw = {}
94
+ if not use_gpu:
95
+ kw['map_location'] = 'cpu'
96
+ if(model_path is None):
97
+ model_path = get_path('weights/v%s/%s.pth' % (version, net))
98
+
99
+ if(not is_train):
100
+ print('Loading model from: %s' % model_path)
101
+ self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
102
+
103
+ elif(self.model == 'net'): # pretrained network
104
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
105
+ elif(self.model in ['L2', 'l2']):
106
+ self.net = networks.L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing
107
+ self.model_name = 'L2'
108
+ elif(self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']):
109
+ self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace)
110
+ self.model_name = 'SSIM'
111
+ else:
112
+ raise ValueError("Model [%s] not recognized." % self.model)
113
+
114
+ self.parameters = list(self.net.parameters())
115
+
116
+ if self.is_train: # training mode
117
+ # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
118
+ self.rankLoss = networks.BCERankingLoss()
119
+ self.parameters += list(self.rankLoss.net.parameters())
120
+ self.lr = lr
121
+ self.old_lr = lr
122
+ self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
123
+ else: # test mode
124
+ self.net.eval()
125
+
126
+ if(use_gpu):
127
+ self.net.to(gpu_ids[0])
128
+ self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
129
+ if(self.is_train):
130
+ self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
131
+
132
+ if(printNet):
133
+ print('---------- Networks initialized -------------')
134
+ networks.print_network(self.net)
135
+ print('-----------------------------------------------')
136
+
137
+ def forward(self, in0, in1, retPerLayer=False):
138
+ ''' Function computes the distance between image patches in0 and in1
139
+ INPUTS
140
+ in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
141
+ OUTPUT
142
+ computed distances between in0 and in1
143
+ '''
144
+
145
+ return self.net.forward(in0, in1, retPerLayer=retPerLayer)
146
+
147
+ # ***** TRAINING FUNCTIONS *****
148
+ def optimize_parameters(self):
149
+ self.forward_train()
150
+ self.optimizer_net.zero_grad()
151
+ self.backward_train()
152
+ self.optimizer_net.step()
153
+ self.clamp_weights()
154
+
155
+ def clamp_weights(self):
156
+ for module in self.net.modules():
157
+ if(hasattr(module, 'weight') and module.kernel_size == (1, 1)):
158
+ module.weight.data = torch.clamp(module.weight.data, min=0)
159
+
160
+ def set_input(self, data):
161
+ self.input_ref = data['ref']
162
+ self.input_p0 = data['p0']
163
+ self.input_p1 = data['p1']
164
+ self.input_judge = data['judge']
165
+
166
+ if(self.use_gpu):
167
+ self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
168
+ self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
169
+ self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
170
+ self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
171
+
172
+ self.var_ref = Variable(self.input_ref, requires_grad=True)
173
+ self.var_p0 = Variable(self.input_p0, requires_grad=True)
174
+ self.var_p1 = Variable(self.input_p1, requires_grad=True)
175
+
176
+ def forward_train(self): # run forward pass
177
+ # print(self.net.module.scaling_layer.shift)
178
+ # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
179
+
180
+ self.d0 = self.forward(self.var_ref, self.var_p0)
181
+ self.d1 = self.forward(self.var_ref, self.var_p1)
182
+ self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
183
+
184
+ self.var_judge = Variable(1. * self.input_judge).view(self.d0.size())
185
+
186
+ self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge * 2. - 1.)
187
+
188
+ return self.loss_total
189
+
190
+ def backward_train(self):
191
+ torch.mean(self.loss_total).backward()
192
+
193
+ def compute_accuracy(self, d0, d1, judge):
194
+ ''' d0, d1 are Variables, judge is a Tensor '''
195
+ d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
196
+ judge_per = judge.cpu().numpy().flatten()
197
+ return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
198
+
199
+ def get_current_errors(self):
200
+ retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
201
+ ('acc_r', self.acc_r)])
202
+
203
+ for key in retDict.keys():
204
+ retDict[key] = np.mean(retDict[key])
205
+
206
+ return retDict
207
+
208
+ def get_current_visuals(self):
209
+ zoom_factor = 256 / self.var_ref.data.size()[2]
210
+
211
+ ref_img = util.tensor2im(self.var_ref.data)
212
+ p0_img = util.tensor2im(self.var_p0.data)
213
+ p1_img = util.tensor2im(self.var_p1.data)
214
+
215
+ ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
216
+ p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
217
+ p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
218
+
219
+ return OrderedDict([('ref', ref_img_vis),
220
+ ('p0', p0_img_vis),
221
+ ('p1', p1_img_vis)])
222
+
223
+ def save(self, path, label):
224
+ if(self.use_gpu):
225
+ self.save_network(self.net.module, path, '', label)
226
+ else:
227
+ self.save_network(self.net, path, '', label)
228
+ self.save_network(self.rankLoss.net, path, 'rank', label)
229
+
230
+ def update_learning_rate(self, nepoch_decay):
231
+ lrd = self.lr / nepoch_decay
232
+ lr = self.old_lr - lrd
233
+
234
+ for param_group in self.optimizer_net.param_groups:
235
+ param_group['lr'] = lr
236
+
237
+ print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr))
238
+ self.old_lr = lr
239
+
240
+
241
+ def score_2afc_dataset(data_loader, func, name=''):
242
+ ''' Function computes Two Alternative Forced Choice (2AFC) score using
243
+ distance function 'func' in dataset 'data_loader'
244
+ INPUTS
245
+ data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
246
+ func - callable distance function - calling d=func(in0,in1) should take 2
247
+ pytorch tensors with shape Nx3xXxY, and return numpy array of length N
248
+ OUTPUTS
249
+ [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
250
+ [1] - dictionary with following elements
251
+ d0s,d1s - N arrays containing distances between reference patch to perturbed patches
252
+ gts - N array in [0,1], preferred patch selected by human evaluators
253
+ (closer to "0" for left patch p0, "1" for right patch p1,
254
+ "0.6" means 60pct people preferred right patch, 40pct preferred left)
255
+ scores - N array in [0,1], corresponding to what percentage function agreed with humans
256
+ CONSTS
257
+ N - number of test triplets in data_loader
258
+ '''
259
+
260
+ d0s = []
261
+ d1s = []
262
+ gts = []
263
+
264
+ for data in tqdm(data_loader.load_data(), desc=name):
265
+ d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist()
266
+ d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist()
267
+ gts += data['judge'].cpu().numpy().flatten().tolist()
268
+
269
+ d0s = np.array(d0s)
270
+ d1s = np.array(d1s)
271
+ gts = np.array(gts)
272
+ scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5
273
+
274
+ return(np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
275
+
276
+
277
+ def score_jnd_dataset(data_loader, func, name=''):
278
+ ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
279
+ INPUTS
280
+ data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
281
+ func - callable distance function - calling d=func(in0,in1) should take 2
282
+ pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
283
+ OUTPUTS
284
+ [0] - JND score in [0,1], mAP score (area under precision-recall curve)
285
+ [1] - dictionary with following elements
286
+ ds - N array containing distances between two patches shown to human evaluator
287
+ sames - N array containing fraction of people who thought the two patches were identical
288
+ CONSTS
289
+ N - number of test triplets in data_loader
290
+ '''
291
+
292
+ ds = []
293
+ gts = []
294
+
295
+ for data in tqdm(data_loader.load_data(), desc=name):
296
+ ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist()
297
+ gts += data['same'].cpu().numpy().flatten().tolist()
298
+
299
+ sames = np.array(gts)
300
+ ds = np.array(ds)
301
+
302
+ sorted_inds = np.argsort(ds)
303
+ ds_sorted = ds[sorted_inds]
304
+ sames_sorted = sames[sorted_inds]
305
+
306
+ TPs = np.cumsum(sames_sorted)
307
+ FPs = np.cumsum(1 - sames_sorted)
308
+ FNs = np.sum(sames_sorted) - TPs
309
+
310
+ precs = TPs / (TPs + FPs)
311
+ recs = TPs / (TPs + FNs)
312
+ score = util.voc_ap(recs, precs)
313
+
314
+ return(score, dict(ds=ds, sames=sames))
stylegan2/lpips/networks_basic.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.init as init
8
+ from torch.autograd import Variable
9
+ import numpy as np
10
+ from pdb import set_trace as st
11
+ from skimage import color
12
+ from IPython import embed
13
+ from . import pretrained_networks as pn
14
+
15
+ from . import util
16
+
17
+
18
+ def spatial_average(in_tens, keepdim=True):
19
+ return in_tens.mean([2,3],keepdim=keepdim)
20
+
21
+ def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
22
+ in_H = in_tens.shape[2]
23
+ scale_factor = 1.*out_H/in_H
24
+
25
+ return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
26
+
27
+ # Learned perceptual metric
28
+ class PNetLin(nn.Module):
29
+ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
30
+ super(PNetLin, self).__init__()
31
+
32
+ self.pnet_type = pnet_type
33
+ self.pnet_tune = pnet_tune
34
+ self.pnet_rand = pnet_rand
35
+ self.spatial = spatial
36
+ self.lpips = lpips
37
+ self.version = version
38
+ self.scaling_layer = ScalingLayer()
39
+
40
+ if(self.pnet_type in ['vgg','vgg16']):
41
+ net_type = pn.vgg16
42
+ self.chns = [64,128,256,512,512]
43
+ elif(self.pnet_type=='alex'):
44
+ net_type = pn.alexnet
45
+ self.chns = [64,192,384,256,256]
46
+ elif(self.pnet_type=='squeeze'):
47
+ net_type = pn.squeezenet
48
+ self.chns = [64,128,256,384,384,512,512]
49
+ self.L = len(self.chns)
50
+
51
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
52
+
53
+ if(lpips):
54
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
55
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
56
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
57
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
58
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
59
+ self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
60
+ if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
61
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
62
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
63
+ self.lins+=[self.lin5,self.lin6]
64
+
65
+ def forward(self, in0, in1, retPerLayer=False):
66
+ # v0.0 - original release had a bug, where input was not scaled
67
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
68
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
69
+ feats0, feats1, diffs = {}, {}, {}
70
+
71
+ for kk in range(self.L):
72
+ feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
73
+ diffs[kk] = (feats0[kk]-feats1[kk])**2
74
+
75
+ if(self.lpips):
76
+ if(self.spatial):
77
+ res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
78
+ else:
79
+ res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
80
+ else:
81
+ if(self.spatial):
82
+ res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
83
+ else:
84
+ res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
85
+
86
+ val = res[0]
87
+ for l in range(1,self.L):
88
+ val += res[l]
89
+
90
+ if(retPerLayer):
91
+ return (val, res)
92
+ else:
93
+ return val
94
+
95
+ class ScalingLayer(nn.Module):
96
+ def __init__(self):
97
+ super(ScalingLayer, self).__init__()
98
+ self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
99
+ self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
100
+
101
+ def forward(self, inp):
102
+ return (inp - self.shift) / self.scale
103
+
104
+
105
+ class NetLinLayer(nn.Module):
106
+ ''' A single linear layer which does a 1x1 conv '''
107
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
108
+ super(NetLinLayer, self).__init__()
109
+
110
+ layers = [nn.Dropout(),] if(use_dropout) else []
111
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
112
+ self.model = nn.Sequential(*layers)
113
+
114
+
115
+ class Dist2LogitLayer(nn.Module):
116
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
117
+ def __init__(self, chn_mid=32, use_sigmoid=True):
118
+ super(Dist2LogitLayer, self).__init__()
119
+
120
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
121
+ layers += [nn.LeakyReLU(0.2,True),]
122
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
123
+ layers += [nn.LeakyReLU(0.2,True),]
124
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
125
+ if(use_sigmoid):
126
+ layers += [nn.Sigmoid(),]
127
+ self.model = nn.Sequential(*layers)
128
+
129
+ def forward(self,d0,d1,eps=0.1):
130
+ return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
131
+
132
+ class BCERankingLoss(nn.Module):
133
+ def __init__(self, chn_mid=32):
134
+ super(BCERankingLoss, self).__init__()
135
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
136
+ # self.parameters = list(self.net.parameters())
137
+ self.loss = torch.nn.BCELoss()
138
+
139
+ def forward(self, d0, d1, judge):
140
+ per = (judge+1.)/2.
141
+ self.logit = self.net.forward(d0,d1)
142
+ return self.loss(self.logit, per)
143
+
144
+ # L2, DSSIM metrics
145
+ class FakeNet(nn.Module):
146
+ def __init__(self, use_gpu=True, colorspace='Lab'):
147
+ super(FakeNet, self).__init__()
148
+ self.use_gpu = use_gpu
149
+ self.colorspace=colorspace
150
+
151
+ class L2(FakeNet):
152
+
153
+ def forward(self, in0, in1, retPerLayer=None):
154
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
155
+
156
+ if(self.colorspace=='RGB'):
157
+ (N,C,X,Y) = in0.size()
158
+ value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
159
+ return value
160
+ elif(self.colorspace=='Lab'):
161
+ value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
162
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
163
+ ret_var = Variable( torch.Tensor((value,) ) )
164
+ if(self.use_gpu):
165
+ ret_var = ret_var.cuda()
166
+ return ret_var
167
+
168
+ class DSSIM(FakeNet):
169
+
170
+ def forward(self, in0, in1, retPerLayer=None):
171
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
172
+
173
+ if(self.colorspace=='RGB'):
174
+ value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
175
+ elif(self.colorspace=='Lab'):
176
+ value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
177
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
178
+ ret_var = Variable( torch.Tensor((value,) ) )
179
+ if(self.use_gpu):
180
+ ret_var = ret_var.cuda()
181
+ return ret_var
182
+
183
+ def print_network(net):
184
+ num_params = 0
185
+ for param in net.parameters():
186
+ num_params += param.numel()
187
+ print('Network',net)
188
+ print('Total number of parameters: %d' % num_params)
stylegan2/lpips/pretrained_networks.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torchvision import models as tv
4
+ from IPython import embed
5
+
6
+ class squeezenet(torch.nn.Module):
7
+ def __init__(self, requires_grad=False, pretrained=True):
8
+ super(squeezenet, self).__init__()
9
+ pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
10
+ self.slice1 = torch.nn.Sequential()
11
+ self.slice2 = torch.nn.Sequential()
12
+ self.slice3 = torch.nn.Sequential()
13
+ self.slice4 = torch.nn.Sequential()
14
+ self.slice5 = torch.nn.Sequential()
15
+ self.slice6 = torch.nn.Sequential()
16
+ self.slice7 = torch.nn.Sequential()
17
+ self.N_slices = 7
18
+ for x in range(2):
19
+ self.slice1.add_module(str(x), pretrained_features[x])
20
+ for x in range(2,5):
21
+ self.slice2.add_module(str(x), pretrained_features[x])
22
+ for x in range(5, 8):
23
+ self.slice3.add_module(str(x), pretrained_features[x])
24
+ for x in range(8, 10):
25
+ self.slice4.add_module(str(x), pretrained_features[x])
26
+ for x in range(10, 11):
27
+ self.slice5.add_module(str(x), pretrained_features[x])
28
+ for x in range(11, 12):
29
+ self.slice6.add_module(str(x), pretrained_features[x])
30
+ for x in range(12, 13):
31
+ self.slice7.add_module(str(x), pretrained_features[x])
32
+ if not requires_grad:
33
+ for param in self.parameters():
34
+ param.requires_grad = False
35
+
36
+ def forward(self, X):
37
+ h = self.slice1(X)
38
+ h_relu1 = h
39
+ h = self.slice2(h)
40
+ h_relu2 = h
41
+ h = self.slice3(h)
42
+ h_relu3 = h
43
+ h = self.slice4(h)
44
+ h_relu4 = h
45
+ h = self.slice5(h)
46
+ h_relu5 = h
47
+ h = self.slice6(h)
48
+ h_relu6 = h
49
+ h = self.slice7(h)
50
+ h_relu7 = h
51
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52
+ out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53
+
54
+ return out
55
+
56
+
57
+ class alexnet(torch.nn.Module):
58
+ def __init__(self, requires_grad=False, pretrained=True):
59
+ super(alexnet, self).__init__()
60
+ alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
61
+ self.slice1 = torch.nn.Sequential()
62
+ self.slice2 = torch.nn.Sequential()
63
+ self.slice3 = torch.nn.Sequential()
64
+ self.slice4 = torch.nn.Sequential()
65
+ self.slice5 = torch.nn.Sequential()
66
+ self.N_slices = 5
67
+ for x in range(2):
68
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69
+ for x in range(2, 5):
70
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71
+ for x in range(5, 8):
72
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73
+ for x in range(8, 10):
74
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75
+ for x in range(10, 12):
76
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77
+ if not requires_grad:
78
+ for param in self.parameters():
79
+ param.requires_grad = False
80
+
81
+ def forward(self, X):
82
+ h = self.slice1(X)
83
+ h_relu1 = h
84
+ h = self.slice2(h)
85
+ h_relu2 = h
86
+ h = self.slice3(h)
87
+ h_relu3 = h
88
+ h = self.slice4(h)
89
+ h_relu4 = h
90
+ h = self.slice5(h)
91
+ h_relu5 = h
92
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94
+
95
+ return out
96
+
97
+ class vgg16(torch.nn.Module):
98
+ def __init__(self, requires_grad=False, pretrained=True):
99
+ super(vgg16, self).__init__()
100
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
101
+ self.slice1 = torch.nn.Sequential()
102
+ self.slice2 = torch.nn.Sequential()
103
+ self.slice3 = torch.nn.Sequential()
104
+ self.slice4 = torch.nn.Sequential()
105
+ self.slice5 = torch.nn.Sequential()
106
+ self.N_slices = 5
107
+ for x in range(4):
108
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
109
+ for x in range(4, 9):
110
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(9, 16):
112
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(16, 23):
114
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(23, 30):
116
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
117
+ if not requires_grad:
118
+ for param in self.parameters():
119
+ param.requires_grad = False
120
+
121
+ def forward(self, X):
122
+ h = self.slice1(X)
123
+ h_relu1_2 = h
124
+ h = self.slice2(h)
125
+ h_relu2_2 = h
126
+ h = self.slice3(h)
127
+ h_relu3_3 = h
128
+ h = self.slice4(h)
129
+ h_relu4_3 = h
130
+ h = self.slice5(h)
131
+ h_relu5_3 = h
132
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134
+
135
+ return out
136
+
137
+
138
+
139
+ class resnet(torch.nn.Module):
140
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
141
+ super(resnet, self).__init__()
142
+ if(num==18):
143
+ self.net = tv.resnet18(pretrained=pretrained)
144
+ elif(num==34):
145
+ self.net = tv.resnet34(pretrained=pretrained)
146
+ elif(num==50):
147
+ self.net = tv.resnet50(pretrained=pretrained)
148
+ elif(num==101):
149
+ self.net = tv.resnet101(pretrained=pretrained)
150
+ elif(num==152):
151
+ self.net = tv.resnet152(pretrained=pretrained)
152
+ self.N_slices = 5
153
+
154
+ self.conv1 = self.net.conv1
155
+ self.bn1 = self.net.bn1
156
+ self.relu = self.net.relu
157
+ self.maxpool = self.net.maxpool
158
+ self.layer1 = self.net.layer1
159
+ self.layer2 = self.net.layer2
160
+ self.layer3 = self.net.layer3
161
+ self.layer4 = self.net.layer4
162
+
163
+ def forward(self, X):
164
+ h = self.conv1(X)
165
+ h = self.bn1(h)
166
+ h = self.relu(h)
167
+ h_relu1 = h
168
+ h = self.maxpool(h)
169
+ h = self.layer1(h)
170
+ h_conv2 = h
171
+ h = self.layer2(h)
172
+ h_conv3 = h
173
+ h = self.layer3(h)
174
+ h_conv4 = h
175
+ h = self.layer4(h)
176
+ h_conv5 = h
177
+
178
+ outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180
+
181
+ return out
stylegan2/lpips/util.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import numpy as np
7
+ from skimage.metrics import structural_similarity
8
+ import torch
9
+
10
+
11
+ from . import dist_model
12
+
13
+ class PerceptualLoss(torch.nn.Module):
14
+ def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
15
+ # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
16
+ super(PerceptualLoss, self).__init__()
17
+ print('Setting up Perceptual loss...')
18
+ self.use_gpu = use_gpu
19
+ self.spatial = spatial
20
+ self.gpu_ids = gpu_ids
21
+ self.model = dist_model.DistModel()
22
+ self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
23
+ print('...[%s] initialized'%self.model.name())
24
+ print('...Done')
25
+
26
+ def forward(self, pred, target, normalize=False):
27
+ """
28
+ Pred and target are Variables.
29
+ If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
30
+ If normalize is False, assumes the images are already between [-1,+1]
31
+
32
+ Inputs pred and target are Nx3xHxW
33
+ Output pytorch Variable N long
34
+ """
35
+
36
+ if normalize:
37
+ target = 2 * target - 1
38
+ pred = 2 * pred - 1
39
+
40
+ return self.model.forward(target, pred)
41
+
42
+ def normalize_tensor(in_feat,eps=1e-10):
43
+ norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
44
+ return in_feat/(norm_factor+eps)
45
+
46
+ def l2(p0, p1, range=255.):
47
+ return .5*np.mean((p0 / range - p1 / range)**2)
48
+
49
+ def psnr(p0, p1, peak=255.):
50
+ return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
51
+
52
+ def dssim(p0, p1, range=255.):
53
+ return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2.
54
+
55
+ def rgb2lab(in_img,mean_cent=False):
56
+ from skimage import color
57
+ img_lab = color.rgb2lab(in_img)
58
+ if(mean_cent):
59
+ img_lab[:,:,0] = img_lab[:,:,0]-50
60
+ return img_lab
61
+
62
+ def tensor2np(tensor_obj):
63
+ # change dimension of a tensor object into a numpy array
64
+ return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
65
+
66
+ def np2tensor(np_obj):
67
+ # change dimenion of np array into tensor array
68
+ return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
69
+
70
+ def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
71
+ # image tensor to lab tensor
72
+ from skimage import color
73
+
74
+ img = tensor2im(image_tensor)
75
+ img_lab = color.rgb2lab(img)
76
+ if(mc_only):
77
+ img_lab[:,:,0] = img_lab[:,:,0]-50
78
+ if(to_norm and not mc_only):
79
+ img_lab[:,:,0] = img_lab[:,:,0]-50
80
+ img_lab = img_lab/100.
81
+
82
+ return np2tensor(img_lab)
83
+
84
+ def tensorlab2tensor(lab_tensor,return_inbnd=False):
85
+ from skimage import color
86
+ import warnings
87
+ warnings.filterwarnings("ignore")
88
+
89
+ lab = tensor2np(lab_tensor)*100.
90
+ lab[:,:,0] = lab[:,:,0]+50
91
+
92
+ rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
93
+ if(return_inbnd):
94
+ # convert back to lab, see if we match
95
+ lab_back = color.rgb2lab(rgb_back.astype('uint8'))
96
+ mask = 1.*np.isclose(lab_back,lab,atol=2.)
97
+ mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
98
+ return (im2tensor(rgb_back),mask)
99
+ else:
100
+ return im2tensor(rgb_back)
101
+
102
+ def rgb2lab(input):
103
+ from skimage import color
104
+ return color.rgb2lab(input / 255.)
105
+
106
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
107
+ image_numpy = image_tensor[0].cpu().float().numpy()
108
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
109
+ return image_numpy.astype(imtype)
110
+
111
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
112
+ return torch.Tensor((image / factor - cent)
113
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
114
+
115
+ def tensor2vec(vector_tensor):
116
+ return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
117
+
118
+ def voc_ap(rec, prec, use_07_metric=False):
119
+ """ ap = voc_ap(rec, prec, [use_07_metric])
120
+ Compute VOC AP given precision and recall.
121
+ If use_07_metric is true, uses the
122
+ VOC 07 11 point method (default:False).
123
+ """
124
+ if use_07_metric:
125
+ # 11 point metric
126
+ ap = 0.
127
+ for t in np.arange(0., 1.1, 0.1):
128
+ if np.sum(rec >= t) == 0:
129
+ p = 0
130
+ else:
131
+ p = np.max(prec[rec >= t])
132
+ ap = ap + p / 11.
133
+ else:
134
+ # correct AP calculation
135
+ # first append sentinel values at the end
136
+ mrec = np.concatenate(([0.], rec, [1.]))
137
+ mpre = np.concatenate(([0.], prec, [0.]))
138
+
139
+ # compute the precision envelope
140
+ for i in range(mpre.size - 1, 0, -1):
141
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
142
+
143
+ # to calculate area under PR curve, look for points
144
+ # where X axis (recall) changes value
145
+ i = np.where(mrec[1:] != mrec[:-1])[0]
146
+
147
+ # and sum (\Delta recall) * prec
148
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
149
+ return ap
150
+
151
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
152
+ # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
153
+ image_numpy = image_tensor[0].cpu().float().numpy()
154
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
155
+ return image_numpy.astype(imtype)
156
+
157
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
158
+ # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
159
+ return torch.Tensor((image / factor - cent)
160
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
stylegan2/model.py CHANGED
@@ -5,7 +5,25 @@ import torch
5
  from torch import nn
6
  from torch.nn import functional as F
7
 
8
- from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  class PixelNorm(nn.Module):
 
5
  from torch import nn
6
  from torch.nn import functional as F
7
 
8
+ from .op.fused_act import fused
9
+
10
+ if fused is not None:
11
+ from .op.fused_act import FusedLeakyReLU, fused_leaky_relu
12
+ else:
13
+ from .op import FusedLeakyReLU_Native as FusedLeakyReLU
14
+ from .op import fused_leaky_relu_native as fused_leaky_relu
15
+
16
+ from .op.upfirdn2d import upfirdn2d_op
17
+
18
+ if upfirdn2d_op is not None:
19
+ from .op.upfirdn2d import upfirdn2d
20
+ else:
21
+ from .op import upfirdn2d_native as upfirdn2d
22
+
23
+ from .op import conv2d_gradfix
24
+
25
+ # https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py#L152
26
+ # https://github.com/rosinality/stylegan2-pytorch/issues/70
27
 
28
 
29
  class PixelNorm(nn.Module):
stylegan2/op/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
- from .upfirdn2d import upfirdn2d
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu, fused_leaky_relu_native, FusedLeakyReLU_Native
2
+ from .upfirdn2d import upfirdn2d, upfirdn2d_native
stylegan2/op/conv2d_gradfix.py CHANGED
@@ -76,6 +76,8 @@ def conv_transpose2d(
76
 
77
 
78
  def could_use_op(input):
 
 
79
  if (not enabled) or (not torch.backends.cudnn.enabled):
80
  return False
81
 
 
76
 
77
 
78
  def could_use_op(input):
79
+ return False
80
+
81
  if (not enabled) or (not torch.backends.cudnn.enabled):
82
  return False
83
 
stylegan2/op/fused_act.py CHANGED
@@ -6,15 +6,24 @@ from torch.nn import functional as F
6
  from torch.autograd import Function
7
  from torch.utils.cpp_extension import load
8
 
 
9
 
10
- module_path = os.path.dirname(__file__)
11
- fused = load(
12
- "fused",
13
- sources=[
14
- os.path.join(module_path, "fused_bias_act.cpp"),
15
- os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
- ],
17
- )
 
 
 
 
 
 
 
 
18
 
19
 
20
  class FusedLeakyReLUFunctionBackward(Function):
@@ -125,3 +134,24 @@ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
125
  return FusedLeakyReLUFunction.apply(
126
  input.contiguous(), bias, negative_slope, scale
127
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from torch.autograd import Function
7
  from torch.utils.cpp_extension import load
8
 
9
+ import warnings
10
 
11
+ module_path = os.path.dirname(os.path.abspath(__file__))
12
+
13
+ try:
14
+ fused = load(
15
+ "fused",
16
+ sources=[
17
+ os.path.join(module_path, "fused_bias_act.cpp"),
18
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
19
+ ],
20
+ )
21
+ except:
22
+ warnings.warn(
23
+ f"(This is not error) Switch to native implementation"
24
+ )
25
+
26
+ fused = None
27
 
28
 
29
  class FusedLeakyReLUFunctionBackward(Function):
 
134
  return FusedLeakyReLUFunction.apply(
135
  input.contiguous(), bias, negative_slope, scale
136
  )
137
+
138
+
139
+ class FusedLeakyReLU_Native(nn.Module):
140
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
141
+ super().__init__()
142
+
143
+ if bias:
144
+ self.bias = nn.Parameter(torch.zeros(channel))
145
+
146
+ else:
147
+ self.bias = None
148
+
149
+ self.negative_slope = negative_slope
150
+ self.scale = scale
151
+
152
+ def forward(self, input):
153
+ return fused_leaky_relu_native(input, self.bias, self.negative_slope, self.scale)
154
+
155
+
156
+ def fused_leaky_relu_native(input, bias, negative_slope=0.2, scale=2 ** 0.5):
157
+ return scale * F.leaky_relu(input + bias.view((1, -1) + (1,) * (len(input.shape) - 2)), negative_slope=negative_slope)
stylegan2/op/upfirdn2d.py CHANGED
@@ -5,16 +5,24 @@ import torch
5
  from torch.nn import functional as F
6
  from torch.autograd import Function
7
  from torch.utils.cpp_extension import load
 
8
 
 
9
 
10
- module_path = os.path.dirname(__file__)
11
- upfirdn2d_op = load(
12
- "upfirdn2d",
13
- sources=[
14
- os.path.join(module_path, "upfirdn2d.cpp"),
15
- os.path.join(module_path, "upfirdn2d_kernel.cu"),
16
- ],
17
- )
 
 
 
 
 
 
18
 
19
 
20
  class UpFirDn2dBackward(Function):
@@ -157,7 +165,7 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
157
  pad = (pad[0], pad[1], pad[0], pad[1])
158
 
159
  if input.device.type == "cpu":
160
- out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161
 
162
  else:
163
  out = UpFirDn2d.apply(input, kernel, up, down, pad)
@@ -165,7 +173,22 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
165
  return out
166
 
167
 
168
- def upfirdn2d_native(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170
  ):
171
  _, channel, in_h, in_w = input.shape
@@ -183,8 +206,8 @@ def upfirdn2d_native(
183
  )
184
  out = out[
185
  :,
186
- max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187
- max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188
  :,
189
  ]
190
 
 
5
  from torch.nn import functional as F
6
  from torch.autograd import Function
7
  from torch.utils.cpp_extension import load
8
+ import warnings
9
 
10
+ module_path = os.path.dirname(os.path.abspath(__file__))
11
 
12
+ try:
13
+ upfirdn2d_op = load(
14
+ "upfirdn2d",
15
+ sources=[
16
+ os.path.join(module_path, "upfirdn2d.cpp"),
17
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
18
+ ],
19
+ )
20
+ except:
21
+ warnings.warn(
22
+ f"(This is not error) Switch to native implementation"
23
+ )
24
+
25
+ upfirdn2d_op = None
26
 
27
 
28
  class UpFirDn2dBackward(Function):
 
165
  pad = (pad[0], pad[1], pad[0], pad[1])
166
 
167
  if input.device.type == "cpu":
168
+ out = _upfirdn2d_native(input, kernel, *up, *down, *pad)
169
 
170
  else:
171
  out = UpFirDn2d.apply(input, kernel, up, down, pad)
 
173
  return out
174
 
175
 
176
+ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
177
+ if not isinstance(up, abc.Iterable):
178
+ up = (up, up)
179
+
180
+ if not isinstance(down, abc.Iterable):
181
+ down = (down, down)
182
+
183
+ if len(pad) == 2:
184
+ pad = (pad[0], pad[1], pad[0], pad[1])
185
+
186
+ out = _upfirdn2d_native(input, kernel, *up, *down, *pad)
187
+
188
+ return out
189
+
190
+
191
+ def _upfirdn2d_native(
192
  input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
193
  ):
194
  _, channel, in_h, in_w = input.shape
 
206
  )
207
  out = out[
208
  :,
209
+ max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
210
+ max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
211
  :,
212
  ]
213