#!/usr/bin/env python3 """Central NNUE pipeline CLI for training and exporting models.""" import argparse import os import sys import subprocess from pathlib import Path def get_python_cmd(): """Get available Python command.""" if os.name == 'nt': return "python" return "python3" if os.popen("which python3 2>/dev/null").read() else "python" def list_checkpoints(): """List available checkpoint versions.""" checkpoints = sorted(Path(".").glob("nnue_weights_v*.pt")) if not checkpoints: return [] return [int(cp.stem.split("_v")[1]) for cp in checkpoints] def run_generate_positions(num_games): """Generate random positions.""" positions_file = "positions.txt" print(f"Generating {num_games} positions...") result = subprocess.run( [get_python_cmd(), "generate_positions.py", positions_file, "--games", str(num_games)], capture_output=False ) if result.returncode != 0: print("ERROR: Position generation failed") return False return Path(positions_file).exists() def run_label_positions(stockfish_path): """Label positions with Stockfish.""" positions_file = "positions.txt" output_file = "training_data.jsonl" if not Path(positions_file).exists(): print("ERROR: positions.txt not found") return False print("Labeling positions with Stockfish...") result = subprocess.run( [get_python_cmd(), "label_positions.py", positions_file, output_file, stockfish_path], capture_output=False ) if result.returncode != 0: print("ERROR: Position labeling failed") return False return Path(output_file).exists() def run_train(positions_file, output_weights, from_checkpoint=None): """Train NNUE model.""" if not Path(positions_file).exists(): print(f"ERROR: {positions_file} not found") return False print(f"Training model (output: {output_weights})...") if from_checkpoint: print(f" Starting from checkpoint: {from_checkpoint}") cmd = [get_python_cmd(), "train_nnue.py", positions_file, output_weights] if from_checkpoint: cmd.extend(["--checkpoint", from_checkpoint]) result = subprocess.run(cmd, capture_output=False) if result.returncode != 0: print("ERROR: Training failed") return False return True # train_nnue creates versioned file, not the base name def run_export(weights_file, output_file): """Export weights to Scala.""" if not Path(weights_file).exists(): print(f"ERROR: {weights_file} not found") return False print(f"Exporting {weights_file} to Scala...") result = subprocess.run( [get_python_cmd(), "export_weights.py", weights_file, output_file], capture_output=False ) if result.returncode != 0: print("ERROR: Export failed") return False return Path(output_file).exists() def cmd_train(args): """Handle train command.""" stockfish_path = args.stockfish or os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish") # Determine checkpoint checkpoint = None if args.from_checkpoint: checkpoint_version = args.from_checkpoint checkpoint = f"nnue_weights_v{checkpoint_version}.pt" if not Path(checkpoint).exists(): print(f"ERROR: Checkpoint {checkpoint} not found") return False else: available = list_checkpoints() if available: latest = max(available) checkpoint = f"nnue_weights_v{latest}.pt" print(f"No checkpoint specified, using latest: v{latest}") # Generate or use existing positions if args.positions_file: if not Path(args.positions_file).exists(): print(f"ERROR: {args.positions_file} not found") return False positions_file = args.positions_file else: positions_file = "positions.txt" num_games = args.games or 500000 if not run_generate_positions(num_games): return False # Label positions if not run_label_positions(stockfish_path): return False print("\nStarting training...") # Train (train_nnue.py handles versioning internally) if not run_train("training_data.jsonl", "nnue_weights.pt", checkpoint): return False # Show created version available = list_checkpoints() new_version = max(available) if available else 1 print(f"\nāœ“ Training complete: nnue_weights_v{new_version}.pt") return True def cmd_export(args): """Handle export command.""" weights_file = args.weights # Auto-detect if version is specified if not weights_file.endswith(".pt"): weights_file = f"nnue_weights_v{weights_file}.pt" if not Path(weights_file).exists(): print(f"ERROR: {weights_file} not found") return False # Determine version from filename version = Path(weights_file).stem.split("_v")[1] if "_v" in weights_file else "1" output_file = f"../src/main/scala/de/nowchess/bot/bots/nnue/NNUEWeights_v{version}.scala" if not run_export(weights_file, output_file): return False print(f"āœ“ Export complete: {output_file}") return True def cmd_list(args): """List available checkpoints.""" available = list_checkpoints() if not available: print("No checkpoints found") return True print("Available checkpoints:") for v in available: weights_file = f"nnue_weights_v{v}.pt" size = Path(weights_file).stat().st_size / (1024**2) # MB print(f" v{v} ({size:.1f} MB)") return True def main(): parser = argparse.ArgumentParser( description="NNUE pipeline CLI for training and exporting models", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Train with 500k random positions python nnue.py train # Train from checkpoint v2 python nnue.py train --from-checkpoint 2 # Train with custom positions file python nnue.py train --positions-file my_positions.txt # Train with 200k games python nnue.py train --games 200000 # Export specific weights version python nnue.py export 2 # Export with full filename python nnue.py export nnue_weights_v3.pt # List available checkpoints python nnue.py list """ ) subparsers = parser.add_subparsers(dest="command", help="Command to run") # Train subcommand train_parser = subparsers.add_parser("train", help="Train NNUE model") train_parser.add_argument( "--from-checkpoint", type=int, help="Start training from checkpoint version (e.g., 2)" ) train_parser.add_argument( "--games", type=int, help="Number of games to generate (default: 500000)" ) train_parser.add_argument( "--positions-file", help="Use existing positions file instead of generating" ) train_parser.add_argument( "--stockfish", help="Path to Stockfish binary (default: $STOCKFISH_PATH or /usr/games/stockfish)" ) train_parser.set_defaults(func=cmd_train) # Export subcommand export_parser = subparsers.add_parser("export", help="Export weights to Scala") export_parser.add_argument( "weights", help="Weights file or version (e.g., 2 or nnue_weights_v2.pt)" ) export_parser.set_defaults(func=cmd_export) # List subcommand list_parser = subparsers.add_parser("list", help="List available checkpoints") list_parser.set_defaults(func=cmd_list) args = parser.parse_args() if not args.command: parser.print_help() return 0 success = args.func(args) return 0 if success else 1 if __name__ == "__main__": sys.exit(main())