feat: enhance nnue.py and train.py with options for existing labels and improved file handling

This commit is contained in:
2026-04-08 09:16:02 +02:00
committed by Janis
parent adc2de23bc
commit 34945e4fb8
2 changed files with 34 additions and 18 deletions
+28 -16
View File
@@ -131,12 +131,17 @@ def train_interactive():
num_games = 500000
if use_existing:
positions_file = Prompt.ask("Enter path to positions file")
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
else:
num_games = int(Prompt.ask("Number of games to generate", default="500000"))
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
labels_file = None
if use_existing_labels:
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
# Stockfish path
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/bin/stockfish")
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
# Training parameters
@@ -180,20 +185,27 @@ def train_interactive():
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
return
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
positions_file = data_dir / "positions.txt"
output_file = data_dir / "training_data.jsonl"
success = label_positions_with_stockfish(
str(positions_file),
str(output_file),
stockfish_path,
depth=12
)
if not success:
console.print("[red]✗ Position labeling failed[/red]")
return
console.print(f"[green]✓ Positions labeled[/green]")
if not use_existing_labels:
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
positions_file = data_dir / "positions.txt"
output_file = data_dir / "training_data.jsonl"
success = label_positions_with_stockfish(
str(positions_file),
str(output_file),
stockfish_path,
depth=12
)
if not success:
console.print("[red]✗ Position labeling failed[/red]")
return
console.print(f"[green]✓ Positions labeled[/green]")
else:
console.print("\n[bold cyan]Step 2: Loading Existing Labels[/bold cyan]")
output_file = labels_file
if not Path(output_file).exists():
console.print(f"[red]✗ Labels file not found: {output_file}[/red]")
return
# Train model
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
+6 -2
View File
@@ -92,10 +92,14 @@ def find_next_version(base_name="nnue_weights"):
Looks for nnue_weights_v*.pt files and returns the next version number.
If no versioned files exist, returns 1.
"""
pattern = re.compile(rf"{re.escape(base_name)}_v(\d+)\.pt")
base_path = Path(base_name)
directory = base_path.parent
filename = base_path.name
pattern = re.compile(rf"{re.escape(filename)}_v(\d+)\.pt")
versions = []
for file in Path(".").glob(f"{base_name}_v*.pt"):
for file in directory.glob(f"{filename}_v*.pt"):
match = pattern.match(file.name)
if match:
versions.append(int(match.group(1)))