Spaces:
Runtime error
Runtime error
import jax | |
import flax | |
import matplotlib.pyplot as plt | |
import nmslib | |
import numpy as np | |
import os | |
import requests | |
import streamlit as st | |
from tempfile import NamedTemporaryFile | |
from torchvision.transforms import Compose, Resize, ToPILImage | |
from transformers import CLIPProcessor, FlaxCLIPModel | |
from PIL import Image | |
import utils | |
BASELINE_MODEL = "openai/clip-vit-base-patch32" | |
MODEL_PATH = "flax-community/clip-rsicd-v2" | |
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv" | |
IMAGES_DIR = "./images" | |
DEMO_IMAGES_DIR = "./demo-images" | |
def split_image(X): | |
num_rows = X.shape[0] // 224 | |
num_cols = X.shape[1] // 224 | |
Xc = X[0 : num_rows * 224, 0 : num_cols * 224, :] | |
patches = [] | |
for j in range(num_rows): | |
for i in range(num_cols): | |
patches.append(Xc[j * 224 : (j + 1) * 224, | |
i * 224 : (i + 1) * 224, | |
:]) | |
return num_rows, num_cols, patches | |
def get_patch_probabilities(patches, searched_feature, | |
image_preprocesor, | |
model, processor): | |
images = [image_preprocesor(patch) for patch in patches] | |
text = "An aerial image of {:s}".format(searched_feature) | |
inputs = processor(images=images, | |
text=text, | |
return_tensors="jax", | |
padding=True) | |
outputs = model(**inputs) | |
probs = jax.nn.softmax(outputs.logits_per_text, axis=-1) | |
probs_np = np.asarray(probs)[0] | |
return probs_np | |
def get_image_ranks(probs): | |
temp = np.argsort(-probs) | |
ranks = np.empty_like(temp) | |
ranks[temp] = np.arange(len(probs)) | |
return ranks | |
def download_and_prepare_image(image_url): | |
""" | |
Take input image and resize it to 672x896 | |
""" | |
try: | |
image_raw = requests.get(image_url, stream=True,).raw | |
image = Image.open(image_raw).convert("RGB") | |
width, height = image.size | |
# print("WID,HGT:", width, height) | |
if width < 224 or height < 224: | |
return None | |
# take the short edge and reduce to 672 | |
if width < height: | |
resize_factor = 672 / width | |
image = image.resize((672, int(height * resize_factor))) | |
image = image.crop((0, 0, 672, 896)) | |
else: | |
resize_factor = 672 / height | |
image = image.resize((int(width * resize_factor), 896)) | |
image = image.crop((0, 0, 896, 672)) | |
return np.asarray(image) | |
except Exception as e: | |
# print(e) | |
return None | |
def app(): | |
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL) | |
st.title("Find Features in Images") | |
st.markdown(""" | |
This demo shows the ability of the model to find specific features | |
(specified as text queries) in the image. As an example, say you wish to | |
find the parts of the following image that contain a `beach`, `houses`, | |
or `ships`. We partition the image into tiles of (224, 224) and report | |
how likely each of them are to contain each text features. | |
""") | |
st.image("demo-images/st_tropez_1.png") | |
st.image("demo-images/st_tropez_2.png") | |
st.markdown(""" | |
For this image and the queries listed above, our model reports that the | |
two left tiles are most likely to contain a `beach`, the two top right | |
tiles are most likely to contain `houses`, and the two bottom right tiles | |
are likely to contain `boats`. | |
We have provided a few representative images from [Unsplash](https://unsplash.com/s/photos/aerial-view) | |
that you can experiment with. Use the image name to put in an initial feature | |
to look for, this will show the original image, and you will get more ideas | |
for features that you can ask the model to identify. | |
""") | |
image_file = st.selectbox( | |
"Sample Image File", | |
options=[ | |
"-- select one --", | |
"St-Tropez-Port.jpg", | |
"Acopulco-Bay.jpg", | |
"Highway-through-Forest.jpg", | |
"Forest-with-River.jpg", | |
"Eagle-Bay-Coastline.jpg", | |
"Multistoreyed-Buildings.jpg", | |
"Street-View-Malayasia.jpg", | |
]) | |
image_url = st.text_input( | |
"OR provide an image URL", | |
value="https://static.eos.com/wp-content/uploads/2019/04/Main.jpg") | |
searched_feature = st.text_input("Feature to find", value="beach") | |
if st.button("Find"): | |
if image_file.startswith("--"): | |
image = download_and_prepare_image(image_url) | |
else: | |
image = plt.imread(os.path.join("demo-images", image_file)) | |
if image is None: | |
st.error("Image could not be downloaded, please try another one") | |
else: | |
st.image(image, caption="Input Image") | |
st.markdown("---") | |
num_rows, num_cols, patches = split_image(image) | |
image_preprocessor = Compose([ | |
ToPILImage(), | |
Resize(224) | |
]) | |
num_rows, num_cols, patches = split_image(image) | |
patch_probs = get_patch_probabilities( | |
patches, | |
searched_feature, | |
image_preprocessor, | |
model, | |
processor) | |
patch_ranks = get_image_ranks(patch_probs) | |
pid = 0 | |
for i in range(num_rows): | |
cols = st.columns(num_cols) | |
for col in cols: | |
caption = "#{:d} p({:s})={:.3f}".format( | |
patch_ranks[pid] + 1, searched_feature, patch_probs[pid]) | |
col.image(patches[pid], caption=caption) | |
pid += 1 | |