feat: integrate NNUE bot and add Python training pipeline with weight export functionality

This commit is contained in:
2026-04-07 23:33:20 +02:00
parent 6a9ac55b31
commit b25be99dcf
29 changed files with 338 additions and 2538 deletions
+66
View File
@@ -0,0 +1,66 @@
#!/usr/bin/env python3
"""Export NNUE weights to binary format for runtime loading."""
import torch
import struct
import sys
from pathlib import Path
def export_weights_to_binary(weights_file, output_file):
"""Load PyTorch weights and export as binary file."""
if not Path(weights_file).exists():
print(f"Error: Weights file not found at {weights_file}")
sys.exit(1)
# Load weights
state_dict = torch.load(weights_file, map_location='cpu')
# Debug: print available layers
print(f"Available layers in {weights_file}:")
for key in sorted(state_dict.keys()):
print(f" {key}: {state_dict[key].shape}")
# Create output directory if needed
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'wb') as f:
# Write magic number and version
f.write(b'NNUE')
f.write(struct.pack('<I', 1)) # version 1
# Write each weight tensor in order
for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias']:
if layer_name not in state_dict:
print(f"Error: Missing layer {layer_name}")
sys.exit(1)
tensor = state_dict[layer_name]
# Convert to float32 and flatten
data = tensor.float().flatten().cpu().numpy()
# Write shape (allows validation on load)
shape = list(tensor.shape)
f.write(struct.pack('<I', len(shape)))
for dim in shape:
f.write(struct.pack('<I', dim))
# Write flattened data as binary floats
f.write(struct.pack(f'<{len(data)}f', *data))
print(f" {layer_name}: shape {shape}, {len(data)} floats")
file_size_mb = output_path.stat().st_size / (1024**2)
print(f"Weights exported to {output_file} ({file_size_mb:.2f} MB)")
if __name__ == "__main__":
weights_file = "nnue_weights.pt"
output_file = "../src/main/resources/nnue_weights.bin"
if len(sys.argv) > 1:
weights_file = sys.argv[1]
if len(sys.argv) > 2:
output_file = sys.argv[2]
export_weights_to_binary(weights_file, output_file)
+110
View File
@@ -0,0 +1,110 @@
#!/usr/bin/env python3
"""Generate 500,000 random chess positions for NNUE training."""
import chess
import random
import sys
from pathlib import Path
from tqdm import tqdm
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
"""Play random games and save positions after 8-20 random moves.
Returns:
Number of valid positions saved
"""
positions_count = 0
filtered_check = 0
filtered_captures = 0
filtered_game_over = 0
with open(output_file, 'w') as f:
with tqdm(total=total_games, desc="Generating positions") as pbar:
for game_num in range(total_games):
board = chess.Board()
# Play 8-20 random opening moves
num_moves = random.randint(8, 20)
for move_num in range(num_moves):
if board.is_game_over():
break
legal_moves = list(board.legal_moves)
if not legal_moves:
break
move = random.choice(legal_moves)
board.push(move)
# Skip if game over
if board.is_game_over():
filtered_game_over += 1
pbar.update(1)
continue
# Skip if in check
if board.is_check():
filtered_check += 1
pbar.update(1)
continue
# Check if any captures are available (if filtering enabled)
if filter_captures:
has_captures = any(board.is_capture(move) for move in board.legal_moves)
if has_captures:
filtered_captures += 1
pbar.update(1)
continue
# Save valid position
fen = board.fen()
f.write(fen + '\n')
positions_count += 1
pbar.update(1)
# Print summary
print()
print("=" * 60)
print("POSITION GENERATION SUMMARY")
print("=" * 60)
print(f"Total games: {total_games}")
print(f"Saved positions: {positions_count}")
print(f"Filtered (check): {filtered_check}")
print(f"Filtered (captures): {filtered_captures}")
print(f"Filtered (game over): {filtered_game_over}")
print(f"Total filtered: {filtered_check + filtered_captures + filtered_game_over}")
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
print("=" * 60)
print()
if positions_count == 0:
print("WARNING: No valid positions were generated!")
print("This might happen if:")
print(" - The filter criteria are too strict (captures, checks)")
print(" - Try using: --no-filter-captures to accept positions with captures")
return 0
return positions_count
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training")
parser.add_argument("output_file", nargs="?", default="positions.txt",
help="Output file for positions (default: positions.txt)")
parser.add_argument("--games", type=int, default=5000,
help="Number of games to play (default: 500000)")
parser.add_argument("--no-filter-captures", action="store_true",
help="Include positions with available captures (increases output)")
args = parser.parse_args()
count = play_random_game_and_collect_positions(
output_file=args.output_file,
total_games=args.games,
filter_captures=not args.no_filter_captures
)
sys.exit(0 if count > 0 else 1)
+198
View File
@@ -0,0 +1,198 @@
#!/usr/bin/env python3
"""Label positions with Stockfish evaluations."""
import json
import chess.engine
import sys
import os
from pathlib import Path
from tqdm import tqdm
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False):
"""Read positions and label them with Stockfish evaluations.
Args:
positions_file: Path to positions.txt
output_file: Path to training_data.jsonl
stockfish_path: Path to stockfish binary
batch_size: Batch size (not used, kept for compatibility)
depth: Stockfish depth
verbose: Print detailed error messages
"""
# Check if stockfish exists
if not Path(stockfish_path).exists():
print(f"Error: Stockfish not found at {stockfish_path}")
print(f"Tried: {stockfish_path}")
print(f"Set STOCKFISH_PATH environment variable or pass as argument")
sys.exit(1)
print(f"Using Stockfish: {stockfish_path}")
# Check if positions file exists
if not Path(positions_file).exists():
print(f"Error: Positions file not found at {positions_file}")
sys.exit(1)
# Load existing evaluations if resuming
evaluated_fens = set()
position_count = 0
if Path(output_file).exists():
with open(output_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
evaluated_fens.add(data['fen'])
position_count += 1
except json.JSONDecodeError:
pass
print(f"Resuming from {position_count} already evaluated positions")
# Count total positions
with open(positions_file, 'r') as f:
total_lines = sum(1 for _ in f)
if total_lines == 0:
print(f"Error: Positions file is empty ({positions_file})")
sys.exit(1)
print(f"Total positions to process: {total_lines}")
print(f"Using depth: {depth}")
print()
# Initialize engine
try:
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
except Exception as e:
print(f"Error: Could not start Stockfish engine")
print(f" Stockfish path: {stockfish_path}")
print(f" Error: {e}")
sys.exit(1)
# Track statistics
evaluated = 0
skipped_invalid = 0
skipped_duplicate = 0
errors = 0
try:
with open(positions_file, 'r') as f:
with open(output_file, 'a') as out:
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar:
for fen in f:
fen = fen.strip()
# Skip empty lines
if not fen:
skipped_invalid += 1
pbar.update(1)
continue
# Skip already evaluated
if fen in evaluated_fens:
skipped_duplicate += 1
pbar.update(1)
continue
try:
# Validate FEN
board = chess.Board(fen)
if not board.is_valid():
skipped_invalid += 1
pbar.update(1)
continue
# Evaluate at specified depth
info = engine.analyse(board, chess.engine.Limit(depth=depth))
if info.get('score') is None:
skipped_invalid += 1
pbar.update(1)
continue
score = info['score'].white()
# Convert to centipawns
if score.is_mate():
# Use large values for mate scores
eval_cp = 2000 if score.mate() > 0 else -2000
else:
eval_cp = score.cp
# Clamp to [-2000, 2000]
eval_cp = max(-2000, min(2000, eval_cp))
# Save evaluation
data = {"fen": fen, "eval": eval_cp}
out.write(json.dumps(data) + '\n')
out.flush() # Force write to disk
evaluated += 1
except Exception as e:
errors += 1
if verbose:
print(f"Error evaluating position: {fen[:50]}...")
print(f" {type(e).__name__}: {e}")
pbar.update(1)
continue
pbar.update(1)
finally:
engine.quit()
# Print summary
print()
print("=" * 60)
print("LABELING SUMMARY")
print("=" * 60)
print(f"Successfully evaluated: {evaluated}")
print(f"Skipped (duplicates): {skipped_duplicate}")
print(f"Skipped (invalid): {skipped_invalid}")
print(f"Errors: {errors}")
print(f"Total processed: {evaluated + skipped_duplicate + skipped_invalid + errors}")
print("=" * 60)
print()
if evaluated == 0:
print("WARNING: No positions were successfully evaluated!")
print("Check that:")
print(" 1. positions.txt is not empty")
print(" 2. positions.txt contains valid FENs")
print(" 3. Stockfish is installed and working")
print(" 4. Stockfish path is correct")
return False
print(f"✓ Labeling complete. Output saved to {output_file}")
return True
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Label chess positions with Stockfish evaluations")
parser.add_argument("positions_file", nargs="?", default="positions.txt",
help="Input positions file (default: positions.txt)")
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
help="Output file (default: training_data.jsonl)")
parser.add_argument("stockfish_path", nargs="?", default=None,
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
parser.add_argument("--depth", type=int, default=12,
help="Stockfish depth (default: 12)")
parser.add_argument("--verbose", action="store_true",
help="Print detailed error messages")
args = parser.parse_args()
# Determine Stockfish path
stockfish_path = args.stockfish_path or os.environ.get("STOCKFISH_PATH", "stockfish")
success = label_positions_with_stockfish(
positions_file=args.positions_file,
output_file=args.output_file,
stockfish_path=stockfish_path,
depth=args.depth,
verbose=args.verbose
)
sys.exit(0 if success else 1)
+301
View File
@@ -0,0 +1,301 @@
#!/usr/bin/env python3
"""Train NNUE neural network for chess evaluation."""
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import sys
from pathlib import Path
from tqdm import tqdm
import chess
from datetime import datetime
import re
class NNUEDataset(Dataset):
"""Dataset of chess positions with evaluations."""
def __init__(self, data_file):
self.positions = []
self.evals = []
with open(data_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
fen = data['fen']
eval_cp = data['eval']
self.positions.append(fen)
self.evals.append(eval_cp)
except (json.JSONDecodeError, KeyError):
pass
def __len__(self):
return len(self.positions)
def __getitem__(self, idx):
fen = self.positions[idx]
eval_cp = self.evals[idx]
features = fen_to_features(fen)
target = torch.sigmoid(torch.tensor(eval_cp / 400.0, dtype=torch.float32))
return features, target
def fen_to_features(fen):
"""Convert FEN to 768-dimensional binary feature vector."""
# Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5
piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}
features = torch.zeros(768, dtype=torch.float32)
try:
board = chess.Board(fen)
# 12 piece types × 64 squares = 768
for square in chess.SQUARES:
piece = board.piece_at(square)
if piece is not None:
piece_char = piece.symbol()
if piece_char in piece_to_idx:
piece_idx = piece_to_idx[piece_char]
feature_idx = piece_idx * 64 + square
features[feature_idx] = 1.0
except:
pass
return features
class NNUE(nn.Module):
"""NNUE neural network architecture."""
def __init__(self):
super().__init__()
self.l1 = nn.Linear(768, 256)
self.relu1 = nn.ReLU()
self.l2 = nn.Linear(256, 32)
self.relu2 = nn.ReLU()
self.l3 = nn.Linear(32, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.l1(x)
x = self.relu1(x)
x = self.l2(x)
x = self.relu2(x)
x = self.l3(x)
return x
def find_next_version(base_name="nnue_weights"):
"""Find the next version number for model versioning.
Looks for nnue_weights_v*.pt files and returns the next version number.
If no versioned files exist, returns 1.
"""
pattern = re.compile(rf"{re.escape(base_name)}_v(\d+)\.pt")
versions = []
for file in Path(".").glob(f"{base_name}_v*.pt"):
match = pattern.match(file.name)
if match:
versions.append(int(match.group(1)))
if versions:
return max(versions) + 1
return 1
def save_metadata(weights_file, metadata):
"""Save training metadata alongside the weights file.
Args:
weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt)
metadata: Dictionary with training info
"""
metadata_file = weights_file.replace(".pt", "_metadata.json")
with open(metadata_file, "w") as f:
json.dump(metadata, f, indent=2, default=str)
return metadata_file
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True):
"""Train the NNUE model.
Args:
data_file: Path to training_data.jsonl
output_file: Where to save best weights (or base name if use_versioning=True)
epochs: Number of training epochs
batch_size: Training batch size
lr: Learning rate
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
"""
print("Loading dataset...")
dataset = NNUEDataset(data_file)
num_positions = len(dataset)
print(f"Dataset size: {num_positions}")
# Split 90% train, 10% validation
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
from torch.utils.data import random_split
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Model
model = NNUE().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
# Load checkpoint if provided
checkpoint_to_load = checkpoint
if checkpoint_to_load is None and Path(output_file).exists():
# Auto-detect checkpoint: if output file already exists, use it as checkpoint
checkpoint_to_load = output_file
start_epoch = 0
if checkpoint_to_load is not None and Path(checkpoint_to_load).exists():
print(f"Loading checkpoint from {checkpoint_to_load}...")
try:
checkpoint_state = torch.load(checkpoint_to_load, map_location=device)
model.load_state_dict(checkpoint_state)
print(f"✓ Checkpoint loaded successfully")
except Exception as e:
print(f"Warning: Could not load checkpoint: {e}")
print("Training from scratch instead")
best_val_loss = float('inf')
best_model_state = None
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...")
print()
training_start_time = datetime.now()
for epoch in range(start_epoch, start_epoch + epochs):
# Train
model.train()
train_loss = 0.0
epoch_display = epoch + 1
total_epochs = start_epoch + epochs
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
for batch_features, batch_targets in train_loader:
batch_features = batch_features.to(device)
batch_targets = batch_targets.to(device).unsqueeze(1)
optimizer.zero_grad()
outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
loss.backward()
optimizer.step()
train_loss += loss.item() * batch_features.size(0)
pbar.update(1)
train_loss /= len(train_dataset)
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar:
for batch_features, batch_targets in val_loader:
batch_features = batch_features.to(device)
batch_targets = batch_targets.to(device).unsqueeze(1)
outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
val_loss += loss.item() * batch_features.size(0)
pbar.update(1)
val_loss /= len(val_dataset)
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = model.state_dict().copy()
# Save best model
if best_model_state is not None:
# Determine final output file with versioning
final_output_file = output_file
metadata = {}
if use_versioning:
base_name = output_file.replace(".pt", "")
version = find_next_version(base_name)
final_output_file = f"{base_name}_v{version}.pt"
# Prepare metadata
metadata = {
"version": version,
"date": training_start_time.isoformat(),
"num_positions": num_positions,
"stockfish_depth": stockfish_depth,
"epochs": epochs,
"batch_size": batch_size,
"learning_rate": lr,
"final_val_loss": float(best_val_loss),
"device": str(device),
"checkpoint": str(checkpoint) if checkpoint else None,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
torch.save(best_model_state, final_output_file)
print(f"Best model saved to {final_output_file}")
# Save metadata if versioning is enabled
if use_versioning and metadata:
metadata_file = save_metadata(final_output_file, metadata)
print(f"Metadata saved to {metadata_file}")
print(f"\nTraining Summary:")
print(f" Version: v{metadata['version']}")
print(f" Positions: {metadata['num_positions']}")
print(f" Stockfish depth: {metadata['stockfish_depth']}")
print(f" Epochs: {metadata['epochs']}")
print(f" Final validation loss: {metadata['final_val_loss']:.6f}")
print(f" Device: {metadata['device']}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation")
parser.add_argument("data_file", nargs="?", default="training_data.jsonl",
help="Path to training_data.jsonl (default: training_data.jsonl)")
parser.add_argument("output_file", nargs="?", default="nnue_weights.pt",
help="Output file base name (default: nnue_weights.pt)")
parser.add_argument("--checkpoint", type=str, default=None,
help="Path to checkpoint file to resume training from (optional)")
parser.add_argument("--epochs", type=int, default=20,
help="Number of epochs to train (default: 20)")
parser.add_argument("--batch-size", type=int, default=4096,
help="Batch size (default: 4096)")
parser.add_argument("--lr", type=float, default=1e-3,
help="Learning rate (default: 1e-3)")
parser.add_argument("--stockfish-depth", type=int, default=12,
help="Stockfish depth used for evaluations (for metadata, default: 12)")
parser.add_argument("--no-versioning", action="store_true",
help="Disable automatic versioning (save directly to output file)")
args = parser.parse_args()
train_nnue(
data_file=args.data_file,
output_file=args.output_file,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning
)