andsteing commited on
Commit
ab79e7e
·
1 Parent(s): ab808e5

Minimal version with lit-tuning-demo data.

Browse files
Files changed (3) hide show
  1. README.md +17 -1
  2. app.py +102 -0
  3. requirements.txt +1 -0
README.md CHANGED
@@ -10,4 +10,20 @@ pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  license: apache-2.0
11
  ---
12
 
13
+
14
+ Simple space for matching texts to images with a contrastive model.
15
+
16
+ Matching Colab:
17
+ https://colab.research.google.com/drive/1f5MpJgE0XCU8ElT34uK4kTUkPnUqvJUt
18
+
19
+
20
+ Local development:
21
+
22
+ 1. `pyenv version 3.10.0`
23
+ 2. `pip install virtualenv`
24
+ 3. `python -m virtualenv env`
25
+ 4. `. env/bin/activate`
26
+ 5. `pip install -r requirements.txt`
27
+ 6. `pip install gradio`
28
+ 7. `python app.py`
29
+
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import functools
3
+ import json
4
+ import logging
5
+ import os
6
+ import time
7
+ import urllib.request
8
+
9
+ import gradio as gr
10
+ import open_clip # works on open-clip-torch>=2.23.0, timm>=0.9.8
11
+ import PIL.Image
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+
16
+ INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json'
17
+ IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg'
18
+
19
+
20
+ @contextlib.contextmanager
21
+ def timed(name):
22
+ t0 = time.monotonic()
23
+ try:
24
+ yield
25
+ finally:
26
+ logging.info('Timed %s: %.1f secs', name, time.monotonic() - t0)
27
+
28
+
29
+ @functools.cache
30
+ def load_model(name='hf-hub:timm/ViT-SO400M-14-SigLIP-384'):
31
+ with timed('loading model, preprocess, tokenizer'):
32
+ t0 = time.time()
33
+ model, preprocess = open_clip.create_model_from_pretrained(name)
34
+ tokenizer = open_clip.get_tokenizer(name)
35
+ logging.info('loaded in %.1fs', time.time() - t0)
36
+ return model, preprocess, tokenizer
37
+
38
+
39
+ def generate_answers(image_path, prompts):
40
+
41
+ model, preprocess, tokenizer = load_model()
42
+
43
+ with torch.no_grad(), torch.cuda.amp.autocast():
44
+ logging.info('Opening image "%s"', image_path)
45
+ with timed(f'opening image "{image_path}"'):
46
+ image = PIL.Image.open(image_path)
47
+ with timed('image features'):
48
+ image = preprocess(image).unsqueeze(0)
49
+ image_features = model.encode_image(image)
50
+
51
+ with timed('text features'):
52
+ prompts = prompts.split(', ')
53
+ text = tokenizer(prompts, context_length=model.context_length)
54
+ text_features = model.encode_text(text)
55
+ image_features = F.normalize(image_features, dim=-1)
56
+ text_features = F.normalize(text_features, dim=-1)
57
+
58
+ exp, bias = model.logit_scale.exp(), model.logit_bias
59
+ text_probs = torch.sigmoid(image_features @ text_features.T * exp + bias)
60
+ return list(zip(prompts, [round(p.item(), 3) for p in text_probs[0]]))
61
+
62
+
63
+ def create_app():
64
+ info = json.load(urllib.request.urlopen(INFO_URL))
65
+
66
+ with gr.Blocks() as demo:
67
+
68
+ gr.Markdown('Minimal gradio clone of [lit-tuning-demo](https://google-research.github.io/vision_transformer/lit/)')
69
+ gr.Markdown('Using `open_clip` implementation of SigLIP model `timm/ViT-SO400M-14-SigLIP-384`')
70
+
71
+ with gr.Row():
72
+ image = gr.Image(label='input_image', type='filepath')
73
+ with gr.Column():
74
+ prompts = gr.Textbox(label='prompts')
75
+ answer = gr.Textbox(label='answer')
76
+ run = gr.Button('Run')
77
+
78
+ gr.Examples(
79
+ examples=[
80
+ [IMG_URL_FMT.format(ex['id']), ex['prompts']]
81
+ for ex in info
82
+ ],
83
+ inputs=[image, prompts],
84
+ outputs=[answer],
85
+ )
86
+
87
+ run.click(fn=generate_answers, inputs=[image, prompts], outputs=[answer])
88
+
89
+ return demo
90
+
91
+
92
+ if __name__ == "__main__":
93
+
94
+ logging.basicConfig(level=logging.INFO,
95
+ format='%(asctime)s - %(levelname)s - %(message)s')
96
+
97
+ for k, v in os.environ.items():
98
+ logging.info('environ["%s"] = %r', k, v)
99
+
100
+ _ = load_model()
101
+
102
+ create_app().queue().launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ open-clip-torch