Spaces:
Runtime error
Runtime error
Make gpu compatible
Browse files
app.py
CHANGED
@@ -26,7 +26,7 @@ image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "
|
|
26 |
|
27 |
# Objects for prediction.
|
28 |
clicker = ck.Clicker()
|
29 |
-
device = torch.device("cpu")
|
30 |
predictor = None
|
31 |
with st.spinner("Wait for downloading a model..."):
|
32 |
if not os.path.exists(models[model]):
|
@@ -43,6 +43,7 @@ if image_path:
|
|
43 |
image = Image.open(image_path).convert("RGB")
|
44 |
canvas_height, canvas_width = 600, 600
|
45 |
pos_color, neg_color = "#3498DB", "#C70039"
|
|
|
46 |
st.title("Canvas:")
|
47 |
canvas_result = st_canvas(
|
48 |
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
@@ -75,11 +76,15 @@ if canvas_result.json_data and canvas_result.json_data["objects"] and image:
|
|
75 |
click = ck.Click(is_positive=is_positive, coords=(y, x))
|
76 |
clicker.add_click(click)
|
77 |
|
78 |
-
# prediction.
|
79 |
pred = None
|
80 |
predictor.set_input_image(np.array(image))
|
|
|
|
|
81 |
with st.spinner("Wait for prediction..."):
|
82 |
-
pred = predictor.get_prediction(clicker, prev_mask=
|
83 |
pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
|
84 |
pred = np.where(pred > threshold, 1.0, 0)
|
|
|
|
|
85 |
st.image(pred, caption="")
|
|
|
26 |
|
27 |
# Objects for prediction.
|
28 |
clicker = ck.Clicker()
|
29 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
predictor = None
|
31 |
with st.spinner("Wait for downloading a model..."):
|
32 |
if not os.path.exists(models[model]):
|
|
|
43 |
image = Image.open(image_path).convert("RGB")
|
44 |
canvas_height, canvas_width = 600, 600
|
45 |
pos_color, neg_color = "#3498DB", "#C70039"
|
46 |
+
|
47 |
st.title("Canvas:")
|
48 |
canvas_result = st_canvas(
|
49 |
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
|
|
76 |
click = ck.Click(is_positive=is_positive, coords=(y, x))
|
77 |
clicker.add_click(click)
|
78 |
|
79 |
+
# Run prediction.
|
80 |
pred = None
|
81 |
predictor.set_input_image(np.array(image))
|
82 |
+
init_mask = torch.zeros((1, 1, image_height, image_width), device=device)
|
83 |
+
|
84 |
with st.spinner("Wait for prediction..."):
|
85 |
+
pred = predictor.get_prediction(clicker, prev_mask=init_mask)
|
86 |
pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
|
87 |
pred = np.where(pred > threshold, 1.0, 0)
|
88 |
+
|
89 |
+
# Show the prediction result.
|
90 |
st.image(pred, caption="")
|