Spaces:
Runtime error
Runtime error
Fix bug
Browse files
app.py
CHANGED
@@ -30,7 +30,7 @@ image = None
|
|
30 |
###################################
|
31 |
# Functions.
|
32 |
###################################
|
33 |
-
# @st.cache_resource
|
34 |
def load_model(model_path: str, device: torch.device) -> BasePredictor:
|
35 |
model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
|
36 |
predictor_params = {"brs_mode": "NoBRS"}
|
@@ -54,9 +54,7 @@ def feed_clicks(
|
|
54 |
clicker.add_click(click)
|
55 |
|
56 |
|
57 |
-
def predict(
|
58 |
-
image: Image, mask: torch.Tensor, threshold: float = 0.5
|
59 |
-
) -> torch.Tensor:
|
60 |
predictor.set_input_image(np.array(image))
|
61 |
with st.spinner("Wait for prediction..."):
|
62 |
pred = predictor.get_prediction(clicker, prev_mask=mask)
|
@@ -120,7 +118,7 @@ if canvas_result.json_data and canvas_result.json_data["objects"] and image:
|
|
120 |
feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height)
|
121 |
|
122 |
# Run prediction.
|
123 |
-
mask = torch.zeros((1, 1,
|
124 |
pred = predict(image, mask, threshold)
|
125 |
|
126 |
# Show the prediction result.
|
|
|
30 |
###################################
|
31 |
# Functions.
|
32 |
###################################
|
33 |
+
# @st.cache_resource # TODO: this doesn't work on Huggingface!
|
34 |
def load_model(model_path: str, device: torch.device) -> BasePredictor:
|
35 |
model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
|
36 |
predictor_params = {"brs_mode": "NoBRS"}
|
|
|
54 |
clicker.add_click(click)
|
55 |
|
56 |
|
57 |
+
def predict(image: Image, mask: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
|
|
|
|
|
58 |
predictor.set_input_image(np.array(image))
|
59 |
with st.spinner("Wait for prediction..."):
|
60 |
pred = predictor.get_prediction(clicker, prev_mask=mask)
|
|
|
118 |
feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height)
|
119 |
|
120 |
# Run prediction.
|
121 |
+
mask = torch.zeros((1, 1, image_height, image_width), device=device)
|
122 |
pred = predict(image, mask, threshold)
|
123 |
|
124 |
# Show the prediction result.
|