Spaces:
Running
on
Zero
Running
on
Zero
# Guidelines | |
Here, we provide guidelines for the model architecture, pre-training, SFT, and inference of LLaDA. | |
## Model Architecture | |
LLaDA employs a Transformer Encoder as the network architecture for its mask predictor. | |
In terms of trainable parameters, the Transformer Encoder is identical to the Transformer | |
Decoder. Starting from an autoregressive model, we derive the backbone of LLaDA by simply | |
removing the causal mask from the self-attention mechanism as following. | |
<div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 50px;"> | |
<img src="imgs/transformer1.png" style="width: 90%;" /> | |
<img src="imgs/transformer2.png" style="width: 90%;" /> | |
</div> | |
In addition, LLaDA designates a reserved token as the mask token (i.e., 126336). | |
## Pre-training | |
The pre-training of LLaDA is straightforward and simple. Starting from an existing | |
autoregressive model training code, only a few lines need to be modified. | |
We provide the core code (i.e., loss computation) here. | |
```angular2html | |
def forward_process(input_ids, eps=1e-3): | |
b, l = input_ids.shape | |
t = torch.rand(b, device=input_ids.device) | |
p_mask = (1 - eps) * t + eps | |
p_mask = p_mask[:, None].repeat(1, l) | |
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask | |
# 126336 is used for [MASK] token | |
noisy_batch = torch.where(masked_indices, 126336, input_ids) | |
return noisy_batch, masked_indices, p_mask | |
# The data is an integer tensor of shape (b, 4096), | |
# where b represents the batch size and 4096 is the sequence length. | |
input_ids = batch["input_ids"] | |
# We set 1% of the pre-training data to a random length that is uniformly sampled from the range [1, 4096]. | |
# The following implementation is not elegant and involves some data waste. | |
# However, the data waste is minimal, so we ignore it. | |
if torch.rand(1) < 0.01: | |
random_length = torch.randint(1, input_ids.shape[1] + 1, (1,)) | |
input_ids = input_ids[:, :random_length] | |
noisy_batch, masked_indices, p_mask = forward_process(input_ids) | |
logits = model(input_ids=noisy_batch).logits | |
token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] | |
loss = token_loss.sum() / (input_ids.shape[0] * input_ids.shape[1]) | |
``` | |
## SFT | |
First, please refer to Appendix B.1 for the preprocessing of the SFT data. After preprocessing the data, | |
the data format is as follows. For simplicity, we treat each word as a token and set the batch size to 2 | |
in the following visualization. | |
```angular2html | |
input_ids: | |
<BOS><start_id>user<end_id>\nWhat is the capital of France?<eot_id><start_id>assistant<end_id>\nParis.<EOS><EOS><EOS><EOS><EOS><EOS><EOS><EOS><EOS><EOS> | |
<BOS><start_id>user<end_id>\nWhat is the capital of Canada?<eot_id><start_id>assistant<end_id>\nThe capital of Canada is Ottawa, located in Ontario.<EOS> | |
prompt_lengths: | |
[17, 17] | |
``` | |
After preprocessing the SFT data, we can obtain the SFT code by making simple modifications to the pre-training code. | |
The key difference from pre-training is that SFT does not add noise to the prompt. | |
```angular2html | |
input_ids, prompt_lengths = batch["input_ids"], batch["prompt_lengths"] | |
noisy_batch, _, p_mask = forward_process(input_ids) | |
# Do not add noise to the prompt | |
token_positions = torch.arange(noisy_batch.shape[1], device=noisy_batch.device).expand(noisy_batch.size(0), noisy_batch.size(1)) | |
prompt_mask = (temp_tensor < prompt_length.unsqueeze(1)) | |
noisy_batch[prompt_mask] = input_ids[prompt_mask] | |
# Calculate the answer length (including the padded <EOS> tokens) | |
prompt_mask = prompt_mask.to(torch.int64) | |
answer_lengths = torch.sum((1 - prompt_mask), dim=-1, keepdim=True) | |
answer_lengths = answer_length.repeat(1, noisy_batch.shape[1]) | |
masked_indices = (noisy_batch == 126336) | |
logits = model(input_ids=noisy_batch).logits | |
token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] | |
ce_loss = torch.sum(token_loss / answer_lengths[masked_indices]) / input_ids.shape[0] | |
``` | |
## Sampling | |
Overall, we categorize LLaDA's sampling process into three types: fixed-length, semi-autoregressive-origin, and semi-autoregressive-padding. | |
**It is worth noting that the semi-autoregressive-origin method was not mentioned in our paper, nor did we provide the corresponding code**. | |
However, we include it here because we believe that sharing both our failures and insights from the exploration process is valuable. | |
These three sampling methods are illustrated in the figure below. | |
<div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 50px;"> | |
<img src="imgs/sample.png" style="width: 100%;" /> | |
</div> | |
For each step in the above three sampling processes, as detailed in Section 2.4 in our paper, the mask predictor | |
first predicts all masked tokens simultaneously. Then, a certain proportion of these predictions are remasked. | |
To determine which predicted tokens should be re-masked, we can adopt two strategies: *randomly remasking* or | |
*low-confidence remasking*. Notably, both remasking strategies can be applied to all three sampling processes | |
mentioned above. | |
For the LLaDA-Base model, we adapt low-confidence remasking to the three sampling processes mentioned above. | |
We find that fixed-length and semi-autoregressive-padding achieve similar results, whereas semi-autoregressive-origin | |
performs slightly worse. | |
For the LLaDA-Instruct model, the situation is slightly more complex. | |
First, if the semi-autoregressive-origin method is used, | |
the Instruct model performs poorly. This is because, during SFT, each sequence is a complete sentence (whereas in pre-training, | |
many sequences are truncated sentences). As a result, during sampling, given a generated length, regardless of whether it is | |
long or short, the Instruct model tends to generate a complete sentence. Unlike the Base model, it does not encounter cases | |
where a sentence is only partially generated and needs to be continued. | |
When performing fixed-length sampling with a high answer length (e.g., greater than 512), | |
we find that low-confidence remasking results in an unusually high proportion of `<EOS>` tokens in | |
the generated sentences, which severely impacts the model's performance. In contrast, this | |
issue does not arise when randomly remasking is used. | |
Furthermore, since low-confidence remasking achieved better results in the Base model, we also hoped that it could be applied to | |
the Instruct model. We found that combining low-confidence remasking with semi-autoregressive-padding effectively mitigates | |
the issue of generating an excessively high proportion of <EOS> tokens. Moreover, this combination achieves | |
slightly better results than randomly remasking & fixed-length. | |
You can find more details about the sampling method in our paper. | |