Files
NowChessSystems/modules/bot/python/nnue.py
T

276 lines
8.6 KiB
Python

#!/usr/bin/env python3
"""Central NNUE pipeline CLI for training and exporting models."""
import argparse
import os
import sys
import subprocess
from pathlib import Path
def get_python_cmd():
"""Get available Python command."""
if os.name == 'nt':
return "python"
return "python3" if os.popen("which python3 2>/dev/null").read() else "python"
def get_src_module(module_name):
"""Get path to module in src/ directory."""
return Path(__file__).parent / "src" / f"{module_name}.py"
def get_data_dir():
"""Get/create data directory."""
data_dir = Path(__file__).parent / "data"
data_dir.mkdir(exist_ok=True)
return data_dir
def get_weights_dir():
"""Get/create weights directory."""
weights_dir = Path(__file__).parent / "weights"
weights_dir.mkdir(exist_ok=True)
return weights_dir
def list_checkpoints():
"""List available checkpoint versions."""
weights_dir = get_weights_dir()
checkpoints = sorted(weights_dir.glob("nnue_weights_v*.pt"))
if not checkpoints:
return []
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
def run_generate_positions(num_games):
"""Generate random positions."""
data_dir = get_data_dir()
positions_file = data_dir / "positions.txt"
print(f"Generating {num_games} positions...")
result = subprocess.run(
[get_python_cmd(), str(get_src_module("generate")), str(positions_file), "--games", str(num_games)],
capture_output=False
)
if result.returncode != 0:
print("ERROR: Position generation failed")
return False
return positions_file.exists()
def run_label_positions(stockfish_path):
"""Label positions with Stockfish."""
data_dir = get_data_dir()
positions_file = data_dir / "positions.txt"
output_file = data_dir / "training_data.jsonl"
if not positions_file.exists():
print("ERROR: positions.txt not found")
return False
print("Labeling positions with Stockfish...")
result = subprocess.run(
[get_python_cmd(), str(get_src_module("label")), str(positions_file), str(output_file), stockfish_path],
capture_output=False
)
if result.returncode != 0:
print("ERROR: Position labeling failed")
return False
return output_file.exists()
def run_train(positions_file, output_weights, from_checkpoint=None):
"""Train NNUE model."""
if not Path(positions_file).exists():
print(f"ERROR: {positions_file} not found")
return False
weights_dir = get_weights_dir()
print(f"Training model (output: {output_weights})...")
if from_checkpoint:
print(f" Starting from checkpoint: {from_checkpoint}")
cmd = [get_python_cmd(), str(get_src_module("train")), str(positions_file), str(output_weights)]
if from_checkpoint:
cmd.extend(["--checkpoint", str(from_checkpoint)])
# Run from weights directory so outputs save there
result = subprocess.run(cmd, cwd=str(weights_dir), capture_output=False)
if result.returncode != 0:
print("ERROR: Training failed")
return False
return True
def run_export(weights_file, output_file):
"""Export weights to Scala."""
weights_dir = get_weights_dir()
weights_path = weights_dir / Path(weights_file).name
if not weights_path.exists():
print(f"ERROR: {weights_file} not found in {weights_dir}")
return False
print(f"Exporting {weights_file} to Scala...")
result = subprocess.run(
[get_python_cmd(), str(get_src_module("export")), str(weights_path), output_file],
capture_output=False
)
if result.returncode != 0:
print("ERROR: Export failed")
return False
return Path(output_file).exists()
def cmd_train(args):
"""Handle train command."""
stockfish_path = args.stockfish or os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
data_dir = get_data_dir()
weights_dir = get_weights_dir()
# Determine checkpoint
checkpoint = None
if args.from_checkpoint:
checkpoint_version = args.from_checkpoint
checkpoint = f"nnue_weights_v{checkpoint_version}.pt"
checkpoint_path = weights_dir / checkpoint
if not checkpoint_path.exists():
print(f"ERROR: Checkpoint {checkpoint} not found")
return False
else:
available = list_checkpoints()
if available:
latest = max(available)
checkpoint = f"nnue_weights_v{latest}.pt"
print(f"No checkpoint specified, using latest: v{latest}")
# Generate or use existing positions
if args.positions_file:
positions_file = Path(args.positions_file)
if not positions_file.exists():
print(f"ERROR: {args.positions_file} not found")
return False
else:
positions_file = data_dir / "positions.txt"
num_games = args.games or 500000
if not run_generate_positions(num_games):
return False
# Label positions
if not run_label_positions(stockfish_path):
return False
print("\nStarting training...")
# Train with absolute path to data, checkpoint is relative to weights dir
training_data = str(data_dir / "training_data.jsonl")
if not run_train(training_data, "nnue_weights.pt", checkpoint):
return False
# Show created version
available = list_checkpoints()
new_version = max(available) if available else 1
print(f"\n✓ Training complete: nnue_weights_v{new_version}.pt")
return True
def cmd_export(args):
"""Handle export command."""
weights_file = args.weights
# Auto-detect if version is specified
if not weights_file.endswith(".pt"):
weights_file = f"nnue_weights_v{weights_file}.pt"
# Output to resources directory as binary format
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
if not run_export(weights_file, output_file):
return False
print(f"✓ Export complete: {output_file}")
return True
def cmd_list(args):
"""List available checkpoints."""
available = list_checkpoints()
if not available:
print("No checkpoints found")
return True
weights_dir = get_weights_dir()
print("Available checkpoints:")
for v in available:
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
if weights_file.exists():
size = weights_file.stat().st_size / (1024**2) # MB
print(f" v{v} ({size:.1f} MB)")
else:
print(f" v{v} (file not found)")
return True
def main():
parser = argparse.ArgumentParser(
description="NNUE pipeline CLI for training and exporting models",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Train with 500k random positions
python nnue.py train
# Train from checkpoint v2
python nnue.py train --from-checkpoint 2
# Train with custom positions file
python nnue.py train --positions-file my_positions.txt
# Train with 200k games
python nnue.py train --games 200000
# Export specific weights version
python nnue.py export 2
# Export with full filename
python nnue.py export nnue_weights_v3.pt
# List available checkpoints
python nnue.py list
"""
)
subparsers = parser.add_subparsers(dest="command", help="Command to run")
# Train subcommand
train_parser = subparsers.add_parser("train", help="Train NNUE model")
train_parser.add_argument(
"--from-checkpoint",
type=int,
help="Start training from checkpoint version (e.g., 2)"
)
train_parser.add_argument(
"--games",
type=int,
help="Number of games to generate (default: 500000)"
)
train_parser.add_argument(
"--positions-file",
help="Use existing positions file instead of generating"
)
train_parser.add_argument(
"--stockfish",
help="Path to Stockfish binary (default: $STOCKFISH_PATH or /usr/games/stockfish)"
)
train_parser.set_defaults(func=cmd_train)
# Export subcommand
export_parser = subparsers.add_parser("export", help="Export weights to Scala")
export_parser.add_argument(
"weights",
help="Weights file or version (e.g., 2 or nnue_weights_v2.pt)"
)
export_parser.set_defaults(func=cmd_export)
# List subcommand
list_parser = subparsers.add_parser("list", help="List available checkpoints")
list_parser.set_defaults(func=cmd_list)
args = parser.parse_args()
if not args.command:
parser.print_help()
return 0
success = args.func(args)
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())