Files
NowChessSystems/modules/bot/python/nnue.py
T

302 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Central NNUE pipeline TUI for training and exporting models."""
import os
import sys
from pathlib import Path
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.prompt import Prompt, Confirm
from rich import print as rprint
# Add src directory to path so we can import modules
sys.path.insert(0, str(Path(__file__).parent / "src"))
from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish
from train import train_nnue
from export import export_weights_to_binary
def get_data_dir():
"""Get/create data directory."""
data_dir = Path(__file__).parent / "data"
data_dir.mkdir(exist_ok=True)
return data_dir
def get_weights_dir():
"""Get/create weights directory."""
weights_dir = Path(__file__).parent / "weights"
weights_dir.mkdir(exist_ok=True)
return weights_dir
def list_checkpoints():
"""List available checkpoint versions."""
weights_dir = get_weights_dir()
checkpoints = sorted(weights_dir.glob("nnue_weights_v*.pt"))
if not checkpoints:
return []
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
def show_header():
"""Display application header."""
console = Console()
console.clear()
console.print(
Panel(
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
"[dim]Neural Network Utility Evaluation - Model Management[/dim]",
border_style="cyan",
padding=(1, 2),
)
)
def show_checkpoints_table():
"""Display available checkpoints in a table."""
console = Console()
available = list_checkpoints()
if not available:
console.print("[yellow] No checkpoints found yet[/yellow]")
return
table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan")
table.add_column("Version", style="dim")
table.add_column("File Size", justify="right")
table.add_column("Status", justify="center")
weights_dir = get_weights_dir()
for v in sorted(available):
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
if weights_file.exists():
size = weights_file.stat().st_size / (1024**2)
table.add_row(f"v{v}", f"{size:.1f} MB", "✓ Ready")
else:
table.add_row(f"v{v}", "?", "[red]✗ Missing[/red]")
console.print(table)
def show_main_menu():
"""Display and handle main menu."""
console = Console()
while True:
show_header()
show_checkpoints_table()
console.print("\n[bold]What would you like to do?[/bold]")
console.print("[cyan]1[/cyan] - Train NNUE Model")
console.print("[cyan]2[/cyan] - Export Weights to Scala")
console.print("[cyan]3[/cyan] - View Checkpoints")
console.print("[cyan]4[/cyan] - Exit")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
if choice == "1":
train_interactive()
elif choice == "2":
export_interactive()
elif choice == "3":
show_header()
show_checkpoints_table()
Prompt.ask("\nPress Enter to continue")
elif choice == "4":
console.print("[yellow]👋 Goodbye![/yellow]")
return
def train_interactive():
"""Interactive training menu."""
console = Console()
show_header()
console.print("\n[bold cyan]📚 Training Configuration[/bold cyan]")
# Checkpoint selection
available = list_checkpoints()
use_checkpoint = False
checkpoint_version = None
if available:
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
use_checkpoint = Confirm.ask("Start from an existing checkpoint?", default=False)
if use_checkpoint:
checkpoint_version = Prompt.ask(
"Enter checkpoint version",
default=str(max(available))
)
# Positions source
use_existing = Confirm.ask("Use existing positions file?", default=False)
positions_file = None
num_games = 500000
if use_existing:
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
else:
num_games = int(Prompt.ask("Number of games to generate", default="500000"))
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
labels_file = None
if use_existing_labels:
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
# Stockfish path
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/bin/stockfish")
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
# Training parameters
epochs = int(Prompt.ask("Number of epochs", default="20"))
batch_size = int(Prompt.ask("Batch size", default="4096"))
# Confirm and start
console.print("\n[bold]Configuration Summary:[/bold]")
if use_checkpoint:
console.print(f" Checkpoint: v{checkpoint_version}")
else:
console.print(" Checkpoint: None (training from scratch)")
console.print(f" Games: {num_games:,}")
console.print(f" Epochs: {epochs}")
console.print(f" Batch size: {batch_size}")
console.print(f" Stockfish: {stockfish_path}")
if not Confirm.ask("\nStart training?", default=True):
console.print("[yellow]Training cancelled[/yellow]")
return
# Execute training
data_dir = get_data_dir()
weights_dir = get_weights_dir()
try:
# Generate positions
if not use_existing:
console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]")
count = play_random_game_and_collect_positions(
str(data_dir / "positions.txt"),
total_games=num_games,
filter_captures=True
)
if count == 0:
console.print("[red]✗ No valid positions generated[/red]")
return
console.print(f"[green]✓ Generated {count:,} positions[/green]")
else:
if not Path(positions_file).exists():
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
return
if not use_existing_labels:
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
positions_file = data_dir / "positions.txt"
output_file = data_dir / "training_data.jsonl"
success = label_positions_with_stockfish(
str(positions_file),
str(output_file),
stockfish_path,
depth=12
)
if not success:
console.print("[red]✗ Position labeling failed[/red]")
return
console.print(f"[green]✓ Positions labeled[/green]")
else:
console.print("\n[bold cyan]Step 2: Loading Existing Labels[/bold cyan]")
output_file = labels_file
if not Path(output_file).exists():
console.print(f"[red]✗ Labels file not found: {output_file}[/red]")
return
# Train model
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
checkpoint = None
if use_checkpoint:
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
train_nnue(
data_file=str(output_file),
output_file=str(weights_dir / "nnue_weights.pt"),
epochs=epochs,
batch_size=batch_size,
checkpoint=checkpoint,
use_versioning=True
)
console.print("[green]✓ Training complete[/green]")
# Show result
available = list_checkpoints()
new_version = max(available) if available else 1
console.print(f"\n[bold green]✓ Training successful![/bold green]")
console.print(f"[bold]New checkpoint: v{new_version}[/bold]")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def export_interactive():
"""Interactive export menu."""
console = Console()
show_header()
console.print("\n[bold cyan]📦 Export Configuration[/bold cyan]")
# Select weights version
available = list_checkpoints()
if not available:
console.print("[red]✗ No checkpoints available to export[/red]")
Prompt.ask("Press Enter to continue")
return
console.print(f"[dim]Available versions: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
version = Prompt.ask("Enter version to export (e.g., 2)")
weights_file = f"nnue_weights_v{version}.pt"
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
console.print(f"\n[bold]Export Configuration:[/bold]")
console.print(f" Source: {weights_file}")
console.print(f" Destination: {output_file}")
if not Confirm.ask("\nExport weights?", default=True):
console.print("[yellow]Export cancelled[/yellow]")
return
try:
weights_dir = get_weights_dir()
weights_path = weights_dir / weights_file
if not weights_path.exists():
console.print(f"[red]✗ {weights_file} not found[/red]")
return
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
export_weights_to_binary(str(weights_path), output_file)
console.print(f"\n[green]✓ Export complete![/green]")
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def main():
try:
show_main_menu()
return 0
except KeyboardInterrupt:
console = Console()
console.print("\n[yellow]Interrupted by user[/yellow]")
return 1
except Exception as e:
console = Console()
console.print(f"[red]Error:[/red] {e}")
return 1
if __name__ == "__main__":
sys.exit(main())