shb777 commited on
Commit
b225623
·
1 Parent(s): 4371bd7
Files changed (1) hide show
  1. app.py +37 -17
app.py CHANGED
@@ -3,7 +3,8 @@ import random
3
  import torch
4
  import hashlib
5
  import gradio as gr
6
- from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
 
7
 
8
  model_id = "ibm-granite/granite-vision-3.1-2b-preview"
9
  processor = LlavaNextProcessor.from_pretrained(model_id, use_fast=True)
@@ -32,8 +33,8 @@ def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversat
32
  user_content.append({"type": "text", "text": text.strip()})
33
 
34
  if not user_content:
35
- return conversation_display(conversation), conversation
36
-
37
  conversation.append({
38
  "role": "user",
39
  "content": user_content
@@ -63,17 +64,32 @@ def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversat
63
  generation_kwargs["temperature"] = temperature
64
  generation_kwargs["do_sample"] = True
65
 
66
- output = model.generate(**inputs, **generation_kwargs)
67
- raw_response = processor.decode(output[0], skip_special_tokens=True)
68
- assistant_text = extract_answer(raw_response)
69
-
70
- # Append the assistant's answer.
71
  conversation.append({
72
  "role": "assistant",
73
- "content": [{"type": "text", "text": assistant_text}]
74
  })
75
 
76
- return conversation_display(conversation), conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  def extract_answer(response):
79
  if "<|assistant|>" in response:
@@ -142,8 +158,11 @@ def conversation_display(conversation):
142
  })
143
  return chat_history
144
 
145
- def clear_chat():
146
- return [], [], "", None
 
 
 
147
 
148
  with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo:
149
  gr.Markdown("# [Granite Vision 3.1 2B](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview)")
@@ -164,18 +183,19 @@ with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as
164
  send_button = gr.Button("Chat")
165
  clear_button = gr.Button("Clear Chat")
166
 
167
- state = gr.State([])
 
168
 
169
  send_button.click(
170
  chat_inference,
171
- inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, state],
172
- outputs=[chatbot, state]
173
  )
174
 
175
  clear_button.click(
176
  clear_chat,
177
- inputs=None,
178
- outputs=[chatbot, state, text_input, image_input]
179
  )
180
 
181
  gr.Examples(
 
3
  import torch
4
  import hashlib
5
  import gradio as gr
6
+ import threading
7
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, TextIteratorStreamer
8
 
9
  model_id = "ibm-granite/granite-vision-3.1-2b-preview"
10
  processor = LlavaNextProcessor.from_pretrained(model_id, use_fast=True)
 
33
  user_content.append({"type": "text", "text": text.strip()})
34
 
35
  if not user_content:
36
+ return conversation_display(conversation), conversation, "", False
37
+
38
  conversation.append({
39
  "role": "user",
40
  "content": user_content
 
64
  generation_kwargs["temperature"] = temperature
65
  generation_kwargs["do_sample"] = True
66
 
 
 
 
 
 
67
  conversation.append({
68
  "role": "assistant",
69
+ "content": [{"type": "text", "text": ""}]
70
  })
71
 
72
+ yield conversation_display(conversation), conversation, "Processing...", True
73
+
74
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
75
+ generation_kwargs["streamer"] = streamer
76
+
77
+ def generate_thread():
78
+ model.generate(**inputs, **generation_kwargs)
79
+
80
+ thread = threading.Thread(target=generate_thread)
81
+ thread.start()
82
+
83
+ assistant_text = ""
84
+ for new_text in streamer:
85
+ assistant_text += new_text
86
+ conversation[-1]["content"][0]["text"] = extract_answer(assistant_text)
87
+ yield conversation_display(conversation), conversation, "Processing...", True
88
+
89
+ thread.join()
90
+
91
+ yield conversation_display(conversation), conversation, "", False
92
+ return
93
 
94
  def extract_answer(response):
95
  if "<|assistant|>" in response:
 
158
  })
159
  return chat_history
160
 
161
+ def clear_chat(chat_history, conversation, text_value, image, is_generating):
162
+ if is_generating:
163
+ return chat_history, conversation, text_value, image, is_generating
164
+ else:
165
+ return [], [], "", None, is_generating
166
 
167
  with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo:
168
  gr.Markdown("# [Granite Vision 3.1 2B](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview)")
 
183
  send_button = gr.Button("Chat")
184
  clear_button = gr.Button("Clear Chat")
185
 
186
+ conversation_state = gr.State([])
187
+ is_generating = gr.State(False)
188
 
189
  send_button.click(
190
  chat_inference,
191
+ inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, conversation_state],
192
+ outputs=[chatbot, conversation_state, text_input, is_generating]
193
  )
194
 
195
  clear_button.click(
196
  clear_chat,
197
+ inputs=[chatbot, conversation_state, text_input, image_input, is_generating],
198
+ outputs=[chatbot, conversation_state, text_input, image_input, is_generating]
199
  )
200
 
201
  gr.Examples(