feat: add rich console interface for NNUE training pipeline and update requirements
This commit is contained in:
+246
-233
@@ -1,21 +1,22 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Central NNUE pipeline CLI for training and exporting models."""
|
"""Central NNUE pipeline TUI for training and exporting models."""
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.prompt import Prompt, Confirm
|
||||||
|
from rich import print as rprint
|
||||||
|
|
||||||
def get_python_cmd():
|
# Add src directory to path so we can import modules
|
||||||
"""Get available Python command."""
|
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||||||
if os.name == 'nt':
|
|
||||||
return "python"
|
|
||||||
return "python3" if os.popen("which python3 2>/dev/null").read() else "python"
|
|
||||||
|
|
||||||
def get_src_module(module_name):
|
from generate import play_random_game_and_collect_positions
|
||||||
"""Get path to module in src/ directory."""
|
from label import label_positions_with_stockfish
|
||||||
return Path(__file__).parent / "src" / f"{module_name}.py"
|
from train import train_nnue
|
||||||
|
from export import export_weights_to_binary
|
||||||
|
|
||||||
def get_data_dir():
|
def get_data_dir():
|
||||||
"""Get/create data directory."""
|
"""Get/create data directory."""
|
||||||
@@ -37,240 +38,252 @@ def list_checkpoints():
|
|||||||
return []
|
return []
|
||||||
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
||||||
|
|
||||||
def run_generate_positions(num_games):
|
def show_header():
|
||||||
"""Generate random positions."""
|
"""Display application header."""
|
||||||
data_dir = get_data_dir()
|
console = Console()
|
||||||
positions_file = data_dir / "positions.txt"
|
console.clear()
|
||||||
print(f"Generating {num_games} positions...")
|
console.print(
|
||||||
result = subprocess.run(
|
Panel(
|
||||||
[get_python_cmd(), str(get_src_module("generate")), str(positions_file), "--games", str(num_games)],
|
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
|
||||||
capture_output=False
|
"[dim]Neural Network Utility Evaluation - Model Management[/dim]",
|
||||||
|
border_style="cyan",
|
||||||
|
padding=(1, 2),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if result.returncode != 0:
|
|
||||||
print("ERROR: Position generation failed")
|
|
||||||
return False
|
|
||||||
return positions_file.exists()
|
|
||||||
|
|
||||||
def run_label_positions(stockfish_path):
|
def show_checkpoints_table():
|
||||||
"""Label positions with Stockfish."""
|
"""Display available checkpoints in a table."""
|
||||||
data_dir = get_data_dir()
|
console = Console()
|
||||||
positions_file = data_dir / "positions.txt"
|
|
||||||
output_file = data_dir / "training_data.jsonl"
|
|
||||||
|
|
||||||
if not positions_file.exists():
|
|
||||||
print("ERROR: positions.txt not found")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print("Labeling positions with Stockfish...")
|
|
||||||
result = subprocess.run(
|
|
||||||
[get_python_cmd(), str(get_src_module("label")), str(positions_file), str(output_file), stockfish_path],
|
|
||||||
capture_output=False
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
print("ERROR: Position labeling failed")
|
|
||||||
return False
|
|
||||||
return output_file.exists()
|
|
||||||
|
|
||||||
def run_train(positions_file, output_weights, from_checkpoint=None):
|
|
||||||
"""Train NNUE model."""
|
|
||||||
if not Path(positions_file).exists():
|
|
||||||
print(f"ERROR: {positions_file} not found")
|
|
||||||
return False
|
|
||||||
|
|
||||||
weights_dir = get_weights_dir()
|
|
||||||
print(f"Training model (output: {output_weights})...")
|
|
||||||
if from_checkpoint:
|
|
||||||
print(f" Starting from checkpoint: {from_checkpoint}")
|
|
||||||
|
|
||||||
cmd = [get_python_cmd(), str(get_src_module("train")), str(positions_file), str(output_weights)]
|
|
||||||
if from_checkpoint:
|
|
||||||
cmd.extend(["--checkpoint", str(from_checkpoint)])
|
|
||||||
|
|
||||||
# Run from weights directory so outputs save there
|
|
||||||
result = subprocess.run(cmd, cwd=str(weights_dir), capture_output=False)
|
|
||||||
if result.returncode != 0:
|
|
||||||
print("ERROR: Training failed")
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def run_export(weights_file, output_file):
|
|
||||||
"""Export weights to Scala."""
|
|
||||||
weights_dir = get_weights_dir()
|
|
||||||
weights_path = weights_dir / Path(weights_file).name
|
|
||||||
|
|
||||||
if not weights_path.exists():
|
|
||||||
print(f"ERROR: {weights_file} not found in {weights_dir}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"Exporting {weights_file} to Scala...")
|
|
||||||
result = subprocess.run(
|
|
||||||
[get_python_cmd(), str(get_src_module("export")), str(weights_path), output_file],
|
|
||||||
capture_output=False
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
print("ERROR: Export failed")
|
|
||||||
return False
|
|
||||||
return Path(output_file).exists()
|
|
||||||
|
|
||||||
def cmd_train(args):
|
|
||||||
"""Handle train command."""
|
|
||||||
stockfish_path = args.stockfish or os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
|
|
||||||
data_dir = get_data_dir()
|
|
||||||
weights_dir = get_weights_dir()
|
|
||||||
|
|
||||||
# Determine checkpoint
|
|
||||||
checkpoint = None
|
|
||||||
if args.from_checkpoint:
|
|
||||||
checkpoint_version = args.from_checkpoint
|
|
||||||
checkpoint = f"nnue_weights_v{checkpoint_version}.pt"
|
|
||||||
checkpoint_path = weights_dir / checkpoint
|
|
||||||
if not checkpoint_path.exists():
|
|
||||||
print(f"ERROR: Checkpoint {checkpoint} not found")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
available = list_checkpoints()
|
|
||||||
if available:
|
|
||||||
latest = max(available)
|
|
||||||
checkpoint = f"nnue_weights_v{latest}.pt"
|
|
||||||
print(f"No checkpoint specified, using latest: v{latest}")
|
|
||||||
|
|
||||||
# Generate or use existing positions
|
|
||||||
if args.positions_file:
|
|
||||||
positions_file = Path(args.positions_file)
|
|
||||||
if not positions_file.exists():
|
|
||||||
print(f"ERROR: {args.positions_file} not found")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
positions_file = data_dir / "positions.txt"
|
|
||||||
num_games = args.games or 500000
|
|
||||||
if not run_generate_positions(num_games):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Label positions
|
|
||||||
if not run_label_positions(stockfish_path):
|
|
||||||
return False
|
|
||||||
|
|
||||||
print("\nStarting training...")
|
|
||||||
|
|
||||||
# Train with absolute path to data, checkpoint is relative to weights dir
|
|
||||||
training_data = str(data_dir / "training_data.jsonl")
|
|
||||||
if not run_train(training_data, "nnue_weights.pt", checkpoint):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Show created version
|
|
||||||
available = list_checkpoints()
|
available = list_checkpoints()
|
||||||
new_version = max(available) if available else 1
|
|
||||||
print(f"\n✓ Training complete: nnue_weights_v{new_version}.pt")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def cmd_export(args):
|
|
||||||
"""Handle export command."""
|
|
||||||
weights_file = args.weights
|
|
||||||
|
|
||||||
# Auto-detect if version is specified
|
|
||||||
if not weights_file.endswith(".pt"):
|
|
||||||
weights_file = f"nnue_weights_v{weights_file}.pt"
|
|
||||||
|
|
||||||
# Output to resources directory as binary format
|
|
||||||
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
|
|
||||||
|
|
||||||
if not run_export(weights_file, output_file):
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"✓ Export complete: {output_file}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def cmd_list(args):
|
|
||||||
"""List available checkpoints."""
|
|
||||||
available = list_checkpoints()
|
|
||||||
if not available:
|
if not available:
|
||||||
print("No checkpoints found")
|
console.print("[yellow]ℹ No checkpoints found yet[/yellow]")
|
||||||
return True
|
return
|
||||||
|
|
||||||
|
table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan")
|
||||||
|
table.add_column("Version", style="dim")
|
||||||
|
table.add_column("File Size", justify="right")
|
||||||
|
table.add_column("Status", justify="center")
|
||||||
|
|
||||||
weights_dir = get_weights_dir()
|
weights_dir = get_weights_dir()
|
||||||
print("Available checkpoints:")
|
for v in sorted(available):
|
||||||
for v in available:
|
|
||||||
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
|
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
|
||||||
if weights_file.exists():
|
if weights_file.exists():
|
||||||
size = weights_file.stat().st_size / (1024**2) # MB
|
size = weights_file.stat().st_size / (1024**2)
|
||||||
print(f" v{v} ({size:.1f} MB)")
|
table.add_row(f"v{v}", f"{size:.1f} MB", "✓ Ready")
|
||||||
else:
|
else:
|
||||||
print(f" v{v} (file not found)")
|
table.add_row(f"v{v}", "?", "[red]✗ Missing[/red]")
|
||||||
return True
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
def show_main_menu():
|
||||||
|
"""Display and handle main menu."""
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
show_header()
|
||||||
|
show_checkpoints_table()
|
||||||
|
|
||||||
|
console.print("\n[bold]What would you like to do?[/bold]")
|
||||||
|
console.print("[cyan]1[/cyan] - Train NNUE Model")
|
||||||
|
console.print("[cyan]2[/cyan] - Export Weights to Scala")
|
||||||
|
console.print("[cyan]3[/cyan] - View Checkpoints")
|
||||||
|
console.print("[cyan]4[/cyan] - Exit")
|
||||||
|
|
||||||
|
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||||
|
|
||||||
|
if choice == "1":
|
||||||
|
train_interactive()
|
||||||
|
elif choice == "2":
|
||||||
|
export_interactive()
|
||||||
|
elif choice == "3":
|
||||||
|
show_header()
|
||||||
|
show_checkpoints_table()
|
||||||
|
Prompt.ask("\nPress Enter to continue")
|
||||||
|
elif choice == "4":
|
||||||
|
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
def train_interactive():
|
||||||
|
"""Interactive training menu."""
|
||||||
|
console = Console()
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]📚 Training Configuration[/bold cyan]")
|
||||||
|
|
||||||
|
# Checkpoint selection
|
||||||
|
available = list_checkpoints()
|
||||||
|
use_checkpoint = False
|
||||||
|
checkpoint_version = None
|
||||||
|
|
||||||
|
if available:
|
||||||
|
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||||
|
use_checkpoint = Confirm.ask("Start from an existing checkpoint?", default=False)
|
||||||
|
if use_checkpoint:
|
||||||
|
checkpoint_version = Prompt.ask(
|
||||||
|
"Enter checkpoint version",
|
||||||
|
default=str(max(available))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Positions source
|
||||||
|
use_existing = Confirm.ask("Use existing positions file?", default=False)
|
||||||
|
positions_file = None
|
||||||
|
num_games = 500000
|
||||||
|
|
||||||
|
if use_existing:
|
||||||
|
positions_file = Prompt.ask("Enter path to positions file")
|
||||||
|
else:
|
||||||
|
num_games = int(Prompt.ask("Number of games to generate", default="500000"))
|
||||||
|
|
||||||
|
# Stockfish path
|
||||||
|
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
|
||||||
|
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
epochs = int(Prompt.ask("Number of epochs", default="20"))
|
||||||
|
batch_size = int(Prompt.ask("Batch size", default="4096"))
|
||||||
|
|
||||||
|
# Confirm and start
|
||||||
|
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||||
|
if use_checkpoint:
|
||||||
|
console.print(f" Checkpoint: v{checkpoint_version}")
|
||||||
|
else:
|
||||||
|
console.print(" Checkpoint: None (training from scratch)")
|
||||||
|
console.print(f" Games: {num_games:,}")
|
||||||
|
console.print(f" Epochs: {epochs}")
|
||||||
|
console.print(f" Batch size: {batch_size}")
|
||||||
|
console.print(f" Stockfish: {stockfish_path}")
|
||||||
|
|
||||||
|
if not Confirm.ask("\nStart training?", default=True):
|
||||||
|
console.print("[yellow]Training cancelled[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Execute training
|
||||||
|
data_dir = get_data_dir()
|
||||||
|
weights_dir = get_weights_dir()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate positions
|
||||||
|
if not use_existing:
|
||||||
|
console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]")
|
||||||
|
count = play_random_game_and_collect_positions(
|
||||||
|
str(data_dir / "positions.txt"),
|
||||||
|
total_games=num_games,
|
||||||
|
filter_captures=True
|
||||||
|
)
|
||||||
|
if count == 0:
|
||||||
|
console.print("[red]✗ No valid positions generated[/red]")
|
||||||
|
return
|
||||||
|
console.print(f"[green]✓ Generated {count:,} positions[/green]")
|
||||||
|
else:
|
||||||
|
if not Path(positions_file).exists():
|
||||||
|
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Label positions
|
||||||
|
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||||
|
positions_file = data_dir / "positions.txt"
|
||||||
|
output_file = data_dir / "training_data.jsonl"
|
||||||
|
success = label_positions_with_stockfish(
|
||||||
|
str(positions_file),
|
||||||
|
str(output_file),
|
||||||
|
stockfish_path,
|
||||||
|
depth=12
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
console.print("[red]✗ Position labeling failed[/red]")
|
||||||
|
return
|
||||||
|
console.print(f"[green]✓ Positions labeled[/green]")
|
||||||
|
|
||||||
|
# Train model
|
||||||
|
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
|
||||||
|
checkpoint = None
|
||||||
|
if use_checkpoint:
|
||||||
|
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
|
||||||
|
|
||||||
|
train_nnue(
|
||||||
|
data_file=str(output_file),
|
||||||
|
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||||
|
epochs=epochs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
use_versioning=True
|
||||||
|
)
|
||||||
|
console.print("[green]✓ Training complete[/green]")
|
||||||
|
|
||||||
|
# Show result
|
||||||
|
available = list_checkpoints()
|
||||||
|
new_version = max(available) if available else 1
|
||||||
|
console.print(f"\n[bold green]✓ Training successful![/bold green]")
|
||||||
|
console.print(f"[bold]New checkpoint: v{new_version}[/bold]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Error: {e}[/red]")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
def export_interactive():
|
||||||
|
"""Interactive export menu."""
|
||||||
|
console = Console()
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]📦 Export Configuration[/bold cyan]")
|
||||||
|
|
||||||
|
# Select weights version
|
||||||
|
available = list_checkpoints()
|
||||||
|
if not available:
|
||||||
|
console.print("[red]✗ No checkpoints available to export[/red]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print(f"[dim]Available versions: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||||
|
version = Prompt.ask("Enter version to export (e.g., 2)")
|
||||||
|
|
||||||
|
weights_file = f"nnue_weights_v{version}.pt"
|
||||||
|
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
|
||||||
|
|
||||||
|
console.print(f"\n[bold]Export Configuration:[/bold]")
|
||||||
|
console.print(f" Source: {weights_file}")
|
||||||
|
console.print(f" Destination: {output_file}")
|
||||||
|
|
||||||
|
if not Confirm.ask("\nExport weights?", default=True):
|
||||||
|
console.print("[yellow]Export cancelled[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
weights_dir = get_weights_dir()
|
||||||
|
weights_path = weights_dir / weights_file
|
||||||
|
|
||||||
|
if not weights_path.exists():
|
||||||
|
console.print(f"[red]✗ {weights_file} not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
|
||||||
|
export_weights_to_binary(str(weights_path), output_file)
|
||||||
|
console.print(f"\n[green]✓ Export complete![/green]")
|
||||||
|
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Error: {e}[/red]")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
try:
|
||||||
description="NNUE pipeline CLI for training and exporting models",
|
show_main_menu()
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
||||||
epilog="""
|
|
||||||
Examples:
|
|
||||||
# Train with 500k random positions
|
|
||||||
python nnue.py train
|
|
||||||
|
|
||||||
# Train from checkpoint v2
|
|
||||||
python nnue.py train --from-checkpoint 2
|
|
||||||
|
|
||||||
# Train with custom positions file
|
|
||||||
python nnue.py train --positions-file my_positions.txt
|
|
||||||
|
|
||||||
# Train with 200k games
|
|
||||||
python nnue.py train --games 200000
|
|
||||||
|
|
||||||
# Export specific weights version
|
|
||||||
python nnue.py export 2
|
|
||||||
|
|
||||||
# Export with full filename
|
|
||||||
python nnue.py export nnue_weights_v3.pt
|
|
||||||
|
|
||||||
# List available checkpoints
|
|
||||||
python nnue.py list
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
|
||||||
|
|
||||||
# Train subcommand
|
|
||||||
train_parser = subparsers.add_parser("train", help="Train NNUE model")
|
|
||||||
train_parser.add_argument(
|
|
||||||
"--from-checkpoint",
|
|
||||||
type=int,
|
|
||||||
help="Start training from checkpoint version (e.g., 2)"
|
|
||||||
)
|
|
||||||
train_parser.add_argument(
|
|
||||||
"--games",
|
|
||||||
type=int,
|
|
||||||
help="Number of games to generate (default: 500000)"
|
|
||||||
)
|
|
||||||
train_parser.add_argument(
|
|
||||||
"--positions-file",
|
|
||||||
help="Use existing positions file instead of generating"
|
|
||||||
)
|
|
||||||
train_parser.add_argument(
|
|
||||||
"--stockfish",
|
|
||||||
help="Path to Stockfish binary (default: $STOCKFISH_PATH or /usr/games/stockfish)"
|
|
||||||
)
|
|
||||||
train_parser.set_defaults(func=cmd_train)
|
|
||||||
|
|
||||||
# Export subcommand
|
|
||||||
export_parser = subparsers.add_parser("export", help="Export weights to Scala")
|
|
||||||
export_parser.add_argument(
|
|
||||||
"weights",
|
|
||||||
help="Weights file or version (e.g., 2 or nnue_weights_v2.pt)"
|
|
||||||
)
|
|
||||||
export_parser.set_defaults(func=cmd_export)
|
|
||||||
|
|
||||||
# List subcommand
|
|
||||||
list_parser = subparsers.add_parser("list", help="List available checkpoints")
|
|
||||||
list_parser.set_defaults(func=cmd_list)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if not args.command:
|
|
||||||
parser.print_help()
|
|
||||||
return 0
|
return 0
|
||||||
|
except KeyboardInterrupt:
|
||||||
success = args.func(args)
|
console = Console()
|
||||||
return 0 if success else 1
|
console.print("\n[yellow]Interrupted by user[/yellow]")
|
||||||
|
return 1
|
||||||
|
except Exception as e:
|
||||||
|
console = Console()
|
||||||
|
console.print(f"[red]Error:[/red] {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sys.exit(main())
|
sys.exit(main())
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
chess==1.11.2
|
chess==1.11.2
|
||||||
torch==2.11.0
|
torch==2.11.0
|
||||||
tqdm==4.67.3
|
tqdm==4.67.3
|
||||||
numpy==2.4.4
|
numpy==2.4.4
|
||||||
|
rich==13.7.0
|
||||||
@@ -7,6 +7,36 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Standard piece values for capture filtering
|
||||||
|
PIECE_VALUES = {
|
||||||
|
chess.PAWN: 1,
|
||||||
|
chess.KNIGHT: 3,
|
||||||
|
chess.BISHOP: 3,
|
||||||
|
chess.ROOK: 5,
|
||||||
|
chess.QUEEN: 9,
|
||||||
|
}
|
||||||
|
|
||||||
|
def has_winning_or_equal_capture(board):
|
||||||
|
"""Check if position has a capture where victim >= attacker (winning or equal trade).
|
||||||
|
|
||||||
|
Returns True only if there's at least one favorable capture.
|
||||||
|
Positions with only losing captures return False (are kept).
|
||||||
|
"""
|
||||||
|
for move in board.legal_moves:
|
||||||
|
if board.is_capture(move):
|
||||||
|
attacker_piece = board.piece_at(move.from_square)
|
||||||
|
victim_piece = board.piece_at(move.to_square)
|
||||||
|
|
||||||
|
if attacker_piece and victim_piece:
|
||||||
|
attacker_value = PIECE_VALUES.get(attacker_piece.piece_type, 0)
|
||||||
|
victim_value = PIECE_VALUES.get(victim_piece.piece_type, 0)
|
||||||
|
|
||||||
|
# If victim >= attacker, it's a winning or equal capture
|
||||||
|
if victim_value >= attacker_value:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
|
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
|
||||||
"""Play random games and save positions after 8-20 random moves.
|
"""Play random games and save positions after 8-20 random moves.
|
||||||
|
|
||||||
@@ -49,10 +79,9 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
|
|||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if any captures are available (if filtering enabled)
|
# Check if there are winning or equal captures (if filtering enabled)
|
||||||
if filter_captures:
|
if filter_captures:
|
||||||
has_captures = any(board.is_capture(move) for move in board.legal_moves)
|
if has_winning_or_equal_capture(board):
|
||||||
if has_captures:
|
|
||||||
filtered_captures += 1
|
filtered_captures += 1
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
continue
|
continue
|
||||||
@@ -69,21 +98,23 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
|
|||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("POSITION GENERATION SUMMARY")
|
print("POSITION GENERATION SUMMARY")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
total_filtered = filtered_check + filtered_captures + filtered_game_over
|
||||||
print(f"Total games: {total_games}")
|
print(f"Total games: {total_games}")
|
||||||
print(f"Saved positions: {positions_count}")
|
print(f"Saved positions: {positions_count}")
|
||||||
print(f"Filtered (check): {filtered_check}")
|
print(f"Filtered (in check): {filtered_check}")
|
||||||
print(f"Filtered (captures): {filtered_captures}")
|
print(f"Filtered (winning+ cap): {filtered_captures}")
|
||||||
print(f"Filtered (game over): {filtered_game_over}")
|
print(f"Filtered (game over): {filtered_game_over}")
|
||||||
print(f"Total filtered: {filtered_check + filtered_captures + filtered_game_over}")
|
print(f"Total filtered: {total_filtered}")
|
||||||
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
|
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
|
||||||
|
print(f"(Keeps positions with only losing/bad captures)")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
if positions_count == 0:
|
if positions_count == 0:
|
||||||
print("WARNING: No valid positions were generated!")
|
print("WARNING: No valid positions were generated!")
|
||||||
print("This might happen if:")
|
print("This might happen if:")
|
||||||
print(" - The filter criteria are too strict (captures, checks)")
|
print(" - Most positions have checks or game-over states")
|
||||||
print(" - Try using: --no-filter-captures to accept positions with captures")
|
print(" - Try using: --no-filter-captures to accept positions with winning captures")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
return positions_count
|
return positions_count
|
||||||
@@ -97,7 +128,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--games", type=int, default=5000,
|
parser.add_argument("--games", type=int, default=5000,
|
||||||
help="Number of games to play (default: 500000)")
|
help="Number of games to play (default: 500000)")
|
||||||
parser.add_argument("--no-filter-captures", action="store_true",
|
parser.add_argument("--no-filter-captures", action="store_true",
|
||||||
help="Include positions with available captures (increases output)")
|
help="Include positions with winning/equal captures (increases output)")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"version": 2,
|
||||||
|
"date": "2026-04-07T23:50:05.390402",
|
||||||
|
"num_positions": 6886,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"epochs": 100,
|
||||||
|
"batch_size": 4096,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"final_val_loss": 0.007848377339541912,
|
||||||
|
"device": "cuda",
|
||||||
|
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v1.pt",
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user