406 lines
14 KiB
Python
406 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
"""Central NNUE pipeline TUI for training and exporting models."""
|
||
|
||
import os
|
||
import shutil
|
||
import sys
|
||
from pathlib import Path
|
||
from rich.console import Console
|
||
from rich.table import Table
|
||
from rich.panel import Panel
|
||
from rich.prompt import Prompt, Confirm
|
||
from rich import print as rprint
|
||
|
||
# Add src directory to path so we can import modules
|
||
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||
|
||
from generate import play_random_game_and_collect_positions
|
||
from label import label_positions_with_stockfish
|
||
from train import train_nnue
|
||
from export import export_weights_to_binary
|
||
from tactical_positions_extractor import (
|
||
download_and_extract_puzzle_db,
|
||
interactive_merge_positions
|
||
)
|
||
|
||
def get_data_dir():
|
||
"""Get/create data directory."""
|
||
data_dir = Path(__file__).parent / "data"
|
||
data_dir.mkdir(exist_ok=True)
|
||
return data_dir
|
||
|
||
def get_tactical_data_dir():
|
||
"""Get/create data directory."""
|
||
data_dir = Path(__file__).parent / "tactical_data"
|
||
data_dir.mkdir(exist_ok=True)
|
||
return data_dir
|
||
|
||
def get_weights_dir():
|
||
"""Get/create weights directory."""
|
||
weights_dir = Path(__file__).parent / "weights"
|
||
weights_dir.mkdir(exist_ok=True)
|
||
return weights_dir
|
||
|
||
def list_checkpoints():
|
||
"""List available checkpoint versions."""
|
||
weights_dir = get_weights_dir()
|
||
checkpoints = sorted(weights_dir.glob("nnue_weights_v*.pt"))
|
||
if not checkpoints:
|
||
return []
|
||
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
||
|
||
def show_header():
|
||
"""Display application header."""
|
||
console = Console()
|
||
console.clear()
|
||
console.print(
|
||
Panel(
|
||
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
|
||
"[dim]Neural Network Utility Evaluation - Model Management[/dim]",
|
||
border_style="cyan",
|
||
padding=(1, 2),
|
||
)
|
||
)
|
||
|
||
def show_checkpoints_table():
|
||
"""Display available checkpoints in a table."""
|
||
console = Console()
|
||
available = list_checkpoints()
|
||
|
||
if not available:
|
||
console.print("[yellow]ℹ No checkpoints found yet[/yellow]")
|
||
return
|
||
|
||
table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan")
|
||
table.add_column("Version", style="dim")
|
||
table.add_column("File Size", justify="right")
|
||
table.add_column("Status", justify="center")
|
||
|
||
weights_dir = get_weights_dir()
|
||
for v in sorted(available):
|
||
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
|
||
if weights_file.exists():
|
||
size = weights_file.stat().st_size / (1024**2)
|
||
table.add_row(f"v{v}", f"{size:.1f} MB", "✓ Ready")
|
||
else:
|
||
table.add_row(f"v{v}", "?", "[red]✗ Missing[/red]")
|
||
|
||
console.print(table)
|
||
|
||
def show_main_menu():
|
||
"""Display and handle main menu."""
|
||
console = Console()
|
||
|
||
while True:
|
||
show_header()
|
||
show_checkpoints_table()
|
||
|
||
console.print("\n[bold]What would you like to do?[/bold]")
|
||
console.print("[cyan]1[/cyan] - Train NNUE Model")
|
||
console.print("[cyan]2[/cyan] - Export Weights to Scala")
|
||
console.print("[cyan]3[/cyan] - Extract Tactical Positions")
|
||
console.print("[cyan]4[/cyan] - View Checkpoints")
|
||
console.print("[cyan]5[/cyan] - Exit")
|
||
|
||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
|
||
|
||
if choice == "1":
|
||
train_interactive()
|
||
elif choice == "2":
|
||
export_interactive()
|
||
elif choice == "3":
|
||
extract_tactical_interactive()
|
||
elif choice == "4":
|
||
show_header()
|
||
show_checkpoints_table()
|
||
Prompt.ask("\nPress Enter to continue")
|
||
elif choice == "5":
|
||
console.print("[yellow]👋 Goodbye![/yellow]")
|
||
return
|
||
|
||
def train_interactive():
|
||
"""Interactive training menu."""
|
||
console = Console()
|
||
show_header()
|
||
|
||
console.print("\n[bold cyan]📚 Training Configuration[/bold cyan]")
|
||
|
||
# Checkpoint selection
|
||
available = list_checkpoints()
|
||
use_checkpoint = False
|
||
checkpoint_version = None
|
||
|
||
if available:
|
||
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||
use_checkpoint = Confirm.ask("Start from an existing checkpoint?", default=False)
|
||
if use_checkpoint:
|
||
checkpoint_version = Prompt.ask(
|
||
"Enter checkpoint version",
|
||
default=str(max(available))
|
||
)
|
||
|
||
# Positions source
|
||
use_existing = Confirm.ask("Use existing positions file?", default=False)
|
||
positions_file = None
|
||
num_games = 500000
|
||
samples_per_game = 1
|
||
min_move = 1
|
||
max_move = 50
|
||
|
||
if use_existing:
|
||
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
|
||
else:
|
||
num_games = int(Prompt.ask("Number of games to generate", default="5000"))
|
||
samples_per_game = int(Prompt.ask("Positions to sample per game", default="1"))
|
||
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||
|
||
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
|
||
labels_file = None
|
||
if use_existing_labels:
|
||
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
|
||
|
||
# Stockfish path and labeling parameters
|
||
default_stockfish = os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
|
||
stockfish_depth = 12
|
||
num_workers = 1
|
||
if not use_existing_labels:
|
||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||
|
||
# Training parameters
|
||
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||
early_stopping = None
|
||
if Confirm.ask("Enable early stopping?", default=False):
|
||
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||
|
||
# Confirm and start
|
||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||
if use_checkpoint:
|
||
console.print(f" Checkpoint: v{checkpoint_version}")
|
||
else:
|
||
console.print(" Checkpoint: None (training from scratch)")
|
||
if not use_existing:
|
||
console.print(f" Games: {num_games:,}")
|
||
console.print(f" Samples per game: {samples_per_game}")
|
||
console.print(f" Move range: {min_move}-{max_move}")
|
||
else:
|
||
console.print(f" Positions file: {positions_file}")
|
||
console.print(f" Epochs: {epochs}")
|
||
console.print(f" Batch size: {batch_size}")
|
||
if early_stopping:
|
||
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||
else:
|
||
console.print(f" Early stopping: No")
|
||
if not use_existing_labels:
|
||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||
console.print(f" Workers: {num_workers}")
|
||
console.print(f" Stockfish: {stockfish_path}")
|
||
|
||
if not Confirm.ask("\nStart training?", default=True):
|
||
console.print("[yellow]Training cancelled[/yellow]")
|
||
return
|
||
|
||
# Execute training
|
||
data_dir = get_data_dir()
|
||
weights_dir = get_weights_dir()
|
||
|
||
try:
|
||
# Generate positions
|
||
if not use_existing:
|
||
console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]")
|
||
count = play_random_game_and_collect_positions(
|
||
str(data_dir / "positions.txt"),
|
||
total_games=num_games,
|
||
samples_per_game=samples_per_game,
|
||
min_move=min_move,
|
||
max_move=max_move
|
||
)
|
||
if count == 0:
|
||
console.print("[red]✗ No valid positions generated[/red]")
|
||
return
|
||
console.print(f"[green]✓ Generated {count:,} positions[/green]")
|
||
else:
|
||
if not Path(positions_file).exists():
|
||
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
|
||
return
|
||
|
||
if not use_existing_labels:
|
||
# Label positions
|
||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||
positions_file = data_dir / "positions.txt"
|
||
output_file = data_dir / "training_data.jsonl"
|
||
success = label_positions_with_stockfish(
|
||
str(positions_file),
|
||
str(output_file),
|
||
stockfish_path,
|
||
depth=stockfish_depth,
|
||
num_workers=num_workers
|
||
)
|
||
if not success:
|
||
console.print("[red]✗ Position labeling failed[/red]")
|
||
return
|
||
console.print(f"[green]✓ Positions labeled[/green]")
|
||
else:
|
||
console.print("\n[bold cyan]Step 2: Loading Existing Labels[/bold cyan]")
|
||
output_file = labels_file
|
||
if not Path(output_file).exists():
|
||
console.print(f"[red]✗ Labels file not found: {output_file}[/red]")
|
||
return
|
||
|
||
# Train model
|
||
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
|
||
checkpoint = None
|
||
if use_checkpoint:
|
||
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
|
||
|
||
train_nnue(
|
||
data_file=str(output_file),
|
||
output_file=str(weights_dir / "nnue_weights.pt"),
|
||
epochs=epochs,
|
||
batch_size=batch_size,
|
||
checkpoint=checkpoint,
|
||
use_versioning=True,
|
||
early_stopping_patience=early_stopping
|
||
)
|
||
console.print("[green]✓ Training complete[/green]")
|
||
|
||
# Show result
|
||
available = list_checkpoints()
|
||
new_version = max(available) if available else 1
|
||
console.print(f"\n[bold green]✓ Training successful![/bold green]")
|
||
console.print(f"[bold]New checkpoint: v{new_version}[/bold]")
|
||
Prompt.ask("Press Enter to continue")
|
||
|
||
except Exception as e:
|
||
console.print(f"[red]✗ Error: {e}[/red]")
|
||
import traceback
|
||
traceback.print_exc()
|
||
Prompt.ask("Press Enter to continue")
|
||
|
||
def export_interactive():
|
||
"""Interactive export menu."""
|
||
console = Console()
|
||
show_header()
|
||
|
||
console.print("\n[bold cyan]📦 Export Configuration[/bold cyan]")
|
||
|
||
# Select weights version
|
||
available = list_checkpoints()
|
||
if not available:
|
||
console.print("[red]✗ No checkpoints available to export[/red]")
|
||
Prompt.ask("Press Enter to continue")
|
||
return
|
||
|
||
console.print(f"[dim]Available versions: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||
version = Prompt.ask("Enter version to export (e.g., 2)")
|
||
|
||
weights_file = f"nnue_weights_v{version}.pt"
|
||
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
|
||
|
||
console.print(f"\n[bold]Export Configuration:[/bold]")
|
||
console.print(f" Source: {weights_file}")
|
||
console.print(f" Destination: {output_file}")
|
||
|
||
if not Confirm.ask("\nExport weights?", default=True):
|
||
console.print("[yellow]Export cancelled[/yellow]")
|
||
return
|
||
|
||
try:
|
||
weights_dir = get_weights_dir()
|
||
weights_path = weights_dir / weights_file
|
||
|
||
if not weights_path.exists():
|
||
console.print(f"[red]✗ {weights_file} not found[/red]")
|
||
return
|
||
|
||
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
|
||
export_weights_to_binary(str(weights_path), output_file)
|
||
console.print(f"\n[green]✓ Export complete![/green]")
|
||
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
|
||
Prompt.ask("Press Enter to continue")
|
||
|
||
except Exception as e:
|
||
console.print(f"[red]✗ Error: {e}[/red]")
|
||
import traceback
|
||
traceback.print_exc()
|
||
Prompt.ask("Press Enter to continue")
|
||
|
||
def extract_tactical_interactive():
|
||
"""Interactive tactical positions extraction and merge menu."""
|
||
console = Console()
|
||
show_header()
|
||
|
||
console.print("\n[bold cyan]♟️ Tactical Positions Extraction & Merge[/bold cyan]")
|
||
|
||
# Download and extract options
|
||
console.print("\n[bold]Lichess Puzzle Database:[/bold]")
|
||
download_url = Prompt.ask(
|
||
"Download URL",
|
||
default="https://database.lichess.org/lichess_db_puzzle.csv.zst"
|
||
)
|
||
|
||
output_dir = Prompt.ask(
|
||
"Extract to directory",
|
||
default=str(Path(__file__).parent / "trainingdata")
|
||
)
|
||
|
||
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||
|
||
# Confirm and download
|
||
console.print("\n[bold]Configuration:[/bold]")
|
||
console.print(f" Download URL: {download_url}")
|
||
console.print(f" Extract directory: {output_dir}")
|
||
console.print(f" Max puzzles: {max_puzzles:,}")
|
||
|
||
if not Confirm.ask("\nProceed?", default=True):
|
||
console.print("[yellow]Cancelled[/yellow]")
|
||
Prompt.ask("Press Enter to continue")
|
||
return
|
||
|
||
try:
|
||
console.print("\n[bold cyan]Step 1: Download & Extract[/bold cyan]")
|
||
csv_path = download_and_extract_puzzle_db(download_url, output_dir)
|
||
|
||
if not csv_path:
|
||
console.print("[red]✗ Failed to download/extract[/red]")
|
||
Prompt.ask("Press Enter to continue")
|
||
return
|
||
|
||
console.print(f"[green]✓ Ready: {csv_path}[/green]")
|
||
|
||
# Interactive merge
|
||
console.print("\n[bold cyan]Step 2: Extract & Merge[/bold cyan]")
|
||
output_file = Prompt.ask(
|
||
"Output file path",
|
||
default=str(Path(__file__).parent / "data" / "position.txt")
|
||
)
|
||
|
||
interactive_merge_positions(csv_path, output_file, max_puzzles)
|
||
console.print(f"\n[green]✓ Complete![/green]")
|
||
Prompt.ask("Press Enter to continue")
|
||
|
||
except Exception as e:
|
||
console.print(f"[red]✗ Error: {e}[/red]")
|
||
import traceback
|
||
traceback.print_exc()
|
||
Prompt.ask("Press Enter to continue")
|
||
|
||
def main():
|
||
try:
|
||
show_main_menu()
|
||
return 0
|
||
except KeyboardInterrupt:
|
||
console = Console()
|
||
console.print("\n[yellow]Interrupted by user[/yellow]")
|
||
return 1
|
||
except Exception as e:
|
||
console = Console()
|
||
console.print(f"[red]Error:[/red] {e}")
|
||
return 1
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|