Spaces:
Runtime error
Runtime error
File size: 4,126 Bytes
1615d09 e82cf8b 2cdd41c 1615d09 2cdd41c 1615d09 a74cdb0 1615d09 7d80b1e 1615d09 2cdd41c 1615d09 5fd2412 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 96d7d21 1615d09 96d7d21 2cdd41c 1615d09 2cdd41c 96d7d21 1615d09 96d7d21 ad0c87f 5fd2412 1615d09 ad0c87f 96d7d21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
from typing import Dict, List
import cv2
import numpy as np
import streamlit as st
import torch
import wget
from PIL import Image
from streamlit_drawable_canvas import st_canvas
from isegm.inference import clicker as ck
from isegm.inference import utils
from isegm.inference.predictors import BasePredictor, get_predictor
###################################
# Global scope objects.
###################################
URL_PREFIX = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
CANVAS_HEIGHT, CANVAS_WIDTH = 600, 600
POS_COLOR, NEG_COLOR = "#3498DB", "#C70039"
ERR_X, ERR_Y = 5.5, 1.0
MODELS = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clicker = ck.Clicker()
predictor = None
image = None
###################################
# Functions.
###################################
@st.cache(allow_output_mutation=True)
def load_model(model_path: str, device: torch.device) -> BasePredictor:
model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
predictor_params = {"brs_mode": "NoBRS"}
predictor = get_predictor(model, device=device, **predictor_params)
return predictor
def feed_clicks(
clicker: ck.Clicker,
clicks: List[Dict[str, float]],
image_width: int,
image_height: int,
) -> None:
ratio_h, ratio_w = image_height / CANVAS_HEIGHT, image_width / CANVAS_WIDTH
for click in clicks:
x, y = (click["left"] + ERR_X) * ratio_w, (click["top"] + ERR_Y) * ratio_h
x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
is_positive = click["stroke"] == POS_COLOR
click = ck.Click(is_positive=is_positive, coords=(y, x))
clicker.add_click(click)
def predict(image: Image, mask: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
predictor.set_input_image(np.array(image))
with st.spinner("Wait for prediction..."):
pred = predictor.get_prediction(clicker, prev_mask=mask)
pred = cv2.resize(
pred,
dsize=(CANVAS_HEIGHT, CANVAS_WIDTH),
interpolation=cv2.INTER_CUBIC,
)
pred = np.where(pred > threshold, 1.0, 0)
return pred
###################################
# Sidebar GUI
###################################
# Items in the sidebar.
model = st.sidebar.selectbox("Select a Model:", tuple(MODELS.keys()))
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
marking_type = st.sidebar.radio("Click Type:", ("Positive", "Negative"))
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
if image_path:
image = Image.open(image_path).convert("RGB")
###################################
# Preparation
###################################
# Model.
with st.spinner("Wait for downloading a model..."):
if not os.path.exists(MODELS[model]):
_ = wget.download(f"{URL_PREFIX}/{MODELS[model]}")
# Predictor.
with st.spinner("Wait for loading a model..."):
predictor = load_model(MODELS[model], device)
###################################
# GUI
###################################
# Create a canvas component.
st.title("Canvas:")
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=3,
stroke_color=POS_COLOR if marking_type == "Positive" else NEG_COLOR,
background_color="#eee",
background_image=image,
update_streamlit=True,
drawing_mode="point",
point_display_radius=3,
key="canvas",
width=CANVAS_WIDTH,
height=CANVAS_HEIGHT,
)
###################################
# Prediction
###################################
# Check the user inputs ans execute predictions.
st.title("Prediction:")
if canvas_result.json_data and canvas_result.json_data["objects"] and image:
image_width, image_height = image.size
feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height)
# Run prediction.
mask = torch.zeros((1, 1, image_height, image_width), device=device)
pred = predict(image, mask, threshold)
# Show the prediction result.
st.image(pred, caption="")
|