feat: add metadata files for NNUE weights versions and implement burst training functionality with subsampling options

This commit is contained in:
2026-04-10 02:16:01 +02:00
parent 5d4cf5f13c
commit 01c8a0f8fe
14 changed files with 465 additions and 134 deletions
+92 -10
View File
@@ -16,7 +16,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish
from train import train_nnue
from train import train_nnue, burst_train
from export import export_weights_to_binary
from tactical_positions_extractor import (
download_and_extract_puzzle_db,
@@ -97,24 +97,27 @@ def show_main_menu():
console.print("\n[bold]What would you like to do?[/bold]")
console.print("[cyan]1[/cyan] - Train NNUE Model")
console.print("[cyan]2[/cyan] - Export Weights to Scala")
console.print("[cyan]3[/cyan] - Extract Tactical Positions")
console.print("[cyan]4[/cyan] - View Checkpoints")
console.print("[cyan]5[/cyan] - Exit")
console.print("[cyan]2[/cyan] - Burst Train NNUE Model")
console.print("[cyan]3[/cyan] - Export Weights to Scala")
console.print("[cyan]4[/cyan] - Extract Tactical Positions")
console.print("[cyan]5[/cyan] - View Checkpoints")
console.print("[cyan]6[/cyan] - Exit")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5", "6"])
if choice == "1":
train_interactive()
elif choice == "2":
export_interactive()
burst_train_interactive()
elif choice == "3":
extract_tactical_interactive()
export_interactive()
elif choice == "4":
extract_tactical_interactive()
elif choice == "5":
show_header()
show_checkpoints_table()
Prompt.ask("\nPress Enter to continue")
elif choice == "5":
elif choice == "6":
console.print("[yellow]👋 Goodbye![/yellow]")
return
@@ -172,6 +175,7 @@ def train_interactive():
# Training parameters
epochs = int(Prompt.ask("Number of epochs", default="100"))
batch_size = int(Prompt.ask("Batch size", default="16384"))
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
early_stopping = None
if Confirm.ask("Enable early stopping?", default=False):
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
@@ -190,6 +194,7 @@ def train_interactive():
console.print(f" Positions file: {positions_file}")
console.print(f" Epochs: {epochs}")
console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
if early_stopping:
console.print(f" Early stopping: Yes (patience: {early_stopping})")
else:
@@ -263,7 +268,8 @@ def train_interactive():
batch_size=batch_size,
checkpoint=checkpoint,
use_versioning=True,
early_stopping_patience=early_stopping
early_stopping_patience=early_stopping,
subsample_ratio=subsample_ratio
)
console.print("[green]✓ Training complete[/green]")
@@ -280,6 +286,82 @@ def train_interactive():
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def burst_train_interactive():
"""Interactive burst training menu."""
console = Console()
show_header()
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
# Data file
default_labels = str(get_data_dir() / "training_data.jsonl")
labels_file = Prompt.ask("Path to labeled data file (.jsonl)", default=default_labels)
# Optional initial checkpoint
available = list_checkpoints()
checkpoint = None
if available:
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
if Confirm.ask("Start from an existing checkpoint?", default=False):
version = Prompt.ask("Enter checkpoint version", default=str(max(available)))
checkpoint = str(get_weights_dir() / f"nnue_weights_v{version}.pt")
# Training hyperparameters
batch_size = int(Prompt.ask("Batch size", default="16384"))
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
# Summary
console.print("\n[bold]Configuration Summary:[/bold]")
console.print(f" Duration: {duration_minutes:.0f} minutes")
console.print(f" Epochs per season: {epochs_per_season}")
console.print(f" Patience: {early_stopping_patience}")
console.print(f" Data file: {labels_file}")
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
if not Confirm.ask("\nStart burst training?", default=True):
console.print("[yellow]Burst training cancelled[/yellow]")
return
weights_dir = get_weights_dir()
try:
if not Path(labels_file).exists():
console.print(f"[red]✗ Data file not found: {labels_file}[/red]")
Prompt.ask("Press Enter to continue")
return
console.print("\n[bold cyan]Burst Training[/bold cyan]")
burst_train(
data_file=labels_file,
output_file=str(weights_dir / "nnue_weights.pt"),
duration_minutes=duration_minutes,
epochs_per_season=epochs_per_season,
early_stopping_patience=early_stopping_patience,
batch_size=batch_size,
initial_checkpoint=checkpoint,
use_versioning=True,
subsample_ratio=subsample_ratio,
)
console.print("[green]✓ Burst training complete[/green]")
available = list_checkpoints()
if available:
console.print(f"[bold]Latest checkpoint: v{max(available)}[/bold]")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def export_interactive():
"""Interactive export menu."""
console = Console()
+308 -124
View File
@@ -10,7 +10,7 @@ import sys
from pathlib import Path
from tqdm import tqdm
import chess
from datetime import datetime
from datetime import datetime, timedelta
import re
import numpy as np
@@ -152,22 +152,12 @@ def save_metadata(weights_file, metadata):
return metadata_file
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4):
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
def _setup_training(data_file, batch_size, subsample_ratio):
"""Set up device, dataset, and data loaders.
Args:
data_file: Path to training_data.jsonl
output_file: Where to save best weights (or base name if use_versioning=True)
epochs: Number of training epochs (default: 100)
batch_size: Training batch size (default: 16384)
lr: Learning rate (default: 0.001)
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
Returns:
(device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions)
"""
print("Checking GPU availability...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
@@ -184,7 +174,6 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
print(f"Dataset size: {num_positions}")
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
# Print dataset diagnostics
evals_array = np.array(dataset.evals)
print()
print("=" * 60)
@@ -198,19 +187,19 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
print("=" * 60)
print()
# Split 90% train, 10% validation
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
from torch.utils.data import random_split
from torch.utils.data import random_split, RandomSampler
generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
# DataLoader with GPU optimizations: num_workers=8, pin_memory, persistent_workers
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
train_sampler = RandomSampler(train_dataset, replacement=False, num_samples=subsample_size)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
sampler=train_sampler,
num_workers=8,
pin_memory=True,
persistent_workers=True
@@ -224,54 +213,40 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
persistent_workers=True
)
# Model
model = NNUE().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions
# Cosine annealing learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
def _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
start_epoch, epochs, early_stopping_patience,
season_start_time, deadline=None, initial_best_val_loss=float('inf')
):
"""Run one training season until epoch limit, early stopping, or deadline.
# Mixed precision training
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
start_epoch = 0
best_val_loss = float('inf')
if checkpoint:
print(f"Loading checkpoint: {checkpoint}")
ckpt = torch.load(checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
scaler.load_state_dict(ckpt["scaler_state_dict"])
start_epoch = ckpt["epoch"] + 1
best_val_loss = ckpt.get("best_val_loss", float('inf'))
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no optimizer state)")
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
Args:
initial_best_val_loss: Baseline to beat — epochs that don't improve on this count
toward early stopping and do not save snapshots.
Returns:
(best_val_loss, best_model_state, last_epoch)
best_model_state is None if no epoch beat initial_best_val_loss.
"""
best_val_loss = initial_best_val_loss
best_model_state = None
epochs_without_improvement = 0
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
print(f"Mixed precision training: enabled")
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
if early_stopping_patience:
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
print()
training_start_time = datetime.now()
total_epochs = start_epoch + epochs
last_epoch = start_epoch - 1
for epoch in range(start_epoch, start_epoch + epochs):
if deadline and datetime.now() >= deadline:
print("Time limit reached, stopping season.")
break
epoch_display = epoch + 1
# Train
model.train()
train_loss = 0.0
epoch_display = epoch + 1
total_epochs = start_epoch + epochs
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
for batch_features, batch_targets in train_loader:
batch_features = batch_features.to(device)
@@ -279,7 +254,6 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
optimizer.zero_grad()
# Mixed precision forward and backward
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
@@ -310,25 +284,20 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
val_loss /= len(val_dataset)
# Update learning rate
scheduler.step()
# Print GPU memory usage
if torch.cuda.is_available():
gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9
gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9
print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved")
# Calculate and print estimated time remaining
elapsed_time = datetime.now() - training_start_time
elapsed_time = datetime.now() - season_start_time
time_per_epoch = elapsed_time.total_seconds() / (epoch + 1)
remaining_epochs = total_epochs - epoch_display
eta_seconds = time_per_epoch * remaining_epochs
eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0]
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}")
# Save checkpoint after every epoch
checkpoint_file = output_file.replace(".pt", "_checkpoint.pt")
torch.save({
"epoch": epoch,
@@ -339,74 +308,264 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
"best_val_loss": best_val_loss,
}, checkpoint_file)
if val_loss < :
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = model.state_dict().copy()
epochs_without_improvement = 0
# Save best model snapshot
snapshot_file = output_file.replace(".pt", "_best_snapshot.pt")
torch.save(best_model_state, snapshot_file)
print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})")
else:
epochs_without_improvement += 1
# Early stopping check
last_epoch = epoch
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
break
# Save best model
return best_val_loss, best_model_state, last_epoch
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
best_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, training_start_time, extra_metadata=None):
"""Save the best model with optional versioning and metadata."""
final_output_file = output_file
metadata = {}
if use_versioning:
base_name = output_file.replace(".pt", "")
version = find_next_version(base_name)
final_output_file = f"{base_name}_v{version}.pt"
metadata = {
"version": version,
"date": training_start_time.isoformat(),
"num_positions": num_positions,
"stockfish_depth": stockfish_depth,
"final_val_loss": float(best_val_loss),
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
if extra_metadata:
metadata.update(extra_metadata)
torch.save({
"model_state_dict": best_model_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"epoch": last_epoch,
"best_val_loss": best_val_loss,
}, final_output_file)
print(f"Best model saved to {final_output_file}")
if use_versioning and metadata:
metadata_file = save_metadata(final_output_file, metadata)
print(f"Metadata saved to {metadata_file}")
print(f"\nTraining Summary:")
for key, val in metadata.items():
print(f" {key}: {val}")
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0):
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
Args:
data_file: Path to training_data.jsonl
output_file: Where to save best weights (or base name if use_versioning=True)
epochs: Number of training epochs (default: 100)
batch_size: Training batch size (default: 16384)
lr: Learning rate (default: 0.001)
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
"""
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
_setup_training(data_file, batch_size, subsample_ratio)
model = NNUE().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
start_epoch = 0
best_val_loss = float('inf')
if checkpoint:
print(f"Loading checkpoint: {checkpoint}")
ckpt = torch.load(checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
scaler.load_state_dict(ckpt["scaler_state_dict"])
start_epoch = ckpt["epoch"] + 1
best_val_loss = ckpt.get("best_val_loss", float('inf'))
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no optimizer state)")
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
print(f"Mixed precision training: enabled")
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
if subsample_ratio < 1.0:
print(f"Stochastic sampling: {subsample_ratio:.0%} of train set per epoch ({subsample_size:,} positions)")
if early_stopping_patience:
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
print()
training_start_time = datetime.now()
best_val_loss, best_model_state, last_epoch = _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
start_epoch, epochs, early_stopping_patience,
training_start_time
)
if best_model_state is None or best_val_loss >= checkpoint_val_loss:
print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})")
print("No new model created.")
return
if best_model_state is not None:
# Determine final output file with versioning
final_output_file = output_file
metadata = {}
_save_versioned_model(
best_model_state, optimizer, scheduler, scaler, last_epoch,
best_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, training_start_time,
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
"checkpoint": str(checkpoint) if checkpoint else None}
)
if use_versioning:
base_name = output_file.replace(".pt", "")
version = find_next_version(base_name)
final_output_file = f"{base_name}_v{version}.pt"
def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
epochs_per_season=50, early_stopping_patience=10,
batch_size=16384, lr=0.001, initial_checkpoint=None,
stockfish_depth=12, use_versioning=True,
weight_decay=1e-4, subsample_ratio=1.0):
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
# Prepare metadata
metadata = {
"version": version,
"date": training_start_time.isoformat(),
"num_positions": num_positions,
"stockfish_depth": stockfish_depth,
"epochs": epochs,
"batch_size": batch_size,
"learning_rate": lr,
"final_val_loss": float(best_val_loss),
"device": str(device),
"checkpoint": str(checkpoint) if checkpoint else None,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Each season trains with early stopping. When early stopping fires, the model reloads the
global best weights and begins a fresh season with a reset optimizer and scheduler.
This prevents the model from drifting away from its best known state.
torch.save({
"model_state_dict": best_model_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"epoch": epoch,
"best_val_loss": best_val_loss,
}, final_output_file)
print(f"Best model saved to {final_output_file}")
Args:
data_file: Path to training_data.jsonl
output_file: Output file base name
duration_minutes: Total training budget in minutes
epochs_per_season: Max epochs per restart season (default: 50)
early_stopping_patience: Patience for early stopping within each season (default: 10)
batch_size: Training batch size (default: 16384)
lr: Learning rate reset to this value at the start of each season (default: 0.001)
initial_checkpoint: Optional weights-only .pt file to start from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
weight_decay: L2 regularization strength (default: 1e-4)
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
"""
deadline = datetime.now() + timedelta(minutes=duration_minutes)
# Save metadata if versioning is enabled
if use_versioning and metadata:
metadata_file = save_metadata(final_output_file, metadata)
print(f"Metadata saved to {metadata_file}")
print(f"\nTraining Summary:")
print(f" Version: v{metadata['version']}")
print(f" Positions: {metadata['num_positions']}")
print(f" Stockfish depth: {metadata['stockfish_depth']}")
print(f" Epochs: {metadata['epochs']}")
print(f" Final validation loss: {metadata['final_val_loss']:.6f}")
print(f" Device: {metadata['device']}")
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
_setup_training(data_file, batch_size, subsample_ratio)
model = NNUE().to(device)
criterion = nn.MSELoss()
best_global_val_loss = float('inf')
if initial_checkpoint:
print(f"Loading initial weights: {initial_checkpoint}")
ckpt = torch.load(initial_checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
best_global_val_loss = ckpt.get("best_val_loss", float('inf'))
if best_global_val_loss < float('inf'):
print(f"Resumed from checkpoint (best val loss: {best_global_val_loss:.6f})")
else:
print("Initial weights loaded (no val loss in checkpoint).")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no val loss info).")
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
print()
burst_start_time = datetime.now()
season = 0
best_global_state = None
last_optimizer = None
last_scheduler = None
last_scaler = None
last_epoch = 0
while datetime.now() < deadline:
season += 1
remaining_minutes = (deadline - datetime.now()).total_seconds() / 60
print(f"\n{'=' * 60}")
print(f"BURST SEASON {season} | {remaining_minutes:.1f} minutes remaining")
if best_global_val_loss < float('inf'):
print(f"Global best val loss so far: {best_global_val_loss:.6f}")
print(f"{'=' * 60}\n")
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs_per_season)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
season_start_time = datetime.now()
val_loss, model_state, last_epoch = _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
0, epochs_per_season, early_stopping_patience,
season_start_time, deadline=deadline,
initial_best_val_loss=best_global_val_loss
)
last_optimizer = optimizer
last_scheduler = scheduler
last_scaler = scaler
if model_state is not None and val_loss < best_global_val_loss:
best_global_val_loss = val_loss
best_global_state = model_state
print(f" New global best: {best_global_val_loss:.6f} (season {season})")
# Reload global best for the next season so we never drift backwards
if best_global_state is not None:
model.load_state_dict(best_global_state)
total_minutes = (datetime.now() - burst_start_time).total_seconds() / 60
print(f"\n{'=' * 60}")
print(f"Burst training complete: {season} season(s) in {total_minutes:.1f}m")
print(f"Best val loss: {best_global_val_loss:.6f}")
print(f"{'=' * 60}\n")
if best_global_state is None:
print("No model improvement found. No file saved.")
return
_save_versioned_model(
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
best_global_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, burst_start_time,
extra_metadata={
"mode": "burst",
"duration_minutes": duration_minutes,
"epochs_per_season": epochs_per_season,
"early_stopping_patience": early_stopping_patience,
"seasons_completed": season,
"batch_size": batch_size,
"learning_rate": lr,
"initial_checkpoint": str(initial_checkpoint) if initial_checkpoint else None,
}
)
if __name__ == "__main__":
import argparse
@@ -432,18 +591,43 @@ if __name__ == "__main__":
help="Disable automatic versioning (save directly to output file)")
parser.add_argument("--weight-decay", type=float, default=5e-5,
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
parser.add_argument("--subsample-ratio", type=float, default=1.0,
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
# Burst mode
parser.add_argument("--burst-duration", type=float, default=None,
help="Enable burst mode: total training budget in minutes")
parser.add_argument("--epochs-per-season", type=int, default=50,
help="Max epochs per burst season before restarting (default: 50, burst mode only)")
args = parser.parse_args()
train_nnue(
data_file=args.data_file,
output_file=args.output_file,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning,
early_stopping_patience=args.early_stopping,
weight_decay=args.weight_decay
)
if args.burst_duration is not None:
burst_train(
data_file=args.data_file,
output_file=args.output_file,
duration_minutes=args.burst_duration,
epochs_per_season=args.epochs_per_season,
early_stopping_patience=args.early_stopping or 10,
batch_size=args.batch_size,
lr=args.lr,
initial_checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning,
weight_decay=args.weight_decay,
subsample_ratio=args.subsample_ratio,
)
else:
train_nnue(
data_file=args.data_file,
output_file=args.output_file,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning,
early_stopping_patience=args.early_stopping,
weight_decay=args.weight_decay,
subsample_ratio=args.subsample_ratio,
)
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 4,
"date": "2026-04-09T00:28:07.572209",
"num_positions": 2009355,
"stockfish_depth": 12,
"epochs": 40,
"batch_size": 4096,
"learning_rate": 0.001,
"final_val_loss": 9.106677896235248e-05,
"device": "cuda",
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v3.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 5,
"date": "2026-04-09T18:50:27.845632",
"num_positions": 2009355,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"final_val_loss": 9.180421525105905e-05,
"device": "cuda",
"checkpoint": null,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 6,
"date": "2026-04-09T21:28:21.000832",
"num_positions": 1958728,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"final_val_loss": 0.2984530149085532,
"device": "cuda",
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v5.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 7,
"date": "2026-04-09T22:06:50.439858",
"num_positions": 1958728,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"final_val_loss": 0.2997283308762831,
"device": "cuda",
"checkpoint": null,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 8,
"date": "2026-04-09T22:22:47.859730",
"num_positions": 1958728,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"final_val_loss": 0.24803777390839968,
"device": "cuda",
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v7.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.