Antoine Chaffin commited on
Commit
349b5c2
·
1 Parent(s): d7e0b8c

Initial commit

Browse files
Files changed (4) hide show
  1. app.py +106 -0
  2. model.py +118 -0
  3. requirements.txt +9 -0
  4. voyager_index.py +221 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from qwen_vl_utils import process_vision_info
6
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
7
+ from voyager_index import Voyager
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ device = "cpu"
11
+
12
+ # Initialize the model and processor
13
+ model = (
14
+ Qwen2VLForConditionalGeneration.from_pretrained(
15
+ "Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
16
+ )
17
+ .to(device)
18
+ .eval()
19
+ )
20
+
21
+ processor = AutoProcessor.from_pretrained(
22
+ "Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True
23
+ )
24
+
25
+
26
+ def create_index(session_id):
27
+ return Voyager(embedding_size=1536, override=True, index_name=f"{session_id}")
28
+
29
+
30
+ def add_to_index(files, index):
31
+ index.add_documents([file.name for file in files], batch_size=1)
32
+ return f"Added {len(files)} files to the index."
33
+
34
+
35
+ def query_index(query, index):
36
+ res = index(query, k=1)
37
+ retrieved_image = res["documents"][0][0]["image"]
38
+
39
+ messages = [
40
+ {
41
+ "role": "user",
42
+ "content": [
43
+ {
44
+ "type": "image",
45
+ "image": retrieved_image,
46
+ },
47
+ {"type": "text", "text": query},
48
+ ],
49
+ }
50
+ ]
51
+ text = processor.apply_chat_template(
52
+ messages, tokenize=False, add_generation_prompt=True
53
+ )
54
+
55
+ image_inputs, video_inputs = process_vision_info(messages)
56
+ inputs = processor(
57
+ text=[text],
58
+ images=image_inputs,
59
+ videos=video_inputs,
60
+ padding=True,
61
+ return_tensors="pt",
62
+ )
63
+ inputs = inputs.to(device)
64
+ generated_ids = model.generate(**inputs, max_new_tokens=200)
65
+ generated_ids_trimmed = [
66
+ out_ids[len(in_ids) :]
67
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
68
+ ]
69
+ output_text = processor.batch_decode(
70
+ generated_ids_trimmed,
71
+ skip_special_tokens=True,
72
+ clean_up_tokenization_spaces=False,
73
+ )
74
+
75
+ return output_text[0], retrieved_image
76
+
77
+
78
+ # Define the Gradio interface
79
+ with gr.Blocks() as demo:
80
+ session_id = gr.State(lambda: str(uuid.uuid4()))
81
+ index = gr.State(lambda: create_index(session_id.value))
82
+
83
+ gr.Markdown("# Full vision pipeline demo")
84
+
85
+ with gr.Tab("Add to Index"):
86
+ file_input = gr.File(file_count="multiple", label="Upload Files")
87
+ add_button = gr.Button("Add to Index")
88
+ add_output = gr.Textbox(label="Result")
89
+
90
+ add_button.click(add_to_index, inputs=[file_input, index], outputs=add_output)
91
+
92
+ with gr.Tab("Query Index"):
93
+ query_input = gr.Textbox(label="Enter your query")
94
+ query_button = gr.Button("Submit Query")
95
+ with gr.Row():
96
+ query_output = gr.Textbox(label="Answer")
97
+ image_output = gr.Image(label="Retrieved Image")
98
+
99
+ query_button.click(
100
+ query_index,
101
+ inputs=[query_input, index],
102
+ outputs=[query_output, image_output],
103
+ )
104
+
105
+ # Launch the interface
106
+ demo.launch()
model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from qwen_vl_utils import process_vision_info
4
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ # device = "cpu"
8
+
9
+ min_pixels = 1 * 28 * 28
10
+ max_pixels = 256 * 28 * 28 # 2560 * 28 * 28
11
+
12
+
13
+ processor = AutoProcessor.from_pretrained(
14
+ "MrLight/dse-qwen2-2b-mrl-v1", min_pixels=min_pixels, max_pixels=max_pixels
15
+ )
16
+ model = (
17
+ Qwen2VLForConditionalGeneration.from_pretrained(
18
+ "MrLight/dse-qwen2-2b-mrl-v1",
19
+ # attn_implementation="eager",
20
+ attn_implementation="flash_attention_2"
21
+ if device == "cuda"
22
+ else "eager", # flash_attn is required but is a pain to install on spaces
23
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
24
+ )
25
+ .to(device)
26
+ .eval()
27
+ )
28
+ processor.tokenizer.padding_side = "left"
29
+ model.padding_side = "left"
30
+
31
+
32
+ def get_embedding(last_hidden_state: torch.Tensor, dimension: int):
33
+ reps = last_hidden_state[:, -1]
34
+ reps = torch.nn.functional.normalize(reps[:, :dimension], p=2, dim=-1)
35
+ return reps.to(torch.float32).cpu().numpy()
36
+
37
+
38
+ def encode_queries(queries: list):
39
+ if isinstance(queries, str):
40
+ queries = [queries]
41
+ query_messages = []
42
+ for query in queries:
43
+ message = [
44
+ {
45
+ "role": "user",
46
+ "content": [
47
+ {
48
+ "type": "image",
49
+ "image": Image.new("RGB", (28, 28)),
50
+ "resized_height": 1,
51
+ "resized_width": 1,
52
+ }, # need a dummy image here for an easier process.
53
+ {"type": "text", "text": f"Query: {query}"},
54
+ ],
55
+ }
56
+ ]
57
+ query_messages.append(message)
58
+ query_texts = [
59
+ processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
60
+ + "<|endoftext|>"
61
+ for msg in query_messages
62
+ ]
63
+ query_image_inputs, query_video_inputs = process_vision_info(query_messages)
64
+ query_inputs = processor(
65
+ text=query_texts,
66
+ images=query_image_inputs,
67
+ videos=query_video_inputs,
68
+ padding="longest",
69
+ return_tensors="pt",
70
+ ).to(device)
71
+ query_inputs = model.prepare_inputs_for_generation(**query_inputs, use_cache=False)
72
+ with torch.no_grad():
73
+ output = model(**query_inputs, return_dict=True, output_hidden_states=True)
74
+ query_embeddings = get_embedding(
75
+ output.hidden_states[-1], 1536
76
+ ) # adjust dimensionality for efficiency trade-off, e.g. 512
77
+ return query_embeddings
78
+
79
+
80
+ def encode_images(images: list):
81
+ if isinstance(images, Image.Image):
82
+ images = [images]
83
+ doc_messages = []
84
+ for image in images:
85
+ message = [
86
+ {
87
+ "role": "user",
88
+ "content": [
89
+ {
90
+ "type": "image",
91
+ "image": image,
92
+ }, #'resized_height':680 , 'resized_width':680} # adjust the image size for efficiency trade-off
93
+ {"type": "text", "text": "What is shown in this image?"},
94
+ ],
95
+ }
96
+ ]
97
+ doc_messages.append(message)
98
+ doc_texts = [
99
+ processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
100
+ + "<|endoftext|>"
101
+ for msg in doc_messages
102
+ ]
103
+ doc_image_inputs, doc_video_inputs = process_vision_info(doc_messages)
104
+ doc_inputs = processor(
105
+ text=doc_texts,
106
+ images=doc_image_inputs,
107
+ videos=doc_video_inputs,
108
+ padding="longest",
109
+ return_tensors="pt",
110
+ ).to(device)
111
+ doc_inputs = model.prepare_inputs_for_generation(**doc_inputs, use_cache=False)
112
+ output = model(**doc_inputs, return_dict=True, output_hidden_states=True)
113
+ with torch.no_grad():
114
+ output = model(**doc_inputs, return_dict=True, output_hidden_states=True)
115
+ doc_embeddings = get_embedding(
116
+ output.hidden_states[-1], 1536
117
+ ) # adjust dimensionality for efficiency trade-off e.g. 512
118
+ return doc_embeddings
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ git+https://github.com/huggingface/transformers.git@refs/pull/33654/head#egg=transformers #git+https://github.com/huggingface/transformers #transformers
4
+ qwen-vl-utils
5
+ gradio
6
+ pypdfium2
7
+ # flash_attn # https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch1.12cxx11abiFALSE-cp310-cp310-linux_x86_64.whl #flash_attn
8
+ sqlitedict
9
+ voyager
voyager_index.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import pypdfium2 as pdfium
5
+ import torch
6
+ import tqdm
7
+ from model import encode_images, encode_queries
8
+ from PIL import Image
9
+ from sqlitedict import SqliteDict
10
+ from voyager import Index, Space
11
+
12
+
13
+ def iter_batch(
14
+ X: list[str], batch_size: int, tqdm_bar: bool = True, desc: str = ""
15
+ ) -> list:
16
+ """Iterate over a list of elements by batch."""
17
+ batchs = [X[pos : pos + batch_size] for pos in range(0, len(X), batch_size)]
18
+
19
+ if tqdm_bar:
20
+ for batch in tqdm.tqdm(
21
+ iterable=batchs,
22
+ position=0,
23
+ total=1 + len(X) // batch_size,
24
+ desc=desc,
25
+ ):
26
+ yield batch
27
+ else:
28
+ yield from batchs
29
+
30
+
31
+ class Voyager:
32
+ """Voyager index. The Voyager index is a fast and efficient index for approximate nearest neighbor search.
33
+
34
+ Parameters
35
+ ----------
36
+ name
37
+ The name of the collection.
38
+ override
39
+ Whether to override the collection if it already exists.
40
+ embedding_size
41
+ The number of dimensions of the embeddings.
42
+ M
43
+ The number of subquantizers.
44
+ ef_construction
45
+ The number of candidates to evaluate during the construction of the index.
46
+ ef_search
47
+ The number of candidates to evaluate during the search.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ index_folder: str = "indexes",
53
+ index_name: str = "base_collection",
54
+ override: bool = False,
55
+ embedding_size: int = 128,
56
+ M: int = 64,
57
+ ef_construction: int = 200,
58
+ ef_search: int = 200,
59
+ ) -> None:
60
+ self.ef_search = ef_search
61
+
62
+ if not os.path.exists(path=index_folder):
63
+ os.makedirs(name=index_folder)
64
+
65
+ self.index_path = os.path.join(index_folder, f"{index_name}.voyager")
66
+ self.page_ids_to_data_path = os.path.join(
67
+ index_folder, f"{index_name}_page_ids_to_data.sqlite"
68
+ )
69
+
70
+ self.index = self._create_collection(
71
+ index_path=self.index_path,
72
+ embedding_size=embedding_size,
73
+ M=M,
74
+ ef_constructions=ef_construction,
75
+ override=override,
76
+ )
77
+
78
+ def _load_page_ids_to_data(self) -> SqliteDict:
79
+ """Load the SQLite database that maps document IDs to images."""
80
+ return SqliteDict(self.page_ids_to_data_path, outer_stack=False)
81
+
82
+ def _create_collection(
83
+ self,
84
+ index_path: str,
85
+ embedding_size: int,
86
+ M: int,
87
+ ef_constructions: int,
88
+ override: bool,
89
+ ) -> None:
90
+ """Create a new Voyager collection.
91
+
92
+ Parameters
93
+ ----------
94
+ index_path
95
+ The path to the index.
96
+ embedding_size
97
+ The size of the embeddings.
98
+ M
99
+ The number of subquantizers.
100
+ ef_constructions
101
+ The number of candidates to evaluate during the construction of the index.
102
+ override
103
+ Whether to override the collection if it already exists.
104
+
105
+ """
106
+ if os.path.exists(path=index_path) and not override:
107
+ return Index.load(index_path)
108
+
109
+ if os.path.exists(path=index_path):
110
+ os.remove(index_path)
111
+
112
+ # Create the Voyager index
113
+ index = Index(
114
+ Space.Cosine,
115
+ num_dimensions=embedding_size,
116
+ M=M,
117
+ ef_construction=ef_constructions,
118
+ )
119
+
120
+ index.save(index_path)
121
+
122
+ if override and os.path.exists(path=self.page_ids_to_data_path):
123
+ os.remove(path=self.page_ids_to_data_path)
124
+
125
+ # Create the SQLite databases
126
+ page_ids_to_data = self._load_page_ids_to_data()
127
+ page_ids_to_data.close()
128
+ return index
129
+
130
+ def add_documents(
131
+ self,
132
+ paths: str | list[str],
133
+ batch_size: int = 1,
134
+ ) -> None:
135
+ """Add documents to the index. Note that batch_size means the number of pages to encode at once, not documents."""
136
+ if isinstance(paths, str):
137
+ paths = [paths]
138
+
139
+ page_ids_to_data = self._load_page_ids_to_data()
140
+
141
+ images = []
142
+ num_pages = []
143
+
144
+ for path in paths:
145
+ if path.lower().endswith(".pdf"):
146
+ pdf = pdfium.PdfDocument(path)
147
+ n_pages = len(pdf)
148
+ num_pages.append(n_pages)
149
+ for page_number in range(n_pages):
150
+ page = pdf.get_page(page_number)
151
+ pil_image = page.render(
152
+ scale=1,
153
+ rotation=0,
154
+ )
155
+ pil_image = pil_image.to_pil()
156
+ images.append(pil_image)
157
+ pdf.close()
158
+ else:
159
+ pil_image = Image.open(path)
160
+ images.append(pil_image)
161
+ num_pages.append(1)
162
+
163
+ embeddings = []
164
+ for batch in iter_batch(
165
+ X=images, batch_size=batch_size, desc=f"Encoding pages (bs={batch_size})"
166
+ ):
167
+ embeddings.extend(encode_images(batch))
168
+
169
+ embeddings_ids = self.index.add_items(embeddings)
170
+ current_index = 0
171
+
172
+ for i, path in enumerate(paths):
173
+ for page_number in range(num_pages[i]):
174
+ page_ids_to_data[embeddings_ids[current_index]] = {
175
+ "path": path,
176
+ "image": images[current_index],
177
+ "page_number": page_number,
178
+ }
179
+ current_index += 1
180
+
181
+ page_ids_to_data.commit()
182
+ self.index.save(self.index_path)
183
+
184
+ return self
185
+
186
+ def __call__(
187
+ self,
188
+ queries: np.ndarray | torch.Tensor,
189
+ k: int = 10,
190
+ ) -> dict:
191
+ """Query the index for the nearest neighbors of the queries embeddings.
192
+
193
+ Parameters
194
+ ----------
195
+ queries_embeddings
196
+ The queries embeddings.
197
+ k
198
+ The number of nearest neighbors to return.
199
+
200
+ """
201
+
202
+ queries_embeddings = encode_queries(queries)
203
+ page_ids_to_data = self._load_page_ids_to_data()
204
+ k = min(k, len(page_ids_to_data))
205
+
206
+ n_queries = len(queries_embeddings)
207
+ indices, distances = self.index.query(
208
+ queries_embeddings, k, query_ef=self.ef_search
209
+ )
210
+
211
+ if len(indices) == 0:
212
+ raise ValueError("Index is empty, add documents before querying.")
213
+ documents = [
214
+ [page_ids_to_data[str(indice)] for indice in query_indices]
215
+ for query_indices in indices
216
+ ]
217
+ page_ids_to_data.close()
218
+ return {
219
+ "documents": documents,
220
+ "distances": distances.reshape(n_queries, -1, k),
221
+ }