sxela commited on
Commit
8810a39
·
1 Parent(s): d4092e8
Files changed (3) hide show
  1. app.py +171 -0
  2. obama.webm +0 -0
  3. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Thanks to nateraw for making this scape happen!
3
+ This code has been mostly taken from https://huggingface.co/spaces/nateraw/animegan-v2-for-videos/tree/main
4
+ """
5
+ import os
6
+ os.system("wget https://github.com/Sxela/ArcaneGAN/releases/download/v0.3/ArcaneGANv0.3.jit")
7
+
8
+ import sys
9
+ from subprocess import call
10
+ def run_cmd(command):
11
+ try:
12
+ print(command)
13
+ call(command, shell=True)
14
+ except KeyboardInterrupt:
15
+ print("Process interrupted")
16
+ sys.exit(1)
17
+
18
+ print("⬇️ Installing latest gradio==2.4.7b9")
19
+ run_cmd("pip install --upgrade pip")
20
+ run_cmd('pip install gradio==2.4.7b9')
21
+
22
+ import gc
23
+ import math
24
+
25
+
26
+ import gradio as gr
27
+ import numpy as np
28
+ import torch
29
+ from encoded_video import EncodedVideo, write_video
30
+ from PIL import Image
31
+ from torchvision.transforms.functional import center_crop, to_tensor
32
+
33
+
34
+
35
+
36
+ print("🧠 Loading Model...")
37
+ model = torch.jit.load('ArcaneGANv0.3.jit').cuda().eval().half()
38
+
39
+ # This function is taken from pytorchvideo!
40
+ def uniform_temporal_subsample(x: torch.Tensor, num_samples: int, temporal_dim: int = -3) -> torch.Tensor:
41
+ """
42
+ Uniformly subsamples num_samples indices from the temporal dimension of the video.
43
+ When num_samples is larger than the size of temporal dimension of the video, it
44
+ will sample frames based on nearest neighbor interpolation.
45
+ Args:
46
+ x (torch.Tensor): A video tensor with dimension larger than one with torch
47
+ tensor type includes int, long, float, complex, etc.
48
+ num_samples (int): The number of equispaced samples to be selected
49
+ temporal_dim (int): dimension of temporal to perform temporal subsample.
50
+ Returns:
51
+ An x-like Tensor with subsampled temporal dimension.
52
+ """
53
+ t = x.shape[temporal_dim]
54
+ assert num_samples > 0 and t > 0
55
+ # Sample by nearest neighbor interpolation if num_samples > t.
56
+ indices = torch.linspace(0, t - 1, num_samples)
57
+ indices = torch.clamp(indices, 0, t - 1).long()
58
+ return torch.index_select(x, temporal_dim, indices)
59
+
60
+
61
+ # This function is taken from pytorchvideo!
62
+ def short_side_scale(
63
+ x: torch.Tensor,
64
+ size: int,
65
+ interpolation: str = "bilinear",
66
+ ) -> torch.Tensor:
67
+ """
68
+ Determines the shorter spatial dim of the video (i.e. width or height) and scales
69
+ it to the given size. To maintain aspect ratio, the longer side is then scaled
70
+ accordingly.
71
+ Args:
72
+ x (torch.Tensor): A video tensor of shape (C, T, H, W) and type torch.float32.
73
+ size (int): The size the shorter side is scaled to.
74
+ interpolation (str): Algorithm used for upsampling,
75
+ options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'
76
+ Returns:
77
+ An x-like Tensor with scaled spatial dims.
78
+ """
79
+ assert len(x.shape) == 4
80
+ assert x.dtype == torch.float32
81
+ c, t, h, w = x.shape
82
+ if w < h:
83
+ new_h = int(math.floor((float(h) / w) * size))
84
+ new_w = size
85
+ else:
86
+ new_h = size
87
+ new_w = int(math.floor((float(w) / h) * size))
88
+
89
+ return torch.nn.functional.interpolate(x, size=(new_h, new_w), mode=interpolation, align_corners=False)
90
+
91
+ means = [0.485, 0.456, 0.406]
92
+ stds = [0.229, 0.224, 0.225]
93
+
94
+ from torchvision import transforms
95
+ norm = transforms.Normalize(means,stds)
96
+
97
+ norms = torch.tensor(means)[None,:,None,None].cuda()
98
+ stds = torch.tensor(stds)[None,:,None,None].cuda()
99
+
100
+ def inference_step(vid, start_sec, duration, out_fps):
101
+ clip = vid.get_clip(start_sec, start_sec + duration)
102
+ video_arr = torch.from_numpy(clip['video']).permute(3, 0, 1, 2)
103
+ audio_arr = np.expand_dims(clip['audio'], 0)
104
+ audio_fps = None if not vid._has_audio else vid._container.streams.audio[0].sample_rate
105
+
106
+ x = uniform_temporal_subsample(video_arr, duration * out_fps)
107
+ x = center_crop(short_side_scale(x, 512), 512)
108
+ x /= 255.
109
+ x = x.permute(1, 0, 2, 3)
110
+ x = norm(x)
111
+
112
+ with torch.no_grad():
113
+ output = model(x.to('cuda').half())
114
+ output = (output * stds + norms).clip(0, 1) * 255.
115
+
116
+ output_video = output.permute(0, 2, 3, 1).float().detach().cpu().numpy()
117
+
118
+ return output_video, audio_arr, out_fps, audio_fps
119
+
120
+
121
+ def predict_fn(filepath, start_sec, duration, out_fps):
122
+ # out_fps=12
123
+ vid = EncodedVideo.from_path(filepath)
124
+ for i in range(duration):
125
+ video, audio, fps, audio_fps = inference_step(
126
+ vid = vid,
127
+ start_sec = i + start_sec,
128
+ duration = 1,
129
+ out_fps = out_fps
130
+ )
131
+ gc.collect()
132
+ if i == 0:
133
+ video_all = video
134
+ audio_all = audio
135
+ else:
136
+ video_all = np.concatenate((video_all, video))
137
+ audio_all = np.hstack((audio_all, audio))
138
+
139
+ write_video(
140
+ 'out.mp4',
141
+ video_all,
142
+ fps=fps,
143
+ audio_array=audio_all,
144
+ audio_fps=audio_fps,
145
+ audio_codec='aac'
146
+ )
147
+
148
+ del video_all
149
+ del audio_all
150
+
151
+ return 'out.mp4'
152
+
153
+
154
+ title = "ArcaneGAN"
155
+ description = "Gradio demo for ArcaneGAN, video to Arcane style. To use it, simply upload your video, or click one of the examples to load them. Follow <a href='https://twitter.com/devdef' target='_blank'>Alex Spirin</a> for more info and updates."
156
+ article = "<div style='text-align: center;'>ArcaneGan by <a href='https://twitter.com/devdef' target='_blank'>Alex Spirin</a> | <a href='https://github.com/Sxela/ArcaneGAN' target='_blank'>Github Repo</a> | <center><img src='https://visitor-badge.glitch.me/badge?page_id=sxela_arcanegan_video_hf' alt='visitor badge'></center></div>"
157
+
158
+
159
+ gr.Interface(
160
+ predict_fn,
161
+ inputs=[gr.inputs.Video(), gr.inputs.Slider(minimum=0, maximum=300, step=1, default=0), gr.inputs.Slider(minimum=1, maximum=10, step=1, default=2), gr.inputs.Slider(minimum=12, maximum=30, step=6, default=24)],
162
+ outputs=gr.outputs.Video(),
163
+ title='ArcaneGAN On Videos',
164
+ description="Applying ArcaneGAN to frame from video clips",
165
+ article = article,
166
+ enable_queue=True,
167
+ examples=[
168
+ ['obama.webm', 23, 10, 30],
169
+ ],
170
+ allow_flagging=False
171
+ ).launch()
obama.webm ADDED
Binary file (1.21 MB). View file
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ gdown
6
+ numpy
7
+ scipy
8
+ opencv-python-headless
9
+ encoded-video