sam2 / app.py
not-lain's picture
load model
2b2f0b8
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import spaces
from sam2.sam2_image_predictor import SAM2ImagePredictor
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small",device="cpu")
def preprocess_image(image):
return image, gr.State([]), gr.State([]), image, None
def get_point(
point_type,
tracking_points,
trackings_input_label,
first_frame_path,
evt: gr.SelectData,
):
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
tracking_points.value.append(evt.index)
print(f"TRACKING POINT: {tracking_points.value}")
if point_type == "include":
trackings_input_label.value.append(1)
elif point_type == "exclude":
trackings_input_label.value.append(0)
print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
# Open the image and get its dimensions
transparent_background = Image.open(first_frame_path).convert("RGBA")
w, h = transparent_background.size
# Define the circle radius as a fraction of the smaller dimension
fraction = 0.02 # You can adjust this value as needed
radius = int(fraction * min(w, h))
# Create a transparent layer to draw on
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
holder_list = []
for index, track in enumerate(tracking_points.value):
holder_list.append({str(trackings_input_label.value[index]): track})
if trackings_input_label.value[index] == 1:
cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
else:
cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
# Convert the transparent layer back to an image
transparent_layer = Image.fromarray(transparent_layer, "RGBA")
selected_point_map = Image.alpha_composite(
transparent_background, transparent_layer
)
return tracking_points, trackings_input_label, selected_point_map, holder_list
def sam_process(input_image, checkpoint, tracking_points, trackings_input_label):
return None, None
@spaces.GPU
def sam_process2(input_image, checkpoint, holder):
tracking_points, trackings_input_label = [], []
for i in holder:
trackings_input_label.append(list(i.keys())[0])
tracking_points.append(list(i.values())[0])
return None, None
with gr.Blocks() as demo:
first_frame_path = gr.State()
tracking_points = gr.State([])
trackings_input_label = gr.State([])
with gr.Column():
gr.Markdown("# SAM2 Image Segmenter")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="input image",
interactive=False,
type="filepath",
visible=False,
)
points_map = gr.Image(
label="points map", type="filepath", interactive=True
)
with gr.Row():
point_type = gr.Radio(
label="point type",
choices=["include", "exclude"],
value="include",
)
clear_points_btn = gr.Button("Clear Points")
checkpoint = gr.Dropdown(
label="Checkpoint",
choices=["tiny", "small", "base-plus", "large"],
value="tiny",
)
holder = gr.Json()
submit_btn = gr.Button("Submit")
sub2 = gr.Button("sub2")
with gr.Column():
output_result = gr.Image()
output_result_mask = gr.Image()
clear_points_btn.click(
fn=preprocess_image,
inputs=input_image,
outputs=[
first_frame_path,
tracking_points,
trackings_input_label,
points_map,
holder,
],
queue=False,
)
points_map.upload(
fn=preprocess_image,
inputs=[points_map],
outputs=[first_frame_path, tracking_points, trackings_input_label, input_image],
queue=False,
)
points_map.select(
fn=get_point,
inputs=[point_type, tracking_points, trackings_input_label, first_frame_path],
outputs=[tracking_points, trackings_input_label, points_map, holder],
queue=False,
)
submit_btn.click(
fn=sam_process,
inputs=[input_image, checkpoint, tracking_points, trackings_input_label],
outputs=[output_result, output_result_mask],
)
sub2.click(
fn=sam_process2,
inputs=[input_image, checkpoint, holder],
outputs=[output_result, output_result_mask],
)
demo.launch(debug=True, show_error=True)