Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,6 @@ from io import BytesIO
|
|
8 |
import torch
|
9 |
import clip
|
10 |
|
11 |
-
|
12 |
# Load the segmentation model
|
13 |
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
14 |
model_type = "vit_h"
|
@@ -48,6 +47,7 @@ def find_similarity(base64_image, text_input):
|
|
48 |
except Exception as e:
|
49 |
return str(e)
|
50 |
|
|
|
51 |
def segment_image(input_image, text_input):
|
52 |
image_bytes = base64.b64decode(input_image)
|
53 |
image = Image.open(BytesIO(image_bytes))
|
@@ -76,6 +76,10 @@ def segment_image(input_image, text_input):
|
|
76 |
x, y, w, h = map(int, mask_dict['bbox'])
|
77 |
cropped_region = segmented_region[y:y+h, x:x+w]
|
78 |
|
|
|
|
|
|
|
|
|
79 |
# Convert to base64 image
|
80 |
_, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
|
81 |
segmented_image_base64 = base64.b64encode(buffer).decode()
|
@@ -95,10 +99,10 @@ def segment_image(input_image, text_input):
|
|
95 |
# Return the segmented images in descending order of similarity
|
96 |
return segmented_regions
|
97 |
|
|
|
98 |
# Create Gradio components
|
99 |
input_image = gr.Textbox(label="Base64 Image", lines=8)
|
100 |
text_input = gr.Textbox(label="Text Input") # Use Textbox with a label
|
101 |
-
#output_images = gr.outputs.JSON()
|
102 |
|
103 |
# Create a Gradio interface
|
104 |
gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs="text").launch()
|
|
|
8 |
import torch
|
9 |
import clip
|
10 |
|
|
|
11 |
# Load the segmentation model
|
12 |
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
13 |
model_type = "vit_h"
|
|
|
47 |
except Exception as e:
|
48 |
return str(e)
|
49 |
|
50 |
+
|
51 |
def segment_image(input_image, text_input):
|
52 |
image_bytes = base64.b64decode(input_image)
|
53 |
image = Image.open(BytesIO(image_bytes))
|
|
|
76 |
x, y, w, h = map(int, mask_dict['bbox'])
|
77 |
cropped_region = segmented_region[y:y+h, x:x+w]
|
78 |
|
79 |
+
if not cropped_region.size:
|
80 |
+
# If the cropped region is empty, return the input image as is
|
81 |
+
return input_image
|
82 |
+
|
83 |
# Convert to base64 image
|
84 |
_, buffer = cv2.imencode(".png", cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
|
85 |
segmented_image_base64 = base64.b64encode(buffer).decode()
|
|
|
99 |
# Return the segmented images in descending order of similarity
|
100 |
return segmented_regions
|
101 |
|
102 |
+
|
103 |
# Create Gradio components
|
104 |
input_image = gr.Textbox(label="Base64 Image", lines=8)
|
105 |
text_input = gr.Textbox(label="Text Input") # Use Textbox with a label
|
|
|
106 |
|
107 |
# Create a Gradio interface
|
108 |
gr.Interface(fn=segment_image, inputs=[input_image, text_input], outputs="text").launch()
|