import os from typing import Dict, List import cv2 import numpy as np import streamlit as st import torch import wget from PIL import Image from streamlit_drawable_canvas import st_canvas from isegm.inference import clicker as ck from isegm.inference import utils from isegm.inference.predictors import BasePredictor, get_predictor ################################### # Global scope objects. ################################### URL_PREFIX = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main" MODELS = {"RITM": "ritm_coco_lvis_h18_itermask.pth"} POS_COLOR, NEG_COLOR = "#3498DB", "#C70039" CANVAS_HEIGHT, CANVAS_WIDTH = 600, 600 ERR_X, ERR_Y = 5.5, 1.0 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") clicker = ck.Clicker() predictor = None image = None ################################### # Functions. ################################### @st.cache(allow_output_mutation=True) def load_model(model_path: str, device: torch.device) -> BasePredictor: model = utils.load_is_model(model_path, device, cpu_dist_maps=True) predictor_params = {"brs_mode": "NoBRS"} predictor = get_predictor(model, device=device, **predictor_params) return predictor def feed_clicks( clicker: ck.Clicker, clicks: List[Dict[str, float]], image_width: int, image_height: int, ) -> None: ratio_h, ratio_w = image_height / CANVAS_HEIGHT, image_width / CANVAS_WIDTH for click in clicks: x, y = (click["left"] + ERR_X) * ratio_w, (click["top"] + ERR_Y) * ratio_h x, y = min(image_width, max(0, x)), min(image_height, max(0, y)) is_positive = click["stroke"] == POS_COLOR click = ck.Click(is_positive=is_positive, coords=(y, x)) clicker.add_click(click) def predict(image: Image, mask: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: predictor.set_input_image(np.array(image)) with st.spinner("Wait for prediction..."): pred = predictor.get_prediction(clicker, prev_mask=mask) pred = cv2.resize( pred, dsize=(CANVAS_HEIGHT, CANVAS_WIDTH), interpolation=cv2.INTER_CUBIC, ) pred = np.where(pred > threshold, 1.0, 0) return pred ################################### # Sidebar GUI ################################### # Items in the sidebar. model = st.sidebar.selectbox("Select a Method:", tuple(MODELS.keys())) threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5) marking_type = st.sidebar.radio("Click Type:", ("Positive", "Negative")) image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"]) if image_path: image = Image.open(image_path).convert("RGB") ################################### # Preparation ################################### # Model. with st.spinner("Wait for downloading a model..."): if not os.path.exists(MODELS[model]): _ = wget.download(f"{URL_PREFIX}/{MODELS[model]}") # Predictor. with st.spinner("Wait for loading a model..."): predictor = load_model(MODELS[model], device) ################################### # GUI ################################### # Create a canvas component. st.title("Canvas:") canvas_result = st_canvas( fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity stroke_width=3, stroke_color=POS_COLOR if marking_type == "Positive" else NEG_COLOR, background_color="#eee", background_image=image, update_streamlit=True, drawing_mode="point", point_display_radius=3, key="canvas", width=CANVAS_WIDTH, height=CANVAS_HEIGHT, ) ################################### # Prediction ################################### # Check the user inputs ans execute predictions. st.title("Prediction:") if canvas_result.json_data and canvas_result.json_data["objects"] and image: image_width, image_height = image.size feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height) # Run prediction. mask = torch.zeros((1, 1, image_height, image_width), device=device) pred = predict(image, mask, threshold) # Show the prediction result. st.image(pred, caption="")