import torch import numpy as np import gradio as gr import spaces from transformers import AutoTokenizer, AutoModel import time import re device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True) model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) # Constants MASK_TOKEN = "[MASK]" MASK_ID = 126336 # The token ID of [MASK] in LLaDA def parse_constraints(constraints_text): """Parse constraints in format: 'position:word, position:word, ...'""" constraints = {} if not constraints_text: return constraints parts = constraints_text.split(',') for part in parts: if ':' not in part: continue pos_str, word = part.split(':', 1) try: pos = int(pos_str.strip()) word = word.strip() if word and pos >= 0: constraints[pos] = word except ValueError: continue return constraints def format_chat_history(history): """ Format chat history for the LLaDA model Args: history: List of [user_message, assistant_message] pairs Returns: Formatted conversation for the model """ messages = [] for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) if assistant_msg: # Skip if None (for the latest user message) messages.append({"role": "assistant", "content": assistant_msg}) return messages @spaces.GPU def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=64, steps=32, constraints=None): """ Generate text with LLaDA model with visualization of the denoising process Args: messages: List of message dictionaries with 'role' and 'content' Returns: List of visualization states showing the progression and final text """ # Process constraints if constraints is None: constraints = {} # Convert any string constraints to token IDs processed_constraints = {} for pos, word in constraints.items(): tokens = tokenizer.encode(" " + word, add_special_tokens=False) for i, token_id in enumerate(tokens): processed_constraints[pos + i] = token_id # Prepare the prompt using chat template chat_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) input_ids = tokenizer(chat_input)['input_ids'] input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) # For generation prompt_length = input_ids.shape[1] # Initialize the sequence with masks for the response part x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device) x[:, :prompt_length] = input_ids.clone() # Initialize visualization states for just the response part visualization_states = [] # Add initial state (all masked) - only for the response part initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)] visualization_states.append(initial_state) # Apply constraints to the initial state for pos, token_id in processed_constraints.items(): absolute_pos = prompt_length + pos if absolute_pos < x.shape[1]: x[:, absolute_pos] = token_id # Calculate timesteps timesteps = torch.linspace(1.0, 0.0, steps + 1)[:-1] # Keep track of already revealed tokens revealed_tokens = torch.zeros(1, gen_length, dtype=torch.bool).to(device) for step, t in enumerate(timesteps): # Current t to next t s = t - 1.0 / steps if step < steps - 1 else 0 # Get all mask positions in the current sequence mask_indices = (x == MASK_ID) # Skip if no masks if not mask_indices.any(): break # Get logits from the model logits = model(x).logits # Get the top predictions x0 = torch.argmax(logits, dim=-1) # Get probabilities for visualization probs = torch.softmax(logits, dim=-1) top_probs = torch.max(probs, dim=-1)[0] # Apply the predictions where we have masks x_old = x.clone() x = torch.where(mask_indices, x0, x) # Calculate how many tokens should remain masked at next step total_len = gen_length current_t_value = float(t) next_t_value = float(s) # Linear schedule: t=1 → all masked, t=0 → none masked current_masks_expected = int(current_t_value * total_len) next_masks_expected = int(next_t_value * total_len) # How many to unmask in this step tokens_to_unmask = current_masks_expected - next_masks_expected if tokens_to_unmask > 0 and mask_indices.any(): # Get confidence scores for currently masked tokens confidence_scores = top_probs[mask_indices] # Sort confidence scores sorted_indices = torch.argsort(confidence_scores, descending=True) # Select which tokens to keep masked (the lowest confidence ones) indices_to_remask = sorted_indices[tokens_to_unmask:] # Get the actual indices in the sequence mask_positions = torch.where(mask_indices)[1] positions_to_remask = mask_positions[indices_to_remask] # Remask these positions x[:, positions_to_remask] = MASK_ID # Ensure constraints are maintained for pos, token_id in processed_constraints.items(): absolute_pos = prompt_length + pos if absolute_pos < x.shape[1]: x[:, absolute_pos] = token_id # Create visualization state ONLY for the response part current_state = [] # Update which tokens are newly revealed in this step for i in range(gen_length): pos = prompt_length + i # Absolute position in the sequence if x[0, pos] == MASK_ID: # Still masked current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks elif x_old[0, pos] == MASK_ID: # Newly revealed in this step token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True) confidence = float(top_probs[0, pos].cpu()) # Color based on confidence: red (low) to green (high) if confidence < 0.3: color = "#FF6666" # Light red elif confidence < 0.7: color = "#FFAA33" # Orange else: color = "#66CC66" # Light green current_state.append((token, color)) revealed_tokens[0, i] = True else: # Previously revealed token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True) current_state.append((token, "#6699CC")) # Light blue visualization_states.append(current_state) # Extract final text (just the assistant's response) response_tokens = x[0, prompt_length:] response_text = tokenizer.decode(response_tokens, skip_special_tokens=True) # Clean the response text final_text = tokenizer.decode(response_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True) return visualization_states, final_text css = ''' .category-legend{display:none} button{height: 60px} ''' def create_chatbot_demo(): with gr.Blocks(css=css) as demo: gr.Markdown("# LLaDA - Large Language Diffusion Model demo") gr.Markdown("[model](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct), [project page](https://ml-gsai.github.io/LLaDA-demo/)") # STATE MANAGEMENT - IMPORTANT # We use a dedicated state to track the full conversation history chat_history = gr.State([]) # UI COMPONENTS # Chatbot for displaying messages with gr.Row(): with gr.Column(scale=3): chatbot_ui = gr.Chatbot(label="Conversation", height=500) # Message input with gr.Group(): with gr.Row(): user_input = gr.Textbox( label="Your Message", placeholder="Type your message here...", show_label=False ) send_btn = gr.Button("Send") constraints_input = gr.Textbox( label="Word Constraints", info="This model allows for placing specific words at specific positions using 'position:word' format. Example: 1st word once, 6th word 'upon' and 11th word 'time', would be: '0:Once, 5:upon, 10:time", placeholder="0:Once, 5:upon, 10:time", value="" ) with gr.Column(scale=2): output_vis = gr.HighlightedText( label="Denoising Process Visualization", combine_adjacent=False, show_legend=True, ) # Visualization and response components with gr.Accordion("Generation Settings", open=False): with gr.Row(): gen_length = gr.Slider( minimum=16, maximum=128, value=64, step=8, label="Generation Length" ) steps = gr.Slider( minimum=8, maximum=64, value=32, step=4, label="Denoising Steps" ) visualization_delay = gr.Slider( minimum=0.0, maximum=1.0, value=0.1, step=0.1, visible=False, label="Visualization Delay (seconds)" ) # Current response text box current_response = gr.Textbox( label="Current Response", placeholder="The assistant's response will appear here...", lines=3, visible=False ) # Clear button clear_btn = gr.Button("Clear Conversation") # HELPER FUNCTIONS def add_message(history, message, response): """Add a message pair to the history and return the updated history""" history = history.copy() history.append([message, response]) return history def user_message_submitted(message, history, gen_length, steps, constraints, delay): """Process a submitted user message""" # Skip empty messages if not message.strip(): # Return current state unchanged history_for_display = history.copy() return history, history_for_display, "", [], "" # Add user message to history history = add_message(history, message, None) # Format for display - temporarily show user message with empty response history_for_display = history.copy() # Clear the input message_out = "" # Return immediately to update UI with user message return history, history_for_display, message_out, [], "" def bot_response(history, gen_length, steps, constraints, delay): """Generate bot response for the latest message""" if not history: return history, [], "" # Get the last user message last_user_message = history[-1][0] try: # Format all messages except the last one (which has no response yet) messages = format_chat_history(history[:-1]) # Add the last user message messages.append({"role": "user", "content": last_user_message}) # Parse constraints parsed_constraints = parse_constraints(constraints) # Generate response with visualization vis_states, response_text = generate_response_with_visualization( model, tokenizer, device, messages, gen_length=gen_length, steps=steps, constraints=parsed_constraints ) # Update history with the assistant's response history[-1][1] = response_text # Return the initial state immediately yield history, vis_states[0], response_text # Then animate through visualization states for state in vis_states[1:]: time.sleep(delay) yield history, state, response_text except Exception as e: error_msg = f"Error: {str(e)}" print(error_msg) # Show error in visualization error_vis = [(error_msg, "red")] # Don't update history with error yield history, error_vis, error_msg def clear_conversation(): """Clear the conversation history""" return [], [], "", [] # EVENT HANDLERS # Clear button handler clear_btn.click( fn=clear_conversation, inputs=[], outputs=[chat_history, chatbot_ui, current_response, output_vis] ) # User message submission flow (2-step process) # Step 1: Add user message to history and update UI msg_submit = user_input.submit( fn=user_message_submitted, inputs=[user_input, chat_history, gen_length, steps, constraints_input, visualization_delay], outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response] ) # Also connect the send button send_click = send_btn.click( fn=user_message_submitted, inputs=[user_input, chat_history, gen_length, steps, constraints_input, visualization_delay], outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response] ) # Step 2: Generate bot response # This happens after the user message is displayed msg_submit.then( fn=bot_response, inputs=[chat_history, gen_length, steps, constraints_input, visualization_delay], outputs=[chatbot_ui, output_vis, current_response] ) send_click.then( fn=bot_response, inputs=[chat_history, gen_length, steps, constraints_input, visualization_delay], outputs=[chatbot_ui, output_vis, current_response] ) return demo # Launch the demo if __name__ == "__main__": demo = create_chatbot_demo() demo.queue().launch(share=True)