multimodalart HF staff commited on
Commit
b41aa9b
·
verified ·
1 Parent(s): f6d8cac

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +439 -0
app.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ import spaces
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import time
7
+ import re
8
+
9
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ print(f"Using device: {device}")
11
+
12
+ # Load model and tokenizer
13
+ tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
14
+ model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True,
15
+ torch_dtype=torch.bfloat16).to(device).eval()
16
+
17
+ # Constants
18
+ MASK_TOKEN = "[MASK]"
19
+ MASK_ID = 126336 # The token ID of [MASK] in LLaDA
20
+
21
+ def parse_constraints(constraints_text):
22
+ """Parse constraints in format: 'position:word, position:word, ...'"""
23
+ constraints = {}
24
+ if not constraints_text:
25
+ return constraints
26
+
27
+ parts = constraints_text.split(',')
28
+ for part in parts:
29
+ if ':' not in part:
30
+ continue
31
+ pos_str, word = part.split(':', 1)
32
+ try:
33
+ pos = int(pos_str.strip())
34
+ word = word.strip()
35
+ if word and pos >= 0:
36
+ constraints[pos] = word
37
+ except ValueError:
38
+ continue
39
+
40
+ return constraints
41
+
42
+ def format_chat_history(history):
43
+ """
44
+ Format chat history for the LLaDA model
45
+
46
+ Args:
47
+ history: List of [user_message, assistant_message] pairs
48
+
49
+ Returns:
50
+ Formatted conversation for the model
51
+ """
52
+ messages = []
53
+ for user_msg, assistant_msg in history:
54
+ messages.append({"role": "user", "content": user_msg})
55
+ if assistant_msg: # Skip if None (for the latest user message)
56
+ messages.append({"role": "assistant", "content": assistant_msg})
57
+
58
+ return messages
59
+
60
+ @spaces.GPU
61
+ def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=64, steps=32, constraints=None):
62
+ """
63
+ Generate text with LLaDA model with visualization of the denoising process
64
+
65
+ Args:
66
+ messages: List of message dictionaries with 'role' and 'content'
67
+
68
+ Returns:
69
+ List of visualization states showing the progression and final text
70
+ """
71
+ # Set random seed for reproducibility
72
+ torch.manual_seed(42)
73
+
74
+ # Process constraints
75
+ if constraints is None:
76
+ constraints = {}
77
+
78
+ # Convert any string constraints to token IDs
79
+ processed_constraints = {}
80
+ for pos, word in constraints.items():
81
+ tokens = tokenizer.encode(" " + word, add_special_tokens=False)
82
+ for i, token_id in enumerate(tokens):
83
+ processed_constraints[pos + i] = token_id
84
+
85
+ # Prepare the prompt using chat template
86
+ chat_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
87
+ input_ids = tokenizer(chat_input)['input_ids']
88
+ input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
89
+
90
+ # For generation
91
+ prompt_length = input_ids.shape[1]
92
+
93
+ # Initialize the sequence with masks for the response part
94
+ x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device)
95
+ x[:, :prompt_length] = input_ids.clone()
96
+
97
+ # Initialize visualization states for just the response part
98
+ visualization_states = []
99
+
100
+ # Add initial state (all masked) - only for the response part
101
+ initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
102
+ visualization_states.append(initial_state)
103
+
104
+ # Apply constraints to the initial state
105
+ for pos, token_id in processed_constraints.items():
106
+ absolute_pos = prompt_length + pos
107
+ if absolute_pos < x.shape[1]:
108
+ x[:, absolute_pos] = token_id
109
+
110
+ # Calculate timesteps
111
+ timesteps = torch.linspace(1.0, 0.0, steps + 1)[:-1]
112
+
113
+ # Keep track of already revealed tokens
114
+ revealed_tokens = torch.zeros(1, gen_length, dtype=torch.bool).to(device)
115
+
116
+ for step, t in enumerate(timesteps):
117
+ # Current t to next t
118
+ s = t - 1.0 / steps if step < steps - 1 else 0
119
+
120
+ # Get all mask positions in the current sequence
121
+ mask_indices = (x == MASK_ID)
122
+
123
+ # Skip if no masks
124
+ if not mask_indices.any():
125
+ break
126
+
127
+ # Get logits from the model
128
+ logits = model(x).logits
129
+
130
+ # Get the top predictions
131
+ x0 = torch.argmax(logits, dim=-1)
132
+
133
+ # Get probabilities for visualization
134
+ probs = torch.softmax(logits, dim=-1)
135
+ top_probs = torch.max(probs, dim=-1)[0]
136
+
137
+ # Apply the predictions where we have masks
138
+ x_old = x.clone()
139
+ x = torch.where(mask_indices, x0, x)
140
+
141
+ # Calculate how many tokens should remain masked at next step
142
+ total_len = gen_length
143
+ current_t_value = float(t)
144
+ next_t_value = float(s)
145
+
146
+ # Linear schedule: t=1 → all masked, t=0 → none masked
147
+ current_masks_expected = int(current_t_value * total_len)
148
+ next_masks_expected = int(next_t_value * total_len)
149
+
150
+ # How many to unmask in this step
151
+ tokens_to_unmask = current_masks_expected - next_masks_expected
152
+
153
+ if tokens_to_unmask > 0 and mask_indices.any():
154
+ # Get confidence scores for currently masked tokens
155
+ confidence_scores = top_probs[mask_indices]
156
+
157
+ # Sort confidence scores
158
+ sorted_indices = torch.argsort(confidence_scores, descending=True)
159
+
160
+ # Select which tokens to keep masked (the lowest confidence ones)
161
+ indices_to_remask = sorted_indices[tokens_to_unmask:]
162
+
163
+ # Get the actual indices in the sequence
164
+ mask_positions = torch.where(mask_indices)[1]
165
+ positions_to_remask = mask_positions[indices_to_remask]
166
+
167
+ # Remask these positions
168
+ x[:, positions_to_remask] = MASK_ID
169
+
170
+ # Ensure constraints are maintained
171
+ for pos, token_id in processed_constraints.items():
172
+ absolute_pos = prompt_length + pos
173
+ if absolute_pos < x.shape[1]:
174
+ x[:, absolute_pos] = token_id
175
+
176
+ # Create visualization state ONLY for the response part
177
+ current_state = []
178
+
179
+ # Update which tokens are newly revealed in this step
180
+ for i in range(gen_length):
181
+ pos = prompt_length + i # Absolute position in the sequence
182
+
183
+ if x[0, pos] == MASK_ID:
184
+ # Still masked
185
+ current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks
186
+
187
+ elif x_old[0, pos] == MASK_ID:
188
+ # Newly revealed in this step
189
+ token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
190
+ confidence = float(top_probs[0, pos].cpu())
191
+
192
+ # Color based on confidence: red (low) to green (high)
193
+ if confidence < 0.3:
194
+ color = "#FF6666" # Light red
195
+ elif confidence < 0.7:
196
+ color = "#FFAA33" # Orange
197
+ else:
198
+ color = "#66CC66" # Light green
199
+
200
+ current_state.append((token, color))
201
+ revealed_tokens[0, i] = True
202
+
203
+ else:
204
+ # Previously revealed
205
+ token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
206
+ current_state.append((token, "#6699CC")) # Light blue
207
+
208
+ visualization_states.append(current_state)
209
+
210
+ # Extract final text (just the assistant's response)
211
+ response_tokens = x[0, prompt_length:]
212
+ response_text = tokenizer.decode(response_tokens, skip_special_tokens=True)
213
+
214
+ # Clean the response text
215
+ final_text = clean_output_text(response_text)
216
+
217
+ return visualization_states, final_text
218
+
219
+ def clean_output_text(text):
220
+ """Clean the output text to remove special tokens and fix spacing"""
221
+ # Remove any remaining [MASK] tokens
222
+ text = text.replace(MASK_TOKEN, "")
223
+
224
+ # Fix common spacing issues with tokenization
225
+ text = re.sub(r'\s+', ' ', text) # Remove multiple spaces
226
+ text = re.sub(r' \.', '.', text) # Fix spacing before periods
227
+ text = re.sub(r' ,', ',', text) # Fix spacing before commas
228
+ text = re.sub(r' !', '!', text) # Fix spacing before exclamation marks
229
+ text = re.sub(r' \?', '?', text) # Fix spacing before question marks
230
+ text = re.sub(r' ;', ';', text) # Fix spacing before semicolons
231
+ text = re.sub(r' :', ':', text) # Fix spacing before colons
232
+
233
+ # Fix beginning and end spacing
234
+ text = text.strip()
235
+
236
+ return text
237
+
238
+ css = '''
239
+ .category-legend{display:none}
240
+ '''
241
+ def create_chatbot_demo():
242
+ with gr.Blocks(css=css) as demo:
243
+ gr.Markdown("# LLaDA - Large Language Diffusion Model demo")
244
+ gr.Markdown("[model](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct), [project page](https://ml-gsai.github.io/LLaDA-demo/)")
245
+
246
+ # STATE MANAGEMENT - IMPORTANT
247
+ # We use a dedicated state to track the full conversation history
248
+ chat_history = gr.State([])
249
+
250
+ # UI COMPONENTS
251
+ # Chatbot for displaying messages
252
+ with gr.Row():
253
+ with gr.Column(scale=3):
254
+ chatbot_ui = gr.Chatbot(label="Conversation", height=500)
255
+
256
+ # Message input
257
+ with gr.Group():
258
+ with gr.Row():
259
+ user_input = gr.Textbox(
260
+ label="Your Message",
261
+ placeholder="Type your message here...",
262
+ show_label=False
263
+ )
264
+ send_btn = gr.Button("Send")
265
+
266
+ constraints_input = gr.Textbox(
267
+ label="Word Constraints",
268
+ info="This model allows for placing specific words at specific positions using 'position:word' format. Example: 1st word once, 6th word 'upon' and 11th word 'time', would be: '0:Once, 5:upon, 10:time",
269
+ placeholder="0:Once, 5:upon, 10:time",
270
+ value=""
271
+ )
272
+ with gr.Column(scale=2):
273
+ output_vis = gr.HighlightedText(
274
+ label="Denoising Process Visualization",
275
+ combine_adjacent=False,
276
+ show_legend=True,
277
+ )
278
+ # Visualization and response components
279
+ with gr.Accordion("Generation Settings", open=False):
280
+ with gr.Row():
281
+ gen_length = gr.Slider(
282
+ minimum=16, maximum=128, value=64, step=8,
283
+ label="Generation Length"
284
+ )
285
+ steps = gr.Slider(
286
+ minimum=8, maximum=64, value=32, step=4,
287
+ label="Denoising Steps"
288
+ )
289
+
290
+
291
+ visualization_delay = gr.Slider(
292
+ minimum=0.0, maximum=1.0, value=0.1, step=0.1, visible=False,
293
+ label="Visualization Delay (seconds)"
294
+ )
295
+
296
+ # Current response text box
297
+ current_response = gr.Textbox(
298
+ label="Current Response",
299
+ placeholder="The assistant's response will appear here...",
300
+ lines=3,
301
+ visible=False
302
+ )
303
+
304
+ # Clear button
305
+ clear_btn = gr.Button("Clear Conversation")
306
+
307
+ # Example inputs
308
+ gr.Examples(
309
+ [
310
+ ["Tell me a short joke", 64, 32, ""],
311
+ ["Write a short story", 64, 32, "0:Once, 5:upon, 10:time"],
312
+ ["Explain quantum computing", 64, 32, ""],
313
+ ],
314
+ [user_input, gen_length, steps, constraints_input],
315
+ )
316
+
317
+ # HELPER FUNCTIONS
318
+ def add_message(history, message, response):
319
+ """Add a message pair to the history and return the updated history"""
320
+ history = history.copy()
321
+ history.append([message, response])
322
+ return history
323
+
324
+ def user_message_submitted(message, history, gen_length, steps, constraints, delay):
325
+ """Process a submitted user message"""
326
+ # Skip empty messages
327
+ if not message.strip():
328
+ # Return current state unchanged
329
+ history_for_display = history.copy()
330
+ return history, history_for_display, "", [], ""
331
+
332
+ # Add user message to history
333
+ history = add_message(history, message, None)
334
+
335
+ # Format for display - temporarily show user message with empty response
336
+ history_for_display = history.copy()
337
+
338
+ # Clear the input
339
+ message_out = ""
340
+
341
+ # Return immediately to update UI with user message
342
+ return history, history_for_display, message_out, [], ""
343
+
344
+ def bot_response(history, gen_length, steps, constraints, delay):
345
+ """Generate bot response for the latest message"""
346
+ if not history:
347
+ return history, [], ""
348
+
349
+ # Get the last user message
350
+ last_user_message = history[-1][0]
351
+
352
+ try:
353
+ # Format all messages except the last one (which has no response yet)
354
+ messages = format_chat_history(history[:-1])
355
+
356
+ # Add the last user message
357
+ messages.append({"role": "user", "content": last_user_message})
358
+
359
+ # Parse constraints
360
+ parsed_constraints = parse_constraints(constraints)
361
+
362
+ # Generate response with visualization
363
+ vis_states, response_text = generate_response_with_visualization(
364
+ model, tokenizer, device,
365
+ messages,
366
+ gen_length=gen_length,
367
+ steps=steps,
368
+ constraints=parsed_constraints
369
+ )
370
+
371
+ # Update history with the assistant's response
372
+ history[-1][1] = response_text
373
+
374
+ # Return the initial state immediately
375
+ yield history, vis_states[0], response_text
376
+
377
+ # Then animate through visualization states
378
+ for state in vis_states[1:]:
379
+ time.sleep(delay)
380
+ yield history, state, response_text
381
+
382
+ except Exception as e:
383
+ error_msg = f"Error: {str(e)}"
384
+ print(error_msg)
385
+
386
+ # Show error in visualization
387
+ error_vis = [(error_msg, "red")]
388
+
389
+ # Don't update history with error
390
+ yield history, error_vis, error_msg
391
+
392
+ def clear_conversation():
393
+ """Clear the conversation history"""
394
+ return [], [], "", []
395
+
396
+ # EVENT HANDLERS
397
+
398
+ # Clear button handler
399
+ clear_btn.click(
400
+ fn=clear_conversation,
401
+ inputs=[],
402
+ outputs=[chat_history, chatbot_ui, current_response, output_vis]
403
+ )
404
+
405
+ # User message submission flow (2-step process)
406
+ # Step 1: Add user message to history and update UI
407
+ msg_submit = user_input.submit(
408
+ fn=user_message_submitted,
409
+ inputs=[user_input, chat_history, gen_length, steps, constraints_input, visualization_delay],
410
+ outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response]
411
+ )
412
+
413
+ # Also connect the send button
414
+ send_click = send_btn.click(
415
+ fn=user_message_submitted,
416
+ inputs=[user_input, chat_history, gen_length, steps, constraints_input, visualization_delay],
417
+ outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response]
418
+ )
419
+
420
+ # Step 2: Generate bot response
421
+ # This happens after the user message is displayed
422
+ msg_submit.then(
423
+ fn=bot_response,
424
+ inputs=[chat_history, gen_length, steps, constraints_input, visualization_delay],
425
+ outputs=[chatbot_ui, output_vis, current_response]
426
+ )
427
+
428
+ send_click.then(
429
+ fn=bot_response,
430
+ inputs=[chat_history, gen_length, steps, constraints_input, visualization_delay],
431
+ outputs=[chatbot_ui, output_vis, current_response]
432
+ )
433
+
434
+ return demo
435
+
436
+ # Launch the demo
437
+ if __name__ == "__main__":
438
+ demo = create_chatbot_demo()
439
+ demo.queue().launch(share=True)