feat: add metadata files for NNUE weights versions and implement burst training functionality with subsampling options
This commit is contained in:
+92
-10
@@ -16,7 +16,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||||
|
||||
from generate import play_random_game_and_collect_positions
|
||||
from label import label_positions_with_stockfish
|
||||
from train import train_nnue
|
||||
from train import train_nnue, burst_train
|
||||
from export import export_weights_to_binary
|
||||
from tactical_positions_extractor import (
|
||||
download_and_extract_puzzle_db,
|
||||
@@ -97,24 +97,27 @@ def show_main_menu():
|
||||
|
||||
console.print("\n[bold]What would you like to do?[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Train NNUE Model")
|
||||
console.print("[cyan]2[/cyan] - Export Weights to Scala")
|
||||
console.print("[cyan]3[/cyan] - Extract Tactical Positions")
|
||||
console.print("[cyan]4[/cyan] - View Checkpoints")
|
||||
console.print("[cyan]5[/cyan] - Exit")
|
||||
console.print("[cyan]2[/cyan] - Burst Train NNUE Model")
|
||||
console.print("[cyan]3[/cyan] - Export Weights to Scala")
|
||||
console.print("[cyan]4[/cyan] - Extract Tactical Positions")
|
||||
console.print("[cyan]5[/cyan] - View Checkpoints")
|
||||
console.print("[cyan]6[/cyan] - Exit")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5", "6"])
|
||||
|
||||
if choice == "1":
|
||||
train_interactive()
|
||||
elif choice == "2":
|
||||
export_interactive()
|
||||
burst_train_interactive()
|
||||
elif choice == "3":
|
||||
extract_tactical_interactive()
|
||||
export_interactive()
|
||||
elif choice == "4":
|
||||
extract_tactical_interactive()
|
||||
elif choice == "5":
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
elif choice == "5":
|
||||
elif choice == "6":
|
||||
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||
return
|
||||
|
||||
@@ -172,6 +175,7 @@ def train_interactive():
|
||||
# Training parameters
|
||||
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||
early_stopping = None
|
||||
if Confirm.ask("Enable early stopping?", default=False):
|
||||
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||||
@@ -190,6 +194,7 @@ def train_interactive():
|
||||
console.print(f" Positions file: {positions_file}")
|
||||
console.print(f" Epochs: {epochs}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||
if early_stopping:
|
||||
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||||
else:
|
||||
@@ -263,7 +268,8 @@ def train_interactive():
|
||||
batch_size=batch_size,
|
||||
checkpoint=checkpoint,
|
||||
use_versioning=True,
|
||||
early_stopping_patience=early_stopping
|
||||
early_stopping_patience=early_stopping,
|
||||
subsample_ratio=subsample_ratio
|
||||
)
|
||||
console.print("[green]✓ Training complete[/green]")
|
||||
|
||||
@@ -280,6 +286,82 @@ def train_interactive():
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def burst_train_interactive():
|
||||
"""Interactive burst training menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
|
||||
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
|
||||
|
||||
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
|
||||
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
|
||||
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
|
||||
|
||||
# Data file
|
||||
default_labels = str(get_data_dir() / "training_data.jsonl")
|
||||
labels_file = Prompt.ask("Path to labeled data file (.jsonl)", default=default_labels)
|
||||
|
||||
# Optional initial checkpoint
|
||||
available = list_checkpoints()
|
||||
checkpoint = None
|
||||
if available:
|
||||
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||
if Confirm.ask("Start from an existing checkpoint?", default=False):
|
||||
version = Prompt.ask("Enter checkpoint version", default=str(max(available)))
|
||||
checkpoint = str(get_weights_dir() / f"nnue_weights_v{version}.pt")
|
||||
|
||||
# Training hyperparameters
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||
|
||||
# Summary
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
console.print(f" Duration: {duration_minutes:.0f} minutes")
|
||||
console.print(f" Epochs per season: {epochs_per_season}")
|
||||
console.print(f" Patience: {early_stopping_patience}")
|
||||
console.print(f" Data file: {labels_file}")
|
||||
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||
|
||||
if not Confirm.ask("\nStart burst training?", default=True):
|
||||
console.print("[yellow]Burst training cancelled[/yellow]")
|
||||
return
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
try:
|
||||
if not Path(labels_file).exists():
|
||||
console.print(f"[red]✗ Data file not found: {labels_file}[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print("\n[bold cyan]Burst Training[/bold cyan]")
|
||||
burst_train(
|
||||
data_file=labels_file,
|
||||
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||
duration_minutes=duration_minutes,
|
||||
epochs_per_season=epochs_per_season,
|
||||
early_stopping_patience=early_stopping_patience,
|
||||
batch_size=batch_size,
|
||||
initial_checkpoint=checkpoint,
|
||||
use_versioning=True,
|
||||
subsample_ratio=subsample_ratio,
|
||||
)
|
||||
console.print("[green]✓ Burst training complete[/green]")
|
||||
|
||||
available = list_checkpoints()
|
||||
if available:
|
||||
console.print(f"[bold]Latest checkpoint: v{max(available)}[/bold]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def export_interactive():
|
||||
"""Interactive export menu."""
|
||||
console = Console()
|
||||
|
||||
Reference in New Issue
Block a user