import gradio as gr import numpy as np import cv2 from PIL import Image import spaces from sam2.sam2_image_predictor import SAM2ImagePredictor predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small",device="cpu") def preprocess_image(image): return image, gr.State([]), gr.State([]), image, None def get_point( point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData, ): print(f"You selected {evt.value} at {evt.index} from {evt.target}") tracking_points.value.append(evt.index) print(f"TRACKING POINT: {tracking_points.value}") if point_type == "include": trackings_input_label.value.append(1) elif point_type == "exclude": trackings_input_label.value.append(0) print(f"TRACKING INPUT LABEL: {trackings_input_label.value}") # Open the image and get its dimensions transparent_background = Image.open(first_frame_path).convert("RGBA") w, h = transparent_background.size # Define the circle radius as a fraction of the smaller dimension fraction = 0.02 # You can adjust this value as needed radius = int(fraction * min(w, h)) # Create a transparent layer to draw on transparent_layer = np.zeros((h, w, 4), dtype=np.uint8) holder_list = [] for index, track in enumerate(tracking_points.value): holder_list.append({str(trackings_input_label.value[index]): track}) if trackings_input_label.value[index] == 1: cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1) else: cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1) # Convert the transparent layer back to an image transparent_layer = Image.fromarray(transparent_layer, "RGBA") selected_point_map = Image.alpha_composite( transparent_background, transparent_layer ) return tracking_points, trackings_input_label, selected_point_map, holder_list def sam_process(input_image, checkpoint, tracking_points, trackings_input_label): return None, None @spaces.GPU def sam_process2(input_image, checkpoint, holder): tracking_points, trackings_input_label = [], [] for i in holder: trackings_input_label.append(list(i.keys())[0]) tracking_points.append(list(i.values())[0]) return None, None with gr.Blocks() as demo: first_frame_path = gr.State() tracking_points = gr.State([]) trackings_input_label = gr.State([]) with gr.Column(): gr.Markdown("# SAM2 Image Segmenter") with gr.Row(): with gr.Column(): input_image = gr.Image( label="input image", interactive=False, type="filepath", visible=False, ) points_map = gr.Image( label="points map", type="filepath", interactive=True ) with gr.Row(): point_type = gr.Radio( label="point type", choices=["include", "exclude"], value="include", ) clear_points_btn = gr.Button("Clear Points") checkpoint = gr.Dropdown( label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="tiny", ) holder = gr.Json() submit_btn = gr.Button("Submit") sub2 = gr.Button("sub2") with gr.Column(): output_result = gr.Image() output_result_mask = gr.Image() clear_points_btn.click( fn=preprocess_image, inputs=input_image, outputs=[ first_frame_path, tracking_points, trackings_input_label, points_map, holder, ], queue=False, ) points_map.upload( fn=preprocess_image, inputs=[points_map], outputs=[first_frame_path, tracking_points, trackings_input_label, input_image], queue=False, ) points_map.select( fn=get_point, inputs=[point_type, tracking_points, trackings_input_label, first_frame_path], outputs=[tracking_points, trackings_input_label, points_map, holder], queue=False, ) submit_btn.click( fn=sam_process, inputs=[input_image, checkpoint, tracking_points, trackings_input_label], outputs=[output_result, output_result_mask], ) sub2.click( fn=sam_process2, inputs=[input_image, checkpoint, holder], outputs=[output_result, output_result_mask], ) demo.launch(debug=True, show_error=True)