#!/usr/bin/env python3 """Train NNUE neural network for chess evaluation.""" import json import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset import sys from pathlib import Path from tqdm import tqdm import chess from datetime import datetime import re class NNUEDataset(Dataset): """Dataset of chess positions with evaluations.""" def __init__(self, data_file): self.positions = [] self.evals = [] with open(data_file, 'r') as f: for line in f: try: data = json.loads(line) fen = data['fen'] eval_cp = data['eval'] self.positions.append(fen) self.evals.append(eval_cp) except (json.JSONDecodeError, KeyError): pass def __len__(self): return len(self.positions) def __getitem__(self, idx): fen = self.positions[idx] eval_cp = self.evals[idx] features = fen_to_features(fen) target = torch.sigmoid(torch.tensor(eval_cp / 400.0, dtype=torch.float32)) return features, target def fen_to_features(fen): """Convert FEN to 768-dimensional binary feature vector.""" # Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5 piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5, 'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11} features = torch.zeros(768, dtype=torch.float32) try: board = chess.Board(fen) # 12 piece types × 64 squares = 768 for square in chess.SQUARES: piece = board.piece_at(square) if piece is not None: piece_char = piece.symbol() if piece_char in piece_to_idx: piece_idx = piece_to_idx[piece_char] feature_idx = piece_idx * 64 + square features[feature_idx] = 1.0 except: pass return features class NNUE(nn.Module): """NNUE neural network architecture.""" def __init__(self): super().__init__() self.l1 = nn.Linear(768, 256) self.relu1 = nn.ReLU() self.l2 = nn.Linear(256, 32) self.relu2 = nn.ReLU() self.l3 = nn.Linear(32, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.l1(x) x = self.relu1(x) x = self.l2(x) x = self.relu2(x) x = self.l3(x) return x def find_next_version(base_name="nnue_weights"): """Find the next version number for model versioning. Looks for nnue_weights_v*.pt files and returns the next version number. If no versioned files exist, returns 1. """ pattern = re.compile(rf"{re.escape(base_name)}_v(\d+)\.pt") versions = [] for file in Path(".").glob(f"{base_name}_v*.pt"): match = pattern.match(file.name) if match: versions.append(int(match.group(1))) if versions: return max(versions) + 1 return 1 def save_metadata(weights_file, metadata): """Save training metadata alongside the weights file. Args: weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt) metadata: Dictionary with training info """ metadata_file = weights_file.replace(".pt", "_metadata.json") with open(metadata_file, "w") as f: json.dump(metadata, f, indent=2, default=str) return metadata_file def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True): """Train the NNUE model. Args: data_file: Path to training_data.jsonl output_file: Where to save best weights (or base name if use_versioning=True) epochs: Number of training epochs batch_size: Training batch size lr: Learning rate checkpoint: Optional path to checkpoint file to resume from stockfish_depth: Depth used in Stockfish evaluation (for metadata) use_versioning: If True, save as nnue_weights_v{N}.pt with metadata """ print("Loading dataset...") dataset = NNUEDataset(data_file) num_positions = len(dataset) print(f"Dataset size: {num_positions}") # Split 90% train, 10% validation train_size = int(0.9 * len(dataset)) val_size = len(dataset) - train_size from torch.utils.data import random_split train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) # Device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Model model = NNUE().to(device) criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=lr) # Load checkpoint if provided checkpoint_to_load = checkpoint if checkpoint_to_load is None and Path(output_file).exists(): # Auto-detect checkpoint: if output file already exists, use it as checkpoint checkpoint_to_load = output_file start_epoch = 0 if checkpoint_to_load is not None and Path(checkpoint_to_load).exists(): print(f"Loading checkpoint from {checkpoint_to_load}...") try: checkpoint_state = torch.load(checkpoint_to_load, map_location=device) model.load_state_dict(checkpoint_state) print(f"✓ Checkpoint loaded successfully") except Exception as e: print(f"Warning: Could not load checkpoint: {e}") print("Training from scratch instead") best_val_loss = float('inf') best_model_state = None print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...") print() training_start_time = datetime.now() for epoch in range(start_epoch, start_epoch + epochs): # Train model.train() train_loss = 0.0 epoch_display = epoch + 1 total_epochs = start_epoch + epochs with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar: for batch_features, batch_targets in train_loader: batch_features = batch_features.to(device) batch_targets = batch_targets.to(device).unsqueeze(1) optimizer.zero_grad() outputs = model(batch_features) loss = criterion(outputs, batch_targets) loss.backward() optimizer.step() train_loss += loss.item() * batch_features.size(0) pbar.update(1) train_loss /= len(train_dataset) # Validation model.eval() val_loss = 0.0 with torch.no_grad(): with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar: for batch_features, batch_targets in val_loader: batch_features = batch_features.to(device) batch_targets = batch_targets.to(device).unsqueeze(1) outputs = model(batch_features) loss = criterion(outputs, batch_targets) val_loss += loss.item() * batch_features.size(0) pbar.update(1) val_loss /= len(val_dataset) print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}") if val_loss < best_val_loss: best_val_loss = val_loss best_model_state = model.state_dict().copy() # Save best model if best_model_state is not None: # Determine final output file with versioning final_output_file = output_file metadata = {} if use_versioning: base_name = output_file.replace(".pt", "") version = find_next_version(base_name) final_output_file = f"{base_name}_v{version}.pt" # Prepare metadata metadata = { "version": version, "date": training_start_time.isoformat(), "num_positions": num_positions, "stockfish_depth": stockfish_depth, "epochs": epochs, "batch_size": batch_size, "learning_rate": lr, "final_val_loss": float(best_val_loss), "device": str(device), "checkpoint": str(checkpoint) if checkpoint else None, "notes": "Win rate vs classical eval: TBD (requires benchmark games)" } torch.save(best_model_state, final_output_file) print(f"Best model saved to {final_output_file}") # Save metadata if versioning is enabled if use_versioning and metadata: metadata_file = save_metadata(final_output_file, metadata) print(f"Metadata saved to {metadata_file}") print(f"\nTraining Summary:") print(f" Version: v{metadata['version']}") print(f" Positions: {metadata['num_positions']}") print(f" Stockfish depth: {metadata['stockfish_depth']}") print(f" Epochs: {metadata['epochs']}") print(f" Final validation loss: {metadata['final_val_loss']:.6f}") print(f" Device: {metadata['device']}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation") parser.add_argument("data_file", nargs="?", default="training_data.jsonl", help="Path to training_data.jsonl (default: training_data.jsonl)") parser.add_argument("output_file", nargs="?", default="nnue_weights.pt", help="Output file base name (default: nnue_weights.pt)") parser.add_argument("--checkpoint", type=str, default=None, help="Path to checkpoint file to resume training from (optional)") parser.add_argument("--epochs", type=int, default=20, help="Number of epochs to train (default: 20)") parser.add_argument("--batch-size", type=int, default=4096, help="Batch size (default: 4096)") parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate (default: 1e-3)") parser.add_argument("--stockfish-depth", type=int, default=12, help="Stockfish depth used for evaluations (for metadata, default: 12)") parser.add_argument("--no-versioning", action="store_true", help="Disable automatic versioning (save directly to output file)") args = parser.parse_args() train_nnue( data_file=args.data_file, output_file=args.output_file, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, checkpoint=args.checkpoint, stockfish_depth=args.stockfish_depth, use_versioning=not args.no_versioning )