Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch | |
from transformer import GPT, GPTConfig, DataLoaderLite # Import your model and data loader | |
# Initialize the model and data loader | |
config = GPTConfig() | |
model = GPT(config) | |
train_loader = DataLoaderLite(B=4, T=1024) | |
# Define the optimizer | |
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) | |
# Function to load the most recent checkpoint | |
def load_latest_checkpoint(model): | |
checkpoint_file = 'checkpoint.pt' | |
if not os.path.exists(checkpoint_file): | |
return 0 # No checkpoint found, start from epoch 0 | |
print(f'Loading checkpoint from {checkpoint_file}') | |
checkpoint = torch.load(checkpoint_file) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
return checkpoint['epoch'] | |
# Load the latest checkpoint if available | |
start_epoch = load_latest_checkpoint(model) | |
# Training loop | |
num_epochs = 91 | |
# Start time tracking | |
start_time = time.time() | |
for epoch in range(start_epoch, num_epochs): # Start from the loaded epoch | |
epoch_loss = 0.0 # Initialize epoch loss | |
num_steps = 0 # Initialize step counter for the epoch | |
last_loss = None # Variable to store the last loss | |
# Calculate total steps for the progress bar | |
total_steps = len(train_loader.tokens) // (train_loader.B * train_loader.T) | |
# Use tqdm to create a progress bar | |
with tqdm(total=total_steps, desc=f'Epoch {epoch + 1}/{num_epochs}') as pbar: | |
for step in range(total_steps): # Iterate over the number of steps | |
x, y = train_loader.next_batch() | |
x, y = x.to(device), y.to(device) | |
optimizer.zero_grad() | |
logits, loss = model(x, y) | |
loss.backward() | |
optimizer.step() | |
epoch_loss += loss.item() # Accumulate loss | |
num_steps += 1 # Increment step counter | |
last_loss = loss.item() # Store the last loss | |
pbar.update(1) # Update progress bar | |
# Check if the loss is below the threshold | |
if last_loss < 0.099999: | |
print(f'Loss below threshold: {last_loss:.6f}') # Print loss before breaking | |
break # Exit the loop if the loss condition is met | |
# Print the loss at the end of the epoch | |
print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {last_loss:.6f}') | |
# Check if the loss condition was met to break out of the epoch loop | |
if last_loss < 0.099999: | |
print(f'Early stopping at epoch {epoch + 1} due to loss condition met.') | |
break # Exit the epoch loop if the loss condition is met | |
# Checkpointing: Save the model and the current epoch after each epoch | |
checkpoint_path = 'checkpoint.pt' # Save to a single checkpoint file | |
torch.save({ | |
'epoch': epoch + 1, # Save the current epoch number | |
'model_state_dict': model.state_dict(), # Save the model state | |
}, checkpoint_path) | |
print(f'Checkpoint saved to {checkpoint_path}') | |
# End time tracking | |
end_time = time.time() | |
training_duration = end_time - start_time | |
# Convert training duration to minutes and seconds | |
minutes = int(training_duration // 60) | |
seconds = int(training_duration % 60) | |
# Print the total training time in minute:second format | |
print(f'Total training time: {minutes} minutes and {seconds} seconds') | |
# After training your model, apply quantization and save it with compression | |
def save_model_with_quantization(model, file_path): | |
# Switch model to evaluation mode | |
model.eval() | |
# Apply dynamic quantization | |
quantized_model = torch.quantization.quantize_dynamic( | |
model, # the model to be quantized | |
{nn.Linear}, # layers to quantize | |
dtype=torch.qint8 # quantization type | |
) | |
# Save the quantized model with compression | |
torch.save(quantized_model.state_dict(), file_path, _use_new_zipfile_serialization=True) | |
print(f'Model saved to {file_path} with quantization and compression.') | |
# Call this function after training your model | |
save_model_with_quantization(model, 'trained_model_quantized.pt') | |