File size: 4,240 Bytes
3f01084
b14f3d4
3f01084
08563b9
 
 
 
 
3f01084
 
08563b9
 
3f01084
e6a9c05
3f01084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6a9c05
3f01084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import random
import spaces
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
from huggingface_hub import hf_hub_download
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

model_path = "ibm-granite/granite-vision-3.1-2b-preview"
processor = AutoProcessor.from_pretrained(model_path)
model = AutoModelForVision2Seq.from_pretrained(model_path).to(device)

@spaces.GPU()
def get_text_from_content(content):
    texts = []
    for item in content:
        if item["type"] == "text":
            texts.append(item["text"])
        elif item["type"] == "image":
            texts.append("[Image]")
    return " ".join(texts)

def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation):
    if conversation is None:
        conversation = []
        
    user_content = []
    if image is not None:
        user_content.append({"type": "image", "image": image})
    if text and text.strip():
        user_content.append({"type": "text", "text": text.strip()})
    if not user_content:
        return conversation_display(conversation), conversation

    conversation.append({
        "role": "user",
        "content": user_content
    })

    inputs = processor.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to(model.device)

    torch.manual_seed(random.randint(0, 10000))

    generation_kwargs = {
        "max_new_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "do_sample": True,
    }

    output = model.generate(**inputs, **generation_kwargs)
    assistant_response = processor.decode(output[0], skip_special_tokens=True)

    conversation.append({
        "role": "assistant",
        "content": [{"type": "text", "text": assistant_response.strip()}]
    })
    
    return conversation_display(conversation), conversation

def conversation_display(conversation):
    chat_history = []
    for msg in conversation:
        if msg["role"] == "user":
            user_text = get_text_from_content(msg["content"])
        elif msg["role"] == "assistant":
            assistant_text = msg["content"][0]["text"].split("<|assistant|>")[-1].strip()
            chat_history.append({"role": "user", "content": user_text})
            chat_history.append({"role": "assistant", "content": assistant_text})
    return chat_history

def clear_chat():
    return [], [], "", None
    
with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo:
    gr.Markdown("# Granite Vision 3.1 2B")
    
    with gr.Row():
        with gr.Column(scale=2):
            image_input = gr.Image(type="pil", label="Upload Image (optional)")
            with gr.Column():
                temperature_input = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.01, label="Temperature")
                top_p_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top p")
                top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top k")
                max_tokens_input = gr.Slider(minimum=10, maximum=300, value=128, step=1, label="Max Tokens")
            
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages')
            text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message")
            with gr.Row():
                send_button = gr.Button("Chat")
                clear_button = gr.Button("Clear Chat")
    

    state = gr.State([])

    send_button.click(
        chat_inference,
        inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, state],
        outputs=[chatbot, state]
    )

    clear_button.click(
        clear_chat,
        inputs=None,
        outputs=[chatbot, state, text_input, image_input]
    )

    gr.Examples(
        examples=[
            ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"]
        ],
        inputs=[image_input, text_input]
    )

if __name__ == "__main__":
    demo.launch()