feat: enhance nnue.py and train.py with options for existing labels and improved file handling
This commit is contained in:
+28
-16
@@ -131,12 +131,17 @@ def train_interactive():
|
||||
num_games = 500000
|
||||
|
||||
if use_existing:
|
||||
positions_file = Prompt.ask("Enter path to positions file")
|
||||
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
|
||||
else:
|
||||
num_games = int(Prompt.ask("Number of games to generate", default="500000"))
|
||||
|
||||
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
|
||||
labels_file = None
|
||||
if use_existing_labels:
|
||||
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
|
||||
|
||||
# Stockfish path
|
||||
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
|
||||
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/bin/stockfish")
|
||||
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
|
||||
|
||||
# Training parameters
|
||||
@@ -180,20 +185,27 @@ def train_interactive():
|
||||
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
|
||||
return
|
||||
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
positions_file = data_dir / "positions.txt"
|
||||
output_file = data_dir / "training_data.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(positions_file),
|
||||
str(output_file),
|
||||
stockfish_path,
|
||||
depth=12
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Position labeling failed[/red]")
|
||||
return
|
||||
console.print(f"[green]✓ Positions labeled[/green]")
|
||||
if not use_existing_labels:
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
positions_file = data_dir / "positions.txt"
|
||||
output_file = data_dir / "training_data.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(positions_file),
|
||||
str(output_file),
|
||||
stockfish_path,
|
||||
depth=12
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Position labeling failed[/red]")
|
||||
return
|
||||
console.print(f"[green]✓ Positions labeled[/green]")
|
||||
else:
|
||||
console.print("\n[bold cyan]Step 2: Loading Existing Labels[/bold cyan]")
|
||||
output_file = labels_file
|
||||
if not Path(output_file).exists():
|
||||
console.print(f"[red]✗ Labels file not found: {output_file}[/red]")
|
||||
return
|
||||
|
||||
# Train model
|
||||
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
|
||||
|
||||
Reference in New Issue
Block a user