Ceyda Cinarel commited on
Commit
578e499
·
1 Parent(s): 5214785

add explanation to app

Browse files
Files changed (2) hide show
  1. .streamlit/config.toml +2 -0
  2. app.py +33 -19
.streamlit/config.toml CHANGED
@@ -1 +1,3 @@
 
 
1
  maxUploadSize = 1
 
1
+ [server]
2
+
3
  maxUploadSize = 1
app.py CHANGED
@@ -8,18 +8,12 @@ import jax.numpy as jnp
8
  import os
9
  import jax
10
 
11
- st.header('Under construction')
12
 
13
-
14
- st.title("CLIP Reply Demo")
15
- st.sidebar.markdown(
16
- """
17
-
18
- Validation set: 351 images/273 deduped (There are still duplicates)
19
-
20
- Example Queries :
21
- """
22
- )
23
  @st.cache(allow_output_mutation=True)
24
  def load_model():
25
  model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply")
@@ -41,7 +35,7 @@ image_index = load_image_index()
41
  model, processor = load_model()
42
 
43
  col_count=4
44
- top_k=10
45
 
46
  show_val=st.sidebar.button("show all validation set images")
47
  if show_val:
@@ -76,15 +70,35 @@ def query_with_images(query_images,query_text):
76
  st.write(results)
77
  return zip(*results)
78
 
79
- q_cols=st.beta_columns(2)
80
- query_text = q_cols[0].text_input("Input text", value="I love you")
81
- query_images = q_cols[1].file_uploader("(optional) upload query image",type=['jpg','jpeg'], accept_multiple_files=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  if query_images:
84
- st.write("Ranking uploaded images with respect to input text")
85
  ids, dists = query_with_images(query_images,query_text)
86
  else:
87
- st.write("Finding within validation set")
88
  proc = processor(text=[query_text], images=None, return_tensors="jax", padding=True)
89
  vec = np.asarray(model.get_text_features(**proc))
90
  ids, dists = image_index.knnQuery(vec, k=top_k)
@@ -96,8 +110,8 @@ for i,(id_, dist) in enumerate(zip(ids, dists)):
96
  if isinstance(id_, np.int32):
97
  st.image("./imgs/"+file_names[id_])
98
  # st.write(file_names[id_])
99
- st.write(1.0 - dist)
100
  else:
101
  st.image(id_)
102
- st.write(dist)
103
 
 
8
  import os
9
  import jax
10
 
11
+ # st.header('Under construction')
12
 
13
+ st.sidebar.write("")
14
+ st.title("CLIP React Demo")
15
+ st.write("[Model Card](https://huggingface.co/flax-community/clip-reply)")
16
+ st.write(" ")
 
 
 
 
 
 
17
  @st.cache(allow_output_mutation=True)
18
  def load_model():
19
  model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply")
 
35
  model, processor = load_model()
36
 
37
  col_count=4
38
+ top_k=st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20)
39
 
40
  show_val=st.sidebar.button("show all validation set images")
41
  if show_val:
 
70
  st.write(results)
71
  return zip(*results)
72
 
73
+ q_cols=st.beta_columns([5,2,5])
74
+
75
+ q_cols[0].markdown(
76
+ """
77
+ Example Queries :
78
+
79
+ - I'm so scared right now
80
+ - I got the job 🎉
81
+ - OMG that is disgusting
82
+ - I'm awesome
83
+
84
+ """
85
+ )
86
+ q_cols[2].markdown(
87
+ """
88
+ Searches among the validation set images if not specified
89
+ (There may be non-exact duplicates)
90
+
91
+ """
92
+ )
93
+
94
+ query_text = q_cols[0].text_input("Input text you want to get reaction for", value="I love you ❤️")
95
+ query_images = q_cols[2].file_uploader("(optional) Upload images to rank them",type=['jpg','jpeg'], accept_multiple_files=True)
96
 
97
  if query_images:
98
+ st.write("Ranking your uploaded images with respect to input text:")
99
  ids, dists = query_with_images(query_images,query_text)
100
  else:
101
+ st.write("Found these images within validation set:")
102
  proc = processor(text=[query_text], images=None, return_tensors="jax", padding=True)
103
  vec = np.asarray(model.get_text_features(**proc))
104
  ids, dists = image_index.knnQuery(vec, k=top_k)
 
110
  if isinstance(id_, np.int32):
111
  st.image("./imgs/"+file_names[id_])
112
  # st.write(file_names[id_])
113
+ st.write(1.0 - dist, help="score")
114
  else:
115
  st.image(id_)
116
+ st.write(dist, help="score")
117