Akshayram1 commited on
Commit
57c929f
·
verified ·
1 Parent(s): 900613f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -39
app.py CHANGED
@@ -36,19 +36,113 @@ def infer(image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
36
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)
37
  return result[0][len(text):].lstrip("\n")
38
 
39
- # Image Captioning
40
- def generate_caption(image: PIL.Image.Image) -> str:
41
- return infer(image, "caption", max_new_tokens=50)
42
 
43
- # Object Detection
44
- def detect_objects(image: PIL.Image.Image) -> str:
45
- return infer(image, "detect objects", max_new_tokens=200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Visual Question Answering (VQA)
48
- def vqa(image: PIL.Image.Image, question: str) -> str:
49
- return infer(image, f"Q: {question} A:", max_new_tokens=50)
 
50
 
51
- # Gradio App
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  with gr.Blocks() as demo:
53
  gr.Markdown("# PaliGemma Multi-Modal App")
54
  gr.Markdown("Upload an image and explore its features using the PaliGemma model!")
@@ -59,43 +153,23 @@ with gr.Blocks() as demo:
59
  with gr.Row():
60
  with gr.Column():
61
  caption_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
 
62
  caption_btn = gr.Button("Generate Caption")
63
  with gr.Column():
64
  caption_output = gr.Text(label="Generated Caption")
65
- caption_btn.click(fn=generate_caption, inputs=[caption_image], outputs=[caption_output])
66
 
67
- # Tab 2: Object Detection
68
- with gr.Tab("Object Detection"):
69
  with gr.Row():
70
  with gr.Column():
71
  detect_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
72
- detect_btn = gr.Button("Detect Objects")
73
- with gr.Column():
74
- detect_output = gr.Text(label="Detected Objects")
75
- detect_btn.click(fn=detect_objects, inputs=[detect_image], outputs=[detect_output])
76
-
77
- # Tab 3: Visual Question Answering (VQA)
78
- with gr.Tab("Visual Question Answering"):
79
- with gr.Row():
80
- with gr.Column():
81
- vqa_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
82
- vqa_question = gr.Text(label="Ask a Question", placeholder="What is in the image?")
83
- vqa_btn = gr.Button("Ask")
84
- with gr.Column():
85
- vqa_output = gr.Text(label="Answer")
86
- vqa_btn.click(fn=vqa, inputs=[vqa_image, vqa_question], outputs=[vqa_output])
87
-
88
- # Tab 4: Text Generation (Original Feature)
89
- with gr.Tab("Text Generation"):
90
- with gr.Row():
91
- with gr.Column():
92
- text_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
93
- text_input = gr.Text(label="Input Text", placeholder="Describe the image...")
94
- text_btn = gr.Button("Generate Text")
95
  with gr.Column():
96
- text_output = gr.Text(label="Generated Text")
97
- text_btn.click(fn=infer, inputs=[text_image, text_input, gr.Slider(10, 200, value=50)], outputs=[text_output])
98
 
99
  # Launch the App
100
  if __name__ == "__main__":
101
- demo.queue(max_size=10).launch(debug=True)
 
36
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)
37
  return result[0][len(text):].lstrip("\n")
38
 
39
+ # Image Captioning (with user input for improvement)
40
+ def generate_caption(image: PIL.Image.Image, caption_improvement: str) -> str:
41
+ return infer(image, f"caption: {caption_improvement}", max_new_tokens=50)
42
 
43
+ # Object Detection/Segmentation
44
+ def parse_segmentation(input_image, input_text):
45
+ out = infer(input_image, input_text, max_new_tokens=200)
46
+ objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
47
+ labels = set(obj.get('name') for obj in objs if obj.get('name'))
48
+ color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
49
+ highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
50
+ annotated_img = (
51
+ input_image,
52
+ [
53
+ (
54
+ obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
55
+ obj['name'] or '',
56
+ )
57
+ for obj in objs
58
+ if 'mask' in obj or 'xyxy' in obj
59
+ ],
60
+ )
61
+ has_annotations = bool(annotated_img[1])
62
+ return annotated_img
63
 
64
+ # Helper functions for object detection/segmentation
65
+ def _get_params(checkpoint):
66
+ def transp(kernel):
67
+ return np.transpose(kernel, (2, 3, 1, 0))
68
 
69
+ def conv(name):
70
+ return {
71
+ 'bias': checkpoint[name + '.bias'],
72
+ 'kernel': transp(checkpoint[name + '.weight']),
73
+ }
74
+
75
+ def resblock(name):
76
+ return {
77
+ 'Conv_0': conv(name + '.0'),
78
+ 'Conv_1': conv(name + '.2'),
79
+ 'Conv_2': conv(name + '.4'),
80
+ }
81
+
82
+ return {
83
+ '_embeddings': checkpoint['_vq_vae._embedding'],
84
+ 'Conv_0': conv('decoder.0'),
85
+ 'ResBlock_0': resblock('decoder.2.net'),
86
+ 'ResBlock_1': resblock('decoder.3.net'),
87
+ 'ConvTranspose_0': conv('decoder.4'),
88
+ 'ConvTranspose_1': conv('decoder.6'),
89
+ 'ConvTranspose_2': conv('decoder.8'),
90
+ 'ConvTranspose_3': conv('decoder.10'),
91
+ 'Conv_1': conv('decoder.12'),
92
+ }
93
+
94
+ def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
95
+ batch_size, num_tokens = codebook_indices.shape
96
+ assert num_tokens == 16, codebook_indices.shape
97
+ unused_num_embeddings, embedding_dim = embeddings.shape
98
+
99
+ encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
100
+ encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
101
+ return encodings
102
+
103
+ def extract_objs(text, width, height, unique_labels=False):
104
+ objs = []
105
+ seen = set()
106
+ while text:
107
+ m = _SEGMENT_DETECT_RE.match(text)
108
+ if not m:
109
+ break
110
+
111
+ gs = list(m.groups())
112
+ before = gs.pop(0)
113
+ name = gs.pop()
114
+ y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
115
+
116
+ y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
117
+ seg_indices = gs[4:20]
118
+ if seg_indices[0] is None:
119
+ mask = None
120
+ else:
121
+ seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
122
+ m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
123
+ m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
124
+ m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
125
+ mask = np.zeros([height, width])
126
+ if y2 > y1 and x2 > x1:
127
+ mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
128
+
129
+ content = m.group()
130
+ if before:
131
+ objs.append(dict(content=before))
132
+ content = content[len(before):]
133
+ while unique_labels and name in seen:
134
+ name = (name or '') + "'"
135
+ seen.add(name)
136
+ objs.append(dict(
137
+ content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
138
+ text = text[len(before) + len(content):]
139
+
140
+ if text:
141
+ objs.append(dict(content=text))
142
+
143
+ return objs
144
+
145
+ # Gradio Interface
146
  with gr.Blocks() as demo:
147
  gr.Markdown("# PaliGemma Multi-Modal App")
148
  gr.Markdown("Upload an image and explore its features using the PaliGemma model!")
 
153
  with gr.Row():
154
  with gr.Column():
155
  caption_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
156
+ caption_improvement_input = gr.Textbox(label="Improvement Input", placeholder="Enter description to improve caption")
157
  caption_btn = gr.Button("Generate Caption")
158
  with gr.Column():
159
  caption_output = gr.Text(label="Generated Caption")
160
+ caption_btn.click(fn=generate_caption, inputs=[caption_image, caption_improvement_input], outputs=[caption_output])
161
 
162
+ # Tab 2: Segment/Detect
163
+ with gr.Tab("Segment/Detect"):
164
  with gr.Row():
165
  with gr.Column():
166
  detect_image = gr.Image(type="pil", label="Upload Image", width=512, height=512)
167
+ detect_text = gr.Textbox(label="Entities to Detect", placeholder="List entities to segment/detect")
168
+ detect_btn = gr.Button("Detect/Segment")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  with gr.Column():
170
+ detect_output = gr.AnnotatedImage(label="Annotated Image")
171
+ detect_btn.click(fn=parse_segmentation, inputs=[detect_image, detect_text], outputs=[detect_output])
172
 
173
  # Launch the App
174
  if __name__ == "__main__":
175
+ demo.queue(max_size=10).launch(debug=True)