curt-park commited on
Commit
7d80b1e
·
1 Parent(s): ad0c87f

Cache loaded model

Browse files
Files changed (1) hide show
  1. app.py +20 -16
app.py CHANGED
@@ -12,11 +12,24 @@ from isegm.inference import clicker as ck
12
  from isegm.inference import utils
13
  from isegm.inference.predictors import get_predictor
14
 
15
- # Model Path
16
- prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
17
- models = {
18
- "RITM": "ritm_coco_lvis_h18_itermask.pth",
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # Items in the sidebar.
22
  model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
@@ -25,24 +38,16 @@ marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative"))
25
  image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
26
 
27
  # Objects for prediction.
28
- clicker = ck.Clicker()
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- predictor = None
31
  with st.spinner("Wait for downloading a model..."):
32
  if not os.path.exists(models[model]):
33
- _ = wget.download(f"{prefix}/{models[model]}")
34
 
35
  with st.spinner("Wait for loading a model..."):
36
- model = utils.load_is_model(models[model], device, cpu_dist_maps=True)
37
- predictor_params = {"brs_mode": "NoBRS"}
38
- predictor = get_predictor(model, device=device, **predictor_params)
39
 
40
  # Create a canvas component.
41
- image = None
42
  if image_path:
43
  image = Image.open(image_path).convert("RGB")
44
- canvas_height, canvas_width = 600, 600
45
- pos_color, neg_color = "#3498DB", "#C70039"
46
 
47
  st.title("Canvas:")
48
  canvas_result = st_canvas(
@@ -66,7 +71,6 @@ if canvas_result.json_data and canvas_result.json_data["objects"] and image:
66
  image_width, image_height = image.size
67
  ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
68
 
69
- err_x, err_y = 5.5, 1.0
70
  pos_clicks, neg_clicks = [], []
71
  for click in objects:
72
  x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h
 
12
  from isegm.inference import utils
13
  from isegm.inference.predictors import get_predictor
14
 
15
+ @st.cache_data
16
+ def load_model(model_path, device):
17
+ model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
18
+ predictor_params = {"brs_mode": "NoBRS"}
19
+ predictor = get_predictor(model, device=device, **predictor_params)
20
+ return predictor
21
+
22
+
23
+ # Objects in the global scope
24
+ url_prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
25
+ models = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
26
+ clicker = ck.Clicker()
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ pos_color, neg_color = "#3498DB", "#C70039"
29
+ canvas_height, canvas_width = 600, 600
30
+ err_x, err_y = 5.5, 1.0
31
+ predictor = None
32
+ image = None
33
 
34
  # Items in the sidebar.
35
  model = st.sidebar.selectbox("Select a Model:", tuple(models.keys()))
 
38
  image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
39
 
40
  # Objects for prediction.
 
 
 
41
  with st.spinner("Wait for downloading a model..."):
42
  if not os.path.exists(models[model]):
43
+ _ = wget.download(f"{url_prefix}/{models[model]}")
44
 
45
  with st.spinner("Wait for loading a model..."):
46
+ predictor = load_model(models[model], device)
 
 
47
 
48
  # Create a canvas component.
 
49
  if image_path:
50
  image = Image.open(image_path).convert("RGB")
 
 
51
 
52
  st.title("Canvas:")
53
  canvas_result = st_canvas(
 
71
  image_width, image_height = image.size
72
  ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width
73
 
 
74
  pos_clicks, neg_clicks = [], []
75
  for click in objects:
76
  x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h