multimodalart HF staff commited on
Commit
f6d8cac
·
verified ·
1 Parent(s): 7a998bc

Upload 11 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ imgs/diff_remask.gif filter=lfs diff=lfs merge=lfs -text
37
+ imgs/sample.png filter=lfs diff=lfs merge=lfs -text
38
+ imgs/transformer2.png filter=lfs diff=lfs merge=lfs -text
GUIDELINES.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Guidelines
2
+ Here, we provide guidelines for the model architecture, pre-training, SFT, and inference of LLaDA.
3
+
4
+ ## Model Architecture
5
+
6
+ LLaDA employs a Transformer Encoder as the network architecture for its mask predictor.
7
+ In terms of trainable parameters, the Transformer Encoder is identical to the Transformer
8
+ Decoder. Starting from an autoregressive model, we derive the backbone of LLaDA by simply
9
+ removing the causal mask from the self-attention mechanism as following.
10
+
11
+ <div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 50px;">
12
+ <img src="imgs/transformer1.png" style="width: 90%;" />
13
+ <img src="imgs/transformer2.png" style="width: 90%;" />
14
+ </div>
15
+
16
+ In addition, LLaDA designates a reserved token as the mask token (i.e., 126336).
17
+
18
+
19
+ ## Pre-training
20
+ The pre-training of LLaDA is straightforward and simple. Starting from an existing
21
+ autoregressive model training code, only a few lines need to be modified.
22
+ We provide the core code (i.e., loss computation) here.
23
+
24
+ ```angular2html
25
+ def forward_process(input_ids, eps=1e-3):
26
+ b, l = input_ids.shape
27
+ t = torch.rand(b, device=input_ids.device)
28
+ p_mask = (1 - eps) * t + eps
29
+ p_mask = p_mask[:, None].repeat(1, l)
30
+
31
+ masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
32
+ # 126336 is used for [MASK] token
33
+ noisy_batch = torch.where(masked_indices, 126336, input_ids)
34
+ return noisy_batch, masked_indices, p_mask
35
+
36
+ # The data is an integer tensor of shape (b, 4096),
37
+ # where b represents the batch size and 4096 is the sequence length.
38
+ input_ids = batch["input_ids"]
39
+
40
+ # We set 1% of the pre-training data to a random length that is uniformly sampled from the range [1, 4096].
41
+ # The following implementation is not elegant and involves some data waste.
42
+ # However, the data waste is minimal, so we ignore it.
43
+ if torch.rand(1) < 0.01:
44
+ random_length = torch.randint(1, input_ids.shape[1] + 1, (1,))
45
+ input_ids = input_ids[:, :random_length]
46
+
47
+ noisy_batch, masked_indices, p_mask = forward_process(input_ids)
48
+ logits = model(input_ids=noisy_batch).logits
49
+
50
+ token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
51
+ loss = token_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
52
+
53
+ ```
54
+
55
+ ## SFT
56
+ First, please refer to Appendix B.1 for the preprocessing of the SFT data. After preprocessing the data,
57
+ the data format is as follows. For simplicity, we treat each word as a token and set the batch size to 2
58
+ in the following visualization.
59
+ ```angular2html
60
+ input_ids:
61
+ <BOS><start_id>user<end_id>\nWhat is the capital of France?<eot_id><start_id>assistant<end_id>\nParis.<EOS><EOS><EOS><EOS><EOS><EOS><EOS><EOS><EOS><EOS>
62
+ <BOS><start_id>user<end_id>\nWhat is the capital of Canada?<eot_id><start_id>assistant<end_id>\nThe capital of Canada is Ottawa, located in Ontario.<EOS>
63
+
64
+ prompt_lengths:
65
+ [17, 17]
66
+ ```
67
+
68
+ After preprocessing the SFT data, we can obtain the SFT code by making simple modifications to the pre-training code.
69
+ The key difference from pre-training is that SFT does not add noise to the prompt.
70
+ ```angular2html
71
+ input_ids, prompt_lengths = batch["input_ids"], batch["prompt_lengths"]
72
+
73
+ noisy_batch, _, p_mask = forward_process(input_ids)
74
+
75
+ # Do not add noise to the prompt
76
+ token_positions = torch.arange(noisy_batch.shape[1], device=noisy_batch.device).expand(noisy_batch.size(0), noisy_batch.size(1))
77
+ prompt_mask = (temp_tensor < prompt_length.unsqueeze(1))
78
+ noisy_batch[prompt_mask] = input_ids[prompt_mask]
79
+
80
+ # Calculate the answer length (including the padded <EOS> tokens)
81
+ prompt_mask = prompt_mask.to(torch.int64)
82
+ answer_lengths = torch.sum((1 - prompt_mask), dim=-1, keepdim=True)
83
+ answer_lengths = answer_length.repeat(1, noisy_batch.shape[1])
84
+
85
+ masked_indices = (noisy_batch == 126336)
86
+
87
+ logits = model(input_ids=noisy_batch).logits
88
+
89
+ token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
90
+ ce_loss = torch.sum(token_loss / answer_lengths[masked_indices]) / input_ids.shape[0]
91
+ ```
92
+
93
+ ## Sampling
94
+ Overall, we categorize LLaDA's sampling process into three types: fixed-length, semi-autoregressive-origin, and semi-autoregressive-padding.
95
+ **It is worth noting that the semi-autoregressive-origin method was not mentioned in our paper, nor did we provide the corresponding code**.
96
+ However, we include it here because we believe that sharing both our failures and insights from the exploration process is valuable.
97
+ These three sampling methods are illustrated in the figure below.
98
+
99
+
100
+ <div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 50px;">
101
+ <img src="imgs/sample.png" style="width: 100%;" />
102
+ </div>
103
+
104
+ For each step in the above three sampling processes, as detailed in Section 2.4 in our paper, the mask predictor
105
+ first predicts all masked tokens simultaneously. Then, a certain proportion of these predictions are remasked.
106
+ To determine which predicted tokens should be re-masked, we can adopt two strategies: *randomly remasking* or
107
+ *low-confidence remasking*. Notably, both remasking strategies can be applied to all three sampling processes
108
+ mentioned above.
109
+
110
+ For the LLaDA-Base model, we adapt low-confidence remasking to the three sampling processes mentioned above.
111
+ We find that fixed-length and semi-autoregressive-padding achieve similar results, whereas semi-autoregressive-origin
112
+ performs slightly worse.
113
+
114
+ For the LLaDA-Instruct model, the situation is slightly more complex.
115
+
116
+ First, if the semi-autoregressive-origin method is used,
117
+ the Instruct model performs poorly. This is because, during SFT, each sequence is a complete sentence (whereas in pre-training,
118
+ many sequences are truncated sentences). As a result, during sampling, given a generated length, regardless of whether it is
119
+ long or short, the Instruct model tends to generate a complete sentence. Unlike the Base model, it does not encounter cases
120
+ where a sentence is only partially generated and needs to be continued.
121
+
122
+ When performing fixed-length sampling with a high answer length (e.g., greater than 512),
123
+ we find that low-confidence remasking results in an unusually high proportion of `<EOS>` tokens in
124
+ the generated sentences, which severely impacts the model's performance. In contrast, this
125
+ issue does not arise when randomly remasking is used.
126
+
127
+ Furthermore, since low-confidence remasking achieved better results in the Base model, we also hoped that it could be applied to
128
+ the Instruct model. We found that combining low-confidence remasking with semi-autoregressive-padding effectively mitigates
129
+ the issue of generating an excessively high proportion of <EOS> tokens. Moreover, this combination achieves
130
+ slightly better results than randomly remasking & fixed-length.
131
+
132
+ You can find more details about the sampling method in our paper.
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 NieShenRuc
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
chat.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from generate import generate
4
+ from transformers import AutoTokenizer, AutoModel
5
+
6
+
7
+ def chat():
8
+ device = 'cuda'
9
+ model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
10
+ tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
11
+
12
+ gen_length = 128
13
+ steps = 128
14
+ print('*' * 66)
15
+ print(f'** Answer Length: {gen_length} | Sampling Steps: {steps} **')
16
+ print('*' * 66)
17
+
18
+ conversation_num = 0
19
+ while True:
20
+ user_input = input("Enter your question: ")
21
+
22
+ m = [{"role": "user", "content": user_input}]
23
+ user_input = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
24
+ input_ids = tokenizer(user_input)['input_ids']
25
+ input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
26
+
27
+ if conversation_num == 0:
28
+ prompt = input_ids
29
+ else:
30
+ prompt = torch.cat([prompt, input_ids[:, 1:]], dim=1)
31
+
32
+ out = generate(model, prompt, steps=steps, gen_length=gen_length, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
33
+
34
+ answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
35
+ print(f"Bot's reply: {answer}")
36
+
37
+ # remove the <EOS>
38
+ prompt = out[out != 126081].unsqueeze(0)
39
+ conversation_num += 1
40
+ print('-----------------------------------------------------------------------')
41
+
42
+
43
+ if __name__ == "__main__":
44
+ chat()
45
+
generate.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ from transformers import AutoTokenizer, AutoModel
6
+
7
+
8
+ def add_gumbel_noise(logits, temperature):
9
+ '''
10
+ The Gumbel max is a method for sampling categorical distributions.
11
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
12
+ Thus, we use float64.
13
+ '''
14
+ logits = logits.to(torch.float64)
15
+ noise = torch.rand_like(logits, dtype=torch.float64)
16
+ gumbel_noise = (- torch.log(noise)) ** temperature
17
+ return logits.exp() / gumbel_noise
18
+
19
+
20
+ def get_num_transfer_tokens(mask_index, steps):
21
+ '''
22
+ In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
23
+ Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
24
+ the expected number of tokens transitioned at each step should be consistent.
25
+
26
+ This function is designed to precompute the number of tokens that need to be transitioned at each step.
27
+ '''
28
+ mask_num = mask_index.sum(dim=1, keepdim=True)
29
+
30
+ base = mask_num // steps
31
+ remainder = mask_num % steps
32
+
33
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
34
+
35
+ for i in range(mask_num.size(0)):
36
+ num_transfer_tokens[i, :remainder[i]] += 1
37
+
38
+ return num_transfer_tokens
39
+
40
+
41
+ @ torch.no_grad()
42
+ def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
43
+ cfg_scale=0., remasking='low_confidence', mask_id=126336):
44
+ '''
45
+ Args:
46
+ model: Mask predictor.
47
+ prompt: A tensor of shape (1, l).
48
+ steps: Sampling steps, less than or equal to gen_length.
49
+ gen_length: Generated answer length.
50
+ block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
51
+ temperature: Categorical distribution sampling temperature.
52
+ cfg_scale: Unsupervised classifier-free guidance scale.
53
+ remasking: Remasking strategy. 'low_confidence' or 'random'.
54
+ mask_id: The toke id of [MASK] is 126336.
55
+ '''
56
+ x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
57
+ x[:, :prompt.shape[1]] = prompt.clone()
58
+
59
+ prompt_index = (x != mask_id)
60
+
61
+ assert gen_length % block_length == 0
62
+ num_blocks = gen_length // block_length
63
+
64
+ assert steps % num_blocks == 0
65
+ steps = steps // num_blocks
66
+
67
+ for num_block in range(num_blocks):
68
+ block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
69
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
70
+ for i in range(steps):
71
+ mask_index = (x == mask_id)
72
+ if cfg_scale > 0.:
73
+ un_x = x.clone()
74
+ un_x[prompt_index] = mask_id
75
+ x_ = torch.cat([x, un_x], dim=0)
76
+ logits = model(x_).logits
77
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
78
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
79
+ else:
80
+ logits = model(x).logits
81
+
82
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
83
+ x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
84
+
85
+ if remasking == 'low_confidence':
86
+ p = F.softmax(logits.to(torch.float64), dim=-1)
87
+ x0_p = torch.squeeze(
88
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
89
+ elif remasking == 'random':
90
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
91
+ else:
92
+ raise NotImplementedError(remasking)
93
+
94
+ x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
95
+
96
+ x0 = torch.where(mask_index, x0, x)
97
+ confidence = torch.where(mask_index, x0_p, -np.inf)
98
+
99
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
100
+ for j in range(confidence.shape[0]):
101
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
102
+ transfer_index[j, select_index] = True
103
+ x[transfer_index] = x0[transfer_index]
104
+
105
+ return x
106
+
107
+
108
+ def main():
109
+ device = 'cuda'
110
+
111
+ model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
112
+ tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
113
+
114
+ prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?"
115
+
116
+ # Add special tokens for the Instruct model. The Base model does not require the following two lines.
117
+ m = [{"role": "user", "content": prompt}, ]
118
+ prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
119
+
120
+ input_ids = tokenizer(prompt)['input_ids']
121
+ input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
122
+
123
+ out = generate(model, input_ids, steps=128, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
124
+ print(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0])
125
+
126
+
127
+ if __name__ == '__main__':
128
+ main()
get_log_likelihood.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from transformers import AutoTokenizer, AutoModel
5
+
6
+
7
+ def forward_process(batch, prompt_index, mask_id):
8
+ b, l = batch.shape
9
+
10
+ target_len = (l - prompt_index.sum()).item()
11
+ k = torch.randint(1, target_len + 1, (), device=batch.device)
12
+
13
+ x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long()
14
+ x = ((x - 1) % target_len) + 1
15
+ assert x.min() >= 1 and x.max() <= target_len
16
+
17
+ indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
18
+ is_mask = indices < x.unsqueeze(1)
19
+ for i in range(b):
20
+ is_mask[i] = is_mask[i][torch.randperm(target_len)]
21
+
22
+ is_mask = torch.cat((torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1)
23
+ noisy_batch = torch.where(is_mask, mask_id, batch)
24
+
25
+ # Return the masked batch and the mask ratio
26
+ return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
27
+
28
+
29
+ def get_logits(model, batch, prompt_index, cfg_scale, mask_id):
30
+ if cfg_scale > 0.:
31
+ assert len(prompt_index) == batch.shape[1]
32
+ prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
33
+ un_batch = batch.clone()
34
+ un_batch[prompt_index] = mask_id
35
+ batch = torch.cat([batch, un_batch])
36
+
37
+ input = batch
38
+ logits = model(input).logits
39
+
40
+ if cfg_scale > 0.:
41
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
42
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
43
+ return logits
44
+
45
+
46
+ @ torch.no_grad()
47
+ def get_log_likelihood(model, prompt, answer, mc_num=128, batch_size=16, cfg_scale=0., mask_id=126336):
48
+ '''
49
+ Args:
50
+ model: Mask predictor.
51
+ prompt: A tensor of shape (l1).
52
+ answer: A tensor of shape (l2).
53
+ mc_num: Monte Carlo estimation times.
54
+ As detailed in Appendix B.5. Since MMLU, CMMLU, and C-EVAL only require the likelihood of a single token, a
55
+ single Monte Carlo estimate is sufficient for these benchmarks. For all other benchmarks, we find that 128
56
+ Monte Carlo samples are adequate to produce stable results.
57
+ batch_size: Mini batch size.
58
+ cfg_scale: Unsupervised classifier-free guidance scale.
59
+ mask_id: The toke id of [MASK] is 126336.
60
+ '''
61
+ seq = torch.concatenate([prompt, answer])[None, :]
62
+ seq = seq.repeat((batch_size, 1)).to(model.device)
63
+ prompt_index = torch.arange(seq.shape[1], device=model.device) < len(prompt)
64
+
65
+ loss_ = []
66
+ for _ in range(mc_num // batch_size):
67
+ perturbed_seq, p_mask = forward_process(seq, prompt_index, mask_id)
68
+ mask_index = perturbed_seq == mask_id
69
+
70
+ logits = get_logits(model, perturbed_seq, prompt_index, cfg_scale, mask_id)
71
+
72
+ loss = F.cross_entropy(logits[mask_index], seq[mask_index], reduction='none') / p_mask[mask_index]
73
+ loss = loss.sum() / batch_size
74
+
75
+ loss_.append(loss.item())
76
+
77
+ return - sum(loss_) / len(loss_)
78
+
79
+
80
+ def main():
81
+ device = 'cuda'
82
+
83
+ model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Base', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
84
+ tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Base', trust_remote_code=True)
85
+
86
+ # this prompt and answer is from Hellaswag dataset
87
+ prompt = 'Roof shingle removal: A man is sitting on a roof. He'
88
+ answer = ' is using wrap to wrap a pair of skis.'
89
+
90
+ prompt = torch.tensor(tokenizer(prompt)['input_ids']).to(device)
91
+ answer = torch.tensor(tokenizer(answer)['input_ids']).to(device)
92
+ print(get_log_likelihood(model, prompt, answer, mc_num=128))
93
+
94
+
95
+ if __name__ == '__main__':
96
+ main()
imgs/LLaDA_vs_LLaMA.svg ADDED
imgs/LLaDA_vs_LLaMA_chat.svg ADDED
imgs/diff_remask.gif ADDED

Git LFS Details

  • SHA256: 0c97f2e338df118984e08456964abc5d0da2119e867066429c218c2a26f7dd3a
  • Pointer size: 132 Bytes
  • Size of remote file: 9.13 MB
imgs/sample.png ADDED

Git LFS Details

  • SHA256: 4e35901be05e2cf4bbde8fc79c32286cd127600b3b62763f240aae989dda12ca
  • Pointer size: 131 Bytes
  • Size of remote file: 298 kB
imgs/transformer1.png ADDED
imgs/transformer2.png ADDED

Git LFS Details

  • SHA256: 8b00226e8c9a653c8efdd0a858d1baf53f7a71817853e3c8d2e60b82b21e5b5c
  • Pointer size: 131 Bytes
  • Size of remote file: 175 kB