Spaces:
Runtime error
Runtime error
Cache loaded model
Browse files
app.py
CHANGED
@@ -12,11 +12,24 @@ from isegm.inference import clicker as ck
|
|
12 |
from isegm.inference import utils
|
13 |
from isegm.inference.predictors import get_predictor
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
"
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# Items in the sidebar.
|
22 |
model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
|
@@ -25,24 +38,16 @@ marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative"))
|
|
25 |
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
|
26 |
|
27 |
# Objects for prediction.
|
28 |
-
clicker = ck.Clicker()
|
29 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
-
predictor = None
|
31 |
with st.spinner("Wait for downloading a model..."):
|
32 |
if not os.path.exists(models[model]):
|
33 |
-
_ = wget.download(f"{
|
34 |
|
35 |
with st.spinner("Wait for loading a model..."):
|
36 |
-
|
37 |
-
predictor_params = {"brs_mode": "NoBRS"}
|
38 |
-
predictor = get_predictor(model, device=device, **predictor_params)
|
39 |
|
40 |
# Create a canvas component.
|
41 |
-
image = None
|
42 |
if image_path:
|
43 |
image = Image.open(image_path).convert("RGB")
|
44 |
-
canvas_height, canvas_width = 600, 600
|
45 |
-
pos_color, neg_color = "#3498DB", "#C70039"
|
46 |
|
47 |
st.title("Canvas:")
|
48 |
canvas_result = st_canvas(
|
@@ -66,7 +71,6 @@ if canvas_result.json_data and canvas_result.json_data["objects"] and image:
|
|
66 |
image_width, image_height = image.size
|
67 |
ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
|
68 |
|
69 |
-
err_x, err_y = 5.5, 1.0
|
70 |
pos_clicks, neg_clicks = [], []
|
71 |
for click in objects:
|
72 |
x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
|
|
|
12 |
from isegm.inference import utils
|
13 |
from isegm.inference.predictors import get_predictor
|
14 |
|
15 |
+
@st.cache_data
|
16 |
+
def load_model(model_path, device):
|
17 |
+
model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
|
18 |
+
predictor_params = {"brs_mode": "NoBRS"}
|
19 |
+
predictor = get_predictor(model, device=device, **predictor_params)
|
20 |
+
return predictor
|
21 |
+
|
22 |
+
|
23 |
+
# Objects in the global scope
|
24 |
+
url_prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
|
25 |
+
models = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
|
26 |
+
clicker = ck.Clicker()
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
pos_color, neg_color = "#3498DB", "#C70039"
|
29 |
+
canvas_height, canvas_width = 600, 600
|
30 |
+
err_x, err_y = 5.5, 1.0
|
31 |
+
predictor = None
|
32 |
+
image = None
|
33 |
|
34 |
# Items in the sidebar.
|
35 |
model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
|
|
|
38 |
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
|
39 |
|
40 |
# Objects for prediction.
|
|
|
|
|
|
|
41 |
with st.spinner("Wait for downloading a model..."):
|
42 |
if not os.path.exists(models[model]):
|
43 |
+
_ = wget.download(f"{url_prefix}/{models[model]}")
|
44 |
|
45 |
with st.spinner("Wait for loading a model..."):
|
46 |
+
predictor = load_model(models[model], device)
|
|
|
|
|
47 |
|
48 |
# Create a canvas component.
|
|
|
49 |
if image_path:
|
50 |
image = Image.open(image_path).convert("RGB")
|
|
|
|
|
51 |
|
52 |
st.title("Canvas:")
|
53 |
canvas_result = st_canvas(
|
|
|
71 |
image_width, image_height = image.size
|
72 |
ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
|
73 |
|
|
|
74 |
pos_clicks, neg_clicks = [], []
|
75 |
for click in objects:
|
76 |
x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
|