feat: NCS-41 Bot Platform (#33)
Build & Test (NowChessSystems) TeamCity build finished

Co-authored-by: Janis <janis@nowchess.de>
Reviewed-on: #33
Co-authored-by: Janis <janis.e.20@gmx.de>
Co-committed-by: Janis <janis.e.20@gmx.de>
This commit was merged in pull request #33.
This commit is contained in:
2026-04-19 15:52:08 +02:00
committed by Janis
parent 5f4d33f3ca
commit dceab0875e
117 changed files with 2531201 additions and 424 deletions
+287
View File
@@ -0,0 +1,287 @@
#!/usr/bin/env python3
"""Dataset versioning and management for NNUE training data."""
import json
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, List, Tuple
from rich.console import Console
from rich.table import Table
def get_datasets_dir() -> Path:
"""Get/create datasets directory."""
datasets_dir = Path(__file__).parent.parent / "datasets"
datasets_dir.mkdir(exist_ok=True)
return datasets_dir
def next_dataset_version() -> int:
"""Find the next available dataset version number."""
datasets_dir = get_datasets_dir()
versions = []
for d in datasets_dir.iterdir():
if d.is_dir() and d.name.startswith("ds_v"):
try:
v = int(d.name.split("_v")[1])
versions.append(v)
except (ValueError, IndexError):
pass
return max(versions) + 1 if versions else 1
def list_datasets() -> List[Tuple[int, Dict]]:
"""List all datasets with their metadata.
Returns:
List of (version, metadata_dict) tuples, sorted by version.
"""
datasets_dir = get_datasets_dir()
datasets = []
for d in datasets_dir.iterdir():
if d.is_dir() and d.name.startswith("ds_v"):
try:
v = int(d.name.split("_v")[1])
metadata_file = d / "metadata.json"
if metadata_file.exists():
with open(metadata_file, 'r') as f:
metadata = json.load(f)
datasets.append((v, metadata))
except (ValueError, IndexError, json.JSONDecodeError):
pass
return sorted(datasets, key=lambda x: x[0])
def load_dataset_metadata(version: int) -> Optional[Dict]:
"""Load metadata for a specific dataset version.
Returns:
Metadata dict or None if not found.
"""
datasets_dir = get_datasets_dir()
metadata_file = datasets_dir / f"ds_v{version}" / "metadata.json"
if not metadata_file.exists():
return None
with open(metadata_file, 'r') as f:
return json.load(f)
def save_dataset_metadata(version: int, metadata: Dict) -> None:
"""Save metadata for a dataset version."""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
dataset_dir.mkdir(exist_ok=True)
metadata_file = dataset_dir / "metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2, default=str)
def create_dataset(
version: int,
labeled_jsonl_path: str,
sources: List[Dict],
stockfish_depth: int = 12
) -> Path:
"""Create a new versioned dataset.
Args:
version: Dataset version number
labeled_jsonl_path: Path to labeled.jsonl to copy
sources: List of source dicts (see plan for schema)
stockfish_depth: Depth used for labeling
Returns:
Path to the created dataset directory.
"""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
dataset_dir.mkdir(exist_ok=True)
# Copy labeled data with deduplication (in case source has duplicates)
source_path = Path(labeled_jsonl_path)
if source_path.exists():
dest_path = dataset_dir / "labeled.jsonl"
seen_fens = set()
unique_count = 0
with open(source_path, 'r') as src, open(dest_path, 'w') as dst:
for line in src:
try:
data = json.loads(line)
fen = data.get('fen')
if fen and fen not in seen_fens:
dst.write(line)
seen_fens.add(fen)
unique_count += 1
except json.JSONDecodeError:
# Skip malformed lines
pass
# Count positions
total_positions = 0
if (dataset_dir / "labeled.jsonl").exists():
with open(dataset_dir / "labeled.jsonl", 'r') as f:
total_positions = sum(1 for _ in f)
# Create metadata
metadata = {
"version": version,
"created": datetime.now().isoformat(),
"total_positions": total_positions,
"stockfish_depth": stockfish_depth,
"sources": sources
}
save_dataset_metadata(version, metadata)
return dataset_dir
def extend_dataset(
version: int,
new_labeled_path: str,
new_source_entry: Dict
) -> bool:
"""Extend an existing dataset with new labeled positions (with deduplication).
Args:
version: Dataset version to extend
new_labeled_path: Path to new labeled.jsonl to merge
new_source_entry: Source entry to add to metadata
Returns:
True if successful, False otherwise.
"""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
if not dataset_dir.exists():
return False
labeled_file = dataset_dir / "labeled.jsonl"
new_labeled_file = Path(new_labeled_path)
if not new_labeled_file.exists():
return False
# Load existing FENs (dedup set) — must load entire file to avoid duplicates
existing_fens = set()
if labeled_file.exists():
with open(labeled_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
fen = data.get('fen')
if fen:
existing_fens.add(fen)
except json.JSONDecodeError:
pass
# Merge new positions, skipping duplicates
new_count = 0
new_lines = []
with open(new_labeled_file, 'r') as f_new:
for line in f_new:
try:
data = json.loads(line)
fen = data.get('fen')
if fen and fen not in existing_fens:
new_lines.append(line)
existing_fens.add(fen)
new_count += 1
except json.JSONDecodeError:
pass
# Append only the new, unique positions
if new_lines:
with open(labeled_file, 'a') as f_append:
for line in new_lines:
f_append.write(line)
# Update metadata
metadata = load_dataset_metadata(version)
if metadata:
# Count total positions
total_positions = 0
with open(labeled_file, 'r') as f:
total_positions = sum(1 for _ in f)
metadata['total_positions'] = total_positions
# Update the source entry with actual count of new positions added
new_source_entry['actual_count'] = new_count
metadata['sources'].append(new_source_entry)
save_dataset_metadata(version, metadata)
return True
def get_dataset_labeled_path(version: int) -> Optional[Path]:
"""Get the path to a dataset's labeled.jsonl file.
Returns:
Path to labeled.jsonl or None if dataset doesn't exist.
"""
datasets_dir = get_datasets_dir()
labeled_file = datasets_dir / f"ds_v{version}" / "labeled.jsonl"
if labeled_file.exists():
return labeled_file
return None
def delete_dataset(version: int) -> bool:
"""Delete a dataset (recursively removes directory).
Args:
version: Dataset version to delete
Returns:
True if successful.
"""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
if not dataset_dir.exists():
return False
import shutil
shutil.rmtree(dataset_dir)
return True
def show_datasets_table(console: Console = None) -> None:
"""Display all datasets in a Rich table."""
if console is None:
console = Console()
datasets = list_datasets()
if not datasets:
console.print("[yellow] No datasets found yet[/yellow]")
return
table = Table(title="Available Datasets", show_header=True, header_style="bold cyan")
table.add_column("Version", style="dim")
table.add_column("Positions", justify="right")
table.add_column("Sources", justify="left")
table.add_column("Depth", justify="center")
table.add_column("Created", justify="left")
for v, metadata in datasets:
positions = metadata.get('total_positions', 0)
sources = metadata.get('sources', [])
source_str = ", ".join([s.get('type', '?') for s in sources])
depth = metadata.get('stockfish_depth', '?')
created = metadata.get('created', '?')
if created != '?':
created = created.split('T')[0] # Just the date
table.add_row(f"v{v}", f"{positions:,}", source_str, str(depth), created)
console.print(table)
+137
View File
@@ -0,0 +1,137 @@
#!/usr/bin/env python3
"""Export NNUE weights to .nbai format for runtime loading."""
import json
import struct
import sys
from datetime import datetime
from pathlib import Path
import torch
MAGIC = 0x4942_414E # bytes 'N','B','A','I' as little-endian int32
VERSION = 1
def _read_sidecar(weights_file: str) -> dict:
sidecar = weights_file.replace(".pt", "_metadata.json")
if Path(sidecar).exists():
with open(sidecar) as f:
return json.load(f)
return {}
def _infer_layers(state_dict: dict) -> list[dict]:
"""Derive layer descriptors from state_dict weight shapes.
Assumes layers named l1, l2, ..., lN.
All hidden layers get activation 'relu'; the last gets 'linear'.
"""
names = sorted(
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
key=lambda n: int(n[1:]),
)
layers = []
for i, name in enumerate(names):
out_size, in_size = state_dict[f"{name}.weight"].shape
activation = "linear" if i == len(names) - 1 else "relu"
layers.append({"activation": activation, "inputSize": int(in_size), "outputSize": int(out_size)})
return layers
def _write_floats(f, tensor):
data = tensor.float().flatten().cpu().numpy()
f.write(struct.pack("<I", len(data)))
f.write(struct.pack(f"<{len(data)}f", *data))
def export_to_nbai(
weights_file: str,
output_file: str,
trained_by: str = "unknown",
train_loss: float = 0.0,
):
if not Path(weights_file).exists():
print(f"Error: weights file not found at {weights_file}")
sys.exit(1)
loaded = torch.load(weights_file, map_location="cpu")
state_dict = (
loaded["model_state_dict"]
if isinstance(loaded, dict) and "model_state_dict" in loaded
else loaded
)
sidecar = _read_sidecar(weights_file)
val_loss = float(loaded.get("best_val_loss", sidecar.get("final_val_loss", 0.0))) if isinstance(loaded, dict) else 0.0
trained_at = sidecar.get("date", datetime.now().isoformat())
training_data_count = int(sidecar.get("num_positions", 0))
metadata = {
"trainedBy": trained_by,
"trainedAt": trained_at,
"trainingDataCount": training_data_count,
"valLoss": val_loss,
"trainLoss": train_loss,
}
layers = _infer_layers(state_dict)
layer_names = sorted(
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
key=lambda n: int(n[1:]),
)
print(f"Architecture ({len(layers)} layers):")
for i, l in enumerate(layers):
print(f" l{i + 1}: {l['inputSize']} -> {l['outputSize']} [{l['activation']}]")
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
with open(output_file, "wb") as f:
# Header
f.write(struct.pack("<I", MAGIC))
f.write(struct.pack("<H", VERSION))
# Metadata (length-prefixed UTF-8 JSON)
meta_bytes = json.dumps(metadata, indent=2).encode("utf-8")
f.write(struct.pack("<I", len(meta_bytes)))
f.write(meta_bytes)
# Layer descriptors
f.write(struct.pack("<H", len(layers)))
for layer in layers:
name_bytes = layer["activation"].encode("ascii")
f.write(struct.pack("<B", len(name_bytes)))
f.write(name_bytes)
f.write(struct.pack("<I", layer["inputSize"]))
f.write(struct.pack("<I", layer["outputSize"]))
# Weights: weight tensor then bias tensor per layer
for name in layer_names:
w = state_dict[f"{name}.weight"]
b = state_dict[f"{name}.bias"]
_write_floats(f, w)
_write_floats(f, b)
print(f" Wrote {name}: weight {tuple(w.shape)}, bias {tuple(b.shape)}")
size_mb = Path(output_file).stat().st_size / (1024 ** 2)
print(f"\nExported to {output_file} ({size_mb:.2f} MB)")
print(f"Metadata: {json.dumps(metadata, indent=2)}")
if __name__ == "__main__":
weights_file = "nnue_weights.pt"
output_file = "../src/main/resources/nnue_weights.nbai"
trained_by = "unknown"
train_loss = 0.0
if len(sys.argv) > 1:
weights_file = sys.argv[1]
if len(sys.argv) > 2:
output_file = sys.argv[2]
if len(sys.argv) > 3:
trained_by = sys.argv[3]
if len(sys.argv) > 4:
train_loss = float(sys.argv[4])
export_to_nbai(weights_file, output_file, trained_by, train_loss)
+171
View File
@@ -0,0 +1,171 @@
#!/usr/bin/env python3
"""Generate random chess positions for NNUE training with multiprocessing."""
import chess
import random
import sys
from pathlib import Path
from tqdm import tqdm
from multiprocessing import Pool, Queue
from datetime import datetime
import time
def _worker_generate_games(worker_id, games_per_worker, samples_per_game, min_move, max_move):
"""Generate games for one worker.
Returns:
list of FENs generated by this worker
"""
positions = []
for game_num in range(games_per_worker):
board = chess.Board()
move_history = []
# Play a complete random game
while not board.is_game_over() and len(move_history) < 200:
legal_moves = list(board.legal_moves)
if not legal_moves:
break
move = random.choice(legal_moves)
board.push(move)
move_history.append(board.copy())
# Determine the range of moves to sample from
game_length = len(move_history)
valid_start = max(min_move, 0)
valid_end = min(max_move, game_length)
if valid_start >= valid_end:
continue
# Randomly sample positions from this game
sample_count = min(samples_per_game, valid_end - valid_start)
if sample_count > 0:
sample_indices = random.sample(
range(valid_start, valid_end),
k=sample_count
)
for idx in sample_indices:
sampled_board = move_history[idx]
# Only filter truly invalid or terminal positions
if not sampled_board.is_valid() or sampled_board.is_game_over():
continue
# Save position (include check, captures, all positions)
fen = sampled_board.fen()
positions.append(fen)
return positions
def play_random_game_and_collect_positions(
output_file,
total_positions=3000000,
samples_per_game=1,
min_move=1,
max_move=50,
num_workers=8
):
"""Generate positions using multiprocessing with multiple workers.
Args:
output_file: Output file for positions
total_positions: Target number of positions to generate
samples_per_game: Number of positions to sample per game (1-N)
min_move: Minimum move number to start sampling from
max_move: Maximum move number for sampling
num_workers: Number of parallel worker processes
Returns:
Number of valid positions saved
"""
# Estimate games needed (roughly 1 position per game on average)
total_games = max(total_positions // samples_per_game, num_workers)
games_per_worker = total_games // num_workers
print(f"Generating {total_positions:,} positions using {num_workers} workers")
print(f"Total games: ~{total_games:,} ({games_per_worker:,} per worker)")
print()
start_time = datetime.now()
# Generate positions in parallel
worker_tasks = [
(i, games_per_worker, samples_per_game, min_move, max_move)
for i in range(num_workers)
]
positions_count = 0
all_positions = []
with Pool(num_workers) as pool:
with tqdm(total=num_workers, desc="Workers generating games") as pbar:
for positions in pool.starmap(_worker_generate_games, worker_tasks):
all_positions.extend(positions)
positions_count += len(positions)
pbar.update(1)
# Write all positions to file
print(f"Writing {positions_count:,} positions to {output_file}...")
with open(output_file, 'w') as f:
for fen in all_positions:
f.write(fen + '\n')
elapsed_time = datetime.now() - start_time
elapsed_seconds = elapsed_time.total_seconds()
positions_per_second = positions_count / elapsed_seconds if elapsed_seconds > 0 else 0
# Print summary
print()
print("=" * 60)
print("POSITION GENERATION SUMMARY")
print("=" * 60)
print(f"Target positions: {total_positions:,}")
print(f"Actual positions saved: {positions_count:,}")
print(f"Workers: {num_workers}")
print(f"Games per worker: {games_per_worker:,}")
print(f"Samples per game: {samples_per_game}")
print(f"Move range: {min_move}-{max_move}")
print(f"Elapsed time: {elapsed_time}")
print(f"Throughput: {positions_per_second:.0f} positions/second")
print("=" * 60)
print()
if positions_count == 0:
print("WARNING: No valid positions were generated!")
return 0
return positions_count
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training")
parser.add_argument("output_file", nargs="?", default="positions.txt",
help="Output file for positions (default: positions.txt)")
parser.add_argument("--positions", type=int, default=3000000,
help="Target number of positions to generate (default: 3000000)")
parser.add_argument("--samples-per-game", type=int, default=1,
help="Number of positions to sample per game (default: 1)")
parser.add_argument("--min-move", type=int, default=1,
help="Minimum move number to sample from (default: 1)")
parser.add_argument("--max-move", type=int, default=50,
help="Maximum move number to sample from (default: 50)")
parser.add_argument("--workers", type=int, default=8,
help="Number of parallel worker processes (default: 8)")
args = parser.parse_args()
count = play_random_game_and_collect_positions(
output_file=args.output_file,
total_positions=args.positions,
samples_per_game=args.samples_per_game,
min_move=args.min_move,
max_move=args.max_move,
num_workers=args.workers
)
sys.exit(0 if count > 0 else 1)
+326
View File
@@ -0,0 +1,326 @@
#!/usr/bin/env python3
"""Label positions with Stockfish evaluations and analyze distribution."""
import json
import chess.engine
import sys
import os
import numpy as np
from pathlib import Path
from tqdm import tqdm
from multiprocessing import Pool
from functools import partial
def normalize_evaluation(cp_value, method='tanh', scale=300.0):
"""Normalize centipawn evaluation to a bounded range.
Args:
cp_value: Centipawn evaluation from Stockfish
method: 'tanh' (default) or 'sigmoid'
scale: Scale factor (tanh: 300 is typical)
Returns:
Normalized value in approximately [-1, 1] (tanh) or [0, 1] (sigmoid)
"""
if method == 'tanh':
return np.tanh(cp_value / scale)
elif method == 'sigmoid':
return 1.0 / (1.0 + np.exp(-cp_value / scale))
else:
return cp_value / 100.0
def _evaluate_fen_batch(args):
"""Worker function to evaluate a batch of FENs with Stockfish threading.
Args:
args: tuple of (fens, stockfish_path, depth, normalize)
Returns:
list of (fen, eval_normalized, eval_raw) tuples
"""
fens, stockfish_path, depth, normalize = args
results = []
try:
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
except Exception:
return []
try:
for fen in fens:
try:
board = chess.Board(fen)
if not board.is_valid():
continue
info = engine.analyse(board, chess.engine.Limit(depth=depth))
if info.get('score') is None:
continue
score = info['score'].white()
if score.is_mate():
eval_cp = 2000 if score.mate() > 0 else -2000
else:
eval_cp = score.cp
eval_cp = max(-2000, min(2000, eval_cp))
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
results.append((fen, eval_normalized, eval_cp))
except Exception:
continue
finally:
engine.quit()
return results
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=1000, depth=12, verbose=False, normalize=True, num_workers=1):
"""Read positions and label them with Stockfish evaluations.
Args:
positions_file: Path to positions.txt
output_file: Path to training_data.jsonl
stockfish_path: Path to stockfish binary
batch_size: Batch size for processing (positions per worker task, default: 1000)
depth: Stockfish depth
verbose: Print detailed error messages
normalize: If True, normalize evals using tanh
num_workers: Number of parallel Stockfish processes
"""
# Check if stockfish exists
if not Path(stockfish_path).exists():
print(f"Error: Stockfish not found at {stockfish_path}")
print(f"Tried: {stockfish_path}")
print(f"Set STOCKFISH_PATH environment variable or pass as argument")
sys.exit(1)
print(f"Using Stockfish: {stockfish_path}")
print(f"Number of workers: {num_workers}")
# Check if positions file exists
if not Path(positions_file).exists():
print(f"Error: Positions file not found at {positions_file}")
sys.exit(1)
# Load existing evaluations if resuming
evaluated_fens = set()
position_count = 0
if Path(output_file).exists():
with open(output_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
evaluated_fens.add(data['fen'])
position_count += 1
except json.JSONDecodeError:
pass
print(f"Resuming from {position_count} already evaluated positions")
# Load all FENs that need evaluation
fens_to_evaluate = []
fens_seen_in_batch = set() # Track duplicates within current batch
skipped_invalid = 0
skipped_duplicate = 0
with open(positions_file, 'r') as f:
for fen in f:
fen = fen.strip()
if not fen:
skipped_invalid += 1
continue
if fen in evaluated_fens:
skipped_duplicate += 1
continue
if fen in fens_seen_in_batch:
skipped_duplicate += 1
continue
fens_to_evaluate.append(fen)
fens_seen_in_batch.add(fen)
total_to_evaluate = len(fens_to_evaluate)
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
if total_to_evaluate == 0:
if position_count == 0:
print(f"Error: No valid positions to evaluate in {positions_file}")
sys.exit(1)
else:
print(f"All positions already evaluated. No new positions to process.")
return True
print(f"Total positions to process: {total_lines}")
print(f"New positions to evaluate: {total_to_evaluate}")
print(f"Using depth: {depth}")
print()
# Split FENs into batches for workers
batches = []
for i in range(0, total_to_evaluate, batch_size):
batch = fens_to_evaluate[i:i+batch_size]
batches.append((batch, stockfish_path, depth, normalize))
# Process batches in parallel
evaluated = 0
errors = 0
raw_evals = []
normalized_evals = []
import time
start_time = time.time()
with Pool(num_workers) as pool:
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar:
with open(output_file, 'a') as out:
for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
for fen, eval_normalized, eval_cp in batch_results:
# Skip if already evaluated in output file during this run
if fen in evaluated_fens:
continue
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
out.write(json.dumps(data) + '\n')
evaluated_fens.add(fen) # Track as evaluated
evaluated += 1
raw_evals.append(eval_cp)
normalized_evals.append(eval_normalized)
pbar.update(1)
# Update progress for any failed evaluations in the batch
batch_size_actual = len(batches[0][0]) if batches else batch_size
failed = batch_size_actual - len(batch_results)
if failed > 0:
errors += failed
pbar.update(failed)
# Calculate and show throughput and ETA
elapsed = time.time() - start_time
throughput = evaluated / elapsed if elapsed > 0 else 0
remaining_positions = total_to_evaluate - evaluated
eta_seconds = remaining_positions / throughput if throughput > 0 else 0
eta_str = f"{int(eta_seconds // 60)}:{int(eta_seconds % 60):02d}"
if (batch_idx + 1) % max(1, len(batches) // 10) == 0:
pbar.set_postfix({
'rate': f'{throughput:.0f} pos/s',
'eta': eta_str
})
# Print summary and analysis
print()
print("=" * 60)
print("LABELING SUMMARY")
print("=" * 60)
print(f"Successfully evaluated: {evaluated}")
print(f"Skipped (duplicates): {skipped_duplicate}")
print(f"Skipped (invalid): {skipped_invalid}")
print(f"Errors: {errors}")
print(f"Total processed: {evaluated + skipped_duplicate + skipped_invalid + errors}")
print("=" * 60)
print()
if evaluated == 0:
print("WARNING: No positions were successfully evaluated!")
print("Check that:")
print(" 1. positions.txt is not empty")
print(" 2. positions.txt contains valid FENs")
print(" 3. Stockfish is installed and working")
print(" 4. Stockfish path is correct")
return False
# Print distribution analysis
if raw_evals:
raw_evals_arr = np.array(raw_evals)
norm_evals_arr = np.array(normalized_evals)
print("=" * 60)
print("EVALUATION DISTRIBUTION ANALYSIS")
print("=" * 60)
print()
print("Raw Evaluations (centipawns):")
print(f" Min: {raw_evals_arr.min():.1f}")
print(f" Max: {raw_evals_arr.max():.1f}")
print(f" Mean: {raw_evals_arr.mean():.1f}")
print(f" Median: {np.median(raw_evals_arr):.1f}")
print(f" Std: {raw_evals_arr.std():.1f}")
print()
print("Normalized Evaluations (tanh):")
print(f" Min: {norm_evals_arr.min():.4f}")
print(f" Max: {norm_evals_arr.max():.4f}")
print(f" Mean: {norm_evals_arr.mean():.4f}")
print(f" Median: {np.median(norm_evals_arr):.4f}")
print(f" Std: {norm_evals_arr.std():.4f}")
print()
# Distribution buckets
print("Raw Evaluation Buckets (counts):")
buckets = [
(-float('inf'), -500, "< -5.00"),
(-500, -300, "[-5.00, -3.00)"),
(-300, -100, "[-3.00, -1.00)"),
(-100, 0, "[-1.00, 0.00)"),
(0, 100, "[0.00, 1.00)"),
(100, 300, "[1.00, 3.00)"),
(300, 500, "[3.00, 5.00)"),
(500, float('inf'), "> 5.00"),
]
for low, high, label in buckets:
count = np.sum((raw_evals_arr > low) & (raw_evals_arr <= high))
pct = 100.0 * count / len(raw_evals_arr)
print(f" {label}: {count:6d} ({pct:5.1f}%)")
print("=" * 60)
print()
print(f"✓ Labeling complete. Output saved to {output_file}")
return True
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Label chess positions with Stockfish evaluations")
parser.add_argument("positions_file", nargs="?", default="positions.txt",
help="Input positions file (default: positions.txt)")
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
help="Output file (default: training_data.jsonl)")
parser.add_argument("stockfish_path", nargs="?", default=None,
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
parser.add_argument("--depth", type=int, default=12,
help="Stockfish depth (default: 12)")
parser.add_argument("--batch-size", type=int, default=1000,
help="Batch size for processing (default: 1000)")
parser.add_argument("--no-normalize", action="store_true",
help="Disable evaluation normalization (keep raw centipawns)")
parser.add_argument("--verbose", action="store_true",
help="Print detailed error messages")
parser.add_argument("--workers", type=int, default=1,
help="Number of parallel Stockfish processes (default: 1)")
args = parser.parse_args()
# Determine Stockfish path
stockfish_path = args.stockfish_path or os.environ.get("STOCKFISH_PATH", "stockfish")
success = label_positions_with_stockfish(
positions_file=args.positions_file,
output_file=args.output_file,
stockfish_path=stockfish_path,
batch_size=args.batch_size,
depth=args.depth,
normalize=not args.no_normalize,
verbose=args.verbose,
num_workers=args.workers
)
sys.exit(0 if success else 1)
+208
View File
@@ -0,0 +1,208 @@
#!/usr/bin/env python3
"""Import pre-labeled positions from the Lichess evaluation database.
Source: https://database.lichess.org/#evals
Format: lichess_db_eval.jsonl.zst — compressed JSONL, one position per line.
Each line:
{
"fen": "<pieces> <turn> <castling> <ep>",
"evals": [
{
"knodes": <int>,
"depth": <int>,
"pvs": [{"cp": <int>, "line": "..."} | {"mate": <int>, "line": "..."}]
}
]
}
cp and mate are from White's perspective (positive = White winning), matching
the sign convention used by label.py (score.white()) and expected by train.py.
"""
import json
import sys
import numpy as np
from pathlib import Path
from tqdm import tqdm
MATE_CP = 20000
SCALE = 300.0
def _best_eval(evals: list) -> dict | None:
"""Return the highest-depth evaluation entry, using knodes as tiebreaker."""
if not evals:
return None
return max(evals, key=lambda e: (e.get("depth", 0), e.get("knodes", 0)))
def _cp_from_pv(pv: dict) -> int | None:
"""Extract centipawn value from a principal variation entry."""
if "cp" in pv:
return max(-MATE_CP, min(MATE_CP, pv["cp"]))
if "mate" in pv:
return MATE_CP if pv["mate"] > 0 else -MATE_CP
return None
def _normalize(cp: int) -> float:
return float(np.tanh(cp / SCALE))
def import_lichess_evals(
input_path: str,
output_file: str,
max_positions: int | None = None,
min_depth: int = 0,
verbose: bool = False,
) -> int:
"""Stream the Lichess eval database and write a labeled.jsonl file.
Args:
input_path: Path to lichess_db_eval.jsonl.zst (or uncompressed .jsonl).
output_file: Destination labeled.jsonl (appended — supports resuming).
max_positions: Stop after this many new positions (None = no limit).
min_depth: Skip positions whose best eval has depth < min_depth.
verbose: Print warnings for skipped lines.
Returns:
Number of new positions written.
"""
import zstandard as zstd
input_path = Path(input_path)
if not input_path.exists():
print(f"Error: {input_path} not found")
sys.exit(1)
# Resume: collect already-written FENs so we skip duplicates.
seen_fens: set[str] = set()
if Path(output_file).exists():
with open(output_file, "r") as f:
for line in f:
try:
seen_fens.add(json.loads(line)["fen"])
except (json.JSONDecodeError, KeyError):
pass
if seen_fens:
print(f"Resuming — skipping {len(seen_fens):,} already-imported positions")
written = 0
skipped_depth = 0
skipped_no_eval = 0
skipped_dup = 0
def iter_lines():
"""Yield decoded text lines from either a .zst or plain .jsonl file."""
import io
if input_path.suffix == ".zst":
dctx = zstd.ZstdDecompressor()
with open(input_path, "rb") as fh:
with dctx.stream_reader(fh) as reader:
text_stream = io.TextIOWrapper(reader, encoding="utf-8")
yield from text_stream
else:
with open(input_path, "r", encoding="utf-8") as fh:
yield from fh
try:
with open(output_file, "a") as out:
with tqdm(desc="Importing Lichess evals", unit=" pos") as pbar:
for raw_line in iter_lines():
line = raw_line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
if verbose:
print("Warning: malformed JSON line skipped")
continue
fen = data.get("fen", "")
if not fen:
skipped_no_eval += 1
continue
if fen in seen_fens:
skipped_dup += 1
continue
best = _best_eval(data.get("evals", []))
if best is None:
skipped_no_eval += 1
continue
if best.get("depth", 0) < min_depth:
skipped_depth += 1
continue
pvs = best.get("pvs", [])
if not pvs:
skipped_no_eval += 1
continue
cp = _cp_from_pv(pvs[0])
if cp is None:
skipped_no_eval += 1
continue
record = {
"fen": fen,
"eval": _normalize(cp),
"eval_raw": cp,
}
out.write(json.dumps(record) + "\n")
seen_fens.add(fen)
written += 1
pbar.update(1)
if max_positions and written >= max_positions:
print(f"\nReached max_positions limit ({max_positions:,})")
break
except Exception:
raise
print()
print("=" * 60)
print("LICHESS IMPORT SUMMARY")
print("=" * 60)
print(f"Positions written: {written:,}")
print(f"Skipped (dup): {skipped_dup:,}")
print(f"Skipped (no eval): {skipped_no_eval:,}")
print(f"Skipped (depth<{min_depth}): {skipped_depth:,}")
print("=" * 60)
print(f"\n✓ Output: {output_file}")
return written
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Import Lichess pre-labeled positions into labeled.jsonl"
)
parser.add_argument("input_path",
help="Path to lichess_db_eval.jsonl.zst")
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
help="Output labeled.jsonl (default: training_data.jsonl)")
parser.add_argument("--max-positions", type=int, default=None,
help="Stop after N positions (default: no limit)")
parser.add_argument("--min-depth", type=int, default=0,
help="Minimum eval depth to accept (default: 0)")
parser.add_argument("--verbose", action="store_true",
help="Print warnings for skipped lines")
args = parser.parse_args()
count = import_lichess_evals(
input_path=args.input_path,
output_file=args.output_file,
max_positions=args.max_positions,
min_depth=args.min_depth,
verbose=args.verbose,
)
sys.exit(0 if count > 0 else 1)
@@ -0,0 +1,249 @@
import chess
import csv
import json
import sys
import urllib.request
from pathlib import Path
from typing import Set, Tuple
try:
import zstandard as zstd
except ImportError:
print("zstandard library not found. Install with: pip install zstandard")
sys.exit(1)
from generate import play_random_game_and_collect_positions
def download_and_extract_puzzle_db(
url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst',
output_dir: str = 'tactical_data'
):
"""Download and extract the Lichess puzzle database."""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
csv_file = output_path / 'lichess_db_puzzle.csv'
zst_file = output_path / 'lichess_db_puzzle.csv.zst'
# Download if not already present
if not zst_file.exists():
print(f"Downloading puzzle database from {url}...")
try:
urllib.request.urlretrieve(url, zst_file)
print(f"Downloaded to {zst_file}")
except Exception as e:
print(f"Failed to download: {e}")
return None
# Extract if CSV doesn't exist
if not csv_file.exists():
print(f"Extracting {zst_file}...")
try:
with open(zst_file, 'rb') as f:
dctx = zstd.ZstdDecompressor()
with dctx.stream_reader(f) as reader:
with open(csv_file, 'wb') as out:
out.write(reader.read())
print(f"Extracted to {csv_file}")
except Exception as e:
print(f"Failed to extract: {e}")
return None
return str(csv_file)
def extract_puzzle_positions(
puzzle_csv: str,
max_puzzles: int = 300_000
) -> Set[str]:
"""
Extract the position BEFORE the blunder from each puzzle.
This is exactly the type of position where tactical
recognition matters most.
Returns a set of unique FENs.
"""
positions = set()
with open(puzzle_csv) as f:
reader = csv.DictReader(f)
for row in reader:
if len(positions) >= max_puzzles:
break
try:
board = chess.Board(row['FEN'])
# The puzzle FEN is AFTER the blunder move
# We want the position BEFORE — so it learns
# to find the tactic, not just play it
moves = row['Moves'].split()
# Undo one move to get pre-tactic position
board.push_uci(moves[0]) # opponent blunder
fen = board.fen()
# Filter for useful tactical themes
themes = row.get('Themes', '')
useful = any(t in themes for t in [
'fork', 'pin', 'skewer', 'discoveredAttack',
'mate', 'mateIn2', 'mateIn3', 'hangingPiece',
'trappedPiece', 'sacrifice'
])
if useful:
positions.add(fen)
except Exception:
continue
return positions
def load_positions_from_file(file_path: str) -> Set[str]:
"""Load positions from a text file (one FEN per line)."""
positions = set()
try:
with open(file_path) as f:
for line in f:
line = line.strip()
if line:
positions.add(line)
print(f"Loaded {len(positions)} positions from {file_path}")
return positions
except Exception as e:
print(f"Failed to load from {file_path}: {e}")
return set()
def merge_positions(
tactical: Set[str],
other: Set[str],
output_file: str = 'position.txt'
):
"""Merge two position sets and write to file."""
merged = tactical | other
with open(output_file, 'w') as f:
for fen in merged:
f.write(fen + '\n')
overlap = len(tactical & other)
print(f"\n{'='*60}")
print(f"MERGE SUMMARY")
print(f"{'='*60}")
print(f"Tactical positions: {len(tactical):,}")
print(f"Other positions: {len(other):,}")
print(f"Overlap (deduplicated): {overlap:,}")
print(f"Total merged positions: {len(merged):,}")
print(f"Written to: {output_file}")
print(f"{'='*60}\n")
def extract_tactical_only(
puzzle_csv: str,
output_file: str,
max_puzzles: int = 300_000
) -> int:
"""Extract tactical positions and save to file (no merge prompts).
Args:
puzzle_csv: Path to Lichess puzzle CSV
output_file: Where to save the FEN positions
max_puzzles: Maximum puzzles to extract
Returns:
Number of positions extracted
"""
print("Extracting tactical positions from puzzle database...")
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
with open(output_file, 'w') as f:
for fen in tactical_positions:
f.write(fen + '\n')
return len(tactical_positions)
def interactive_merge_positions(
puzzle_csv: str,
output_file: str = 'position.txt',
max_puzzles: int = 300_000
):
"""Interactive workflow: extract tactical positions and merge with user selection."""
print("\n" + "="*60)
print("TACTICAL POSITION EXTRACTOR & MERGER")
print("="*60 + "\n")
# Extract tactical positions
print("Extracting tactical positions from puzzle database...")
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
print(f"Extracted {len(tactical_positions):,} unique tactical positions\n")
# Ask what to merge with
print("What would you like to merge with these tactical positions?")
print("1. Load from a position file")
print("2. Generate random positions")
print("3. Skip merging (save tactical only)")
choice = input("\nEnter choice (1-3): ").strip()
other_positions = set()
if choice == '1':
file_path = input("Enter path to position file: ").strip()
other_positions = load_positions_from_file(file_path)
elif choice == '2':
positions_to_gen = input("How many positions to generate? (default 1000000): ").strip()
try:
positions_to_gen = int(positions_to_gen) if positions_to_gen else 1000000
except ValueError:
positions_to_gen = 1000000
temp_file = 'temp_generated_positions.txt'
print(f"\nGenerating {positions_to_gen:,} random positions...")
play_random_game_and_collect_positions(
output_file=temp_file,
total_positions=positions_to_gen,
samples_per_game=1,
min_move=1,
max_move=50,
num_workers=8
)
other_positions = load_positions_from_file(temp_file)
elif choice == '3':
print("Skipping merge, saving tactical positions only...")
else:
print("Invalid choice, saving tactical positions only...")
merge_positions(tactical_positions, other_positions, output_file)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description="Extract and merge tactical positions")
parser.add_argument("--url", default='https://database.lichess.org/lichess_db_puzzle.csv.zst',
help="URL to download puzzle database from")
parser.add_argument("--output-dir", default='trainingdata',
help="Directory to extract puzzle database to")
parser.add_argument("--max-puzzles", type=int, default=300_000,
help="Maximum puzzles to extract (default: 300000)")
parser.add_argument("--output-file", default='position.txt',
help="Output file for merged positions (default: position.txt)")
args = parser.parse_args()
# Download and extract
csv_path = download_and_extract_puzzle_db(args.url, args.output_dir)
if csv_path:
# Interactive merge
interactive_merge_positions(csv_path, args.output_file, args.max_puzzles)
else:
print("Failed to download/extract puzzle database")
sys.exit(1)
+676
View File
@@ -0,0 +1,676 @@
#!/usr/bin/env python3
"""Train NNUE neural network for chess evaluation."""
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import sys
from pathlib import Path
from tqdm import tqdm
import chess
from datetime import datetime, timedelta
import re
import numpy as np
class NNUEDataset(Dataset):
"""Dataset of chess positions with evaluations."""
def __init__(self, data_file):
self.positions = []
self.evals = []
self.evals_raw = []
self.is_normalized = None
with open(data_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
fen = data['fen']
eval_val = data['eval']
self.positions.append(fen)
self.evals.append(eval_val)
# Check if normalized or raw
if self.is_normalized is None:
# If eval is in range [-1, 1], assume normalized
self.is_normalized = abs(eval_val) <= 1.0
# Store raw if available
if 'eval_raw' in data:
self.evals_raw.append(data['eval_raw'])
else:
self.evals_raw.append(eval_val)
except (json.JSONDecodeError, KeyError):
pass
def __len__(self):
return len(self.positions)
def __getitem__(self, idx):
fen = self.positions[idx]
eval_val = self.evals[idx]
features = fen_to_features(fen)
# Use evaluation as-is if normalized, otherwise apply sigmoid scaling
if self.is_normalized:
target = torch.tensor(eval_val, dtype=torch.float32)
else:
target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32))
return features, target
def fen_to_features(fen):
"""Convert FEN to 768-dimensional binary feature vector."""
# Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5
piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}
features = torch.zeros(768, dtype=torch.float32)
try:
board = chess.Board(fen)
# 12 piece types × 64 squares = 768
for square in chess.SQUARES:
piece = board.piece_at(square)
if piece is not None:
piece_char = piece.symbol()
if piece_char in piece_to_idx:
piece_idx = piece_to_idx[piece_char]
feature_idx = piece_idx * 64 + square
features[feature_idx] = 1.0
except:
pass
return features
DEFAULT_HIDDEN_SIZES = [1536, 1024, 512, 256]
class NNUE(nn.Module):
"""NNUE neural network with configurable hidden layers.
Architecture: 768 → hidden_sizes[0] → ... → hidden_sizes[-1] → 1
Layer attributes follow the naming l1, l2, ..., lN so export.py can
infer the architecture directly from the state_dict.
"""
def __init__(self, hidden_sizes=None, dropout_rate=0.2):
super().__init__()
if hidden_sizes is None:
hidden_sizes = DEFAULT_HIDDEN_SIZES
self.hidden_sizes = list(hidden_sizes)
sizes = [768] + self.hidden_sizes + [1]
num_hidden = len(self.hidden_sizes)
for i in range(num_hidden):
setattr(self, f"l{i + 1}", nn.Linear(sizes[i], sizes[i + 1]))
setattr(self, f"relu{i + 1}", nn.ReLU())
setattr(self, f"drop{i + 1}", nn.Dropout(dropout_rate))
setattr(self, f"l{num_hidden + 1}", nn.Linear(sizes[-2], sizes[-1]))
self._num_hidden = num_hidden
def forward(self, x):
for i in range(1, self._num_hidden + 1):
layer = getattr(self, f"l{i}")
relu = getattr(self, f"relu{i}")
drop = getattr(self, f"drop{i}")
x = drop(relu(layer(x)))
return getattr(self, f"l{self._num_hidden + 1}")(x)
def find_next_version(base_name="nnue_weights"):
"""Find the next version number for model versioning.
Looks for nnue_weights_v*.pt files and returns the next version number.
If no versioned files exist, returns 1.
"""
base_path = Path(base_name)
directory = base_path.parent
filename = base_path.name
pattern = re.compile(rf"{re.escape(filename)}_v(\d+)\.pt")
versions = []
for file in directory.glob(f"{filename}_v*.pt"):
match = pattern.match(file.name)
if match:
versions.append(int(match.group(1)))
if versions:
return max(versions) + 1
return 1
def save_metadata(weights_file, metadata):
"""Save training metadata alongside the weights file.
Args:
weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt)
metadata: Dictionary with training info
"""
metadata_file = weights_file.replace(".pt", "_metadata.json")
with open(metadata_file, "w") as f:
json.dump(metadata, f, indent=2, default=str)
return metadata_file
def _setup_training(data_file, batch_size, subsample_ratio):
"""Set up device, dataset, and data loaders.
Returns:
(device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions)
"""
print("Checking GPU availability...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
print("⚠ GPU not available, using CPU")
print(f"Using device: {device}")
print()
print("Loading dataset...")
dataset = NNUEDataset(data_file)
num_positions = len(dataset)
print(f"Dataset size: {num_positions}")
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
evals_array = np.array(dataset.evals)
print()
print("=" * 60)
print("TRAINING DATASET DIAGNOSTICS")
print("=" * 60)
print(f"Min evaluation: {evals_array.min():.4f}")
print(f"Max evaluation: {evals_array.max():.4f}")
print(f"Mean evaluation: {evals_array.mean():.4f}")
print(f"Median evaluation: {np.median(evals_array):.4f}")
print(f"Std deviation: {evals_array.std():.4f}")
print("=" * 60)
print()
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
from torch.utils.data import random_split, RandomSampler
generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
train_sampler = RandomSampler(train_dataset, replacement=False, num_samples=subsample_size)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
sampler=train_sampler,
num_workers=8,
pin_memory=True,
persistent_workers=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=8,
pin_memory=True,
persistent_workers=True
)
return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions
def _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
start_epoch, epochs, early_stopping_patience,
season_start_time, deadline=None, initial_best_val_loss=float('inf')
):
"""Run one training season until epoch limit, early stopping, or deadline.
Args:
initial_best_val_loss: Baseline to beat — epochs that don't improve on this count
toward early stopping and do not save snapshots.
Returns:
(best_val_loss, best_model_state, last_epoch)
best_model_state is None if no epoch beat initial_best_val_loss.
"""
best_val_loss = initial_best_val_loss
best_model_state = None
epochs_without_improvement = 0
total_epochs = start_epoch + epochs
last_epoch = start_epoch - 1
for epoch in range(start_epoch, start_epoch + epochs):
if deadline and datetime.now() >= deadline:
print("Time limit reached, stopping season.")
break
epoch_display = epoch + 1
# Train
model.train()
train_loss = 0.0
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
for batch_features, batch_targets in train_loader:
batch_features = batch_features.to(device)
batch_targets = batch_targets.to(device).unsqueeze(1)
optimizer.zero_grad()
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item() * batch_features.size(0)
pbar.update(1)
train_loss /= len(train_dataset)
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar:
for batch_features, batch_targets in val_loader:
batch_features = batch_features.to(device)
batch_targets = batch_targets.to(device).unsqueeze(1)
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
val_loss += loss.item() * batch_features.size(0)
pbar.update(1)
val_loss /= len(val_dataset)
scheduler.step()
if torch.cuda.is_available():
gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9
gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9
print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved")
elapsed_time = datetime.now() - season_start_time
time_per_epoch = elapsed_time.total_seconds() / (epoch + 1)
remaining_epochs = total_epochs - epoch_display
eta_seconds = time_per_epoch * remaining_epochs
eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0]
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}")
checkpoint_file = output_file.replace(".pt", "_checkpoint.pt")
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"best_val_loss": best_val_loss,
"hidden_sizes": model.hidden_sizes,
}, checkpoint_file)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = model.state_dict().copy()
epochs_without_improvement = 0
snapshot_file = output_file.replace(".pt", "_best_snapshot.pt")
torch.save(best_model_state, snapshot_file)
print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})")
else:
epochs_without_improvement += 1
last_epoch = epoch
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
break
return best_val_loss, best_model_state, last_epoch
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
best_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, training_start_time, hidden_sizes=None,
extra_metadata=None):
"""Save the best model with optional versioning and metadata."""
final_output_file = output_file
metadata = {}
architecture = [768] + list(hidden_sizes or DEFAULT_HIDDEN_SIZES) + [1]
if use_versioning:
base_name = output_file.replace(".pt", "")
version = find_next_version(base_name)
final_output_file = f"{base_name}_v{version}.pt"
metadata = {
"version": version,
"date": training_start_time.isoformat(),
"num_positions": num_positions,
"stockfish_depth": stockfish_depth,
"final_val_loss": float(best_val_loss),
"architecture": architecture,
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
if extra_metadata:
metadata.update(extra_metadata)
torch.save({
"model_state_dict": best_model_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"epoch": last_epoch,
"best_val_loss": best_val_loss,
"hidden_sizes": list(hidden_sizes or DEFAULT_HIDDEN_SIZES),
}, final_output_file)
print(f"Best model saved to {final_output_file}")
if use_versioning and metadata:
metadata_file = save_metadata(final_output_file, metadata)
print(f"Metadata saved to {metadata_file}")
print(f"\nTraining Summary:")
for key, val in metadata.items():
print(f" {key}: {val}")
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
Args:
data_file: Path to training_data.jsonl
output_file: Where to save best weights (or base name if use_versioning=True)
epochs: Number of training epochs (default: 100)
batch_size: Training batch size (default: 16384)
lr: Learning rate (default: 0.001)
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
"""
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
_setup_training(data_file, batch_size, subsample_ratio)
start_epoch = 0
best_val_loss = float('inf')
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
if checkpoint:
print(f"Loading checkpoint: {checkpoint}")
ckpt = torch.load(checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
ckpt_hidden = ckpt.get("hidden_sizes")
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
print(f" Using architecture from checkpoint: {ckpt_hidden}")
resolved_hidden_sizes = ckpt_hidden
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
if checkpoint:
ckpt = torch.load(checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
scaler.load_state_dict(ckpt["scaler_state_dict"])
start_epoch = ckpt["epoch"] + 1
best_val_loss = ckpt.get("best_val_loss", float('inf'))
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no optimizer state)")
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
arch_str = "".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
print(f"Architecture: {arch_str}")
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
print(f"Mixed precision training: enabled")
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
if subsample_ratio < 1.0:
print(f"Stochastic sampling: {subsample_ratio:.0%} of train set per epoch ({subsample_size:,} positions)")
if early_stopping_patience:
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
print()
training_start_time = datetime.now()
best_val_loss, best_model_state, last_epoch = _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
start_epoch, epochs, early_stopping_patience,
training_start_time
)
if best_model_state is None or best_val_loss >= checkpoint_val_loss:
print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})")
print("No new model created.")
return
_save_versioned_model(
best_model_state, optimizer, scheduler, scaler, last_epoch,
best_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, training_start_time,
hidden_sizes=resolved_hidden_sizes,
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
"checkpoint": str(checkpoint) if checkpoint else None}
)
def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
epochs_per_season=50, early_stopping_patience=10,
batch_size=16384, lr=0.001, initial_checkpoint=None,
stockfish_depth=12, use_versioning=True,
weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
Each season trains with early stopping. When early stopping fires, the model reloads the
global best weights and begins a fresh season with a reset optimizer and scheduler.
This prevents the model from drifting away from its best known state.
Args:
data_file: Path to training_data.jsonl
output_file: Output file base name
duration_minutes: Total training budget in minutes
epochs_per_season: Max epochs per restart season (default: 50)
early_stopping_patience: Patience for early stopping within each season (default: 10)
batch_size: Training batch size (default: 16384)
lr: Learning rate reset to this value at the start of each season (default: 0.001)
initial_checkpoint: Optional weights-only .pt file to start from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
weight_decay: L2 regularization strength (default: 1e-4)
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
"""
deadline = datetime.now() + timedelta(minutes=duration_minutes)
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
_setup_training(data_file, batch_size, subsample_ratio)
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
if initial_checkpoint:
print(f"Loading initial weights: {initial_checkpoint}")
ckpt = torch.load(initial_checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
ckpt_hidden = ckpt.get("hidden_sizes")
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
print(f" Using architecture from checkpoint: {ckpt_hidden}")
resolved_hidden_sizes = ckpt_hidden
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
criterion = nn.MSELoss()
best_global_val_loss = float('inf')
if initial_checkpoint:
ckpt = torch.load(initial_checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
best_global_val_loss = ckpt.get("best_val_loss", float('inf'))
if best_global_val_loss < float('inf'):
print(f"Resumed from checkpoint (best val loss: {best_global_val_loss:.6f})")
else:
print("Initial weights loaded (no val loss in checkpoint).")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no val loss info).")
arch_str = "".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
print(f"Architecture: {arch_str}")
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
print()
burst_start_time = datetime.now()
season = 0
best_global_state = None
last_optimizer = None
last_scheduler = None
last_scaler = None
last_epoch = 0
while datetime.now() < deadline:
season += 1
remaining_minutes = (deadline - datetime.now()).total_seconds() / 60
print(f"\n{'=' * 60}")
print(f"BURST SEASON {season} | {remaining_minutes:.1f} minutes remaining")
if best_global_val_loss < float('inf'):
print(f"Global best val loss so far: {best_global_val_loss:.6f}")
print(f"{'=' * 60}\n")
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs_per_season)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
season_start_time = datetime.now()
val_loss, model_state, last_epoch = _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
0, epochs_per_season, early_stopping_patience,
season_start_time, deadline=deadline,
initial_best_val_loss=best_global_val_loss
)
last_optimizer = optimizer
last_scheduler = scheduler
last_scaler = scaler
if model_state is not None and val_loss < best_global_val_loss:
best_global_val_loss = val_loss
best_global_state = model_state
print(f" New global best: {best_global_val_loss:.6f} (season {season})")
# Reload global best for the next season so we never drift backwards
if best_global_state is not None:
model.load_state_dict(best_global_state)
total_minutes = (datetime.now() - burst_start_time).total_seconds() / 60
print(f"\n{'=' * 60}")
print(f"Burst training complete: {season} season(s) in {total_minutes:.1f}m")
print(f"Best val loss: {best_global_val_loss:.6f}")
print(f"{'=' * 60}\n")
if best_global_state is None:
print("No model improvement found. No file saved.")
return
_save_versioned_model(
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
best_global_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, burst_start_time,
hidden_sizes=resolved_hidden_sizes,
extra_metadata={
"mode": "burst",
"duration_minutes": duration_minutes,
"epochs_per_season": epochs_per_season,
"early_stopping_patience": early_stopping_patience,
"seasons_completed": season,
"batch_size": batch_size,
"learning_rate": lr,
"initial_checkpoint": str(initial_checkpoint) if initial_checkpoint else None,
}
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation")
parser.add_argument("data_file", nargs="?", default="training_data.jsonl",
help="Path to training_data.jsonl (default: training_data.jsonl)")
parser.add_argument("output_file", nargs="?", default="nnue_weights.pt",
help="Output file base name (default: nnue_weights.pt)")
parser.add_argument("--checkpoint", type=str, default=None,
help="Path to checkpoint file to resume training from (optional)")
parser.add_argument("--epochs", type=int, default=100,
help="Number of epochs to train (default: 100)")
parser.add_argument("--batch-size", type=int, default=16384,
help="Batch size (default: 16384)")
parser.add_argument("--lr", type=float, default=0.001,
help="Learning rate (default: 0.001)")
parser.add_argument("--early-stopping", type=int, default=None,
help="Stop if val loss doesn't improve for N epochs (optional)")
parser.add_argument("--stockfish-depth", type=int, default=12,
help="Stockfish depth used for evaluations (for metadata, default: 12)")
parser.add_argument("--no-versioning", action="store_true",
help="Disable automatic versioning (save directly to output file)")
parser.add_argument("--weight-decay", type=float, default=5e-5,
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
parser.add_argument("--subsample-ratio", type=float, default=1.0,
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
parser.add_argument("--hidden-layers", type=str, default=None,
help="Comma-separated hidden layer sizes (default: 1536,1024,512,256)")
# Burst mode
parser.add_argument("--burst-duration", type=float, default=None,
help="Enable burst mode: total training budget in minutes")
parser.add_argument("--epochs-per-season", type=int, default=50,
help="Max epochs per burst season before restarting (default: 50, burst mode only)")
args = parser.parse_args()
hidden_sizes = [int(x) for x in args.hidden_layers.split(",")] if args.hidden_layers else None
if args.burst_duration is not None:
burst_train(
data_file=args.data_file,
output_file=args.output_file,
duration_minutes=args.burst_duration,
epochs_per_season=args.epochs_per_season,
early_stopping_patience=args.early_stopping or 10,
batch_size=args.batch_size,
lr=args.lr,
initial_checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning,
weight_decay=args.weight_decay,
subsample_ratio=args.subsample_ratio,
hidden_sizes=hidden_sizes,
)
else:
train_nnue(
data_file=args.data_file,
output_file=args.output_file,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning,
early_stopping_patience=args.early_stopping,
weight_decay=args.weight_decay,
subsample_ratio=args.subsample_ratio,
hidden_sizes=hidden_sizes,
)