diff --git a/modules/bot/python/.gitignore b/modules/bot/python/.gitignore index 54e198c..cf3ffd9 100644 --- a/modules/bot/python/.gitignore +++ b/modules/bot/python/.gitignore @@ -19,3 +19,4 @@ ENV/ *.swo tactical_data/ trainingdata/ +/datasets/ diff --git a/modules/bot/python/DATASETS.md b/modules/bot/python/DATASETS.md new file mode 100644 index 0000000..627634b --- /dev/null +++ b/modules/bot/python/DATASETS.md @@ -0,0 +1,173 @@ +# Training Dataset Management + +The NNUE training pipeline now features versioned dataset management, similar to model versioning. This prevents data loss and allows you to maintain multiple training configurations. + +## Directory Structure + +``` +datasets/ + ds_v1/ + labeled.jsonl # Training data: {"fen": "...", "eval": 0.5, "eval_raw": 150} + metadata.json # Version info and composition + ds_v2/ + labeled.jsonl + metadata.json +``` + +## Metadata Schema + +Each dataset has a `metadata.json` file tracking its composition: + +```json +{ + "version": 1, + "created": "2026-04-13T15:30:45.123456", + "total_positions": 1000000, + "stockfish_depth": 12, + "sources": [ + { + "type": "generated", + "count": 500000, + "params": { + "num_positions": 500000, + "min_move": 1, + "max_move": 50 + } + }, + { + "type": "tactical", + "count": 300000, + "max_puzzles": 300000 + }, + { + "type": "file_import", + "count": 200000, + "path": "/path/to/original_file.txt" + } + ] +} +``` + +## TUI Workflow + +### Main Menu +``` +1 - Manage Training Data +2 - Train Model +3 - Export Model +4 - Exit +``` + +### Training Data Management Submenu +``` +1 - Create new dataset +2 - Extend existing dataset +3 - View all datasets +4 - Delete dataset +5 - Back +``` + +## Creating a Dataset + +Use the "Create new dataset" option to add data from one or more sources: + +1. **Generate random positions** — Play random games and sample positions + - Number of positions + - Move range (min/max move number to sample from) + - Number of worker threads + +2. **Import from file** — Load positions from a FEN file + - File must contain one FEN string per line + - Duplicates are automatically removed + +3. **Extract tactical puzzles** — Download and extract Lichess puzzle database + - Maximum number of puzzles to include + - Automatically filters for tactical themes (forks, pins, mates, etc.) + +You can combine multiple sources in a single dataset creation session. All positions are: +- Deduplicated (only unique FENs are kept) +- Labeled with Stockfish evaluations +- Saved to `datasets/ds_vN/labeled.jsonl` + +## Extending a Dataset + +Use "Extend existing dataset" to add more positions to an existing dataset: + +1. Select the dataset version to extend +2. Choose data sources (same options as creation) +3. Confirm labeling parameters +4. New positions are: + - Labeled with Stockfish + - Deduplicated against the target dataset (preventing duplicates) + - Merged into the existing `labeled.jsonl` + - Metadata is updated with the new source entry + +## Training with a Dataset + +When you start training (Standard or Burst mode), you'll be prompted to select a dataset version. The TUI will display all available datasets with: +- Version number +- Total number of positions +- Source types (generated, tactical, imported) +- Stockfish depth used +- Creation date + +## Legacy Data Migration + +If you have existing labeled data in `data/training_data.jsonl` from before this update: + +1. Open the "Manage Training Data" menu +2. Choose "Create new dataset" +3. Select "Import from file" +4. Point to `data/training_data.jsonl` +5. Complete the dataset creation + +Alternatively, you can manually copy the file to `datasets/ds_v1/labeled.jsonl` and create a `metadata.json` file. + +## Viewing Dataset Details + +Use "View all datasets" to see a table of all datasets with: +- Version number +- Position count +- Source composition +- Stockfish depth +- Creation date + +## Deleting a Dataset + +Use "Delete dataset" to remove a dataset and free up disk space. **This action cannot be undone.** + +⚠️ The system does not prevent deleting datasets used by model checkpoints. Plan accordingly. + +## Technical Details + +### Deduplication Strategy + +When extending a dataset, positions are deduplicated **within that dataset only**. This allows different datasets to contain overlapping positions if desired. + +When creating a new dataset from multiple sources, all sources are combined and deduplicated before labeling. + +### Labeled Position Format + +Each line in `labeled.jsonl` is a JSON object: +```json +{ + "fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", + "eval": 0.0, + "eval_raw": 0 +} +``` + +- `fen`: The position in Forsyth-Edwards Notation +- `eval`: Normalized evaluation ([-1, 1] range using tanh) +- `eval_raw`: Raw Stockfish evaluation in centipawns + +### Storage Location + +Datasets are stored in the `datasets/` directory relative to the script location. The old `data/` directory is preserved for backward compatibility but is not actively used by the new system. + +## Performance Tips + +- **Smaller datasets train faster** — Start with 100k-500k positions +- **Deduplication matters** — Use the extend functionality to build up your dataset without redundant data +- **Stockfish depth** — Depth 12-14 balances accuracy and labeling speed +- **Workers** — Use 4-8 workers for labeling if your machine supports it; more workers = faster but uses more CPU/memory diff --git a/modules/bot/python/nnue.py b/modules/bot/python/nnue.py index bd5dc47..588c7c4 100644 --- a/modules/bot/python/nnue.py +++ b/modules/bot/python/nnue.py @@ -4,6 +4,7 @@ import os import shutil import sys +import tempfile from pathlib import Path from rich.console import Console from rich.table import Table @@ -17,23 +18,23 @@ sys.path.insert(0, str(Path(__file__).parent / "src")) from generate import play_random_game_and_collect_positions from label import label_positions_with_stockfish from train import train_nnue, burst_train -from export import export_weights_to_binary +from export import export_to_nbai from tactical_positions_extractor import ( download_and_extract_puzzle_db, - interactive_merge_positions + extract_tactical_only +) +from dataset import ( + get_datasets_dir, + list_datasets, + next_dataset_version, + load_dataset_metadata, + create_dataset, + extend_dataset, + get_dataset_labeled_path, + delete_dataset, + show_datasets_table ) -def get_data_dir(): - """Get/create data directory.""" - data_dir = Path(__file__).parent / "data" - data_dir.mkdir(exist_ok=True) - return data_dir - -def get_tactical_data_dir(): - """Get/create data directory.""" - data_dir = Path(__file__).parent / "tactical_data" - data_dir.mkdir(exist_ok=True) - return data_dir def get_weights_dir(): """Get/create weights directory.""" @@ -41,6 +42,14 @@ def get_weights_dir(): weights_dir.mkdir(exist_ok=True) return weights_dir + +def get_data_dir(): + """Get/create legacy data directory (for migration).""" + data_dir = Path(__file__).parent / "data" + data_dir.mkdir(exist_ok=True) + return data_dir + + def list_checkpoints(): """List available checkpoint versions.""" weights_dir = get_weights_dir() @@ -49,6 +58,20 @@ def list_checkpoints(): return [] return [int(cp.stem.split("_v")[1]) for cp in checkpoints] + +def migrate_legacy_data(): + """On first run, offer to import existing data/training_data.jsonl as ds_v1.""" + console = Console() + data_dir = get_data_dir() + legacy_file = data_dir / "training_data.jsonl" + datasets = list_datasets() + + # Only migrate if legacy data exists and no datasets exist yet + if legacy_file.exists() and not datasets: + console.print("\n[cyan]Legacy data detected: data/training_data.jsonl[/cyan]") + console.print("[dim]Tip: Use 'Manage Training Data' menu to import it as ds_v1[/dim]") + + def show_header(): """Display application header.""" console = Console() @@ -56,22 +79,23 @@ def show_header(): console.print( Panel( "[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n" - "[dim]Neural Network Utility Evaluation - Model Management[/dim]", + "[dim]Neural Network Utility Evaluation - Dataset & Model Management[/dim]", border_style="cyan", padding=(1, 2), ) ) + def show_checkpoints_table(): """Display available checkpoints in a table.""" console = Console() available = list_checkpoints() if not available: - console.print("[yellow]ℹ No checkpoints found yet[/yellow]") + console.print("[yellow]ℹ No model checkpoints found yet[/yellow]") return - table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan") + table = Table(title="Available Model Checkpoints", show_header=True, header_style="bold cyan") table.add_column("Version", style="dim") table.add_column("File Size", justify="right") table.add_column("Status", justify="center") @@ -87,46 +111,500 @@ def show_checkpoints_table(): console.print(table) + def show_main_menu(): """Display and handle main menu.""" console = Console() + # Migrate legacy data on first run + migrate_legacy_data() + while True: show_header() show_checkpoints_table() console.print("\n[bold]What would you like to do?[/bold]") - console.print("[cyan]1[/cyan] - Train NNUE Model") - console.print("[cyan]2[/cyan] - Burst Train NNUE Model") - console.print("[cyan]3[/cyan] - Export Weights to Scala") - console.print("[cyan]4[/cyan] - Extract Tactical Positions") - console.print("[cyan]5[/cyan] - View Checkpoints") - console.print("[cyan]6[/cyan] - Exit") + console.print("[cyan]1[/cyan] - Manage Training Data") + console.print("[cyan]2[/cyan] - Train Model") + console.print("[cyan]3[/cyan] - Export Model") + console.print("[cyan]4[/cyan] - Exit") - choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5", "6"]) + choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"]) + + if choice == "1": + datasets_menu() + elif choice == "2": + training_menu() + elif choice == "3": + export_interactive() + elif choice == "4": + console.print("[yellow]👋 Goodbye![/yellow]") + return + + +def datasets_menu(): + """Dataset management submenu.""" + console = Console() + + while True: + show_header() + show_datasets_table(console) + + console.print("\n[bold]Training Data Management[/bold]") + console.print("[cyan]1[/cyan] - Create new dataset") + console.print("[cyan]2[/cyan] - Extend existing dataset") + console.print("[cyan]3[/cyan] - View all datasets") + console.print("[cyan]4[/cyan] - Delete dataset") + console.print("[cyan]5[/cyan] - Back") + + choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"]) + + if choice == "1": + create_dataset_interactive() + elif choice == "2": + extend_dataset_interactive() + elif choice == "3": + show_header() + show_datasets_table(console) + Prompt.ask("\nPress Enter to continue") + elif choice == "4": + delete_dataset_interactive() + elif choice == "5": + return + + +def create_dataset_interactive(): + """Interactive dataset creation flow.""" + console = Console() + show_header() + + console.print("\n[bold cyan]📊 Create New Dataset[/bold cyan]") + + sources = [] + combined_count = 0 + + # Allow user to add multiple sources + while True: + console.print("\n[bold]Add data source (repeat until done):[/bold]") + console.print("[cyan]a[/cyan] - Generate random positions") + console.print("[cyan]b[/cyan] - Import from file") + console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles") + console.print("[cyan]d[/cyan] - Done adding sources") + + choice = Prompt.ask("Select", choices=["a", "b", "c", "d"]) + + if choice == "a": + num_positions = int(Prompt.ask("Number of positions to generate", default="100000")) + min_move = int(Prompt.ask("Minimum move number", default="1")) + max_move = int(Prompt.ask("Maximum move number", default="50")) + num_workers = int(Prompt.ask("Number of workers", default="8")) + + console.print("[dim]Generating positions...[/dim]") + temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt" + count = play_random_game_and_collect_positions( + str(temp_file), + total_positions=num_positions, + samples_per_game=1, + min_move=min_move, + max_move=max_move, + num_workers=num_workers + ) + if count > 0: + sources.append({ + "type": "generated", + "count": count, + "params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move} + }) + combined_count += count + console.print(f"[green]✓ {count:,} positions generated[/green]") + else: + console.print("[red]✗ Generation failed[/red]") + + elif choice == "b": + file_path = Prompt.ask("Path to FEN file") + try: + with open(file_path, 'r') as f: + count = sum(1 for _ in f) + sources.append({"type": "file_import", "count": count, "path": file_path}) + combined_count += count + console.print(f"[green]✓ {count:,} positions from file[/green]") + except FileNotFoundError: + console.print(f"[red]✗ File not found: {file_path}[/red]") + + elif choice == "c": + max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000")) + console.print("[dim]Extracting tactical positions...[/dim]") + temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt" + try: + csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data")) + if csv_path: + count = extract_tactical_only(csv_path, str(temp_file), max_puzzles) + sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles}) + combined_count += count + console.print(f"[green]✓ {count:,} tactical positions extracted[/green]") + except Exception as e: + console.print(f"[red]✗ Tactical extraction failed: {e}[/red]") + + elif choice == "d": + if not sources: + console.print("[yellow]⚠ No sources added yet[/yellow]") + continue + break + + if not sources: + console.print("[yellow]Dataset creation cancelled[/yellow]") + return + + # Stockfish labeling parameters + console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]") + stockfish_path = Prompt.ask( + "Stockfish path", + default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish" + ) + stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12")) + num_workers = int(Prompt.ask("Number of parallel workers", default="1")) + + # Summary and confirm + console.print("\n[bold]Dataset Summary:[/bold]") + console.print(f" Total positions: {combined_count:,}") + for source in sources: + console.print(f" - {source['type']}: {source['count']:,}") + console.print(f" Stockfish depth: {stockfish_depth}") + + if not Confirm.ask("\nProceed to label and create dataset?", default=True): + console.print("[yellow]Cancelled[/yellow]") + return + + try: + # Combine all sources into one FEN file + console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]") + combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt" + all_fens = set() + + for source in sources: + if source['type'] == 'generated': + temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt" + elif source['type'] == 'file_import': + temp_file = Path(source['path']) + elif source['type'] == 'tactical': + temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt" + + if temp_file.exists(): + with open(temp_file, 'r') as f: + for line in f: + fen = line.strip() + if fen: + all_fens.add(fen) + + with open(combined_fen_file, 'w') as f: + for fen in all_fens: + f.write(fen + '\n') + console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]") + + # Label positions + console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]") + labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl" + success = label_positions_with_stockfish( + str(combined_fen_file), + str(labeled_file), + stockfish_path, + depth=stockfish_depth, + num_workers=num_workers + ) + + if not success: + console.print("[red]✗ Labeling failed[/red]") + return + + console.print("[green]✓ Positions labeled[/green]") + + # Save dataset + console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]") + version = next_dataset_version() + create_dataset( + version=version, + labeled_jsonl_path=str(labeled_file), + sources=sources, + stockfish_depth=stockfish_depth + ) + console.print(f"[green]✓ Dataset created: ds_v{version}[/green]") + console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]") + + Prompt.ask("\nPress Enter to continue") + + except Exception as e: + console.print(f"[red]✗ Error: {e}[/red]") + import traceback + traceback.print_exc() + Prompt.ask("Press Enter to continue") + + +def extend_dataset_interactive(): + """Interactive dataset extension flow.""" + console = Console() + show_header() + + console.print("\n[bold cyan]📊 Extend Existing Dataset[/bold cyan]") + + datasets = list_datasets() + if not datasets: + console.print("[yellow]ℹ No datasets available to extend[/yellow]") + Prompt.ask("Press Enter to continue") + return + + show_datasets_table(console) + version = int(Prompt.ask("\nEnter dataset version to extend (e.g., 1)")) + + if not any(v == version for v, _ in datasets): + console.print("[red]✗ Dataset not found[/red]") + return + + sources = [] + combined_count = 0 + + # Allow user to add sources + while True: + console.print("\n[bold]Add data source:[/bold]") + console.print("[cyan]a[/cyan] - Generate random positions") + console.print("[cyan]b[/cyan] - Import from file") + console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles") + console.print("[cyan]d[/cyan] - Done adding sources") + + choice = Prompt.ask("Select", choices=["a", "b", "c", "d"]) + + if choice == "a": + num_positions = int(Prompt.ask("Number of positions to generate", default="100000")) + min_move = int(Prompt.ask("Minimum move number", default="1")) + max_move = int(Prompt.ask("Maximum move number", default="50")) + num_workers = int(Prompt.ask("Number of workers", default="8")) + + console.print("[dim]Generating positions...[/dim]") + temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt" + count = play_random_game_and_collect_positions( + str(temp_file), + total_positions=num_positions, + samples_per_game=1, + min_move=min_move, + max_move=max_move, + num_workers=num_workers + ) + if count > 0: + sources.append({ + "type": "generated", + "count": count, + "params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move} + }) + combined_count += count + console.print(f"[green]✓ {count:,} positions generated[/green]") + + elif choice == "b": + file_path = Prompt.ask("Path to FEN file") + try: + with open(file_path, 'r') as f: + count = sum(1 for _ in f) + sources.append({"type": "file_import", "count": count, "path": file_path}) + combined_count += count + console.print(f"[green]✓ {count:,} positions from file[/green]") + except FileNotFoundError: + console.print(f"[red]✗ File not found: {file_path}[/red]") + + elif choice == "c": + max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000")) + console.print("[dim]Extracting tactical positions...[/dim]") + temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt" + try: + csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data")) + if csv_path: + count = extract_tactical_only(csv_path, str(temp_file), max_puzzles) + sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles}) + combined_count += count + console.print(f"[green]✓ {count:,} tactical positions extracted[/green]") + except Exception as e: + console.print(f"[red]✗ Extraction failed: {e}[/red]") + + elif choice == "d": + if not sources: + console.print("[yellow]⚠ No sources added yet[/yellow]") + continue + break + + if not sources: + console.print("[yellow]Extension cancelled[/yellow]") + return + + # Stockfish labeling parameters + console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]") + stockfish_path = Prompt.ask( + "Stockfish path", + default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish" + ) + stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12")) + num_workers = int(Prompt.ask("Number of parallel workers", default="1")) + + # Summary and confirm + console.print("\n[bold]Extension Summary:[/bold]") + console.print(f" Target dataset: ds_v{version}") + console.print(f" New positions: {combined_count:,}") + for source in sources: + console.print(f" - {source['type']}: {source['count']:,}") + + if not Confirm.ask("\nProceed to label and extend dataset?", default=True): + console.print("[yellow]Cancelled[/yellow]") + return + + try: + # Combine all sources into one FEN file + console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]") + combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt" + all_fens = set() + + for source in sources: + if source['type'] == 'generated': + temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt" + elif source['type'] == 'file_import': + temp_file = Path(source['path']) + elif source['type'] == 'tactical': + temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt" + + if temp_file.exists(): + with open(temp_file, 'r') as f: + for line in f: + fen = line.strip() + if fen: + all_fens.add(fen) + + with open(combined_fen_file, 'w') as f: + for fen in all_fens: + f.write(fen + '\n') + console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]") + + # Label positions + console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]") + labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl" + success = label_positions_with_stockfish( + str(combined_fen_file), + str(labeled_file), + stockfish_path, + depth=stockfish_depth, + num_workers=num_workers + ) + + if not success: + console.print("[red]✗ Labeling failed[/red]") + return + + console.print("[green]✓ Positions labeled[/green]") + + # Extend dataset + console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]") + success = extend_dataset( + version=version, + new_labeled_path=str(labeled_file), + new_source_entry={ + "type": "merged_sources", + "count": len(all_fens), + "sources": sources + } + ) + + if success: + metadata = load_dataset_metadata(version) + console.print(f"[green]✓ Dataset extended[/green]") + console.print(f"[bold]Total positions: {metadata['total_positions']:,}[/bold]") + else: + console.print("[red]✗ Extension failed[/red]") + + Prompt.ask("\nPress Enter to continue") + + except Exception as e: + console.print(f"[red]✗ Error: {e}[/red]") + import traceback + traceback.print_exc() + Prompt.ask("Press Enter to continue") + + +def delete_dataset_interactive(): + """Interactive dataset deletion.""" + console = Console() + show_header() + + console.print("\n[bold cyan]⚠️ Delete Dataset[/bold cyan]") + + datasets = list_datasets() + if not datasets: + console.print("[yellow]ℹ No datasets to delete[/yellow]") + Prompt.ask("Press Enter to continue") + return + + show_datasets_table(console) + version = int(Prompt.ask("\nEnter dataset version to delete (e.g., 1)")) + + if not any(v == version for v, _ in datasets): + console.print("[red]✗ Dataset not found[/red]") + return + + if Confirm.ask(f"Delete ds_v{version}? This cannot be undone.", default=False): + if delete_dataset(version): + console.print(f"[green]✓ Dataset ds_v{version} deleted[/green]") + else: + console.print("[red]✗ Deletion failed[/red]") + + Prompt.ask("Press Enter to continue") + + +def training_menu(): + """Training submenu.""" + console = Console() + + while True: + show_header() + + console.print("\n[bold]Training[/bold]") + console.print("[cyan]1[/cyan] - Standard Training") + console.print("[cyan]2[/cyan] - Burst Training") + console.print("[cyan]3[/cyan] - View Model Checkpoints") + console.print("[cyan]4[/cyan] - Back") + + choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"]) if choice == "1": train_interactive() elif choice == "2": burst_train_interactive() elif choice == "3": - export_interactive() - elif choice == "4": - extract_tactical_interactive() - elif choice == "5": show_header() show_checkpoints_table() Prompt.ask("\nPress Enter to continue") - elif choice == "6": - console.print("[yellow]👋 Goodbye![/yellow]") + elif choice == "4": return + def train_interactive(): """Interactive training menu.""" console = Console() show_header() - console.print("\n[bold cyan]📚 Training Configuration[/bold cyan]") + console.print("\n[bold cyan]📚 Standard Training Configuration[/bold cyan]") + + # Dataset selection + datasets = list_datasets() + if not datasets: + console.print("[red]✗ No datasets available. Create one first.[/red]") + Prompt.ask("Press Enter to continue") + return + + console.print("\n[bold]Available Datasets:[/bold]") + show_datasets_table(console) + dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)")) + + if not any(v == dataset_version for v, _ in datasets): + console.print("[red]✗ Dataset not found[/red]") + return + + labeled_file = get_dataset_labeled_path(dataset_version) + if not labeled_file: + console.print("[red]✗ Dataset labeled.jsonl not found[/red]") + return # Checkpoint selection available = list_checkpoints() @@ -142,36 +620,6 @@ def train_interactive(): default=str(max(available)) ) - # Positions source - use_existing = Confirm.ask("Use existing positions file?", default=False) - positions_file = None - num_games = 500000 - samples_per_game = 1 - min_move = 1 - max_move = 50 - - if use_existing: - positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt")) - else: - num_games = int(Prompt.ask("Number of games to generate", default="5000")) - samples_per_game = int(Prompt.ask("Positions to sample per game", default="1")) - min_move = int(Prompt.ask("Minimum move number", default="1")) - max_move = int(Prompt.ask("Maximum move number", default="50")) - - use_existing_labels = Confirm.ask("Use existing labels file?", default=False) - labels_file = None - if use_existing_labels: - labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl")) - - # Stockfish path and labeling parameters - default_stockfish = os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish" - stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish) - stockfish_depth = 12 - num_workers = 1 - if not use_existing_labels: - stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12")) - num_workers = int(Prompt.ask("Number of parallel workers", default="1")) - # Training parameters epochs = int(Prompt.ask("Number of epochs", default="100")) batch_size = int(Prompt.ask("Batch size", default="16384")) @@ -182,16 +630,7 @@ def train_interactive(): # Confirm and start console.print("\n[bold]Configuration Summary:[/bold]") - if use_checkpoint: - console.print(f" Checkpoint: v{checkpoint_version}") - else: - console.print(" Checkpoint: None (training from scratch)") - if not use_existing: - console.print(f" Games: {num_games:,}") - console.print(f" Samples per game: {samples_per_game}") - console.print(f" Move range: {min_move}-{max_move}") - else: - console.print(f" Positions file: {positions_file}") + console.print(f" Dataset: ds_v{dataset_version}") console.print(f" Epochs: {epochs}") console.print(f" Batch size: {batch_size}") console.print(f" Subsample ratio: {subsample_ratio:.0%}") @@ -199,70 +638,27 @@ def train_interactive(): console.print(f" Early stopping: Yes (patience: {early_stopping})") else: console.print(f" Early stopping: No") - if not use_existing_labels: - console.print(f" Stockfish depth: {stockfish_depth}") - console.print(f" Workers: {num_workers}") - console.print(f" Stockfish: {stockfish_path}") + if use_checkpoint: + console.print(f" Checkpoint: v{checkpoint_version}") + else: + console.print(f" Checkpoint: None (training from scratch)") if not Confirm.ask("\nStart training?", default=True): console.print("[yellow]Training cancelled[/yellow]") + Prompt.ask("Press Enter to continue") return # Execute training - data_dir = get_data_dir() weights_dir = get_weights_dir() try: - # Generate positions - if not use_existing: - console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]") - count = play_random_game_and_collect_positions( - str(data_dir / "positions.txt"), - total_games=num_games, - samples_per_game=samples_per_game, - min_move=min_move, - max_move=max_move - ) - if count == 0: - console.print("[red]✗ No valid positions generated[/red]") - return - console.print(f"[green]✓ Generated {count:,} positions[/green]") - else: - if not Path(positions_file).exists(): - console.print(f"[red]✗ Positions file not found: {positions_file}[/red]") - return - - if not use_existing_labels: - # Label positions - console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]") - positions_file = data_dir / "positions.txt" - output_file = data_dir / "training_data.jsonl" - success = label_positions_with_stockfish( - str(positions_file), - str(output_file), - stockfish_path, - depth=stockfish_depth, - num_workers=num_workers - ) - if not success: - console.print("[red]✗ Position labeling failed[/red]") - return - console.print(f"[green]✓ Positions labeled[/green]") - else: - console.print("\n[bold cyan]Step 2: Loading Existing Labels[/bold cyan]") - output_file = labels_file - if not Path(output_file).exists(): - console.print(f"[red]✗ Labels file not found: {output_file}[/red]") - return - - # Train model - console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]") + console.print("\n[bold cyan]Training Model[/bold cyan]") checkpoint = None if use_checkpoint: checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt") train_nnue( - data_file=str(output_file), + data_file=str(labeled_file), output_file=str(weights_dir / "nnue_weights.pt"), epochs=epochs, batch_size=batch_size, @@ -286,6 +682,7 @@ def train_interactive(): traceback.print_exc() Prompt.ask("Press Enter to continue") + def burst_train_interactive(): """Interactive burst training menu.""" console = Console() @@ -294,14 +691,30 @@ def burst_train_interactive(): console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]") console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n") + # Dataset selection + datasets = list_datasets() + if not datasets: + console.print("[red]✗ No datasets available. Create one first.[/red]") + Prompt.ask("Press Enter to continue") + return + + console.print("[bold]Available Datasets:[/bold]") + show_datasets_table(console) + dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)")) + + if not any(v == dataset_version for v, _ in datasets): + console.print("[red]✗ Dataset not found[/red]") + return + + labeled_file = get_dataset_labeled_path(dataset_version) + if not labeled_file: + console.print("[red]✗ Dataset labeled.jsonl not found[/red]") + return + duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60")) epochs_per_season = int(Prompt.ask("Max epochs per season", default="50")) early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10")) - # Data file - default_labels = str(get_data_dir() / "training_data.jsonl") - labels_file = Prompt.ask("Path to labeled data file (.jsonl)", default=default_labels) - # Optional initial checkpoint available = list_checkpoints() checkpoint = None @@ -317,29 +730,25 @@ def burst_train_interactive(): # Summary console.print("\n[bold]Configuration Summary:[/bold]") + console.print(f" Dataset: ds_v{dataset_version}") console.print(f" Duration: {duration_minutes:.0f} minutes") console.print(f" Epochs per season: {epochs_per_season}") console.print(f" Patience: {early_stopping_patience}") - console.print(f" Data file: {labels_file}") - console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}") console.print(f" Batch size: {batch_size}") console.print(f" Subsample ratio: {subsample_ratio:.0%}") + console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}") if not Confirm.ask("\nStart burst training?", default=True): console.print("[yellow]Burst training cancelled[/yellow]") + Prompt.ask("Press Enter to continue") return weights_dir = get_weights_dir() try: - if not Path(labels_file).exists(): - console.print(f"[red]✗ Data file not found: {labels_file}[/red]") - Prompt.ask("Press Enter to continue") - return - console.print("\n[bold cyan]Burst Training[/bold cyan]") burst_train( - data_file=labels_file, + data_file=str(labeled_file), output_file=str(weights_dir / "nnue_weights.pt"), duration_minutes=duration_minutes, epochs_per_season=epochs_per_season, @@ -362,6 +771,7 @@ def burst_train_interactive(): traceback.print_exc() Prompt.ask("Press Enter to continue") + def export_interactive(): """Interactive export menu.""" console = Console() @@ -380,7 +790,7 @@ def export_interactive(): version = Prompt.ask("Enter version to export (e.g., 2)") weights_file = f"nnue_weights_v{version}.pt" - output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin") + output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.nbai") console.print(f"\n[bold]Export Configuration:[/bold]") console.print(f" Source: {weights_file}") @@ -399,7 +809,7 @@ def export_interactive(): return console.print("\n[bold cyan]Exporting Weights[/bold cyan]") - export_weights_to_binary(str(weights_path), output_file) + export_to_nbai(str(weights_path), output_file) console.print(f"\n[green]✓ Export complete![/green]") console.print(f"[bold]Weights saved to:[/bold] {output_file}") Prompt.ask("Press Enter to continue") @@ -410,65 +820,6 @@ def export_interactive(): traceback.print_exc() Prompt.ask("Press Enter to continue") -def extract_tactical_interactive(): - """Interactive tactical positions extraction and merge menu.""" - console = Console() - show_header() - - console.print("\n[bold cyan]♟️ Tactical Positions Extraction & Merge[/bold cyan]") - - # Download and extract options - console.print("\n[bold]Lichess Puzzle Database:[/bold]") - download_url = Prompt.ask( - "Download URL", - default="https://database.lichess.org/lichess_db_puzzle.csv.zst" - ) - - output_dir = Prompt.ask( - "Extract to directory", - default=str(Path(__file__).parent / "trainingdata") - ) - - max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000")) - - # Confirm and download - console.print("\n[bold]Configuration:[/bold]") - console.print(f" Download URL: {download_url}") - console.print(f" Extract directory: {output_dir}") - console.print(f" Max puzzles: {max_puzzles:,}") - - if not Confirm.ask("\nProceed?", default=True): - console.print("[yellow]Cancelled[/yellow]") - Prompt.ask("Press Enter to continue") - return - - try: - console.print("\n[bold cyan]Step 1: Download & Extract[/bold cyan]") - csv_path = download_and_extract_puzzle_db(download_url, output_dir) - - if not csv_path: - console.print("[red]✗ Failed to download/extract[/red]") - Prompt.ask("Press Enter to continue") - return - - console.print(f"[green]✓ Ready: {csv_path}[/green]") - - # Interactive merge - console.print("\n[bold cyan]Step 2: Extract & Merge[/bold cyan]") - output_file = Prompt.ask( - "Output file path", - default=str(Path(__file__).parent / "data" / "position.txt") - ) - - interactive_merge_positions(csv_path, output_file, max_puzzles) - console.print(f"\n[green]✓ Complete![/green]") - Prompt.ask("Press Enter to continue") - - except Exception as e: - console.print(f"[red]✗ Error: {e}[/red]") - import traceback - traceback.print_exc() - Prompt.ask("Press Enter to continue") def main(): try: @@ -481,7 +832,10 @@ def main(): except Exception as e: console = Console() console.print(f"[red]Error:[/red] {e}") + import traceback + traceback.print_exc() return 1 + if __name__ == "__main__": sys.exit(main()) diff --git a/modules/bot/python/src/dataset.py b/modules/bot/python/src/dataset.py new file mode 100644 index 0000000..dc89807 --- /dev/null +++ b/modules/bot/python/src/dataset.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +"""Dataset versioning and management for NNUE training data.""" + +import json +from pathlib import Path +from datetime import datetime +from typing import Optional, Dict, List, Tuple +from rich.console import Console +from rich.table import Table + + +def get_datasets_dir() -> Path: + """Get/create datasets directory.""" + datasets_dir = Path(__file__).parent.parent / "datasets" + datasets_dir.mkdir(exist_ok=True) + return datasets_dir + + +def next_dataset_version() -> int: + """Find the next available dataset version number.""" + datasets_dir = get_datasets_dir() + versions = [] + + for d in datasets_dir.iterdir(): + if d.is_dir() and d.name.startswith("ds_v"): + try: + v = int(d.name.split("_v")[1]) + versions.append(v) + except (ValueError, IndexError): + pass + + return max(versions) + 1 if versions else 1 + + +def list_datasets() -> List[Tuple[int, Dict]]: + """List all datasets with their metadata. + + Returns: + List of (version, metadata_dict) tuples, sorted by version. + """ + datasets_dir = get_datasets_dir() + datasets = [] + + for d in datasets_dir.iterdir(): + if d.is_dir() and d.name.startswith("ds_v"): + try: + v = int(d.name.split("_v")[1]) + metadata_file = d / "metadata.json" + if metadata_file.exists(): + with open(metadata_file, 'r') as f: + metadata = json.load(f) + datasets.append((v, metadata)) + except (ValueError, IndexError, json.JSONDecodeError): + pass + + return sorted(datasets, key=lambda x: x[0]) + + +def load_dataset_metadata(version: int) -> Optional[Dict]: + """Load metadata for a specific dataset version. + + Returns: + Metadata dict or None if not found. + """ + datasets_dir = get_datasets_dir() + metadata_file = datasets_dir / f"ds_v{version}" / "metadata.json" + + if not metadata_file.exists(): + return None + + with open(metadata_file, 'r') as f: + return json.load(f) + + +def save_dataset_metadata(version: int, metadata: Dict) -> None: + """Save metadata for a dataset version.""" + datasets_dir = get_datasets_dir() + dataset_dir = datasets_dir / f"ds_v{version}" + dataset_dir.mkdir(exist_ok=True) + + metadata_file = dataset_dir / "metadata.json" + with open(metadata_file, 'w') as f: + json.dump(metadata, f, indent=2, default=str) + + +def create_dataset( + version: int, + labeled_jsonl_path: str, + sources: List[Dict], + stockfish_depth: int = 12 +) -> Path: + """Create a new versioned dataset. + + Args: + version: Dataset version number + labeled_jsonl_path: Path to labeled.jsonl to copy + sources: List of source dicts (see plan for schema) + stockfish_depth: Depth used for labeling + + Returns: + Path to the created dataset directory. + """ + datasets_dir = get_datasets_dir() + dataset_dir = datasets_dir / f"ds_v{version}" + dataset_dir.mkdir(exist_ok=True) + + # Copy labeled data with deduplication (in case source has duplicates) + source_path = Path(labeled_jsonl_path) + if source_path.exists(): + dest_path = dataset_dir / "labeled.jsonl" + seen_fens = set() + unique_count = 0 + + with open(source_path, 'r') as src, open(dest_path, 'w') as dst: + for line in src: + try: + data = json.loads(line) + fen = data.get('fen') + if fen and fen not in seen_fens: + dst.write(line) + seen_fens.add(fen) + unique_count += 1 + except json.JSONDecodeError: + # Skip malformed lines + pass + + # Count positions + total_positions = 0 + if (dataset_dir / "labeled.jsonl").exists(): + with open(dataset_dir / "labeled.jsonl", 'r') as f: + total_positions = sum(1 for _ in f) + + # Create metadata + metadata = { + "version": version, + "created": datetime.now().isoformat(), + "total_positions": total_positions, + "stockfish_depth": stockfish_depth, + "sources": sources + } + + save_dataset_metadata(version, metadata) + return dataset_dir + + +def extend_dataset( + version: int, + new_labeled_path: str, + new_source_entry: Dict +) -> bool: + """Extend an existing dataset with new labeled positions (with deduplication). + + Args: + version: Dataset version to extend + new_labeled_path: Path to new labeled.jsonl to merge + new_source_entry: Source entry to add to metadata + + Returns: + True if successful, False otherwise. + """ + datasets_dir = get_datasets_dir() + dataset_dir = datasets_dir / f"ds_v{version}" + + if not dataset_dir.exists(): + return False + + labeled_file = dataset_dir / "labeled.jsonl" + new_labeled_file = Path(new_labeled_path) + + if not new_labeled_file.exists(): + return False + + # Load existing FENs (dedup set) — must load entire file to avoid duplicates + existing_fens = set() + if labeled_file.exists(): + with open(labeled_file, 'r') as f: + for line in f: + try: + data = json.loads(line) + fen = data.get('fen') + if fen: + existing_fens.add(fen) + except json.JSONDecodeError: + pass + + # Merge new positions, skipping duplicates + new_count = 0 + new_lines = [] + with open(new_labeled_file, 'r') as f_new: + for line in f_new: + try: + data = json.loads(line) + fen = data.get('fen') + if fen and fen not in existing_fens: + new_lines.append(line) + existing_fens.add(fen) + new_count += 1 + except json.JSONDecodeError: + pass + + # Append only the new, unique positions + if new_lines: + with open(labeled_file, 'a') as f_append: + for line in new_lines: + f_append.write(line) + + # Update metadata + metadata = load_dataset_metadata(version) + if metadata: + # Count total positions + total_positions = 0 + with open(labeled_file, 'r') as f: + total_positions = sum(1 for _ in f) + + metadata['total_positions'] = total_positions + # Update the source entry with actual count of new positions added + new_source_entry['actual_count'] = new_count + metadata['sources'].append(new_source_entry) + save_dataset_metadata(version, metadata) + + return True + + +def get_dataset_labeled_path(version: int) -> Optional[Path]: + """Get the path to a dataset's labeled.jsonl file. + + Returns: + Path to labeled.jsonl or None if dataset doesn't exist. + """ + datasets_dir = get_datasets_dir() + labeled_file = datasets_dir / f"ds_v{version}" / "labeled.jsonl" + + if labeled_file.exists(): + return labeled_file + return None + + +def delete_dataset(version: int) -> bool: + """Delete a dataset (recursively removes directory). + + Args: + version: Dataset version to delete + + Returns: + True if successful. + """ + datasets_dir = get_datasets_dir() + dataset_dir = datasets_dir / f"ds_v{version}" + + if not dataset_dir.exists(): + return False + + import shutil + shutil.rmtree(dataset_dir) + return True + + +def show_datasets_table(console: Console = None) -> None: + """Display all datasets in a Rich table.""" + if console is None: + console = Console() + + datasets = list_datasets() + + if not datasets: + console.print("[yellow]ℹ No datasets found yet[/yellow]") + return + + table = Table(title="Available Datasets", show_header=True, header_style="bold cyan") + table.add_column("Version", style="dim") + table.add_column("Positions", justify="right") + table.add_column("Sources", justify="left") + table.add_column("Depth", justify="center") + table.add_column("Created", justify="left") + + for v, metadata in datasets: + positions = metadata.get('total_positions', 0) + sources = metadata.get('sources', []) + source_str = ", ".join([s.get('type', '?') for s in sources]) + depth = metadata.get('stockfish_depth', '?') + created = metadata.get('created', '?') + if created != '?': + created = created.split('T')[0] # Just the date + + table.add_row(f"v{v}", f"{positions:,}", source_str, str(depth), created) + + console.print(table) diff --git a/modules/bot/python/src/export.py b/modules/bot/python/src/export.py index f20370b..864c8d8 100644 --- a/modules/bot/python/src/export.py +++ b/modules/bot/python/src/export.py @@ -1,67 +1,137 @@ #!/usr/bin/env python3 -"""Export NNUE weights to binary format for runtime loading.""" +"""Export NNUE weights to .nbai format for runtime loading.""" -import torch +import json import struct import sys +from datetime import datetime from pathlib import Path -def export_weights_to_binary(weights_file, output_file): - """Load PyTorch weights and export as binary file.""" +import torch +MAGIC = 0x4942_414E # bytes 'N','B','A','I' as little-endian int32 +VERSION = 1 + + +def _read_sidecar(weights_file: str) -> dict: + sidecar = weights_file.replace(".pt", "_metadata.json") + if Path(sidecar).exists(): + with open(sidecar) as f: + return json.load(f) + return {} + + +def _infer_layers(state_dict: dict) -> list[dict]: + """Derive layer descriptors from state_dict weight shapes. + + Assumes layers named l1, l2, ..., lN. + All hidden layers get activation 'relu'; the last gets 'linear'. + """ + names = sorted( + {k.split(".")[0] for k in state_dict if k.endswith(".weight")}, + key=lambda n: int(n[1:]), + ) + layers = [] + for i, name in enumerate(names): + out_size, in_size = state_dict[f"{name}.weight"].shape + activation = "linear" if i == len(names) - 1 else "relu" + layers.append({"activation": activation, "inputSize": int(in_size), "outputSize": int(out_size)}) + return layers + + +def _write_floats(f, tensor): + data = tensor.float().flatten().cpu().numpy() + f.write(struct.pack(" {l['outputSize']} [{l['activation']}]") - tensor = state_dict[layer_name] - # Convert to float32 and flatten - data = tensor.float().flatten().cpu().numpy() + Path(output_file).parent.mkdir(parents=True, exist_ok=True) - # Write shape (allows validation on load) - shape = list(tensor.shape) - f.write(struct.pack(' 1: weights_file = sys.argv[1] if len(sys.argv) > 2: output_file = sys.argv[2] + if len(sys.argv) > 3: + trained_by = sys.argv[3] + if len(sys.argv) > 4: + train_loss = float(sys.argv[4]) - export_weights_to_binary(weights_file, output_file) + export_to_nbai(weights_file, output_file, trained_by, train_loss) diff --git a/modules/bot/python/src/label.py b/modules/bot/python/src/label.py index 67b832b..5fa3ebb 100644 --- a/modules/bot/python/src/label.py +++ b/modules/bot/python/src/label.py @@ -125,6 +125,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path, # Load all FENs that need evaluation fens_to_evaluate = [] + fens_seen_in_batch = set() # Track duplicates within current batch skipped_invalid = 0 skipped_duplicate = 0 @@ -140,7 +141,12 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path, skipped_duplicate += 1 continue + if fen in fens_seen_in_batch: + skipped_duplicate += 1 + continue + fens_to_evaluate.append(fen) + fens_seen_in_batch.add(fen) total_to_evaluate = len(fens_to_evaluate) total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate @@ -178,8 +184,13 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path, with open(output_file, 'a') as out: for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)): for fen, eval_normalized, eval_cp in batch_results: + # Skip if already evaluated in output file during this run + if fen in evaluated_fens: + continue + data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp} out.write(json.dumps(data) + '\n') + evaluated_fens.add(fen) # Track as evaluated evaluated += 1 raw_evals.append(eval_cp) normalized_evals.append(eval_normalized) @@ -287,7 +298,7 @@ if __name__ == "__main__": help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')") parser.add_argument("--depth", type=int, default=12, help="Stockfish depth (default: 12)") - parser.add_argument("--batch-size", type=int, default=20, + parser.add_argument("--batch-size", type=int, default=1000, help="Batch size for processing (default: 1000)") parser.add_argument("--no-normalize", action="store_true", help="Disable evaluation normalization (keep raw centipawns)") diff --git a/modules/bot/python/src/tactical_positions_extractor.py b/modules/bot/python/src/tactical_positions_extractor.py index fd30efd..476e9b1 100644 --- a/modules/bot/python/src/tactical_positions_extractor.py +++ b/modules/bot/python/src/tactical_positions_extractor.py @@ -17,7 +17,7 @@ from generate import play_random_game_and_collect_positions def download_and_extract_puzzle_db( url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst', - output_dir: str = 'trainingdata' + output_dir: str = 'tactical_data' ): """Download and extract the Lichess puzzle database.""" output_path = Path(output_dir) @@ -141,6 +141,31 @@ def merge_positions( print(f"{'='*60}\n") +def extract_tactical_only( + puzzle_csv: str, + output_file: str, + max_puzzles: int = 300_000 +) -> int: + """Extract tactical positions and save to file (no merge prompts). + + Args: + puzzle_csv: Path to Lichess puzzle CSV + output_file: Where to save the FEN positions + max_puzzles: Maximum puzzles to extract + + Returns: + Number of positions extracted + """ + print("Extracting tactical positions from puzzle database...") + tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles) + + with open(output_file, 'w') as f: + for fen in tactical_positions: + f.write(fen + '\n') + + return len(tactical_positions) + + def interactive_merge_positions( puzzle_csv: str, output_file: str = 'position.txt', diff --git a/modules/bot/python/weights/nnue_weights_best_snapshot.pt b/modules/bot/python/weights/nnue_weights_best_snapshot.pt new file mode 100644 index 0000000..4e3ff54 Binary files /dev/null and b/modules/bot/python/weights/nnue_weights_best_snapshot.pt differ diff --git a/modules/bot/python/weights/nnue_weights_checkpoint.pt b/modules/bot/python/weights/nnue_weights_checkpoint.pt index ad5b4e4..e714522 100644 Binary files a/modules/bot/python/weights/nnue_weights_checkpoint.pt and b/modules/bot/python/weights/nnue_weights_checkpoint.pt differ diff --git a/modules/bot/python/weights/nnue_weights_v9.pt b/modules/bot/python/weights/nnue_weights_v9.pt new file mode 100644 index 0000000..56bff56 Binary files /dev/null and b/modules/bot/python/weights/nnue_weights_v9.pt differ diff --git a/modules/bot/python/weights/nnue_weights_v9_metadata.json b/modules/bot/python/weights/nnue_weights_v9_metadata.json new file mode 100644 index 0000000..8dc8fa8 --- /dev/null +++ b/modules/bot/python/weights/nnue_weights_v9_metadata.json @@ -0,0 +1,17 @@ +{ + "version": 9, + "date": "2026-04-13T20:19:08.123315", + "num_positions": 2522562, + "stockfish_depth": 12, + "final_val_loss": 6.994176222619626e-05, + "device": "cuda", + "notes": "Win rate vs classical eval: TBD (requires benchmark games)", + "mode": "burst", + "duration_minutes": 30.0, + "epochs_per_season": 50, + "early_stopping_patience": 10, + "seasons_completed": 3, + "batch_size": 16384, + "learning_rate": 0.001, + "initial_checkpoint": "/home/janis/Workspaces/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v8.pt" +} \ No newline at end of file diff --git a/modules/bot/src/main/resources/nnue_weights.bin b/modules/bot/src/main/resources/nnue_weights.nbai similarity index 66% rename from modules/bot/src/main/resources/nnue_weights.bin rename to modules/bot/src/main/resources/nnue_weights.nbai index 5d44bc0..c90ed1d 100644 Binary files a/modules/bot/src/main/resources/nnue_weights.bin and b/modules/bot/src/main/resources/nnue_weights.nbai differ diff --git a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala index 57fe624..478a021 100644 --- a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala +++ b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala @@ -6,7 +6,7 @@ import de.nowchess.bot.ai.Evaluation object EvaluationNNUE extends Evaluation: - private val nnue = NNUE() + private val nnue = NNUE(NbaiLoader.loadDefault()) val CHECKMATE_SCORE: Int = 10_000_000 val DRAW_SCORE: Int = 0 diff --git a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala index c4c17ca..3a23d62 100644 --- a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala +++ b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala @@ -3,84 +3,31 @@ package de.nowchess.bot.bots.nnue import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square} import de.nowchess.api.game.GameContext import de.nowchess.api.move.{Move, MoveType, PromotionPiece} -import java.nio.ByteBuffer -import java.nio.ByteOrder -class NNUE: +class NNUE(model: NbaiModel): - private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) = - loadWeights() + private val featureSize = model.layers(0).inputSize + private val accSize = model.layers(0).outputSize // Column-major L1 weights for cache-friendly sparse & incremental updates. - // l1WeightsT(featureIdx * 1536 + outputIdx) = l1Weights(outputIdx * 768 + featureIdx) + // l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx) private val l1WeightsT: Array[Float] = - val t = new Array[Float](768 * 1536) - for j <- 0 until 768; i <- 0 until 1536 do t(j * 1536 + i) = l1Weights(i * 768 + j) + val w = model.weights(0).weights + val t = new Array[Float](featureSize * accSize) + for j <- 0 until featureSize; i <- 0 until accSize do t(j * accSize + i) = w(i * featureSize + j) t - private def loadWeights(): ( - Array[Float], - Array[Float], - Array[Float], - Array[Float], - Array[Float], - Array[Float], - Array[Float], - Array[Float], - Array[Float], - Array[Float], - ) = - val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin")) - .getOrElse(sys.error("NNUE weights file not found in resources")) - - try - val bytes = stream.readAllBytes() - val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) - - val magic = buffer.getInt() - if magic != 0x4555_4e4e then sys.error(s"Invalid magic number: 0x${magic.toHexString}") - - val version = buffer.getInt() - if version != 1 then sys.error(s"Unsupported weight version: $version") - - val l1w = readTensor(buffer) - val l1b = readTensor(buffer) - val l2w = readTensor(buffer) - val l2b = readTensor(buffer) - val l3w = readTensor(buffer) - val l3b = readTensor(buffer) - val l4w = readTensor(buffer) - val l4b = readTensor(buffer) - val l5w = readTensor(buffer) - val l5b = readTensor(buffer) - - (l1w, l1b, l2w, l2b, l3w, l3b, l4w, l4b, l5w, l5b) - finally stream.close() - - private def readTensor(buffer: ByteBuffer): Array[Float] = - val shapeLen = buffer.getInt() - val shape = Array.ofDim[Int](shapeLen) - for i <- 0 until shapeLen do shape(i) = buffer.getInt() - val totalElements = shape.product - val floats = Array.ofDim[Float](totalElements) - for i <- 0 until totalElements do floats(i) = buffer.getFloat() - floats - // ── Accumulator stack ──────────────────────────────────────────────────── - // l1Stack(ply) holds the L1 pre-activations (before ReLU) for that ply. - // Initialised once at root; each child ply is derived incrementally. private val MAX_PLY = 128 - private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](1536)) + private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](accSize)) - // Shared buffers for the dense L2-L5 layers (single-threaded, non-reentrant). - private val l1ReLU = new Array[Float](1536) - private val l2Output = new Array[Float](1024) - private val l3Output = new Array[Float](512) - private val l4Output = new Array[Float](256) + // Shared evaluation buffers: index i holds the output of layers(i) (all except the scalar output layer). + private val evalBuffers: Array[Array[Float]] = model.layers.init.map(l => new Array[Float](l.outputSize)) // ── Eval cache ─────────────────────────────────────────────────────────── - private val EVAL_CACHE_MASK = (1 << 18) - 1L // 256 K slots ≈ 3 MB + + private val EVAL_CACHE_MASK = (1 << 18) - 1L private val evalCacheHashes = new Array[Long](1 << 18) private val evalCacheScores = new Array[Int](1 << 18) @@ -93,35 +40,32 @@ class NNUE: (colorOffset + piece.pieceType.ordinal) * 64 + sqNum private def addColumn(l1Pre: Array[Float], featureIdx: Int): Unit = - val offset = featureIdx * 1536 - for i <- 0 until 1536 do l1Pre(i) += l1WeightsT(offset + i) + val offset = featureIdx * accSize + for i <- 0 until accSize do l1Pre(i) += l1WeightsT(offset + i) private def subtractColumn(l1Pre: Array[Float], featureIdx: Int): Unit = - val offset = featureIdx * 1536 - for i <- 0 until 1536 do l1Pre(i) -= l1WeightsT(offset + i) + val offset = featureIdx * accSize + for i <- 0 until accSize do l1Pre(i) -= l1WeightsT(offset + i) // ── Accumulator init ───────────────────────────────────────────────────── - /** Initialise l1Stack(0) from scratch using sparse active features. */ def initAccumulator(board: Board): Unit = - System.arraycopy(l1Bias, 0, l1Stack(0), 0, 1536) + System.arraycopy(model.weights(0).bias, 0, l1Stack(0), 0, accSize) for (sq, piece) <- board.pieces do addColumn(l1Stack(0), featureIndex(piece, squareNum(sq))) // ── Accumulator push (incremental updates) ─────────────────────────────── - /** Copy parent ply's pre-activations to childPly, then apply move deltas. */ def pushAccumulator(childPly: Int, move: Move, board: Board): Unit = - System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, 1536) + System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize) val l1 = l1Stack(childPly) move.moveType match - case MoveType.Normal(_) => applyNormalDelta(l1, move, board) - case MoveType.EnPassant => applyEnPassantDelta(l1, move, board) - case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board) - case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board) + case MoveType.Normal(_) => applyNormalDelta(l1, move, board) + case MoveType.EnPassant => applyEnPassantDelta(l1, move, board) + case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board) + case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board) - /** Copy pre-activations from parentPly to childPly without any move delta (null-move). */ def copyAccumulator(parentPly: Int, childPly: Int): Unit = - System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, 1536) + System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize) private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit = board.pieceAt(move.from).foreach { mover => @@ -170,9 +114,6 @@ class NNUE: // ── Evaluation from accumulator ────────────────────────────────────────── - /** Evaluate from pre-computed L1 pre-activations at the given ply. Probes eval cache first; stores result after - * computation. - */ def evaluateAtPly(ply: Int, turn: Color, hash: Long): Int = val idx = (hash & EVAL_CACHE_MASK).toInt if evalCacheHashes(idx) == hash then evalCacheScores(idx) @@ -183,11 +124,19 @@ class NNUE: score private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int = - for i <- 0 until 1536 do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f - runDenseReLU(l1ReLU, 1536, l2Weights, l2Bias, l2Output, 1024) - runDenseReLU(l2Output, 1024, l3Weights, l3Bias, l3Output, 512) - runDenseReLU(l3Output, 512, l4Weights, l4Bias, l4Output, 256) - val output = runOutputLayer(l4Output, 256) + val l1ReLU = evalBuffers(0) + for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f + + var input = l1ReLU + for i <- 1 until model.layers.length - 1 do + val lw = model.weights(i) + val out = evalBuffers(i) + val ld = model.layers(i) + runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize) + input = out + + val lastIdx = model.layers.length - 1 + val output = runOutputLayer(input, model.layers(lastIdx).inputSize, model.weights(lastIdx)) scoreFromOutput(output, turn) private def runDenseReLU( @@ -202,8 +151,8 @@ class NNUE: val sum = (0 until inSize).foldLeft(bias(i))((s, j) => s + input(j) * weights(i * inSize + j)) output(i) = if sum > 0f then sum else 0f - private def runOutputLayer(input: Array[Float], inSize: Int): Float = - (0 until inSize).foldLeft(l5Bias(0))((sum, j) => sum + input(j) * l5Weights(j)) + private def runOutputLayer(input: Array[Float], inSize: Int, lw: LayerWeights): Float = + (0 until inSize).foldLeft(lw.bias(0))((sum, j) => sum + input(j) * lw.weights(j)) private def scoreFromOutput(output: Float, turn: Color): Int = val cp = @@ -214,21 +163,15 @@ class NNUE: val cpFromTurn = if turn == Color.Black then -cp else cp math.max(-20000, math.min(20000, cpFromTurn)) - // ── Legacy full-board evaluate (kept for Evaluation.evaluate compatibility) ── + // ── Legacy full-board evaluate ──────────────────────────────────────────── - // Pre-allocated buffers used only by the legacy evaluate path. - private val features = new Array[Float](768) - private val legacyL1 = new Array[Float](1536) + private val legacyL1 = new Array[Float](accSize) - /** Evaluate using full board scan (sparse over active features). Layout: black pieces at indices 0-5, white at 6-11. - */ def evaluate(context: GameContext): Int = - val l1Pre = legacyL1 - System.arraycopy(l1Bias, 0, l1Pre, 0, 1536) - for (sq, piece) <- context.board.pieces do addColumn(l1Pre, featureIndex(piece, squareNum(sq))) - runL2toOutput(l1Pre, context.turn) + System.arraycopy(model.weights(0).bias, 0, legacyL1, 0, accSize) + for (sq, piece) <- context.board.pieces do addColumn(legacyL1, featureIndex(piece, squareNum(sq))) + runL2toOutput(legacyL1, context.turn) - /** Benchmark: time 1M evaluations and report ns/eval. */ def benchmark(): Unit = val context = GameContext.initial val iterations = 1_000_000 diff --git a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiLoader.scala b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiLoader.scala new file mode 100644 index 0000000..e70f850 --- /dev/null +++ b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiLoader.scala @@ -0,0 +1,50 @@ +package de.nowchess.bot.bots.nnue + +import java.io.InputStream +import java.nio.{ByteBuffer, ByteOrder} +import java.nio.charset.StandardCharsets + +object NbaiLoader: + + /** Little-endian encoding of ASCII bytes 'N','B','A','I'. */ + val MAGIC: Int = 0x4942_414e + + def load(stream: InputStream): NbaiModel = + val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN) + checkHeader(buf) + val metadata = readMetadata(buf) + val descs = readLayerDescriptors(buf) + val weights = descs.map(_ => readLayerWeights(buf)) + NbaiModel(metadata, descs, weights) + + /** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */ + def loadDefault(): NbaiModel = + Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match + case Some(s) => try load(s) finally s.close() + case None => NbaiMigrator.migrateFromBin() + + private def checkHeader(buf: ByteBuffer): Unit = + val magic = buf.getInt() + if magic != MAGIC then sys.error(s"Invalid NBAI magic: 0x${magic.toHexString}") + val version = buf.getShort() & 0xffff + if version != 1 then sys.error(s"Unsupported NBAI version: $version") + + private def readMetadata(buf: ByteBuffer): NbaiMetadata = + val bytes = new Array[Byte](buf.getInt()) + buf.get(bytes) + NbaiMetadata.fromJson(new String(bytes, StandardCharsets.UTF_8)) + + private def readLayerDescriptors(buf: ByteBuffer): Array[LayerDescriptor] = + Array.tabulate(buf.getShort() & 0xffff) { _ => + val nameBytes = new Array[Byte](buf.get() & 0xff) + buf.get(nameBytes) + LayerDescriptor(new String(nameBytes, StandardCharsets.US_ASCII), buf.getInt(), buf.getInt()) + } + + private def readLayerWeights(buf: ByteBuffer): LayerWeights = + LayerWeights(readFloats(buf), readFloats(buf)) + + private def readFloats(buf: ByteBuffer): Array[Float] = + val arr = new Array[Float](buf.getInt()) + for i <- arr.indices do arr(i) = buf.getFloat() + arr diff --git a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiMigrator.scala b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiMigrator.scala new file mode 100644 index 0000000..bf19bce --- /dev/null +++ b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiMigrator.scala @@ -0,0 +1,43 @@ +package de.nowchess.bot.bots.nnue + +import java.nio.{ByteBuffer, ByteOrder} + +/** Converts the legacy nnue_weights.bin resource into an NbaiModel. Used as fallback when no .nbai file exists. */ +object NbaiMigrator: + + private val BinMagic = 0x4555_4e4e + private val BinVersion = 1 + + private val DefaultLayers: Array[LayerDescriptor] = Array( + LayerDescriptor("relu", 768, 1536), + LayerDescriptor("relu", 1536, 1024), + LayerDescriptor("relu", 1024, 512), + LayerDescriptor("relu", 512, 256), + LayerDescriptor("linear", 256, 1), + ) + + private val UnknownMetadata: NbaiMetadata = + NbaiMetadata(trainedBy = "unknown", trainedAt = "unknown", trainingDataCount = 0L, valLoss = 0.0, trainLoss = 0.0) + + def migrateFromBin(): NbaiModel = + val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin")) + .getOrElse(sys.error("Neither nnue_weights.nbai nor nnue_weights.bin found in resources")) + try + val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN) + checkBinHeader(buf) + val weights = DefaultLayers.map(_ => readBinLayerWeights(buf)) + NbaiModel(UnknownMetadata, DefaultLayers, weights) + finally stream.close() + + private def checkBinHeader(buf: ByteBuffer): Unit = + val magic = buf.getInt() + if magic != BinMagic then sys.error(s"Invalid bin magic: 0x${magic.toHexString}") + val version = buf.getInt() + if version != BinVersion then sys.error(s"Unsupported bin version: $version") + + private def readBinLayerWeights(buf: ByteBuffer): LayerWeights = + LayerWeights(readBinTensor(buf), readBinTensor(buf)) + + private def readBinTensor(buf: ByteBuffer): Array[Float] = + val shape = Array.tabulate(buf.getInt())(_ => buf.getInt()) + Array.tabulate(shape.product)(_ => buf.getFloat()) diff --git a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiModel.scala b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiModel.scala new file mode 100644 index 0000000..e9487b0 --- /dev/null +++ b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiModel.scala @@ -0,0 +1,39 @@ +package de.nowchess.bot.bots.nnue + +/** Descriptor for a single dense layer stored in a .nbai file. */ +case class LayerDescriptor(activation: String, inputSize: Int, outputSize: Int) + +/** Training metadata embedded in every .nbai file. */ +case class NbaiMetadata( + trainedBy: String, + trainedAt: String, + trainingDataCount: Long, + valLoss: Double, + trainLoss: Double, +): + def toJson: String = + s"""{ + | "trainedBy": "$trainedBy", + | "trainedAt": "$trainedAt", + | "trainingDataCount": $trainingDataCount, + | "valLoss": $valLoss, + | "trainLoss": $trainLoss + |}""".stripMargin + +object NbaiMetadata: + def fromJson(json: String): NbaiMetadata = + def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("") + def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0") + NbaiMetadata(str("trainedBy"), str("trainedAt"), num("trainingDataCount").toLong, num("valLoss").toDouble, num("trainLoss").toDouble) + +/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */ +case class LayerWeights(weights: Array[Float], bias: Array[Float]) + +/** A fully deserialized .nbai model ready to initialize NNUE. */ +case class NbaiModel( + metadata: NbaiMetadata, + layers: Array[LayerDescriptor], + weights: Array[LayerWeights], +): + require(layers.length == weights.length, "Layer count must match weight count") + require(layers.length >= 2, "Model must have at least 2 layers") diff --git a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiWriter.scala b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiWriter.scala new file mode 100644 index 0000000..d7e0a1b --- /dev/null +++ b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiWriter.scala @@ -0,0 +1,51 @@ +package de.nowchess.bot.bots.nnue + +import java.io.{ByteArrayOutputStream, OutputStream} +import java.nio.{ByteBuffer, ByteOrder} +import java.nio.charset.StandardCharsets + +object NbaiWriter: + + def write(model: NbaiModel, out: OutputStream): Unit = + val acc = new ByteArrayOutputStream() + writeHeader(acc) + writeMetadata(acc, model.metadata) + writeLayerDescriptors(acc, model.layers) + model.weights.foreach(lw => writeLayerWeights(acc, lw)) + out.write(acc.toByteArray) + + private def writeHeader(out: ByteArrayOutputStream): Unit = + val buf = ByteBuffer.allocate(6).order(ByteOrder.LITTLE_ENDIAN) + buf.putInt(NbaiLoader.MAGIC) + buf.putShort(1.toShort) + out.write(buf.array()) + + private def writeMetadata(out: ByteArrayOutputStream, meta: NbaiMetadata): Unit = + val json = meta.toJson.getBytes(StandardCharsets.UTF_8) + val buf = ByteBuffer.allocate(4 + json.length).order(ByteOrder.LITTLE_ENDIAN) + buf.putInt(json.length) + buf.put(json) + out.write(buf.array()) + + private def writeLayerDescriptors(out: ByteArrayOutputStream, layers: Array[LayerDescriptor]): Unit = + val nameBytes = layers.map(_.activation.getBytes(StandardCharsets.US_ASCII)) + val capacity = 2 + layers.indices.map(i => 1 + nameBytes(i).length + 8).sum + val buf = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN) + buf.putShort(layers.length.toShort) + layers.zip(nameBytes).foreach { (l, nb) => + buf.put(nb.length.toByte) + buf.put(nb) + buf.putInt(l.inputSize) + buf.putInt(l.outputSize) + } + out.write(buf.array()) + + private def writeLayerWeights(out: ByteArrayOutputStream, lw: LayerWeights): Unit = + writeFloats(out, lw.weights) + writeFloats(out, lw.bias) + + private def writeFloats(out: ByteArrayOutputStream, floats: Array[Float]): Unit = + val buf = ByteBuffer.allocate(4 + floats.length * 4).order(ByteOrder.LITTLE_ENDIAN) + buf.putInt(floats.length) + floats.foreach(buf.putFloat) + out.write(buf.array())