feat: Implement dataset versioning and management for NNUE training data
This commit is contained in:
@@ -19,3 +19,4 @@ ENV/
|
||||
*.swo
|
||||
tactical_data/
|
||||
trainingdata/
|
||||
/datasets/
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
# Training Dataset Management
|
||||
|
||||
The NNUE training pipeline now features versioned dataset management, similar to model versioning. This prevents data loss and allows you to maintain multiple training configurations.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
datasets/
|
||||
ds_v1/
|
||||
labeled.jsonl # Training data: {"fen": "...", "eval": 0.5, "eval_raw": 150}
|
||||
metadata.json # Version info and composition
|
||||
ds_v2/
|
||||
labeled.jsonl
|
||||
metadata.json
|
||||
```
|
||||
|
||||
## Metadata Schema
|
||||
|
||||
Each dataset has a `metadata.json` file tracking its composition:
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"created": "2026-04-13T15:30:45.123456",
|
||||
"total_positions": 1000000,
|
||||
"stockfish_depth": 12,
|
||||
"sources": [
|
||||
{
|
||||
"type": "generated",
|
||||
"count": 500000,
|
||||
"params": {
|
||||
"num_positions": 500000,
|
||||
"min_move": 1,
|
||||
"max_move": 50
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "tactical",
|
||||
"count": 300000,
|
||||
"max_puzzles": 300000
|
||||
},
|
||||
{
|
||||
"type": "file_import",
|
||||
"count": 200000,
|
||||
"path": "/path/to/original_file.txt"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## TUI Workflow
|
||||
|
||||
### Main Menu
|
||||
```
|
||||
1 - Manage Training Data
|
||||
2 - Train Model
|
||||
3 - Export Model
|
||||
4 - Exit
|
||||
```
|
||||
|
||||
### Training Data Management Submenu
|
||||
```
|
||||
1 - Create new dataset
|
||||
2 - Extend existing dataset
|
||||
3 - View all datasets
|
||||
4 - Delete dataset
|
||||
5 - Back
|
||||
```
|
||||
|
||||
## Creating a Dataset
|
||||
|
||||
Use the "Create new dataset" option to add data from one or more sources:
|
||||
|
||||
1. **Generate random positions** — Play random games and sample positions
|
||||
- Number of positions
|
||||
- Move range (min/max move number to sample from)
|
||||
- Number of worker threads
|
||||
|
||||
2. **Import from file** — Load positions from a FEN file
|
||||
- File must contain one FEN string per line
|
||||
- Duplicates are automatically removed
|
||||
|
||||
3. **Extract tactical puzzles** — Download and extract Lichess puzzle database
|
||||
- Maximum number of puzzles to include
|
||||
- Automatically filters for tactical themes (forks, pins, mates, etc.)
|
||||
|
||||
You can combine multiple sources in a single dataset creation session. All positions are:
|
||||
- Deduplicated (only unique FENs are kept)
|
||||
- Labeled with Stockfish evaluations
|
||||
- Saved to `datasets/ds_vN/labeled.jsonl`
|
||||
|
||||
## Extending a Dataset
|
||||
|
||||
Use "Extend existing dataset" to add more positions to an existing dataset:
|
||||
|
||||
1. Select the dataset version to extend
|
||||
2. Choose data sources (same options as creation)
|
||||
3. Confirm labeling parameters
|
||||
4. New positions are:
|
||||
- Labeled with Stockfish
|
||||
- Deduplicated against the target dataset (preventing duplicates)
|
||||
- Merged into the existing `labeled.jsonl`
|
||||
- Metadata is updated with the new source entry
|
||||
|
||||
## Training with a Dataset
|
||||
|
||||
When you start training (Standard or Burst mode), you'll be prompted to select a dataset version. The TUI will display all available datasets with:
|
||||
- Version number
|
||||
- Total number of positions
|
||||
- Source types (generated, tactical, imported)
|
||||
- Stockfish depth used
|
||||
- Creation date
|
||||
|
||||
## Legacy Data Migration
|
||||
|
||||
If you have existing labeled data in `data/training_data.jsonl` from before this update:
|
||||
|
||||
1. Open the "Manage Training Data" menu
|
||||
2. Choose "Create new dataset"
|
||||
3. Select "Import from file"
|
||||
4. Point to `data/training_data.jsonl`
|
||||
5. Complete the dataset creation
|
||||
|
||||
Alternatively, you can manually copy the file to `datasets/ds_v1/labeled.jsonl` and create a `metadata.json` file.
|
||||
|
||||
## Viewing Dataset Details
|
||||
|
||||
Use "View all datasets" to see a table of all datasets with:
|
||||
- Version number
|
||||
- Position count
|
||||
- Source composition
|
||||
- Stockfish depth
|
||||
- Creation date
|
||||
|
||||
## Deleting a Dataset
|
||||
|
||||
Use "Delete dataset" to remove a dataset and free up disk space. **This action cannot be undone.**
|
||||
|
||||
⚠️ The system does not prevent deleting datasets used by model checkpoints. Plan accordingly.
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Deduplication Strategy
|
||||
|
||||
When extending a dataset, positions are deduplicated **within that dataset only**. This allows different datasets to contain overlapping positions if desired.
|
||||
|
||||
When creating a new dataset from multiple sources, all sources are combined and deduplicated before labeling.
|
||||
|
||||
### Labeled Position Format
|
||||
|
||||
Each line in `labeled.jsonl` is a JSON object:
|
||||
```json
|
||||
{
|
||||
"fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
|
||||
"eval": 0.0,
|
||||
"eval_raw": 0
|
||||
}
|
||||
```
|
||||
|
||||
- `fen`: The position in Forsyth-Edwards Notation
|
||||
- `eval`: Normalized evaluation ([-1, 1] range using tanh)
|
||||
- `eval_raw`: Raw Stockfish evaluation in centipawns
|
||||
|
||||
### Storage Location
|
||||
|
||||
Datasets are stored in the `datasets/` directory relative to the script location. The old `data/` directory is preserved for backward compatibility but is not actively used by the new system.
|
||||
|
||||
## Performance Tips
|
||||
|
||||
- **Smaller datasets train faster** — Start with 100k-500k positions
|
||||
- **Deduplication matters** — Use the extend functionality to build up your dataset without redundant data
|
||||
- **Stockfish depth** — Depth 12-14 balances accuracy and labeling speed
|
||||
- **Workers** — Use 4-8 workers for labeling if your machine supports it; more workers = faster but uses more CPU/memory
|
||||
+547
-193
@@ -4,6 +4,7 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
@@ -17,23 +18,23 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||||
from generate import play_random_game_and_collect_positions
|
||||
from label import label_positions_with_stockfish
|
||||
from train import train_nnue, burst_train
|
||||
from export import export_weights_to_binary
|
||||
from export import export_to_nbai
|
||||
from tactical_positions_extractor import (
|
||||
download_and_extract_puzzle_db,
|
||||
interactive_merge_positions
|
||||
extract_tactical_only
|
||||
)
|
||||
from dataset import (
|
||||
get_datasets_dir,
|
||||
list_datasets,
|
||||
next_dataset_version,
|
||||
load_dataset_metadata,
|
||||
create_dataset,
|
||||
extend_dataset,
|
||||
get_dataset_labeled_path,
|
||||
delete_dataset,
|
||||
show_datasets_table
|
||||
)
|
||||
|
||||
def get_data_dir():
|
||||
"""Get/create data directory."""
|
||||
data_dir = Path(__file__).parent / "data"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
return data_dir
|
||||
|
||||
def get_tactical_data_dir():
|
||||
"""Get/create data directory."""
|
||||
data_dir = Path(__file__).parent / "tactical_data"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
return data_dir
|
||||
|
||||
def get_weights_dir():
|
||||
"""Get/create weights directory."""
|
||||
@@ -41,6 +42,14 @@ def get_weights_dir():
|
||||
weights_dir.mkdir(exist_ok=True)
|
||||
return weights_dir
|
||||
|
||||
|
||||
def get_data_dir():
|
||||
"""Get/create legacy data directory (for migration)."""
|
||||
data_dir = Path(__file__).parent / "data"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
return data_dir
|
||||
|
||||
|
||||
def list_checkpoints():
|
||||
"""List available checkpoint versions."""
|
||||
weights_dir = get_weights_dir()
|
||||
@@ -49,6 +58,20 @@ def list_checkpoints():
|
||||
return []
|
||||
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
||||
|
||||
|
||||
def migrate_legacy_data():
|
||||
"""On first run, offer to import existing data/training_data.jsonl as ds_v1."""
|
||||
console = Console()
|
||||
data_dir = get_data_dir()
|
||||
legacy_file = data_dir / "training_data.jsonl"
|
||||
datasets = list_datasets()
|
||||
|
||||
# Only migrate if legacy data exists and no datasets exist yet
|
||||
if legacy_file.exists() and not datasets:
|
||||
console.print("\n[cyan]Legacy data detected: data/training_data.jsonl[/cyan]")
|
||||
console.print("[dim]Tip: Use 'Manage Training Data' menu to import it as ds_v1[/dim]")
|
||||
|
||||
|
||||
def show_header():
|
||||
"""Display application header."""
|
||||
console = Console()
|
||||
@@ -56,22 +79,23 @@ def show_header():
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
|
||||
"[dim]Neural Network Utility Evaluation - Model Management[/dim]",
|
||||
"[dim]Neural Network Utility Evaluation - Dataset & Model Management[/dim]",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def show_checkpoints_table():
|
||||
"""Display available checkpoints in a table."""
|
||||
console = Console()
|
||||
available = list_checkpoints()
|
||||
|
||||
if not available:
|
||||
console.print("[yellow]ℹ No checkpoints found yet[/yellow]")
|
||||
console.print("[yellow]ℹ No model checkpoints found yet[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan")
|
||||
table = Table(title="Available Model Checkpoints", show_header=True, header_style="bold cyan")
|
||||
table.add_column("Version", style="dim")
|
||||
table.add_column("File Size", justify="right")
|
||||
table.add_column("Status", justify="center")
|
||||
@@ -87,46 +111,500 @@ def show_checkpoints_table():
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def show_main_menu():
|
||||
"""Display and handle main menu."""
|
||||
console = Console()
|
||||
|
||||
# Migrate legacy data on first run
|
||||
migrate_legacy_data()
|
||||
|
||||
while True:
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
|
||||
console.print("\n[bold]What would you like to do?[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Train NNUE Model")
|
||||
console.print("[cyan]2[/cyan] - Burst Train NNUE Model")
|
||||
console.print("[cyan]3[/cyan] - Export Weights to Scala")
|
||||
console.print("[cyan]4[/cyan] - Extract Tactical Positions")
|
||||
console.print("[cyan]5[/cyan] - View Checkpoints")
|
||||
console.print("[cyan]6[/cyan] - Exit")
|
||||
console.print("[cyan]1[/cyan] - Manage Training Data")
|
||||
console.print("[cyan]2[/cyan] - Train Model")
|
||||
console.print("[cyan]3[/cyan] - Export Model")
|
||||
console.print("[cyan]4[/cyan] - Exit")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5", "6"])
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||
|
||||
if choice == "1":
|
||||
datasets_menu()
|
||||
elif choice == "2":
|
||||
training_menu()
|
||||
elif choice == "3":
|
||||
export_interactive()
|
||||
elif choice == "4":
|
||||
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||
return
|
||||
|
||||
|
||||
def datasets_menu():
|
||||
"""Dataset management submenu."""
|
||||
console = Console()
|
||||
|
||||
while True:
|
||||
show_header()
|
||||
show_datasets_table(console)
|
||||
|
||||
console.print("\n[bold]Training Data Management[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Create new dataset")
|
||||
console.print("[cyan]2[/cyan] - Extend existing dataset")
|
||||
console.print("[cyan]3[/cyan] - View all datasets")
|
||||
console.print("[cyan]4[/cyan] - Delete dataset")
|
||||
console.print("[cyan]5[/cyan] - Back")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
|
||||
|
||||
if choice == "1":
|
||||
create_dataset_interactive()
|
||||
elif choice == "2":
|
||||
extend_dataset_interactive()
|
||||
elif choice == "3":
|
||||
show_header()
|
||||
show_datasets_table(console)
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
elif choice == "4":
|
||||
delete_dataset_interactive()
|
||||
elif choice == "5":
|
||||
return
|
||||
|
||||
|
||||
def create_dataset_interactive():
|
||||
"""Interactive dataset creation flow."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📊 Create New Dataset[/bold cyan]")
|
||||
|
||||
sources = []
|
||||
combined_count = 0
|
||||
|
||||
# Allow user to add multiple sources
|
||||
while True:
|
||||
console.print("\n[bold]Add data source (repeat until done):[/bold]")
|
||||
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||
console.print("[cyan]b[/cyan] - Import from file")
|
||||
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||
console.print("[cyan]d[/cyan] - Done adding sources")
|
||||
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
|
||||
|
||||
if choice == "a":
|
||||
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||
num_workers = int(Prompt.ask("Number of workers", default="8"))
|
||||
|
||||
console.print("[dim]Generating positions...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
count = play_random_game_and_collect_positions(
|
||||
str(temp_file),
|
||||
total_positions=num_positions,
|
||||
samples_per_game=1,
|
||||
min_move=min_move,
|
||||
max_move=max_move,
|
||||
num_workers=num_workers
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "generated",
|
||||
"count": count,
|
||||
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions generated[/green]")
|
||||
else:
|
||||
console.print("[red]✗ Generation failed[/red]")
|
||||
|
||||
elif choice == "b":
|
||||
file_path = Prompt.ask("Path to FEN file")
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
count = sum(1 for _ in f)
|
||||
sources.append({"type": "file_import", "count": count, "path": file_path})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions from file[/green]")
|
||||
except FileNotFoundError:
|
||||
console.print(f"[red]✗ File not found: {file_path}[/red]")
|
||||
|
||||
elif choice == "c":
|
||||
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||
console.print("[dim]Extracting tactical positions...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
try:
|
||||
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
|
||||
if csv_path:
|
||||
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
|
||||
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Tactical extraction failed: {e}[/red]")
|
||||
|
||||
elif choice == "d":
|
||||
if not sources:
|
||||
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||
continue
|
||||
break
|
||||
|
||||
if not sources:
|
||||
console.print("[yellow]Dataset creation cancelled[/yellow]")
|
||||
return
|
||||
|
||||
# Stockfish labeling parameters
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Summary and confirm
|
||||
console.print("\n[bold]Dataset Summary:[/bold]")
|
||||
console.print(f" Total positions: {combined_count:,}")
|
||||
for source in sources:
|
||||
console.print(f" - {source['type']}: {source['count']:,}")
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
|
||||
if not Confirm.ask("\nProceed to label and create dataset?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
# Combine all sources into one FEN file
|
||||
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
for source in sources:
|
||||
if source['type'] == 'generated':
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source['type'] == 'file_import':
|
||||
temp_file = Path(source['path'])
|
||||
elif source['type'] == 'tactical':
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
|
||||
if temp_file.exists():
|
||||
with open(temp_file, 'r') as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, 'w') as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + '\n')
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
|
||||
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers
|
||||
)
|
||||
|
||||
if not success:
|
||||
console.print("[red]✗ Labeling failed[/red]")
|
||||
return
|
||||
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
|
||||
# Save dataset
|
||||
console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]")
|
||||
version = next_dataset_version()
|
||||
create_dataset(
|
||||
version=version,
|
||||
labeled_jsonl_path=str(labeled_file),
|
||||
sources=sources,
|
||||
stockfish_depth=stockfish_depth
|
||||
)
|
||||
console.print(f"[green]✓ Dataset created: ds_v{version}[/green]")
|
||||
console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]")
|
||||
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def extend_dataset_interactive():
|
||||
"""Interactive dataset extension flow."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📊 Extend Existing Dataset[/bold cyan]")
|
||||
|
||||
datasets = list_datasets()
|
||||
if not datasets:
|
||||
console.print("[yellow]ℹ No datasets available to extend[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
show_datasets_table(console)
|
||||
version = int(Prompt.ask("\nEnter dataset version to extend (e.g., 1)"))
|
||||
|
||||
if not any(v == version for v, _ in datasets):
|
||||
console.print("[red]✗ Dataset not found[/red]")
|
||||
return
|
||||
|
||||
sources = []
|
||||
combined_count = 0
|
||||
|
||||
# Allow user to add sources
|
||||
while True:
|
||||
console.print("\n[bold]Add data source:[/bold]")
|
||||
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||
console.print("[cyan]b[/cyan] - Import from file")
|
||||
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||
console.print("[cyan]d[/cyan] - Done adding sources")
|
||||
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
|
||||
|
||||
if choice == "a":
|
||||
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||
num_workers = int(Prompt.ask("Number of workers", default="8"))
|
||||
|
||||
console.print("[dim]Generating positions...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
count = play_random_game_and_collect_positions(
|
||||
str(temp_file),
|
||||
total_positions=num_positions,
|
||||
samples_per_game=1,
|
||||
min_move=min_move,
|
||||
max_move=max_move,
|
||||
num_workers=num_workers
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "generated",
|
||||
"count": count,
|
||||
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions generated[/green]")
|
||||
|
||||
elif choice == "b":
|
||||
file_path = Prompt.ask("Path to FEN file")
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
count = sum(1 for _ in f)
|
||||
sources.append({"type": "file_import", "count": count, "path": file_path})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions from file[/green]")
|
||||
except FileNotFoundError:
|
||||
console.print(f"[red]✗ File not found: {file_path}[/red]")
|
||||
|
||||
elif choice == "c":
|
||||
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||
console.print("[dim]Extracting tactical positions...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
try:
|
||||
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
|
||||
if csv_path:
|
||||
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
|
||||
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Extraction failed: {e}[/red]")
|
||||
|
||||
elif choice == "d":
|
||||
if not sources:
|
||||
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||
continue
|
||||
break
|
||||
|
||||
if not sources:
|
||||
console.print("[yellow]Extension cancelled[/yellow]")
|
||||
return
|
||||
|
||||
# Stockfish labeling parameters
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Summary and confirm
|
||||
console.print("\n[bold]Extension Summary:[/bold]")
|
||||
console.print(f" Target dataset: ds_v{version}")
|
||||
console.print(f" New positions: {combined_count:,}")
|
||||
for source in sources:
|
||||
console.print(f" - {source['type']}: {source['count']:,}")
|
||||
|
||||
if not Confirm.ask("\nProceed to label and extend dataset?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
# Combine all sources into one FEN file
|
||||
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
for source in sources:
|
||||
if source['type'] == 'generated':
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source['type'] == 'file_import':
|
||||
temp_file = Path(source['path'])
|
||||
elif source['type'] == 'tactical':
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
|
||||
if temp_file.exists():
|
||||
with open(temp_file, 'r') as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, 'w') as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + '\n')
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
|
||||
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers
|
||||
)
|
||||
|
||||
if not success:
|
||||
console.print("[red]✗ Labeling failed[/red]")
|
||||
return
|
||||
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
|
||||
# Extend dataset
|
||||
console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]")
|
||||
success = extend_dataset(
|
||||
version=version,
|
||||
new_labeled_path=str(labeled_file),
|
||||
new_source_entry={
|
||||
"type": "merged_sources",
|
||||
"count": len(all_fens),
|
||||
"sources": sources
|
||||
}
|
||||
)
|
||||
|
||||
if success:
|
||||
metadata = load_dataset_metadata(version)
|
||||
console.print(f"[green]✓ Dataset extended[/green]")
|
||||
console.print(f"[bold]Total positions: {metadata['total_positions']:,}[/bold]")
|
||||
else:
|
||||
console.print("[red]✗ Extension failed[/red]")
|
||||
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def delete_dataset_interactive():
|
||||
"""Interactive dataset deletion."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]⚠️ Delete Dataset[/bold cyan]")
|
||||
|
||||
datasets = list_datasets()
|
||||
if not datasets:
|
||||
console.print("[yellow]ℹ No datasets to delete[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
show_datasets_table(console)
|
||||
version = int(Prompt.ask("\nEnter dataset version to delete (e.g., 1)"))
|
||||
|
||||
if not any(v == version for v, _ in datasets):
|
||||
console.print("[red]✗ Dataset not found[/red]")
|
||||
return
|
||||
|
||||
if Confirm.ask(f"Delete ds_v{version}? This cannot be undone.", default=False):
|
||||
if delete_dataset(version):
|
||||
console.print(f"[green]✓ Dataset ds_v{version} deleted[/green]")
|
||||
else:
|
||||
console.print("[red]✗ Deletion failed[/red]")
|
||||
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def training_menu():
|
||||
"""Training submenu."""
|
||||
console = Console()
|
||||
|
||||
while True:
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold]Training[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Standard Training")
|
||||
console.print("[cyan]2[/cyan] - Burst Training")
|
||||
console.print("[cyan]3[/cyan] - View Model Checkpoints")
|
||||
console.print("[cyan]4[/cyan] - Back")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||
|
||||
if choice == "1":
|
||||
train_interactive()
|
||||
elif choice == "2":
|
||||
burst_train_interactive()
|
||||
elif choice == "3":
|
||||
export_interactive()
|
||||
elif choice == "4":
|
||||
extract_tactical_interactive()
|
||||
elif choice == "5":
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
elif choice == "6":
|
||||
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||
elif choice == "4":
|
||||
return
|
||||
|
||||
|
||||
def train_interactive():
|
||||
"""Interactive training menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📚 Training Configuration[/bold cyan]")
|
||||
console.print("\n[bold cyan]📚 Standard Training Configuration[/bold cyan]")
|
||||
|
||||
# Dataset selection
|
||||
datasets = list_datasets()
|
||||
if not datasets:
|
||||
console.print("[red]✗ No datasets available. Create one first.[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print("\n[bold]Available Datasets:[/bold]")
|
||||
show_datasets_table(console)
|
||||
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
|
||||
|
||||
if not any(v == dataset_version for v, _ in datasets):
|
||||
console.print("[red]✗ Dataset not found[/red]")
|
||||
return
|
||||
|
||||
labeled_file = get_dataset_labeled_path(dataset_version)
|
||||
if not labeled_file:
|
||||
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
|
||||
return
|
||||
|
||||
# Checkpoint selection
|
||||
available = list_checkpoints()
|
||||
@@ -142,36 +620,6 @@ def train_interactive():
|
||||
default=str(max(available))
|
||||
)
|
||||
|
||||
# Positions source
|
||||
use_existing = Confirm.ask("Use existing positions file?", default=False)
|
||||
positions_file = None
|
||||
num_games = 500000
|
||||
samples_per_game = 1
|
||||
min_move = 1
|
||||
max_move = 50
|
||||
|
||||
if use_existing:
|
||||
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
|
||||
else:
|
||||
num_games = int(Prompt.ask("Number of games to generate", default="5000"))
|
||||
samples_per_game = int(Prompt.ask("Positions to sample per game", default="1"))
|
||||
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||
|
||||
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
|
||||
labels_file = None
|
||||
if use_existing_labels:
|
||||
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
|
||||
|
||||
# Stockfish path and labeling parameters
|
||||
default_stockfish = os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
|
||||
stockfish_depth = 12
|
||||
num_workers = 1
|
||||
if not use_existing_labels:
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Training parameters
|
||||
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
@@ -182,16 +630,7 @@ def train_interactive():
|
||||
|
||||
# Confirm and start
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
if use_checkpoint:
|
||||
console.print(f" Checkpoint: v{checkpoint_version}")
|
||||
else:
|
||||
console.print(" Checkpoint: None (training from scratch)")
|
||||
if not use_existing:
|
||||
console.print(f" Games: {num_games:,}")
|
||||
console.print(f" Samples per game: {samples_per_game}")
|
||||
console.print(f" Move range: {min_move}-{max_move}")
|
||||
else:
|
||||
console.print(f" Positions file: {positions_file}")
|
||||
console.print(f" Dataset: ds_v{dataset_version}")
|
||||
console.print(f" Epochs: {epochs}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||
@@ -199,70 +638,27 @@ def train_interactive():
|
||||
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||||
else:
|
||||
console.print(f" Early stopping: No")
|
||||
if not use_existing_labels:
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
console.print(f" Workers: {num_workers}")
|
||||
console.print(f" Stockfish: {stockfish_path}")
|
||||
if use_checkpoint:
|
||||
console.print(f" Checkpoint: v{checkpoint_version}")
|
||||
else:
|
||||
console.print(f" Checkpoint: None (training from scratch)")
|
||||
|
||||
if not Confirm.ask("\nStart training?", default=True):
|
||||
console.print("[yellow]Training cancelled[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
# Execute training
|
||||
data_dir = get_data_dir()
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
try:
|
||||
# Generate positions
|
||||
if not use_existing:
|
||||
console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]")
|
||||
count = play_random_game_and_collect_positions(
|
||||
str(data_dir / "positions.txt"),
|
||||
total_games=num_games,
|
||||
samples_per_game=samples_per_game,
|
||||
min_move=min_move,
|
||||
max_move=max_move
|
||||
)
|
||||
if count == 0:
|
||||
console.print("[red]✗ No valid positions generated[/red]")
|
||||
return
|
||||
console.print(f"[green]✓ Generated {count:,} positions[/green]")
|
||||
else:
|
||||
if not Path(positions_file).exists():
|
||||
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
|
||||
return
|
||||
|
||||
if not use_existing_labels:
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
positions_file = data_dir / "positions.txt"
|
||||
output_file = data_dir / "training_data.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(positions_file),
|
||||
str(output_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Position labeling failed[/red]")
|
||||
return
|
||||
console.print(f"[green]✓ Positions labeled[/green]")
|
||||
else:
|
||||
console.print("\n[bold cyan]Step 2: Loading Existing Labels[/bold cyan]")
|
||||
output_file = labels_file
|
||||
if not Path(output_file).exists():
|
||||
console.print(f"[red]✗ Labels file not found: {output_file}[/red]")
|
||||
return
|
||||
|
||||
# Train model
|
||||
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
|
||||
console.print("\n[bold cyan]Training Model[/bold cyan]")
|
||||
checkpoint = None
|
||||
if use_checkpoint:
|
||||
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
|
||||
|
||||
train_nnue(
|
||||
data_file=str(output_file),
|
||||
data_file=str(labeled_file),
|
||||
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
@@ -286,6 +682,7 @@ def train_interactive():
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def burst_train_interactive():
|
||||
"""Interactive burst training menu."""
|
||||
console = Console()
|
||||
@@ -294,14 +691,30 @@ def burst_train_interactive():
|
||||
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
|
||||
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
|
||||
|
||||
# Dataset selection
|
||||
datasets = list_datasets()
|
||||
if not datasets:
|
||||
console.print("[red]✗ No datasets available. Create one first.[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print("[bold]Available Datasets:[/bold]")
|
||||
show_datasets_table(console)
|
||||
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
|
||||
|
||||
if not any(v == dataset_version for v, _ in datasets):
|
||||
console.print("[red]✗ Dataset not found[/red]")
|
||||
return
|
||||
|
||||
labeled_file = get_dataset_labeled_path(dataset_version)
|
||||
if not labeled_file:
|
||||
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
|
||||
return
|
||||
|
||||
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
|
||||
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
|
||||
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
|
||||
|
||||
# Data file
|
||||
default_labels = str(get_data_dir() / "training_data.jsonl")
|
||||
labels_file = Prompt.ask("Path to labeled data file (.jsonl)", default=default_labels)
|
||||
|
||||
# Optional initial checkpoint
|
||||
available = list_checkpoints()
|
||||
checkpoint = None
|
||||
@@ -317,29 +730,25 @@ def burst_train_interactive():
|
||||
|
||||
# Summary
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
console.print(f" Dataset: ds_v{dataset_version}")
|
||||
console.print(f" Duration: {duration_minutes:.0f} minutes")
|
||||
console.print(f" Epochs per season: {epochs_per_season}")
|
||||
console.print(f" Patience: {early_stopping_patience}")
|
||||
console.print(f" Data file: {labels_file}")
|
||||
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
|
||||
|
||||
if not Confirm.ask("\nStart burst training?", default=True):
|
||||
console.print("[yellow]Burst training cancelled[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
try:
|
||||
if not Path(labels_file).exists():
|
||||
console.print(f"[red]✗ Data file not found: {labels_file}[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print("\n[bold cyan]Burst Training[/bold cyan]")
|
||||
burst_train(
|
||||
data_file=labels_file,
|
||||
data_file=str(labeled_file),
|
||||
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||
duration_minutes=duration_minutes,
|
||||
epochs_per_season=epochs_per_season,
|
||||
@@ -362,6 +771,7 @@ def burst_train_interactive():
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def export_interactive():
|
||||
"""Interactive export menu."""
|
||||
console = Console()
|
||||
@@ -380,7 +790,7 @@ def export_interactive():
|
||||
version = Prompt.ask("Enter version to export (e.g., 2)")
|
||||
|
||||
weights_file = f"nnue_weights_v{version}.pt"
|
||||
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
|
||||
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.nbai")
|
||||
|
||||
console.print(f"\n[bold]Export Configuration:[/bold]")
|
||||
console.print(f" Source: {weights_file}")
|
||||
@@ -399,7 +809,7 @@ def export_interactive():
|
||||
return
|
||||
|
||||
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
|
||||
export_weights_to_binary(str(weights_path), output_file)
|
||||
export_to_nbai(str(weights_path), output_file)
|
||||
console.print(f"\n[green]✓ Export complete![/green]")
|
||||
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
@@ -410,65 +820,6 @@ def export_interactive():
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def extract_tactical_interactive():
|
||||
"""Interactive tactical positions extraction and merge menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]♟️ Tactical Positions Extraction & Merge[/bold cyan]")
|
||||
|
||||
# Download and extract options
|
||||
console.print("\n[bold]Lichess Puzzle Database:[/bold]")
|
||||
download_url = Prompt.ask(
|
||||
"Download URL",
|
||||
default="https://database.lichess.org/lichess_db_puzzle.csv.zst"
|
||||
)
|
||||
|
||||
output_dir = Prompt.ask(
|
||||
"Extract to directory",
|
||||
default=str(Path(__file__).parent / "trainingdata")
|
||||
)
|
||||
|
||||
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||
|
||||
# Confirm and download
|
||||
console.print("\n[bold]Configuration:[/bold]")
|
||||
console.print(f" Download URL: {download_url}")
|
||||
console.print(f" Extract directory: {output_dir}")
|
||||
console.print(f" Max puzzles: {max_puzzles:,}")
|
||||
|
||||
if not Confirm.ask("\nProceed?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
try:
|
||||
console.print("\n[bold cyan]Step 1: Download & Extract[/bold cyan]")
|
||||
csv_path = download_and_extract_puzzle_db(download_url, output_dir)
|
||||
|
||||
if not csv_path:
|
||||
console.print("[red]✗ Failed to download/extract[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print(f"[green]✓ Ready: {csv_path}[/green]")
|
||||
|
||||
# Interactive merge
|
||||
console.print("\n[bold cyan]Step 2: Extract & Merge[/bold cyan]")
|
||||
output_file = Prompt.ask(
|
||||
"Output file path",
|
||||
default=str(Path(__file__).parent / "data" / "position.txt")
|
||||
)
|
||||
|
||||
interactive_merge_positions(csv_path, output_file, max_puzzles)
|
||||
console.print(f"\n[green]✓ Complete![/green]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def main():
|
||||
try:
|
||||
@@ -481,7 +832,10 @@ def main():
|
||||
except Exception as e:
|
||||
console = Console()
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
@@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Dataset versioning and management for NNUE training data."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
|
||||
def get_datasets_dir() -> Path:
|
||||
"""Get/create datasets directory."""
|
||||
datasets_dir = Path(__file__).parent.parent / "datasets"
|
||||
datasets_dir.mkdir(exist_ok=True)
|
||||
return datasets_dir
|
||||
|
||||
|
||||
def next_dataset_version() -> int:
|
||||
"""Find the next available dataset version number."""
|
||||
datasets_dir = get_datasets_dir()
|
||||
versions = []
|
||||
|
||||
for d in datasets_dir.iterdir():
|
||||
if d.is_dir() and d.name.startswith("ds_v"):
|
||||
try:
|
||||
v = int(d.name.split("_v")[1])
|
||||
versions.append(v)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return max(versions) + 1 if versions else 1
|
||||
|
||||
|
||||
def list_datasets() -> List[Tuple[int, Dict]]:
|
||||
"""List all datasets with their metadata.
|
||||
|
||||
Returns:
|
||||
List of (version, metadata_dict) tuples, sorted by version.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
datasets = []
|
||||
|
||||
for d in datasets_dir.iterdir():
|
||||
if d.is_dir() and d.name.startswith("ds_v"):
|
||||
try:
|
||||
v = int(d.name.split("_v")[1])
|
||||
metadata_file = d / "metadata.json"
|
||||
if metadata_file.exists():
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
datasets.append((v, metadata))
|
||||
except (ValueError, IndexError, json.JSONDecodeError):
|
||||
pass
|
||||
|
||||
return sorted(datasets, key=lambda x: x[0])
|
||||
|
||||
|
||||
def load_dataset_metadata(version: int) -> Optional[Dict]:
|
||||
"""Load metadata for a specific dataset version.
|
||||
|
||||
Returns:
|
||||
Metadata dict or None if not found.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
metadata_file = datasets_dir / f"ds_v{version}" / "metadata.json"
|
||||
|
||||
if not metadata_file.exists():
|
||||
return None
|
||||
|
||||
with open(metadata_file, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_dataset_metadata(version: int, metadata: Dict) -> None:
|
||||
"""Save metadata for a dataset version."""
|
||||
datasets_dir = get_datasets_dir()
|
||||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||
dataset_dir.mkdir(exist_ok=True)
|
||||
|
||||
metadata_file = dataset_dir / "metadata.json"
|
||||
with open(metadata_file, 'w') as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
|
||||
|
||||
def create_dataset(
|
||||
version: int,
|
||||
labeled_jsonl_path: str,
|
||||
sources: List[Dict],
|
||||
stockfish_depth: int = 12
|
||||
) -> Path:
|
||||
"""Create a new versioned dataset.
|
||||
|
||||
Args:
|
||||
version: Dataset version number
|
||||
labeled_jsonl_path: Path to labeled.jsonl to copy
|
||||
sources: List of source dicts (see plan for schema)
|
||||
stockfish_depth: Depth used for labeling
|
||||
|
||||
Returns:
|
||||
Path to the created dataset directory.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||
dataset_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Copy labeled data with deduplication (in case source has duplicates)
|
||||
source_path = Path(labeled_jsonl_path)
|
||||
if source_path.exists():
|
||||
dest_path = dataset_dir / "labeled.jsonl"
|
||||
seen_fens = set()
|
||||
unique_count = 0
|
||||
|
||||
with open(source_path, 'r') as src, open(dest_path, 'w') as dst:
|
||||
for line in src:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data.get('fen')
|
||||
if fen and fen not in seen_fens:
|
||||
dst.write(line)
|
||||
seen_fens.add(fen)
|
||||
unique_count += 1
|
||||
except json.JSONDecodeError:
|
||||
# Skip malformed lines
|
||||
pass
|
||||
|
||||
# Count positions
|
||||
total_positions = 0
|
||||
if (dataset_dir / "labeled.jsonl").exists():
|
||||
with open(dataset_dir / "labeled.jsonl", 'r') as f:
|
||||
total_positions = sum(1 for _ in f)
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"version": version,
|
||||
"created": datetime.now().isoformat(),
|
||||
"total_positions": total_positions,
|
||||
"stockfish_depth": stockfish_depth,
|
||||
"sources": sources
|
||||
}
|
||||
|
||||
save_dataset_metadata(version, metadata)
|
||||
return dataset_dir
|
||||
|
||||
|
||||
def extend_dataset(
|
||||
version: int,
|
||||
new_labeled_path: str,
|
||||
new_source_entry: Dict
|
||||
) -> bool:
|
||||
"""Extend an existing dataset with new labeled positions (with deduplication).
|
||||
|
||||
Args:
|
||||
version: Dataset version to extend
|
||||
new_labeled_path: Path to new labeled.jsonl to merge
|
||||
new_source_entry: Source entry to add to metadata
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||
|
||||
if not dataset_dir.exists():
|
||||
return False
|
||||
|
||||
labeled_file = dataset_dir / "labeled.jsonl"
|
||||
new_labeled_file = Path(new_labeled_path)
|
||||
|
||||
if not new_labeled_file.exists():
|
||||
return False
|
||||
|
||||
# Load existing FENs (dedup set) — must load entire file to avoid duplicates
|
||||
existing_fens = set()
|
||||
if labeled_file.exists():
|
||||
with open(labeled_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data.get('fen')
|
||||
if fen:
|
||||
existing_fens.add(fen)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Merge new positions, skipping duplicates
|
||||
new_count = 0
|
||||
new_lines = []
|
||||
with open(new_labeled_file, 'r') as f_new:
|
||||
for line in f_new:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data.get('fen')
|
||||
if fen and fen not in existing_fens:
|
||||
new_lines.append(line)
|
||||
existing_fens.add(fen)
|
||||
new_count += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Append only the new, unique positions
|
||||
if new_lines:
|
||||
with open(labeled_file, 'a') as f_append:
|
||||
for line in new_lines:
|
||||
f_append.write(line)
|
||||
|
||||
# Update metadata
|
||||
metadata = load_dataset_metadata(version)
|
||||
if metadata:
|
||||
# Count total positions
|
||||
total_positions = 0
|
||||
with open(labeled_file, 'r') as f:
|
||||
total_positions = sum(1 for _ in f)
|
||||
|
||||
metadata['total_positions'] = total_positions
|
||||
# Update the source entry with actual count of new positions added
|
||||
new_source_entry['actual_count'] = new_count
|
||||
metadata['sources'].append(new_source_entry)
|
||||
save_dataset_metadata(version, metadata)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_dataset_labeled_path(version: int) -> Optional[Path]:
|
||||
"""Get the path to a dataset's labeled.jsonl file.
|
||||
|
||||
Returns:
|
||||
Path to labeled.jsonl or None if dataset doesn't exist.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
labeled_file = datasets_dir / f"ds_v{version}" / "labeled.jsonl"
|
||||
|
||||
if labeled_file.exists():
|
||||
return labeled_file
|
||||
return None
|
||||
|
||||
|
||||
def delete_dataset(version: int) -> bool:
|
||||
"""Delete a dataset (recursively removes directory).
|
||||
|
||||
Args:
|
||||
version: Dataset version to delete
|
||||
|
||||
Returns:
|
||||
True if successful.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||
|
||||
if not dataset_dir.exists():
|
||||
return False
|
||||
|
||||
import shutil
|
||||
shutil.rmtree(dataset_dir)
|
||||
return True
|
||||
|
||||
|
||||
def show_datasets_table(console: Console = None) -> None:
|
||||
"""Display all datasets in a Rich table."""
|
||||
if console is None:
|
||||
console = Console()
|
||||
|
||||
datasets = list_datasets()
|
||||
|
||||
if not datasets:
|
||||
console.print("[yellow]ℹ No datasets found yet[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Available Datasets", show_header=True, header_style="bold cyan")
|
||||
table.add_column("Version", style="dim")
|
||||
table.add_column("Positions", justify="right")
|
||||
table.add_column("Sources", justify="left")
|
||||
table.add_column("Depth", justify="center")
|
||||
table.add_column("Created", justify="left")
|
||||
|
||||
for v, metadata in datasets:
|
||||
positions = metadata.get('total_positions', 0)
|
||||
sources = metadata.get('sources', [])
|
||||
source_str = ", ".join([s.get('type', '?') for s in sources])
|
||||
depth = metadata.get('stockfish_depth', '?')
|
||||
created = metadata.get('created', '?')
|
||||
if created != '?':
|
||||
created = created.split('T')[0] # Just the date
|
||||
|
||||
table.add_row(f"v{v}", f"{positions:,}", source_str, str(depth), created)
|
||||
|
||||
console.print(table)
|
||||
@@ -1,67 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Export NNUE weights to binary format for runtime loading."""
|
||||
"""Export NNUE weights to .nbai format for runtime loading."""
|
||||
|
||||
import torch
|
||||
import json
|
||||
import struct
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
def export_weights_to_binary(weights_file, output_file):
|
||||
"""Load PyTorch weights and export as binary file."""
|
||||
import torch
|
||||
|
||||
if not Path(weights_file).exists():
|
||||
print(f"Error: Weights file not found at {weights_file}")
|
||||
sys.exit(1)
|
||||
MAGIC = 0x4942_414E # bytes 'N','B','A','I' as little-endian int32
|
||||
VERSION = 1
|
||||
|
||||
# Load weights — handle both raw state dicts and full training checkpoints
|
||||
loaded = torch.load(weights_file, map_location='cpu')
|
||||
state_dict = loaded["model_state_dict"] if isinstance(loaded, dict) and "model_state_dict" in loaded else loaded
|
||||
|
||||
# Debug: print available layers
|
||||
print(f"Available layers in {weights_file}:")
|
||||
for key in sorted(state_dict.keys()):
|
||||
print(f" {key}: {state_dict[key].shape}")
|
||||
def _read_sidecar(weights_file: str) -> dict:
|
||||
sidecar = weights_file.replace(".pt", "_metadata.json")
|
||||
if Path(sidecar).exists():
|
||||
with open(sidecar) as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
# Create output directory if needed
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_file, 'wb') as f:
|
||||
# Write magic number and version
|
||||
f.write(b'NNUE')
|
||||
f.write(struct.pack('<I', 1)) # version 1
|
||||
def _infer_layers(state_dict: dict) -> list[dict]:
|
||||
"""Derive layer descriptors from state_dict weight shapes.
|
||||
|
||||
# Write each weight tensor in order
|
||||
for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias', 'l4.weight', 'l4.bias', 'l5.weight', 'l5.bias']:
|
||||
if layer_name not in state_dict:
|
||||
print(f"Error: Missing layer {layer_name}")
|
||||
sys.exit(1)
|
||||
Assumes layers named l1, l2, ..., lN.
|
||||
All hidden layers get activation 'relu'; the last gets 'linear'.
|
||||
"""
|
||||
names = sorted(
|
||||
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
|
||||
key=lambda n: int(n[1:]),
|
||||
)
|
||||
layers = []
|
||||
for i, name in enumerate(names):
|
||||
out_size, in_size = state_dict[f"{name}.weight"].shape
|
||||
activation = "linear" if i == len(names) - 1 else "relu"
|
||||
layers.append({"activation": activation, "inputSize": int(in_size), "outputSize": int(out_size)})
|
||||
return layers
|
||||
|
||||
tensor = state_dict[layer_name]
|
||||
# Convert to float32 and flatten
|
||||
|
||||
def _write_floats(f, tensor):
|
||||
data = tensor.float().flatten().cpu().numpy()
|
||||
f.write(struct.pack("<I", len(data)))
|
||||
f.write(struct.pack(f"<{len(data)}f", *data))
|
||||
|
||||
# Write shape (allows validation on load)
|
||||
shape = list(tensor.shape)
|
||||
f.write(struct.pack('<I', len(shape)))
|
||||
for dim in shape:
|
||||
f.write(struct.pack('<I', dim))
|
||||
|
||||
# Write flattened data as binary floats
|
||||
f.write(struct.pack(f'<{len(data)}f', *data))
|
||||
def export_to_nbai(
|
||||
weights_file: str,
|
||||
output_file: str,
|
||||
trained_by: str = "unknown",
|
||||
train_loss: float = 0.0,
|
||||
):
|
||||
if not Path(weights_file).exists():
|
||||
print(f"Error: weights file not found at {weights_file}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f" {layer_name}: shape {shape}, {len(data)} floats")
|
||||
loaded = torch.load(weights_file, map_location="cpu")
|
||||
state_dict = (
|
||||
loaded["model_state_dict"]
|
||||
if isinstance(loaded, dict) and "model_state_dict" in loaded
|
||||
else loaded
|
||||
)
|
||||
|
||||
sidecar = _read_sidecar(weights_file)
|
||||
val_loss = float(loaded.get("best_val_loss", sidecar.get("final_val_loss", 0.0))) if isinstance(loaded, dict) else 0.0
|
||||
trained_at = sidecar.get("date", datetime.now().isoformat())
|
||||
training_data_count = int(sidecar.get("num_positions", 0))
|
||||
|
||||
metadata = {
|
||||
"trainedBy": trained_by,
|
||||
"trainedAt": trained_at,
|
||||
"trainingDataCount": training_data_count,
|
||||
"valLoss": val_loss,
|
||||
"trainLoss": train_loss,
|
||||
}
|
||||
|
||||
layers = _infer_layers(state_dict)
|
||||
layer_names = sorted(
|
||||
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
|
||||
key=lambda n: int(n[1:]),
|
||||
)
|
||||
|
||||
print(f"Architecture ({len(layers)} layers):")
|
||||
for i, l in enumerate(layers):
|
||||
print(f" l{i + 1}: {l['inputSize']} -> {l['outputSize']} [{l['activation']}]")
|
||||
|
||||
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_file, "wb") as f:
|
||||
# Header
|
||||
f.write(struct.pack("<I", MAGIC))
|
||||
f.write(struct.pack("<H", VERSION))
|
||||
|
||||
# Metadata (length-prefixed UTF-8 JSON)
|
||||
meta_bytes = json.dumps(metadata, indent=2).encode("utf-8")
|
||||
f.write(struct.pack("<I", len(meta_bytes)))
|
||||
f.write(meta_bytes)
|
||||
|
||||
# Layer descriptors
|
||||
f.write(struct.pack("<H", len(layers)))
|
||||
for layer in layers:
|
||||
name_bytes = layer["activation"].encode("ascii")
|
||||
f.write(struct.pack("<B", len(name_bytes)))
|
||||
f.write(name_bytes)
|
||||
f.write(struct.pack("<I", layer["inputSize"]))
|
||||
f.write(struct.pack("<I", layer["outputSize"]))
|
||||
|
||||
# Weights: weight tensor then bias tensor per layer
|
||||
for name in layer_names:
|
||||
w = state_dict[f"{name}.weight"]
|
||||
b = state_dict[f"{name}.bias"]
|
||||
_write_floats(f, w)
|
||||
_write_floats(f, b)
|
||||
print(f" Wrote {name}: weight {tuple(w.shape)}, bias {tuple(b.shape)}")
|
||||
|
||||
size_mb = Path(output_file).stat().st_size / (1024 ** 2)
|
||||
print(f"\nExported to {output_file} ({size_mb:.2f} MB)")
|
||||
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
|
||||
file_size_mb = output_path.stat().st_size / (1024**2)
|
||||
print(f"Weights exported to {output_file} ({file_size_mb:.2f} MB)")
|
||||
|
||||
if __name__ == "__main__":
|
||||
weights_file = "nnue_weights.pt"
|
||||
output_file = "../src/main/resources/nnue_weights.bin"
|
||||
output_file = "../src/main/resources/nnue_weights.nbai"
|
||||
trained_by = "unknown"
|
||||
train_loss = 0.0
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
weights_file = sys.argv[1]
|
||||
if len(sys.argv) > 2:
|
||||
output_file = sys.argv[2]
|
||||
if len(sys.argv) > 3:
|
||||
trained_by = sys.argv[3]
|
||||
if len(sys.argv) > 4:
|
||||
train_loss = float(sys.argv[4])
|
||||
|
||||
export_weights_to_binary(weights_file, output_file)
|
||||
export_to_nbai(weights_file, output_file, trained_by, train_loss)
|
||||
|
||||
@@ -125,6 +125,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
||||
|
||||
# Load all FENs that need evaluation
|
||||
fens_to_evaluate = []
|
||||
fens_seen_in_batch = set() # Track duplicates within current batch
|
||||
skipped_invalid = 0
|
||||
skipped_duplicate = 0
|
||||
|
||||
@@ -140,7 +141,12 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
||||
skipped_duplicate += 1
|
||||
continue
|
||||
|
||||
if fen in fens_seen_in_batch:
|
||||
skipped_duplicate += 1
|
||||
continue
|
||||
|
||||
fens_to_evaluate.append(fen)
|
||||
fens_seen_in_batch.add(fen)
|
||||
|
||||
total_to_evaluate = len(fens_to_evaluate)
|
||||
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
|
||||
@@ -178,8 +184,13 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
||||
with open(output_file, 'a') as out:
|
||||
for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
|
||||
for fen, eval_normalized, eval_cp in batch_results:
|
||||
# Skip if already evaluated in output file during this run
|
||||
if fen in evaluated_fens:
|
||||
continue
|
||||
|
||||
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
|
||||
out.write(json.dumps(data) + '\n')
|
||||
evaluated_fens.add(fen) # Track as evaluated
|
||||
evaluated += 1
|
||||
raw_evals.append(eval_cp)
|
||||
normalized_evals.append(eval_normalized)
|
||||
@@ -287,7 +298,7 @@ if __name__ == "__main__":
|
||||
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
||||
parser.add_argument("--depth", type=int, default=12,
|
||||
help="Stockfish depth (default: 12)")
|
||||
parser.add_argument("--batch-size", type=int, default=20,
|
||||
parser.add_argument("--batch-size", type=int, default=1000,
|
||||
help="Batch size for processing (default: 1000)")
|
||||
parser.add_argument("--no-normalize", action="store_true",
|
||||
help="Disable evaluation normalization (keep raw centipawns)")
|
||||
|
||||
@@ -17,7 +17,7 @@ from generate import play_random_game_and_collect_positions
|
||||
|
||||
def download_and_extract_puzzle_db(
|
||||
url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst',
|
||||
output_dir: str = 'trainingdata'
|
||||
output_dir: str = 'tactical_data'
|
||||
):
|
||||
"""Download and extract the Lichess puzzle database."""
|
||||
output_path = Path(output_dir)
|
||||
@@ -141,6 +141,31 @@ def merge_positions(
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def extract_tactical_only(
|
||||
puzzle_csv: str,
|
||||
output_file: str,
|
||||
max_puzzles: int = 300_000
|
||||
) -> int:
|
||||
"""Extract tactical positions and save to file (no merge prompts).
|
||||
|
||||
Args:
|
||||
puzzle_csv: Path to Lichess puzzle CSV
|
||||
output_file: Where to save the FEN positions
|
||||
max_puzzles: Maximum puzzles to extract
|
||||
|
||||
Returns:
|
||||
Number of positions extracted
|
||||
"""
|
||||
print("Extracting tactical positions from puzzle database...")
|
||||
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
for fen in tactical_positions:
|
||||
f.write(fen + '\n')
|
||||
|
||||
return len(tactical_positions)
|
||||
|
||||
|
||||
def interactive_merge_positions(
|
||||
puzzle_csv: str,
|
||||
output_file: str = 'position.txt',
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"version": 9,
|
||||
"date": "2026-04-13T20:19:08.123315",
|
||||
"num_positions": 2522562,
|
||||
"stockfish_depth": 12,
|
||||
"final_val_loss": 6.994176222619626e-05,
|
||||
"device": "cuda",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
|
||||
"mode": "burst",
|
||||
"duration_minutes": 30.0,
|
||||
"epochs_per_season": 50,
|
||||
"early_stopping_patience": 10,
|
||||
"seasons_completed": 3,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"initial_checkpoint": "/home/janis/Workspaces/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v8.pt"
|
||||
}
|
||||
BIN
Binary file not shown.
@@ -6,7 +6,7 @@ import de.nowchess.bot.ai.Evaluation
|
||||
|
||||
object EvaluationNNUE extends Evaluation:
|
||||
|
||||
private val nnue = NNUE()
|
||||
private val nnue = NNUE(NbaiLoader.loadDefault())
|
||||
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
|
||||
@@ -3,84 +3,31 @@ package de.nowchess.bot.bots.nnue
|
||||
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square}
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
|
||||
class NNUE:
|
||||
class NNUE(model: NbaiModel):
|
||||
|
||||
private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) =
|
||||
loadWeights()
|
||||
private val featureSize = model.layers(0).inputSize
|
||||
private val accSize = model.layers(0).outputSize
|
||||
|
||||
// Column-major L1 weights for cache-friendly sparse & incremental updates.
|
||||
// l1WeightsT(featureIdx * 1536 + outputIdx) = l1Weights(outputIdx * 768 + featureIdx)
|
||||
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
|
||||
private val l1WeightsT: Array[Float] =
|
||||
val t = new Array[Float](768 * 1536)
|
||||
for j <- 0 until 768; i <- 0 until 1536 do t(j * 1536 + i) = l1Weights(i * 768 + j)
|
||||
val w = model.weights(0).weights
|
||||
val t = new Array[Float](featureSize * accSize)
|
||||
for j <- 0 until featureSize; i <- 0 until accSize do t(j * accSize + i) = w(i * featureSize + j)
|
||||
t
|
||||
|
||||
private def loadWeights(): (
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
) =
|
||||
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
|
||||
.getOrElse(sys.error("NNUE weights file not found in resources"))
|
||||
|
||||
try
|
||||
val bytes = stream.readAllBytes()
|
||||
val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
||||
|
||||
val magic = buffer.getInt()
|
||||
if magic != 0x4555_4e4e then sys.error(s"Invalid magic number: 0x${magic.toHexString}")
|
||||
|
||||
val version = buffer.getInt()
|
||||
if version != 1 then sys.error(s"Unsupported weight version: $version")
|
||||
|
||||
val l1w = readTensor(buffer)
|
||||
val l1b = readTensor(buffer)
|
||||
val l2w = readTensor(buffer)
|
||||
val l2b = readTensor(buffer)
|
||||
val l3w = readTensor(buffer)
|
||||
val l3b = readTensor(buffer)
|
||||
val l4w = readTensor(buffer)
|
||||
val l4b = readTensor(buffer)
|
||||
val l5w = readTensor(buffer)
|
||||
val l5b = readTensor(buffer)
|
||||
|
||||
(l1w, l1b, l2w, l2b, l3w, l3b, l4w, l4b, l5w, l5b)
|
||||
finally stream.close()
|
||||
|
||||
private def readTensor(buffer: ByteBuffer): Array[Float] =
|
||||
val shapeLen = buffer.getInt()
|
||||
val shape = Array.ofDim[Int](shapeLen)
|
||||
for i <- 0 until shapeLen do shape(i) = buffer.getInt()
|
||||
val totalElements = shape.product
|
||||
val floats = Array.ofDim[Float](totalElements)
|
||||
for i <- 0 until totalElements do floats(i) = buffer.getFloat()
|
||||
floats
|
||||
|
||||
// ── Accumulator stack ────────────────────────────────────────────────────
|
||||
// l1Stack(ply) holds the L1 pre-activations (before ReLU) for that ply.
|
||||
// Initialised once at root; each child ply is derived incrementally.
|
||||
|
||||
private val MAX_PLY = 128
|
||||
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](1536))
|
||||
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](accSize))
|
||||
|
||||
// Shared buffers for the dense L2-L5 layers (single-threaded, non-reentrant).
|
||||
private val l1ReLU = new Array[Float](1536)
|
||||
private val l2Output = new Array[Float](1024)
|
||||
private val l3Output = new Array[Float](512)
|
||||
private val l4Output = new Array[Float](256)
|
||||
// Shared evaluation buffers: index i holds the output of layers(i) (all except the scalar output layer).
|
||||
private val evalBuffers: Array[Array[Float]] = model.layers.init.map(l => new Array[Float](l.outputSize))
|
||||
|
||||
// ── Eval cache ───────────────────────────────────────────────────────────
|
||||
private val EVAL_CACHE_MASK = (1 << 18) - 1L // 256 K slots ≈ 3 MB
|
||||
|
||||
private val EVAL_CACHE_MASK = (1 << 18) - 1L
|
||||
private val evalCacheHashes = new Array[Long](1 << 18)
|
||||
private val evalCacheScores = new Array[Int](1 << 18)
|
||||
|
||||
@@ -93,25 +40,23 @@ class NNUE:
|
||||
(colorOffset + piece.pieceType.ordinal) * 64 + sqNum
|
||||
|
||||
private def addColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
||||
val offset = featureIdx * 1536
|
||||
for i <- 0 until 1536 do l1Pre(i) += l1WeightsT(offset + i)
|
||||
val offset = featureIdx * accSize
|
||||
for i <- 0 until accSize do l1Pre(i) += l1WeightsT(offset + i)
|
||||
|
||||
private def subtractColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
||||
val offset = featureIdx * 1536
|
||||
for i <- 0 until 1536 do l1Pre(i) -= l1WeightsT(offset + i)
|
||||
val offset = featureIdx * accSize
|
||||
for i <- 0 until accSize do l1Pre(i) -= l1WeightsT(offset + i)
|
||||
|
||||
// ── Accumulator init ─────────────────────────────────────────────────────
|
||||
|
||||
/** Initialise l1Stack(0) from scratch using sparse active features. */
|
||||
def initAccumulator(board: Board): Unit =
|
||||
System.arraycopy(l1Bias, 0, l1Stack(0), 0, 1536)
|
||||
System.arraycopy(model.weights(0).bias, 0, l1Stack(0), 0, accSize)
|
||||
for (sq, piece) <- board.pieces do addColumn(l1Stack(0), featureIndex(piece, squareNum(sq)))
|
||||
|
||||
// ── Accumulator push (incremental updates) ───────────────────────────────
|
||||
|
||||
/** Copy parent ply's pre-activations to childPly, then apply move deltas. */
|
||||
def pushAccumulator(childPly: Int, move: Move, board: Board): Unit =
|
||||
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, 1536)
|
||||
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize)
|
||||
val l1 = l1Stack(childPly)
|
||||
move.moveType match
|
||||
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
|
||||
@@ -119,9 +64,8 @@ class NNUE:
|
||||
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
|
||||
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
|
||||
|
||||
/** Copy pre-activations from parentPly to childPly without any move delta (null-move). */
|
||||
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
||||
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, 1536)
|
||||
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
|
||||
|
||||
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||
board.pieceAt(move.from).foreach { mover =>
|
||||
@@ -170,9 +114,6 @@ class NNUE:
|
||||
|
||||
// ── Evaluation from accumulator ──────────────────────────────────────────
|
||||
|
||||
/** Evaluate from pre-computed L1 pre-activations at the given ply. Probes eval cache first; stores result after
|
||||
* computation.
|
||||
*/
|
||||
def evaluateAtPly(ply: Int, turn: Color, hash: Long): Int =
|
||||
val idx = (hash & EVAL_CACHE_MASK).toInt
|
||||
if evalCacheHashes(idx) == hash then evalCacheScores(idx)
|
||||
@@ -183,11 +124,19 @@ class NNUE:
|
||||
score
|
||||
|
||||
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
|
||||
for i <- 0 until 1536 do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
|
||||
runDenseReLU(l1ReLU, 1536, l2Weights, l2Bias, l2Output, 1024)
|
||||
runDenseReLU(l2Output, 1024, l3Weights, l3Bias, l3Output, 512)
|
||||
runDenseReLU(l3Output, 512, l4Weights, l4Bias, l4Output, 256)
|
||||
val output = runOutputLayer(l4Output, 256)
|
||||
val l1ReLU = evalBuffers(0)
|
||||
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
|
||||
|
||||
var input = l1ReLU
|
||||
for i <- 1 until model.layers.length - 1 do
|
||||
val lw = model.weights(i)
|
||||
val out = evalBuffers(i)
|
||||
val ld = model.layers(i)
|
||||
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
|
||||
input = out
|
||||
|
||||
val lastIdx = model.layers.length - 1
|
||||
val output = runOutputLayer(input, model.layers(lastIdx).inputSize, model.weights(lastIdx))
|
||||
scoreFromOutput(output, turn)
|
||||
|
||||
private def runDenseReLU(
|
||||
@@ -202,8 +151,8 @@ class NNUE:
|
||||
val sum = (0 until inSize).foldLeft(bias(i))((s, j) => s + input(j) * weights(i * inSize + j))
|
||||
output(i) = if sum > 0f then sum else 0f
|
||||
|
||||
private def runOutputLayer(input: Array[Float], inSize: Int): Float =
|
||||
(0 until inSize).foldLeft(l5Bias(0))((sum, j) => sum + input(j) * l5Weights(j))
|
||||
private def runOutputLayer(input: Array[Float], inSize: Int, lw: LayerWeights): Float =
|
||||
(0 until inSize).foldLeft(lw.bias(0))((sum, j) => sum + input(j) * lw.weights(j))
|
||||
|
||||
private def scoreFromOutput(output: Float, turn: Color): Int =
|
||||
val cp =
|
||||
@@ -214,21 +163,15 @@ class NNUE:
|
||||
val cpFromTurn = if turn == Color.Black then -cp else cp
|
||||
math.max(-20000, math.min(20000, cpFromTurn))
|
||||
|
||||
// ── Legacy full-board evaluate (kept for Evaluation.evaluate compatibility) ──
|
||||
// ── Legacy full-board evaluate ────────────────────────────────────────────
|
||||
|
||||
// Pre-allocated buffers used only by the legacy evaluate path.
|
||||
private val features = new Array[Float](768)
|
||||
private val legacyL1 = new Array[Float](1536)
|
||||
private val legacyL1 = new Array[Float](accSize)
|
||||
|
||||
/** Evaluate using full board scan (sparse over active features). Layout: black pieces at indices 0-5, white at 6-11.
|
||||
*/
|
||||
def evaluate(context: GameContext): Int =
|
||||
val l1Pre = legacyL1
|
||||
System.arraycopy(l1Bias, 0, l1Pre, 0, 1536)
|
||||
for (sq, piece) <- context.board.pieces do addColumn(l1Pre, featureIndex(piece, squareNum(sq)))
|
||||
runL2toOutput(l1Pre, context.turn)
|
||||
System.arraycopy(model.weights(0).bias, 0, legacyL1, 0, accSize)
|
||||
for (sq, piece) <- context.board.pieces do addColumn(legacyL1, featureIndex(piece, squareNum(sq)))
|
||||
runL2toOutput(legacyL1, context.turn)
|
||||
|
||||
/** Benchmark: time 1M evaluations and report ns/eval. */
|
||||
def benchmark(): Unit =
|
||||
val context = GameContext.initial
|
||||
val iterations = 1_000_000
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import java.io.InputStream
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
object NbaiLoader:
|
||||
|
||||
/** Little-endian encoding of ASCII bytes 'N','B','A','I'. */
|
||||
val MAGIC: Int = 0x4942_414e
|
||||
|
||||
def load(stream: InputStream): NbaiModel =
|
||||
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||||
checkHeader(buf)
|
||||
val metadata = readMetadata(buf)
|
||||
val descs = readLayerDescriptors(buf)
|
||||
val weights = descs.map(_ => readLayerWeights(buf))
|
||||
NbaiModel(metadata, descs, weights)
|
||||
|
||||
/** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */
|
||||
def loadDefault(): NbaiModel =
|
||||
Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match
|
||||
case Some(s) => try load(s) finally s.close()
|
||||
case None => NbaiMigrator.migrateFromBin()
|
||||
|
||||
private def checkHeader(buf: ByteBuffer): Unit =
|
||||
val magic = buf.getInt()
|
||||
if magic != MAGIC then sys.error(s"Invalid NBAI magic: 0x${magic.toHexString}")
|
||||
val version = buf.getShort() & 0xffff
|
||||
if version != 1 then sys.error(s"Unsupported NBAI version: $version")
|
||||
|
||||
private def readMetadata(buf: ByteBuffer): NbaiMetadata =
|
||||
val bytes = new Array[Byte](buf.getInt())
|
||||
buf.get(bytes)
|
||||
NbaiMetadata.fromJson(new String(bytes, StandardCharsets.UTF_8))
|
||||
|
||||
private def readLayerDescriptors(buf: ByteBuffer): Array[LayerDescriptor] =
|
||||
Array.tabulate(buf.getShort() & 0xffff) { _ =>
|
||||
val nameBytes = new Array[Byte](buf.get() & 0xff)
|
||||
buf.get(nameBytes)
|
||||
LayerDescriptor(new String(nameBytes, StandardCharsets.US_ASCII), buf.getInt(), buf.getInt())
|
||||
}
|
||||
|
||||
private def readLayerWeights(buf: ByteBuffer): LayerWeights =
|
||||
LayerWeights(readFloats(buf), readFloats(buf))
|
||||
|
||||
private def readFloats(buf: ByteBuffer): Array[Float] =
|
||||
val arr = new Array[Float](buf.getInt())
|
||||
for i <- arr.indices do arr(i) = buf.getFloat()
|
||||
arr
|
||||
@@ -0,0 +1,43 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
|
||||
/** Converts the legacy nnue_weights.bin resource into an NbaiModel. Used as fallback when no .nbai file exists. */
|
||||
object NbaiMigrator:
|
||||
|
||||
private val BinMagic = 0x4555_4e4e
|
||||
private val BinVersion = 1
|
||||
|
||||
private val DefaultLayers: Array[LayerDescriptor] = Array(
|
||||
LayerDescriptor("relu", 768, 1536),
|
||||
LayerDescriptor("relu", 1536, 1024),
|
||||
LayerDescriptor("relu", 1024, 512),
|
||||
LayerDescriptor("relu", 512, 256),
|
||||
LayerDescriptor("linear", 256, 1),
|
||||
)
|
||||
|
||||
private val UnknownMetadata: NbaiMetadata =
|
||||
NbaiMetadata(trainedBy = "unknown", trainedAt = "unknown", trainingDataCount = 0L, valLoss = 0.0, trainLoss = 0.0)
|
||||
|
||||
def migrateFromBin(): NbaiModel =
|
||||
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
|
||||
.getOrElse(sys.error("Neither nnue_weights.nbai nor nnue_weights.bin found in resources"))
|
||||
try
|
||||
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||||
checkBinHeader(buf)
|
||||
val weights = DefaultLayers.map(_ => readBinLayerWeights(buf))
|
||||
NbaiModel(UnknownMetadata, DefaultLayers, weights)
|
||||
finally stream.close()
|
||||
|
||||
private def checkBinHeader(buf: ByteBuffer): Unit =
|
||||
val magic = buf.getInt()
|
||||
if magic != BinMagic then sys.error(s"Invalid bin magic: 0x${magic.toHexString}")
|
||||
val version = buf.getInt()
|
||||
if version != BinVersion then sys.error(s"Unsupported bin version: $version")
|
||||
|
||||
private def readBinLayerWeights(buf: ByteBuffer): LayerWeights =
|
||||
LayerWeights(readBinTensor(buf), readBinTensor(buf))
|
||||
|
||||
private def readBinTensor(buf: ByteBuffer): Array[Float] =
|
||||
val shape = Array.tabulate(buf.getInt())(_ => buf.getInt())
|
||||
Array.tabulate(shape.product)(_ => buf.getFloat())
|
||||
@@ -0,0 +1,39 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
/** Descriptor for a single dense layer stored in a .nbai file. */
|
||||
case class LayerDescriptor(activation: String, inputSize: Int, outputSize: Int)
|
||||
|
||||
/** Training metadata embedded in every .nbai file. */
|
||||
case class NbaiMetadata(
|
||||
trainedBy: String,
|
||||
trainedAt: String,
|
||||
trainingDataCount: Long,
|
||||
valLoss: Double,
|
||||
trainLoss: Double,
|
||||
):
|
||||
def toJson: String =
|
||||
s"""{
|
||||
| "trainedBy": "$trainedBy",
|
||||
| "trainedAt": "$trainedAt",
|
||||
| "trainingDataCount": $trainingDataCount,
|
||||
| "valLoss": $valLoss,
|
||||
| "trainLoss": $trainLoss
|
||||
|}""".stripMargin
|
||||
|
||||
object NbaiMetadata:
|
||||
def fromJson(json: String): NbaiMetadata =
|
||||
def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("")
|
||||
def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0")
|
||||
NbaiMetadata(str("trainedBy"), str("trainedAt"), num("trainingDataCount").toLong, num("valLoss").toDouble, num("trainLoss").toDouble)
|
||||
|
||||
/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */
|
||||
case class LayerWeights(weights: Array[Float], bias: Array[Float])
|
||||
|
||||
/** A fully deserialized .nbai model ready to initialize NNUE. */
|
||||
case class NbaiModel(
|
||||
metadata: NbaiMetadata,
|
||||
layers: Array[LayerDescriptor],
|
||||
weights: Array[LayerWeights],
|
||||
):
|
||||
require(layers.length == weights.length, "Layer count must match weight count")
|
||||
require(layers.length >= 2, "Model must have at least 2 layers")
|
||||
@@ -0,0 +1,51 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import java.io.{ByteArrayOutputStream, OutputStream}
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
object NbaiWriter:
|
||||
|
||||
def write(model: NbaiModel, out: OutputStream): Unit =
|
||||
val acc = new ByteArrayOutputStream()
|
||||
writeHeader(acc)
|
||||
writeMetadata(acc, model.metadata)
|
||||
writeLayerDescriptors(acc, model.layers)
|
||||
model.weights.foreach(lw => writeLayerWeights(acc, lw))
|
||||
out.write(acc.toByteArray)
|
||||
|
||||
private def writeHeader(out: ByteArrayOutputStream): Unit =
|
||||
val buf = ByteBuffer.allocate(6).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putInt(NbaiLoader.MAGIC)
|
||||
buf.putShort(1.toShort)
|
||||
out.write(buf.array())
|
||||
|
||||
private def writeMetadata(out: ByteArrayOutputStream, meta: NbaiMetadata): Unit =
|
||||
val json = meta.toJson.getBytes(StandardCharsets.UTF_8)
|
||||
val buf = ByteBuffer.allocate(4 + json.length).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putInt(json.length)
|
||||
buf.put(json)
|
||||
out.write(buf.array())
|
||||
|
||||
private def writeLayerDescriptors(out: ByteArrayOutputStream, layers: Array[LayerDescriptor]): Unit =
|
||||
val nameBytes = layers.map(_.activation.getBytes(StandardCharsets.US_ASCII))
|
||||
val capacity = 2 + layers.indices.map(i => 1 + nameBytes(i).length + 8).sum
|
||||
val buf = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putShort(layers.length.toShort)
|
||||
layers.zip(nameBytes).foreach { (l, nb) =>
|
||||
buf.put(nb.length.toByte)
|
||||
buf.put(nb)
|
||||
buf.putInt(l.inputSize)
|
||||
buf.putInt(l.outputSize)
|
||||
}
|
||||
out.write(buf.array())
|
||||
|
||||
private def writeLayerWeights(out: ByteArrayOutputStream, lw: LayerWeights): Unit =
|
||||
writeFloats(out, lw.weights)
|
||||
writeFloats(out, lw.bias)
|
||||
|
||||
private def writeFloats(out: ByteArrayOutputStream, floats: Array[Float]): Unit =
|
||||
val buf = ByteBuffer.allocate(4 + floats.length * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putInt(floats.length)
|
||||
floats.foreach(buf.putFloat)
|
||||
out.write(buf.array())
|
||||
Reference in New Issue
Block a user