feat: enhance training and evaluation processes with new parameters and normalization options
This commit is contained in:
@@ -129,11 +129,17 @@ def train_interactive():
|
||||
use_existing = Confirm.ask("Use existing positions file?", default=False)
|
||||
positions_file = None
|
||||
num_games = 500000
|
||||
samples_per_game = 1
|
||||
min_move = 1
|
||||
max_move = 50
|
||||
|
||||
if use_existing:
|
||||
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
|
||||
else:
|
||||
num_games = int(Prompt.ask("Number of games to generate", default="500000"))
|
||||
num_games = int(Prompt.ask("Number of games to generate", default="5000"))
|
||||
samples_per_game = int(Prompt.ask("Positions to sample per game", default="1"))
|
||||
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||
|
||||
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
|
||||
labels_file = None
|
||||
@@ -147,6 +153,9 @@ def train_interactive():
|
||||
# Training parameters
|
||||
epochs = int(Prompt.ask("Number of epochs", default="20"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="4096"))
|
||||
early_stopping = None
|
||||
if Confirm.ask("Enable early stopping?", default=False):
|
||||
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||||
|
||||
# Confirm and start
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
@@ -154,9 +163,18 @@ def train_interactive():
|
||||
console.print(f" Checkpoint: v{checkpoint_version}")
|
||||
else:
|
||||
console.print(" Checkpoint: None (training from scratch)")
|
||||
console.print(f" Games: {num_games:,}")
|
||||
if not use_existing:
|
||||
console.print(f" Games: {num_games:,}")
|
||||
console.print(f" Samples per game: {samples_per_game}")
|
||||
console.print(f" Move range: {min_move}-{max_move}")
|
||||
else:
|
||||
console.print(f" Positions file: {positions_file}")
|
||||
console.print(f" Epochs: {epochs}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
if early_stopping:
|
||||
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||||
else:
|
||||
console.print(f" Early stopping: No")
|
||||
console.print(f" Stockfish: {stockfish_path}")
|
||||
|
||||
if not Confirm.ask("\nStart training?", default=True):
|
||||
@@ -174,7 +192,9 @@ def train_interactive():
|
||||
count = play_random_game_and_collect_positions(
|
||||
str(data_dir / "positions.txt"),
|
||||
total_games=num_games,
|
||||
filter_captures=True
|
||||
samples_per_game=samples_per_game,
|
||||
min_move=min_move,
|
||||
max_move=max_move
|
||||
)
|
||||
if count == 0:
|
||||
console.print("[red]✗ No valid positions generated[/red]")
|
||||
@@ -219,7 +239,8 @@ def train_interactive():
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
checkpoint=checkpoint,
|
||||
use_versioning=True
|
||||
use_versioning=True,
|
||||
early_stopping_patience=early_stopping
|
||||
)
|
||||
console.print("[green]✓ Training complete[/green]")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user