feat: add metadata files for NNUE weights versions and implement burst training functionality with subsampling options
This commit is contained in:
+92
-10
@@ -16,7 +16,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||||
|
||||
from generate import play_random_game_and_collect_positions
|
||||
from label import label_positions_with_stockfish
|
||||
from train import train_nnue
|
||||
from train import train_nnue, burst_train
|
||||
from export import export_weights_to_binary
|
||||
from tactical_positions_extractor import (
|
||||
download_and_extract_puzzle_db,
|
||||
@@ -97,24 +97,27 @@ def show_main_menu():
|
||||
|
||||
console.print("\n[bold]What would you like to do?[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Train NNUE Model")
|
||||
console.print("[cyan]2[/cyan] - Export Weights to Scala")
|
||||
console.print("[cyan]3[/cyan] - Extract Tactical Positions")
|
||||
console.print("[cyan]4[/cyan] - View Checkpoints")
|
||||
console.print("[cyan]5[/cyan] - Exit")
|
||||
console.print("[cyan]2[/cyan] - Burst Train NNUE Model")
|
||||
console.print("[cyan]3[/cyan] - Export Weights to Scala")
|
||||
console.print("[cyan]4[/cyan] - Extract Tactical Positions")
|
||||
console.print("[cyan]5[/cyan] - View Checkpoints")
|
||||
console.print("[cyan]6[/cyan] - Exit")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5", "6"])
|
||||
|
||||
if choice == "1":
|
||||
train_interactive()
|
||||
elif choice == "2":
|
||||
export_interactive()
|
||||
burst_train_interactive()
|
||||
elif choice == "3":
|
||||
extract_tactical_interactive()
|
||||
export_interactive()
|
||||
elif choice == "4":
|
||||
extract_tactical_interactive()
|
||||
elif choice == "5":
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
elif choice == "5":
|
||||
elif choice == "6":
|
||||
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||
return
|
||||
|
||||
@@ -172,6 +175,7 @@ def train_interactive():
|
||||
# Training parameters
|
||||
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||
early_stopping = None
|
||||
if Confirm.ask("Enable early stopping?", default=False):
|
||||
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||||
@@ -190,6 +194,7 @@ def train_interactive():
|
||||
console.print(f" Positions file: {positions_file}")
|
||||
console.print(f" Epochs: {epochs}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||
if early_stopping:
|
||||
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||||
else:
|
||||
@@ -263,7 +268,8 @@ def train_interactive():
|
||||
batch_size=batch_size,
|
||||
checkpoint=checkpoint,
|
||||
use_versioning=True,
|
||||
early_stopping_patience=early_stopping
|
||||
early_stopping_patience=early_stopping,
|
||||
subsample_ratio=subsample_ratio
|
||||
)
|
||||
console.print("[green]✓ Training complete[/green]")
|
||||
|
||||
@@ -280,6 +286,82 @@ def train_interactive():
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def burst_train_interactive():
|
||||
"""Interactive burst training menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
|
||||
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
|
||||
|
||||
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
|
||||
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
|
||||
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
|
||||
|
||||
# Data file
|
||||
default_labels = str(get_data_dir() / "training_data.jsonl")
|
||||
labels_file = Prompt.ask("Path to labeled data file (.jsonl)", default=default_labels)
|
||||
|
||||
# Optional initial checkpoint
|
||||
available = list_checkpoints()
|
||||
checkpoint = None
|
||||
if available:
|
||||
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||
if Confirm.ask("Start from an existing checkpoint?", default=False):
|
||||
version = Prompt.ask("Enter checkpoint version", default=str(max(available)))
|
||||
checkpoint = str(get_weights_dir() / f"nnue_weights_v{version}.pt")
|
||||
|
||||
# Training hyperparameters
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||
|
||||
# Summary
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
console.print(f" Duration: {duration_minutes:.0f} minutes")
|
||||
console.print(f" Epochs per season: {epochs_per_season}")
|
||||
console.print(f" Patience: {early_stopping_patience}")
|
||||
console.print(f" Data file: {labels_file}")
|
||||
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||
|
||||
if not Confirm.ask("\nStart burst training?", default=True):
|
||||
console.print("[yellow]Burst training cancelled[/yellow]")
|
||||
return
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
try:
|
||||
if not Path(labels_file).exists():
|
||||
console.print(f"[red]✗ Data file not found: {labels_file}[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print("\n[bold cyan]Burst Training[/bold cyan]")
|
||||
burst_train(
|
||||
data_file=labels_file,
|
||||
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||
duration_minutes=duration_minutes,
|
||||
epochs_per_season=epochs_per_season,
|
||||
early_stopping_patience=early_stopping_patience,
|
||||
batch_size=batch_size,
|
||||
initial_checkpoint=checkpoint,
|
||||
use_versioning=True,
|
||||
subsample_ratio=subsample_ratio,
|
||||
)
|
||||
console.print("[green]✓ Burst training complete[/green]")
|
||||
|
||||
available = list_checkpoints()
|
||||
if available:
|
||||
console.print(f"[bold]Latest checkpoint: v{max(available)}[/bold]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def export_interactive():
|
||||
"""Interactive export menu."""
|
||||
console = Console()
|
||||
|
||||
+308
-124
@@ -10,7 +10,7 @@ import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import chess
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
@@ -152,22 +152,12 @@ def save_metadata(weights_file, metadata):
|
||||
|
||||
return metadata_file
|
||||
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4):
|
||||
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
|
||||
def _setup_training(data_file, batch_size, subsample_ratio):
|
||||
"""Set up device, dataset, and data loaders.
|
||||
|
||||
Args:
|
||||
data_file: Path to training_data.jsonl
|
||||
output_file: Where to save best weights (or base name if use_versioning=True)
|
||||
epochs: Number of training epochs (default: 100)
|
||||
batch_size: Training batch size (default: 16384)
|
||||
lr: Learning rate (default: 0.001)
|
||||
checkpoint: Optional path to checkpoint file to resume from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
|
||||
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
|
||||
Returns:
|
||||
(device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions)
|
||||
"""
|
||||
|
||||
print("Checking GPU availability...")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.is_available():
|
||||
@@ -184,7 +174,6 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
|
||||
print(f"Dataset size: {num_positions}")
|
||||
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
|
||||
|
||||
# Print dataset diagnostics
|
||||
evals_array = np.array(dataset.evals)
|
||||
print()
|
||||
print("=" * 60)
|
||||
@@ -198,19 +187,19 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Split 90% train, 10% validation
|
||||
train_size = int(0.9 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
|
||||
from torch.utils.data import random_split
|
||||
from torch.utils.data import random_split, RandomSampler
|
||||
generator = torch.Generator().manual_seed(42)
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
|
||||
|
||||
# DataLoader with GPU optimizations: num_workers=8, pin_memory, persistent_workers
|
||||
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
||||
train_sampler = RandomSampler(train_dataset, replacement=False, num_samples=subsample_size)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
sampler=train_sampler,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
@@ -224,54 +213,40 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
|
||||
persistent_workers=True
|
||||
)
|
||||
|
||||
# Model
|
||||
model = NNUE().to(device)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions
|
||||
|
||||
# Cosine annealing learning rate scheduler
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||
def _run_training_season(
|
||||
model, optimizer, scheduler, scaler,
|
||||
train_loader, val_loader, train_dataset, val_dataset,
|
||||
device, criterion, output_file,
|
||||
start_epoch, epochs, early_stopping_patience,
|
||||
season_start_time, deadline=None, initial_best_val_loss=float('inf')
|
||||
):
|
||||
"""Run one training season until epoch limit, early stopping, or deadline.
|
||||
|
||||
# Mixed precision training
|
||||
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||
|
||||
start_epoch = 0
|
||||
best_val_loss = float('inf')
|
||||
if checkpoint:
|
||||
print(f"Loading checkpoint: {checkpoint}")
|
||||
ckpt = torch.load(checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
||||
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
||||
scaler.load_state_dict(ckpt["scaler_state_dict"])
|
||||
start_epoch = ckpt["epoch"] + 1
|
||||
best_val_loss = ckpt.get("best_val_loss", float('inf'))
|
||||
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
print("Loaded weights-only checkpoint (no optimizer state)")
|
||||
|
||||
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
|
||||
Args:
|
||||
initial_best_val_loss: Baseline to beat — epochs that don't improve on this count
|
||||
toward early stopping and do not save snapshots.
|
||||
Returns:
|
||||
(best_val_loss, best_model_state, last_epoch)
|
||||
best_model_state is None if no epoch beat initial_best_val_loss.
|
||||
"""
|
||||
best_val_loss = initial_best_val_loss
|
||||
best_model_state = None
|
||||
epochs_without_improvement = 0
|
||||
|
||||
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
|
||||
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
|
||||
print(f"Mixed precision training: enabled")
|
||||
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
|
||||
if early_stopping_patience:
|
||||
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
|
||||
print()
|
||||
|
||||
training_start_time = datetime.now()
|
||||
total_epochs = start_epoch + epochs
|
||||
last_epoch = start_epoch - 1
|
||||
|
||||
for epoch in range(start_epoch, start_epoch + epochs):
|
||||
if deadline and datetime.now() >= deadline:
|
||||
print("Time limit reached, stopping season.")
|
||||
break
|
||||
|
||||
epoch_display = epoch + 1
|
||||
|
||||
# Train
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
epoch_display = epoch + 1
|
||||
total_epochs = start_epoch + epochs
|
||||
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
|
||||
for batch_features, batch_targets in train_loader:
|
||||
batch_features = batch_features.to(device)
|
||||
@@ -279,7 +254,6 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Mixed precision forward and backward
|
||||
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
@@ -310,25 +284,20 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
|
||||
|
||||
val_loss /= len(val_dataset)
|
||||
|
||||
# Update learning rate
|
||||
scheduler.step()
|
||||
|
||||
# Print GPU memory usage
|
||||
if torch.cuda.is_available():
|
||||
gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9
|
||||
gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9
|
||||
print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved")
|
||||
|
||||
# Calculate and print estimated time remaining
|
||||
elapsed_time = datetime.now() - training_start_time
|
||||
elapsed_time = datetime.now() - season_start_time
|
||||
time_per_epoch = elapsed_time.total_seconds() / (epoch + 1)
|
||||
remaining_epochs = total_epochs - epoch_display
|
||||
eta_seconds = time_per_epoch * remaining_epochs
|
||||
eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0]
|
||||
|
||||
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}")
|
||||
|
||||
# Save checkpoint after every epoch
|
||||
checkpoint_file = output_file.replace(".pt", "_checkpoint.pt")
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
@@ -339,74 +308,264 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
|
||||
"best_val_loss": best_val_loss,
|
||||
}, checkpoint_file)
|
||||
|
||||
if val_loss < :
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_model_state = model.state_dict().copy()
|
||||
epochs_without_improvement = 0
|
||||
# Save best model snapshot
|
||||
snapshot_file = output_file.replace(".pt", "_best_snapshot.pt")
|
||||
torch.save(best_model_state, snapshot_file)
|
||||
print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})")
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
|
||||
# Early stopping check
|
||||
last_epoch = epoch
|
||||
|
||||
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
|
||||
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
|
||||
break
|
||||
|
||||
# Save best model
|
||||
return best_val_loss, best_model_state, last_epoch
|
||||
|
||||
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||
best_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, training_start_time, extra_metadata=None):
|
||||
"""Save the best model with optional versioning and metadata."""
|
||||
final_output_file = output_file
|
||||
metadata = {}
|
||||
|
||||
if use_versioning:
|
||||
base_name = output_file.replace(".pt", "")
|
||||
version = find_next_version(base_name)
|
||||
final_output_file = f"{base_name}_v{version}.pt"
|
||||
|
||||
metadata = {
|
||||
"version": version,
|
||||
"date": training_start_time.isoformat(),
|
||||
"num_positions": num_positions,
|
||||
"stockfish_depth": stockfish_depth,
|
||||
"final_val_loss": float(best_val_loss),
|
||||
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
torch.save({
|
||||
"model_state_dict": best_model_state,
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"scaler_state_dict": scaler.state_dict(),
|
||||
"epoch": last_epoch,
|
||||
"best_val_loss": best_val_loss,
|
||||
}, final_output_file)
|
||||
print(f"Best model saved to {final_output_file}")
|
||||
|
||||
if use_versioning and metadata:
|
||||
metadata_file = save_metadata(final_output_file, metadata)
|
||||
print(f"Metadata saved to {metadata_file}")
|
||||
print(f"\nTraining Summary:")
|
||||
for key, val in metadata.items():
|
||||
print(f" {key}: {val}")
|
||||
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0):
|
||||
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
|
||||
|
||||
Args:
|
||||
data_file: Path to training_data.jsonl
|
||||
output_file: Where to save best weights (or base name if use_versioning=True)
|
||||
epochs: Number of training epochs (default: 100)
|
||||
batch_size: Training batch size (default: 16384)
|
||||
lr: Learning rate (default: 0.001)
|
||||
checkpoint: Optional path to checkpoint file to resume from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
|
||||
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
|
||||
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
|
||||
"""
|
||||
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||
_setup_training(data_file, batch_size, subsample_ratio)
|
||||
|
||||
model = NNUE().to(device)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||
|
||||
start_epoch = 0
|
||||
best_val_loss = float('inf')
|
||||
if checkpoint:
|
||||
print(f"Loading checkpoint: {checkpoint}")
|
||||
ckpt = torch.load(checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
||||
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
||||
scaler.load_state_dict(ckpt["scaler_state_dict"])
|
||||
start_epoch = ckpt["epoch"] + 1
|
||||
best_val_loss = ckpt.get("best_val_loss", float('inf'))
|
||||
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
print("Loaded weights-only checkpoint (no optimizer state)")
|
||||
|
||||
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
|
||||
|
||||
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
||||
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
|
||||
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
|
||||
print(f"Mixed precision training: enabled")
|
||||
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
|
||||
if subsample_ratio < 1.0:
|
||||
print(f"Stochastic sampling: {subsample_ratio:.0%} of train set per epoch ({subsample_size:,} positions)")
|
||||
if early_stopping_patience:
|
||||
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
|
||||
print()
|
||||
|
||||
training_start_time = datetime.now()
|
||||
|
||||
best_val_loss, best_model_state, last_epoch = _run_training_season(
|
||||
model, optimizer, scheduler, scaler,
|
||||
train_loader, val_loader, train_dataset, val_dataset,
|
||||
device, criterion, output_file,
|
||||
start_epoch, epochs, early_stopping_patience,
|
||||
training_start_time
|
||||
)
|
||||
|
||||
if best_model_state is None or best_val_loss >= checkpoint_val_loss:
|
||||
print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})")
|
||||
print("No new model created.")
|
||||
return
|
||||
|
||||
if best_model_state is not None:
|
||||
# Determine final output file with versioning
|
||||
final_output_file = output_file
|
||||
metadata = {}
|
||||
_save_versioned_model(
|
||||
best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||
best_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, training_start_time,
|
||||
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
|
||||
"checkpoint": str(checkpoint) if checkpoint else None}
|
||||
)
|
||||
|
||||
if use_versioning:
|
||||
base_name = output_file.replace(".pt", "")
|
||||
version = find_next_version(base_name)
|
||||
final_output_file = f"{base_name}_v{version}.pt"
|
||||
def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
||||
epochs_per_season=50, early_stopping_patience=10,
|
||||
batch_size=16384, lr=0.001, initial_checkpoint=None,
|
||||
stockfish_depth=12, use_versioning=True,
|
||||
weight_decay=1e-4, subsample_ratio=1.0):
|
||||
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
|
||||
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
"version": version,
|
||||
"date": training_start_time.isoformat(),
|
||||
"num_positions": num_positions,
|
||||
"stockfish_depth": stockfish_depth,
|
||||
"epochs": epochs,
|
||||
"batch_size": batch_size,
|
||||
"learning_rate": lr,
|
||||
"final_val_loss": float(best_val_loss),
|
||||
"device": str(device),
|
||||
"checkpoint": str(checkpoint) if checkpoint else None,
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Each season trains with early stopping. When early stopping fires, the model reloads the
|
||||
global best weights and begins a fresh season with a reset optimizer and scheduler.
|
||||
This prevents the model from drifting away from its best known state.
|
||||
|
||||
torch.save({
|
||||
"model_state_dict": best_model_state,
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"scaler_state_dict": scaler.state_dict(),
|
||||
"epoch": epoch,
|
||||
"best_val_loss": best_val_loss,
|
||||
}, final_output_file)
|
||||
print(f"Best model saved to {final_output_file}")
|
||||
Args:
|
||||
data_file: Path to training_data.jsonl
|
||||
output_file: Output file base name
|
||||
duration_minutes: Total training budget in minutes
|
||||
epochs_per_season: Max epochs per restart season (default: 50)
|
||||
early_stopping_patience: Patience for early stopping within each season (default: 10)
|
||||
batch_size: Training batch size (default: 16384)
|
||||
lr: Learning rate reset to this value at the start of each season (default: 0.001)
|
||||
initial_checkpoint: Optional weights-only .pt file to start from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
weight_decay: L2 regularization strength (default: 1e-4)
|
||||
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
|
||||
"""
|
||||
deadline = datetime.now() + timedelta(minutes=duration_minutes)
|
||||
|
||||
# Save metadata if versioning is enabled
|
||||
if use_versioning and metadata:
|
||||
metadata_file = save_metadata(final_output_file, metadata)
|
||||
print(f"Metadata saved to {metadata_file}")
|
||||
print(f"\nTraining Summary:")
|
||||
print(f" Version: v{metadata['version']}")
|
||||
print(f" Positions: {metadata['num_positions']}")
|
||||
print(f" Stockfish depth: {metadata['stockfish_depth']}")
|
||||
print(f" Epochs: {metadata['epochs']}")
|
||||
print(f" Final validation loss: {metadata['final_val_loss']:.6f}")
|
||||
print(f" Device: {metadata['device']}")
|
||||
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||
_setup_training(data_file, batch_size, subsample_ratio)
|
||||
|
||||
model = NNUE().to(device)
|
||||
criterion = nn.MSELoss()
|
||||
best_global_val_loss = float('inf')
|
||||
|
||||
if initial_checkpoint:
|
||||
print(f"Loading initial weights: {initial_checkpoint}")
|
||||
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_global_val_loss = ckpt.get("best_val_loss", float('inf'))
|
||||
if best_global_val_loss < float('inf'):
|
||||
print(f"Resumed from checkpoint (best val loss: {best_global_val_loss:.6f})")
|
||||
else:
|
||||
print("Initial weights loaded (no val loss in checkpoint).")
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
print("Loaded weights-only checkpoint (no val loss info).")
|
||||
|
||||
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
|
||||
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
|
||||
print()
|
||||
|
||||
burst_start_time = datetime.now()
|
||||
season = 0
|
||||
best_global_state = None
|
||||
last_optimizer = None
|
||||
last_scheduler = None
|
||||
last_scaler = None
|
||||
last_epoch = 0
|
||||
|
||||
while datetime.now() < deadline:
|
||||
season += 1
|
||||
remaining_minutes = (deadline - datetime.now()).total_seconds() / 60
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"BURST SEASON {season} | {remaining_minutes:.1f} minutes remaining")
|
||||
if best_global_val_loss < float('inf'):
|
||||
print(f"Global best val loss so far: {best_global_val_loss:.6f}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs_per_season)
|
||||
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||
|
||||
season_start_time = datetime.now()
|
||||
val_loss, model_state, last_epoch = _run_training_season(
|
||||
model, optimizer, scheduler, scaler,
|
||||
train_loader, val_loader, train_dataset, val_dataset,
|
||||
device, criterion, output_file,
|
||||
0, epochs_per_season, early_stopping_patience,
|
||||
season_start_time, deadline=deadline,
|
||||
initial_best_val_loss=best_global_val_loss
|
||||
)
|
||||
|
||||
last_optimizer = optimizer
|
||||
last_scheduler = scheduler
|
||||
last_scaler = scaler
|
||||
|
||||
if model_state is not None and val_loss < best_global_val_loss:
|
||||
best_global_val_loss = val_loss
|
||||
best_global_state = model_state
|
||||
print(f" New global best: {best_global_val_loss:.6f} (season {season})")
|
||||
|
||||
# Reload global best for the next season so we never drift backwards
|
||||
if best_global_state is not None:
|
||||
model.load_state_dict(best_global_state)
|
||||
|
||||
total_minutes = (datetime.now() - burst_start_time).total_seconds() / 60
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Burst training complete: {season} season(s) in {total_minutes:.1f}m")
|
||||
print(f"Best val loss: {best_global_val_loss:.6f}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
if best_global_state is None:
|
||||
print("No model improvement found. No file saved.")
|
||||
return
|
||||
|
||||
_save_versioned_model(
|
||||
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
|
||||
best_global_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, burst_start_time,
|
||||
extra_metadata={
|
||||
"mode": "burst",
|
||||
"duration_minutes": duration_minutes,
|
||||
"epochs_per_season": epochs_per_season,
|
||||
"early_stopping_patience": early_stopping_patience,
|
||||
"seasons_completed": season,
|
||||
"batch_size": batch_size,
|
||||
"learning_rate": lr,
|
||||
"initial_checkpoint": str(initial_checkpoint) if initial_checkpoint else None,
|
||||
}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
@@ -432,18 +591,43 @@ if __name__ == "__main__":
|
||||
help="Disable automatic versioning (save directly to output file)")
|
||||
parser.add_argument("--weight-decay", type=float, default=5e-5,
|
||||
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
|
||||
parser.add_argument("--subsample-ratio", type=float, default=1.0,
|
||||
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
|
||||
|
||||
# Burst mode
|
||||
parser.add_argument("--burst-duration", type=float, default=None,
|
||||
help="Enable burst mode: total training budget in minutes")
|
||||
parser.add_argument("--epochs-per-season", type=int, default=50,
|
||||
help="Max epochs per burst season before restarting (default: 50, burst mode only)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
train_nnue(
|
||||
data_file=args.data_file,
|
||||
output_file=args.output_file,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning,
|
||||
early_stopping_patience=args.early_stopping,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
if args.burst_duration is not None:
|
||||
burst_train(
|
||||
data_file=args.data_file,
|
||||
output_file=args.output_file,
|
||||
duration_minutes=args.burst_duration,
|
||||
epochs_per_season=args.epochs_per_season,
|
||||
early_stopping_patience=args.early_stopping or 10,
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
initial_checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning,
|
||||
weight_decay=args.weight_decay,
|
||||
subsample_ratio=args.subsample_ratio,
|
||||
)
|
||||
else:
|
||||
train_nnue(
|
||||
data_file=args.data_file,
|
||||
output_file=args.output_file,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning,
|
||||
early_stopping_patience=args.early_stopping,
|
||||
weight_decay=args.weight_decay,
|
||||
subsample_ratio=args.subsample_ratio,
|
||||
)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 4,
|
||||
"date": "2026-04-09T00:28:07.572209",
|
||||
"num_positions": 2009355,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 40,
|
||||
"batch_size": 4096,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 9.106677896235248e-05,
|
||||
"device": "cuda",
|
||||
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v3.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 5,
|
||||
"date": "2026-04-09T18:50:27.845632",
|
||||
"num_positions": 2009355,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 9.180421525105905e-05,
|
||||
"device": "cuda",
|
||||
"checkpoint": null,
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 6,
|
||||
"date": "2026-04-09T21:28:21.000832",
|
||||
"num_positions": 1958728,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.2984530149085532,
|
||||
"device": "cuda",
|
||||
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v5.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 7,
|
||||
"date": "2026-04-09T22:06:50.439858",
|
||||
"num_positions": 1958728,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.2997283308762831,
|
||||
"device": "cuda",
|
||||
"checkpoint": null,
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 8,
|
||||
"date": "2026-04-09T22:22:47.859730",
|
||||
"num_positions": 1958728,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.24803777390839968,
|
||||
"device": "cuda",
|
||||
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v7.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
Reference in New Issue
Block a user