302 lines
11 KiB
Python
302 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
"""Train NNUE neural network for chess evaluation."""
|
||
|
||
import json
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
from torch.utils.data import DataLoader, Dataset
|
||
import sys
|
||
from pathlib import Path
|
||
from tqdm import tqdm
|
||
import chess
|
||
from datetime import datetime
|
||
import re
|
||
|
||
class NNUEDataset(Dataset):
|
||
"""Dataset of chess positions with evaluations."""
|
||
|
||
def __init__(self, data_file):
|
||
self.positions = []
|
||
self.evals = []
|
||
|
||
with open(data_file, 'r') as f:
|
||
for line in f:
|
||
try:
|
||
data = json.loads(line)
|
||
fen = data['fen']
|
||
eval_cp = data['eval']
|
||
self.positions.append(fen)
|
||
self.evals.append(eval_cp)
|
||
except (json.JSONDecodeError, KeyError):
|
||
pass
|
||
|
||
def __len__(self):
|
||
return len(self.positions)
|
||
|
||
def __getitem__(self, idx):
|
||
fen = self.positions[idx]
|
||
eval_cp = self.evals[idx]
|
||
features = fen_to_features(fen)
|
||
target = torch.sigmoid(torch.tensor(eval_cp / 400.0, dtype=torch.float32))
|
||
return features, target
|
||
|
||
def fen_to_features(fen):
|
||
"""Convert FEN to 768-dimensional binary feature vector."""
|
||
# Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5
|
||
piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
|
||
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}
|
||
|
||
features = torch.zeros(768, dtype=torch.float32)
|
||
|
||
try:
|
||
board = chess.Board(fen)
|
||
|
||
# 12 piece types × 64 squares = 768
|
||
for square in chess.SQUARES:
|
||
piece = board.piece_at(square)
|
||
if piece is not None:
|
||
piece_char = piece.symbol()
|
||
if piece_char in piece_to_idx:
|
||
piece_idx = piece_to_idx[piece_char]
|
||
feature_idx = piece_idx * 64 + square
|
||
features[feature_idx] = 1.0
|
||
except:
|
||
pass
|
||
|
||
return features
|
||
|
||
class NNUE(nn.Module):
|
||
"""NNUE neural network architecture."""
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.l1 = nn.Linear(768, 256)
|
||
self.relu1 = nn.ReLU()
|
||
self.l2 = nn.Linear(256, 32)
|
||
self.relu2 = nn.ReLU()
|
||
self.l3 = nn.Linear(32, 1)
|
||
self.sigmoid = nn.Sigmoid()
|
||
|
||
def forward(self, x):
|
||
x = self.l1(x)
|
||
x = self.relu1(x)
|
||
x = self.l2(x)
|
||
x = self.relu2(x)
|
||
x = self.l3(x)
|
||
return x
|
||
|
||
def find_next_version(base_name="nnue_weights"):
|
||
"""Find the next version number for model versioning.
|
||
|
||
Looks for nnue_weights_v*.pt files and returns the next version number.
|
||
If no versioned files exist, returns 1.
|
||
"""
|
||
pattern = re.compile(rf"{re.escape(base_name)}_v(\d+)\.pt")
|
||
versions = []
|
||
|
||
for file in Path(".").glob(f"{base_name}_v*.pt"):
|
||
match = pattern.match(file.name)
|
||
if match:
|
||
versions.append(int(match.group(1)))
|
||
|
||
if versions:
|
||
return max(versions) + 1
|
||
return 1
|
||
|
||
def save_metadata(weights_file, metadata):
|
||
"""Save training metadata alongside the weights file.
|
||
|
||
Args:
|
||
weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt)
|
||
metadata: Dictionary with training info
|
||
"""
|
||
metadata_file = weights_file.replace(".pt", "_metadata.json")
|
||
|
||
with open(metadata_file, "w") as f:
|
||
json.dump(metadata, f, indent=2, default=str)
|
||
|
||
return metadata_file
|
||
|
||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True):
|
||
"""Train the NNUE model.
|
||
|
||
Args:
|
||
data_file: Path to training_data.jsonl
|
||
output_file: Where to save best weights (or base name if use_versioning=True)
|
||
epochs: Number of training epochs
|
||
batch_size: Training batch size
|
||
lr: Learning rate
|
||
checkpoint: Optional path to checkpoint file to resume from
|
||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||
"""
|
||
|
||
print("Loading dataset...")
|
||
dataset = NNUEDataset(data_file)
|
||
num_positions = len(dataset)
|
||
print(f"Dataset size: {num_positions}")
|
||
|
||
# Split 90% train, 10% validation
|
||
train_size = int(0.9 * len(dataset))
|
||
val_size = len(dataset) - train_size
|
||
|
||
from torch.utils.data import random_split
|
||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||
|
||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||
|
||
# Device
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"Using device: {device}")
|
||
|
||
# Model
|
||
model = NNUE().to(device)
|
||
criterion = nn.MSELoss()
|
||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||
|
||
# Load checkpoint if provided
|
||
checkpoint_to_load = checkpoint
|
||
if checkpoint_to_load is None and Path(output_file).exists():
|
||
# Auto-detect checkpoint: if output file already exists, use it as checkpoint
|
||
checkpoint_to_load = output_file
|
||
|
||
start_epoch = 0
|
||
if checkpoint_to_load is not None and Path(checkpoint_to_load).exists():
|
||
print(f"Loading checkpoint from {checkpoint_to_load}...")
|
||
try:
|
||
checkpoint_state = torch.load(checkpoint_to_load, map_location=device)
|
||
model.load_state_dict(checkpoint_state)
|
||
print(f"✓ Checkpoint loaded successfully")
|
||
except Exception as e:
|
||
print(f"Warning: Could not load checkpoint: {e}")
|
||
print("Training from scratch instead")
|
||
|
||
best_val_loss = float('inf')
|
||
best_model_state = None
|
||
|
||
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...")
|
||
print()
|
||
|
||
training_start_time = datetime.now()
|
||
|
||
for epoch in range(start_epoch, start_epoch + epochs):
|
||
# Train
|
||
model.train()
|
||
train_loss = 0.0
|
||
epoch_display = epoch + 1
|
||
total_epochs = start_epoch + epochs
|
||
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
|
||
for batch_features, batch_targets in train_loader:
|
||
batch_features = batch_features.to(device)
|
||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||
|
||
optimizer.zero_grad()
|
||
outputs = model(batch_features)
|
||
loss = criterion(outputs, batch_targets)
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
train_loss += loss.item() * batch_features.size(0)
|
||
pbar.update(1)
|
||
|
||
train_loss /= len(train_dataset)
|
||
|
||
# Validation
|
||
model.eval()
|
||
val_loss = 0.0
|
||
with torch.no_grad():
|
||
with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar:
|
||
for batch_features, batch_targets in val_loader:
|
||
batch_features = batch_features.to(device)
|
||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||
|
||
outputs = model(batch_features)
|
||
loss = criterion(outputs, batch_targets)
|
||
val_loss += loss.item() * batch_features.size(0)
|
||
pbar.update(1)
|
||
|
||
val_loss /= len(val_dataset)
|
||
|
||
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
|
||
|
||
if val_loss < best_val_loss:
|
||
best_val_loss = val_loss
|
||
best_model_state = model.state_dict().copy()
|
||
|
||
# Save best model
|
||
if best_model_state is not None:
|
||
# Determine final output file with versioning
|
||
final_output_file = output_file
|
||
metadata = {}
|
||
|
||
if use_versioning:
|
||
base_name = output_file.replace(".pt", "")
|
||
version = find_next_version(base_name)
|
||
final_output_file = f"{base_name}_v{version}.pt"
|
||
|
||
# Prepare metadata
|
||
metadata = {
|
||
"version": version,
|
||
"date": training_start_time.isoformat(),
|
||
"num_positions": num_positions,
|
||
"stockfish_depth": stockfish_depth,
|
||
"epochs": epochs,
|
||
"batch_size": batch_size,
|
||
"learning_rate": lr,
|
||
"final_val_loss": float(best_val_loss),
|
||
"device": str(device),
|
||
"checkpoint": str(checkpoint) if checkpoint else None,
|
||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||
}
|
||
|
||
torch.save(best_model_state, final_output_file)
|
||
print(f"Best model saved to {final_output_file}")
|
||
|
||
# Save metadata if versioning is enabled
|
||
if use_versioning and metadata:
|
||
metadata_file = save_metadata(final_output_file, metadata)
|
||
print(f"Metadata saved to {metadata_file}")
|
||
print(f"\nTraining Summary:")
|
||
print(f" Version: v{metadata['version']}")
|
||
print(f" Positions: {metadata['num_positions']}")
|
||
print(f" Stockfish depth: {metadata['stockfish_depth']}")
|
||
print(f" Epochs: {metadata['epochs']}")
|
||
print(f" Final validation loss: {metadata['final_val_loss']:.6f}")
|
||
print(f" Device: {metadata['device']}")
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation")
|
||
parser.add_argument("data_file", nargs="?", default="training_data.jsonl",
|
||
help="Path to training_data.jsonl (default: training_data.jsonl)")
|
||
parser.add_argument("output_file", nargs="?", default="nnue_weights.pt",
|
||
help="Output file base name (default: nnue_weights.pt)")
|
||
parser.add_argument("--checkpoint", type=str, default=None,
|
||
help="Path to checkpoint file to resume training from (optional)")
|
||
parser.add_argument("--epochs", type=int, default=20,
|
||
help="Number of epochs to train (default: 20)")
|
||
parser.add_argument("--batch-size", type=int, default=4096,
|
||
help="Batch size (default: 4096)")
|
||
parser.add_argument("--lr", type=float, default=1e-3,
|
||
help="Learning rate (default: 1e-3)")
|
||
parser.add_argument("--stockfish-depth", type=int, default=12,
|
||
help="Stockfish depth used for evaluations (for metadata, default: 12)")
|
||
parser.add_argument("--no-versioning", action="store_true",
|
||
help="Disable automatic versioning (save directly to output file)")
|
||
|
||
args = parser.parse_args()
|
||
|
||
train_nnue(
|
||
data_file=args.data_file,
|
||
output_file=args.output_file,
|
||
epochs=args.epochs,
|
||
batch_size=args.batch_size,
|
||
lr=args.lr,
|
||
checkpoint=args.checkpoint,
|
||
stockfish_depth=args.stockfish_depth,
|
||
use_versioning=not args.no_versioning
|
||
)
|