feat: refactor AlphaBetaSearch and ClassicalBot for improved evaluation and organization
This commit is contained in:
@@ -0,0 +1,301 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train NNUE neural network for chess evaluation."""
|
||||
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import chess
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
class NNUEDataset(Dataset):
|
||||
"""Dataset of chess positions with evaluations."""
|
||||
|
||||
def __init__(self, data_file):
|
||||
self.positions = []
|
||||
self.evals = []
|
||||
|
||||
with open(data_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data['fen']
|
||||
eval_cp = data['eval']
|
||||
self.positions.append(fen)
|
||||
self.evals.append(eval_cp)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
return len(self.positions)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
fen = self.positions[idx]
|
||||
eval_cp = self.evals[idx]
|
||||
features = fen_to_features(fen)
|
||||
target = torch.sigmoid(torch.tensor(eval_cp / 400.0, dtype=torch.float32))
|
||||
return features, target
|
||||
|
||||
def fen_to_features(fen):
|
||||
"""Convert FEN to 768-dimensional binary feature vector."""
|
||||
# Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5
|
||||
piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
|
||||
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}
|
||||
|
||||
features = torch.zeros(768, dtype=torch.float32)
|
||||
|
||||
try:
|
||||
board = chess.Board(fen)
|
||||
|
||||
# 12 piece types × 64 squares = 768
|
||||
for square in chess.SQUARES:
|
||||
piece = board.piece_at(square)
|
||||
if piece is not None:
|
||||
piece_char = piece.symbol()
|
||||
if piece_char in piece_to_idx:
|
||||
piece_idx = piece_to_idx[piece_char]
|
||||
feature_idx = piece_idx * 64 + square
|
||||
features[feature_idx] = 1.0
|
||||
except:
|
||||
pass
|
||||
|
||||
return features
|
||||
|
||||
class NNUE(nn.Module):
|
||||
"""NNUE neural network architecture."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(768, 256)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.l2 = nn.Linear(256, 32)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.l3 = nn.Linear(32, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.l1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.l2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.l3(x)
|
||||
return x
|
||||
|
||||
def find_next_version(base_name="nnue_weights"):
|
||||
"""Find the next version number for model versioning.
|
||||
|
||||
Looks for nnue_weights_v*.pt files and returns the next version number.
|
||||
If no versioned files exist, returns 1.
|
||||
"""
|
||||
pattern = re.compile(rf"{re.escape(base_name)}_v(\d+)\.pt")
|
||||
versions = []
|
||||
|
||||
for file in Path(".").glob(f"{base_name}_v*.pt"):
|
||||
match = pattern.match(file.name)
|
||||
if match:
|
||||
versions.append(int(match.group(1)))
|
||||
|
||||
if versions:
|
||||
return max(versions) + 1
|
||||
return 1
|
||||
|
||||
def save_metadata(weights_file, metadata):
|
||||
"""Save training metadata alongside the weights file.
|
||||
|
||||
Args:
|
||||
weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt)
|
||||
metadata: Dictionary with training info
|
||||
"""
|
||||
metadata_file = weights_file.replace(".pt", "_metadata.json")
|
||||
|
||||
with open(metadata_file, "w") as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
|
||||
return metadata_file
|
||||
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True):
|
||||
"""Train the NNUE model.
|
||||
|
||||
Args:
|
||||
data_file: Path to training_data.jsonl
|
||||
output_file: Where to save best weights (or base name if use_versioning=True)
|
||||
epochs: Number of training epochs
|
||||
batch_size: Training batch size
|
||||
lr: Learning rate
|
||||
checkpoint: Optional path to checkpoint file to resume from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
"""
|
||||
|
||||
print("Loading dataset...")
|
||||
dataset = NNUEDataset(data_file)
|
||||
num_positions = len(dataset)
|
||||
print(f"Dataset size: {num_positions}")
|
||||
|
||||
# Split 90% train, 10% validation
|
||||
train_size = int(0.9 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
|
||||
from torch.utils.data import random_split
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
# Device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Model
|
||||
model = NNUE().to(device)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
# Load checkpoint if provided
|
||||
checkpoint_to_load = checkpoint
|
||||
if checkpoint_to_load is None and Path(output_file).exists():
|
||||
# Auto-detect checkpoint: if output file already exists, use it as checkpoint
|
||||
checkpoint_to_load = output_file
|
||||
|
||||
start_epoch = 0
|
||||
if checkpoint_to_load is not None and Path(checkpoint_to_load).exists():
|
||||
print(f"Loading checkpoint from {checkpoint_to_load}...")
|
||||
try:
|
||||
checkpoint_state = torch.load(checkpoint_to_load, map_location=device)
|
||||
model.load_state_dict(checkpoint_state)
|
||||
print(f"✓ Checkpoint loaded successfully")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load checkpoint: {e}")
|
||||
print("Training from scratch instead")
|
||||
|
||||
best_val_loss = float('inf')
|
||||
best_model_state = None
|
||||
|
||||
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...")
|
||||
print()
|
||||
|
||||
training_start_time = datetime.now()
|
||||
|
||||
for epoch in range(start_epoch, start_epoch + epochs):
|
||||
# Train
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
epoch_display = epoch + 1
|
||||
total_epochs = start_epoch + epochs
|
||||
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
|
||||
for batch_features, batch_targets in train_loader:
|
||||
batch_features = batch_features.to(device)
|
||||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item() * batch_features.size(0)
|
||||
pbar.update(1)
|
||||
|
||||
train_loss /= len(train_dataset)
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar:
|
||||
for batch_features, batch_targets in val_loader:
|
||||
batch_features = batch_features.to(device)
|
||||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
val_loss += loss.item() * batch_features.size(0)
|
||||
pbar.update(1)
|
||||
|
||||
val_loss /= len(val_dataset)
|
||||
|
||||
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_model_state = model.state_dict().copy()
|
||||
|
||||
# Save best model
|
||||
if best_model_state is not None:
|
||||
# Determine final output file with versioning
|
||||
final_output_file = output_file
|
||||
metadata = {}
|
||||
|
||||
if use_versioning:
|
||||
base_name = output_file.replace(".pt", "")
|
||||
version = find_next_version(base_name)
|
||||
final_output_file = f"{base_name}_v{version}.pt"
|
||||
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
"version": version,
|
||||
"date": training_start_time.isoformat(),
|
||||
"num_positions": num_positions,
|
||||
"stockfish_depth": stockfish_depth,
|
||||
"epochs": epochs,
|
||||
"batch_size": batch_size,
|
||||
"learning_rate": lr,
|
||||
"final_val_loss": float(best_val_loss),
|
||||
"device": str(device),
|
||||
"checkpoint": str(checkpoint) if checkpoint else None,
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
|
||||
torch.save(best_model_state, final_output_file)
|
||||
print(f"Best model saved to {final_output_file}")
|
||||
|
||||
# Save metadata if versioning is enabled
|
||||
if use_versioning and metadata:
|
||||
metadata_file = save_metadata(final_output_file, metadata)
|
||||
print(f"Metadata saved to {metadata_file}")
|
||||
print(f"\nTraining Summary:")
|
||||
print(f" Version: v{metadata['version']}")
|
||||
print(f" Positions: {metadata['num_positions']}")
|
||||
print(f" Stockfish depth: {metadata['stockfish_depth']}")
|
||||
print(f" Epochs: {metadata['epochs']}")
|
||||
print(f" Final validation loss: {metadata['final_val_loss']:.6f}")
|
||||
print(f" Device: {metadata['device']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation")
|
||||
parser.add_argument("data_file", nargs="?", default="training_data.jsonl",
|
||||
help="Path to training_data.jsonl (default: training_data.jsonl)")
|
||||
parser.add_argument("output_file", nargs="?", default="nnue_weights.pt",
|
||||
help="Output file base name (default: nnue_weights.pt)")
|
||||
parser.add_argument("--checkpoint", type=str, default=None,
|
||||
help="Path to checkpoint file to resume training from (optional)")
|
||||
parser.add_argument("--epochs", type=int, default=20,
|
||||
help="Number of epochs to train (default: 20)")
|
||||
parser.add_argument("--batch-size", type=int, default=4096,
|
||||
help="Batch size (default: 4096)")
|
||||
parser.add_argument("--lr", type=float, default=1e-3,
|
||||
help="Learning rate (default: 1e-3)")
|
||||
parser.add_argument("--stockfish-depth", type=int, default=12,
|
||||
help="Stockfish depth used for evaluations (for metadata, default: 12)")
|
||||
parser.add_argument("--no-versioning", action="store_true",
|
||||
help="Disable automatic versioning (save directly to output file)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
train_nnue(
|
||||
data_file=args.data_file,
|
||||
output_file=args.output_file,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning
|
||||
)
|
||||
Reference in New Issue
Block a user