Spaces:
Sleeping
Sleeping
Commit
·
61d0253
1
Parent(s):
1b4fddf
Adding code for transformer model
Browse files- README.md +82 -1
- app.py +63 -0
- checkpoint.pt +3 -0
- input.txt +0 -0
- trained_model_quantized.pt +3 -0
- training.log +253 -0
- transformer.py +353 -0
README.md
CHANGED
@@ -10,4 +10,85 @@ pinned: false
|
|
10 |
short_description: Transformer trained on Shakespeare play dataset
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
short_description: Transformer trained on Shakespeare play dataset
|
11 |
---
|
12 |
|
13 |
+
# Transformer Model Training
|
14 |
+
|
15 |
+
This project implements a transformer-based language model using PyTorch. The model is designed to learn from a text corpus and can be trained and fine-tuned for various natural language processing tasks.
|
16 |
+
|
17 |
+
## Table of Contents
|
18 |
+
- [Features](#features)
|
19 |
+
- [Requirements](#requirements)
|
20 |
+
- [Installation](#installation)
|
21 |
+
- [Usage](#usage)
|
22 |
+
- [Training](#training)
|
23 |
+
- [Actual Training](#actual-training)
|
24 |
+
- [Checkpointing](#checkpointing)
|
25 |
+
- [Model Compression](#model-compression)
|
26 |
+
- [License](#license)
|
27 |
+
- [Acknowledgments](#acknowledgments)
|
28 |
+
|
29 |
+
## Features
|
30 |
+
- Transformer architecture with causal self-attention and feedforward layers.
|
31 |
+
- Efficient data loading and batching.
|
32 |
+
- Checkpointing to resume training.
|
33 |
+
- Support for multiple devices (CPU, CUDA, MPS).
|
34 |
+
- Model compression for reduced file size.
|
35 |
+
- Streamlit application for text generation using the trained model.
|
36 |
+
|
37 |
+
## Requirements
|
38 |
+
- Python 3.6 or higher
|
39 |
+
- PyTorch 1.7 or higher
|
40 |
+
- tqdm
|
41 |
+
- tiktoken
|
42 |
+
- streamlit
|
43 |
+
- transformers
|
44 |
+
|
45 |
+
## Installation
|
46 |
+
1. Clone the repository:
|
47 |
+
```bash
|
48 |
+
git clone https://github.com/yourusername/transformer-model-training.git
|
49 |
+
cd transformer-model-training
|
50 |
+
```
|
51 |
+
|
52 |
+
2. Install the required packages:
|
53 |
+
```bash
|
54 |
+
pip install -r requirements.txt
|
55 |
+
```
|
56 |
+
|
57 |
+
## Usage
|
58 |
+
1. Prepare your text data in a file named `input.txt`. The model will read this file to load tokens for training.
|
59 |
+
|
60 |
+
2. Run the training script:
|
61 |
+
```bash
|
62 |
+
python transformer.py
|
63 |
+
```
|
64 |
+
|
65 |
+
3. The model will save checkpoints after each epoch in `checkpoint.pt` and the final model in `trained_model_quantized.pt`.
|
66 |
+
|
67 |
+
4. To generate text using the trained model, run the Streamlit application:
|
68 |
+
```bash
|
69 |
+
streamlit run app.py
|
70 |
+
```
|
71 |
+
|
72 |
+
5. Enter your text and specify the length of additional text to generate in the Streamlit interface.
|
73 |
+
|
74 |
+
## Training
|
75 |
+
- The model is trained using a batch size of 4 and a learning rate of 3e-4.
|
76 |
+
- The training loop includes loss calculation, backpropagation, and optimizer steps.
|
77 |
+
- The loss is monitored, and checkpoints are saved to allow for resuming training.
|
78 |
+
- The training process is logged in `training.log`, which contains detailed statistics for each epoch, including loss values and checkpointing information.
|
79 |
+
|
80 |
+
## Actual Training
|
81 |
+
The model was trained for a total of **78 epochs**. The final loss achieved at the end of training was approximately **0.904894**. The training log file contains detailed statistics for each epoch, including loss values and checkpointing information. You can find the log file named `training.log` in the project directory.
|
82 |
+
|
83 |
+
## Checkpointing
|
84 |
+
- The model state and current epoch are saved in a single checkpoint file (`checkpoint.pt`).
|
85 |
+
- To resume training from the last checkpoint, simply run the training script again. The model will automatically load the latest checkpoint.
|
86 |
+
|
87 |
+
## Model Compression
|
88 |
+
- The final model is saved with compression to reduce file size. The model file will be saved as `trained_model_quantized.pt`.
|
89 |
+
|
90 |
+
## License
|
91 |
+
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
|
92 |
+
|
93 |
+
## Acknowledgments
|
94 |
+
- This project is inspired by the original GPT architecture and various resources available in the NLP community.
|
app.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import tiktoken
|
4 |
+
from transformer import GPT, GPTConfig # Ensure you import your model class
|
5 |
+
|
6 |
+
# Load the trained model
|
7 |
+
@st.cache_resource
|
8 |
+
def load_model():
|
9 |
+
config = GPTConfig()
|
10 |
+
model = GPT(config)
|
11 |
+
try:
|
12 |
+
model.load_state_dict(torch.load('trained_model_quantized.pt'))
|
13 |
+
model.eval() # Set the model to evaluation mode
|
14 |
+
st.success("Model loaded successfully!")
|
15 |
+
except Exception as e:
|
16 |
+
st.error(f"Error loading model: {e}")
|
17 |
+
return model
|
18 |
+
|
19 |
+
# Load the tokenizer
|
20 |
+
def load_tokenizer():
|
21 |
+
return tiktoken.get_encoding('gpt2')
|
22 |
+
|
23 |
+
# Generate text function
|
24 |
+
def generate_text(model, tokenizer, input_text, length, num_sequences):
|
25 |
+
# Encode the input text
|
26 |
+
input_ids = tokenizer.encode(input_text)
|
27 |
+
input_tensor = torch.tensor(input_ids).unsqueeze(0) # Add batch dimension
|
28 |
+
|
29 |
+
generated_sequences = []
|
30 |
+
for _ in range(num_sequences):
|
31 |
+
# Generate additional tokens
|
32 |
+
with torch.no_grad():
|
33 |
+
for _ in range(length):
|
34 |
+
logits = model(input_tensor)[0] # Get logits
|
35 |
+
next_token_logits = logits[:, -1, :] # Get the last token's logits
|
36 |
+
next_token_probs = torch.softmax(next_token_logits, dim=-1)
|
37 |
+
next_token = torch.multinomial(next_token_probs, num_samples=1) # Sample from the distribution
|
38 |
+
input_tensor = torch.cat((input_tensor, next_token.unsqueeze(0)), dim=1) # Append the new token
|
39 |
+
|
40 |
+
# Decode the generated tokens
|
41 |
+
generated_sequences.append(tokenizer.decode(input_tensor[0].tolist()))
|
42 |
+
|
43 |
+
return generated_sequences
|
44 |
+
|
45 |
+
# Streamlit app layout
|
46 |
+
st.title("GPT Text Generator")
|
47 |
+
st.write("Enter your text and specify the length of additional text to generate.")
|
48 |
+
|
49 |
+
input_text = st.text_area("Input Text", "Once upon a time", max_chars=512) # Limit to 512 characters
|
50 |
+
length = st.slider("Predict Additional Text of Length", 1, 50, 10)
|
51 |
+
num_sequences = st.slider("Number of Sequences to Generate", 1, 5, 1)
|
52 |
+
|
53 |
+
if st.button("Generate"):
|
54 |
+
model = load_model()
|
55 |
+
tokenizer = load_tokenizer()
|
56 |
+
st.write("Generating text...")
|
57 |
+
generated_texts = generate_text(model, tokenizer, input_text, length, num_sequences)
|
58 |
+
st.write("Text generation complete.")
|
59 |
+
|
60 |
+
st.write("Generated Texts:")
|
61 |
+
for i, text in enumerate(generated_texts):
|
62 |
+
st.subheader(f"Sequence {i + 1}")
|
63 |
+
st.write(text)
|
checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9f02348249d0b8457a59bc3331ac807b879f7d32b35886d60c8ab15d18fa6bd
|
3 |
+
size 548146590
|
input.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
trained_model_quantized.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8fec31d4b4fa71331f80d77d4066bb10a71d6118c0c757e341a143b630be08a6
|
3 |
+
size 331982620
|
training.log
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
using device: cuda
|
2 |
+
loaded 338025 tokens
|
3 |
+
1 epoch = 82 batches
|
4 |
+
Number of model parameters: 124439808
|
5 |
+
Epoch 1/70: 100% 82/82 [01:38<00:00, 1.20s/it]
|
6 |
+
Epoch 1/70, Loss: 6.169636
|
7 |
+
Checkpoint saved to checkpoint.pt
|
8 |
+
Epoch 2/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
9 |
+
Epoch 2/70, Loss: 5.720689
|
10 |
+
Checkpoint saved to checkpoint.pt
|
11 |
+
Epoch 3/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
12 |
+
Epoch 3/70, Loss: 5.390238
|
13 |
+
Checkpoint saved to checkpoint.pt
|
14 |
+
Epoch 4/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
15 |
+
Epoch 4/70, Loss: 5.164030
|
16 |
+
Checkpoint saved to checkpoint.pt
|
17 |
+
Epoch 5/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
18 |
+
Epoch 5/70, Loss: 5.051653
|
19 |
+
Checkpoint saved to checkpoint.pt
|
20 |
+
Epoch 6/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
21 |
+
Epoch 6/70, Loss: 4.947546
|
22 |
+
Checkpoint saved to checkpoint.pt
|
23 |
+
Epoch 7/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
24 |
+
Epoch 7/70, Loss: 4.893464
|
25 |
+
Checkpoint saved to checkpoint.pt
|
26 |
+
Epoch 8/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
27 |
+
Epoch 8/70, Loss: 4.785249
|
28 |
+
Checkpoint saved to checkpoint.pt
|
29 |
+
Epoch 9/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
30 |
+
Epoch 9/70, Loss: 4.773346
|
31 |
+
Checkpoint saved to checkpoint.pt
|
32 |
+
Epoch 10/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
33 |
+
Epoch 10/70, Loss: 4.669469
|
34 |
+
Checkpoint saved to checkpoint.pt
|
35 |
+
Epoch 11/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
36 |
+
Epoch 11/70, Loss: 4.617172
|
37 |
+
Checkpoint saved to checkpoint.pt
|
38 |
+
Epoch 12/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
39 |
+
Epoch 12/70, Loss: 4.594382
|
40 |
+
Checkpoint saved to checkpoint.pt
|
41 |
+
Epoch 13/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
42 |
+
Epoch 13/70, Loss: 4.554847
|
43 |
+
Checkpoint saved to checkpoint.pt
|
44 |
+
Epoch 14/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
45 |
+
Epoch 14/70, Loss: 4.506260
|
46 |
+
Checkpoint saved to checkpoint.pt
|
47 |
+
Epoch 15/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
48 |
+
Epoch 15/70, Loss: 4.416086
|
49 |
+
Checkpoint saved to checkpoint.pt
|
50 |
+
Epoch 16/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
51 |
+
Epoch 16/70, Loss: 4.370214
|
52 |
+
Checkpoint saved to checkpoint.pt
|
53 |
+
Epoch 17/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
54 |
+
Epoch 17/70, Loss: 4.278370
|
55 |
+
Checkpoint saved to checkpoint.pt
|
56 |
+
Epoch 18/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
57 |
+
Epoch 18/70, Loss: 4.304771
|
58 |
+
Checkpoint saved to checkpoint.pt
|
59 |
+
Epoch 19/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
60 |
+
Epoch 19/70, Loss: 4.209321
|
61 |
+
Checkpoint saved to checkpoint.pt
|
62 |
+
Epoch 20/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
63 |
+
Epoch 20/70, Loss: 4.175936
|
64 |
+
Checkpoint saved to checkpoint.pt
|
65 |
+
Epoch 21/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
66 |
+
Epoch 21/70, Loss: 4.071361
|
67 |
+
Checkpoint saved to checkpoint.pt
|
68 |
+
Epoch 22/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
69 |
+
Epoch 22/70, Loss: 4.071530
|
70 |
+
Checkpoint saved to checkpoint.pt
|
71 |
+
Epoch 23/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
72 |
+
Epoch 23/70, Loss: 4.053171
|
73 |
+
Checkpoint saved to checkpoint.pt
|
74 |
+
Epoch 24/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
75 |
+
Epoch 24/70, Loss: 3.923664
|
76 |
+
Checkpoint saved to checkpoint.pt
|
77 |
+
Epoch 25/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
78 |
+
Epoch 25/70, Loss: 3.827437
|
79 |
+
Checkpoint saved to checkpoint.pt
|
80 |
+
Epoch 26/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
81 |
+
Epoch 26/70, Loss: 3.767063
|
82 |
+
Checkpoint saved to checkpoint.pt
|
83 |
+
Epoch 27/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
84 |
+
Epoch 27/70, Loss: 3.711340
|
85 |
+
Checkpoint saved to checkpoint.pt
|
86 |
+
Epoch 28/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
87 |
+
Epoch 28/70, Loss: 3.622302
|
88 |
+
Checkpoint saved to checkpoint.pt
|
89 |
+
Epoch 29/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
90 |
+
Epoch 29/70, Loss: 3.583114
|
91 |
+
Checkpoint saved to checkpoint.pt
|
92 |
+
Epoch 30/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
93 |
+
Epoch 30/70, Loss: 3.517573
|
94 |
+
Checkpoint saved to checkpoint.pt
|
95 |
+
Epoch 31/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
96 |
+
Epoch 31/70, Loss: 3.445611
|
97 |
+
Checkpoint saved to checkpoint.pt
|
98 |
+
Epoch 32/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
99 |
+
Epoch 32/70, Loss: 3.410571
|
100 |
+
Checkpoint saved to checkpoint.pt
|
101 |
+
Epoch 33/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
102 |
+
Epoch 33/70, Loss: 3.282128
|
103 |
+
Checkpoint saved to checkpoint.pt
|
104 |
+
Epoch 34/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
105 |
+
Epoch 34/70, Loss: 3.307455
|
106 |
+
Checkpoint saved to checkpoint.pt
|
107 |
+
Epoch 35/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
108 |
+
Epoch 35/70, Loss: 3.126928
|
109 |
+
Checkpoint saved to checkpoint.pt
|
110 |
+
Epoch 36/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
111 |
+
Epoch 36/70, Loss: 3.057953
|
112 |
+
Checkpoint saved to checkpoint.pt
|
113 |
+
Epoch 37/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
114 |
+
Epoch 37/70, Loss: 3.082567
|
115 |
+
Checkpoint saved to checkpoint.pt
|
116 |
+
Epoch 38/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
117 |
+
Epoch 38/70, Loss: 3.066772
|
118 |
+
Checkpoint saved to checkpoint.pt
|
119 |
+
Epoch 39/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
120 |
+
Epoch 39/70, Loss: 2.943954
|
121 |
+
Checkpoint saved to checkpoint.pt
|
122 |
+
Epoch 40/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
123 |
+
Epoch 40/70, Loss: 2.874876
|
124 |
+
Checkpoint saved to checkpoint.pt
|
125 |
+
Epoch 41/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
126 |
+
Epoch 41/70, Loss: 2.781206
|
127 |
+
Checkpoint saved to checkpoint.pt
|
128 |
+
Epoch 42/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
129 |
+
Epoch 42/70, Loss: 2.729423
|
130 |
+
Checkpoint saved to checkpoint.pt
|
131 |
+
Epoch 43/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
132 |
+
Epoch 43/70, Loss: 2.656427
|
133 |
+
Checkpoint saved to checkpoint.pt
|
134 |
+
Epoch 44/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
135 |
+
Epoch 44/70, Loss: 2.641519
|
136 |
+
Checkpoint saved to checkpoint.pt
|
137 |
+
Epoch 45/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
138 |
+
Epoch 45/70, Loss: 2.593380
|
139 |
+
Checkpoint saved to checkpoint.pt
|
140 |
+
Epoch 46/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
141 |
+
Epoch 46/70, Loss: 2.504074
|
142 |
+
Checkpoint saved to checkpoint.pt
|
143 |
+
Epoch 47/70: 100% 82/82 [01:41<00:00, 1.24s/it]
|
144 |
+
Epoch 47/70, Loss: 2.510426
|
145 |
+
Checkpoint saved to checkpoint.pt
|
146 |
+
Epoch 48/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
147 |
+
Epoch 48/70, Loss: 2.465840
|
148 |
+
Checkpoint saved to checkpoint.pt
|
149 |
+
Epoch 49/70: 100% 82/82 [01:41<00:00, 1.24s/it]
|
150 |
+
Epoch 49/70, Loss: 2.339541
|
151 |
+
Checkpoint saved to checkpoint.pt
|
152 |
+
Epoch 50/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
153 |
+
Epoch 50/70, Loss: 2.288784
|
154 |
+
Checkpoint saved to checkpoint.pt
|
155 |
+
Epoch 51/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
156 |
+
Epoch 51/70, Loss: 2.272939
|
157 |
+
Checkpoint saved to checkpoint.pt
|
158 |
+
Epoch 52/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
159 |
+
Epoch 52/70, Loss: 2.150897
|
160 |
+
Checkpoint saved to checkpoint.pt
|
161 |
+
Epoch 53/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
162 |
+
Epoch 53/70, Loss: 2.096288
|
163 |
+
Checkpoint saved to checkpoint.pt
|
164 |
+
Epoch 54/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
165 |
+
Epoch 54/70, Loss: 2.057416
|
166 |
+
Checkpoint saved to checkpoint.pt
|
167 |
+
Epoch 55/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
168 |
+
Epoch 55/70, Loss: 1.962530
|
169 |
+
Checkpoint saved to checkpoint.pt
|
170 |
+
Epoch 56/70: 100% 82/82 [01:41<00:00, 1.24s/it]
|
171 |
+
Epoch 56/70, Loss: 1.930993
|
172 |
+
Checkpoint saved to checkpoint.pt
|
173 |
+
Epoch 57/70: 100% 82/82 [01:41<00:00, 1.24s/it]
|
174 |
+
Epoch 57/70, Loss: 1.854412
|
175 |
+
Checkpoint saved to checkpoint.pt
|
176 |
+
Epoch 58/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
177 |
+
Epoch 58/70, Loss: 1.818957
|
178 |
+
Checkpoint saved to checkpoint.pt
|
179 |
+
Epoch 59/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
180 |
+
Epoch 59/70, Loss: 1.764919
|
181 |
+
Checkpoint saved to checkpoint.pt
|
182 |
+
Epoch 60/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
183 |
+
Epoch 60/70, Loss: 1.741000
|
184 |
+
Checkpoint saved to checkpoint.pt
|
185 |
+
Epoch 61/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
186 |
+
Epoch 61/70, Loss: 1.694582
|
187 |
+
Checkpoint saved to checkpoint.pt
|
188 |
+
Epoch 62/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
189 |
+
Epoch 62/70, Loss: 1.751990
|
190 |
+
Checkpoint saved to checkpoint.pt
|
191 |
+
Epoch 63/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
192 |
+
Epoch 63/70, Loss: 1.664971
|
193 |
+
Checkpoint saved to checkpoint.pt
|
194 |
+
Epoch 64/70: 100% 82/82 [01:41<00:00, 1.24s/it]
|
195 |
+
Epoch 64/70, Loss: 1.557876
|
196 |
+
Checkpoint saved to checkpoint.pt
|
197 |
+
Epoch 65/70: 100% 82/82 [01:41<00:00, 1.24s/it]
|
198 |
+
Epoch 65/70, Loss: 1.543549
|
199 |
+
Checkpoint saved to checkpoint.pt
|
200 |
+
Epoch 66/70: 100% 82/82 [01:42<00:00, 1.25s/it]
|
201 |
+
Epoch 66/70, Loss: 1.436256
|
202 |
+
Checkpoint saved to checkpoint.pt
|
203 |
+
Epoch 67/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
204 |
+
Epoch 67/70, Loss: 1.352293
|
205 |
+
Checkpoint saved to checkpoint.pt
|
206 |
+
Epoch 68/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
207 |
+
Epoch 68/70, Loss: 1.361581
|
208 |
+
Checkpoint saved to checkpoint.pt
|
209 |
+
Epoch 69/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
210 |
+
Epoch 69/70, Loss: 1.308131
|
211 |
+
Checkpoint saved to checkpoint.pt
|
212 |
+
Epoch 70/70: 100% 82/82 [01:42<00:00, 1.24s/it]
|
213 |
+
Epoch 70/70, Loss: 1.287876
|
214 |
+
Checkpoint saved to checkpoint.pt
|
215 |
+
Total training time: 127 minutes and 37 seconds
|
216 |
+
Model saved to trained_model_quantized.pt with quantization and compression.
|
217 |
+
==================================================
|
218 |
+
Increased epoch to 78 to reach loss < 0.99999
|
219 |
+
==================================================
|
220 |
+
using device: cuda
|
221 |
+
loaded 338025 tokens
|
222 |
+
1 epoch = 82 batches
|
223 |
+
Number of model parameters: 124439808
|
224 |
+
Loading checkpoint from checkpoint.pt
|
225 |
+
/content/erav3-s12-transformer-model/erav3-s12-transformer-model/transformer.py:262: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
226 |
+
checkpoint = torch.load(checkpoint_file)
|
227 |
+
Epoch 71/78: 100% 82/82 [01:36<00:00, 1.18s/it]
|
228 |
+
Epoch 71/78, Loss: 1.453567
|
229 |
+
Checkpoint saved to checkpoint.pt
|
230 |
+
Epoch 72/78: 100% 82/82 [01:42<00:00, 1.25s/it]
|
231 |
+
Epoch 72/78, Loss: 1.162141
|
232 |
+
Checkpoint saved to checkpoint.pt
|
233 |
+
Epoch 73/78: 100% 82/82 [01:42<00:00, 1.24s/it]
|
234 |
+
Epoch 73/78, Loss: 1.174683
|
235 |
+
Checkpoint saved to checkpoint.pt
|
236 |
+
Epoch 74/78: 100% 82/82 [01:42<00:00, 1.25s/it]
|
237 |
+
Epoch 74/78, Loss: 1.089287
|
238 |
+
Checkpoint saved to checkpoint.pt
|
239 |
+
Epoch 75/78: 100% 82/82 [01:42<00:00, 1.25s/it]
|
240 |
+
Epoch 75/78, Loss: 1.010704
|
241 |
+
Checkpoint saved to checkpoint.pt
|
242 |
+
Epoch 76/78: 100% 82/82 [01:42<00:00, 1.24s/it]
|
243 |
+
Epoch 76/78, Loss: 0.979691
|
244 |
+
Checkpoint saved to checkpoint.pt
|
245 |
+
Epoch 77/78: 100% 82/82 [01:41<00:00, 1.24s/it]
|
246 |
+
Epoch 77/78, Loss: 0.918769
|
247 |
+
Checkpoint saved to checkpoint.pt
|
248 |
+
Epoch 78/78: 100% 82/82 [01:41<00:00, 1.24s/it]
|
249 |
+
Epoch 78/78, Loss: 0.904894
|
250 |
+
Checkpoint saved to checkpoint.pt
|
251 |
+
Total training time: 14 minutes and 37 seconds
|
252 |
+
Model saved to trained_model_quantized.pt with quantization and compression.
|
253 |
+
|
transformer.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Solving for residual std scaling issue
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import time
|
5 |
+
from dataclasses import dataclass
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from tqdm import tqdm # Import tqdm for progress bar
|
10 |
+
import torch.quantization # Import quantization module
|
11 |
+
import torch.nn.utils.prune as prune
|
12 |
+
import tiktoken
|
13 |
+
|
14 |
+
|
15 |
+
class CausalSelfAttention(nn.Module):
|
16 |
+
|
17 |
+
def __init__(self, config):
|
18 |
+
super().__init__()
|
19 |
+
assert config.n_embd % config.n_head == 0
|
20 |
+
# key, query, value projections for all heads, but in a batch
|
21 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
22 |
+
# output projection
|
23 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
24 |
+
self.c_proj.NANGPT_SCALE_INIT = 1
|
25 |
+
# regularization
|
26 |
+
self.n_head = config.n_head
|
27 |
+
self.n_embd = config.n_embd
|
28 |
+
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
32 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
33 |
+
# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
|
34 |
+
# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
|
35 |
+
qkv = self.c_attn(x)
|
36 |
+
q, k, v = qkv.split(self.n_embd, dim=2)
|
37 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
38 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
39 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
40 |
+
|
41 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
42 |
+
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
|
43 |
+
att = F.softmax(att, dim=-1)
|
44 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
45 |
+
|
46 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
47 |
+
# output projection
|
48 |
+
y = self.c_proj(y)
|
49 |
+
return y
|
50 |
+
|
51 |
+
|
52 |
+
class MLP(nn.Module):
|
53 |
+
|
54 |
+
def __init__(self, config):
|
55 |
+
super().__init__()
|
56 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
|
57 |
+
self.gelu = nn.GELU(approximate='tanh')
|
58 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
|
59 |
+
self.c_proj.NANOGPT_SCALE_INIT = 1
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
x = self.c_fc(x)
|
63 |
+
x = self.gelu(x)
|
64 |
+
x = self.c_proj(x)
|
65 |
+
return x
|
66 |
+
|
67 |
+
class Block(nn.Module):
|
68 |
+
|
69 |
+
def __init__(self, config):
|
70 |
+
super().__init__()
|
71 |
+
self.ln_1 = nn.LayerNorm(config.n_embd)
|
72 |
+
self.attn = CausalSelfAttention(config)
|
73 |
+
self.ln_2 = nn.LayerNorm(config.n_embd)
|
74 |
+
self.mlp = MLP(config)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
x = x + self.attn(self.ln_1(x))
|
78 |
+
x = x + self.mlp(self.ln_2(x))
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class GPTConfig:
|
84 |
+
block_size: int = 1024 # max sequence length
|
85 |
+
vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
|
86 |
+
n_layer: int = 12 # number of layers
|
87 |
+
n_head: int = 12 # number of heads
|
88 |
+
n_embd: int = 768 # embedding dimension
|
89 |
+
|
90 |
+
|
91 |
+
class GPT(nn.Module):
|
92 |
+
|
93 |
+
def __init__(self, config):
|
94 |
+
super().__init__()
|
95 |
+
self.config = config
|
96 |
+
|
97 |
+
self.transformer = nn.ModuleDict(dict(
|
98 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
99 |
+
wpe = nn.Embedding(config.block_size, config.n_embd),
|
100 |
+
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
101 |
+
ln_f = nn.LayerNorm(config.n_embd),
|
102 |
+
))
|
103 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
104 |
+
|
105 |
+
# weight sharing
|
106 |
+
self.transformer.wte.weight = self.lm_head.weight
|
107 |
+
|
108 |
+
# weight initialization
|
109 |
+
self.apply(self._init_weights)
|
110 |
+
|
111 |
+
def _init_weights(self, module):
|
112 |
+
if isinstance(module, nn.Linear):
|
113 |
+
std = 0.02
|
114 |
+
if hasattr(module, 'NANGPT_SCALE_INIT'):
|
115 |
+
std *= (2 * self.config.n_layer) ** -0.5
|
116 |
+
torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
|
117 |
+
if module.bias is not None:
|
118 |
+
torch.nn.init.zeros_(module.bias)
|
119 |
+
elif isinstance(module, nn.Embedding):
|
120 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
|
121 |
+
|
122 |
+
def print_num_parameters(self):
|
123 |
+
num_params = sum(p.numel() for p in self.parameters())
|
124 |
+
print(f"Number of model parameters: {num_params}")
|
125 |
+
|
126 |
+
def forward(self, idx, targets=None):
|
127 |
+
# idx is of shape (B, T)
|
128 |
+
B, T = idx.size()
|
129 |
+
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
|
130 |
+
# forward the token and posisition embeddings
|
131 |
+
pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
|
132 |
+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
|
133 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
|
134 |
+
x = tok_emb + pos_emb
|
135 |
+
# forward the blocks of the transformer
|
136 |
+
for block in self.transformer.h:
|
137 |
+
x = block(x)
|
138 |
+
# forward the final layernorm and the classifier
|
139 |
+
x = self.transformer.ln_f(x)
|
140 |
+
logits = self.lm_head(x) # (B, T, vocab_size)
|
141 |
+
loss = None
|
142 |
+
if targets is not None:
|
143 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
144 |
+
return logits, loss
|
145 |
+
|
146 |
+
@classmethod
|
147 |
+
def from_pretrained(cls, model_type):
|
148 |
+
"""Loads pretrained GPT-2 model weights from huggingface"""
|
149 |
+
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
|
150 |
+
from transformers import GPT2LMHeadModel
|
151 |
+
print("loading weights from pretrained gpt: %s" % model_type)
|
152 |
+
|
153 |
+
# n_layer, n_head and n_embd are determined from model_type
|
154 |
+
config_args = {
|
155 |
+
'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
|
156 |
+
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
|
157 |
+
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
|
158 |
+
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
|
159 |
+
}[model_type]
|
160 |
+
config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
|
161 |
+
config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
|
162 |
+
# create a from-scratch initialized minGPT model
|
163 |
+
config = GPTConfig(**config_args)
|
164 |
+
model = GPT(config)
|
165 |
+
sd = model.state_dict()
|
166 |
+
sd_keys = sd.keys()
|
167 |
+
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
|
168 |
+
|
169 |
+
# init a huggingface/transformers model
|
170 |
+
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
|
171 |
+
sd_hf = model_hf.state_dict()
|
172 |
+
|
173 |
+
# copy while ensuring all of the parameters are aligned and match in names and shapes
|
174 |
+
sd_keys_hf = sd_hf.keys()
|
175 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
|
176 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
|
177 |
+
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
178 |
+
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
|
179 |
+
# this means that we have to transpose these weights when we import them
|
180 |
+
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
|
181 |
+
for k in sd_keys_hf:
|
182 |
+
if any(k.endswith(w) for w in transposed):
|
183 |
+
# special treatment for the Conv1D weights we need to transpose
|
184 |
+
assert sd_hf[k].shape[::-1] == sd[k].shape
|
185 |
+
with torch.no_grad():
|
186 |
+
sd[k].copy_(sd_hf[k].t())
|
187 |
+
else:
|
188 |
+
# vanilla copy over the other parameters
|
189 |
+
assert sd_hf[k].shape == sd[k].shape
|
190 |
+
with torch.no_grad():
|
191 |
+
sd[k].copy_(sd_hf[k])
|
192 |
+
|
193 |
+
return model
|
194 |
+
|
195 |
+
|
196 |
+
device = 'cpu'
|
197 |
+
if torch.cuda.is_available():
|
198 |
+
device = 'cuda'
|
199 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
200 |
+
device = "mps"
|
201 |
+
print(f"using device: {device}")
|
202 |
+
|
203 |
+
# SEED
|
204 |
+
torch.manual_seed(1337)
|
205 |
+
if torch.cuda.is_available():
|
206 |
+
torch.cuda.manual_seed(1337)
|
207 |
+
|
208 |
+
class DataLoaderLite:
|
209 |
+
def __init__(self, B, T):
|
210 |
+
self.B = B
|
211 |
+
self.T = T
|
212 |
+
|
213 |
+
# at init load tokens from disk and store them in memory
|
214 |
+
with open('input.txt', 'r') as f:
|
215 |
+
text = f.read()
|
216 |
+
enc = tiktoken.get_encoding('gpt2')
|
217 |
+
tokens = enc.encode(text)
|
218 |
+
self.tokens = torch.tensor(tokens)
|
219 |
+
print(f'loaded {len(self.tokens)} tokens')
|
220 |
+
print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
|
221 |
+
|
222 |
+
# state
|
223 |
+
self.current_position = 0
|
224 |
+
|
225 |
+
def next_batch(self):
|
226 |
+
B, T = self.B, self.T
|
227 |
+
buf = self.tokens[self.current_position: self.current_position + B * T + 1]
|
228 |
+
x = (buf[:-1]).view(B, T) # inputs
|
229 |
+
y = (buf[1:]).view(B, T) # targets
|
230 |
+
# advance the position in the tensor
|
231 |
+
self.current_position += B*T
|
232 |
+
# if loading the next batch would be out of bounds, reset
|
233 |
+
if self.current_position + (B * T + 1) > len(self.tokens):
|
234 |
+
self.current_position = 0
|
235 |
+
return x, y
|
236 |
+
|
237 |
+
# Initialize the data loader with batch size 4 and sequence length 1024
|
238 |
+
train_loader = DataLoaderLite(B=4, T=1024)
|
239 |
+
|
240 |
+
# Initialize the model
|
241 |
+
model = GPT(GPTConfig())
|
242 |
+
model.to(device)
|
243 |
+
|
244 |
+
# Print number of model parameters
|
245 |
+
model.print_num_parameters()
|
246 |
+
|
247 |
+
# Define the optimizer
|
248 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
|
249 |
+
|
250 |
+
# Function to load the most recent checkpoint
|
251 |
+
def load_latest_checkpoint(model):
|
252 |
+
# Find the checkpoint file
|
253 |
+
checkpoint_file = 'checkpoint.pt'
|
254 |
+
if not os.path.exists(checkpoint_file):
|
255 |
+
return 0 # No checkpoint found, start from epoch 0
|
256 |
+
|
257 |
+
print(f'Loading checkpoint from {checkpoint_file}')
|
258 |
+
|
259 |
+
# Load the model state and epoch number
|
260 |
+
checkpoint = torch.load(checkpoint_file)
|
261 |
+
|
262 |
+
# Ensure the checkpoint contains the expected keys
|
263 |
+
if 'model_state_dict' not in checkpoint or 'epoch' not in checkpoint:
|
264 |
+
raise KeyError("Checkpoint does not contain required keys.")
|
265 |
+
|
266 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
267 |
+
|
268 |
+
# Return the epoch number
|
269 |
+
return checkpoint['epoch']
|
270 |
+
|
271 |
+
# Load the latest checkpoint if available
|
272 |
+
start_epoch = load_latest_checkpoint(model)
|
273 |
+
|
274 |
+
# NEW CODE: Training loop until loss is less than 0.099999
|
275 |
+
loss = float('inf') # Initialize loss to a large value
|
276 |
+
num_epochs = 78 # Set the number of epochs to 78
|
277 |
+
|
278 |
+
# Start time tracking
|
279 |
+
start_time = time.time()
|
280 |
+
|
281 |
+
for epoch in range(start_epoch, num_epochs): # Start from the loaded epoch
|
282 |
+
epoch_loss = 0.0 # Initialize epoch loss
|
283 |
+
num_steps = 0 # Initialize step counter for the epoch
|
284 |
+
last_loss = None # Variable to store the last loss
|
285 |
+
|
286 |
+
# Calculate total steps for the progress bar
|
287 |
+
total_steps = len(train_loader.tokens) // (train_loader.B * train_loader.T)
|
288 |
+
|
289 |
+
# Use tqdm to create a progress bar
|
290 |
+
with tqdm(total=total_steps, desc=f'Epoch {epoch + 1}/{num_epochs}') as pbar:
|
291 |
+
for step in range(total_steps): # Iterate over the number of steps
|
292 |
+
x, y = train_loader.next_batch()
|
293 |
+
x, y = x.to(device), y.to(device)
|
294 |
+
optimizer.zero_grad()
|
295 |
+
logits, loss = model(x, y)
|
296 |
+
loss.backward()
|
297 |
+
optimizer.step()
|
298 |
+
|
299 |
+
epoch_loss += loss.item() # Accumulate loss
|
300 |
+
num_steps += 1 # Increment step counter
|
301 |
+
last_loss = loss.item() # Store the last loss
|
302 |
+
pbar.update(1) # Update progress bar
|
303 |
+
|
304 |
+
# Check if the loss is below the threshold
|
305 |
+
if last_loss < 0.099999:
|
306 |
+
print(f'Loss below threshold: {last_loss:.6f}') # Print loss before breaking
|
307 |
+
break # Exit the loop if the loss condition is met
|
308 |
+
|
309 |
+
# Print the loss at the end of the epoch
|
310 |
+
print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {last_loss:.6f}')
|
311 |
+
|
312 |
+
# Check if the loss condition was met to break out of the epoch loop
|
313 |
+
if last_loss < 0.099999:
|
314 |
+
print(f'Early stopping at epoch {epoch + 1} due to loss condition met.')
|
315 |
+
break # Exit the epoch loop if the loss condition is met
|
316 |
+
|
317 |
+
# Checkpointing: Save the model and the current epoch after each epoch
|
318 |
+
checkpoint_path = 'checkpoint.pt' # Save to a single checkpoint file
|
319 |
+
torch.save({
|
320 |
+
'epoch': epoch + 1, # Save the current epoch number
|
321 |
+
'model_state_dict': model.state_dict(), # Save the model state
|
322 |
+
}, checkpoint_path)
|
323 |
+
print(f'Checkpoint saved to {checkpoint_path}')
|
324 |
+
|
325 |
+
# End time tracking
|
326 |
+
end_time = time.time()
|
327 |
+
training_duration = end_time - start_time
|
328 |
+
|
329 |
+
# Convert training duration to minutes and seconds
|
330 |
+
minutes = int(training_duration // 60)
|
331 |
+
seconds = int(training_duration % 60)
|
332 |
+
|
333 |
+
# Print the total training time in minute:second format
|
334 |
+
print(f'Total training time: {minutes} minutes and {seconds} seconds')
|
335 |
+
|
336 |
+
# After training your model, apply quantization and save it with compression
|
337 |
+
def save_model_with_quantization(model, file_path):
|
338 |
+
# Switch model to evaluation mode
|
339 |
+
model.eval()
|
340 |
+
|
341 |
+
# Apply dynamic quantization
|
342 |
+
quantized_model = torch.quantization.quantize_dynamic(
|
343 |
+
model, # the model to be quantized
|
344 |
+
{nn.Linear}, # layers to quantize
|
345 |
+
dtype=torch.qint8 # quantization type
|
346 |
+
)
|
347 |
+
|
348 |
+
# Save the quantized model with compression
|
349 |
+
torch.save(quantized_model.state_dict(), file_path, _use_new_zipfile_serialization=True)
|
350 |
+
print(f'Model saved to {file_path} with quantization and compression.')
|
351 |
+
|
352 |
+
# Call this function after training your model
|
353 |
+
save_model_with_quantization(model, 'trained_model_quantized.pt')
|