feat: Implement dataset versioning and management for NNUE training data

This commit is contained in:
2026-04-13 21:19:26 +02:00
parent 4b52199754
commit 8fb872e958
18 changed files with 1399 additions and 335 deletions
+1
View File
@@ -19,3 +19,4 @@ ENV/
*.swo
tactical_data/
trainingdata/
/datasets/
+173
View File
@@ -0,0 +1,173 @@
# Training Dataset Management
The NNUE training pipeline now features versioned dataset management, similar to model versioning. This prevents data loss and allows you to maintain multiple training configurations.
## Directory Structure
```
datasets/
ds_v1/
labeled.jsonl # Training data: {"fen": "...", "eval": 0.5, "eval_raw": 150}
metadata.json # Version info and composition
ds_v2/
labeled.jsonl
metadata.json
```
## Metadata Schema
Each dataset has a `metadata.json` file tracking its composition:
```json
{
"version": 1,
"created": "2026-04-13T15:30:45.123456",
"total_positions": 1000000,
"stockfish_depth": 12,
"sources": [
{
"type": "generated",
"count": 500000,
"params": {
"num_positions": 500000,
"min_move": 1,
"max_move": 50
}
},
{
"type": "tactical",
"count": 300000,
"max_puzzles": 300000
},
{
"type": "file_import",
"count": 200000,
"path": "/path/to/original_file.txt"
}
]
}
```
## TUI Workflow
### Main Menu
```
1 - Manage Training Data
2 - Train Model
3 - Export Model
4 - Exit
```
### Training Data Management Submenu
```
1 - Create new dataset
2 - Extend existing dataset
3 - View all datasets
4 - Delete dataset
5 - Back
```
## Creating a Dataset
Use the "Create new dataset" option to add data from one or more sources:
1. **Generate random positions** — Play random games and sample positions
- Number of positions
- Move range (min/max move number to sample from)
- Number of worker threads
2. **Import from file** — Load positions from a FEN file
- File must contain one FEN string per line
- Duplicates are automatically removed
3. **Extract tactical puzzles** — Download and extract Lichess puzzle database
- Maximum number of puzzles to include
- Automatically filters for tactical themes (forks, pins, mates, etc.)
You can combine multiple sources in a single dataset creation session. All positions are:
- Deduplicated (only unique FENs are kept)
- Labeled with Stockfish evaluations
- Saved to `datasets/ds_vN/labeled.jsonl`
## Extending a Dataset
Use "Extend existing dataset" to add more positions to an existing dataset:
1. Select the dataset version to extend
2. Choose data sources (same options as creation)
3. Confirm labeling parameters
4. New positions are:
- Labeled with Stockfish
- Deduplicated against the target dataset (preventing duplicates)
- Merged into the existing `labeled.jsonl`
- Metadata is updated with the new source entry
## Training with a Dataset
When you start training (Standard or Burst mode), you'll be prompted to select a dataset version. The TUI will display all available datasets with:
- Version number
- Total number of positions
- Source types (generated, tactical, imported)
- Stockfish depth used
- Creation date
## Legacy Data Migration
If you have existing labeled data in `data/training_data.jsonl` from before this update:
1. Open the "Manage Training Data" menu
2. Choose "Create new dataset"
3. Select "Import from file"
4. Point to `data/training_data.jsonl`
5. Complete the dataset creation
Alternatively, you can manually copy the file to `datasets/ds_v1/labeled.jsonl` and create a `metadata.json` file.
## Viewing Dataset Details
Use "View all datasets" to see a table of all datasets with:
- Version number
- Position count
- Source composition
- Stockfish depth
- Creation date
## Deleting a Dataset
Use "Delete dataset" to remove a dataset and free up disk space. **This action cannot be undone.**
⚠️ The system does not prevent deleting datasets used by model checkpoints. Plan accordingly.
## Technical Details
### Deduplication Strategy
When extending a dataset, positions are deduplicated **within that dataset only**. This allows different datasets to contain overlapping positions if desired.
When creating a new dataset from multiple sources, all sources are combined and deduplicated before labeling.
### Labeled Position Format
Each line in `labeled.jsonl` is a JSON object:
```json
{
"fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
"eval": 0.0,
"eval_raw": 0
}
```
- `fen`: The position in Forsyth-Edwards Notation
- `eval`: Normalized evaluation ([-1, 1] range using tanh)
- `eval_raw`: Raw Stockfish evaluation in centipawns
### Storage Location
Datasets are stored in the `datasets/` directory relative to the script location. The old `data/` directory is preserved for backward compatibility but is not actively used by the new system.
## Performance Tips
- **Smaller datasets train faster** — Start with 100k-500k positions
- **Deduplication matters** — Use the extend functionality to build up your dataset without redundant data
- **Stockfish depth** — Depth 12-14 balances accuracy and labeling speed
- **Workers** — Use 4-8 workers for labeling if your machine supports it; more workers = faster but uses more CPU/memory
+547 -193
View File
@@ -4,6 +4,7 @@
import os
import shutil
import sys
import tempfile
from pathlib import Path
from rich.console import Console
from rich.table import Table
@@ -17,23 +18,23 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish
from train import train_nnue, burst_train
from export import export_weights_to_binary
from export import export_to_nbai
from tactical_positions_extractor import (
download_and_extract_puzzle_db,
interactive_merge_positions
extract_tactical_only
)
from dataset import (
get_datasets_dir,
list_datasets,
next_dataset_version,
load_dataset_metadata,
create_dataset,
extend_dataset,
get_dataset_labeled_path,
delete_dataset,
show_datasets_table
)
def get_data_dir():
"""Get/create data directory."""
data_dir = Path(__file__).parent / "data"
data_dir.mkdir(exist_ok=True)
return data_dir
def get_tactical_data_dir():
"""Get/create data directory."""
data_dir = Path(__file__).parent / "tactical_data"
data_dir.mkdir(exist_ok=True)
return data_dir
def get_weights_dir():
"""Get/create weights directory."""
@@ -41,6 +42,14 @@ def get_weights_dir():
weights_dir.mkdir(exist_ok=True)
return weights_dir
def get_data_dir():
"""Get/create legacy data directory (for migration)."""
data_dir = Path(__file__).parent / "data"
data_dir.mkdir(exist_ok=True)
return data_dir
def list_checkpoints():
"""List available checkpoint versions."""
weights_dir = get_weights_dir()
@@ -49,6 +58,20 @@ def list_checkpoints():
return []
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
def migrate_legacy_data():
"""On first run, offer to import existing data/training_data.jsonl as ds_v1."""
console = Console()
data_dir = get_data_dir()
legacy_file = data_dir / "training_data.jsonl"
datasets = list_datasets()
# Only migrate if legacy data exists and no datasets exist yet
if legacy_file.exists() and not datasets:
console.print("\n[cyan]Legacy data detected: data/training_data.jsonl[/cyan]")
console.print("[dim]Tip: Use 'Manage Training Data' menu to import it as ds_v1[/dim]")
def show_header():
"""Display application header."""
console = Console()
@@ -56,22 +79,23 @@ def show_header():
console.print(
Panel(
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
"[dim]Neural Network Utility Evaluation - Model Management[/dim]",
"[dim]Neural Network Utility Evaluation - Dataset & Model Management[/dim]",
border_style="cyan",
padding=(1, 2),
)
)
def show_checkpoints_table():
"""Display available checkpoints in a table."""
console = Console()
available = list_checkpoints()
if not available:
console.print("[yellow] No checkpoints found yet[/yellow]")
console.print("[yellow] No model checkpoints found yet[/yellow]")
return
table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan")
table = Table(title="Available Model Checkpoints", show_header=True, header_style="bold cyan")
table.add_column("Version", style="dim")
table.add_column("File Size", justify="right")
table.add_column("Status", justify="center")
@@ -87,46 +111,500 @@ def show_checkpoints_table():
console.print(table)
def show_main_menu():
"""Display and handle main menu."""
console = Console()
# Migrate legacy data on first run
migrate_legacy_data()
while True:
show_header()
show_checkpoints_table()
console.print("\n[bold]What would you like to do?[/bold]")
console.print("[cyan]1[/cyan] - Train NNUE Model")
console.print("[cyan]2[/cyan] - Burst Train NNUE Model")
console.print("[cyan]3[/cyan] - Export Weights to Scala")
console.print("[cyan]4[/cyan] - Extract Tactical Positions")
console.print("[cyan]5[/cyan] - View Checkpoints")
console.print("[cyan]6[/cyan] - Exit")
console.print("[cyan]1[/cyan] - Manage Training Data")
console.print("[cyan]2[/cyan] - Train Model")
console.print("[cyan]3[/cyan] - Export Model")
console.print("[cyan]4[/cyan] - Exit")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5", "6"])
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
if choice == "1":
datasets_menu()
elif choice == "2":
training_menu()
elif choice == "3":
export_interactive()
elif choice == "4":
console.print("[yellow]👋 Goodbye![/yellow]")
return
def datasets_menu():
"""Dataset management submenu."""
console = Console()
while True:
show_header()
show_datasets_table(console)
console.print("\n[bold]Training Data Management[/bold]")
console.print("[cyan]1[/cyan] - Create new dataset")
console.print("[cyan]2[/cyan] - Extend existing dataset")
console.print("[cyan]3[/cyan] - View all datasets")
console.print("[cyan]4[/cyan] - Delete dataset")
console.print("[cyan]5[/cyan] - Back")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
if choice == "1":
create_dataset_interactive()
elif choice == "2":
extend_dataset_interactive()
elif choice == "3":
show_header()
show_datasets_table(console)
Prompt.ask("\nPress Enter to continue")
elif choice == "4":
delete_dataset_interactive()
elif choice == "5":
return
def create_dataset_interactive():
"""Interactive dataset creation flow."""
console = Console()
show_header()
console.print("\n[bold cyan]📊 Create New Dataset[/bold cyan]")
sources = []
combined_count = 0
# Allow user to add multiple sources
while True:
console.print("\n[bold]Add data source (repeat until done):[/bold]")
console.print("[cyan]a[/cyan] - Generate random positions")
console.print("[cyan]b[/cyan] - Import from file")
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
console.print("[cyan]d[/cyan] - Done adding sources")
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
if choice == "a":
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
min_move = int(Prompt.ask("Minimum move number", default="1"))
max_move = int(Prompt.ask("Maximum move number", default="50"))
num_workers = int(Prompt.ask("Number of workers", default="8"))
console.print("[dim]Generating positions...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
count = play_random_game_and_collect_positions(
str(temp_file),
total_positions=num_positions,
samples_per_game=1,
min_move=min_move,
max_move=max_move,
num_workers=num_workers
)
if count > 0:
sources.append({
"type": "generated",
"count": count,
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
})
combined_count += count
console.print(f"[green]✓ {count:,} positions generated[/green]")
else:
console.print("[red]✗ Generation failed[/red]")
elif choice == "b":
file_path = Prompt.ask("Path to FEN file")
try:
with open(file_path, 'r') as f:
count = sum(1 for _ in f)
sources.append({"type": "file_import", "count": count, "path": file_path})
combined_count += count
console.print(f"[green]✓ {count:,} positions from file[/green]")
except FileNotFoundError:
console.print(f"[red]✗ File not found: {file_path}[/red]")
elif choice == "c":
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
console.print("[dim]Extracting tactical positions...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
try:
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
if csv_path:
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
combined_count += count
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
except Exception as e:
console.print(f"[red]✗ Tactical extraction failed: {e}[/red]")
elif choice == "d":
if not sources:
console.print("[yellow]⚠ No sources added yet[/yellow]")
continue
break
if not sources:
console.print("[yellow]Dataset creation cancelled[/yellow]")
return
# Stockfish labeling parameters
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Summary and confirm
console.print("\n[bold]Dataset Summary:[/bold]")
console.print(f" Total positions: {combined_count:,}")
for source in sources:
console.print(f" - {source['type']}: {source['count']:,}")
console.print(f" Stockfish depth: {stockfish_depth}")
if not Confirm.ask("\nProceed to label and create dataset?", default=True):
console.print("[yellow]Cancelled[/yellow]")
return
try:
# Combine all sources into one FEN file
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
for source in sources:
if source['type'] == 'generated':
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source['type'] == 'file_import':
temp_file = Path(source['path'])
elif source['type'] == 'tactical':
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
if temp_file.exists():
with open(temp_file, 'r') as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, 'w') as f:
for fen in all_fens:
f.write(fen + '\n')
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers
)
if not success:
console.print("[red]✗ Labeling failed[/red]")
return
console.print("[green]✓ Positions labeled[/green]")
# Save dataset
console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]")
version = next_dataset_version()
create_dataset(
version=version,
labeled_jsonl_path=str(labeled_file),
sources=sources,
stockfish_depth=stockfish_depth
)
console.print(f"[green]✓ Dataset created: ds_v{version}[/green]")
console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]")
Prompt.ask("\nPress Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def extend_dataset_interactive():
"""Interactive dataset extension flow."""
console = Console()
show_header()
console.print("\n[bold cyan]📊 Extend Existing Dataset[/bold cyan]")
datasets = list_datasets()
if not datasets:
console.print("[yellow] No datasets available to extend[/yellow]")
Prompt.ask("Press Enter to continue")
return
show_datasets_table(console)
version = int(Prompt.ask("\nEnter dataset version to extend (e.g., 1)"))
if not any(v == version for v, _ in datasets):
console.print("[red]✗ Dataset not found[/red]")
return
sources = []
combined_count = 0
# Allow user to add sources
while True:
console.print("\n[bold]Add data source:[/bold]")
console.print("[cyan]a[/cyan] - Generate random positions")
console.print("[cyan]b[/cyan] - Import from file")
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
console.print("[cyan]d[/cyan] - Done adding sources")
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
if choice == "a":
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
min_move = int(Prompt.ask("Minimum move number", default="1"))
max_move = int(Prompt.ask("Maximum move number", default="50"))
num_workers = int(Prompt.ask("Number of workers", default="8"))
console.print("[dim]Generating positions...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
count = play_random_game_and_collect_positions(
str(temp_file),
total_positions=num_positions,
samples_per_game=1,
min_move=min_move,
max_move=max_move,
num_workers=num_workers
)
if count > 0:
sources.append({
"type": "generated",
"count": count,
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
})
combined_count += count
console.print(f"[green]✓ {count:,} positions generated[/green]")
elif choice == "b":
file_path = Prompt.ask("Path to FEN file")
try:
with open(file_path, 'r') as f:
count = sum(1 for _ in f)
sources.append({"type": "file_import", "count": count, "path": file_path})
combined_count += count
console.print(f"[green]✓ {count:,} positions from file[/green]")
except FileNotFoundError:
console.print(f"[red]✗ File not found: {file_path}[/red]")
elif choice == "c":
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
console.print("[dim]Extracting tactical positions...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
try:
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
if csv_path:
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
combined_count += count
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
except Exception as e:
console.print(f"[red]✗ Extraction failed: {e}[/red]")
elif choice == "d":
if not sources:
console.print("[yellow]⚠ No sources added yet[/yellow]")
continue
break
if not sources:
console.print("[yellow]Extension cancelled[/yellow]")
return
# Stockfish labeling parameters
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Summary and confirm
console.print("\n[bold]Extension Summary:[/bold]")
console.print(f" Target dataset: ds_v{version}")
console.print(f" New positions: {combined_count:,}")
for source in sources:
console.print(f" - {source['type']}: {source['count']:,}")
if not Confirm.ask("\nProceed to label and extend dataset?", default=True):
console.print("[yellow]Cancelled[/yellow]")
return
try:
# Combine all sources into one FEN file
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
for source in sources:
if source['type'] == 'generated':
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source['type'] == 'file_import':
temp_file = Path(source['path'])
elif source['type'] == 'tactical':
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
if temp_file.exists():
with open(temp_file, 'r') as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, 'w') as f:
for fen in all_fens:
f.write(fen + '\n')
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers
)
if not success:
console.print("[red]✗ Labeling failed[/red]")
return
console.print("[green]✓ Positions labeled[/green]")
# Extend dataset
console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]")
success = extend_dataset(
version=version,
new_labeled_path=str(labeled_file),
new_source_entry={
"type": "merged_sources",
"count": len(all_fens),
"sources": sources
}
)
if success:
metadata = load_dataset_metadata(version)
console.print(f"[green]✓ Dataset extended[/green]")
console.print(f"[bold]Total positions: {metadata['total_positions']:,}[/bold]")
else:
console.print("[red]✗ Extension failed[/red]")
Prompt.ask("\nPress Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def delete_dataset_interactive():
"""Interactive dataset deletion."""
console = Console()
show_header()
console.print("\n[bold cyan]⚠️ Delete Dataset[/bold cyan]")
datasets = list_datasets()
if not datasets:
console.print("[yellow] No datasets to delete[/yellow]")
Prompt.ask("Press Enter to continue")
return
show_datasets_table(console)
version = int(Prompt.ask("\nEnter dataset version to delete (e.g., 1)"))
if not any(v == version for v, _ in datasets):
console.print("[red]✗ Dataset not found[/red]")
return
if Confirm.ask(f"Delete ds_v{version}? This cannot be undone.", default=False):
if delete_dataset(version):
console.print(f"[green]✓ Dataset ds_v{version} deleted[/green]")
else:
console.print("[red]✗ Deletion failed[/red]")
Prompt.ask("Press Enter to continue")
def training_menu():
"""Training submenu."""
console = Console()
while True:
show_header()
console.print("\n[bold]Training[/bold]")
console.print("[cyan]1[/cyan] - Standard Training")
console.print("[cyan]2[/cyan] - Burst Training")
console.print("[cyan]3[/cyan] - View Model Checkpoints")
console.print("[cyan]4[/cyan] - Back")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
if choice == "1":
train_interactive()
elif choice == "2":
burst_train_interactive()
elif choice == "3":
export_interactive()
elif choice == "4":
extract_tactical_interactive()
elif choice == "5":
show_header()
show_checkpoints_table()
Prompt.ask("\nPress Enter to continue")
elif choice == "6":
console.print("[yellow]👋 Goodbye![/yellow]")
elif choice == "4":
return
def train_interactive():
"""Interactive training menu."""
console = Console()
show_header()
console.print("\n[bold cyan]📚 Training Configuration[/bold cyan]")
console.print("\n[bold cyan]📚 Standard Training Configuration[/bold cyan]")
# Dataset selection
datasets = list_datasets()
if not datasets:
console.print("[red]✗ No datasets available. Create one first.[/red]")
Prompt.ask("Press Enter to continue")
return
console.print("\n[bold]Available Datasets:[/bold]")
show_datasets_table(console)
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
if not any(v == dataset_version for v, _ in datasets):
console.print("[red]✗ Dataset not found[/red]")
return
labeled_file = get_dataset_labeled_path(dataset_version)
if not labeled_file:
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
return
# Checkpoint selection
available = list_checkpoints()
@@ -142,36 +620,6 @@ def train_interactive():
default=str(max(available))
)
# Positions source
use_existing = Confirm.ask("Use existing positions file?", default=False)
positions_file = None
num_games = 500000
samples_per_game = 1
min_move = 1
max_move = 50
if use_existing:
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
else:
num_games = int(Prompt.ask("Number of games to generate", default="5000"))
samples_per_game = int(Prompt.ask("Positions to sample per game", default="1"))
min_move = int(Prompt.ask("Minimum move number", default="1"))
max_move = int(Prompt.ask("Maximum move number", default="50"))
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
labels_file = None
if use_existing_labels:
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
# Stockfish path and labeling parameters
default_stockfish = os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
stockfish_depth = 12
num_workers = 1
if not use_existing_labels:
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Training parameters
epochs = int(Prompt.ask("Number of epochs", default="100"))
batch_size = int(Prompt.ask("Batch size", default="16384"))
@@ -182,16 +630,7 @@ def train_interactive():
# Confirm and start
console.print("\n[bold]Configuration Summary:[/bold]")
if use_checkpoint:
console.print(f" Checkpoint: v{checkpoint_version}")
else:
console.print(" Checkpoint: None (training from scratch)")
if not use_existing:
console.print(f" Games: {num_games:,}")
console.print(f" Samples per game: {samples_per_game}")
console.print(f" Move range: {min_move}-{max_move}")
else:
console.print(f" Positions file: {positions_file}")
console.print(f" Dataset: ds_v{dataset_version}")
console.print(f" Epochs: {epochs}")
console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
@@ -199,70 +638,27 @@ def train_interactive():
console.print(f" Early stopping: Yes (patience: {early_stopping})")
else:
console.print(f" Early stopping: No")
if not use_existing_labels:
console.print(f" Stockfish depth: {stockfish_depth}")
console.print(f" Workers: {num_workers}")
console.print(f" Stockfish: {stockfish_path}")
if use_checkpoint:
console.print(f" Checkpoint: v{checkpoint_version}")
else:
console.print(f" Checkpoint: None (training from scratch)")
if not Confirm.ask("\nStart training?", default=True):
console.print("[yellow]Training cancelled[/yellow]")
Prompt.ask("Press Enter to continue")
return
# Execute training
data_dir = get_data_dir()
weights_dir = get_weights_dir()
try:
# Generate positions
if not use_existing:
console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]")
count = play_random_game_and_collect_positions(
str(data_dir / "positions.txt"),
total_games=num_games,
samples_per_game=samples_per_game,
min_move=min_move,
max_move=max_move
)
if count == 0:
console.print("[red]✗ No valid positions generated[/red]")
return
console.print(f"[green]✓ Generated {count:,} positions[/green]")
else:
if not Path(positions_file).exists():
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
return
if not use_existing_labels:
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
positions_file = data_dir / "positions.txt"
output_file = data_dir / "training_data.jsonl"
success = label_positions_with_stockfish(
str(positions_file),
str(output_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers
)
if not success:
console.print("[red]✗ Position labeling failed[/red]")
return
console.print(f"[green]✓ Positions labeled[/green]")
else:
console.print("\n[bold cyan]Step 2: Loading Existing Labels[/bold cyan]")
output_file = labels_file
if not Path(output_file).exists():
console.print(f"[red]✗ Labels file not found: {output_file}[/red]")
return
# Train model
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
console.print("\n[bold cyan]Training Model[/bold cyan]")
checkpoint = None
if use_checkpoint:
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
train_nnue(
data_file=str(output_file),
data_file=str(labeled_file),
output_file=str(weights_dir / "nnue_weights.pt"),
epochs=epochs,
batch_size=batch_size,
@@ -286,6 +682,7 @@ def train_interactive():
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def burst_train_interactive():
"""Interactive burst training menu."""
console = Console()
@@ -294,14 +691,30 @@ def burst_train_interactive():
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
# Dataset selection
datasets = list_datasets()
if not datasets:
console.print("[red]✗ No datasets available. Create one first.[/red]")
Prompt.ask("Press Enter to continue")
return
console.print("[bold]Available Datasets:[/bold]")
show_datasets_table(console)
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
if not any(v == dataset_version for v, _ in datasets):
console.print("[red]✗ Dataset not found[/red]")
return
labeled_file = get_dataset_labeled_path(dataset_version)
if not labeled_file:
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
return
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
# Data file
default_labels = str(get_data_dir() / "training_data.jsonl")
labels_file = Prompt.ask("Path to labeled data file (.jsonl)", default=default_labels)
# Optional initial checkpoint
available = list_checkpoints()
checkpoint = None
@@ -317,29 +730,25 @@ def burst_train_interactive():
# Summary
console.print("\n[bold]Configuration Summary:[/bold]")
console.print(f" Dataset: ds_v{dataset_version}")
console.print(f" Duration: {duration_minutes:.0f} minutes")
console.print(f" Epochs per season: {epochs_per_season}")
console.print(f" Patience: {early_stopping_patience}")
console.print(f" Data file: {labels_file}")
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
if not Confirm.ask("\nStart burst training?", default=True):
console.print("[yellow]Burst training cancelled[/yellow]")
Prompt.ask("Press Enter to continue")
return
weights_dir = get_weights_dir()
try:
if not Path(labels_file).exists():
console.print(f"[red]✗ Data file not found: {labels_file}[/red]")
Prompt.ask("Press Enter to continue")
return
console.print("\n[bold cyan]Burst Training[/bold cyan]")
burst_train(
data_file=labels_file,
data_file=str(labeled_file),
output_file=str(weights_dir / "nnue_weights.pt"),
duration_minutes=duration_minutes,
epochs_per_season=epochs_per_season,
@@ -362,6 +771,7 @@ def burst_train_interactive():
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def export_interactive():
"""Interactive export menu."""
console = Console()
@@ -380,7 +790,7 @@ def export_interactive():
version = Prompt.ask("Enter version to export (e.g., 2)")
weights_file = f"nnue_weights_v{version}.pt"
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.nbai")
console.print(f"\n[bold]Export Configuration:[/bold]")
console.print(f" Source: {weights_file}")
@@ -399,7 +809,7 @@ def export_interactive():
return
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
export_weights_to_binary(str(weights_path), output_file)
export_to_nbai(str(weights_path), output_file)
console.print(f"\n[green]✓ Export complete![/green]")
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
Prompt.ask("Press Enter to continue")
@@ -410,65 +820,6 @@ def export_interactive():
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def extract_tactical_interactive():
"""Interactive tactical positions extraction and merge menu."""
console = Console()
show_header()
console.print("\n[bold cyan]♟️ Tactical Positions Extraction & Merge[/bold cyan]")
# Download and extract options
console.print("\n[bold]Lichess Puzzle Database:[/bold]")
download_url = Prompt.ask(
"Download URL",
default="https://database.lichess.org/lichess_db_puzzle.csv.zst"
)
output_dir = Prompt.ask(
"Extract to directory",
default=str(Path(__file__).parent / "trainingdata")
)
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
# Confirm and download
console.print("\n[bold]Configuration:[/bold]")
console.print(f" Download URL: {download_url}")
console.print(f" Extract directory: {output_dir}")
console.print(f" Max puzzles: {max_puzzles:,}")
if not Confirm.ask("\nProceed?", default=True):
console.print("[yellow]Cancelled[/yellow]")
Prompt.ask("Press Enter to continue")
return
try:
console.print("\n[bold cyan]Step 1: Download & Extract[/bold cyan]")
csv_path = download_and_extract_puzzle_db(download_url, output_dir)
if not csv_path:
console.print("[red]✗ Failed to download/extract[/red]")
Prompt.ask("Press Enter to continue")
return
console.print(f"[green]✓ Ready: {csv_path}[/green]")
# Interactive merge
console.print("\n[bold cyan]Step 2: Extract & Merge[/bold cyan]")
output_file = Prompt.ask(
"Output file path",
default=str(Path(__file__).parent / "data" / "position.txt")
)
interactive_merge_positions(csv_path, output_file, max_puzzles)
console.print(f"\n[green]✓ Complete![/green]")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def main():
try:
@@ -481,7 +832,10 @@ def main():
except Exception as e:
console = Console()
console.print(f"[red]Error:[/red] {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())
+287
View File
@@ -0,0 +1,287 @@
#!/usr/bin/env python3
"""Dataset versioning and management for NNUE training data."""
import json
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, List, Tuple
from rich.console import Console
from rich.table import Table
def get_datasets_dir() -> Path:
"""Get/create datasets directory."""
datasets_dir = Path(__file__).parent.parent / "datasets"
datasets_dir.mkdir(exist_ok=True)
return datasets_dir
def next_dataset_version() -> int:
"""Find the next available dataset version number."""
datasets_dir = get_datasets_dir()
versions = []
for d in datasets_dir.iterdir():
if d.is_dir() and d.name.startswith("ds_v"):
try:
v = int(d.name.split("_v")[1])
versions.append(v)
except (ValueError, IndexError):
pass
return max(versions) + 1 if versions else 1
def list_datasets() -> List[Tuple[int, Dict]]:
"""List all datasets with their metadata.
Returns:
List of (version, metadata_dict) tuples, sorted by version.
"""
datasets_dir = get_datasets_dir()
datasets = []
for d in datasets_dir.iterdir():
if d.is_dir() and d.name.startswith("ds_v"):
try:
v = int(d.name.split("_v")[1])
metadata_file = d / "metadata.json"
if metadata_file.exists():
with open(metadata_file, 'r') as f:
metadata = json.load(f)
datasets.append((v, metadata))
except (ValueError, IndexError, json.JSONDecodeError):
pass
return sorted(datasets, key=lambda x: x[0])
def load_dataset_metadata(version: int) -> Optional[Dict]:
"""Load metadata for a specific dataset version.
Returns:
Metadata dict or None if not found.
"""
datasets_dir = get_datasets_dir()
metadata_file = datasets_dir / f"ds_v{version}" / "metadata.json"
if not metadata_file.exists():
return None
with open(metadata_file, 'r') as f:
return json.load(f)
def save_dataset_metadata(version: int, metadata: Dict) -> None:
"""Save metadata for a dataset version."""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
dataset_dir.mkdir(exist_ok=True)
metadata_file = dataset_dir / "metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2, default=str)
def create_dataset(
version: int,
labeled_jsonl_path: str,
sources: List[Dict],
stockfish_depth: int = 12
) -> Path:
"""Create a new versioned dataset.
Args:
version: Dataset version number
labeled_jsonl_path: Path to labeled.jsonl to copy
sources: List of source dicts (see plan for schema)
stockfish_depth: Depth used for labeling
Returns:
Path to the created dataset directory.
"""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
dataset_dir.mkdir(exist_ok=True)
# Copy labeled data with deduplication (in case source has duplicates)
source_path = Path(labeled_jsonl_path)
if source_path.exists():
dest_path = dataset_dir / "labeled.jsonl"
seen_fens = set()
unique_count = 0
with open(source_path, 'r') as src, open(dest_path, 'w') as dst:
for line in src:
try:
data = json.loads(line)
fen = data.get('fen')
if fen and fen not in seen_fens:
dst.write(line)
seen_fens.add(fen)
unique_count += 1
except json.JSONDecodeError:
# Skip malformed lines
pass
# Count positions
total_positions = 0
if (dataset_dir / "labeled.jsonl").exists():
with open(dataset_dir / "labeled.jsonl", 'r') as f:
total_positions = sum(1 for _ in f)
# Create metadata
metadata = {
"version": version,
"created": datetime.now().isoformat(),
"total_positions": total_positions,
"stockfish_depth": stockfish_depth,
"sources": sources
}
save_dataset_metadata(version, metadata)
return dataset_dir
def extend_dataset(
version: int,
new_labeled_path: str,
new_source_entry: Dict
) -> bool:
"""Extend an existing dataset with new labeled positions (with deduplication).
Args:
version: Dataset version to extend
new_labeled_path: Path to new labeled.jsonl to merge
new_source_entry: Source entry to add to metadata
Returns:
True if successful, False otherwise.
"""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
if not dataset_dir.exists():
return False
labeled_file = dataset_dir / "labeled.jsonl"
new_labeled_file = Path(new_labeled_path)
if not new_labeled_file.exists():
return False
# Load existing FENs (dedup set) — must load entire file to avoid duplicates
existing_fens = set()
if labeled_file.exists():
with open(labeled_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
fen = data.get('fen')
if fen:
existing_fens.add(fen)
except json.JSONDecodeError:
pass
# Merge new positions, skipping duplicates
new_count = 0
new_lines = []
with open(new_labeled_file, 'r') as f_new:
for line in f_new:
try:
data = json.loads(line)
fen = data.get('fen')
if fen and fen not in existing_fens:
new_lines.append(line)
existing_fens.add(fen)
new_count += 1
except json.JSONDecodeError:
pass
# Append only the new, unique positions
if new_lines:
with open(labeled_file, 'a') as f_append:
for line in new_lines:
f_append.write(line)
# Update metadata
metadata = load_dataset_metadata(version)
if metadata:
# Count total positions
total_positions = 0
with open(labeled_file, 'r') as f:
total_positions = sum(1 for _ in f)
metadata['total_positions'] = total_positions
# Update the source entry with actual count of new positions added
new_source_entry['actual_count'] = new_count
metadata['sources'].append(new_source_entry)
save_dataset_metadata(version, metadata)
return True
def get_dataset_labeled_path(version: int) -> Optional[Path]:
"""Get the path to a dataset's labeled.jsonl file.
Returns:
Path to labeled.jsonl or None if dataset doesn't exist.
"""
datasets_dir = get_datasets_dir()
labeled_file = datasets_dir / f"ds_v{version}" / "labeled.jsonl"
if labeled_file.exists():
return labeled_file
return None
def delete_dataset(version: int) -> bool:
"""Delete a dataset (recursively removes directory).
Args:
version: Dataset version to delete
Returns:
True if successful.
"""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
if not dataset_dir.exists():
return False
import shutil
shutil.rmtree(dataset_dir)
return True
def show_datasets_table(console: Console = None) -> None:
"""Display all datasets in a Rich table."""
if console is None:
console = Console()
datasets = list_datasets()
if not datasets:
console.print("[yellow] No datasets found yet[/yellow]")
return
table = Table(title="Available Datasets", show_header=True, header_style="bold cyan")
table.add_column("Version", style="dim")
table.add_column("Positions", justify="right")
table.add_column("Sources", justify="left")
table.add_column("Depth", justify="center")
table.add_column("Created", justify="left")
for v, metadata in datasets:
positions = metadata.get('total_positions', 0)
sources = metadata.get('sources', [])
source_str = ", ".join([s.get('type', '?') for s in sources])
depth = metadata.get('stockfish_depth', '?')
created = metadata.get('created', '?')
if created != '?':
created = created.split('T')[0] # Just the date
table.add_row(f"v{v}", f"{positions:,}", source_str, str(depth), created)
console.print(table)
+109 -39
View File
@@ -1,67 +1,137 @@
#!/usr/bin/env python3
"""Export NNUE weights to binary format for runtime loading."""
"""Export NNUE weights to .nbai format for runtime loading."""
import torch
import json
import struct
import sys
from datetime import datetime
from pathlib import Path
def export_weights_to_binary(weights_file, output_file):
"""Load PyTorch weights and export as binary file."""
import torch
MAGIC = 0x4942_414E # bytes 'N','B','A','I' as little-endian int32
VERSION = 1
def _read_sidecar(weights_file: str) -> dict:
sidecar = weights_file.replace(".pt", "_metadata.json")
if Path(sidecar).exists():
with open(sidecar) as f:
return json.load(f)
return {}
def _infer_layers(state_dict: dict) -> list[dict]:
"""Derive layer descriptors from state_dict weight shapes.
Assumes layers named l1, l2, ..., lN.
All hidden layers get activation 'relu'; the last gets 'linear'.
"""
names = sorted(
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
key=lambda n: int(n[1:]),
)
layers = []
for i, name in enumerate(names):
out_size, in_size = state_dict[f"{name}.weight"].shape
activation = "linear" if i == len(names) - 1 else "relu"
layers.append({"activation": activation, "inputSize": int(in_size), "outputSize": int(out_size)})
return layers
def _write_floats(f, tensor):
data = tensor.float().flatten().cpu().numpy()
f.write(struct.pack("<I", len(data)))
f.write(struct.pack(f"<{len(data)}f", *data))
def export_to_nbai(
weights_file: str,
output_file: str,
trained_by: str = "unknown",
train_loss: float = 0.0,
):
if not Path(weights_file).exists():
print(f"Error: Weights file not found at {weights_file}")
print(f"Error: weights file not found at {weights_file}")
sys.exit(1)
# Load weights — handle both raw state dicts and full training checkpoints
loaded = torch.load(weights_file, map_location='cpu')
state_dict = loaded["model_state_dict"] if isinstance(loaded, dict) and "model_state_dict" in loaded else loaded
loaded = torch.load(weights_file, map_location="cpu")
state_dict = (
loaded["model_state_dict"]
if isinstance(loaded, dict) and "model_state_dict" in loaded
else loaded
)
# Debug: print available layers
print(f"Available layers in {weights_file}:")
for key in sorted(state_dict.keys()):
print(f" {key}: {state_dict[key].shape}")
sidecar = _read_sidecar(weights_file)
val_loss = float(loaded.get("best_val_loss", sidecar.get("final_val_loss", 0.0))) if isinstance(loaded, dict) else 0.0
trained_at = sidecar.get("date", datetime.now().isoformat())
training_data_count = int(sidecar.get("num_positions", 0))
# Create output directory if needed
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
metadata = {
"trainedBy": trained_by,
"trainedAt": trained_at,
"trainingDataCount": training_data_count,
"valLoss": val_loss,
"trainLoss": train_loss,
}
with open(output_file, 'wb') as f:
# Write magic number and version
f.write(b'NNUE')
f.write(struct.pack('<I', 1)) # version 1
layers = _infer_layers(state_dict)
layer_names = sorted(
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
key=lambda n: int(n[1:]),
)
# Write each weight tensor in order
for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias', 'l4.weight', 'l4.bias', 'l5.weight', 'l5.bias']:
if layer_name not in state_dict:
print(f"Error: Missing layer {layer_name}")
sys.exit(1)
print(f"Architecture ({len(layers)} layers):")
for i, l in enumerate(layers):
print(f" l{i + 1}: {l['inputSize']} -> {l['outputSize']} [{l['activation']}]")
tensor = state_dict[layer_name]
# Convert to float32 and flatten
data = tensor.float().flatten().cpu().numpy()
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
# Write shape (allows validation on load)
shape = list(tensor.shape)
f.write(struct.pack('<I', len(shape)))
for dim in shape:
f.write(struct.pack('<I', dim))
with open(output_file, "wb") as f:
# Header
f.write(struct.pack("<I", MAGIC))
f.write(struct.pack("<H", VERSION))
# Write flattened data as binary floats
f.write(struct.pack(f'<{len(data)}f', *data))
# Metadata (length-prefixed UTF-8 JSON)
meta_bytes = json.dumps(metadata, indent=2).encode("utf-8")
f.write(struct.pack("<I", len(meta_bytes)))
f.write(meta_bytes)
print(f" {layer_name}: shape {shape}, {len(data)} floats")
# Layer descriptors
f.write(struct.pack("<H", len(layers)))
for layer in layers:
name_bytes = layer["activation"].encode("ascii")
f.write(struct.pack("<B", len(name_bytes)))
f.write(name_bytes)
f.write(struct.pack("<I", layer["inputSize"]))
f.write(struct.pack("<I", layer["outputSize"]))
# Weights: weight tensor then bias tensor per layer
for name in layer_names:
w = state_dict[f"{name}.weight"]
b = state_dict[f"{name}.bias"]
_write_floats(f, w)
_write_floats(f, b)
print(f" Wrote {name}: weight {tuple(w.shape)}, bias {tuple(b.shape)}")
size_mb = Path(output_file).stat().st_size / (1024 ** 2)
print(f"\nExported to {output_file} ({size_mb:.2f} MB)")
print(f"Metadata: {json.dumps(metadata, indent=2)}")
file_size_mb = output_path.stat().st_size / (1024**2)
print(f"Weights exported to {output_file} ({file_size_mb:.2f} MB)")
if __name__ == "__main__":
weights_file = "nnue_weights.pt"
output_file = "../src/main/resources/nnue_weights.bin"
output_file = "../src/main/resources/nnue_weights.nbai"
trained_by = "unknown"
train_loss = 0.0
if len(sys.argv) > 1:
weights_file = sys.argv[1]
if len(sys.argv) > 2:
output_file = sys.argv[2]
if len(sys.argv) > 3:
trained_by = sys.argv[3]
if len(sys.argv) > 4:
train_loss = float(sys.argv[4])
export_weights_to_binary(weights_file, output_file)
export_to_nbai(weights_file, output_file, trained_by, train_loss)
+12 -1
View File
@@ -125,6 +125,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
# Load all FENs that need evaluation
fens_to_evaluate = []
fens_seen_in_batch = set() # Track duplicates within current batch
skipped_invalid = 0
skipped_duplicate = 0
@@ -140,7 +141,12 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
skipped_duplicate += 1
continue
if fen in fens_seen_in_batch:
skipped_duplicate += 1
continue
fens_to_evaluate.append(fen)
fens_seen_in_batch.add(fen)
total_to_evaluate = len(fens_to_evaluate)
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
@@ -178,8 +184,13 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
with open(output_file, 'a') as out:
for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
for fen, eval_normalized, eval_cp in batch_results:
# Skip if already evaluated in output file during this run
if fen in evaluated_fens:
continue
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
out.write(json.dumps(data) + '\n')
evaluated_fens.add(fen) # Track as evaluated
evaluated += 1
raw_evals.append(eval_cp)
normalized_evals.append(eval_normalized)
@@ -287,7 +298,7 @@ if __name__ == "__main__":
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
parser.add_argument("--depth", type=int, default=12,
help="Stockfish depth (default: 12)")
parser.add_argument("--batch-size", type=int, default=20,
parser.add_argument("--batch-size", type=int, default=1000,
help="Batch size for processing (default: 1000)")
parser.add_argument("--no-normalize", action="store_true",
help="Disable evaluation normalization (keep raw centipawns)")
@@ -17,7 +17,7 @@ from generate import play_random_game_and_collect_positions
def download_and_extract_puzzle_db(
url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst',
output_dir: str = 'trainingdata'
output_dir: str = 'tactical_data'
):
"""Download and extract the Lichess puzzle database."""
output_path = Path(output_dir)
@@ -141,6 +141,31 @@ def merge_positions(
print(f"{'='*60}\n")
def extract_tactical_only(
puzzle_csv: str,
output_file: str,
max_puzzles: int = 300_000
) -> int:
"""Extract tactical positions and save to file (no merge prompts).
Args:
puzzle_csv: Path to Lichess puzzle CSV
output_file: Where to save the FEN positions
max_puzzles: Maximum puzzles to extract
Returns:
Number of positions extracted
"""
print("Extracting tactical positions from puzzle database...")
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
with open(output_file, 'w') as f:
for fen in tactical_positions:
f.write(fen + '\n')
return len(tactical_positions)
def interactive_merge_positions(
puzzle_csv: str,
output_file: str = 'position.txt',
Binary file not shown.
@@ -0,0 +1,17 @@
{
"version": 9,
"date": "2026-04-13T20:19:08.123315",
"num_positions": 2522562,
"stockfish_depth": 12,
"final_val_loss": 6.994176222619626e-05,
"device": "cuda",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
"mode": "burst",
"duration_minutes": 30.0,
"epochs_per_season": 50,
"early_stopping_patience": 10,
"seasons_completed": 3,
"batch_size": 16384,
"learning_rate": 0.001,
"initial_checkpoint": "/home/janis/Workspaces/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v8.pt"
}
@@ -6,7 +6,7 @@ import de.nowchess.bot.ai.Evaluation
object EvaluationNNUE extends Evaluation:
private val nnue = NNUE()
private val nnue = NNUE(NbaiLoader.loadDefault())
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
@@ -3,84 +3,31 @@ package de.nowchess.bot.bots.nnue
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
import java.nio.ByteBuffer
import java.nio.ByteOrder
class NNUE:
class NNUE(model: NbaiModel):
private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) =
loadWeights()
private val featureSize = model.layers(0).inputSize
private val accSize = model.layers(0).outputSize
// Column-major L1 weights for cache-friendly sparse & incremental updates.
// l1WeightsT(featureIdx * 1536 + outputIdx) = l1Weights(outputIdx * 768 + featureIdx)
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
private val l1WeightsT: Array[Float] =
val t = new Array[Float](768 * 1536)
for j <- 0 until 768; i <- 0 until 1536 do t(j * 1536 + i) = l1Weights(i * 768 + j)
val w = model.weights(0).weights
val t = new Array[Float](featureSize * accSize)
for j <- 0 until featureSize; i <- 0 until accSize do t(j * accSize + i) = w(i * featureSize + j)
t
private def loadWeights(): (
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
) =
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
.getOrElse(sys.error("NNUE weights file not found in resources"))
try
val bytes = stream.readAllBytes()
val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
val magic = buffer.getInt()
if magic != 0x4555_4e4e then sys.error(s"Invalid magic number: 0x${magic.toHexString}")
val version = buffer.getInt()
if version != 1 then sys.error(s"Unsupported weight version: $version")
val l1w = readTensor(buffer)
val l1b = readTensor(buffer)
val l2w = readTensor(buffer)
val l2b = readTensor(buffer)
val l3w = readTensor(buffer)
val l3b = readTensor(buffer)
val l4w = readTensor(buffer)
val l4b = readTensor(buffer)
val l5w = readTensor(buffer)
val l5b = readTensor(buffer)
(l1w, l1b, l2w, l2b, l3w, l3b, l4w, l4b, l5w, l5b)
finally stream.close()
private def readTensor(buffer: ByteBuffer): Array[Float] =
val shapeLen = buffer.getInt()
val shape = Array.ofDim[Int](shapeLen)
for i <- 0 until shapeLen do shape(i) = buffer.getInt()
val totalElements = shape.product
val floats = Array.ofDim[Float](totalElements)
for i <- 0 until totalElements do floats(i) = buffer.getFloat()
floats
// ── Accumulator stack ────────────────────────────────────────────────────
// l1Stack(ply) holds the L1 pre-activations (before ReLU) for that ply.
// Initialised once at root; each child ply is derived incrementally.
private val MAX_PLY = 128
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](1536))
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](accSize))
// Shared buffers for the dense L2-L5 layers (single-threaded, non-reentrant).
private val l1ReLU = new Array[Float](1536)
private val l2Output = new Array[Float](1024)
private val l3Output = new Array[Float](512)
private val l4Output = new Array[Float](256)
// Shared evaluation buffers: index i holds the output of layers(i) (all except the scalar output layer).
private val evalBuffers: Array[Array[Float]] = model.layers.init.map(l => new Array[Float](l.outputSize))
// ── Eval cache ───────────────────────────────────────────────────────────
private val EVAL_CACHE_MASK = (1 << 18) - 1L // 256 K slots ≈ 3 MB
private val EVAL_CACHE_MASK = (1 << 18) - 1L
private val evalCacheHashes = new Array[Long](1 << 18)
private val evalCacheScores = new Array[Int](1 << 18)
@@ -93,35 +40,32 @@ class NNUE:
(colorOffset + piece.pieceType.ordinal) * 64 + sqNum
private def addColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
val offset = featureIdx * 1536
for i <- 0 until 1536 do l1Pre(i) += l1WeightsT(offset + i)
val offset = featureIdx * accSize
for i <- 0 until accSize do l1Pre(i) += l1WeightsT(offset + i)
private def subtractColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
val offset = featureIdx * 1536
for i <- 0 until 1536 do l1Pre(i) -= l1WeightsT(offset + i)
val offset = featureIdx * accSize
for i <- 0 until accSize do l1Pre(i) -= l1WeightsT(offset + i)
// ── Accumulator init ─────────────────────────────────────────────────────
/** Initialise l1Stack(0) from scratch using sparse active features. */
def initAccumulator(board: Board): Unit =
System.arraycopy(l1Bias, 0, l1Stack(0), 0, 1536)
System.arraycopy(model.weights(0).bias, 0, l1Stack(0), 0, accSize)
for (sq, piece) <- board.pieces do addColumn(l1Stack(0), featureIndex(piece, squareNum(sq)))
// ── Accumulator push (incremental updates) ───────────────────────────────
/** Copy parent ply's pre-activations to childPly, then apply move deltas. */
def pushAccumulator(childPly: Int, move: Move, board: Board): Unit =
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, 1536)
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize)
val l1 = l1Stack(childPly)
move.moveType match
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
/** Copy pre-activations from parentPly to childPly without any move delta (null-move). */
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, 1536)
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
board.pieceAt(move.from).foreach { mover =>
@@ -170,9 +114,6 @@ class NNUE:
// ── Evaluation from accumulator ──────────────────────────────────────────
/** Evaluate from pre-computed L1 pre-activations at the given ply. Probes eval cache first; stores result after
* computation.
*/
def evaluateAtPly(ply: Int, turn: Color, hash: Long): Int =
val idx = (hash & EVAL_CACHE_MASK).toInt
if evalCacheHashes(idx) == hash then evalCacheScores(idx)
@@ -183,11 +124,19 @@ class NNUE:
score
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
for i <- 0 until 1536 do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
runDenseReLU(l1ReLU, 1536, l2Weights, l2Bias, l2Output, 1024)
runDenseReLU(l2Output, 1024, l3Weights, l3Bias, l3Output, 512)
runDenseReLU(l3Output, 512, l4Weights, l4Bias, l4Output, 256)
val output = runOutputLayer(l4Output, 256)
val l1ReLU = evalBuffers(0)
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
var input = l1ReLU
for i <- 1 until model.layers.length - 1 do
val lw = model.weights(i)
val out = evalBuffers(i)
val ld = model.layers(i)
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
input = out
val lastIdx = model.layers.length - 1
val output = runOutputLayer(input, model.layers(lastIdx).inputSize, model.weights(lastIdx))
scoreFromOutput(output, turn)
private def runDenseReLU(
@@ -202,8 +151,8 @@ class NNUE:
val sum = (0 until inSize).foldLeft(bias(i))((s, j) => s + input(j) * weights(i * inSize + j))
output(i) = if sum > 0f then sum else 0f
private def runOutputLayer(input: Array[Float], inSize: Int): Float =
(0 until inSize).foldLeft(l5Bias(0))((sum, j) => sum + input(j) * l5Weights(j))
private def runOutputLayer(input: Array[Float], inSize: Int, lw: LayerWeights): Float =
(0 until inSize).foldLeft(lw.bias(0))((sum, j) => sum + input(j) * lw.weights(j))
private def scoreFromOutput(output: Float, turn: Color): Int =
val cp =
@@ -214,21 +163,15 @@ class NNUE:
val cpFromTurn = if turn == Color.Black then -cp else cp
math.max(-20000, math.min(20000, cpFromTurn))
// ── Legacy full-board evaluate (kept for Evaluation.evaluate compatibility) ──
// ── Legacy full-board evaluate ────────────────────────────────────────────
// Pre-allocated buffers used only by the legacy evaluate path.
private val features = new Array[Float](768)
private val legacyL1 = new Array[Float](1536)
private val legacyL1 = new Array[Float](accSize)
/** Evaluate using full board scan (sparse over active features). Layout: black pieces at indices 0-5, white at 6-11.
*/
def evaluate(context: GameContext): Int =
val l1Pre = legacyL1
System.arraycopy(l1Bias, 0, l1Pre, 0, 1536)
for (sq, piece) <- context.board.pieces do addColumn(l1Pre, featureIndex(piece, squareNum(sq)))
runL2toOutput(l1Pre, context.turn)
System.arraycopy(model.weights(0).bias, 0, legacyL1, 0, accSize)
for (sq, piece) <- context.board.pieces do addColumn(legacyL1, featureIndex(piece, squareNum(sq)))
runL2toOutput(legacyL1, context.turn)
/** Benchmark: time 1M evaluations and report ns/eval. */
def benchmark(): Unit =
val context = GameContext.initial
val iterations = 1_000_000
@@ -0,0 +1,50 @@
package de.nowchess.bot.bots.nnue
import java.io.InputStream
import java.nio.{ByteBuffer, ByteOrder}
import java.nio.charset.StandardCharsets
object NbaiLoader:
/** Little-endian encoding of ASCII bytes 'N','B','A','I'. */
val MAGIC: Int = 0x4942_414e
def load(stream: InputStream): NbaiModel =
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
checkHeader(buf)
val metadata = readMetadata(buf)
val descs = readLayerDescriptors(buf)
val weights = descs.map(_ => readLayerWeights(buf))
NbaiModel(metadata, descs, weights)
/** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */
def loadDefault(): NbaiModel =
Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match
case Some(s) => try load(s) finally s.close()
case None => NbaiMigrator.migrateFromBin()
private def checkHeader(buf: ByteBuffer): Unit =
val magic = buf.getInt()
if magic != MAGIC then sys.error(s"Invalid NBAI magic: 0x${magic.toHexString}")
val version = buf.getShort() & 0xffff
if version != 1 then sys.error(s"Unsupported NBAI version: $version")
private def readMetadata(buf: ByteBuffer): NbaiMetadata =
val bytes = new Array[Byte](buf.getInt())
buf.get(bytes)
NbaiMetadata.fromJson(new String(bytes, StandardCharsets.UTF_8))
private def readLayerDescriptors(buf: ByteBuffer): Array[LayerDescriptor] =
Array.tabulate(buf.getShort() & 0xffff) { _ =>
val nameBytes = new Array[Byte](buf.get() & 0xff)
buf.get(nameBytes)
LayerDescriptor(new String(nameBytes, StandardCharsets.US_ASCII), buf.getInt(), buf.getInt())
}
private def readLayerWeights(buf: ByteBuffer): LayerWeights =
LayerWeights(readFloats(buf), readFloats(buf))
private def readFloats(buf: ByteBuffer): Array[Float] =
val arr = new Array[Float](buf.getInt())
for i <- arr.indices do arr(i) = buf.getFloat()
arr
@@ -0,0 +1,43 @@
package de.nowchess.bot.bots.nnue
import java.nio.{ByteBuffer, ByteOrder}
/** Converts the legacy nnue_weights.bin resource into an NbaiModel. Used as fallback when no .nbai file exists. */
object NbaiMigrator:
private val BinMagic = 0x4555_4e4e
private val BinVersion = 1
private val DefaultLayers: Array[LayerDescriptor] = Array(
LayerDescriptor("relu", 768, 1536),
LayerDescriptor("relu", 1536, 1024),
LayerDescriptor("relu", 1024, 512),
LayerDescriptor("relu", 512, 256),
LayerDescriptor("linear", 256, 1),
)
private val UnknownMetadata: NbaiMetadata =
NbaiMetadata(trainedBy = "unknown", trainedAt = "unknown", trainingDataCount = 0L, valLoss = 0.0, trainLoss = 0.0)
def migrateFromBin(): NbaiModel =
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
.getOrElse(sys.error("Neither nnue_weights.nbai nor nnue_weights.bin found in resources"))
try
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
checkBinHeader(buf)
val weights = DefaultLayers.map(_ => readBinLayerWeights(buf))
NbaiModel(UnknownMetadata, DefaultLayers, weights)
finally stream.close()
private def checkBinHeader(buf: ByteBuffer): Unit =
val magic = buf.getInt()
if magic != BinMagic then sys.error(s"Invalid bin magic: 0x${magic.toHexString}")
val version = buf.getInt()
if version != BinVersion then sys.error(s"Unsupported bin version: $version")
private def readBinLayerWeights(buf: ByteBuffer): LayerWeights =
LayerWeights(readBinTensor(buf), readBinTensor(buf))
private def readBinTensor(buf: ByteBuffer): Array[Float] =
val shape = Array.tabulate(buf.getInt())(_ => buf.getInt())
Array.tabulate(shape.product)(_ => buf.getFloat())
@@ -0,0 +1,39 @@
package de.nowchess.bot.bots.nnue
/** Descriptor for a single dense layer stored in a .nbai file. */
case class LayerDescriptor(activation: String, inputSize: Int, outputSize: Int)
/** Training metadata embedded in every .nbai file. */
case class NbaiMetadata(
trainedBy: String,
trainedAt: String,
trainingDataCount: Long,
valLoss: Double,
trainLoss: Double,
):
def toJson: String =
s"""{
| "trainedBy": "$trainedBy",
| "trainedAt": "$trainedAt",
| "trainingDataCount": $trainingDataCount,
| "valLoss": $valLoss,
| "trainLoss": $trainLoss
|}""".stripMargin
object NbaiMetadata:
def fromJson(json: String): NbaiMetadata =
def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("")
def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0")
NbaiMetadata(str("trainedBy"), str("trainedAt"), num("trainingDataCount").toLong, num("valLoss").toDouble, num("trainLoss").toDouble)
/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */
case class LayerWeights(weights: Array[Float], bias: Array[Float])
/** A fully deserialized .nbai model ready to initialize NNUE. */
case class NbaiModel(
metadata: NbaiMetadata,
layers: Array[LayerDescriptor],
weights: Array[LayerWeights],
):
require(layers.length == weights.length, "Layer count must match weight count")
require(layers.length >= 2, "Model must have at least 2 layers")
@@ -0,0 +1,51 @@
package de.nowchess.bot.bots.nnue
import java.io.{ByteArrayOutputStream, OutputStream}
import java.nio.{ByteBuffer, ByteOrder}
import java.nio.charset.StandardCharsets
object NbaiWriter:
def write(model: NbaiModel, out: OutputStream): Unit =
val acc = new ByteArrayOutputStream()
writeHeader(acc)
writeMetadata(acc, model.metadata)
writeLayerDescriptors(acc, model.layers)
model.weights.foreach(lw => writeLayerWeights(acc, lw))
out.write(acc.toByteArray)
private def writeHeader(out: ByteArrayOutputStream): Unit =
val buf = ByteBuffer.allocate(6).order(ByteOrder.LITTLE_ENDIAN)
buf.putInt(NbaiLoader.MAGIC)
buf.putShort(1.toShort)
out.write(buf.array())
private def writeMetadata(out: ByteArrayOutputStream, meta: NbaiMetadata): Unit =
val json = meta.toJson.getBytes(StandardCharsets.UTF_8)
val buf = ByteBuffer.allocate(4 + json.length).order(ByteOrder.LITTLE_ENDIAN)
buf.putInt(json.length)
buf.put(json)
out.write(buf.array())
private def writeLayerDescriptors(out: ByteArrayOutputStream, layers: Array[LayerDescriptor]): Unit =
val nameBytes = layers.map(_.activation.getBytes(StandardCharsets.US_ASCII))
val capacity = 2 + layers.indices.map(i => 1 + nameBytes(i).length + 8).sum
val buf = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN)
buf.putShort(layers.length.toShort)
layers.zip(nameBytes).foreach { (l, nb) =>
buf.put(nb.length.toByte)
buf.put(nb)
buf.putInt(l.inputSize)
buf.putInt(l.outputSize)
}
out.write(buf.array())
private def writeLayerWeights(out: ByteArrayOutputStream, lw: LayerWeights): Unit =
writeFloats(out, lw.weights)
writeFloats(out, lw.bias)
private def writeFloats(out: ByteArrayOutputStream, floats: Array[Float]): Unit =
val buf = ByteBuffer.allocate(4 + floats.length * 4).order(ByteOrder.LITTLE_ENDIAN)
buf.putInt(floats.length)
floats.foreach(buf.putFloat)
out.write(buf.array())