feat: add hybrid bot implementation and enhance NNUE training pipeline with tactical data extraction

This commit is contained in:
2026-04-09 22:36:09 +02:00
parent 0bf1d52132
commit 5d4cf5f13c
16 changed files with 940 additions and 245 deletions
+92 -9
View File
@@ -2,6 +2,7 @@
"""Central NNUE pipeline TUI for training and exporting models."""
import os
import shutil
import sys
from pathlib import Path
from rich.console import Console
@@ -17,6 +18,10 @@ from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish
from train import train_nnue
from export import export_weights_to_binary
from tactical_positions_extractor import (
download_and_extract_puzzle_db,
interactive_merge_positions
)
def get_data_dir():
"""Get/create data directory."""
@@ -24,6 +29,12 @@ def get_data_dir():
data_dir.mkdir(exist_ok=True)
return data_dir
def get_tactical_data_dir():
"""Get/create data directory."""
data_dir = Path(__file__).parent / "tactical_data"
data_dir.mkdir(exist_ok=True)
return data_dir
def get_weights_dir():
"""Get/create weights directory."""
weights_dir = Path(__file__).parent / "weights"
@@ -87,20 +98,23 @@ def show_main_menu():
console.print("\n[bold]What would you like to do?[/bold]")
console.print("[cyan]1[/cyan] - Train NNUE Model")
console.print("[cyan]2[/cyan] - Export Weights to Scala")
console.print("[cyan]3[/cyan] - View Checkpoints")
console.print("[cyan]4[/cyan] - Exit")
console.print("[cyan]3[/cyan] - Extract Tactical Positions")
console.print("[cyan]4[/cyan] - View Checkpoints")
console.print("[cyan]5[/cyan] - Exit")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
if choice == "1":
train_interactive()
elif choice == "2":
export_interactive()
elif choice == "3":
extract_tactical_interactive()
elif choice == "4":
show_header()
show_checkpoints_table()
Prompt.ask("\nPress Enter to continue")
elif choice == "4":
elif choice == "5":
console.print("[yellow]👋 Goodbye![/yellow]")
return
@@ -146,13 +160,18 @@ def train_interactive():
if use_existing_labels:
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
# Stockfish path
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/bin/stockfish")
# Stockfish path and labeling parameters
default_stockfish = os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
stockfish_depth = 12
num_workers = 1
if not use_existing_labels:
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Training parameters
epochs = int(Prompt.ask("Number of epochs", default="20"))
batch_size = int(Prompt.ask("Batch size", default="4096"))
epochs = int(Prompt.ask("Number of epochs", default="100"))
batch_size = int(Prompt.ask("Batch size", default="16384"))
early_stopping = None
if Confirm.ask("Enable early stopping?", default=False):
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
@@ -175,6 +194,9 @@ def train_interactive():
console.print(f" Early stopping: Yes (patience: {early_stopping})")
else:
console.print(f" Early stopping: No")
if not use_existing_labels:
console.print(f" Stockfish depth: {stockfish_depth}")
console.print(f" Workers: {num_workers}")
console.print(f" Stockfish: {stockfish_path}")
if not Confirm.ask("\nStart training?", default=True):
@@ -214,7 +236,8 @@ def train_interactive():
str(positions_file),
str(output_file),
stockfish_path,
depth=12
depth=stockfish_depth,
num_workers=num_workers
)
if not success:
console.print("[red]✗ Position labeling failed[/red]")
@@ -305,6 +328,66 @@ def export_interactive():
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def extract_tactical_interactive():
"""Interactive tactical positions extraction and merge menu."""
console = Console()
show_header()
console.print("\n[bold cyan]♟️ Tactical Positions Extraction & Merge[/bold cyan]")
# Download and extract options
console.print("\n[bold]Lichess Puzzle Database:[/bold]")
download_url = Prompt.ask(
"Download URL",
default="https://database.lichess.org/lichess_db_puzzle.csv.zst"
)
output_dir = Prompt.ask(
"Extract to directory",
default=str(Path(__file__).parent / "trainingdata")
)
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
# Confirm and download
console.print("\n[bold]Configuration:[/bold]")
console.print(f" Download URL: {download_url}")
console.print(f" Extract directory: {output_dir}")
console.print(f" Max puzzles: {max_puzzles:,}")
if not Confirm.ask("\nProceed?", default=True):
console.print("[yellow]Cancelled[/yellow]")
Prompt.ask("Press Enter to continue")
return
try:
console.print("\n[bold cyan]Step 1: Download & Extract[/bold cyan]")
csv_path = download_and_extract_puzzle_db(download_url, output_dir)
if not csv_path:
console.print("[red]✗ Failed to download/extract[/red]")
Prompt.ask("Press Enter to continue")
return
console.print(f"[green]✓ Ready: {csv_path}[/green]")
# Interactive merge
console.print("\n[bold cyan]Step 2: Extract & Merge[/bold cyan]")
output_file = Prompt.ask(
"Output file path",
default=str(Path(__file__).parent / "data" / "position.txt")
)
interactive_merge_positions(csv_path, output_file, max_puzzles)
console.print(f"\n[green]✓ Complete![/green]")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def main():
try:
show_main_menu()