ktllc commited on
Commit
3882946
·
1 Parent(s): 1f09605

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -8,7 +8,6 @@ from io import BytesIO
8
  import torch
9
  import clip
10
 
11
-
12
  # Load the segmentation model
13
  sam_checkpoint = "sam_vit_h_4b8939.pth"
14
  model_type = "vit_h"
@@ -48,6 +47,7 @@ def find_similarity(base64_image, text_input):
48
  except Exception as e:
49
  return str(e)
50
 
 
51
  def segment_image(input_image, text_input):
52
  image_bytes = base64.b64decode(input_image)
53
  image = Image.open(BytesIO(image_bytes))
@@ -76,6 +76,10 @@ def segment_image(input_image, text_input):
76
  x, y, w, h = map(int, mask_dict['bbox'])
77
  cropped_region = segmented_region[y:y+h, x:x+w]
78
 
 
 
 
 
79
  # Convert to base64 image
80
  _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
81
  segmented_image_base64 = base64.b64encode(buffer).decode()
@@ -95,10 +99,10 @@ def segment_image(input_image, text_input):
95
  # Return the segmented images in descending order of similarity
96
  return segmented_regions
97
 
 
98
  # Create Gradio components
99
  input_image = gr.Textbox(label="Base64 Image", lines=8)
100
  text_input = gr.Textbox(label="Text Input") # Use Textbox with a label
101
- #output_images = gr.outputs.JSON()
102
 
103
  # Create a Gradio interface
104
  gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs="text").launch()
 
8
  import torch
9
  import clip
10
 
 
11
  # Load the segmentation model
12
  sam_checkpoint = "sam_vit_h_4b8939.pth"
13
  model_type = "vit_h"
 
47
  except Exception as e:
48
  return str(e)
49
 
50
+
51
  def segment_image(input_image, text_input):
52
  image_bytes = base64.b64decode(input_image)
53
  image = Image.open(BytesIO(image_bytes))
 
76
  x, y, w, h = map(int, mask_dict['bbox'])
77
  cropped_region = segmented_region[y:y+h, x:x+w]
78
 
79
+ if not cropped_region.size:
80
+ # If the cropped region is empty, return the input image as is
81
+ return input_image
82
+
83
  # Convert to base64 image
84
  _, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
85
  segmented_image_base64 = base64.b64encode(buffer).decode()
 
99
  # Return the segmented images in descending order of similarity
100
  return segmented_regions
101
 
102
+
103
  # Create Gradio components
104
  input_image = gr.Textbox(label="Base64 Image", lines=8)
105
  text_input = gr.Textbox(label="Text Input") # Use Textbox with a label
 
106
 
107
  # Create a Gradio interface
108
  gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs="text").launch()