|
# Guidelines |
|
Here, we provide guidelines for the model architecture, pre-training, SFT, and inference of LLaDA. |
|
|
|
## Model Architecture |
|
|
|
LLaDA employs a Transformer Encoder as the network architecture for its mask predictor. |
|
In terms of trainable parameters, the Transformer Encoder is identical to the Transformer |
|
Decoder. Starting from an autoregressive model, we derive the backbone of LLaDA by simply |
|
removing the causal mask from the self-attention mechanism as following. |
|
|
|
<div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 50px;"> |
|
<img src="imgs/transformer1.png" style="width: 90%;" /> |
|
<img src="imgs/transformer2.png" style="width: 90%;" /> |
|
</div> |
|
|
|
In addition, LLaDA designates a reserved token as the mask token (i.e., 126336). |
|
|
|
|
|
## Pre-training |
|
The pre-training of LLaDA is straightforward and simple. Starting from an existing |
|
autoregressive model training code, only a few lines need to be modified. |
|
We provide the core code (i.e., loss computation) here. |
|
|
|
```angular2html |
|
def forward_process(input_ids, eps=1e-3): |
|
b, l = input_ids.shape |
|
t = torch.rand(b, device=input_ids.device) |
|
p_mask = (1 - eps) * t + eps |
|
p_mask = p_mask[:, None].repeat(1, l) |
|
|
|
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask |
|
# 126336 is used for [MASK] token |
|
noisy_batch = torch.where(masked_indices, 126336, input_ids) |
|
return noisy_batch, masked_indices, p_mask |
|
|
|
# The data is an integer tensor of shape (b, 4096), |
|
# where b represents the batch size and 4096 is the sequence length. |
|
input_ids = batch["input_ids"] |
|
|
|
# We set 1% of the pre-training data to a random length that is uniformly sampled from the range [1, 4096]. |
|
# The following implementation is not elegant and involves some data waste. |
|
# However, the data waste is minimal, so we ignore it. |
|
if torch.rand(1) < 0.01: |
|
random_length = torch.randint(1, input_ids.shape[1] + 1, (1,)) |
|
input_ids = input_ids[:, :random_length] |
|
|
|
noisy_batch, masked_indices, p_mask = forward_process(input_ids) |
|
logits = model(input_ids=noisy_batch).logits |
|
|
|
token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] |
|
loss = token_loss.sum() / (input_ids.shape[0] * input_ids.shape[1]) |
|
|
|
``` |
|
|
|
## SFT |
|
First, please refer to Appendix B.1 for the preprocessing of the SFT data. After preprocessing the data, |
|
the data format is as follows. For simplicity, we treat each word as a token and set the batch size to 2 |
|
in the following visualization. |
|
```angular2html |
|
input_ids: |
|
<BOS><start_id>user<end_id>\nWhat is the capital of France?<eot_id><start_id>assistant<end_id>\nParis.<EOS><EOS><EOS><EOS><EOS><EOS><EOS><EOS><EOS><EOS> |
|
<BOS><start_id>user<end_id>\nWhat is the capital of Canada?<eot_id><start_id>assistant<end_id>\nThe capital of Canada is Ottawa, located in Ontario.<EOS> |
|
|
|
prompt_lengths: |
|
[17, 17] |
|
``` |
|
|
|
After preprocessing the SFT data, we can obtain the SFT code by making simple modifications to the pre-training code. |
|
The key difference from pre-training is that SFT does not add noise to the prompt. |
|
```angular2html |
|
input_ids, prompt_lengths = batch["input_ids"], batch["prompt_lengths"] |
|
|
|
noisy_batch, _, p_mask = forward_process(input_ids) |
|
|
|
# Do not add noise to the prompt |
|
token_positions = torch.arange(noisy_batch.shape[1], device=noisy_batch.device).expand(noisy_batch.size(0), noisy_batch.size(1)) |
|
prompt_mask = (temp_tensor < prompt_length.unsqueeze(1)) |
|
noisy_batch[prompt_mask] = input_ids[prompt_mask] |
|
|
|
# Calculate the answer length (including the padded <EOS> tokens) |
|
prompt_mask = prompt_mask.to(torch.int64) |
|
answer_lengths = torch.sum((1 - prompt_mask), dim=-1, keepdim=True) |
|
answer_lengths = answer_length.repeat(1, noisy_batch.shape[1]) |
|
|
|
masked_indices = (noisy_batch == 126336) |
|
|
|
logits = model(input_ids=noisy_batch).logits |
|
|
|
token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] |
|
ce_loss = torch.sum(token_loss / answer_lengths[masked_indices]) / input_ids.shape[0] |
|
``` |
|
|
|
## Sampling |
|
Overall, we categorize LLaDA's sampling process into three types: fixed-length, semi-autoregressive-origin, and semi-autoregressive-padding. |
|
**It is worth noting that the semi-autoregressive-origin method was not mentioned in our paper, nor did we provide the corresponding code**. |
|
However, we include it here because we believe that sharing both our failures and insights from the exploration process is valuable. |
|
These three sampling methods are illustrated in the figure below. |
|
|
|
|
|
<div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 50px;"> |
|
<img src="imgs/sample.png" style="width: 100%;" /> |
|
</div> |
|
|
|
For each step in the above three sampling processes, as detailed in Section 2.4 in our paper, the mask predictor |
|
first predicts all masked tokens simultaneously. Then, a certain proportion of these predictions are remasked. |
|
To determine which predicted tokens should be re-masked, we can adopt two strategies: *randomly remasking* or |
|
*low-confidence remasking*. Notably, both remasking strategies can be applied to all three sampling processes |
|
mentioned above. |
|
|
|
For the LLaDA-Base model, we adapt low-confidence remasking to the three sampling processes mentioned above. |
|
We find that fixed-length and semi-autoregressive-padding achieve similar results, whereas semi-autoregressive-origin |
|
performs slightly worse. |
|
|
|
For the LLaDA-Instruct model, the situation is slightly more complex. |
|
|
|
First, if the semi-autoregressive-origin method is used, |
|
the Instruct model performs poorly. This is because, during SFT, each sequence is a complete sentence (whereas in pre-training, |
|
many sequences are truncated sentences). As a result, during sampling, given a generated length, regardless of whether it is |
|
long or short, the Instruct model tends to generate a complete sentence. Unlike the Base model, it does not encounter cases |
|
where a sentence is only partially generated and needs to be continued. |
|
|
|
When performing fixed-length sampling with a high answer length (e.g., greater than 512), |
|
we find that low-confidence remasking results in an unusually high proportion of `<EOS>` tokens in |
|
the generated sentences, which severely impacts the model's performance. In contrast, this |
|
issue does not arise when randomly remasking is used. |
|
|
|
Furthermore, since low-confidence remasking achieved better results in the Base model, we also hoped that it could be applied to |
|
the Instruct model. We found that combining low-confidence remasking with semi-autoregressive-padding effectively mitigates |
|
the issue of generating an excessively high proportion of <EOS> tokens. Moreover, this combination achieves |
|
slightly better results than randomly remasking & fixed-length. |
|
|
|
You can find more details about the sampling method in our paper. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|