multimodalart HF staff commited on
Commit
ef6b1de
·
verified ·
1 Parent(s): 691f73d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -111
app.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import numpy as np
3
  import gradio as gr
4
  import spaces
 
5
  from transformers import AutoTokenizer, AutoModel
6
  import time
7
  import re
@@ -57,13 +58,56 @@ def format_chat_history(history):
57
 
58
  return messages
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  @spaces.GPU
61
- def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=64, steps=32, constraints=None):
 
 
62
  """
63
- Generate text with LLaDA model with visualization of the denoising process
64
 
65
  Args:
66
  messages: List of message dictionaries with 'role' and 'content'
 
 
 
 
 
 
 
67
 
68
  Returns:
69
  List of visualization states showing the progression and final text
@@ -92,10 +136,10 @@ def generate_response_with_visualization(model, tokenizer, device, messages, gen
92
  x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device)
93
  x[:, :prompt_length] = input_ids.clone()
94
 
95
- # Initialize visualization states for just the response part
96
  visualization_states = []
97
 
98
- # Add initial state (all masked) - only for the response part
99
  initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
100
  visualization_states.append(initial_state)
101
 
@@ -105,114 +149,144 @@ def generate_response_with_visualization(model, tokenizer, device, messages, gen
105
  if absolute_pos < x.shape[1]:
106
  x[:, absolute_pos] = token_id
107
 
108
- # Calculate timesteps
109
- timesteps = torch.linspace(1.0, 0.0, steps + 1)[:-1]
110
 
111
- # Keep track of already revealed tokens
112
- revealed_tokens = torch.zeros(1, gen_length, dtype=torch.bool).to(device)
 
113
 
114
- for step, t in enumerate(timesteps):
115
- # Current t to next t
116
- s = t - 1.0 / steps if step < steps - 1 else 0
117
-
118
- # Get all mask positions in the current sequence
119
- mask_indices = (x == MASK_ID)
120
-
121
- # Skip if no masks
122
- if not mask_indices.any():
123
- break
124
-
125
- # Get logits from the model
126
- logits = model(x).logits
127
-
128
- # Get the top predictions
129
- x0 = torch.argmax(logits, dim=-1)
130
-
131
- # Get probabilities for visualization
132
- probs = torch.softmax(logits, dim=-1)
133
- top_probs = torch.max(probs, dim=-1)[0]
134
-
135
- # Apply the predictions where we have masks
136
- x_old = x.clone()
137
- x = torch.where(mask_indices, x0, x)
138
 
139
- # Calculate how many tokens should remain masked at next step
140
- total_len = gen_length
141
- current_t_value = float(t)
142
- next_t_value = float(s)
143
 
144
- # Linear schedule: t=1 all masked, t=0 → none masked
145
- current_masks_expected = int(current_t_value * total_len)
146
- next_masks_expected = int(next_t_value * total_len)
147
 
148
- # How many to unmask in this step
149
- tokens_to_unmask = current_masks_expected - next_masks_expected
150
 
151
- if tokens_to_unmask > 0 and mask_indices.any():
152
- # Get confidence scores for currently masked tokens
153
- confidence_scores = top_probs[mask_indices]
 
154
 
155
- # Sort confidence scores
156
- sorted_indices = torch.argsort(confidence_scores, descending=True)
 
157
 
158
- # Select which tokens to keep masked (the lowest confidence ones)
159
- indices_to_remask = sorted_indices[tokens_to_unmask:]
 
 
 
 
 
 
 
 
160
 
161
- # Get the actual indices in the sequence
162
- mask_positions = torch.where(mask_indices)[1]
163
- positions_to_remask = mask_positions[indices_to_remask]
164
 
165
- # Remask these positions
166
- x[:, positions_to_remask] = MASK_ID
167
-
168
- # Ensure constraints are maintained
169
- for pos, token_id in processed_constraints.items():
170
- absolute_pos = prompt_length + pos
171
- if absolute_pos < x.shape[1]:
172
- x[:, absolute_pos] = token_id
173
-
174
- # Create visualization state ONLY for the response part
175
- current_state = []
176
-
177
- # Update which tokens are newly revealed in this step
178
- for i in range(gen_length):
179
- pos = prompt_length + i # Absolute position in the sequence
180
 
181
- if x[0, pos] == MASK_ID:
182
- # Still masked
183
- current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks
184
-
185
- elif x_old[0, pos] == MASK_ID:
186
- # Newly revealed in this step
187
- token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
188
- confidence = float(top_probs[0, pos].cpu())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- # Color based on confidence: red (low) to green (high)
191
- if confidence < 0.3:
192
- color = "#FF6666" # Light red
193
- elif confidence < 0.7:
194
- color = "#FFAA33" # Orange
195
- else:
196
- color = "#66CC66" # Light green
197
 
198
- current_state.append((token, color))
199
- revealed_tokens[0, i] = True
200
-
201
- else:
202
- # Previously revealed
203
- token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
204
- current_state.append((token, "#6699CC")) # Light blue
205
-
206
- visualization_states.append(current_state)
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  # Extract final text (just the assistant's response)
209
  response_tokens = x[0, prompt_length:]
210
- response_text = tokenizer.decode(response_tokens, skip_special_tokens=True)
211
-
212
- # Clean the response text
213
  final_text = tokenizer.decode(response_tokens,
214
- skip_special_tokens=True,
215
- clean_up_tokenization_spaces=True)
216
 
217
  return visualization_states, final_text
218
 
@@ -222,15 +296,13 @@ button{height: 60px}
222
  '''
223
  def create_chatbot_demo():
224
  with gr.Blocks(css=css) as demo:
225
- gr.Markdown("# LLaDA - Large Language Diffusion Model demo")
226
  gr.Markdown("[model](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct), [project page](https://ml-gsai.github.io/LLaDA-demo/)")
227
 
228
- # STATE MANAGEMENT - IMPORTANT
229
- # We use a dedicated state to track the full conversation history
230
  chat_history = gr.State([])
231
 
232
  # UI COMPONENTS
233
- # Chatbot for displaying messages
234
  with gr.Row():
235
  with gr.Column(scale=3):
236
  chatbot_ui = gr.Chatbot(label="Conversation", height=500)
@@ -257,7 +329,8 @@ def create_chatbot_demo():
257
  combine_adjacent=False,
258
  show_legend=True,
259
  )
260
- # Visualization and response components
 
261
  with gr.Accordion("Generation Settings", open=False):
262
  with gr.Row():
263
  gen_length = gr.Slider(
@@ -268,14 +341,32 @@ def create_chatbot_demo():
268
  minimum=8, maximum=64, value=32, step=4,
269
  label="Denoising Steps"
270
  )
271
-
272
-
273
- visualization_delay = gr.Slider(
274
- minimum=0.0, maximum=1.0, value=0.1, step=0.1, visible=False,
275
- label="Visualization Delay (seconds)"
276
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- # Current response text box
279
  current_response = gr.Textbox(
280
  label="Current Response",
281
  placeholder="The assistant's response will appear here...",
@@ -313,7 +404,7 @@ def create_chatbot_demo():
313
  # Return immediately to update UI with user message
314
  return history, history_for_display, message_out, [], ""
315
 
316
- def bot_response(history, gen_length, steps, constraints, delay):
317
  """Generate bot response for the latest message"""
318
  if not history:
319
  return history, [], ""
@@ -337,7 +428,11 @@ def create_chatbot_demo():
337
  messages,
338
  gen_length=gen_length,
339
  steps=steps,
340
- constraints=parsed_constraints
 
 
 
 
341
  )
342
 
343
  # Update history with the assistant's response
@@ -393,13 +488,21 @@ def create_chatbot_demo():
393
  # This happens after the user message is displayed
394
  msg_submit.then(
395
  fn=bot_response,
396
- inputs=[chat_history, gen_length, steps, constraints_input, visualization_delay],
 
 
 
 
397
  outputs=[chatbot_ui, output_vis, current_response]
398
  )
399
 
400
  send_click.then(
401
  fn=bot_response,
402
- inputs=[chat_history, gen_length, steps, constraints_input, visualization_delay],
 
 
 
 
403
  outputs=[chatbot_ui, output_vis, current_response]
404
  )
405
 
 
2
  import numpy as np
3
  import gradio as gr
4
  import spaces
5
+ import torch.nn.functional as F
6
  from transformers import AutoTokenizer, AutoModel
7
  import time
8
  import re
 
58
 
59
  return messages
60
 
61
+ def add_gumbel_noise(logits, temperature):
62
+ '''
63
+ The Gumbel max is a method for sampling categorical distributions.
64
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
65
+ Thus, we use float64.
66
+ '''
67
+ if temperature <= 0:
68
+ return logits
69
+
70
+ logits = logits.to(torch.float64)
71
+ noise = torch.rand_like(logits, dtype=torch.float64)
72
+ gumbel_noise = (- torch.log(noise)) ** temperature
73
+ return logits.exp() / gumbel_noise
74
+
75
+ def get_num_transfer_tokens(mask_index, steps):
76
+ '''
77
+ In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
78
+ Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
79
+ the expected number of tokens transitioned at each step should be consistent.
80
+
81
+ This function is designed to precompute the number of tokens that need to be transitioned at each step.
82
+ '''
83
+ mask_num = mask_index.sum(dim=1, keepdim=True)
84
+
85
+ base = mask_num // steps
86
+ remainder = mask_num % steps
87
+
88
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
89
+
90
+ for i in range(mask_num.size(0)):
91
+ num_transfer_tokens[i, :remainder[i]] += 1
92
+
93
+ return num_transfer_tokens
94
+
95
  @spaces.GPU
96
+ def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=64, steps=32,
97
+ constraints=None, temperature=0.0, cfg_scale=0.0, block_length=32,
98
+ remasking='low_confidence'):
99
  """
100
+ Generate text with LLaDA model with visualization using the same sampling as in generate.py
101
 
102
  Args:
103
  messages: List of message dictionaries with 'role' and 'content'
104
+ gen_length: Length of text to generate
105
+ steps: Number of denoising steps
106
+ constraints: Dictionary mapping positions to words
107
+ temperature: Sampling temperature
108
+ cfg_scale: Classifier-free guidance scale
109
+ block_length: Block length for semi-autoregressive generation
110
+ remasking: Remasking strategy ('low_confidence' or 'random')
111
 
112
  Returns:
113
  List of visualization states showing the progression and final text
 
136
  x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device)
137
  x[:, :prompt_length] = input_ids.clone()
138
 
139
+ # Initialize visualization states for the response part
140
  visualization_states = []
141
 
142
+ # Add initial state (all masked)
143
  initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
144
  visualization_states.append(initial_state)
145
 
 
149
  if absolute_pos < x.shape[1]:
150
  x[:, absolute_pos] = token_id
151
 
152
+ # Mark prompt positions to exclude them from masking during classifier-free guidance
153
+ prompt_index = (x != MASK_ID)
154
 
155
+ # Ensure block_length is valid
156
+ if block_length > gen_length:
157
+ block_length = gen_length
158
 
159
+ # Calculate number of blocks
160
+ num_blocks = gen_length // block_length
161
+ if gen_length % block_length != 0:
162
+ num_blocks += 1
163
+
164
+ # Adjust steps per block
165
+ steps_per_block = steps // num_blocks
166
+ if steps_per_block < 1:
167
+ steps_per_block = 1
168
+
169
+ # Track the current state of x for visualization
170
+ current_x = x.clone()
171
+
172
+ # Process each block
173
+ for num_block in range(num_blocks):
174
+ # Calculate the start and end indices for the current block
175
+ block_start = prompt_length + num_block * block_length
176
+ block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1])
 
 
 
 
 
 
177
 
178
+ # Get mask indices for the current block
179
+ block_mask_index = (x[:, block_start:block_end] == MASK_ID)
 
 
180
 
181
+ # Skip if no masks in this block
182
+ if not block_mask_index.any():
183
+ continue
184
 
185
+ # Calculate number of tokens to unmask at each step
186
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
187
 
188
+ # Process each step
189
+ for i in range(steps_per_block):
190
+ # Get all mask positions in the current sequence
191
+ mask_index = (x == MASK_ID)
192
 
193
+ # Skip if no masks
194
+ if not mask_index.any():
195
+ break
196
 
197
+ # Apply classifier-free guidance if enabled
198
+ if cfg_scale > 0.0:
199
+ un_x = x.clone()
200
+ un_x[prompt_index] = MASK_ID
201
+ x_ = torch.cat([x, un_x], dim=0)
202
+ logits = model(x_).logits
203
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
204
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
205
+ else:
206
+ logits = model(x).logits
207
 
208
+ # Apply Gumbel noise for sampling
209
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
210
+ x0 = torch.argmax(logits_with_noise, dim=-1)
211
 
212
+ # Calculate confidence scores for remasking
213
+ if remasking == 'low_confidence':
214
+ p = F.softmax(logits.to(torch.float64), dim=-1)
215
+ x0_p = torch.squeeze(
216
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
217
+ elif remasking == 'random':
218
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
219
+ else:
220
+ raise NotImplementedError(f"Remasking strategy '{remasking}' not implemented")
 
 
 
 
 
 
221
 
222
+ # Don't consider positions beyond the current block
223
+ x0_p[:, block_end:] = -float('inf')
224
+
225
+ # Apply predictions where we have masks
226
+ old_x = x.clone()
227
+ x0 = torch.where(mask_index, x0, x)
228
+ confidence = torch.where(mask_index, x0_p, -float('inf'))
229
+
230
+ # Select tokens to unmask based on confidence
231
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
232
+ for j in range(confidence.shape[0]):
233
+ # Only consider positions within the current block for unmasking
234
+ block_confidence = confidence[j, block_start:block_end]
235
+ if i < steps_per_block - 1: # Not the last step
236
+ # Take top-k confidences
237
+ _, select_indices = torch.topk(block_confidence,
238
+ k=min(num_transfer_tokens[j, i].item(),
239
+ block_confidence.numel()))
240
+ # Adjust indices to global positions
241
+ select_indices = select_indices + block_start
242
+ transfer_index[j, select_indices] = True
243
+ else: # Last step - unmask everything remaining
244
+ transfer_index[j, block_start:block_end] = mask_index[j, block_start:block_end]
245
+
246
+ # Apply the selected tokens
247
+ x = torch.where(transfer_index, x0, x)
248
+
249
+ # Ensure constraints are maintained
250
+ for pos, token_id in processed_constraints.items():
251
+ absolute_pos = prompt_length + pos
252
+ if absolute_pos < x.shape[1]:
253
+ x[:, absolute_pos] = token_id
254
+
255
+ # Create visualization state only for the response part
256
+ current_state = []
257
+ for i in range(gen_length):
258
+ pos = prompt_length + i # Absolute position in the sequence
259
 
260
+ if x[0, pos] == MASK_ID:
261
+ # Still masked
262
+ current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks
 
 
 
 
263
 
264
+ elif old_x[0, pos] == MASK_ID:
265
+ # Newly revealed in this step
266
+ token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
267
+ # Color based on confidence
268
+ confidence = float(x0_p[0, pos].cpu())
269
+ if confidence < 0.3:
270
+ color = "#FF6666" # Light red
271
+ elif confidence < 0.7:
272
+ color = "#FFAA33" # Orange
273
+ else:
274
+ color = "#66CC66" # Light green
275
+
276
+ current_state.append((token, color))
277
+
278
+ else:
279
+ # Previously revealed
280
+ token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
281
+ current_state.append((token, "#6699CC")) # Light blue
282
+
283
+ visualization_states.append(current_state)
284
 
285
  # Extract final text (just the assistant's response)
286
  response_tokens = x[0, prompt_length:]
 
 
 
287
  final_text = tokenizer.decode(response_tokens,
288
+ skip_special_tokens=True,
289
+ clean_up_tokenization_spaces=True)
290
 
291
  return visualization_states, final_text
292
 
 
296
  '''
297
  def create_chatbot_demo():
298
  with gr.Blocks(css=css) as demo:
299
+ gr.Markdown("# LLaDA - Large Language Diffusion Model Demo")
300
  gr.Markdown("[model](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct), [project page](https://ml-gsai.github.io/LLaDA-demo/)")
301
 
302
+ # STATE MANAGEMENT
 
303
  chat_history = gr.State([])
304
 
305
  # UI COMPONENTS
 
306
  with gr.Row():
307
  with gr.Column(scale=3):
308
  chatbot_ui = gr.Chatbot(label="Conversation", height=500)
 
329
  combine_adjacent=False,
330
  show_legend=True,
331
  )
332
+
333
+ # Advanced generation settings
334
  with gr.Accordion("Generation Settings", open=False):
335
  with gr.Row():
336
  gen_length = gr.Slider(
 
341
  minimum=8, maximum=64, value=32, step=4,
342
  label="Denoising Steps"
343
  )
344
+ with gr.Row():
345
+ temperature = gr.Slider(
346
+ minimum=0.0, maximum=1.0, value=0.0, step=0.1,
347
+ label="Temperature"
348
+ )
349
+ cfg_scale = gr.Slider(
350
+ minimum=0.0, maximum=2.0, value=0.0, step=0.1,
351
+ label="CFG Scale"
352
+ )
353
+ with gr.Row():
354
+ block_length = gr.Slider(
355
+ minimum=8, maximum=128, value=32, step=8,
356
+ label="Block Length"
357
+ )
358
+ remasking_strategy = gr.Radio(
359
+ choices=["low_confidence", "random"],
360
+ value="low_confidence",
361
+ label="Remasking Strategy"
362
+ )
363
+ with gr.Row():
364
+ visualization_delay = gr.Slider(
365
+ minimum=0.0, maximum=1.0, value=0.1, step=0.1,
366
+ label="Visualization Delay (seconds)"
367
+ )
368
 
369
+ # Current response text box (hidden)
370
  current_response = gr.Textbox(
371
  label="Current Response",
372
  placeholder="The assistant's response will appear here...",
 
404
  # Return immediately to update UI with user message
405
  return history, history_for_display, message_out, [], ""
406
 
407
+ def bot_response(history, gen_length, steps, constraints, delay, temperature, cfg_scale, block_length, remasking):
408
  """Generate bot response for the latest message"""
409
  if not history:
410
  return history, [], ""
 
428
  messages,
429
  gen_length=gen_length,
430
  steps=steps,
431
+ constraints=parsed_constraints,
432
+ temperature=temperature,
433
+ cfg_scale=cfg_scale,
434
+ block_length=block_length,
435
+ remasking=remasking
436
  )
437
 
438
  # Update history with the assistant's response
 
488
  # This happens after the user message is displayed
489
  msg_submit.then(
490
  fn=bot_response,
491
+ inputs=[
492
+ chat_history, gen_length, steps, constraints_input,
493
+ visualization_delay, temperature, cfg_scale, block_length,
494
+ remasking_strategy
495
+ ],
496
  outputs=[chatbot_ui, output_vis, current_response]
497
  )
498
 
499
  send_click.then(
500
  fn=bot_response,
501
+ inputs=[
502
+ chat_history, gen_length, steps, constraints_input,
503
+ visualization_delay, temperature, cfg_scale, block_length,
504
+ remasking_strategy
505
+ ],
506
  outputs=[chatbot_ui, output_vis, current_response]
507
  )
508