feat: add metadata files for NNUE weights versions and implement burst training functionality with subsampling options

This commit is contained in:
2026-04-10 02:16:01 +02:00
parent 5d4cf5f13c
commit 01c8a0f8fe
14 changed files with 465 additions and 134 deletions
+92 -10
View File
@@ -16,7 +16,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish
from train import train_nnue
from train import train_nnue, burst_train
from export import export_weights_to_binary
from tactical_positions_extractor import (
download_and_extract_puzzle_db,
@@ -97,24 +97,27 @@ def show_main_menu():
console.print("\n[bold]What would you like to do?[/bold]")
console.print("[cyan]1[/cyan] - Train NNUE Model")
console.print("[cyan]2[/cyan] - Export Weights to Scala")
console.print("[cyan]3[/cyan] - Extract Tactical Positions")
console.print("[cyan]4[/cyan] - View Checkpoints")
console.print("[cyan]5[/cyan] - Exit")
console.print("[cyan]2[/cyan] - Burst Train NNUE Model")
console.print("[cyan]3[/cyan] - Export Weights to Scala")
console.print("[cyan]4[/cyan] - Extract Tactical Positions")
console.print("[cyan]5[/cyan] - View Checkpoints")
console.print("[cyan]6[/cyan] - Exit")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5", "6"])
if choice == "1":
train_interactive()
elif choice == "2":
export_interactive()
burst_train_interactive()
elif choice == "3":
extract_tactical_interactive()
export_interactive()
elif choice == "4":
extract_tactical_interactive()
elif choice == "5":
show_header()
show_checkpoints_table()
Prompt.ask("\nPress Enter to continue")
elif choice == "5":
elif choice == "6":
console.print("[yellow]👋 Goodbye![/yellow]")
return
@@ -172,6 +175,7 @@ def train_interactive():
# Training parameters
epochs = int(Prompt.ask("Number of epochs", default="100"))
batch_size = int(Prompt.ask("Batch size", default="16384"))
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
early_stopping = None
if Confirm.ask("Enable early stopping?", default=False):
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
@@ -190,6 +194,7 @@ def train_interactive():
console.print(f" Positions file: {positions_file}")
console.print(f" Epochs: {epochs}")
console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
if early_stopping:
console.print(f" Early stopping: Yes (patience: {early_stopping})")
else:
@@ -263,7 +268,8 @@ def train_interactive():
batch_size=batch_size,
checkpoint=checkpoint,
use_versioning=True,
early_stopping_patience=early_stopping
early_stopping_patience=early_stopping,
subsample_ratio=subsample_ratio
)
console.print("[green]✓ Training complete[/green]")
@@ -280,6 +286,82 @@ def train_interactive():
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def burst_train_interactive():
"""Interactive burst training menu."""
console = Console()
show_header()
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
# Data file
default_labels = str(get_data_dir() / "training_data.jsonl")
labels_file = Prompt.ask("Path to labeled data file (.jsonl)", default=default_labels)
# Optional initial checkpoint
available = list_checkpoints()
checkpoint = None
if available:
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
if Confirm.ask("Start from an existing checkpoint?", default=False):
version = Prompt.ask("Enter checkpoint version", default=str(max(available)))
checkpoint = str(get_weights_dir() / f"nnue_weights_v{version}.pt")
# Training hyperparameters
batch_size = int(Prompt.ask("Batch size", default="16384"))
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
# Summary
console.print("\n[bold]Configuration Summary:[/bold]")
console.print(f" Duration: {duration_minutes:.0f} minutes")
console.print(f" Epochs per season: {epochs_per_season}")
console.print(f" Patience: {early_stopping_patience}")
console.print(f" Data file: {labels_file}")
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
if not Confirm.ask("\nStart burst training?", default=True):
console.print("[yellow]Burst training cancelled[/yellow]")
return
weights_dir = get_weights_dir()
try:
if not Path(labels_file).exists():
console.print(f"[red]✗ Data file not found: {labels_file}[/red]")
Prompt.ask("Press Enter to continue")
return
console.print("\n[bold cyan]Burst Training[/bold cyan]")
burst_train(
data_file=labels_file,
output_file=str(weights_dir / "nnue_weights.pt"),
duration_minutes=duration_minutes,
epochs_per_season=epochs_per_season,
early_stopping_patience=early_stopping_patience,
batch_size=batch_size,
initial_checkpoint=checkpoint,
use_versioning=True,
subsample_ratio=subsample_ratio,
)
console.print("[green]✓ Burst training complete[/green]")
available = list_checkpoints()
if available:
console.print(f"[bold]Latest checkpoint: v{max(available)}[/bold]")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def export_interactive():
"""Interactive export menu."""
console = Console()