#!/usr/bin/env python3 """Central NNUE pipeline TUI for training and exporting models.""" import os import sys from pathlib import Path from rich.console import Console from rich.table import Table from rich.panel import Panel from rich.prompt import Prompt, Confirm from rich import print as rprint # Add src directory to path so we can import modules sys.path.insert(0, str(Path(__file__).parent / "src")) from generate import play_random_game_and_collect_positions from label import label_positions_with_stockfish from train import train_nnue from export import export_weights_to_binary def get_data_dir(): """Get/create data directory.""" data_dir = Path(__file__).parent / "data" data_dir.mkdir(exist_ok=True) return data_dir def get_weights_dir(): """Get/create weights directory.""" weights_dir = Path(__file__).parent / "weights" weights_dir.mkdir(exist_ok=True) return weights_dir def list_checkpoints(): """List available checkpoint versions.""" weights_dir = get_weights_dir() checkpoints = sorted(weights_dir.glob("nnue_weights_v*.pt")) if not checkpoints: return [] return [int(cp.stem.split("_v")[1]) for cp in checkpoints] def show_header(): """Display application header.""" console = Console() console.clear() console.print( Panel( "[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n" "[dim]Neural Network Utility Evaluation - Model Management[/dim]", border_style="cyan", padding=(1, 2), ) ) def show_checkpoints_table(): """Display available checkpoints in a table.""" console = Console() available = list_checkpoints() if not available: console.print("[yellow]ℹ No checkpoints found yet[/yellow]") return table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan") table.add_column("Version", style="dim") table.add_column("File Size", justify="right") table.add_column("Status", justify="center") weights_dir = get_weights_dir() for v in sorted(available): weights_file = weights_dir / f"nnue_weights_v{v}.pt" if weights_file.exists(): size = weights_file.stat().st_size / (1024**2) table.add_row(f"v{v}", f"{size:.1f} MB", "✓ Ready") else: table.add_row(f"v{v}", "?", "[red]✗ Missing[/red]") console.print(table) def show_main_menu(): """Display and handle main menu.""" console = Console() while True: show_header() show_checkpoints_table() console.print("\n[bold]What would you like to do?[/bold]") console.print("[cyan]1[/cyan] - Train NNUE Model") console.print("[cyan]2[/cyan] - Export Weights to Scala") console.print("[cyan]3[/cyan] - View Checkpoints") console.print("[cyan]4[/cyan] - Exit") choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"]) if choice == "1": train_interactive() elif choice == "2": export_interactive() elif choice == "3": show_header() show_checkpoints_table() Prompt.ask("\nPress Enter to continue") elif choice == "4": console.print("[yellow]👋 Goodbye![/yellow]") return def train_interactive(): """Interactive training menu.""" console = Console() show_header() console.print("\n[bold cyan]📚 Training Configuration[/bold cyan]") # Checkpoint selection available = list_checkpoints() use_checkpoint = False checkpoint_version = None if available: console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]") use_checkpoint = Confirm.ask("Start from an existing checkpoint?", default=False) if use_checkpoint: checkpoint_version = Prompt.ask( "Enter checkpoint version", default=str(max(available)) ) # Positions source use_existing = Confirm.ask("Use existing positions file?", default=False) positions_file = None num_games = 500000 samples_per_game = 1 min_move = 1 max_move = 50 if use_existing: positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt")) else: num_games = int(Prompt.ask("Number of games to generate", default="5000")) samples_per_game = int(Prompt.ask("Positions to sample per game", default="1")) min_move = int(Prompt.ask("Minimum move number", default="1")) max_move = int(Prompt.ask("Maximum move number", default="50")) use_existing_labels = Confirm.ask("Use existing labels file?", default=False) labels_file = None if use_existing_labels: labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl")) # Stockfish path default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/bin/stockfish") stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish) # Training parameters epochs = int(Prompt.ask("Number of epochs", default="20")) batch_size = int(Prompt.ask("Batch size", default="4096")) early_stopping = None if Confirm.ask("Enable early stopping?", default=False): early_stopping = int(Prompt.ask("Patience (epochs)", default="5")) # Confirm and start console.print("\n[bold]Configuration Summary:[/bold]") if use_checkpoint: console.print(f" Checkpoint: v{checkpoint_version}") else: console.print(" Checkpoint: None (training from scratch)") if not use_existing: console.print(f" Games: {num_games:,}") console.print(f" Samples per game: {samples_per_game}") console.print(f" Move range: {min_move}-{max_move}") else: console.print(f" Positions file: {positions_file}") console.print(f" Epochs: {epochs}") console.print(f" Batch size: {batch_size}") if early_stopping: console.print(f" Early stopping: Yes (patience: {early_stopping})") else: console.print(f" Early stopping: No") console.print(f" Stockfish: {stockfish_path}") if not Confirm.ask("\nStart training?", default=True): console.print("[yellow]Training cancelled[/yellow]") return # Execute training data_dir = get_data_dir() weights_dir = get_weights_dir() try: # Generate positions if not use_existing: console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]") count = play_random_game_and_collect_positions( str(data_dir / "positions.txt"), total_games=num_games, samples_per_game=samples_per_game, min_move=min_move, max_move=max_move ) if count == 0: console.print("[red]✗ No valid positions generated[/red]") return console.print(f"[green]✓ Generated {count:,} positions[/green]") else: if not Path(positions_file).exists(): console.print(f"[red]✗ Positions file not found: {positions_file}[/red]") return if not use_existing_labels: # Label positions console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]") positions_file = data_dir / "positions.txt" output_file = data_dir / "training_data.jsonl" success = label_positions_with_stockfish( str(positions_file), str(output_file), stockfish_path, depth=12 ) if not success: console.print("[red]✗ Position labeling failed[/red]") return console.print(f"[green]✓ Positions labeled[/green]") else: console.print("\n[bold cyan]Step 2: Loading Existing Labels[/bold cyan]") output_file = labels_file if not Path(output_file).exists(): console.print(f"[red]✗ Labels file not found: {output_file}[/red]") return # Train model console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]") checkpoint = None if use_checkpoint: checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt") train_nnue( data_file=str(output_file), output_file=str(weights_dir / "nnue_weights.pt"), epochs=epochs, batch_size=batch_size, checkpoint=checkpoint, use_versioning=True, early_stopping_patience=early_stopping ) console.print("[green]✓ Training complete[/green]") # Show result available = list_checkpoints() new_version = max(available) if available else 1 console.print(f"\n[bold green]✓ Training successful![/bold green]") console.print(f"[bold]New checkpoint: v{new_version}[/bold]") Prompt.ask("Press Enter to continue") except Exception as e: console.print(f"[red]✗ Error: {e}[/red]") import traceback traceback.print_exc() Prompt.ask("Press Enter to continue") def export_interactive(): """Interactive export menu.""" console = Console() show_header() console.print("\n[bold cyan]📦 Export Configuration[/bold cyan]") # Select weights version available = list_checkpoints() if not available: console.print("[red]✗ No checkpoints available to export[/red]") Prompt.ask("Press Enter to continue") return console.print(f"[dim]Available versions: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]") version = Prompt.ask("Enter version to export (e.g., 2)") weights_file = f"nnue_weights_v{version}.pt" output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin") console.print(f"\n[bold]Export Configuration:[/bold]") console.print(f" Source: {weights_file}") console.print(f" Destination: {output_file}") if not Confirm.ask("\nExport weights?", default=True): console.print("[yellow]Export cancelled[/yellow]") return try: weights_dir = get_weights_dir() weights_path = weights_dir / weights_file if not weights_path.exists(): console.print(f"[red]✗ {weights_file} not found[/red]") return console.print("\n[bold cyan]Exporting Weights[/bold cyan]") export_weights_to_binary(str(weights_path), output_file) console.print(f"\n[green]✓ Export complete![/green]") console.print(f"[bold]Weights saved to:[/bold] {output_file}") Prompt.ask("Press Enter to continue") except Exception as e: console.print(f"[red]✗ Error: {e}[/red]") import traceback traceback.print_exc() Prompt.ask("Press Enter to continue") def main(): try: show_main_menu() return 0 except KeyboardInterrupt: console = Console() console.print("\n[yellow]Interrupted by user[/yellow]") return 1 except Exception as e: console = Console() console.print(f"[red]Error:[/red] {e}") return 1 if __name__ == "__main__": sys.exit(main())