import spaces import random import torch import hashlib import gradio as gr import threading from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, TextIteratorStreamer model_id = "ibm-granite/granite-vision-3.1-2b-preview" processor = LlavaNextProcessor.from_pretrained(model_id, use_fast=True) model = LlavaNextForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto") SYSTEM_PROMPT = ( "A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions." ) @spaces.GPU def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation): if conversation is None or conversation == []: conversation = [{ "role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}] }] user_content = [] if image is not None: if image.width > 512 or image.height > 512: image.thumbnail((512, 512)) user_content.append({"type": "image", "image": image}) if text and text.strip(): user_content.append({"type": "text", "text": text.strip()}) if not user_content: return conversation_display(conversation), conversation, "", False conversation.append({ "role": "user", "content": user_content }) conversation = preprocess_conversation(conversation) # Generate input prompt using the chat template. inputs = processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to("cuda") torch.manual_seed(random.randint(0, 10000)) generation_kwargs = { "max_new_tokens": max_tokens, "top_p": top_p, "top_k": top_k, "do_sample": True, } if temperature > 0: generation_kwargs["temperature"] = temperature generation_kwargs["do_sample"] = True conversation.append({ "role": "assistant", "content": [{"type": "text", "text": ""}] }) yield conversation_display(conversation), conversation, "Processing...", True streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs["streamer"] = streamer def generate_thread(): model.generate(**inputs, **generation_kwargs) thread = threading.Thread(target=generate_thread) thread.start() assistant_text = "" for new_text in streamer: assistant_text += new_text conversation[-1]["content"][0]["text"] = extract_answer(assistant_text) yield conversation_display(conversation), conversation, "Processing...", True thread.join() yield conversation_display(conversation), conversation, "", False return def extract_answer(response): if "<|assistant|>" in response: return response.split("<|assistant|>")[-1].strip() return response.strip() def compute_image_hash(image): image = image.convert("RGB") image_bytes = image.tobytes() return hashlib.md5(image_bytes).hexdigest() def preprocess_conversation(conversation): # Find the last sent image in previous user messages (excluding the latest message) last_image_hash = None for msg in reversed(conversation[:-1]): if msg.get("role") == "user": for item in msg.get("content", []): if item.get("type") == "image" and item.get("image") is not None: try: last_image_hash = compute_image_hash(item["image"]) break except Exception as e: continue if last_image_hash is not None: break # Process the latest user message. latest_msg = conversation[-1] if latest_msg.get("role") == "user": new_content = [] for item in latest_msg.get("content", []): if item.get("type") == "image" and item.get("image") is not None: try: current_hash = compute_image_hash(item["image"]) except Exception as e: current_hash = None # Remove the image if it matches the last sent image. if last_image_hash is not None and current_hash is not None and current_hash == last_image_hash: continue else: new_content.append(item) else: new_content.append(item) latest_msg["content"] = new_content return conversation def conversation_display(conversation): chat_history = [] for msg in conversation: if msg["role"] == "user": texts = [] for item in msg["content"]: if item["type"] == "image": texts.append("") elif item["type"] == "text": texts.append(item["text"]) chat_history.append({ "role": "user", "content": "\n".join(texts) }) else: chat_history.append({ "role": msg["role"], "content": msg["content"][0]["text"] }) return chat_history def clear_chat(chat_history, conversation, text_value, image, is_generating): if is_generating: return chat_history, conversation, text_value, image, is_generating else: return [], [], "", None, is_generating with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo: gr.Markdown("# [Granite Vision 3.1 2B](https://huggingface.co/ibm-granite/granite-vision-3.1-2b-preview)") with gr.Row(): with gr.Column(scale=2): image_input = gr.Image(type="pil", label="Upload Image (optional)") with gr.Column(): temperature_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.01, label="Temperature") top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.01, label="Top p") top_k_input = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top k") max_tokens_input = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max Tokens") with gr.Column(scale=3): chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages') text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message") with gr.Row(): send_button = gr.Button("Chat") clear_button = gr.Button("Clear Chat") conversation_state = gr.State([]) is_generating = gr.State(False) send_button.click( chat_inference, inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, conversation_state], outputs=[chatbot, conversation_state, text_input, is_generating] ) clear_button.click( clear_chat, inputs=[chatbot, conversation_state, text_input, image_input, is_generating], outputs=[chatbot, conversation_state, text_input, image_input, is_generating] ) gr.Examples( examples=[ ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"] ], inputs=[image_input, text_input] ) if __name__ == "__main__": demo.launch(show_api=False)