feat: add metadata files for NNUE weights versions and implement burst training functionality with subsampling options

This commit is contained in:
2026-04-10 02:16:01 +02:00
parent 5d4cf5f13c
commit 01c8a0f8fe
14 changed files with 465 additions and 134 deletions
+92 -10
View File
@@ -16,7 +16,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
from generate import play_random_game_and_collect_positions from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish from label import label_positions_with_stockfish
from train import train_nnue from train import train_nnue, burst_train
from export import export_weights_to_binary from export import export_weights_to_binary
from tactical_positions_extractor import ( from tactical_positions_extractor import (
download_and_extract_puzzle_db, download_and_extract_puzzle_db,
@@ -97,24 +97,27 @@ def show_main_menu():
console.print("\n[bold]What would you like to do?[/bold]") console.print("\n[bold]What would you like to do?[/bold]")
console.print("[cyan]1[/cyan] - Train NNUE Model") console.print("[cyan]1[/cyan] - Train NNUE Model")
console.print("[cyan]2[/cyan] - Export Weights to Scala") console.print("[cyan]2[/cyan] - Burst Train NNUE Model")
console.print("[cyan]3[/cyan] - Extract Tactical Positions") console.print("[cyan]3[/cyan] - Export Weights to Scala")
console.print("[cyan]4[/cyan] - View Checkpoints") console.print("[cyan]4[/cyan] - Extract Tactical Positions")
console.print("[cyan]5[/cyan] - Exit") console.print("[cyan]5[/cyan] - View Checkpoints")
console.print("[cyan]6[/cyan] - Exit")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"]) choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5", "6"])
if choice == "1": if choice == "1":
train_interactive() train_interactive()
elif choice == "2": elif choice == "2":
export_interactive() burst_train_interactive()
elif choice == "3": elif choice == "3":
extract_tactical_interactive() export_interactive()
elif choice == "4": elif choice == "4":
extract_tactical_interactive()
elif choice == "5":
show_header() show_header()
show_checkpoints_table() show_checkpoints_table()
Prompt.ask("\nPress Enter to continue") Prompt.ask("\nPress Enter to continue")
elif choice == "5": elif choice == "6":
console.print("[yellow]👋 Goodbye![/yellow]") console.print("[yellow]👋 Goodbye![/yellow]")
return return
@@ -172,6 +175,7 @@ def train_interactive():
# Training parameters # Training parameters
epochs = int(Prompt.ask("Number of epochs", default="100")) epochs = int(Prompt.ask("Number of epochs", default="100"))
batch_size = int(Prompt.ask("Batch size", default="16384")) batch_size = int(Prompt.ask("Batch size", default="16384"))
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
early_stopping = None early_stopping = None
if Confirm.ask("Enable early stopping?", default=False): if Confirm.ask("Enable early stopping?", default=False):
early_stopping = int(Prompt.ask("Patience (epochs)", default="5")) early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
@@ -190,6 +194,7 @@ def train_interactive():
console.print(f" Positions file: {positions_file}") console.print(f" Positions file: {positions_file}")
console.print(f" Epochs: {epochs}") console.print(f" Epochs: {epochs}")
console.print(f" Batch size: {batch_size}") console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
if early_stopping: if early_stopping:
console.print(f" Early stopping: Yes (patience: {early_stopping})") console.print(f" Early stopping: Yes (patience: {early_stopping})")
else: else:
@@ -263,7 +268,8 @@ def train_interactive():
batch_size=batch_size, batch_size=batch_size,
checkpoint=checkpoint, checkpoint=checkpoint,
use_versioning=True, use_versioning=True,
early_stopping_patience=early_stopping early_stopping_patience=early_stopping,
subsample_ratio=subsample_ratio
) )
console.print("[green]✓ Training complete[/green]") console.print("[green]✓ Training complete[/green]")
@@ -280,6 +286,82 @@ def train_interactive():
traceback.print_exc() traceback.print_exc()
Prompt.ask("Press Enter to continue") Prompt.ask("Press Enter to continue")
def burst_train_interactive():
"""Interactive burst training menu."""
console = Console()
show_header()
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
# Data file
default_labels = str(get_data_dir() / "training_data.jsonl")
labels_file = Prompt.ask("Path to labeled data file (.jsonl)", default=default_labels)
# Optional initial checkpoint
available = list_checkpoints()
checkpoint = None
if available:
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
if Confirm.ask("Start from an existing checkpoint?", default=False):
version = Prompt.ask("Enter checkpoint version", default=str(max(available)))
checkpoint = str(get_weights_dir() / f"nnue_weights_v{version}.pt")
# Training hyperparameters
batch_size = int(Prompt.ask("Batch size", default="16384"))
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
# Summary
console.print("\n[bold]Configuration Summary:[/bold]")
console.print(f" Duration: {duration_minutes:.0f} minutes")
console.print(f" Epochs per season: {epochs_per_season}")
console.print(f" Patience: {early_stopping_patience}")
console.print(f" Data file: {labels_file}")
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
if not Confirm.ask("\nStart burst training?", default=True):
console.print("[yellow]Burst training cancelled[/yellow]")
return
weights_dir = get_weights_dir()
try:
if not Path(labels_file).exists():
console.print(f"[red]✗ Data file not found: {labels_file}[/red]")
Prompt.ask("Press Enter to continue")
return
console.print("\n[bold cyan]Burst Training[/bold cyan]")
burst_train(
data_file=labels_file,
output_file=str(weights_dir / "nnue_weights.pt"),
duration_minutes=duration_minutes,
epochs_per_season=epochs_per_season,
early_stopping_patience=early_stopping_patience,
batch_size=batch_size,
initial_checkpoint=checkpoint,
use_versioning=True,
subsample_ratio=subsample_ratio,
)
console.print("[green]✓ Burst training complete[/green]")
available = list_checkpoints()
if available:
console.print(f"[bold]Latest checkpoint: v{max(available)}[/bold]")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def export_interactive(): def export_interactive():
"""Interactive export menu.""" """Interactive export menu."""
console = Console() console = Console()
+308 -124
View File
@@ -10,7 +10,7 @@ import sys
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
import chess import chess
from datetime import datetime from datetime import datetime, timedelta
import re import re
import numpy as np import numpy as np
@@ -152,22 +152,12 @@ def save_metadata(weights_file, metadata):
return metadata_file return metadata_file
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4): def _setup_training(data_file, batch_size, subsample_ratio):
"""Train the NNUE model with GPU optimizations and automatic mixed precision. """Set up device, dataset, and data loaders.
Args: Returns:
data_file: Path to training_data.jsonl (device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions)
output_file: Where to save best weights (or base name if use_versioning=True)
epochs: Number of training epochs (default: 100)
batch_size: Training batch size (default: 16384)
lr: Learning rate (default: 0.001)
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
""" """
print("Checking GPU availability...") print("Checking GPU availability...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
@@ -184,7 +174,6 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
print(f"Dataset size: {num_positions}") print(f"Dataset size: {num_positions}")
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})") print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
# Print dataset diagnostics
evals_array = np.array(dataset.evals) evals_array = np.array(dataset.evals)
print() print()
print("=" * 60) print("=" * 60)
@@ -198,19 +187,19 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
print("=" * 60) print("=" * 60)
print() print()
# Split 90% train, 10% validation
train_size = int(0.9 * len(dataset)) train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size val_size = len(dataset) - train_size
from torch.utils.data import random_split from torch.utils.data import random_split, RandomSampler
generator = torch.Generator().manual_seed(42) generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator) train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
# DataLoader with GPU optimizations: num_workers=8, pin_memory, persistent_workers subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
train_sampler = RandomSampler(train_dataset, replacement=False, num_samples=subsample_size)
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, sampler=train_sampler,
num_workers=8, num_workers=8,
pin_memory=True, pin_memory=True,
persistent_workers=True persistent_workers=True
@@ -224,54 +213,40 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
persistent_workers=True persistent_workers=True
) )
# Model return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions
model = NNUE().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
# Cosine annealing learning rate scheduler def _run_training_season(
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
start_epoch, epochs, early_stopping_patience,
season_start_time, deadline=None, initial_best_val_loss=float('inf')
):
"""Run one training season until epoch limit, early stopping, or deadline.
# Mixed precision training Args:
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu') initial_best_val_loss: Baseline to beat — epochs that don't improve on this count
toward early stopping and do not save snapshots.
start_epoch = 0 Returns:
best_val_loss = float('inf') (best_val_loss, best_model_state, last_epoch)
if checkpoint: best_model_state is None if no epoch beat initial_best_val_loss.
print(f"Loading checkpoint: {checkpoint}") """
ckpt = torch.load(checkpoint, map_location=device) best_val_loss = initial_best_val_loss
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
scaler.load_state_dict(ckpt["scaler_state_dict"])
start_epoch = ckpt["epoch"] + 1
best_val_loss = ckpt.get("best_val_loss", float('inf'))
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no optimizer state)")
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
best_model_state = None best_model_state = None
epochs_without_improvement = 0 epochs_without_improvement = 0
total_epochs = start_epoch + epochs
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...") last_epoch = start_epoch - 1
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
print(f"Mixed precision training: enabled")
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
if early_stopping_patience:
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
print()
training_start_time = datetime.now()
for epoch in range(start_epoch, start_epoch + epochs): for epoch in range(start_epoch, start_epoch + epochs):
if deadline and datetime.now() >= deadline:
print("Time limit reached, stopping season.")
break
epoch_display = epoch + 1
# Train # Train
model.train() model.train()
train_loss = 0.0 train_loss = 0.0
epoch_display = epoch + 1
total_epochs = start_epoch + epochs
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar: with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
for batch_features, batch_targets in train_loader: for batch_features, batch_targets in train_loader:
batch_features = batch_features.to(device) batch_features = batch_features.to(device)
@@ -279,7 +254,6 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
optimizer.zero_grad() optimizer.zero_grad()
# Mixed precision forward and backward
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'): with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
outputs = model(batch_features) outputs = model(batch_features)
loss = criterion(outputs, batch_targets) loss = criterion(outputs, batch_targets)
@@ -310,25 +284,20 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
val_loss /= len(val_dataset) val_loss /= len(val_dataset)
# Update learning rate
scheduler.step() scheduler.step()
# Print GPU memory usage
if torch.cuda.is_available(): if torch.cuda.is_available():
gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9 gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9
gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9 gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9
print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved") print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved")
# Calculate and print estimated time remaining elapsed_time = datetime.now() - season_start_time
elapsed_time = datetime.now() - training_start_time
time_per_epoch = elapsed_time.total_seconds() / (epoch + 1) time_per_epoch = elapsed_time.total_seconds() / (epoch + 1)
remaining_epochs = total_epochs - epoch_display remaining_epochs = total_epochs - epoch_display
eta_seconds = time_per_epoch * remaining_epochs eta_seconds = time_per_epoch * remaining_epochs
eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0] eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0]
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}") print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}")
# Save checkpoint after every epoch
checkpoint_file = output_file.replace(".pt", "_checkpoint.pt") checkpoint_file = output_file.replace(".pt", "_checkpoint.pt")
torch.save({ torch.save({
"epoch": epoch, "epoch": epoch,
@@ -339,74 +308,264 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
"best_val_loss": best_val_loss, "best_val_loss": best_val_loss,
}, checkpoint_file) }, checkpoint_file)
if val_loss < : if val_loss < best_val_loss:
best_val_loss = val_loss best_val_loss = val_loss
best_model_state = model.state_dict().copy() best_model_state = model.state_dict().copy()
epochs_without_improvement = 0 epochs_without_improvement = 0
# Save best model snapshot
snapshot_file = output_file.replace(".pt", "_best_snapshot.pt") snapshot_file = output_file.replace(".pt", "_best_snapshot.pt")
torch.save(best_model_state, snapshot_file) torch.save(best_model_state, snapshot_file)
print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})") print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})")
else: else:
epochs_without_improvement += 1 epochs_without_improvement += 1
# Early stopping check last_epoch = epoch
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience: if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
print(f"Early stopping: no improvement for {early_stopping_patience} epochs") print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
break break
# Save best model return best_val_loss, best_model_state, last_epoch
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
best_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, training_start_time, extra_metadata=None):
"""Save the best model with optional versioning and metadata."""
final_output_file = output_file
metadata = {}
if use_versioning:
base_name = output_file.replace(".pt", "")
version = find_next_version(base_name)
final_output_file = f"{base_name}_v{version}.pt"
metadata = {
"version": version,
"date": training_start_time.isoformat(),
"num_positions": num_positions,
"stockfish_depth": stockfish_depth,
"final_val_loss": float(best_val_loss),
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
if extra_metadata:
metadata.update(extra_metadata)
torch.save({
"model_state_dict": best_model_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"epoch": last_epoch,
"best_val_loss": best_val_loss,
}, final_output_file)
print(f"Best model saved to {final_output_file}")
if use_versioning and metadata:
metadata_file = save_metadata(final_output_file, metadata)
print(f"Metadata saved to {metadata_file}")
print(f"\nTraining Summary:")
for key, val in metadata.items():
print(f" {key}: {val}")
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0):
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
Args:
data_file: Path to training_data.jsonl
output_file: Where to save best weights (or base name if use_versioning=True)
epochs: Number of training epochs (default: 100)
batch_size: Training batch size (default: 16384)
lr: Learning rate (default: 0.001)
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
"""
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
_setup_training(data_file, batch_size, subsample_ratio)
model = NNUE().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
start_epoch = 0
best_val_loss = float('inf')
if checkpoint:
print(f"Loading checkpoint: {checkpoint}")
ckpt = torch.load(checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
scaler.load_state_dict(ckpt["scaler_state_dict"])
start_epoch = ckpt["epoch"] + 1
best_val_loss = ckpt.get("best_val_loss", float('inf'))
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no optimizer state)")
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
print(f"Mixed precision training: enabled")
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
if subsample_ratio < 1.0:
print(f"Stochastic sampling: {subsample_ratio:.0%} of train set per epoch ({subsample_size:,} positions)")
if early_stopping_patience:
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
print()
training_start_time = datetime.now()
best_val_loss, best_model_state, last_epoch = _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
start_epoch, epochs, early_stopping_patience,
training_start_time
)
if best_model_state is None or best_val_loss >= checkpoint_val_loss: if best_model_state is None or best_val_loss >= checkpoint_val_loss:
print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})") print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})")
print("No new model created.") print("No new model created.")
return return
if best_model_state is not None: _save_versioned_model(
# Determine final output file with versioning best_model_state, optimizer, scheduler, scaler, last_epoch,
final_output_file = output_file best_val_loss, output_file, use_versioning, num_positions,
metadata = {} stockfish_depth, training_start_time,
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
"checkpoint": str(checkpoint) if checkpoint else None}
)
if use_versioning: def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
base_name = output_file.replace(".pt", "") epochs_per_season=50, early_stopping_patience=10,
version = find_next_version(base_name) batch_size=16384, lr=0.001, initial_checkpoint=None,
final_output_file = f"{base_name}_v{version}.pt" stockfish_depth=12, use_versioning=True,
weight_decay=1e-4, subsample_ratio=1.0):
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
# Prepare metadata Each season trains with early stopping. When early stopping fires, the model reloads the
metadata = { global best weights and begins a fresh season with a reset optimizer and scheduler.
"version": version, This prevents the model from drifting away from its best known state.
"date": training_start_time.isoformat(),
"num_positions": num_positions,
"stockfish_depth": stockfish_depth,
"epochs": epochs,
"batch_size": batch_size,
"learning_rate": lr,
"final_val_loss": float(best_val_loss),
"device": str(device),
"checkpoint": str(checkpoint) if checkpoint else None,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
torch.save({ Args:
"model_state_dict": best_model_state, data_file: Path to training_data.jsonl
"optimizer_state_dict": optimizer.state_dict(), output_file: Output file base name
"scheduler_state_dict": scheduler.state_dict(), duration_minutes: Total training budget in minutes
"scaler_state_dict": scaler.state_dict(), epochs_per_season: Max epochs per restart season (default: 50)
"epoch": epoch, early_stopping_patience: Patience for early stopping within each season (default: 10)
"best_val_loss": best_val_loss, batch_size: Training batch size (default: 16384)
}, final_output_file) lr: Learning rate reset to this value at the start of each season (default: 0.001)
print(f"Best model saved to {final_output_file}") initial_checkpoint: Optional weights-only .pt file to start from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
weight_decay: L2 regularization strength (default: 1e-4)
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
"""
deadline = datetime.now() + timedelta(minutes=duration_minutes)
# Save metadata if versioning is enabled device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
if use_versioning and metadata: _setup_training(data_file, batch_size, subsample_ratio)
metadata_file = save_metadata(final_output_file, metadata)
print(f"Metadata saved to {metadata_file}") model = NNUE().to(device)
print(f"\nTraining Summary:") criterion = nn.MSELoss()
print(f" Version: v{metadata['version']}") best_global_val_loss = float('inf')
print(f" Positions: {metadata['num_positions']}")
print(f" Stockfish depth: {metadata['stockfish_depth']}") if initial_checkpoint:
print(f" Epochs: {metadata['epochs']}") print(f"Loading initial weights: {initial_checkpoint}")
print(f" Final validation loss: {metadata['final_val_loss']:.6f}") ckpt = torch.load(initial_checkpoint, map_location=device)
print(f" Device: {metadata['device']}") if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
best_global_val_loss = ckpt.get("best_val_loss", float('inf'))
if best_global_val_loss < float('inf'):
print(f"Resumed from checkpoint (best val loss: {best_global_val_loss:.6f})")
else:
print("Initial weights loaded (no val loss in checkpoint).")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no val loss info).")
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
print()
burst_start_time = datetime.now()
season = 0
best_global_state = None
last_optimizer = None
last_scheduler = None
last_scaler = None
last_epoch = 0
while datetime.now() < deadline:
season += 1
remaining_minutes = (deadline - datetime.now()).total_seconds() / 60
print(f"\n{'=' * 60}")
print(f"BURST SEASON {season} | {remaining_minutes:.1f} minutes remaining")
if best_global_val_loss < float('inf'):
print(f"Global best val loss so far: {best_global_val_loss:.6f}")
print(f"{'=' * 60}\n")
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs_per_season)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
season_start_time = datetime.now()
val_loss, model_state, last_epoch = _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
0, epochs_per_season, early_stopping_patience,
season_start_time, deadline=deadline,
initial_best_val_loss=best_global_val_loss
)
last_optimizer = optimizer
last_scheduler = scheduler
last_scaler = scaler
if model_state is not None and val_loss < best_global_val_loss:
best_global_val_loss = val_loss
best_global_state = model_state
print(f" New global best: {best_global_val_loss:.6f} (season {season})")
# Reload global best for the next season so we never drift backwards
if best_global_state is not None:
model.load_state_dict(best_global_state)
total_minutes = (datetime.now() - burst_start_time).total_seconds() / 60
print(f"\n{'=' * 60}")
print(f"Burst training complete: {season} season(s) in {total_minutes:.1f}m")
print(f"Best val loss: {best_global_val_loss:.6f}")
print(f"{'=' * 60}\n")
if best_global_state is None:
print("No model improvement found. No file saved.")
return
_save_versioned_model(
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
best_global_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, burst_start_time,
extra_metadata={
"mode": "burst",
"duration_minutes": duration_minutes,
"epochs_per_season": epochs_per_season,
"early_stopping_patience": early_stopping_patience,
"seasons_completed": season,
"batch_size": batch_size,
"learning_rate": lr,
"initial_checkpoint": str(initial_checkpoint) if initial_checkpoint else None,
}
)
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
@@ -432,18 +591,43 @@ if __name__ == "__main__":
help="Disable automatic versioning (save directly to output file)") help="Disable automatic versioning (save directly to output file)")
parser.add_argument("--weight-decay", type=float, default=5e-5, parser.add_argument("--weight-decay", type=float, default=5e-5,
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)") help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
parser.add_argument("--subsample-ratio", type=float, default=1.0,
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
# Burst mode
parser.add_argument("--burst-duration", type=float, default=None,
help="Enable burst mode: total training budget in minutes")
parser.add_argument("--epochs-per-season", type=int, default=50,
help="Max epochs per burst season before restarting (default: 50, burst mode only)")
args = parser.parse_args() args = parser.parse_args()
train_nnue( if args.burst_duration is not None:
data_file=args.data_file, burst_train(
output_file=args.output_file, data_file=args.data_file,
epochs=args.epochs, output_file=args.output_file,
batch_size=args.batch_size, duration_minutes=args.burst_duration,
lr=args.lr, epochs_per_season=args.epochs_per_season,
checkpoint=args.checkpoint, early_stopping_patience=args.early_stopping or 10,
stockfish_depth=args.stockfish_depth, batch_size=args.batch_size,
use_versioning=not args.no_versioning, lr=args.lr,
early_stopping_patience=args.early_stopping, initial_checkpoint=args.checkpoint,
weight_decay=args.weight_decay stockfish_depth=args.stockfish_depth,
) use_versioning=not args.no_versioning,
weight_decay=args.weight_decay,
subsample_ratio=args.subsample_ratio,
)
else:
train_nnue(
data_file=args.data_file,
output_file=args.output_file,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning,
early_stopping_patience=args.early_stopping,
weight_decay=args.weight_decay,
subsample_ratio=args.subsample_ratio,
)
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 4,
"date": "2026-04-09T00:28:07.572209",
"num_positions": 2009355,
"stockfish_depth": 12,
"epochs": 40,
"batch_size": 4096,
"learning_rate": 0.001,
"final_val_loss": 9.106677896235248e-05,
"device": "cuda",
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v3.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 5,
"date": "2026-04-09T18:50:27.845632",
"num_positions": 2009355,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"final_val_loss": 9.180421525105905e-05,
"device": "cuda",
"checkpoint": null,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 6,
"date": "2026-04-09T21:28:21.000832",
"num_positions": 1958728,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"final_val_loss": 0.2984530149085532,
"device": "cuda",
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v5.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 7,
"date": "2026-04-09T22:06:50.439858",
"num_positions": 1958728,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"final_val_loss": 0.2997283308762831,
"device": "cuda",
"checkpoint": null,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 8,
"date": "2026-04-09T22:22:47.859730",
"num_positions": 1958728,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"final_val_loss": 0.24803777390839968,
"device": "cuda",
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v7.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.