Canyu commited on
Commit
cd0e33c
·
1 Parent(s): a51380e
Files changed (1) hide show
  1. app.py +139 -43
app.py CHANGED
@@ -20,8 +20,28 @@ class Examples(gr.helpers.Examples):
20
  self.create()
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # user click the image to get points, and show the points on the image
24
  def get_point(img, sel_pix, evt: gr.SelectData):
 
25
  if len(sel_pix) < 5:
26
  sel_pix.append((evt.index, 1)) # default foreground_point
27
  img = cv2.imread(img)
@@ -54,11 +74,11 @@ def undo_points(orig_img, sel_pix):
54
  return temp, sel_pix
55
 
56
 
57
- # HF_TOKEN = os.environ.get('HF_KEY')
58
 
59
- # client = Client("Canyu/Diception",
60
- # max_workers=3,
61
- # hf_token=HF_TOKEN)
62
 
63
  colors = [(255, 0, 0), (0, 255, 0)]
64
  markers = [1, 5]
@@ -89,12 +109,6 @@ def load_additional_params(model_name):
89
  return additional_params
90
 
91
  def process_image_check(path_input, prompt, sel_points, semantic):
92
- print('=========== PROCESS IMAGE CHECK ===========')
93
- print(f"Image Path: {path_input}")
94
- print(f"Prompt: {prompt}")
95
- print(f"Selected Points (before processing): {sel_points}")
96
- print(f"Semantic Input: {semantic}")
97
- print('===========================================')
98
  if path_input is None:
99
  raise gr.Error(
100
  "Missing image in the left pane: please upload an image first."
@@ -103,23 +117,6 @@ def process_image_check(path_input, prompt, sel_points, semantic):
103
  raise gr.Error(
104
  "At least 1 prediction type is needed."
105
  )
106
- if 'point segmentation' in prompt and len(sel_points) == 0:
107
- raise gr.Error(
108
- "At least 1 point is needed."
109
- )
110
- if 'point segmentation' not in prompt and len(sel_points) != 0:
111
- raise gr.Error(
112
- "You must select 'point segmentation' when performing point segmentation."
113
- )
114
-
115
- if 'semantic segmentation' in prompt and semantic == None:
116
- raise gr.Error(
117
- "Target category is needed."
118
- )
119
- if 'semantic segmentation' not in prompt and semantic != None:
120
- raise gr.Error(
121
- "You must select 'semantic segmentation' when performing semantic segmentation."
122
- )
123
 
124
 
125
 
@@ -146,14 +143,51 @@ def process_image_4(image_path, prompt):
146
 
147
 
148
  def inf(image_path, prompt, sel_points, semantic):
149
- inputs = process_image_4(image_path, prompt, sel_points, semantic)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  # return None
151
- return client.predict(
152
- image=handle_file(image_path),
153
- data=inputs,
 
 
 
 
 
 
154
  api_name="/inf"
155
  )
156
 
 
 
 
157
  def clear_cache():
158
  return None, None
159
 
@@ -162,18 +196,76 @@ def run_demo_server():
162
  gradio_theme = gr.themes.Default()
163
  with gr.Blocks(
164
  theme=gradio_theme,
165
- title="Matting",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  ) as demo:
167
  selected_points = gr.State([]) # store points
168
  original_image = gr.State(value=None) # store original image without points, default None
169
- with gr.Row():
170
- gr.Markdown("# Diception Demo")
171
- with gr.Row():
172
- gr.Markdown("### All results are generated using the same single model. To facilitate input processing, we separate point-prompted segmentation and semantic segmentation, as they require input points and segmentation targets.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  with gr.Row():
174
  checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:")
175
  with gr.Row():
176
  semantic_input = gr.Textbox(label="Category Name (for semantic segmentation only, in COCO)", placeholder="e.g. person/cat/dog/elephant......")
 
 
177
  with gr.Row():
178
  with gr.Column():
179
  input_image = gr.Image(
@@ -184,20 +276,22 @@ def run_demo_server():
184
  with gr.Column():
185
  with gr.Row():
186
  gr.Markdown('You can click on the image to select points prompt. At most 5 point.')
187
- undo_button = gr.Button('Undo point')
188
 
189
- with gr.Row():
190
  matting_image_submit_btn = gr.Button(
191
- value="Estimate Matting", variant="primary"
192
  )
 
 
 
193
  matting_image_reset_btn = gr.Button(value="Reset")
194
 
195
- with gr.Row():
196
- img_clear_button = gr.Button("Clear Cache")
197
 
198
  with gr.Column():
199
  # matting_image_output = gr.Image(label='Output')
200
- matting_image_output = gr.Image(label='Matting Output')
 
201
 
202
  # label="Matting Output",
203
  # type="filepath",
@@ -210,7 +304,7 @@ def run_demo_server():
210
 
211
 
212
 
213
- img_clear_button.click(clear_cache, outputs=[input_image, matting_image_output])
214
 
215
  matting_image_submit_btn.click(
216
  fn=process_image_check,
@@ -230,11 +324,13 @@ def run_demo_server():
230
  fn=lambda: (
231
  None,
232
  None,
 
233
  ),
234
  inputs=[],
235
  outputs=[
236
  input_image,
237
  matting_image_output,
 
238
  ],
239
  queue=False,
240
  )
 
20
  self.create()
21
 
22
 
23
+ def postprocess(output, prompt):
24
+ result = []
25
+ image = Image.open(output)
26
+ w, h = image.size
27
+ n = len(prompt)
28
+ slice_width = w // n
29
+
30
+ for i in range(n):
31
+ left = i * slice_width
32
+ right = (i + 1) * slice_width if i < n - 1 else w
33
+ cropped_img = image.crop((left, 0, right, h))
34
+
35
+ # 生成 caption
36
+ caption = prompt[i]
37
+
38
+ # 存入列表
39
+ result.append((cropped_img, caption))
40
+ return result
41
+
42
  # user click the image to get points, and show the points on the image
43
  def get_point(img, sel_pix, evt: gr.SelectData):
44
+ print(sel_pix)
45
  if len(sel_pix) < 5:
46
  sel_pix.append((evt.index, 1)) # default foreground_point
47
  img = cv2.imread(img)
 
74
  return temp, sel_pix
75
 
76
 
77
+ HF_TOKEN = os.environ.get('HF_KEY')
78
 
79
+ client = Client("Canyu/Diception",
80
+ max_workers=3,
81
+ hf_token=HF_TOKEN)
82
 
83
  colors = [(255, 0, 0), (0, 255, 0)]
84
  markers = [1, 5]
 
109
  return additional_params
110
 
111
  def process_image_check(path_input, prompt, sel_points, semantic):
 
 
 
 
 
 
112
  if path_input is None:
113
  raise gr.Error(
114
  "Missing image in the left pane: please upload an image first."
 
117
  raise gr.Error(
118
  "At least 1 prediction type is needed."
119
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
 
122
 
 
143
 
144
 
145
  def inf(image_path, prompt, sel_points, semantic):
146
+ print('=========== PROCESS IMAGE CHECK ===========')
147
+ print(f"Image Path: {image_path}")
148
+ print(f"Prompt: {prompt}")
149
+ print(f"Selected Points (before processing): {sel_points}")
150
+ print(f"Semantic Input: {semantic}")
151
+ print('===========================================')
152
+
153
+ if 'point segmentation' in prompt and len(sel_points) == 0:
154
+ raise gr.Error(
155
+ "At least 1 point is needed."
156
+ )
157
+ return
158
+ if 'point segmentation' not in prompt and len(sel_points) != 0:
159
+ raise gr.Error(
160
+ "You must select 'point segmentation' when performing point segmentation."
161
+ )
162
+ return
163
+
164
+ if 'semantic segmentation' in prompt and semantic == '':
165
+ raise gr.Error(
166
+ "Target category is needed."
167
+ )
168
+ return
169
+ if 'semantic segmentation' not in prompt and semantic != '':
170
+ raise gr.Error(
171
+ "You must select 'semantic segmentation' when performing semantic segmentation."
172
+ )
173
+ return
174
+
175
  # return None
176
+ # inputs = process_image_4(image_path, prompt, sel_points, semantic)
177
+
178
+ prompt_str = str(sel_points)
179
+
180
+ result = client.predict(
181
+ input_image=handle_file(image_path),
182
+ checkbox_group=prompt,
183
+ selected_points=prompt_str,
184
+ semantic_input=semantic,
185
  api_name="/inf"
186
  )
187
 
188
+ result = postprocess(result, prompt)
189
+ return result
190
+
191
  def clear_cache():
192
  return None, None
193
 
 
196
  gradio_theme = gr.themes.Default()
197
  with gr.Blocks(
198
  theme=gradio_theme,
199
+ title="Diception",
200
+ css="""
201
+ #download {
202
+ height: 118px;
203
+ }
204
+ .slider .inner {
205
+ width: 5px;
206
+ background: #FFF;
207
+ }
208
+ .viewport {
209
+ aspect-ratio: 4/3;
210
+ }
211
+ .tabs button.selected {
212
+ font-size: 20px !important;
213
+ color: crimson !important;
214
+ }
215
+ h1 {
216
+ text-align: center;
217
+ display: block;
218
+ }
219
+ h2 {
220
+ text-align: center;
221
+ display: block;
222
+ }
223
+ h3 {
224
+ text-align: center;
225
+ display: block;
226
+ }
227
+ .md_feedback li {
228
+ margin-bottom: 0px !important;
229
+ }
230
+ """,
231
+ head="""
232
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
233
+ <script>
234
+ window.dataLayer = window.dataLayer || [];
235
+ function gtag() {dataLayer.push(arguments);}
236
+ gtag('js', new Date());
237
+ gtag('config', 'G-1FWSVCGZTG');
238
+ </script>
239
+ """,
240
  ) as demo:
241
  selected_points = gr.State([]) # store points
242
  original_image = gr.State(value=None) # store original image without points, default None
243
+ gr.Markdown(
244
+ """
245
+ # DICEPTION: A Generalist Diffusion Model for Vision Perception
246
+ <p align="center">
247
+ <a title="arXiv" href="https://arxiv.org" target="_blank" rel="noopener noreferrer"
248
+ style="display: inline-block;">
249
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
250
+ </a>
251
+ <a title="Github" href="https://github.com/aim-uofa/Diception" target="_blank" rel="noopener noreferrer"
252
+ style="display: inline-block;">
253
+ <img src="https://img.shields.io/github/stars/aim-uofa/GenPercept?label=GitHub%20%E2%98%85&logo=github&color=C8C"
254
+ alt="badge-github-stars">
255
+ </a>
256
+ </p>
257
+ <p align="justify">
258
+ One single model solves multiple perception tasks, producing impressive results!
259
+ </p>
260
+ """
261
+ )
262
+
263
  with gr.Row():
264
  checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:")
265
  with gr.Row():
266
  semantic_input = gr.Textbox(label="Category Name (for semantic segmentation only, in COCO)", placeholder="e.g. person/cat/dog/elephant......")
267
+ with gr.Row():
268
+ gr.Markdown('For non-human image inputs, the pose results may have issues. Same when perform semantic segmentation with categories that are not in COCO.')
269
  with gr.Row():
270
  with gr.Column():
271
  input_image = gr.Image(
 
276
  with gr.Column():
277
  with gr.Row():
278
  gr.Markdown('You can click on the image to select points prompt. At most 5 point.')
 
279
 
 
280
  matting_image_submit_btn = gr.Button(
281
+ value="Run", variant="primary"
282
  )
283
+
284
+ with gr.Row():
285
+ undo_button = gr.Button('Undo point')
286
  matting_image_reset_btn = gr.Button(value="Reset")
287
 
288
+ # with gr.Row():
289
+ # img_clear_button = gr.Button("Clear Cache")
290
 
291
  with gr.Column():
292
  # matting_image_output = gr.Image(label='Output')
293
+ # matting_image_output = gr.Image(label='Results')
294
+ matting_image_output = gr.Gallery(label="Results")
295
 
296
  # label="Matting Output",
297
  # type="filepath",
 
304
 
305
 
306
 
307
+ # img_clear_button.click(clear_cache, outputs=[input_image, matting_image_output])
308
 
309
  matting_image_submit_btn.click(
310
  fn=process_image_check,
 
324
  fn=lambda: (
325
  None,
326
  None,
327
+ []
328
  ),
329
  inputs=[],
330
  outputs=[
331
  input_image,
332
  matting_image_output,
333
+ selected_points
334
  ],
335
  queue=False,
336
  )