feat: enhance nnue.py and train.py with options for existing labels and improved file handling
This commit is contained in:
+28
-16
@@ -131,12 +131,17 @@ def train_interactive():
|
||||
num_games = 500000
|
||||
|
||||
if use_existing:
|
||||
positions_file = Prompt.ask("Enter path to positions file")
|
||||
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
|
||||
else:
|
||||
num_games = int(Prompt.ask("Number of games to generate", default="500000"))
|
||||
|
||||
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
|
||||
labels_file = None
|
||||
if use_existing_labels:
|
||||
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
|
||||
|
||||
# Stockfish path
|
||||
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
|
||||
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/bin/stockfish")
|
||||
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
|
||||
|
||||
# Training parameters
|
||||
@@ -180,20 +185,27 @@ def train_interactive():
|
||||
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
|
||||
return
|
||||
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
positions_file = data_dir / "positions.txt"
|
||||
output_file = data_dir / "training_data.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(positions_file),
|
||||
str(output_file),
|
||||
stockfish_path,
|
||||
depth=12
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Position labeling failed[/red]")
|
||||
return
|
||||
console.print(f"[green]✓ Positions labeled[/green]")
|
||||
if not use_existing_labels:
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
positions_file = data_dir / "positions.txt"
|
||||
output_file = data_dir / "training_data.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(positions_file),
|
||||
str(output_file),
|
||||
stockfish_path,
|
||||
depth=12
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Position labeling failed[/red]")
|
||||
return
|
||||
console.print(f"[green]✓ Positions labeled[/green]")
|
||||
else:
|
||||
console.print("\n[bold cyan]Step 2: Loading Existing Labels[/bold cyan]")
|
||||
output_file = labels_file
|
||||
if not Path(output_file).exists():
|
||||
console.print(f"[red]✗ Labels file not found: {output_file}[/red]")
|
||||
return
|
||||
|
||||
# Train model
|
||||
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
|
||||
|
||||
@@ -92,10 +92,14 @@ def find_next_version(base_name="nnue_weights"):
|
||||
Looks for nnue_weights_v*.pt files and returns the next version number.
|
||||
If no versioned files exist, returns 1.
|
||||
"""
|
||||
pattern = re.compile(rf"{re.escape(base_name)}_v(\d+)\.pt")
|
||||
base_path = Path(base_name)
|
||||
directory = base_path.parent
|
||||
filename = base_path.name
|
||||
|
||||
pattern = re.compile(rf"{re.escape(filename)}_v(\d+)\.pt")
|
||||
versions = []
|
||||
|
||||
for file in Path(".").glob(f"{base_name}_v*.pt"):
|
||||
for file in directory.glob(f"{filename}_v*.pt"):
|
||||
match = pattern.match(file.name)
|
||||
if match:
|
||||
versions.append(int(match.group(1)))
|
||||
|
||||
Reference in New Issue
Block a user