#!/usr/bin/env python3 """Central NNUE pipeline CLI for training and exporting models.""" import argparse import os import sys import subprocess from pathlib import Path def get_python_cmd(): """Get available Python command.""" if os.name == 'nt': return "python" return "python3" if os.popen("which python3 2>/dev/null").read() else "python" def get_src_module(module_name): """Get path to module in src/ directory.""" return Path(__file__).parent / "src" / f"{module_name}.py" def get_data_dir(): """Get/create data directory.""" data_dir = Path(__file__).parent / "data" data_dir.mkdir(exist_ok=True) return data_dir def get_weights_dir(): """Get/create weights directory.""" weights_dir = Path(__file__).parent / "weights" weights_dir.mkdir(exist_ok=True) return weights_dir def list_checkpoints(): """List available checkpoint versions.""" weights_dir = get_weights_dir() checkpoints = sorted(weights_dir.glob("nnue_weights_v*.pt")) if not checkpoints: return [] return [int(cp.stem.split("_v")[1]) for cp in checkpoints] def run_generate_positions(num_games): """Generate random positions.""" data_dir = get_data_dir() positions_file = data_dir / "positions.txt" print(f"Generating {num_games} positions...") result = subprocess.run( [get_python_cmd(), str(get_src_module("generate")), str(positions_file), "--games", str(num_games)], capture_output=False ) if result.returncode != 0: print("ERROR: Position generation failed") return False return positions_file.exists() def run_label_positions(stockfish_path): """Label positions with Stockfish.""" data_dir = get_data_dir() positions_file = data_dir / "positions.txt" output_file = data_dir / "training_data.jsonl" if not positions_file.exists(): print("ERROR: positions.txt not found") return False print("Labeling positions with Stockfish...") result = subprocess.run( [get_python_cmd(), str(get_src_module("label")), str(positions_file), str(output_file), stockfish_path], capture_output=False ) if result.returncode != 0: print("ERROR: Position labeling failed") return False return output_file.exists() def run_train(positions_file, output_weights, from_checkpoint=None): """Train NNUE model.""" if not Path(positions_file).exists(): print(f"ERROR: {positions_file} not found") return False weights_dir = get_weights_dir() print(f"Training model (output: {output_weights})...") if from_checkpoint: print(f" Starting from checkpoint: {from_checkpoint}") cmd = [get_python_cmd(), str(get_src_module("train")), str(positions_file), str(output_weights)] if from_checkpoint: cmd.extend(["--checkpoint", str(from_checkpoint)]) # Run from weights directory so outputs save there result = subprocess.run(cmd, cwd=str(weights_dir), capture_output=False) if result.returncode != 0: print("ERROR: Training failed") return False return True def run_export(weights_file, output_file): """Export weights to Scala.""" weights_dir = get_weights_dir() weights_path = weights_dir / Path(weights_file).name if not weights_path.exists(): print(f"ERROR: {weights_file} not found in {weights_dir}") return False print(f"Exporting {weights_file} to Scala...") result = subprocess.run( [get_python_cmd(), str(get_src_module("export")), str(weights_path), output_file], capture_output=False ) if result.returncode != 0: print("ERROR: Export failed") return False return Path(output_file).exists() def cmd_train(args): """Handle train command.""" stockfish_path = args.stockfish or os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish") data_dir = get_data_dir() weights_dir = get_weights_dir() # Determine checkpoint checkpoint = None if args.from_checkpoint: checkpoint_version = args.from_checkpoint checkpoint = f"nnue_weights_v{checkpoint_version}.pt" checkpoint_path = weights_dir / checkpoint if not checkpoint_path.exists(): print(f"ERROR: Checkpoint {checkpoint} not found") return False else: available = list_checkpoints() if available: latest = max(available) checkpoint = f"nnue_weights_v{latest}.pt" print(f"No checkpoint specified, using latest: v{latest}") # Generate or use existing positions if args.positions_file: positions_file = Path(args.positions_file) if not positions_file.exists(): print(f"ERROR: {args.positions_file} not found") return False else: positions_file = data_dir / "positions.txt" num_games = args.games or 500000 if not run_generate_positions(num_games): return False # Label positions if not run_label_positions(stockfish_path): return False print("\nStarting training...") # Train with absolute path to data, checkpoint is relative to weights dir training_data = str(data_dir / "training_data.jsonl") if not run_train(training_data, "nnue_weights.pt", checkpoint): return False # Show created version available = list_checkpoints() new_version = max(available) if available else 1 print(f"\nāœ“ Training complete: nnue_weights_v{new_version}.pt") return True def cmd_export(args): """Handle export command.""" weights_file = args.weights # Auto-detect if version is specified if not weights_file.endswith(".pt"): weights_file = f"nnue_weights_v{weights_file}.pt" # Output to resources directory as binary format output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin") if not run_export(weights_file, output_file): return False print(f"āœ“ Export complete: {output_file}") return True def cmd_list(args): """List available checkpoints.""" available = list_checkpoints() if not available: print("No checkpoints found") return True weights_dir = get_weights_dir() print("Available checkpoints:") for v in available: weights_file = weights_dir / f"nnue_weights_v{v}.pt" if weights_file.exists(): size = weights_file.stat().st_size / (1024**2) # MB print(f" v{v} ({size:.1f} MB)") else: print(f" v{v} (file not found)") return True def main(): parser = argparse.ArgumentParser( description="NNUE pipeline CLI for training and exporting models", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Train with 500k random positions python nnue.py train # Train from checkpoint v2 python nnue.py train --from-checkpoint 2 # Train with custom positions file python nnue.py train --positions-file my_positions.txt # Train with 200k games python nnue.py train --games 200000 # Export specific weights version python nnue.py export 2 # Export with full filename python nnue.py export nnue_weights_v3.pt # List available checkpoints python nnue.py list """ ) subparsers = parser.add_subparsers(dest="command", help="Command to run") # Train subcommand train_parser = subparsers.add_parser("train", help="Train NNUE model") train_parser.add_argument( "--from-checkpoint", type=int, help="Start training from checkpoint version (e.g., 2)" ) train_parser.add_argument( "--games", type=int, help="Number of games to generate (default: 500000)" ) train_parser.add_argument( "--positions-file", help="Use existing positions file instead of generating" ) train_parser.add_argument( "--stockfish", help="Path to Stockfish binary (default: $STOCKFISH_PATH or /usr/games/stockfish)" ) train_parser.set_defaults(func=cmd_train) # Export subcommand export_parser = subparsers.add_parser("export", help="Export weights to Scala") export_parser.add_argument( "weights", help="Weights file or version (e.g., 2 or nnue_weights_v2.pt)" ) export_parser.set_defaults(func=cmd_export) # List subcommand list_parser = subparsers.add_parser("list", help="List available checkpoints") list_parser.set_defaults(func=cmd_list) args = parser.parse_args() if not args.command: parser.print_help() return 0 success = args.func(args) return 0 if success else 1 if __name__ == "__main__": sys.exit(main())