File size: 3,233 Bytes
2c0d085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import cv2
from ultralytics import YOLO

# Load YOLO model
model = YOLO('last.torchscript')  # Replace with 'best.onnx' or 'best.torchscript' if converted

# Function for image inference
def detect_in_image(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    results = model.predict(source=image, save=False, save_txt=False)
    annotated_frame = results[0].plot()  # Annotated frame with bounding boxes
    annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)  # Convert to RGB
    return annotated_frame

# Function for video inference
def detect_in_video(video):
    cap = cv2.VideoCapture(video)
    output_path = "output_video.mp4"

    # Get video properties
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Create VideoWriter for saving the output video
    out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

    # Frame generator for live streaming
    def frame_generator(frame_skip=6):
        frame_count = 0
        while True:
            ret, frame = cap.read()
            if not ret:
                break

            if frame_count % frame_skip == 0:  # Process every nth frame
                results = model.predict(source=frame, save=False, save_txt=False)
                annotated_frame = results[0].plot()  # Annotated frame with bounding boxes
                
                # Save annotated frame to output video
                out.write(annotated_frame)

                # Convert frame to RGB for display
                annotated_frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
                yield annotated_frame_rgb
            
            frame_count += 1

        # Release resources
        cap.release()
        out.release()

    return frame_generator(), output_path

# Build the Gradio interface
with gr.Blocks(css=".header {font-size: 30px; color: #4CAF50; font-weight: bold; text-align: center;} .image-output {max-width: 400px; margin: auto;}") as app:
    gr.Markdown("<h1 class='header'>🐾 Rat Paw Detection App 🐾</h1>")
    
    # Image detection tab
    with gr.Tab("Image Detection"):
        image_input = gr.Image(label="Upload an Image", type="numpy")
        image_output = gr.Image(label="Annotated Image", type="numpy", elem_id="image-output")
        image_button = gr.Button("Detect", variant="primary")
        image_button.click(detect_in_image, inputs=image_input, outputs=image_output)
    
    # Video detection tab
    with gr.Tab("Video Detection"):
        video_input = gr.Video(label="Upload a Video")
        video_display = gr.Image(label="Live Detection", elem_id="image-output")

        def video_handler(video):
            frame_gen, output_path = detect_in_video(video)
            for frame in frame_gen:
                yield {video_display: frame}  # Live update for each processed frame

        video_button = gr.Button("Detect", variant="primary")
        video_button.click(fn=video_handler, inputs=video_input, outputs=[video_display])

# Launch the app
app.launch()