curt-park's picture
Rephrase texts
9dac542
import os
from typing import Dict, List
import cv2
import numpy as np
import streamlit as st
import torch
import wget
from PIL import Image
from streamlit_drawable_canvas import st_canvas
from isegm.inference import clicker as ck
from isegm.inference import utils
from isegm.inference.predictors import BasePredictor, get_predictor
###################################
# Global scope objects.
###################################
URL_PREFIX = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
MODELS = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
POS_COLOR, NEG_COLOR = "#3498DB", "#C70039"
CANVAS_HEIGHT, CANVAS_WIDTH = 600, 600
ERR_X, ERR_Y = 5.5, 1.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clicker = ck.Clicker()
predictor = None
image = None
###################################
# Functions.
###################################
@st.cache(allow_output_mutation=True)
def load_model(model_path: str, device: torch.device) -> BasePredictor:
model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
predictor_params = {"brs_mode": "NoBRS"}
predictor = get_predictor(model, device=device, **predictor_params)
return predictor
def feed_clicks(
clicker: ck.Clicker,
clicks: List[Dict[str, float]],
image_width: int,
image_height: int,
) -> None:
ratio_h, ratio_w = image_height / CANVAS_HEIGHT, image_width / CANVAS_WIDTH
for click in clicks:
x, y = (click["left"] + ERR_X) * ratio_w, (click["top"] + ERR_Y) * ratio_h
x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
is_positive = click["stroke"] == POS_COLOR
click = ck.Click(is_positive=is_positive, coords=(y, x))
clicker.add_click(click)
def predict(image: Image, mask: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
predictor.set_input_image(np.array(image))
with st.spinner("Wait for prediction..."):
pred = predictor.get_prediction(clicker, prev_mask=mask)
pred = cv2.resize(
pred,
dsize=(CANVAS_HEIGHT, CANVAS_WIDTH),
interpolation=cv2.INTER_CUBIC,
)
pred = np.where(pred > threshold, 1.0, 0)
return pred
###################################
# Sidebar GUI
###################################
# Items in the sidebar.
model = st.sidebar.selectbox("Select a Method:", tuple(MODELS.keys()))
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
marking_type = st.sidebar.radio("Click Type:", ("Positive", "Negative"))
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
if image_path:
image = Image.open(image_path).convert("RGB")
###################################
# Preparation
###################################
# Model.
with st.spinner("Wait for downloading a model..."):
if not os.path.exists(MODELS[model]):
_ = wget.download(f"{URL_PREFIX}/{MODELS[model]}")
# Predictor.
with st.spinner("Wait for loading a model..."):
predictor = load_model(MODELS[model], device)
###################################
# GUI
###################################
# Create a canvas component.
st.title("Canvas:")
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=3,
stroke_color=POS_COLOR if marking_type == "Positive" else NEG_COLOR,
background_color="#eee",
background_image=image,
update_streamlit=True,
drawing_mode="point",
point_display_radius=3,
key="canvas",
width=CANVAS_WIDTH,
height=CANVAS_HEIGHT,
)
###################################
# Prediction
###################################
# Check the user inputs ans execute predictions.
st.title("Prediction:")
if canvas_result.json_data and canvas_result.json_data["objects"] and image:
image_width, image_height = image.size
feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height)
# Run prediction.
mask = torch.zeros((1, 1, image_height, image_width), device=device)
pred = predict(image, mask, threshold)
# Show the prediction result.
st.image(pred, caption="")