import re import requests from pathlib import Path from collections import defaultdict, Counter from tqdm import tqdm import matplotlib.pyplot as plt import json import numpy as np class TrieNode: """Node in the prefix tree (trie) for fast token matching""" def __init__(self): self.children = {} self.is_token = False self.token = None class BPETokenizer: def __init__(self, vocab_size=5000): self.vocab_size = vocab_size self.chars = [] # List of unique characters self.stoi = {} # String to index mapping self.itos = {} # Index to string mapping self.data = [] # Encoded text data self.special_tokens = ["", "", "", ""] # Statistics tracking self.stats = { "vocab_sizes": [], "data_sizes": [], "compression_ratios": [], "merge_counts": [], "tokens_created": [], "max_token_lengths": [1], } self.original_length = 0 self.max_token_length = 1 def initialize_vocab(self, text): """Initialize vocabulary from characters in text""" # Preprocess text first text = preprocess_hindi_text(text) # Get unique characters and add special tokens chars = sorted(list(set(text))) all_tokens = self.special_tokens + chars # Create mappings self.stoi = {ch: i for i, ch in enumerate(all_tokens)} self.itos = {i: ch for i, ch in enumerate(all_tokens)} # Initial encoding self.data = [self.stoi[c] for c in text] self.original_length = len(self.data) # Initialize stats self.stats["vocab_sizes"].append(len(self.stoi)) self.stats["data_sizes"].append(len(self.data)) self.stats["compression_ratios"].append(1.0) def get_digram_stats(self): """Optimized digram counting using Counter""" # Pre-compute pairs for all data at once pairs = zip(self.data, self.data[1:]) return Counter((int(pair[0]), int(pair[1])) for pair in pairs) def replace_byte_pair_in_data(self, pair, new_token): """Optimized pair replacement using numpy""" data = np.array(self.data) i = 0 result = [] # Use numpy's vectorized operations while i < len(data) - 1: if data[i] == pair[0] and data[i + 1] == pair[1]: result.append(new_token) i += 2 else: result.append(data[i]) i += 1 if i == len(data) - 1: result.append(data[-1]) return result def encode_pair(self, pair): """Add a new token to vocabulary from pair""" pair_str = self.itos[pair[0]] + self.itos[pair[1]] next_idx = len(self.itos) self.stoi[pair_str] = next_idx self.itos[next_idx] = pair_str # Update max token length self.max_token_length = max(self.max_token_length, len(pair_str)) return next_idx def train(self, texts, min_frequency=2, print_interval=500): """Optimized BPE training with vectorized operations""" # Combine all texts and initialize vocab print("Initializing vocabulary...") full_text = " ".join(texts) self.initialize_vocab(full_text) # Convert data to numpy array for faster operations data = np.array(self.data, dtype=np.int32) # Pre-compute character frequencies using numpy print("Computing initial frequencies...") unique, counts = np.unique(data, return_counts=True) char_freqs = dict(zip(unique, counts)) # Initialize progress bar pbar = tqdm(total=self.vocab_size - len(self.stoi), desc="Training BPE", position=0) # Batch processing parameters batch_size = min(1000, self.vocab_size - len(self.stoi)) stats_buffer = [] while len(self.stoi) < self.vocab_size: # Get pair frequencies using vectorized operations # Create a view of consecutive pairs pair_view = np.lib.stride_tricks.sliding_window_view(data, 2) # Convert to tuples for counting pairs = [tuple(pair) for pair in pair_view] pair_counts = Counter(pairs) if not pair_counts: break # Get top pairs for batch processing top_pairs = sorted(pair_counts.items(), key=lambda x: x[1], reverse=True)[:batch_size] # Process batch of pairs for (token1, token2), freq in top_pairs: if len(self.stoi) >= self.vocab_size: break # Create new token new_idx = self.encode_pair((token1, token2)) # Vectorized pair replacement # Create a boolean mask for matching pairs pair_mask = (data[:-1] == token1) & (data[1:] == token2) if not np.any(pair_mask): continue # Create new data array efficiently indices = np.where(pair_mask)[0] new_data = np.empty(len(data) - len(indices), dtype=np.int32) # Fill new data array using vectorized operations pos = 0 prev_idx = 0 for idx in indices: # Copy unchanged elements new_data[pos:pos + (idx - prev_idx)] = data[prev_idx:idx] pos += idx - prev_idx # Add merged token new_data[pos] = new_idx pos += 1 prev_idx = idx + 2 # Copy remaining elements if prev_idx < len(data): new_data[pos:] = data[prev_idx:] data = new_data # Update statistics stats_buffer.append({ 'vocab_size': len(self.stoi), 'data_size': len(data), 'merge_count': freq, 'new_token': self.itos[new_idx] }) pbar.update(1) # Batch update statistics if len(stats_buffer) >= print_interval: self._update_stats_batch(stats_buffer) if print_interval: self.print_progress( len(self.stoi), stats_buffer[-1]['new_token'], stats_buffer[-1]['merge_count'] ) stats_buffer = [] # Final statistics update if stats_buffer: self._update_stats_batch(stats_buffer) pbar.close() self.data = data.tolist() # Calculate final compression ratio final_ratio = self.original_length / len(self.data) print(f"\nTraining completed. Final vocabulary size: {len(self.stoi)}") print(f"Final compression ratio: {final_ratio:.2f}") def _update_stats_batch(self, stats_buffer): """Update statistics in batch for better performance""" if not stats_buffer: return # Update all statistics at once self.stats["vocab_sizes"].extend(s['vocab_size'] for s in stats_buffer) self.stats["data_sizes"].extend(s['data_size'] for s in stats_buffer) self.stats["merge_counts"].extend(s['merge_count'] for s in stats_buffer) self.stats["tokens_created"].extend(s['new_token'] for s in stats_buffer) # Update compression ratios new_ratios = [self.original_length / s['data_size'] for s in stats_buffer] self.stats["compression_ratios"].extend(new_ratios) # Update max token lengths self.stats["max_token_lengths"].extend([self.max_token_length] * len(stats_buffer)) def print_progress(self, iteration, new_token, merge_count): """Print training progress""" print(f"\nIteration {iteration:,}") print(f"Created token: '{new_token}' (merged {merge_count:,} times)") print(f"Current vocabulary size: {len(self.stoi):,}") print(f"Current data size: {len(self.data):,}") print(f"Current compression ratio: {self.stats['compression_ratios'][-1]:.2f}") print("-" * 80) def plot_statistics(self): """Plot training statistics""" fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) # Plot 1: Vocabulary Size vs Data Size ax1.plot(self.stats["vocab_sizes"], self.stats["data_sizes"]) ax1.set_xlabel("Vocabulary Size") ax1.set_ylabel("Dataset Size") ax1.set_title("Vocabulary Size vs Dataset Size") # Plot 2: Compression Ratio vs Vocabulary Size ax2.plot(self.stats["vocab_sizes"], self.stats["compression_ratios"]) ax2.set_xlabel("Vocabulary Size") ax2.set_ylabel("Compression Ratio") ax2.set_title("Compression Ratio vs Vocabulary Size") # Plot 3: Merge Counts Distribution if self.stats["merge_counts"]: ax3.hist(self.stats["merge_counts"], bins=30) ax3.set_xlabel("Number of Merges") ax3.set_ylabel("Frequency") ax3.set_title("Distribution of Merge Counts") # Plot 4: Token Lengths Over Time if self.stats["tokens_created"]: token_lengths = [len(token) for token in self.stats["tokens_created"]] ax4.plot(range(len(token_lengths)), token_lengths) ax4.set_xlabel("Merge Operation") ax4.set_ylabel("New Token Length") ax4.set_title("Token Length Evolution") plt.tight_layout() plt.show() def save(self, filepath: str) -> None: """Save tokenizer state to a JSON file""" state = { "stoi": self.stoi, "itos": self.itos, "max_token_length": self.max_token_length, "stats": self.stats, "special_tokens": self.special_tokens } with open(filepath, "w", encoding="utf-8") as f: json.dump(state, f, ensure_ascii=False, indent=2) @classmethod def load(cls, filepath: str) -> "BPETokenizer": """Load tokenizer state from a JSON file""" with open(filepath, "r", encoding="utf-8") as f: state = json.load(f) # Create new instance instance = cls() # Convert string keys to integers in itos instance.itos = {int(k): v for k, v in state["itos"].items()} # Convert string values to integers in stoi instance.stoi = {k: int(v) for k, v in state["stoi"].items()} instance.max_token_length = state["max_token_length"] instance.stats = state["stats"] instance.special_tokens = state["special_tokens"] # Debug info print(f"Loaded vocabulary size: {len(instance.itos)}") print(f"Max token ID: {max(instance.itos.keys())}") print(f"Sample tokens: {list(instance.itos.items())[:5]}") return instance def encode(self, text: str): """Convert text to token indices""" # Preprocess input text text = preprocess_hindi_text(text) tokens = [] token_ids = [] # Split text into words words = text.split() for word in words: # Try to find longest matching token while word: longest_match = None for token, idx in sorted(self.stoi.items(), key=lambda x: len(x[0]), reverse=True): if word.startswith(token): longest_match = (token, idx) break if longest_match: token, idx = longest_match tokens.append(token) token_ids.append(idx) word = word[len(token):] else: # Skip unknown character and continue word = word[1:] return token_ids, tokens def decode(self, token_ids: list) -> str: """Convert token indices back to text with better error handling""" decoded_tokens = [] max_id = max(self.itos.keys()) for idx in token_ids: try: # Convert to int and check range idx = int(idx) if isinstance(idx, str) else idx if idx < 0 or idx > max_id: continue # Get token from vocabulary if idx in self.itos: token = self.itos[idx] if token not in self.special_tokens: # Add token with space decoded_tokens.append(token) except (ValueError, KeyError): continue # Join all tokens with spaces and clean up extra spaces result = " ".join(token for token in decoded_tokens if token.strip()) # Remove duplicate spaces and strip result = " ".join(result.split()) return result def download_dataset(url, filepath, max_size_gb=2): """ Downloads a portion of the dataset with size limit and resume capability. Args: url (str): URL of the dataset filepath (Path): Path where the file should be saved max_size_gb (float): Maximum size to download in gigabytes """ max_size_bytes = max_size_gb * 1024 * 1024 * 1024 # Convert GB to bytes # Check if we already have enough data if filepath.exists() and filepath.stat().st_size >= max_size_bytes: print(f"Already have {max_size_gb}GB of data, skipping download.") return print(f"Downloading first {max_size_gb}GB from {url}") # Get the current size if file exists (for resume) current_size = filepath.stat().st_size if filepath.exists() else 0 # Set up headers for resume headers = {'Range': f'bytes={current_size}-'} if current_size > 0 else {} try: response = requests.get(url, stream=True, headers=headers) response.raise_for_status() # Get file size for progress bar total_size = min( int(response.headers.get('content-length', 0)) + current_size, max_size_bytes ) mode = 'ab' if current_size > 0 else 'wb' with open(filepath, mode) as file, tqdm( desc="Downloading", initial=current_size, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as progress_bar: for data in response.iter_content(chunk_size=8192): if not data: break size = file.write(data) progress_bar.update(size) # Check if we've reached the size limit if file.tell() >= max_size_bytes: print(f"\nReached {max_size_gb}GB limit, stopping download.") break except requests.exceptions.RequestException as e: print(f"Error during download: {e}") if filepath.exists(): print("Partial download remains available for resume.") raise def prepare_dataset(input_path, sample_size=None, max_lines=None): """ Prepares the dataset by optionally sampling and basic cleaning. Args: input_path (Path): Path to the raw dataset sample_size (int, optional): Number of lines to sample. If None, use entire dataset max_lines (int, optional): Maximum number of lines to read from file Returns: list: Processed lines from the dataset """ print("Reading and preparing dataset...") lines = [] with open(input_path, 'r', encoding='utf-8') as file: for i, line in enumerate(tqdm(file, desc="Reading lines")): if max_lines and i >= max_lines: break if line.strip(): lines.append(line) if sample_size and len(lines) >= sample_size: break return lines def preprocess_hindi_text(text): """ Preprocesses Hindi text by removing unwanted characters and normalizing punctuation. Args: text (str): Raw Hindi text input Returns: str: Cleaned and normalized text """ # Remove tokens first text = text.replace("", "") # Retain Hindi characters and punctuation text = re.sub(r"[^\u0900-\u097F\s।,.!?\-]", "", text) # Remove digits (both English and Hindi) text = re.sub(r"[0-9०-९]", "", text) # Normalize full stops and whitespace text = re.sub(r"।", ".", text) text = re.sub(r"\s+", " ", text).strip() return text def calculate_compression_ratio(tokenizer, corpus_path): """ Calculates the compression ratio for the tokenizer on the given corpus. Args: tokenizer (Tokenizer): Trained BPE tokenizer corpus_path (str): Path to the preprocessed corpus Returns: float: Compression ratio (characters/tokens) """ with open(corpus_path, "r", encoding="utf-8") as file: corpus = file.readlines() total_chars = sum(len(line) for line in corpus) total_tokens = sum(len(tokenizer.encode(line).tokens) for line in corpus) return total_chars / total_tokens def encode_text(tokenizer, text): cleaned_text = preprocess_hindi_text(text) return tokenizer.encode(cleaned_text) def decode_text(tokenizer, token_ids): return tokenizer.decode(token_ids) def test_tokenizer(tokenizer, test_text): """ Tests the tokenizer by encoding and decoding sample text. Args: tokenizer (Tokenizer): Trained BPE tokenizer test_text (str): Sample text for testing """ print("\nTokenizer Test:") print("-" * 50) print(f"Original Text: {test_text}") # Encode token_ids, tokens = encode_text(tokenizer, test_text) print(f"\nTokens: {tokens}") print(f"Token IDs: {token_ids}") # Decode decoded_text = decode_text(tokenizer, token_ids) print(f"\nDecoded Text: {decoded_text}") def main(): # Create output directory if it doesn't exist output_dir = Path("output") output_dir.mkdir(exist_ok=True) # Dataset URL and paths dataset_url = "https://objectstore.e2enetworks.net/ai4b-public-nlu-nlg/v1-indiccorp/hi.txt" raw_dataset_path = Path("raw_hindi_dataset.txt") preprocessed_path = output_dir / "preprocessed_hindi.txt" # Step 1: Download dataset if it doesn't exist or is too small if not raw_dataset_path.exists() or raw_dataset_path.stat().st_size < (10 * 1024 * 1024 * 1024): print("Step 1: Downloading dataset (10GB limit)...") try: download_dataset(dataset_url, raw_dataset_path, max_size_gb=10) except requests.exceptions.RequestException as e: print(f"Error downloading dataset: {e}") if not raw_dataset_path.exists(): return print("Continuing with existing partial download...") else: print("Sufficient dataset already exists, skipping download.") # Step 2: Prepare and preprocess the dataset print("Step 2: Preprocessing dataset...") try: # Sample 2 Million lines from the first 3 Million lines raw_data = prepare_dataset( raw_dataset_path, sample_size=2_000_000, max_lines=3_000_000 ) except FileNotFoundError: print(f"Error: Input file '{raw_dataset_path}' not found!") return except Exception as e: print(f"Error preparing dataset: {e}") return # Preprocess the text print("Cleaning and normalizing text...") preprocessed_data = [preprocess_hindi_text(line) for line in tqdm(raw_data)] # Save the preprocessed dataset with open(preprocessed_path, "w", encoding="utf-8") as file: file.write("\n".join(preprocessed_data)) # Initialize and train our custom BPE tokenizer tokenizer = BPETokenizer(vocab_size=5000) tokenizer.train(preprocessed_data, min_frequency=2) # Save the tokenizer config_path = output_dir / "hindi_encoder.json" tokenizer.save(str(config_path)) # Test the tokenizer #test_text = "नमस्ते भारत! यह एक परीक्षण वाक्य है।" test_text = "फिर पानी भी कम मात्रा में" test_tokenizer(tokenizer, test_text) return tokenizer def load_tokenizer(config_path): """ Loads a previously trained tokenizer from a configuration file. Args: config_path (str): Path to the tokenizer configuration file Returns: Tokenizer: Loaded tokenizer """ return BPETokenizer.load(config_path) if __name__ == "__main__": main()