feat: add rich console interface for NNUE training pipeline and update requirements
This commit is contained in:
+246
-233
@@ -1,21 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Central NNUE pipeline CLI for training and exporting models."""
|
||||
"""Central NNUE pipeline TUI for training and exporting models."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt, Confirm
|
||||
from rich import print as rprint
|
||||
|
||||
def get_python_cmd():
|
||||
"""Get available Python command."""
|
||||
if os.name == 'nt':
|
||||
return "python"
|
||||
return "python3" if os.popen("which python3 2>/dev/null").read() else "python"
|
||||
# Add src directory to path so we can import modules
|
||||
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||||
|
||||
def get_src_module(module_name):
|
||||
"""Get path to module in src/ directory."""
|
||||
return Path(__file__).parent / "src" / f"{module_name}.py"
|
||||
from generate import play_random_game_and_collect_positions
|
||||
from label import label_positions_with_stockfish
|
||||
from train import train_nnue
|
||||
from export import export_weights_to_binary
|
||||
|
||||
def get_data_dir():
|
||||
"""Get/create data directory."""
|
||||
@@ -37,240 +38,252 @@ def list_checkpoints():
|
||||
return []
|
||||
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
||||
|
||||
def run_generate_positions(num_games):
|
||||
"""Generate random positions."""
|
||||
data_dir = get_data_dir()
|
||||
positions_file = data_dir / "positions.txt"
|
||||
print(f"Generating {num_games} positions...")
|
||||
result = subprocess.run(
|
||||
[get_python_cmd(), str(get_src_module("generate")), str(positions_file), "--games", str(num_games)],
|
||||
capture_output=False
|
||||
def show_header():
|
||||
"""Display application header."""
|
||||
console = Console()
|
||||
console.clear()
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
|
||||
"[dim]Neural Network Utility Evaluation - Model Management[/dim]",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print("ERROR: Position generation failed")
|
||||
return False
|
||||
return positions_file.exists()
|
||||
|
||||
def run_label_positions(stockfish_path):
|
||||
"""Label positions with Stockfish."""
|
||||
data_dir = get_data_dir()
|
||||
positions_file = data_dir / "positions.txt"
|
||||
output_file = data_dir / "training_data.jsonl"
|
||||
|
||||
if not positions_file.exists():
|
||||
print("ERROR: positions.txt not found")
|
||||
return False
|
||||
|
||||
print("Labeling positions with Stockfish...")
|
||||
result = subprocess.run(
|
||||
[get_python_cmd(), str(get_src_module("label")), str(positions_file), str(output_file), stockfish_path],
|
||||
capture_output=False
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print("ERROR: Position labeling failed")
|
||||
return False
|
||||
return output_file.exists()
|
||||
|
||||
def run_train(positions_file, output_weights, from_checkpoint=None):
|
||||
"""Train NNUE model."""
|
||||
if not Path(positions_file).exists():
|
||||
print(f"ERROR: {positions_file} not found")
|
||||
return False
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
print(f"Training model (output: {output_weights})...")
|
||||
if from_checkpoint:
|
||||
print(f" Starting from checkpoint: {from_checkpoint}")
|
||||
|
||||
cmd = [get_python_cmd(), str(get_src_module("train")), str(positions_file), str(output_weights)]
|
||||
if from_checkpoint:
|
||||
cmd.extend(["--checkpoint", str(from_checkpoint)])
|
||||
|
||||
# Run from weights directory so outputs save there
|
||||
result = subprocess.run(cmd, cwd=str(weights_dir), capture_output=False)
|
||||
if result.returncode != 0:
|
||||
print("ERROR: Training failed")
|
||||
return False
|
||||
return True
|
||||
|
||||
def run_export(weights_file, output_file):
|
||||
"""Export weights to Scala."""
|
||||
weights_dir = get_weights_dir()
|
||||
weights_path = weights_dir / Path(weights_file).name
|
||||
|
||||
if not weights_path.exists():
|
||||
print(f"ERROR: {weights_file} not found in {weights_dir}")
|
||||
return False
|
||||
|
||||
print(f"Exporting {weights_file} to Scala...")
|
||||
result = subprocess.run(
|
||||
[get_python_cmd(), str(get_src_module("export")), str(weights_path), output_file],
|
||||
capture_output=False
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print("ERROR: Export failed")
|
||||
return False
|
||||
return Path(output_file).exists()
|
||||
|
||||
def cmd_train(args):
|
||||
"""Handle train command."""
|
||||
stockfish_path = args.stockfish or os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
|
||||
data_dir = get_data_dir()
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
# Determine checkpoint
|
||||
checkpoint = None
|
||||
if args.from_checkpoint:
|
||||
checkpoint_version = args.from_checkpoint
|
||||
checkpoint = f"nnue_weights_v{checkpoint_version}.pt"
|
||||
checkpoint_path = weights_dir / checkpoint
|
||||
if not checkpoint_path.exists():
|
||||
print(f"ERROR: Checkpoint {checkpoint} not found")
|
||||
return False
|
||||
else:
|
||||
available = list_checkpoints()
|
||||
if available:
|
||||
latest = max(available)
|
||||
checkpoint = f"nnue_weights_v{latest}.pt"
|
||||
print(f"No checkpoint specified, using latest: v{latest}")
|
||||
|
||||
# Generate or use existing positions
|
||||
if args.positions_file:
|
||||
positions_file = Path(args.positions_file)
|
||||
if not positions_file.exists():
|
||||
print(f"ERROR: {args.positions_file} not found")
|
||||
return False
|
||||
else:
|
||||
positions_file = data_dir / "positions.txt"
|
||||
num_games = args.games or 500000
|
||||
if not run_generate_positions(num_games):
|
||||
return False
|
||||
|
||||
# Label positions
|
||||
if not run_label_positions(stockfish_path):
|
||||
return False
|
||||
|
||||
print("\nStarting training...")
|
||||
|
||||
# Train with absolute path to data, checkpoint is relative to weights dir
|
||||
training_data = str(data_dir / "training_data.jsonl")
|
||||
if not run_train(training_data, "nnue_weights.pt", checkpoint):
|
||||
return False
|
||||
|
||||
# Show created version
|
||||
def show_checkpoints_table():
|
||||
"""Display available checkpoints in a table."""
|
||||
console = Console()
|
||||
available = list_checkpoints()
|
||||
new_version = max(available) if available else 1
|
||||
print(f"\n✓ Training complete: nnue_weights_v{new_version}.pt")
|
||||
return True
|
||||
|
||||
def cmd_export(args):
|
||||
"""Handle export command."""
|
||||
weights_file = args.weights
|
||||
|
||||
# Auto-detect if version is specified
|
||||
if not weights_file.endswith(".pt"):
|
||||
weights_file = f"nnue_weights_v{weights_file}.pt"
|
||||
|
||||
# Output to resources directory as binary format
|
||||
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
|
||||
|
||||
if not run_export(weights_file, output_file):
|
||||
return False
|
||||
|
||||
print(f"✓ Export complete: {output_file}")
|
||||
return True
|
||||
|
||||
def cmd_list(args):
|
||||
"""List available checkpoints."""
|
||||
available = list_checkpoints()
|
||||
if not available:
|
||||
print("No checkpoints found")
|
||||
return True
|
||||
console.print("[yellow]ℹ No checkpoints found yet[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan")
|
||||
table.add_column("Version", style="dim")
|
||||
table.add_column("File Size", justify="right")
|
||||
table.add_column("Status", justify="center")
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
print("Available checkpoints:")
|
||||
for v in available:
|
||||
for v in sorted(available):
|
||||
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
|
||||
if weights_file.exists():
|
||||
size = weights_file.stat().st_size / (1024**2) # MB
|
||||
print(f" v{v} ({size:.1f} MB)")
|
||||
size = weights_file.stat().st_size / (1024**2)
|
||||
table.add_row(f"v{v}", f"{size:.1f} MB", "✓ Ready")
|
||||
else:
|
||||
print(f" v{v} (file not found)")
|
||||
return True
|
||||
table.add_row(f"v{v}", "?", "[red]✗ Missing[/red]")
|
||||
|
||||
console.print(table)
|
||||
|
||||
def show_main_menu():
|
||||
"""Display and handle main menu."""
|
||||
console = Console()
|
||||
|
||||
while True:
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
|
||||
console.print("\n[bold]What would you like to do?[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Train NNUE Model")
|
||||
console.print("[cyan]2[/cyan] - Export Weights to Scala")
|
||||
console.print("[cyan]3[/cyan] - View Checkpoints")
|
||||
console.print("[cyan]4[/cyan] - Exit")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||
|
||||
if choice == "1":
|
||||
train_interactive()
|
||||
elif choice == "2":
|
||||
export_interactive()
|
||||
elif choice == "3":
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
elif choice == "4":
|
||||
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||
return
|
||||
|
||||
def train_interactive():
|
||||
"""Interactive training menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📚 Training Configuration[/bold cyan]")
|
||||
|
||||
# Checkpoint selection
|
||||
available = list_checkpoints()
|
||||
use_checkpoint = False
|
||||
checkpoint_version = None
|
||||
|
||||
if available:
|
||||
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||
use_checkpoint = Confirm.ask("Start from an existing checkpoint?", default=False)
|
||||
if use_checkpoint:
|
||||
checkpoint_version = Prompt.ask(
|
||||
"Enter checkpoint version",
|
||||
default=str(max(available))
|
||||
)
|
||||
|
||||
# Positions source
|
||||
use_existing = Confirm.ask("Use existing positions file?", default=False)
|
||||
positions_file = None
|
||||
num_games = 500000
|
||||
|
||||
if use_existing:
|
||||
positions_file = Prompt.ask("Enter path to positions file")
|
||||
else:
|
||||
num_games = int(Prompt.ask("Number of games to generate", default="500000"))
|
||||
|
||||
# Stockfish path
|
||||
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
|
||||
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
|
||||
|
||||
# Training parameters
|
||||
epochs = int(Prompt.ask("Number of epochs", default="20"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="4096"))
|
||||
|
||||
# Confirm and start
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
if use_checkpoint:
|
||||
console.print(f" Checkpoint: v{checkpoint_version}")
|
||||
else:
|
||||
console.print(" Checkpoint: None (training from scratch)")
|
||||
console.print(f" Games: {num_games:,}")
|
||||
console.print(f" Epochs: {epochs}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Stockfish: {stockfish_path}")
|
||||
|
||||
if not Confirm.ask("\nStart training?", default=True):
|
||||
console.print("[yellow]Training cancelled[/yellow]")
|
||||
return
|
||||
|
||||
# Execute training
|
||||
data_dir = get_data_dir()
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
try:
|
||||
# Generate positions
|
||||
if not use_existing:
|
||||
console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]")
|
||||
count = play_random_game_and_collect_positions(
|
||||
str(data_dir / "positions.txt"),
|
||||
total_games=num_games,
|
||||
filter_captures=True
|
||||
)
|
||||
if count == 0:
|
||||
console.print("[red]✗ No valid positions generated[/red]")
|
||||
return
|
||||
console.print(f"[green]✓ Generated {count:,} positions[/green]")
|
||||
else:
|
||||
if not Path(positions_file).exists():
|
||||
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
|
||||
return
|
||||
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
positions_file = data_dir / "positions.txt"
|
||||
output_file = data_dir / "training_data.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(positions_file),
|
||||
str(output_file),
|
||||
stockfish_path,
|
||||
depth=12
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Position labeling failed[/red]")
|
||||
return
|
||||
console.print(f"[green]✓ Positions labeled[/green]")
|
||||
|
||||
# Train model
|
||||
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
|
||||
checkpoint = None
|
||||
if use_checkpoint:
|
||||
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
|
||||
|
||||
train_nnue(
|
||||
data_file=str(output_file),
|
||||
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
checkpoint=checkpoint,
|
||||
use_versioning=True
|
||||
)
|
||||
console.print("[green]✓ Training complete[/green]")
|
||||
|
||||
# Show result
|
||||
available = list_checkpoints()
|
||||
new_version = max(available) if available else 1
|
||||
console.print(f"\n[bold green]✓ Training successful![/bold green]")
|
||||
console.print(f"[bold]New checkpoint: v{new_version}[/bold]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def export_interactive():
|
||||
"""Interactive export menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📦 Export Configuration[/bold cyan]")
|
||||
|
||||
# Select weights version
|
||||
available = list_checkpoints()
|
||||
if not available:
|
||||
console.print("[red]✗ No checkpoints available to export[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print(f"[dim]Available versions: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||
version = Prompt.ask("Enter version to export (e.g., 2)")
|
||||
|
||||
weights_file = f"nnue_weights_v{version}.pt"
|
||||
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
|
||||
|
||||
console.print(f"\n[bold]Export Configuration:[/bold]")
|
||||
console.print(f" Source: {weights_file}")
|
||||
console.print(f" Destination: {output_file}")
|
||||
|
||||
if not Confirm.ask("\nExport weights?", default=True):
|
||||
console.print("[yellow]Export cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
weights_dir = get_weights_dir()
|
||||
weights_path = weights_dir / weights_file
|
||||
|
||||
if not weights_path.exists():
|
||||
console.print(f"[red]✗ {weights_file} not found[/red]")
|
||||
return
|
||||
|
||||
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
|
||||
export_weights_to_binary(str(weights_path), output_file)
|
||||
console.print(f"\n[green]✓ Export complete![/green]")
|
||||
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="NNUE pipeline CLI for training and exporting models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Train with 500k random positions
|
||||
python nnue.py train
|
||||
|
||||
# Train from checkpoint v2
|
||||
python nnue.py train --from-checkpoint 2
|
||||
|
||||
# Train with custom positions file
|
||||
python nnue.py train --positions-file my_positions.txt
|
||||
|
||||
# Train with 200k games
|
||||
python nnue.py train --games 200000
|
||||
|
||||
# Export specific weights version
|
||||
python nnue.py export 2
|
||||
|
||||
# Export with full filename
|
||||
python nnue.py export nnue_weights_v3.pt
|
||||
|
||||
# List available checkpoints
|
||||
python nnue.py list
|
||||
"""
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
|
||||
# Train subcommand
|
||||
train_parser = subparsers.add_parser("train", help="Train NNUE model")
|
||||
train_parser.add_argument(
|
||||
"--from-checkpoint",
|
||||
type=int,
|
||||
help="Start training from checkpoint version (e.g., 2)"
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--games",
|
||||
type=int,
|
||||
help="Number of games to generate (default: 500000)"
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--positions-file",
|
||||
help="Use existing positions file instead of generating"
|
||||
)
|
||||
train_parser.add_argument(
|
||||
"--stockfish",
|
||||
help="Path to Stockfish binary (default: $STOCKFISH_PATH or /usr/games/stockfish)"
|
||||
)
|
||||
train_parser.set_defaults(func=cmd_train)
|
||||
|
||||
# Export subcommand
|
||||
export_parser = subparsers.add_parser("export", help="Export weights to Scala")
|
||||
export_parser.add_argument(
|
||||
"weights",
|
||||
help="Weights file or version (e.g., 2 or nnue_weights_v2.pt)"
|
||||
)
|
||||
export_parser.set_defaults(func=cmd_export)
|
||||
|
||||
# List subcommand
|
||||
list_parser = subparsers.add_parser("list", help="List available checkpoints")
|
||||
list_parser.set_defaults(func=cmd_list)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
try:
|
||||
show_main_menu()
|
||||
return 0
|
||||
|
||||
success = args.func(args)
|
||||
return 0 if success else 1
|
||||
except KeyboardInterrupt:
|
||||
console = Console()
|
||||
console.print("\n[yellow]Interrupted by user[/yellow]")
|
||||
return 1
|
||||
except Exception as e:
|
||||
console = Console()
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
sys.exit(main())
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
chess==1.11.2
|
||||
torch==2.11.0
|
||||
tqdm==4.67.3
|
||||
numpy==2.4.4
|
||||
numpy==2.4.4
|
||||
rich==13.7.0
|
||||
@@ -7,6 +7,36 @@ import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
# Standard piece values for capture filtering
|
||||
PIECE_VALUES = {
|
||||
chess.PAWN: 1,
|
||||
chess.KNIGHT: 3,
|
||||
chess.BISHOP: 3,
|
||||
chess.ROOK: 5,
|
||||
chess.QUEEN: 9,
|
||||
}
|
||||
|
||||
def has_winning_or_equal_capture(board):
|
||||
"""Check if position has a capture where victim >= attacker (winning or equal trade).
|
||||
|
||||
Returns True only if there's at least one favorable capture.
|
||||
Positions with only losing captures return False (are kept).
|
||||
"""
|
||||
for move in board.legal_moves:
|
||||
if board.is_capture(move):
|
||||
attacker_piece = board.piece_at(move.from_square)
|
||||
victim_piece = board.piece_at(move.to_square)
|
||||
|
||||
if attacker_piece and victim_piece:
|
||||
attacker_value = PIECE_VALUES.get(attacker_piece.piece_type, 0)
|
||||
victim_value = PIECE_VALUES.get(victim_piece.piece_type, 0)
|
||||
|
||||
# If victim >= attacker, it's a winning or equal capture
|
||||
if victim_value >= attacker_value:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
|
||||
"""Play random games and save positions after 8-20 random moves.
|
||||
|
||||
@@ -49,10 +79,9 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# Check if any captures are available (if filtering enabled)
|
||||
# Check if there are winning or equal captures (if filtering enabled)
|
||||
if filter_captures:
|
||||
has_captures = any(board.is_capture(move) for move in board.legal_moves)
|
||||
if has_captures:
|
||||
if has_winning_or_equal_capture(board):
|
||||
filtered_captures += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
@@ -69,21 +98,23 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
|
||||
print("=" * 60)
|
||||
print("POSITION GENERATION SUMMARY")
|
||||
print("=" * 60)
|
||||
total_filtered = filtered_check + filtered_captures + filtered_game_over
|
||||
print(f"Total games: {total_games}")
|
||||
print(f"Saved positions: {positions_count}")
|
||||
print(f"Filtered (check): {filtered_check}")
|
||||
print(f"Filtered (captures): {filtered_captures}")
|
||||
print(f"Filtered (in check): {filtered_check}")
|
||||
print(f"Filtered (winning+ cap): {filtered_captures}")
|
||||
print(f"Filtered (game over): {filtered_game_over}")
|
||||
print(f"Total filtered: {filtered_check + filtered_captures + filtered_game_over}")
|
||||
print(f"Total filtered: {total_filtered}")
|
||||
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
|
||||
print(f"(Keeps positions with only losing/bad captures)")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if positions_count == 0:
|
||||
print("WARNING: No valid positions were generated!")
|
||||
print("This might happen if:")
|
||||
print(" - The filter criteria are too strict (captures, checks)")
|
||||
print(" - Try using: --no-filter-captures to accept positions with captures")
|
||||
print(" - Most positions have checks or game-over states")
|
||||
print(" - Try using: --no-filter-captures to accept positions with winning captures")
|
||||
return 0
|
||||
|
||||
return positions_count
|
||||
@@ -97,7 +128,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--games", type=int, default=5000,
|
||||
help="Number of games to play (default: 500000)")
|
||||
parser.add_argument("--no-filter-captures", action="store_true",
|
||||
help="Include positions with available captures (increases output)")
|
||||
help="Include positions with winning/equal captures (increases output)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 2,
|
||||
"date": "2026-04-07T23:50:05.390402",
|
||||
"num_positions": 6886,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 4096,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.007848377339541912,
|
||||
"device": "cuda",
|
||||
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v1.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Reference in New Issue
Block a user