feat: Allow configurable hidden layer sizes for NNUE architecture
This commit is contained in:
@@ -17,7 +17,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||||
|
||||
from generate import play_random_game_and_collect_positions
|
||||
from label import label_positions_with_stockfish
|
||||
from train import train_nnue, burst_train
|
||||
from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES
|
||||
from export import export_to_nbai
|
||||
from tactical_positions_extractor import (
|
||||
download_and_extract_puzzle_db,
|
||||
@@ -624,13 +624,22 @@ def train_interactive():
|
||||
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
|
||||
hidden_layers_str = Prompt.ask(
|
||||
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
|
||||
default=default_layers
|
||||
)
|
||||
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
|
||||
early_stopping = None
|
||||
if Confirm.ask("Enable early stopping?", default=False):
|
||||
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||||
|
||||
arch_str = " → ".join(str(s) for s in [768] + hidden_sizes + [1])
|
||||
|
||||
# Confirm and start
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
console.print(f" Dataset: ds_v{dataset_version}")
|
||||
console.print(f" Architecture: {arch_str}")
|
||||
console.print(f" Epochs: {epochs}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||
@@ -665,7 +674,8 @@ def train_interactive():
|
||||
checkpoint=checkpoint,
|
||||
use_versioning=True,
|
||||
early_stopping_patience=early_stopping,
|
||||
subsample_ratio=subsample_ratio
|
||||
subsample_ratio=subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
console.print("[green]✓ Training complete[/green]")
|
||||
|
||||
@@ -727,10 +737,18 @@ def burst_train_interactive():
|
||||
# Training hyperparameters
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
|
||||
hidden_layers_str = Prompt.ask(
|
||||
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
|
||||
default=default_layers
|
||||
)
|
||||
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
|
||||
arch_str = " → ".join(str(s) for s in [768] + hidden_sizes + [1])
|
||||
|
||||
# Summary
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
console.print(f" Dataset: ds_v{dataset_version}")
|
||||
console.print(f" Architecture: {arch_str}")
|
||||
console.print(f" Duration: {duration_minutes:.0f} minutes")
|
||||
console.print(f" Epochs per season: {epochs_per_season}")
|
||||
console.print(f" Patience: {early_stopping_patience}")
|
||||
@@ -757,6 +775,7 @@ def burst_train_interactive():
|
||||
initial_checkpoint=checkpoint,
|
||||
use_versioning=True,
|
||||
subsample_ratio=subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
console.print("[green]✓ Burst training complete[/green]")
|
||||
|
||||
|
||||
@@ -86,35 +86,39 @@ def fen_to_features(fen):
|
||||
|
||||
return features
|
||||
|
||||
DEFAULT_HIDDEN_SIZES = [1536, 1024, 512, 256]
|
||||
|
||||
|
||||
class NNUE(nn.Module):
|
||||
"""NNUE neural network architecture: 768→1536→1024→512→256→1 with dropout regularization."""
|
||||
"""NNUE neural network with configurable hidden layers.
|
||||
|
||||
def __init__(self, dropout_rate=0.2):
|
||||
Architecture: 768 → hidden_sizes[0] → ... → hidden_sizes[-1] → 1
|
||||
Layer attributes follow the naming l1, l2, ..., lN so export.py can
|
||||
infer the architecture directly from the state_dict.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_sizes=None, dropout_rate=0.2):
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(768, 1536)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.drop1 = nn.Dropout(dropout_rate)
|
||||
if hidden_sizes is None:
|
||||
hidden_sizes = DEFAULT_HIDDEN_SIZES
|
||||
self.hidden_sizes = list(hidden_sizes)
|
||||
sizes = [768] + self.hidden_sizes + [1]
|
||||
num_hidden = len(self.hidden_sizes)
|
||||
|
||||
self.l2 = nn.Linear(1536, 1024)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.drop2 = nn.Dropout(dropout_rate)
|
||||
|
||||
self.l3 = nn.Linear(1024, 512)
|
||||
self.relu3 = nn.ReLU()
|
||||
self.drop3 = nn.Dropout(dropout_rate)
|
||||
|
||||
self.l4 = nn.Linear(512, 256)
|
||||
self.relu4 = nn.ReLU()
|
||||
self.drop4 = nn.Dropout(dropout_rate)
|
||||
|
||||
self.l5 = nn.Linear(256, 1)
|
||||
for i in range(num_hidden):
|
||||
setattr(self, f"l{i + 1}", nn.Linear(sizes[i], sizes[i + 1]))
|
||||
setattr(self, f"relu{i + 1}", nn.ReLU())
|
||||
setattr(self, f"drop{i + 1}", nn.Dropout(dropout_rate))
|
||||
setattr(self, f"l{num_hidden + 1}", nn.Linear(sizes[-2], sizes[-1]))
|
||||
self._num_hidden = num_hidden
|
||||
|
||||
def forward(self, x):
|
||||
x = self.drop1(self.relu1(self.l1(x)))
|
||||
x = self.drop2(self.relu2(self.l2(x)))
|
||||
x = self.drop3(self.relu3(self.l3(x)))
|
||||
x = self.drop4(self.relu4(self.l4(x)))
|
||||
return self.l5(x)
|
||||
for i in range(1, self._num_hidden + 1):
|
||||
layer = getattr(self, f"l{i}")
|
||||
relu = getattr(self, f"relu{i}")
|
||||
drop = getattr(self, f"drop{i}")
|
||||
x = drop(relu(layer(x)))
|
||||
return getattr(self, f"l{self._num_hidden + 1}")(x)
|
||||
|
||||
def find_next_version(base_name="nnue_weights"):
|
||||
"""Find the next version number for model versioning.
|
||||
@@ -306,6 +310,7 @@ def _run_training_season(
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"scaler_state_dict": scaler.state_dict(),
|
||||
"best_val_loss": best_val_loss,
|
||||
"hidden_sizes": model.hidden_sizes,
|
||||
}, checkpoint_file)
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
@@ -328,10 +333,12 @@ def _run_training_season(
|
||||
|
||||
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||
best_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, training_start_time, extra_metadata=None):
|
||||
stockfish_depth, training_start_time, hidden_sizes=None,
|
||||
extra_metadata=None):
|
||||
"""Save the best model with optional versioning and metadata."""
|
||||
final_output_file = output_file
|
||||
metadata = {}
|
||||
architecture = [768] + list(hidden_sizes or DEFAULT_HIDDEN_SIZES) + [1]
|
||||
|
||||
if use_versioning:
|
||||
base_name = output_file.replace(".pt", "")
|
||||
@@ -344,6 +351,7 @@ def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_e
|
||||
"num_positions": num_positions,
|
||||
"stockfish_depth": stockfish_depth,
|
||||
"final_val_loss": float(best_val_loss),
|
||||
"architecture": architecture,
|
||||
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
@@ -357,6 +365,7 @@ def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_e
|
||||
"scaler_state_dict": scaler.state_dict(),
|
||||
"epoch": last_epoch,
|
||||
"best_val_loss": best_val_loss,
|
||||
"hidden_sizes": list(hidden_sizes or DEFAULT_HIDDEN_SIZES),
|
||||
}, final_output_file)
|
||||
print(f"Best model saved to {final_output_file}")
|
||||
|
||||
@@ -367,7 +376,7 @@ def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_e
|
||||
for key, val in metadata.items():
|
||||
print(f" {key}: {val}")
|
||||
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0):
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
|
||||
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
|
||||
|
||||
Args:
|
||||
@@ -382,20 +391,31 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
|
||||
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
|
||||
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
|
||||
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
|
||||
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
|
||||
"""
|
||||
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||
_setup_training(data_file, batch_size, subsample_ratio)
|
||||
|
||||
model = NNUE().to(device)
|
||||
start_epoch = 0
|
||||
best_val_loss = float('inf')
|
||||
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
|
||||
|
||||
if checkpoint:
|
||||
print(f"Loading checkpoint: {checkpoint}")
|
||||
ckpt = torch.load(checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
ckpt_hidden = ckpt.get("hidden_sizes")
|
||||
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
|
||||
print(f" Using architecture from checkpoint: {ckpt_hidden}")
|
||||
resolved_hidden_sizes = ckpt_hidden
|
||||
|
||||
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||
|
||||
start_epoch = 0
|
||||
best_val_loss = float('inf')
|
||||
if checkpoint:
|
||||
print(f"Loading checkpoint: {checkpoint}")
|
||||
ckpt = torch.load(checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
@@ -412,6 +432,8 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
|
||||
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
|
||||
|
||||
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
||||
arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
|
||||
print(f"Architecture: {arch_str}")
|
||||
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
|
||||
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
|
||||
print(f"Mixed precision training: enabled")
|
||||
@@ -441,6 +463,7 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
|
||||
best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||
best_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, training_start_time,
|
||||
hidden_sizes=resolved_hidden_sizes,
|
||||
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
|
||||
"checkpoint": str(checkpoint) if checkpoint else None}
|
||||
)
|
||||
@@ -449,7 +472,7 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
||||
epochs_per_season=50, early_stopping_patience=10,
|
||||
batch_size=16384, lr=0.001, initial_checkpoint=None,
|
||||
stockfish_depth=12, use_versioning=True,
|
||||
weight_decay=1e-4, subsample_ratio=1.0):
|
||||
weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
|
||||
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
|
||||
|
||||
Each season trains with early stopping. When early stopping fires, the model reloads the
|
||||
@@ -469,18 +492,29 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
weight_decay: L2 regularization strength (default: 1e-4)
|
||||
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
|
||||
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
|
||||
"""
|
||||
deadline = datetime.now() + timedelta(minutes=duration_minutes)
|
||||
|
||||
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||
_setup_training(data_file, batch_size, subsample_ratio)
|
||||
|
||||
model = NNUE().to(device)
|
||||
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
|
||||
|
||||
if initial_checkpoint:
|
||||
print(f"Loading initial weights: {initial_checkpoint}")
|
||||
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
ckpt_hidden = ckpt.get("hidden_sizes")
|
||||
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
|
||||
print(f" Using architecture from checkpoint: {ckpt_hidden}")
|
||||
resolved_hidden_sizes = ckpt_hidden
|
||||
|
||||
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
|
||||
criterion = nn.MSELoss()
|
||||
best_global_val_loss = float('inf')
|
||||
|
||||
if initial_checkpoint:
|
||||
print(f"Loading initial weights: {initial_checkpoint}")
|
||||
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
@@ -493,6 +527,8 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
||||
model.load_state_dict(ckpt)
|
||||
print("Loaded weights-only checkpoint (no val loss info).")
|
||||
|
||||
arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
|
||||
print(f"Architecture: {arch_str}")
|
||||
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
|
||||
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
|
||||
print()
|
||||
@@ -555,6 +591,7 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
||||
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
|
||||
best_global_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, burst_start_time,
|
||||
hidden_sizes=resolved_hidden_sizes,
|
||||
extra_metadata={
|
||||
"mode": "burst",
|
||||
"duration_minutes": duration_minutes,
|
||||
@@ -593,6 +630,8 @@ if __name__ == "__main__":
|
||||
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
|
||||
parser.add_argument("--subsample-ratio", type=float, default=1.0,
|
||||
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
|
||||
parser.add_argument("--hidden-layers", type=str, default=None,
|
||||
help="Comma-separated hidden layer sizes (default: 1536,1024,512,256)")
|
||||
|
||||
# Burst mode
|
||||
parser.add_argument("--burst-duration", type=float, default=None,
|
||||
@@ -602,6 +641,8 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
hidden_sizes = [int(x) for x in args.hidden_layers.split(",")] if args.hidden_layers else None
|
||||
|
||||
if args.burst_duration is not None:
|
||||
burst_train(
|
||||
data_file=args.data_file,
|
||||
@@ -616,6 +657,7 @@ if __name__ == "__main__":
|
||||
use_versioning=not args.no_versioning,
|
||||
weight_decay=args.weight_decay,
|
||||
subsample_ratio=args.subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
else:
|
||||
train_nnue(
|
||||
@@ -630,4 +672,5 @@ if __name__ == "__main__":
|
||||
early_stopping_patience=args.early_stopping,
|
||||
weight_decay=args.weight_decay,
|
||||
subsample_ratio=args.subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user