diff --git a/modules/bot/python/nnue.py b/modules/bot/python/nnue.py index ccb9585..156a89c 100644 --- a/modules/bot/python/nnue.py +++ b/modules/bot/python/nnue.py @@ -131,12 +131,17 @@ def train_interactive(): num_games = 500000 if use_existing: - positions_file = Prompt.ask("Enter path to positions file") + positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt")) else: num_games = int(Prompt.ask("Number of games to generate", default="500000")) + use_existing_labels = Confirm.ask("Use existing labels file?", default=False) + labels_file = None + if use_existing_labels: + labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl")) + # Stockfish path - default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish") + default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/bin/stockfish") stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish) # Training parameters @@ -180,20 +185,27 @@ def train_interactive(): console.print(f"[red]✗ Positions file not found: {positions_file}[/red]") return - # Label positions - console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]") - positions_file = data_dir / "positions.txt" - output_file = data_dir / "training_data.jsonl" - success = label_positions_with_stockfish( - str(positions_file), - str(output_file), - stockfish_path, - depth=12 - ) - if not success: - console.print("[red]✗ Position labeling failed[/red]") - return - console.print(f"[green]✓ Positions labeled[/green]") + if not use_existing_labels: + # Label positions + console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]") + positions_file = data_dir / "positions.txt" + output_file = data_dir / "training_data.jsonl" + success = label_positions_with_stockfish( + str(positions_file), + str(output_file), + stockfish_path, + depth=12 + ) + if not success: + console.print("[red]✗ Position labeling failed[/red]") + return + console.print(f"[green]✓ Positions labeled[/green]") + else: + console.print("\n[bold cyan]Step 2: Loading Existing Labels[/bold cyan]") + output_file = labels_file + if not Path(output_file).exists(): + console.print(f"[red]✗ Labels file not found: {output_file}[/red]") + return # Train model console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]") diff --git a/modules/bot/python/src/train.py b/modules/bot/python/src/train.py index 9f9a83d..5fe4176 100644 --- a/modules/bot/python/src/train.py +++ b/modules/bot/python/src/train.py @@ -92,10 +92,14 @@ def find_next_version(base_name="nnue_weights"): Looks for nnue_weights_v*.pt files and returns the next version number. If no versioned files exist, returns 1. """ - pattern = re.compile(rf"{re.escape(base_name)}_v(\d+)\.pt") + base_path = Path(base_name) + directory = base_path.parent + filename = base_path.name + + pattern = re.compile(rf"{re.escape(filename)}_v(\d+)\.pt") versions = [] - for file in Path(".").glob(f"{base_name}_v*.pt"): + for file in directory.glob(f"{filename}_v*.pt"): match = pattern.match(file.name) if match: versions.append(int(match.group(1)))