MilindChawre commited on
Commit
61d0253
·
1 Parent(s): 1b4fddf

Adding code for transformer model

Browse files
Files changed (7) hide show
  1. README.md +82 -1
  2. app.py +63 -0
  3. checkpoint.pt +3 -0
  4. input.txt +0 -0
  5. trained_model_quantized.pt +3 -0
  6. training.log +253 -0
  7. transformer.py +353 -0
README.md CHANGED
@@ -10,4 +10,85 @@ pinned: false
10
  short_description: Transformer trained on Shakespeare play dataset
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  short_description: Transformer trained on Shakespeare play dataset
11
  ---
12
 
13
+ # Transformer Model Training
14
+
15
+ This project implements a transformer-based language model using PyTorch. The model is designed to learn from a text corpus and can be trained and fine-tuned for various natural language processing tasks.
16
+
17
+ ## Table of Contents
18
+ - [Features](#features)
19
+ - [Requirements](#requirements)
20
+ - [Installation](#installation)
21
+ - [Usage](#usage)
22
+ - [Training](#training)
23
+ - [Actual Training](#actual-training)
24
+ - [Checkpointing](#checkpointing)
25
+ - [Model Compression](#model-compression)
26
+ - [License](#license)
27
+ - [Acknowledgments](#acknowledgments)
28
+
29
+ ## Features
30
+ - Transformer architecture with causal self-attention and feedforward layers.
31
+ - Efficient data loading and batching.
32
+ - Checkpointing to resume training.
33
+ - Support for multiple devices (CPU, CUDA, MPS).
34
+ - Model compression for reduced file size.
35
+ - Streamlit application for text generation using the trained model.
36
+
37
+ ## Requirements
38
+ - Python 3.6 or higher
39
+ - PyTorch 1.7 or higher
40
+ - tqdm
41
+ - tiktoken
42
+ - streamlit
43
+ - transformers
44
+
45
+ ## Installation
46
+ 1. Clone the repository:
47
+ ```bash
48
+ git clone https://github.com/yourusername/transformer-model-training.git
49
+ cd transformer-model-training
50
+ ```
51
+
52
+ 2. Install the required packages:
53
+ ```bash
54
+ pip install -r requirements.txt
55
+ ```
56
+
57
+ ## Usage
58
+ 1. Prepare your text data in a file named `input.txt`. The model will read this file to load tokens for training.
59
+
60
+ 2. Run the training script:
61
+ ```bash
62
+ python transformer.py
63
+ ```
64
+
65
+ 3. The model will save checkpoints after each epoch in `checkpoint.pt` and the final model in `trained_model_quantized.pt`.
66
+
67
+ 4. To generate text using the trained model, run the Streamlit application:
68
+ ```bash
69
+ streamlit run app.py
70
+ ```
71
+
72
+ 5. Enter your text and specify the length of additional text to generate in the Streamlit interface.
73
+
74
+ ## Training
75
+ - The model is trained using a batch size of 4 and a learning rate of 3e-4.
76
+ - The training loop includes loss calculation, backpropagation, and optimizer steps.
77
+ - The loss is monitored, and checkpoints are saved to allow for resuming training.
78
+ - The training process is logged in `training.log`, which contains detailed statistics for each epoch, including loss values and checkpointing information.
79
+
80
+ ## Actual Training
81
+ The model was trained for a total of **78 epochs**. The final loss achieved at the end of training was approximately **0.904894**. The training log file contains detailed statistics for each epoch, including loss values and checkpointing information. You can find the log file named `training.log` in the project directory.
82
+
83
+ ## Checkpointing
84
+ - The model state and current epoch are saved in a single checkpoint file (`checkpoint.pt`).
85
+ - To resume training from the last checkpoint, simply run the training script again. The model will automatically load the latest checkpoint.
86
+
87
+ ## Model Compression
88
+ - The final model is saved with compression to reduce file size. The model file will be saved as `trained_model_quantized.pt`.
89
+
90
+ ## License
91
+ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
92
+
93
+ ## Acknowledgments
94
+ - This project is inspired by the original GPT architecture and various resources available in the NLP community.
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import tiktoken
4
+ from transformer import GPT, GPTConfig # Ensure you import your model class
5
+
6
+ # Load the trained model
7
+ @st.cache_resource
8
+ def load_model():
9
+ config = GPTConfig()
10
+ model = GPT(config)
11
+ try:
12
+ model.load_state_dict(torch.load('trained_model_quantized.pt'))
13
+ model.eval() # Set the model to evaluation mode
14
+ st.success("Model loaded successfully!")
15
+ except Exception as e:
16
+ st.error(f"Error loading model: {e}")
17
+ return model
18
+
19
+ # Load the tokenizer
20
+ def load_tokenizer():
21
+ return tiktoken.get_encoding('gpt2')
22
+
23
+ # Generate text function
24
+ def generate_text(model, tokenizer, input_text, length, num_sequences):
25
+ # Encode the input text
26
+ input_ids = tokenizer.encode(input_text)
27
+ input_tensor = torch.tensor(input_ids).unsqueeze(0) # Add batch dimension
28
+
29
+ generated_sequences = []
30
+ for _ in range(num_sequences):
31
+ # Generate additional tokens
32
+ with torch.no_grad():
33
+ for _ in range(length):
34
+ logits = model(input_tensor)[0] # Get logits
35
+ next_token_logits = logits[:, -1, :] # Get the last token's logits
36
+ next_token_probs = torch.softmax(next_token_logits, dim=-1)
37
+ next_token = torch.multinomial(next_token_probs, num_samples=1) # Sample from the distribution
38
+ input_tensor = torch.cat((input_tensor, next_token.unsqueeze(0)), dim=1) # Append the new token
39
+
40
+ # Decode the generated tokens
41
+ generated_sequences.append(tokenizer.decode(input_tensor[0].tolist()))
42
+
43
+ return generated_sequences
44
+
45
+ # Streamlit app layout
46
+ st.title("GPT Text Generator")
47
+ st.write("Enter your text and specify the length of additional text to generate.")
48
+
49
+ input_text = st.text_area("Input Text", "Once upon a time", max_chars=512) # Limit to 512 characters
50
+ length = st.slider("Predict Additional Text of Length", 1, 50, 10)
51
+ num_sequences = st.slider("Number of Sequences to Generate", 1, 5, 1)
52
+
53
+ if st.button("Generate"):
54
+ model = load_model()
55
+ tokenizer = load_tokenizer()
56
+ st.write("Generating text...")
57
+ generated_texts = generate_text(model, tokenizer, input_text, length, num_sequences)
58
+ st.write("Text generation complete.")
59
+
60
+ st.write("Generated Texts:")
61
+ for i, text in enumerate(generated_texts):
62
+ st.subheader(f"Sequence {i + 1}")
63
+ st.write(text)
checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9f02348249d0b8457a59bc3331ac807b879f7d32b35886d60c8ab15d18fa6bd
3
+ size 548146590
input.txt ADDED
The diff for this file is too large to render. See raw diff
 
trained_model_quantized.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fec31d4b4fa71331f80d77d4066bb10a71d6118c0c757e341a143b630be08a6
3
+ size 331982620
training.log ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ using device: cuda
2
+ loaded 338025 tokens
3
+ 1 epoch = 82 batches
4
+ Number of model parameters: 124439808
5
+ Epoch 1/70: 100% 82/82 [01:38<00:00, 1.20s/it]
6
+ Epoch 1/70, Loss: 6.169636
7
+ Checkpoint saved to checkpoint.pt
8
+ Epoch 2/70: 100% 82/82 [01:42<00:00, 1.25s/it]
9
+ Epoch 2/70, Loss: 5.720689
10
+ Checkpoint saved to checkpoint.pt
11
+ Epoch 3/70: 100% 82/82 [01:42<00:00, 1.25s/it]
12
+ Epoch 3/70, Loss: 5.390238
13
+ Checkpoint saved to checkpoint.pt
14
+ Epoch 4/70: 100% 82/82 [01:42<00:00, 1.25s/it]
15
+ Epoch 4/70, Loss: 5.164030
16
+ Checkpoint saved to checkpoint.pt
17
+ Epoch 5/70: 100% 82/82 [01:42<00:00, 1.25s/it]
18
+ Epoch 5/70, Loss: 5.051653
19
+ Checkpoint saved to checkpoint.pt
20
+ Epoch 6/70: 100% 82/82 [01:42<00:00, 1.25s/it]
21
+ Epoch 6/70, Loss: 4.947546
22
+ Checkpoint saved to checkpoint.pt
23
+ Epoch 7/70: 100% 82/82 [01:42<00:00, 1.25s/it]
24
+ Epoch 7/70, Loss: 4.893464
25
+ Checkpoint saved to checkpoint.pt
26
+ Epoch 8/70: 100% 82/82 [01:42<00:00, 1.25s/it]
27
+ Epoch 8/70, Loss: 4.785249
28
+ Checkpoint saved to checkpoint.pt
29
+ Epoch 9/70: 100% 82/82 [01:42<00:00, 1.25s/it]
30
+ Epoch 9/70, Loss: 4.773346
31
+ Checkpoint saved to checkpoint.pt
32
+ Epoch 10/70: 100% 82/82 [01:42<00:00, 1.25s/it]
33
+ Epoch 10/70, Loss: 4.669469
34
+ Checkpoint saved to checkpoint.pt
35
+ Epoch 11/70: 100% 82/82 [01:42<00:00, 1.25s/it]
36
+ Epoch 11/70, Loss: 4.617172
37
+ Checkpoint saved to checkpoint.pt
38
+ Epoch 12/70: 100% 82/82 [01:42<00:00, 1.25s/it]
39
+ Epoch 12/70, Loss: 4.594382
40
+ Checkpoint saved to checkpoint.pt
41
+ Epoch 13/70: 100% 82/82 [01:42<00:00, 1.25s/it]
42
+ Epoch 13/70, Loss: 4.554847
43
+ Checkpoint saved to checkpoint.pt
44
+ Epoch 14/70: 100% 82/82 [01:42<00:00, 1.25s/it]
45
+ Epoch 14/70, Loss: 4.506260
46
+ Checkpoint saved to checkpoint.pt
47
+ Epoch 15/70: 100% 82/82 [01:42<00:00, 1.25s/it]
48
+ Epoch 15/70, Loss: 4.416086
49
+ Checkpoint saved to checkpoint.pt
50
+ Epoch 16/70: 100% 82/82 [01:42<00:00, 1.25s/it]
51
+ Epoch 16/70, Loss: 4.370214
52
+ Checkpoint saved to checkpoint.pt
53
+ Epoch 17/70: 100% 82/82 [01:42<00:00, 1.25s/it]
54
+ Epoch 17/70, Loss: 4.278370
55
+ Checkpoint saved to checkpoint.pt
56
+ Epoch 18/70: 100% 82/82 [01:42<00:00, 1.25s/it]
57
+ Epoch 18/70, Loss: 4.304771
58
+ Checkpoint saved to checkpoint.pt
59
+ Epoch 19/70: 100% 82/82 [01:42<00:00, 1.25s/it]
60
+ Epoch 19/70, Loss: 4.209321
61
+ Checkpoint saved to checkpoint.pt
62
+ Epoch 20/70: 100% 82/82 [01:42<00:00, 1.25s/it]
63
+ Epoch 20/70, Loss: 4.175936
64
+ Checkpoint saved to checkpoint.pt
65
+ Epoch 21/70: 100% 82/82 [01:42<00:00, 1.25s/it]
66
+ Epoch 21/70, Loss: 4.071361
67
+ Checkpoint saved to checkpoint.pt
68
+ Epoch 22/70: 100% 82/82 [01:42<00:00, 1.25s/it]
69
+ Epoch 22/70, Loss: 4.071530
70
+ Checkpoint saved to checkpoint.pt
71
+ Epoch 23/70: 100% 82/82 [01:42<00:00, 1.25s/it]
72
+ Epoch 23/70, Loss: 4.053171
73
+ Checkpoint saved to checkpoint.pt
74
+ Epoch 24/70: 100% 82/82 [01:42<00:00, 1.25s/it]
75
+ Epoch 24/70, Loss: 3.923664
76
+ Checkpoint saved to checkpoint.pt
77
+ Epoch 25/70: 100% 82/82 [01:42<00:00, 1.25s/it]
78
+ Epoch 25/70, Loss: 3.827437
79
+ Checkpoint saved to checkpoint.pt
80
+ Epoch 26/70: 100% 82/82 [01:42<00:00, 1.25s/it]
81
+ Epoch 26/70, Loss: 3.767063
82
+ Checkpoint saved to checkpoint.pt
83
+ Epoch 27/70: 100% 82/82 [01:42<00:00, 1.25s/it]
84
+ Epoch 27/70, Loss: 3.711340
85
+ Checkpoint saved to checkpoint.pt
86
+ Epoch 28/70: 100% 82/82 [01:42<00:00, 1.25s/it]
87
+ Epoch 28/70, Loss: 3.622302
88
+ Checkpoint saved to checkpoint.pt
89
+ Epoch 29/70: 100% 82/82 [01:42<00:00, 1.25s/it]
90
+ Epoch 29/70, Loss: 3.583114
91
+ Checkpoint saved to checkpoint.pt
92
+ Epoch 30/70: 100% 82/82 [01:42<00:00, 1.25s/it]
93
+ Epoch 30/70, Loss: 3.517573
94
+ Checkpoint saved to checkpoint.pt
95
+ Epoch 31/70: 100% 82/82 [01:42<00:00, 1.25s/it]
96
+ Epoch 31/70, Loss: 3.445611
97
+ Checkpoint saved to checkpoint.pt
98
+ Epoch 32/70: 100% 82/82 [01:42<00:00, 1.25s/it]
99
+ Epoch 32/70, Loss: 3.410571
100
+ Checkpoint saved to checkpoint.pt
101
+ Epoch 33/70: 100% 82/82 [01:42<00:00, 1.25s/it]
102
+ Epoch 33/70, Loss: 3.282128
103
+ Checkpoint saved to checkpoint.pt
104
+ Epoch 34/70: 100% 82/82 [01:42<00:00, 1.25s/it]
105
+ Epoch 34/70, Loss: 3.307455
106
+ Checkpoint saved to checkpoint.pt
107
+ Epoch 35/70: 100% 82/82 [01:42<00:00, 1.25s/it]
108
+ Epoch 35/70, Loss: 3.126928
109
+ Checkpoint saved to checkpoint.pt
110
+ Epoch 36/70: 100% 82/82 [01:42<00:00, 1.25s/it]
111
+ Epoch 36/70, Loss: 3.057953
112
+ Checkpoint saved to checkpoint.pt
113
+ Epoch 37/70: 100% 82/82 [01:42<00:00, 1.24s/it]
114
+ Epoch 37/70, Loss: 3.082567
115
+ Checkpoint saved to checkpoint.pt
116
+ Epoch 38/70: 100% 82/82 [01:42<00:00, 1.25s/it]
117
+ Epoch 38/70, Loss: 3.066772
118
+ Checkpoint saved to checkpoint.pt
119
+ Epoch 39/70: 100% 82/82 [01:42<00:00, 1.25s/it]
120
+ Epoch 39/70, Loss: 2.943954
121
+ Checkpoint saved to checkpoint.pt
122
+ Epoch 40/70: 100% 82/82 [01:42<00:00, 1.25s/it]
123
+ Epoch 40/70, Loss: 2.874876
124
+ Checkpoint saved to checkpoint.pt
125
+ Epoch 41/70: 100% 82/82 [01:42<00:00, 1.25s/it]
126
+ Epoch 41/70, Loss: 2.781206
127
+ Checkpoint saved to checkpoint.pt
128
+ Epoch 42/70: 100% 82/82 [01:42<00:00, 1.24s/it]
129
+ Epoch 42/70, Loss: 2.729423
130
+ Checkpoint saved to checkpoint.pt
131
+ Epoch 43/70: 100% 82/82 [01:42<00:00, 1.25s/it]
132
+ Epoch 43/70, Loss: 2.656427
133
+ Checkpoint saved to checkpoint.pt
134
+ Epoch 44/70: 100% 82/82 [01:42<00:00, 1.25s/it]
135
+ Epoch 44/70, Loss: 2.641519
136
+ Checkpoint saved to checkpoint.pt
137
+ Epoch 45/70: 100% 82/82 [01:42<00:00, 1.24s/it]
138
+ Epoch 45/70, Loss: 2.593380
139
+ Checkpoint saved to checkpoint.pt
140
+ Epoch 46/70: 100% 82/82 [01:42<00:00, 1.25s/it]
141
+ Epoch 46/70, Loss: 2.504074
142
+ Checkpoint saved to checkpoint.pt
143
+ Epoch 47/70: 100% 82/82 [01:41<00:00, 1.24s/it]
144
+ Epoch 47/70, Loss: 2.510426
145
+ Checkpoint saved to checkpoint.pt
146
+ Epoch 48/70: 100% 82/82 [01:42<00:00, 1.24s/it]
147
+ Epoch 48/70, Loss: 2.465840
148
+ Checkpoint saved to checkpoint.pt
149
+ Epoch 49/70: 100% 82/82 [01:41<00:00, 1.24s/it]
150
+ Epoch 49/70, Loss: 2.339541
151
+ Checkpoint saved to checkpoint.pt
152
+ Epoch 50/70: 100% 82/82 [01:42<00:00, 1.25s/it]
153
+ Epoch 50/70, Loss: 2.288784
154
+ Checkpoint saved to checkpoint.pt
155
+ Epoch 51/70: 100% 82/82 [01:42<00:00, 1.24s/it]
156
+ Epoch 51/70, Loss: 2.272939
157
+ Checkpoint saved to checkpoint.pt
158
+ Epoch 52/70: 100% 82/82 [01:42<00:00, 1.25s/it]
159
+ Epoch 52/70, Loss: 2.150897
160
+ Checkpoint saved to checkpoint.pt
161
+ Epoch 53/70: 100% 82/82 [01:42<00:00, 1.25s/it]
162
+ Epoch 53/70, Loss: 2.096288
163
+ Checkpoint saved to checkpoint.pt
164
+ Epoch 54/70: 100% 82/82 [01:42<00:00, 1.24s/it]
165
+ Epoch 54/70, Loss: 2.057416
166
+ Checkpoint saved to checkpoint.pt
167
+ Epoch 55/70: 100% 82/82 [01:42<00:00, 1.24s/it]
168
+ Epoch 55/70, Loss: 1.962530
169
+ Checkpoint saved to checkpoint.pt
170
+ Epoch 56/70: 100% 82/82 [01:41<00:00, 1.24s/it]
171
+ Epoch 56/70, Loss: 1.930993
172
+ Checkpoint saved to checkpoint.pt
173
+ Epoch 57/70: 100% 82/82 [01:41<00:00, 1.24s/it]
174
+ Epoch 57/70, Loss: 1.854412
175
+ Checkpoint saved to checkpoint.pt
176
+ Epoch 58/70: 100% 82/82 [01:42<00:00, 1.24s/it]
177
+ Epoch 58/70, Loss: 1.818957
178
+ Checkpoint saved to checkpoint.pt
179
+ Epoch 59/70: 100% 82/82 [01:42<00:00, 1.24s/it]
180
+ Epoch 59/70, Loss: 1.764919
181
+ Checkpoint saved to checkpoint.pt
182
+ Epoch 60/70: 100% 82/82 [01:42<00:00, 1.25s/it]
183
+ Epoch 60/70, Loss: 1.741000
184
+ Checkpoint saved to checkpoint.pt
185
+ Epoch 61/70: 100% 82/82 [01:42<00:00, 1.24s/it]
186
+ Epoch 61/70, Loss: 1.694582
187
+ Checkpoint saved to checkpoint.pt
188
+ Epoch 62/70: 100% 82/82 [01:42<00:00, 1.24s/it]
189
+ Epoch 62/70, Loss: 1.751990
190
+ Checkpoint saved to checkpoint.pt
191
+ Epoch 63/70: 100% 82/82 [01:42<00:00, 1.25s/it]
192
+ Epoch 63/70, Loss: 1.664971
193
+ Checkpoint saved to checkpoint.pt
194
+ Epoch 64/70: 100% 82/82 [01:41<00:00, 1.24s/it]
195
+ Epoch 64/70, Loss: 1.557876
196
+ Checkpoint saved to checkpoint.pt
197
+ Epoch 65/70: 100% 82/82 [01:41<00:00, 1.24s/it]
198
+ Epoch 65/70, Loss: 1.543549
199
+ Checkpoint saved to checkpoint.pt
200
+ Epoch 66/70: 100% 82/82 [01:42<00:00, 1.25s/it]
201
+ Epoch 66/70, Loss: 1.436256
202
+ Checkpoint saved to checkpoint.pt
203
+ Epoch 67/70: 100% 82/82 [01:42<00:00, 1.24s/it]
204
+ Epoch 67/70, Loss: 1.352293
205
+ Checkpoint saved to checkpoint.pt
206
+ Epoch 68/70: 100% 82/82 [01:42<00:00, 1.24s/it]
207
+ Epoch 68/70, Loss: 1.361581
208
+ Checkpoint saved to checkpoint.pt
209
+ Epoch 69/70: 100% 82/82 [01:42<00:00, 1.24s/it]
210
+ Epoch 69/70, Loss: 1.308131
211
+ Checkpoint saved to checkpoint.pt
212
+ Epoch 70/70: 100% 82/82 [01:42<00:00, 1.24s/it]
213
+ Epoch 70/70, Loss: 1.287876
214
+ Checkpoint saved to checkpoint.pt
215
+ Total training time: 127 minutes and 37 seconds
216
+ Model saved to trained_model_quantized.pt with quantization and compression.
217
+ ==================================================
218
+ Increased epoch to 78 to reach loss < 0.99999
219
+ ==================================================
220
+ using device: cuda
221
+ loaded 338025 tokens
222
+ 1 epoch = 82 batches
223
+ Number of model parameters: 124439808
224
+ Loading checkpoint from checkpoint.pt
225
+ /content/erav3-s12-transformer-model/erav3-s12-transformer-model/transformer.py:262: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
226
+ checkpoint = torch.load(checkpoint_file)
227
+ Epoch 71/78: 100% 82/82 [01:36<00:00, 1.18s/it]
228
+ Epoch 71/78, Loss: 1.453567
229
+ Checkpoint saved to checkpoint.pt
230
+ Epoch 72/78: 100% 82/82 [01:42<00:00, 1.25s/it]
231
+ Epoch 72/78, Loss: 1.162141
232
+ Checkpoint saved to checkpoint.pt
233
+ Epoch 73/78: 100% 82/82 [01:42<00:00, 1.24s/it]
234
+ Epoch 73/78, Loss: 1.174683
235
+ Checkpoint saved to checkpoint.pt
236
+ Epoch 74/78: 100% 82/82 [01:42<00:00, 1.25s/it]
237
+ Epoch 74/78, Loss: 1.089287
238
+ Checkpoint saved to checkpoint.pt
239
+ Epoch 75/78: 100% 82/82 [01:42<00:00, 1.25s/it]
240
+ Epoch 75/78, Loss: 1.010704
241
+ Checkpoint saved to checkpoint.pt
242
+ Epoch 76/78: 100% 82/82 [01:42<00:00, 1.24s/it]
243
+ Epoch 76/78, Loss: 0.979691
244
+ Checkpoint saved to checkpoint.pt
245
+ Epoch 77/78: 100% 82/82 [01:41<00:00, 1.24s/it]
246
+ Epoch 77/78, Loss: 0.918769
247
+ Checkpoint saved to checkpoint.pt
248
+ Epoch 78/78: 100% 82/82 [01:41<00:00, 1.24s/it]
249
+ Epoch 78/78, Loss: 0.904894
250
+ Checkpoint saved to checkpoint.pt
251
+ Total training time: 14 minutes and 37 seconds
252
+ Model saved to trained_model_quantized.pt with quantization and compression.
253
+
transformer.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Solving for residual std scaling issue
2
+ import os
3
+ import math
4
+ import time
5
+ from dataclasses import dataclass
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+ from tqdm import tqdm # Import tqdm for progress bar
10
+ import torch.quantization # Import quantization module
11
+ import torch.nn.utils.prune as prune
12
+ import tiktoken
13
+
14
+
15
+ class CausalSelfAttention(nn.Module):
16
+
17
+ def __init__(self, config):
18
+ super().__init__()
19
+ assert config.n_embd % config.n_head == 0
20
+ # key, query, value projections for all heads, but in a batch
21
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
22
+ # output projection
23
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
24
+ self.c_proj.NANGPT_SCALE_INIT = 1
25
+ # regularization
26
+ self.n_head = config.n_head
27
+ self.n_embd = config.n_embd
28
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
29
+
30
+ def forward(self, x):
31
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
32
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
33
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
34
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
35
+ qkv = self.c_attn(x)
36
+ q, k, v = qkv.split(self.n_embd, dim=2)
37
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
38
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
39
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
40
+
41
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
42
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
43
+ att = F.softmax(att, dim=-1)
44
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
45
+
46
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
47
+ # output projection
48
+ y = self.c_proj(y)
49
+ return y
50
+
51
+
52
+ class MLP(nn.Module):
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
57
+ self.gelu = nn.GELU(approximate='tanh')
58
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
59
+ self.c_proj.NANOGPT_SCALE_INIT = 1
60
+
61
+ def forward(self, x):
62
+ x = self.c_fc(x)
63
+ x = self.gelu(x)
64
+ x = self.c_proj(x)
65
+ return x
66
+
67
+ class Block(nn.Module):
68
+
69
+ def __init__(self, config):
70
+ super().__init__()
71
+ self.ln_1 = nn.LayerNorm(config.n_embd)
72
+ self.attn = CausalSelfAttention(config)
73
+ self.ln_2 = nn.LayerNorm(config.n_embd)
74
+ self.mlp = MLP(config)
75
+
76
+ def forward(self, x):
77
+ x = x + self.attn(self.ln_1(x))
78
+ x = x + self.mlp(self.ln_2(x))
79
+ return x
80
+
81
+
82
+ @dataclass
83
+ class GPTConfig:
84
+ block_size: int = 1024 # max sequence length
85
+ vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
86
+ n_layer: int = 12 # number of layers
87
+ n_head: int = 12 # number of heads
88
+ n_embd: int = 768 # embedding dimension
89
+
90
+
91
+ class GPT(nn.Module):
92
+
93
+ def __init__(self, config):
94
+ super().__init__()
95
+ self.config = config
96
+
97
+ self.transformer = nn.ModuleDict(dict(
98
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
99
+ wpe = nn.Embedding(config.block_size, config.n_embd),
100
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
101
+ ln_f = nn.LayerNorm(config.n_embd),
102
+ ))
103
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
104
+
105
+ # weight sharing
106
+ self.transformer.wte.weight = self.lm_head.weight
107
+
108
+ # weight initialization
109
+ self.apply(self._init_weights)
110
+
111
+ def _init_weights(self, module):
112
+ if isinstance(module, nn.Linear):
113
+ std = 0.02
114
+ if hasattr(module, 'NANGPT_SCALE_INIT'):
115
+ std *= (2 * self.config.n_layer) ** -0.5
116
+ torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
117
+ if module.bias is not None:
118
+ torch.nn.init.zeros_(module.bias)
119
+ elif isinstance(module, nn.Embedding):
120
+ torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
121
+
122
+ def print_num_parameters(self):
123
+ num_params = sum(p.numel() for p in self.parameters())
124
+ print(f"Number of model parameters: {num_params}")
125
+
126
+ def forward(self, idx, targets=None):
127
+ # idx is of shape (B, T)
128
+ B, T = idx.size()
129
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
130
+ # forward the token and posisition embeddings
131
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
132
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
133
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
134
+ x = tok_emb + pos_emb
135
+ # forward the blocks of the transformer
136
+ for block in self.transformer.h:
137
+ x = block(x)
138
+ # forward the final layernorm and the classifier
139
+ x = self.transformer.ln_f(x)
140
+ logits = self.lm_head(x) # (B, T, vocab_size)
141
+ loss = None
142
+ if targets is not None:
143
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
144
+ return logits, loss
145
+
146
+ @classmethod
147
+ def from_pretrained(cls, model_type):
148
+ """Loads pretrained GPT-2 model weights from huggingface"""
149
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
150
+ from transformers import GPT2LMHeadModel
151
+ print("loading weights from pretrained gpt: %s" % model_type)
152
+
153
+ # n_layer, n_head and n_embd are determined from model_type
154
+ config_args = {
155
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
156
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
157
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
158
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
159
+ }[model_type]
160
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
161
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
162
+ # create a from-scratch initialized minGPT model
163
+ config = GPTConfig(**config_args)
164
+ model = GPT(config)
165
+ sd = model.state_dict()
166
+ sd_keys = sd.keys()
167
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
168
+
169
+ # init a huggingface/transformers model
170
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
171
+ sd_hf = model_hf.state_dict()
172
+
173
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
174
+ sd_keys_hf = sd_hf.keys()
175
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
176
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
177
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
178
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
179
+ # this means that we have to transpose these weights when we import them
180
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
181
+ for k in sd_keys_hf:
182
+ if any(k.endswith(w) for w in transposed):
183
+ # special treatment for the Conv1D weights we need to transpose
184
+ assert sd_hf[k].shape[::-1] == sd[k].shape
185
+ with torch.no_grad():
186
+ sd[k].copy_(sd_hf[k].t())
187
+ else:
188
+ # vanilla copy over the other parameters
189
+ assert sd_hf[k].shape == sd[k].shape
190
+ with torch.no_grad():
191
+ sd[k].copy_(sd_hf[k])
192
+
193
+ return model
194
+
195
+
196
+ device = 'cpu'
197
+ if torch.cuda.is_available():
198
+ device = 'cuda'
199
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
200
+ device = "mps"
201
+ print(f"using device: {device}")
202
+
203
+ # SEED
204
+ torch.manual_seed(1337)
205
+ if torch.cuda.is_available():
206
+ torch.cuda.manual_seed(1337)
207
+
208
+ class DataLoaderLite:
209
+ def __init__(self, B, T):
210
+ self.B = B
211
+ self.T = T
212
+
213
+ # at init load tokens from disk and store them in memory
214
+ with open('input.txt', 'r') as f:
215
+ text = f.read()
216
+ enc = tiktoken.get_encoding('gpt2')
217
+ tokens = enc.encode(text)
218
+ self.tokens = torch.tensor(tokens)
219
+ print(f'loaded {len(self.tokens)} tokens')
220
+ print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
221
+
222
+ # state
223
+ self.current_position = 0
224
+
225
+ def next_batch(self):
226
+ B, T = self.B, self.T
227
+ buf = self.tokens[self.current_position: self.current_position + B * T + 1]
228
+ x = (buf[:-1]).view(B, T) # inputs
229
+ y = (buf[1:]).view(B, T) # targets
230
+ # advance the position in the tensor
231
+ self.current_position += B*T
232
+ # if loading the next batch would be out of bounds, reset
233
+ if self.current_position + (B * T + 1) > len(self.tokens):
234
+ self.current_position = 0
235
+ return x, y
236
+
237
+ # Initialize the data loader with batch size 4 and sequence length 1024
238
+ train_loader = DataLoaderLite(B=4, T=1024)
239
+
240
+ # Initialize the model
241
+ model = GPT(GPTConfig())
242
+ model.to(device)
243
+
244
+ # Print number of model parameters
245
+ model.print_num_parameters()
246
+
247
+ # Define the optimizer
248
+ optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
249
+
250
+ # Function to load the most recent checkpoint
251
+ def load_latest_checkpoint(model):
252
+ # Find the checkpoint file
253
+ checkpoint_file = 'checkpoint.pt'
254
+ if not os.path.exists(checkpoint_file):
255
+ return 0 # No checkpoint found, start from epoch 0
256
+
257
+ print(f'Loading checkpoint from {checkpoint_file}')
258
+
259
+ # Load the model state and epoch number
260
+ checkpoint = torch.load(checkpoint_file)
261
+
262
+ # Ensure the checkpoint contains the expected keys
263
+ if 'model_state_dict' not in checkpoint or 'epoch' not in checkpoint:
264
+ raise KeyError("Checkpoint does not contain required keys.")
265
+
266
+ model.load_state_dict(checkpoint['model_state_dict'])
267
+
268
+ # Return the epoch number
269
+ return checkpoint['epoch']
270
+
271
+ # Load the latest checkpoint if available
272
+ start_epoch = load_latest_checkpoint(model)
273
+
274
+ # NEW CODE: Training loop until loss is less than 0.099999
275
+ loss = float('inf') # Initialize loss to a large value
276
+ num_epochs = 78 # Set the number of epochs to 78
277
+
278
+ # Start time tracking
279
+ start_time = time.time()
280
+
281
+ for epoch in range(start_epoch, num_epochs): # Start from the loaded epoch
282
+ epoch_loss = 0.0 # Initialize epoch loss
283
+ num_steps = 0 # Initialize step counter for the epoch
284
+ last_loss = None # Variable to store the last loss
285
+
286
+ # Calculate total steps for the progress bar
287
+ total_steps = len(train_loader.tokens) // (train_loader.B * train_loader.T)
288
+
289
+ # Use tqdm to create a progress bar
290
+ with tqdm(total=total_steps, desc=f'Epoch {epoch + 1}/{num_epochs}') as pbar:
291
+ for step in range(total_steps): # Iterate over the number of steps
292
+ x, y = train_loader.next_batch()
293
+ x, y = x.to(device), y.to(device)
294
+ optimizer.zero_grad()
295
+ logits, loss = model(x, y)
296
+ loss.backward()
297
+ optimizer.step()
298
+
299
+ epoch_loss += loss.item() # Accumulate loss
300
+ num_steps += 1 # Increment step counter
301
+ last_loss = loss.item() # Store the last loss
302
+ pbar.update(1) # Update progress bar
303
+
304
+ # Check if the loss is below the threshold
305
+ if last_loss < 0.099999:
306
+ print(f'Loss below threshold: {last_loss:.6f}') # Print loss before breaking
307
+ break # Exit the loop if the loss condition is met
308
+
309
+ # Print the loss at the end of the epoch
310
+ print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {last_loss:.6f}')
311
+
312
+ # Check if the loss condition was met to break out of the epoch loop
313
+ if last_loss < 0.099999:
314
+ print(f'Early stopping at epoch {epoch + 1} due to loss condition met.')
315
+ break # Exit the epoch loop if the loss condition is met
316
+
317
+ # Checkpointing: Save the model and the current epoch after each epoch
318
+ checkpoint_path = 'checkpoint.pt' # Save to a single checkpoint file
319
+ torch.save({
320
+ 'epoch': epoch + 1, # Save the current epoch number
321
+ 'model_state_dict': model.state_dict(), # Save the model state
322
+ }, checkpoint_path)
323
+ print(f'Checkpoint saved to {checkpoint_path}')
324
+
325
+ # End time tracking
326
+ end_time = time.time()
327
+ training_duration = end_time - start_time
328
+
329
+ # Convert training duration to minutes and seconds
330
+ minutes = int(training_duration // 60)
331
+ seconds = int(training_duration % 60)
332
+
333
+ # Print the total training time in minute:second format
334
+ print(f'Total training time: {minutes} minutes and {seconds} seconds')
335
+
336
+ # After training your model, apply quantization and save it with compression
337
+ def save_model_with_quantization(model, file_path):
338
+ # Switch model to evaluation mode
339
+ model.eval()
340
+
341
+ # Apply dynamic quantization
342
+ quantized_model = torch.quantization.quantize_dynamic(
343
+ model, # the model to be quantized
344
+ {nn.Linear}, # layers to quantize
345
+ dtype=torch.qint8 # quantization type
346
+ )
347
+
348
+ # Save the quantized model with compression
349
+ torch.save(quantized_model.state_dict(), file_path, _use_new_zipfile_serialization=True)
350
+ print(f'Model saved to {file_path} with quantization and compression.')
351
+
352
+ # Call this function after training your model
353
+ save_model_with_quantization(model, 'trained_model_quantized.pt')