File size: 4,126 Bytes
1615d09
 
 
 
 
e82cf8b
2cdd41c
 
 
 
 
 
 
1615d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdd41c
1615d09
 
 
a74cdb0
1615d09
7d80b1e
 
 
 
 
 
1615d09
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdd41c
1615d09
5fd2412
1615d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdd41c
1615d09
2cdd41c
1615d09
2cdd41c
1615d09
 
2cdd41c
1615d09
 
 
 
2cdd41c
1615d09
 
 
2cdd41c
1615d09
2cdd41c
1615d09
 
 
2cdd41c
96d7d21
 
1615d09
 
 
 
 
 
 
 
 
 
 
96d7d21
2cdd41c
1615d09
 
 
2cdd41c
96d7d21
 
 
1615d09
96d7d21
ad0c87f
5fd2412
1615d09
ad0c87f
 
96d7d21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
from typing import Dict, List

import cv2
import numpy as np
import streamlit as st
import torch
import wget
from PIL import Image
from streamlit_drawable_canvas import st_canvas

from isegm.inference import clicker as ck
from isegm.inference import utils
from isegm.inference.predictors import BasePredictor, get_predictor

###################################
# Global scope objects.
###################################
URL_PREFIX = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
CANVAS_HEIGHT, CANVAS_WIDTH = 600, 600
POS_COLOR, NEG_COLOR = "#3498DB", "#C70039"
ERR_X, ERR_Y = 5.5, 1.0
MODELS = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clicker = ck.Clicker()
predictor = None
image = None


###################################
# Functions.
###################################
@st.cache(allow_output_mutation=True)
def load_model(model_path: str, device: torch.device) -> BasePredictor:
    model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
    predictor_params = {"brs_mode": "NoBRS"}
    predictor = get_predictor(model, device=device, **predictor_params)
    return predictor


def feed_clicks(
    clicker: ck.Clicker,
    clicks: List[Dict[str, float]],
    image_width: int,
    image_height: int,
) -> None:
    ratio_h, ratio_w = image_height / CANVAS_HEIGHT, image_width / CANVAS_WIDTH
    for click in clicks:
        x, y = (click["left"] + ERR_X) * ratio_w, (click["top"] + ERR_Y) * ratio_h
        x, y = min(image_width, max(0, x)), min(image_height, max(0, y))

        is_positive = click["stroke"] == POS_COLOR
        click = ck.Click(is_positive=is_positive, coords=(y, x))
        clicker.add_click(click)


def predict(image: Image, mask: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
    predictor.set_input_image(np.array(image))
    with st.spinner("Wait for prediction..."):
        pred = predictor.get_prediction(clicker, prev_mask=mask)
    pred = cv2.resize(
        pred,
        dsize=(CANVAS_HEIGHT, CANVAS_WIDTH),
        interpolation=cv2.INTER_CUBIC,
    )
    pred = np.where(pred > threshold, 1.0, 0)
    return pred


###################################
# Sidebar GUI
###################################
# Items in the sidebar.
model = st.sidebar.selectbox("Select a Model:", tuple(MODELS.keys()))
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
marking_type = st.sidebar.radio("Click Type:", ("Positive", "Negative"))
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
if image_path:
    image = Image.open(image_path).convert("RGB")

###################################
# Preparation
###################################
# Model.
with st.spinner("Wait for downloading a model..."):
    if not os.path.exists(MODELS[model]):
        _ = wget.download(f"{URL_PREFIX}/{MODELS[model]}")
# Predictor.
with st.spinner("Wait for loading a model..."):
    predictor = load_model(MODELS[model], device)

###################################
# GUI
###################################
# Create a canvas component.
st.title("Canvas:")
canvas_result = st_canvas(
    fill_color="rgba(255, 165, 0, 0.3)",  # Fixed fill color with some opacity
    stroke_width=3,
    stroke_color=POS_COLOR if marking_type == "Positive" else NEG_COLOR,
    background_color="#eee",
    background_image=image,
    update_streamlit=True,
    drawing_mode="point",
    point_display_radius=3,
    key="canvas",
    width=CANVAS_WIDTH,
    height=CANVAS_HEIGHT,
)

###################################
# Prediction
###################################
# Check the user inputs ans execute predictions.
st.title("Prediction:")
if canvas_result.json_data and canvas_result.json_data["objects"] and image:
    image_width, image_height = image.size
    feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height)

    # Run prediction.
    mask = torch.zeros((1, 1, image_height, image_width), device=device)
    pred = predict(image, mask, threshold)

    # Show the prediction result.
    st.image(pred, caption="")