|
import math |
|
import os |
|
|
|
import torch |
|
from torch import optim |
|
from torch.nn import functional as FF |
|
from torchvision import transforms |
|
from PIL import Image |
|
from tqdm import tqdm |
|
import dataclasses |
|
|
|
from .lpips import util |
|
|
|
|
|
def noise_regularize(noises): |
|
loss = 0 |
|
|
|
for noise in noises: |
|
size = noise.shape[2] |
|
|
|
while True: |
|
loss = ( |
|
loss |
|
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) |
|
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) |
|
) |
|
|
|
if size <= 8: |
|
break |
|
|
|
noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) |
|
noise = noise.mean([3, 5]) |
|
size //= 2 |
|
|
|
return loss |
|
|
|
|
|
def noise_normalize_(noises): |
|
for noise in noises: |
|
mean = noise.mean() |
|
std = noise.std() |
|
|
|
noise.data.add_(-mean).div_(std) |
|
|
|
|
|
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): |
|
lr_ramp = min(1, (1 - t) / rampdown) |
|
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) |
|
lr_ramp = lr_ramp * min(1, t / rampup) |
|
|
|
return initial_lr * lr_ramp |
|
|
|
|
|
def latent_noise(latent, strength): |
|
noise = torch.randn_like(latent) * strength |
|
|
|
return latent + noise |
|
|
|
|
|
def make_image(tensor): |
|
return ( |
|
tensor.detach() |
|
.clamp_(min=-1, max=1) |
|
.add(1) |
|
.div_(2) |
|
.mul(255) |
|
.type(torch.uint8) |
|
.permute(0, 2, 3, 1) |
|
.to("cpu") |
|
.numpy() |
|
) |
|
|
|
|
|
@dataclasses.dataclass |
|
class InverseConfig: |
|
lr_warmup = 0.05 |
|
lr_decay = 0.25 |
|
lr = 0.1 |
|
noise = 0.05 |
|
noise_decay = 0.75 |
|
step = 1000 |
|
noise_regularize = 1e5 |
|
mse = 0 |
|
w_plus = False, |
|
|
|
|
|
def inverse_image( |
|
g_ema, |
|
image, |
|
image_size=256, |
|
config=InverseConfig() |
|
): |
|
device = "cuda" |
|
args = config |
|
|
|
n_mean_latent = 10000 |
|
|
|
resize = min(image_size, 256) |
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize(resize), |
|
transforms.CenterCrop(resize), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
|
] |
|
) |
|
|
|
imgs = [] |
|
img = transform(image) |
|
imgs.append(img) |
|
|
|
imgs = torch.stack(imgs, 0).to(device) |
|
|
|
with torch.no_grad(): |
|
noise_sample = torch.randn(n_mean_latent, 512, device=device) |
|
latent_out = g_ema.style(noise_sample) |
|
|
|
latent_mean = latent_out.mean(0) |
|
latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 |
|
|
|
percept = util.PerceptualLoss( |
|
model="net-lin", net="vgg", use_gpu=device.startswith("cuda") |
|
) |
|
|
|
noises_single = g_ema.make_noise() |
|
noises = [] |
|
for noise in noises_single: |
|
noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_()) |
|
|
|
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1) |
|
|
|
if args.w_plus: |
|
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) |
|
|
|
latent_in.requires_grad = True |
|
|
|
for noise in noises: |
|
noise.requires_grad = True |
|
|
|
optimizer = optim.Adam([latent_in] + noises, lr=args.lr) |
|
|
|
pbar = tqdm(range(args.step)) |
|
latent_path = [] |
|
|
|
for i in pbar: |
|
t = i / args.step |
|
lr = get_lr(t, args.lr) |
|
optimizer.param_groups[0]["lr"] = lr |
|
noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2 |
|
latent_n = latent_noise(latent_in, noise_strength.item()) |
|
|
|
latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises) |
|
img_gen, F = g_ema.generate(latent, noise) |
|
|
|
batch, channel, height, width = img_gen.shape |
|
|
|
if height > 256: |
|
factor = height // 256 |
|
|
|
img_gen = img_gen.reshape( |
|
batch, channel, height // factor, factor, width // factor, factor |
|
) |
|
img_gen = img_gen.mean([3, 5]) |
|
|
|
p_loss = percept(img_gen, imgs).sum() |
|
n_loss = noise_regularize(noises) |
|
mse_loss = FF.mse_loss(img_gen, imgs) |
|
|
|
loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
noise_normalize_(noises) |
|
|
|
if (i + 1) % 100 == 0: |
|
latent_path.append(latent_in.detach().clone()) |
|
|
|
pbar.set_description( |
|
( |
|
f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};" |
|
f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" |
|
) |
|
) |
|
|
|
latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises) |
|
img_gen, F = g_ema.generate(latent, noise) |
|
|
|
img_ar = make_image(img_gen) |
|
|
|
i = 0 |
|
|
|
noise_single = [] |
|
for noise in noises: |
|
noise_single.append(noise[i: i + 1]) |
|
|
|
result = { |
|
"latent": latent, |
|
"noise": noise_single, |
|
'F': F, |
|
"sample": img_gen, |
|
} |
|
|
|
pil_img = Image.fromarray(img_ar[i]) |
|
pil_img.save('project.png') |
|
|
|
return result |
|
|