feat: add rich console interface for NNUE training pipeline and update requirements

This commit is contained in:
2026-04-07 23:59:17 +02:00
parent b25be99dcf
commit adc2de23bc
5 changed files with 301 additions and 243 deletions
+246 -233
View File
@@ -1,21 +1,22 @@
#!/usr/bin/env python3
"""Central NNUE pipeline CLI for training and exporting models."""
"""Central NNUE pipeline TUI for training and exporting models."""
import argparse
import os
import sys
import subprocess
from pathlib import Path
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.prompt import Prompt, Confirm
from rich import print as rprint
def get_python_cmd():
"""Get available Python command."""
if os.name == 'nt':
return "python"
return "python3" if os.popen("which python3 2>/dev/null").read() else "python"
# Add src directory to path so we can import modules
sys.path.insert(0, str(Path(__file__).parent / "src"))
def get_src_module(module_name):
"""Get path to module in src/ directory."""
return Path(__file__).parent / "src" / f"{module_name}.py"
from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish
from train import train_nnue
from export import export_weights_to_binary
def get_data_dir():
"""Get/create data directory."""
@@ -37,240 +38,252 @@ def list_checkpoints():
return []
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
def run_generate_positions(num_games):
"""Generate random positions."""
data_dir = get_data_dir()
positions_file = data_dir / "positions.txt"
print(f"Generating {num_games} positions...")
result = subprocess.run(
[get_python_cmd(), str(get_src_module("generate")), str(positions_file), "--games", str(num_games)],
capture_output=False
def show_header():
"""Display application header."""
console = Console()
console.clear()
console.print(
Panel(
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
"[dim]Neural Network Utility Evaluation - Model Management[/dim]",
border_style="cyan",
padding=(1, 2),
)
)
if result.returncode != 0:
print("ERROR: Position generation failed")
return False
return positions_file.exists()
def run_label_positions(stockfish_path):
"""Label positions with Stockfish."""
data_dir = get_data_dir()
positions_file = data_dir / "positions.txt"
output_file = data_dir / "training_data.jsonl"
if not positions_file.exists():
print("ERROR: positions.txt not found")
return False
print("Labeling positions with Stockfish...")
result = subprocess.run(
[get_python_cmd(), str(get_src_module("label")), str(positions_file), str(output_file), stockfish_path],
capture_output=False
)
if result.returncode != 0:
print("ERROR: Position labeling failed")
return False
return output_file.exists()
def run_train(positions_file, output_weights, from_checkpoint=None):
"""Train NNUE model."""
if not Path(positions_file).exists():
print(f"ERROR: {positions_file} not found")
return False
weights_dir = get_weights_dir()
print(f"Training model (output: {output_weights})...")
if from_checkpoint:
print(f" Starting from checkpoint: {from_checkpoint}")
cmd = [get_python_cmd(), str(get_src_module("train")), str(positions_file), str(output_weights)]
if from_checkpoint:
cmd.extend(["--checkpoint", str(from_checkpoint)])
# Run from weights directory so outputs save there
result = subprocess.run(cmd, cwd=str(weights_dir), capture_output=False)
if result.returncode != 0:
print("ERROR: Training failed")
return False
return True
def run_export(weights_file, output_file):
"""Export weights to Scala."""
weights_dir = get_weights_dir()
weights_path = weights_dir / Path(weights_file).name
if not weights_path.exists():
print(f"ERROR: {weights_file} not found in {weights_dir}")
return False
print(f"Exporting {weights_file} to Scala...")
result = subprocess.run(
[get_python_cmd(), str(get_src_module("export")), str(weights_path), output_file],
capture_output=False
)
if result.returncode != 0:
print("ERROR: Export failed")
return False
return Path(output_file).exists()
def cmd_train(args):
"""Handle train command."""
stockfish_path = args.stockfish or os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
data_dir = get_data_dir()
weights_dir = get_weights_dir()
# Determine checkpoint
checkpoint = None
if args.from_checkpoint:
checkpoint_version = args.from_checkpoint
checkpoint = f"nnue_weights_v{checkpoint_version}.pt"
checkpoint_path = weights_dir / checkpoint
if not checkpoint_path.exists():
print(f"ERROR: Checkpoint {checkpoint} not found")
return False
else:
available = list_checkpoints()
if available:
latest = max(available)
checkpoint = f"nnue_weights_v{latest}.pt"
print(f"No checkpoint specified, using latest: v{latest}")
# Generate or use existing positions
if args.positions_file:
positions_file = Path(args.positions_file)
if not positions_file.exists():
print(f"ERROR: {args.positions_file} not found")
return False
else:
positions_file = data_dir / "positions.txt"
num_games = args.games or 500000
if not run_generate_positions(num_games):
return False
# Label positions
if not run_label_positions(stockfish_path):
return False
print("\nStarting training...")
# Train with absolute path to data, checkpoint is relative to weights dir
training_data = str(data_dir / "training_data.jsonl")
if not run_train(training_data, "nnue_weights.pt", checkpoint):
return False
# Show created version
def show_checkpoints_table():
"""Display available checkpoints in a table."""
console = Console()
available = list_checkpoints()
new_version = max(available) if available else 1
print(f"\n✓ Training complete: nnue_weights_v{new_version}.pt")
return True
def cmd_export(args):
"""Handle export command."""
weights_file = args.weights
# Auto-detect if version is specified
if not weights_file.endswith(".pt"):
weights_file = f"nnue_weights_v{weights_file}.pt"
# Output to resources directory as binary format
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
if not run_export(weights_file, output_file):
return False
print(f"✓ Export complete: {output_file}")
return True
def cmd_list(args):
"""List available checkpoints."""
available = list_checkpoints()
if not available:
print("No checkpoints found")
return True
console.print("[yellow] No checkpoints found yet[/yellow]")
return
table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan")
table.add_column("Version", style="dim")
table.add_column("File Size", justify="right")
table.add_column("Status", justify="center")
weights_dir = get_weights_dir()
print("Available checkpoints:")
for v in available:
for v in sorted(available):
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
if weights_file.exists():
size = weights_file.stat().st_size / (1024**2) # MB
print(f" v{v} ({size:.1f} MB)")
size = weights_file.stat().st_size / (1024**2)
table.add_row(f"v{v}", f"{size:.1f} MB", "✓ Ready")
else:
print(f" v{v} (file not found)")
return True
table.add_row(f"v{v}", "?", "[red]✗ Missing[/red]")
console.print(table)
def show_main_menu():
"""Display and handle main menu."""
console = Console()
while True:
show_header()
show_checkpoints_table()
console.print("\n[bold]What would you like to do?[/bold]")
console.print("[cyan]1[/cyan] - Train NNUE Model")
console.print("[cyan]2[/cyan] - Export Weights to Scala")
console.print("[cyan]3[/cyan] - View Checkpoints")
console.print("[cyan]4[/cyan] - Exit")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
if choice == "1":
train_interactive()
elif choice == "2":
export_interactive()
elif choice == "3":
show_header()
show_checkpoints_table()
Prompt.ask("\nPress Enter to continue")
elif choice == "4":
console.print("[yellow]👋 Goodbye![/yellow]")
return
def train_interactive():
"""Interactive training menu."""
console = Console()
show_header()
console.print("\n[bold cyan]📚 Training Configuration[/bold cyan]")
# Checkpoint selection
available = list_checkpoints()
use_checkpoint = False
checkpoint_version = None
if available:
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
use_checkpoint = Confirm.ask("Start from an existing checkpoint?", default=False)
if use_checkpoint:
checkpoint_version = Prompt.ask(
"Enter checkpoint version",
default=str(max(available))
)
# Positions source
use_existing = Confirm.ask("Use existing positions file?", default=False)
positions_file = None
num_games = 500000
if use_existing:
positions_file = Prompt.ask("Enter path to positions file")
else:
num_games = int(Prompt.ask("Number of games to generate", default="500000"))
# Stockfish path
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
# Training parameters
epochs = int(Prompt.ask("Number of epochs", default="20"))
batch_size = int(Prompt.ask("Batch size", default="4096"))
# Confirm and start
console.print("\n[bold]Configuration Summary:[/bold]")
if use_checkpoint:
console.print(f" Checkpoint: v{checkpoint_version}")
else:
console.print(" Checkpoint: None (training from scratch)")
console.print(f" Games: {num_games:,}")
console.print(f" Epochs: {epochs}")
console.print(f" Batch size: {batch_size}")
console.print(f" Stockfish: {stockfish_path}")
if not Confirm.ask("\nStart training?", default=True):
console.print("[yellow]Training cancelled[/yellow]")
return
# Execute training
data_dir = get_data_dir()
weights_dir = get_weights_dir()
try:
# Generate positions
if not use_existing:
console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]")
count = play_random_game_and_collect_positions(
str(data_dir / "positions.txt"),
total_games=num_games,
filter_captures=True
)
if count == 0:
console.print("[red]✗ No valid positions generated[/red]")
return
console.print(f"[green]✓ Generated {count:,} positions[/green]")
else:
if not Path(positions_file).exists():
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
return
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
positions_file = data_dir / "positions.txt"
output_file = data_dir / "training_data.jsonl"
success = label_positions_with_stockfish(
str(positions_file),
str(output_file),
stockfish_path,
depth=12
)
if not success:
console.print("[red]✗ Position labeling failed[/red]")
return
console.print(f"[green]✓ Positions labeled[/green]")
# Train model
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
checkpoint = None
if use_checkpoint:
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
train_nnue(
data_file=str(output_file),
output_file=str(weights_dir / "nnue_weights.pt"),
epochs=epochs,
batch_size=batch_size,
checkpoint=checkpoint,
use_versioning=True
)
console.print("[green]✓ Training complete[/green]")
# Show result
available = list_checkpoints()
new_version = max(available) if available else 1
console.print(f"\n[bold green]✓ Training successful![/bold green]")
console.print(f"[bold]New checkpoint: v{new_version}[/bold]")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def export_interactive():
"""Interactive export menu."""
console = Console()
show_header()
console.print("\n[bold cyan]📦 Export Configuration[/bold cyan]")
# Select weights version
available = list_checkpoints()
if not available:
console.print("[red]✗ No checkpoints available to export[/red]")
Prompt.ask("Press Enter to continue")
return
console.print(f"[dim]Available versions: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
version = Prompt.ask("Enter version to export (e.g., 2)")
weights_file = f"nnue_weights_v{version}.pt"
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
console.print(f"\n[bold]Export Configuration:[/bold]")
console.print(f" Source: {weights_file}")
console.print(f" Destination: {output_file}")
if not Confirm.ask("\nExport weights?", default=True):
console.print("[yellow]Export cancelled[/yellow]")
return
try:
weights_dir = get_weights_dir()
weights_path = weights_dir / weights_file
if not weights_path.exists():
console.print(f"[red]✗ {weights_file} not found[/red]")
return
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
export_weights_to_binary(str(weights_path), output_file)
console.print(f"\n[green]✓ Export complete![/green]")
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def main():
parser = argparse.ArgumentParser(
description="NNUE pipeline CLI for training and exporting models",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Train with 500k random positions
python nnue.py train
# Train from checkpoint v2
python nnue.py train --from-checkpoint 2
# Train with custom positions file
python nnue.py train --positions-file my_positions.txt
# Train with 200k games
python nnue.py train --games 200000
# Export specific weights version
python nnue.py export 2
# Export with full filename
python nnue.py export nnue_weights_v3.pt
# List available checkpoints
python nnue.py list
"""
)
subparsers = parser.add_subparsers(dest="command", help="Command to run")
# Train subcommand
train_parser = subparsers.add_parser("train", help="Train NNUE model")
train_parser.add_argument(
"--from-checkpoint",
type=int,
help="Start training from checkpoint version (e.g., 2)"
)
train_parser.add_argument(
"--games",
type=int,
help="Number of games to generate (default: 500000)"
)
train_parser.add_argument(
"--positions-file",
help="Use existing positions file instead of generating"
)
train_parser.add_argument(
"--stockfish",
help="Path to Stockfish binary (default: $STOCKFISH_PATH or /usr/games/stockfish)"
)
train_parser.set_defaults(func=cmd_train)
# Export subcommand
export_parser = subparsers.add_parser("export", help="Export weights to Scala")
export_parser.add_argument(
"weights",
help="Weights file or version (e.g., 2 or nnue_weights_v2.pt)"
)
export_parser.set_defaults(func=cmd_export)
# List subcommand
list_parser = subparsers.add_parser("list", help="List available checkpoints")
list_parser.set_defaults(func=cmd_list)
args = parser.parse_args()
if not args.command:
parser.print_help()
try:
show_main_menu()
return 0
success = args.func(args)
return 0 if success else 1
except KeyboardInterrupt:
console = Console()
console.print("\n[yellow]Interrupted by user[/yellow]")
return 1
except Exception as e:
console = Console()
console.print(f"[red]Error:[/red] {e}")
return 1
if __name__ == "__main__":
sys.exit(main())
sys.exit(main())
+2 -1
View File
@@ -1,4 +1,5 @@
chess==1.11.2
torch==2.11.0
tqdm==4.67.3
numpy==2.4.4
numpy==2.4.4
rich==13.7.0
+40 -9
View File
@@ -7,6 +7,36 @@ import sys
from pathlib import Path
from tqdm import tqdm
# Standard piece values for capture filtering
PIECE_VALUES = {
chess.PAWN: 1,
chess.KNIGHT: 3,
chess.BISHOP: 3,
chess.ROOK: 5,
chess.QUEEN: 9,
}
def has_winning_or_equal_capture(board):
"""Check if position has a capture where victim >= attacker (winning or equal trade).
Returns True only if there's at least one favorable capture.
Positions with only losing captures return False (are kept).
"""
for move in board.legal_moves:
if board.is_capture(move):
attacker_piece = board.piece_at(move.from_square)
victim_piece = board.piece_at(move.to_square)
if attacker_piece and victim_piece:
attacker_value = PIECE_VALUES.get(attacker_piece.piece_type, 0)
victim_value = PIECE_VALUES.get(victim_piece.piece_type, 0)
# If victim >= attacker, it's a winning or equal capture
if victim_value >= attacker_value:
return True
return False
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
"""Play random games and save positions after 8-20 random moves.
@@ -49,10 +79,9 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
pbar.update(1)
continue
# Check if any captures are available (if filtering enabled)
# Check if there are winning or equal captures (if filtering enabled)
if filter_captures:
has_captures = any(board.is_capture(move) for move in board.legal_moves)
if has_captures:
if has_winning_or_equal_capture(board):
filtered_captures += 1
pbar.update(1)
continue
@@ -69,21 +98,23 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
print("=" * 60)
print("POSITION GENERATION SUMMARY")
print("=" * 60)
total_filtered = filtered_check + filtered_captures + filtered_game_over
print(f"Total games: {total_games}")
print(f"Saved positions: {positions_count}")
print(f"Filtered (check): {filtered_check}")
print(f"Filtered (captures): {filtered_captures}")
print(f"Filtered (in check): {filtered_check}")
print(f"Filtered (winning+ cap): {filtered_captures}")
print(f"Filtered (game over): {filtered_game_over}")
print(f"Total filtered: {filtered_check + filtered_captures + filtered_game_over}")
print(f"Total filtered: {total_filtered}")
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
print(f"(Keeps positions with only losing/bad captures)")
print("=" * 60)
print()
if positions_count == 0:
print("WARNING: No valid positions were generated!")
print("This might happen if:")
print(" - The filter criteria are too strict (captures, checks)")
print(" - Try using: --no-filter-captures to accept positions with captures")
print(" - Most positions have checks or game-over states")
print(" - Try using: --no-filter-captures to accept positions with winning captures")
return 0
return positions_count
@@ -97,7 +128,7 @@ if __name__ == "__main__":
parser.add_argument("--games", type=int, default=5000,
help="Number of games to play (default: 500000)")
parser.add_argument("--no-filter-captures", action="store_true",
help="Include positions with available captures (increases output)")
help="Include positions with winning/equal captures (increases output)")
args = parser.parse_args()
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 2,
"date": "2026-04-07T23:50:05.390402",
"num_positions": 6886,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 4096,
"learning_rate": 0.001,
"final_val_loss": 0.007848377339541912,
"device": "cuda",
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v1.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}