diff --git a/modules/bot/python/nnue.py b/modules/bot/python/nnue.py index 588c7c4..58bf155 100644 --- a/modules/bot/python/nnue.py +++ b/modules/bot/python/nnue.py @@ -17,7 +17,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src")) from generate import play_random_game_and_collect_positions from label import label_positions_with_stockfish -from train import train_nnue, burst_train +from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES from export import export_to_nbai from tactical_positions_extractor import ( download_and_extract_puzzle_db, @@ -624,13 +624,22 @@ def train_interactive(): epochs = int(Prompt.ask("Number of epochs", default="100")) batch_size = int(Prompt.ask("Batch size", default="16384")) subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0")) + default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES) + hidden_layers_str = Prompt.ask( + "Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)", + default=default_layers + ) + hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()] early_stopping = None if Confirm.ask("Enable early stopping?", default=False): early_stopping = int(Prompt.ask("Patience (epochs)", default="5")) + arch_str = " → ".join(str(s) for s in [768] + hidden_sizes + [1]) + # Confirm and start console.print("\n[bold]Configuration Summary:[/bold]") console.print(f" Dataset: ds_v{dataset_version}") + console.print(f" Architecture: {arch_str}") console.print(f" Epochs: {epochs}") console.print(f" Batch size: {batch_size}") console.print(f" Subsample ratio: {subsample_ratio:.0%}") @@ -665,7 +674,8 @@ def train_interactive(): checkpoint=checkpoint, use_versioning=True, early_stopping_patience=early_stopping, - subsample_ratio=subsample_ratio + subsample_ratio=subsample_ratio, + hidden_sizes=hidden_sizes, ) console.print("[green]✓ Training complete[/green]") @@ -727,10 +737,18 @@ def burst_train_interactive(): # Training hyperparameters batch_size = int(Prompt.ask("Batch size", default="16384")) subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0")) + default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES) + hidden_layers_str = Prompt.ask( + "Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)", + default=default_layers + ) + hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()] + arch_str = " → ".join(str(s) for s in [768] + hidden_sizes + [1]) # Summary console.print("\n[bold]Configuration Summary:[/bold]") console.print(f" Dataset: ds_v{dataset_version}") + console.print(f" Architecture: {arch_str}") console.print(f" Duration: {duration_minutes:.0f} minutes") console.print(f" Epochs per season: {epochs_per_season}") console.print(f" Patience: {early_stopping_patience}") @@ -757,6 +775,7 @@ def burst_train_interactive(): initial_checkpoint=checkpoint, use_versioning=True, subsample_ratio=subsample_ratio, + hidden_sizes=hidden_sizes, ) console.print("[green]✓ Burst training complete[/green]") diff --git a/modules/bot/python/src/train.py b/modules/bot/python/src/train.py index cdead7f..3b85755 100644 --- a/modules/bot/python/src/train.py +++ b/modules/bot/python/src/train.py @@ -86,35 +86,39 @@ def fen_to_features(fen): return features +DEFAULT_HIDDEN_SIZES = [1536, 1024, 512, 256] + + class NNUE(nn.Module): - """NNUE neural network architecture: 768→1536→1024→512→256→1 with dropout regularization.""" + """NNUE neural network with configurable hidden layers. - def __init__(self, dropout_rate=0.2): + Architecture: 768 → hidden_sizes[0] → ... → hidden_sizes[-1] → 1 + Layer attributes follow the naming l1, l2, ..., lN so export.py can + infer the architecture directly from the state_dict. + """ + + def __init__(self, hidden_sizes=None, dropout_rate=0.2): super().__init__() - self.l1 = nn.Linear(768, 1536) - self.relu1 = nn.ReLU() - self.drop1 = nn.Dropout(dropout_rate) + if hidden_sizes is None: + hidden_sizes = DEFAULT_HIDDEN_SIZES + self.hidden_sizes = list(hidden_sizes) + sizes = [768] + self.hidden_sizes + [1] + num_hidden = len(self.hidden_sizes) - self.l2 = nn.Linear(1536, 1024) - self.relu2 = nn.ReLU() - self.drop2 = nn.Dropout(dropout_rate) - - self.l3 = nn.Linear(1024, 512) - self.relu3 = nn.ReLU() - self.drop3 = nn.Dropout(dropout_rate) - - self.l4 = nn.Linear(512, 256) - self.relu4 = nn.ReLU() - self.drop4 = nn.Dropout(dropout_rate) - - self.l5 = nn.Linear(256, 1) + for i in range(num_hidden): + setattr(self, f"l{i + 1}", nn.Linear(sizes[i], sizes[i + 1])) + setattr(self, f"relu{i + 1}", nn.ReLU()) + setattr(self, f"drop{i + 1}", nn.Dropout(dropout_rate)) + setattr(self, f"l{num_hidden + 1}", nn.Linear(sizes[-2], sizes[-1])) + self._num_hidden = num_hidden def forward(self, x): - x = self.drop1(self.relu1(self.l1(x))) - x = self.drop2(self.relu2(self.l2(x))) - x = self.drop3(self.relu3(self.l3(x))) - x = self.drop4(self.relu4(self.l4(x))) - return self.l5(x) + for i in range(1, self._num_hidden + 1): + layer = getattr(self, f"l{i}") + relu = getattr(self, f"relu{i}") + drop = getattr(self, f"drop{i}") + x = drop(relu(layer(x))) + return getattr(self, f"l{self._num_hidden + 1}")(x) def find_next_version(base_name="nnue_weights"): """Find the next version number for model versioning. @@ -306,6 +310,7 @@ def _run_training_season( "scheduler_state_dict": scheduler.state_dict(), "scaler_state_dict": scaler.state_dict(), "best_val_loss": best_val_loss, + "hidden_sizes": model.hidden_sizes, }, checkpoint_file) if val_loss < best_val_loss: @@ -328,10 +333,12 @@ def _run_training_season( def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch, best_val_loss, output_file, use_versioning, num_positions, - stockfish_depth, training_start_time, extra_metadata=None): + stockfish_depth, training_start_time, hidden_sizes=None, + extra_metadata=None): """Save the best model with optional versioning and metadata.""" final_output_file = output_file metadata = {} + architecture = [768] + list(hidden_sizes or DEFAULT_HIDDEN_SIZES) + [1] if use_versioning: base_name = output_file.replace(".pt", "") @@ -344,6 +351,7 @@ def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_e "num_positions": num_positions, "stockfish_depth": stockfish_depth, "final_val_loss": float(best_val_loss), + "architecture": architecture, "device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")), "notes": "Win rate vs classical eval: TBD (requires benchmark games)" } @@ -357,6 +365,7 @@ def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_e "scaler_state_dict": scaler.state_dict(), "epoch": last_epoch, "best_val_loss": best_val_loss, + "hidden_sizes": list(hidden_sizes or DEFAULT_HIDDEN_SIZES), }, final_output_file) print(f"Best model saved to {final_output_file}") @@ -367,7 +376,7 @@ def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_e for key, val in metadata.items(): print(f" {key}: {val}") -def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0): +def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None): """Train the NNUE model with GPU optimizations and automatic mixed precision. Args: @@ -382,20 +391,31 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size= early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable) weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting) subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data) + hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256]) """ device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \ _setup_training(data_file, batch_size, subsample_ratio) - model = NNUE().to(device) + start_epoch = 0 + best_val_loss = float('inf') + resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES) + + if checkpoint: + print(f"Loading checkpoint: {checkpoint}") + ckpt = torch.load(checkpoint, map_location=device) + if isinstance(ckpt, dict) and "model_state_dict" in ckpt: + ckpt_hidden = ckpt.get("hidden_sizes") + if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes: + print(f" Using architecture from checkpoint: {ckpt_hidden}") + resolved_hidden_sizes = ckpt_hidden + + model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device) criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu') - start_epoch = 0 - best_val_loss = float('inf') if checkpoint: - print(f"Loading checkpoint: {checkpoint}") ckpt = torch.load(checkpoint, map_location=device) if isinstance(ckpt, dict) and "model_state_dict" in ckpt: model.load_state_dict(ckpt["model_state_dict"]) @@ -412,6 +432,8 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size= checkpoint_val_loss = best_val_loss if checkpoint else float('inf') subsample_size = max(1, int(subsample_ratio * len(train_dataset))) + arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1]) + print(f"Architecture: {arch_str}") print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...") print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})") print(f"Mixed precision training: enabled") @@ -441,6 +463,7 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size= best_model_state, optimizer, scheduler, scaler, last_epoch, best_val_loss, output_file, use_versioning, num_positions, stockfish_depth, training_start_time, + hidden_sizes=resolved_hidden_sizes, extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr, "checkpoint": str(checkpoint) if checkpoint else None} ) @@ -449,7 +472,7 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60, epochs_per_season=50, early_stopping_patience=10, batch_size=16384, lr=0.001, initial_checkpoint=None, stockfish_depth=12, use_versioning=True, - weight_decay=1e-4, subsample_ratio=1.0): + weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None): """Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires. Each season trains with early stopping. When early stopping fires, the model reloads the @@ -469,18 +492,29 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60, use_versioning: If True, save as nnue_weights_v{N}.pt with metadata weight_decay: L2 regularization strength (default: 1e-4) subsample_ratio: Fraction of training data to sample per epoch (default: 1.0) + hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256]) """ deadline = datetime.now() + timedelta(minutes=duration_minutes) device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \ _setup_training(data_file, batch_size, subsample_ratio) - model = NNUE().to(device) + resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES) + + if initial_checkpoint: + print(f"Loading initial weights: {initial_checkpoint}") + ckpt = torch.load(initial_checkpoint, map_location=device) + if isinstance(ckpt, dict) and "model_state_dict" in ckpt: + ckpt_hidden = ckpt.get("hidden_sizes") + if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes: + print(f" Using architecture from checkpoint: {ckpt_hidden}") + resolved_hidden_sizes = ckpt_hidden + + model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device) criterion = nn.MSELoss() best_global_val_loss = float('inf') if initial_checkpoint: - print(f"Loading initial weights: {initial_checkpoint}") ckpt = torch.load(initial_checkpoint, map_location=device) if isinstance(ckpt, dict) and "model_state_dict" in ckpt: model.load_state_dict(ckpt["model_state_dict"]) @@ -493,6 +527,8 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60, model.load_state_dict(ckpt) print("Loaded weights-only checkpoint (no val loss info).") + arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1]) + print(f"Architecture: {arch_str}") print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}") print(f"Deadline: {deadline.strftime('%H:%M:%S')}") print() @@ -555,6 +591,7 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60, best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch, best_global_val_loss, output_file, use_versioning, num_positions, stockfish_depth, burst_start_time, + hidden_sizes=resolved_hidden_sizes, extra_metadata={ "mode": "burst", "duration_minutes": duration_minutes, @@ -593,6 +630,8 @@ if __name__ == "__main__": help="L2 regularization strength (default: 1e-4, helps prevent overfitting)") parser.add_argument("--subsample-ratio", type=float, default=1.0, help="Fraction of training data to sample per epoch (default: 1.0 = all data)") + parser.add_argument("--hidden-layers", type=str, default=None, + help="Comma-separated hidden layer sizes (default: 1536,1024,512,256)") # Burst mode parser.add_argument("--burst-duration", type=float, default=None, @@ -602,6 +641,8 @@ if __name__ == "__main__": args = parser.parse_args() + hidden_sizes = [int(x) for x in args.hidden_layers.split(",")] if args.hidden_layers else None + if args.burst_duration is not None: burst_train( data_file=args.data_file, @@ -616,6 +657,7 @@ if __name__ == "__main__": use_versioning=not args.no_versioning, weight_decay=args.weight_decay, subsample_ratio=args.subsample_ratio, + hidden_sizes=hidden_sizes, ) else: train_nnue( @@ -630,4 +672,5 @@ if __name__ == "__main__": early_stopping_patience=args.early_stopping, weight_decay=args.weight_decay, subsample_ratio=args.subsample_ratio, + hidden_sizes=hidden_sizes, )