MilindChawre commited on
Commit
b7ca7fe
·
1 Parent(s): 52bbfb5

Updating README and splitting training logic

Browse files
Files changed (4) hide show
  1. README.md +2 -3
  2. app.py +9 -5
  3. train.py +106 -0
  4. transformer.py +0 -118
README.md CHANGED
@@ -48,10 +48,9 @@ This project implements a transformer-based language model using PyTorch. The mo
48
  git clone https://github.com/yourusername/transformer-model-training.git
49
  cd transformer-model-training
50
  ```
51
-
52
- 2. Install the required packages:
53
  ```bash
54
- pip install -r requirements.txt
55
  ```
56
 
57
  ## Usage
 
48
  git clone https://github.com/yourusername/transformer-model-training.git
49
  cd transformer-model-training
50
  ```
51
+ 2. To train the model, run the training script:
 
52
  ```bash
53
+ python train.py
54
  ```
55
 
56
  ## Usage
app.py CHANGED
@@ -9,7 +9,8 @@ def load_model():
9
  config = GPTConfig()
10
  model = GPT(config)
11
  try:
12
- model.load_state_dict(torch.load('trained_model_quantized.pt'))
 
13
  model.eval() # Set the model to evaluation mode
14
  st.success("Model loaded successfully!")
15
  except Exception as e:
@@ -24,7 +25,7 @@ def load_tokenizer():
24
  def generate_text(model, tokenizer, input_text, length, num_sequences):
25
  # Encode the input text
26
  input_ids = tokenizer.encode(input_text)
27
- input_tensor = torch.tensor(input_ids).unsqueeze(0) # Add batch dimension
28
 
29
  generated_sequences = []
30
  for _ in range(num_sequences):
@@ -35,7 +36,10 @@ def generate_text(model, tokenizer, input_text, length, num_sequences):
35
  next_token_logits = logits[:, -1, :] # Get the last token's logits
36
  next_token_probs = torch.softmax(next_token_logits, dim=-1)
37
  next_token = torch.multinomial(next_token_probs, num_samples=1) # Sample from the distribution
38
- input_tensor = torch.cat((input_tensor, next_token.unsqueeze(0)), dim=1) # Append the new token
 
 
 
39
 
40
  # Decode the generated tokens
41
  generated_sequences.append(tokenizer.decode(input_tensor[0].tolist()))
@@ -51,8 +55,8 @@ length = st.slider("Predict Additional Text of Length", 1, 50, 10)
51
  num_sequences = st.slider("Number of Sequences to Generate", 1, 5, 1)
52
 
53
  if st.button("Generate"):
54
- model = load_model()
55
- tokenizer = load_tokenizer()
56
  st.write("Generating text...")
57
  generated_texts = generate_text(model, tokenizer, input_text, length, num_sequences)
58
  st.write("Text generation complete.")
 
9
  config = GPTConfig()
10
  model = GPT(config)
11
  try:
12
+ # Load the model with map_location to handle CPU-only environments
13
+ model.load_state_dict(torch.load('trained_model_quantized.pt', map_location=torch.device('cpu')), strict=False)
14
  model.eval() # Set the model to evaluation mode
15
  st.success("Model loaded successfully!")
16
  except Exception as e:
 
25
  def generate_text(model, tokenizer, input_text, length, num_sequences):
26
  # Encode the input text
27
  input_ids = tokenizer.encode(input_text)
28
+ input_tensor = torch.tensor(input_ids).unsqueeze(0) # Add batch dimension (shape: [1, T])
29
 
30
  generated_sequences = []
31
  for _ in range(num_sequences):
 
36
  next_token_logits = logits[:, -1, :] # Get the last token's logits
37
  next_token_probs = torch.softmax(next_token_logits, dim=-1)
38
  next_token = torch.multinomial(next_token_probs, num_samples=1) # Sample from the distribution
39
+
40
+ # Ensure the next_token has the correct shape for concatenation
41
+ next_token = next_token.view(1, -1) # Reshape to [1, 1] if necessary
42
+ input_tensor = torch.cat((input_tensor, next_token), dim=1) # Append the new token
43
 
44
  # Decode the generated tokens
45
  generated_sequences.append(tokenizer.decode(input_tensor[0].tolist()))
 
55
  num_sequences = st.slider("Number of Sequences to Generate", 1, 5, 1)
56
 
57
  if st.button("Generate"):
58
+ model = load_model() # Load the model for inference
59
+ tokenizer = load_tokenizer() # Load the tokenizer
60
  st.write("Generating text...")
61
  generated_texts = generate_text(model, tokenizer, input_text, length, num_sequences)
62
  st.write("Text generation complete.")
train.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ from transformer import GPT, GPTConfig, DataLoaderLite # Import your model and data loader
5
+
6
+ # Initialize the model and data loader
7
+ config = GPTConfig()
8
+ model = GPT(config)
9
+ train_loader = DataLoaderLite(B=4, T=1024)
10
+
11
+ # Define the optimizer
12
+ optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
13
+
14
+ # Function to load the most recent checkpoint
15
+ def load_latest_checkpoint(model):
16
+ checkpoint_file = 'checkpoint.pt'
17
+ if not os.path.exists(checkpoint_file):
18
+ return 0 # No checkpoint found, start from epoch 0
19
+
20
+ print(f'Loading checkpoint from {checkpoint_file}')
21
+ checkpoint = torch.load(checkpoint_file)
22
+ model.load_state_dict(checkpoint['model_state_dict'])
23
+ return checkpoint['epoch']
24
+
25
+ # Load the latest checkpoint if available
26
+ start_epoch = load_latest_checkpoint(model)
27
+
28
+ # Training loop
29
+ num_epochs = 78
30
+
31
+ # Start time tracking
32
+ start_time = time.time()
33
+
34
+ for epoch in range(start_epoch, num_epochs): # Start from the loaded epoch
35
+ epoch_loss = 0.0 # Initialize epoch loss
36
+ num_steps = 0 # Initialize step counter for the epoch
37
+ last_loss = None # Variable to store the last loss
38
+
39
+ # Calculate total steps for the progress bar
40
+ total_steps = len(train_loader.tokens) // (train_loader.B * train_loader.T)
41
+
42
+ # Use tqdm to create a progress bar
43
+ with tqdm(total=total_steps, desc=f'Epoch {epoch + 1}/{num_epochs}') as pbar:
44
+ for step in range(total_steps): # Iterate over the number of steps
45
+ x, y = train_loader.next_batch()
46
+ x, y = x.to(device), y.to(device)
47
+ optimizer.zero_grad()
48
+ logits, loss = model(x, y)
49
+ loss.backward()
50
+ optimizer.step()
51
+
52
+ epoch_loss += loss.item() # Accumulate loss
53
+ num_steps += 1 # Increment step counter
54
+ last_loss = loss.item() # Store the last loss
55
+ pbar.update(1) # Update progress bar
56
+
57
+ # Check if the loss is below the threshold
58
+ if last_loss < 0.099999:
59
+ print(f'Loss below threshold: {last_loss:.6f}') # Print loss before breaking
60
+ break # Exit the loop if the loss condition is met
61
+
62
+ # Print the loss at the end of the epoch
63
+ print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {last_loss:.6f}')
64
+
65
+ # Check if the loss condition was met to break out of the epoch loop
66
+ if last_loss < 0.099999:
67
+ print(f'Early stopping at epoch {epoch + 1} due to loss condition met.')
68
+ break # Exit the epoch loop if the loss condition is met
69
+
70
+ # Checkpointing: Save the model and the current epoch after each epoch
71
+ checkpoint_path = 'checkpoint.pt' # Save to a single checkpoint file
72
+ torch.save({
73
+ 'epoch': epoch + 1, # Save the current epoch number
74
+ 'model_state_dict': model.state_dict(), # Save the model state
75
+ }, checkpoint_path)
76
+ print(f'Checkpoint saved to {checkpoint_path}')
77
+
78
+ # End time tracking
79
+ end_time = time.time()
80
+ training_duration = end_time - start_time
81
+
82
+ # Convert training duration to minutes and seconds
83
+ minutes = int(training_duration // 60)
84
+ seconds = int(training_duration % 60)
85
+
86
+ # Print the total training time in minute:second format
87
+ print(f'Total training time: {minutes} minutes and {seconds} seconds')
88
+
89
+ # After training your model, apply quantization and save it with compression
90
+ def save_model_with_quantization(model, file_path):
91
+ # Switch model to evaluation mode
92
+ model.eval()
93
+
94
+ # Apply dynamic quantization
95
+ quantized_model = torch.quantization.quantize_dynamic(
96
+ model, # the model to be quantized
97
+ {nn.Linear}, # layers to quantize
98
+ dtype=torch.qint8 # quantization type
99
+ )
100
+
101
+ # Save the quantized model with compression
102
+ torch.save(quantized_model.state_dict(), file_path, _use_new_zipfile_serialization=True)
103
+ print(f'Model saved to {file_path} with quantization and compression.')
104
+
105
+ # Call this function after training your model
106
+ save_model_with_quantization(model, 'trained_model_quantized.pt')
transformer.py CHANGED
@@ -233,121 +233,3 @@ class DataLoaderLite:
233
  if self.current_position + (B * T + 1) > len(self.tokens):
234
  self.current_position = 0
235
  return x, y
236
-
237
- # Initialize the data loader with batch size 4 and sequence length 1024
238
- train_loader = DataLoaderLite(B=4, T=1024)
239
-
240
- # Initialize the model
241
- model = GPT(GPTConfig())
242
- model.to(device)
243
-
244
- # Print number of model parameters
245
- model.print_num_parameters()
246
-
247
- # Define the optimizer
248
- optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
249
-
250
- # Function to load the most recent checkpoint
251
- def load_latest_checkpoint(model):
252
- # Find the checkpoint file
253
- checkpoint_file = 'checkpoint.pt'
254
- if not os.path.exists(checkpoint_file):
255
- return 0 # No checkpoint found, start from epoch 0
256
-
257
- print(f'Loading checkpoint from {checkpoint_file}')
258
-
259
- # Load the model state and epoch number
260
- checkpoint = torch.load(checkpoint_file)
261
-
262
- # Ensure the checkpoint contains the expected keys
263
- if 'model_state_dict' not in checkpoint or 'epoch' not in checkpoint:
264
- raise KeyError("Checkpoint does not contain required keys.")
265
-
266
- model.load_state_dict(checkpoint['model_state_dict'])
267
-
268
- # Return the epoch number
269
- return checkpoint['epoch']
270
-
271
- # Load the latest checkpoint if available
272
- start_epoch = load_latest_checkpoint(model)
273
-
274
- # NEW CODE: Training loop until loss is less than 0.099999
275
- loss = float('inf') # Initialize loss to a large value
276
- num_epochs = 78 # Set the number of epochs to 78
277
-
278
- # Start time tracking
279
- start_time = time.time()
280
-
281
- for epoch in range(start_epoch, num_epochs): # Start from the loaded epoch
282
- epoch_loss = 0.0 # Initialize epoch loss
283
- num_steps = 0 # Initialize step counter for the epoch
284
- last_loss = None # Variable to store the last loss
285
-
286
- # Calculate total steps for the progress bar
287
- total_steps = len(train_loader.tokens) // (train_loader.B * train_loader.T)
288
-
289
- # Use tqdm to create a progress bar
290
- with tqdm(total=total_steps, desc=f'Epoch {epoch + 1}/{num_epochs}') as pbar:
291
- for step in range(total_steps): # Iterate over the number of steps
292
- x, y = train_loader.next_batch()
293
- x, y = x.to(device), y.to(device)
294
- optimizer.zero_grad()
295
- logits, loss = model(x, y)
296
- loss.backward()
297
- optimizer.step()
298
-
299
- epoch_loss += loss.item() # Accumulate loss
300
- num_steps += 1 # Increment step counter
301
- last_loss = loss.item() # Store the last loss
302
- pbar.update(1) # Update progress bar
303
-
304
- # Check if the loss is below the threshold
305
- if last_loss < 0.099999:
306
- print(f'Loss below threshold: {last_loss:.6f}') # Print loss before breaking
307
- break # Exit the loop if the loss condition is met
308
-
309
- # Print the loss at the end of the epoch
310
- print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {last_loss:.6f}')
311
-
312
- # Check if the loss condition was met to break out of the epoch loop
313
- if last_loss < 0.099999:
314
- print(f'Early stopping at epoch {epoch + 1} due to loss condition met.')
315
- break # Exit the epoch loop if the loss condition is met
316
-
317
- # Checkpointing: Save the model and the current epoch after each epoch
318
- checkpoint_path = 'checkpoint.pt' # Save to a single checkpoint file
319
- torch.save({
320
- 'epoch': epoch + 1, # Save the current epoch number
321
- 'model_state_dict': model.state_dict(), # Save the model state
322
- }, checkpoint_path)
323
- print(f'Checkpoint saved to {checkpoint_path}')
324
-
325
- # End time tracking
326
- end_time = time.time()
327
- training_duration = end_time - start_time
328
-
329
- # Convert training duration to minutes and seconds
330
- minutes = int(training_duration // 60)
331
- seconds = int(training_duration % 60)
332
-
333
- # Print the total training time in minute:second format
334
- print(f'Total training time: {minutes} minutes and {seconds} seconds')
335
-
336
- # After training your model, apply quantization and save it with compression
337
- def save_model_with_quantization(model, file_path):
338
- # Switch model to evaluation mode
339
- model.eval()
340
-
341
- # Apply dynamic quantization
342
- quantized_model = torch.quantization.quantize_dynamic(
343
- model, # the model to be quantized
344
- {nn.Linear}, # layers to quantize
345
- dtype=torch.qint8 # quantization type
346
- )
347
-
348
- # Save the quantized model with compression
349
- torch.save(quantized_model.state_dict(), file_path, _use_new_zipfile_serialization=True)
350
- print(f'Model saved to {file_path} with quantization and compression.')
351
-
352
- # Call this function after training your model
353
- save_model_with_quantization(model, 'trained_model_quantized.pt')
 
233
  if self.current_position + (B * T + 1) > len(self.tokens):
234
  self.current_position = 0
235
  return x, y