curt-park commited on
Commit
5fd2412
·
1 Parent(s): 1615d09
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -30,7 +30,7 @@ image = None
30
  ###################################
31
  # Functions.
32
  ###################################
33
- # @st.cache_resource
34
  def load_model(model_path: str, device: torch.device) -> BasePredictor:
35
  model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
36
  predictor_params = {"brs_mode": "NoBRS"}
@@ -54,9 +54,7 @@ def feed_clicks(
54
  clicker.add_click(click)
55
 
56
 
57
- def predict(
58
- image: Image, mask: torch.Tensor, threshold: float = 0.5
59
- ) -> torch.Tensor:
60
  predictor.set_input_image(np.array(image))
61
  with st.spinner("Wait for prediction..."):
62
  pred = predictor.get_prediction(clicker, prev_mask=mask)
@@ -120,7 +118,7 @@ if canvas_result.json_data and canvas_result.json_data["objects"] and image:
120
  feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height)
121
 
122
  # Run prediction.
123
- mask = torch.zeros((1, 1, image_width, image_height), device=device)
124
  pred = predict(image, mask, threshold)
125
 
126
  # Show the prediction result.
 
30
  ###################################
31
  # Functions.
32
  ###################################
33
+ # @st.cache_resource # TODO: this doesn't work on Huggingface!
34
  def load_model(model_path: str, device: torch.device) -> BasePredictor:
35
  model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
36
  predictor_params = {"brs_mode": "NoBRS"}
 
54
  clicker.add_click(click)
55
 
56
 
57
+ def predict(image: Image, mask: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
 
 
58
  predictor.set_input_image(np.array(image))
59
  with st.spinner("Wait for prediction..."):
60
  pred = predictor.get_prediction(clicker, prev_mask=mask)
 
118
  feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height)
119
 
120
  # Run prediction.
121
+ mask = torch.zeros((1, 1, image_height, image_width), device=device)
122
  pred = predict(image, mask, threshold)
123
 
124
  # Show the prediction result.