yumyum2081 commited on
Commit
858dae2
·
verified ·
1 Parent(s): 2ed6668

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -94
app.py CHANGED
@@ -1,12 +1,10 @@
1
  import gradio as gr
2
-
3
  import os
4
- os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
5
  import torch
6
  import numpy as np
7
  import cv2
8
  import matplotlib.pyplot as plt
9
- from PIL import Image, ImageFilter
10
  from sam2.build_sam import build_sam2
11
  from sam2.sam2_image_predictor import SAM2ImagePredictor
12
 
@@ -25,15 +23,12 @@ def get_point(point_type, tracking_points, trackings_input_label, first_frame_pa
25
  trackings_input_label.value.append(0)
26
  print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
27
 
28
- # Open the image and get its dimensions
29
  transparent_background = Image.open(first_frame_path).convert('RGBA')
30
  w, h = transparent_background.size
31
 
32
- # Define the circle radius as a fraction of the smaller dimension
33
- fraction = 0.02 # You can adjust this value as needed
34
  radius = int(fraction * min(w, h))
35
 
36
- # Create a transparent layer to draw on
37
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
38
 
39
  for index, track in enumerate(tracking_points.value):
@@ -42,21 +37,15 @@ def get_point(point_type, tracking_points, trackings_input_label, first_frame_pa
42
  else:
43
  cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
44
 
45
- # Convert the transparent layer back to an image
46
  transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
47
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
48
 
49
  return tracking_points, trackings_input_label, selected_point_map
50
-
51
- # use bfloat16 for the entire notebook
52
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
53
 
54
- if torch.cuda.get_device_properties(0).major >= 8:
55
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
56
- torch.backends.cuda.matmul.allow_tf32 = True
57
- torch.backends.cudnn.allow_tf32 = True
58
-
59
- def show_mask(mask, ax, random_color=False, borders = True):
60
  if random_color:
61
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
62
  else:
@@ -65,9 +54,7 @@ def show_mask(mask, ax, random_color=False, borders = True):
65
  mask = mask.astype(np.uint8)
66
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
67
  if borders:
68
- import cv2
69
  contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
70
- # Try to smooth contours
71
  contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
72
  mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
73
  ax.imshow(mask_image)
@@ -84,94 +71,66 @@ def show_box(box, ax):
84
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
85
 
86
  def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
87
- combined_images = [] # List to store filenames of images with masks overlaid
88
- mask_images = [] # List to store filenames of separate mask images
89
 
90
  for i, (mask, score) in enumerate(zip(masks, scores)):
91
- # ---- Original Image with Mask Overlaid ----
92
  plt.figure(figsize=(10, 10))
93
  plt.imshow(image)
94
- show_mask(mask, plt.gca(), borders=borders) # Draw the mask with borders
95
- """
96
- if point_coords is not None:
97
- assert input_labels is not None
98
- show_points(point_coords, input_labels, plt.gca())
99
- """
100
- if box_coords is not None:
101
- show_box(box_coords, plt.gca())
102
- if len(scores) > 1:
103
- plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
104
  plt.axis('off')
105
 
106
- # Save the figure as a JPG file
107
  combined_filename = f"combined_image_{i+1}.jpg"
108
  plt.savefig(combined_filename, format='jpg', bbox_inches='tight')
109
  combined_images.append(combined_filename)
 
110
 
111
- plt.close() # Close the figure to free up memory
112
-
113
- # ---- Separate Mask Image (White Mask on Black Background) ----
114
- # Create a black image
115
  mask_image = np.zeros_like(image, dtype=np.uint8)
116
-
117
- # The mask is a binary array where the masked area is 1, else 0.
118
- # Convert the mask to a white color in the mask_image
119
  mask_layer = (mask > 0).astype(np.uint8) * 255
120
- for c in range(3): # Assuming RGB, repeat mask for all channels
121
  mask_image[:, :, c] = mask_layer
122
 
123
- # Save the mask image
124
  mask_filename = f"mask_image_{i+1}.png"
125
  Image.fromarray(mask_image).save(mask_filename)
126
  mask_images.append(mask_filename)
127
 
128
- plt.close() # Close the figure to free up memory
129
-
130
  return combined_images, mask_images
131
 
132
-
133
  def sam_process(input_image, checkpoint, tracking_points, trackings_input_label):
134
  image = Image.open(input_image)
135
  image = np.array(image.convert("RGB"))
136
 
137
- if checkpoint == "tiny":
138
- sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
139
- model_cfg = "sam2_hiera_t.yaml"
140
- elif checkpoint == "samll":
141
- sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt"
142
- model_cfg = "sam2_hiera_s.yaml"
143
- elif checkpoint == "base-plus":
144
- sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt"
145
- model_cfg = "sam2_hiera_b+.yaml"
146
- elif checkpoint == "large":
147
- sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
148
- model_cfg = "sam2_hiera_l.yaml"
149
 
150
- sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
151
 
 
 
152
  predictor = SAM2ImagePredictor(sam2_model)
153
-
154
  predictor.set_image(image)
155
 
156
  input_point = np.array(tracking_points.value)
157
  input_label = np.array(trackings_input_label.value)
158
 
159
- print(predictor._features["image_embed"].shape, predictor._features["image_embed"][-1].shape)
160
-
161
  masks, scores, logits = predictor.predict(
162
  point_coords=input_point,
163
  point_labels=input_label,
164
  multimask_output=False,
165
  )
 
166
  sorted_ind = np.argsort(scores)[::-1]
167
  masks = masks[sorted_ind]
168
  scores = scores[sorted_ind]
169
- logits = logits[sorted_ind]
170
 
171
- print(masks.shape)
172
-
173
- results, mask_results = show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)
174
- print(results)
175
 
176
  return results[0], mask_results[0]
177
 
@@ -180,23 +139,12 @@ with gr.Blocks() as demo:
180
  tracking_points = gr.State([])
181
  trackings_input_label = gr.State([])
182
  with gr.Column():
183
- gr.Markdown("# SAM2 Image Predictor")
184
- gr.Markdown("This is a simple demo for image segmentation with SAM2.")
185
- gr.Markdown("""Instructions:
186
-
187
- 1. Upload your image
188
- 2. With 'include' point type selected, Click on the object to mask
189
- 3. Switch to 'exclude' point type if you want to specify an area to avoid
190
- 4. Submit !
191
- """)
192
  with gr.Row():
193
  with gr.Column():
194
  input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
195
- points_map = gr.Image(
196
- label="points map",
197
- type="filepath",
198
- interactive=True
199
- )
200
  with gr.Row():
201
  point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include")
202
  clear_points_btn = gr.Button("Clear Points")
@@ -207,30 +155,30 @@ with gr.Blocks() as demo:
207
  output_result_mask = gr.Image()
208
 
209
  clear_points_btn.click(
210
- fn = preprocess_image,
211
- inputs = input_image,
212
- outputs = [first_frame_path, tracking_points, trackings_input_label, points_map],
213
  queue=False
214
  )
215
 
216
  points_map.upload(
217
- fn = preprocess_image,
218
- inputs = [points_map],
219
- outputs = [first_frame_path, tracking_points, trackings_input_label, input_image],
220
- queue = False
221
  )
222
 
223
  points_map.select(
224
- fn = get_point,
225
- inputs = [point_type, tracking_points, trackings_input_label, first_frame_path],
226
- outputs = [tracking_points, trackings_input_label, points_map],
227
- queue = False
228
  )
229
 
230
  submit_btn.click(
231
- fn = sam_process,
232
- inputs = [input_image, checkpoint, tracking_points, trackings_input_label],
233
- outputs = [output_result, output_result_mask]
234
  )
235
 
236
  demo.launch(show_api=False, show_error=True)
 
1
  import gradio as gr
 
2
  import os
 
3
  import torch
4
  import numpy as np
5
  import cv2
6
  import matplotlib.pyplot as plt
7
+ from PIL import Image
8
  from sam2.build_sam import build_sam2
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
 
 
23
  trackings_input_label.value.append(0)
24
  print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
25
 
 
26
  transparent_background = Image.open(first_frame_path).convert('RGBA')
27
  w, h = transparent_background.size
28
 
29
+ fraction = 0.02
 
30
  radius = int(fraction * min(w, h))
31
 
 
32
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
33
 
34
  for index, track in enumerate(tracking_points.value):
 
37
  else:
38
  cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
39
 
 
40
  transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
41
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
42
 
43
  return tracking_points, trackings_input_label, selected_point_map
 
 
 
44
 
45
+ # Remove all CUDA-specific configurations
46
+ torch.autocast(device_type="cpu", dtype=torch.float32).__enter__()
47
+
48
+ def show_mask(mask, ax, random_color=False, borders=True):
 
 
49
  if random_color:
50
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
51
  else:
 
54
  mask = mask.astype(np.uint8)
55
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
56
  if borders:
 
57
  contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
 
58
  contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
59
  mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
60
  ax.imshow(mask_image)
 
71
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
72
 
73
  def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
74
+ combined_images = []
75
+ mask_images = []
76
 
77
  for i, (mask, score) in enumerate(zip(masks, scores)):
 
78
  plt.figure(figsize=(10, 10))
79
  plt.imshow(image)
80
+ show_mask(mask, plt.gca(), borders=borders)
 
 
 
 
 
 
 
 
 
81
  plt.axis('off')
82
 
 
83
  combined_filename = f"combined_image_{i+1}.jpg"
84
  plt.savefig(combined_filename, format='jpg', bbox_inches='tight')
85
  combined_images.append(combined_filename)
86
+ plt.close()
87
 
 
 
 
 
88
  mask_image = np.zeros_like(image, dtype=np.uint8)
 
 
 
89
  mask_layer = (mask > 0).astype(np.uint8) * 255
90
+ for c in range(3):
91
  mask_image[:, :, c] = mask_layer
92
 
 
93
  mask_filename = f"mask_image_{i+1}.png"
94
  Image.fromarray(mask_image).save(mask_filename)
95
  mask_images.append(mask_filename)
96
 
 
 
97
  return combined_images, mask_images
98
 
 
99
  def sam_process(input_image, checkpoint, tracking_points, trackings_input_label):
100
  image = Image.open(input_image)
101
  image = np.array(image.convert("RGB"))
102
 
103
+ checkpoint_map = {
104
+ "tiny": ("./checkpoints/sam2_hiera_tiny.pt", "sam2_hiera_t.yaml"),
105
+ "small": ("./checkpoints/sam2_hiera_small.pt", "sam2_hiera_s.yaml"),
106
+ "base-plus": ("./checkpoints/sam2_hiera_base_plus.pt", "sam2_hiera_b+.yaml"),
107
+ "large": ("./checkpoints/sam2_hiera_large.pt", "sam2_hiera_l.yaml")
108
+ }
 
 
 
 
 
 
109
 
110
+ sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
111
 
112
+ # Use CPU for both model and computations
113
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
114
  predictor = SAM2ImagePredictor(sam2_model)
 
115
  predictor.set_image(image)
116
 
117
  input_point = np.array(tracking_points.value)
118
  input_label = np.array(trackings_input_label.value)
119
 
 
 
120
  masks, scores, logits = predictor.predict(
121
  point_coords=input_point,
122
  point_labels=input_label,
123
  multimask_output=False,
124
  )
125
+
126
  sorted_ind = np.argsort(scores)[::-1]
127
  masks = masks[sorted_ind]
128
  scores = scores[sorted_ind]
 
129
 
130
+ results, mask_results = show_masks(image, masks, scores,
131
+ point_coords=input_point,
132
+ input_labels=input_label,
133
+ borders=True)
134
 
135
  return results[0], mask_results[0]
136
 
 
139
  tracking_points = gr.State([])
140
  trackings_input_label = gr.State([])
141
  with gr.Column():
142
+ gr.Markdown("# SAM2 Image Predictor (CPU Version)")
143
+ gr.Markdown("This version runs entirely on CPU")
 
 
 
 
 
 
 
144
  with gr.Row():
145
  with gr.Column():
146
  input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
147
+ points_map = gr.Image(label="points map", type="filepath", interactive=True)
 
 
 
 
148
  with gr.Row():
149
  point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include")
150
  clear_points_btn = gr.Button("Clear Points")
 
155
  output_result_mask = gr.Image()
156
 
157
  clear_points_btn.click(
158
+ fn=preprocess_image,
159
+ inputs=input_image,
160
+ outputs=[first_frame_path, tracking_points, trackings_input_label, points_map],
161
  queue=False
162
  )
163
 
164
  points_map.upload(
165
+ fn=preprocess_image,
166
+ inputs=[points_map],
167
+ outputs=[first_frame_path, tracking_points, trackings_input_label, input_image],
168
+ queue=False
169
  )
170
 
171
  points_map.select(
172
+ fn=get_point,
173
+ inputs=[point_type, tracking_points, trackings_input_label, first_frame_path],
174
+ outputs=[tracking_points, trackings_input_label, points_map],
175
+ queue=False
176
  )
177
 
178
  submit_btn.click(
179
+ fn=sam_process,
180
+ inputs=[input_image, checkpoint, tracking_points, trackings_input_label],
181
+ outputs=[output_result, output_result_mask]
182
  )
183
 
184
  demo.launch(show_api=False, show_error=True)