curt-park commited on
Commit
ad0c87f
·
1 Parent(s): 2535e18

Make gpu compatible

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -26,7 +26,7 @@ image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "
26
 
27
  # Objects for prediction.
28
  clicker = ck.Clicker()
29
- device = torch.device("cpu")
30
  predictor = None
31
  with st.spinner("Wait for downloading a model..."):
32
  if not os.path.exists(models[model]):
@@ -43,6 +43,7 @@ if image_path:
43
  image = Image.open(image_path).convert("RGB")
44
  canvas_height, canvas_width = 600, 600
45
  pos_color, neg_color = "#3498DB", "#C70039"
 
46
  st.title("Canvas:")
47
  canvas_result = st_canvas(
48
  fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
@@ -75,11 +76,15 @@ if canvas_result.json_data and canvas_result.json_data["objects"] and image:
75
  click = ck.Click(is_positive=is_positive, coords=(y, x))
76
  clicker.add_click(click)
77
 
78
- # prediction.
79
  pred = None
80
  predictor.set_input_image(np.array(image))
 
 
81
  with st.spinner("Wait for prediction..."):
82
- pred = predictor.get_prediction(clicker, prev_mask=None)
83
  pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
84
  pred = np.where(pred > threshold, 1.0, 0)
 
 
85
  st.image(pred, caption="")
 
26
 
27
  # Objects for prediction.
28
  clicker = ck.Clicker()
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  predictor = None
31
  with st.spinner("Wait for downloading a model..."):
32
  if not os.path.exists(models[model]):
 
43
  image = Image.open(image_path).convert("RGB")
44
  canvas_height, canvas_width = 600, 600
45
  pos_color, neg_color = "#3498DB", "#C70039"
46
+
47
  st.title("Canvas:")
48
  canvas_result = st_canvas(
49
  fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
 
76
  click = ck.Click(is_positive=is_positive, coords=(y, x))
77
  clicker.add_click(click)
78
 
79
+ # Run prediction.
80
  pred = None
81
  predictor.set_input_image(np.array(image))
82
+ init_mask = torch.zeros((1, 1, image_height, image_width), device=device)
83
+
84
  with st.spinner("Wait for prediction..."):
85
+ pred = predictor.get_prediction(clicker, prev_mask=init_mask)
86
  pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC)
87
  pred = np.where(pred > threshold, 1.0, 0)
88
+
89
+ # Show the prediction result.
90
  st.image(pred, caption="")