fffiloni commited on
Commit
eee8ee3
·
verified ·
1 Parent(s): 95d366b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -8
app.py CHANGED
@@ -2,12 +2,40 @@ import gradio as gr
2
  import os
3
  os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
4
  import torch
5
- import numpy as np
 
6
  import matplotlib.pyplot as plt
7
- from PIL import Image
8
  from sam2.build_sam import build_sam2
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # use bfloat16 for the entire notebook
12
  torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
13
 
@@ -71,7 +99,7 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
71
 
72
  return masks_store
73
 
74
- def sam_process(input_image):
75
  image = Image.open(input_image)
76
  image = np.array(image.convert("RGB"))
77
 
@@ -84,7 +112,7 @@ def sam_process(input_image):
84
 
85
  predictor.set_image(image)
86
 
87
- input_point = np.array([[539, 384]])
88
  input_label = np.array([1])
89
 
90
  print(predictor._features["image_embed"].shape, predictor._features["image_embed"][-1].shape)
@@ -107,13 +135,26 @@ def sam_process(input_image):
107
  return results
108
 
109
  with gr.Blocks() as demo:
 
 
 
110
  with gr.Column():
111
- input_image = gr.Image(label="input image", type="filepath")
112
- submit_btn = gr.Button("Submit")
113
- output_result = gr.Gallery()
 
 
 
 
 
 
 
 
 
 
114
  submit_btn.click(
115
  fn = sam_process,
116
- inputs = [input_image],
117
  outputs = [output_result]
118
  )
119
  demo.launch()
 
2
  import os
3
  os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
4
  import torch
5
+ import numpy as
6
+ import cv2
7
  import matplotlib.pyplot as plt
8
+ from PIL import Image, ImageFilter
9
  from sam2.build_sam import build_sam2
10
  from sam2.sam2_image_predictor import SAM2ImagePredictor
11
 
12
+ def preprocess_image(image):
13
+ return image, gr.State, gr.State
14
+
15
+ def get_point(tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
16
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
17
+
18
+ tracking_points.value.append(evt.index)
19
+ print(f"TRACKING POINT: {tracking_points.value}")
20
+
21
+ trackings_input_label.value.append(1)
22
+ print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
23
+ # for SAM2
24
+ # input_point = np.array(tracking_points.value)
25
+ # print(f"SAM2 INPUT POINT: {input_point}")
26
+ # input_label = np.array([1])
27
+
28
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
29
+ w, h = transparent_background.size
30
+ transparent_layer = np.zeros((h, w, 4))
31
+ for track in tracking_points.value:
32
+ cv2.circle(transparent_layer, track, 5, (255, 0, 0, 255), -1)
33
+
34
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
35
+ selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
36
+
37
+ return tracking_points, trackings_input_label, selected_point_map
38
+
39
  # use bfloat16 for the entire notebook
40
  torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
41
 
 
99
 
100
  return masks_store
101
 
102
+ def sam_process(input_image, tracking_points, trackings_input_label):
103
  image = Image.open(input_image)
104
  image = np.array(image.convert("RGB"))
105
 
 
112
 
113
  predictor.set_image(image)
114
 
115
+ input_point = np.array(tracking_points.value)
116
  input_label = np.array([1])
117
 
118
  print(predictor._features["image_embed"].shape, predictor._features["image_embed"][-1].shape)
 
135
  return results
136
 
137
  with gr.Blocks() as demo:
138
+ first_frame_path = gr.State()
139
+ tracking_points = gr.State([])
140
+ trackings_input_label = gr.State([])
141
  with gr.Column():
142
+ gr.Markdown("# SAM2 Image Predictor")
143
+ with gr.Row():
144
+ input_image = gr.Image(label="input image", interactive=True, type="filepath")
145
+ with gr.Column():
146
+ points_map = gr.Image(label="points map")
147
+ submit_btn = gr.Button("Submit")
148
+ output_result = gr.Gallery()
149
+
150
+ input_image.upload(preprocess_image, input_image, [first_frame_path, tracking_points, trackings_input_label])
151
+
152
+ input_image.select(get_point, [tracking_points, trackings_input_label, first_frame_path], [tracking_points, trackings_input_label, points_map])
153
+
154
+
155
  submit_btn.click(
156
  fn = sam_process,
157
+ inputs = [input_image, tracking_points, trackings_input_label],
158
  outputs = [output_result]
159
  )
160
  demo.launch()