Spaces:
Runtime error
Runtime error
ZeqiangLai
commited on
Commit
·
5ae5c5c
1
Parent(s):
5febeab
update
Browse files- .gitignore +160 -0
- drag_gan.py +12 -10
- gradio_app.py +90 -17
- requirements.txt +2 -1
- stylegan2/{_init__.py → __init__.py} +0 -0
- stylegan2/inversion.py +209 -0
- stylegan2/lpips/__init__.py +5 -0
- stylegan2/lpips/base_model.py +58 -0
- stylegan2/lpips/dist_model.py +314 -0
- stylegan2/lpips/networks_basic.py +188 -0
- stylegan2/lpips/pretrained_networks.py +181 -0
- stylegan2/lpips/util.py +160 -0
- stylegan2/model.py +19 -1
- stylegan2/op/__init__.py +2 -2
- stylegan2/op/conv2d_gradfix.py +2 -0
- stylegan2/op/fused_act.py +38 -8
- stylegan2/op/upfirdn2d.py +35 -12
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
drag_gan.py
CHANGED
@@ -136,20 +136,21 @@ def bilinear_interpolate_torch(im, y, x):
|
|
136 |
y : 1,numPoints -- pixel location y float
|
137 |
x : 1,numPOints -- pixel location y float
|
138 |
"""
|
139 |
-
|
140 |
-
|
|
|
141 |
x1 = x0 + 1
|
142 |
|
143 |
-
y0 = torch.floor(y).long()
|
144 |
y1 = y0 + 1
|
145 |
|
146 |
-
wa = (x1.float() - x) * (y1.float() - y)
|
147 |
-
wb = (x1.float() - x) * (y - y0.float())
|
148 |
-
wc = (x - x0.float()) * (y1.float() - y)
|
149 |
-
wd = (x - x0.float()) * (y - y0.float())
|
150 |
# Instead of clamp
|
151 |
-
x1 = x1 - torch.floor(x1 / im.shape[3]).int()
|
152 |
-
y1 = y1 - torch.floor(y1 / im.shape[2]).int()
|
153 |
Ia = im[:, :, y0, x0]
|
154 |
Ib = im[:, :, y1, x0]
|
155 |
Ic = im[:, :, y0, x1]
|
@@ -194,7 +195,8 @@ def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points
|
|
194 |
f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1])
|
195 |
loss += FF.l1_loss(f2, f1)
|
196 |
|
197 |
-
|
|
|
198 |
|
199 |
loss.backward()
|
200 |
optimizer.step()
|
|
|
136 |
y : 1,numPoints -- pixel location y float
|
137 |
x : 1,numPOints -- pixel location y float
|
138 |
"""
|
139 |
+
device = im.device
|
140 |
+
|
141 |
+
x0 = torch.floor(x).long().to(device)
|
142 |
x1 = x0 + 1
|
143 |
|
144 |
+
y0 = torch.floor(y).long().to(device)
|
145 |
y1 = y0 + 1
|
146 |
|
147 |
+
wa = ((x1.float() - x) * (y1.float() - y)).to(device)
|
148 |
+
wb = ((x1.float() - x) * (y - y0.float())).to(device)
|
149 |
+
wc = ((x - x0.float()) * (y1.float() - y)).to(device)
|
150 |
+
wd = ((x - x0.float()) * (y - y0.float())).to(device)
|
151 |
# Instead of clamp
|
152 |
+
x1 = x1 - torch.floor(x1 / im.shape[3]).int().to(device)
|
153 |
+
y1 = y1 - torch.floor(y1 / im.shape[2]).int().to(device)
|
154 |
Ia = im[:, :, y0, x0]
|
155 |
Ib = im[:, :, y1, x0]
|
156 |
Ic = im[:, :, y0, x1]
|
|
|
195 |
f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1])
|
196 |
loss += FF.l1_loss(f2, f1)
|
197 |
|
198 |
+
if mask is not None:
|
199 |
+
loss += ((F2 - F0) * (1 - mask)).abs().mean() * lam
|
200 |
|
201 |
loss.backward()
|
202 |
optimizer.step()
|
gradio_app.py
CHANGED
@@ -7,12 +7,14 @@ from PIL import Image
|
|
7 |
import uuid
|
8 |
|
9 |
from drag_gan import drag_gan, stylegan2
|
|
|
10 |
|
11 |
-
device = '
|
12 |
|
13 |
|
14 |
SIZE_TO_CLICK_SIZE = {
|
15 |
-
1024:
|
|
|
16 |
256: 2
|
17 |
}
|
18 |
|
@@ -21,8 +23,32 @@ CKPT_SIZE = {
|
|
21 |
'stylegan2-cat-config-f.pt': 256,
|
22 |
'stylegan2-church-config-f.pt': 256,
|
23 |
'stylegan2-horse-config-f.pt': 256,
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
}
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
class ImageMask(gr.components.Image):
|
28 |
"""
|
@@ -94,11 +120,14 @@ def on_drag(model, points, max_iters, state, size, mask):
|
|
94 |
handle_points = [torch.tensor(p).float() for p in points['handle']]
|
95 |
target_points = [torch.tensor(p).float() for p in points['target']]
|
96 |
|
97 |
-
mask
|
98 |
-
|
|
|
99 |
|
100 |
-
|
101 |
-
|
|
|
|
|
102 |
|
103 |
step = 0
|
104 |
for sample2, latent, F, handle_points in drag_gan(model.g_ema, latent, noise, F,
|
@@ -149,7 +178,7 @@ def on_change_model(selected, model):
|
|
149 |
'sample': sample,
|
150 |
'history': []
|
151 |
}
|
152 |
-
return model, state, to_image(sample), size
|
153 |
|
154 |
|
155 |
def on_new_image(model):
|
@@ -187,11 +216,29 @@ def on_show_save():
|
|
187 |
return gr.update(visible=True)
|
188 |
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
def main():
|
191 |
torch.cuda.manual_seed(25)
|
192 |
|
193 |
with gr.Blocks() as demo:
|
194 |
-
wrapped_model = ModelWrapper()
|
195 |
model = gr.State(wrapped_model)
|
196 |
sample_z = torch.randn([1, 512], device=device)
|
197 |
latent, noise = wrapped_model.g_ema.prepare([sample_z])
|
@@ -199,11 +246,11 @@ def main():
|
|
199 |
|
200 |
gr.Markdown(
|
201 |
"""
|
202 |
-
# DragGAN
|
203 |
|
204 |
Unofficial implementation of [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
|
205 |
|
206 |
-
[
|
207 |
|
208 |
## Tutorial
|
209 |
|
@@ -211,6 +258,22 @@ def main():
|
|
211 |
2. Setup a least one pair of handle point and target point.
|
212 |
3. Click "Drag it".
|
213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
""",
|
215 |
)
|
216 |
state = gr.State({
|
@@ -221,12 +284,12 @@ def main():
|
|
221 |
'history': []
|
222 |
})
|
223 |
points = gr.State({'target': [], 'handle': []})
|
224 |
-
size = gr.State(
|
225 |
|
226 |
with gr.Row():
|
227 |
with gr.Column(scale=0.3):
|
228 |
with gr.Accordion("Model"):
|
229 |
-
model_dropdown = gr.Dropdown(choices=list(CKPT_SIZE.keys()), value=
|
230 |
label='StyleGAN2 model')
|
231 |
max_iters = gr.Slider(1, 500, 20, step=1, label='Max Iterations')
|
232 |
new_btn = gr.Button('New Image')
|
@@ -252,24 +315,34 @@ def main():
|
|
252 |
with gr.Column():
|
253 |
with gr.Tabs():
|
254 |
with gr.Tab('Draw a Mask', id='mask'):
|
255 |
-
mask =
|
256 |
with gr.Tab('Setup Handle Points', id='input'):
|
257 |
-
image =
|
258 |
|
259 |
image.select(on_click, [image, target_point, points, size], [image, text, target_point])
|
|
|
|
|
260 |
btn.click(on_drag, inputs=[model, points, max_iters, state, size, mask], outputs=[image, state, progress]).then(
|
261 |
on_show_save, outputs=save_panel).then(
|
262 |
on_save_files, inputs=[image, state], outputs=[files]
|
263 |
)
|
264 |
reset_btn.click(on_reset, inputs=[points, image, state], outputs=[points, image])
|
265 |
undo_btn.click(on_undo, inputs=[points, image, state, size], outputs=[points, image])
|
266 |
-
model_dropdown.change(on_change_model, inputs=[model_dropdown, model], outputs=[model, state, image, size])
|
267 |
new_btn.click(on_new_image, inputs=[model], outputs=[image, mask, state, points, target_point])
|
268 |
max_iters.change(on_max_iter_change, inputs=max_iters, outputs=progress)
|
269 |
return demo
|
270 |
|
271 |
|
272 |
if __name__ == '__main__':
|
273 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
demo = main()
|
275 |
-
|
|
|
|
7 |
import uuid
|
8 |
|
9 |
from drag_gan import drag_gan, stylegan2
|
10 |
+
from stylegan2.inversion import inverse_image
|
11 |
|
12 |
+
device = 'cpu'
|
13 |
|
14 |
|
15 |
SIZE_TO_CLICK_SIZE = {
|
16 |
+
1024: 8,
|
17 |
+
512: 5,
|
18 |
256: 2
|
19 |
}
|
20 |
|
|
|
23 |
'stylegan2-cat-config-f.pt': 256,
|
24 |
'stylegan2-church-config-f.pt': 256,
|
25 |
'stylegan2-horse-config-f.pt': 256,
|
26 |
+
'ada/ffhq.pt': 1024,
|
27 |
+
'ada/afhqcat.pt': 512,
|
28 |
+
'ada/afhqdog.pt': 512,
|
29 |
+
'ada/afhqwild.pt': 512,
|
30 |
+
'ada/brecahad.pt': 512,
|
31 |
+
'ada/metfaces.pt': 512,
|
32 |
}
|
33 |
|
34 |
+
DEFAULT_CKPT = 'stylegan2-ffhq-config-f.pt'
|
35 |
+
|
36 |
+
|
37 |
+
class grImage(gr.components.Image):
|
38 |
+
is_template = True
|
39 |
+
|
40 |
+
def preprocess(self, x):
|
41 |
+
if x is None:
|
42 |
+
return x
|
43 |
+
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
|
44 |
+
decode_image = gr.processing_utils.decode_base64_to_image(x)
|
45 |
+
width, height = decode_image.size
|
46 |
+
mask = np.zeros((height, width, 4), dtype=np.uint8)
|
47 |
+
mask[..., -1] = 255
|
48 |
+
mask = self.postprocess(mask)
|
49 |
+
x = {'image': x, 'mask': mask}
|
50 |
+
return super().preprocess(x)
|
51 |
+
|
52 |
|
53 |
class ImageMask(gr.components.Image):
|
54 |
"""
|
|
|
120 |
handle_points = [torch.tensor(p).float() for p in points['handle']]
|
121 |
target_points = [torch.tensor(p).float() for p in points['target']]
|
122 |
|
123 |
+
if mask.get('mask') is not None:
|
124 |
+
mask = Image.fromarray(mask['mask']).convert('L')
|
125 |
+
mask = np.array(mask) == 255
|
126 |
|
127 |
+
mask = torch.from_numpy(mask).float().to(device)
|
128 |
+
mask = mask.unsqueeze(0).unsqueeze(0)
|
129 |
+
else:
|
130 |
+
mask = None
|
131 |
|
132 |
step = 0
|
133 |
for sample2, latent, F, handle_points in drag_gan(model.g_ema, latent, noise, F,
|
|
|
178 |
'sample': sample,
|
179 |
'history': []
|
180 |
}
|
181 |
+
return model, state, to_image(sample), to_image(sample), size
|
182 |
|
183 |
|
184 |
def on_new_image(model):
|
|
|
216 |
return gr.update(visible=True)
|
217 |
|
218 |
|
219 |
+
def on_image_change(model, image_size, image):
|
220 |
+
image = Image.fromarray(image)
|
221 |
+
result = inverse_image(
|
222 |
+
model.g_ema,
|
223 |
+
image,
|
224 |
+
image_size=image_size
|
225 |
+
)
|
226 |
+
result['history'] = []
|
227 |
+
image = to_image(result['sample'])
|
228 |
+
points = {'target': [], 'handle': []}
|
229 |
+
target_point = False
|
230 |
+
return image, image, result, points, target_point
|
231 |
+
|
232 |
+
|
233 |
+
def on_mask_change(mask):
|
234 |
+
return mask['image']
|
235 |
+
|
236 |
+
|
237 |
def main():
|
238 |
torch.cuda.manual_seed(25)
|
239 |
|
240 |
with gr.Blocks() as demo:
|
241 |
+
wrapped_model = ModelWrapper(ckpt=DEFAULT_CKPT, size=CKPT_SIZE[DEFAULT_CKPT])
|
242 |
model = gr.State(wrapped_model)
|
243 |
sample_z = torch.randn([1, 512], device=device)
|
244 |
latent, noise = wrapped_model.g_ema.prepare([sample_z])
|
|
|
246 |
|
247 |
gr.Markdown(
|
248 |
"""
|
249 |
+
# DragGAN
|
250 |
|
251 |
Unofficial implementation of [Drag Your GAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://vcai.mpi-inf.mpg.de/projects/DragGAN/)
|
252 |
|
253 |
+
[Our Implementation](https://github.com/Zeqiang-Lai/DragGAN) | [Official Implementation](https://github.com/XingangPan/DragGAN) (Not released yet)
|
254 |
|
255 |
## Tutorial
|
256 |
|
|
|
258 |
2. Setup a least one pair of handle point and target point.
|
259 |
3. Click "Drag it".
|
260 |
|
261 |
+
## Hints
|
262 |
+
|
263 |
+
- Handle points (Blue): the point you want to drag.
|
264 |
+
- Target points (Red): the destination you want to drag towards to.
|
265 |
+
|
266 |
+
## Primary Support of Custom Image.
|
267 |
+
|
268 |
+
- We now support dragging user uploaded image by GAN inversion.
|
269 |
+
- **Please upload your image at `Setup Handle Points` pannel.** Upload it from `Draw a Mask` would cause errors for now.
|
270 |
+
- Due to the limitation of GAN inversion,
|
271 |
+
- You might wait roughly 1 minute to see the GAN version of the uploaded image.
|
272 |
+
- The shown image might be slightly difference from the uploaded one.
|
273 |
+
- It could also fail to invert the uploaded image and generate very poor results.
|
274 |
+
- Idealy, you should choose the closest model of the uploaded image. For example, choose `stylegan2-ffhq-config-f.pt` for human face. `stylegan2-cat-config-f.pt` for cat.
|
275 |
+
|
276 |
+
> Please fire an issue if you have encounted any problem. Also don't forgot to give a star to the [Official Repo](https://github.com/XingangPan/DragGAN), [our project](https://github.com/Zeqiang-Lai/DragGAN) could not exist without it.
|
277 |
""",
|
278 |
)
|
279 |
state = gr.State({
|
|
|
284 |
'history': []
|
285 |
})
|
286 |
points = gr.State({'target': [], 'handle': []})
|
287 |
+
size = gr.State(CKPT_SIZE[DEFAULT_CKPT])
|
288 |
|
289 |
with gr.Row():
|
290 |
with gr.Column(scale=0.3):
|
291 |
with gr.Accordion("Model"):
|
292 |
+
model_dropdown = gr.Dropdown(choices=list(CKPT_SIZE.keys()), value=DEFAULT_CKPT,
|
293 |
label='StyleGAN2 model')
|
294 |
max_iters = gr.Slider(1, 500, 20, step=1, label='Max Iterations')
|
295 |
new_btn = gr.Button('New Image')
|
|
|
315 |
with gr.Column():
|
316 |
with gr.Tabs():
|
317 |
with gr.Tab('Draw a Mask', id='mask'):
|
318 |
+
mask = ImageMask(value=to_image(sample), label='Mask').style(height=768, width=768)
|
319 |
with gr.Tab('Setup Handle Points', id='input'):
|
320 |
+
image = grImage(to_image(sample)).style(height=768, width=768)
|
321 |
|
322 |
image.select(on_click, [image, target_point, points, size], [image, text, target_point])
|
323 |
+
image.upload(on_image_change, [model, size, image], [image, mask, state, points, target_point])
|
324 |
+
mask.upload(on_mask_change, [mask], [image])
|
325 |
btn.click(on_drag, inputs=[model, points, max_iters, state, size, mask], outputs=[image, state, progress]).then(
|
326 |
on_show_save, outputs=save_panel).then(
|
327 |
on_save_files, inputs=[image, state], outputs=[files]
|
328 |
)
|
329 |
reset_btn.click(on_reset, inputs=[points, image, state], outputs=[points, image])
|
330 |
undo_btn.click(on_undo, inputs=[points, image, state, size], outputs=[points, image])
|
331 |
+
model_dropdown.change(on_change_model, inputs=[model_dropdown, model], outputs=[model, state, image, mask, size])
|
332 |
new_btn.click(on_new_image, inputs=[model], outputs=[image, mask, state, points, target_point])
|
333 |
max_iters.change(on_max_iter_change, inputs=max_iters, outputs=progress)
|
334 |
return demo
|
335 |
|
336 |
|
337 |
if __name__ == '__main__':
|
338 |
+
import argparse
|
339 |
+
parser = argparse.ArgumentParser()
|
340 |
+
parser.add_argument('--device', default='cuda')
|
341 |
+
parser.add_argument('--share', action='store_true')
|
342 |
+
parser.add_argument('-p', '--port', default=None)
|
343 |
+
parser.add_argument('--ip', default=None)
|
344 |
+
args = parser.parse_args()
|
345 |
+
device = args.device
|
346 |
demo = main()
|
347 |
+
print('Successfully loaded, starting gradio demo')
|
348 |
+
demo.queue(concurrency_count=1, max_size=20).launch(share=args.share, server_name=args.ip, server_port=args.port)
|
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ numpy
|
|
5 |
ninja
|
6 |
fire
|
7 |
imageio
|
8 |
-
torchvision
|
|
|
|
5 |
ninja
|
6 |
fire
|
7 |
imageio
|
8 |
+
torchvision
|
9 |
+
IPython
|
stylegan2/{_init__.py → __init__.py}
RENAMED
File without changes
|
stylegan2/inversion.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import optim
|
6 |
+
from torch.nn import functional as FF
|
7 |
+
from torchvision import transforms
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
import dataclasses
|
11 |
+
|
12 |
+
from .lpips import util
|
13 |
+
|
14 |
+
|
15 |
+
def noise_regularize(noises):
|
16 |
+
loss = 0
|
17 |
+
|
18 |
+
for noise in noises:
|
19 |
+
size = noise.shape[2]
|
20 |
+
|
21 |
+
while True:
|
22 |
+
loss = (
|
23 |
+
loss
|
24 |
+
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
|
25 |
+
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
|
26 |
+
)
|
27 |
+
|
28 |
+
if size <= 8:
|
29 |
+
break
|
30 |
+
|
31 |
+
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
|
32 |
+
noise = noise.mean([3, 5])
|
33 |
+
size //= 2
|
34 |
+
|
35 |
+
return loss
|
36 |
+
|
37 |
+
|
38 |
+
def noise_normalize_(noises):
|
39 |
+
for noise in noises:
|
40 |
+
mean = noise.mean()
|
41 |
+
std = noise.std()
|
42 |
+
|
43 |
+
noise.data.add_(-mean).div_(std)
|
44 |
+
|
45 |
+
|
46 |
+
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
47 |
+
lr_ramp = min(1, (1 - t) / rampdown)
|
48 |
+
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
|
49 |
+
lr_ramp = lr_ramp * min(1, t / rampup)
|
50 |
+
|
51 |
+
return initial_lr * lr_ramp
|
52 |
+
|
53 |
+
|
54 |
+
def latent_noise(latent, strength):
|
55 |
+
noise = torch.randn_like(latent) * strength
|
56 |
+
|
57 |
+
return latent + noise
|
58 |
+
|
59 |
+
|
60 |
+
def make_image(tensor):
|
61 |
+
return (
|
62 |
+
tensor.detach()
|
63 |
+
.clamp_(min=-1, max=1)
|
64 |
+
.add(1)
|
65 |
+
.div_(2)
|
66 |
+
.mul(255)
|
67 |
+
.type(torch.uint8)
|
68 |
+
.permute(0, 2, 3, 1)
|
69 |
+
.to("cpu")
|
70 |
+
.numpy()
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
@dataclasses.dataclass
|
75 |
+
class InverseConfig:
|
76 |
+
lr_warmup = 0.05
|
77 |
+
lr_decay = 0.25
|
78 |
+
lr = 0.1
|
79 |
+
noise = 0.05
|
80 |
+
noise_decay = 0.75
|
81 |
+
step = 1000
|
82 |
+
noise_regularize = 1e5
|
83 |
+
mse = 0
|
84 |
+
w_plus = False,
|
85 |
+
|
86 |
+
|
87 |
+
def inverse_image(
|
88 |
+
g_ema,
|
89 |
+
image,
|
90 |
+
image_size=256,
|
91 |
+
config=InverseConfig()
|
92 |
+
):
|
93 |
+
device = "cuda"
|
94 |
+
args = config
|
95 |
+
|
96 |
+
n_mean_latent = 10000
|
97 |
+
|
98 |
+
resize = min(image_size, 256)
|
99 |
+
|
100 |
+
transform = transforms.Compose(
|
101 |
+
[
|
102 |
+
transforms.Resize(resize),
|
103 |
+
transforms.CenterCrop(resize),
|
104 |
+
transforms.ToTensor(),
|
105 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
106 |
+
]
|
107 |
+
)
|
108 |
+
|
109 |
+
imgs = []
|
110 |
+
img = transform(image)
|
111 |
+
imgs.append(img)
|
112 |
+
|
113 |
+
imgs = torch.stack(imgs, 0).to(device)
|
114 |
+
|
115 |
+
with torch.no_grad():
|
116 |
+
noise_sample = torch.randn(n_mean_latent, 512, device=device)
|
117 |
+
latent_out = g_ema.style(noise_sample)
|
118 |
+
|
119 |
+
latent_mean = latent_out.mean(0)
|
120 |
+
latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
|
121 |
+
|
122 |
+
percept = util.PerceptualLoss(
|
123 |
+
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
|
124 |
+
)
|
125 |
+
|
126 |
+
noises_single = g_ema.make_noise()
|
127 |
+
noises = []
|
128 |
+
for noise in noises_single:
|
129 |
+
noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
|
130 |
+
|
131 |
+
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
|
132 |
+
|
133 |
+
if args.w_plus:
|
134 |
+
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
|
135 |
+
|
136 |
+
latent_in.requires_grad = True
|
137 |
+
|
138 |
+
for noise in noises:
|
139 |
+
noise.requires_grad = True
|
140 |
+
|
141 |
+
optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
|
142 |
+
|
143 |
+
pbar = tqdm(range(args.step))
|
144 |
+
latent_path = []
|
145 |
+
|
146 |
+
for i in pbar:
|
147 |
+
t = i / args.step
|
148 |
+
lr = get_lr(t, args.lr)
|
149 |
+
optimizer.param_groups[0]["lr"] = lr
|
150 |
+
noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2
|
151 |
+
latent_n = latent_noise(latent_in, noise_strength.item())
|
152 |
+
|
153 |
+
latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises)
|
154 |
+
img_gen, F = g_ema.generate(latent, noise)
|
155 |
+
|
156 |
+
batch, channel, height, width = img_gen.shape
|
157 |
+
|
158 |
+
if height > 256:
|
159 |
+
factor = height // 256
|
160 |
+
|
161 |
+
img_gen = img_gen.reshape(
|
162 |
+
batch, channel, height // factor, factor, width // factor, factor
|
163 |
+
)
|
164 |
+
img_gen = img_gen.mean([3, 5])
|
165 |
+
|
166 |
+
p_loss = percept(img_gen, imgs).sum()
|
167 |
+
n_loss = noise_regularize(noises)
|
168 |
+
mse_loss = FF.mse_loss(img_gen, imgs)
|
169 |
+
|
170 |
+
loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
|
171 |
+
|
172 |
+
optimizer.zero_grad()
|
173 |
+
loss.backward()
|
174 |
+
optimizer.step()
|
175 |
+
|
176 |
+
noise_normalize_(noises)
|
177 |
+
|
178 |
+
if (i + 1) % 100 == 0:
|
179 |
+
latent_path.append(latent_in.detach().clone())
|
180 |
+
|
181 |
+
pbar.set_description(
|
182 |
+
(
|
183 |
+
f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
|
184 |
+
f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
|
185 |
+
)
|
186 |
+
)
|
187 |
+
|
188 |
+
latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises)
|
189 |
+
img_gen, F = g_ema.generate(latent, noise)
|
190 |
+
|
191 |
+
img_ar = make_image(img_gen)
|
192 |
+
|
193 |
+
i = 0
|
194 |
+
|
195 |
+
noise_single = []
|
196 |
+
for noise in noises:
|
197 |
+
noise_single.append(noise[i: i + 1])
|
198 |
+
|
199 |
+
result = {
|
200 |
+
"latent": latent,
|
201 |
+
"noise": noise_single,
|
202 |
+
'F': F,
|
203 |
+
"sample": img_gen,
|
204 |
+
}
|
205 |
+
|
206 |
+
pil_img = Image.fromarray(img_ar[i])
|
207 |
+
pil_img.save('project.png')
|
208 |
+
|
209 |
+
return result
|
stylegan2/lpips/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
stylegan2/lpips/base_model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from pdb import set_trace as st
|
6 |
+
from IPython import embed
|
7 |
+
|
8 |
+
class BaseModel():
|
9 |
+
def __init__(self):
|
10 |
+
pass;
|
11 |
+
|
12 |
+
def name(self):
|
13 |
+
return 'BaseModel'
|
14 |
+
|
15 |
+
def initialize(self, use_gpu=True, gpu_ids=[0]):
|
16 |
+
self.use_gpu = use_gpu
|
17 |
+
self.gpu_ids = gpu_ids
|
18 |
+
|
19 |
+
def forward(self):
|
20 |
+
pass
|
21 |
+
|
22 |
+
def get_image_paths(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def optimize_parameters(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
def get_current_visuals(self):
|
29 |
+
return self.input
|
30 |
+
|
31 |
+
def get_current_errors(self):
|
32 |
+
return {}
|
33 |
+
|
34 |
+
def save(self, label):
|
35 |
+
pass
|
36 |
+
|
37 |
+
# helper saving function that can be used by subclasses
|
38 |
+
def save_network(self, network, path, network_label, epoch_label):
|
39 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
40 |
+
save_path = os.path.join(path, save_filename)
|
41 |
+
torch.save(network.state_dict(), save_path)
|
42 |
+
|
43 |
+
# helper loading function that can be used by subclasses
|
44 |
+
def load_network(self, network, network_label, epoch_label):
|
45 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
46 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
47 |
+
print('Loading network from %s'%save_path)
|
48 |
+
network.load_state_dict(torch.load(save_path))
|
49 |
+
|
50 |
+
def update_learning_rate():
|
51 |
+
pass
|
52 |
+
|
53 |
+
def get_image_paths(self):
|
54 |
+
return self.image_paths
|
55 |
+
|
56 |
+
def save_done(self, flag=False):
|
57 |
+
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
|
58 |
+
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
|
stylegan2/lpips/dist_model.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import os
|
9 |
+
from collections import OrderedDict
|
10 |
+
from torch.autograd import Variable
|
11 |
+
import itertools
|
12 |
+
from .base_model import BaseModel
|
13 |
+
from scipy.ndimage import zoom
|
14 |
+
import fractions
|
15 |
+
import functools
|
16 |
+
import skimage.transform
|
17 |
+
from tqdm import tqdm
|
18 |
+
import urllib
|
19 |
+
|
20 |
+
from IPython import embed
|
21 |
+
|
22 |
+
from . import networks_basic as networks
|
23 |
+
from . import util
|
24 |
+
|
25 |
+
|
26 |
+
class DownloadProgressBar(tqdm):
|
27 |
+
def update_to(self, b=1, bsize=1, tsize=None):
|
28 |
+
if tsize is not None:
|
29 |
+
self.total = tsize
|
30 |
+
self.update(b * bsize - self.n)
|
31 |
+
|
32 |
+
|
33 |
+
def get_path(base_path):
|
34 |
+
BASE_DIR = os.path.join('checkpoints')
|
35 |
+
|
36 |
+
save_path = os.path.join(BASE_DIR, base_path)
|
37 |
+
if not os.path.exists(save_path):
|
38 |
+
url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
|
39 |
+
print(f'{base_path} not found')
|
40 |
+
print('Try to download from huggingface: ', url)
|
41 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
42 |
+
download_url(url, save_path)
|
43 |
+
print('Downloaded to ', save_path)
|
44 |
+
return save_path
|
45 |
+
|
46 |
+
|
47 |
+
def download_url(url, output_path):
|
48 |
+
with DownloadProgressBar(unit='B', unit_scale=True,
|
49 |
+
miniters=1, desc=url.split('/')[-1]) as t:
|
50 |
+
urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
|
51 |
+
|
52 |
+
|
53 |
+
class DistModel(BaseModel):
|
54 |
+
def name(self):
|
55 |
+
return self.model_name
|
56 |
+
|
57 |
+
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
|
58 |
+
use_gpu=True, printNet=False, spatial=False,
|
59 |
+
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
|
60 |
+
'''
|
61 |
+
INPUTS
|
62 |
+
model - ['net-lin'] for linearly calibrated network
|
63 |
+
['net'] for off-the-shelf network
|
64 |
+
['L2'] for L2 distance in Lab colorspace
|
65 |
+
['SSIM'] for ssim in RGB colorspace
|
66 |
+
net - ['squeeze','alex','vgg']
|
67 |
+
model_path - if None, will look in weights/[NET_NAME].pth
|
68 |
+
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
|
69 |
+
use_gpu - bool - whether or not to use a GPU
|
70 |
+
printNet - bool - whether or not to print network architecture out
|
71 |
+
spatial - bool - whether to output an array containing varying distances across spatial dimensions
|
72 |
+
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
|
73 |
+
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
|
74 |
+
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
|
75 |
+
is_train - bool - [True] for training mode
|
76 |
+
lr - float - initial learning rate
|
77 |
+
beta1 - float - initial momentum term for adam
|
78 |
+
version - 0.1 for latest, 0.0 was original (with a bug)
|
79 |
+
gpu_ids - int array - [0] by default, gpus to use
|
80 |
+
'''
|
81 |
+
BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
|
82 |
+
|
83 |
+
self.model = model
|
84 |
+
self.net = net
|
85 |
+
self.is_train = is_train
|
86 |
+
self.spatial = spatial
|
87 |
+
self.gpu_ids = gpu_ids
|
88 |
+
self.model_name = '%s [%s]' % (model, net)
|
89 |
+
|
90 |
+
if(self.model == 'net-lin'): # pretrained net + linear layer
|
91 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
|
92 |
+
use_dropout=True, spatial=spatial, version=version, lpips=True)
|
93 |
+
kw = {}
|
94 |
+
if not use_gpu:
|
95 |
+
kw['map_location'] = 'cpu'
|
96 |
+
if(model_path is None):
|
97 |
+
model_path = get_path('weights/v%s/%s.pth' % (version, net))
|
98 |
+
|
99 |
+
if(not is_train):
|
100 |
+
print('Loading model from: %s' % model_path)
|
101 |
+
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
|
102 |
+
|
103 |
+
elif(self.model == 'net'): # pretrained network
|
104 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
|
105 |
+
elif(self.model in ['L2', 'l2']):
|
106 |
+
self.net = networks.L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing
|
107 |
+
self.model_name = 'L2'
|
108 |
+
elif(self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']):
|
109 |
+
self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace)
|
110 |
+
self.model_name = 'SSIM'
|
111 |
+
else:
|
112 |
+
raise ValueError("Model [%s] not recognized." % self.model)
|
113 |
+
|
114 |
+
self.parameters = list(self.net.parameters())
|
115 |
+
|
116 |
+
if self.is_train: # training mode
|
117 |
+
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
|
118 |
+
self.rankLoss = networks.BCERankingLoss()
|
119 |
+
self.parameters += list(self.rankLoss.net.parameters())
|
120 |
+
self.lr = lr
|
121 |
+
self.old_lr = lr
|
122 |
+
self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
|
123 |
+
else: # test mode
|
124 |
+
self.net.eval()
|
125 |
+
|
126 |
+
if(use_gpu):
|
127 |
+
self.net.to(gpu_ids[0])
|
128 |
+
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
|
129 |
+
if(self.is_train):
|
130 |
+
self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
|
131 |
+
|
132 |
+
if(printNet):
|
133 |
+
print('---------- Networks initialized -------------')
|
134 |
+
networks.print_network(self.net)
|
135 |
+
print('-----------------------------------------------')
|
136 |
+
|
137 |
+
def forward(self, in0, in1, retPerLayer=False):
|
138 |
+
''' Function computes the distance between image patches in0 and in1
|
139 |
+
INPUTS
|
140 |
+
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
|
141 |
+
OUTPUT
|
142 |
+
computed distances between in0 and in1
|
143 |
+
'''
|
144 |
+
|
145 |
+
return self.net.forward(in0, in1, retPerLayer=retPerLayer)
|
146 |
+
|
147 |
+
# ***** TRAINING FUNCTIONS *****
|
148 |
+
def optimize_parameters(self):
|
149 |
+
self.forward_train()
|
150 |
+
self.optimizer_net.zero_grad()
|
151 |
+
self.backward_train()
|
152 |
+
self.optimizer_net.step()
|
153 |
+
self.clamp_weights()
|
154 |
+
|
155 |
+
def clamp_weights(self):
|
156 |
+
for module in self.net.modules():
|
157 |
+
if(hasattr(module, 'weight') and module.kernel_size == (1, 1)):
|
158 |
+
module.weight.data = torch.clamp(module.weight.data, min=0)
|
159 |
+
|
160 |
+
def set_input(self, data):
|
161 |
+
self.input_ref = data['ref']
|
162 |
+
self.input_p0 = data['p0']
|
163 |
+
self.input_p1 = data['p1']
|
164 |
+
self.input_judge = data['judge']
|
165 |
+
|
166 |
+
if(self.use_gpu):
|
167 |
+
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
|
168 |
+
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
|
169 |
+
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
|
170 |
+
self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
|
171 |
+
|
172 |
+
self.var_ref = Variable(self.input_ref, requires_grad=True)
|
173 |
+
self.var_p0 = Variable(self.input_p0, requires_grad=True)
|
174 |
+
self.var_p1 = Variable(self.input_p1, requires_grad=True)
|
175 |
+
|
176 |
+
def forward_train(self): # run forward pass
|
177 |
+
# print(self.net.module.scaling_layer.shift)
|
178 |
+
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
|
179 |
+
|
180 |
+
self.d0 = self.forward(self.var_ref, self.var_p0)
|
181 |
+
self.d1 = self.forward(self.var_ref, self.var_p1)
|
182 |
+
self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
|
183 |
+
|
184 |
+
self.var_judge = Variable(1. * self.input_judge).view(self.d0.size())
|
185 |
+
|
186 |
+
self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge * 2. - 1.)
|
187 |
+
|
188 |
+
return self.loss_total
|
189 |
+
|
190 |
+
def backward_train(self):
|
191 |
+
torch.mean(self.loss_total).backward()
|
192 |
+
|
193 |
+
def compute_accuracy(self, d0, d1, judge):
|
194 |
+
''' d0, d1 are Variables, judge is a Tensor '''
|
195 |
+
d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
|
196 |
+
judge_per = judge.cpu().numpy().flatten()
|
197 |
+
return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
|
198 |
+
|
199 |
+
def get_current_errors(self):
|
200 |
+
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
|
201 |
+
('acc_r', self.acc_r)])
|
202 |
+
|
203 |
+
for key in retDict.keys():
|
204 |
+
retDict[key] = np.mean(retDict[key])
|
205 |
+
|
206 |
+
return retDict
|
207 |
+
|
208 |
+
def get_current_visuals(self):
|
209 |
+
zoom_factor = 256 / self.var_ref.data.size()[2]
|
210 |
+
|
211 |
+
ref_img = util.tensor2im(self.var_ref.data)
|
212 |
+
p0_img = util.tensor2im(self.var_p0.data)
|
213 |
+
p1_img = util.tensor2im(self.var_p1.data)
|
214 |
+
|
215 |
+
ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
|
216 |
+
p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
|
217 |
+
p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
|
218 |
+
|
219 |
+
return OrderedDict([('ref', ref_img_vis),
|
220 |
+
('p0', p0_img_vis),
|
221 |
+
('p1', p1_img_vis)])
|
222 |
+
|
223 |
+
def save(self, path, label):
|
224 |
+
if(self.use_gpu):
|
225 |
+
self.save_network(self.net.module, path, '', label)
|
226 |
+
else:
|
227 |
+
self.save_network(self.net, path, '', label)
|
228 |
+
self.save_network(self.rankLoss.net, path, 'rank', label)
|
229 |
+
|
230 |
+
def update_learning_rate(self, nepoch_decay):
|
231 |
+
lrd = self.lr / nepoch_decay
|
232 |
+
lr = self.old_lr - lrd
|
233 |
+
|
234 |
+
for param_group in self.optimizer_net.param_groups:
|
235 |
+
param_group['lr'] = lr
|
236 |
+
|
237 |
+
print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr))
|
238 |
+
self.old_lr = lr
|
239 |
+
|
240 |
+
|
241 |
+
def score_2afc_dataset(data_loader, func, name=''):
|
242 |
+
''' Function computes Two Alternative Forced Choice (2AFC) score using
|
243 |
+
distance function 'func' in dataset 'data_loader'
|
244 |
+
INPUTS
|
245 |
+
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
|
246 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
247 |
+
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
|
248 |
+
OUTPUTS
|
249 |
+
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
|
250 |
+
[1] - dictionary with following elements
|
251 |
+
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
|
252 |
+
gts - N array in [0,1], preferred patch selected by human evaluators
|
253 |
+
(closer to "0" for left patch p0, "1" for right patch p1,
|
254 |
+
"0.6" means 60pct people preferred right patch, 40pct preferred left)
|
255 |
+
scores - N array in [0,1], corresponding to what percentage function agreed with humans
|
256 |
+
CONSTS
|
257 |
+
N - number of test triplets in data_loader
|
258 |
+
'''
|
259 |
+
|
260 |
+
d0s = []
|
261 |
+
d1s = []
|
262 |
+
gts = []
|
263 |
+
|
264 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
265 |
+
d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist()
|
266 |
+
d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist()
|
267 |
+
gts += data['judge'].cpu().numpy().flatten().tolist()
|
268 |
+
|
269 |
+
d0s = np.array(d0s)
|
270 |
+
d1s = np.array(d1s)
|
271 |
+
gts = np.array(gts)
|
272 |
+
scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5
|
273 |
+
|
274 |
+
return(np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
|
275 |
+
|
276 |
+
|
277 |
+
def score_jnd_dataset(data_loader, func, name=''):
|
278 |
+
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
|
279 |
+
INPUTS
|
280 |
+
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
|
281 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
282 |
+
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
|
283 |
+
OUTPUTS
|
284 |
+
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
|
285 |
+
[1] - dictionary with following elements
|
286 |
+
ds - N array containing distances between two patches shown to human evaluator
|
287 |
+
sames - N array containing fraction of people who thought the two patches were identical
|
288 |
+
CONSTS
|
289 |
+
N - number of test triplets in data_loader
|
290 |
+
'''
|
291 |
+
|
292 |
+
ds = []
|
293 |
+
gts = []
|
294 |
+
|
295 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
296 |
+
ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist()
|
297 |
+
gts += data['same'].cpu().numpy().flatten().tolist()
|
298 |
+
|
299 |
+
sames = np.array(gts)
|
300 |
+
ds = np.array(ds)
|
301 |
+
|
302 |
+
sorted_inds = np.argsort(ds)
|
303 |
+
ds_sorted = ds[sorted_inds]
|
304 |
+
sames_sorted = sames[sorted_inds]
|
305 |
+
|
306 |
+
TPs = np.cumsum(sames_sorted)
|
307 |
+
FPs = np.cumsum(1 - sames_sorted)
|
308 |
+
FNs = np.sum(sames_sorted) - TPs
|
309 |
+
|
310 |
+
precs = TPs / (TPs + FPs)
|
311 |
+
recs = TPs / (TPs + FNs)
|
312 |
+
score = util.voc_ap(recs, precs)
|
313 |
+
|
314 |
+
return(score, dict(ds=ds, sames=sames))
|
stylegan2/lpips/networks_basic.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.init as init
|
8 |
+
from torch.autograd import Variable
|
9 |
+
import numpy as np
|
10 |
+
from pdb import set_trace as st
|
11 |
+
from skimage import color
|
12 |
+
from IPython import embed
|
13 |
+
from . import pretrained_networks as pn
|
14 |
+
|
15 |
+
from . import util
|
16 |
+
|
17 |
+
|
18 |
+
def spatial_average(in_tens, keepdim=True):
|
19 |
+
return in_tens.mean([2,3],keepdim=keepdim)
|
20 |
+
|
21 |
+
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
|
22 |
+
in_H = in_tens.shape[2]
|
23 |
+
scale_factor = 1.*out_H/in_H
|
24 |
+
|
25 |
+
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
|
26 |
+
|
27 |
+
# Learned perceptual metric
|
28 |
+
class PNetLin(nn.Module):
|
29 |
+
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
|
30 |
+
super(PNetLin, self).__init__()
|
31 |
+
|
32 |
+
self.pnet_type = pnet_type
|
33 |
+
self.pnet_tune = pnet_tune
|
34 |
+
self.pnet_rand = pnet_rand
|
35 |
+
self.spatial = spatial
|
36 |
+
self.lpips = lpips
|
37 |
+
self.version = version
|
38 |
+
self.scaling_layer = ScalingLayer()
|
39 |
+
|
40 |
+
if(self.pnet_type in ['vgg','vgg16']):
|
41 |
+
net_type = pn.vgg16
|
42 |
+
self.chns = [64,128,256,512,512]
|
43 |
+
elif(self.pnet_type=='alex'):
|
44 |
+
net_type = pn.alexnet
|
45 |
+
self.chns = [64,192,384,256,256]
|
46 |
+
elif(self.pnet_type=='squeeze'):
|
47 |
+
net_type = pn.squeezenet
|
48 |
+
self.chns = [64,128,256,384,384,512,512]
|
49 |
+
self.L = len(self.chns)
|
50 |
+
|
51 |
+
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
52 |
+
|
53 |
+
if(lpips):
|
54 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
55 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
56 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
57 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
58 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
59 |
+
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
|
60 |
+
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
|
61 |
+
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
62 |
+
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
63 |
+
self.lins+=[self.lin5,self.lin6]
|
64 |
+
|
65 |
+
def forward(self, in0, in1, retPerLayer=False):
|
66 |
+
# v0.0 - original release had a bug, where input was not scaled
|
67 |
+
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
|
68 |
+
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
69 |
+
feats0, feats1, diffs = {}, {}, {}
|
70 |
+
|
71 |
+
for kk in range(self.L):
|
72 |
+
feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
|
73 |
+
diffs[kk] = (feats0[kk]-feats1[kk])**2
|
74 |
+
|
75 |
+
if(self.lpips):
|
76 |
+
if(self.spatial):
|
77 |
+
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
|
78 |
+
else:
|
79 |
+
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
|
80 |
+
else:
|
81 |
+
if(self.spatial):
|
82 |
+
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
|
83 |
+
else:
|
84 |
+
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
|
85 |
+
|
86 |
+
val = res[0]
|
87 |
+
for l in range(1,self.L):
|
88 |
+
val += res[l]
|
89 |
+
|
90 |
+
if(retPerLayer):
|
91 |
+
return (val, res)
|
92 |
+
else:
|
93 |
+
return val
|
94 |
+
|
95 |
+
class ScalingLayer(nn.Module):
|
96 |
+
def __init__(self):
|
97 |
+
super(ScalingLayer, self).__init__()
|
98 |
+
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
|
99 |
+
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
|
100 |
+
|
101 |
+
def forward(self, inp):
|
102 |
+
return (inp - self.shift) / self.scale
|
103 |
+
|
104 |
+
|
105 |
+
class NetLinLayer(nn.Module):
|
106 |
+
''' A single linear layer which does a 1x1 conv '''
|
107 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
108 |
+
super(NetLinLayer, self).__init__()
|
109 |
+
|
110 |
+
layers = [nn.Dropout(),] if(use_dropout) else []
|
111 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
|
112 |
+
self.model = nn.Sequential(*layers)
|
113 |
+
|
114 |
+
|
115 |
+
class Dist2LogitLayer(nn.Module):
|
116 |
+
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
|
117 |
+
def __init__(self, chn_mid=32, use_sigmoid=True):
|
118 |
+
super(Dist2LogitLayer, self).__init__()
|
119 |
+
|
120 |
+
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
|
121 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
122 |
+
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
|
123 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
124 |
+
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
|
125 |
+
if(use_sigmoid):
|
126 |
+
layers += [nn.Sigmoid(),]
|
127 |
+
self.model = nn.Sequential(*layers)
|
128 |
+
|
129 |
+
def forward(self,d0,d1,eps=0.1):
|
130 |
+
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
|
131 |
+
|
132 |
+
class BCERankingLoss(nn.Module):
|
133 |
+
def __init__(self, chn_mid=32):
|
134 |
+
super(BCERankingLoss, self).__init__()
|
135 |
+
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
136 |
+
# self.parameters = list(self.net.parameters())
|
137 |
+
self.loss = torch.nn.BCELoss()
|
138 |
+
|
139 |
+
def forward(self, d0, d1, judge):
|
140 |
+
per = (judge+1.)/2.
|
141 |
+
self.logit = self.net.forward(d0,d1)
|
142 |
+
return self.loss(self.logit, per)
|
143 |
+
|
144 |
+
# L2, DSSIM metrics
|
145 |
+
class FakeNet(nn.Module):
|
146 |
+
def __init__(self, use_gpu=True, colorspace='Lab'):
|
147 |
+
super(FakeNet, self).__init__()
|
148 |
+
self.use_gpu = use_gpu
|
149 |
+
self.colorspace=colorspace
|
150 |
+
|
151 |
+
class L2(FakeNet):
|
152 |
+
|
153 |
+
def forward(self, in0, in1, retPerLayer=None):
|
154 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
155 |
+
|
156 |
+
if(self.colorspace=='RGB'):
|
157 |
+
(N,C,X,Y) = in0.size()
|
158 |
+
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
|
159 |
+
return value
|
160 |
+
elif(self.colorspace=='Lab'):
|
161 |
+
value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
162 |
+
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
163 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
164 |
+
if(self.use_gpu):
|
165 |
+
ret_var = ret_var.cuda()
|
166 |
+
return ret_var
|
167 |
+
|
168 |
+
class DSSIM(FakeNet):
|
169 |
+
|
170 |
+
def forward(self, in0, in1, retPerLayer=None):
|
171 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
172 |
+
|
173 |
+
if(self.colorspace=='RGB'):
|
174 |
+
value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
|
175 |
+
elif(self.colorspace=='Lab'):
|
176 |
+
value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
177 |
+
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
178 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
179 |
+
if(self.use_gpu):
|
180 |
+
ret_var = ret_var.cuda()
|
181 |
+
return ret_var
|
182 |
+
|
183 |
+
def print_network(net):
|
184 |
+
num_params = 0
|
185 |
+
for param in net.parameters():
|
186 |
+
num_params += param.numel()
|
187 |
+
print('Network',net)
|
188 |
+
print('Total number of parameters: %d' % num_params)
|
stylegan2/lpips/pretrained_networks.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torchvision import models as tv
|
4 |
+
from IPython import embed
|
5 |
+
|
6 |
+
class squeezenet(torch.nn.Module):
|
7 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
8 |
+
super(squeezenet, self).__init__()
|
9 |
+
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
|
10 |
+
self.slice1 = torch.nn.Sequential()
|
11 |
+
self.slice2 = torch.nn.Sequential()
|
12 |
+
self.slice3 = torch.nn.Sequential()
|
13 |
+
self.slice4 = torch.nn.Sequential()
|
14 |
+
self.slice5 = torch.nn.Sequential()
|
15 |
+
self.slice6 = torch.nn.Sequential()
|
16 |
+
self.slice7 = torch.nn.Sequential()
|
17 |
+
self.N_slices = 7
|
18 |
+
for x in range(2):
|
19 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
20 |
+
for x in range(2,5):
|
21 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
22 |
+
for x in range(5, 8):
|
23 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
24 |
+
for x in range(8, 10):
|
25 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
26 |
+
for x in range(10, 11):
|
27 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
28 |
+
for x in range(11, 12):
|
29 |
+
self.slice6.add_module(str(x), pretrained_features[x])
|
30 |
+
for x in range(12, 13):
|
31 |
+
self.slice7.add_module(str(x), pretrained_features[x])
|
32 |
+
if not requires_grad:
|
33 |
+
for param in self.parameters():
|
34 |
+
param.requires_grad = False
|
35 |
+
|
36 |
+
def forward(self, X):
|
37 |
+
h = self.slice1(X)
|
38 |
+
h_relu1 = h
|
39 |
+
h = self.slice2(h)
|
40 |
+
h_relu2 = h
|
41 |
+
h = self.slice3(h)
|
42 |
+
h_relu3 = h
|
43 |
+
h = self.slice4(h)
|
44 |
+
h_relu4 = h
|
45 |
+
h = self.slice5(h)
|
46 |
+
h_relu5 = h
|
47 |
+
h = self.slice6(h)
|
48 |
+
h_relu6 = h
|
49 |
+
h = self.slice7(h)
|
50 |
+
h_relu7 = h
|
51 |
+
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
|
52 |
+
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
|
53 |
+
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class alexnet(torch.nn.Module):
|
58 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
59 |
+
super(alexnet, self).__init__()
|
60 |
+
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
|
61 |
+
self.slice1 = torch.nn.Sequential()
|
62 |
+
self.slice2 = torch.nn.Sequential()
|
63 |
+
self.slice3 = torch.nn.Sequential()
|
64 |
+
self.slice4 = torch.nn.Sequential()
|
65 |
+
self.slice5 = torch.nn.Sequential()
|
66 |
+
self.N_slices = 5
|
67 |
+
for x in range(2):
|
68 |
+
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
69 |
+
for x in range(2, 5):
|
70 |
+
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
71 |
+
for x in range(5, 8):
|
72 |
+
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
73 |
+
for x in range(8, 10):
|
74 |
+
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
75 |
+
for x in range(10, 12):
|
76 |
+
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
77 |
+
if not requires_grad:
|
78 |
+
for param in self.parameters():
|
79 |
+
param.requires_grad = False
|
80 |
+
|
81 |
+
def forward(self, X):
|
82 |
+
h = self.slice1(X)
|
83 |
+
h_relu1 = h
|
84 |
+
h = self.slice2(h)
|
85 |
+
h_relu2 = h
|
86 |
+
h = self.slice3(h)
|
87 |
+
h_relu3 = h
|
88 |
+
h = self.slice4(h)
|
89 |
+
h_relu4 = h
|
90 |
+
h = self.slice5(h)
|
91 |
+
h_relu5 = h
|
92 |
+
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
93 |
+
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
94 |
+
|
95 |
+
return out
|
96 |
+
|
97 |
+
class vgg16(torch.nn.Module):
|
98 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
99 |
+
super(vgg16, self).__init__()
|
100 |
+
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
|
101 |
+
self.slice1 = torch.nn.Sequential()
|
102 |
+
self.slice2 = torch.nn.Sequential()
|
103 |
+
self.slice3 = torch.nn.Sequential()
|
104 |
+
self.slice4 = torch.nn.Sequential()
|
105 |
+
self.slice5 = torch.nn.Sequential()
|
106 |
+
self.N_slices = 5
|
107 |
+
for x in range(4):
|
108 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
109 |
+
for x in range(4, 9):
|
110 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
for x in range(9, 16):
|
112 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
for x in range(16, 23):
|
114 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(23, 30):
|
116 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
if not requires_grad:
|
118 |
+
for param in self.parameters():
|
119 |
+
param.requires_grad = False
|
120 |
+
|
121 |
+
def forward(self, X):
|
122 |
+
h = self.slice1(X)
|
123 |
+
h_relu1_2 = h
|
124 |
+
h = self.slice2(h)
|
125 |
+
h_relu2_2 = h
|
126 |
+
h = self.slice3(h)
|
127 |
+
h_relu3_3 = h
|
128 |
+
h = self.slice4(h)
|
129 |
+
h_relu4_3 = h
|
130 |
+
h = self.slice5(h)
|
131 |
+
h_relu5_3 = h
|
132 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
133 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
134 |
+
|
135 |
+
return out
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
class resnet(torch.nn.Module):
|
140 |
+
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
141 |
+
super(resnet, self).__init__()
|
142 |
+
if(num==18):
|
143 |
+
self.net = tv.resnet18(pretrained=pretrained)
|
144 |
+
elif(num==34):
|
145 |
+
self.net = tv.resnet34(pretrained=pretrained)
|
146 |
+
elif(num==50):
|
147 |
+
self.net = tv.resnet50(pretrained=pretrained)
|
148 |
+
elif(num==101):
|
149 |
+
self.net = tv.resnet101(pretrained=pretrained)
|
150 |
+
elif(num==152):
|
151 |
+
self.net = tv.resnet152(pretrained=pretrained)
|
152 |
+
self.N_slices = 5
|
153 |
+
|
154 |
+
self.conv1 = self.net.conv1
|
155 |
+
self.bn1 = self.net.bn1
|
156 |
+
self.relu = self.net.relu
|
157 |
+
self.maxpool = self.net.maxpool
|
158 |
+
self.layer1 = self.net.layer1
|
159 |
+
self.layer2 = self.net.layer2
|
160 |
+
self.layer3 = self.net.layer3
|
161 |
+
self.layer4 = self.net.layer4
|
162 |
+
|
163 |
+
def forward(self, X):
|
164 |
+
h = self.conv1(X)
|
165 |
+
h = self.bn1(h)
|
166 |
+
h = self.relu(h)
|
167 |
+
h_relu1 = h
|
168 |
+
h = self.maxpool(h)
|
169 |
+
h = self.layer1(h)
|
170 |
+
h_conv2 = h
|
171 |
+
h = self.layer2(h)
|
172 |
+
h_conv3 = h
|
173 |
+
h = self.layer3(h)
|
174 |
+
h_conv4 = h
|
175 |
+
h = self.layer4(h)
|
176 |
+
h_conv5 = h
|
177 |
+
|
178 |
+
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
|
179 |
+
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
180 |
+
|
181 |
+
return out
|
stylegan2/lpips/util.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from skimage.metrics import structural_similarity
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
from . import dist_model
|
12 |
+
|
13 |
+
class PerceptualLoss(torch.nn.Module):
|
14 |
+
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
|
15 |
+
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
|
16 |
+
super(PerceptualLoss, self).__init__()
|
17 |
+
print('Setting up Perceptual loss...')
|
18 |
+
self.use_gpu = use_gpu
|
19 |
+
self.spatial = spatial
|
20 |
+
self.gpu_ids = gpu_ids
|
21 |
+
self.model = dist_model.DistModel()
|
22 |
+
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
|
23 |
+
print('...[%s] initialized'%self.model.name())
|
24 |
+
print('...Done')
|
25 |
+
|
26 |
+
def forward(self, pred, target, normalize=False):
|
27 |
+
"""
|
28 |
+
Pred and target are Variables.
|
29 |
+
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
|
30 |
+
If normalize is False, assumes the images are already between [-1,+1]
|
31 |
+
|
32 |
+
Inputs pred and target are Nx3xHxW
|
33 |
+
Output pytorch Variable N long
|
34 |
+
"""
|
35 |
+
|
36 |
+
if normalize:
|
37 |
+
target = 2 * target - 1
|
38 |
+
pred = 2 * pred - 1
|
39 |
+
|
40 |
+
return self.model.forward(target, pred)
|
41 |
+
|
42 |
+
def normalize_tensor(in_feat,eps=1e-10):
|
43 |
+
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
|
44 |
+
return in_feat/(norm_factor+eps)
|
45 |
+
|
46 |
+
def l2(p0, p1, range=255.):
|
47 |
+
return .5*np.mean((p0 / range - p1 / range)**2)
|
48 |
+
|
49 |
+
def psnr(p0, p1, peak=255.):
|
50 |
+
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
|
51 |
+
|
52 |
+
def dssim(p0, p1, range=255.):
|
53 |
+
return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2.
|
54 |
+
|
55 |
+
def rgb2lab(in_img,mean_cent=False):
|
56 |
+
from skimage import color
|
57 |
+
img_lab = color.rgb2lab(in_img)
|
58 |
+
if(mean_cent):
|
59 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
60 |
+
return img_lab
|
61 |
+
|
62 |
+
def tensor2np(tensor_obj):
|
63 |
+
# change dimension of a tensor object into a numpy array
|
64 |
+
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
|
65 |
+
|
66 |
+
def np2tensor(np_obj):
|
67 |
+
# change dimenion of np array into tensor array
|
68 |
+
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
69 |
+
|
70 |
+
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
|
71 |
+
# image tensor to lab tensor
|
72 |
+
from skimage import color
|
73 |
+
|
74 |
+
img = tensor2im(image_tensor)
|
75 |
+
img_lab = color.rgb2lab(img)
|
76 |
+
if(mc_only):
|
77 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
78 |
+
if(to_norm and not mc_only):
|
79 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
80 |
+
img_lab = img_lab/100.
|
81 |
+
|
82 |
+
return np2tensor(img_lab)
|
83 |
+
|
84 |
+
def tensorlab2tensor(lab_tensor,return_inbnd=False):
|
85 |
+
from skimage import color
|
86 |
+
import warnings
|
87 |
+
warnings.filterwarnings("ignore")
|
88 |
+
|
89 |
+
lab = tensor2np(lab_tensor)*100.
|
90 |
+
lab[:,:,0] = lab[:,:,0]+50
|
91 |
+
|
92 |
+
rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
|
93 |
+
if(return_inbnd):
|
94 |
+
# convert back to lab, see if we match
|
95 |
+
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
|
96 |
+
mask = 1.*np.isclose(lab_back,lab,atol=2.)
|
97 |
+
mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
|
98 |
+
return (im2tensor(rgb_back),mask)
|
99 |
+
else:
|
100 |
+
return im2tensor(rgb_back)
|
101 |
+
|
102 |
+
def rgb2lab(input):
|
103 |
+
from skimage import color
|
104 |
+
return color.rgb2lab(input / 255.)
|
105 |
+
|
106 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
107 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
108 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
109 |
+
return image_numpy.astype(imtype)
|
110 |
+
|
111 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
112 |
+
return torch.Tensor((image / factor - cent)
|
113 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
114 |
+
|
115 |
+
def tensor2vec(vector_tensor):
|
116 |
+
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
|
117 |
+
|
118 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
119 |
+
""" ap = voc_ap(rec, prec, [use_07_metric])
|
120 |
+
Compute VOC AP given precision and recall.
|
121 |
+
If use_07_metric is true, uses the
|
122 |
+
VOC 07 11 point method (default:False).
|
123 |
+
"""
|
124 |
+
if use_07_metric:
|
125 |
+
# 11 point metric
|
126 |
+
ap = 0.
|
127 |
+
for t in np.arange(0., 1.1, 0.1):
|
128 |
+
if np.sum(rec >= t) == 0:
|
129 |
+
p = 0
|
130 |
+
else:
|
131 |
+
p = np.max(prec[rec >= t])
|
132 |
+
ap = ap + p / 11.
|
133 |
+
else:
|
134 |
+
# correct AP calculation
|
135 |
+
# first append sentinel values at the end
|
136 |
+
mrec = np.concatenate(([0.], rec, [1.]))
|
137 |
+
mpre = np.concatenate(([0.], prec, [0.]))
|
138 |
+
|
139 |
+
# compute the precision envelope
|
140 |
+
for i in range(mpre.size - 1, 0, -1):
|
141 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
142 |
+
|
143 |
+
# to calculate area under PR curve, look for points
|
144 |
+
# where X axis (recall) changes value
|
145 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
146 |
+
|
147 |
+
# and sum (\Delta recall) * prec
|
148 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
149 |
+
return ap
|
150 |
+
|
151 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
152 |
+
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
|
153 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
154 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
155 |
+
return image_numpy.astype(imtype)
|
156 |
+
|
157 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
158 |
+
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
|
159 |
+
return torch.Tensor((image / factor - cent)
|
160 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
stylegan2/model.py
CHANGED
@@ -5,7 +5,25 @@ import torch
|
|
5 |
from torch import nn
|
6 |
from torch.nn import functional as F
|
7 |
|
8 |
-
from .op import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
class PixelNorm(nn.Module):
|
|
|
5 |
from torch import nn
|
6 |
from torch.nn import functional as F
|
7 |
|
8 |
+
from .op.fused_act import fused
|
9 |
+
|
10 |
+
if fused is not None:
|
11 |
+
from .op.fused_act import FusedLeakyReLU, fused_leaky_relu
|
12 |
+
else:
|
13 |
+
from .op import FusedLeakyReLU_Native as FusedLeakyReLU
|
14 |
+
from .op import fused_leaky_relu_native as fused_leaky_relu
|
15 |
+
|
16 |
+
from .op.upfirdn2d import upfirdn2d_op
|
17 |
+
|
18 |
+
if upfirdn2d_op is not None:
|
19 |
+
from .op.upfirdn2d import upfirdn2d
|
20 |
+
else:
|
21 |
+
from .op import upfirdn2d_native as upfirdn2d
|
22 |
+
|
23 |
+
from .op import conv2d_gradfix
|
24 |
+
|
25 |
+
# https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py#L152
|
26 |
+
# https://github.com/rosinality/stylegan2-pytorch/issues/70
|
27 |
|
28 |
|
29 |
class PixelNorm(nn.Module):
|
stylegan2/op/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
2 |
-
from .upfirdn2d import upfirdn2d
|
|
|
1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu, fused_leaky_relu_native, FusedLeakyReLU_Native
|
2 |
+
from .upfirdn2d import upfirdn2d, upfirdn2d_native
|
stylegan2/op/conv2d_gradfix.py
CHANGED
@@ -76,6 +76,8 @@ def conv_transpose2d(
|
|
76 |
|
77 |
|
78 |
def could_use_op(input):
|
|
|
|
|
79 |
if (not enabled) or (not torch.backends.cudnn.enabled):
|
80 |
return False
|
81 |
|
|
|
76 |
|
77 |
|
78 |
def could_use_op(input):
|
79 |
+
return False
|
80 |
+
|
81 |
if (not enabled) or (not torch.backends.cudnn.enabled):
|
82 |
return False
|
83 |
|
stylegan2/op/fused_act.py
CHANGED
@@ -6,15 +6,24 @@ from torch.nn import functional as F
|
|
6 |
from torch.autograd import Function
|
7 |
from torch.utils.cpp_extension import load
|
8 |
|
|
|
9 |
|
10 |
-
module_path = os.path.dirname(__file__)
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
class FusedLeakyReLUFunctionBackward(Function):
|
@@ -125,3 +134,24 @@ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
|
125 |
return FusedLeakyReLUFunction.apply(
|
126 |
input.contiguous(), bias, negative_slope, scale
|
127 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from torch.autograd import Function
|
7 |
from torch.utils.cpp_extension import load
|
8 |
|
9 |
+
import warnings
|
10 |
|
11 |
+
module_path = os.path.dirname(os.path.abspath(__file__))
|
12 |
+
|
13 |
+
try:
|
14 |
+
fused = load(
|
15 |
+
"fused",
|
16 |
+
sources=[
|
17 |
+
os.path.join(module_path, "fused_bias_act.cpp"),
|
18 |
+
os.path.join(module_path, "fused_bias_act_kernel.cu"),
|
19 |
+
],
|
20 |
+
)
|
21 |
+
except:
|
22 |
+
warnings.warn(
|
23 |
+
f"(This is not error) Switch to native implementation"
|
24 |
+
)
|
25 |
+
|
26 |
+
fused = None
|
27 |
|
28 |
|
29 |
class FusedLeakyReLUFunctionBackward(Function):
|
|
|
134 |
return FusedLeakyReLUFunction.apply(
|
135 |
input.contiguous(), bias, negative_slope, scale
|
136 |
)
|
137 |
+
|
138 |
+
|
139 |
+
class FusedLeakyReLU_Native(nn.Module):
|
140 |
+
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
141 |
+
super().__init__()
|
142 |
+
|
143 |
+
if bias:
|
144 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
145 |
+
|
146 |
+
else:
|
147 |
+
self.bias = None
|
148 |
+
|
149 |
+
self.negative_slope = negative_slope
|
150 |
+
self.scale = scale
|
151 |
+
|
152 |
+
def forward(self, input):
|
153 |
+
return fused_leaky_relu_native(input, self.bias, self.negative_slope, self.scale)
|
154 |
+
|
155 |
+
|
156 |
+
def fused_leaky_relu_native(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
157 |
+
return scale * F.leaky_relu(input + bias.view((1, -1) + (1,) * (len(input.shape) - 2)), negative_slope=negative_slope)
|
stylegan2/op/upfirdn2d.py
CHANGED
@@ -5,16 +5,24 @@ import torch
|
|
5 |
from torch.nn import functional as F
|
6 |
from torch.autograd import Function
|
7 |
from torch.utils.cpp_extension import load
|
|
|
8 |
|
|
|
9 |
|
10 |
-
|
11 |
-
upfirdn2d_op = load(
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
class UpFirDn2dBackward(Function):
|
@@ -157,7 +165,7 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
|
157 |
pad = (pad[0], pad[1], pad[0], pad[1])
|
158 |
|
159 |
if input.device.type == "cpu":
|
160 |
-
out =
|
161 |
|
162 |
else:
|
163 |
out = UpFirDn2d.apply(input, kernel, up, down, pad)
|
@@ -165,7 +173,22 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
|
165 |
return out
|
166 |
|
167 |
|
168 |
-
def upfirdn2d_native(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
170 |
):
|
171 |
_, channel, in_h, in_w = input.shape
|
@@ -183,8 +206,8 @@ def upfirdn2d_native(
|
|
183 |
)
|
184 |
out = out[
|
185 |
:,
|
186 |
-
max(-pad_y0, 0)
|
187 |
-
max(-pad_x0, 0)
|
188 |
:,
|
189 |
]
|
190 |
|
|
|
5 |
from torch.nn import functional as F
|
6 |
from torch.autograd import Function
|
7 |
from torch.utils.cpp_extension import load
|
8 |
+
import warnings
|
9 |
|
10 |
+
module_path = os.path.dirname(os.path.abspath(__file__))
|
11 |
|
12 |
+
try:
|
13 |
+
upfirdn2d_op = load(
|
14 |
+
"upfirdn2d",
|
15 |
+
sources=[
|
16 |
+
os.path.join(module_path, "upfirdn2d.cpp"),
|
17 |
+
os.path.join(module_path, "upfirdn2d_kernel.cu"),
|
18 |
+
],
|
19 |
+
)
|
20 |
+
except:
|
21 |
+
warnings.warn(
|
22 |
+
f"(This is not error) Switch to native implementation"
|
23 |
+
)
|
24 |
+
|
25 |
+
upfirdn2d_op = None
|
26 |
|
27 |
|
28 |
class UpFirDn2dBackward(Function):
|
|
|
165 |
pad = (pad[0], pad[1], pad[0], pad[1])
|
166 |
|
167 |
if input.device.type == "cpu":
|
168 |
+
out = _upfirdn2d_native(input, kernel, *up, *down, *pad)
|
169 |
|
170 |
else:
|
171 |
out = UpFirDn2d.apply(input, kernel, up, down, pad)
|
|
|
173 |
return out
|
174 |
|
175 |
|
176 |
+
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
|
177 |
+
if not isinstance(up, abc.Iterable):
|
178 |
+
up = (up, up)
|
179 |
+
|
180 |
+
if not isinstance(down, abc.Iterable):
|
181 |
+
down = (down, down)
|
182 |
+
|
183 |
+
if len(pad) == 2:
|
184 |
+
pad = (pad[0], pad[1], pad[0], pad[1])
|
185 |
+
|
186 |
+
out = _upfirdn2d_native(input, kernel, *up, *down, *pad)
|
187 |
+
|
188 |
+
return out
|
189 |
+
|
190 |
+
|
191 |
+
def _upfirdn2d_native(
|
192 |
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
193 |
):
|
194 |
_, channel, in_h, in_w = input.shape
|
|
|
206 |
)
|
207 |
out = out[
|
208 |
:,
|
209 |
+
max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
|
210 |
+
max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
|
211 |
:,
|
212 |
]
|
213 |
|