curt-park commited on
Commit
96d7d21
·
1 Parent(s): 98ddf8e

Uncomment all lines

Browse files
Files changed (1) hide show
  1. app.py +44 -44
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import torch
3
- # import numpy as np
4
  import cv2
5
  import wget
6
  import os
@@ -38,48 +38,48 @@ with st.spinner("Wait for loading a model..."):
38
  predictor = get_predictor(model, device=device, **predictor_params)
39
 
40
  # Create a canvas component.
41
- #image = None
42
- #if image_path:
43
- # image = Image.open(image_path)
44
- #canvas_height, canvas_width = 600, 600
45
- #pos_color, neg_color = "#3498DB", "#C70039"
46
- #st.title("Canvas:")
47
- #canvas_result = st_canvas(
48
- # fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
49
- # stroke_width=3,
50
- # stroke_color=pos_color if marking_type == "positive" else neg_color,
51
- # background_color="#eee",
52
- # background_image=image,
53
- # update_streamlit=True,
54
- # drawing_mode="point",
55
- # point_display_radius=3,
56
- # key="canvas",
57
- # width=canvas_width,
58
- # height=canvas_height,
59
- #)
60
 
61
  # Check the user inputs ans execute predictions.
62
- #st.title("Prediction:")
63
- #if canvas_result.json_data and canvas_result.json_data["objects"] and image:
64
- # objects = canvas_result.json_data["objects"]
65
- # image_width, image_height = image.size
66
- # ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
67
- #
68
- # err_x, err_y = 5.5, 1.0
69
- # pos_clicks, neg_clicks = [], []
70
- # for click in objects:
71
- # x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
72
- # x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
73
- #
74
- # is_positive = click["stroke"] == pos_color
75
- # click = ck.Click(is_positive=is_positive, coords=(y, x))
76
- # clicker.add_click(click)
77
- #
78
- # # prediction.
79
- # pred = None
80
- # predictor.set_input_image(np.array(image))
81
- # with st.spinner("Wait for prediction..."):
82
- # pred = predictor.get_prediction(clicker, prev_mask=None)
83
- # pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
84
- # pred = np.where(pred > threshold, 1.0, 0)
85
- # st.image(pred, caption="")
 
1
  import streamlit as st
2
  import torch
3
+ import numpy as np
4
  import cv2
5
  import wget
6
  import os
 
38
  predictor = get_predictor(model, device=device, **predictor_params)
39
 
40
  # Create a canvas component.
41
+ image = None
42
+ if image_path:
43
+ image = Image.open(image_path)
44
+ canvas_height, canvas_width = 600, 600
45
+ pos_color, neg_color = "#3498DB", "#C70039"
46
+ st.title("Canvas:")
47
+ canvas_result = st_canvas(
48
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
49
+ stroke_width=3,
50
+ stroke_color=pos_color if marking_type == "positive" else neg_color,
51
+ background_color="#eee",
52
+ background_image=image,
53
+ update_streamlit=True,
54
+ drawing_mode="point",
55
+ point_display_radius=3,
56
+ key="canvas",
57
+ width=canvas_width,
58
+ height=canvas_height,
59
+ )
60
 
61
  # Check the user inputs ans execute predictions.
62
+ st.title("Prediction:")
63
+ if canvas_result.json_data and canvas_result.json_data["objects"] and image:
64
+ objects = canvas_result.json_data["objects"]
65
+ image_width, image_height = image.size
66
+ ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
67
+
68
+ err_x, err_y = 5.5, 1.0
69
+ pos_clicks, neg_clicks = [], []
70
+ for click in objects:
71
+ x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
72
+ x, y = min(image_width, max(0, x)), min(image_height, max(0, y))
73
+
74
+ is_positive = click["stroke"] == pos_color
75
+ click = ck.Click(is_positive=is_positive, coords=(y, x))
76
+ clicker.add_click(click)
77
+
78
+ # prediction.
79
+ pred = None
80
+ predictor.set_input_image(np.array(image))
81
+ with st.spinner("Wait for prediction..."):
82
+ pred = predictor.get_prediction(clicker, prev_mask=None)
83
+ pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
84
+ pred = np.where(pred > threshold, 1.0, 0)
85
+ st.image(pred, caption="")