shb777's picture
streaming
b225623
import spaces
import random
import torch
import hashlib
import gradio as gr
import threading
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, TextIteratorStreamer
model_id = "ibm-granite/granite-vision-3.1-2b-preview"
processor = LlavaNextProcessor.from_pretrained(model_id, use_fast=True)
model = LlavaNextForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
SYSTEM_PROMPT = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
@spaces.GPU
def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation):
if conversation is None or conversation == []:
conversation = [{
"role": "system",
"content": [{"type": "text", "text": SYSTEM_PROMPT}]
}]
user_content = []
if image is not None:
if image.width > 512 or image.height > 512:
image.thumbnail((512, 512))
user_content.append({"type": "image", "image": image})
if text and text.strip():
user_content.append({"type": "text", "text": text.strip()})
if not user_content:
return conversation_display(conversation), conversation, "", False
conversation.append({
"role": "user",
"content": user_content
})
conversation = preprocess_conversation(conversation)
# Generate input prompt using the chat template.
inputs = processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to("cuda")
torch.manual_seed(random.randint(0, 10000))
generation_kwargs = {
"max_new_tokens": max_tokens,
"top_p": top_p,
"top_k": top_k,
"do_sample": True,
}
if temperature > 0:
generation_kwargs["temperature"] = temperature
generation_kwargs["do_sample"] = True
conversation.append({
"role": "assistant",
"content": [{"type": "text", "text": ""}]
})
yield conversation_display(conversation), conversation, "Processing...", True
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs["streamer"] = streamer
def generate_thread():
model.generate(**inputs, **generation_kwargs)
thread = threading.Thread(target=generate_thread)
thread.start()
assistant_text = ""
for new_text in streamer:
assistant_text += new_text
conversation[-1]["content"][0]["text"] = extract_answer(assistant_text)
yield conversation_display(conversation), conversation, "Processing...", True
thread.join()
yield conversation_display(conversation), conversation, "", False
return
def extract_answer(response):
if "<|assistant|>" in response:
return response.split("<|assistant|>")[-1].strip()
return response.strip()
def compute_image_hash(image):
image = image.convert("RGB")
image_bytes = image.tobytes()
return hashlib.md5(image_bytes).hexdigest()
def preprocess_conversation(conversation):
# Find the last sent image in previous user messages (excluding the latest message)
last_image_hash = None
for msg in reversed(conversation[:-1]):
if msg.get("role") == "user":
for item in msg.get("content", []):
if item.get("type") == "image" and item.get("image") is not None:
try:
last_image_hash = compute_image_hash(item["image"])
break
except Exception as e:
continue
if last_image_hash is not None:
break
# Process the latest user message.
latest_msg = conversation[-1]
if latest_msg.get("role") == "user":
new_content = []
for item in latest_msg.get("content", []):
if item.get("type") == "image" and item.get("image") is not None:
try:
current_hash = compute_image_hash(item["image"])
except Exception as e:
current_hash = None
# Remove the image if it matches the last sent image.
if last_image_hash is not None and current_hash is not None and current_hash == last_image_hash:
continue
else:
new_content.append(item)
else:
new_content.append(item)
latest_msg["content"] = new_content
return conversation
def conversation_display(conversation):
chat_history = []
for msg in conversation:
if msg["role"] == "user":
texts = []
for item in msg["content"]:
if item["type"] == "image":
texts.append("<image>")
elif item["type"] == "text":
texts.append(item["text"])
chat_history.append({
"role": "user",
"content": "\n".join(texts)
})
else:
chat_history.append({
"role": msg["role"],
"content": msg["content"][0]["text"]
})
return chat_history
def clear_chat(chat_history, conversation, text_value, image, is_generating):
if is_generating:
return chat_history, conversation, text_value, image, is_generating
else:
return [], [], "", None, is_generating
with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo:
gr.Markdown("# [Granite Vision 3.1 2B](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview)")
with gr.Row():
with gr.Column(scale=2):
image_input = gr.Image(type="pil", label="Upload Image (optional)")
with gr.Column():
temperature_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.01, label="Temperature")
top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.01, label="Top p")
top_k_input = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top k")
max_tokens_input = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max Tokens")
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages')
text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message")
with gr.Row():
send_button = gr.Button("Chat")
clear_button = gr.Button("Clear Chat")
conversation_state = gr.State([])
is_generating = gr.State(False)
send_button.click(
chat_inference,
inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, conversation_state],
outputs=[chatbot, conversation_state, text_input, is_generating]
)
clear_button.click(
clear_chat,
inputs=[chatbot, conversation_state, text_input, image_input, is_generating],
outputs=[chatbot, conversation_state, text_input, image_input, is_generating]
)
gr.Examples(
examples=[
["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"]
],
inputs=[image_input, text_input]
)
if __name__ == "__main__":
demo.launch(show_api=False)