feat: Implement dataset versioning and management for NNUE training data
This commit is contained in:
@@ -19,3 +19,4 @@ ENV/
|
|||||||
*.swo
|
*.swo
|
||||||
tactical_data/
|
tactical_data/
|
||||||
trainingdata/
|
trainingdata/
|
||||||
|
/datasets/
|
||||||
|
|||||||
@@ -0,0 +1,173 @@
|
|||||||
|
# Training Dataset Management
|
||||||
|
|
||||||
|
The NNUE training pipeline now features versioned dataset management, similar to model versioning. This prevents data loss and allows you to maintain multiple training configurations.
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
datasets/
|
||||||
|
ds_v1/
|
||||||
|
labeled.jsonl # Training data: {"fen": "...", "eval": 0.5, "eval_raw": 150}
|
||||||
|
metadata.json # Version info and composition
|
||||||
|
ds_v2/
|
||||||
|
labeled.jsonl
|
||||||
|
metadata.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Metadata Schema
|
||||||
|
|
||||||
|
Each dataset has a `metadata.json` file tracking its composition:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"created": "2026-04-13T15:30:45.123456",
|
||||||
|
"total_positions": 1000000,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"sources": [
|
||||||
|
{
|
||||||
|
"type": "generated",
|
||||||
|
"count": 500000,
|
||||||
|
"params": {
|
||||||
|
"num_positions": 500000,
|
||||||
|
"min_move": 1,
|
||||||
|
"max_move": 50
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tactical",
|
||||||
|
"count": 300000,
|
||||||
|
"max_puzzles": 300000
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "file_import",
|
||||||
|
"count": 200000,
|
||||||
|
"path": "/path/to/original_file.txt"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## TUI Workflow
|
||||||
|
|
||||||
|
### Main Menu
|
||||||
|
```
|
||||||
|
1 - Manage Training Data
|
||||||
|
2 - Train Model
|
||||||
|
3 - Export Model
|
||||||
|
4 - Exit
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training Data Management Submenu
|
||||||
|
```
|
||||||
|
1 - Create new dataset
|
||||||
|
2 - Extend existing dataset
|
||||||
|
3 - View all datasets
|
||||||
|
4 - Delete dataset
|
||||||
|
5 - Back
|
||||||
|
```
|
||||||
|
|
||||||
|
## Creating a Dataset
|
||||||
|
|
||||||
|
Use the "Create new dataset" option to add data from one or more sources:
|
||||||
|
|
||||||
|
1. **Generate random positions** — Play random games and sample positions
|
||||||
|
- Number of positions
|
||||||
|
- Move range (min/max move number to sample from)
|
||||||
|
- Number of worker threads
|
||||||
|
|
||||||
|
2. **Import from file** — Load positions from a FEN file
|
||||||
|
- File must contain one FEN string per line
|
||||||
|
- Duplicates are automatically removed
|
||||||
|
|
||||||
|
3. **Extract tactical puzzles** — Download and extract Lichess puzzle database
|
||||||
|
- Maximum number of puzzles to include
|
||||||
|
- Automatically filters for tactical themes (forks, pins, mates, etc.)
|
||||||
|
|
||||||
|
You can combine multiple sources in a single dataset creation session. All positions are:
|
||||||
|
- Deduplicated (only unique FENs are kept)
|
||||||
|
- Labeled with Stockfish evaluations
|
||||||
|
- Saved to `datasets/ds_vN/labeled.jsonl`
|
||||||
|
|
||||||
|
## Extending a Dataset
|
||||||
|
|
||||||
|
Use "Extend existing dataset" to add more positions to an existing dataset:
|
||||||
|
|
||||||
|
1. Select the dataset version to extend
|
||||||
|
2. Choose data sources (same options as creation)
|
||||||
|
3. Confirm labeling parameters
|
||||||
|
4. New positions are:
|
||||||
|
- Labeled with Stockfish
|
||||||
|
- Deduplicated against the target dataset (preventing duplicates)
|
||||||
|
- Merged into the existing `labeled.jsonl`
|
||||||
|
- Metadata is updated with the new source entry
|
||||||
|
|
||||||
|
## Training with a Dataset
|
||||||
|
|
||||||
|
When you start training (Standard or Burst mode), you'll be prompted to select a dataset version. The TUI will display all available datasets with:
|
||||||
|
- Version number
|
||||||
|
- Total number of positions
|
||||||
|
- Source types (generated, tactical, imported)
|
||||||
|
- Stockfish depth used
|
||||||
|
- Creation date
|
||||||
|
|
||||||
|
## Legacy Data Migration
|
||||||
|
|
||||||
|
If you have existing labeled data in `data/training_data.jsonl` from before this update:
|
||||||
|
|
||||||
|
1. Open the "Manage Training Data" menu
|
||||||
|
2. Choose "Create new dataset"
|
||||||
|
3. Select "Import from file"
|
||||||
|
4. Point to `data/training_data.jsonl`
|
||||||
|
5. Complete the dataset creation
|
||||||
|
|
||||||
|
Alternatively, you can manually copy the file to `datasets/ds_v1/labeled.jsonl` and create a `metadata.json` file.
|
||||||
|
|
||||||
|
## Viewing Dataset Details
|
||||||
|
|
||||||
|
Use "View all datasets" to see a table of all datasets with:
|
||||||
|
- Version number
|
||||||
|
- Position count
|
||||||
|
- Source composition
|
||||||
|
- Stockfish depth
|
||||||
|
- Creation date
|
||||||
|
|
||||||
|
## Deleting a Dataset
|
||||||
|
|
||||||
|
Use "Delete dataset" to remove a dataset and free up disk space. **This action cannot be undone.**
|
||||||
|
|
||||||
|
⚠️ The system does not prevent deleting datasets used by model checkpoints. Plan accordingly.
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
|
||||||
|
### Deduplication Strategy
|
||||||
|
|
||||||
|
When extending a dataset, positions are deduplicated **within that dataset only**. This allows different datasets to contain overlapping positions if desired.
|
||||||
|
|
||||||
|
When creating a new dataset from multiple sources, all sources are combined and deduplicated before labeling.
|
||||||
|
|
||||||
|
### Labeled Position Format
|
||||||
|
|
||||||
|
Each line in `labeled.jsonl` is a JSON object:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
|
||||||
|
"eval": 0.0,
|
||||||
|
"eval_raw": 0
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `fen`: The position in Forsyth-Edwards Notation
|
||||||
|
- `eval`: Normalized evaluation ([-1, 1] range using tanh)
|
||||||
|
- `eval_raw`: Raw Stockfish evaluation in centipawns
|
||||||
|
|
||||||
|
### Storage Location
|
||||||
|
|
||||||
|
Datasets are stored in the `datasets/` directory relative to the script location. The old `data/` directory is preserved for backward compatibility but is not actively used by the new system.
|
||||||
|
|
||||||
|
## Performance Tips
|
||||||
|
|
||||||
|
- **Smaller datasets train faster** — Start with 100k-500k positions
|
||||||
|
- **Deduplication matters** — Use the extend functionality to build up your dataset without redundant data
|
||||||
|
- **Stockfish depth** — Depth 12-14 balances accuracy and labeling speed
|
||||||
|
- **Workers** — Use 4-8 workers for labeling if your machine supports it; more workers = faster but uses more CPU/memory
|
||||||
+547
-193
@@ -4,6 +4,7 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
@@ -17,23 +18,23 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
|
|||||||
from generate import play_random_game_and_collect_positions
|
from generate import play_random_game_and_collect_positions
|
||||||
from label import label_positions_with_stockfish
|
from label import label_positions_with_stockfish
|
||||||
from train import train_nnue, burst_train
|
from train import train_nnue, burst_train
|
||||||
from export import export_weights_to_binary
|
from export import export_to_nbai
|
||||||
from tactical_positions_extractor import (
|
from tactical_positions_extractor import (
|
||||||
download_and_extract_puzzle_db,
|
download_and_extract_puzzle_db,
|
||||||
interactive_merge_positions
|
extract_tactical_only
|
||||||
|
)
|
||||||
|
from dataset import (
|
||||||
|
get_datasets_dir,
|
||||||
|
list_datasets,
|
||||||
|
next_dataset_version,
|
||||||
|
load_dataset_metadata,
|
||||||
|
create_dataset,
|
||||||
|
extend_dataset,
|
||||||
|
get_dataset_labeled_path,
|
||||||
|
delete_dataset,
|
||||||
|
show_datasets_table
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_data_dir():
|
|
||||||
"""Get/create data directory."""
|
|
||||||
data_dir = Path(__file__).parent / "data"
|
|
||||||
data_dir.mkdir(exist_ok=True)
|
|
||||||
return data_dir
|
|
||||||
|
|
||||||
def get_tactical_data_dir():
|
|
||||||
"""Get/create data directory."""
|
|
||||||
data_dir = Path(__file__).parent / "tactical_data"
|
|
||||||
data_dir.mkdir(exist_ok=True)
|
|
||||||
return data_dir
|
|
||||||
|
|
||||||
def get_weights_dir():
|
def get_weights_dir():
|
||||||
"""Get/create weights directory."""
|
"""Get/create weights directory."""
|
||||||
@@ -41,6 +42,14 @@ def get_weights_dir():
|
|||||||
weights_dir.mkdir(exist_ok=True)
|
weights_dir.mkdir(exist_ok=True)
|
||||||
return weights_dir
|
return weights_dir
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_dir():
|
||||||
|
"""Get/create legacy data directory (for migration)."""
|
||||||
|
data_dir = Path(__file__).parent / "data"
|
||||||
|
data_dir.mkdir(exist_ok=True)
|
||||||
|
return data_dir
|
||||||
|
|
||||||
|
|
||||||
def list_checkpoints():
|
def list_checkpoints():
|
||||||
"""List available checkpoint versions."""
|
"""List available checkpoint versions."""
|
||||||
weights_dir = get_weights_dir()
|
weights_dir = get_weights_dir()
|
||||||
@@ -49,6 +58,20 @@ def list_checkpoints():
|
|||||||
return []
|
return []
|
||||||
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_legacy_data():
|
||||||
|
"""On first run, offer to import existing data/training_data.jsonl as ds_v1."""
|
||||||
|
console = Console()
|
||||||
|
data_dir = get_data_dir()
|
||||||
|
legacy_file = data_dir / "training_data.jsonl"
|
||||||
|
datasets = list_datasets()
|
||||||
|
|
||||||
|
# Only migrate if legacy data exists and no datasets exist yet
|
||||||
|
if legacy_file.exists() and not datasets:
|
||||||
|
console.print("\n[cyan]Legacy data detected: data/training_data.jsonl[/cyan]")
|
||||||
|
console.print("[dim]Tip: Use 'Manage Training Data' menu to import it as ds_v1[/dim]")
|
||||||
|
|
||||||
|
|
||||||
def show_header():
|
def show_header():
|
||||||
"""Display application header."""
|
"""Display application header."""
|
||||||
console = Console()
|
console = Console()
|
||||||
@@ -56,22 +79,23 @@ def show_header():
|
|||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
|
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
|
||||||
"[dim]Neural Network Utility Evaluation - Model Management[/dim]",
|
"[dim]Neural Network Utility Evaluation - Dataset & Model Management[/dim]",
|
||||||
border_style="cyan",
|
border_style="cyan",
|
||||||
padding=(1, 2),
|
padding=(1, 2),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def show_checkpoints_table():
|
def show_checkpoints_table():
|
||||||
"""Display available checkpoints in a table."""
|
"""Display available checkpoints in a table."""
|
||||||
console = Console()
|
console = Console()
|
||||||
available = list_checkpoints()
|
available = list_checkpoints()
|
||||||
|
|
||||||
if not available:
|
if not available:
|
||||||
console.print("[yellow]ℹ No checkpoints found yet[/yellow]")
|
console.print("[yellow]ℹ No model checkpoints found yet[/yellow]")
|
||||||
return
|
return
|
||||||
|
|
||||||
table = Table(title="Available Checkpoints", show_header=True, header_style="bold cyan")
|
table = Table(title="Available Model Checkpoints", show_header=True, header_style="bold cyan")
|
||||||
table.add_column("Version", style="dim")
|
table.add_column("Version", style="dim")
|
||||||
table.add_column("File Size", justify="right")
|
table.add_column("File Size", justify="right")
|
||||||
table.add_column("Status", justify="center")
|
table.add_column("Status", justify="center")
|
||||||
@@ -87,46 +111,500 @@ def show_checkpoints_table():
|
|||||||
|
|
||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
def show_main_menu():
|
def show_main_menu():
|
||||||
"""Display and handle main menu."""
|
"""Display and handle main menu."""
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
# Migrate legacy data on first run
|
||||||
|
migrate_legacy_data()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
show_header()
|
show_header()
|
||||||
show_checkpoints_table()
|
show_checkpoints_table()
|
||||||
|
|
||||||
console.print("\n[bold]What would you like to do?[/bold]")
|
console.print("\n[bold]What would you like to do?[/bold]")
|
||||||
console.print("[cyan]1[/cyan] - Train NNUE Model")
|
console.print("[cyan]1[/cyan] - Manage Training Data")
|
||||||
console.print("[cyan]2[/cyan] - Burst Train NNUE Model")
|
console.print("[cyan]2[/cyan] - Train Model")
|
||||||
console.print("[cyan]3[/cyan] - Export Weights to Scala")
|
console.print("[cyan]3[/cyan] - Export Model")
|
||||||
console.print("[cyan]4[/cyan] - Extract Tactical Positions")
|
console.print("[cyan]4[/cyan] - Exit")
|
||||||
console.print("[cyan]5[/cyan] - View Checkpoints")
|
|
||||||
console.print("[cyan]6[/cyan] - Exit")
|
|
||||||
|
|
||||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5", "6"])
|
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||||
|
|
||||||
|
if choice == "1":
|
||||||
|
datasets_menu()
|
||||||
|
elif choice == "2":
|
||||||
|
training_menu()
|
||||||
|
elif choice == "3":
|
||||||
|
export_interactive()
|
||||||
|
elif choice == "4":
|
||||||
|
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def datasets_menu():
|
||||||
|
"""Dataset management submenu."""
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
show_header()
|
||||||
|
show_datasets_table(console)
|
||||||
|
|
||||||
|
console.print("\n[bold]Training Data Management[/bold]")
|
||||||
|
console.print("[cyan]1[/cyan] - Create new dataset")
|
||||||
|
console.print("[cyan]2[/cyan] - Extend existing dataset")
|
||||||
|
console.print("[cyan]3[/cyan] - View all datasets")
|
||||||
|
console.print("[cyan]4[/cyan] - Delete dataset")
|
||||||
|
console.print("[cyan]5[/cyan] - Back")
|
||||||
|
|
||||||
|
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
|
||||||
|
|
||||||
|
if choice == "1":
|
||||||
|
create_dataset_interactive()
|
||||||
|
elif choice == "2":
|
||||||
|
extend_dataset_interactive()
|
||||||
|
elif choice == "3":
|
||||||
|
show_header()
|
||||||
|
show_datasets_table(console)
|
||||||
|
Prompt.ask("\nPress Enter to continue")
|
||||||
|
elif choice == "4":
|
||||||
|
delete_dataset_interactive()
|
||||||
|
elif choice == "5":
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset_interactive():
|
||||||
|
"""Interactive dataset creation flow."""
|
||||||
|
console = Console()
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]📊 Create New Dataset[/bold cyan]")
|
||||||
|
|
||||||
|
sources = []
|
||||||
|
combined_count = 0
|
||||||
|
|
||||||
|
# Allow user to add multiple sources
|
||||||
|
while True:
|
||||||
|
console.print("\n[bold]Add data source (repeat until done):[/bold]")
|
||||||
|
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||||
|
console.print("[cyan]b[/cyan] - Import from file")
|
||||||
|
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||||
|
console.print("[cyan]d[/cyan] - Done adding sources")
|
||||||
|
|
||||||
|
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
|
||||||
|
|
||||||
|
if choice == "a":
|
||||||
|
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||||
|
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||||
|
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||||
|
num_workers = int(Prompt.ask("Number of workers", default="8"))
|
||||||
|
|
||||||
|
console.print("[dim]Generating positions...[/dim]")
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||||
|
count = play_random_game_and_collect_positions(
|
||||||
|
str(temp_file),
|
||||||
|
total_positions=num_positions,
|
||||||
|
samples_per_game=1,
|
||||||
|
min_move=min_move,
|
||||||
|
max_move=max_move,
|
||||||
|
num_workers=num_workers
|
||||||
|
)
|
||||||
|
if count > 0:
|
||||||
|
sources.append({
|
||||||
|
"type": "generated",
|
||||||
|
"count": count,
|
||||||
|
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
|
||||||
|
})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} positions generated[/green]")
|
||||||
|
else:
|
||||||
|
console.print("[red]✗ Generation failed[/red]")
|
||||||
|
|
||||||
|
elif choice == "b":
|
||||||
|
file_path = Prompt.ask("Path to FEN file")
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
count = sum(1 for _ in f)
|
||||||
|
sources.append({"type": "file_import", "count": count, "path": file_path})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} positions from file[/green]")
|
||||||
|
except FileNotFoundError:
|
||||||
|
console.print(f"[red]✗ File not found: {file_path}[/red]")
|
||||||
|
|
||||||
|
elif choice == "c":
|
||||||
|
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||||
|
console.print("[dim]Extracting tactical positions...[/dim]")
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||||
|
try:
|
||||||
|
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
|
||||||
|
if csv_path:
|
||||||
|
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
|
||||||
|
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Tactical extraction failed: {e}[/red]")
|
||||||
|
|
||||||
|
elif choice == "d":
|
||||||
|
if not sources:
|
||||||
|
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
if not sources:
|
||||||
|
console.print("[yellow]Dataset creation cancelled[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stockfish labeling parameters
|
||||||
|
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||||
|
stockfish_path = Prompt.ask(
|
||||||
|
"Stockfish path",
|
||||||
|
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||||
|
)
|
||||||
|
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||||
|
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||||
|
|
||||||
|
# Summary and confirm
|
||||||
|
console.print("\n[bold]Dataset Summary:[/bold]")
|
||||||
|
console.print(f" Total positions: {combined_count:,}")
|
||||||
|
for source in sources:
|
||||||
|
console.print(f" - {source['type']}: {source['count']:,}")
|
||||||
|
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||||
|
|
||||||
|
if not Confirm.ask("\nProceed to label and create dataset?", default=True):
|
||||||
|
console.print("[yellow]Cancelled[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Combine all sources into one FEN file
|
||||||
|
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
|
||||||
|
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||||
|
all_fens = set()
|
||||||
|
|
||||||
|
for source in sources:
|
||||||
|
if source['type'] == 'generated':
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||||
|
elif source['type'] == 'file_import':
|
||||||
|
temp_file = Path(source['path'])
|
||||||
|
elif source['type'] == 'tactical':
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||||
|
|
||||||
|
if temp_file.exists():
|
||||||
|
with open(temp_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
fen = line.strip()
|
||||||
|
if fen:
|
||||||
|
all_fens.add(fen)
|
||||||
|
|
||||||
|
with open(combined_fen_file, 'w') as f:
|
||||||
|
for fen in all_fens:
|
||||||
|
f.write(fen + '\n')
|
||||||
|
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
|
||||||
|
|
||||||
|
# Label positions
|
||||||
|
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||||
|
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||||
|
success = label_positions_with_stockfish(
|
||||||
|
str(combined_fen_file),
|
||||||
|
str(labeled_file),
|
||||||
|
stockfish_path,
|
||||||
|
depth=stockfish_depth,
|
||||||
|
num_workers=num_workers
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
console.print("[red]✗ Labeling failed[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print("[green]✓ Positions labeled[/green]")
|
||||||
|
|
||||||
|
# Save dataset
|
||||||
|
console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]")
|
||||||
|
version = next_dataset_version()
|
||||||
|
create_dataset(
|
||||||
|
version=version,
|
||||||
|
labeled_jsonl_path=str(labeled_file),
|
||||||
|
sources=sources,
|
||||||
|
stockfish_depth=stockfish_depth
|
||||||
|
)
|
||||||
|
console.print(f"[green]✓ Dataset created: ds_v{version}[/green]")
|
||||||
|
console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]")
|
||||||
|
|
||||||
|
Prompt.ask("\nPress Enter to continue")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Error: {e}[/red]")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
|
||||||
|
def extend_dataset_interactive():
|
||||||
|
"""Interactive dataset extension flow."""
|
||||||
|
console = Console()
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]📊 Extend Existing Dataset[/bold cyan]")
|
||||||
|
|
||||||
|
datasets = list_datasets()
|
||||||
|
if not datasets:
|
||||||
|
console.print("[yellow]ℹ No datasets available to extend[/yellow]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
return
|
||||||
|
|
||||||
|
show_datasets_table(console)
|
||||||
|
version = int(Prompt.ask("\nEnter dataset version to extend (e.g., 1)"))
|
||||||
|
|
||||||
|
if not any(v == version for v, _ in datasets):
|
||||||
|
console.print("[red]✗ Dataset not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
sources = []
|
||||||
|
combined_count = 0
|
||||||
|
|
||||||
|
# Allow user to add sources
|
||||||
|
while True:
|
||||||
|
console.print("\n[bold]Add data source:[/bold]")
|
||||||
|
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||||
|
console.print("[cyan]b[/cyan] - Import from file")
|
||||||
|
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||||
|
console.print("[cyan]d[/cyan] - Done adding sources")
|
||||||
|
|
||||||
|
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
|
||||||
|
|
||||||
|
if choice == "a":
|
||||||
|
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||||
|
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||||
|
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||||
|
num_workers = int(Prompt.ask("Number of workers", default="8"))
|
||||||
|
|
||||||
|
console.print("[dim]Generating positions...[/dim]")
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||||
|
count = play_random_game_and_collect_positions(
|
||||||
|
str(temp_file),
|
||||||
|
total_positions=num_positions,
|
||||||
|
samples_per_game=1,
|
||||||
|
min_move=min_move,
|
||||||
|
max_move=max_move,
|
||||||
|
num_workers=num_workers
|
||||||
|
)
|
||||||
|
if count > 0:
|
||||||
|
sources.append({
|
||||||
|
"type": "generated",
|
||||||
|
"count": count,
|
||||||
|
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
|
||||||
|
})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} positions generated[/green]")
|
||||||
|
|
||||||
|
elif choice == "b":
|
||||||
|
file_path = Prompt.ask("Path to FEN file")
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
count = sum(1 for _ in f)
|
||||||
|
sources.append({"type": "file_import", "count": count, "path": file_path})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} positions from file[/green]")
|
||||||
|
except FileNotFoundError:
|
||||||
|
console.print(f"[red]✗ File not found: {file_path}[/red]")
|
||||||
|
|
||||||
|
elif choice == "c":
|
||||||
|
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||||
|
console.print("[dim]Extracting tactical positions...[/dim]")
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||||
|
try:
|
||||||
|
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
|
||||||
|
if csv_path:
|
||||||
|
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
|
||||||
|
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Extraction failed: {e}[/red]")
|
||||||
|
|
||||||
|
elif choice == "d":
|
||||||
|
if not sources:
|
||||||
|
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
if not sources:
|
||||||
|
console.print("[yellow]Extension cancelled[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stockfish labeling parameters
|
||||||
|
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||||
|
stockfish_path = Prompt.ask(
|
||||||
|
"Stockfish path",
|
||||||
|
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||||
|
)
|
||||||
|
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||||
|
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||||
|
|
||||||
|
# Summary and confirm
|
||||||
|
console.print("\n[bold]Extension Summary:[/bold]")
|
||||||
|
console.print(f" Target dataset: ds_v{version}")
|
||||||
|
console.print(f" New positions: {combined_count:,}")
|
||||||
|
for source in sources:
|
||||||
|
console.print(f" - {source['type']}: {source['count']:,}")
|
||||||
|
|
||||||
|
if not Confirm.ask("\nProceed to label and extend dataset?", default=True):
|
||||||
|
console.print("[yellow]Cancelled[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Combine all sources into one FEN file
|
||||||
|
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
|
||||||
|
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||||
|
all_fens = set()
|
||||||
|
|
||||||
|
for source in sources:
|
||||||
|
if source['type'] == 'generated':
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||||
|
elif source['type'] == 'file_import':
|
||||||
|
temp_file = Path(source['path'])
|
||||||
|
elif source['type'] == 'tactical':
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||||
|
|
||||||
|
if temp_file.exists():
|
||||||
|
with open(temp_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
fen = line.strip()
|
||||||
|
if fen:
|
||||||
|
all_fens.add(fen)
|
||||||
|
|
||||||
|
with open(combined_fen_file, 'w') as f:
|
||||||
|
for fen in all_fens:
|
||||||
|
f.write(fen + '\n')
|
||||||
|
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
|
||||||
|
|
||||||
|
# Label positions
|
||||||
|
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||||
|
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||||
|
success = label_positions_with_stockfish(
|
||||||
|
str(combined_fen_file),
|
||||||
|
str(labeled_file),
|
||||||
|
stockfish_path,
|
||||||
|
depth=stockfish_depth,
|
||||||
|
num_workers=num_workers
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
console.print("[red]✗ Labeling failed[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print("[green]✓ Positions labeled[/green]")
|
||||||
|
|
||||||
|
# Extend dataset
|
||||||
|
console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]")
|
||||||
|
success = extend_dataset(
|
||||||
|
version=version,
|
||||||
|
new_labeled_path=str(labeled_file),
|
||||||
|
new_source_entry={
|
||||||
|
"type": "merged_sources",
|
||||||
|
"count": len(all_fens),
|
||||||
|
"sources": sources
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
metadata = load_dataset_metadata(version)
|
||||||
|
console.print(f"[green]✓ Dataset extended[/green]")
|
||||||
|
console.print(f"[bold]Total positions: {metadata['total_positions']:,}[/bold]")
|
||||||
|
else:
|
||||||
|
console.print("[red]✗ Extension failed[/red]")
|
||||||
|
|
||||||
|
Prompt.ask("\nPress Enter to continue")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Error: {e}[/red]")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
|
||||||
|
def delete_dataset_interactive():
|
||||||
|
"""Interactive dataset deletion."""
|
||||||
|
console = Console()
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]⚠️ Delete Dataset[/bold cyan]")
|
||||||
|
|
||||||
|
datasets = list_datasets()
|
||||||
|
if not datasets:
|
||||||
|
console.print("[yellow]ℹ No datasets to delete[/yellow]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
return
|
||||||
|
|
||||||
|
show_datasets_table(console)
|
||||||
|
version = int(Prompt.ask("\nEnter dataset version to delete (e.g., 1)"))
|
||||||
|
|
||||||
|
if not any(v == version for v, _ in datasets):
|
||||||
|
console.print("[red]✗ Dataset not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
if Confirm.ask(f"Delete ds_v{version}? This cannot be undone.", default=False):
|
||||||
|
if delete_dataset(version):
|
||||||
|
console.print(f"[green]✓ Dataset ds_v{version} deleted[/green]")
|
||||||
|
else:
|
||||||
|
console.print("[red]✗ Deletion failed[/red]")
|
||||||
|
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
|
||||||
|
def training_menu():
|
||||||
|
"""Training submenu."""
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold]Training[/bold]")
|
||||||
|
console.print("[cyan]1[/cyan] - Standard Training")
|
||||||
|
console.print("[cyan]2[/cyan] - Burst Training")
|
||||||
|
console.print("[cyan]3[/cyan] - View Model Checkpoints")
|
||||||
|
console.print("[cyan]4[/cyan] - Back")
|
||||||
|
|
||||||
|
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||||
|
|
||||||
if choice == "1":
|
if choice == "1":
|
||||||
train_interactive()
|
train_interactive()
|
||||||
elif choice == "2":
|
elif choice == "2":
|
||||||
burst_train_interactive()
|
burst_train_interactive()
|
||||||
elif choice == "3":
|
elif choice == "3":
|
||||||
export_interactive()
|
|
||||||
elif choice == "4":
|
|
||||||
extract_tactical_interactive()
|
|
||||||
elif choice == "5":
|
|
||||||
show_header()
|
show_header()
|
||||||
show_checkpoints_table()
|
show_checkpoints_table()
|
||||||
Prompt.ask("\nPress Enter to continue")
|
Prompt.ask("\nPress Enter to continue")
|
||||||
elif choice == "6":
|
elif choice == "4":
|
||||||
console.print("[yellow]👋 Goodbye![/yellow]")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def train_interactive():
|
def train_interactive():
|
||||||
"""Interactive training menu."""
|
"""Interactive training menu."""
|
||||||
console = Console()
|
console = Console()
|
||||||
show_header()
|
show_header()
|
||||||
|
|
||||||
console.print("\n[bold cyan]📚 Training Configuration[/bold cyan]")
|
console.print("\n[bold cyan]📚 Standard Training Configuration[/bold cyan]")
|
||||||
|
|
||||||
|
# Dataset selection
|
||||||
|
datasets = list_datasets()
|
||||||
|
if not datasets:
|
||||||
|
console.print("[red]✗ No datasets available. Create one first.[/red]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print("\n[bold]Available Datasets:[/bold]")
|
||||||
|
show_datasets_table(console)
|
||||||
|
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
|
||||||
|
|
||||||
|
if not any(v == dataset_version for v, _ in datasets):
|
||||||
|
console.print("[red]✗ Dataset not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
labeled_file = get_dataset_labeled_path(dataset_version)
|
||||||
|
if not labeled_file:
|
||||||
|
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
# Checkpoint selection
|
# Checkpoint selection
|
||||||
available = list_checkpoints()
|
available = list_checkpoints()
|
||||||
@@ -142,36 +620,6 @@ def train_interactive():
|
|||||||
default=str(max(available))
|
default=str(max(available))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Positions source
|
|
||||||
use_existing = Confirm.ask("Use existing positions file?", default=False)
|
|
||||||
positions_file = None
|
|
||||||
num_games = 500000
|
|
||||||
samples_per_game = 1
|
|
||||||
min_move = 1
|
|
||||||
max_move = 50
|
|
||||||
|
|
||||||
if use_existing:
|
|
||||||
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
|
|
||||||
else:
|
|
||||||
num_games = int(Prompt.ask("Number of games to generate", default="5000"))
|
|
||||||
samples_per_game = int(Prompt.ask("Positions to sample per game", default="1"))
|
|
||||||
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
|
||||||
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
|
||||||
|
|
||||||
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
|
|
||||||
labels_file = None
|
|
||||||
if use_existing_labels:
|
|
||||||
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
|
|
||||||
|
|
||||||
# Stockfish path and labeling parameters
|
|
||||||
default_stockfish = os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
|
||||||
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
|
|
||||||
stockfish_depth = 12
|
|
||||||
num_workers = 1
|
|
||||||
if not use_existing_labels:
|
|
||||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
|
||||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
|
||||||
|
|
||||||
# Training parameters
|
# Training parameters
|
||||||
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
||||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||||
@@ -182,16 +630,7 @@ def train_interactive():
|
|||||||
|
|
||||||
# Confirm and start
|
# Confirm and start
|
||||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||||
if use_checkpoint:
|
console.print(f" Dataset: ds_v{dataset_version}")
|
||||||
console.print(f" Checkpoint: v{checkpoint_version}")
|
|
||||||
else:
|
|
||||||
console.print(" Checkpoint: None (training from scratch)")
|
|
||||||
if not use_existing:
|
|
||||||
console.print(f" Games: {num_games:,}")
|
|
||||||
console.print(f" Samples per game: {samples_per_game}")
|
|
||||||
console.print(f" Move range: {min_move}-{max_move}")
|
|
||||||
else:
|
|
||||||
console.print(f" Positions file: {positions_file}")
|
|
||||||
console.print(f" Epochs: {epochs}")
|
console.print(f" Epochs: {epochs}")
|
||||||
console.print(f" Batch size: {batch_size}")
|
console.print(f" Batch size: {batch_size}")
|
||||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||||
@@ -199,70 +638,27 @@ def train_interactive():
|
|||||||
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||||||
else:
|
else:
|
||||||
console.print(f" Early stopping: No")
|
console.print(f" Early stopping: No")
|
||||||
if not use_existing_labels:
|
if use_checkpoint:
|
||||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
console.print(f" Checkpoint: v{checkpoint_version}")
|
||||||
console.print(f" Workers: {num_workers}")
|
else:
|
||||||
console.print(f" Stockfish: {stockfish_path}")
|
console.print(f" Checkpoint: None (training from scratch)")
|
||||||
|
|
||||||
if not Confirm.ask("\nStart training?", default=True):
|
if not Confirm.ask("\nStart training?", default=True):
|
||||||
console.print("[yellow]Training cancelled[/yellow]")
|
console.print("[yellow]Training cancelled[/yellow]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Execute training
|
# Execute training
|
||||||
data_dir = get_data_dir()
|
|
||||||
weights_dir = get_weights_dir()
|
weights_dir = get_weights_dir()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate positions
|
console.print("\n[bold cyan]Training Model[/bold cyan]")
|
||||||
if not use_existing:
|
|
||||||
console.print("\n[bold cyan]Step 1: Generating Positions[/bold cyan]")
|
|
||||||
count = play_random_game_and_collect_positions(
|
|
||||||
str(data_dir / "positions.txt"),
|
|
||||||
total_games=num_games,
|
|
||||||
samples_per_game=samples_per_game,
|
|
||||||
min_move=min_move,
|
|
||||||
max_move=max_move
|
|
||||||
)
|
|
||||||
if count == 0:
|
|
||||||
console.print("[red]✗ No valid positions generated[/red]")
|
|
||||||
return
|
|
||||||
console.print(f"[green]✓ Generated {count:,} positions[/green]")
|
|
||||||
else:
|
|
||||||
if not Path(positions_file).exists():
|
|
||||||
console.print(f"[red]✗ Positions file not found: {positions_file}[/red]")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not use_existing_labels:
|
|
||||||
# Label positions
|
|
||||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
|
||||||
positions_file = data_dir / "positions.txt"
|
|
||||||
output_file = data_dir / "training_data.jsonl"
|
|
||||||
success = label_positions_with_stockfish(
|
|
||||||
str(positions_file),
|
|
||||||
str(output_file),
|
|
||||||
stockfish_path,
|
|
||||||
depth=stockfish_depth,
|
|
||||||
num_workers=num_workers
|
|
||||||
)
|
|
||||||
if not success:
|
|
||||||
console.print("[red]✗ Position labeling failed[/red]")
|
|
||||||
return
|
|
||||||
console.print(f"[green]✓ Positions labeled[/green]")
|
|
||||||
else:
|
|
||||||
console.print("\n[bold cyan]Step 2: Loading Existing Labels[/bold cyan]")
|
|
||||||
output_file = labels_file
|
|
||||||
if not Path(output_file).exists():
|
|
||||||
console.print(f"[red]✗ Labels file not found: {output_file}[/red]")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Train model
|
|
||||||
console.print("\n[bold cyan]Step 3: Training Model[/bold cyan]")
|
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
if use_checkpoint:
|
if use_checkpoint:
|
||||||
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
|
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
|
||||||
|
|
||||||
train_nnue(
|
train_nnue(
|
||||||
data_file=str(output_file),
|
data_file=str(labeled_file),
|
||||||
output_file=str(weights_dir / "nnue_weights.pt"),
|
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@@ -286,6 +682,7 @@ def train_interactive():
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
Prompt.ask("Press Enter to continue")
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
|
||||||
def burst_train_interactive():
|
def burst_train_interactive():
|
||||||
"""Interactive burst training menu."""
|
"""Interactive burst training menu."""
|
||||||
console = Console()
|
console = Console()
|
||||||
@@ -294,14 +691,30 @@ def burst_train_interactive():
|
|||||||
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
|
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
|
||||||
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
|
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
|
||||||
|
|
||||||
|
# Dataset selection
|
||||||
|
datasets = list_datasets()
|
||||||
|
if not datasets:
|
||||||
|
console.print("[red]✗ No datasets available. Create one first.[/red]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print("[bold]Available Datasets:[/bold]")
|
||||||
|
show_datasets_table(console)
|
||||||
|
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
|
||||||
|
|
||||||
|
if not any(v == dataset_version for v, _ in datasets):
|
||||||
|
console.print("[red]✗ Dataset not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
labeled_file = get_dataset_labeled_path(dataset_version)
|
||||||
|
if not labeled_file:
|
||||||
|
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
|
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
|
||||||
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
|
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
|
||||||
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
|
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
|
||||||
|
|
||||||
# Data file
|
|
||||||
default_labels = str(get_data_dir() / "training_data.jsonl")
|
|
||||||
labels_file = Prompt.ask("Path to labeled data file (.jsonl)", default=default_labels)
|
|
||||||
|
|
||||||
# Optional initial checkpoint
|
# Optional initial checkpoint
|
||||||
available = list_checkpoints()
|
available = list_checkpoints()
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
@@ -317,29 +730,25 @@ def burst_train_interactive():
|
|||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||||
|
console.print(f" Dataset: ds_v{dataset_version}")
|
||||||
console.print(f" Duration: {duration_minutes:.0f} minutes")
|
console.print(f" Duration: {duration_minutes:.0f} minutes")
|
||||||
console.print(f" Epochs per season: {epochs_per_season}")
|
console.print(f" Epochs per season: {epochs_per_season}")
|
||||||
console.print(f" Patience: {early_stopping_patience}")
|
console.print(f" Patience: {early_stopping_patience}")
|
||||||
console.print(f" Data file: {labels_file}")
|
|
||||||
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
|
|
||||||
console.print(f" Batch size: {batch_size}")
|
console.print(f" Batch size: {batch_size}")
|
||||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||||
|
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
|
||||||
|
|
||||||
if not Confirm.ask("\nStart burst training?", default=True):
|
if not Confirm.ask("\nStart burst training?", default=True):
|
||||||
console.print("[yellow]Burst training cancelled[/yellow]")
|
console.print("[yellow]Burst training cancelled[/yellow]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
return
|
return
|
||||||
|
|
||||||
weights_dir = get_weights_dir()
|
weights_dir = get_weights_dir()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not Path(labels_file).exists():
|
|
||||||
console.print(f"[red]✗ Data file not found: {labels_file}[/red]")
|
|
||||||
Prompt.ask("Press Enter to continue")
|
|
||||||
return
|
|
||||||
|
|
||||||
console.print("\n[bold cyan]Burst Training[/bold cyan]")
|
console.print("\n[bold cyan]Burst Training[/bold cyan]")
|
||||||
burst_train(
|
burst_train(
|
||||||
data_file=labels_file,
|
data_file=str(labeled_file),
|
||||||
output_file=str(weights_dir / "nnue_weights.pt"),
|
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||||
duration_minutes=duration_minutes,
|
duration_minutes=duration_minutes,
|
||||||
epochs_per_season=epochs_per_season,
|
epochs_per_season=epochs_per_season,
|
||||||
@@ -362,6 +771,7 @@ def burst_train_interactive():
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
Prompt.ask("Press Enter to continue")
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
|
||||||
def export_interactive():
|
def export_interactive():
|
||||||
"""Interactive export menu."""
|
"""Interactive export menu."""
|
||||||
console = Console()
|
console = Console()
|
||||||
@@ -380,7 +790,7 @@ def export_interactive():
|
|||||||
version = Prompt.ask("Enter version to export (e.g., 2)")
|
version = Prompt.ask("Enter version to export (e.g., 2)")
|
||||||
|
|
||||||
weights_file = f"nnue_weights_v{version}.pt"
|
weights_file = f"nnue_weights_v{version}.pt"
|
||||||
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
|
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.nbai")
|
||||||
|
|
||||||
console.print(f"\n[bold]Export Configuration:[/bold]")
|
console.print(f"\n[bold]Export Configuration:[/bold]")
|
||||||
console.print(f" Source: {weights_file}")
|
console.print(f" Source: {weights_file}")
|
||||||
@@ -399,7 +809,7 @@ def export_interactive():
|
|||||||
return
|
return
|
||||||
|
|
||||||
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
|
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
|
||||||
export_weights_to_binary(str(weights_path), output_file)
|
export_to_nbai(str(weights_path), output_file)
|
||||||
console.print(f"\n[green]✓ Export complete![/green]")
|
console.print(f"\n[green]✓ Export complete![/green]")
|
||||||
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
|
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
|
||||||
Prompt.ask("Press Enter to continue")
|
Prompt.ask("Press Enter to continue")
|
||||||
@@ -410,65 +820,6 @@ def export_interactive():
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
Prompt.ask("Press Enter to continue")
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
def extract_tactical_interactive():
|
|
||||||
"""Interactive tactical positions extraction and merge menu."""
|
|
||||||
console = Console()
|
|
||||||
show_header()
|
|
||||||
|
|
||||||
console.print("\n[bold cyan]♟️ Tactical Positions Extraction & Merge[/bold cyan]")
|
|
||||||
|
|
||||||
# Download and extract options
|
|
||||||
console.print("\n[bold]Lichess Puzzle Database:[/bold]")
|
|
||||||
download_url = Prompt.ask(
|
|
||||||
"Download URL",
|
|
||||||
default="https://database.lichess.org/lichess_db_puzzle.csv.zst"
|
|
||||||
)
|
|
||||||
|
|
||||||
output_dir = Prompt.ask(
|
|
||||||
"Extract to directory",
|
|
||||||
default=str(Path(__file__).parent / "trainingdata")
|
|
||||||
)
|
|
||||||
|
|
||||||
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
|
||||||
|
|
||||||
# Confirm and download
|
|
||||||
console.print("\n[bold]Configuration:[/bold]")
|
|
||||||
console.print(f" Download URL: {download_url}")
|
|
||||||
console.print(f" Extract directory: {output_dir}")
|
|
||||||
console.print(f" Max puzzles: {max_puzzles:,}")
|
|
||||||
|
|
||||||
if not Confirm.ask("\nProceed?", default=True):
|
|
||||||
console.print("[yellow]Cancelled[/yellow]")
|
|
||||||
Prompt.ask("Press Enter to continue")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
console.print("\n[bold cyan]Step 1: Download & Extract[/bold cyan]")
|
|
||||||
csv_path = download_and_extract_puzzle_db(download_url, output_dir)
|
|
||||||
|
|
||||||
if not csv_path:
|
|
||||||
console.print("[red]✗ Failed to download/extract[/red]")
|
|
||||||
Prompt.ask("Press Enter to continue")
|
|
||||||
return
|
|
||||||
|
|
||||||
console.print(f"[green]✓ Ready: {csv_path}[/green]")
|
|
||||||
|
|
||||||
# Interactive merge
|
|
||||||
console.print("\n[bold cyan]Step 2: Extract & Merge[/bold cyan]")
|
|
||||||
output_file = Prompt.ask(
|
|
||||||
"Output file path",
|
|
||||||
default=str(Path(__file__).parent / "data" / "position.txt")
|
|
||||||
)
|
|
||||||
|
|
||||||
interactive_merge_positions(csv_path, output_file, max_puzzles)
|
|
||||||
console.print(f"\n[green]✓ Complete![/green]")
|
|
||||||
Prompt.ask("Press Enter to continue")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
console.print(f"[red]✗ Error: {e}[/red]")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
Prompt.ask("Press Enter to continue")
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
try:
|
try:
|
||||||
@@ -481,7 +832,10 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
console = Console()
|
console = Console()
|
||||||
console.print(f"[red]Error:[/red] {e}")
|
console.print(f"[red]Error:[/red] {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sys.exit(main())
|
sys.exit(main())
|
||||||
|
|||||||
@@ -0,0 +1,287 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Dataset versioning and management for NNUE training data."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, List, Tuple
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
|
||||||
|
def get_datasets_dir() -> Path:
|
||||||
|
"""Get/create datasets directory."""
|
||||||
|
datasets_dir = Path(__file__).parent.parent / "datasets"
|
||||||
|
datasets_dir.mkdir(exist_ok=True)
|
||||||
|
return datasets_dir
|
||||||
|
|
||||||
|
|
||||||
|
def next_dataset_version() -> int:
|
||||||
|
"""Find the next available dataset version number."""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
versions = []
|
||||||
|
|
||||||
|
for d in datasets_dir.iterdir():
|
||||||
|
if d.is_dir() and d.name.startswith("ds_v"):
|
||||||
|
try:
|
||||||
|
v = int(d.name.split("_v")[1])
|
||||||
|
versions.append(v)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return max(versions) + 1 if versions else 1
|
||||||
|
|
||||||
|
|
||||||
|
def list_datasets() -> List[Tuple[int, Dict]]:
|
||||||
|
"""List all datasets with their metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (version, metadata_dict) tuples, sorted by version.
|
||||||
|
"""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
datasets = []
|
||||||
|
|
||||||
|
for d in datasets_dir.iterdir():
|
||||||
|
if d.is_dir() and d.name.startswith("ds_v"):
|
||||||
|
try:
|
||||||
|
v = int(d.name.split("_v")[1])
|
||||||
|
metadata_file = d / "metadata.json"
|
||||||
|
if metadata_file.exists():
|
||||||
|
with open(metadata_file, 'r') as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
datasets.append((v, metadata))
|
||||||
|
except (ValueError, IndexError, json.JSONDecodeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return sorted(datasets, key=lambda x: x[0])
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_metadata(version: int) -> Optional[Dict]:
|
||||||
|
"""Load metadata for a specific dataset version.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Metadata dict or None if not found.
|
||||||
|
"""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
metadata_file = datasets_dir / f"ds_v{version}" / "metadata.json"
|
||||||
|
|
||||||
|
if not metadata_file.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(metadata_file, 'r') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def save_dataset_metadata(version: int, metadata: Dict) -> None:
|
||||||
|
"""Save metadata for a dataset version."""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||||
|
dataset_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
metadata_file = dataset_dir / "metadata.json"
|
||||||
|
with open(metadata_file, 'w') as f:
|
||||||
|
json.dump(metadata, f, indent=2, default=str)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(
|
||||||
|
version: int,
|
||||||
|
labeled_jsonl_path: str,
|
||||||
|
sources: List[Dict],
|
||||||
|
stockfish_depth: int = 12
|
||||||
|
) -> Path:
|
||||||
|
"""Create a new versioned dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: Dataset version number
|
||||||
|
labeled_jsonl_path: Path to labeled.jsonl to copy
|
||||||
|
sources: List of source dicts (see plan for schema)
|
||||||
|
stockfish_depth: Depth used for labeling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the created dataset directory.
|
||||||
|
"""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||||
|
dataset_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Copy labeled data with deduplication (in case source has duplicates)
|
||||||
|
source_path = Path(labeled_jsonl_path)
|
||||||
|
if source_path.exists():
|
||||||
|
dest_path = dataset_dir / "labeled.jsonl"
|
||||||
|
seen_fens = set()
|
||||||
|
unique_count = 0
|
||||||
|
|
||||||
|
with open(source_path, 'r') as src, open(dest_path, 'w') as dst:
|
||||||
|
for line in src:
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
fen = data.get('fen')
|
||||||
|
if fen and fen not in seen_fens:
|
||||||
|
dst.write(line)
|
||||||
|
seen_fens.add(fen)
|
||||||
|
unique_count += 1
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Skip malformed lines
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Count positions
|
||||||
|
total_positions = 0
|
||||||
|
if (dataset_dir / "labeled.jsonl").exists():
|
||||||
|
with open(dataset_dir / "labeled.jsonl", 'r') as f:
|
||||||
|
total_positions = sum(1 for _ in f)
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"version": version,
|
||||||
|
"created": datetime.now().isoformat(),
|
||||||
|
"total_positions": total_positions,
|
||||||
|
"stockfish_depth": stockfish_depth,
|
||||||
|
"sources": sources
|
||||||
|
}
|
||||||
|
|
||||||
|
save_dataset_metadata(version, metadata)
|
||||||
|
return dataset_dir
|
||||||
|
|
||||||
|
|
||||||
|
def extend_dataset(
|
||||||
|
version: int,
|
||||||
|
new_labeled_path: str,
|
||||||
|
new_source_entry: Dict
|
||||||
|
) -> bool:
|
||||||
|
"""Extend an existing dataset with new labeled positions (with deduplication).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: Dataset version to extend
|
||||||
|
new_labeled_path: Path to new labeled.jsonl to merge
|
||||||
|
new_source_entry: Source entry to add to metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise.
|
||||||
|
"""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||||
|
|
||||||
|
if not dataset_dir.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
labeled_file = dataset_dir / "labeled.jsonl"
|
||||||
|
new_labeled_file = Path(new_labeled_path)
|
||||||
|
|
||||||
|
if not new_labeled_file.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Load existing FENs (dedup set) — must load entire file to avoid duplicates
|
||||||
|
existing_fens = set()
|
||||||
|
if labeled_file.exists():
|
||||||
|
with open(labeled_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
fen = data.get('fen')
|
||||||
|
if fen:
|
||||||
|
existing_fens.add(fen)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Merge new positions, skipping duplicates
|
||||||
|
new_count = 0
|
||||||
|
new_lines = []
|
||||||
|
with open(new_labeled_file, 'r') as f_new:
|
||||||
|
for line in f_new:
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
fen = data.get('fen')
|
||||||
|
if fen and fen not in existing_fens:
|
||||||
|
new_lines.append(line)
|
||||||
|
existing_fens.add(fen)
|
||||||
|
new_count += 1
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Append only the new, unique positions
|
||||||
|
if new_lines:
|
||||||
|
with open(labeled_file, 'a') as f_append:
|
||||||
|
for line in new_lines:
|
||||||
|
f_append.write(line)
|
||||||
|
|
||||||
|
# Update metadata
|
||||||
|
metadata = load_dataset_metadata(version)
|
||||||
|
if metadata:
|
||||||
|
# Count total positions
|
||||||
|
total_positions = 0
|
||||||
|
with open(labeled_file, 'r') as f:
|
||||||
|
total_positions = sum(1 for _ in f)
|
||||||
|
|
||||||
|
metadata['total_positions'] = total_positions
|
||||||
|
# Update the source entry with actual count of new positions added
|
||||||
|
new_source_entry['actual_count'] = new_count
|
||||||
|
metadata['sources'].append(new_source_entry)
|
||||||
|
save_dataset_metadata(version, metadata)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_labeled_path(version: int) -> Optional[Path]:
|
||||||
|
"""Get the path to a dataset's labeled.jsonl file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to labeled.jsonl or None if dataset doesn't exist.
|
||||||
|
"""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
labeled_file = datasets_dir / f"ds_v{version}" / "labeled.jsonl"
|
||||||
|
|
||||||
|
if labeled_file.exists():
|
||||||
|
return labeled_file
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def delete_dataset(version: int) -> bool:
|
||||||
|
"""Delete a dataset (recursively removes directory).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: Dataset version to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful.
|
||||||
|
"""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||||
|
|
||||||
|
if not dataset_dir.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(dataset_dir)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def show_datasets_table(console: Console = None) -> None:
|
||||||
|
"""Display all datasets in a Rich table."""
|
||||||
|
if console is None:
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
datasets = list_datasets()
|
||||||
|
|
||||||
|
if not datasets:
|
||||||
|
console.print("[yellow]ℹ No datasets found yet[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title="Available Datasets", show_header=True, header_style="bold cyan")
|
||||||
|
table.add_column("Version", style="dim")
|
||||||
|
table.add_column("Positions", justify="right")
|
||||||
|
table.add_column("Sources", justify="left")
|
||||||
|
table.add_column("Depth", justify="center")
|
||||||
|
table.add_column("Created", justify="left")
|
||||||
|
|
||||||
|
for v, metadata in datasets:
|
||||||
|
positions = metadata.get('total_positions', 0)
|
||||||
|
sources = metadata.get('sources', [])
|
||||||
|
source_str = ", ".join([s.get('type', '?') for s in sources])
|
||||||
|
depth = metadata.get('stockfish_depth', '?')
|
||||||
|
created = metadata.get('created', '?')
|
||||||
|
if created != '?':
|
||||||
|
created = created.split('T')[0] # Just the date
|
||||||
|
|
||||||
|
table.add_row(f"v{v}", f"{positions:,}", source_str, str(depth), created)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
@@ -1,67 +1,137 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Export NNUE weights to binary format for runtime loading."""
|
"""Export NNUE weights to .nbai format for runtime loading."""
|
||||||
|
|
||||||
import torch
|
import json
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
def export_weights_to_binary(weights_file, output_file):
|
import torch
|
||||||
"""Load PyTorch weights and export as binary file."""
|
|
||||||
|
|
||||||
|
MAGIC = 0x4942_414E # bytes 'N','B','A','I' as little-endian int32
|
||||||
|
VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
|
def _read_sidecar(weights_file: str) -> dict:
|
||||||
|
sidecar = weights_file.replace(".pt", "_metadata.json")
|
||||||
|
if Path(sidecar).exists():
|
||||||
|
with open(sidecar) as f:
|
||||||
|
return json.load(f)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_layers(state_dict: dict) -> list[dict]:
|
||||||
|
"""Derive layer descriptors from state_dict weight shapes.
|
||||||
|
|
||||||
|
Assumes layers named l1, l2, ..., lN.
|
||||||
|
All hidden layers get activation 'relu'; the last gets 'linear'.
|
||||||
|
"""
|
||||||
|
names = sorted(
|
||||||
|
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
|
||||||
|
key=lambda n: int(n[1:]),
|
||||||
|
)
|
||||||
|
layers = []
|
||||||
|
for i, name in enumerate(names):
|
||||||
|
out_size, in_size = state_dict[f"{name}.weight"].shape
|
||||||
|
activation = "linear" if i == len(names) - 1 else "relu"
|
||||||
|
layers.append({"activation": activation, "inputSize": int(in_size), "outputSize": int(out_size)})
|
||||||
|
return layers
|
||||||
|
|
||||||
|
|
||||||
|
def _write_floats(f, tensor):
|
||||||
|
data = tensor.float().flatten().cpu().numpy()
|
||||||
|
f.write(struct.pack("<I", len(data)))
|
||||||
|
f.write(struct.pack(f"<{len(data)}f", *data))
|
||||||
|
|
||||||
|
|
||||||
|
def export_to_nbai(
|
||||||
|
weights_file: str,
|
||||||
|
output_file: str,
|
||||||
|
trained_by: str = "unknown",
|
||||||
|
train_loss: float = 0.0,
|
||||||
|
):
|
||||||
if not Path(weights_file).exists():
|
if not Path(weights_file).exists():
|
||||||
print(f"Error: Weights file not found at {weights_file}")
|
print(f"Error: weights file not found at {weights_file}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Load weights — handle both raw state dicts and full training checkpoints
|
loaded = torch.load(weights_file, map_location="cpu")
|
||||||
loaded = torch.load(weights_file, map_location='cpu')
|
state_dict = (
|
||||||
state_dict = loaded["model_state_dict"] if isinstance(loaded, dict) and "model_state_dict" in loaded else loaded
|
loaded["model_state_dict"]
|
||||||
|
if isinstance(loaded, dict) and "model_state_dict" in loaded
|
||||||
|
else loaded
|
||||||
|
)
|
||||||
|
|
||||||
# Debug: print available layers
|
sidecar = _read_sidecar(weights_file)
|
||||||
print(f"Available layers in {weights_file}:")
|
val_loss = float(loaded.get("best_val_loss", sidecar.get("final_val_loss", 0.0))) if isinstance(loaded, dict) else 0.0
|
||||||
for key in sorted(state_dict.keys()):
|
trained_at = sidecar.get("date", datetime.now().isoformat())
|
||||||
print(f" {key}: {state_dict[key].shape}")
|
training_data_count = int(sidecar.get("num_positions", 0))
|
||||||
|
|
||||||
# Create output directory if needed
|
metadata = {
|
||||||
output_path = Path(output_file)
|
"trainedBy": trained_by,
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
"trainedAt": trained_at,
|
||||||
|
"trainingDataCount": training_data_count,
|
||||||
|
"valLoss": val_loss,
|
||||||
|
"trainLoss": train_loss,
|
||||||
|
}
|
||||||
|
|
||||||
with open(output_file, 'wb') as f:
|
layers = _infer_layers(state_dict)
|
||||||
# Write magic number and version
|
layer_names = sorted(
|
||||||
f.write(b'NNUE')
|
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
|
||||||
f.write(struct.pack('<I', 1)) # version 1
|
key=lambda n: int(n[1:]),
|
||||||
|
)
|
||||||
|
|
||||||
# Write each weight tensor in order
|
print(f"Architecture ({len(layers)} layers):")
|
||||||
for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias', 'l4.weight', 'l4.bias', 'l5.weight', 'l5.bias']:
|
for i, l in enumerate(layers):
|
||||||
if layer_name not in state_dict:
|
print(f" l{i + 1}: {l['inputSize']} -> {l['outputSize']} [{l['activation']}]")
|
||||||
print(f"Error: Missing layer {layer_name}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
tensor = state_dict[layer_name]
|
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
|
||||||
# Convert to float32 and flatten
|
|
||||||
data = tensor.float().flatten().cpu().numpy()
|
|
||||||
|
|
||||||
# Write shape (allows validation on load)
|
with open(output_file, "wb") as f:
|
||||||
shape = list(tensor.shape)
|
# Header
|
||||||
f.write(struct.pack('<I', len(shape)))
|
f.write(struct.pack("<I", MAGIC))
|
||||||
for dim in shape:
|
f.write(struct.pack("<H", VERSION))
|
||||||
f.write(struct.pack('<I', dim))
|
|
||||||
|
|
||||||
# Write flattened data as binary floats
|
# Metadata (length-prefixed UTF-8 JSON)
|
||||||
f.write(struct.pack(f'<{len(data)}f', *data))
|
meta_bytes = json.dumps(metadata, indent=2).encode("utf-8")
|
||||||
|
f.write(struct.pack("<I", len(meta_bytes)))
|
||||||
|
f.write(meta_bytes)
|
||||||
|
|
||||||
print(f" {layer_name}: shape {shape}, {len(data)} floats")
|
# Layer descriptors
|
||||||
|
f.write(struct.pack("<H", len(layers)))
|
||||||
|
for layer in layers:
|
||||||
|
name_bytes = layer["activation"].encode("ascii")
|
||||||
|
f.write(struct.pack("<B", len(name_bytes)))
|
||||||
|
f.write(name_bytes)
|
||||||
|
f.write(struct.pack("<I", layer["inputSize"]))
|
||||||
|
f.write(struct.pack("<I", layer["outputSize"]))
|
||||||
|
|
||||||
|
# Weights: weight tensor then bias tensor per layer
|
||||||
|
for name in layer_names:
|
||||||
|
w = state_dict[f"{name}.weight"]
|
||||||
|
b = state_dict[f"{name}.bias"]
|
||||||
|
_write_floats(f, w)
|
||||||
|
_write_floats(f, b)
|
||||||
|
print(f" Wrote {name}: weight {tuple(w.shape)}, bias {tuple(b.shape)}")
|
||||||
|
|
||||||
|
size_mb = Path(output_file).stat().st_size / (1024 ** 2)
|
||||||
|
print(f"\nExported to {output_file} ({size_mb:.2f} MB)")
|
||||||
|
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||||
|
|
||||||
file_size_mb = output_path.stat().st_size / (1024**2)
|
|
||||||
print(f"Weights exported to {output_file} ({file_size_mb:.2f} MB)")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
weights_file = "nnue_weights.pt"
|
weights_file = "nnue_weights.pt"
|
||||||
output_file = "../src/main/resources/nnue_weights.bin"
|
output_file = "../src/main/resources/nnue_weights.nbai"
|
||||||
|
trained_by = "unknown"
|
||||||
|
train_loss = 0.0
|
||||||
|
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
weights_file = sys.argv[1]
|
weights_file = sys.argv[1]
|
||||||
if len(sys.argv) > 2:
|
if len(sys.argv) > 2:
|
||||||
output_file = sys.argv[2]
|
output_file = sys.argv[2]
|
||||||
|
if len(sys.argv) > 3:
|
||||||
|
trained_by = sys.argv[3]
|
||||||
|
if len(sys.argv) > 4:
|
||||||
|
train_loss = float(sys.argv[4])
|
||||||
|
|
||||||
export_weights_to_binary(weights_file, output_file)
|
export_to_nbai(weights_file, output_file, trained_by, train_loss)
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
|||||||
|
|
||||||
# Load all FENs that need evaluation
|
# Load all FENs that need evaluation
|
||||||
fens_to_evaluate = []
|
fens_to_evaluate = []
|
||||||
|
fens_seen_in_batch = set() # Track duplicates within current batch
|
||||||
skipped_invalid = 0
|
skipped_invalid = 0
|
||||||
skipped_duplicate = 0
|
skipped_duplicate = 0
|
||||||
|
|
||||||
@@ -140,7 +141,12 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
|||||||
skipped_duplicate += 1
|
skipped_duplicate += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if fen in fens_seen_in_batch:
|
||||||
|
skipped_duplicate += 1
|
||||||
|
continue
|
||||||
|
|
||||||
fens_to_evaluate.append(fen)
|
fens_to_evaluate.append(fen)
|
||||||
|
fens_seen_in_batch.add(fen)
|
||||||
|
|
||||||
total_to_evaluate = len(fens_to_evaluate)
|
total_to_evaluate = len(fens_to_evaluate)
|
||||||
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
|
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
|
||||||
@@ -178,8 +184,13 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
|||||||
with open(output_file, 'a') as out:
|
with open(output_file, 'a') as out:
|
||||||
for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
|
for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
|
||||||
for fen, eval_normalized, eval_cp in batch_results:
|
for fen, eval_normalized, eval_cp in batch_results:
|
||||||
|
# Skip if already evaluated in output file during this run
|
||||||
|
if fen in evaluated_fens:
|
||||||
|
continue
|
||||||
|
|
||||||
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
|
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
|
||||||
out.write(json.dumps(data) + '\n')
|
out.write(json.dumps(data) + '\n')
|
||||||
|
evaluated_fens.add(fen) # Track as evaluated
|
||||||
evaluated += 1
|
evaluated += 1
|
||||||
raw_evals.append(eval_cp)
|
raw_evals.append(eval_cp)
|
||||||
normalized_evals.append(eval_normalized)
|
normalized_evals.append(eval_normalized)
|
||||||
@@ -287,7 +298,7 @@ if __name__ == "__main__":
|
|||||||
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
||||||
parser.add_argument("--depth", type=int, default=12,
|
parser.add_argument("--depth", type=int, default=12,
|
||||||
help="Stockfish depth (default: 12)")
|
help="Stockfish depth (default: 12)")
|
||||||
parser.add_argument("--batch-size", type=int, default=20,
|
parser.add_argument("--batch-size", type=int, default=1000,
|
||||||
help="Batch size for processing (default: 1000)")
|
help="Batch size for processing (default: 1000)")
|
||||||
parser.add_argument("--no-normalize", action="store_true",
|
parser.add_argument("--no-normalize", action="store_true",
|
||||||
help="Disable evaluation normalization (keep raw centipawns)")
|
help="Disable evaluation normalization (keep raw centipawns)")
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from generate import play_random_game_and_collect_positions
|
|||||||
|
|
||||||
def download_and_extract_puzzle_db(
|
def download_and_extract_puzzle_db(
|
||||||
url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst',
|
url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst',
|
||||||
output_dir: str = 'trainingdata'
|
output_dir: str = 'tactical_data'
|
||||||
):
|
):
|
||||||
"""Download and extract the Lichess puzzle database."""
|
"""Download and extract the Lichess puzzle database."""
|
||||||
output_path = Path(output_dir)
|
output_path = Path(output_dir)
|
||||||
@@ -141,6 +141,31 @@ def merge_positions(
|
|||||||
print(f"{'='*60}\n")
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tactical_only(
|
||||||
|
puzzle_csv: str,
|
||||||
|
output_file: str,
|
||||||
|
max_puzzles: int = 300_000
|
||||||
|
) -> int:
|
||||||
|
"""Extract tactical positions and save to file (no merge prompts).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
puzzle_csv: Path to Lichess puzzle CSV
|
||||||
|
output_file: Where to save the FEN positions
|
||||||
|
max_puzzles: Maximum puzzles to extract
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of positions extracted
|
||||||
|
"""
|
||||||
|
print("Extracting tactical positions from puzzle database...")
|
||||||
|
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
|
||||||
|
|
||||||
|
with open(output_file, 'w') as f:
|
||||||
|
for fen in tactical_positions:
|
||||||
|
f.write(fen + '\n')
|
||||||
|
|
||||||
|
return len(tactical_positions)
|
||||||
|
|
||||||
|
|
||||||
def interactive_merge_positions(
|
def interactive_merge_positions(
|
||||||
puzzle_csv: str,
|
puzzle_csv: str,
|
||||||
output_file: str = 'position.txt',
|
output_file: str = 'position.txt',
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"version": 9,
|
||||||
|
"date": "2026-04-13T20:19:08.123315",
|
||||||
|
"num_positions": 2522562,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"final_val_loss": 6.994176222619626e-05,
|
||||||
|
"device": "cuda",
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
|
||||||
|
"mode": "burst",
|
||||||
|
"duration_minutes": 30.0,
|
||||||
|
"epochs_per_season": 50,
|
||||||
|
"early_stopping_patience": 10,
|
||||||
|
"seasons_completed": 3,
|
||||||
|
"batch_size": 16384,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"initial_checkpoint": "/home/janis/Workspaces/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v8.pt"
|
||||||
|
}
|
||||||
BIN
Binary file not shown.
@@ -6,7 +6,7 @@ import de.nowchess.bot.ai.Evaluation
|
|||||||
|
|
||||||
object EvaluationNNUE extends Evaluation:
|
object EvaluationNNUE extends Evaluation:
|
||||||
|
|
||||||
private val nnue = NNUE()
|
private val nnue = NNUE(NbaiLoader.loadDefault())
|
||||||
|
|
||||||
val CHECKMATE_SCORE: Int = 10_000_000
|
val CHECKMATE_SCORE: Int = 10_000_000
|
||||||
val DRAW_SCORE: Int = 0
|
val DRAW_SCORE: Int = 0
|
||||||
|
|||||||
@@ -3,84 +3,31 @@ package de.nowchess.bot.bots.nnue
|
|||||||
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square}
|
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square}
|
||||||
import de.nowchess.api.game.GameContext
|
import de.nowchess.api.game.GameContext
|
||||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||||
import java.nio.ByteBuffer
|
|
||||||
import java.nio.ByteOrder
|
|
||||||
|
|
||||||
class NNUE:
|
class NNUE(model: NbaiModel):
|
||||||
|
|
||||||
private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) =
|
private val featureSize = model.layers(0).inputSize
|
||||||
loadWeights()
|
private val accSize = model.layers(0).outputSize
|
||||||
|
|
||||||
// Column-major L1 weights for cache-friendly sparse & incremental updates.
|
// Column-major L1 weights for cache-friendly sparse & incremental updates.
|
||||||
// l1WeightsT(featureIdx * 1536 + outputIdx) = l1Weights(outputIdx * 768 + featureIdx)
|
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
|
||||||
private val l1WeightsT: Array[Float] =
|
private val l1WeightsT: Array[Float] =
|
||||||
val t = new Array[Float](768 * 1536)
|
val w = model.weights(0).weights
|
||||||
for j <- 0 until 768; i <- 0 until 1536 do t(j * 1536 + i) = l1Weights(i * 768 + j)
|
val t = new Array[Float](featureSize * accSize)
|
||||||
|
for j <- 0 until featureSize; i <- 0 until accSize do t(j * accSize + i) = w(i * featureSize + j)
|
||||||
t
|
t
|
||||||
|
|
||||||
private def loadWeights(): (
|
|
||||||
Array[Float],
|
|
||||||
Array[Float],
|
|
||||||
Array[Float],
|
|
||||||
Array[Float],
|
|
||||||
Array[Float],
|
|
||||||
Array[Float],
|
|
||||||
Array[Float],
|
|
||||||
Array[Float],
|
|
||||||
Array[Float],
|
|
||||||
Array[Float],
|
|
||||||
) =
|
|
||||||
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
|
|
||||||
.getOrElse(sys.error("NNUE weights file not found in resources"))
|
|
||||||
|
|
||||||
try
|
|
||||||
val bytes = stream.readAllBytes()
|
|
||||||
val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
|
||||||
|
|
||||||
val magic = buffer.getInt()
|
|
||||||
if magic != 0x4555_4e4e then sys.error(s"Invalid magic number: 0x${magic.toHexString}")
|
|
||||||
|
|
||||||
val version = buffer.getInt()
|
|
||||||
if version != 1 then sys.error(s"Unsupported weight version: $version")
|
|
||||||
|
|
||||||
val l1w = readTensor(buffer)
|
|
||||||
val l1b = readTensor(buffer)
|
|
||||||
val l2w = readTensor(buffer)
|
|
||||||
val l2b = readTensor(buffer)
|
|
||||||
val l3w = readTensor(buffer)
|
|
||||||
val l3b = readTensor(buffer)
|
|
||||||
val l4w = readTensor(buffer)
|
|
||||||
val l4b = readTensor(buffer)
|
|
||||||
val l5w = readTensor(buffer)
|
|
||||||
val l5b = readTensor(buffer)
|
|
||||||
|
|
||||||
(l1w, l1b, l2w, l2b, l3w, l3b, l4w, l4b, l5w, l5b)
|
|
||||||
finally stream.close()
|
|
||||||
|
|
||||||
private def readTensor(buffer: ByteBuffer): Array[Float] =
|
|
||||||
val shapeLen = buffer.getInt()
|
|
||||||
val shape = Array.ofDim[Int](shapeLen)
|
|
||||||
for i <- 0 until shapeLen do shape(i) = buffer.getInt()
|
|
||||||
val totalElements = shape.product
|
|
||||||
val floats = Array.ofDim[Float](totalElements)
|
|
||||||
for i <- 0 until totalElements do floats(i) = buffer.getFloat()
|
|
||||||
floats
|
|
||||||
|
|
||||||
// ── Accumulator stack ────────────────────────────────────────────────────
|
// ── Accumulator stack ────────────────────────────────────────────────────
|
||||||
// l1Stack(ply) holds the L1 pre-activations (before ReLU) for that ply.
|
|
||||||
// Initialised once at root; each child ply is derived incrementally.
|
|
||||||
|
|
||||||
private val MAX_PLY = 128
|
private val MAX_PLY = 128
|
||||||
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](1536))
|
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](accSize))
|
||||||
|
|
||||||
// Shared buffers for the dense L2-L5 layers (single-threaded, non-reentrant).
|
// Shared evaluation buffers: index i holds the output of layers(i) (all except the scalar output layer).
|
||||||
private val l1ReLU = new Array[Float](1536)
|
private val evalBuffers: Array[Array[Float]] = model.layers.init.map(l => new Array[Float](l.outputSize))
|
||||||
private val l2Output = new Array[Float](1024)
|
|
||||||
private val l3Output = new Array[Float](512)
|
|
||||||
private val l4Output = new Array[Float](256)
|
|
||||||
|
|
||||||
// ── Eval cache ───────────────────────────────────────────────────────────
|
// ── Eval cache ───────────────────────────────────────────────────────────
|
||||||
private val EVAL_CACHE_MASK = (1 << 18) - 1L // 256 K slots ≈ 3 MB
|
|
||||||
|
private val EVAL_CACHE_MASK = (1 << 18) - 1L
|
||||||
private val evalCacheHashes = new Array[Long](1 << 18)
|
private val evalCacheHashes = new Array[Long](1 << 18)
|
||||||
private val evalCacheScores = new Array[Int](1 << 18)
|
private val evalCacheScores = new Array[Int](1 << 18)
|
||||||
|
|
||||||
@@ -93,35 +40,32 @@ class NNUE:
|
|||||||
(colorOffset + piece.pieceType.ordinal) * 64 + sqNum
|
(colorOffset + piece.pieceType.ordinal) * 64 + sqNum
|
||||||
|
|
||||||
private def addColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
private def addColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
||||||
val offset = featureIdx * 1536
|
val offset = featureIdx * accSize
|
||||||
for i <- 0 until 1536 do l1Pre(i) += l1WeightsT(offset + i)
|
for i <- 0 until accSize do l1Pre(i) += l1WeightsT(offset + i)
|
||||||
|
|
||||||
private def subtractColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
private def subtractColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
||||||
val offset = featureIdx * 1536
|
val offset = featureIdx * accSize
|
||||||
for i <- 0 until 1536 do l1Pre(i) -= l1WeightsT(offset + i)
|
for i <- 0 until accSize do l1Pre(i) -= l1WeightsT(offset + i)
|
||||||
|
|
||||||
// ── Accumulator init ─────────────────────────────────────────────────────
|
// ── Accumulator init ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
/** Initialise l1Stack(0) from scratch using sparse active features. */
|
|
||||||
def initAccumulator(board: Board): Unit =
|
def initAccumulator(board: Board): Unit =
|
||||||
System.arraycopy(l1Bias, 0, l1Stack(0), 0, 1536)
|
System.arraycopy(model.weights(0).bias, 0, l1Stack(0), 0, accSize)
|
||||||
for (sq, piece) <- board.pieces do addColumn(l1Stack(0), featureIndex(piece, squareNum(sq)))
|
for (sq, piece) <- board.pieces do addColumn(l1Stack(0), featureIndex(piece, squareNum(sq)))
|
||||||
|
|
||||||
// ── Accumulator push (incremental updates) ───────────────────────────────
|
// ── Accumulator push (incremental updates) ───────────────────────────────
|
||||||
|
|
||||||
/** Copy parent ply's pre-activations to childPly, then apply move deltas. */
|
|
||||||
def pushAccumulator(childPly: Int, move: Move, board: Board): Unit =
|
def pushAccumulator(childPly: Int, move: Move, board: Board): Unit =
|
||||||
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, 1536)
|
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize)
|
||||||
val l1 = l1Stack(childPly)
|
val l1 = l1Stack(childPly)
|
||||||
move.moveType match
|
move.moveType match
|
||||||
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
|
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
|
||||||
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
|
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
|
||||||
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
|
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
|
||||||
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
|
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
|
||||||
|
|
||||||
/** Copy pre-activations from parentPly to childPly without any move delta (null-move). */
|
|
||||||
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
||||||
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, 1536)
|
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
|
||||||
|
|
||||||
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||||
board.pieceAt(move.from).foreach { mover =>
|
board.pieceAt(move.from).foreach { mover =>
|
||||||
@@ -170,9 +114,6 @@ class NNUE:
|
|||||||
|
|
||||||
// ── Evaluation from accumulator ──────────────────────────────────────────
|
// ── Evaluation from accumulator ──────────────────────────────────────────
|
||||||
|
|
||||||
/** Evaluate from pre-computed L1 pre-activations at the given ply. Probes eval cache first; stores result after
|
|
||||||
* computation.
|
|
||||||
*/
|
|
||||||
def evaluateAtPly(ply: Int, turn: Color, hash: Long): Int =
|
def evaluateAtPly(ply: Int, turn: Color, hash: Long): Int =
|
||||||
val idx = (hash & EVAL_CACHE_MASK).toInt
|
val idx = (hash & EVAL_CACHE_MASK).toInt
|
||||||
if evalCacheHashes(idx) == hash then evalCacheScores(idx)
|
if evalCacheHashes(idx) == hash then evalCacheScores(idx)
|
||||||
@@ -183,11 +124,19 @@ class NNUE:
|
|||||||
score
|
score
|
||||||
|
|
||||||
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
|
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
|
||||||
for i <- 0 until 1536 do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
|
val l1ReLU = evalBuffers(0)
|
||||||
runDenseReLU(l1ReLU, 1536, l2Weights, l2Bias, l2Output, 1024)
|
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
|
||||||
runDenseReLU(l2Output, 1024, l3Weights, l3Bias, l3Output, 512)
|
|
||||||
runDenseReLU(l3Output, 512, l4Weights, l4Bias, l4Output, 256)
|
var input = l1ReLU
|
||||||
val output = runOutputLayer(l4Output, 256)
|
for i <- 1 until model.layers.length - 1 do
|
||||||
|
val lw = model.weights(i)
|
||||||
|
val out = evalBuffers(i)
|
||||||
|
val ld = model.layers(i)
|
||||||
|
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
|
||||||
|
input = out
|
||||||
|
|
||||||
|
val lastIdx = model.layers.length - 1
|
||||||
|
val output = runOutputLayer(input, model.layers(lastIdx).inputSize, model.weights(lastIdx))
|
||||||
scoreFromOutput(output, turn)
|
scoreFromOutput(output, turn)
|
||||||
|
|
||||||
private def runDenseReLU(
|
private def runDenseReLU(
|
||||||
@@ -202,8 +151,8 @@ class NNUE:
|
|||||||
val sum = (0 until inSize).foldLeft(bias(i))((s, j) => s + input(j) * weights(i * inSize + j))
|
val sum = (0 until inSize).foldLeft(bias(i))((s, j) => s + input(j) * weights(i * inSize + j))
|
||||||
output(i) = if sum > 0f then sum else 0f
|
output(i) = if sum > 0f then sum else 0f
|
||||||
|
|
||||||
private def runOutputLayer(input: Array[Float], inSize: Int): Float =
|
private def runOutputLayer(input: Array[Float], inSize: Int, lw: LayerWeights): Float =
|
||||||
(0 until inSize).foldLeft(l5Bias(0))((sum, j) => sum + input(j) * l5Weights(j))
|
(0 until inSize).foldLeft(lw.bias(0))((sum, j) => sum + input(j) * lw.weights(j))
|
||||||
|
|
||||||
private def scoreFromOutput(output: Float, turn: Color): Int =
|
private def scoreFromOutput(output: Float, turn: Color): Int =
|
||||||
val cp =
|
val cp =
|
||||||
@@ -214,21 +163,15 @@ class NNUE:
|
|||||||
val cpFromTurn = if turn == Color.Black then -cp else cp
|
val cpFromTurn = if turn == Color.Black then -cp else cp
|
||||||
math.max(-20000, math.min(20000, cpFromTurn))
|
math.max(-20000, math.min(20000, cpFromTurn))
|
||||||
|
|
||||||
// ── Legacy full-board evaluate (kept for Evaluation.evaluate compatibility) ──
|
// ── Legacy full-board evaluate ────────────────────────────────────────────
|
||||||
|
|
||||||
// Pre-allocated buffers used only by the legacy evaluate path.
|
private val legacyL1 = new Array[Float](accSize)
|
||||||
private val features = new Array[Float](768)
|
|
||||||
private val legacyL1 = new Array[Float](1536)
|
|
||||||
|
|
||||||
/** Evaluate using full board scan (sparse over active features). Layout: black pieces at indices 0-5, white at 6-11.
|
|
||||||
*/
|
|
||||||
def evaluate(context: GameContext): Int =
|
def evaluate(context: GameContext): Int =
|
||||||
val l1Pre = legacyL1
|
System.arraycopy(model.weights(0).bias, 0, legacyL1, 0, accSize)
|
||||||
System.arraycopy(l1Bias, 0, l1Pre, 0, 1536)
|
for (sq, piece) <- context.board.pieces do addColumn(legacyL1, featureIndex(piece, squareNum(sq)))
|
||||||
for (sq, piece) <- context.board.pieces do addColumn(l1Pre, featureIndex(piece, squareNum(sq)))
|
runL2toOutput(legacyL1, context.turn)
|
||||||
runL2toOutput(l1Pre, context.turn)
|
|
||||||
|
|
||||||
/** Benchmark: time 1M evaluations and report ns/eval. */
|
|
||||||
def benchmark(): Unit =
|
def benchmark(): Unit =
|
||||||
val context = GameContext.initial
|
val context = GameContext.initial
|
||||||
val iterations = 1_000_000
|
val iterations = 1_000_000
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
package de.nowchess.bot.bots.nnue
|
||||||
|
|
||||||
|
import java.io.InputStream
|
||||||
|
import java.nio.{ByteBuffer, ByteOrder}
|
||||||
|
import java.nio.charset.StandardCharsets
|
||||||
|
|
||||||
|
object NbaiLoader:
|
||||||
|
|
||||||
|
/** Little-endian encoding of ASCII bytes 'N','B','A','I'. */
|
||||||
|
val MAGIC: Int = 0x4942_414e
|
||||||
|
|
||||||
|
def load(stream: InputStream): NbaiModel =
|
||||||
|
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
checkHeader(buf)
|
||||||
|
val metadata = readMetadata(buf)
|
||||||
|
val descs = readLayerDescriptors(buf)
|
||||||
|
val weights = descs.map(_ => readLayerWeights(buf))
|
||||||
|
NbaiModel(metadata, descs, weights)
|
||||||
|
|
||||||
|
/** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */
|
||||||
|
def loadDefault(): NbaiModel =
|
||||||
|
Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match
|
||||||
|
case Some(s) => try load(s) finally s.close()
|
||||||
|
case None => NbaiMigrator.migrateFromBin()
|
||||||
|
|
||||||
|
private def checkHeader(buf: ByteBuffer): Unit =
|
||||||
|
val magic = buf.getInt()
|
||||||
|
if magic != MAGIC then sys.error(s"Invalid NBAI magic: 0x${magic.toHexString}")
|
||||||
|
val version = buf.getShort() & 0xffff
|
||||||
|
if version != 1 then sys.error(s"Unsupported NBAI version: $version")
|
||||||
|
|
||||||
|
private def readMetadata(buf: ByteBuffer): NbaiMetadata =
|
||||||
|
val bytes = new Array[Byte](buf.getInt())
|
||||||
|
buf.get(bytes)
|
||||||
|
NbaiMetadata.fromJson(new String(bytes, StandardCharsets.UTF_8))
|
||||||
|
|
||||||
|
private def readLayerDescriptors(buf: ByteBuffer): Array[LayerDescriptor] =
|
||||||
|
Array.tabulate(buf.getShort() & 0xffff) { _ =>
|
||||||
|
val nameBytes = new Array[Byte](buf.get() & 0xff)
|
||||||
|
buf.get(nameBytes)
|
||||||
|
LayerDescriptor(new String(nameBytes, StandardCharsets.US_ASCII), buf.getInt(), buf.getInt())
|
||||||
|
}
|
||||||
|
|
||||||
|
private def readLayerWeights(buf: ByteBuffer): LayerWeights =
|
||||||
|
LayerWeights(readFloats(buf), readFloats(buf))
|
||||||
|
|
||||||
|
private def readFloats(buf: ByteBuffer): Array[Float] =
|
||||||
|
val arr = new Array[Float](buf.getInt())
|
||||||
|
for i <- arr.indices do arr(i) = buf.getFloat()
|
||||||
|
arr
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package de.nowchess.bot.bots.nnue
|
||||||
|
|
||||||
|
import java.nio.{ByteBuffer, ByteOrder}
|
||||||
|
|
||||||
|
/** Converts the legacy nnue_weights.bin resource into an NbaiModel. Used as fallback when no .nbai file exists. */
|
||||||
|
object NbaiMigrator:
|
||||||
|
|
||||||
|
private val BinMagic = 0x4555_4e4e
|
||||||
|
private val BinVersion = 1
|
||||||
|
|
||||||
|
private val DefaultLayers: Array[LayerDescriptor] = Array(
|
||||||
|
LayerDescriptor("relu", 768, 1536),
|
||||||
|
LayerDescriptor("relu", 1536, 1024),
|
||||||
|
LayerDescriptor("relu", 1024, 512),
|
||||||
|
LayerDescriptor("relu", 512, 256),
|
||||||
|
LayerDescriptor("linear", 256, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
private val UnknownMetadata: NbaiMetadata =
|
||||||
|
NbaiMetadata(trainedBy = "unknown", trainedAt = "unknown", trainingDataCount = 0L, valLoss = 0.0, trainLoss = 0.0)
|
||||||
|
|
||||||
|
def migrateFromBin(): NbaiModel =
|
||||||
|
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
|
||||||
|
.getOrElse(sys.error("Neither nnue_weights.nbai nor nnue_weights.bin found in resources"))
|
||||||
|
try
|
||||||
|
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
checkBinHeader(buf)
|
||||||
|
val weights = DefaultLayers.map(_ => readBinLayerWeights(buf))
|
||||||
|
NbaiModel(UnknownMetadata, DefaultLayers, weights)
|
||||||
|
finally stream.close()
|
||||||
|
|
||||||
|
private def checkBinHeader(buf: ByteBuffer): Unit =
|
||||||
|
val magic = buf.getInt()
|
||||||
|
if magic != BinMagic then sys.error(s"Invalid bin magic: 0x${magic.toHexString}")
|
||||||
|
val version = buf.getInt()
|
||||||
|
if version != BinVersion then sys.error(s"Unsupported bin version: $version")
|
||||||
|
|
||||||
|
private def readBinLayerWeights(buf: ByteBuffer): LayerWeights =
|
||||||
|
LayerWeights(readBinTensor(buf), readBinTensor(buf))
|
||||||
|
|
||||||
|
private def readBinTensor(buf: ByteBuffer): Array[Float] =
|
||||||
|
val shape = Array.tabulate(buf.getInt())(_ => buf.getInt())
|
||||||
|
Array.tabulate(shape.product)(_ => buf.getFloat())
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package de.nowchess.bot.bots.nnue
|
||||||
|
|
||||||
|
/** Descriptor for a single dense layer stored in a .nbai file. */
|
||||||
|
case class LayerDescriptor(activation: String, inputSize: Int, outputSize: Int)
|
||||||
|
|
||||||
|
/** Training metadata embedded in every .nbai file. */
|
||||||
|
case class NbaiMetadata(
|
||||||
|
trainedBy: String,
|
||||||
|
trainedAt: String,
|
||||||
|
trainingDataCount: Long,
|
||||||
|
valLoss: Double,
|
||||||
|
trainLoss: Double,
|
||||||
|
):
|
||||||
|
def toJson: String =
|
||||||
|
s"""{
|
||||||
|
| "trainedBy": "$trainedBy",
|
||||||
|
| "trainedAt": "$trainedAt",
|
||||||
|
| "trainingDataCount": $trainingDataCount,
|
||||||
|
| "valLoss": $valLoss,
|
||||||
|
| "trainLoss": $trainLoss
|
||||||
|
|}""".stripMargin
|
||||||
|
|
||||||
|
object NbaiMetadata:
|
||||||
|
def fromJson(json: String): NbaiMetadata =
|
||||||
|
def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("")
|
||||||
|
def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0")
|
||||||
|
NbaiMetadata(str("trainedBy"), str("trainedAt"), num("trainingDataCount").toLong, num("valLoss").toDouble, num("trainLoss").toDouble)
|
||||||
|
|
||||||
|
/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */
|
||||||
|
case class LayerWeights(weights: Array[Float], bias: Array[Float])
|
||||||
|
|
||||||
|
/** A fully deserialized .nbai model ready to initialize NNUE. */
|
||||||
|
case class NbaiModel(
|
||||||
|
metadata: NbaiMetadata,
|
||||||
|
layers: Array[LayerDescriptor],
|
||||||
|
weights: Array[LayerWeights],
|
||||||
|
):
|
||||||
|
require(layers.length == weights.length, "Layer count must match weight count")
|
||||||
|
require(layers.length >= 2, "Model must have at least 2 layers")
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package de.nowchess.bot.bots.nnue
|
||||||
|
|
||||||
|
import java.io.{ByteArrayOutputStream, OutputStream}
|
||||||
|
import java.nio.{ByteBuffer, ByteOrder}
|
||||||
|
import java.nio.charset.StandardCharsets
|
||||||
|
|
||||||
|
object NbaiWriter:
|
||||||
|
|
||||||
|
def write(model: NbaiModel, out: OutputStream): Unit =
|
||||||
|
val acc = new ByteArrayOutputStream()
|
||||||
|
writeHeader(acc)
|
||||||
|
writeMetadata(acc, model.metadata)
|
||||||
|
writeLayerDescriptors(acc, model.layers)
|
||||||
|
model.weights.foreach(lw => writeLayerWeights(acc, lw))
|
||||||
|
out.write(acc.toByteArray)
|
||||||
|
|
||||||
|
private def writeHeader(out: ByteArrayOutputStream): Unit =
|
||||||
|
val buf = ByteBuffer.allocate(6).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
buf.putInt(NbaiLoader.MAGIC)
|
||||||
|
buf.putShort(1.toShort)
|
||||||
|
out.write(buf.array())
|
||||||
|
|
||||||
|
private def writeMetadata(out: ByteArrayOutputStream, meta: NbaiMetadata): Unit =
|
||||||
|
val json = meta.toJson.getBytes(StandardCharsets.UTF_8)
|
||||||
|
val buf = ByteBuffer.allocate(4 + json.length).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
buf.putInt(json.length)
|
||||||
|
buf.put(json)
|
||||||
|
out.write(buf.array())
|
||||||
|
|
||||||
|
private def writeLayerDescriptors(out: ByteArrayOutputStream, layers: Array[LayerDescriptor]): Unit =
|
||||||
|
val nameBytes = layers.map(_.activation.getBytes(StandardCharsets.US_ASCII))
|
||||||
|
val capacity = 2 + layers.indices.map(i => 1 + nameBytes(i).length + 8).sum
|
||||||
|
val buf = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
buf.putShort(layers.length.toShort)
|
||||||
|
layers.zip(nameBytes).foreach { (l, nb) =>
|
||||||
|
buf.put(nb.length.toByte)
|
||||||
|
buf.put(nb)
|
||||||
|
buf.putInt(l.inputSize)
|
||||||
|
buf.putInt(l.outputSize)
|
||||||
|
}
|
||||||
|
out.write(buf.array())
|
||||||
|
|
||||||
|
private def writeLayerWeights(out: ByteArrayOutputStream, lw: LayerWeights): Unit =
|
||||||
|
writeFloats(out, lw.weights)
|
||||||
|
writeFloats(out, lw.bias)
|
||||||
|
|
||||||
|
private def writeFloats(out: ByteArrayOutputStream, floats: Array[Float]): Unit =
|
||||||
|
val buf = ByteBuffer.allocate(4 + floats.length * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
buf.putInt(floats.length)
|
||||||
|
floats.foreach(buf.putFloat)
|
||||||
|
out.write(buf.array())
|
||||||
Reference in New Issue
Block a user