Upload 11 files
Browse files- .gitattributes +3 -0
- GUIDELINES.md +140 -0
- LICENSE +21 -0
- chat.py +45 -0
- generate.py +128 -0
- get_log_likelihood.py +96 -0
- imgs/LLaDA_vs_LLaMA.svg +2772 -0
- imgs/LLaDA_vs_LLaMA_chat.svg +2665 -0
- imgs/diff_remask.gif +3 -0
- imgs/sample.png +3 -0
- imgs/transformer1.png +0 -0
- imgs/transformer2.png +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
imgs/diff_remask.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
imgs/sample.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
imgs/transformer2.png filter=lfs diff=lfs merge=lfs -text
|
GUIDELINES.md
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Guidelines
|
2 |
+
Here, we provide guidelines for the model architecture, pre-training, SFT, and inference of LLaDA.
|
3 |
+
|
4 |
+
## Model Architecture
|
5 |
+
|
6 |
+
LLaDA employs a Transformer Encoder as the network architecture for its mask predictor.
|
7 |
+
In terms of trainable parameters, the Transformer Encoder is identical to the Transformer
|
8 |
+
Decoder. Starting from an autoregressive model, we derive the backbone of LLaDA by simply
|
9 |
+
removing the causal mask from the self-attention mechanism as following.
|
10 |
+
|
11 |
+
<div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 50px;">
|
12 |
+
<img src="imgs/transformer1.png" style="width: 90%;" />
|
13 |
+
<img src="imgs/transformer2.png" style="width: 90%;" />
|
14 |
+
</div>
|
15 |
+
|
16 |
+
In addition, LLaDA designates a reserved token as the mask token (i.e., 126336).
|
17 |
+
|
18 |
+
|
19 |
+
## Pre-training
|
20 |
+
The pre-training of LLaDA is straightforward and simple. Starting from an existing
|
21 |
+
autoregressive model training code, only a few lines need to be modified.
|
22 |
+
We provide the core code (i.e., loss computation) here.
|
23 |
+
|
24 |
+
```angular2html
|
25 |
+
def forward_process(input_ids, eps=1e-3):
|
26 |
+
b, l = input_ids.shape
|
27 |
+
t = torch.rand(b, device=input_ids.device)
|
28 |
+
p_mask = (1 - eps) * t + eps
|
29 |
+
p_mask = p_mask[:, None].repeat(1, l)
|
30 |
+
|
31 |
+
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
|
32 |
+
# 126336 is used for [MASK] token
|
33 |
+
noisy_batch = torch.where(masked_indices, 126336, input_ids)
|
34 |
+
return noisy_batch, masked_indices, p_mask
|
35 |
+
|
36 |
+
# The data is an integer tensor of shape (b, 4096),
|
37 |
+
# where b represents the batch size and 4096 is the sequence length.
|
38 |
+
input_ids = batch["input_ids"]
|
39 |
+
|
40 |
+
# We set 1% of the pre-training data to a random length that is uniformly sampled from the range [1, 4096].
|
41 |
+
# The following implementation is not elegant and involves some data waste.
|
42 |
+
# However, the data waste is minimal, so we ignore it.
|
43 |
+
if torch.rand(1) < 0.01:
|
44 |
+
random_length = torch.randint(1, input_ids.shape[1] + 1, (1,))
|
45 |
+
input_ids = input_ids[:, :random_length]
|
46 |
+
|
47 |
+
noisy_batch, masked_indices, p_mask = forward_process(input_ids)
|
48 |
+
logits = model(input_ids=noisy_batch).logits
|
49 |
+
|
50 |
+
token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
|
51 |
+
loss = token_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
|
52 |
+
|
53 |
+
```
|
54 |
+
|
55 |
+
## SFT
|
56 |
+
First, please refer to Appendix B.1 for the preprocessing of the SFT data. After preprocessing the data,
|
57 |
+
the data format is as follows. For simplicity, we treat each word as a token and set the batch size to 2
|
58 |
+
in the following visualization.
|
59 |
+
```angular2html
|
60 |
+
input_ids:
|
61 |
+
<BOS><start_id>user<end_id>\nWhat is the capital of France?<eot_id><start_id>assistant<end_id>\nParis.<EOS><EOS><EOS><EOS><EOS><EOS><EOS><EOS><EOS><EOS>
|
62 |
+
<BOS><start_id>user<end_id>\nWhat is the capital of Canada?<eot_id><start_id>assistant<end_id>\nThe capital of Canada is Ottawa, located in Ontario.<EOS>
|
63 |
+
|
64 |
+
prompt_lengths:
|
65 |
+
[17, 17]
|
66 |
+
```
|
67 |
+
|
68 |
+
After preprocessing the SFT data, we can obtain the SFT code by making simple modifications to the pre-training code.
|
69 |
+
The key difference from pre-training is that SFT does not add noise to the prompt.
|
70 |
+
```angular2html
|
71 |
+
input_ids, prompt_lengths = batch["input_ids"], batch["prompt_lengths"]
|
72 |
+
|
73 |
+
noisy_batch, _, p_mask = forward_process(input_ids)
|
74 |
+
|
75 |
+
# Do not add noise to the prompt
|
76 |
+
token_positions = torch.arange(noisy_batch.shape[1], device=noisy_batch.device).expand(noisy_batch.size(0), noisy_batch.size(1))
|
77 |
+
prompt_mask = (temp_tensor < prompt_length.unsqueeze(1))
|
78 |
+
noisy_batch[prompt_mask] = input_ids[prompt_mask]
|
79 |
+
|
80 |
+
# Calculate the answer length (including the padded <EOS> tokens)
|
81 |
+
prompt_mask = prompt_mask.to(torch.int64)
|
82 |
+
answer_lengths = torch.sum((1 - prompt_mask), dim=-1, keepdim=True)
|
83 |
+
answer_lengths = answer_length.repeat(1, noisy_batch.shape[1])
|
84 |
+
|
85 |
+
masked_indices = (noisy_batch == 126336)
|
86 |
+
|
87 |
+
logits = model(input_ids=noisy_batch).logits
|
88 |
+
|
89 |
+
token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
|
90 |
+
ce_loss = torch.sum(token_loss / answer_lengths[masked_indices]) / input_ids.shape[0]
|
91 |
+
```
|
92 |
+
|
93 |
+
## Sampling
|
94 |
+
Overall, we categorize LLaDA's sampling process into three types: fixed-length, semi-autoregressive-origin, and semi-autoregressive-padding.
|
95 |
+
**It is worth noting that the semi-autoregressive-origin method was not mentioned in our paper, nor did we provide the corresponding code**.
|
96 |
+
However, we include it here because we believe that sharing both our failures and insights from the exploration process is valuable.
|
97 |
+
These three sampling methods are illustrated in the figure below.
|
98 |
+
|
99 |
+
|
100 |
+
<div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 50px;">
|
101 |
+
<img src="imgs/sample.png" style="width: 100%;" />
|
102 |
+
</div>
|
103 |
+
|
104 |
+
For each step in the above three sampling processes, as detailed in Section 2.4 in our paper, the mask predictor
|
105 |
+
first predicts all masked tokens simultaneously. Then, a certain proportion of these predictions are remasked.
|
106 |
+
To determine which predicted tokens should be re-masked, we can adopt two strategies: *randomly remasking* or
|
107 |
+
*low-confidence remasking*. Notably, both remasking strategies can be applied to all three sampling processes
|
108 |
+
mentioned above.
|
109 |
+
|
110 |
+
For the LLaDA-Base model, we adapt low-confidence remasking to the three sampling processes mentioned above.
|
111 |
+
We find that fixed-length and semi-autoregressive-padding achieve similar results, whereas semi-autoregressive-origin
|
112 |
+
performs slightly worse.
|
113 |
+
|
114 |
+
For the LLaDA-Instruct model, the situation is slightly more complex.
|
115 |
+
|
116 |
+
First, if the semi-autoregressive-origin method is used,
|
117 |
+
the Instruct model performs poorly. This is because, during SFT, each sequence is a complete sentence (whereas in pre-training,
|
118 |
+
many sequences are truncated sentences). As a result, during sampling, given a generated length, regardless of whether it is
|
119 |
+
long or short, the Instruct model tends to generate a complete sentence. Unlike the Base model, it does not encounter cases
|
120 |
+
where a sentence is only partially generated and needs to be continued.
|
121 |
+
|
122 |
+
When performing fixed-length sampling with a high answer length (e.g., greater than 512),
|
123 |
+
we find that low-confidence remasking results in an unusually high proportion of `<EOS>` tokens in
|
124 |
+
the generated sentences, which severely impacts the model's performance. In contrast, this
|
125 |
+
issue does not arise when randomly remasking is used.
|
126 |
+
|
127 |
+
Furthermore, since low-confidence remasking achieved better results in the Base model, we also hoped that it could be applied to
|
128 |
+
the Instruct model. We found that combining low-confidence remasking with semi-autoregressive-padding effectively mitigates
|
129 |
+
the issue of generating an excessively high proportion of <EOS> tokens. Moreover, this combination achieves
|
130 |
+
slightly better results than randomly remasking & fixed-length.
|
131 |
+
|
132 |
+
You can find more details about the sampling method in our paper.
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 NieShenRuc
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
chat.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from generate import generate
|
4 |
+
from transformers import AutoTokenizer, AutoModel
|
5 |
+
|
6 |
+
|
7 |
+
def chat():
|
8 |
+
device = 'cuda'
|
9 |
+
model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
|
11 |
+
|
12 |
+
gen_length = 128
|
13 |
+
steps = 128
|
14 |
+
print('*' * 66)
|
15 |
+
print(f'** Answer Length: {gen_length} | Sampling Steps: {steps} **')
|
16 |
+
print('*' * 66)
|
17 |
+
|
18 |
+
conversation_num = 0
|
19 |
+
while True:
|
20 |
+
user_input = input("Enter your question: ")
|
21 |
+
|
22 |
+
m = [{"role": "user", "content": user_input}]
|
23 |
+
user_input = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
|
24 |
+
input_ids = tokenizer(user_input)['input_ids']
|
25 |
+
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
|
26 |
+
|
27 |
+
if conversation_num == 0:
|
28 |
+
prompt = input_ids
|
29 |
+
else:
|
30 |
+
prompt = torch.cat([prompt, input_ids[:, 1:]], dim=1)
|
31 |
+
|
32 |
+
out = generate(model, prompt, steps=steps, gen_length=gen_length, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
|
33 |
+
|
34 |
+
answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
|
35 |
+
print(f"Bot's reply: {answer}")
|
36 |
+
|
37 |
+
# remove the <EOS>
|
38 |
+
prompt = out[out != 126081].unsqueeze(0)
|
39 |
+
conversation_num += 1
|
40 |
+
print('-----------------------------------------------------------------------')
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
chat()
|
45 |
+
|
generate.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from transformers import AutoTokenizer, AutoModel
|
6 |
+
|
7 |
+
|
8 |
+
def add_gumbel_noise(logits, temperature):
|
9 |
+
'''
|
10 |
+
The Gumbel max is a method for sampling categorical distributions.
|
11 |
+
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
|
12 |
+
Thus, we use float64.
|
13 |
+
'''
|
14 |
+
logits = logits.to(torch.float64)
|
15 |
+
noise = torch.rand_like(logits, dtype=torch.float64)
|
16 |
+
gumbel_noise = (- torch.log(noise)) ** temperature
|
17 |
+
return logits.exp() / gumbel_noise
|
18 |
+
|
19 |
+
|
20 |
+
def get_num_transfer_tokens(mask_index, steps):
|
21 |
+
'''
|
22 |
+
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
|
23 |
+
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
|
24 |
+
the expected number of tokens transitioned at each step should be consistent.
|
25 |
+
|
26 |
+
This function is designed to precompute the number of tokens that need to be transitioned at each step.
|
27 |
+
'''
|
28 |
+
mask_num = mask_index.sum(dim=1, keepdim=True)
|
29 |
+
|
30 |
+
base = mask_num // steps
|
31 |
+
remainder = mask_num % steps
|
32 |
+
|
33 |
+
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
|
34 |
+
|
35 |
+
for i in range(mask_num.size(0)):
|
36 |
+
num_transfer_tokens[i, :remainder[i]] += 1
|
37 |
+
|
38 |
+
return num_transfer_tokens
|
39 |
+
|
40 |
+
|
41 |
+
@ torch.no_grad()
|
42 |
+
def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
|
43 |
+
cfg_scale=0., remasking='low_confidence', mask_id=126336):
|
44 |
+
'''
|
45 |
+
Args:
|
46 |
+
model: Mask predictor.
|
47 |
+
prompt: A tensor of shape (1, l).
|
48 |
+
steps: Sampling steps, less than or equal to gen_length.
|
49 |
+
gen_length: Generated answer length.
|
50 |
+
block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
|
51 |
+
temperature: Categorical distribution sampling temperature.
|
52 |
+
cfg_scale: Unsupervised classifier-free guidance scale.
|
53 |
+
remasking: Remasking strategy. 'low_confidence' or 'random'.
|
54 |
+
mask_id: The toke id of [MASK] is 126336.
|
55 |
+
'''
|
56 |
+
x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
|
57 |
+
x[:, :prompt.shape[1]] = prompt.clone()
|
58 |
+
|
59 |
+
prompt_index = (x != mask_id)
|
60 |
+
|
61 |
+
assert gen_length % block_length == 0
|
62 |
+
num_blocks = gen_length // block_length
|
63 |
+
|
64 |
+
assert steps % num_blocks == 0
|
65 |
+
steps = steps // num_blocks
|
66 |
+
|
67 |
+
for num_block in range(num_blocks):
|
68 |
+
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
|
69 |
+
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
|
70 |
+
for i in range(steps):
|
71 |
+
mask_index = (x == mask_id)
|
72 |
+
if cfg_scale > 0.:
|
73 |
+
un_x = x.clone()
|
74 |
+
un_x[prompt_index] = mask_id
|
75 |
+
x_ = torch.cat([x, un_x], dim=0)
|
76 |
+
logits = model(x_).logits
|
77 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
78 |
+
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
79 |
+
else:
|
80 |
+
logits = model(x).logits
|
81 |
+
|
82 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
83 |
+
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
|
84 |
+
|
85 |
+
if remasking == 'low_confidence':
|
86 |
+
p = F.softmax(logits.to(torch.float64), dim=-1)
|
87 |
+
x0_p = torch.squeeze(
|
88 |
+
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
89 |
+
elif remasking == 'random':
|
90 |
+
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
|
91 |
+
else:
|
92 |
+
raise NotImplementedError(remasking)
|
93 |
+
|
94 |
+
x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
|
95 |
+
|
96 |
+
x0 = torch.where(mask_index, x0, x)
|
97 |
+
confidence = torch.where(mask_index, x0_p, -np.inf)
|
98 |
+
|
99 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
100 |
+
for j in range(confidence.shape[0]):
|
101 |
+
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
|
102 |
+
transfer_index[j, select_index] = True
|
103 |
+
x[transfer_index] = x0[transfer_index]
|
104 |
+
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
def main():
|
109 |
+
device = 'cuda'
|
110 |
+
|
111 |
+
model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
|
112 |
+
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
|
113 |
+
|
114 |
+
prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?"
|
115 |
+
|
116 |
+
# Add special tokens for the Instruct model. The Base model does not require the following two lines.
|
117 |
+
m = [{"role": "user", "content": prompt}, ]
|
118 |
+
prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
|
119 |
+
|
120 |
+
input_ids = tokenizer(prompt)['input_ids']
|
121 |
+
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
|
122 |
+
|
123 |
+
out = generate(model, input_ids, steps=128, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
|
124 |
+
print(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0])
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
main()
|
get_log_likelihood.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from transformers import AutoTokenizer, AutoModel
|
5 |
+
|
6 |
+
|
7 |
+
def forward_process(batch, prompt_index, mask_id):
|
8 |
+
b, l = batch.shape
|
9 |
+
|
10 |
+
target_len = (l - prompt_index.sum()).item()
|
11 |
+
k = torch.randint(1, target_len + 1, (), device=batch.device)
|
12 |
+
|
13 |
+
x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long()
|
14 |
+
x = ((x - 1) % target_len) + 1
|
15 |
+
assert x.min() >= 1 and x.max() <= target_len
|
16 |
+
|
17 |
+
indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
|
18 |
+
is_mask = indices < x.unsqueeze(1)
|
19 |
+
for i in range(b):
|
20 |
+
is_mask[i] = is_mask[i][torch.randperm(target_len)]
|
21 |
+
|
22 |
+
is_mask = torch.cat((torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1)
|
23 |
+
noisy_batch = torch.where(is_mask, mask_id, batch)
|
24 |
+
|
25 |
+
# Return the masked batch and the mask ratio
|
26 |
+
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
|
27 |
+
|
28 |
+
|
29 |
+
def get_logits(model, batch, prompt_index, cfg_scale, mask_id):
|
30 |
+
if cfg_scale > 0.:
|
31 |
+
assert len(prompt_index) == batch.shape[1]
|
32 |
+
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
33 |
+
un_batch = batch.clone()
|
34 |
+
un_batch[prompt_index] = mask_id
|
35 |
+
batch = torch.cat([batch, un_batch])
|
36 |
+
|
37 |
+
input = batch
|
38 |
+
logits = model(input).logits
|
39 |
+
|
40 |
+
if cfg_scale > 0.:
|
41 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
42 |
+
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
43 |
+
return logits
|
44 |
+
|
45 |
+
|
46 |
+
@ torch.no_grad()
|
47 |
+
def get_log_likelihood(model, prompt, answer, mc_num=128, batch_size=16, cfg_scale=0., mask_id=126336):
|
48 |
+
'''
|
49 |
+
Args:
|
50 |
+
model: Mask predictor.
|
51 |
+
prompt: A tensor of shape (l1).
|
52 |
+
answer: A tensor of shape (l2).
|
53 |
+
mc_num: Monte Carlo estimation times.
|
54 |
+
As detailed in Appendix B.5. Since MMLU, CMMLU, and C-EVAL only require the likelihood of a single token, a
|
55 |
+
single Monte Carlo estimate is sufficient for these benchmarks. For all other benchmarks, we find that 128
|
56 |
+
Monte Carlo samples are adequate to produce stable results.
|
57 |
+
batch_size: Mini batch size.
|
58 |
+
cfg_scale: Unsupervised classifier-free guidance scale.
|
59 |
+
mask_id: The toke id of [MASK] is 126336.
|
60 |
+
'''
|
61 |
+
seq = torch.concatenate([prompt, answer])[None, :]
|
62 |
+
seq = seq.repeat((batch_size, 1)).to(model.device)
|
63 |
+
prompt_index = torch.arange(seq.shape[1], device=model.device) < len(prompt)
|
64 |
+
|
65 |
+
loss_ = []
|
66 |
+
for _ in range(mc_num // batch_size):
|
67 |
+
perturbed_seq, p_mask = forward_process(seq, prompt_index, mask_id)
|
68 |
+
mask_index = perturbed_seq == mask_id
|
69 |
+
|
70 |
+
logits = get_logits(model, perturbed_seq, prompt_index, cfg_scale, mask_id)
|
71 |
+
|
72 |
+
loss = F.cross_entropy(logits[mask_index], seq[mask_index], reduction='none') / p_mask[mask_index]
|
73 |
+
loss = loss.sum() / batch_size
|
74 |
+
|
75 |
+
loss_.append(loss.item())
|
76 |
+
|
77 |
+
return - sum(loss_) / len(loss_)
|
78 |
+
|
79 |
+
|
80 |
+
def main():
|
81 |
+
device = 'cuda'
|
82 |
+
|
83 |
+
model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Base', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
|
84 |
+
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Base', trust_remote_code=True)
|
85 |
+
|
86 |
+
# this prompt and answer is from Hellaswag dataset
|
87 |
+
prompt = 'Roof shingle removal: A man is sitting on a roof. He'
|
88 |
+
answer = ' is using wrap to wrap a pair of skis.'
|
89 |
+
|
90 |
+
prompt = torch.tensor(tokenizer(prompt)['input_ids']).to(device)
|
91 |
+
answer = torch.tensor(tokenizer(answer)['input_ids']).to(device)
|
92 |
+
print(get_log_likelihood(model, prompt, answer, mc_num=128))
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
main()
|
imgs/LLaDA_vs_LLaMA.svg
ADDED
|
imgs/LLaDA_vs_LLaMA_chat.svg
ADDED
|
imgs/diff_remask.gif
ADDED
![]() |
Git LFS Details
|
imgs/sample.png
ADDED
![]() |
Git LFS Details
|
imgs/transformer1.png
ADDED
![]() |
imgs/transformer2.png
ADDED
![]() |
Git LFS Details
|