nikkar commited on
Commit
541f198
·
verified ·
1 Parent(s): d515d68

CoTracker3 demo

Browse files
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ videos/apple.mp4 filter=lfs diff=lfs merge=lfs -text
2
+ videos/backpack.mp4 filter=lfs diff=lfs merge=lfs -text
3
+ videos/pillow.mp4 filter=lfs diff=lfs merge=lfs -text
4
+ videos/teddy.mp4 filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,114 +1,614 @@
 
 
 
 
1
  import os
 
 
 
 
 
 
2
  import cv2
3
- import imutils
4
  import torch
 
 
 
 
5
  import numpy as np
6
- import gradio as gr
7
 
8
- from cotracker.utils.visualizer import Visualizer
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- def parse_video(video_file):
12
- vs = cv2.VideoCapture(video_file)
13
 
14
- frames = []
15
- while True:
16
- (gotit, frame) = vs.read()
17
- if frame is not None:
18
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
19
- frames.append(frame)
20
- if not gotit:
21
- break
22
 
23
- return np.stack(frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
- def cotracker_demo(
27
- input_video,
28
- grid_size: int = 10,
29
- tracks_leave_trace: bool = False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  ):
31
- load_video = parse_video(input_video)
32
- load_video = torch.from_numpy(load_video).permute(0, 3, 1, 2)[None].float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online")
 
35
 
36
- if torch.cuda.is_available():
37
- model = model.cuda()
38
- load_video = load_video.cuda()
 
 
 
 
 
 
 
 
 
 
39
 
40
- model(video_chunk=load_video, is_first_step=True, grid_size=grid_size)
41
- for ind in range(0, load_video.shape[1] - model.step, model.step):
 
 
 
 
 
42
  pred_tracks, pred_visibility = model(
43
- video_chunk=load_video[:, ind : ind + model.step * 2]
 
 
 
44
  ) # B T N 2, B T N 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- linewidth = 2
47
- if grid_size < 10:
48
- linewidth = 4
49
- elif grid_size < 20:
50
- linewidth = 3
51
-
52
- vis = Visualizer(
53
- save_dir=os.path.join(os.path.dirname(__file__), "results"),
54
- grayscale=False,
55
- pad_value=100,
56
- fps=10,
57
- linewidth=linewidth,
58
- show_first_frame=5,
59
- tracks_leave_trace=-1 if tracks_leave_trace else 0,
60
- )
61
- import time
62
-
63
- def current_milli_time():
64
- return round(time.time() * 1000)
65
-
66
- filename = str(current_milli_time())
67
- vis.visualize(
68
- load_video.cpu(),
69
- tracks=pred_tracks.cpu(),
70
- visibility=pred_visibility.cpu(),
71
- filename=f"{filename}_pred_track",
72
- )
73
- return os.path.join(
74
- os.path.dirname(__file__), "results", f"{filename}_pred_track.mp4"
75
- )
76
-
77
-
78
- apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4")
79
- bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4")
80
- paragliding_launch = os.path.join(
81
- os.path.dirname(__file__), "videos", "paragliding-launch.mp4"
82
- )
83
- paragliding = os.path.join(os.path.dirname(__file__), "videos", "paragliding.mp4")
84
-
85
- app = gr.Interface(
86
- title="🎨 CoTracker: It is Better to Track Together",
87
- description="<div style='text-align: left;'> \
88
- <p>Welcome to <a href='http://co-tracker.github.io' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
89
- Points are sampled on a regular grid and are tracked jointly. </p> \
90
- <p> To get started, simply upload your <b>.mp4</b> video in landscape orientation or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
91
- <ul style='display: inline-block; text-align: left;'> \
92
- <li>The total number of grid points is the square of <b>Grid Size</b>.</li> \
93
- <li>Check <b>Visualize Track Traces</b> to visualize traces of all the tracked points. </li> \
94
- </ul> \
95
- <p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐</p> \
96
- </div>",
97
- fn=cotracker_demo,
98
- inputs=[
99
- gr.Video(type="file", label="Input video", interactive=True),
100
- gr.Slider(minimum=10, maximum=80, step=1, value=10, label="Grid Size"),
101
- gr.Checkbox(label="Visualize Track Traces"),
102
- ],
103
- outputs=gr.Video(label="Video with predicted tracks"),
104
- examples=[
105
- [apple, 30, False],
106
- [apple, 10, True],
107
- [bear, 10, False],
108
- [paragliding, 10, False],
109
- [paragliding_launch, 10, False],
110
- ],
111
- cache_examples=True,
112
- allow_flagging=False,
113
- )
114
- app.queue(max_size=20, concurrency_count=2).launch(debug=True)
 
1
+ # This Gradio demo code is from https://github.com/cvlab-kaist/locotrack/blob/main/demo/demo.py
2
+ # We updated it to work with CoTracker3 models. We thank authors of LocoTrack
3
+ # for such an amazing Gradio demo.
4
+
5
  import os
6
+ import sys
7
+ import uuid
8
+
9
+ import gradio as gr
10
+ import mediapy
11
+ import numpy as np
12
  import cv2
13
+ import matplotlib
14
  import torch
15
+ import colorsys
16
+ import random
17
+ from typing import List, Optional, Sequence, Tuple
18
+
19
  import numpy as np
 
20
 
 
21
 
22
+ # Generate random colormaps for visualizing different points.
23
+ def get_colors(num_colors: int) -> List[Tuple[int, int, int]]:
24
+ """Gets colormap for points."""
25
+ colors = []
26
+ for i in np.arange(0.0, 360.0, 360.0 / num_colors):
27
+ hue = i / 360.0
28
+ lightness = (50 + np.random.rand() * 10) / 100.0
29
+ saturation = (90 + np.random.rand() * 10) / 100.0
30
+ color = colorsys.hls_to_rgb(hue, lightness, saturation)
31
+ colors.append(
32
+ (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
33
+ )
34
+ random.shuffle(colors)
35
+ return colors
36
+
37
+ def get_points_on_a_grid(
38
+ size: int,
39
+ extent: Tuple[float, ...],
40
+ center: Optional[Tuple[float, ...]] = None,
41
+ device: Optional[torch.device] = torch.device("cpu"),
42
+ ):
43
+ r"""Get a grid of points covering a rectangular region
44
+
45
+ `get_points_on_a_grid(size, extent)` generates a :attr:`size` by
46
+ :attr:`size` grid fo points distributed to cover a rectangular area
47
+ specified by `extent`.
48
+
49
+ The `extent` is a pair of integer :math:`(H,W)` specifying the height
50
+ and width of the rectangle.
51
+
52
+ Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
53
+ specifying the vertical and horizontal center coordinates. The center
54
+ defaults to the middle of the extent.
55
+
56
+ Points are distributed uniformly within the rectangle leaving a margin
57
+ :math:`m=W/64` from the border.
58
+
59
+ It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
60
+ points :math:`P_{ij}=(x_i, y_i)` where
61
+
62
+ .. math::
63
+ P_{ij} = \left(
64
+ c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
65
+ c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
66
+ \right)
67
+
68
+ Points are returned in row-major order.
69
+
70
+ Args:
71
+ size (int): grid size.
72
+ extent (tuple): height and with of the grid extent.
73
+ center (tuple, optional): grid center.
74
+ device (str, optional): Defaults to `"cpu"`.
75
+
76
+ Returns:
77
+ Tensor: grid.
78
+ """
79
+ if size == 1:
80
+ return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
81
+
82
+ if center is None:
83
+ center = [extent[0] / 2, extent[1] / 2]
84
+
85
+ margin = extent[1] / 64
86
+ range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
87
+ range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
88
+ grid_y, grid_x = torch.meshgrid(
89
+ torch.linspace(*range_y, size, device=device),
90
+ torch.linspace(*range_x, size, device=device),
91
+ indexing="ij",
92
+ )
93
+ return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
94
+
95
+ def paint_point_track(
96
+ frames: np.ndarray,
97
+ point_tracks: np.ndarray,
98
+ visibles: np.ndarray,
99
+ colormap: Optional[List[Tuple[int, int, int]]] = None,
100
+ ) -> np.ndarray:
101
+ """Converts a sequence of points to color code video.
102
+
103
+ Args:
104
+ frames: [num_frames, height, width, 3], np.uint8, [0, 255]
105
+ point_tracks: [num_points, num_frames, 2], np.float32, [0, width / height]
106
+ visibles: [num_points, num_frames], bool
107
+ colormap: colormap for points, each point has a different RGB color.
108
+
109
+ Returns:
110
+ video: [num_frames, height, width, 3], np.uint8, [0, 255]
111
+ """
112
+ num_points, num_frames = point_tracks.shape[0:2]
113
+ if colormap is None:
114
+ colormap = get_colors(num_colors=num_points)
115
+ height, width = frames.shape[1:3]
116
+ dot_size_as_fraction_of_min_edge = 0.015
117
+ radius = int(round(min(height, width) * dot_size_as_fraction_of_min_edge))
118
+ diam = radius * 2 + 1
119
+ quadratic_y = np.square(np.arange(diam)[:, np.newaxis] - radius - 1)
120
+ quadratic_x = np.square(np.arange(diam)[np.newaxis, :] - radius - 1)
121
+ icon = (quadratic_y + quadratic_x) - (radius**2) / 2.0
122
+ sharpness = 0.15
123
+ icon = np.clip(icon / (radius * 2 * sharpness), 0, 1)
124
+ icon = 1 - icon[:, :, np.newaxis]
125
+ icon1 = np.pad(icon, [(0, 1), (0, 1), (0, 0)])
126
+ icon2 = np.pad(icon, [(1, 0), (0, 1), (0, 0)])
127
+ icon3 = np.pad(icon, [(0, 1), (1, 0), (0, 0)])
128
+ icon4 = np.pad(icon, [(1, 0), (1, 0), (0, 0)])
129
+
130
+ video = frames.copy()
131
+ for t in range(num_frames):
132
+ # Pad so that points that extend outside the image frame don't crash us
133
+ image = np.pad(
134
+ video[t],
135
+ [
136
+ (radius + 1, radius + 1),
137
+ (radius + 1, radius + 1),
138
+ (0, 0),
139
+ ],
140
+ )
141
+ for i in range(num_points):
142
+ # The icon is centered at the center of a pixel, but the input coordinates
143
+ # are raster coordinates. Therefore, to render a point at (1,1) (which
144
+ # lies on the corner between four pixels), we need 1/4 of the icon placed
145
+ # centered on the 0'th row, 0'th column, etc. We need to subtract
146
+ # 0.5 to make the fractional position come out right.
147
+ x, y = point_tracks[i, t, :] + 0.5
148
+ x = min(max(x, 0.0), width)
149
+ y = min(max(y, 0.0), height)
150
+
151
+ if visibles[i, t]:
152
+ x1, y1 = np.floor(x).astype(np.int32), np.floor(y).astype(np.int32)
153
+ x2, y2 = x1 + 1, y1 + 1
154
+
155
+ # bilinear interpolation
156
+ patch = (
157
+ icon1 * (x2 - x) * (y2 - y)
158
+ + icon2 * (x2 - x) * (y - y1)
159
+ + icon3 * (x - x1) * (y2 - y)
160
+ + icon4 * (x - x1) * (y - y1)
161
+ )
162
+ x_ub = x1 + 2 * radius + 2
163
+ y_ub = y1 + 2 * radius + 2
164
+ image[y1:y_ub, x1:x_ub, :] = (1 - patch) * image[
165
+ y1:y_ub, x1:x_ub, :
166
+ ] + patch * np.array(colormap[i])[np.newaxis, np.newaxis, :]
167
+
168
+ # Remove the pad
169
+ video[t] = image[
170
+ radius + 1 : -radius - 1, radius + 1 : -radius - 1
171
+ ].astype(np.uint8)
172
+ return video
173
+
174
+
175
+ PREVIEW_WIDTH = 768 # Width of the preview video
176
+ VIDEO_INPUT_RESO = (384, 512) # Resolution of the input video
177
+ POINT_SIZE = 4 # Size of the query point in the preview video
178
+ FRAME_LIMIT = 300 # Limit the number of frames to process
179
+
180
+
181
+ def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
182
+ print(f"You selected {(evt.index[0], evt.index[1], frame_num)}")
183
+
184
+ current_frame = video_queried_preview[int(frame_num)]
185
+
186
+ # Get the mouse click
187
+ query_points[int(frame_num)].append((evt.index[0], evt.index[1], frame_num))
188
+
189
+ # Choose the color for the point from matplotlib colormap
190
+ color = matplotlib.colormaps.get_cmap("gist_rainbow")(query_count % 20 / 20)
191
+ color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
192
+ # print(f"Color: {color}")
193
+ query_points_color[int(frame_num)].append(color)
194
+
195
+ # Draw the point on the frame
196
+ x, y = evt.index
197
+ current_frame_draw = cv2.circle(current_frame, (x, y), POINT_SIZE, color, -1)
198
+
199
+ # Update the frame
200
+ video_queried_preview[int(frame_num)] = current_frame_draw
201
+
202
+ # Update the query count
203
+ query_count += 1
204
+ return (
205
+ current_frame_draw, # Updated frame for preview
206
+ video_queried_preview, # Updated preview video
207
+ query_points, # Updated query points
208
+ query_points_color, # Updated query points color
209
+ query_count # Updated query count
210
+ )
211
+
212
+
213
+ def undo_point(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
214
+ if len(query_points[int(frame_num)]) == 0:
215
+ return (
216
+ video_queried_preview[int(frame_num)],
217
+ video_queried_preview,
218
+ query_points,
219
+ query_points_color,
220
+ query_count
221
+ )
222
+
223
+ # Get the last point
224
+ query_points[int(frame_num)].pop(-1)
225
+ query_points_color[int(frame_num)].pop(-1)
226
+
227
+ # Redraw the frame
228
+ current_frame_draw = video_preview[int(frame_num)].copy()
229
+ for point, color in zip(query_points[int(frame_num)], query_points_color[int(frame_num)]):
230
+ x, y, _ = point
231
+ current_frame_draw = cv2.circle(current_frame_draw, (x, y), POINT_SIZE, color, -1)
232
+
233
+ # Update the query count
234
+ query_count -= 1
235
+
236
+ # Update the frame
237
+ video_queried_preview[int(frame_num)] = current_frame_draw
238
+ return (
239
+ current_frame_draw, # Updated frame for preview
240
+ video_queried_preview, # Updated preview video
241
+ query_points, # Updated query points
242
+ query_points_color, # Updated query points color
243
+ query_count # Updated query count
244
+ )
245
+
246
+
247
+ def clear_frame_fn(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count):
248
+ query_count -= len(query_points[int(frame_num)])
249
 
250
+ query_points[int(frame_num)] = []
251
+ query_points_color[int(frame_num)] = []
252
 
253
+ video_queried_preview[int(frame_num)] = video_preview[int(frame_num)].copy()
 
 
 
 
 
 
 
254
 
255
+ return (
256
+ video_preview[int(frame_num)], # Set the preview frame to the original frame
257
+ video_queried_preview,
258
+ query_points, # Cleared query points
259
+ query_points_color, # Cleared query points color
260
+ query_count # New query count
261
+ )
262
+
263
+
264
+
265
+ def clear_all_fn(frame_num, video_preview):
266
+ return (
267
+ video_preview[int(frame_num)],
268
+ video_preview.copy(),
269
+ [[] for _ in range(len(video_preview))],
270
+ [[] for _ in range(len(video_preview))],
271
+ 0
272
+ )
273
+
274
+
275
+ def choose_frame(frame_num, video_preview_array):
276
+ return video_preview_array[int(frame_num)]
277
 
278
 
279
+ def preprocess_video_input(video_path):
280
+ video_arr = mediapy.read_video(video_path)
281
+ video_fps = video_arr.metadata.fps
282
+ num_frames = video_arr.shape[0]
283
+ if num_frames > FRAME_LIMIT:
284
+ gr.Warning(f"The video is too long. Only the first {FRAME_LIMIT} frames will be used.", duration=5)
285
+ video_arr = video_arr[:FRAME_LIMIT]
286
+ num_frames = FRAME_LIMIT
287
+
288
+ # Resize to preview size for faster processing, width = PREVIEW_WIDTH
289
+ height, width = video_arr.shape[1:3]
290
+ new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH
291
+
292
+ preview_video = mediapy.resize_video(video_arr, (new_height, new_width))
293
+ input_video = mediapy.resize_video(video_arr, VIDEO_INPUT_RESO)
294
+
295
+ preview_video = np.array(preview_video)
296
+ input_video = np.array(input_video)
297
+
298
+ interactive = True
299
+
300
+ return (
301
+ video_arr, # Original video
302
+ preview_video, # Original preview video, resized for faster processing
303
+ preview_video.copy(), # Copy of preview video for visualization
304
+ input_video, # Resized video input for model
305
+ # None, # video_feature, # Extracted feature
306
+ video_fps, # Set the video FPS
307
+ gr.update(open=False), # Close the video input drawer
308
+ # tracking_mode, # Set the tracking mode
309
+ preview_video[0], # Set the preview frame to the first frame
310
+ gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=interactive), # Set slider interactive
311
+ [[] for _ in range(num_frames)], # Set query_points to empty
312
+ [[] for _ in range(num_frames)], # Set query_points_color to empty
313
+ [[] for _ in range(num_frames)],
314
+ 0, # Set query count to 0
315
+ gr.update(interactive=interactive), # Make the buttons interactive
316
+ gr.update(interactive=interactive),
317
+ gr.update(interactive=interactive),
318
+ gr.update(interactive=True),
319
+ )
320
+
321
+
322
+ def track(
323
+ video_preview,
324
+ video_input,
325
+ video_fps,
326
+ query_points,
327
+ query_points_color,
328
+ query_count,
329
  ):
330
+ tracking_mode = 'selected'
331
+ if query_count == 0:
332
+ tracking_mode='grid'
333
+
334
+ device = "cuda" if torch.cuda.is_available() else "cpu"
335
+ dtype = torch.float if device == "cuda" else torch.float
336
+
337
+ # Convert query points to tensor, normalize to input resolution
338
+ if tracking_mode!='grid':
339
+ query_points_tensor = []
340
+ for frame_points in query_points:
341
+ query_points_tensor.extend(frame_points)
342
+
343
+ query_points_tensor = torch.tensor(query_points_tensor).float()
344
+ query_points_tensor *= torch.tensor([
345
+ VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0], 1
346
+ ]) / torch.tensor([
347
+ [video_preview.shape[2], video_preview.shape[1], 1]
348
+ ])
349
+ query_points_tensor = query_points_tensor[None].flip(-1).to(device, dtype) # xyt -> tyx
350
+ query_points_tensor = query_points_tensor[:, :, [0, 2, 1]] # tyx -> txy
351
+
352
+ video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
353
 
354
+ model = torch.hub.load("facebookresearch/co-tracker:release_cotracker3", "cotracker3_online")
355
+ model = model.to(device)
356
 
357
+ video_input = video_input.permute(0, 1, 4, 2, 3)
358
+ if tracking_mode=='grid':
359
+ xy = get_points_on_a_grid(15, video_input.shape[3:], device=device)
360
+ queries = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
361
+ add_support_grid=False
362
+ cmap = matplotlib.colormaps.get_cmap("gist_rainbow")
363
+ query_points_color = [[]]
364
+ query_count = queries.shape[1]
365
+ for i in range(query_count):
366
+ # Choose the color for the point from matplotlib colormap
367
+ color = cmap(i / float(query_count))
368
+ color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
369
+ query_points_color[0].append(color)
370
 
371
+ else:
372
+ queries = query_points_tensor
373
+ add_support_grid=True
374
+
375
+ model(video_chunk=video_input, is_first_step=True, grid_size=0, queries=queries, add_support_grid=add_support_grid)
376
+ #
377
+ for ind in range(0, video_input.shape[1] - model.step, model.step):
378
  pred_tracks, pred_visibility = model(
379
+ video_chunk=video_input[:, ind : ind + model.step * 2],
380
+ grid_size=0,
381
+ queries=queries,
382
+ add_support_grid=add_support_grid
383
  ) # B T N 2, B T N 1
384
+ tracks = (pred_tracks * torch.tensor([video_preview.shape[2], video_preview.shape[1]]).to(device) / torch.tensor([VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0]]).to(device))[0].permute(1, 0, 2).cpu().numpy()
385
+ pred_occ = pred_visibility[0].permute(1, 0).cpu().numpy()
386
+
387
+ # make color array
388
+ colors = []
389
+ for frame_colors in query_points_color:
390
+ colors.extend(frame_colors)
391
+ colors = np.array(colors)
392
+
393
+ painted_video = paint_point_track(video_preview,tracks,pred_occ,colors)
394
+
395
+ # save video
396
+ video_file_name = uuid.uuid4().hex + ".mp4"
397
+ video_path = os.path.join(os.path.dirname(__file__), "tmp")
398
+ video_file_path = os.path.join(video_path, video_file_name)
399
+ os.makedirs(video_path, exist_ok=True)
400
+
401
+ mediapy.write_video(video_file_path, painted_video, fps=video_fps)
402
+
403
+ return video_file_path
404
+
405
+
406
+ with gr.Blocks() as demo:
407
+ video = gr.State()
408
+ video_queried_preview = gr.State()
409
+ video_preview = gr.State()
410
+ video_input = gr.State()
411
+ video_fps = gr.State(24)
412
+
413
+ query_points = gr.State([])
414
+ query_points_color = gr.State([])
415
+ is_tracked_query = gr.State([])
416
+ query_count = gr.State(0)
417
+
418
+ gr.Markdown("# 🎨 CoTracker3: Simpler and Better Point Tracking by Pseudo-Labelling Real Videos")
419
+ gr.Markdown("<div style='text-align: left;'> \
420
+ <p>Welcome to <a href='https://cotracker3.github.io/' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
421
+ The model tracks points on a grid or points selected by you. </p> \
422
+ <p> To get started, simply upload your <b>.mp4</b> video or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
423
+ <p> After you uploaded a video, please click \"Submit\" and then click \"Track\" for grid tracking or specify points you want to track before clicking. Enjoy the results! </p>\
424
+ <p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐. We thank the authors of LocoTrack for their interactive demo.</p> \
425
+ </div>"
426
+ )
427
+
428
+
429
+ gr.Markdown("## First step: upload your video or select an example video, and click submit.")
430
+ with gr.Row():
431
+
432
+
433
+ with gr.Accordion("Your video input", open=True) as video_in_drawer:
434
+ video_in = gr.Video(label="Video Input", format="mp4")
435
+ submit = gr.Button("Submit", scale=0)
436
+
437
+ import os
438
+ apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4")
439
+ bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4")
440
+ paragliding_launch = os.path.join(
441
+ os.path.dirname(__file__), "videos", "paragliding-launch.mp4"
442
+ )
443
+ paragliding = os.path.join(os.path.dirname(__file__), "videos", "paragliding.mp4")
444
+ cat = os.path.join(os.path.dirname(__file__), "videos", "cat.mp4")
445
+ pillow = os.path.join(os.path.dirname(__file__), "videos", "pillow.mp4")
446
+ teddy = os.path.join(os.path.dirname(__file__), "videos", "teddy.mp4")
447
+ backpack = os.path.join(os.path.dirname(__file__), "videos", "backpack.mp4")
448
+
449
+
450
+ gr.Examples(examples=[bear, apple, paragliding, paragliding_launch, cat, pillow, teddy, backpack],
451
+ inputs = [
452
+ video_in
453
+ ],
454
+ )
455
+
456
+
457
+ gr.Markdown("## Second step: Simply click \"Track\" to track a grid of points or select query points on the video before clicking")
458
+ with gr.Row():
459
+ with gr.Column():
460
+ with gr.Row():
461
+ query_frames = gr.Slider(
462
+ minimum=0, maximum=100, value=0, step=1, label="Choose Frame", interactive=False)
463
+ with gr.Row():
464
+ undo = gr.Button("Undo", interactive=False)
465
+ clear_frame = gr.Button("Clear Frame", interactive=False)
466
+ clear_all = gr.Button("Clear All", interactive=False)
467
+
468
+ with gr.Row():
469
+ current_frame = gr.Image(
470
+ label="Click to add query points",
471
+ type="numpy",
472
+ interactive=False
473
+ )
474
+
475
+ with gr.Row():
476
+ track_button = gr.Button("Track", interactive=False)
477
+
478
+ with gr.Column():
479
+ output_video = gr.Video(
480
+ label="Output Video",
481
+ interactive=False,
482
+ autoplay=True,
483
+ loop=True,
484
+ )
485
+
486
+
487
+
488
+ submit.click(
489
+ fn = preprocess_video_input,
490
+ inputs = [video_in],
491
+ outputs = [
492
+ video,
493
+ video_preview,
494
+ video_queried_preview,
495
+ video_input,
496
+ video_fps,
497
+ video_in_drawer,
498
+ current_frame,
499
+ query_frames,
500
+ query_points,
501
+ query_points_color,
502
+ is_tracked_query,
503
+ query_count,
504
+ undo,
505
+ clear_frame,
506
+ clear_all,
507
+ track_button,
508
+ ],
509
+ queue = False
510
+ )
511
+
512
+ query_frames.change(
513
+ fn = choose_frame,
514
+ inputs = [query_frames, video_queried_preview],
515
+ outputs = [
516
+ current_frame,
517
+ ],
518
+ queue = False
519
+ )
520
+
521
+ current_frame.select(
522
+ fn = get_point,
523
+ inputs = [
524
+ query_frames,
525
+ video_queried_preview,
526
+ query_points,
527
+ query_points_color,
528
+ query_count,
529
+ ],
530
+ outputs = [
531
+ current_frame,
532
+ video_queried_preview,
533
+ query_points,
534
+ query_points_color,
535
+ query_count
536
+ ],
537
+ queue = False
538
+ )
539
+
540
+ undo.click(
541
+ fn = undo_point,
542
+ inputs = [
543
+ query_frames,
544
+ video_preview,
545
+ video_queried_preview,
546
+ query_points,
547
+ query_points_color,
548
+ query_count
549
+ ],
550
+ outputs = [
551
+ current_frame,
552
+ video_queried_preview,
553
+ query_points,
554
+ query_points_color,
555
+ query_count
556
+ ],
557
+ queue = False
558
+ )
559
+
560
+ clear_frame.click(
561
+ fn = clear_frame_fn,
562
+ inputs = [
563
+ query_frames,
564
+ video_preview,
565
+ video_queried_preview,
566
+ query_points,
567
+ query_points_color,
568
+ query_count
569
+ ],
570
+ outputs = [
571
+ current_frame,
572
+ video_queried_preview,
573
+ query_points,
574
+ query_points_color,
575
+ query_count
576
+ ],
577
+ queue = False
578
+ )
579
+
580
+ clear_all.click(
581
+ fn = clear_all_fn,
582
+ inputs = [
583
+ query_frames,
584
+ video_preview,
585
+ ],
586
+ outputs = [
587
+ current_frame,
588
+ video_queried_preview,
589
+ query_points,
590
+ query_points_color,
591
+ query_count
592
+ ],
593
+ queue = False
594
+ )
595
+
596
+
597
+ track_button.click(
598
+ fn = track,
599
+ inputs = [
600
+ video_preview,
601
+ video_input,
602
+ video_fps,
603
+ query_points,
604
+ query_points_color,
605
+ query_count,
606
+ ],
607
+ outputs = [
608
+ output_video,
609
+ ],
610
+ queue = True,
611
+ )
612
 
613
+
614
+ demo.launch(show_api=False, show_error=True, debug=True, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,9 +1,12 @@
1
- matplotlib
 
 
 
 
 
2
  imageio[ffmpeg]
3
  opencv-python
4
- flow_vis
5
- imutils
6
  numpy
7
- imageio
8
- gradio
9
- git+https://github.com/facebookresearch/co-tracker.git
 
1
+ torch==1.13.0
2
+ torchvision==0.14.0
3
+ matplotlib==3.7.5
4
+ moviepy==1.0.3
5
+ flow_vis
6
+ gradio
7
  imageio[ffmpeg]
8
  opencv-python
9
+ imutils==0.5.4
10
+ mediapy==1.2.2
11
  numpy
12
+ git+https://github.com/facebookresearch/co-tracker.git
 
 
videos/apple.mp4 CHANGED
Binary files a/videos/apple.mp4 and b/videos/apple.mp4 differ
 
videos/backpack.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b5ac6b2285ffb48e3a740e419e38c781df9c963589a5fd894e5b4e13dd6a8b8
3
+ size 1208738
videos/cat.mp4 ADDED
Binary file (253 kB). View file
 
videos/pillow.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f05818f586d7b0796fcd4714ea4be489c93701598cadc86ce7973fc24655fee
3
+ size 1407147
videos/teddy.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:720503173c3b23b1d3d3fefa0e930558f944f0562e6a7b3c23810fc7046b39c7
3
+ size 1337504