feat: integrate NNUE bot and add Python training pipeline with weight export functionality
This commit is contained in:
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Export NNUE weights to binary format for runtime loading."""
|
||||
|
||||
import torch
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
def export_weights_to_binary(weights_file, output_file):
|
||||
"""Load PyTorch weights and export as binary file."""
|
||||
|
||||
if not Path(weights_file).exists():
|
||||
print(f"Error: Weights file not found at {weights_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load weights
|
||||
state_dict = torch.load(weights_file, map_location='cpu')
|
||||
|
||||
# Debug: print available layers
|
||||
print(f"Available layers in {weights_file}:")
|
||||
for key in sorted(state_dict.keys()):
|
||||
print(f" {key}: {state_dict[key].shape}")
|
||||
|
||||
# Create output directory if needed
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_file, 'wb') as f:
|
||||
# Write magic number and version
|
||||
f.write(b'NNUE')
|
||||
f.write(struct.pack('<I', 1)) # version 1
|
||||
|
||||
# Write each weight tensor in order
|
||||
for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias']:
|
||||
if layer_name not in state_dict:
|
||||
print(f"Error: Missing layer {layer_name}")
|
||||
sys.exit(1)
|
||||
|
||||
tensor = state_dict[layer_name]
|
||||
# Convert to float32 and flatten
|
||||
data = tensor.float().flatten().cpu().numpy()
|
||||
|
||||
# Write shape (allows validation on load)
|
||||
shape = list(tensor.shape)
|
||||
f.write(struct.pack('<I', len(shape)))
|
||||
for dim in shape:
|
||||
f.write(struct.pack('<I', dim))
|
||||
|
||||
# Write flattened data as binary floats
|
||||
f.write(struct.pack(f'<{len(data)}f', *data))
|
||||
|
||||
print(f" {layer_name}: shape {shape}, {len(data)} floats")
|
||||
|
||||
file_size_mb = output_path.stat().st_size / (1024**2)
|
||||
print(f"Weights exported to {output_file} ({file_size_mb:.2f} MB)")
|
||||
|
||||
if __name__ == "__main__":
|
||||
weights_file = "nnue_weights.pt"
|
||||
output_file = "../src/main/resources/nnue_weights.bin"
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
weights_file = sys.argv[1]
|
||||
if len(sys.argv) > 2:
|
||||
output_file = sys.argv[2]
|
||||
|
||||
export_weights_to_binary(weights_file, output_file)
|
||||
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate 500,000 random chess positions for NNUE training."""
|
||||
|
||||
import chess
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
|
||||
"""Play random games and save positions after 8-20 random moves.
|
||||
|
||||
Returns:
|
||||
Number of valid positions saved
|
||||
"""
|
||||
positions_count = 0
|
||||
filtered_check = 0
|
||||
filtered_captures = 0
|
||||
filtered_game_over = 0
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
with tqdm(total=total_games, desc="Generating positions") as pbar:
|
||||
for game_num in range(total_games):
|
||||
board = chess.Board()
|
||||
|
||||
# Play 8-20 random opening moves
|
||||
num_moves = random.randint(8, 20)
|
||||
|
||||
for move_num in range(num_moves):
|
||||
if board.is_game_over():
|
||||
break
|
||||
|
||||
legal_moves = list(board.legal_moves)
|
||||
if not legal_moves:
|
||||
break
|
||||
|
||||
move = random.choice(legal_moves)
|
||||
board.push(move)
|
||||
|
||||
# Skip if game over
|
||||
if board.is_game_over():
|
||||
filtered_game_over += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# Skip if in check
|
||||
if board.is_check():
|
||||
filtered_check += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# Check if any captures are available (if filtering enabled)
|
||||
if filter_captures:
|
||||
has_captures = any(board.is_capture(move) for move in board.legal_moves)
|
||||
if has_captures:
|
||||
filtered_captures += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# Save valid position
|
||||
fen = board.fen()
|
||||
f.write(fen + '\n')
|
||||
positions_count += 1
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# Print summary
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("POSITION GENERATION SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Total games: {total_games}")
|
||||
print(f"Saved positions: {positions_count}")
|
||||
print(f"Filtered (check): {filtered_check}")
|
||||
print(f"Filtered (captures): {filtered_captures}")
|
||||
print(f"Filtered (game over): {filtered_game_over}")
|
||||
print(f"Total filtered: {filtered_check + filtered_captures + filtered_game_over}")
|
||||
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if positions_count == 0:
|
||||
print("WARNING: No valid positions were generated!")
|
||||
print("This might happen if:")
|
||||
print(" - The filter criteria are too strict (captures, checks)")
|
||||
print(" - Try using: --no-filter-captures to accept positions with captures")
|
||||
return 0
|
||||
|
||||
return positions_count
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training")
|
||||
parser.add_argument("output_file", nargs="?", default="positions.txt",
|
||||
help="Output file for positions (default: positions.txt)")
|
||||
parser.add_argument("--games", type=int, default=5000,
|
||||
help="Number of games to play (default: 500000)")
|
||||
parser.add_argument("--no-filter-captures", action="store_true",
|
||||
help="Include positions with available captures (increases output)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
count = play_random_game_and_collect_positions(
|
||||
output_file=args.output_file,
|
||||
total_games=args.games,
|
||||
filter_captures=not args.no_filter_captures
|
||||
)
|
||||
|
||||
sys.exit(0 if count > 0 else 1)
|
||||
@@ -0,0 +1,198 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Label positions with Stockfish evaluations."""
|
||||
|
||||
import json
|
||||
import chess.engine
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False):
|
||||
"""Read positions and label them with Stockfish evaluations.
|
||||
|
||||
Args:
|
||||
positions_file: Path to positions.txt
|
||||
output_file: Path to training_data.jsonl
|
||||
stockfish_path: Path to stockfish binary
|
||||
batch_size: Batch size (not used, kept for compatibility)
|
||||
depth: Stockfish depth
|
||||
verbose: Print detailed error messages
|
||||
"""
|
||||
|
||||
# Check if stockfish exists
|
||||
if not Path(stockfish_path).exists():
|
||||
print(f"Error: Stockfish not found at {stockfish_path}")
|
||||
print(f"Tried: {stockfish_path}")
|
||||
print(f"Set STOCKFISH_PATH environment variable or pass as argument")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using Stockfish: {stockfish_path}")
|
||||
|
||||
# Check if positions file exists
|
||||
if not Path(positions_file).exists():
|
||||
print(f"Error: Positions file not found at {positions_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load existing evaluations if resuming
|
||||
evaluated_fens = set()
|
||||
position_count = 0
|
||||
|
||||
if Path(output_file).exists():
|
||||
with open(output_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
evaluated_fens.add(data['fen'])
|
||||
position_count += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
print(f"Resuming from {position_count} already evaluated positions")
|
||||
|
||||
# Count total positions
|
||||
with open(positions_file, 'r') as f:
|
||||
total_lines = sum(1 for _ in f)
|
||||
|
||||
if total_lines == 0:
|
||||
print(f"Error: Positions file is empty ({positions_file})")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Total positions to process: {total_lines}")
|
||||
print(f"Using depth: {depth}")
|
||||
print()
|
||||
|
||||
# Initialize engine
|
||||
try:
|
||||
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
||||
except Exception as e:
|
||||
print(f"Error: Could not start Stockfish engine")
|
||||
print(f" Stockfish path: {stockfish_path}")
|
||||
print(f" Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Track statistics
|
||||
evaluated = 0
|
||||
skipped_invalid = 0
|
||||
skipped_duplicate = 0
|
||||
errors = 0
|
||||
|
||||
try:
|
||||
with open(positions_file, 'r') as f:
|
||||
with open(output_file, 'a') as out:
|
||||
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar:
|
||||
for fen in f:
|
||||
fen = fen.strip()
|
||||
|
||||
# Skip empty lines
|
||||
if not fen:
|
||||
skipped_invalid += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# Skip already evaluated
|
||||
if fen in evaluated_fens:
|
||||
skipped_duplicate += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Validate FEN
|
||||
board = chess.Board(fen)
|
||||
if not board.is_valid():
|
||||
skipped_invalid += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# Evaluate at specified depth
|
||||
info = engine.analyse(board, chess.engine.Limit(depth=depth))
|
||||
|
||||
if info.get('score') is None:
|
||||
skipped_invalid += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
score = info['score'].white()
|
||||
|
||||
# Convert to centipawns
|
||||
if score.is_mate():
|
||||
# Use large values for mate scores
|
||||
eval_cp = 2000 if score.mate() > 0 else -2000
|
||||
else:
|
||||
eval_cp = score.cp
|
||||
|
||||
# Clamp to [-2000, 2000]
|
||||
eval_cp = max(-2000, min(2000, eval_cp))
|
||||
|
||||
# Save evaluation
|
||||
data = {"fen": fen, "eval": eval_cp}
|
||||
out.write(json.dumps(data) + '\n')
|
||||
out.flush() # Force write to disk
|
||||
evaluated += 1
|
||||
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
if verbose:
|
||||
print(f"Error evaluating position: {fen[:50]}...")
|
||||
print(f" {type(e).__name__}: {e}")
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
finally:
|
||||
engine.quit()
|
||||
|
||||
# Print summary
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("LABELING SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Successfully evaluated: {evaluated}")
|
||||
print(f"Skipped (duplicates): {skipped_duplicate}")
|
||||
print(f"Skipped (invalid): {skipped_invalid}")
|
||||
print(f"Errors: {errors}")
|
||||
print(f"Total processed: {evaluated + skipped_duplicate + skipped_invalid + errors}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if evaluated == 0:
|
||||
print("WARNING: No positions were successfully evaluated!")
|
||||
print("Check that:")
|
||||
print(" 1. positions.txt is not empty")
|
||||
print(" 2. positions.txt contains valid FENs")
|
||||
print(" 3. Stockfish is installed and working")
|
||||
print(" 4. Stockfish path is correct")
|
||||
return False
|
||||
|
||||
print(f"✓ Labeling complete. Output saved to {output_file}")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Label chess positions with Stockfish evaluations")
|
||||
parser.add_argument("positions_file", nargs="?", default="positions.txt",
|
||||
help="Input positions file (default: positions.txt)")
|
||||
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
|
||||
help="Output file (default: training_data.jsonl)")
|
||||
parser.add_argument("stockfish_path", nargs="?", default=None,
|
||||
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
||||
parser.add_argument("--depth", type=int, default=12,
|
||||
help="Stockfish depth (default: 12)")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Print detailed error messages")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine Stockfish path
|
||||
stockfish_path = args.stockfish_path or os.environ.get("STOCKFISH_PATH", "stockfish")
|
||||
|
||||
success = label_positions_with_stockfish(
|
||||
positions_file=args.positions_file,
|
||||
output_file=args.output_file,
|
||||
stockfish_path=stockfish_path,
|
||||
depth=args.depth,
|
||||
verbose=args.verbose
|
||||
)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -0,0 +1,301 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train NNUE neural network for chess evaluation."""
|
||||
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import chess
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
class NNUEDataset(Dataset):
|
||||
"""Dataset of chess positions with evaluations."""
|
||||
|
||||
def __init__(self, data_file):
|
||||
self.positions = []
|
||||
self.evals = []
|
||||
|
||||
with open(data_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data['fen']
|
||||
eval_cp = data['eval']
|
||||
self.positions.append(fen)
|
||||
self.evals.append(eval_cp)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
return len(self.positions)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
fen = self.positions[idx]
|
||||
eval_cp = self.evals[idx]
|
||||
features = fen_to_features(fen)
|
||||
target = torch.sigmoid(torch.tensor(eval_cp / 400.0, dtype=torch.float32))
|
||||
return features, target
|
||||
|
||||
def fen_to_features(fen):
|
||||
"""Convert FEN to 768-dimensional binary feature vector."""
|
||||
# Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5
|
||||
piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
|
||||
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}
|
||||
|
||||
features = torch.zeros(768, dtype=torch.float32)
|
||||
|
||||
try:
|
||||
board = chess.Board(fen)
|
||||
|
||||
# 12 piece types × 64 squares = 768
|
||||
for square in chess.SQUARES:
|
||||
piece = board.piece_at(square)
|
||||
if piece is not None:
|
||||
piece_char = piece.symbol()
|
||||
if piece_char in piece_to_idx:
|
||||
piece_idx = piece_to_idx[piece_char]
|
||||
feature_idx = piece_idx * 64 + square
|
||||
features[feature_idx] = 1.0
|
||||
except:
|
||||
pass
|
||||
|
||||
return features
|
||||
|
||||
class NNUE(nn.Module):
|
||||
"""NNUE neural network architecture."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(768, 256)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.l2 = nn.Linear(256, 32)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.l3 = nn.Linear(32, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.l1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.l2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.l3(x)
|
||||
return x
|
||||
|
||||
def find_next_version(base_name="nnue_weights"):
|
||||
"""Find the next version number for model versioning.
|
||||
|
||||
Looks for nnue_weights_v*.pt files and returns the next version number.
|
||||
If no versioned files exist, returns 1.
|
||||
"""
|
||||
pattern = re.compile(rf"{re.escape(base_name)}_v(\d+)\.pt")
|
||||
versions = []
|
||||
|
||||
for file in Path(".").glob(f"{base_name}_v*.pt"):
|
||||
match = pattern.match(file.name)
|
||||
if match:
|
||||
versions.append(int(match.group(1)))
|
||||
|
||||
if versions:
|
||||
return max(versions) + 1
|
||||
return 1
|
||||
|
||||
def save_metadata(weights_file, metadata):
|
||||
"""Save training metadata alongside the weights file.
|
||||
|
||||
Args:
|
||||
weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt)
|
||||
metadata: Dictionary with training info
|
||||
"""
|
||||
metadata_file = weights_file.replace(".pt", "_metadata.json")
|
||||
|
||||
with open(metadata_file, "w") as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
|
||||
return metadata_file
|
||||
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True):
|
||||
"""Train the NNUE model.
|
||||
|
||||
Args:
|
||||
data_file: Path to training_data.jsonl
|
||||
output_file: Where to save best weights (or base name if use_versioning=True)
|
||||
epochs: Number of training epochs
|
||||
batch_size: Training batch size
|
||||
lr: Learning rate
|
||||
checkpoint: Optional path to checkpoint file to resume from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
"""
|
||||
|
||||
print("Loading dataset...")
|
||||
dataset = NNUEDataset(data_file)
|
||||
num_positions = len(dataset)
|
||||
print(f"Dataset size: {num_positions}")
|
||||
|
||||
# Split 90% train, 10% validation
|
||||
train_size = int(0.9 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
|
||||
from torch.utils.data import random_split
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
# Device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Model
|
||||
model = NNUE().to(device)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
# Load checkpoint if provided
|
||||
checkpoint_to_load = checkpoint
|
||||
if checkpoint_to_load is None and Path(output_file).exists():
|
||||
# Auto-detect checkpoint: if output file already exists, use it as checkpoint
|
||||
checkpoint_to_load = output_file
|
||||
|
||||
start_epoch = 0
|
||||
if checkpoint_to_load is not None and Path(checkpoint_to_load).exists():
|
||||
print(f"Loading checkpoint from {checkpoint_to_load}...")
|
||||
try:
|
||||
checkpoint_state = torch.load(checkpoint_to_load, map_location=device)
|
||||
model.load_state_dict(checkpoint_state)
|
||||
print(f"✓ Checkpoint loaded successfully")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load checkpoint: {e}")
|
||||
print("Training from scratch instead")
|
||||
|
||||
best_val_loss = float('inf')
|
||||
best_model_state = None
|
||||
|
||||
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...")
|
||||
print()
|
||||
|
||||
training_start_time = datetime.now()
|
||||
|
||||
for epoch in range(start_epoch, start_epoch + epochs):
|
||||
# Train
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
epoch_display = epoch + 1
|
||||
total_epochs = start_epoch + epochs
|
||||
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
|
||||
for batch_features, batch_targets in train_loader:
|
||||
batch_features = batch_features.to(device)
|
||||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item() * batch_features.size(0)
|
||||
pbar.update(1)
|
||||
|
||||
train_loss /= len(train_dataset)
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar:
|
||||
for batch_features, batch_targets in val_loader:
|
||||
batch_features = batch_features.to(device)
|
||||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
val_loss += loss.item() * batch_features.size(0)
|
||||
pbar.update(1)
|
||||
|
||||
val_loss /= len(val_dataset)
|
||||
|
||||
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_model_state = model.state_dict().copy()
|
||||
|
||||
# Save best model
|
||||
if best_model_state is not None:
|
||||
# Determine final output file with versioning
|
||||
final_output_file = output_file
|
||||
metadata = {}
|
||||
|
||||
if use_versioning:
|
||||
base_name = output_file.replace(".pt", "")
|
||||
version = find_next_version(base_name)
|
||||
final_output_file = f"{base_name}_v{version}.pt"
|
||||
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
"version": version,
|
||||
"date": training_start_time.isoformat(),
|
||||
"num_positions": num_positions,
|
||||
"stockfish_depth": stockfish_depth,
|
||||
"epochs": epochs,
|
||||
"batch_size": batch_size,
|
||||
"learning_rate": lr,
|
||||
"final_val_loss": float(best_val_loss),
|
||||
"device": str(device),
|
||||
"checkpoint": str(checkpoint) if checkpoint else None,
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
|
||||
torch.save(best_model_state, final_output_file)
|
||||
print(f"Best model saved to {final_output_file}")
|
||||
|
||||
# Save metadata if versioning is enabled
|
||||
if use_versioning and metadata:
|
||||
metadata_file = save_metadata(final_output_file, metadata)
|
||||
print(f"Metadata saved to {metadata_file}")
|
||||
print(f"\nTraining Summary:")
|
||||
print(f" Version: v{metadata['version']}")
|
||||
print(f" Positions: {metadata['num_positions']}")
|
||||
print(f" Stockfish depth: {metadata['stockfish_depth']}")
|
||||
print(f" Epochs: {metadata['epochs']}")
|
||||
print(f" Final validation loss: {metadata['final_val_loss']:.6f}")
|
||||
print(f" Device: {metadata['device']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation")
|
||||
parser.add_argument("data_file", nargs="?", default="training_data.jsonl",
|
||||
help="Path to training_data.jsonl (default: training_data.jsonl)")
|
||||
parser.add_argument("output_file", nargs="?", default="nnue_weights.pt",
|
||||
help="Output file base name (default: nnue_weights.pt)")
|
||||
parser.add_argument("--checkpoint", type=str, default=None,
|
||||
help="Path to checkpoint file to resume training from (optional)")
|
||||
parser.add_argument("--epochs", type=int, default=20,
|
||||
help="Number of epochs to train (default: 20)")
|
||||
parser.add_argument("--batch-size", type=int, default=4096,
|
||||
help="Batch size (default: 4096)")
|
||||
parser.add_argument("--lr", type=float, default=1e-3,
|
||||
help="Learning rate (default: 1e-3)")
|
||||
parser.add_argument("--stockfish-depth", type=int, default=12,
|
||||
help="Stockfish depth used for evaluations (for metadata, default: 12)")
|
||||
parser.add_argument("--no-versioning", action="store_true",
|
||||
help="Disable automatic versioning (save directly to output file)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
train_nnue(
|
||||
data_file=args.data_file,
|
||||
output_file=args.output_file,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning
|
||||
)
|
||||
Reference in New Issue
Block a user