diff --git a/modules/bot/python/nnue.py b/modules/bot/python/nnue.py index a531a1f..bd5dc47 100644 --- a/modules/bot/python/nnue.py +++ b/modules/bot/python/nnue.py @@ -16,7 +16,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src")) from generate import play_random_game_and_collect_positions from label import label_positions_with_stockfish -from train import train_nnue +from train import train_nnue, burst_train from export import export_weights_to_binary from tactical_positions_extractor import ( download_and_extract_puzzle_db, @@ -97,24 +97,27 @@ def show_main_menu(): console.print("\n[bold]What would you like to do?[/bold]") console.print("[cyan]1[/cyan] - Train NNUE Model") - console.print("[cyan]2[/cyan] - Export Weights to Scala") - console.print("[cyan]3[/cyan] - Extract Tactical Positions") - console.print("[cyan]4[/cyan] - View Checkpoints") - console.print("[cyan]5[/cyan] - Exit") + console.print("[cyan]2[/cyan] - Burst Train NNUE Model") + console.print("[cyan]3[/cyan] - Export Weights to Scala") + console.print("[cyan]4[/cyan] - Extract Tactical Positions") + console.print("[cyan]5[/cyan] - View Checkpoints") + console.print("[cyan]6[/cyan] - Exit") - choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"]) + choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5", "6"]) if choice == "1": train_interactive() elif choice == "2": - export_interactive() + burst_train_interactive() elif choice == "3": - extract_tactical_interactive() + export_interactive() elif choice == "4": + extract_tactical_interactive() + elif choice == "5": show_header() show_checkpoints_table() Prompt.ask("\nPress Enter to continue") - elif choice == "5": + elif choice == "6": console.print("[yellow]👋 Goodbye![/yellow]") return @@ -172,6 +175,7 @@ def train_interactive(): # Training parameters epochs = int(Prompt.ask("Number of epochs", default="100")) batch_size = int(Prompt.ask("Batch size", default="16384")) + subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0")) early_stopping = None if Confirm.ask("Enable early stopping?", default=False): early_stopping = int(Prompt.ask("Patience (epochs)", default="5")) @@ -190,6 +194,7 @@ def train_interactive(): console.print(f" Positions file: {positions_file}") console.print(f" Epochs: {epochs}") console.print(f" Batch size: {batch_size}") + console.print(f" Subsample ratio: {subsample_ratio:.0%}") if early_stopping: console.print(f" Early stopping: Yes (patience: {early_stopping})") else: @@ -263,7 +268,8 @@ def train_interactive(): batch_size=batch_size, checkpoint=checkpoint, use_versioning=True, - early_stopping_patience=early_stopping + early_stopping_patience=early_stopping, + subsample_ratio=subsample_ratio ) console.print("[green]✓ Training complete[/green]") @@ -280,6 +286,82 @@ def train_interactive(): traceback.print_exc() Prompt.ask("Press Enter to continue") +def burst_train_interactive(): + """Interactive burst training menu.""" + console = Console() + show_header() + + console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]") + console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n") + + duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60")) + epochs_per_season = int(Prompt.ask("Max epochs per season", default="50")) + early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10")) + + # Data file + default_labels = str(get_data_dir() / "training_data.jsonl") + labels_file = Prompt.ask("Path to labeled data file (.jsonl)", default=default_labels) + + # Optional initial checkpoint + available = list_checkpoints() + checkpoint = None + if available: + console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]") + if Confirm.ask("Start from an existing checkpoint?", default=False): + version = Prompt.ask("Enter checkpoint version", default=str(max(available))) + checkpoint = str(get_weights_dir() / f"nnue_weights_v{version}.pt") + + # Training hyperparameters + batch_size = int(Prompt.ask("Batch size", default="16384")) + subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0")) + + # Summary + console.print("\n[bold]Configuration Summary:[/bold]") + console.print(f" Duration: {duration_minutes:.0f} minutes") + console.print(f" Epochs per season: {epochs_per_season}") + console.print(f" Patience: {early_stopping_patience}") + console.print(f" Data file: {labels_file}") + console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}") + console.print(f" Batch size: {batch_size}") + console.print(f" Subsample ratio: {subsample_ratio:.0%}") + + if not Confirm.ask("\nStart burst training?", default=True): + console.print("[yellow]Burst training cancelled[/yellow]") + return + + weights_dir = get_weights_dir() + + try: + if not Path(labels_file).exists(): + console.print(f"[red]✗ Data file not found: {labels_file}[/red]") + Prompt.ask("Press Enter to continue") + return + + console.print("\n[bold cyan]Burst Training[/bold cyan]") + burst_train( + data_file=labels_file, + output_file=str(weights_dir / "nnue_weights.pt"), + duration_minutes=duration_minutes, + epochs_per_season=epochs_per_season, + early_stopping_patience=early_stopping_patience, + batch_size=batch_size, + initial_checkpoint=checkpoint, + use_versioning=True, + subsample_ratio=subsample_ratio, + ) + console.print("[green]✓ Burst training complete[/green]") + + available = list_checkpoints() + if available: + console.print(f"[bold]Latest checkpoint: v{max(available)}[/bold]") + Prompt.ask("Press Enter to continue") + + except Exception as e: + console.print(f"[red]✗ Error: {e}[/red]") + import traceback + traceback.print_exc() + Prompt.ask("Press Enter to continue") + def export_interactive(): """Interactive export menu.""" console = Console() diff --git a/modules/bot/python/src/train.py b/modules/bot/python/src/train.py index 4fba99a..cdead7f 100644 --- a/modules/bot/python/src/train.py +++ b/modules/bot/python/src/train.py @@ -10,7 +10,7 @@ import sys from pathlib import Path from tqdm import tqdm import chess -from datetime import datetime +from datetime import datetime, timedelta import re import numpy as np @@ -152,22 +152,12 @@ def save_metadata(weights_file, metadata): return metadata_file -def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4): - """Train the NNUE model with GPU optimizations and automatic mixed precision. +def _setup_training(data_file, batch_size, subsample_ratio): + """Set up device, dataset, and data loaders. - Args: - data_file: Path to training_data.jsonl - output_file: Where to save best weights (or base name if use_versioning=True) - epochs: Number of training epochs (default: 100) - batch_size: Training batch size (default: 16384) - lr: Learning rate (default: 0.001) - checkpoint: Optional path to checkpoint file to resume from - stockfish_depth: Depth used in Stockfish evaluation (for metadata) - use_versioning: If True, save as nnue_weights_v{N}.pt with metadata - early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable) - weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting) + Returns: + (device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions) """ - print("Checking GPU availability...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): @@ -184,7 +174,6 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size= print(f"Dataset size: {num_positions}") print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})") - # Print dataset diagnostics evals_array = np.array(dataset.evals) print() print("=" * 60) @@ -198,19 +187,19 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size= print("=" * 60) print() - # Split 90% train, 10% validation train_size = int(0.9 * len(dataset)) val_size = len(dataset) - train_size - from torch.utils.data import random_split + from torch.utils.data import random_split, RandomSampler generator = torch.Generator().manual_seed(42) train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator) - # DataLoader with GPU optimizations: num_workers=8, pin_memory, persistent_workers + subsample_size = max(1, int(subsample_ratio * len(train_dataset))) + train_sampler = RandomSampler(train_dataset, replacement=False, num_samples=subsample_size) train_loader = DataLoader( train_dataset, batch_size=batch_size, - shuffle=True, + sampler=train_sampler, num_workers=8, pin_memory=True, persistent_workers=True @@ -224,54 +213,40 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size= persistent_workers=True ) - # Model - model = NNUE().to(device) - criterion = nn.MSELoss() - optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions - # Cosine annealing learning rate scheduler - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) +def _run_training_season( + model, optimizer, scheduler, scaler, + train_loader, val_loader, train_dataset, val_dataset, + device, criterion, output_file, + start_epoch, epochs, early_stopping_patience, + season_start_time, deadline=None, initial_best_val_loss=float('inf') +): + """Run one training season until epoch limit, early stopping, or deadline. - # Mixed precision training - scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu') - - start_epoch = 0 - best_val_loss = float('inf') - if checkpoint: - print(f"Loading checkpoint: {checkpoint}") - ckpt = torch.load(checkpoint, map_location=device) - if isinstance(ckpt, dict) and "model_state_dict" in ckpt: - model.load_state_dict(ckpt["model_state_dict"]) - optimizer.load_state_dict(ckpt["optimizer_state_dict"]) - scheduler.load_state_dict(ckpt["scheduler_state_dict"]) - scaler.load_state_dict(ckpt["scaler_state_dict"]) - start_epoch = ckpt["epoch"] + 1 - best_val_loss = ckpt.get("best_val_loss", float('inf')) - print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})") - else: - model.load_state_dict(ckpt) - print("Loaded weights-only checkpoint (no optimizer state)") - - checkpoint_val_loss = best_val_loss if checkpoint else float('inf') + Args: + initial_best_val_loss: Baseline to beat — epochs that don't improve on this count + toward early stopping and do not save snapshots. + Returns: + (best_val_loss, best_model_state, last_epoch) + best_model_state is None if no epoch beat initial_best_val_loss. + """ + best_val_loss = initial_best_val_loss best_model_state = None epochs_without_improvement = 0 - - print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...") - print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})") - print(f"Mixed precision training: enabled") - print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})") - if early_stopping_patience: - print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)") - print() - - training_start_time = datetime.now() + total_epochs = start_epoch + epochs + last_epoch = start_epoch - 1 for epoch in range(start_epoch, start_epoch + epochs): + if deadline and datetime.now() >= deadline: + print("Time limit reached, stopping season.") + break + + epoch_display = epoch + 1 + # Train model.train() train_loss = 0.0 - epoch_display = epoch + 1 - total_epochs = start_epoch + epochs with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar: for batch_features, batch_targets in train_loader: batch_features = batch_features.to(device) @@ -279,7 +254,6 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size= optimizer.zero_grad() - # Mixed precision forward and backward with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'): outputs = model(batch_features) loss = criterion(outputs, batch_targets) @@ -310,25 +284,20 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size= val_loss /= len(val_dataset) - # Update learning rate scheduler.step() - # Print GPU memory usage if torch.cuda.is_available(): gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9 gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9 print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved") - # Calculate and print estimated time remaining - elapsed_time = datetime.now() - training_start_time + elapsed_time = datetime.now() - season_start_time time_per_epoch = elapsed_time.total_seconds() / (epoch + 1) remaining_epochs = total_epochs - epoch_display eta_seconds = time_per_epoch * remaining_epochs eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0] - print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}") - # Save checkpoint after every epoch checkpoint_file = output_file.replace(".pt", "_checkpoint.pt") torch.save({ "epoch": epoch, @@ -339,74 +308,264 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size= "best_val_loss": best_val_loss, }, checkpoint_file) - if val_loss < : + if val_loss < best_val_loss: best_val_loss = val_loss best_model_state = model.state_dict().copy() epochs_without_improvement = 0 - # Save best model snapshot snapshot_file = output_file.replace(".pt", "_best_snapshot.pt") torch.save(best_model_state, snapshot_file) print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})") else: epochs_without_improvement += 1 - # Early stopping check + last_epoch = epoch + if early_stopping_patience and epochs_without_improvement >= early_stopping_patience: print(f"Early stopping: no improvement for {early_stopping_patience} epochs") break - # Save best model + return best_val_loss, best_model_state, last_epoch + +def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch, + best_val_loss, output_file, use_versioning, num_positions, + stockfish_depth, training_start_time, extra_metadata=None): + """Save the best model with optional versioning and metadata.""" + final_output_file = output_file + metadata = {} + + if use_versioning: + base_name = output_file.replace(".pt", "") + version = find_next_version(base_name) + final_output_file = f"{base_name}_v{version}.pt" + + metadata = { + "version": version, + "date": training_start_time.isoformat(), + "num_positions": num_positions, + "stockfish_depth": stockfish_depth, + "final_val_loss": float(best_val_loss), + "device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")), + "notes": "Win rate vs classical eval: TBD (requires benchmark games)" + } + if extra_metadata: + metadata.update(extra_metadata) + + torch.save({ + "model_state_dict": best_model_state, + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "scaler_state_dict": scaler.state_dict(), + "epoch": last_epoch, + "best_val_loss": best_val_loss, + }, final_output_file) + print(f"Best model saved to {final_output_file}") + + if use_versioning and metadata: + metadata_file = save_metadata(final_output_file, metadata) + print(f"Metadata saved to {metadata_file}") + print(f"\nTraining Summary:") + for key, val in metadata.items(): + print(f" {key}: {val}") + +def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0): + """Train the NNUE model with GPU optimizations and automatic mixed precision. + + Args: + data_file: Path to training_data.jsonl + output_file: Where to save best weights (or base name if use_versioning=True) + epochs: Number of training epochs (default: 100) + batch_size: Training batch size (default: 16384) + lr: Learning rate (default: 0.001) + checkpoint: Optional path to checkpoint file to resume from + stockfish_depth: Depth used in Stockfish evaluation (for metadata) + use_versioning: If True, save as nnue_weights_v{N}.pt with metadata + early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable) + weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting) + subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data) + """ + device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \ + _setup_training(data_file, batch_size, subsample_ratio) + + model = NNUE().to(device) + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu') + + start_epoch = 0 + best_val_loss = float('inf') + if checkpoint: + print(f"Loading checkpoint: {checkpoint}") + ckpt = torch.load(checkpoint, map_location=device) + if isinstance(ckpt, dict) and "model_state_dict" in ckpt: + model.load_state_dict(ckpt["model_state_dict"]) + optimizer.load_state_dict(ckpt["optimizer_state_dict"]) + scheduler.load_state_dict(ckpt["scheduler_state_dict"]) + scaler.load_state_dict(ckpt["scaler_state_dict"]) + start_epoch = ckpt["epoch"] + 1 + best_val_loss = ckpt.get("best_val_loss", float('inf')) + print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})") + else: + model.load_state_dict(ckpt) + print("Loaded weights-only checkpoint (no optimizer state)") + + checkpoint_val_loss = best_val_loss if checkpoint else float('inf') + + subsample_size = max(1, int(subsample_ratio * len(train_dataset))) + print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...") + print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})") + print(f"Mixed precision training: enabled") + print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})") + if subsample_ratio < 1.0: + print(f"Stochastic sampling: {subsample_ratio:.0%} of train set per epoch ({subsample_size:,} positions)") + if early_stopping_patience: + print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)") + print() + + training_start_time = datetime.now() + + best_val_loss, best_model_state, last_epoch = _run_training_season( + model, optimizer, scheduler, scaler, + train_loader, val_loader, train_dataset, val_dataset, + device, criterion, output_file, + start_epoch, epochs, early_stopping_patience, + training_start_time + ) + if best_model_state is None or best_val_loss >= checkpoint_val_loss: print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})") print("No new model created.") return - if best_model_state is not None: - # Determine final output file with versioning - final_output_file = output_file - metadata = {} + _save_versioned_model( + best_model_state, optimizer, scheduler, scaler, last_epoch, + best_val_loss, output_file, use_versioning, num_positions, + stockfish_depth, training_start_time, + extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr, + "checkpoint": str(checkpoint) if checkpoint else None} + ) - if use_versioning: - base_name = output_file.replace(".pt", "") - version = find_next_version(base_name) - final_output_file = f"{base_name}_v{version}.pt" +def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60, + epochs_per_season=50, early_stopping_patience=10, + batch_size=16384, lr=0.001, initial_checkpoint=None, + stockfish_depth=12, use_versioning=True, + weight_decay=1e-4, subsample_ratio=1.0): + """Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires. - # Prepare metadata - metadata = { - "version": version, - "date": training_start_time.isoformat(), - "num_positions": num_positions, - "stockfish_depth": stockfish_depth, - "epochs": epochs, - "batch_size": batch_size, - "learning_rate": lr, - "final_val_loss": float(best_val_loss), - "device": str(device), - "checkpoint": str(checkpoint) if checkpoint else None, - "notes": "Win rate vs classical eval: TBD (requires benchmark games)" - } + Each season trains with early stopping. When early stopping fires, the model reloads the + global best weights and begins a fresh season with a reset optimizer and scheduler. + This prevents the model from drifting away from its best known state. - torch.save({ - "model_state_dict": best_model_state, - "optimizer_state_dict": optimizer.state_dict(), - "scheduler_state_dict": scheduler.state_dict(), - "scaler_state_dict": scaler.state_dict(), - "epoch": epoch, - "best_val_loss": best_val_loss, - }, final_output_file) - print(f"Best model saved to {final_output_file}") + Args: + data_file: Path to training_data.jsonl + output_file: Output file base name + duration_minutes: Total training budget in minutes + epochs_per_season: Max epochs per restart season (default: 50) + early_stopping_patience: Patience for early stopping within each season (default: 10) + batch_size: Training batch size (default: 16384) + lr: Learning rate reset to this value at the start of each season (default: 0.001) + initial_checkpoint: Optional weights-only .pt file to start from + stockfish_depth: Depth used in Stockfish evaluation (for metadata) + use_versioning: If True, save as nnue_weights_v{N}.pt with metadata + weight_decay: L2 regularization strength (default: 1e-4) + subsample_ratio: Fraction of training data to sample per epoch (default: 1.0) + """ + deadline = datetime.now() + timedelta(minutes=duration_minutes) - # Save metadata if versioning is enabled - if use_versioning and metadata: - metadata_file = save_metadata(final_output_file, metadata) - print(f"Metadata saved to {metadata_file}") - print(f"\nTraining Summary:") - print(f" Version: v{metadata['version']}") - print(f" Positions: {metadata['num_positions']}") - print(f" Stockfish depth: {metadata['stockfish_depth']}") - print(f" Epochs: {metadata['epochs']}") - print(f" Final validation loss: {metadata['final_val_loss']:.6f}") - print(f" Device: {metadata['device']}") + device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \ + _setup_training(data_file, batch_size, subsample_ratio) + + model = NNUE().to(device) + criterion = nn.MSELoss() + best_global_val_loss = float('inf') + + if initial_checkpoint: + print(f"Loading initial weights: {initial_checkpoint}") + ckpt = torch.load(initial_checkpoint, map_location=device) + if isinstance(ckpt, dict) and "model_state_dict" in ckpt: + model.load_state_dict(ckpt["model_state_dict"]) + best_global_val_loss = ckpt.get("best_val_loss", float('inf')) + if best_global_val_loss < float('inf'): + print(f"Resumed from checkpoint (best val loss: {best_global_val_loss:.6f})") + else: + print("Initial weights loaded (no val loss in checkpoint).") + else: + model.load_state_dict(ckpt) + print("Loaded weights-only checkpoint (no val loss info).") + + print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}") + print(f"Deadline: {deadline.strftime('%H:%M:%S')}") + print() + + burst_start_time = datetime.now() + season = 0 + best_global_state = None + last_optimizer = None + last_scheduler = None + last_scaler = None + last_epoch = 0 + + while datetime.now() < deadline: + season += 1 + remaining_minutes = (deadline - datetime.now()).total_seconds() / 60 + print(f"\n{'=' * 60}") + print(f"BURST SEASON {season} | {remaining_minutes:.1f} minutes remaining") + if best_global_val_loss < float('inf'): + print(f"Global best val loss so far: {best_global_val_loss:.6f}") + print(f"{'=' * 60}\n") + + optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs_per_season) + scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu') + + season_start_time = datetime.now() + val_loss, model_state, last_epoch = _run_training_season( + model, optimizer, scheduler, scaler, + train_loader, val_loader, train_dataset, val_dataset, + device, criterion, output_file, + 0, epochs_per_season, early_stopping_patience, + season_start_time, deadline=deadline, + initial_best_val_loss=best_global_val_loss + ) + + last_optimizer = optimizer + last_scheduler = scheduler + last_scaler = scaler + + if model_state is not None and val_loss < best_global_val_loss: + best_global_val_loss = val_loss + best_global_state = model_state + print(f" New global best: {best_global_val_loss:.6f} (season {season})") + + # Reload global best for the next season so we never drift backwards + if best_global_state is not None: + model.load_state_dict(best_global_state) + + total_minutes = (datetime.now() - burst_start_time).total_seconds() / 60 + print(f"\n{'=' * 60}") + print(f"Burst training complete: {season} season(s) in {total_minutes:.1f}m") + print(f"Best val loss: {best_global_val_loss:.6f}") + print(f"{'=' * 60}\n") + + if best_global_state is None: + print("No model improvement found. No file saved.") + return + + _save_versioned_model( + best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch, + best_global_val_loss, output_file, use_versioning, num_positions, + stockfish_depth, burst_start_time, + extra_metadata={ + "mode": "burst", + "duration_minutes": duration_minutes, + "epochs_per_season": epochs_per_season, + "early_stopping_patience": early_stopping_patience, + "seasons_completed": season, + "batch_size": batch_size, + "learning_rate": lr, + "initial_checkpoint": str(initial_checkpoint) if initial_checkpoint else None, + } + ) if __name__ == "__main__": import argparse @@ -432,18 +591,43 @@ if __name__ == "__main__": help="Disable automatic versioning (save directly to output file)") parser.add_argument("--weight-decay", type=float, default=5e-5, help="L2 regularization strength (default: 1e-4, helps prevent overfitting)") + parser.add_argument("--subsample-ratio", type=float, default=1.0, + help="Fraction of training data to sample per epoch (default: 1.0 = all data)") + + # Burst mode + parser.add_argument("--burst-duration", type=float, default=None, + help="Enable burst mode: total training budget in minutes") + parser.add_argument("--epochs-per-season", type=int, default=50, + help="Max epochs per burst season before restarting (default: 50, burst mode only)") args = parser.parse_args() - train_nnue( - data_file=args.data_file, - output_file=args.output_file, - epochs=args.epochs, - batch_size=args.batch_size, - lr=args.lr, - checkpoint=args.checkpoint, - stockfish_depth=args.stockfish_depth, - use_versioning=not args.no_versioning, - early_stopping_patience=args.early_stopping, - weight_decay=args.weight_decay - ) + if args.burst_duration is not None: + burst_train( + data_file=args.data_file, + output_file=args.output_file, + duration_minutes=args.burst_duration, + epochs_per_season=args.epochs_per_season, + early_stopping_patience=args.early_stopping or 10, + batch_size=args.batch_size, + lr=args.lr, + initial_checkpoint=args.checkpoint, + stockfish_depth=args.stockfish_depth, + use_versioning=not args.no_versioning, + weight_decay=args.weight_decay, + subsample_ratio=args.subsample_ratio, + ) + else: + train_nnue( + data_file=args.data_file, + output_file=args.output_file, + epochs=args.epochs, + batch_size=args.batch_size, + lr=args.lr, + checkpoint=args.checkpoint, + stockfish_depth=args.stockfish_depth, + use_versioning=not args.no_versioning, + early_stopping_patience=args.early_stopping, + weight_decay=args.weight_decay, + subsample_ratio=args.subsample_ratio, + ) diff --git a/modules/bot/python/weights/nnue_weights_checkpoint.pt b/modules/bot/python/weights/nnue_weights_checkpoint.pt new file mode 100644 index 0000000..ad5b4e4 Binary files /dev/null and b/modules/bot/python/weights/nnue_weights_checkpoint.pt differ diff --git a/modules/bot/python/weights/nnue_weights_v4.pt b/modules/bot/python/weights/nnue_weights_v4.pt new file mode 100644 index 0000000..9505929 Binary files /dev/null and b/modules/bot/python/weights/nnue_weights_v4.pt differ diff --git a/modules/bot/python/weights/nnue_weights_v4_metadata.json b/modules/bot/python/weights/nnue_weights_v4_metadata.json new file mode 100644 index 0000000..0ccb826 --- /dev/null +++ b/modules/bot/python/weights/nnue_weights_v4_metadata.json @@ -0,0 +1,13 @@ +{ + "version": 4, + "date": "2026-04-09T00:28:07.572209", + "num_positions": 2009355, + "stockfish_depth": 12, + "epochs": 40, + "batch_size": 4096, + "learning_rate": 0.001, + "final_val_loss": 9.106677896235248e-05, + "device": "cuda", + "checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v3.pt", + "notes": "Win rate vs classical eval: TBD (requires benchmark games)" +} \ No newline at end of file diff --git a/modules/bot/python/weights/nnue_weights_v5.pt b/modules/bot/python/weights/nnue_weights_v5.pt new file mode 100644 index 0000000..75fbba8 Binary files /dev/null and b/modules/bot/python/weights/nnue_weights_v5.pt differ diff --git a/modules/bot/python/weights/nnue_weights_v5_metadata.json b/modules/bot/python/weights/nnue_weights_v5_metadata.json new file mode 100644 index 0000000..c5c441c --- /dev/null +++ b/modules/bot/python/weights/nnue_weights_v5_metadata.json @@ -0,0 +1,13 @@ +{ + "version": 5, + "date": "2026-04-09T18:50:27.845632", + "num_positions": 2009355, + "stockfish_depth": 12, + "epochs": 100, + "batch_size": 16384, + "learning_rate": 0.001, + "final_val_loss": 9.180421525105905e-05, + "device": "cuda", + "checkpoint": null, + "notes": "Win rate vs classical eval: TBD (requires benchmark games)" +} \ No newline at end of file diff --git a/modules/bot/python/weights/nnue_weights_v6.pt b/modules/bot/python/weights/nnue_weights_v6.pt new file mode 100644 index 0000000..5b43a66 Binary files /dev/null and b/modules/bot/python/weights/nnue_weights_v6.pt differ diff --git a/modules/bot/python/weights/nnue_weights_v6_metadata.json b/modules/bot/python/weights/nnue_weights_v6_metadata.json new file mode 100644 index 0000000..1b199eb --- /dev/null +++ b/modules/bot/python/weights/nnue_weights_v6_metadata.json @@ -0,0 +1,13 @@ +{ + "version": 6, + "date": "2026-04-09T21:28:21.000832", + "num_positions": 1958728, + "stockfish_depth": 12, + "epochs": 100, + "batch_size": 16384, + "learning_rate": 0.001, + "final_val_loss": 0.2984530149085532, + "device": "cuda", + "checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v5.pt", + "notes": "Win rate vs classical eval: TBD (requires benchmark games)" +} \ No newline at end of file diff --git a/modules/bot/python/weights/nnue_weights_v7.pt b/modules/bot/python/weights/nnue_weights_v7.pt new file mode 100644 index 0000000..047eb46 Binary files /dev/null and b/modules/bot/python/weights/nnue_weights_v7.pt differ diff --git a/modules/bot/python/weights/nnue_weights_v7_metadata.json b/modules/bot/python/weights/nnue_weights_v7_metadata.json new file mode 100644 index 0000000..5b5b0ce --- /dev/null +++ b/modules/bot/python/weights/nnue_weights_v7_metadata.json @@ -0,0 +1,13 @@ +{ + "version": 7, + "date": "2026-04-09T22:06:50.439858", + "num_positions": 1958728, + "stockfish_depth": 12, + "epochs": 100, + "batch_size": 16384, + "learning_rate": 0.001, + "final_val_loss": 0.2997283308762831, + "device": "cuda", + "checkpoint": null, + "notes": "Win rate vs classical eval: TBD (requires benchmark games)" +} \ No newline at end of file diff --git a/modules/bot/python/weights/nnue_weights_v8.pt b/modules/bot/python/weights/nnue_weights_v8.pt new file mode 100644 index 0000000..3964320 Binary files /dev/null and b/modules/bot/python/weights/nnue_weights_v8.pt differ diff --git a/modules/bot/python/weights/nnue_weights_v8_metadata.json b/modules/bot/python/weights/nnue_weights_v8_metadata.json new file mode 100644 index 0000000..ba3ac9f --- /dev/null +++ b/modules/bot/python/weights/nnue_weights_v8_metadata.json @@ -0,0 +1,13 @@ +{ + "version": 8, + "date": "2026-04-09T22:22:47.859730", + "num_positions": 1958728, + "stockfish_depth": 12, + "epochs": 100, + "batch_size": 16384, + "learning_rate": 0.001, + "final_val_loss": 0.24803777390839968, + "device": "cuda", + "checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v7.pt", + "notes": "Win rate vs classical eval: TBD (requires benchmark games)" +} \ No newline at end of file diff --git a/modules/bot/src/main/resources/nnue_weights.bin b/modules/bot/src/main/resources/nnue_weights.bin index 86851fa..fb4d5d8 100644 Binary files a/modules/bot/src/main/resources/nnue_weights.bin and b/modules/bot/src/main/resources/nnue_weights.bin differ