Heekyung's picture
Upload 71 files
8377658
import copy
import random
import torch
import torch.nn.functional as FF
import torch.optim
from . import utils
from .stylegan2.model import Generator
class CustomGenerator(Generator):
def prepare(
self,
styles,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
noise=None,
randomize_noise=True,
):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
]
if truncation < 1:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_t
if len(styles) < 2:
inject_index = self.n_latent
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
return latent, noise
def generate(
self,
latent,
noise,
):
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
if out.shape[-1] == 256: F = out
i += 2
image = skip
F = FF.interpolate(F, image.shape[-2:], mode='bilinear')
return image, F
def stylegan2(
size=1024,
channel_multiplier=2,
latent=512,
n_mlp=8,
ckpt='stylegan2-ffhq-config-f.pt'
):
g_ema = CustomGenerator(size, latent, n_mlp, channel_multiplier=channel_multiplier, human='human' in ckpt)
checkpoint = torch.load(utils.get_path(ckpt))
g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
g_ema.requires_grad_(False)
g_ema.eval()
return g_ema
def drag_gan(
g_ema,
latent: torch.Tensor,
noise,
F,
handle_points,
target_points,
mask,
max_iters=1000,
r1=3,
r2=12,
lam=20,
d=2,
lr=2e-3,
):
handle_points0 = copy.deepcopy(handle_points)
handle_points = torch.stack(handle_points)
handle_points0 = torch.stack(handle_points0)
target_points = torch.stack(target_points)
F0 = F.detach().clone()
device = latent.device
latent_trainable = latent[:, :6, :].detach().clone().requires_grad_(True)
latent_untrainable = latent[:, 6:, :].detach().clone().requires_grad_(False)
optimizer = torch.optim.Adam([latent_trainable], lr=lr)
for _ in range(max_iters):
if torch.allclose(handle_points, target_points, atol=d):
break
optimizer.zero_grad()
latent = torch.cat([latent_trainable, latent_untrainable], dim=1)
sample2, F2 = g_ema.generate(latent, noise)
# motion supervision
loss = motion_supervison(handle_points, target_points, F2, r1, device)
if mask is not None:
loss += ((F2 - F0) * (1 - mask)).abs().mean() * lam
loss.backward()
optimizer.step()
with torch.no_grad():
latent = torch.cat([latent_trainable, latent_untrainable], dim=1)
sample2, F2 = g_ema.generate(latent, noise)
handle_points = point_tracking(F2, F0, handle_points, handle_points0, r2, device)
F = F2.detach().clone()
# if iter % 1 == 0:
# print(iter, loss.item(), handle_points, target_points)
yield sample2, latent, F2, handle_points
def motion_supervison(handle_points, target_points, F2, r1, device):
loss = 0
n = len(handle_points)
for i in range(n):
target2handle = target_points[i] - handle_points[i]
d_i = target2handle / (torch.norm(target2handle) + 1e-7)
if torch.norm(d_i) > torch.norm(target2handle):
d_i = target2handle
mask = utils.create_circular_mask(
F2.shape[2], F2.shape[3], center=handle_points[i].tolist(), radius=r1
).to(device)
coordinates = torch.nonzero(mask).float() # shape [num_points, 2]
# Shift the coordinates in the direction d_i
shifted_coordinates = coordinates + d_i[None]
h, w = F2.shape[2], F2.shape[3]
# Extract features in the mask region and compute the loss
F_qi = F2[:, :, mask] # shape: [C, H*W]
# Sample shifted patch from F
normalized_shifted_coordinates = shifted_coordinates.clone()
normalized_shifted_coordinates[:, 0] = (
2.0 * shifted_coordinates[:, 0] / (h - 1)
) - 1 # for height
normalized_shifted_coordinates[:, 1] = (
2.0 * shifted_coordinates[:, 1] / (w - 1)
) - 1 # for width
# Add extra dimensions for batch and channels (required by grid_sample)
normalized_shifted_coordinates = normalized_shifted_coordinates.unsqueeze(
0
).unsqueeze(
0
) # shape [1, 1, num_points, 2]
normalized_shifted_coordinates = normalized_shifted_coordinates.flip(
-1
) # grid_sample expects [x, y] instead of [y, x]
normalized_shifted_coordinates = normalized_shifted_coordinates.clamp(-1, 1)
# Use grid_sample to interpolate the feature map F at the shifted patch coordinates
F_qi_plus_di = torch.nn.functional.grid_sample(
F2, normalized_shifted_coordinates, mode="bilinear", align_corners=True
)
# Output has shape [1, C, 1, num_points] so squeeze it
F_qi_plus_di = F_qi_plus_di.squeeze(2) # shape [1, C, num_points]
loss += torch.nn.functional.l1_loss(F_qi.detach(), F_qi_plus_di)
return loss
def point_tracking(
F: torch.Tensor,
F0: torch.Tensor,
handle_points: torch.Tensor,
handle_points0: torch.Tensor,
r2: int = 3,
device: torch.device = torch.device("cuda"),
) -> torch.Tensor:
n = handle_points.shape[0] # Number of handle points
new_handle_points = torch.zeros_like(handle_points)
for i in range(n):
# Compute the patch around the handle point
patch = utils.create_square_mask(
F.shape[2], F.shape[3], center=handle_points[i].tolist(), radius=r2
).to(device)
# Find indices where the patch is True
patch_coordinates = torch.nonzero(patch) # shape [num_points, 2]
# Extract features in the patch
F_qi = F[:, :, patch_coordinates[:, 0], patch_coordinates[:, 1]]
# Extract feature of the initial handle point
f_i = F0[:, :, handle_points0[i][0].long(), handle_points0[i][1].long()]
# Compute the L1 distance between the patch features and the initial handle point feature
distances = torch.norm(F_qi - f_i[:, :, None], p=1, dim=1)
# Find the new handle point as the one with minimum distance
min_index = torch.argmin(distances)
new_handle_points[i] = patch_coordinates[min_index]
return new_handle_points