feat: add metadata files for NNUE weights versions and implement burst training functionality with subsampling options

This commit is contained in:
2026-04-10 02:16:01 +02:00
parent 5d4cf5f13c
commit 01c8a0f8fe
14 changed files with 465 additions and 134 deletions
+308 -124
View File
@@ -10,7 +10,7 @@ import sys
from pathlib import Path
from tqdm import tqdm
import chess
from datetime import datetime
from datetime import datetime, timedelta
import re
import numpy as np
@@ -152,22 +152,12 @@ def save_metadata(weights_file, metadata):
return metadata_file
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4):
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
def _setup_training(data_file, batch_size, subsample_ratio):
"""Set up device, dataset, and data loaders.
Args:
data_file: Path to training_data.jsonl
output_file: Where to save best weights (or base name if use_versioning=True)
epochs: Number of training epochs (default: 100)
batch_size: Training batch size (default: 16384)
lr: Learning rate (default: 0.001)
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
Returns:
(device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions)
"""
print("Checking GPU availability...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
@@ -184,7 +174,6 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
print(f"Dataset size: {num_positions}")
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
# Print dataset diagnostics
evals_array = np.array(dataset.evals)
print()
print("=" * 60)
@@ -198,19 +187,19 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
print("=" * 60)
print()
# Split 90% train, 10% validation
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
from torch.utils.data import random_split
from torch.utils.data import random_split, RandomSampler
generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
# DataLoader with GPU optimizations: num_workers=8, pin_memory, persistent_workers
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
train_sampler = RandomSampler(train_dataset, replacement=False, num_samples=subsample_size)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
sampler=train_sampler,
num_workers=8,
pin_memory=True,
persistent_workers=True
@@ -224,54 +213,40 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
persistent_workers=True
)
# Model
model = NNUE().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions
# Cosine annealing learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
def _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
start_epoch, epochs, early_stopping_patience,
season_start_time, deadline=None, initial_best_val_loss=float('inf')
):
"""Run one training season until epoch limit, early stopping, or deadline.
# Mixed precision training
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
start_epoch = 0
best_val_loss = float('inf')
if checkpoint:
print(f"Loading checkpoint: {checkpoint}")
ckpt = torch.load(checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
scaler.load_state_dict(ckpt["scaler_state_dict"])
start_epoch = ckpt["epoch"] + 1
best_val_loss = ckpt.get("best_val_loss", float('inf'))
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no optimizer state)")
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
Args:
initial_best_val_loss: Baseline to beat — epochs that don't improve on this count
toward early stopping and do not save snapshots.
Returns:
(best_val_loss, best_model_state, last_epoch)
best_model_state is None if no epoch beat initial_best_val_loss.
"""
best_val_loss = initial_best_val_loss
best_model_state = None
epochs_without_improvement = 0
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
print(f"Mixed precision training: enabled")
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
if early_stopping_patience:
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
print()
training_start_time = datetime.now()
total_epochs = start_epoch + epochs
last_epoch = start_epoch - 1
for epoch in range(start_epoch, start_epoch + epochs):
if deadline and datetime.now() >= deadline:
print("Time limit reached, stopping season.")
break
epoch_display = epoch + 1
# Train
model.train()
train_loss = 0.0
epoch_display = epoch + 1
total_epochs = start_epoch + epochs
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
for batch_features, batch_targets in train_loader:
batch_features = batch_features.to(device)
@@ -279,7 +254,6 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
optimizer.zero_grad()
# Mixed precision forward and backward
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
@@ -310,25 +284,20 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
val_loss /= len(val_dataset)
# Update learning rate
scheduler.step()
# Print GPU memory usage
if torch.cuda.is_available():
gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9
gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9
print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved")
# Calculate and print estimated time remaining
elapsed_time = datetime.now() - training_start_time
elapsed_time = datetime.now() - season_start_time
time_per_epoch = elapsed_time.total_seconds() / (epoch + 1)
remaining_epochs = total_epochs - epoch_display
eta_seconds = time_per_epoch * remaining_epochs
eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0]
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}")
# Save checkpoint after every epoch
checkpoint_file = output_file.replace(".pt", "_checkpoint.pt")
torch.save({
"epoch": epoch,
@@ -339,74 +308,264 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
"best_val_loss": best_val_loss,
}, checkpoint_file)
if val_loss < :
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = model.state_dict().copy()
epochs_without_improvement = 0
# Save best model snapshot
snapshot_file = output_file.replace(".pt", "_best_snapshot.pt")
torch.save(best_model_state, snapshot_file)
print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})")
else:
epochs_without_improvement += 1
# Early stopping check
last_epoch = epoch
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
break
# Save best model
return best_val_loss, best_model_state, last_epoch
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
best_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, training_start_time, extra_metadata=None):
"""Save the best model with optional versioning and metadata."""
final_output_file = output_file
metadata = {}
if use_versioning:
base_name = output_file.replace(".pt", "")
version = find_next_version(base_name)
final_output_file = f"{base_name}_v{version}.pt"
metadata = {
"version": version,
"date": training_start_time.isoformat(),
"num_positions": num_positions,
"stockfish_depth": stockfish_depth,
"final_val_loss": float(best_val_loss),
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
if extra_metadata:
metadata.update(extra_metadata)
torch.save({
"model_state_dict": best_model_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"epoch": last_epoch,
"best_val_loss": best_val_loss,
}, final_output_file)
print(f"Best model saved to {final_output_file}")
if use_versioning and metadata:
metadata_file = save_metadata(final_output_file, metadata)
print(f"Metadata saved to {metadata_file}")
print(f"\nTraining Summary:")
for key, val in metadata.items():
print(f" {key}: {val}")
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0):
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
Args:
data_file: Path to training_data.jsonl
output_file: Where to save best weights (or base name if use_versioning=True)
epochs: Number of training epochs (default: 100)
batch_size: Training batch size (default: 16384)
lr: Learning rate (default: 0.001)
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
"""
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
_setup_training(data_file, batch_size, subsample_ratio)
model = NNUE().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
start_epoch = 0
best_val_loss = float('inf')
if checkpoint:
print(f"Loading checkpoint: {checkpoint}")
ckpt = torch.load(checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
scaler.load_state_dict(ckpt["scaler_state_dict"])
start_epoch = ckpt["epoch"] + 1
best_val_loss = ckpt.get("best_val_loss", float('inf'))
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no optimizer state)")
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
print(f"Mixed precision training: enabled")
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
if subsample_ratio < 1.0:
print(f"Stochastic sampling: {subsample_ratio:.0%} of train set per epoch ({subsample_size:,} positions)")
if early_stopping_patience:
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
print()
training_start_time = datetime.now()
best_val_loss, best_model_state, last_epoch = _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
start_epoch, epochs, early_stopping_patience,
training_start_time
)
if best_model_state is None or best_val_loss >= checkpoint_val_loss:
print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})")
print("No new model created.")
return
if best_model_state is not None:
# Determine final output file with versioning
final_output_file = output_file
metadata = {}
_save_versioned_model(
best_model_state, optimizer, scheduler, scaler, last_epoch,
best_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, training_start_time,
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
"checkpoint": str(checkpoint) if checkpoint else None}
)
if use_versioning:
base_name = output_file.replace(".pt", "")
version = find_next_version(base_name)
final_output_file = f"{base_name}_v{version}.pt"
def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
epochs_per_season=50, early_stopping_patience=10,
batch_size=16384, lr=0.001, initial_checkpoint=None,
stockfish_depth=12, use_versioning=True,
weight_decay=1e-4, subsample_ratio=1.0):
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
# Prepare metadata
metadata = {
"version": version,
"date": training_start_time.isoformat(),
"num_positions": num_positions,
"stockfish_depth": stockfish_depth,
"epochs": epochs,
"batch_size": batch_size,
"learning_rate": lr,
"final_val_loss": float(best_val_loss),
"device": str(device),
"checkpoint": str(checkpoint) if checkpoint else None,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Each season trains with early stopping. When early stopping fires, the model reloads the
global best weights and begins a fresh season with a reset optimizer and scheduler.
This prevents the model from drifting away from its best known state.
torch.save({
"model_state_dict": best_model_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"epoch": epoch,
"best_val_loss": best_val_loss,
}, final_output_file)
print(f"Best model saved to {final_output_file}")
Args:
data_file: Path to training_data.jsonl
output_file: Output file base name
duration_minutes: Total training budget in minutes
epochs_per_season: Max epochs per restart season (default: 50)
early_stopping_patience: Patience for early stopping within each season (default: 10)
batch_size: Training batch size (default: 16384)
lr: Learning rate reset to this value at the start of each season (default: 0.001)
initial_checkpoint: Optional weights-only .pt file to start from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
weight_decay: L2 regularization strength (default: 1e-4)
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
"""
deadline = datetime.now() + timedelta(minutes=duration_minutes)
# Save metadata if versioning is enabled
if use_versioning and metadata:
metadata_file = save_metadata(final_output_file, metadata)
print(f"Metadata saved to {metadata_file}")
print(f"\nTraining Summary:")
print(f" Version: v{metadata['version']}")
print(f" Positions: {metadata['num_positions']}")
print(f" Stockfish depth: {metadata['stockfish_depth']}")
print(f" Epochs: {metadata['epochs']}")
print(f" Final validation loss: {metadata['final_val_loss']:.6f}")
print(f" Device: {metadata['device']}")
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
_setup_training(data_file, batch_size, subsample_ratio)
model = NNUE().to(device)
criterion = nn.MSELoss()
best_global_val_loss = float('inf')
if initial_checkpoint:
print(f"Loading initial weights: {initial_checkpoint}")
ckpt = torch.load(initial_checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
best_global_val_loss = ckpt.get("best_val_loss", float('inf'))
if best_global_val_loss < float('inf'):
print(f"Resumed from checkpoint (best val loss: {best_global_val_loss:.6f})")
else:
print("Initial weights loaded (no val loss in checkpoint).")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no val loss info).")
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
print()
burst_start_time = datetime.now()
season = 0
best_global_state = None
last_optimizer = None
last_scheduler = None
last_scaler = None
last_epoch = 0
while datetime.now() < deadline:
season += 1
remaining_minutes = (deadline - datetime.now()).total_seconds() / 60
print(f"\n{'=' * 60}")
print(f"BURST SEASON {season} | {remaining_minutes:.1f} minutes remaining")
if best_global_val_loss < float('inf'):
print(f"Global best val loss so far: {best_global_val_loss:.6f}")
print(f"{'=' * 60}\n")
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs_per_season)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
season_start_time = datetime.now()
val_loss, model_state, last_epoch = _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
0, epochs_per_season, early_stopping_patience,
season_start_time, deadline=deadline,
initial_best_val_loss=best_global_val_loss
)
last_optimizer = optimizer
last_scheduler = scheduler
last_scaler = scaler
if model_state is not None and val_loss < best_global_val_loss:
best_global_val_loss = val_loss
best_global_state = model_state
print(f" New global best: {best_global_val_loss:.6f} (season {season})")
# Reload global best for the next season so we never drift backwards
if best_global_state is not None:
model.load_state_dict(best_global_state)
total_minutes = (datetime.now() - burst_start_time).total_seconds() / 60
print(f"\n{'=' * 60}")
print(f"Burst training complete: {season} season(s) in {total_minutes:.1f}m")
print(f"Best val loss: {best_global_val_loss:.6f}")
print(f"{'=' * 60}\n")
if best_global_state is None:
print("No model improvement found. No file saved.")
return
_save_versioned_model(
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
best_global_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, burst_start_time,
extra_metadata={
"mode": "burst",
"duration_minutes": duration_minutes,
"epochs_per_season": epochs_per_season,
"early_stopping_patience": early_stopping_patience,
"seasons_completed": season,
"batch_size": batch_size,
"learning_rate": lr,
"initial_checkpoint": str(initial_checkpoint) if initial_checkpoint else None,
}
)
if __name__ == "__main__":
import argparse
@@ -432,18 +591,43 @@ if __name__ == "__main__":
help="Disable automatic versioning (save directly to output file)")
parser.add_argument("--weight-decay", type=float, default=5e-5,
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
parser.add_argument("--subsample-ratio", type=float, default=1.0,
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
# Burst mode
parser.add_argument("--burst-duration", type=float, default=None,
help="Enable burst mode: total training budget in minutes")
parser.add_argument("--epochs-per-season", type=int, default=50,
help="Max epochs per burst season before restarting (default: 50, burst mode only)")
args = parser.parse_args()
train_nnue(
data_file=args.data_file,
output_file=args.output_file,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning,
early_stopping_patience=args.early_stopping,
weight_decay=args.weight_decay
)
if args.burst_duration is not None:
burst_train(
data_file=args.data_file,
output_file=args.output_file,
duration_minutes=args.burst_duration,
epochs_per_season=args.epochs_per_season,
early_stopping_patience=args.early_stopping or 10,
batch_size=args.batch_size,
lr=args.lr,
initial_checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning,
weight_decay=args.weight_decay,
subsample_ratio=args.subsample_ratio,
)
else:
train_nnue(
data_file=args.data_file,
output_file=args.output_file,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning,
early_stopping_patience=args.early_stopping,
weight_decay=args.weight_decay,
subsample_ratio=args.subsample_ratio,
)