Spaces:
Runtime error
Runtime error
import nmslib | |
import numpy as np | |
import streamlit as st | |
from transformers import AutoTokenizer, CLIPProcessor | |
from model import FlaxHybridCLIP | |
from PIL import Image | |
import jax.numpy as jnp | |
import os | |
import jax | |
# st.header('Under construction') | |
st.sidebar.title("CLIP React Demo") | |
st.sidebar.write("[Model Card](https://huggingface.co/flax-community/clip-reply)") | |
sc= st.sidebar.columns(2) | |
sc[0].image("./huggingface_explode3.png",width=150) | |
sc[1].write(" ") | |
sc[1].write(" ") | |
sc[1].markdown("## Researching fun") | |
with st.sidebar.expander("Motivation",expanded=True): | |
st.markdown( | |
""" | |
Reaction GIFs became an integral part of communication. | |
They convey complex emotions with many levels, in a short compact format. | |
If a picture is worth a thousand words then a GIF is worth more. | |
A lot of people would agree it is not always easy to find the perfect reaction GIF. | |
This is just a first step in the more ambitious goal of GIF/Image generation. | |
""" | |
) | |
top_k=st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20) | |
col_count=4 | |
file_names=os.listdir("./jpg") | |
file_names.sort() | |
show_val=st.sidebar.button("show all validation set images") | |
if show_val: | |
cols=st.sidebar.columns(col_count) | |
for i,im in enumerate(file_names): | |
j=i%col_count | |
cols[j].image("./jpg/"+im) | |
st.write("# Search Reaction GIFs with CLIP ") | |
st.write(" ") | |
st.write(" ") | |
def load_model(): | |
model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
processor.tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base") | |
return model, processor | |
def load_image_index(): | |
index = nmslib.init(method='hnsw', space='cosinesimil') | |
index.loadIndex("./features/image_embeddings", load_data=True) | |
return index | |
image_index = load_image_index() | |
model, processor = load_model() | |
# TODO | |
def add_image_emb(image): | |
image = Image.open(image).convert("RGB") | |
inputs = processor(text=[""], images=image, return_tensors="jax", padding=True) | |
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) | |
features = model(**inputs).image_embeds | |
image_index.addDataPoint(features) | |
def query_with_images(query_images,query_text): | |
images=[] | |
for im in query_images: | |
img=Image.open(im).convert("RGB") | |
if im.name.endswith(".gif"): | |
img.seek(0) | |
images.append(img) | |
inputs = processor(text=[query_text], images=images, return_tensors="jax", padding=True) | |
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image.reshape(-1) | |
# st.write(logits_per_image) | |
probs = jax.nn.softmax(logits_per_image) | |
# st.write(probs) | |
# st.write(list(zip(images,probs))) | |
results = sorted(list(zip(images,probs)),key=lambda x: x[1], reverse=True) | |
# st.write(results) | |
return zip(*results) | |
q_cols=st.columns([5,2,5]) | |
examples = ["OMG that is disgusting","I'm so scared right now"," I got the job 🎉","Congratulations to all the flax-community week teams","You're awesome","I love you ❤️"] | |
example_input = q_cols[0].radio("Example Queries :",examples,index=4,help="These are examples I wrote off the top of my head. They don't occur in the dataset") | |
q_cols[2].markdown( | |
""" | |
Searches among the validation set images if not specified | |
(There may be non-exact duplicates) | |
""" | |
) | |
query_text = q_cols[0].text_input("Write text you want to get reaction for", value=example_input) | |
query_images = q_cols[2].file_uploader("(optional) Upload images to rank them",type=['jpg','jpeg','gif'], accept_multiple_files=True) | |
if query_images: | |
st.write("Ranking your uploaded images with respect to input text:") | |
with st.spinner("Calculating..."): | |
ids, dists = query_with_images(query_images,query_text) | |
else: | |
st.write("Found these images within validation set:") | |
with st.spinner("Calculating..."): | |
proc = processor(text=[query_text], images=None, return_tensors="jax", padding=True) | |
vec = np.asarray(model.get_text_features(**proc)) | |
ids, dists = image_index.knnQuery(vec, k=top_k) | |
show_gif=st.checkbox("Play GIFs",value=True,help="Will play the original animation. Only first frame is used in training!") | |
ext = "jpg" if not show_gif else "gif" | |
res_cols=st.columns(col_count) | |
for i,(id_, dist) in enumerate(zip(ids, dists)): | |
j=i%col_count | |
with res_cols[j]: | |
if isinstance(id_, np.int32): | |
st.image(f"./{ext}/{file_names[id_][:-4]}.{ext}") | |
# st.write(file_names[id_]) | |
st.write(1.0 - dist, help="score") | |
else: | |
st.image(id_) | |
st.write(dist, help="score") | |
# Credits | |
st.sidebar.caption("Made by [Ceyda Cinarel](https://huggingface.co/ceyda)") | |