feat: add rich console interface for NNUE training pipeline and update requirements
This commit is contained in:
+246
-233
@@ -1,21 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Central NNUE pipeline CLI for training and exporting models."""
|
||||
"""Central NNUE pipeline TUI for training and exporting models."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt, Confirm
|
||||
from rich import print as rprint
|
||||
|
||||
def get_python_cmd():
|
||||
"""Get available Python command."""
|
||||
if os.name == 'nt':
|
||||
return "python"
|
||||
return "python3" if os.popen("which python3 2>/dev/null").read() else "python"
|
||||
# Add src directory to path so we can import modules
|
||||
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||||
|
||||
def get_src_module(module_name):
|
||||
"""Get path to module in src/ directory."""
|
||||
return Path(__file__).parent / "src" / f"{module_name}.py"
|
||||
from generate import play_random_game_and_collect_positions
|
||||
from label import label_positions_with_stockfish
|
||||
from train import train_nnue
|
||||
from export import export_weights_to_binary
|
||||
|
||||
def get_data_dir():
|
||||
"""Get/create data directory."""
|
||||
@@ -37,240 +38,252 @@ def list_checkpoints():
|
||||
return []
|
||||
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
||||
|
||||
def run_generate_positions(num_games):
|
||||
"""Generate random positions."""
|
||||
data_dir = get_data_dir()
|
||||
positions_file = data_dir / "positions.txt"
|
||||
print(f"Generating {num_games} positions...")
|
||||
result = subprocess.run(
|
||||
[get_python_cmd(), str(get_src_module("generate")), str(positions_file), "--games", str(num_games)],
|
||||
capture_output=False
|
||||
def show_header():
|
||||
"""Display application header."""
|
||||
console = Console()
|
||||
console.clear()
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
|
||||
"[dim]Neural Network Utility Evaluation - Model Management[/dim]",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print("ERROR: Position generation failed")
|
||||
return False
|
||||
return positions_file.exists()
|
||||
|
||||
def run_label_positions(stockfish_path):
|
||||
"""Label positions with Stockfish."""
|
||||
data_dir = get_data_dir()
|
||||
positions_file = data_dir / "positions.txt"
|
||||
output_file = data_dir / "training_data.jsonl"
|
||||
|
||||
if not positions_file.exists():
|
||||
print("ERROR: positions.txt not found")
|
||||
return False
|
||||
|
||||
print("Labeling positions with Stockfish...")
|
||||
result = subprocess.run(
|
||||
[get_python_cmd(), str(get_src_module("label")), str(positions_file), str(output_file), stockfish_path],
|
||||
capture_output=False
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print("ERROR: Position labeling failed")
|
||||
return False
|
||||
return output_file.exists()
|
||||
|
||||
def run_train(positions_file, output_weights, from_checkpoint=None):
|
||||
"""Train NNUE model."""
|
||||
if not Path(positions_file).exists():
|
||||
print(f"ERROR: {positions_file} not found")
|
||||
return False
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
print(f"Training model (output: {output_weights})...")
|
||||
if from_checkpoint:
|
||||
print(f" Starting from checkpoint: {from_checkpoint}")
|
||||
|
||||
cmd = [get_python_cmd(), str(get_src_module("train")), str(positions_file), str(output_weights)]
|
||||
if from_checkpoint:
|
||||
cmd.extend(["--checkpoint", str(from_checkpoint)])
|
||||
|
||||
# Run from weights directory so outputs save there
|
||||
result = subprocess.run(cmd, cwd=str(weights_dir), capture_output=False)
|
||||
if result.returncode != 0:
|
||||
print("ERROR: Training failed")
|
||||
return False
|
||||
return True
|
||||
|
||||
def run_export(weights_file, output_file):
|
||||
"""Export weights to Scala."""
|
||||
weights_dir = get_weights_dir()
|
||||
weights_path = weights_dir / Path(weights_file).name
|
||||
|
||||
if not weights_path.exists():
|
||||
print(f"ERROR: {weights_file} not found in {weights_dir}")
|
||||
return False
|
||||
|
||||
print(f"Exporting {weights_file} to Scala...")
|
||||
result = subprocess.run(
|
||||
[get_python_cmd(), str(get_src_module("export")), str(weights_path), output_file],
|
||||
capture_output=False
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print("ERROR: Export failed")
|
||||
return False
|
||||
return Path(output_file).exists()
|
||||
|
||||
def cmd_train(args):
|
||||
"""Handle train command."""
|
||||
stockfish_path = args.stockfish or os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
|
||||
data_dir = get_data_dir()
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
# Determine checkpoint
|
||||
checkpoint = None
|
||||
if args.from_checkpoint:
|
||||
checkpoint_version = args.from_checkpoint
|
||||
checkpoint = f"nnue_weights_v{checkpoint_version}.pt"
|
||||
checkpoint_path = weights_dir / checkpoint
|
||||
if not checkpoint_path.exists():
|
||||
print(f"ERROR: Checkpoint {checkpoint} not found")
|
||||
return False
|
||||
else:
|
||||
available = list_checkpoints()
|
||||
if available:
|
||||
latest = max(available)
|
||||
checkpoint = f"nnue_weights_v{latest}.pt"
|
||||
print(f"No checkpoint specified, using latest: v{latest}")
|
||||
|
||||
# Generate or use existing positions
|
||||
if args.positions_file:
|
||||
positions_file = Path(args.positions_file)
|
||||
if not positions_file.exists():
|
||||
print(f"ERROR: {args.positions_file} not found")
|
||||
return False
|
||||
else:
|
||||
positions_file = data_dir / "positions.txt"
|
||||
num_games = args.games or 500000
|
||||
if not run_generate_positions(num_games):
|
||||
return False
|
||||
|
||||
# Label positions
|
||||
if not run_label_positions(stockfish_path):
|
||||
return False
|
||||
|
||||
print("\nStarting training...")
|
||||
|
||||
# Train with absolute path to data, checkpoint is relative to weights dir
|
||||
training_data = str(data_dir / "training_data.jsonl")
|
||||
if not run_train(training_data, "nnue_weights.pt", checkpoint):
|
||||
return False
|
||||
|
||||
# Show created version
|
||||
def show_checkpoints_table():
|
||||
"""Display available checkpoints in a table."""
|
||||
console = Console()
|
||||
available = list_checkpoints()
|
||||
new_version = max(available) if available else 1
|
||||
print(f"\n✓ Training complete: nnue_weights_v{new_version}.pt")
|
||||
return True
|
||||
|
||||
def cmd_export(args):
|
||||
"""Handle export command."""
|
||||
weights_file = args.weights
|
||||
|
||||
# Auto-detect if version is specified
|
||||
if not weights_file.endswith(".pt"):
|
||||
weights_file = f"nnue_weights_v{weights_file}.pt"
|
||||
|
||||
# Output to resources directory as binary format
|
||||
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
|
||||
|
||||
if not run_export(weights_file, output_file):
|
||||
return False
|
||||
|
||||
print(f"✓ Export complete: {output_file}")
|
||||
return True
|
||||
|
||||
def cmd_list(args):
|
||||
"""List available checkpoints."""
|
||||
available = list_checkpoints()
|
||||
if not available:
|
||||
print("No checkpoints found")
|
||||
return True
|
||||
console.print("[yellow]ℹ No checkpoints found yet[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan")
|
||||
table.add_column("Version", style="dim")
|
||||
table.add_column("File Size", justify="right")
|
||||
table.add_column("Status", justify="center")
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
print("Available checkpoints:")
|
||||
for v in available:
|
||||
for v in sorted(available):
|
||||
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
|
||||
if weights_file.exists():
|
||||
size = weights_file.stat().st_size / (1024**2) # MB
|
||||
print(f" v{v} ({size:.1f} MB)")
|
||||
size = weights_file.stat().st_size / (1024**2)
|
||||
table.add_row(f"v{v}", f"{size:.1f} MB", "✓ Ready")
|
||||
else:
|
||||
print(f" v{v} (file not found)")
|
||||
return True
|
||||
table.add_row(f"v{v}", "?", "[red]✗ Missing[/red]")
|
||||
|
||||
console.print(table)
|
||||
|
||||
def show_main_menu():
|
||||
"""Display and handle main menu."""
|
||||
console = Console()
|
||||
|
||||
while True:
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
|
||||
console.print("\n[bold]What would you like to do?[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Train NNUE Model")
|
||||
console.print("[cyan]2[/cyan] - Export Weights to Scala")
|
||||
console.print("[cyan]3[/cyan] - View Checkpoints")
|
||||
console.print("[cyan]4[/cyan] - Exit")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||
|
||||
if choice == "1":
|
||||
train_interactive()
|
||||
elif choice == "2":
|
||||
export_interactive()
|
||||
elif choice == "3":
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
elif choice == "4":
|
||||
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||
return
|
||||
|
||||
def train_interactive():
|
||||
"""Interactive training menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📚 Training Configuration[/bold cyan]")
|
||||
|
||||
# Checkpoint selection
|
||||
available = list_checkpoints()
|
||||
use_checkpoint = False
|
||||
checkpoint_version = None
|
||||
|
||||
if available:
|
||||
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||
use_checkpoint = Confirm.ask("Start from an existing checkpoint?", default=False)
|
||||
if use_checkpoint:
|
||||
checkpoint_version = Prompt.ask(
|
||||
"Enter checkpoint version",
|
||||
default=str(max(available))
|
||||
)
|
||||
|
||||
# Positions source
|
||||
use_existing = Confirm.ask("Use existing positions file?", default=False)
|
||||
positions_file = None
|
||||
num_games = 500000
|
||||
|
||||
if use_existing:
|
||||
positions_file = Prompt.ask("Enter path to positions file")
|
||||
else:
|
||||
num_games = int(Prompt.ask("Number of games to generate", default="500000"))
|
||||
|
||||
# Stockfish path
|
||||
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
|
||||
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
|
||||
|
||||
# Training parameters
|
||||
epochs = int(Prompt.ask("Number of epochs", default="20"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="4096"))
|
||||
|
||||
# Confirm and start
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
if use_checkpoint:
|
||||
console.print(f" Checkpoint: v{checkpoint_version}")
|
||||
else:
|
||||
console.print(" Checkpoint: None (training from scratch)")
|
||||
console.print(f" Games: {num_games:,}")
|
||||
console.print(f" Epochs: {epochs}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Stockfish: {stockfish_path}")
|
||||
|
||||
if not Confirm.ask("\nStart training?", default=True):
|
||||
console.print("[yellow]Training cancelled[/yellow]")
|
||||
return
|
||||
|
||||
# Execute training
|
||||
data_dir = get_data_dir()
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
try:
|
||||
# Generate positions
|
||||
if not use_existing:
|
||||
console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]")
|
||||
count = play_random_game_and_collect_positions(
|
||||
str(data_dir / "positions.txt"),
|
||||
total_games=num_games,
|
||||
filter_captures=True
|
||||
)
|
||||
if count == 0:
|
||||
console.print("[red]✗ No valid positions generated[/red]")
|
||||
return
|
||||
console.print(f"[green]✓ Generated {count:,} positions[/green]")
|
||||
else:
|
||||
if not Path(positions_file).exists():
|
||||
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
|
||||
return
|
||||
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
positions_file = data_dir / "positions.txt"
|
||||
output_file = data_dir / "training_data.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(positions_file),
|
||||
str(output_file),
|
||||
stockfish_path,
|
||||
depth=12
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Position labeling failed[/red]")
|
||||
return
|
||||
console.print(f"[green]✓ Positions labeled[/green]")
|
||||
|
||||
# Train model
|
||||
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
|
||||
checkpoint = None
|
||||
if use_checkpoint:
|
||||
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
|
||||
|
||||
train_nnue(
|
||||
data_file=str(output_file),
|
||||
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
checkpoint=checkpoint,
|
||||
use_versioning=True
|
||||
)
|
||||
console.print("[green]✓ Training complete[/green]")
|
||||
|
||||
# Show result
|
||||
available = list_checkpoints()
|
||||
new_version = max(available) if available else 1
|
||||
console.print(f"\n[bold green]✓ Training successful![/bold green]")
|
||||
console.print(f"[bold]New checkpoint: v{new_version}[/bold]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def export_interactive():
|
||||
"""Interactive export menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📦 Export Configuration[/bold cyan]")
|
||||
|
||||
# Select weights version
|
||||
available = list_checkpoints()
|
||||
if not available:
|
||||
console.print("[red]✗ No checkpoints available to export[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print(f"[dim]Available versions: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||
version = Prompt.ask("Enter version to export (e.g., 2)")
|
||||
|
||||
weights_file = f"nnue_weights_v{version}.pt"
|
||||
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
|
||||
|
||||
console.print(f"\n[bold]Export Configuration:[/bold]")
|
||||
console.print(f" Source: {weights_file}")
|
||||
console.print(f" Destination: {output_file}")
|
||||
|
||||
if not Confirm.ask("\nExport weights?", default=True):
|
||||
console.print("[yellow]Export cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
weights_dir = get_weights_dir()
|
||||
weights_path = weights_dir / weights_file
|
||||
|
||||
if not weights_path.exists():
|
||||
console.print(f"[red]✗ {weights_file} not found[/red]")
|
||||
return
|
||||
|
||||
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
|
||||
export_weights_to_binary(str(weights_path), output_file)
|
||||
console.print(f"\n[green]✓ Export complete![/green]")
|
||||
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="NNUE pipeline CLI for training and exporting models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Train with 500k random positions
|
||||
python nnue.py train
|
||||
|
||||
# Train from checkpoint v2
|
||||
python nnue.py train --from-checkpoint 2
|
||||
|
||||
# Train with custom positions file
|
||||
python nnue.py train --positions-file my_positions.txt
|
||||
|
||||
# Train with 200k games
|
||||
python nnue.py train --games 200000
|
||||
|
||||
# Export specific weights version
|
||||
python nnue.py export 2
|
||||
|
||||
# Export with full filename
|
||||
python nnue.py export nnue_weights_v3.pt
|
||||
|
||||
# List available checkpoints
|
||||
python nnue.py list
|
||||
"""
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
|
||||
# Train subcommand
|
||||
train_parser = subparsers.add_parser("train", help="Train NNUE model")
|
||||
train_parser.add_argument(
|
||||
"--from-checkpoint",
|
||||
type=int,
|
||||
help="Start training from checkpoint version (e.g., 2)"
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--games",
|
||||
type=int,
|
||||
help="Number of games to generate (default: 500000)"
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--positions-file",
|
||||
help="Use existing positions file instead of generating"
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--stockfish",
|
||||
help="Path to Stockfish binary (default: $STOCKFISH_PATH or /usr/games/stockfish)"
|
||||
)
|
||||
train_parser.set_defaults(func=cmd_train)
|
||||
|
||||
# Export subcommand
|
||||
export_parser = subparsers.add_parser("export", help="Export weights to Scala")
|
||||
export_parser.add_argument(
|
||||
"weights",
|
||||
help="Weights file or version (e.g., 2 or nnue_weights_v2.pt)"
|
||||
)
|
||||
export_parser.set_defaults(func=cmd_export)
|
||||
|
||||
# List subcommand
|
||||
list_parser = subparsers.add_parser("list", help="List available checkpoints")
|
||||
list_parser.set_defaults(func=cmd_list)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
try:
|
||||
show_main_menu()
|
||||
return 0
|
||||
|
||||
success = args.func(args)
|
||||
return 0 if success else 1
|
||||
except KeyboardInterrupt:
|
||||
console = Console()
|
||||
console.print("\n[yellow]Interrupted by user[/yellow]")
|
||||
return 1
|
||||
except Exception as e:
|
||||
console = Console()
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
Reference in New Issue
Block a user