Spaces:
Runtime error
Runtime error
import os | |
from typing import Dict, List | |
import cv2 | |
import numpy as np | |
import streamlit as st | |
import torch | |
import wget | |
from PIL import Image | |
from streamlit_drawable_canvas import st_canvas | |
from isegm.inference import clicker as ck | |
from isegm.inference import utils | |
from isegm.inference.predictors import BasePredictor, get_predictor | |
################################### | |
# Global scope objects. | |
################################### | |
URL_PREFIX = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main" | |
MODELS = {"RITM": "ritm_coco_lvis_h18_itermask.pth"} | |
POS_COLOR, NEG_COLOR = "#3498DB", "#C70039" | |
CANVAS_HEIGHT, CANVAS_WIDTH = 600, 600 | |
ERR_X, ERR_Y = 5.5, 1.0 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
clicker = ck.Clicker() | |
predictor = None | |
image = None | |
################################### | |
# Functions. | |
################################### | |
def load_model(model_path: str, device: torch.device) -> BasePredictor: | |
model = utils.load_is_model(model_path, device, cpu_dist_maps=True) | |
predictor_params = {"brs_mode": "NoBRS"} | |
predictor = get_predictor(model, device=device, **predictor_params) | |
return predictor | |
def feed_clicks( | |
clicker: ck.Clicker, | |
clicks: List[Dict[str, float]], | |
image_width: int, | |
image_height: int, | |
) -> None: | |
ratio_h, ratio_w = image_height / CANVAS_HEIGHT, image_width / CANVAS_WIDTH | |
for click in clicks: | |
x, y = (click["left"] + ERR_X) * ratio_w, (click["top"] + ERR_Y) * ratio_h | |
x, y = min(image_width, max(0, x)), min(image_height, max(0, y)) | |
is_positive = click["stroke"] == POS_COLOR | |
click = ck.Click(is_positive=is_positive, coords=(y, x)) | |
clicker.add_click(click) | |
def predict(image: Image, mask: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: | |
predictor.set_input_image(np.array(image)) | |
with st.spinner("Wait for prediction..."): | |
pred = predictor.get_prediction(clicker, prev_mask=mask) | |
pred = cv2.resize( | |
pred, | |
dsize=(CANVAS_HEIGHT, CANVAS_WIDTH), | |
interpolation=cv2.INTER_CUBIC, | |
) | |
pred = np.where(pred > threshold, 1.0, 0) | |
return pred | |
################################### | |
# Sidebar GUI | |
################################### | |
# Items in the sidebar. | |
model = st.sidebar.selectbox("Select a Method:", tuple(MODELS.keys())) | |
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5) | |
marking_type = st.sidebar.radio("Click Type:", ("Positive", "Negative")) | |
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"]) | |
if image_path: | |
image = Image.open(image_path).convert("RGB") | |
################################### | |
# Preparation | |
################################### | |
# Model. | |
with st.spinner("Wait for downloading a model..."): | |
if not os.path.exists(MODELS[model]): | |
_ = wget.download(f"{URL_PREFIX}/{MODELS[model]}") | |
# Predictor. | |
with st.spinner("Wait for loading a model..."): | |
predictor = load_model(MODELS[model], device) | |
################################### | |
# GUI | |
################################### | |
# Create a canvas component. | |
st.title("Canvas:") | |
canvas_result = st_canvas( | |
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity | |
stroke_width=3, | |
stroke_color=POS_COLOR if marking_type == "Positive" else NEG_COLOR, | |
background_color="#eee", | |
background_image=image, | |
update_streamlit=True, | |
drawing_mode="point", | |
point_display_radius=3, | |
key="canvas", | |
width=CANVAS_WIDTH, | |
height=CANVAS_HEIGHT, | |
) | |
################################### | |
# Prediction | |
################################### | |
# Check the user inputs ans execute predictions. | |
st.title("Prediction:") | |
if canvas_result.json_data and canvas_result.json_data["objects"] and image: | |
image_width, image_height = image.size | |
feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height) | |
# Run prediction. | |
mask = torch.zeros((1, 1, image_height, image_width), device=device) | |
pred = predict(image, mask, threshold) | |
# Show the prediction result. | |
st.image(pred, caption="") | |