|
import random |
|
import spaces |
|
import gradio as gr |
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model_path = "ibm-granite/granite-vision-3.1-2b-preview" |
|
processor = AutoProcessor.from_pretrained(model_path) |
|
model = AutoModelForVision2Seq.from_pretrained(model_path).to(device) |
|
|
|
@spaces.GPU() |
|
def get_text_from_content(content): |
|
texts = [] |
|
for item in content: |
|
if item["type"] == "text": |
|
texts.append(item["text"]) |
|
elif item["type"] == "image": |
|
texts.append("[Image]") |
|
return " ".join(texts) |
|
|
|
def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation): |
|
if conversation is None: |
|
conversation = [] |
|
|
|
user_content = [] |
|
if image is not None: |
|
user_content.append({"type": "image", "image": image}) |
|
if text and text.strip(): |
|
user_content.append({"type": "text", "text": text.strip()}) |
|
if not user_content: |
|
return conversation_display(conversation), conversation |
|
|
|
conversation.append({ |
|
"role": "user", |
|
"content": user_content |
|
}) |
|
|
|
inputs = processor.apply_chat_template( |
|
conversation, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt" |
|
).to(model.device) |
|
|
|
torch.manual_seed(random.randint(0, 10000)) |
|
|
|
generation_kwargs = { |
|
"max_new_tokens": max_tokens, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"do_sample": True, |
|
} |
|
|
|
output = model.generate(**inputs, **generation_kwargs) |
|
assistant_response = processor.decode(output[0], skip_special_tokens=True) |
|
|
|
conversation.append({ |
|
"role": "assistant", |
|
"content": [{"type": "text", "text": assistant_response.strip()}] |
|
}) |
|
|
|
return conversation_display(conversation), conversation |
|
|
|
def conversation_display(conversation): |
|
chat_history = [] |
|
for msg in conversation: |
|
if msg["role"] == "user": |
|
user_text = get_text_from_content(msg["content"]) |
|
elif msg["role"] == "assistant": |
|
assistant_text = msg["content"][0]["text"].split("<|assistant|>")[-1].strip() |
|
chat_history.append({"role": "user", "content": user_text}) |
|
chat_history.append({"role": "assistant", "content": assistant_text}) |
|
return chat_history |
|
|
|
def clear_chat(): |
|
return [], [], "", None |
|
|
|
with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo: |
|
gr.Markdown("# Granite Vision 3.1 2B") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
image_input = gr.Image(type="pil", label="Upload Image (optional)") |
|
with gr.Column(): |
|
temperature_input = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.01, label="Temperature") |
|
top_p_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top p") |
|
top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top k") |
|
max_tokens_input = gr.Slider(minimum=10, maximum=300, value=128, step=1, label="Max Tokens") |
|
|
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages') |
|
text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message") |
|
with gr.Row(): |
|
send_button = gr.Button("Chat") |
|
clear_button = gr.Button("Clear Chat") |
|
|
|
|
|
state = gr.State([]) |
|
|
|
send_button.click( |
|
chat_inference, |
|
inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, state], |
|
outputs=[chatbot, state] |
|
) |
|
|
|
clear_button.click( |
|
clear_chat, |
|
inputs=None, |
|
outputs=[chatbot, state, text_input, image_input] |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"] |
|
], |
|
inputs=[image_input, text_input] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |