feat: Allow configurable hidden layer sizes for NNUE architecture

This commit is contained in:
2026-04-14 21:34:40 +02:00
parent 227f2e4a41
commit 8db2c8ca7f
2 changed files with 96 additions and 34 deletions
+21 -2
View File
@@ -17,7 +17,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
from generate import play_random_game_and_collect_positions from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish from label import label_positions_with_stockfish
from train import train_nnue, burst_train from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES
from export import export_to_nbai from export import export_to_nbai
from tactical_positions_extractor import ( from tactical_positions_extractor import (
download_and_extract_puzzle_db, download_and_extract_puzzle_db,
@@ -624,13 +624,22 @@ def train_interactive():
epochs = int(Prompt.ask("Number of epochs", default="100")) epochs = int(Prompt.ask("Number of epochs", default="100"))
batch_size = int(Prompt.ask("Batch size", default="16384")) batch_size = int(Prompt.ask("Batch size", default="16384"))
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0")) subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
hidden_layers_str = Prompt.ask(
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
default=default_layers
)
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
early_stopping = None early_stopping = None
if Confirm.ask("Enable early stopping?", default=False): if Confirm.ask("Enable early stopping?", default=False):
early_stopping = int(Prompt.ask("Patience (epochs)", default="5")) early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
arch_str = "".join(str(s) for s in [768] + hidden_sizes + [1])
# Confirm and start # Confirm and start
console.print("\n[bold]Configuration Summary:[/bold]") console.print("\n[bold]Configuration Summary:[/bold]")
console.print(f" Dataset: ds_v{dataset_version}") console.print(f" Dataset: ds_v{dataset_version}")
console.print(f" Architecture: {arch_str}")
console.print(f" Epochs: {epochs}") console.print(f" Epochs: {epochs}")
console.print(f" Batch size: {batch_size}") console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}") console.print(f" Subsample ratio: {subsample_ratio:.0%}")
@@ -665,7 +674,8 @@ def train_interactive():
checkpoint=checkpoint, checkpoint=checkpoint,
use_versioning=True, use_versioning=True,
early_stopping_patience=early_stopping, early_stopping_patience=early_stopping,
subsample_ratio=subsample_ratio subsample_ratio=subsample_ratio,
hidden_sizes=hidden_sizes,
) )
console.print("[green]✓ Training complete[/green]") console.print("[green]✓ Training complete[/green]")
@@ -727,10 +737,18 @@ def burst_train_interactive():
# Training hyperparameters # Training hyperparameters
batch_size = int(Prompt.ask("Batch size", default="16384")) batch_size = int(Prompt.ask("Batch size", default="16384"))
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0")) subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
hidden_layers_str = Prompt.ask(
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
default=default_layers
)
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
arch_str = "".join(str(s) for s in [768] + hidden_sizes + [1])
# Summary # Summary
console.print("\n[bold]Configuration Summary:[/bold]") console.print("\n[bold]Configuration Summary:[/bold]")
console.print(f" Dataset: ds_v{dataset_version}") console.print(f" Dataset: ds_v{dataset_version}")
console.print(f" Architecture: {arch_str}")
console.print(f" Duration: {duration_minutes:.0f} minutes") console.print(f" Duration: {duration_minutes:.0f} minutes")
console.print(f" Epochs per season: {epochs_per_season}") console.print(f" Epochs per season: {epochs_per_season}")
console.print(f" Patience: {early_stopping_patience}") console.print(f" Patience: {early_stopping_patience}")
@@ -757,6 +775,7 @@ def burst_train_interactive():
initial_checkpoint=checkpoint, initial_checkpoint=checkpoint,
use_versioning=True, use_versioning=True,
subsample_ratio=subsample_ratio, subsample_ratio=subsample_ratio,
hidden_sizes=hidden_sizes,
) )
console.print("[green]✓ Burst training complete[/green]") console.print("[green]✓ Burst training complete[/green]")
+75 -32
View File
@@ -86,35 +86,39 @@ def fen_to_features(fen):
return features return features
DEFAULT_HIDDEN_SIZES = [1536, 1024, 512, 256]
class NNUE(nn.Module): class NNUE(nn.Module):
"""NNUE neural network architecture: 768→1536→1024→512→256→1 with dropout regularization.""" """NNUE neural network with configurable hidden layers.
def __init__(self, dropout_rate=0.2): Architecture: 768 → hidden_sizes[0] → ... → hidden_sizes[-1] → 1
Layer attributes follow the naming l1, l2, ..., lN so export.py can
infer the architecture directly from the state_dict.
"""
def __init__(self, hidden_sizes=None, dropout_rate=0.2):
super().__init__() super().__init__()
self.l1 = nn.Linear(768, 1536) if hidden_sizes is None:
self.relu1 = nn.ReLU() hidden_sizes = DEFAULT_HIDDEN_SIZES
self.drop1 = nn.Dropout(dropout_rate) self.hidden_sizes = list(hidden_sizes)
sizes = [768] + self.hidden_sizes + [1]
num_hidden = len(self.hidden_sizes)
self.l2 = nn.Linear(1536, 1024) for i in range(num_hidden):
self.relu2 = nn.ReLU() setattr(self, f"l{i + 1}", nn.Linear(sizes[i], sizes[i + 1]))
self.drop2 = nn.Dropout(dropout_rate) setattr(self, f"relu{i + 1}", nn.ReLU())
setattr(self, f"drop{i + 1}", nn.Dropout(dropout_rate))
self.l3 = nn.Linear(1024, 512) setattr(self, f"l{num_hidden + 1}", nn.Linear(sizes[-2], sizes[-1]))
self.relu3 = nn.ReLU() self._num_hidden = num_hidden
self.drop3 = nn.Dropout(dropout_rate)
self.l4 = nn.Linear(512, 256)
self.relu4 = nn.ReLU()
self.drop4 = nn.Dropout(dropout_rate)
self.l5 = nn.Linear(256, 1)
def forward(self, x): def forward(self, x):
x = self.drop1(self.relu1(self.l1(x))) for i in range(1, self._num_hidden + 1):
x = self.drop2(self.relu2(self.l2(x))) layer = getattr(self, f"l{i}")
x = self.drop3(self.relu3(self.l3(x))) relu = getattr(self, f"relu{i}")
x = self.drop4(self.relu4(self.l4(x))) drop = getattr(self, f"drop{i}")
return self.l5(x) x = drop(relu(layer(x)))
return getattr(self, f"l{self._num_hidden + 1}")(x)
def find_next_version(base_name="nnue_weights"): def find_next_version(base_name="nnue_weights"):
"""Find the next version number for model versioning. """Find the next version number for model versioning.
@@ -306,6 +310,7 @@ def _run_training_season(
"scheduler_state_dict": scheduler.state_dict(), "scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(), "scaler_state_dict": scaler.state_dict(),
"best_val_loss": best_val_loss, "best_val_loss": best_val_loss,
"hidden_sizes": model.hidden_sizes,
}, checkpoint_file) }, checkpoint_file)
if val_loss < best_val_loss: if val_loss < best_val_loss:
@@ -328,10 +333,12 @@ def _run_training_season(
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch, def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
best_val_loss, output_file, use_versioning, num_positions, best_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, training_start_time, extra_metadata=None): stockfish_depth, training_start_time, hidden_sizes=None,
extra_metadata=None):
"""Save the best model with optional versioning and metadata.""" """Save the best model with optional versioning and metadata."""
final_output_file = output_file final_output_file = output_file
metadata = {} metadata = {}
architecture = [768] + list(hidden_sizes or DEFAULT_HIDDEN_SIZES) + [1]
if use_versioning: if use_versioning:
base_name = output_file.replace(".pt", "") base_name = output_file.replace(".pt", "")
@@ -344,6 +351,7 @@ def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_e
"num_positions": num_positions, "num_positions": num_positions,
"stockfish_depth": stockfish_depth, "stockfish_depth": stockfish_depth,
"final_val_loss": float(best_val_loss), "final_val_loss": float(best_val_loss),
"architecture": architecture,
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")), "device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
"notes": "Win rate vs classical eval: TBD (requires benchmark games)" "notes": "Win rate vs classical eval: TBD (requires benchmark games)"
} }
@@ -357,6 +365,7 @@ def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_e
"scaler_state_dict": scaler.state_dict(), "scaler_state_dict": scaler.state_dict(),
"epoch": last_epoch, "epoch": last_epoch,
"best_val_loss": best_val_loss, "best_val_loss": best_val_loss,
"hidden_sizes": list(hidden_sizes or DEFAULT_HIDDEN_SIZES),
}, final_output_file) }, final_output_file)
print(f"Best model saved to {final_output_file}") print(f"Best model saved to {final_output_file}")
@@ -367,7 +376,7 @@ def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_e
for key, val in metadata.items(): for key, val in metadata.items():
print(f" {key}: {val}") print(f" {key}: {val}")
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0): def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
"""Train the NNUE model with GPU optimizations and automatic mixed precision. """Train the NNUE model with GPU optimizations and automatic mixed precision.
Args: Args:
@@ -382,20 +391,31 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable) early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting) weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data) subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
""" """
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \ device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
_setup_training(data_file, batch_size, subsample_ratio) _setup_training(data_file, batch_size, subsample_ratio)
model = NNUE().to(device) start_epoch = 0
best_val_loss = float('inf')
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
if checkpoint:
print(f"Loading checkpoint: {checkpoint}")
ckpt = torch.load(checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
ckpt_hidden = ckpt.get("hidden_sizes")
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
print(f" Using architecture from checkpoint: {ckpt_hidden}")
resolved_hidden_sizes = ckpt_hidden
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
criterion = nn.MSELoss() criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu') scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
start_epoch = 0
best_val_loss = float('inf')
if checkpoint: if checkpoint:
print(f"Loading checkpoint: {checkpoint}")
ckpt = torch.load(checkpoint, map_location=device) ckpt = torch.load(checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt: if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"]) model.load_state_dict(ckpt["model_state_dict"])
@@ -412,6 +432,8 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
checkpoint_val_loss = best_val_loss if checkpoint else float('inf') checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
subsample_size = max(1, int(subsample_ratio * len(train_dataset))) subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
arch_str = "".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
print(f"Architecture: {arch_str}")
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...") print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})") print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
print(f"Mixed precision training: enabled") print(f"Mixed precision training: enabled")
@@ -441,6 +463,7 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
best_model_state, optimizer, scheduler, scaler, last_epoch, best_model_state, optimizer, scheduler, scaler, last_epoch,
best_val_loss, output_file, use_versioning, num_positions, best_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, training_start_time, stockfish_depth, training_start_time,
hidden_sizes=resolved_hidden_sizes,
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr, extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
"checkpoint": str(checkpoint) if checkpoint else None} "checkpoint": str(checkpoint) if checkpoint else None}
) )
@@ -449,7 +472,7 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
epochs_per_season=50, early_stopping_patience=10, epochs_per_season=50, early_stopping_patience=10,
batch_size=16384, lr=0.001, initial_checkpoint=None, batch_size=16384, lr=0.001, initial_checkpoint=None,
stockfish_depth=12, use_versioning=True, stockfish_depth=12, use_versioning=True,
weight_decay=1e-4, subsample_ratio=1.0): weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires. """Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
Each season trains with early stopping. When early stopping fires, the model reloads the Each season trains with early stopping. When early stopping fires, the model reloads the
@@ -469,18 +492,29 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
weight_decay: L2 regularization strength (default: 1e-4) weight_decay: L2 regularization strength (default: 1e-4)
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0) subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
""" """
deadline = datetime.now() + timedelta(minutes=duration_minutes) deadline = datetime.now() + timedelta(minutes=duration_minutes)
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \ device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
_setup_training(data_file, batch_size, subsample_ratio) _setup_training(data_file, batch_size, subsample_ratio)
model = NNUE().to(device) resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
if initial_checkpoint:
print(f"Loading initial weights: {initial_checkpoint}")
ckpt = torch.load(initial_checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
ckpt_hidden = ckpt.get("hidden_sizes")
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
print(f" Using architecture from checkpoint: {ckpt_hidden}")
resolved_hidden_sizes = ckpt_hidden
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
criterion = nn.MSELoss() criterion = nn.MSELoss()
best_global_val_loss = float('inf') best_global_val_loss = float('inf')
if initial_checkpoint: if initial_checkpoint:
print(f"Loading initial weights: {initial_checkpoint}")
ckpt = torch.load(initial_checkpoint, map_location=device) ckpt = torch.load(initial_checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt: if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"]) model.load_state_dict(ckpt["model_state_dict"])
@@ -493,6 +527,8 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
model.load_state_dict(ckpt) model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no val loss info).") print("Loaded weights-only checkpoint (no val loss info).")
arch_str = "".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
print(f"Architecture: {arch_str}")
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}") print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
print(f"Deadline: {deadline.strftime('%H:%M:%S')}") print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
print() print()
@@ -555,6 +591,7 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch, best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
best_global_val_loss, output_file, use_versioning, num_positions, best_global_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, burst_start_time, stockfish_depth, burst_start_time,
hidden_sizes=resolved_hidden_sizes,
extra_metadata={ extra_metadata={
"mode": "burst", "mode": "burst",
"duration_minutes": duration_minutes, "duration_minutes": duration_minutes,
@@ -593,6 +630,8 @@ if __name__ == "__main__":
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)") help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
parser.add_argument("--subsample-ratio", type=float, default=1.0, parser.add_argument("--subsample-ratio", type=float, default=1.0,
help="Fraction of training data to sample per epoch (default: 1.0 = all data)") help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
parser.add_argument("--hidden-layers", type=str, default=None,
help="Comma-separated hidden layer sizes (default: 1536,1024,512,256)")
# Burst mode # Burst mode
parser.add_argument("--burst-duration", type=float, default=None, parser.add_argument("--burst-duration", type=float, default=None,
@@ -602,6 +641,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
hidden_sizes = [int(x) for x in args.hidden_layers.split(",")] if args.hidden_layers else None
if args.burst_duration is not None: if args.burst_duration is not None:
burst_train( burst_train(
data_file=args.data_file, data_file=args.data_file,
@@ -616,6 +657,7 @@ if __name__ == "__main__":
use_versioning=not args.no_versioning, use_versioning=not args.no_versioning,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
subsample_ratio=args.subsample_ratio, subsample_ratio=args.subsample_ratio,
hidden_sizes=hidden_sizes,
) )
else: else:
train_nnue( train_nnue(
@@ -630,4 +672,5 @@ if __name__ == "__main__":
early_stopping_patience=args.early_stopping, early_stopping_patience=args.early_stopping,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
subsample_ratio=args.subsample_ratio, subsample_ratio=args.subsample_ratio,
hidden_sizes=hidden_sizes,
) )