|
import torch |
|
import numpy as np |
|
import gradio as gr |
|
import spaces |
|
from transformers import AutoTokenizer, AutoModel |
|
import time |
|
import re |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
print(f"Using device: {device}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True) |
|
model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, |
|
torch_dtype=torch.bfloat16).to(device) |
|
|
|
|
|
MASK_TOKEN = "[MASK]" |
|
MASK_ID = 126336 |
|
|
|
def parse_constraints(constraints_text): |
|
"""Parse constraints in format: 'position:word, position:word, ...'""" |
|
constraints = {} |
|
if not constraints_text: |
|
return constraints |
|
|
|
parts = constraints_text.split(',') |
|
for part in parts: |
|
if ':' not in part: |
|
continue |
|
pos_str, word = part.split(':', 1) |
|
try: |
|
pos = int(pos_str.strip()) |
|
word = word.strip() |
|
if word and pos >= 0: |
|
constraints[pos] = word |
|
except ValueError: |
|
continue |
|
|
|
return constraints |
|
|
|
def format_chat_history(history): |
|
""" |
|
Format chat history for the LLaDA model |
|
|
|
Args: |
|
history: List of [user_message, assistant_message] pairs |
|
|
|
Returns: |
|
Formatted conversation for the model |
|
""" |
|
messages = [] |
|
for user_msg, assistant_msg in history: |
|
messages.append({"role": "user", "content": user_msg}) |
|
if assistant_msg: |
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
return messages |
|
|
|
@spaces.GPU |
|
def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=64, steps=32, constraints=None): |
|
""" |
|
Generate text with LLaDA model with visualization of the denoising process |
|
|
|
Args: |
|
messages: List of message dictionaries with 'role' and 'content' |
|
|
|
Returns: |
|
List of visualization states showing the progression and final text |
|
""" |
|
|
|
|
|
if constraints is None: |
|
constraints = {} |
|
|
|
|
|
processed_constraints = {} |
|
for pos, word in constraints.items(): |
|
tokens = tokenizer.encode(" " + word, add_special_tokens=False) |
|
for i, token_id in enumerate(tokens): |
|
processed_constraints[pos + i] = token_id |
|
|
|
|
|
chat_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
|
input_ids = tokenizer(chat_input)['input_ids'] |
|
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) |
|
|
|
|
|
prompt_length = input_ids.shape[1] |
|
|
|
|
|
x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device) |
|
x[:, :prompt_length] = input_ids.clone() |
|
|
|
|
|
visualization_states = [] |
|
|
|
|
|
initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)] |
|
visualization_states.append(initial_state) |
|
|
|
|
|
for pos, token_id in processed_constraints.items(): |
|
absolute_pos = prompt_length + pos |
|
if absolute_pos < x.shape[1]: |
|
x[:, absolute_pos] = token_id |
|
|
|
|
|
timesteps = torch.linspace(1.0, 0.0, steps + 1)[:-1] |
|
|
|
|
|
revealed_tokens = torch.zeros(1, gen_length, dtype=torch.bool).to(device) |
|
|
|
for step, t in enumerate(timesteps): |
|
|
|
s = t - 1.0 / steps if step < steps - 1 else 0 |
|
|
|
|
|
mask_indices = (x == MASK_ID) |
|
|
|
|
|
if not mask_indices.any(): |
|
break |
|
|
|
|
|
logits = model(x).logits |
|
|
|
|
|
x0 = torch.argmax(logits, dim=-1) |
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
top_probs = torch.max(probs, dim=-1)[0] |
|
|
|
|
|
x_old = x.clone() |
|
x = torch.where(mask_indices, x0, x) |
|
|
|
|
|
total_len = gen_length |
|
current_t_value = float(t) |
|
next_t_value = float(s) |
|
|
|
|
|
current_masks_expected = int(current_t_value * total_len) |
|
next_masks_expected = int(next_t_value * total_len) |
|
|
|
|
|
tokens_to_unmask = current_masks_expected - next_masks_expected |
|
|
|
if tokens_to_unmask > 0 and mask_indices.any(): |
|
|
|
confidence_scores = top_probs[mask_indices] |
|
|
|
|
|
sorted_indices = torch.argsort(confidence_scores, descending=True) |
|
|
|
|
|
indices_to_remask = sorted_indices[tokens_to_unmask:] |
|
|
|
|
|
mask_positions = torch.where(mask_indices)[1] |
|
positions_to_remask = mask_positions[indices_to_remask] |
|
|
|
|
|
x[:, positions_to_remask] = MASK_ID |
|
|
|
|
|
for pos, token_id in processed_constraints.items(): |
|
absolute_pos = prompt_length + pos |
|
if absolute_pos < x.shape[1]: |
|
x[:, absolute_pos] = token_id |
|
|
|
|
|
current_state = [] |
|
|
|
|
|
for i in range(gen_length): |
|
pos = prompt_length + i |
|
|
|
if x[0, pos] == MASK_ID: |
|
|
|
current_state.append((MASK_TOKEN, "#444444")) |
|
|
|
elif x_old[0, pos] == MASK_ID: |
|
|
|
token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True) |
|
confidence = float(top_probs[0, pos].cpu()) |
|
|
|
|
|
if confidence < 0.3: |
|
color = "#FF6666" |
|
elif confidence < 0.7: |
|
color = "#FFAA33" |
|
else: |
|
color = "#66CC66" |
|
|
|
current_state.append((token, color)) |
|
revealed_tokens[0, i] = True |
|
|
|
else: |
|
|
|
token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True) |
|
current_state.append((token, "#6699CC")) |
|
|
|
visualization_states.append(current_state) |
|
|
|
|
|
response_tokens = x[0, prompt_length:] |
|
response_text = tokenizer.decode(response_tokens, skip_special_tokens=True) |
|
|
|
|
|
final_text = tokenizer.decode(response_tokens, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True) |
|
|
|
return visualization_states, final_text |
|
|
|
css = ''' |
|
.category-legend{display:none} |
|
button{height: 60px} |
|
''' |
|
def create_chatbot_demo(): |
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown("# LLaDA - Large Language Diffusion Model demo") |
|
gr.Markdown("[model](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct), [project page](https://ml-gsai.github.io/LLaDA-demo/)") |
|
|
|
|
|
|
|
chat_history = gr.State([]) |
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
chatbot_ui = gr.Chatbot(label="Conversation", height=500) |
|
|
|
|
|
with gr.Group(): |
|
with gr.Row(): |
|
user_input = gr.Textbox( |
|
label="Your Message", |
|
placeholder="Type your message here...", |
|
show_label=False |
|
) |
|
send_btn = gr.Button("Send") |
|
|
|
constraints_input = gr.Textbox( |
|
label="Word Constraints", |
|
info="This model allows for placing specific words at specific positions using 'position:word' format. Example: 1st word once, 6th word 'upon' and 11th word 'time', would be: '0:Once, 5:upon, 10:time", |
|
placeholder="0:Once, 5:upon, 10:time", |
|
value="" |
|
) |
|
with gr.Column(scale=2): |
|
output_vis = gr.HighlightedText( |
|
label="Denoising Process Visualization", |
|
combine_adjacent=False, |
|
show_legend=True, |
|
) |
|
|
|
with gr.Accordion("Generation Settings", open=False): |
|
with gr.Row(): |
|
gen_length = gr.Slider( |
|
minimum=16, maximum=128, value=64, step=8, |
|
label="Generation Length" |
|
) |
|
steps = gr.Slider( |
|
minimum=8, maximum=64, value=32, step=4, |
|
label="Denoising Steps" |
|
) |
|
|
|
|
|
visualization_delay = gr.Slider( |
|
minimum=0.0, maximum=1.0, value=0.1, step=0.1, visible=False, |
|
label="Visualization Delay (seconds)" |
|
) |
|
|
|
|
|
current_response = gr.Textbox( |
|
label="Current Response", |
|
placeholder="The assistant's response will appear here...", |
|
lines=3, |
|
visible=False |
|
) |
|
|
|
|
|
clear_btn = gr.Button("Clear Conversation") |
|
|
|
|
|
def add_message(history, message, response): |
|
"""Add a message pair to the history and return the updated history""" |
|
history = history.copy() |
|
history.append([message, response]) |
|
return history |
|
|
|
def user_message_submitted(message, history, gen_length, steps, constraints, delay): |
|
"""Process a submitted user message""" |
|
|
|
if not message.strip(): |
|
|
|
history_for_display = history.copy() |
|
return history, history_for_display, "", [], "" |
|
|
|
|
|
history = add_message(history, message, None) |
|
|
|
|
|
history_for_display = history.copy() |
|
|
|
|
|
message_out = "" |
|
|
|
|
|
return history, history_for_display, message_out, [], "" |
|
|
|
def bot_response(history, gen_length, steps, constraints, delay): |
|
"""Generate bot response for the latest message""" |
|
if not history: |
|
return history, [], "" |
|
|
|
|
|
last_user_message = history[-1][0] |
|
|
|
try: |
|
|
|
messages = format_chat_history(history[:-1]) |
|
|
|
|
|
messages.append({"role": "user", "content": last_user_message}) |
|
|
|
|
|
parsed_constraints = parse_constraints(constraints) |
|
|
|
|
|
vis_states, response_text = generate_response_with_visualization( |
|
model, tokenizer, device, |
|
messages, |
|
gen_length=gen_length, |
|
steps=steps, |
|
constraints=parsed_constraints |
|
) |
|
|
|
|
|
history[-1][1] = response_text |
|
|
|
|
|
yield history, vis_states[0], response_text |
|
|
|
|
|
for state in vis_states[1:]: |
|
time.sleep(delay) |
|
yield history, state, response_text |
|
|
|
except Exception as e: |
|
error_msg = f"Error: {str(e)}" |
|
print(error_msg) |
|
|
|
|
|
error_vis = [(error_msg, "red")] |
|
|
|
|
|
yield history, error_vis, error_msg |
|
|
|
def clear_conversation(): |
|
"""Clear the conversation history""" |
|
return [], [], "", [] |
|
|
|
|
|
|
|
|
|
clear_btn.click( |
|
fn=clear_conversation, |
|
inputs=[], |
|
outputs=[chat_history, chatbot_ui, current_response, output_vis] |
|
) |
|
|
|
|
|
|
|
msg_submit = user_input.submit( |
|
fn=user_message_submitted, |
|
inputs=[user_input, chat_history, gen_length, steps, constraints_input, visualization_delay], |
|
outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response] |
|
) |
|
|
|
|
|
send_click = send_btn.click( |
|
fn=user_message_submitted, |
|
inputs=[user_input, chat_history, gen_length, steps, constraints_input, visualization_delay], |
|
outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response] |
|
) |
|
|
|
|
|
|
|
msg_submit.then( |
|
fn=bot_response, |
|
inputs=[chat_history, gen_length, steps, constraints_input, visualization_delay], |
|
outputs=[chatbot_ui, output_vis, current_response] |
|
) |
|
|
|
send_click.then( |
|
fn=bot_response, |
|
inputs=[chat_history, gen_length, steps, constraints_input, visualization_delay], |
|
outputs=[chatbot_ui, output_vis, current_response] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_chatbot_demo() |
|
demo.queue().launch(share=True) |