feat: add hybrid bot implementation and enhance NNUE training pipeline with tactical data extraction
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
"""Central NNUE pipeline TUI for training and exporting models."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
@@ -17,6 +18,10 @@ from generate import play_random_game_and_collect_positions
|
||||
from label import label_positions_with_stockfish
|
||||
from train import train_nnue
|
||||
from export import export_weights_to_binary
|
||||
from tactical_positions_extractor import (
|
||||
download_and_extract_puzzle_db,
|
||||
interactive_merge_positions
|
||||
)
|
||||
|
||||
def get_data_dir():
|
||||
"""Get/create data directory."""
|
||||
@@ -24,6 +29,12 @@ def get_data_dir():
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
return data_dir
|
||||
|
||||
def get_tactical_data_dir():
|
||||
"""Get/create data directory."""
|
||||
data_dir = Path(__file__).parent / "tactical_data"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
return data_dir
|
||||
|
||||
def get_weights_dir():
|
||||
"""Get/create weights directory."""
|
||||
weights_dir = Path(__file__).parent / "weights"
|
||||
@@ -87,20 +98,23 @@ def show_main_menu():
|
||||
console.print("\n[bold]What would you like to do?[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Train NNUE Model")
|
||||
console.print("[cyan]2[/cyan] - Export Weights to Scala")
|
||||
console.print("[cyan]3[/cyan] - View Checkpoints")
|
||||
console.print("[cyan]4[/cyan] - Exit")
|
||||
console.print("[cyan]3[/cyan] - Extract Tactical Positions")
|
||||
console.print("[cyan]4[/cyan] - View Checkpoints")
|
||||
console.print("[cyan]5[/cyan] - Exit")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
|
||||
|
||||
if choice == "1":
|
||||
train_interactive()
|
||||
elif choice == "2":
|
||||
export_interactive()
|
||||
elif choice == "3":
|
||||
extract_tactical_interactive()
|
||||
elif choice == "4":
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
elif choice == "4":
|
||||
elif choice == "5":
|
||||
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||
return
|
||||
|
||||
@@ -146,13 +160,18 @@ def train_interactive():
|
||||
if use_existing_labels:
|
||||
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
|
||||
|
||||
# Stockfish path
|
||||
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/bin/stockfish")
|
||||
# Stockfish path and labeling parameters
|
||||
default_stockfish = os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
|
||||
stockfish_depth = 12
|
||||
num_workers = 1
|
||||
if not use_existing_labels:
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Training parameters
|
||||
epochs = int(Prompt.ask("Number of epochs", default="20"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="4096"))
|
||||
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
early_stopping = None
|
||||
if Confirm.ask("Enable early stopping?", default=False):
|
||||
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||||
@@ -175,6 +194,9 @@ def train_interactive():
|
||||
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||||
else:
|
||||
console.print(f" Early stopping: No")
|
||||
if not use_existing_labels:
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
console.print(f" Workers: {num_workers}")
|
||||
console.print(f" Stockfish: {stockfish_path}")
|
||||
|
||||
if not Confirm.ask("\nStart training?", default=True):
|
||||
@@ -214,7 +236,8 @@ def train_interactive():
|
||||
str(positions_file),
|
||||
str(output_file),
|
||||
stockfish_path,
|
||||
depth=12
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Position labeling failed[/red]")
|
||||
@@ -305,6 +328,66 @@ def export_interactive():
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def extract_tactical_interactive():
|
||||
"""Interactive tactical positions extraction and merge menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]♟️ Tactical Positions Extraction & Merge[/bold cyan]")
|
||||
|
||||
# Download and extract options
|
||||
console.print("\n[bold]Lichess Puzzle Database:[/bold]")
|
||||
download_url = Prompt.ask(
|
||||
"Download URL",
|
||||
default="https://database.lichess.org/lichess_db_puzzle.csv.zst"
|
||||
)
|
||||
|
||||
output_dir = Prompt.ask(
|
||||
"Extract to directory",
|
||||
default=str(Path(__file__).parent / "trainingdata")
|
||||
)
|
||||
|
||||
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||
|
||||
# Confirm and download
|
||||
console.print("\n[bold]Configuration:[/bold]")
|
||||
console.print(f" Download URL: {download_url}")
|
||||
console.print(f" Extract directory: {output_dir}")
|
||||
console.print(f" Max puzzles: {max_puzzles:,}")
|
||||
|
||||
if not Confirm.ask("\nProceed?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
try:
|
||||
console.print("\n[bold cyan]Step 1: Download & Extract[/bold cyan]")
|
||||
csv_path = download_and_extract_puzzle_db(download_url, output_dir)
|
||||
|
||||
if not csv_path:
|
||||
console.print("[red]✗ Failed to download/extract[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print(f"[green]✓ Ready: {csv_path}[/green]")
|
||||
|
||||
# Interactive merge
|
||||
console.print("\n[bold cyan]Step 2: Extract & Merge[/bold cyan]")
|
||||
output_file = Prompt.ask(
|
||||
"Output file path",
|
||||
default=str(Path(__file__).parent / "data" / "position.txt")
|
||||
)
|
||||
|
||||
interactive_merge_positions(csv_path, output_file, max_puzzles)
|
||||
console.print(f"\n[green]✓ Complete![/green]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def main():
|
||||
try:
|
||||
show_main_menu()
|
||||
|
||||
Reference in New Issue
Block a user