Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
import numpy as np
|
3 |
import gradio as gr
|
4 |
import spaces
|
|
|
5 |
from transformers import AutoTokenizer, AutoModel
|
6 |
import time
|
7 |
import re
|
@@ -57,13 +58,56 @@ def format_chat_history(history):
|
|
57 |
|
58 |
return messages
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
@spaces.GPU
|
61 |
-
def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=64, steps=32,
|
|
|
|
|
62 |
"""
|
63 |
-
Generate text with LLaDA model with visualization
|
64 |
|
65 |
Args:
|
66 |
messages: List of message dictionaries with 'role' and 'content'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
Returns:
|
69 |
List of visualization states showing the progression and final text
|
@@ -92,10 +136,10 @@ def generate_response_with_visualization(model, tokenizer, device, messages, gen
|
|
92 |
x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device)
|
93 |
x[:, :prompt_length] = input_ids.clone()
|
94 |
|
95 |
-
# Initialize visualization states for
|
96 |
visualization_states = []
|
97 |
|
98 |
-
# Add initial state (all masked)
|
99 |
initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
|
100 |
visualization_states.append(initial_state)
|
101 |
|
@@ -105,114 +149,144 @@ def generate_response_with_visualization(model, tokenizer, device, messages, gen
|
|
105 |
if absolute_pos < x.shape[1]:
|
106 |
x[:, absolute_pos] = token_id
|
107 |
|
108 |
-
#
|
109 |
-
|
110 |
|
111 |
-
#
|
112 |
-
|
|
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
probs = torch.softmax(logits, dim=-1)
|
133 |
-
top_probs = torch.max(probs, dim=-1)[0]
|
134 |
-
|
135 |
-
# Apply the predictions where we have masks
|
136 |
-
x_old = x.clone()
|
137 |
-
x = torch.where(mask_indices, x0, x)
|
138 |
|
139 |
-
#
|
140 |
-
|
141 |
-
current_t_value = float(t)
|
142 |
-
next_t_value = float(s)
|
143 |
|
144 |
-
#
|
145 |
-
|
146 |
-
|
147 |
|
148 |
-
#
|
149 |
-
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
154 |
|
155 |
-
#
|
156 |
-
|
|
|
157 |
|
158 |
-
#
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
-
#
|
162 |
-
|
163 |
-
|
164 |
|
165 |
-
#
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
# Create visualization state ONLY for the response part
|
175 |
-
current_state = []
|
176 |
-
|
177 |
-
# Update which tokens are newly revealed in this step
|
178 |
-
for i in range(gen_length):
|
179 |
-
pos = prompt_length + i # Absolute position in the sequence
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
elif confidence < 0.7:
|
194 |
-
color = "#FFAA33" # Orange
|
195 |
-
else:
|
196 |
-
color = "#66CC66" # Light green
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
# Extract final text (just the assistant's response)
|
209 |
response_tokens = x[0, prompt_length:]
|
210 |
-
response_text = tokenizer.decode(response_tokens, skip_special_tokens=True)
|
211 |
-
|
212 |
-
# Clean the response text
|
213 |
final_text = tokenizer.decode(response_tokens,
|
214 |
-
|
215 |
-
|
216 |
|
217 |
return visualization_states, final_text
|
218 |
|
@@ -222,15 +296,13 @@ button{height: 60px}
|
|
222 |
'''
|
223 |
def create_chatbot_demo():
|
224 |
with gr.Blocks(css=css) as demo:
|
225 |
-
gr.Markdown("# LLaDA - Large Language Diffusion Model
|
226 |
gr.Markdown("[model](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct), [project page](https://ml-gsai.github.io/LLaDA-demo/)")
|
227 |
|
228 |
-
# STATE MANAGEMENT
|
229 |
-
# We use a dedicated state to track the full conversation history
|
230 |
chat_history = gr.State([])
|
231 |
|
232 |
# UI COMPONENTS
|
233 |
-
# Chatbot for displaying messages
|
234 |
with gr.Row():
|
235 |
with gr.Column(scale=3):
|
236 |
chatbot_ui = gr.Chatbot(label="Conversation", height=500)
|
@@ -257,7 +329,8 @@ def create_chatbot_demo():
|
|
257 |
combine_adjacent=False,
|
258 |
show_legend=True,
|
259 |
)
|
260 |
-
|
|
|
261 |
with gr.Accordion("Generation Settings", open=False):
|
262 |
with gr.Row():
|
263 |
gen_length = gr.Slider(
|
@@ -268,14 +341,32 @@ def create_chatbot_demo():
|
|
268 |
minimum=8, maximum=64, value=32, step=4,
|
269 |
label="Denoising Steps"
|
270 |
)
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
-
# Current response text box
|
279 |
current_response = gr.Textbox(
|
280 |
label="Current Response",
|
281 |
placeholder="The assistant's response will appear here...",
|
@@ -313,7 +404,7 @@ def create_chatbot_demo():
|
|
313 |
# Return immediately to update UI with user message
|
314 |
return history, history_for_display, message_out, [], ""
|
315 |
|
316 |
-
def bot_response(history, gen_length, steps, constraints, delay):
|
317 |
"""Generate bot response for the latest message"""
|
318 |
if not history:
|
319 |
return history, [], ""
|
@@ -337,7 +428,11 @@ def create_chatbot_demo():
|
|
337 |
messages,
|
338 |
gen_length=gen_length,
|
339 |
steps=steps,
|
340 |
-
constraints=parsed_constraints
|
|
|
|
|
|
|
|
|
341 |
)
|
342 |
|
343 |
# Update history with the assistant's response
|
@@ -393,13 +488,21 @@ def create_chatbot_demo():
|
|
393 |
# This happens after the user message is displayed
|
394 |
msg_submit.then(
|
395 |
fn=bot_response,
|
396 |
-
inputs=[
|
|
|
|
|
|
|
|
|
397 |
outputs=[chatbot_ui, output_vis, current_response]
|
398 |
)
|
399 |
|
400 |
send_click.then(
|
401 |
fn=bot_response,
|
402 |
-
inputs=[
|
|
|
|
|
|
|
|
|
403 |
outputs=[chatbot_ui, output_vis, current_response]
|
404 |
)
|
405 |
|
|
|
2 |
import numpy as np
|
3 |
import gradio as gr
|
4 |
import spaces
|
5 |
+
import torch.nn.functional as F
|
6 |
from transformers import AutoTokenizer, AutoModel
|
7 |
import time
|
8 |
import re
|
|
|
58 |
|
59 |
return messages
|
60 |
|
61 |
+
def add_gumbel_noise(logits, temperature):
|
62 |
+
'''
|
63 |
+
The Gumbel max is a method for sampling categorical distributions.
|
64 |
+
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
|
65 |
+
Thus, we use float64.
|
66 |
+
'''
|
67 |
+
if temperature <= 0:
|
68 |
+
return logits
|
69 |
+
|
70 |
+
logits = logits.to(torch.float64)
|
71 |
+
noise = torch.rand_like(logits, dtype=torch.float64)
|
72 |
+
gumbel_noise = (- torch.log(noise)) ** temperature
|
73 |
+
return logits.exp() / gumbel_noise
|
74 |
+
|
75 |
+
def get_num_transfer_tokens(mask_index, steps):
|
76 |
+
'''
|
77 |
+
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
|
78 |
+
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
|
79 |
+
the expected number of tokens transitioned at each step should be consistent.
|
80 |
+
|
81 |
+
This function is designed to precompute the number of tokens that need to be transitioned at each step.
|
82 |
+
'''
|
83 |
+
mask_num = mask_index.sum(dim=1, keepdim=True)
|
84 |
+
|
85 |
+
base = mask_num // steps
|
86 |
+
remainder = mask_num % steps
|
87 |
+
|
88 |
+
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
|
89 |
+
|
90 |
+
for i in range(mask_num.size(0)):
|
91 |
+
num_transfer_tokens[i, :remainder[i]] += 1
|
92 |
+
|
93 |
+
return num_transfer_tokens
|
94 |
+
|
95 |
@spaces.GPU
|
96 |
+
def generate_response_with_visualization(model, tokenizer, device, messages, gen_length=64, steps=32,
|
97 |
+
constraints=None, temperature=0.0, cfg_scale=0.0, block_length=32,
|
98 |
+
remasking='low_confidence'):
|
99 |
"""
|
100 |
+
Generate text with LLaDA model with visualization using the same sampling as in generate.py
|
101 |
|
102 |
Args:
|
103 |
messages: List of message dictionaries with 'role' and 'content'
|
104 |
+
gen_length: Length of text to generate
|
105 |
+
steps: Number of denoising steps
|
106 |
+
constraints: Dictionary mapping positions to words
|
107 |
+
temperature: Sampling temperature
|
108 |
+
cfg_scale: Classifier-free guidance scale
|
109 |
+
block_length: Block length for semi-autoregressive generation
|
110 |
+
remasking: Remasking strategy ('low_confidence' or 'random')
|
111 |
|
112 |
Returns:
|
113 |
List of visualization states showing the progression and final text
|
|
|
136 |
x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device)
|
137 |
x[:, :prompt_length] = input_ids.clone()
|
138 |
|
139 |
+
# Initialize visualization states for the response part
|
140 |
visualization_states = []
|
141 |
|
142 |
+
# Add initial state (all masked)
|
143 |
initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
|
144 |
visualization_states.append(initial_state)
|
145 |
|
|
|
149 |
if absolute_pos < x.shape[1]:
|
150 |
x[:, absolute_pos] = token_id
|
151 |
|
152 |
+
# Mark prompt positions to exclude them from masking during classifier-free guidance
|
153 |
+
prompt_index = (x != MASK_ID)
|
154 |
|
155 |
+
# Ensure block_length is valid
|
156 |
+
if block_length > gen_length:
|
157 |
+
block_length = gen_length
|
158 |
|
159 |
+
# Calculate number of blocks
|
160 |
+
num_blocks = gen_length // block_length
|
161 |
+
if gen_length % block_length != 0:
|
162 |
+
num_blocks += 1
|
163 |
+
|
164 |
+
# Adjust steps per block
|
165 |
+
steps_per_block = steps // num_blocks
|
166 |
+
if steps_per_block < 1:
|
167 |
+
steps_per_block = 1
|
168 |
+
|
169 |
+
# Track the current state of x for visualization
|
170 |
+
current_x = x.clone()
|
171 |
+
|
172 |
+
# Process each block
|
173 |
+
for num_block in range(num_blocks):
|
174 |
+
# Calculate the start and end indices for the current block
|
175 |
+
block_start = prompt_length + num_block * block_length
|
176 |
+
block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
+
# Get mask indices for the current block
|
179 |
+
block_mask_index = (x[:, block_start:block_end] == MASK_ID)
|
|
|
|
|
180 |
|
181 |
+
# Skip if no masks in this block
|
182 |
+
if not block_mask_index.any():
|
183 |
+
continue
|
184 |
|
185 |
+
# Calculate number of tokens to unmask at each step
|
186 |
+
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
|
187 |
|
188 |
+
# Process each step
|
189 |
+
for i in range(steps_per_block):
|
190 |
+
# Get all mask positions in the current sequence
|
191 |
+
mask_index = (x == MASK_ID)
|
192 |
|
193 |
+
# Skip if no masks
|
194 |
+
if not mask_index.any():
|
195 |
+
break
|
196 |
|
197 |
+
# Apply classifier-free guidance if enabled
|
198 |
+
if cfg_scale > 0.0:
|
199 |
+
un_x = x.clone()
|
200 |
+
un_x[prompt_index] = MASK_ID
|
201 |
+
x_ = torch.cat([x, un_x], dim=0)
|
202 |
+
logits = model(x_).logits
|
203 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
204 |
+
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
205 |
+
else:
|
206 |
+
logits = model(x).logits
|
207 |
|
208 |
+
# Apply Gumbel noise for sampling
|
209 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
210 |
+
x0 = torch.argmax(logits_with_noise, dim=-1)
|
211 |
|
212 |
+
# Calculate confidence scores for remasking
|
213 |
+
if remasking == 'low_confidence':
|
214 |
+
p = F.softmax(logits.to(torch.float64), dim=-1)
|
215 |
+
x0_p = torch.squeeze(
|
216 |
+
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
217 |
+
elif remasking == 'random':
|
218 |
+
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
|
219 |
+
else:
|
220 |
+
raise NotImplementedError(f"Remasking strategy '{remasking}' not implemented")
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
+
# Don't consider positions beyond the current block
|
223 |
+
x0_p[:, block_end:] = -float('inf')
|
224 |
+
|
225 |
+
# Apply predictions where we have masks
|
226 |
+
old_x = x.clone()
|
227 |
+
x0 = torch.where(mask_index, x0, x)
|
228 |
+
confidence = torch.where(mask_index, x0_p, -float('inf'))
|
229 |
+
|
230 |
+
# Select tokens to unmask based on confidence
|
231 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
232 |
+
for j in range(confidence.shape[0]):
|
233 |
+
# Only consider positions within the current block for unmasking
|
234 |
+
block_confidence = confidence[j, block_start:block_end]
|
235 |
+
if i < steps_per_block - 1: # Not the last step
|
236 |
+
# Take top-k confidences
|
237 |
+
_, select_indices = torch.topk(block_confidence,
|
238 |
+
k=min(num_transfer_tokens[j, i].item(),
|
239 |
+
block_confidence.numel()))
|
240 |
+
# Adjust indices to global positions
|
241 |
+
select_indices = select_indices + block_start
|
242 |
+
transfer_index[j, select_indices] = True
|
243 |
+
else: # Last step - unmask everything remaining
|
244 |
+
transfer_index[j, block_start:block_end] = mask_index[j, block_start:block_end]
|
245 |
+
|
246 |
+
# Apply the selected tokens
|
247 |
+
x = torch.where(transfer_index, x0, x)
|
248 |
+
|
249 |
+
# Ensure constraints are maintained
|
250 |
+
for pos, token_id in processed_constraints.items():
|
251 |
+
absolute_pos = prompt_length + pos
|
252 |
+
if absolute_pos < x.shape[1]:
|
253 |
+
x[:, absolute_pos] = token_id
|
254 |
+
|
255 |
+
# Create visualization state only for the response part
|
256 |
+
current_state = []
|
257 |
+
for i in range(gen_length):
|
258 |
+
pos = prompt_length + i # Absolute position in the sequence
|
259 |
|
260 |
+
if x[0, pos] == MASK_ID:
|
261 |
+
# Still masked
|
262 |
+
current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks
|
|
|
|
|
|
|
|
|
263 |
|
264 |
+
elif old_x[0, pos] == MASK_ID:
|
265 |
+
# Newly revealed in this step
|
266 |
+
token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
|
267 |
+
# Color based on confidence
|
268 |
+
confidence = float(x0_p[0, pos].cpu())
|
269 |
+
if confidence < 0.3:
|
270 |
+
color = "#FF6666" # Light red
|
271 |
+
elif confidence < 0.7:
|
272 |
+
color = "#FFAA33" # Orange
|
273 |
+
else:
|
274 |
+
color = "#66CC66" # Light green
|
275 |
+
|
276 |
+
current_state.append((token, color))
|
277 |
+
|
278 |
+
else:
|
279 |
+
# Previously revealed
|
280 |
+
token = tokenizer.decode([x[0, pos].item()], skip_special_tokens=True)
|
281 |
+
current_state.append((token, "#6699CC")) # Light blue
|
282 |
+
|
283 |
+
visualization_states.append(current_state)
|
284 |
|
285 |
# Extract final text (just the assistant's response)
|
286 |
response_tokens = x[0, prompt_length:]
|
|
|
|
|
|
|
287 |
final_text = tokenizer.decode(response_tokens,
|
288 |
+
skip_special_tokens=True,
|
289 |
+
clean_up_tokenization_spaces=True)
|
290 |
|
291 |
return visualization_states, final_text
|
292 |
|
|
|
296 |
'''
|
297 |
def create_chatbot_demo():
|
298 |
with gr.Blocks(css=css) as demo:
|
299 |
+
gr.Markdown("# LLaDA - Large Language Diffusion Model Demo")
|
300 |
gr.Markdown("[model](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct), [project page](https://ml-gsai.github.io/LLaDA-demo/)")
|
301 |
|
302 |
+
# STATE MANAGEMENT
|
|
|
303 |
chat_history = gr.State([])
|
304 |
|
305 |
# UI COMPONENTS
|
|
|
306 |
with gr.Row():
|
307 |
with gr.Column(scale=3):
|
308 |
chatbot_ui = gr.Chatbot(label="Conversation", height=500)
|
|
|
329 |
combine_adjacent=False,
|
330 |
show_legend=True,
|
331 |
)
|
332 |
+
|
333 |
+
# Advanced generation settings
|
334 |
with gr.Accordion("Generation Settings", open=False):
|
335 |
with gr.Row():
|
336 |
gen_length = gr.Slider(
|
|
|
341 |
minimum=8, maximum=64, value=32, step=4,
|
342 |
label="Denoising Steps"
|
343 |
)
|
344 |
+
with gr.Row():
|
345 |
+
temperature = gr.Slider(
|
346 |
+
minimum=0.0, maximum=1.0, value=0.0, step=0.1,
|
347 |
+
label="Temperature"
|
348 |
+
)
|
349 |
+
cfg_scale = gr.Slider(
|
350 |
+
minimum=0.0, maximum=2.0, value=0.0, step=0.1,
|
351 |
+
label="CFG Scale"
|
352 |
+
)
|
353 |
+
with gr.Row():
|
354 |
+
block_length = gr.Slider(
|
355 |
+
minimum=8, maximum=128, value=32, step=8,
|
356 |
+
label="Block Length"
|
357 |
+
)
|
358 |
+
remasking_strategy = gr.Radio(
|
359 |
+
choices=["low_confidence", "random"],
|
360 |
+
value="low_confidence",
|
361 |
+
label="Remasking Strategy"
|
362 |
+
)
|
363 |
+
with gr.Row():
|
364 |
+
visualization_delay = gr.Slider(
|
365 |
+
minimum=0.0, maximum=1.0, value=0.1, step=0.1,
|
366 |
+
label="Visualization Delay (seconds)"
|
367 |
+
)
|
368 |
|
369 |
+
# Current response text box (hidden)
|
370 |
current_response = gr.Textbox(
|
371 |
label="Current Response",
|
372 |
placeholder="The assistant's response will appear here...",
|
|
|
404 |
# Return immediately to update UI with user message
|
405 |
return history, history_for_display, message_out, [], ""
|
406 |
|
407 |
+
def bot_response(history, gen_length, steps, constraints, delay, temperature, cfg_scale, block_length, remasking):
|
408 |
"""Generate bot response for the latest message"""
|
409 |
if not history:
|
410 |
return history, [], ""
|
|
|
428 |
messages,
|
429 |
gen_length=gen_length,
|
430 |
steps=steps,
|
431 |
+
constraints=parsed_constraints,
|
432 |
+
temperature=temperature,
|
433 |
+
cfg_scale=cfg_scale,
|
434 |
+
block_length=block_length,
|
435 |
+
remasking=remasking
|
436 |
)
|
437 |
|
438 |
# Update history with the assistant's response
|
|
|
488 |
# This happens after the user message is displayed
|
489 |
msg_submit.then(
|
490 |
fn=bot_response,
|
491 |
+
inputs=[
|
492 |
+
chat_history, gen_length, steps, constraints_input,
|
493 |
+
visualization_delay, temperature, cfg_scale, block_length,
|
494 |
+
remasking_strategy
|
495 |
+
],
|
496 |
outputs=[chatbot_ui, output_vis, current_response]
|
497 |
)
|
498 |
|
499 |
send_click.then(
|
500 |
fn=bot_response,
|
501 |
+
inputs=[
|
502 |
+
chat_history, gen_length, steps, constraints_input,
|
503 |
+
visualization_delay, temperature, cfg_scale, block_length,
|
504 |
+
remasking_strategy
|
505 |
+
],
|
506 |
outputs=[chatbot_ui, output_vis, current_response]
|
507 |
)
|
508 |
|