diff --git a/dataset.py b/dataset.py index 185fc66..7084a5d 100644 --- a/dataset.py +++ b/dataset.py @@ -100,7 +100,7 @@ def decode(self, ids: list) -> str: characters = [] for id in ids: characters.append(self.id_to_char[id]) - return ''.join(characters) + return "".join(characters) def get_batch(data: torch.Tensor, block_size: int, batch_size: int): @@ -152,8 +152,10 @@ def get_batch(data: torch.Tensor, block_size: int, batch_size: int): y_list = [] for pos in positions: - x_list.append(data[pos : pos + block_size]) # Input: chars 0 to n-1 - y_list.append(data[pos + 1 : pos + block_size + 1]) # Target: chars 1 to n (shifted by 1) + x_list.append(data[pos : pos + block_size]) # Input: chars 0 to n-1 + y_list.append( + data[pos + 1 : pos + block_size + 1] + ) # Target: chars 1 to n (shifted by 1) # 3. Stack into batch tensors: (batch_size, block_size) x = torch.stack(x_list) @@ -179,7 +181,7 @@ def load_data(block_size: int = 256, train_split: float = 0.9): download_shakespeare() # 2. Load the text file - with open(DATA_PATH, 'r', encoding='utf-8') as file: + with open(DATA_PATH, "r", encoding="utf-8") as file: text = file.read() print(f"\nDataset size: {len(text):,} characters") diff --git a/generate.py b/generate.py index 59d6c4d..b4d94d5 100644 --- a/generate.py +++ b/generate.py @@ -18,6 +18,7 @@ # Load Model # ============================================================================== + def load_model(checkpoint_path: str = "checkpoints/model.pt"): """ Load a trained GPT model from checkpoint. @@ -39,26 +40,26 @@ def load_model(checkpoint_path: str = "checkpoints/model.pt"): checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) # 3. Get the model configuration that was saved during training - config = checkpoint['config'] + config = checkpoint["config"] # 4. Create the tokenizer (we need the same vocabulary as training) download_shakespeare() - with open(DATA_PATH, 'r', encoding='utf-8') as file: + with open(DATA_PATH, "r", encoding="utf-8") as file: text = file.read() tokenizer = CharacterTokenizer(text) # 5. Create the model with the saved configuration model = GPT( - vocab_size=config['vocab_size'], - embedding_dim=config['embedding_dim'], - num_heads=config['num_heads'], - num_layers=config['num_layers'], - block_size=config['block_size'], - dropout=0.0 # No dropout during generation + vocab_size=config["vocab_size"], + embedding_dim=config["embedding_dim"], + num_heads=config["num_heads"], + num_layers=config["num_layers"], + block_size=config["block_size"], + dropout=0.0, # No dropout during generation ) # 6. Load the trained weights into the model - model.load_state_dict(checkpoint['model_state_dict']) + model.load_state_dict(checkpoint["model_state_dict"]) # 7. Move model to device and set to evaluation mode model = model.to(device) @@ -76,8 +77,16 @@ def load_model(checkpoint_path: str = "checkpoints/model.pt"): # Generate Text # ============================================================================== + @torch.no_grad() -def generate(model, tokenizer, device, prompt: str, max_tokens: int = 500, temperature: float = 0.8): +def generate( + model, + tokenizer, + device, + prompt: str, + max_tokens: int = 500, + temperature: float = 0.8, +): """ Generate text given a starting prompt. @@ -103,13 +112,13 @@ def generate(model, tokenizer, device, prompt: str, max_tokens: int = 500, tempe # 1. Convert prompt text to token IDs prompt_ids = tokenizer.encode(prompt) input_ids = torch.tensor(prompt_ids, dtype=torch.long, device=device) - input_ids = input_ids.unsqueeze(0) # Add batch dimension: shape becomes (1, seq_len) + input_ids = input_ids.unsqueeze( + 0 + ) # Add batch dimension: shape becomes (1, seq_len) # 2. Generate new tokens using the model's generate method output_ids = model.generate( - input_ids=input_ids, - max_new_tokens=max_tokens, - temperature=temperature + input_ids=input_ids, max_new_tokens=max_tokens, temperature=temperature ) # 3. Convert token IDs back to text @@ -123,7 +132,6 @@ def generate(model, tokenizer, device, prompt: str, max_tokens: int = 500, tempe # ============================================================================== if __name__ == "__main__": - # 1. Load the trained model print("=" * 60) print("Shakespeare GPT - Text Generation") @@ -152,7 +160,7 @@ def generate(model, tokenizer, device, prompt: str, max_tokens: int = 500, tempe device=device, prompt=prompt, max_tokens=300, - temperature=0.8 + temperature=0.8, ) print(generated_text) @@ -169,7 +177,7 @@ def generate(model, tokenizer, device, prompt: str, max_tokens: int = 500, tempe try: prompt = input("\nYour prompt: ") - if prompt.lower() in ['quit', 'exit', 'q']: + if prompt.lower() in ["quit", "exit", "q"]: print("Farewell!") break @@ -182,7 +190,7 @@ def generate(model, tokenizer, device, prompt: str, max_tokens: int = 500, tempe device=device, prompt=prompt, max_tokens=500, - temperature=0.8 + temperature=0.8, ) print("\n" + generated_text) diff --git a/model.py b/model.py index ab27128..bd0a58f 100644 --- a/model.py +++ b/model.py @@ -75,10 +75,9 @@ class TransformerBlock(nn.Module): # The key relationship: head_size = embedding_dim / num_heads = 64 # This head_size=64 is consistent across most GPT models. # ------------------------------------------------------------------------- - def __init__(self, - embedding_dim: int = 384, - num_heads: int = 6, - dropout: float = 0.1): + def __init__( + self, embedding_dim: int = 384, num_heads: int = 6, dropout: float = 0.1 + ): super().__init__() # 2. Create the first Layer Normalization @@ -98,7 +97,7 @@ def __init__(self, embed_dim=embedding_dim, num_heads=num_heads, dropout=dropout, - batch_first=True # Input shape: (batch, sequence, embedding) + batch_first=True, # Input shape: (batch, sequence, embedding) ) # 4. Create the second Layer Normalization @@ -117,9 +116,9 @@ def __init__(self, # patterns, then compresses back to the original size. self.mlp = nn.Sequential( nn.Linear(embedding_dim, 4 * embedding_dim), # Expand: 384 β†’ 1536 - nn.GELU(), # Activation function + nn.GELU(), # Activation function nn.Linear(4 * embedding_dim, embedding_dim), # Project back: 1536 β†’ 384 - nn.Dropout(dropout) # Regularization + nn.Dropout(dropout), # Regularization ) # 6. Create the forward method @@ -140,7 +139,7 @@ def forward(self, x: torch.Tensor, causal_mask: torch.Tensor) -> torch.Tensor: key=x_norm, value=x_norm, attn_mask=causal_mask, - is_causal=False # We provide our own mask + is_causal=False, # We provide our own mask ) x = x + attn_output # Residual connection @@ -165,13 +164,15 @@ class GPT(nn.Module): """ # 1. Initialize the class with hyperparameters - def __init__(self, - vocab_size: int, - embedding_dim: int = 384, - num_heads: int = 6, - num_layers: int = 6, - block_size: int = 256, - dropout: float = 0.1): + def __init__( + self, + vocab_size: int, + embedding_dim: int = 384, + num_heads: int = 6, + num_layers: int = 6, + block_size: int = 256, + dropout: float = 0.1, + ): super().__init__() # 2. Store block_size for generation @@ -210,30 +211,28 @@ def __init__(self, # (e.g., start of sentence behaves differently from middle) # ----------------------------------------------------------------------- self.token_embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim + num_embeddings=vocab_size, embedding_dim=embedding_dim ) # 4. Create Position Embedding layer # Each position (0, 1, 2, ..., block_size-1) gets its own learnable vector # Same as nn.Parameter(torch.randn(block_size, embedding_dim)) in ViT self.position_embedding = nn.Embedding( - num_embeddings=block_size, - embedding_dim=embedding_dim + num_embeddings=block_size, embedding_dim=embedding_dim ) # 5. Create Embedding Dropout self.dropout = nn.Dropout(dropout) # 6. Create stack of Transformer Blocks - self.blocks = nn.ModuleList([ - TransformerBlock( - embedding_dim=embedding_dim, - num_heads=num_heads, - dropout=dropout - ) - for _ in range(num_layers) - ]) + self.blocks = nn.ModuleList( + [ + TransformerBlock( + embedding_dim=embedding_dim, num_heads=num_heads, dropout=dropout + ) + for _ in range(num_layers) + ] + ) # 7. Create Final Layer Normalization self.ln_final = nn.LayerNorm(embedding_dim) @@ -285,8 +284,7 @@ def __init__(self, # torch.triu creates an upper triangular matrix of True values. # True = masked (blocked), False = allowed to attend causal_mask = torch.triu( - torch.ones(block_size, block_size, dtype=torch.bool), - diagonal=1 + torch.ones(block_size, block_size, dtype=torch.bool), diagonal=1 ) # register_buffer: saves tensor with model & moves it to GPU with model, # but it's NOT a learnable parameter (optimizer won't update it) @@ -320,9 +318,7 @@ def _init_weights(self, module): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) # 12. Create the forward method - def forward(self, - input_ids: torch.Tensor, - targets: torch.Tensor = None) -> tuple: + def forward(self, input_ids: torch.Tensor, targets: torch.Tensor = None) -> tuple: """ Forward pass of the GPT model. @@ -397,7 +393,9 @@ def forward(self, # 21. Generate method for text generation @torch.no_grad() - def generate(self, input_ids: torch.Tensor, max_new_tokens: int, temperature: float = 1.0): + def generate( + self, input_ids: torch.Tensor, max_new_tokens: int, temperature: float = 1.0 + ): """ Generate new tokens one at a time (autoregressive generation). @@ -410,7 +408,6 @@ def generate(self, input_ids: torch.Tensor, max_new_tokens: int, temperature: fl temperature: Controls randomness (0.5=predictable, 1.0=normal, 1.5=creative) """ for _ in range(max_new_tokens): - # 22. If sequence is longer than block_size, crop to last block_size tokens # .size(1) gets the sequence length (dim 0=batch, dim 1=sequence) # -self.block_size uses negative indexing to take the LAST 256 tokens @@ -419,7 +416,7 @@ def generate(self, input_ids: torch.Tensor, max_new_tokens: int, temperature: fl if input_ids.size(1) <= self.block_size: current_input = input_ids else: - current_input = input_ids[:, -self.block_size:] + current_input = input_ids[:, -self.block_size :] # 23. Get model predictions logits, _ = self.forward(current_input) @@ -457,11 +454,7 @@ def generate(self, input_ids: torch.Tensor, max_new_tokens: int, temperature: fl # 2. Create model model = GPT( - vocab_size=65, - embedding_dim=384, - num_heads=6, - num_layers=6, - block_size=256 + vocab_size=65, embedding_dim=384, num_heads=6, num_layers=6, block_size=256 ).to(device) # 3. Create dummy input diff --git a/train.py b/train.py index ff8da1b..5191ac6 100644 --- a/train.py +++ b/train.py @@ -16,27 +16,45 @@ from model import GPT from dataset import load_data, get_batch +from rich.console import Console +from rich.progress import ( + Progress, + SpinnerColumn, + BarColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, + TaskProgressColumn, + MofNCompleteColumn, +) +from rich.table import Table +from rich.panel import Panel +from rich import box +from rich.rule import Rule + +console = Console() + # ============================================================================== # Hyperparameters (Settings) # ============================================================================== # Model architecture -EMBEDDING_DIM = 384 # Size of embeddings (how big each vector is) -NUM_HEADS = 6 # Number of attention heads -NUM_LAYERS = 6 # Number of transformer blocks -BLOCK_SIZE = 256 # Maximum sequence length -DROPOUT = 0.1 # Dropout rate for regularization +EMBEDDING_DIM = 384 # Size of embeddings (how big each vector is) +NUM_HEADS = 6 # Number of attention heads +NUM_LAYERS = 6 # Number of transformer blocks +BLOCK_SIZE = 256 # Maximum sequence length +DROPOUT = 0.1 # Dropout rate for regularization # Training settings -BATCH_SIZE = 64 # Number of sequences per batch -MAX_ITERS = 5000 # Total training iterations -EVAL_INTERVAL = 500 # Evaluate every N iterations -LEARNING_RATE = 3e-4 # Learning rate -WARMUP_ITERS = 100 # Warmup iterations (gradually increase LR) +BATCH_SIZE = 64 # Number of sequences per batch +MAX_ITERS = 5000 # Total training iterations +EVAL_INTERVAL = 500 # Evaluate every N iterations +LEARNING_RATE = 3e-4 # Learning rate +WARMUP_ITERS = 100 # Warmup iterations (gradually increase LR) # System -DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" +DEVICE = "cuda" if torch.backends.mps.is_available() else "cpu" CHECKPOINT_PATH = "checkpoints/model.pt" @@ -44,6 +62,7 @@ # Learning Rate Schedule (with Warmup) # ============================================================================== + def get_learning_rate(iteration): """ Learning rate with warmup. @@ -75,10 +94,8 @@ def get_learning_rate(iteration): Iteration 100+: LR stays at 0.0003 (constant) """ if iteration < WARMUP_ITERS: - # Warmup: linearly increase from 0 to LEARNING_RATE return LEARNING_RATE * (iteration / WARMUP_ITERS) else: - # After warmup: use constant learning rate return LEARNING_RATE @@ -86,24 +103,22 @@ def get_learning_rate(iteration): # Evaluation Function # ============================================================================== + @torch.no_grad() def evaluate(model, train_data, val_data): """Calculate average loss on training and validation data.""" model.eval() - + results = {} for name, data in [("train", train_data), ("val", val_data)]: total_loss = 0.0 - - # Run 100 batches to get a good estimate for _ in range(100): x, y = get_batch(data, BLOCK_SIZE, BATCH_SIZE) x, y = x.to(DEVICE), y.to(DEVICE) _, loss = model(x, y) total_loss += loss.item() - results[name] = total_loss / 100 - + model.train() return results @@ -112,137 +127,275 @@ def evaluate(model, train_data, val_data): # Sample Generation (to see progress) # ============================================================================== + @torch.no_grad() def generate_sample(model, tokenizer): """Generate a text sample to see how the model is learning.""" model.eval() - + prompt = "ROMEO:" prompt_ids = tokenizer.encode(prompt) input_ids = torch.tensor(prompt_ids, dtype=torch.long, device=DEVICE).unsqueeze(0) - + output_ids = model.generate(input_ids, max_new_tokens=200, temperature=0.8) - + model.train() return tokenizer.decode(output_ids[0]) +# ============================================================================== +# Rich UI Helpers +# ============================================================================== + + +def format_time(seconds): + """Format seconds into human-readable HH:MM:SS.""" + h = int(seconds // 3600) + m = int((seconds % 3600) // 60) + s = int(seconds % 60) + if h > 0: + return f"{h:02d}:{m:02d}:{s:02d}" + return f"{m:02d}:{s:02d}" + + +def build_stats_table(history): + """Build a rich table showing training history.""" + table = Table( + title="πŸ“Š Training History", + box=box.ROUNDED, + border_style="bright_blue", + header_style="bold cyan", + show_lines=True, + expand=False, + ) + table.add_column("Iteration", style="bold white", justify="right", width=10) + table.add_column("Train Loss", style="green", justify="center", width=12) + table.add_column("Val Loss", style="yellow", justify="center", width=12) + table.add_column("LR", style="magenta", justify="center", width=10) + table.add_column("Elapsed", style="cyan", justify="center", width=10) + + for row in history: + # Color val loss: green if improving, red if worsening + val_color = "bright_green" + if len(history) > 1: + prev_val = ( + history[-2]["val_loss"] if history[-2] != row else row["val_loss"] + ) + val_color = "bright_green" if row["val_loss"] <= prev_val else "bright_red" + + table.add_row( + f"{row['iter']:,}", + f"{row['train_loss']:.4f}", + f"[{val_color}]{row['val_loss']:.4f}[/{val_color}]", + f"{row['lr']:.2e}", + format_time(row["elapsed"]), + ) + + return table + + +def build_info_panel(iteration, loss, lr, elapsed, iters_per_sec): + """Build a live status panel.""" + remaining_iters = MAX_ITERS - iteration + eta = remaining_iters / iters_per_sec if iters_per_sec > 0 else 0 + progress_pct = 100 * iteration / MAX_ITERS + + info = ( + f"[bold cyan]Iter:[/bold cyan] {iteration:,} / {MAX_ITERS:,}\n" + f"[bold cyan]Progress:[/bold cyan] {progress_pct:.1f}%\n" + f"[bold cyan]Loss:[/bold cyan] [bold yellow]{loss:.4f}[/bold yellow]\n" + f"[bold cyan]LR:[/bold cyan] [magenta]{lr:.2e}[/magenta]\n" + f"[bold cyan]Speed:[/bold cyan] {iters_per_sec:.1f} iter/s\n" + f"[bold cyan]Elapsed:[/bold cyan] [green]{format_time(elapsed)}[/green]\n" + f"[bold cyan]ETA:[/bold cyan] [red]{format_time(eta)}[/red]" + ) + return Panel( + info, title="⚑ Live Stats", border_style="bright_yellow", expand=False + ) + + # ============================================================================== # Main Training Function # ============================================================================== + def train(): - print("=" * 60) - print("Shakespeare GPT Training") - print("=" * 60) - print(f"Device: {DEVICE}") - - # ------------------------------------------------------------------------- - # 1. Load Data - # ------------------------------------------------------------------------- - print("\nLoading data...") - train_data, val_data, tokenizer = load_data(block_size=BLOCK_SIZE) - vocab_size = tokenizer.vocab_size - - # ------------------------------------------------------------------------- - # 2. Create Model - # ------------------------------------------------------------------------- - print("\nCreating model...") - model = GPT( - vocab_size=vocab_size, - embedding_dim=EMBEDDING_DIM, - num_heads=NUM_HEADS, - num_layers=NUM_LAYERS, - block_size=BLOCK_SIZE, - dropout=DROPOUT + # ── Header ──────────────────────────────────────────────────────────────── + console.print() + console.print( + Panel.fit( + "[bold magenta]🎭 Shakespeare GPT β€” Training[/bold magenta]\n" + f"[dim]Device: {DEVICE} | Batch: {BATCH_SIZE} | LR: {LEARNING_RATE} |" + f" Iters: {MAX_ITERS} | Heads: {NUM_HEADS} | Layers: {NUM_LAYERS}[/dim]", + border_style="magenta", + box=box.DOUBLE_EDGE, + ) ) - model = model.to(DEVICE) - - # ------------------------------------------------------------------------- - # 3. Create Optimizer - # ------------------------------------------------------------------------- + console.print() + + # ── Load Data ───────────────────────────────────────────────────────────── + with console.status("[bold green]Loading dataset...[/bold green]", spinner="dots"): + train_data, val_data, tokenizer = load_data(block_size=BLOCK_SIZE) + vocab_size = tokenizer.vocab_size + console.print(f" βœ… Data loaded β€” vocab size: [bold]{vocab_size:,}[/bold]") + + # ── Create Model ────────────────────────────────────────────────────────── + with console.status("[bold green]Building model...[/bold green]", spinner="dots"): + model = GPT( + vocab_size=vocab_size, + embedding_dim=EMBEDDING_DIM, + num_heads=NUM_HEADS, + num_layers=NUM_LAYERS, + block_size=BLOCK_SIZE, + dropout=DROPOUT, + ) + model = model.to(DEVICE) + param_count = sum(p.numel() for p in model.parameters()) / 1e6 + + console.print(f" βœ… Model ready β€” [bold]{param_count:.2f}M[/bold] parameters") + + # ── Optimizer ───────────────────────────────────────────────────────────── optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) - # ------------------------------------------------------------------------- - # 4. Training Loop - # ------------------------------------------------------------------------- - print("\nStarting training...") - print(f"Max iterations: {MAX_ITERS}") - print(f"Eval interval: {EVAL_INTERVAL}\n") + # ── Training Loop ───────────────────────────────────────────────────────── + console.print() + console.print(Rule("[bold blue]Training[/bold blue]", style="blue")) + console.print() os.makedirs("checkpoints", exist_ok=True) + + history = [] # list of eval snapshots start_time = time.time() + iter_times = [] # rolling window for speed estimate + current_loss = float("inf") + + with Progress( + SpinnerColumn(style="bold magenta"), + TextColumn("[progress.description]{task.description}"), + BarColumn(bar_width=40, style="magenta", complete_style="bold green"), + TaskProgressColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + TextColumn("ETA"), + TimeRemainingColumn(), + console=console, + refresh_per_second=5, + transient=False, + ) as progress: + task = progress.add_task("[bold cyan]Training…[/bold cyan]", total=MAX_ITERS) + + for iteration in range(MAX_ITERS): + iter_start = time.time() + + # ── LR warmup ────────────────────────────────────────────────── + lr = get_learning_rate(iteration) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + # ── Forward / backward ───────────────────────────────────────── + x, y = get_batch(train_data, BLOCK_SIZE, BATCH_SIZE) + x, y = x.to(DEVICE), y.to(DEVICE) - for iteration in range(MAX_ITERS): - - # 4.1 Update learning rate (warmup schedule) - lr = get_learning_rate(iteration) - for param_group in optimizer.param_groups: - param_group['lr'] = lr - - # 4.2 Get a batch of training data - x, y = get_batch(train_data, BLOCK_SIZE, BATCH_SIZE) - x, y = x.to(DEVICE), y.to(DEVICE) - - # 4.3 Forward pass - get predictions and loss - logits, loss = model(x, y) - - # 4.4 Backward pass - compute gradients - optimizer.zero_grad() - loss.backward() - - # 4.5 Gradient clipping - prevent exploding gradients - # In deep networks, gradients can become very large during backprop. - # This clips them to a maximum norm of 1.0 for training stability. - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - - # 4.6 Update weights - optimizer.step() - - # 4.7 Evaluate periodically - if iteration % EVAL_INTERVAL == 0 or iteration == MAX_ITERS - 1: - losses = evaluate(model, train_data, val_data) - elapsed = time.time() - start_time - - print(f"Iter {iteration:5d} | " - f"Train Loss: {losses['train']:.4f} | " - f"Val Loss: {losses['val']:.4f} | " - f"LR: {lr:.2e} | " - f"Time: {elapsed:.0f}s") - - # Show a sample generation - if iteration > 0: - print("\n--- Sample ---") - print(generate_sample(model, tokenizer)[:400]) - print("--------------\n") - - # ------------------------------------------------------------------------- - # 5. Save Model - # ------------------------------------------------------------------------- - checkpoint = { - 'model_state_dict': model.state_dict(), - 'iteration': MAX_ITERS, - 'val_loss': losses['val'], - 'config': { - 'vocab_size': vocab_size, - 'embedding_dim': EMBEDDING_DIM, - 'num_heads': NUM_HEADS, - 'num_layers': NUM_LAYERS, - 'block_size': BLOCK_SIZE, - } - } - torch.save(checkpoint, CHECKPOINT_PATH) - - # ------------------------------------------------------------------------- - # 6. Done! - # ------------------------------------------------------------------------- + logits, loss = model(x, y) + current_loss = loss.item() + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + + # ── Speed tracking ───────────────────────────────────────────── + iter_times.append(time.time() - iter_start) + if len(iter_times) > 50: + iter_times.pop(0) + iters_per_sec = 1.0 / (sum(iter_times) / len(iter_times)) + + # ── Update progress bar label ────────────────────────────────── + phase = ( + "[yellow]warmup[/yellow]" + if iteration < WARMUP_ITERS + else "[green]train[/green]" + ) + progress.update( + task, + advance=1, + description=( + f"[bold cyan]{phase}[/bold cyan] " + f"loss=[bold yellow]{current_loss:.4f}[/bold yellow] " + f"lr=[magenta]{lr:.1e}[/magenta] " + f"[dim]{iters_per_sec:.1f} it/s[/dim]" + ), + ) + + # ── Periodic evaluation ──────────────────────────────────────── + if iteration % EVAL_INTERVAL == 0 or iteration == MAX_ITERS - 1: + elapsed = time.time() - start_time + losses = evaluate(model, train_data, val_data) + + history.append( + { + "iter": iteration, + "train_loss": losses["train"], + "val_loss": losses["val"], + "lr": lr, + "elapsed": elapsed, + } + ) + + # Print a snapshot below the bar + progress.console.print( + f" [dim]β”‚[/dim] [bold white]iter {iteration:5,}[/bold white] " + f"train=[green]{losses['train']:.4f}[/green] " + f"val=[yellow]{losses['val']:.4f}[/yellow] " + f"elapsed=[cyan]{format_time(elapsed)}[/cyan] " + f"ETA=[red]{format_time((MAX_ITERS - iteration) / iters_per_sec)}[/red]" + ) + + # Save checkpoint + checkpoint = { + "model_state_dict": model.state_dict(), + "iteration": iteration, + "val_loss": losses["val"], + "config": { + "vocab_size": vocab_size, + "embedding_dim": EMBEDDING_DIM, + "num_heads": NUM_HEADS, + "num_layers": NUM_LAYERS, + "block_size": BLOCK_SIZE, + }, + } + torch.save(checkpoint, CHECKPOINT_PATH) + + # Show text sample (skip iteration 0) + if iteration > 0: + sample = generate_sample(model, tokenizer)[:400] + progress.console.print( + Panel( + f"[italic]{sample}[/italic]", + title="🎭 Generated Sample", + border_style="dim blue", + expand=False, + ) + ) + + # ── History Table ───────────────────────────────────────────────────────── + console.print() + console.print(build_stats_table(history)) + + # ── Final Summary ───────────────────────────────────────────────────────── total_time = time.time() - start_time - print("\n" + "=" * 60) - print("Training Complete!") - print("=" * 60) - print(f"Total time: {total_time / 60:.1f} minutes") - print(f"Final val loss: {losses['val']:.4f}") - print(f"Model saved to: {CHECKPOINT_PATH}") + console.print() + console.print( + Panel( + f"[bold green]βœ… Training complete![/bold green]\n\n" + f" Total time : [cyan]{format_time(total_time)}[/cyan] " + f"([dim]{total_time / 60:.1f} min[/dim])\n" + f" Final val loss : [bold yellow]{history[-1]['val_loss']:.4f}[/bold yellow]\n" + f" Checkpoint : [dim]{CHECKPOINT_PATH}[/dim]", + border_style="bright_green", + box=box.DOUBLE_EDGE, + ) + ) if __name__ == "__main__":