File size: 4,789 Bytes
d0b406d 9f0fd75 ac17784 2b2f0b8 d0b406d 9f0fd75 ac17784 9f0fd75 5743599 9f0fd75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import spaces
from sam2.sam2_image_predictor import SAM2ImagePredictor
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small",device="cpu")
def preprocess_image(image):
return image, gr.State([]), gr.State([]), image, None
def get_point(
point_type,
tracking_points,
trackings_input_label,
first_frame_path,
evt: gr.SelectData,
):
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
tracking_points.value.append(evt.index)
print(f"TRACKING POINT: {tracking_points.value}")
if point_type == "include":
trackings_input_label.value.append(1)
elif point_type == "exclude":
trackings_input_label.value.append(0)
print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
# Open the image and get its dimensions
transparent_background = Image.open(first_frame_path).convert("RGBA")
w, h = transparent_background.size
# Define the circle radius as a fraction of the smaller dimension
fraction = 0.02 # You can adjust this value as needed
radius = int(fraction * min(w, h))
# Create a transparent layer to draw on
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
holder_list = []
for index, track in enumerate(tracking_points.value):
holder_list.append({str(trackings_input_label.value[index]): track})
if trackings_input_label.value[index] == 1:
cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
else:
cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
# Convert the transparent layer back to an image
transparent_layer = Image.fromarray(transparent_layer, "RGBA")
selected_point_map = Image.alpha_composite(
transparent_background, transparent_layer
)
return tracking_points, trackings_input_label, selected_point_map, holder_list
def sam_process(input_image, checkpoint, tracking_points, trackings_input_label):
return None, None
@spaces.GPU
def sam_process2(input_image, checkpoint, holder):
tracking_points, trackings_input_label = [], []
for i in holder:
trackings_input_label.append(list(i.keys())[0])
tracking_points.append(list(i.values())[0])
return None, None
with gr.Blocks() as demo:
first_frame_path = gr.State()
tracking_points = gr.State([])
trackings_input_label = gr.State([])
with gr.Column():
gr.Markdown("# SAM2 Image Segmenter")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="input image",
interactive=False,
type="filepath",
visible=False,
)
points_map = gr.Image(
label="points map", type="filepath", interactive=True
)
with gr.Row():
point_type = gr.Radio(
label="point type",
choices=["include", "exclude"],
value="include",
)
clear_points_btn = gr.Button("Clear Points")
checkpoint = gr.Dropdown(
label="Checkpoint",
choices=["tiny", "small", "base-plus", "large"],
value="tiny",
)
holder = gr.Json()
submit_btn = gr.Button("Submit")
sub2 = gr.Button("sub2")
with gr.Column():
output_result = gr.Image()
output_result_mask = gr.Image()
clear_points_btn.click(
fn=preprocess_image,
inputs=input_image,
outputs=[
first_frame_path,
tracking_points,
trackings_input_label,
points_map,
holder,
],
queue=False,
)
points_map.upload(
fn=preprocess_image,
inputs=[points_map],
outputs=[first_frame_path, tracking_points, trackings_input_label, input_image],
queue=False,
)
points_map.select(
fn=get_point,
inputs=[point_type, tracking_points, trackings_input_label, first_frame_path],
outputs=[tracking_points, trackings_input_label, points_map, holder],
queue=False,
)
submit_btn.click(
fn=sam_process,
inputs=[input_image, checkpoint, tracking_points, trackings_input_label],
outputs=[output_result, output_result_mask],
)
sub2.click(
fn=sam_process2,
inputs=[input_image, checkpoint, holder],
outputs=[output_result, output_result_mask],
)
demo.launch(debug=True, show_error=True)
|