feat: Implement dataset versioning and management for NNUE training data

This commit is contained in:
2026-04-13 21:19:26 +02:00
parent 4b52199754
commit 8fb872e958
18 changed files with 1399 additions and 335 deletions
+547 -193
View File
@@ -4,6 +4,7 @@
import os
import shutil
import sys
import tempfile
from pathlib import Path
from rich.console import Console
from rich.table import Table
@@ -17,23 +18,23 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish
from train import train_nnue, burst_train
from export import export_weights_to_binary
from export import export_to_nbai
from tactical_positions_extractor import (
download_and_extract_puzzle_db,
interactive_merge_positions
extract_tactical_only
)
from dataset import (
get_datasets_dir,
list_datasets,
next_dataset_version,
load_dataset_metadata,
create_dataset,
extend_dataset,
get_dataset_labeled_path,
delete_dataset,
show_datasets_table
)
def get_data_dir():
"""Get/create data directory."""
data_dir = Path(__file__).parent / "data"
data_dir.mkdir(exist_ok=True)
return data_dir
def get_tactical_data_dir():
"""Get/create data directory."""
data_dir = Path(__file__).parent / "tactical_data"
data_dir.mkdir(exist_ok=True)
return data_dir
def get_weights_dir():
"""Get/create weights directory."""
@@ -41,6 +42,14 @@ def get_weights_dir():
weights_dir.mkdir(exist_ok=True)
return weights_dir
def get_data_dir():
"""Get/create legacy data directory (for migration)."""
data_dir = Path(__file__).parent / "data"
data_dir.mkdir(exist_ok=True)
return data_dir
def list_checkpoints():
"""List available checkpoint versions."""
weights_dir = get_weights_dir()
@@ -49,6 +58,20 @@ def list_checkpoints():
return []
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
def migrate_legacy_data():
"""On first run, offer to import existing data/training_data.jsonl as ds_v1."""
console = Console()
data_dir = get_data_dir()
legacy_file = data_dir / "training_data.jsonl"
datasets = list_datasets()
# Only migrate if legacy data exists and no datasets exist yet
if legacy_file.exists() and not datasets:
console.print("\n[cyan]Legacy data detected: data/training_data.jsonl[/cyan]")
console.print("[dim]Tip: Use 'Manage Training Data' menu to import it as ds_v1[/dim]")
def show_header():
"""Display application header."""
console = Console()
@@ -56,22 +79,23 @@ def show_header():
console.print(
Panel(
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
"[dim]Neural Network Utility Evaluation - Model Management[/dim]",
"[dim]Neural Network Utility Evaluation - Dataset & Model Management[/dim]",
border_style="cyan",
padding=(1, 2),
)
)
def show_checkpoints_table():
"""Display available checkpoints in a table."""
console = Console()
available = list_checkpoints()
if not available:
console.print("[yellow] No checkpoints found yet[/yellow]")
console.print("[yellow] No model checkpoints found yet[/yellow]")
return
table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan")
table = Table(title="Available Model Checkpoints", show_header=True, header_style="bold cyan")
table.add_column("Version", style="dim")
table.add_column("File Size", justify="right")
table.add_column("Status", justify="center")
@@ -87,46 +111,500 @@ def show_checkpoints_table():
console.print(table)
def show_main_menu():
"""Display and handle main menu."""
console = Console()
# Migrate legacy data on first run
migrate_legacy_data()
while True:
show_header()
show_checkpoints_table()
console.print("\n[bold]What would you like to do?[/bold]")
console.print("[cyan]1[/cyan] - Train NNUE Model")
console.print("[cyan]2[/cyan] - Burst Train NNUE Model")
console.print("[cyan]3[/cyan] - Export Weights to Scala")
console.print("[cyan]4[/cyan] - Extract Tactical Positions")
console.print("[cyan]5[/cyan] - View Checkpoints")
console.print("[cyan]6[/cyan] - Exit")
console.print("[cyan]1[/cyan] - Manage Training Data")
console.print("[cyan]2[/cyan] - Train Model")
console.print("[cyan]3[/cyan] - Export Model")
console.print("[cyan]4[/cyan] - Exit")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5", "6"])
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
if choice == "1":
datasets_menu()
elif choice == "2":
training_menu()
elif choice == "3":
export_interactive()
elif choice == "4":
console.print("[yellow]👋 Goodbye![/yellow]")
return
def datasets_menu():
"""Dataset management submenu."""
console = Console()
while True:
show_header()
show_datasets_table(console)
console.print("\n[bold]Training Data Management[/bold]")
console.print("[cyan]1[/cyan] - Create new dataset")
console.print("[cyan]2[/cyan] - Extend existing dataset")
console.print("[cyan]3[/cyan] - View all datasets")
console.print("[cyan]4[/cyan] - Delete dataset")
console.print("[cyan]5[/cyan] - Back")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
if choice == "1":
create_dataset_interactive()
elif choice == "2":
extend_dataset_interactive()
elif choice == "3":
show_header()
show_datasets_table(console)
Prompt.ask("\nPress Enter to continue")
elif choice == "4":
delete_dataset_interactive()
elif choice == "5":
return
def create_dataset_interactive():
"""Interactive dataset creation flow."""
console = Console()
show_header()
console.print("\n[bold cyan]📊 Create New Dataset[/bold cyan]")
sources = []
combined_count = 0
# Allow user to add multiple sources
while True:
console.print("\n[bold]Add data source (repeat until done):[/bold]")
console.print("[cyan]a[/cyan] - Generate random positions")
console.print("[cyan]b[/cyan] - Import from file")
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
console.print("[cyan]d[/cyan] - Done adding sources")
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
if choice == "a":
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
min_move = int(Prompt.ask("Minimum move number", default="1"))
max_move = int(Prompt.ask("Maximum move number", default="50"))
num_workers = int(Prompt.ask("Number of workers", default="8"))
console.print("[dim]Generating positions...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
count = play_random_game_and_collect_positions(
str(temp_file),
total_positions=num_positions,
samples_per_game=1,
min_move=min_move,
max_move=max_move,
num_workers=num_workers
)
if count > 0:
sources.append({
"type": "generated",
"count": count,
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
})
combined_count += count
console.print(f"[green]✓ {count:,} positions generated[/green]")
else:
console.print("[red]✗ Generation failed[/red]")
elif choice == "b":
file_path = Prompt.ask("Path to FEN file")
try:
with open(file_path, 'r') as f:
count = sum(1 for _ in f)
sources.append({"type": "file_import", "count": count, "path": file_path})
combined_count += count
console.print(f"[green]✓ {count:,} positions from file[/green]")
except FileNotFoundError:
console.print(f"[red]✗ File not found: {file_path}[/red]")
elif choice == "c":
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
console.print("[dim]Extracting tactical positions...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
try:
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
if csv_path:
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
combined_count += count
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
except Exception as e:
console.print(f"[red]✗ Tactical extraction failed: {e}[/red]")
elif choice == "d":
if not sources:
console.print("[yellow]⚠ No sources added yet[/yellow]")
continue
break
if not sources:
console.print("[yellow]Dataset creation cancelled[/yellow]")
return
# Stockfish labeling parameters
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Summary and confirm
console.print("\n[bold]Dataset Summary:[/bold]")
console.print(f" Total positions: {combined_count:,}")
for source in sources:
console.print(f" - {source['type']}: {source['count']:,}")
console.print(f" Stockfish depth: {stockfish_depth}")
if not Confirm.ask("\nProceed to label and create dataset?", default=True):
console.print("[yellow]Cancelled[/yellow]")
return
try:
# Combine all sources into one FEN file
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
for source in sources:
if source['type'] == 'generated':
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source['type'] == 'file_import':
temp_file = Path(source['path'])
elif source['type'] == 'tactical':
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
if temp_file.exists():
with open(temp_file, 'r') as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, 'w') as f:
for fen in all_fens:
f.write(fen + '\n')
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers
)
if not success:
console.print("[red]✗ Labeling failed[/red]")
return
console.print("[green]✓ Positions labeled[/green]")
# Save dataset
console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]")
version = next_dataset_version()
create_dataset(
version=version,
labeled_jsonl_path=str(labeled_file),
sources=sources,
stockfish_depth=stockfish_depth
)
console.print(f"[green]✓ Dataset created: ds_v{version}[/green]")
console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]")
Prompt.ask("\nPress Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def extend_dataset_interactive():
"""Interactive dataset extension flow."""
console = Console()
show_header()
console.print("\n[bold cyan]📊 Extend Existing Dataset[/bold cyan]")
datasets = list_datasets()
if not datasets:
console.print("[yellow] No datasets available to extend[/yellow]")
Prompt.ask("Press Enter to continue")
return
show_datasets_table(console)
version = int(Prompt.ask("\nEnter dataset version to extend (e.g., 1)"))
if not any(v == version for v, _ in datasets):
console.print("[red]✗ Dataset not found[/red]")
return
sources = []
combined_count = 0
# Allow user to add sources
while True:
console.print("\n[bold]Add data source:[/bold]")
console.print("[cyan]a[/cyan] - Generate random positions")
console.print("[cyan]b[/cyan] - Import from file")
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
console.print("[cyan]d[/cyan] - Done adding sources")
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
if choice == "a":
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
min_move = int(Prompt.ask("Minimum move number", default="1"))
max_move = int(Prompt.ask("Maximum move number", default="50"))
num_workers = int(Prompt.ask("Number of workers", default="8"))
console.print("[dim]Generating positions...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
count = play_random_game_and_collect_positions(
str(temp_file),
total_positions=num_positions,
samples_per_game=1,
min_move=min_move,
max_move=max_move,
num_workers=num_workers
)
if count > 0:
sources.append({
"type": "generated",
"count": count,
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
})
combined_count += count
console.print(f"[green]✓ {count:,} positions generated[/green]")
elif choice == "b":
file_path = Prompt.ask("Path to FEN file")
try:
with open(file_path, 'r') as f:
count = sum(1 for _ in f)
sources.append({"type": "file_import", "count": count, "path": file_path})
combined_count += count
console.print(f"[green]✓ {count:,} positions from file[/green]")
except FileNotFoundError:
console.print(f"[red]✗ File not found: {file_path}[/red]")
elif choice == "c":
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
console.print("[dim]Extracting tactical positions...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
try:
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
if csv_path:
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
combined_count += count
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
except Exception as e:
console.print(f"[red]✗ Extraction failed: {e}[/red]")
elif choice == "d":
if not sources:
console.print("[yellow]⚠ No sources added yet[/yellow]")
continue
break
if not sources:
console.print("[yellow]Extension cancelled[/yellow]")
return
# Stockfish labeling parameters
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Summary and confirm
console.print("\n[bold]Extension Summary:[/bold]")
console.print(f" Target dataset: ds_v{version}")
console.print(f" New positions: {combined_count:,}")
for source in sources:
console.print(f" - {source['type']}: {source['count']:,}")
if not Confirm.ask("\nProceed to label and extend dataset?", default=True):
console.print("[yellow]Cancelled[/yellow]")
return
try:
# Combine all sources into one FEN file
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
for source in sources:
if source['type'] == 'generated':
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source['type'] == 'file_import':
temp_file = Path(source['path'])
elif source['type'] == 'tactical':
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
if temp_file.exists():
with open(temp_file, 'r') as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, 'w') as f:
for fen in all_fens:
f.write(fen + '\n')
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers
)
if not success:
console.print("[red]✗ Labeling failed[/red]")
return
console.print("[green]✓ Positions labeled[/green]")
# Extend dataset
console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]")
success = extend_dataset(
version=version,
new_labeled_path=str(labeled_file),
new_source_entry={
"type": "merged_sources",
"count": len(all_fens),
"sources": sources
}
)
if success:
metadata = load_dataset_metadata(version)
console.print(f"[green]✓ Dataset extended[/green]")
console.print(f"[bold]Total positions: {metadata['total_positions']:,}[/bold]")
else:
console.print("[red]✗ Extension failed[/red]")
Prompt.ask("\nPress Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def delete_dataset_interactive():
"""Interactive dataset deletion."""
console = Console()
show_header()
console.print("\n[bold cyan]⚠️ Delete Dataset[/bold cyan]")
datasets = list_datasets()
if not datasets:
console.print("[yellow] No datasets to delete[/yellow]")
Prompt.ask("Press Enter to continue")
return
show_datasets_table(console)
version = int(Prompt.ask("\nEnter dataset version to delete (e.g., 1)"))
if not any(v == version for v, _ in datasets):
console.print("[red]✗ Dataset not found[/red]")
return
if Confirm.ask(f"Delete ds_v{version}? This cannot be undone.", default=False):
if delete_dataset(version):
console.print(f"[green]✓ Dataset ds_v{version} deleted[/green]")
else:
console.print("[red]✗ Deletion failed[/red]")
Prompt.ask("Press Enter to continue")
def training_menu():
"""Training submenu."""
console = Console()
while True:
show_header()
console.print("\n[bold]Training[/bold]")
console.print("[cyan]1[/cyan] - Standard Training")
console.print("[cyan]2[/cyan] - Burst Training")
console.print("[cyan]3[/cyan] - View Model Checkpoints")
console.print("[cyan]4[/cyan] - Back")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
if choice == "1":
train_interactive()
elif choice == "2":
burst_train_interactive()
elif choice == "3":
export_interactive()
elif choice == "4":
extract_tactical_interactive()
elif choice == "5":
show_header()
show_checkpoints_table()
Prompt.ask("\nPress Enter to continue")
elif choice == "6":
console.print("[yellow]👋 Goodbye![/yellow]")
elif choice == "4":
return
def train_interactive():
"""Interactive training menu."""
console = Console()
show_header()
console.print("\n[bold cyan]📚 Training Configuration[/bold cyan]")
console.print("\n[bold cyan]📚 Standard Training Configuration[/bold cyan]")
# Dataset selection
datasets = list_datasets()
if not datasets:
console.print("[red]✗ No datasets available. Create one first.[/red]")
Prompt.ask("Press Enter to continue")
return
console.print("\n[bold]Available Datasets:[/bold]")
show_datasets_table(console)
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
if not any(v == dataset_version for v, _ in datasets):
console.print("[red]✗ Dataset not found[/red]")
return
labeled_file = get_dataset_labeled_path(dataset_version)
if not labeled_file:
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
return
# Checkpoint selection
available = list_checkpoints()
@@ -142,36 +620,6 @@ def train_interactive():
default=str(max(available))
)
# Positions source
use_existing = Confirm.ask("Use existing positions file?", default=False)
positions_file = None
num_games = 500000
samples_per_game = 1
min_move = 1
max_move = 50
if use_existing:
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
else:
num_games = int(Prompt.ask("Number of games to generate", default="5000"))
samples_per_game = int(Prompt.ask("Positions to sample per game", default="1"))
min_move = int(Prompt.ask("Minimum move number", default="1"))
max_move = int(Prompt.ask("Maximum move number", default="50"))
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
labels_file = None
if use_existing_labels:
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
# Stockfish path and labeling parameters
default_stockfish = os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
stockfish_depth = 12
num_workers = 1
if not use_existing_labels:
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Training parameters
epochs = int(Prompt.ask("Number of epochs", default="100"))
batch_size = int(Prompt.ask("Batch size", default="16384"))
@@ -182,16 +630,7 @@ def train_interactive():
# Confirm and start
console.print("\n[bold]Configuration Summary:[/bold]")
if use_checkpoint:
console.print(f" Checkpoint: v{checkpoint_version}")
else:
console.print(" Checkpoint: None (training from scratch)")
if not use_existing:
console.print(f" Games: {num_games:,}")
console.print(f" Samples per game: {samples_per_game}")
console.print(f" Move range: {min_move}-{max_move}")
else:
console.print(f" Positions file: {positions_file}")
console.print(f" Dataset: ds_v{dataset_version}")
console.print(f" Epochs: {epochs}")
console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
@@ -199,70 +638,27 @@ def train_interactive():
console.print(f" Early stopping: Yes (patience: {early_stopping})")
else:
console.print(f" Early stopping: No")
if not use_existing_labels:
console.print(f" Stockfish depth: {stockfish_depth}")
console.print(f" Workers: {num_workers}")
console.print(f" Stockfish: {stockfish_path}")
if use_checkpoint:
console.print(f" Checkpoint: v{checkpoint_version}")
else:
console.print(f" Checkpoint: None (training from scratch)")
if not Confirm.ask("\nStart training?", default=True):
console.print("[yellow]Training cancelled[/yellow]")
Prompt.ask("Press Enter to continue")
return
# Execute training
data_dir = get_data_dir()
weights_dir = get_weights_dir()
try:
# Generate positions
if not use_existing:
console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]")
count = play_random_game_and_collect_positions(
str(data_dir / "positions.txt"),
total_games=num_games,
samples_per_game=samples_per_game,
min_move=min_move,
max_move=max_move
)
if count == 0:
console.print("[red]✗ No valid positions generated[/red]")
return
console.print(f"[green]✓ Generated {count:,} positions[/green]")
else:
if not Path(positions_file).exists():
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
return
if not use_existing_labels:
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
positions_file = data_dir / "positions.txt"
output_file = data_dir / "training_data.jsonl"
success = label_positions_with_stockfish(
str(positions_file),
str(output_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers
)
if not success:
console.print("[red]✗ Position labeling failed[/red]")
return
console.print(f"[green]✓ Positions labeled[/green]")
else:
console.print("\n[bold cyan]Step 2: Loading Existing Labels[/bold cyan]")
output_file = labels_file
if not Path(output_file).exists():
console.print(f"[red]✗ Labels file not found: {output_file}[/red]")
return
# Train model
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
console.print("\n[bold cyan]Training Model[/bold cyan]")
checkpoint = None
if use_checkpoint:
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
train_nnue(
data_file=str(output_file),
data_file=str(labeled_file),
output_file=str(weights_dir / "nnue_weights.pt"),
epochs=epochs,
batch_size=batch_size,
@@ -286,6 +682,7 @@ def train_interactive():
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def burst_train_interactive():
"""Interactive burst training menu."""
console = Console()
@@ -294,14 +691,30 @@ def burst_train_interactive():
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
# Dataset selection
datasets = list_datasets()
if not datasets:
console.print("[red]✗ No datasets available. Create one first.[/red]")
Prompt.ask("Press Enter to continue")
return
console.print("[bold]Available Datasets:[/bold]")
show_datasets_table(console)
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
if not any(v == dataset_version for v, _ in datasets):
console.print("[red]✗ Dataset not found[/red]")
return
labeled_file = get_dataset_labeled_path(dataset_version)
if not labeled_file:
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
return
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
# Data file
default_labels = str(get_data_dir() / "training_data.jsonl")
labels_file = Prompt.ask("Path to labeled data file (.jsonl)", default=default_labels)
# Optional initial checkpoint
available = list_checkpoints()
checkpoint = None
@@ -317,29 +730,25 @@ def burst_train_interactive():
# Summary
console.print("\n[bold]Configuration Summary:[/bold]")
console.print(f" Dataset: ds_v{dataset_version}")
console.print(f" Duration: {duration_minutes:.0f} minutes")
console.print(f" Epochs per season: {epochs_per_season}")
console.print(f" Patience: {early_stopping_patience}")
console.print(f" Data file: {labels_file}")
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
if not Confirm.ask("\nStart burst training?", default=True):
console.print("[yellow]Burst training cancelled[/yellow]")
Prompt.ask("Press Enter to continue")
return
weights_dir = get_weights_dir()
try:
if not Path(labels_file).exists():
console.print(f"[red]✗ Data file not found: {labels_file}[/red]")
Prompt.ask("Press Enter to continue")
return
console.print("\n[bold cyan]Burst Training[/bold cyan]")
burst_train(
data_file=labels_file,
data_file=str(labeled_file),
output_file=str(weights_dir / "nnue_weights.pt"),
duration_minutes=duration_minutes,
epochs_per_season=epochs_per_season,
@@ -362,6 +771,7 @@ def burst_train_interactive():
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def export_interactive():
"""Interactive export menu."""
console = Console()
@@ -380,7 +790,7 @@ def export_interactive():
version = Prompt.ask("Enter version to export (e.g., 2)")
weights_file = f"nnue_weights_v{version}.pt"
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.nbai")
console.print(f"\n[bold]Export Configuration:[/bold]")
console.print(f" Source: {weights_file}")
@@ -399,7 +809,7 @@ def export_interactive():
return
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
export_weights_to_binary(str(weights_path), output_file)
export_to_nbai(str(weights_path), output_file)
console.print(f"\n[green]✓ Export complete![/green]")
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
Prompt.ask("Press Enter to continue")
@@ -410,65 +820,6 @@ def export_interactive():
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def extract_tactical_interactive():
"""Interactive tactical positions extraction and merge menu."""
console = Console()
show_header()
console.print("\n[bold cyan]♟️ Tactical Positions Extraction & Merge[/bold cyan]")
# Download and extract options
console.print("\n[bold]Lichess Puzzle Database:[/bold]")
download_url = Prompt.ask(
"Download URL",
default="https://database.lichess.org/lichess_db_puzzle.csv.zst"
)
output_dir = Prompt.ask(
"Extract to directory",
default=str(Path(__file__).parent / "trainingdata")
)
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
# Confirm and download
console.print("\n[bold]Configuration:[/bold]")
console.print(f" Download URL: {download_url}")
console.print(f" Extract directory: {output_dir}")
console.print(f" Max puzzles: {max_puzzles:,}")
if not Confirm.ask("\nProceed?", default=True):
console.print("[yellow]Cancelled[/yellow]")
Prompt.ask("Press Enter to continue")
return
try:
console.print("\n[bold cyan]Step 1: Download & Extract[/bold cyan]")
csv_path = download_and_extract_puzzle_db(download_url, output_dir)
if not csv_path:
console.print("[red]✗ Failed to download/extract[/red]")
Prompt.ask("Press Enter to continue")
return
console.print(f"[green]✓ Ready: {csv_path}[/green]")
# Interactive merge
console.print("\n[bold cyan]Step 2: Extract & Merge[/bold cyan]")
output_file = Prompt.ask(
"Output file path",
default=str(Path(__file__).parent / "data" / "position.txt")
)
interactive_merge_positions(csv_path, output_file, max_puzzles)
console.print(f"\n[green]✓ Complete![/green]")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def main():
try:
@@ -481,7 +832,10 @@ def main():
except Exception as e:
console = Console()
console.print(f"[red]Error:[/red] {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())