Spaces:
Runtime error
Runtime error
Commit
Β·
5a67fb4
1
Parent(s):
1085c64
Implement interpolation between labels
Browse files- app.py +2 -1
- src/app/interpolate_labels.py +141 -0
app.py
CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
|
|
2 |
|
3 |
# Custom imports
|
4 |
from src.app import MultiPage
|
5 |
-
from src.app import explore_infoscc_gan, explore_biggan, explore_cvae, compare_models
|
6 |
|
7 |
# Create an instance of the app
|
8 |
app = MultiPage()
|
@@ -15,6 +15,7 @@ app.add_page('Compare models', compare_models.app)
|
|
15 |
app.add_page('Explore BigGAN', explore_biggan.app)
|
16 |
app.add_page('Explore cVAE', explore_cvae.app)
|
17 |
app.add_page('Explore InfoSCC-GAN', explore_infoscc_gan.app)
|
|
|
18 |
|
19 |
# The main app
|
20 |
app.run()
|
|
|
2 |
|
3 |
# Custom imports
|
4 |
from src.app import MultiPage
|
5 |
+
from src.app import explore_infoscc_gan, explore_biggan, explore_cvae, compare_models, interpolate_labels
|
6 |
|
7 |
# Create an instance of the app
|
8 |
app = MultiPage()
|
|
|
15 |
app.add_page('Explore BigGAN', explore_biggan.app)
|
16 |
app.add_page('Explore cVAE', explore_cvae.app)
|
17 |
app.add_page('Explore InfoSCC-GAN', explore_infoscc_gan.app)
|
18 |
+
app.add_page('Interpolate labels', interpolate_labels.app)
|
19 |
|
20 |
# The main app
|
21 |
app.run()
|
src/app/interpolate_labels.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import math
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import src.app.params as params
|
11 |
+
from src.models import ConditionalGenerator as InfoSCC_GAN
|
12 |
+
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
13 |
+
from src.models import ConditionalDecoder as cVAE
|
14 |
+
from src.data import get_labels_train
|
15 |
+
from src.utils import download_file, sample_labels
|
16 |
+
|
17 |
+
|
18 |
+
device = params.device
|
19 |
+
size = params.size
|
20 |
+
n_layers = int(math.log2(size) - 2)
|
21 |
+
bs = 12
|
22 |
+
lin_space = torch.linspace(0, 1, bs).unsqueeze(1)
|
23 |
+
captions = [f'label_a * {(1 - x):.02f} + label_b * {x:.02f}' for x in lin_space.squeeze().numpy()]
|
24 |
+
|
25 |
+
|
26 |
+
@st.cache(allow_output_mutation=True)
|
27 |
+
def load_model(model_type: str):
|
28 |
+
|
29 |
+
print(f'Loading model: {model_type}')
|
30 |
+
if model_type == 'InfoSCC-GAN':
|
31 |
+
g = InfoSCC_GAN(size=params.size,
|
32 |
+
y_size=params.shape_label,
|
33 |
+
z_size=params.noise_dim)
|
34 |
+
|
35 |
+
if not Path(params.path_infoscc_gan).exists():
|
36 |
+
download_file(params.drive_id_infoscc_gan, params.path_infoscc_gan)
|
37 |
+
|
38 |
+
ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
|
39 |
+
g.load_state_dict(ckpt['g_ema'])
|
40 |
+
elif model_type == 'BigGAN':
|
41 |
+
g = BigGAN2Generator()
|
42 |
+
|
43 |
+
if not Path(params.path_biggan).exists():
|
44 |
+
download_file(params.drive_id_biggan, params.path_biggan)
|
45 |
+
|
46 |
+
ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
|
47 |
+
g.load_state_dict(ckpt)
|
48 |
+
elif model_type == 'cVAE':
|
49 |
+
g = cVAE()
|
50 |
+
|
51 |
+
if not Path(params.path_cvae).exists():
|
52 |
+
download_file(params.drive_id_cvae, params.path_cvae)
|
53 |
+
|
54 |
+
ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
|
55 |
+
g.load_state_dict(ckpt)
|
56 |
+
else:
|
57 |
+
raise ValueError('Unsupported model')
|
58 |
+
g = g.eval().to(device=params.device)
|
59 |
+
return g
|
60 |
+
|
61 |
+
|
62 |
+
@st.cache
|
63 |
+
def get_labels() -> torch.Tensor:
|
64 |
+
path_labels = params.path_labels
|
65 |
+
|
66 |
+
if not Path(path_labels).exists():
|
67 |
+
download_file(params.drive_id_labels, path_labels)
|
68 |
+
|
69 |
+
labels_train = get_labels_train(path_labels)
|
70 |
+
return labels_train
|
71 |
+
|
72 |
+
|
73 |
+
def get_eps(n: int) -> torch.Tensor:
|
74 |
+
eps = torch.randn((n, params.dim_z), device=device)
|
75 |
+
return eps
|
76 |
+
|
77 |
+
|
78 |
+
def app():
|
79 |
+
|
80 |
+
global lin_space, captions
|
81 |
+
|
82 |
+
st.title('Interpolate Labels')
|
83 |
+
st.markdown('This app allows the generation of the images with the labels that are interpolated between two labels.')
|
84 |
+
st.markdown('In each row there are images generated with the same interpolated label by one of the models')
|
85 |
+
|
86 |
+
biggan = load_model('BigGAN')
|
87 |
+
infoscc_gan = load_model('InfoSCC-GAN')
|
88 |
+
cvae = load_model('cVAE')
|
89 |
+
labels_train = get_labels()
|
90 |
+
|
91 |
+
# ==================== Labels ==============================================
|
92 |
+
label_a = sample_labels(labels_train, n=1).repeat(bs, 1)
|
93 |
+
label_b = sample_labels(labels_train, n=1).repeat(bs, 1)
|
94 |
+
label_interpolated = (1 - lin_space) * label_a + lin_space * label_b
|
95 |
+
|
96 |
+
sample_label = st.button('Sample label')
|
97 |
+
if sample_label:
|
98 |
+
label_a = sample_labels(labels_train, n=1).repeat(bs, 1)
|
99 |
+
label_b = sample_labels(labels_train, n=1).repeat(bs, 1)
|
100 |
+
label_interpolated = (1 - lin_space) * label_a + lin_space * label_b
|
101 |
+
# ==================== Labels ==============================================
|
102 |
+
|
103 |
+
# ==================== Noise ==============================================
|
104 |
+
eps = get_eps(1).repeat(bs, 1)
|
105 |
+
eps_infoscc = infoscc_gan.sample_eps(1).repeat(bs, 1)
|
106 |
+
|
107 |
+
zs = np.array([[0.0] * params.n_basis] * n_layers, dtype=np.float32)
|
108 |
+
zs_torch = torch.from_numpy(zs).unsqueeze(0).repeat(bs, 1, 1).to(device)
|
109 |
+
|
110 |
+
st.subheader('Noise')
|
111 |
+
st.markdown(r'Click on __Change eps__ button to change input $\varepsilon$ latent space')
|
112 |
+
change_eps = st.button('Change eps')
|
113 |
+
if change_eps:
|
114 |
+
eps = get_eps(1).repeat(bs, 1)
|
115 |
+
eps_infoscc = infoscc_gan.sample_eps(1).repeat(bs, 1)
|
116 |
+
# ==================== Noise ==============================================
|
117 |
+
|
118 |
+
with torch.no_grad():
|
119 |
+
imgs_biggan = biggan(eps, label_interpolated).squeeze(0).cpu()
|
120 |
+
imgs_infoscc = infoscc_gan(label_interpolated, eps_infoscc, zs_torch).squeeze(0).cpu()
|
121 |
+
imgs_cvae = cvae(eps, label_interpolated).squeeze(0).cpu()
|
122 |
+
|
123 |
+
if params.upsample:
|
124 |
+
imgs_biggan = F.interpolate(imgs_biggan, (size * 4, size * 4), mode='bicubic')
|
125 |
+
imgs_infoscc = F.interpolate(imgs_infoscc, (size * 4, size * 4), mode='bicubic')
|
126 |
+
imgs_cvae = F.interpolate(imgs_cvae, (size * 4, size * 4), mode='bicubic')
|
127 |
+
|
128 |
+
imgs_biggan = torch.clip(imgs_biggan, 0, 1)
|
129 |
+
imgs_biggan = [(imgs_biggan[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) for i in range(bs)]
|
130 |
+
imgs_infoscc = [(imgs_infoscc[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
|
131 |
+
imgs_cvae = [(imgs_cvae[i].permute(1, 2, 0).numpy() * 127.5 + 127.5).astype(np.uint8) for i in range(bs)]
|
132 |
+
|
133 |
+
c1, c2, c3 = st.columns(3)
|
134 |
+
c1.header('BigGAN')
|
135 |
+
c1.image(imgs_biggan, use_column_width=True, caption=captions)
|
136 |
+
|
137 |
+
c2.header('InfoSCC-GAN')
|
138 |
+
c2.image(imgs_infoscc, use_column_width=True, caption=captions)
|
139 |
+
|
140 |
+
c3.header('cVAE')
|
141 |
+
c3.image(imgs_cvae, use_column_width=True, caption=captions)
|