feat: Allow configurable hidden layer sizes for NNUE architecture

This commit is contained in:
2026-04-14 21:34:40 +02:00
parent 227f2e4a41
commit 8db2c8ca7f
2 changed files with 96 additions and 34 deletions
+21 -2
View File
@@ -17,7 +17,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish
from train import train_nnue, burst_train
from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES
from export import export_to_nbai
from tactical_positions_extractor import (
download_and_extract_puzzle_db,
@@ -624,13 +624,22 @@ def train_interactive():
epochs = int(Prompt.ask("Number of epochs", default="100"))
batch_size = int(Prompt.ask("Batch size", default="16384"))
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
hidden_layers_str = Prompt.ask(
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
default=default_layers
)
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
early_stopping = None
if Confirm.ask("Enable early stopping?", default=False):
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
arch_str = "".join(str(s) for s in [768] + hidden_sizes + [1])
# Confirm and start
console.print("\n[bold]Configuration Summary:[/bold]")
console.print(f" Dataset: ds_v{dataset_version}")
console.print(f" Architecture: {arch_str}")
console.print(f" Epochs: {epochs}")
console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
@@ -665,7 +674,8 @@ def train_interactive():
checkpoint=checkpoint,
use_versioning=True,
early_stopping_patience=early_stopping,
subsample_ratio=subsample_ratio
subsample_ratio=subsample_ratio,
hidden_sizes=hidden_sizes,
)
console.print("[green]✓ Training complete[/green]")
@@ -727,10 +737,18 @@ def burst_train_interactive():
# Training hyperparameters
batch_size = int(Prompt.ask("Batch size", default="16384"))
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
hidden_layers_str = Prompt.ask(
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
default=default_layers
)
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
arch_str = "".join(str(s) for s in [768] + hidden_sizes + [1])
# Summary
console.print("\n[bold]Configuration Summary:[/bold]")
console.print(f" Dataset: ds_v{dataset_version}")
console.print(f" Architecture: {arch_str}")
console.print(f" Duration: {duration_minutes:.0f} minutes")
console.print(f" Epochs per season: {epochs_per_season}")
console.print(f" Patience: {early_stopping_patience}")
@@ -757,6 +775,7 @@ def burst_train_interactive():
initial_checkpoint=checkpoint,
use_versioning=True,
subsample_ratio=subsample_ratio,
hidden_sizes=hidden_sizes,
)
console.print("[green]✓ Burst training complete[/green]")