249 lines
7.6 KiB
Python
249 lines
7.6 KiB
Python
#!/usr/bin/env python3
|
|
"""Central NNUE pipeline CLI for training and exporting models."""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import subprocess
|
|
from pathlib import Path
|
|
|
|
def get_python_cmd():
|
|
"""Get available Python command."""
|
|
if os.name == 'nt':
|
|
return "python"
|
|
return "python3" if os.popen("which python3 2>/dev/null").read() else "python"
|
|
|
|
def list_checkpoints():
|
|
"""List available checkpoint versions."""
|
|
checkpoints = sorted(Path(".").glob("nnue_weights_v*.pt"))
|
|
if not checkpoints:
|
|
return []
|
|
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
|
|
|
def run_generate_positions(num_games):
|
|
"""Generate random positions."""
|
|
positions_file = "positions.txt"
|
|
print(f"Generating {num_games} positions...")
|
|
result = subprocess.run(
|
|
[get_python_cmd(), "generate_positions.py", positions_file, "--games", str(num_games)],
|
|
capture_output=False
|
|
)
|
|
if result.returncode != 0:
|
|
print("ERROR: Position generation failed")
|
|
return False
|
|
return Path(positions_file).exists()
|
|
|
|
def run_label_positions(stockfish_path):
|
|
"""Label positions with Stockfish."""
|
|
positions_file = "positions.txt"
|
|
output_file = "training_data.jsonl"
|
|
|
|
if not Path(positions_file).exists():
|
|
print("ERROR: positions.txt not found")
|
|
return False
|
|
|
|
print("Labeling positions with Stockfish...")
|
|
result = subprocess.run(
|
|
[get_python_cmd(), "label_positions.py", positions_file, output_file, stockfish_path],
|
|
capture_output=False
|
|
)
|
|
if result.returncode != 0:
|
|
print("ERROR: Position labeling failed")
|
|
return False
|
|
return Path(output_file).exists()
|
|
|
|
def run_train(positions_file, output_weights, from_checkpoint=None):
|
|
"""Train NNUE model."""
|
|
if not Path(positions_file).exists():
|
|
print(f"ERROR: {positions_file} not found")
|
|
return False
|
|
|
|
print(f"Training model (output: {output_weights})...")
|
|
if from_checkpoint:
|
|
print(f" Starting from checkpoint: {from_checkpoint}")
|
|
|
|
cmd = [get_python_cmd(), "train_nnue.py", positions_file, output_weights]
|
|
if from_checkpoint:
|
|
cmd.extend(["--checkpoint", from_checkpoint])
|
|
|
|
result = subprocess.run(cmd, capture_output=False)
|
|
if result.returncode != 0:
|
|
print("ERROR: Training failed")
|
|
return False
|
|
return True # train_nnue creates versioned file, not the base name
|
|
|
|
def run_export(weights_file, output_file):
|
|
"""Export weights to Scala."""
|
|
if not Path(weights_file).exists():
|
|
print(f"ERROR: {weights_file} not found")
|
|
return False
|
|
|
|
print(f"Exporting {weights_file} to Scala...")
|
|
result = subprocess.run(
|
|
[get_python_cmd(), "export_weights.py", weights_file, output_file],
|
|
capture_output=False
|
|
)
|
|
if result.returncode != 0:
|
|
print("ERROR: Export failed")
|
|
return False
|
|
return Path(output_file).exists()
|
|
|
|
def cmd_train(args):
|
|
"""Handle train command."""
|
|
stockfish_path = args.stockfish or os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
|
|
|
|
# Determine checkpoint
|
|
checkpoint = None
|
|
if args.from_checkpoint:
|
|
checkpoint_version = args.from_checkpoint
|
|
checkpoint = f"nnue_weights_v{checkpoint_version}.pt"
|
|
if not Path(checkpoint).exists():
|
|
print(f"ERROR: Checkpoint {checkpoint} not found")
|
|
return False
|
|
else:
|
|
available = list_checkpoints()
|
|
if available:
|
|
latest = max(available)
|
|
checkpoint = f"nnue_weights_v{latest}.pt"
|
|
print(f"No checkpoint specified, using latest: v{latest}")
|
|
|
|
# Generate or use existing positions
|
|
if args.positions_file:
|
|
if not Path(args.positions_file).exists():
|
|
print(f"ERROR: {args.positions_file} not found")
|
|
return False
|
|
positions_file = args.positions_file
|
|
else:
|
|
positions_file = "positions.txt"
|
|
num_games = args.games or 500000
|
|
if not run_generate_positions(num_games):
|
|
return False
|
|
|
|
# Label positions
|
|
if not run_label_positions(stockfish_path):
|
|
return False
|
|
|
|
print("\nStarting training...")
|
|
|
|
# Train (train_nnue.py handles versioning internally)
|
|
if not run_train("training_data.jsonl", "nnue_weights.pt", checkpoint):
|
|
return False
|
|
|
|
# Show created version
|
|
available = list_checkpoints()
|
|
new_version = max(available) if available else 1
|
|
print(f"\n✓ Training complete: nnue_weights_v{new_version}.pt")
|
|
return True
|
|
|
|
def cmd_export(args):
|
|
"""Handle export command."""
|
|
weights_file = args.weights
|
|
|
|
# Auto-detect if version is specified
|
|
if not weights_file.endswith(".pt"):
|
|
weights_file = f"nnue_weights_v{weights_file}.pt"
|
|
|
|
if not Path(weights_file).exists():
|
|
print(f"ERROR: {weights_file} not found")
|
|
return False
|
|
|
|
# Determine version from filename
|
|
version = Path(weights_file).stem.split("_v")[1] if "_v" in weights_file else "1"
|
|
output_file = f"../src/main/scala/de/nowchess/bot/bots/nnue/NNUEWeights_v{version}.scala"
|
|
|
|
if not run_export(weights_file, output_file):
|
|
return False
|
|
|
|
print(f"✓ Export complete: {output_file}")
|
|
return True
|
|
|
|
def cmd_list(args):
|
|
"""List available checkpoints."""
|
|
available = list_checkpoints()
|
|
if not available:
|
|
print("No checkpoints found")
|
|
return True
|
|
|
|
print("Available checkpoints:")
|
|
for v in available:
|
|
weights_file = f"nnue_weights_v{v}.pt"
|
|
size = Path(weights_file).stat().st_size / (1024**2) # MB
|
|
print(f" v{v} ({size:.1f} MB)")
|
|
return True
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="NNUE pipeline CLI for training and exporting models",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Train with 500k random positions
|
|
python nnue.py train
|
|
|
|
# Train from checkpoint v2
|
|
python nnue.py train --from-checkpoint 2
|
|
|
|
# Train with custom positions file
|
|
python nnue.py train --positions-file my_positions.txt
|
|
|
|
# Train with 200k games
|
|
python nnue.py train --games 200000
|
|
|
|
# Export specific weights version
|
|
python nnue.py export 2
|
|
|
|
# Export with full filename
|
|
python nnue.py export nnue_weights_v3.pt
|
|
|
|
# List available checkpoints
|
|
python nnue.py list
|
|
"""
|
|
)
|
|
|
|
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
|
|
|
# Train subcommand
|
|
train_parser = subparsers.add_parser("train", help="Train NNUE model")
|
|
train_parser.add_argument(
|
|
"--from-checkpoint",
|
|
type=int,
|
|
help="Start training from checkpoint version (e.g., 2)"
|
|
)
|
|
train_parser.add_argument(
|
|
"--games",
|
|
type=int,
|
|
help="Number of games to generate (default: 500000)"
|
|
)
|
|
train_parser.add_argument(
|
|
"--positions-file",
|
|
help="Use existing positions file instead of generating"
|
|
)
|
|
train_parser.add_argument(
|
|
"--stockfish",
|
|
help="Path to Stockfish binary (default: $STOCKFISH_PATH or /usr/games/stockfish)"
|
|
)
|
|
train_parser.set_defaults(func=cmd_train)
|
|
|
|
# Export subcommand
|
|
export_parser = subparsers.add_parser("export", help="Export weights to Scala")
|
|
export_parser.add_argument(
|
|
"weights",
|
|
help="Weights file or version (e.g., 2 or nnue_weights_v2.pt)"
|
|
)
|
|
export_parser.set_defaults(func=cmd_export)
|
|
|
|
# List subcommand
|
|
list_parser = subparsers.add_parser("list", help="List available checkpoints")
|
|
list_parser.set_defaults(func=cmd_list)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not args.command:
|
|
parser.print_help()
|
|
return 0
|
|
|
|
success = args.func(args)
|
|
return 0 if success else 1
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main()) |