Files
NowChessSystems/modules/bot/python/train_nnue.py
T

302 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Train NNUE neural network for chess evaluation."""
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import sys
from pathlib import Path
from tqdm import tqdm
import chess
from datetime import datetime
import re
class NNUEDataset(Dataset):
"""Dataset of chess positions with evaluations."""
def __init__(self, data_file):
self.positions = []
self.evals = []
with open(data_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
fen = data['fen']
eval_cp = data['eval']
self.positions.append(fen)
self.evals.append(eval_cp)
except (json.JSONDecodeError, KeyError):
pass
def __len__(self):
return len(self.positions)
def __getitem__(self, idx):
fen = self.positions[idx]
eval_cp = self.evals[idx]
features = fen_to_features(fen)
target = torch.sigmoid(torch.tensor(eval_cp / 400.0, dtype=torch.float32))
return features, target
def fen_to_features(fen):
"""Convert FEN to 768-dimensional binary feature vector."""
# Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5
piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}
features = torch.zeros(768, dtype=torch.float32)
try:
board = chess.Board(fen)
# 12 piece types × 64 squares = 768
for square in chess.SQUARES:
piece = board.piece_at(square)
if piece is not None:
piece_char = piece.symbol()
if piece_char in piece_to_idx:
piece_idx = piece_to_idx[piece_char]
feature_idx = piece_idx * 64 + square
features[feature_idx] = 1.0
except:
pass
return features
class NNUE(nn.Module):
"""NNUE neural network architecture."""
def __init__(self):
super().__init__()
self.l1 = nn.Linear(768, 256)
self.relu1 = nn.ReLU()
self.l2 = nn.Linear(256, 32)
self.relu2 = nn.ReLU()
self.l3 = nn.Linear(32, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.l1(x)
x = self.relu1(x)
x = self.l2(x)
x = self.relu2(x)
x = self.l3(x)
return x
def find_next_version(base_name="nnue_weights"):
"""Find the next version number for model versioning.
Looks for nnue_weights_v*.pt files and returns the next version number.
If no versioned files exist, returns 1.
"""
pattern = re.compile(rf"{re.escape(base_name)}_v(\d+)\.pt")
versions = []
for file in Path(".").glob(f"{base_name}_v*.pt"):
match = pattern.match(file.name)
if match:
versions.append(int(match.group(1)))
if versions:
return max(versions) + 1
return 1
def save_metadata(weights_file, metadata):
"""Save training metadata alongside the weights file.
Args:
weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt)
metadata: Dictionary with training info
"""
metadata_file = weights_file.replace(".pt", "_metadata.json")
with open(metadata_file, "w") as f:
json.dump(metadata, f, indent=2, default=str)
return metadata_file
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True):
"""Train the NNUE model.
Args:
data_file: Path to training_data.jsonl
output_file: Where to save best weights (or base name if use_versioning=True)
epochs: Number of training epochs
batch_size: Training batch size
lr: Learning rate
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
"""
print("Loading dataset...")
dataset = NNUEDataset(data_file)
num_positions = len(dataset)
print(f"Dataset size: {num_positions}")
# Split 90% train, 10% validation
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
from torch.utils.data import random_split
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Model
model = NNUE().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
# Load checkpoint if provided
checkpoint_to_load = checkpoint
if checkpoint_to_load is None and Path(output_file).exists():
# Auto-detect checkpoint: if output file already exists, use it as checkpoint
checkpoint_to_load = output_file
start_epoch = 0
if checkpoint_to_load is not None and Path(checkpoint_to_load).exists():
print(f"Loading checkpoint from {checkpoint_to_load}...")
try:
checkpoint_state = torch.load(checkpoint_to_load, map_location=device)
model.load_state_dict(checkpoint_state)
print(f"✓ Checkpoint loaded successfully")
except Exception as e:
print(f"Warning: Could not load checkpoint: {e}")
print("Training from scratch instead")
best_val_loss = float('inf')
best_model_state = None
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...")
print()
training_start_time = datetime.now()
for epoch in range(start_epoch, start_epoch + epochs):
# Train
model.train()
train_loss = 0.0
epoch_display = epoch + 1
total_epochs = start_epoch + epochs
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
for batch_features, batch_targets in train_loader:
batch_features = batch_features.to(device)
batch_targets = batch_targets.to(device).unsqueeze(1)
optimizer.zero_grad()
outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
loss.backward()
optimizer.step()
train_loss += loss.item() * batch_features.size(0)
pbar.update(1)
train_loss /= len(train_dataset)
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar:
for batch_features, batch_targets in val_loader:
batch_features = batch_features.to(device)
batch_targets = batch_targets.to(device).unsqueeze(1)
outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
val_loss += loss.item() * batch_features.size(0)
pbar.update(1)
val_loss /= len(val_dataset)
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = model.state_dict().copy()
# Save best model
if best_model_state is not None:
# Determine final output file with versioning
final_output_file = output_file
metadata = {}
if use_versioning:
base_name = output_file.replace(".pt", "")
version = find_next_version(base_name)
final_output_file = f"{base_name}_v{version}.pt"
# Prepare metadata
metadata = {
"version": version,
"date": training_start_time.isoformat(),
"num_positions": num_positions,
"stockfish_depth": stockfish_depth,
"epochs": epochs,
"batch_size": batch_size,
"learning_rate": lr,
"final_val_loss": float(best_val_loss),
"device": str(device),
"checkpoint": str(checkpoint) if checkpoint else None,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
torch.save(best_model_state, final_output_file)
print(f"Best model saved to {final_output_file}")
# Save metadata if versioning is enabled
if use_versioning and metadata:
metadata_file = save_metadata(final_output_file, metadata)
print(f"Metadata saved to {metadata_file}")
print(f"\nTraining Summary:")
print(f" Version: v{metadata['version']}")
print(f" Positions: {metadata['num_positions']}")
print(f" Stockfish depth: {metadata['stockfish_depth']}")
print(f" Epochs: {metadata['epochs']}")
print(f" Final validation loss: {metadata['final_val_loss']:.6f}")
print(f" Device: {metadata['device']}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation")
parser.add_argument("data_file", nargs="?", default="training_data.jsonl",
help="Path to training_data.jsonl (default: training_data.jsonl)")
parser.add_argument("output_file", nargs="?", default="nnue_weights.pt",
help="Output file base name (default: nnue_weights.pt)")
parser.add_argument("--checkpoint", type=str, default=None,
help="Path to checkpoint file to resume training from (optional)")
parser.add_argument("--epochs", type=int, default=20,
help="Number of epochs to train (default: 20)")
parser.add_argument("--batch-size", type=int, default=4096,
help="Batch size (default: 4096)")
parser.add_argument("--lr", type=float, default=1e-3,
help="Learning rate (default: 1e-3)")
parser.add_argument("--stockfish-depth", type=int, default=12,
help="Stockfish depth used for evaluations (for metadata, default: 12)")
parser.add_argument("--no-versioning", action="store_true",
help="Disable automatic versioning (save directly to output file)")
args = parser.parse_args()
train_nnue(
data_file=args.data_file,
output_file=args.output_file,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning
)