Files
NowChessSystems/modules/bot/python/src/train.py
T

356 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Train NNUE neural network for chess evaluation."""
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import sys
from pathlib import Path
from tqdm import tqdm
import chess
from datetime import datetime
import re
import numpy as np
class NNUEDataset(Dataset):
"""Dataset of chess positions with evaluations."""
def __init__(self, data_file):
self.positions = []
self.evals = []
self.evals_raw = []
self.is_normalized = None
with open(data_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
fen = data['fen']
eval_val = data['eval']
self.positions.append(fen)
self.evals.append(eval_val)
# Check if normalized or raw
if self.is_normalized is None:
# If eval is in range [-1, 1], assume normalized
self.is_normalized = abs(eval_val) <= 1.0
# Store raw if available
if 'eval_raw' in data:
self.evals_raw.append(data['eval_raw'])
else:
self.evals_raw.append(eval_val)
except (json.JSONDecodeError, KeyError):
pass
def __len__(self):
return len(self.positions)
def __getitem__(self, idx):
fen = self.positions[idx]
eval_val = self.evals[idx]
features = fen_to_features(fen)
# Use evaluation as-is if normalized, otherwise apply sigmoid scaling
if self.is_normalized:
target = torch.tensor(eval_val, dtype=torch.float32)
else:
target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32))
return features, target
def fen_to_features(fen):
"""Convert FEN to 768-dimensional binary feature vector."""
# Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5
piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}
features = torch.zeros(768, dtype=torch.float32)
try:
board = chess.Board(fen)
# 12 piece types × 64 squares = 768
for square in chess.SQUARES:
piece = board.piece_at(square)
if piece is not None:
piece_char = piece.symbol()
if piece_char in piece_to_idx:
piece_idx = piece_to_idx[piece_char]
feature_idx = piece_idx * 64 + square
features[feature_idx] = 1.0
except:
pass
return features
class NNUE(nn.Module):
"""NNUE neural network architecture."""
def __init__(self):
super().__init__()
self.l1 = nn.Linear(768, 256)
self.relu1 = nn.ReLU()
self.l2 = nn.Linear(256, 32)
self.relu2 = nn.ReLU()
self.l3 = nn.Linear(32, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.l1(x)
x = self.relu1(x)
x = self.l2(x)
x = self.relu2(x)
x = self.l3(x)
return x
def find_next_version(base_name="nnue_weights"):
"""Find the next version number for model versioning.
Looks for nnue_weights_v*.pt files and returns the next version number.
If no versioned files exist, returns 1.
"""
base_path = Path(base_name)
directory = base_path.parent
filename = base_path.name
pattern = re.compile(rf"{re.escape(filename)}_v(\d+)\.pt")
versions = []
for file in directory.glob(f"{filename}_v*.pt"):
match = pattern.match(file.name)
if match:
versions.append(int(match.group(1)))
if versions:
return max(versions) + 1
return 1
def save_metadata(weights_file, metadata):
"""Save training metadata alongside the weights file.
Args:
weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt)
metadata: Dictionary with training info
"""
metadata_file = weights_file.replace(".pt", "_metadata.json")
with open(metadata_file, "w") as f:
json.dump(metadata, f, indent=2, default=str)
return metadata_file
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None):
"""Train the NNUE model.
Args:
data_file: Path to training_data.jsonl
output_file: Where to save best weights (or base name if use_versioning=True)
epochs: Number of training epochs
batch_size: Training batch size
lr: Learning rate
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
"""
print("Loading dataset...")
dataset = NNUEDataset(data_file)
num_positions = len(dataset)
print(f"Dataset size: {num_positions}")
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
# Print dataset diagnostics
evals_array = np.array(dataset.evals)
print()
print("=" * 60)
print("TRAINING DATASET DIAGNOSTICS")
print("=" * 60)
print(f"Min evaluation: {evals_array.min():.4f}")
print(f"Max evaluation: {evals_array.max():.4f}")
print(f"Mean evaluation: {evals_array.mean():.4f}")
print(f"Median evaluation: {np.median(evals_array):.4f}")
print(f"Std deviation: {evals_array.std():.4f}")
print("=" * 60)
print()
# Split 90% train, 10% validation
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
from torch.utils.data import random_split
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Model
model = NNUE().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
# Load checkpoint if provided
checkpoint_to_load = checkpoint
if checkpoint_to_load is None and Path(output_file).exists():
# Auto-detect checkpoint: if output file already exists, use it as checkpoint
checkpoint_to_load = output_file
start_epoch = 0
if checkpoint_to_load is not None and Path(checkpoint_to_load).exists():
print(f"Loading checkpoint from {checkpoint_to_load}...")
try:
checkpoint_state = torch.load(checkpoint_to_load, map_location=device)
model.load_state_dict(checkpoint_state)
print(f"✓ Checkpoint loaded successfully")
except Exception as e:
print(f"Warning: Could not load checkpoint: {e}")
print("Training from scratch instead")
best_val_loss = float('inf')
best_model_state = None
epochs_without_improvement = 0
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...")
if early_stopping_patience:
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
print()
training_start_time = datetime.now()
for epoch in range(start_epoch, start_epoch + epochs):
# Train
model.train()
train_loss = 0.0
epoch_display = epoch + 1
total_epochs = start_epoch + epochs
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
for batch_features, batch_targets in train_loader:
batch_features = batch_features.to(device)
batch_targets = batch_targets.to(device).unsqueeze(1)
optimizer.zero_grad()
outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
loss.backward()
optimizer.step()
train_loss += loss.item() * batch_features.size(0)
pbar.update(1)
train_loss /= len(train_dataset)
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar:
for batch_features, batch_targets in val_loader:
batch_features = batch_features.to(device)
batch_targets = batch_targets.to(device).unsqueeze(1)
outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
val_loss += loss.item() * batch_features.size(0)
pbar.update(1)
val_loss /= len(val_dataset)
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = model.state_dict().copy()
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
# Early stopping check
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
break
# Save best model
if best_model_state is not None:
# Determine final output file with versioning
final_output_file = output_file
metadata = {}
if use_versioning:
base_name = output_file.replace(".pt", "")
version = find_next_version(base_name)
final_output_file = f"{base_name}_v{version}.pt"
# Prepare metadata
metadata = {
"version": version,
"date": training_start_time.isoformat(),
"num_positions": num_positions,
"stockfish_depth": stockfish_depth,
"epochs": epochs,
"batch_size": batch_size,
"learning_rate": lr,
"final_val_loss": float(best_val_loss),
"device": str(device),
"checkpoint": str(checkpoint) if checkpoint else None,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
torch.save(best_model_state, final_output_file)
print(f"Best model saved to {final_output_file}")
# Save metadata if versioning is enabled
if use_versioning and metadata:
metadata_file = save_metadata(final_output_file, metadata)
print(f"Metadata saved to {metadata_file}")
print(f"\nTraining Summary:")
print(f" Version: v{metadata['version']}")
print(f" Positions: {metadata['num_positions']}")
print(f" Stockfish depth: {metadata['stockfish_depth']}")
print(f" Epochs: {metadata['epochs']}")
print(f" Final validation loss: {metadata['final_val_loss']:.6f}")
print(f" Device: {metadata['device']}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation")
parser.add_argument("data_file", nargs="?", default="training_data.jsonl",
help="Path to training_data.jsonl (default: training_data.jsonl)")
parser.add_argument("output_file", nargs="?", default="nnue_weights.pt",
help="Output file base name (default: nnue_weights.pt)")
parser.add_argument("--checkpoint", type=str, default=None,
help="Path to checkpoint file to resume training from (optional)")
parser.add_argument("--epochs", type=int, default=20,
help="Number of epochs to train (default: 20)")
parser.add_argument("--batch-size", type=int, default=4096,
help="Batch size (default: 4096)")
parser.add_argument("--lr", type=float, default=1e-3,
help="Learning rate (default: 1e-3)")
parser.add_argument("--early-stopping", type=int, default=None,
help="Stop if val loss doesn't improve for N epochs (optional)")
parser.add_argument("--stockfish-depth", type=int, default=12,
help="Stockfish depth used for evaluations (for metadata, default: 12)")
parser.add_argument("--no-versioning", action="store_true",
help="Disable automatic versioning (save directly to output file)")
args = parser.parse_args()
train_nnue(
data_file=args.data_file,
output_file=args.output_file,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning,
early_stopping_patience=args.early_stopping
)