from transformers import Owlv2Processor, Owlv2ForObjectDetection from typing import List import os import numpy as np import supervision as sv import uuid import torch from tqdm import tqdm import gradio as gr import torch import numpy as np from PIL import Image import spaces device = torch.device("cuda" if torch.cuda.is_available() else "cpu") processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device) BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator() MASK_ANNOTATOR = sv.MaskAnnotator() LABEL_ANNOTATOR = sv.LabelAnnotator() def calculate_end_frame_index(source_video_path): video_info = sv.VideoInfo.from_video_path(source_video_path) return min( video_info.total_frames, video_info.fps * 2 ) def annotate_image( input_image, detections, labels ) -> np.ndarray: output_image = MASK_ANNOTATOR.annotate(input_image, detections) output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections) output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels) return output_image @spaces.GPU def process_video( input_video, labels, progress=gr.Progress(track_tqdm=True) ): labels = labels.split(",") video_info = sv.VideoInfo.from_video_path(input_video) total = calculate_end_frame_index(input_video) frame_generator = sv.get_video_frames_generator( source_path=input_video, end=total ) result_file_name = f"{uuid.uuid4()}.mp4" result_file_path = os.path.join("./", result_file_name) with sv.VideoSink(result_file_path, video_info=video_info) as sink: for _ in tqdm(range(total), desc="Processing video.."): frame = next(frame_generator) # list of dict of {"box": box, "mask":mask, "score":score, "label":label} results = query(frame, labels) print("results", results) detections = sv.Detections.from_transformers(results[0]) final_labels = [] for id in results[0]["labels"]: final_labels.append(labels[id]) frame = annotate_image( input_image=frame, detections=detections, labels=final_labels, ) sink.write_frame(frame) return result_file_path def query(image, texts): inputs = processor(text=texts, images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) target_sizes = torch.Tensor([image.shape[:-1]]) results = processor.post_process_object_detection(outputs=outputs, threshold=0.3, target_sizes=target_sizes) return results with gr.Blocks() as demo: gr.Markdown("## Zero-shot Object Tracking with OWLv2 🦉") gr.Markdown("This is a demo for zero-shot object tracking using [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) model by Google.") gr.Markdown("Simply upload a video and enter the candidate labels, or try the example below. 👇") with gr.Tab(label="Video"): with gr.Row(): input_video = gr.Video( label='Input Video' ) output_video = gr.Video( label='Output Video' ) with gr.Row(): candidate_labels = gr.Textbox( label='Labels', placeholder='Labels separated by a comma', ) submit = gr.Button() gr.Examples( fn=process_video, examples=[["./cats.mp4", "dog,cat"]], inputs=[ input_video, candidate_labels, ], outputs=output_video ) submit.click( fn=process_video, inputs=[input_video, candidate_labels], outputs=output_video ) demo.launch(debug=False, show_error=True)