feat: NCS-41 Bot Platform (#33)
Co-authored-by: Janis <janis@nowchess.de> Reviewed-on: #33 Co-authored-by: Janis <janis.e.20@gmx.de> Co-committed-by: Janis <janis.e.20@gmx.de>
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
# Data and weights are local artifacts, not committed
|
||||
data/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
.venv
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
tactical_data/
|
||||
trainingdata/
|
||||
/datasets/
|
||||
@@ -0,0 +1,173 @@
|
||||
# Training Dataset Management
|
||||
|
||||
The NNUE training pipeline now features versioned dataset management, similar to model versioning. This prevents data loss and allows you to maintain multiple training configurations.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
datasets/
|
||||
ds_v1/
|
||||
labeled.jsonl # Training data: {"fen": "...", "eval": 0.5, "eval_raw": 150}
|
||||
metadata.json # Version info and composition
|
||||
ds_v2/
|
||||
labeled.jsonl
|
||||
metadata.json
|
||||
```
|
||||
|
||||
## Metadata Schema
|
||||
|
||||
Each dataset has a `metadata.json` file tracking its composition:
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"created": "2026-04-13T15:30:45.123456",
|
||||
"total_positions": 1000000,
|
||||
"stockfish_depth": 12,
|
||||
"sources": [
|
||||
{
|
||||
"type": "generated",
|
||||
"count": 500000,
|
||||
"params": {
|
||||
"num_positions": 500000,
|
||||
"min_move": 1,
|
||||
"max_move": 50
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "tactical",
|
||||
"count": 300000,
|
||||
"max_puzzles": 300000
|
||||
},
|
||||
{
|
||||
"type": "file_import",
|
||||
"count": 200000,
|
||||
"path": "/path/to/original_file.txt"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## TUI Workflow
|
||||
|
||||
### Main Menu
|
||||
```
|
||||
1 - Manage Training Data
|
||||
2 - Train Model
|
||||
3 - Export Model
|
||||
4 - Exit
|
||||
```
|
||||
|
||||
### Training Data Management Submenu
|
||||
```
|
||||
1 - Create new dataset
|
||||
2 - Extend existing dataset
|
||||
3 - View all datasets
|
||||
4 - Delete dataset
|
||||
5 - Back
|
||||
```
|
||||
|
||||
## Creating a Dataset
|
||||
|
||||
Use the "Create new dataset" option to add data from one or more sources:
|
||||
|
||||
1. **Generate random positions** — Play random games and sample positions
|
||||
- Number of positions
|
||||
- Move range (min/max move number to sample from)
|
||||
- Number of worker threads
|
||||
|
||||
2. **Import from file** — Load positions from a FEN file
|
||||
- File must contain one FEN string per line
|
||||
- Duplicates are automatically removed
|
||||
|
||||
3. **Extract tactical puzzles** — Download and extract Lichess puzzle database
|
||||
- Maximum number of puzzles to include
|
||||
- Automatically filters for tactical themes (forks, pins, mates, etc.)
|
||||
|
||||
You can combine multiple sources in a single dataset creation session. All positions are:
|
||||
- Deduplicated (only unique FENs are kept)
|
||||
- Labeled with Stockfish evaluations
|
||||
- Saved to `datasets/ds_vN/labeled.jsonl`
|
||||
|
||||
## Extending a Dataset
|
||||
|
||||
Use "Extend existing dataset" to add more positions to an existing dataset:
|
||||
|
||||
1. Select the dataset version to extend
|
||||
2. Choose data sources (same options as creation)
|
||||
3. Confirm labeling parameters
|
||||
4. New positions are:
|
||||
- Labeled with Stockfish
|
||||
- Deduplicated against the target dataset (preventing duplicates)
|
||||
- Merged into the existing `labeled.jsonl`
|
||||
- Metadata is updated with the new source entry
|
||||
|
||||
## Training with a Dataset
|
||||
|
||||
When you start training (Standard or Burst mode), you'll be prompted to select a dataset version. The TUI will display all available datasets with:
|
||||
- Version number
|
||||
- Total number of positions
|
||||
- Source types (generated, tactical, imported)
|
||||
- Stockfish depth used
|
||||
- Creation date
|
||||
|
||||
## Legacy Data Migration
|
||||
|
||||
If you have existing labeled data in `data/training_data.jsonl` from before this update:
|
||||
|
||||
1. Open the "Manage Training Data" menu
|
||||
2. Choose "Create new dataset"
|
||||
3. Select "Import from file"
|
||||
4. Point to `data/training_data.jsonl`
|
||||
5. Complete the dataset creation
|
||||
|
||||
Alternatively, you can manually copy the file to `datasets/ds_v1/labeled.jsonl` and create a `metadata.json` file.
|
||||
|
||||
## Viewing Dataset Details
|
||||
|
||||
Use "View all datasets" to see a table of all datasets with:
|
||||
- Version number
|
||||
- Position count
|
||||
- Source composition
|
||||
- Stockfish depth
|
||||
- Creation date
|
||||
|
||||
## Deleting a Dataset
|
||||
|
||||
Use "Delete dataset" to remove a dataset and free up disk space. **This action cannot be undone.**
|
||||
|
||||
⚠️ The system does not prevent deleting datasets used by model checkpoints. Plan accordingly.
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Deduplication Strategy
|
||||
|
||||
When extending a dataset, positions are deduplicated **within that dataset only**. This allows different datasets to contain overlapping positions if desired.
|
||||
|
||||
When creating a new dataset from multiple sources, all sources are combined and deduplicated before labeling.
|
||||
|
||||
### Labeled Position Format
|
||||
|
||||
Each line in `labeled.jsonl` is a JSON object:
|
||||
```json
|
||||
{
|
||||
"fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
|
||||
"eval": 0.0,
|
||||
"eval_raw": 0
|
||||
}
|
||||
```
|
||||
|
||||
- `fen`: The position in Forsyth-Edwards Notation
|
||||
- `eval`: Normalized evaluation ([-1, 1] range using tanh)
|
||||
- `eval_raw`: Raw Stockfish evaluation in centipawns
|
||||
|
||||
### Storage Location
|
||||
|
||||
Datasets are stored in the `datasets/` directory relative to the script location. The old `data/` directory is preserved for backward compatibility but is not actively used by the new system.
|
||||
|
||||
## Performance Tips
|
||||
|
||||
- **Smaller datasets train faster** — Start with 100k-500k positions
|
||||
- **Deduplication matters** — Use the extend functionality to build up your dataset without redundant data
|
||||
- **Stockfish depth** — Depth 12-14 balances accuracy and labeling speed
|
||||
- **Workers** — Use 4-8 workers for labeling if your machine supports it; more workers = faster but uses more CPU/memory
|
||||
@@ -0,0 +1,129 @@
|
||||
# NNUE Python Pipeline
|
||||
|
||||
Central CLI for training and exporting chess evaluation neural networks (NNUE).
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
python/
|
||||
├── nnue.py # Main CLI entry point
|
||||
├── src/ # Python modules
|
||||
│ ├── generate.py # Generate random chess positions
|
||||
│ ├── label.py # Label positions with Stockfish
|
||||
│ ├── train.py # Train NNUE model
|
||||
│ └── export.py # Export weights to Scala
|
||||
├── data/ # Training data (gitignored)
|
||||
│ ├── positions.txt
|
||||
│ └── training_data.jsonl
|
||||
└── weights/ # Model weights (gitignored)
|
||||
├── nnue_weights_v1.pt
|
||||
├── nnue_weights_v1_metadata.json
|
||||
└── ...
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Train a new model (500k positions, auto-detect checkpoint)
|
||||
python nnue.py train
|
||||
|
||||
# Train from specific checkpoint
|
||||
python nnue.py train --from-checkpoint 2
|
||||
|
||||
# Train with custom games count
|
||||
python nnue.py train --games 200000
|
||||
|
||||
# Train with custom positions file
|
||||
python nnue.py train --positions-file my_positions.txt
|
||||
|
||||
# Export specific version to Scala
|
||||
python nnue.py export 2
|
||||
|
||||
# List all checkpoints
|
||||
python nnue.py list
|
||||
```
|
||||
|
||||
## CLI Commands
|
||||
|
||||
### `train` - Train NNUE model
|
||||
|
||||
```bash
|
||||
python nnue.py train [OPTIONS]
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--from-checkpoint N` - Resume from checkpoint version N (default: uses latest)
|
||||
- `--games N` - Number of games to generate (default: 500000)
|
||||
- `--positions-file FILE` - Use existing positions file instead of generating
|
||||
- `--stockfish PATH` - Path to Stockfish binary (default: `$STOCKFISH_PATH` or `/usr/games/stockfish`)
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Train with latest checkpoint
|
||||
python nnue.py train
|
||||
|
||||
# Train from v2 with 100k games
|
||||
python nnue.py train --from-checkpoint 2 --games 100000
|
||||
|
||||
# Train with custom positions
|
||||
python nnue.py train --positions-file my_games.txt --stockfish /opt/stockfish/sf15
|
||||
```
|
||||
|
||||
### `export` - Export weights to Scala
|
||||
|
||||
```bash
|
||||
python nnue.py export WEIGHTS [output_path]
|
||||
```
|
||||
|
||||
**Arguments:**
|
||||
- `WEIGHTS` - Version number (e.g., `2`) or full filename (e.g., `nnue_weights_v2.pt`)
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Export version 2
|
||||
python nnue.py export 2
|
||||
|
||||
# Export with full filename
|
||||
python nnue.py export nnue_weights_v3.pt
|
||||
```
|
||||
|
||||
Output goes to `../src/main/scala/de/nowchess/bot/bots/nnue/NNUEWeights_vN.scala`
|
||||
|
||||
### `list` - List available checkpoints
|
||||
|
||||
```bash
|
||||
python nnue.py list
|
||||
```
|
||||
|
||||
Shows all available model versions with file sizes.
|
||||
|
||||
## Data Flow
|
||||
|
||||
1. **Generate** → `data/positions.txt`
|
||||
- Random chess positions from 8-20 move openings
|
||||
- Filters out checks, game-over states, and captures
|
||||
|
||||
2. **Label** → `data/training_data.jsonl`
|
||||
- Evaluates each position with Stockfish at depth 12
|
||||
- Stores FEN + evaluation in JSONL format
|
||||
|
||||
3. **Train** → `weights/nnue_weights_vN.pt`
|
||||
- Trains neural network on labeled positions
|
||||
- Auto-versioning (v1, v2, v3, etc.)
|
||||
- Saves metadata alongside weights
|
||||
|
||||
4. **Export** → `NNUEWeights_vN.scala`
|
||||
- Converts weights to Scala object
|
||||
- Ready for integration into bot
|
||||
|
||||
## Versioning
|
||||
|
||||
- Models are automatically versioned (v1, v2, v3, etc.)
|
||||
- Each version gets a `_metadata.json` file with training info
|
||||
- Training from checkpoint uses latest version unless specified with `--from-checkpoint`
|
||||
|
||||
## Files
|
||||
|
||||
- `data/` and `weights/` are gitignored (local artifacts)
|
||||
- Documentation in `docs/` explains training, debugging, and incremental improvements
|
||||
- Source modules in `src/` are independent and can be imported for custom workflows
|
||||
@@ -0,0 +1,951 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Central NNUE pipeline TUI for training and exporting models."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt, Confirm
|
||||
from rich import print as rprint
|
||||
|
||||
# Add src directory to path so we can import modules
|
||||
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||||
|
||||
from generate import play_random_game_and_collect_positions
|
||||
from label import label_positions_with_stockfish
|
||||
from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES
|
||||
from export import export_to_nbai
|
||||
from tactical_positions_extractor import (
|
||||
download_and_extract_puzzle_db,
|
||||
extract_tactical_only
|
||||
)
|
||||
from lichess_importer import import_lichess_evals
|
||||
from dataset import (
|
||||
get_datasets_dir,
|
||||
list_datasets,
|
||||
next_dataset_version,
|
||||
load_dataset_metadata,
|
||||
create_dataset,
|
||||
extend_dataset,
|
||||
get_dataset_labeled_path,
|
||||
delete_dataset,
|
||||
show_datasets_table
|
||||
)
|
||||
|
||||
|
||||
def get_weights_dir():
|
||||
"""Get/create weights directory."""
|
||||
weights_dir = Path(__file__).parent / "weights"
|
||||
weights_dir.mkdir(exist_ok=True)
|
||||
return weights_dir
|
||||
|
||||
|
||||
def get_data_dir():
|
||||
"""Get/create legacy data directory (for migration)."""
|
||||
data_dir = Path(__file__).parent / "data"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
return data_dir
|
||||
|
||||
|
||||
def list_checkpoints():
|
||||
"""List available checkpoint versions."""
|
||||
weights_dir = get_weights_dir()
|
||||
checkpoints = sorted(weights_dir.glob("nnue_weights_v*.pt"))
|
||||
if not checkpoints:
|
||||
return []
|
||||
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
||||
|
||||
|
||||
def migrate_legacy_data():
|
||||
"""On first run, offer to import existing data/training_data.jsonl as ds_v1."""
|
||||
console = Console()
|
||||
data_dir = get_data_dir()
|
||||
legacy_file = data_dir / "training_data.jsonl"
|
||||
datasets = list_datasets()
|
||||
|
||||
# Only migrate if legacy data exists and no datasets exist yet
|
||||
if legacy_file.exists() and not datasets:
|
||||
console.print("\n[cyan]Legacy data detected: data/training_data.jsonl[/cyan]")
|
||||
console.print("[dim]Tip: Use 'Manage Training Data' menu to import it as ds_v1[/dim]")
|
||||
|
||||
|
||||
def show_header():
|
||||
"""Display application header."""
|
||||
console = Console()
|
||||
console.clear()
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
|
||||
"[dim]Neural Network Utility Evaluation - Dataset & Model Management[/dim]",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def show_checkpoints_table():
|
||||
"""Display available checkpoints in a table."""
|
||||
console = Console()
|
||||
available = list_checkpoints()
|
||||
|
||||
if not available:
|
||||
console.print("[yellow]ℹ No model checkpoints found yet[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Available Model Checkpoints", show_header=True, header_style="bold cyan")
|
||||
table.add_column("Version", style="dim")
|
||||
table.add_column("File Size", justify="right")
|
||||
table.add_column("Status", justify="center")
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
for v in sorted(available):
|
||||
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
|
||||
if weights_file.exists():
|
||||
size = weights_file.stat().st_size / (1024**2)
|
||||
table.add_row(f"v{v}", f"{size:.1f} MB", "✓ Ready")
|
||||
else:
|
||||
table.add_row(f"v{v}", "?", "[red]✗ Missing[/red]")
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def show_main_menu():
|
||||
"""Display and handle main menu."""
|
||||
console = Console()
|
||||
|
||||
# Migrate legacy data on first run
|
||||
migrate_legacy_data()
|
||||
|
||||
while True:
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
|
||||
console.print("\n[bold]What would you like to do?[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Manage Training Data")
|
||||
console.print("[cyan]2[/cyan] - Train Model")
|
||||
console.print("[cyan]3[/cyan] - Export Model")
|
||||
console.print("[cyan]4[/cyan] - Exit")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||
|
||||
if choice == "1":
|
||||
datasets_menu()
|
||||
elif choice == "2":
|
||||
training_menu()
|
||||
elif choice == "3":
|
||||
export_interactive()
|
||||
elif choice == "4":
|
||||
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||
return
|
||||
|
||||
|
||||
def datasets_menu():
|
||||
"""Dataset management submenu."""
|
||||
console = Console()
|
||||
|
||||
while True:
|
||||
show_header()
|
||||
show_datasets_table(console)
|
||||
|
||||
console.print("\n[bold]Training Data Management[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Create new dataset")
|
||||
console.print("[cyan]2[/cyan] - Extend existing dataset")
|
||||
console.print("[cyan]3[/cyan] - View all datasets")
|
||||
console.print("[cyan]4[/cyan] - Delete dataset")
|
||||
console.print("[cyan]5[/cyan] - Back")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
|
||||
|
||||
if choice == "1":
|
||||
create_dataset_interactive()
|
||||
elif choice == "2":
|
||||
extend_dataset_interactive()
|
||||
elif choice == "3":
|
||||
show_header()
|
||||
show_datasets_table(console)
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
elif choice == "4":
|
||||
delete_dataset_interactive()
|
||||
elif choice == "5":
|
||||
return
|
||||
|
||||
|
||||
def create_dataset_interactive():
|
||||
"""Interactive dataset creation flow."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📊 Create New Dataset[/bold cyan]")
|
||||
|
||||
sources = []
|
||||
combined_count = 0
|
||||
|
||||
# Allow user to add multiple sources
|
||||
while True:
|
||||
console.print("\n[bold]Add data source (repeat until done):[/bold]")
|
||||
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||
console.print("[cyan]b[/cyan] - Import from file")
|
||||
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
|
||||
console.print("[cyan]e[/cyan] - Done adding sources")
|
||||
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
|
||||
|
||||
if choice == "a":
|
||||
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||
num_workers = int(Prompt.ask("Number of workers", default="8"))
|
||||
|
||||
console.print("[dim]Generating positions...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
count = play_random_game_and_collect_positions(
|
||||
str(temp_file),
|
||||
total_positions=num_positions,
|
||||
samples_per_game=1,
|
||||
min_move=min_move,
|
||||
max_move=max_move,
|
||||
num_workers=num_workers
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "generated",
|
||||
"count": count,
|
||||
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions generated[/green]")
|
||||
else:
|
||||
console.print("[red]✗ Generation failed[/red]")
|
||||
|
||||
elif choice == "b":
|
||||
file_path = Prompt.ask("Path to FEN file")
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
count = sum(1 for _ in f)
|
||||
sources.append({"type": "file_import", "count": count, "path": file_path})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions from file[/green]")
|
||||
except FileNotFoundError:
|
||||
console.print(f"[red]✗ File not found: {file_path}[/red]")
|
||||
|
||||
elif choice == "c":
|
||||
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||
console.print("[dim]Extracting tactical positions...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
try:
|
||||
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
|
||||
if csv_path:
|
||||
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
|
||||
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Tactical extraction failed: {e}[/red]")
|
||||
|
||||
elif choice == "d":
|
||||
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
|
||||
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
|
||||
max_pos = int(max_pos) if max_pos.strip() else None
|
||||
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
|
||||
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
temp_file.unlink(missing_ok=True)
|
||||
try:
|
||||
count = import_lichess_evals(
|
||||
input_path=zst_path,
|
||||
output_file=str(temp_file),
|
||||
max_positions=max_pos,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "lichess",
|
||||
"count": count,
|
||||
"params": {"min_depth": min_depth, "max_positions": max_pos},
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
|
||||
else:
|
||||
console.print("[red]✗ No positions imported[/red]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
|
||||
|
||||
elif choice == "e":
|
||||
if not sources:
|
||||
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||
continue
|
||||
break
|
||||
|
||||
if not sources:
|
||||
console.print("[yellow]Dataset creation cancelled[/yellow]")
|
||||
return
|
||||
|
||||
# Determine whether any sources still need Stockfish labeling.
|
||||
# Lichess sources are already labeled; only generated/tactical/file sources need it.
|
||||
needs_labeling = any(s["type"] != "lichess" for s in sources)
|
||||
|
||||
stockfish_depth = 12
|
||||
if needs_labeling:
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Summary and confirm
|
||||
console.print("\n[bold]Dataset Summary:[/bold]")
|
||||
console.print(f" Total positions: {combined_count:,}")
|
||||
for source in sources:
|
||||
console.print(f" - {source['type']}: {source['count']:,}")
|
||||
if needs_labeling:
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
|
||||
if not Confirm.ask("\nProceed to create dataset?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||
labeled_file.unlink(missing_ok=True)
|
||||
|
||||
# --- Step 1: Collect already-labeled data (Lichess source) ---
|
||||
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
if lichess_tmp.exists():
|
||||
import shutil as _shutil
|
||||
_shutil.copy(lichess_tmp, labeled_file)
|
||||
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
|
||||
console.print(f"[green]✓ Lichess positions ready[/green]")
|
||||
|
||||
# --- Step 2: Combine unlabeled sources and run Stockfish (if any) ---
|
||||
non_lichess = [s for s in sources if s["type"] != "lichess"]
|
||||
if non_lichess:
|
||||
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
for source in non_lichess:
|
||||
if source["type"] == "generated":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source["type"] == "file_import":
|
||||
temp_file = Path(source["path"])
|
||||
elif source["type"] == "tactical":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
else:
|
||||
continue
|
||||
|
||||
if temp_file.exists():
|
||||
with open(temp_file, "r") as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, "w") as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + "\n")
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
|
||||
|
||||
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Stockfish labeling failed[/red]")
|
||||
return
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
|
||||
# --- Step 3: Create dataset ---
|
||||
console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]")
|
||||
version = next_dataset_version()
|
||||
create_dataset(
|
||||
version=version,
|
||||
labeled_jsonl_path=str(labeled_file),
|
||||
sources=sources,
|
||||
stockfish_depth=stockfish_depth,
|
||||
)
|
||||
console.print(f"[green]✓ Dataset created: ds_v{version}[/green]")
|
||||
console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]")
|
||||
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def extend_dataset_interactive():
|
||||
"""Interactive dataset extension flow."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📊 Extend Existing Dataset[/bold cyan]")
|
||||
|
||||
datasets = list_datasets()
|
||||
if not datasets:
|
||||
console.print("[yellow]ℹ No datasets available to extend[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
show_datasets_table(console)
|
||||
version = int(Prompt.ask("\nEnter dataset version to extend (e.g., 1)"))
|
||||
|
||||
if not any(v == version for v, _ in datasets):
|
||||
console.print("[red]✗ Dataset not found[/red]")
|
||||
return
|
||||
|
||||
sources = []
|
||||
combined_count = 0
|
||||
|
||||
# Allow user to add sources
|
||||
while True:
|
||||
console.print("\n[bold]Add data source:[/bold]")
|
||||
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||
console.print("[cyan]b[/cyan] - Import from file")
|
||||
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
|
||||
console.print("[cyan]e[/cyan] - Done adding sources")
|
||||
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
|
||||
|
||||
if choice == "a":
|
||||
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||
num_workers = int(Prompt.ask("Number of workers", default="8"))
|
||||
|
||||
console.print("[dim]Generating positions...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
count = play_random_game_and_collect_positions(
|
||||
str(temp_file),
|
||||
total_positions=num_positions,
|
||||
samples_per_game=1,
|
||||
min_move=min_move,
|
||||
max_move=max_move,
|
||||
num_workers=num_workers
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "generated",
|
||||
"count": count,
|
||||
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions generated[/green]")
|
||||
|
||||
elif choice == "b":
|
||||
file_path = Prompt.ask("Path to FEN file")
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
count = sum(1 for _ in f)
|
||||
sources.append({"type": "file_import", "count": count, "path": file_path})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions from file[/green]")
|
||||
except FileNotFoundError:
|
||||
console.print(f"[red]✗ File not found: {file_path}[/red]")
|
||||
|
||||
elif choice == "c":
|
||||
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||
console.print("[dim]Extracting tactical positions...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
try:
|
||||
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
|
||||
if csv_path:
|
||||
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
|
||||
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Extraction failed: {e}[/red]")
|
||||
|
||||
elif choice == "d":
|
||||
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
|
||||
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
|
||||
max_pos = int(max_pos) if max_pos.strip() else None
|
||||
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
|
||||
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
temp_file.unlink(missing_ok=True)
|
||||
try:
|
||||
count = import_lichess_evals(
|
||||
input_path=zst_path,
|
||||
output_file=str(temp_file),
|
||||
max_positions=max_pos,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "lichess",
|
||||
"count": count,
|
||||
"params": {"min_depth": min_depth, "max_positions": max_pos},
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
|
||||
else:
|
||||
console.print("[red]✗ No positions imported[/red]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
|
||||
|
||||
elif choice == "e":
|
||||
if not sources:
|
||||
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||
continue
|
||||
break
|
||||
|
||||
if not sources:
|
||||
console.print("[yellow]Extension cancelled[/yellow]")
|
||||
return
|
||||
|
||||
needs_labeling = any(s["type"] != "lichess" for s in sources)
|
||||
|
||||
stockfish_depth = 12
|
||||
if needs_labeling:
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Summary and confirm
|
||||
console.print("\n[bold]Extension Summary:[/bold]")
|
||||
console.print(f" Target dataset: ds_v{version}")
|
||||
console.print(f" New positions: {combined_count:,}")
|
||||
for source in sources:
|
||||
console.print(f" - {source['type']}: {source['count']:,}")
|
||||
if needs_labeling:
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
|
||||
if not Confirm.ask("\nProceed to extend dataset?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||
labeled_file.unlink(missing_ok=True)
|
||||
|
||||
# Copy pre-labeled Lichess data if present
|
||||
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
if lichess_tmp.exists():
|
||||
import shutil as _shutil
|
||||
_shutil.copy(lichess_tmp, labeled_file)
|
||||
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
|
||||
console.print(f"[green]✓ Lichess positions ready[/green]")
|
||||
|
||||
# Combine and label remaining sources with Stockfish
|
||||
non_lichess = [s for s in sources if s["type"] != "lichess"]
|
||||
if non_lichess:
|
||||
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
for source in non_lichess:
|
||||
if source["type"] == "generated":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source["type"] == "file_import":
|
||||
temp_file = Path(source["path"])
|
||||
elif source["type"] == "tactical":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
else:
|
||||
continue
|
||||
if temp_file.exists():
|
||||
with open(temp_file, "r") as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, "w") as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + "\n")
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
|
||||
|
||||
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Stockfish labeling failed[/red]")
|
||||
return
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
|
||||
# Extend dataset
|
||||
console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]")
|
||||
success = extend_dataset(
|
||||
version=version,
|
||||
new_labeled_path=str(labeled_file),
|
||||
new_source_entry={
|
||||
"type": "merged_sources",
|
||||
"count": combined_count,
|
||||
"sources": sources,
|
||||
}
|
||||
)
|
||||
|
||||
if success:
|
||||
metadata = load_dataset_metadata(version)
|
||||
console.print(f"[green]✓ Dataset extended[/green]")
|
||||
console.print(f"[bold]Total positions: {metadata['total_positions']:,}[/bold]")
|
||||
else:
|
||||
console.print("[red]✗ Extension failed[/red]")
|
||||
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def delete_dataset_interactive():
|
||||
"""Interactive dataset deletion."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]⚠️ Delete Dataset[/bold cyan]")
|
||||
|
||||
datasets = list_datasets()
|
||||
if not datasets:
|
||||
console.print("[yellow]ℹ No datasets to delete[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
show_datasets_table(console)
|
||||
version = int(Prompt.ask("\nEnter dataset version to delete (e.g., 1)"))
|
||||
|
||||
if not any(v == version for v, _ in datasets):
|
||||
console.print("[red]✗ Dataset not found[/red]")
|
||||
return
|
||||
|
||||
if Confirm.ask(f"Delete ds_v{version}? This cannot be undone.", default=False):
|
||||
if delete_dataset(version):
|
||||
console.print(f"[green]✓ Dataset ds_v{version} deleted[/green]")
|
||||
else:
|
||||
console.print("[red]✗ Deletion failed[/red]")
|
||||
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def training_menu():
|
||||
"""Training submenu."""
|
||||
console = Console()
|
||||
|
||||
while True:
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold]Training[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Standard Training")
|
||||
console.print("[cyan]2[/cyan] - Burst Training")
|
||||
console.print("[cyan]3[/cyan] - View Model Checkpoints")
|
||||
console.print("[cyan]4[/cyan] - Back")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||
|
||||
if choice == "1":
|
||||
train_interactive()
|
||||
elif choice == "2":
|
||||
burst_train_interactive()
|
||||
elif choice == "3":
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
elif choice == "4":
|
||||
return
|
||||
|
||||
|
||||
def train_interactive():
|
||||
"""Interactive training menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📚 Standard Training Configuration[/bold cyan]")
|
||||
|
||||
# Dataset selection
|
||||
datasets = list_datasets()
|
||||
if not datasets:
|
||||
console.print("[red]✗ No datasets available. Create one first.[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print("\n[bold]Available Datasets:[/bold]")
|
||||
show_datasets_table(console)
|
||||
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
|
||||
|
||||
if not any(v == dataset_version for v, _ in datasets):
|
||||
console.print("[red]✗ Dataset not found[/red]")
|
||||
return
|
||||
|
||||
labeled_file = get_dataset_labeled_path(dataset_version)
|
||||
if not labeled_file:
|
||||
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
|
||||
return
|
||||
|
||||
# Checkpoint selection
|
||||
available = list_checkpoints()
|
||||
use_checkpoint = False
|
||||
checkpoint_version = None
|
||||
|
||||
if available:
|
||||
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||
use_checkpoint = Confirm.ask("Start from an existing checkpoint?", default=False)
|
||||
if use_checkpoint:
|
||||
checkpoint_version = Prompt.ask(
|
||||
"Enter checkpoint version",
|
||||
default=str(max(available))
|
||||
)
|
||||
|
||||
# Training parameters
|
||||
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
|
||||
hidden_layers_str = Prompt.ask(
|
||||
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
|
||||
default=default_layers
|
||||
)
|
||||
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
|
||||
early_stopping = None
|
||||
if Confirm.ask("Enable early stopping?", default=False):
|
||||
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||||
|
||||
arch_str = " → ".join(str(s) for s in [768] + hidden_sizes + [1])
|
||||
|
||||
# Confirm and start
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
console.print(f" Dataset: ds_v{dataset_version}")
|
||||
console.print(f" Architecture: {arch_str}")
|
||||
console.print(f" Epochs: {epochs}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||
if early_stopping:
|
||||
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||||
else:
|
||||
console.print(f" Early stopping: No")
|
||||
if use_checkpoint:
|
||||
console.print(f" Checkpoint: v{checkpoint_version}")
|
||||
else:
|
||||
console.print(f" Checkpoint: None (training from scratch)")
|
||||
|
||||
if not Confirm.ask("\nStart training?", default=True):
|
||||
console.print("[yellow]Training cancelled[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
# Execute training
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
try:
|
||||
console.print("\n[bold cyan]Training Model[/bold cyan]")
|
||||
checkpoint = None
|
||||
if use_checkpoint:
|
||||
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
|
||||
|
||||
train_nnue(
|
||||
data_file=str(labeled_file),
|
||||
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
checkpoint=checkpoint,
|
||||
use_versioning=True,
|
||||
early_stopping_patience=early_stopping,
|
||||
subsample_ratio=subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
console.print("[green]✓ Training complete[/green]")
|
||||
|
||||
# Show result
|
||||
available = list_checkpoints()
|
||||
new_version = max(available) if available else 1
|
||||
console.print(f"\n[bold green]✓ Training successful![/bold green]")
|
||||
console.print(f"[bold]New checkpoint: v{new_version}[/bold]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def burst_train_interactive():
|
||||
"""Interactive burst training menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
|
||||
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
|
||||
|
||||
# Dataset selection
|
||||
datasets = list_datasets()
|
||||
if not datasets:
|
||||
console.print("[red]✗ No datasets available. Create one first.[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print("[bold]Available Datasets:[/bold]")
|
||||
show_datasets_table(console)
|
||||
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
|
||||
|
||||
if not any(v == dataset_version for v, _ in datasets):
|
||||
console.print("[red]✗ Dataset not found[/red]")
|
||||
return
|
||||
|
||||
labeled_file = get_dataset_labeled_path(dataset_version)
|
||||
if not labeled_file:
|
||||
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
|
||||
return
|
||||
|
||||
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
|
||||
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
|
||||
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
|
||||
|
||||
# Optional initial checkpoint
|
||||
available = list_checkpoints()
|
||||
checkpoint = None
|
||||
if available:
|
||||
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||
if Confirm.ask("Start from an existing checkpoint?", default=False):
|
||||
version = Prompt.ask("Enter checkpoint version", default=str(max(available)))
|
||||
checkpoint = str(get_weights_dir() / f"nnue_weights_v{version}.pt")
|
||||
|
||||
# Training hyperparameters
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
|
||||
hidden_layers_str = Prompt.ask(
|
||||
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
|
||||
default=default_layers
|
||||
)
|
||||
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
|
||||
arch_str = " → ".join(str(s) for s in [768] + hidden_sizes + [1])
|
||||
|
||||
# Summary
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
console.print(f" Dataset: ds_v{dataset_version}")
|
||||
console.print(f" Architecture: {arch_str}")
|
||||
console.print(f" Duration: {duration_minutes:.0f} minutes")
|
||||
console.print(f" Epochs per season: {epochs_per_season}")
|
||||
console.print(f" Patience: {early_stopping_patience}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
|
||||
|
||||
if not Confirm.ask("\nStart burst training?", default=True):
|
||||
console.print("[yellow]Burst training cancelled[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
try:
|
||||
console.print("\n[bold cyan]Burst Training[/bold cyan]")
|
||||
burst_train(
|
||||
data_file=str(labeled_file),
|
||||
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||
duration_minutes=duration_minutes,
|
||||
epochs_per_season=epochs_per_season,
|
||||
early_stopping_patience=early_stopping_patience,
|
||||
batch_size=batch_size,
|
||||
initial_checkpoint=checkpoint,
|
||||
use_versioning=True,
|
||||
subsample_ratio=subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
console.print("[green]✓ Burst training complete[/green]")
|
||||
|
||||
available = list_checkpoints()
|
||||
if available:
|
||||
console.print(f"[bold]Latest checkpoint: v{max(available)}[/bold]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def export_interactive():
|
||||
"""Interactive export menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📦 Export Configuration[/bold cyan]")
|
||||
|
||||
# Select weights version
|
||||
available = list_checkpoints()
|
||||
if not available:
|
||||
console.print("[red]✗ No checkpoints available to export[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print(f"[dim]Available versions: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||
version = Prompt.ask("Enter version to export (e.g., 2)")
|
||||
|
||||
weights_file = f"nnue_weights_v{version}.pt"
|
||||
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.nbai")
|
||||
|
||||
console.print(f"\n[bold]Export Configuration:[/bold]")
|
||||
console.print(f" Source: {weights_file}")
|
||||
console.print(f" Destination: {output_file}")
|
||||
|
||||
if not Confirm.ask("\nExport weights?", default=True):
|
||||
console.print("[yellow]Export cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
weights_dir = get_weights_dir()
|
||||
weights_path = weights_dir / weights_file
|
||||
|
||||
if not weights_path.exists():
|
||||
console.print(f"[red]✗ {weights_file} not found[/red]")
|
||||
return
|
||||
|
||||
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
|
||||
export_to_nbai(str(weights_path), output_file)
|
||||
console.print(f"\n[green]✓ Export complete![/green]")
|
||||
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
show_main_menu()
|
||||
return 0
|
||||
except KeyboardInterrupt:
|
||||
console = Console()
|
||||
console.print("\n[yellow]Interrupted by user[/yellow]")
|
||||
return 1
|
||||
except Exception as e:
|
||||
console = Console()
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,6 @@
|
||||
chess==1.11.2
|
||||
torch==2.11.0
|
||||
tqdm==4.67.3
|
||||
numpy==2.4.4
|
||||
rich==13.7.0
|
||||
zstandard==0.23.0
|
||||
@@ -0,0 +1,66 @@
|
||||
@echo off
|
||||
REM NNUE Training Pipeline for Windows
|
||||
|
||||
setlocal enabledelayedexpansion
|
||||
|
||||
echo.
|
||||
echo === NNUE Training Pipeline ===
|
||||
echo.
|
||||
|
||||
REM Get the directory where this script is located
|
||||
set SCRIPT_DIR=%~dp0
|
||||
|
||||
cd /d "%SCRIPT_DIR%"
|
||||
|
||||
REM Step 1: Generate positions
|
||||
echo Step 1: Generating 500,000 random positions...
|
||||
python generate_positions.py positions.txt
|
||||
if not exist positions.txt (
|
||||
echo ERROR: positions.txt not created
|
||||
exit /b 1
|
||||
)
|
||||
echo [OK] Positions generated
|
||||
echo.
|
||||
|
||||
REM Step 2: Label positions with Stockfish
|
||||
echo Step 2: Labeling positions with Stockfish (depth 12^)...
|
||||
if "%STOCKFISH_PATH%"=="" (
|
||||
set STOCKFISH_PATH=stockfish
|
||||
)
|
||||
python label_positions.py positions.txt training_data.jsonl "%STOCKFISH_PATH%"
|
||||
if not exist training_data.jsonl (
|
||||
echo ERROR: training_data.jsonl not created
|
||||
exit /b 1
|
||||
)
|
||||
echo [OK] Positions labeled
|
||||
echo.
|
||||
|
||||
REM Step 3: Train NNUE model
|
||||
echo Step 3: Training NNUE model (20 epochs^)...
|
||||
python train_nnue.py training_data.jsonl nnue_weights.pt
|
||||
if not exist nnue_weights.pt (
|
||||
echo ERROR: nnue_weights.pt not created
|
||||
exit /b 1
|
||||
)
|
||||
echo [OK] Model trained
|
||||
echo.
|
||||
|
||||
REM Step 4: Export weights to Scala
|
||||
echo Step 4: Exporting weights to Scala...
|
||||
python export_weights.py nnue_weights.pt ..\src\main\scala\de\nowchess\bot\bots\nnue\NNUEWeights.scala
|
||||
if not exist ..\src\main\scala\de\nowchess\bot\bots\nnue\NNUEWeights.scala (
|
||||
echo ERROR: NNUEWeights.scala not created
|
||||
exit /b 1
|
||||
)
|
||||
echo [OK] Weights exported
|
||||
echo.
|
||||
|
||||
echo === Pipeline Complete ===
|
||||
echo.
|
||||
echo Next steps:
|
||||
echo 1. Navigate to project root: cd ..\..
|
||||
echo 2. Compile: .\compile.bat
|
||||
echo 3. Test: .\test.bat
|
||||
echo.
|
||||
|
||||
endlocal
|
||||
@@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
|
||||
# NNUE Training Pipeline (bash version)
|
||||
# Uses the central CLI (nnue.py) for all operations
|
||||
# Works on Linux, macOS, and Windows (with Git Bash or WSL)
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Use python or python3 (check which is available)
|
||||
PYTHON_CMD="python3"
|
||||
if ! command -v python3 &> /dev/null; then
|
||||
PYTHON_CMD="python"
|
||||
fi
|
||||
|
||||
echo "=== NNUE Training Pipeline ==="
|
||||
echo ""
|
||||
echo "Python command: $PYTHON_CMD"
|
||||
echo "Working directory: $SCRIPT_DIR"
|
||||
echo ""
|
||||
|
||||
# Run the unified training pipeline
|
||||
$PYTHON_CMD nnue.py train
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo ""
|
||||
echo "ERROR: Training pipeline failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== Pipeline Complete ==="
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo "1. Navigate to project root: cd ../.."
|
||||
echo "2. Compile: ./compile"
|
||||
echo "3. Test: ./test"
|
||||
@@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Dataset versioning and management for NNUE training data."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
|
||||
def get_datasets_dir() -> Path:
|
||||
"""Get/create datasets directory."""
|
||||
datasets_dir = Path(__file__).parent.parent / "datasets"
|
||||
datasets_dir.mkdir(exist_ok=True)
|
||||
return datasets_dir
|
||||
|
||||
|
||||
def next_dataset_version() -> int:
|
||||
"""Find the next available dataset version number."""
|
||||
datasets_dir = get_datasets_dir()
|
||||
versions = []
|
||||
|
||||
for d in datasets_dir.iterdir():
|
||||
if d.is_dir() and d.name.startswith("ds_v"):
|
||||
try:
|
||||
v = int(d.name.split("_v")[1])
|
||||
versions.append(v)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return max(versions) + 1 if versions else 1
|
||||
|
||||
|
||||
def list_datasets() -> List[Tuple[int, Dict]]:
|
||||
"""List all datasets with their metadata.
|
||||
|
||||
Returns:
|
||||
List of (version, metadata_dict) tuples, sorted by version.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
datasets = []
|
||||
|
||||
for d in datasets_dir.iterdir():
|
||||
if d.is_dir() and d.name.startswith("ds_v"):
|
||||
try:
|
||||
v = int(d.name.split("_v")[1])
|
||||
metadata_file = d / "metadata.json"
|
||||
if metadata_file.exists():
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
datasets.append((v, metadata))
|
||||
except (ValueError, IndexError, json.JSONDecodeError):
|
||||
pass
|
||||
|
||||
return sorted(datasets, key=lambda x: x[0])
|
||||
|
||||
|
||||
def load_dataset_metadata(version: int) -> Optional[Dict]:
|
||||
"""Load metadata for a specific dataset version.
|
||||
|
||||
Returns:
|
||||
Metadata dict or None if not found.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
metadata_file = datasets_dir / f"ds_v{version}" / "metadata.json"
|
||||
|
||||
if not metadata_file.exists():
|
||||
return None
|
||||
|
||||
with open(metadata_file, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_dataset_metadata(version: int, metadata: Dict) -> None:
|
||||
"""Save metadata for a dataset version."""
|
||||
datasets_dir = get_datasets_dir()
|
||||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||
dataset_dir.mkdir(exist_ok=True)
|
||||
|
||||
metadata_file = dataset_dir / "metadata.json"
|
||||
with open(metadata_file, 'w') as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
|
||||
|
||||
def create_dataset(
|
||||
version: int,
|
||||
labeled_jsonl_path: str,
|
||||
sources: List[Dict],
|
||||
stockfish_depth: int = 12
|
||||
) -> Path:
|
||||
"""Create a new versioned dataset.
|
||||
|
||||
Args:
|
||||
version: Dataset version number
|
||||
labeled_jsonl_path: Path to labeled.jsonl to copy
|
||||
sources: List of source dicts (see plan for schema)
|
||||
stockfish_depth: Depth used for labeling
|
||||
|
||||
Returns:
|
||||
Path to the created dataset directory.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||
dataset_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Copy labeled data with deduplication (in case source has duplicates)
|
||||
source_path = Path(labeled_jsonl_path)
|
||||
if source_path.exists():
|
||||
dest_path = dataset_dir / "labeled.jsonl"
|
||||
seen_fens = set()
|
||||
unique_count = 0
|
||||
|
||||
with open(source_path, 'r') as src, open(dest_path, 'w') as dst:
|
||||
for line in src:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data.get('fen')
|
||||
if fen and fen not in seen_fens:
|
||||
dst.write(line)
|
||||
seen_fens.add(fen)
|
||||
unique_count += 1
|
||||
except json.JSONDecodeError:
|
||||
# Skip malformed lines
|
||||
pass
|
||||
|
||||
# Count positions
|
||||
total_positions = 0
|
||||
if (dataset_dir / "labeled.jsonl").exists():
|
||||
with open(dataset_dir / "labeled.jsonl", 'r') as f:
|
||||
total_positions = sum(1 for _ in f)
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"version": version,
|
||||
"created": datetime.now().isoformat(),
|
||||
"total_positions": total_positions,
|
||||
"stockfish_depth": stockfish_depth,
|
||||
"sources": sources
|
||||
}
|
||||
|
||||
save_dataset_metadata(version, metadata)
|
||||
return dataset_dir
|
||||
|
||||
|
||||
def extend_dataset(
|
||||
version: int,
|
||||
new_labeled_path: str,
|
||||
new_source_entry: Dict
|
||||
) -> bool:
|
||||
"""Extend an existing dataset with new labeled positions (with deduplication).
|
||||
|
||||
Args:
|
||||
version: Dataset version to extend
|
||||
new_labeled_path: Path to new labeled.jsonl to merge
|
||||
new_source_entry: Source entry to add to metadata
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||
|
||||
if not dataset_dir.exists():
|
||||
return False
|
||||
|
||||
labeled_file = dataset_dir / "labeled.jsonl"
|
||||
new_labeled_file = Path(new_labeled_path)
|
||||
|
||||
if not new_labeled_file.exists():
|
||||
return False
|
||||
|
||||
# Load existing FENs (dedup set) — must load entire file to avoid duplicates
|
||||
existing_fens = set()
|
||||
if labeled_file.exists():
|
||||
with open(labeled_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data.get('fen')
|
||||
if fen:
|
||||
existing_fens.add(fen)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Merge new positions, skipping duplicates
|
||||
new_count = 0
|
||||
new_lines = []
|
||||
with open(new_labeled_file, 'r') as f_new:
|
||||
for line in f_new:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data.get('fen')
|
||||
if fen and fen not in existing_fens:
|
||||
new_lines.append(line)
|
||||
existing_fens.add(fen)
|
||||
new_count += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Append only the new, unique positions
|
||||
if new_lines:
|
||||
with open(labeled_file, 'a') as f_append:
|
||||
for line in new_lines:
|
||||
f_append.write(line)
|
||||
|
||||
# Update metadata
|
||||
metadata = load_dataset_metadata(version)
|
||||
if metadata:
|
||||
# Count total positions
|
||||
total_positions = 0
|
||||
with open(labeled_file, 'r') as f:
|
||||
total_positions = sum(1 for _ in f)
|
||||
|
||||
metadata['total_positions'] = total_positions
|
||||
# Update the source entry with actual count of new positions added
|
||||
new_source_entry['actual_count'] = new_count
|
||||
metadata['sources'].append(new_source_entry)
|
||||
save_dataset_metadata(version, metadata)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_dataset_labeled_path(version: int) -> Optional[Path]:
|
||||
"""Get the path to a dataset's labeled.jsonl file.
|
||||
|
||||
Returns:
|
||||
Path to labeled.jsonl or None if dataset doesn't exist.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
labeled_file = datasets_dir / f"ds_v{version}" / "labeled.jsonl"
|
||||
|
||||
if labeled_file.exists():
|
||||
return labeled_file
|
||||
return None
|
||||
|
||||
|
||||
def delete_dataset(version: int) -> bool:
|
||||
"""Delete a dataset (recursively removes directory).
|
||||
|
||||
Args:
|
||||
version: Dataset version to delete
|
||||
|
||||
Returns:
|
||||
True if successful.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||
|
||||
if not dataset_dir.exists():
|
||||
return False
|
||||
|
||||
import shutil
|
||||
shutil.rmtree(dataset_dir)
|
||||
return True
|
||||
|
||||
|
||||
def show_datasets_table(console: Console = None) -> None:
|
||||
"""Display all datasets in a Rich table."""
|
||||
if console is None:
|
||||
console = Console()
|
||||
|
||||
datasets = list_datasets()
|
||||
|
||||
if not datasets:
|
||||
console.print("[yellow]ℹ No datasets found yet[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Available Datasets", show_header=True, header_style="bold cyan")
|
||||
table.add_column("Version", style="dim")
|
||||
table.add_column("Positions", justify="right")
|
||||
table.add_column("Sources", justify="left")
|
||||
table.add_column("Depth", justify="center")
|
||||
table.add_column("Created", justify="left")
|
||||
|
||||
for v, metadata in datasets:
|
||||
positions = metadata.get('total_positions', 0)
|
||||
sources = metadata.get('sources', [])
|
||||
source_str = ", ".join([s.get('type', '?') for s in sources])
|
||||
depth = metadata.get('stockfish_depth', '?')
|
||||
created = metadata.get('created', '?')
|
||||
if created != '?':
|
||||
created = created.split('T')[0] # Just the date
|
||||
|
||||
table.add_row(f"v{v}", f"{positions:,}", source_str, str(depth), created)
|
||||
|
||||
console.print(table)
|
||||
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Export NNUE weights to .nbai format for runtime loading."""
|
||||
|
||||
import json
|
||||
import struct
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
MAGIC = 0x4942_414E # bytes 'N','B','A','I' as little-endian int32
|
||||
VERSION = 1
|
||||
|
||||
|
||||
def _read_sidecar(weights_file: str) -> dict:
|
||||
sidecar = weights_file.replace(".pt", "_metadata.json")
|
||||
if Path(sidecar).exists():
|
||||
with open(sidecar) as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
|
||||
def _infer_layers(state_dict: dict) -> list[dict]:
|
||||
"""Derive layer descriptors from state_dict weight shapes.
|
||||
|
||||
Assumes layers named l1, l2, ..., lN.
|
||||
All hidden layers get activation 'relu'; the last gets 'linear'.
|
||||
"""
|
||||
names = sorted(
|
||||
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
|
||||
key=lambda n: int(n[1:]),
|
||||
)
|
||||
layers = []
|
||||
for i, name in enumerate(names):
|
||||
out_size, in_size = state_dict[f"{name}.weight"].shape
|
||||
activation = "linear" if i == len(names) - 1 else "relu"
|
||||
layers.append({"activation": activation, "inputSize": int(in_size), "outputSize": int(out_size)})
|
||||
return layers
|
||||
|
||||
|
||||
def _write_floats(f, tensor):
|
||||
data = tensor.float().flatten().cpu().numpy()
|
||||
f.write(struct.pack("<I", len(data)))
|
||||
f.write(struct.pack(f"<{len(data)}f", *data))
|
||||
|
||||
|
||||
def export_to_nbai(
|
||||
weights_file: str,
|
||||
output_file: str,
|
||||
trained_by: str = "unknown",
|
||||
train_loss: float = 0.0,
|
||||
):
|
||||
if not Path(weights_file).exists():
|
||||
print(f"Error: weights file not found at {weights_file}")
|
||||
sys.exit(1)
|
||||
|
||||
loaded = torch.load(weights_file, map_location="cpu")
|
||||
state_dict = (
|
||||
loaded["model_state_dict"]
|
||||
if isinstance(loaded, dict) and "model_state_dict" in loaded
|
||||
else loaded
|
||||
)
|
||||
|
||||
sidecar = _read_sidecar(weights_file)
|
||||
val_loss = float(loaded.get("best_val_loss", sidecar.get("final_val_loss", 0.0))) if isinstance(loaded, dict) else 0.0
|
||||
trained_at = sidecar.get("date", datetime.now().isoformat())
|
||||
training_data_count = int(sidecar.get("num_positions", 0))
|
||||
|
||||
metadata = {
|
||||
"trainedBy": trained_by,
|
||||
"trainedAt": trained_at,
|
||||
"trainingDataCount": training_data_count,
|
||||
"valLoss": val_loss,
|
||||
"trainLoss": train_loss,
|
||||
}
|
||||
|
||||
layers = _infer_layers(state_dict)
|
||||
layer_names = sorted(
|
||||
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
|
||||
key=lambda n: int(n[1:]),
|
||||
)
|
||||
|
||||
print(f"Architecture ({len(layers)} layers):")
|
||||
for i, l in enumerate(layers):
|
||||
print(f" l{i + 1}: {l['inputSize']} -> {l['outputSize']} [{l['activation']}]")
|
||||
|
||||
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_file, "wb") as f:
|
||||
# Header
|
||||
f.write(struct.pack("<I", MAGIC))
|
||||
f.write(struct.pack("<H", VERSION))
|
||||
|
||||
# Metadata (length-prefixed UTF-8 JSON)
|
||||
meta_bytes = json.dumps(metadata, indent=2).encode("utf-8")
|
||||
f.write(struct.pack("<I", len(meta_bytes)))
|
||||
f.write(meta_bytes)
|
||||
|
||||
# Layer descriptors
|
||||
f.write(struct.pack("<H", len(layers)))
|
||||
for layer in layers:
|
||||
name_bytes = layer["activation"].encode("ascii")
|
||||
f.write(struct.pack("<B", len(name_bytes)))
|
||||
f.write(name_bytes)
|
||||
f.write(struct.pack("<I", layer["inputSize"]))
|
||||
f.write(struct.pack("<I", layer["outputSize"]))
|
||||
|
||||
# Weights: weight tensor then bias tensor per layer
|
||||
for name in layer_names:
|
||||
w = state_dict[f"{name}.weight"]
|
||||
b = state_dict[f"{name}.bias"]
|
||||
_write_floats(f, w)
|
||||
_write_floats(f, b)
|
||||
print(f" Wrote {name}: weight {tuple(w.shape)}, bias {tuple(b.shape)}")
|
||||
|
||||
size_mb = Path(output_file).stat().st_size / (1024 ** 2)
|
||||
print(f"\nExported to {output_file} ({size_mb:.2f} MB)")
|
||||
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
weights_file = "nnue_weights.pt"
|
||||
output_file = "../src/main/resources/nnue_weights.nbai"
|
||||
trained_by = "unknown"
|
||||
train_loss = 0.0
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
weights_file = sys.argv[1]
|
||||
if len(sys.argv) > 2:
|
||||
output_file = sys.argv[2]
|
||||
if len(sys.argv) > 3:
|
||||
trained_by = sys.argv[3]
|
||||
if len(sys.argv) > 4:
|
||||
train_loss = float(sys.argv[4])
|
||||
|
||||
export_to_nbai(weights_file, output_file, trained_by, train_loss)
|
||||
@@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate random chess positions for NNUE training with multiprocessing."""
|
||||
|
||||
import chess
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool, Queue
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
def _worker_generate_games(worker_id, games_per_worker, samples_per_game, min_move, max_move):
|
||||
"""Generate games for one worker.
|
||||
|
||||
Returns:
|
||||
list of FENs generated by this worker
|
||||
"""
|
||||
positions = []
|
||||
|
||||
for game_num in range(games_per_worker):
|
||||
board = chess.Board()
|
||||
move_history = []
|
||||
|
||||
# Play a complete random game
|
||||
while not board.is_game_over() and len(move_history) < 200:
|
||||
legal_moves = list(board.legal_moves)
|
||||
if not legal_moves:
|
||||
break
|
||||
move = random.choice(legal_moves)
|
||||
board.push(move)
|
||||
move_history.append(board.copy())
|
||||
|
||||
# Determine the range of moves to sample from
|
||||
game_length = len(move_history)
|
||||
valid_start = max(min_move, 0)
|
||||
valid_end = min(max_move, game_length)
|
||||
|
||||
if valid_start >= valid_end:
|
||||
continue
|
||||
|
||||
# Randomly sample positions from this game
|
||||
sample_count = min(samples_per_game, valid_end - valid_start)
|
||||
if sample_count > 0:
|
||||
sample_indices = random.sample(
|
||||
range(valid_start, valid_end),
|
||||
k=sample_count
|
||||
)
|
||||
|
||||
for idx in sample_indices:
|
||||
sampled_board = move_history[idx]
|
||||
|
||||
# Only filter truly invalid or terminal positions
|
||||
if not sampled_board.is_valid() or sampled_board.is_game_over():
|
||||
continue
|
||||
|
||||
# Save position (include check, captures, all positions)
|
||||
fen = sampled_board.fen()
|
||||
positions.append(fen)
|
||||
|
||||
return positions
|
||||
|
||||
|
||||
def play_random_game_and_collect_positions(
|
||||
output_file,
|
||||
total_positions=3000000,
|
||||
samples_per_game=1,
|
||||
min_move=1,
|
||||
max_move=50,
|
||||
num_workers=8
|
||||
):
|
||||
"""Generate positions using multiprocessing with multiple workers.
|
||||
|
||||
Args:
|
||||
output_file: Output file for positions
|
||||
total_positions: Target number of positions to generate
|
||||
samples_per_game: Number of positions to sample per game (1-N)
|
||||
min_move: Minimum move number to start sampling from
|
||||
max_move: Maximum move number for sampling
|
||||
num_workers: Number of parallel worker processes
|
||||
|
||||
Returns:
|
||||
Number of valid positions saved
|
||||
"""
|
||||
# Estimate games needed (roughly 1 position per game on average)
|
||||
total_games = max(total_positions // samples_per_game, num_workers)
|
||||
games_per_worker = total_games // num_workers
|
||||
|
||||
print(f"Generating {total_positions:,} positions using {num_workers} workers")
|
||||
print(f"Total games: ~{total_games:,} ({games_per_worker:,} per worker)")
|
||||
print()
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
# Generate positions in parallel
|
||||
worker_tasks = [
|
||||
(i, games_per_worker, samples_per_game, min_move, max_move)
|
||||
for i in range(num_workers)
|
||||
]
|
||||
|
||||
positions_count = 0
|
||||
all_positions = []
|
||||
|
||||
with Pool(num_workers) as pool:
|
||||
with tqdm(total=num_workers, desc="Workers generating games") as pbar:
|
||||
for positions in pool.starmap(_worker_generate_games, worker_tasks):
|
||||
all_positions.extend(positions)
|
||||
positions_count += len(positions)
|
||||
pbar.update(1)
|
||||
|
||||
# Write all positions to file
|
||||
print(f"Writing {positions_count:,} positions to {output_file}...")
|
||||
with open(output_file, 'w') as f:
|
||||
for fen in all_positions:
|
||||
f.write(fen + '\n')
|
||||
|
||||
elapsed_time = datetime.now() - start_time
|
||||
elapsed_seconds = elapsed_time.total_seconds()
|
||||
positions_per_second = positions_count / elapsed_seconds if elapsed_seconds > 0 else 0
|
||||
|
||||
# Print summary
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("POSITION GENERATION SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Target positions: {total_positions:,}")
|
||||
print(f"Actual positions saved: {positions_count:,}")
|
||||
print(f"Workers: {num_workers}")
|
||||
print(f"Games per worker: {games_per_worker:,}")
|
||||
print(f"Samples per game: {samples_per_game}")
|
||||
print(f"Move range: {min_move}-{max_move}")
|
||||
print(f"Elapsed time: {elapsed_time}")
|
||||
print(f"Throughput: {positions_per_second:.0f} positions/second")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if positions_count == 0:
|
||||
print("WARNING: No valid positions were generated!")
|
||||
return 0
|
||||
|
||||
return positions_count
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training")
|
||||
parser.add_argument("output_file", nargs="?", default="positions.txt",
|
||||
help="Output file for positions (default: positions.txt)")
|
||||
parser.add_argument("--positions", type=int, default=3000000,
|
||||
help="Target number of positions to generate (default: 3000000)")
|
||||
parser.add_argument("--samples-per-game", type=int, default=1,
|
||||
help="Number of positions to sample per game (default: 1)")
|
||||
parser.add_argument("--min-move", type=int, default=1,
|
||||
help="Minimum move number to sample from (default: 1)")
|
||||
parser.add_argument("--max-move", type=int, default=50,
|
||||
help="Maximum move number to sample from (default: 50)")
|
||||
parser.add_argument("--workers", type=int, default=8,
|
||||
help="Number of parallel worker processes (default: 8)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
count = play_random_game_and_collect_positions(
|
||||
output_file=args.output_file,
|
||||
total_positions=args.positions,
|
||||
samples_per_game=args.samples_per_game,
|
||||
min_move=args.min_move,
|
||||
max_move=args.max_move,
|
||||
num_workers=args.workers
|
||||
)
|
||||
|
||||
sys.exit(0 if count > 0 else 1)
|
||||
@@ -0,0 +1,326 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Label positions with Stockfish evaluations and analyze distribution."""
|
||||
|
||||
import json
|
||||
import chess.engine
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
|
||||
def normalize_evaluation(cp_value, method='tanh', scale=300.0):
|
||||
"""Normalize centipawn evaluation to a bounded range.
|
||||
|
||||
Args:
|
||||
cp_value: Centipawn evaluation from Stockfish
|
||||
method: 'tanh' (default) or 'sigmoid'
|
||||
scale: Scale factor (tanh: 300 is typical)
|
||||
|
||||
Returns:
|
||||
Normalized value in approximately [-1, 1] (tanh) or [0, 1] (sigmoid)
|
||||
"""
|
||||
if method == 'tanh':
|
||||
return np.tanh(cp_value / scale)
|
||||
elif method == 'sigmoid':
|
||||
return 1.0 / (1.0 + np.exp(-cp_value / scale))
|
||||
else:
|
||||
return cp_value / 100.0
|
||||
|
||||
def _evaluate_fen_batch(args):
|
||||
"""Worker function to evaluate a batch of FENs with Stockfish threading.
|
||||
|
||||
Args:
|
||||
args: tuple of (fens, stockfish_path, depth, normalize)
|
||||
|
||||
Returns:
|
||||
list of (fen, eval_normalized, eval_raw) tuples
|
||||
"""
|
||||
fens, stockfish_path, depth, normalize = args
|
||||
|
||||
results = []
|
||||
|
||||
try:
|
||||
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
try:
|
||||
for fen in fens:
|
||||
try:
|
||||
board = chess.Board(fen)
|
||||
if not board.is_valid():
|
||||
continue
|
||||
|
||||
info = engine.analyse(board, chess.engine.Limit(depth=depth))
|
||||
|
||||
if info.get('score') is None:
|
||||
continue
|
||||
|
||||
score = info['score'].white()
|
||||
|
||||
if score.is_mate():
|
||||
eval_cp = 2000 if score.mate() > 0 else -2000
|
||||
else:
|
||||
eval_cp = score.cp
|
||||
|
||||
eval_cp = max(-2000, min(2000, eval_cp))
|
||||
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
|
||||
|
||||
results.append((fen, eval_normalized, eval_cp))
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
finally:
|
||||
engine.quit()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=1000, depth=12, verbose=False, normalize=True, num_workers=1):
|
||||
"""Read positions and label them with Stockfish evaluations.
|
||||
|
||||
Args:
|
||||
positions_file: Path to positions.txt
|
||||
output_file: Path to training_data.jsonl
|
||||
stockfish_path: Path to stockfish binary
|
||||
batch_size: Batch size for processing (positions per worker task, default: 1000)
|
||||
depth: Stockfish depth
|
||||
verbose: Print detailed error messages
|
||||
normalize: If True, normalize evals using tanh
|
||||
num_workers: Number of parallel Stockfish processes
|
||||
"""
|
||||
|
||||
# Check if stockfish exists
|
||||
if not Path(stockfish_path).exists():
|
||||
print(f"Error: Stockfish not found at {stockfish_path}")
|
||||
print(f"Tried: {stockfish_path}")
|
||||
print(f"Set STOCKFISH_PATH environment variable or pass as argument")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using Stockfish: {stockfish_path}")
|
||||
print(f"Number of workers: {num_workers}")
|
||||
|
||||
# Check if positions file exists
|
||||
if not Path(positions_file).exists():
|
||||
print(f"Error: Positions file not found at {positions_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load existing evaluations if resuming
|
||||
evaluated_fens = set()
|
||||
position_count = 0
|
||||
|
||||
if Path(output_file).exists():
|
||||
with open(output_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
evaluated_fens.add(data['fen'])
|
||||
position_count += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
print(f"Resuming from {position_count} already evaluated positions")
|
||||
|
||||
# Load all FENs that need evaluation
|
||||
fens_to_evaluate = []
|
||||
fens_seen_in_batch = set() # Track duplicates within current batch
|
||||
skipped_invalid = 0
|
||||
skipped_duplicate = 0
|
||||
|
||||
with open(positions_file, 'r') as f:
|
||||
for fen in f:
|
||||
fen = fen.strip()
|
||||
|
||||
if not fen:
|
||||
skipped_invalid += 1
|
||||
continue
|
||||
|
||||
if fen in evaluated_fens:
|
||||
skipped_duplicate += 1
|
||||
continue
|
||||
|
||||
if fen in fens_seen_in_batch:
|
||||
skipped_duplicate += 1
|
||||
continue
|
||||
|
||||
fens_to_evaluate.append(fen)
|
||||
fens_seen_in_batch.add(fen)
|
||||
|
||||
total_to_evaluate = len(fens_to_evaluate)
|
||||
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
|
||||
|
||||
if total_to_evaluate == 0:
|
||||
if position_count == 0:
|
||||
print(f"Error: No valid positions to evaluate in {positions_file}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"All positions already evaluated. No new positions to process.")
|
||||
return True
|
||||
|
||||
print(f"Total positions to process: {total_lines}")
|
||||
print(f"New positions to evaluate: {total_to_evaluate}")
|
||||
print(f"Using depth: {depth}")
|
||||
print()
|
||||
|
||||
# Split FENs into batches for workers
|
||||
batches = []
|
||||
for i in range(0, total_to_evaluate, batch_size):
|
||||
batch = fens_to_evaluate[i:i+batch_size]
|
||||
batches.append((batch, stockfish_path, depth, normalize))
|
||||
|
||||
# Process batches in parallel
|
||||
evaluated = 0
|
||||
errors = 0
|
||||
raw_evals = []
|
||||
normalized_evals = []
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
with Pool(num_workers) as pool:
|
||||
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar:
|
||||
with open(output_file, 'a') as out:
|
||||
for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
|
||||
for fen, eval_normalized, eval_cp in batch_results:
|
||||
# Skip if already evaluated in output file during this run
|
||||
if fen in evaluated_fens:
|
||||
continue
|
||||
|
||||
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
|
||||
out.write(json.dumps(data) + '\n')
|
||||
evaluated_fens.add(fen) # Track as evaluated
|
||||
evaluated += 1
|
||||
raw_evals.append(eval_cp)
|
||||
normalized_evals.append(eval_normalized)
|
||||
pbar.update(1)
|
||||
|
||||
# Update progress for any failed evaluations in the batch
|
||||
batch_size_actual = len(batches[0][0]) if batches else batch_size
|
||||
failed = batch_size_actual - len(batch_results)
|
||||
if failed > 0:
|
||||
errors += failed
|
||||
pbar.update(failed)
|
||||
|
||||
# Calculate and show throughput and ETA
|
||||
elapsed = time.time() - start_time
|
||||
throughput = evaluated / elapsed if elapsed > 0 else 0
|
||||
remaining_positions = total_to_evaluate - evaluated
|
||||
eta_seconds = remaining_positions / throughput if throughput > 0 else 0
|
||||
eta_str = f"{int(eta_seconds // 60)}:{int(eta_seconds % 60):02d}"
|
||||
|
||||
if (batch_idx + 1) % max(1, len(batches) // 10) == 0:
|
||||
pbar.set_postfix({
|
||||
'rate': f'{throughput:.0f} pos/s',
|
||||
'eta': eta_str
|
||||
})
|
||||
|
||||
# Print summary and analysis
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("LABELING SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Successfully evaluated: {evaluated}")
|
||||
print(f"Skipped (duplicates): {skipped_duplicate}")
|
||||
print(f"Skipped (invalid): {skipped_invalid}")
|
||||
print(f"Errors: {errors}")
|
||||
print(f"Total processed: {evaluated + skipped_duplicate + skipped_invalid + errors}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if evaluated == 0:
|
||||
print("WARNING: No positions were successfully evaluated!")
|
||||
print("Check that:")
|
||||
print(" 1. positions.txt is not empty")
|
||||
print(" 2. positions.txt contains valid FENs")
|
||||
print(" 3. Stockfish is installed and working")
|
||||
print(" 4. Stockfish path is correct")
|
||||
return False
|
||||
|
||||
# Print distribution analysis
|
||||
if raw_evals:
|
||||
raw_evals_arr = np.array(raw_evals)
|
||||
norm_evals_arr = np.array(normalized_evals)
|
||||
|
||||
print("=" * 60)
|
||||
print("EVALUATION DISTRIBUTION ANALYSIS")
|
||||
print("=" * 60)
|
||||
print()
|
||||
print("Raw Evaluations (centipawns):")
|
||||
print(f" Min: {raw_evals_arr.min():.1f}")
|
||||
print(f" Max: {raw_evals_arr.max():.1f}")
|
||||
print(f" Mean: {raw_evals_arr.mean():.1f}")
|
||||
print(f" Median: {np.median(raw_evals_arr):.1f}")
|
||||
print(f" Std: {raw_evals_arr.std():.1f}")
|
||||
print()
|
||||
|
||||
print("Normalized Evaluations (tanh):")
|
||||
print(f" Min: {norm_evals_arr.min():.4f}")
|
||||
print(f" Max: {norm_evals_arr.max():.4f}")
|
||||
print(f" Mean: {norm_evals_arr.mean():.4f}")
|
||||
print(f" Median: {np.median(norm_evals_arr):.4f}")
|
||||
print(f" Std: {norm_evals_arr.std():.4f}")
|
||||
print()
|
||||
|
||||
# Distribution buckets
|
||||
print("Raw Evaluation Buckets (counts):")
|
||||
buckets = [
|
||||
(-float('inf'), -500, "< -5.00"),
|
||||
(-500, -300, "[-5.00, -3.00)"),
|
||||
(-300, -100, "[-3.00, -1.00)"),
|
||||
(-100, 0, "[-1.00, 0.00)"),
|
||||
(0, 100, "[0.00, 1.00)"),
|
||||
(100, 300, "[1.00, 3.00)"),
|
||||
(300, 500, "[3.00, 5.00)"),
|
||||
(500, float('inf'), "> 5.00"),
|
||||
]
|
||||
for low, high, label in buckets:
|
||||
count = np.sum((raw_evals_arr > low) & (raw_evals_arr <= high))
|
||||
pct = 100.0 * count / len(raw_evals_arr)
|
||||
print(f" {label}: {count:6d} ({pct:5.1f}%)")
|
||||
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
print(f"✓ Labeling complete. Output saved to {output_file}")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Label chess positions with Stockfish evaluations")
|
||||
parser.add_argument("positions_file", nargs="?", default="positions.txt",
|
||||
help="Input positions file (default: positions.txt)")
|
||||
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
|
||||
help="Output file (default: training_data.jsonl)")
|
||||
parser.add_argument("stockfish_path", nargs="?", default=None,
|
||||
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
||||
parser.add_argument("--depth", type=int, default=12,
|
||||
help="Stockfish depth (default: 12)")
|
||||
parser.add_argument("--batch-size", type=int, default=1000,
|
||||
help="Batch size for processing (default: 1000)")
|
||||
parser.add_argument("--no-normalize", action="store_true",
|
||||
help="Disable evaluation normalization (keep raw centipawns)")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Print detailed error messages")
|
||||
parser.add_argument("--workers", type=int, default=1,
|
||||
help="Number of parallel Stockfish processes (default: 1)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine Stockfish path
|
||||
stockfish_path = args.stockfish_path or os.environ.get("STOCKFISH_PATH", "stockfish")
|
||||
|
||||
success = label_positions_with_stockfish(
|
||||
positions_file=args.positions_file,
|
||||
output_file=args.output_file,
|
||||
stockfish_path=stockfish_path,
|
||||
batch_size=args.batch_size,
|
||||
depth=args.depth,
|
||||
normalize=not args.no_normalize,
|
||||
verbose=args.verbose,
|
||||
num_workers=args.workers
|
||||
)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Import pre-labeled positions from the Lichess evaluation database.
|
||||
|
||||
Source: https://database.lichess.org/#evals
|
||||
Format: lichess_db_eval.jsonl.zst — compressed JSONL, one position per line.
|
||||
|
||||
Each line:
|
||||
{
|
||||
"fen": "<pieces> <turn> <castling> <ep>",
|
||||
"evals": [
|
||||
{
|
||||
"knodes": <int>,
|
||||
"depth": <int>,
|
||||
"pvs": [{"cp": <int>, "line": "..."} | {"mate": <int>, "line": "..."}]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
cp and mate are from White's perspective (positive = White winning), matching
|
||||
the sign convention used by label.py (score.white()) and expected by train.py.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
MATE_CP = 20000
|
||||
SCALE = 300.0
|
||||
|
||||
|
||||
def _best_eval(evals: list) -> dict | None:
|
||||
"""Return the highest-depth evaluation entry, using knodes as tiebreaker."""
|
||||
if not evals:
|
||||
return None
|
||||
return max(evals, key=lambda e: (e.get("depth", 0), e.get("knodes", 0)))
|
||||
|
||||
|
||||
def _cp_from_pv(pv: dict) -> int | None:
|
||||
"""Extract centipawn value from a principal variation entry."""
|
||||
if "cp" in pv:
|
||||
return max(-MATE_CP, min(MATE_CP, pv["cp"]))
|
||||
if "mate" in pv:
|
||||
return MATE_CP if pv["mate"] > 0 else -MATE_CP
|
||||
return None
|
||||
|
||||
|
||||
def _normalize(cp: int) -> float:
|
||||
return float(np.tanh(cp / SCALE))
|
||||
|
||||
|
||||
def import_lichess_evals(
|
||||
input_path: str,
|
||||
output_file: str,
|
||||
max_positions: int | None = None,
|
||||
min_depth: int = 0,
|
||||
verbose: bool = False,
|
||||
) -> int:
|
||||
"""Stream the Lichess eval database and write a labeled.jsonl file.
|
||||
|
||||
Args:
|
||||
input_path: Path to lichess_db_eval.jsonl.zst (or uncompressed .jsonl).
|
||||
output_file: Destination labeled.jsonl (appended — supports resuming).
|
||||
max_positions: Stop after this many new positions (None = no limit).
|
||||
min_depth: Skip positions whose best eval has depth < min_depth.
|
||||
verbose: Print warnings for skipped lines.
|
||||
|
||||
Returns:
|
||||
Number of new positions written.
|
||||
"""
|
||||
import zstandard as zstd
|
||||
|
||||
input_path = Path(input_path)
|
||||
if not input_path.exists():
|
||||
print(f"Error: {input_path} not found")
|
||||
sys.exit(1)
|
||||
|
||||
# Resume: collect already-written FENs so we skip duplicates.
|
||||
seen_fens: set[str] = set()
|
||||
if Path(output_file).exists():
|
||||
with open(output_file, "r") as f:
|
||||
for line in f:
|
||||
try:
|
||||
seen_fens.add(json.loads(line)["fen"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
if seen_fens:
|
||||
print(f"Resuming — skipping {len(seen_fens):,} already-imported positions")
|
||||
|
||||
written = 0
|
||||
skipped_depth = 0
|
||||
skipped_no_eval = 0
|
||||
skipped_dup = 0
|
||||
|
||||
def iter_lines():
|
||||
"""Yield decoded text lines from either a .zst or plain .jsonl file."""
|
||||
import io
|
||||
if input_path.suffix == ".zst":
|
||||
dctx = zstd.ZstdDecompressor()
|
||||
with open(input_path, "rb") as fh:
|
||||
with dctx.stream_reader(fh) as reader:
|
||||
text_stream = io.TextIOWrapper(reader, encoding="utf-8")
|
||||
yield from text_stream
|
||||
else:
|
||||
with open(input_path, "r", encoding="utf-8") as fh:
|
||||
yield from fh
|
||||
|
||||
try:
|
||||
with open(output_file, "a") as out:
|
||||
with tqdm(desc="Importing Lichess evals", unit=" pos") as pbar:
|
||||
for raw_line in iter_lines():
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
if verbose:
|
||||
print("Warning: malformed JSON line skipped")
|
||||
continue
|
||||
|
||||
fen = data.get("fen", "")
|
||||
if not fen:
|
||||
skipped_no_eval += 1
|
||||
continue
|
||||
|
||||
if fen in seen_fens:
|
||||
skipped_dup += 1
|
||||
continue
|
||||
|
||||
best = _best_eval(data.get("evals", []))
|
||||
if best is None:
|
||||
skipped_no_eval += 1
|
||||
continue
|
||||
|
||||
if best.get("depth", 0) < min_depth:
|
||||
skipped_depth += 1
|
||||
continue
|
||||
|
||||
pvs = best.get("pvs", [])
|
||||
if not pvs:
|
||||
skipped_no_eval += 1
|
||||
continue
|
||||
|
||||
cp = _cp_from_pv(pvs[0])
|
||||
if cp is None:
|
||||
skipped_no_eval += 1
|
||||
continue
|
||||
|
||||
record = {
|
||||
"fen": fen,
|
||||
"eval": _normalize(cp),
|
||||
"eval_raw": cp,
|
||||
}
|
||||
out.write(json.dumps(record) + "\n")
|
||||
seen_fens.add(fen)
|
||||
written += 1
|
||||
pbar.update(1)
|
||||
|
||||
if max_positions and written >= max_positions:
|
||||
print(f"\nReached max_positions limit ({max_positions:,})")
|
||||
break
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("LICHESS IMPORT SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Positions written: {written:,}")
|
||||
print(f"Skipped (dup): {skipped_dup:,}")
|
||||
print(f"Skipped (no eval): {skipped_no_eval:,}")
|
||||
print(f"Skipped (depth<{min_depth}): {skipped_depth:,}")
|
||||
print("=" * 60)
|
||||
print(f"\n✓ Output: {output_file}")
|
||||
|
||||
return written
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Import Lichess pre-labeled positions into labeled.jsonl"
|
||||
)
|
||||
parser.add_argument("input_path",
|
||||
help="Path to lichess_db_eval.jsonl.zst")
|
||||
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
|
||||
help="Output labeled.jsonl (default: training_data.jsonl)")
|
||||
parser.add_argument("--max-positions", type=int, default=None,
|
||||
help="Stop after N positions (default: no limit)")
|
||||
parser.add_argument("--min-depth", type=int, default=0,
|
||||
help="Minimum eval depth to accept (default: 0)")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Print warnings for skipped lines")
|
||||
|
||||
args = parser.parse_args()
|
||||
count = import_lichess_evals(
|
||||
input_path=args.input_path,
|
||||
output_file=args.output_file,
|
||||
max_positions=args.max_positions,
|
||||
min_depth=args.min_depth,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
sys.exit(0 if count > 0 else 1)
|
||||
@@ -0,0 +1,249 @@
|
||||
import chess
|
||||
import csv
|
||||
import json
|
||||
import sys
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from typing import Set, Tuple
|
||||
|
||||
try:
|
||||
import zstandard as zstd
|
||||
except ImportError:
|
||||
print("zstandard library not found. Install with: pip install zstandard")
|
||||
sys.exit(1)
|
||||
|
||||
from generate import play_random_game_and_collect_positions
|
||||
|
||||
|
||||
def download_and_extract_puzzle_db(
|
||||
url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst',
|
||||
output_dir: str = 'tactical_data'
|
||||
):
|
||||
"""Download and extract the Lichess puzzle database."""
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
csv_file = output_path / 'lichess_db_puzzle.csv'
|
||||
zst_file = output_path / 'lichess_db_puzzle.csv.zst'
|
||||
|
||||
# Download if not already present
|
||||
if not zst_file.exists():
|
||||
print(f"Downloading puzzle database from {url}...")
|
||||
try:
|
||||
urllib.request.urlretrieve(url, zst_file)
|
||||
print(f"Downloaded to {zst_file}")
|
||||
except Exception as e:
|
||||
print(f"Failed to download: {e}")
|
||||
return None
|
||||
|
||||
# Extract if CSV doesn't exist
|
||||
if not csv_file.exists():
|
||||
print(f"Extracting {zst_file}...")
|
||||
try:
|
||||
with open(zst_file, 'rb') as f:
|
||||
dctx = zstd.ZstdDecompressor()
|
||||
with dctx.stream_reader(f) as reader:
|
||||
with open(csv_file, 'wb') as out:
|
||||
out.write(reader.read())
|
||||
print(f"Extracted to {csv_file}")
|
||||
except Exception as e:
|
||||
print(f"Failed to extract: {e}")
|
||||
return None
|
||||
|
||||
return str(csv_file)
|
||||
|
||||
|
||||
def extract_puzzle_positions(
|
||||
puzzle_csv: str,
|
||||
max_puzzles: int = 300_000
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Extract the position BEFORE the blunder from each puzzle.
|
||||
This is exactly the type of position where tactical
|
||||
recognition matters most.
|
||||
|
||||
Returns a set of unique FENs.
|
||||
"""
|
||||
positions = set()
|
||||
|
||||
with open(puzzle_csv) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if len(positions) >= max_puzzles:
|
||||
break
|
||||
|
||||
try:
|
||||
board = chess.Board(row['FEN'])
|
||||
|
||||
# The puzzle FEN is AFTER the blunder move
|
||||
# We want the position BEFORE — so it learns
|
||||
# to find the tactic, not just play it
|
||||
moves = row['Moves'].split()
|
||||
|
||||
# Undo one move to get pre-tactic position
|
||||
board.push_uci(moves[0]) # opponent blunder
|
||||
fen = board.fen()
|
||||
|
||||
# Filter for useful tactical themes
|
||||
themes = row.get('Themes', '')
|
||||
useful = any(t in themes for t in [
|
||||
'fork', 'pin', 'skewer', 'discoveredAttack',
|
||||
'mate', 'mateIn2', 'mateIn3', 'hangingPiece',
|
||||
'trappedPiece', 'sacrifice'
|
||||
])
|
||||
|
||||
if useful:
|
||||
positions.add(fen)
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return positions
|
||||
|
||||
|
||||
def load_positions_from_file(file_path: str) -> Set[str]:
|
||||
"""Load positions from a text file (one FEN per line)."""
|
||||
positions = set()
|
||||
try:
|
||||
with open(file_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
positions.add(line)
|
||||
print(f"Loaded {len(positions)} positions from {file_path}")
|
||||
return positions
|
||||
except Exception as e:
|
||||
print(f"Failed to load from {file_path}: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
def merge_positions(
|
||||
tactical: Set[str],
|
||||
other: Set[str],
|
||||
output_file: str = 'position.txt'
|
||||
):
|
||||
"""Merge two position sets and write to file."""
|
||||
merged = tactical | other
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
for fen in merged:
|
||||
f.write(fen + '\n')
|
||||
|
||||
overlap = len(tactical & other)
|
||||
print(f"\n{'='*60}")
|
||||
print(f"MERGE SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
print(f"Tactical positions: {len(tactical):,}")
|
||||
print(f"Other positions: {len(other):,}")
|
||||
print(f"Overlap (deduplicated): {overlap:,}")
|
||||
print(f"Total merged positions: {len(merged):,}")
|
||||
print(f"Written to: {output_file}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def extract_tactical_only(
|
||||
puzzle_csv: str,
|
||||
output_file: str,
|
||||
max_puzzles: int = 300_000
|
||||
) -> int:
|
||||
"""Extract tactical positions and save to file (no merge prompts).
|
||||
|
||||
Args:
|
||||
puzzle_csv: Path to Lichess puzzle CSV
|
||||
output_file: Where to save the FEN positions
|
||||
max_puzzles: Maximum puzzles to extract
|
||||
|
||||
Returns:
|
||||
Number of positions extracted
|
||||
"""
|
||||
print("Extracting tactical positions from puzzle database...")
|
||||
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
for fen in tactical_positions:
|
||||
f.write(fen + '\n')
|
||||
|
||||
return len(tactical_positions)
|
||||
|
||||
|
||||
def interactive_merge_positions(
|
||||
puzzle_csv: str,
|
||||
output_file: str = 'position.txt',
|
||||
max_puzzles: int = 300_000
|
||||
):
|
||||
"""Interactive workflow: extract tactical positions and merge with user selection."""
|
||||
print("\n" + "="*60)
|
||||
print("TACTICAL POSITION EXTRACTOR & MERGER")
|
||||
print("="*60 + "\n")
|
||||
|
||||
# Extract tactical positions
|
||||
print("Extracting tactical positions from puzzle database...")
|
||||
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
|
||||
print(f"Extracted {len(tactical_positions):,} unique tactical positions\n")
|
||||
|
||||
# Ask what to merge with
|
||||
print("What would you like to merge with these tactical positions?")
|
||||
print("1. Load from a position file")
|
||||
print("2. Generate random positions")
|
||||
print("3. Skip merging (save tactical only)")
|
||||
|
||||
choice = input("\nEnter choice (1-3): ").strip()
|
||||
|
||||
other_positions = set()
|
||||
|
||||
if choice == '1':
|
||||
file_path = input("Enter path to position file: ").strip()
|
||||
other_positions = load_positions_from_file(file_path)
|
||||
|
||||
elif choice == '2':
|
||||
positions_to_gen = input("How many positions to generate? (default 1000000): ").strip()
|
||||
try:
|
||||
positions_to_gen = int(positions_to_gen) if positions_to_gen else 1000000
|
||||
except ValueError:
|
||||
positions_to_gen = 1000000
|
||||
|
||||
temp_file = 'temp_generated_positions.txt'
|
||||
print(f"\nGenerating {positions_to_gen:,} random positions...")
|
||||
play_random_game_and_collect_positions(
|
||||
output_file=temp_file,
|
||||
total_positions=positions_to_gen,
|
||||
samples_per_game=1,
|
||||
min_move=1,
|
||||
max_move=50,
|
||||
num_workers=8
|
||||
)
|
||||
other_positions = load_positions_from_file(temp_file)
|
||||
|
||||
elif choice == '3':
|
||||
print("Skipping merge, saving tactical positions only...")
|
||||
|
||||
else:
|
||||
print("Invalid choice, saving tactical positions only...")
|
||||
|
||||
merge_positions(tactical_positions, other_positions, output_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Extract and merge tactical positions")
|
||||
parser.add_argument("--url", default='https://database.lichess.org/lichess_db_puzzle.csv.zst',
|
||||
help="URL to download puzzle database from")
|
||||
parser.add_argument("--output-dir", default='trainingdata',
|
||||
help="Directory to extract puzzle database to")
|
||||
parser.add_argument("--max-puzzles", type=int, default=300_000,
|
||||
help="Maximum puzzles to extract (default: 300000)")
|
||||
parser.add_argument("--output-file", default='position.txt',
|
||||
help="Output file for merged positions (default: position.txt)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Download and extract
|
||||
csv_path = download_and_extract_puzzle_db(args.url, args.output_dir)
|
||||
|
||||
if csv_path:
|
||||
# Interactive merge
|
||||
interactive_merge_positions(csv_path, args.output_file, args.max_puzzles)
|
||||
else:
|
||||
print("Failed to download/extract puzzle database")
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,676 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train NNUE neural network for chess evaluation."""
|
||||
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import chess
|
||||
from datetime import datetime, timedelta
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
class NNUEDataset(Dataset):
|
||||
"""Dataset of chess positions with evaluations."""
|
||||
|
||||
def __init__(self, data_file):
|
||||
self.positions = []
|
||||
self.evals = []
|
||||
self.evals_raw = []
|
||||
self.is_normalized = None
|
||||
|
||||
with open(data_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data['fen']
|
||||
eval_val = data['eval']
|
||||
self.positions.append(fen)
|
||||
self.evals.append(eval_val)
|
||||
|
||||
# Check if normalized or raw
|
||||
if self.is_normalized is None:
|
||||
# If eval is in range [-1, 1], assume normalized
|
||||
self.is_normalized = abs(eval_val) <= 1.0
|
||||
|
||||
# Store raw if available
|
||||
if 'eval_raw' in data:
|
||||
self.evals_raw.append(data['eval_raw'])
|
||||
else:
|
||||
self.evals_raw.append(eval_val)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
return len(self.positions)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
fen = self.positions[idx]
|
||||
eval_val = self.evals[idx]
|
||||
features = fen_to_features(fen)
|
||||
|
||||
# Use evaluation as-is if normalized, otherwise apply sigmoid scaling
|
||||
if self.is_normalized:
|
||||
target = torch.tensor(eval_val, dtype=torch.float32)
|
||||
else:
|
||||
target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32))
|
||||
|
||||
return features, target
|
||||
|
||||
def fen_to_features(fen):
|
||||
"""Convert FEN to 768-dimensional binary feature vector."""
|
||||
# Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5
|
||||
piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
|
||||
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}
|
||||
|
||||
features = torch.zeros(768, dtype=torch.float32)
|
||||
|
||||
try:
|
||||
board = chess.Board(fen)
|
||||
|
||||
# 12 piece types × 64 squares = 768
|
||||
for square in chess.SQUARES:
|
||||
piece = board.piece_at(square)
|
||||
if piece is not None:
|
||||
piece_char = piece.symbol()
|
||||
if piece_char in piece_to_idx:
|
||||
piece_idx = piece_to_idx[piece_char]
|
||||
feature_idx = piece_idx * 64 + square
|
||||
features[feature_idx] = 1.0
|
||||
except:
|
||||
pass
|
||||
|
||||
return features
|
||||
|
||||
DEFAULT_HIDDEN_SIZES = [1536, 1024, 512, 256]
|
||||
|
||||
|
||||
class NNUE(nn.Module):
|
||||
"""NNUE neural network with configurable hidden layers.
|
||||
|
||||
Architecture: 768 → hidden_sizes[0] → ... → hidden_sizes[-1] → 1
|
||||
Layer attributes follow the naming l1, l2, ..., lN so export.py can
|
||||
infer the architecture directly from the state_dict.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_sizes=None, dropout_rate=0.2):
|
||||
super().__init__()
|
||||
if hidden_sizes is None:
|
||||
hidden_sizes = DEFAULT_HIDDEN_SIZES
|
||||
self.hidden_sizes = list(hidden_sizes)
|
||||
sizes = [768] + self.hidden_sizes + [1]
|
||||
num_hidden = len(self.hidden_sizes)
|
||||
|
||||
for i in range(num_hidden):
|
||||
setattr(self, f"l{i + 1}", nn.Linear(sizes[i], sizes[i + 1]))
|
||||
setattr(self, f"relu{i + 1}", nn.ReLU())
|
||||
setattr(self, f"drop{i + 1}", nn.Dropout(dropout_rate))
|
||||
setattr(self, f"l{num_hidden + 1}", nn.Linear(sizes[-2], sizes[-1]))
|
||||
self._num_hidden = num_hidden
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(1, self._num_hidden + 1):
|
||||
layer = getattr(self, f"l{i}")
|
||||
relu = getattr(self, f"relu{i}")
|
||||
drop = getattr(self, f"drop{i}")
|
||||
x = drop(relu(layer(x)))
|
||||
return getattr(self, f"l{self._num_hidden + 1}")(x)
|
||||
|
||||
def find_next_version(base_name="nnue_weights"):
|
||||
"""Find the next version number for model versioning.
|
||||
|
||||
Looks for nnue_weights_v*.pt files and returns the next version number.
|
||||
If no versioned files exist, returns 1.
|
||||
"""
|
||||
base_path = Path(base_name)
|
||||
directory = base_path.parent
|
||||
filename = base_path.name
|
||||
|
||||
pattern = re.compile(rf"{re.escape(filename)}_v(\d+)\.pt")
|
||||
versions = []
|
||||
|
||||
for file in directory.glob(f"{filename}_v*.pt"):
|
||||
match = pattern.match(file.name)
|
||||
if match:
|
||||
versions.append(int(match.group(1)))
|
||||
|
||||
if versions:
|
||||
return max(versions) + 1
|
||||
return 1
|
||||
|
||||
def save_metadata(weights_file, metadata):
|
||||
"""Save training metadata alongside the weights file.
|
||||
|
||||
Args:
|
||||
weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt)
|
||||
metadata: Dictionary with training info
|
||||
"""
|
||||
metadata_file = weights_file.replace(".pt", "_metadata.json")
|
||||
|
||||
with open(metadata_file, "w") as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
|
||||
return metadata_file
|
||||
|
||||
def _setup_training(data_file, batch_size, subsample_ratio):
|
||||
"""Set up device, dataset, and data loaders.
|
||||
|
||||
Returns:
|
||||
(device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions)
|
||||
"""
|
||||
print("Checking GPU availability...")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.is_available():
|
||||
print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
|
||||
print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
||||
else:
|
||||
print("⚠ GPU not available, using CPU")
|
||||
print(f"Using device: {device}")
|
||||
print()
|
||||
|
||||
print("Loading dataset...")
|
||||
dataset = NNUEDataset(data_file)
|
||||
num_positions = len(dataset)
|
||||
print(f"Dataset size: {num_positions}")
|
||||
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
|
||||
|
||||
evals_array = np.array(dataset.evals)
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("TRAINING DATASET DIAGNOSTICS")
|
||||
print("=" * 60)
|
||||
print(f"Min evaluation: {evals_array.min():.4f}")
|
||||
print(f"Max evaluation: {evals_array.max():.4f}")
|
||||
print(f"Mean evaluation: {evals_array.mean():.4f}")
|
||||
print(f"Median evaluation: {np.median(evals_array):.4f}")
|
||||
print(f"Std deviation: {evals_array.std():.4f}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
train_size = int(0.9 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
|
||||
from torch.utils.data import random_split, RandomSampler
|
||||
generator = torch.Generator().manual_seed(42)
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
|
||||
|
||||
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
||||
train_sampler = RandomSampler(train_dataset, replacement=False, num_samples=subsample_size)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
)
|
||||
|
||||
return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions
|
||||
|
||||
def _run_training_season(
|
||||
model, optimizer, scheduler, scaler,
|
||||
train_loader, val_loader, train_dataset, val_dataset,
|
||||
device, criterion, output_file,
|
||||
start_epoch, epochs, early_stopping_patience,
|
||||
season_start_time, deadline=None, initial_best_val_loss=float('inf')
|
||||
):
|
||||
"""Run one training season until epoch limit, early stopping, or deadline.
|
||||
|
||||
Args:
|
||||
initial_best_val_loss: Baseline to beat — epochs that don't improve on this count
|
||||
toward early stopping and do not save snapshots.
|
||||
Returns:
|
||||
(best_val_loss, best_model_state, last_epoch)
|
||||
best_model_state is None if no epoch beat initial_best_val_loss.
|
||||
"""
|
||||
best_val_loss = initial_best_val_loss
|
||||
best_model_state = None
|
||||
epochs_without_improvement = 0
|
||||
total_epochs = start_epoch + epochs
|
||||
last_epoch = start_epoch - 1
|
||||
|
||||
for epoch in range(start_epoch, start_epoch + epochs):
|
||||
if deadline and datetime.now() >= deadline:
|
||||
print("Time limit reached, stopping season.")
|
||||
break
|
||||
|
||||
epoch_display = epoch + 1
|
||||
|
||||
# Train
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
|
||||
for batch_features, batch_targets in train_loader:
|
||||
batch_features = batch_features.to(device)
|
||||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
train_loss += loss.item() * batch_features.size(0)
|
||||
pbar.update(1)
|
||||
|
||||
train_loss /= len(train_dataset)
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar:
|
||||
for batch_features, batch_targets in val_loader:
|
||||
batch_features = batch_features.to(device)
|
||||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||
|
||||
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
val_loss += loss.item() * batch_features.size(0)
|
||||
pbar.update(1)
|
||||
|
||||
val_loss /= len(val_dataset)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9
|
||||
gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9
|
||||
print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved")
|
||||
|
||||
elapsed_time = datetime.now() - season_start_time
|
||||
time_per_epoch = elapsed_time.total_seconds() / (epoch + 1)
|
||||
remaining_epochs = total_epochs - epoch_display
|
||||
eta_seconds = time_per_epoch * remaining_epochs
|
||||
eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0]
|
||||
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}")
|
||||
|
||||
checkpoint_file = output_file.replace(".pt", "_checkpoint.pt")
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"scaler_state_dict": scaler.state_dict(),
|
||||
"best_val_loss": best_val_loss,
|
||||
"hidden_sizes": model.hidden_sizes,
|
||||
}, checkpoint_file)
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_model_state = model.state_dict().copy()
|
||||
epochs_without_improvement = 0
|
||||
snapshot_file = output_file.replace(".pt", "_best_snapshot.pt")
|
||||
torch.save(best_model_state, snapshot_file)
|
||||
print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})")
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
|
||||
last_epoch = epoch
|
||||
|
||||
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
|
||||
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
|
||||
break
|
||||
|
||||
return best_val_loss, best_model_state, last_epoch
|
||||
|
||||
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||
best_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, training_start_time, hidden_sizes=None,
|
||||
extra_metadata=None):
|
||||
"""Save the best model with optional versioning and metadata."""
|
||||
final_output_file = output_file
|
||||
metadata = {}
|
||||
architecture = [768] + list(hidden_sizes or DEFAULT_HIDDEN_SIZES) + [1]
|
||||
|
||||
if use_versioning:
|
||||
base_name = output_file.replace(".pt", "")
|
||||
version = find_next_version(base_name)
|
||||
final_output_file = f"{base_name}_v{version}.pt"
|
||||
|
||||
metadata = {
|
||||
"version": version,
|
||||
"date": training_start_time.isoformat(),
|
||||
"num_positions": num_positions,
|
||||
"stockfish_depth": stockfish_depth,
|
||||
"final_val_loss": float(best_val_loss),
|
||||
"architecture": architecture,
|
||||
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
torch.save({
|
||||
"model_state_dict": best_model_state,
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"scaler_state_dict": scaler.state_dict(),
|
||||
"epoch": last_epoch,
|
||||
"best_val_loss": best_val_loss,
|
||||
"hidden_sizes": list(hidden_sizes or DEFAULT_HIDDEN_SIZES),
|
||||
}, final_output_file)
|
||||
print(f"Best model saved to {final_output_file}")
|
||||
|
||||
if use_versioning and metadata:
|
||||
metadata_file = save_metadata(final_output_file, metadata)
|
||||
print(f"Metadata saved to {metadata_file}")
|
||||
print(f"\nTraining Summary:")
|
||||
for key, val in metadata.items():
|
||||
print(f" {key}: {val}")
|
||||
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
|
||||
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
|
||||
|
||||
Args:
|
||||
data_file: Path to training_data.jsonl
|
||||
output_file: Where to save best weights (or base name if use_versioning=True)
|
||||
epochs: Number of training epochs (default: 100)
|
||||
batch_size: Training batch size (default: 16384)
|
||||
lr: Learning rate (default: 0.001)
|
||||
checkpoint: Optional path to checkpoint file to resume from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
|
||||
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
|
||||
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
|
||||
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
|
||||
"""
|
||||
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||
_setup_training(data_file, batch_size, subsample_ratio)
|
||||
|
||||
start_epoch = 0
|
||||
best_val_loss = float('inf')
|
||||
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
|
||||
|
||||
if checkpoint:
|
||||
print(f"Loading checkpoint: {checkpoint}")
|
||||
ckpt = torch.load(checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
ckpt_hidden = ckpt.get("hidden_sizes")
|
||||
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
|
||||
print(f" Using architecture from checkpoint: {ckpt_hidden}")
|
||||
resolved_hidden_sizes = ckpt_hidden
|
||||
|
||||
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||
|
||||
if checkpoint:
|
||||
ckpt = torch.load(checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
||||
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
||||
scaler.load_state_dict(ckpt["scaler_state_dict"])
|
||||
start_epoch = ckpt["epoch"] + 1
|
||||
best_val_loss = ckpt.get("best_val_loss", float('inf'))
|
||||
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
print("Loaded weights-only checkpoint (no optimizer state)")
|
||||
|
||||
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
|
||||
|
||||
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
||||
arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
|
||||
print(f"Architecture: {arch_str}")
|
||||
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
|
||||
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
|
||||
print(f"Mixed precision training: enabled")
|
||||
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
|
||||
if subsample_ratio < 1.0:
|
||||
print(f"Stochastic sampling: {subsample_ratio:.0%} of train set per epoch ({subsample_size:,} positions)")
|
||||
if early_stopping_patience:
|
||||
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
|
||||
print()
|
||||
|
||||
training_start_time = datetime.now()
|
||||
|
||||
best_val_loss, best_model_state, last_epoch = _run_training_season(
|
||||
model, optimizer, scheduler, scaler,
|
||||
train_loader, val_loader, train_dataset, val_dataset,
|
||||
device, criterion, output_file,
|
||||
start_epoch, epochs, early_stopping_patience,
|
||||
training_start_time
|
||||
)
|
||||
|
||||
if best_model_state is None or best_val_loss >= checkpoint_val_loss:
|
||||
print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})")
|
||||
print("No new model created.")
|
||||
return
|
||||
|
||||
_save_versioned_model(
|
||||
best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||
best_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, training_start_time,
|
||||
hidden_sizes=resolved_hidden_sizes,
|
||||
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
|
||||
"checkpoint": str(checkpoint) if checkpoint else None}
|
||||
)
|
||||
|
||||
def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
||||
epochs_per_season=50, early_stopping_patience=10,
|
||||
batch_size=16384, lr=0.001, initial_checkpoint=None,
|
||||
stockfish_depth=12, use_versioning=True,
|
||||
weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
|
||||
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
|
||||
|
||||
Each season trains with early stopping. When early stopping fires, the model reloads the
|
||||
global best weights and begins a fresh season with a reset optimizer and scheduler.
|
||||
This prevents the model from drifting away from its best known state.
|
||||
|
||||
Args:
|
||||
data_file: Path to training_data.jsonl
|
||||
output_file: Output file base name
|
||||
duration_minutes: Total training budget in minutes
|
||||
epochs_per_season: Max epochs per restart season (default: 50)
|
||||
early_stopping_patience: Patience for early stopping within each season (default: 10)
|
||||
batch_size: Training batch size (default: 16384)
|
||||
lr: Learning rate reset to this value at the start of each season (default: 0.001)
|
||||
initial_checkpoint: Optional weights-only .pt file to start from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
weight_decay: L2 regularization strength (default: 1e-4)
|
||||
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
|
||||
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
|
||||
"""
|
||||
deadline = datetime.now() + timedelta(minutes=duration_minutes)
|
||||
|
||||
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||
_setup_training(data_file, batch_size, subsample_ratio)
|
||||
|
||||
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
|
||||
|
||||
if initial_checkpoint:
|
||||
print(f"Loading initial weights: {initial_checkpoint}")
|
||||
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
ckpt_hidden = ckpt.get("hidden_sizes")
|
||||
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
|
||||
print(f" Using architecture from checkpoint: {ckpt_hidden}")
|
||||
resolved_hidden_sizes = ckpt_hidden
|
||||
|
||||
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
|
||||
criterion = nn.MSELoss()
|
||||
best_global_val_loss = float('inf')
|
||||
|
||||
if initial_checkpoint:
|
||||
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_global_val_loss = ckpt.get("best_val_loss", float('inf'))
|
||||
if best_global_val_loss < float('inf'):
|
||||
print(f"Resumed from checkpoint (best val loss: {best_global_val_loss:.6f})")
|
||||
else:
|
||||
print("Initial weights loaded (no val loss in checkpoint).")
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
print("Loaded weights-only checkpoint (no val loss info).")
|
||||
|
||||
arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
|
||||
print(f"Architecture: {arch_str}")
|
||||
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
|
||||
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
|
||||
print()
|
||||
|
||||
burst_start_time = datetime.now()
|
||||
season = 0
|
||||
best_global_state = None
|
||||
last_optimizer = None
|
||||
last_scheduler = None
|
||||
last_scaler = None
|
||||
last_epoch = 0
|
||||
|
||||
while datetime.now() < deadline:
|
||||
season += 1
|
||||
remaining_minutes = (deadline - datetime.now()).total_seconds() / 60
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"BURST SEASON {season} | {remaining_minutes:.1f} minutes remaining")
|
||||
if best_global_val_loss < float('inf'):
|
||||
print(f"Global best val loss so far: {best_global_val_loss:.6f}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs_per_season)
|
||||
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||
|
||||
season_start_time = datetime.now()
|
||||
val_loss, model_state, last_epoch = _run_training_season(
|
||||
model, optimizer, scheduler, scaler,
|
||||
train_loader, val_loader, train_dataset, val_dataset,
|
||||
device, criterion, output_file,
|
||||
0, epochs_per_season, early_stopping_patience,
|
||||
season_start_time, deadline=deadline,
|
||||
initial_best_val_loss=best_global_val_loss
|
||||
)
|
||||
|
||||
last_optimizer = optimizer
|
||||
last_scheduler = scheduler
|
||||
last_scaler = scaler
|
||||
|
||||
if model_state is not None and val_loss < best_global_val_loss:
|
||||
best_global_val_loss = val_loss
|
||||
best_global_state = model_state
|
||||
print(f" New global best: {best_global_val_loss:.6f} (season {season})")
|
||||
|
||||
# Reload global best for the next season so we never drift backwards
|
||||
if best_global_state is not None:
|
||||
model.load_state_dict(best_global_state)
|
||||
|
||||
total_minutes = (datetime.now() - burst_start_time).total_seconds() / 60
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Burst training complete: {season} season(s) in {total_minutes:.1f}m")
|
||||
print(f"Best val loss: {best_global_val_loss:.6f}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
if best_global_state is None:
|
||||
print("No model improvement found. No file saved.")
|
||||
return
|
||||
|
||||
_save_versioned_model(
|
||||
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
|
||||
best_global_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, burst_start_time,
|
||||
hidden_sizes=resolved_hidden_sizes,
|
||||
extra_metadata={
|
||||
"mode": "burst",
|
||||
"duration_minutes": duration_minutes,
|
||||
"epochs_per_season": epochs_per_season,
|
||||
"early_stopping_patience": early_stopping_patience,
|
||||
"seasons_completed": season,
|
||||
"batch_size": batch_size,
|
||||
"learning_rate": lr,
|
||||
"initial_checkpoint": str(initial_checkpoint) if initial_checkpoint else None,
|
||||
}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation")
|
||||
parser.add_argument("data_file", nargs="?", default="training_data.jsonl",
|
||||
help="Path to training_data.jsonl (default: training_data.jsonl)")
|
||||
parser.add_argument("output_file", nargs="?", default="nnue_weights.pt",
|
||||
help="Output file base name (default: nnue_weights.pt)")
|
||||
parser.add_argument("--checkpoint", type=str, default=None,
|
||||
help="Path to checkpoint file to resume training from (optional)")
|
||||
parser.add_argument("--epochs", type=int, default=100,
|
||||
help="Number of epochs to train (default: 100)")
|
||||
parser.add_argument("--batch-size", type=int, default=16384,
|
||||
help="Batch size (default: 16384)")
|
||||
parser.add_argument("--lr", type=float, default=0.001,
|
||||
help="Learning rate (default: 0.001)")
|
||||
parser.add_argument("--early-stopping", type=int, default=None,
|
||||
help="Stop if val loss doesn't improve for N epochs (optional)")
|
||||
parser.add_argument("--stockfish-depth", type=int, default=12,
|
||||
help="Stockfish depth used for evaluations (for metadata, default: 12)")
|
||||
parser.add_argument("--no-versioning", action="store_true",
|
||||
help="Disable automatic versioning (save directly to output file)")
|
||||
parser.add_argument("--weight-decay", type=float, default=5e-5,
|
||||
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
|
||||
parser.add_argument("--subsample-ratio", type=float, default=1.0,
|
||||
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
|
||||
parser.add_argument("--hidden-layers", type=str, default=None,
|
||||
help="Comma-separated hidden layer sizes (default: 1536,1024,512,256)")
|
||||
|
||||
# Burst mode
|
||||
parser.add_argument("--burst-duration", type=float, default=None,
|
||||
help="Enable burst mode: total training budget in minutes")
|
||||
parser.add_argument("--epochs-per-season", type=int, default=50,
|
||||
help="Max epochs per burst season before restarting (default: 50, burst mode only)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
hidden_sizes = [int(x) for x in args.hidden_layers.split(",")] if args.hidden_layers else None
|
||||
|
||||
if args.burst_duration is not None:
|
||||
burst_train(
|
||||
data_file=args.data_file,
|
||||
output_file=args.output_file,
|
||||
duration_minutes=args.burst_duration,
|
||||
epochs_per_season=args.epochs_per_season,
|
||||
early_stopping_patience=args.early_stopping or 10,
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
initial_checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning,
|
||||
weight_decay=args.weight_decay,
|
||||
subsample_ratio=args.subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
else:
|
||||
train_nnue(
|
||||
data_file=args.data_file,
|
||||
output_file=args.output_file,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning,
|
||||
early_stopping_patience=args.early_stopping,
|
||||
weight_decay=args.weight_decay,
|
||||
subsample_ratio=args.subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
# Setup and run NNUE training pipeline
|
||||
|
||||
$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path
|
||||
$VenvDir = Join-Path $ScriptDir ".venv"
|
||||
|
||||
# Check if virtual environment exists
|
||||
if (-not (Test-Path $VenvDir)) {
|
||||
Write-Host "Creating virtual environment..."
|
||||
python -m venv $VenvDir
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Host "Error: Failed to create virtual environment. Make sure python is installed."
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
# Activate virtual environment
|
||||
Write-Host "Activating virtual environment..."
|
||||
$ActivateScript = Join-Path $VenvDir "Scripts\Activate.ps1"
|
||||
& $ActivateScript
|
||||
|
||||
# Install/update dependencies if requirements.txt exists
|
||||
$RequirementsFile = Join-Path $ScriptDir "requirements.txt"
|
||||
if (Test-Path $RequirementsFile) {
|
||||
Write-Host "Installing dependencies..."
|
||||
pip install -q -r $RequirementsFile
|
||||
}
|
||||
|
||||
# Run nnue.py
|
||||
Write-Host "Starting NNUE Training Pipeline..."
|
||||
python (Join-Path $ScriptDir "nnue.py")
|
||||
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
# Setup and run NNUE training pipeline
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
VENV_DIR="$SCRIPT_DIR/.venv"
|
||||
|
||||
# Check if virtual environment exists
|
||||
if [ ! -d "$VENV_DIR" ]; then
|
||||
echo "Creating virtual environment..."
|
||||
python3 -m venv "$VENV_DIR"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to create virtual environment. Make sure python3 is installed."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Activate virtual environment
|
||||
echo "Activating virtual environment..."
|
||||
source "$VENV_DIR/bin/activate"
|
||||
|
||||
# Install/update dependencies if requirements.txt exists
|
||||
if [ -f "$SCRIPT_DIR/requirements.txt" ]; then
|
||||
echo "Installing dependencies..."
|
||||
pip install -q -r "$SCRIPT_DIR/requirements.txt"
|
||||
fi
|
||||
|
||||
# Run nnue.py
|
||||
echo "Starting NNUE Training Pipeline..."
|
||||
python "$SCRIPT_DIR/nnue.py"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"version": 10,
|
||||
"date": "2026-04-14T22:18:38.824577",
|
||||
"num_positions": 3022562,
|
||||
"stockfish_depth": 12,
|
||||
"final_val_loss": 6.248612448225196e-05,
|
||||
"architecture": [
|
||||
768,
|
||||
1536,
|
||||
1024,
|
||||
512,
|
||||
256,
|
||||
1
|
||||
],
|
||||
"device": "cuda",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"checkpoint": null
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 1,
|
||||
"date": "2026-04-07T22:56:23.259658",
|
||||
"num_positions": 2086,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 20,
|
||||
"batch_size": 4096,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.016311248764395714,
|
||||
"device": "cuda",
|
||||
"checkpoint": null,
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 2,
|
||||
"date": "2026-04-07T23:50:05.390402",
|
||||
"num_positions": 6886,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 4096,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.007848377339541912,
|
||||
"device": "cuda",
|
||||
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v1.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 3,
|
||||
"date": "2026-04-08T09:43:28.000579",
|
||||
"num_positions": 71610,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 20,
|
||||
"batch_size": 4096,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.006398905136849695,
|
||||
"device": "cpu",
|
||||
"checkpoint": "/home/janis/Workspaces/IntelliJ/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v2.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 4,
|
||||
"date": "2026-04-09T00:28:07.572209",
|
||||
"num_positions": 2009355,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 40,
|
||||
"batch_size": 4096,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 9.106677896235248e-05,
|
||||
"device": "cuda",
|
||||
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v3.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 5,
|
||||
"date": "2026-04-09T18:50:27.845632",
|
||||
"num_positions": 2009355,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 9.180421525105905e-05,
|
||||
"device": "cuda",
|
||||
"checkpoint": null,
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 6,
|
||||
"date": "2026-04-09T21:28:21.000832",
|
||||
"num_positions": 1958728,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.2984530149085532,
|
||||
"device": "cuda",
|
||||
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v5.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 7,
|
||||
"date": "2026-04-09T22:06:50.439858",
|
||||
"num_positions": 1958728,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.2997283308762831,
|
||||
"device": "cuda",
|
||||
"checkpoint": null,
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 8,
|
||||
"date": "2026-04-09T22:22:47.859730",
|
||||
"num_positions": 1958728,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.24803777390839968,
|
||||
"device": "cuda",
|
||||
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v7.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"version": 9,
|
||||
"date": "2026-04-13T20:19:08.123315",
|
||||
"num_positions": 2522562,
|
||||
"stockfish_depth": 12,
|
||||
"final_val_loss": 6.994176222619626e-05,
|
||||
"device": "cuda",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
|
||||
"mode": "burst",
|
||||
"duration_minutes": 30.0,
|
||||
"epochs_per_season": 50,
|
||||
"early_stopping_patience": 10,
|
||||
"seasons_completed": 3,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"initial_checkpoint": "/home/janis/Workspaces/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v8.pt"
|
||||
}
|
||||
Reference in New Issue
Block a user