Co-authored-by: Janis <janis@nowchess.de> Reviewed-on: #33 Co-authored-by: Janis <janis.e.20@gmx.de> Co-committed-by: Janis <janis.e.20@gmx.de>
This commit was merged in pull request #33.
This commit is contained in:
@@ -0,0 +1,676 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train NNUE neural network for chess evaluation."""
|
||||
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import chess
|
||||
from datetime import datetime, timedelta
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
class NNUEDataset(Dataset):
|
||||
"""Dataset of chess positions with evaluations."""
|
||||
|
||||
def __init__(self, data_file):
|
||||
self.positions = []
|
||||
self.evals = []
|
||||
self.evals_raw = []
|
||||
self.is_normalized = None
|
||||
|
||||
with open(data_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data['fen']
|
||||
eval_val = data['eval']
|
||||
self.positions.append(fen)
|
||||
self.evals.append(eval_val)
|
||||
|
||||
# Check if normalized or raw
|
||||
if self.is_normalized is None:
|
||||
# If eval is in range [-1, 1], assume normalized
|
||||
self.is_normalized = abs(eval_val) <= 1.0
|
||||
|
||||
# Store raw if available
|
||||
if 'eval_raw' in data:
|
||||
self.evals_raw.append(data['eval_raw'])
|
||||
else:
|
||||
self.evals_raw.append(eval_val)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
return len(self.positions)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
fen = self.positions[idx]
|
||||
eval_val = self.evals[idx]
|
||||
features = fen_to_features(fen)
|
||||
|
||||
# Use evaluation as-is if normalized, otherwise apply sigmoid scaling
|
||||
if self.is_normalized:
|
||||
target = torch.tensor(eval_val, dtype=torch.float32)
|
||||
else:
|
||||
target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32))
|
||||
|
||||
return features, target
|
||||
|
||||
def fen_to_features(fen):
|
||||
"""Convert FEN to 768-dimensional binary feature vector."""
|
||||
# Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5
|
||||
piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
|
||||
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}
|
||||
|
||||
features = torch.zeros(768, dtype=torch.float32)
|
||||
|
||||
try:
|
||||
board = chess.Board(fen)
|
||||
|
||||
# 12 piece types × 64 squares = 768
|
||||
for square in chess.SQUARES:
|
||||
piece = board.piece_at(square)
|
||||
if piece is not None:
|
||||
piece_char = piece.symbol()
|
||||
if piece_char in piece_to_idx:
|
||||
piece_idx = piece_to_idx[piece_char]
|
||||
feature_idx = piece_idx * 64 + square
|
||||
features[feature_idx] = 1.0
|
||||
except:
|
||||
pass
|
||||
|
||||
return features
|
||||
|
||||
DEFAULT_HIDDEN_SIZES = [1536, 1024, 512, 256]
|
||||
|
||||
|
||||
class NNUE(nn.Module):
|
||||
"""NNUE neural network with configurable hidden layers.
|
||||
|
||||
Architecture: 768 → hidden_sizes[0] → ... → hidden_sizes[-1] → 1
|
||||
Layer attributes follow the naming l1, l2, ..., lN so export.py can
|
||||
infer the architecture directly from the state_dict.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_sizes=None, dropout_rate=0.2):
|
||||
super().__init__()
|
||||
if hidden_sizes is None:
|
||||
hidden_sizes = DEFAULT_HIDDEN_SIZES
|
||||
self.hidden_sizes = list(hidden_sizes)
|
||||
sizes = [768] + self.hidden_sizes + [1]
|
||||
num_hidden = len(self.hidden_sizes)
|
||||
|
||||
for i in range(num_hidden):
|
||||
setattr(self, f"l{i + 1}", nn.Linear(sizes[i], sizes[i + 1]))
|
||||
setattr(self, f"relu{i + 1}", nn.ReLU())
|
||||
setattr(self, f"drop{i + 1}", nn.Dropout(dropout_rate))
|
||||
setattr(self, f"l{num_hidden + 1}", nn.Linear(sizes[-2], sizes[-1]))
|
||||
self._num_hidden = num_hidden
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(1, self._num_hidden + 1):
|
||||
layer = getattr(self, f"l{i}")
|
||||
relu = getattr(self, f"relu{i}")
|
||||
drop = getattr(self, f"drop{i}")
|
||||
x = drop(relu(layer(x)))
|
||||
return getattr(self, f"l{self._num_hidden + 1}")(x)
|
||||
|
||||
def find_next_version(base_name="nnue_weights"):
|
||||
"""Find the next version number for model versioning.
|
||||
|
||||
Looks for nnue_weights_v*.pt files and returns the next version number.
|
||||
If no versioned files exist, returns 1.
|
||||
"""
|
||||
base_path = Path(base_name)
|
||||
directory = base_path.parent
|
||||
filename = base_path.name
|
||||
|
||||
pattern = re.compile(rf"{re.escape(filename)}_v(\d+)\.pt")
|
||||
versions = []
|
||||
|
||||
for file in directory.glob(f"{filename}_v*.pt"):
|
||||
match = pattern.match(file.name)
|
||||
if match:
|
||||
versions.append(int(match.group(1)))
|
||||
|
||||
if versions:
|
||||
return max(versions) + 1
|
||||
return 1
|
||||
|
||||
def save_metadata(weights_file, metadata):
|
||||
"""Save training metadata alongside the weights file.
|
||||
|
||||
Args:
|
||||
weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt)
|
||||
metadata: Dictionary with training info
|
||||
"""
|
||||
metadata_file = weights_file.replace(".pt", "_metadata.json")
|
||||
|
||||
with open(metadata_file, "w") as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
|
||||
return metadata_file
|
||||
|
||||
def _setup_training(data_file, batch_size, subsample_ratio):
|
||||
"""Set up device, dataset, and data loaders.
|
||||
|
||||
Returns:
|
||||
(device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions)
|
||||
"""
|
||||
print("Checking GPU availability...")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.is_available():
|
||||
print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
|
||||
print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
||||
else:
|
||||
print("⚠ GPU not available, using CPU")
|
||||
print(f"Using device: {device}")
|
||||
print()
|
||||
|
||||
print("Loading dataset...")
|
||||
dataset = NNUEDataset(data_file)
|
||||
num_positions = len(dataset)
|
||||
print(f"Dataset size: {num_positions}")
|
||||
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
|
||||
|
||||
evals_array = np.array(dataset.evals)
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("TRAINING DATASET DIAGNOSTICS")
|
||||
print("=" * 60)
|
||||
print(f"Min evaluation: {evals_array.min():.4f}")
|
||||
print(f"Max evaluation: {evals_array.max():.4f}")
|
||||
print(f"Mean evaluation: {evals_array.mean():.4f}")
|
||||
print(f"Median evaluation: {np.median(evals_array):.4f}")
|
||||
print(f"Std deviation: {evals_array.std():.4f}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
train_size = int(0.9 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
|
||||
from torch.utils.data import random_split, RandomSampler
|
||||
generator = torch.Generator().manual_seed(42)
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
|
||||
|
||||
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
||||
train_sampler = RandomSampler(train_dataset, replacement=False, num_samples=subsample_size)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
)
|
||||
|
||||
return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions
|
||||
|
||||
def _run_training_season(
|
||||
model, optimizer, scheduler, scaler,
|
||||
train_loader, val_loader, train_dataset, val_dataset,
|
||||
device, criterion, output_file,
|
||||
start_epoch, epochs, early_stopping_patience,
|
||||
season_start_time, deadline=None, initial_best_val_loss=float('inf')
|
||||
):
|
||||
"""Run one training season until epoch limit, early stopping, or deadline.
|
||||
|
||||
Args:
|
||||
initial_best_val_loss: Baseline to beat — epochs that don't improve on this count
|
||||
toward early stopping and do not save snapshots.
|
||||
Returns:
|
||||
(best_val_loss, best_model_state, last_epoch)
|
||||
best_model_state is None if no epoch beat initial_best_val_loss.
|
||||
"""
|
||||
best_val_loss = initial_best_val_loss
|
||||
best_model_state = None
|
||||
epochs_without_improvement = 0
|
||||
total_epochs = start_epoch + epochs
|
||||
last_epoch = start_epoch - 1
|
||||
|
||||
for epoch in range(start_epoch, start_epoch + epochs):
|
||||
if deadline and datetime.now() >= deadline:
|
||||
print("Time limit reached, stopping season.")
|
||||
break
|
||||
|
||||
epoch_display = epoch + 1
|
||||
|
||||
# Train
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
|
||||
for batch_features, batch_targets in train_loader:
|
||||
batch_features = batch_features.to(device)
|
||||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
train_loss += loss.item() * batch_features.size(0)
|
||||
pbar.update(1)
|
||||
|
||||
train_loss /= len(train_dataset)
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar:
|
||||
for batch_features, batch_targets in val_loader:
|
||||
batch_features = batch_features.to(device)
|
||||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||
|
||||
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
val_loss += loss.item() * batch_features.size(0)
|
||||
pbar.update(1)
|
||||
|
||||
val_loss /= len(val_dataset)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9
|
||||
gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9
|
||||
print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved")
|
||||
|
||||
elapsed_time = datetime.now() - season_start_time
|
||||
time_per_epoch = elapsed_time.total_seconds() / (epoch + 1)
|
||||
remaining_epochs = total_epochs - epoch_display
|
||||
eta_seconds = time_per_epoch * remaining_epochs
|
||||
eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0]
|
||||
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}")
|
||||
|
||||
checkpoint_file = output_file.replace(".pt", "_checkpoint.pt")
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"scaler_state_dict": scaler.state_dict(),
|
||||
"best_val_loss": best_val_loss,
|
||||
"hidden_sizes": model.hidden_sizes,
|
||||
}, checkpoint_file)
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_model_state = model.state_dict().copy()
|
||||
epochs_without_improvement = 0
|
||||
snapshot_file = output_file.replace(".pt", "_best_snapshot.pt")
|
||||
torch.save(best_model_state, snapshot_file)
|
||||
print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})")
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
|
||||
last_epoch = epoch
|
||||
|
||||
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
|
||||
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
|
||||
break
|
||||
|
||||
return best_val_loss, best_model_state, last_epoch
|
||||
|
||||
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||
best_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, training_start_time, hidden_sizes=None,
|
||||
extra_metadata=None):
|
||||
"""Save the best model with optional versioning and metadata."""
|
||||
final_output_file = output_file
|
||||
metadata = {}
|
||||
architecture = [768] + list(hidden_sizes or DEFAULT_HIDDEN_SIZES) + [1]
|
||||
|
||||
if use_versioning:
|
||||
base_name = output_file.replace(".pt", "")
|
||||
version = find_next_version(base_name)
|
||||
final_output_file = f"{base_name}_v{version}.pt"
|
||||
|
||||
metadata = {
|
||||
"version": version,
|
||||
"date": training_start_time.isoformat(),
|
||||
"num_positions": num_positions,
|
||||
"stockfish_depth": stockfish_depth,
|
||||
"final_val_loss": float(best_val_loss),
|
||||
"architecture": architecture,
|
||||
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
torch.save({
|
||||
"model_state_dict": best_model_state,
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"scaler_state_dict": scaler.state_dict(),
|
||||
"epoch": last_epoch,
|
||||
"best_val_loss": best_val_loss,
|
||||
"hidden_sizes": list(hidden_sizes or DEFAULT_HIDDEN_SIZES),
|
||||
}, final_output_file)
|
||||
print(f"Best model saved to {final_output_file}")
|
||||
|
||||
if use_versioning and metadata:
|
||||
metadata_file = save_metadata(final_output_file, metadata)
|
||||
print(f"Metadata saved to {metadata_file}")
|
||||
print(f"\nTraining Summary:")
|
||||
for key, val in metadata.items():
|
||||
print(f" {key}: {val}")
|
||||
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
|
||||
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
|
||||
|
||||
Args:
|
||||
data_file: Path to training_data.jsonl
|
||||
output_file: Where to save best weights (or base name if use_versioning=True)
|
||||
epochs: Number of training epochs (default: 100)
|
||||
batch_size: Training batch size (default: 16384)
|
||||
lr: Learning rate (default: 0.001)
|
||||
checkpoint: Optional path to checkpoint file to resume from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
|
||||
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
|
||||
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
|
||||
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
|
||||
"""
|
||||
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||
_setup_training(data_file, batch_size, subsample_ratio)
|
||||
|
||||
start_epoch = 0
|
||||
best_val_loss = float('inf')
|
||||
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
|
||||
|
||||
if checkpoint:
|
||||
print(f"Loading checkpoint: {checkpoint}")
|
||||
ckpt = torch.load(checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
ckpt_hidden = ckpt.get("hidden_sizes")
|
||||
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
|
||||
print(f" Using architecture from checkpoint: {ckpt_hidden}")
|
||||
resolved_hidden_sizes = ckpt_hidden
|
||||
|
||||
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||
|
||||
if checkpoint:
|
||||
ckpt = torch.load(checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
||||
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
||||
scaler.load_state_dict(ckpt["scaler_state_dict"])
|
||||
start_epoch = ckpt["epoch"] + 1
|
||||
best_val_loss = ckpt.get("best_val_loss", float('inf'))
|
||||
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
print("Loaded weights-only checkpoint (no optimizer state)")
|
||||
|
||||
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
|
||||
|
||||
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
||||
arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
|
||||
print(f"Architecture: {arch_str}")
|
||||
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
|
||||
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
|
||||
print(f"Mixed precision training: enabled")
|
||||
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
|
||||
if subsample_ratio < 1.0:
|
||||
print(f"Stochastic sampling: {subsample_ratio:.0%} of train set per epoch ({subsample_size:,} positions)")
|
||||
if early_stopping_patience:
|
||||
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
|
||||
print()
|
||||
|
||||
training_start_time = datetime.now()
|
||||
|
||||
best_val_loss, best_model_state, last_epoch = _run_training_season(
|
||||
model, optimizer, scheduler, scaler,
|
||||
train_loader, val_loader, train_dataset, val_dataset,
|
||||
device, criterion, output_file,
|
||||
start_epoch, epochs, early_stopping_patience,
|
||||
training_start_time
|
||||
)
|
||||
|
||||
if best_model_state is None or best_val_loss >= checkpoint_val_loss:
|
||||
print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})")
|
||||
print("No new model created.")
|
||||
return
|
||||
|
||||
_save_versioned_model(
|
||||
best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||
best_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, training_start_time,
|
||||
hidden_sizes=resolved_hidden_sizes,
|
||||
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
|
||||
"checkpoint": str(checkpoint) if checkpoint else None}
|
||||
)
|
||||
|
||||
def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
||||
epochs_per_season=50, early_stopping_patience=10,
|
||||
batch_size=16384, lr=0.001, initial_checkpoint=None,
|
||||
stockfish_depth=12, use_versioning=True,
|
||||
weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
|
||||
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
|
||||
|
||||
Each season trains with early stopping. When early stopping fires, the model reloads the
|
||||
global best weights and begins a fresh season with a reset optimizer and scheduler.
|
||||
This prevents the model from drifting away from its best known state.
|
||||
|
||||
Args:
|
||||
data_file: Path to training_data.jsonl
|
||||
output_file: Output file base name
|
||||
duration_minutes: Total training budget in minutes
|
||||
epochs_per_season: Max epochs per restart season (default: 50)
|
||||
early_stopping_patience: Patience for early stopping within each season (default: 10)
|
||||
batch_size: Training batch size (default: 16384)
|
||||
lr: Learning rate reset to this value at the start of each season (default: 0.001)
|
||||
initial_checkpoint: Optional weights-only .pt file to start from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
weight_decay: L2 regularization strength (default: 1e-4)
|
||||
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
|
||||
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
|
||||
"""
|
||||
deadline = datetime.now() + timedelta(minutes=duration_minutes)
|
||||
|
||||
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||
_setup_training(data_file, batch_size, subsample_ratio)
|
||||
|
||||
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
|
||||
|
||||
if initial_checkpoint:
|
||||
print(f"Loading initial weights: {initial_checkpoint}")
|
||||
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
ckpt_hidden = ckpt.get("hidden_sizes")
|
||||
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
|
||||
print(f" Using architecture from checkpoint: {ckpt_hidden}")
|
||||
resolved_hidden_sizes = ckpt_hidden
|
||||
|
||||
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
|
||||
criterion = nn.MSELoss()
|
||||
best_global_val_loss = float('inf')
|
||||
|
||||
if initial_checkpoint:
|
||||
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_global_val_loss = ckpt.get("best_val_loss", float('inf'))
|
||||
if best_global_val_loss < float('inf'):
|
||||
print(f"Resumed from checkpoint (best val loss: {best_global_val_loss:.6f})")
|
||||
else:
|
||||
print("Initial weights loaded (no val loss in checkpoint).")
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
print("Loaded weights-only checkpoint (no val loss info).")
|
||||
|
||||
arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
|
||||
print(f"Architecture: {arch_str}")
|
||||
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
|
||||
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
|
||||
print()
|
||||
|
||||
burst_start_time = datetime.now()
|
||||
season = 0
|
||||
best_global_state = None
|
||||
last_optimizer = None
|
||||
last_scheduler = None
|
||||
last_scaler = None
|
||||
last_epoch = 0
|
||||
|
||||
while datetime.now() < deadline:
|
||||
season += 1
|
||||
remaining_minutes = (deadline - datetime.now()).total_seconds() / 60
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"BURST SEASON {season} | {remaining_minutes:.1f} minutes remaining")
|
||||
if best_global_val_loss < float('inf'):
|
||||
print(f"Global best val loss so far: {best_global_val_loss:.6f}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs_per_season)
|
||||
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||
|
||||
season_start_time = datetime.now()
|
||||
val_loss, model_state, last_epoch = _run_training_season(
|
||||
model, optimizer, scheduler, scaler,
|
||||
train_loader, val_loader, train_dataset, val_dataset,
|
||||
device, criterion, output_file,
|
||||
0, epochs_per_season, early_stopping_patience,
|
||||
season_start_time, deadline=deadline,
|
||||
initial_best_val_loss=best_global_val_loss
|
||||
)
|
||||
|
||||
last_optimizer = optimizer
|
||||
last_scheduler = scheduler
|
||||
last_scaler = scaler
|
||||
|
||||
if model_state is not None and val_loss < best_global_val_loss:
|
||||
best_global_val_loss = val_loss
|
||||
best_global_state = model_state
|
||||
print(f" New global best: {best_global_val_loss:.6f} (season {season})")
|
||||
|
||||
# Reload global best for the next season so we never drift backwards
|
||||
if best_global_state is not None:
|
||||
model.load_state_dict(best_global_state)
|
||||
|
||||
total_minutes = (datetime.now() - burst_start_time).total_seconds() / 60
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Burst training complete: {season} season(s) in {total_minutes:.1f}m")
|
||||
print(f"Best val loss: {best_global_val_loss:.6f}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
if best_global_state is None:
|
||||
print("No model improvement found. No file saved.")
|
||||
return
|
||||
|
||||
_save_versioned_model(
|
||||
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
|
||||
best_global_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, burst_start_time,
|
||||
hidden_sizes=resolved_hidden_sizes,
|
||||
extra_metadata={
|
||||
"mode": "burst",
|
||||
"duration_minutes": duration_minutes,
|
||||
"epochs_per_season": epochs_per_season,
|
||||
"early_stopping_patience": early_stopping_patience,
|
||||
"seasons_completed": season,
|
||||
"batch_size": batch_size,
|
||||
"learning_rate": lr,
|
||||
"initial_checkpoint": str(initial_checkpoint) if initial_checkpoint else None,
|
||||
}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation")
|
||||
parser.add_argument("data_file", nargs="?", default="training_data.jsonl",
|
||||
help="Path to training_data.jsonl (default: training_data.jsonl)")
|
||||
parser.add_argument("output_file", nargs="?", default="nnue_weights.pt",
|
||||
help="Output file base name (default: nnue_weights.pt)")
|
||||
parser.add_argument("--checkpoint", type=str, default=None,
|
||||
help="Path to checkpoint file to resume training from (optional)")
|
||||
parser.add_argument("--epochs", type=int, default=100,
|
||||
help="Number of epochs to train (default: 100)")
|
||||
parser.add_argument("--batch-size", type=int, default=16384,
|
||||
help="Batch size (default: 16384)")
|
||||
parser.add_argument("--lr", type=float, default=0.001,
|
||||
help="Learning rate (default: 0.001)")
|
||||
parser.add_argument("--early-stopping", type=int, default=None,
|
||||
help="Stop if val loss doesn't improve for N epochs (optional)")
|
||||
parser.add_argument("--stockfish-depth", type=int, default=12,
|
||||
help="Stockfish depth used for evaluations (for metadata, default: 12)")
|
||||
parser.add_argument("--no-versioning", action="store_true",
|
||||
help="Disable automatic versioning (save directly to output file)")
|
||||
parser.add_argument("--weight-decay", type=float, default=5e-5,
|
||||
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
|
||||
parser.add_argument("--subsample-ratio", type=float, default=1.0,
|
||||
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
|
||||
parser.add_argument("--hidden-layers", type=str, default=None,
|
||||
help="Comma-separated hidden layer sizes (default: 1536,1024,512,256)")
|
||||
|
||||
# Burst mode
|
||||
parser.add_argument("--burst-duration", type=float, default=None,
|
||||
help="Enable burst mode: total training budget in minutes")
|
||||
parser.add_argument("--epochs-per-season", type=int, default=50,
|
||||
help="Max epochs per burst season before restarting (default: 50, burst mode only)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
hidden_sizes = [int(x) for x in args.hidden_layers.split(",")] if args.hidden_layers else None
|
||||
|
||||
if args.burst_duration is not None:
|
||||
burst_train(
|
||||
data_file=args.data_file,
|
||||
output_file=args.output_file,
|
||||
duration_minutes=args.burst_duration,
|
||||
epochs_per_season=args.epochs_per_season,
|
||||
early_stopping_patience=args.early_stopping or 10,
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
initial_checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning,
|
||||
weight_decay=args.weight_decay,
|
||||
subsample_ratio=args.subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
else:
|
||||
train_nnue(
|
||||
data_file=args.data_file,
|
||||
output_file=args.output_file,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning,
|
||||
early_stopping_patience=args.early_stopping,
|
||||
weight_decay=args.weight_decay,
|
||||
subsample_ratio=args.subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
Reference in New Issue
Block a user