diff --git a/modules/bot/python/nnue.py b/modules/bot/python/nnue.py index ec19ca2..ccb9585 100644 --- a/modules/bot/python/nnue.py +++ b/modules/bot/python/nnue.py @@ -1,21 +1,22 @@ #!/usr/bin/env python3 -"""Central NNUE pipeline CLI for training and exporting models.""" +"""Central NNUE pipeline TUI for training and exporting models.""" -import argparse import os import sys -import subprocess from pathlib import Path +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich.prompt import Prompt, Confirm +from rich import print as rprint -def get_python_cmd(): - """Get available Python command.""" - if os.name == 'nt': - return "python" - return "python3" if os.popen("which python3 2>/dev/null").read() else "python" +# Add src directory to path so we can import modules +sys.path.insert(0, str(Path(__file__).parent / "src")) -def get_src_module(module_name): - """Get path to module in src/ directory.""" - return Path(__file__).parent / "src" / f"{module_name}.py" +from generate import play_random_game_and_collect_positions +from label import label_positions_with_stockfish +from train import train_nnue +from export import export_weights_to_binary def get_data_dir(): """Get/create data directory.""" @@ -37,240 +38,252 @@ def list_checkpoints(): return [] return [int(cp.stem.split("_v")[1]) for cp in checkpoints] -def run_generate_positions(num_games): - """Generate random positions.""" - data_dir = get_data_dir() - positions_file = data_dir / "positions.txt" - print(f"Generating {num_games} positions...") - result = subprocess.run( - [get_python_cmd(), str(get_src_module("generate")), str(positions_file), "--games", str(num_games)], - capture_output=False +def show_header(): + """Display application header.""" + console = Console() + console.clear() + console.print( + Panel( + "[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n" + "[dim]Neural Network Utility Evaluation - Model Management[/dim]", + border_style="cyan", + padding=(1, 2), + ) ) - if result.returncode != 0: - print("ERROR: Position generation failed") - return False - return positions_file.exists() -def run_label_positions(stockfish_path): - """Label positions with Stockfish.""" - data_dir = get_data_dir() - positions_file = data_dir / "positions.txt" - output_file = data_dir / "training_data.jsonl" - - if not positions_file.exists(): - print("ERROR: positions.txt not found") - return False - - print("Labeling positions with Stockfish...") - result = subprocess.run( - [get_python_cmd(), str(get_src_module("label")), str(positions_file), str(output_file), stockfish_path], - capture_output=False - ) - if result.returncode != 0: - print("ERROR: Position labeling failed") - return False - return output_file.exists() - -def run_train(positions_file, output_weights, from_checkpoint=None): - """Train NNUE model.""" - if not Path(positions_file).exists(): - print(f"ERROR: {positions_file} not found") - return False - - weights_dir = get_weights_dir() - print(f"Training model (output: {output_weights})...") - if from_checkpoint: - print(f" Starting from checkpoint: {from_checkpoint}") - - cmd = [get_python_cmd(), str(get_src_module("train")), str(positions_file), str(output_weights)] - if from_checkpoint: - cmd.extend(["--checkpoint", str(from_checkpoint)]) - - # Run from weights directory so outputs save there - result = subprocess.run(cmd, cwd=str(weights_dir), capture_output=False) - if result.returncode != 0: - print("ERROR: Training failed") - return False - return True - -def run_export(weights_file, output_file): - """Export weights to Scala.""" - weights_dir = get_weights_dir() - weights_path = weights_dir / Path(weights_file).name - - if not weights_path.exists(): - print(f"ERROR: {weights_file} not found in {weights_dir}") - return False - - print(f"Exporting {weights_file} to Scala...") - result = subprocess.run( - [get_python_cmd(), str(get_src_module("export")), str(weights_path), output_file], - capture_output=False - ) - if result.returncode != 0: - print("ERROR: Export failed") - return False - return Path(output_file).exists() - -def cmd_train(args): - """Handle train command.""" - stockfish_path = args.stockfish or os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish") - data_dir = get_data_dir() - weights_dir = get_weights_dir() - - # Determine checkpoint - checkpoint = None - if args.from_checkpoint: - checkpoint_version = args.from_checkpoint - checkpoint = f"nnue_weights_v{checkpoint_version}.pt" - checkpoint_path = weights_dir / checkpoint - if not checkpoint_path.exists(): - print(f"ERROR: Checkpoint {checkpoint} not found") - return False - else: - available = list_checkpoints() - if available: - latest = max(available) - checkpoint = f"nnue_weights_v{latest}.pt" - print(f"No checkpoint specified, using latest: v{latest}") - - # Generate or use existing positions - if args.positions_file: - positions_file = Path(args.positions_file) - if not positions_file.exists(): - print(f"ERROR: {args.positions_file} not found") - return False - else: - positions_file = data_dir / "positions.txt" - num_games = args.games or 500000 - if not run_generate_positions(num_games): - return False - - # Label positions - if not run_label_positions(stockfish_path): - return False - - print("\nStarting training...") - - # Train with absolute path to data, checkpoint is relative to weights dir - training_data = str(data_dir / "training_data.jsonl") - if not run_train(training_data, "nnue_weights.pt", checkpoint): - return False - - # Show created version +def show_checkpoints_table(): + """Display available checkpoints in a table.""" + console = Console() available = list_checkpoints() - new_version = max(available) if available else 1 - print(f"\nāœ“ Training complete: nnue_weights_v{new_version}.pt") - return True -def cmd_export(args): - """Handle export command.""" - weights_file = args.weights - - # Auto-detect if version is specified - if not weights_file.endswith(".pt"): - weights_file = f"nnue_weights_v{weights_file}.pt" - - # Output to resources directory as binary format - output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin") - - if not run_export(weights_file, output_file): - return False - - print(f"āœ“ Export complete: {output_file}") - return True - -def cmd_list(args): - """List available checkpoints.""" - available = list_checkpoints() if not available: - print("No checkpoints found") - return True + console.print("[yellow]ℹ No checkpoints found yet[/yellow]") + return + + table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan") + table.add_column("Version", style="dim") + table.add_column("File Size", justify="right") + table.add_column("Status", justify="center") weights_dir = get_weights_dir() - print("Available checkpoints:") - for v in available: + for v in sorted(available): weights_file = weights_dir / f"nnue_weights_v{v}.pt" if weights_file.exists(): - size = weights_file.stat().st_size / (1024**2) # MB - print(f" v{v} ({size:.1f} MB)") + size = weights_file.stat().st_size / (1024**2) + table.add_row(f"v{v}", f"{size:.1f} MB", "āœ“ Ready") else: - print(f" v{v} (file not found)") - return True + table.add_row(f"v{v}", "?", "[red]āœ— Missing[/red]") + + console.print(table) + +def show_main_menu(): + """Display and handle main menu.""" + console = Console() + + while True: + show_header() + show_checkpoints_table() + + console.print("\n[bold]What would you like to do?[/bold]") + console.print("[cyan]1[/cyan] - Train NNUE Model") + console.print("[cyan]2[/cyan] - Export Weights to Scala") + console.print("[cyan]3[/cyan] - View Checkpoints") + console.print("[cyan]4[/cyan] - Exit") + + choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"]) + + if choice == "1": + train_interactive() + elif choice == "2": + export_interactive() + elif choice == "3": + show_header() + show_checkpoints_table() + Prompt.ask("\nPress Enter to continue") + elif choice == "4": + console.print("[yellow]šŸ‘‹ Goodbye![/yellow]") + return + +def train_interactive(): + """Interactive training menu.""" + console = Console() + show_header() + + console.print("\n[bold cyan]šŸ“š Training Configuration[/bold cyan]") + + # Checkpoint selection + available = list_checkpoints() + use_checkpoint = False + checkpoint_version = None + + if available: + console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]") + use_checkpoint = Confirm.ask("Start from an existing checkpoint?", default=False) + if use_checkpoint: + checkpoint_version = Prompt.ask( + "Enter checkpoint version", + default=str(max(available)) + ) + + # Positions source + use_existing = Confirm.ask("Use existing positions file?", default=False) + positions_file = None + num_games = 500000 + + if use_existing: + positions_file = Prompt.ask("Enter path to positions file") + else: + num_games = int(Prompt.ask("Number of games to generate", default="500000")) + + # Stockfish path + default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish") + stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish) + + # Training parameters + epochs = int(Prompt.ask("Number of epochs", default="20")) + batch_size = int(Prompt.ask("Batch size", default="4096")) + + # Confirm and start + console.print("\n[bold]Configuration Summary:[/bold]") + if use_checkpoint: + console.print(f" Checkpoint: v{checkpoint_version}") + else: + console.print(" Checkpoint: None (training from scratch)") + console.print(f" Games: {num_games:,}") + console.print(f" Epochs: {epochs}") + console.print(f" Batch size: {batch_size}") + console.print(f" Stockfish: {stockfish_path}") + + if not Confirm.ask("\nStart training?", default=True): + console.print("[yellow]Training cancelled[/yellow]") + return + + # Execute training + data_dir = get_data_dir() + weights_dir = get_weights_dir() + + try: + # Generate positions + if not use_existing: + console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]") + count = play_random_game_and_collect_positions( + str(data_dir / "positions.txt"), + total_games=num_games, + filter_captures=True + ) + if count == 0: + console.print("[red]āœ— No valid positions generated[/red]") + return + console.print(f"[green]āœ“ Generated {count:,} positions[/green]") + else: + if not Path(positions_file).exists(): + console.print(f"[red]āœ— Positions file not found: {positions_file}[/red]") + return + + # Label positions + console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]") + positions_file = data_dir / "positions.txt" + output_file = data_dir / "training_data.jsonl" + success = label_positions_with_stockfish( + str(positions_file), + str(output_file), + stockfish_path, + depth=12 + ) + if not success: + console.print("[red]āœ— Position labeling failed[/red]") + return + console.print(f"[green]āœ“ Positions labeled[/green]") + + # Train model + console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]") + checkpoint = None + if use_checkpoint: + checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt") + + train_nnue( + data_file=str(output_file), + output_file=str(weights_dir / "nnue_weights.pt"), + epochs=epochs, + batch_size=batch_size, + checkpoint=checkpoint, + use_versioning=True + ) + console.print("[green]āœ“ Training complete[/green]") + + # Show result + available = list_checkpoints() + new_version = max(available) if available else 1 + console.print(f"\n[bold green]āœ“ Training successful![/bold green]") + console.print(f"[bold]New checkpoint: v{new_version}[/bold]") + Prompt.ask("Press Enter to continue") + + except Exception as e: + console.print(f"[red]āœ— Error: {e}[/red]") + import traceback + traceback.print_exc() + Prompt.ask("Press Enter to continue") + +def export_interactive(): + """Interactive export menu.""" + console = Console() + show_header() + + console.print("\n[bold cyan]šŸ“¦ Export Configuration[/bold cyan]") + + # Select weights version + available = list_checkpoints() + if not available: + console.print("[red]āœ— No checkpoints available to export[/red]") + Prompt.ask("Press Enter to continue") + return + + console.print(f"[dim]Available versions: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]") + version = Prompt.ask("Enter version to export (e.g., 2)") + + weights_file = f"nnue_weights_v{version}.pt" + output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin") + + console.print(f"\n[bold]Export Configuration:[/bold]") + console.print(f" Source: {weights_file}") + console.print(f" Destination: {output_file}") + + if not Confirm.ask("\nExport weights?", default=True): + console.print("[yellow]Export cancelled[/yellow]") + return + + try: + weights_dir = get_weights_dir() + weights_path = weights_dir / weights_file + + if not weights_path.exists(): + console.print(f"[red]āœ— {weights_file} not found[/red]") + return + + console.print("\n[bold cyan]Exporting Weights[/bold cyan]") + export_weights_to_binary(str(weights_path), output_file) + console.print(f"\n[green]āœ“ Export complete![/green]") + console.print(f"[bold]Weights saved to:[/bold] {output_file}") + Prompt.ask("Press Enter to continue") + + except Exception as e: + console.print(f"[red]āœ— Error: {e}[/red]") + import traceback + traceback.print_exc() + Prompt.ask("Press Enter to continue") def main(): - parser = argparse.ArgumentParser( - description="NNUE pipeline CLI for training and exporting models", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Train with 500k random positions - python nnue.py train - - # Train from checkpoint v2 - python nnue.py train --from-checkpoint 2 - - # Train with custom positions file - python nnue.py train --positions-file my_positions.txt - - # Train with 200k games - python nnue.py train --games 200000 - - # Export specific weights version - python nnue.py export 2 - - # Export with full filename - python nnue.py export nnue_weights_v3.pt - - # List available checkpoints - python nnue.py list - """ - ) - - subparsers = parser.add_subparsers(dest="command", help="Command to run") - - # Train subcommand - train_parser = subparsers.add_parser("train", help="Train NNUE model") - train_parser.add_argument( - "--from-checkpoint", - type=int, - help="Start training from checkpoint version (e.g., 2)" - ) - train_parser.add_argument( - "--games", - type=int, - help="Number of games to generate (default: 500000)" - ) - train_parser.add_argument( - "--positions-file", - help="Use existing positions file instead of generating" - ) - train_parser.add_argument( - "--stockfish", - help="Path to Stockfish binary (default: $STOCKFISH_PATH or /usr/games/stockfish)" - ) - train_parser.set_defaults(func=cmd_train) - - # Export subcommand - export_parser = subparsers.add_parser("export", help="Export weights to Scala") - export_parser.add_argument( - "weights", - help="Weights file or version (e.g., 2 or nnue_weights_v2.pt)" - ) - export_parser.set_defaults(func=cmd_export) - - # List subcommand - list_parser = subparsers.add_parser("list", help="List available checkpoints") - list_parser.set_defaults(func=cmd_list) - - args = parser.parse_args() - - if not args.command: - parser.print_help() + try: + show_main_menu() return 0 - - success = args.func(args) - return 0 if success else 1 + except KeyboardInterrupt: + console = Console() + console.print("\n[yellow]Interrupted by user[/yellow]") + return 1 + except Exception as e: + console = Console() + console.print(f"[red]Error:[/red] {e}") + return 1 if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file + sys.exit(main()) diff --git a/modules/bot/python/requirements.txt b/modules/bot/python/requirements.txt index 9a1ed39..b0bf304 100644 --- a/modules/bot/python/requirements.txt +++ b/modules/bot/python/requirements.txt @@ -1,4 +1,5 @@ chess==1.11.2 torch==2.11.0 tqdm==4.67.3 -numpy==2.4.4 \ No newline at end of file +numpy==2.4.4 +rich==13.7.0 \ No newline at end of file diff --git a/modules/bot/python/src/generate.py b/modules/bot/python/src/generate.py index 8397fae..8188638 100644 --- a/modules/bot/python/src/generate.py +++ b/modules/bot/python/src/generate.py @@ -7,6 +7,36 @@ import sys from pathlib import Path from tqdm import tqdm +# Standard piece values for capture filtering +PIECE_VALUES = { + chess.PAWN: 1, + chess.KNIGHT: 3, + chess.BISHOP: 3, + chess.ROOK: 5, + chess.QUEEN: 9, +} + +def has_winning_or_equal_capture(board): + """Check if position has a capture where victim >= attacker (winning or equal trade). + + Returns True only if there's at least one favorable capture. + Positions with only losing captures return False (are kept). + """ + for move in board.legal_moves: + if board.is_capture(move): + attacker_piece = board.piece_at(move.from_square) + victim_piece = board.piece_at(move.to_square) + + if attacker_piece and victim_piece: + attacker_value = PIECE_VALUES.get(attacker_piece.piece_type, 0) + victim_value = PIECE_VALUES.get(victim_piece.piece_type, 0) + + # If victim >= attacker, it's a winning or equal capture + if victim_value >= attacker_value: + return True + + return False + def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True): """Play random games and save positions after 8-20 random moves. @@ -49,10 +79,9 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt pbar.update(1) continue - # Check if any captures are available (if filtering enabled) + # Check if there are winning or equal captures (if filtering enabled) if filter_captures: - has_captures = any(board.is_capture(move) for move in board.legal_moves) - if has_captures: + if has_winning_or_equal_capture(board): filtered_captures += 1 pbar.update(1) continue @@ -69,21 +98,23 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt print("=" * 60) print("POSITION GENERATION SUMMARY") print("=" * 60) + total_filtered = filtered_check + filtered_captures + filtered_game_over print(f"Total games: {total_games}") print(f"Saved positions: {positions_count}") - print(f"Filtered (check): {filtered_check}") - print(f"Filtered (captures): {filtered_captures}") + print(f"Filtered (in check): {filtered_check}") + print(f"Filtered (winning+ cap): {filtered_captures}") print(f"Filtered (game over): {filtered_game_over}") - print(f"Total filtered: {filtered_check + filtered_captures + filtered_game_over}") + print(f"Total filtered: {total_filtered}") print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%") + print(f"(Keeps positions with only losing/bad captures)") print("=" * 60) print() if positions_count == 0: print("WARNING: No valid positions were generated!") print("This might happen if:") - print(" - The filter criteria are too strict (captures, checks)") - print(" - Try using: --no-filter-captures to accept positions with captures") + print(" - Most positions have checks or game-over states") + print(" - Try using: --no-filter-captures to accept positions with winning captures") return 0 return positions_count @@ -97,7 +128,7 @@ if __name__ == "__main__": parser.add_argument("--games", type=int, default=5000, help="Number of games to play (default: 500000)") parser.add_argument("--no-filter-captures", action="store_true", - help="Include positions with available captures (increases output)") + help="Include positions with winning/equal captures (increases output)") args = parser.parse_args() diff --git a/modules/bot/python/weights/nnue_weights_v2.pt b/modules/bot/python/weights/nnue_weights_v2.pt new file mode 100644 index 0000000..b40b218 Binary files /dev/null and b/modules/bot/python/weights/nnue_weights_v2.pt differ diff --git a/modules/bot/python/weights/nnue_weights_v2_metadata.json b/modules/bot/python/weights/nnue_weights_v2_metadata.json new file mode 100644 index 0000000..40d14b8 --- /dev/null +++ b/modules/bot/python/weights/nnue_weights_v2_metadata.json @@ -0,0 +1,13 @@ +{ + "version": 2, + "date": "2026-04-07T23:50:05.390402", + "num_positions": 6886, + "stockfish_depth": 12, + "epochs": 100, + "batch_size": 4096, + "learning_rate": 0.001, + "final_val_loss": 0.007848377339541912, + "device": "cuda", + "checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v1.pt", + "notes": "Win rate vs classical eval: TBD (requires benchmark games)" +} \ No newline at end of file