feat: Implement dataset versioning and management for NNUE training data

This commit is contained in:
2026-04-13 21:19:26 +02:00
parent 4b52199754
commit 8fb872e958
18 changed files with 1399 additions and 335 deletions
+287
View File
@@ -0,0 +1,287 @@
#!/usr/bin/env python3
"""Dataset versioning and management for NNUE training data."""
import json
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, List, Tuple
from rich.console import Console
from rich.table import Table
def get_datasets_dir() -> Path:
"""Get/create datasets directory."""
datasets_dir = Path(__file__).parent.parent / "datasets"
datasets_dir.mkdir(exist_ok=True)
return datasets_dir
def next_dataset_version() -> int:
"""Find the next available dataset version number."""
datasets_dir = get_datasets_dir()
versions = []
for d in datasets_dir.iterdir():
if d.is_dir() and d.name.startswith("ds_v"):
try:
v = int(d.name.split("_v")[1])
versions.append(v)
except (ValueError, IndexError):
pass
return max(versions) + 1 if versions else 1
def list_datasets() -> List[Tuple[int, Dict]]:
"""List all datasets with their metadata.
Returns:
List of (version, metadata_dict) tuples, sorted by version.
"""
datasets_dir = get_datasets_dir()
datasets = []
for d in datasets_dir.iterdir():
if d.is_dir() and d.name.startswith("ds_v"):
try:
v = int(d.name.split("_v")[1])
metadata_file = d / "metadata.json"
if metadata_file.exists():
with open(metadata_file, 'r') as f:
metadata = json.load(f)
datasets.append((v, metadata))
except (ValueError, IndexError, json.JSONDecodeError):
pass
return sorted(datasets, key=lambda x: x[0])
def load_dataset_metadata(version: int) -> Optional[Dict]:
"""Load metadata for a specific dataset version.
Returns:
Metadata dict or None if not found.
"""
datasets_dir = get_datasets_dir()
metadata_file = datasets_dir / f"ds_v{version}" / "metadata.json"
if not metadata_file.exists():
return None
with open(metadata_file, 'r') as f:
return json.load(f)
def save_dataset_metadata(version: int, metadata: Dict) -> None:
"""Save metadata for a dataset version."""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
dataset_dir.mkdir(exist_ok=True)
metadata_file = dataset_dir / "metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2, default=str)
def create_dataset(
version: int,
labeled_jsonl_path: str,
sources: List[Dict],
stockfish_depth: int = 12
) -> Path:
"""Create a new versioned dataset.
Args:
version: Dataset version number
labeled_jsonl_path: Path to labeled.jsonl to copy
sources: List of source dicts (see plan for schema)
stockfish_depth: Depth used for labeling
Returns:
Path to the created dataset directory.
"""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
dataset_dir.mkdir(exist_ok=True)
# Copy labeled data with deduplication (in case source has duplicates)
source_path = Path(labeled_jsonl_path)
if source_path.exists():
dest_path = dataset_dir / "labeled.jsonl"
seen_fens = set()
unique_count = 0
with open(source_path, 'r') as src, open(dest_path, 'w') as dst:
for line in src:
try:
data = json.loads(line)
fen = data.get('fen')
if fen and fen not in seen_fens:
dst.write(line)
seen_fens.add(fen)
unique_count += 1
except json.JSONDecodeError:
# Skip malformed lines
pass
# Count positions
total_positions = 0
if (dataset_dir / "labeled.jsonl").exists():
with open(dataset_dir / "labeled.jsonl", 'r') as f:
total_positions = sum(1 for _ in f)
# Create metadata
metadata = {
"version": version,
"created": datetime.now().isoformat(),
"total_positions": total_positions,
"stockfish_depth": stockfish_depth,
"sources": sources
}
save_dataset_metadata(version, metadata)
return dataset_dir
def extend_dataset(
version: int,
new_labeled_path: str,
new_source_entry: Dict
) -> bool:
"""Extend an existing dataset with new labeled positions (with deduplication).
Args:
version: Dataset version to extend
new_labeled_path: Path to new labeled.jsonl to merge
new_source_entry: Source entry to add to metadata
Returns:
True if successful, False otherwise.
"""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
if not dataset_dir.exists():
return False
labeled_file = dataset_dir / "labeled.jsonl"
new_labeled_file = Path(new_labeled_path)
if not new_labeled_file.exists():
return False
# Load existing FENs (dedup set) — must load entire file to avoid duplicates
existing_fens = set()
if labeled_file.exists():
with open(labeled_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
fen = data.get('fen')
if fen:
existing_fens.add(fen)
except json.JSONDecodeError:
pass
# Merge new positions, skipping duplicates
new_count = 0
new_lines = []
with open(new_labeled_file, 'r') as f_new:
for line in f_new:
try:
data = json.loads(line)
fen = data.get('fen')
if fen and fen not in existing_fens:
new_lines.append(line)
existing_fens.add(fen)
new_count += 1
except json.JSONDecodeError:
pass
# Append only the new, unique positions
if new_lines:
with open(labeled_file, 'a') as f_append:
for line in new_lines:
f_append.write(line)
# Update metadata
metadata = load_dataset_metadata(version)
if metadata:
# Count total positions
total_positions = 0
with open(labeled_file, 'r') as f:
total_positions = sum(1 for _ in f)
metadata['total_positions'] = total_positions
# Update the source entry with actual count of new positions added
new_source_entry['actual_count'] = new_count
metadata['sources'].append(new_source_entry)
save_dataset_metadata(version, metadata)
return True
def get_dataset_labeled_path(version: int) -> Optional[Path]:
"""Get the path to a dataset's labeled.jsonl file.
Returns:
Path to labeled.jsonl or None if dataset doesn't exist.
"""
datasets_dir = get_datasets_dir()
labeled_file = datasets_dir / f"ds_v{version}" / "labeled.jsonl"
if labeled_file.exists():
return labeled_file
return None
def delete_dataset(version: int) -> bool:
"""Delete a dataset (recursively removes directory).
Args:
version: Dataset version to delete
Returns:
True if successful.
"""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
if not dataset_dir.exists():
return False
import shutil
shutil.rmtree(dataset_dir)
return True
def show_datasets_table(console: Console = None) -> None:
"""Display all datasets in a Rich table."""
if console is None:
console = Console()
datasets = list_datasets()
if not datasets:
console.print("[yellow] No datasets found yet[/yellow]")
return
table = Table(title="Available Datasets", show_header=True, header_style="bold cyan")
table.add_column("Version", style="dim")
table.add_column("Positions", justify="right")
table.add_column("Sources", justify="left")
table.add_column("Depth", justify="center")
table.add_column("Created", justify="left")
for v, metadata in datasets:
positions = metadata.get('total_positions', 0)
sources = metadata.get('sources', [])
source_str = ", ".join([s.get('type', '?') for s in sources])
depth = metadata.get('stockfish_depth', '?')
created = metadata.get('created', '?')
if created != '?':
created = created.split('T')[0] # Just the date
table.add_row(f"v{v}", f"{positions:,}", source_str, str(depth), created)
console.print(table)
+109 -39
View File
@@ -1,67 +1,137 @@
#!/usr/bin/env python3
"""Export NNUE weights to binary format for runtime loading."""
"""Export NNUE weights to .nbai format for runtime loading."""
import torch
import json
import struct
import sys
from datetime import datetime
from pathlib import Path
def export_weights_to_binary(weights_file, output_file):
"""Load PyTorch weights and export as binary file."""
import torch
MAGIC = 0x4942_414E # bytes 'N','B','A','I' as little-endian int32
VERSION = 1
def _read_sidecar(weights_file: str) -> dict:
sidecar = weights_file.replace(".pt", "_metadata.json")
if Path(sidecar).exists():
with open(sidecar) as f:
return json.load(f)
return {}
def _infer_layers(state_dict: dict) -> list[dict]:
"""Derive layer descriptors from state_dict weight shapes.
Assumes layers named l1, l2, ..., lN.
All hidden layers get activation 'relu'; the last gets 'linear'.
"""
names = sorted(
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
key=lambda n: int(n[1:]),
)
layers = []
for i, name in enumerate(names):
out_size, in_size = state_dict[f"{name}.weight"].shape
activation = "linear" if i == len(names) - 1 else "relu"
layers.append({"activation": activation, "inputSize": int(in_size), "outputSize": int(out_size)})
return layers
def _write_floats(f, tensor):
data = tensor.float().flatten().cpu().numpy()
f.write(struct.pack("<I", len(data)))
f.write(struct.pack(f"<{len(data)}f", *data))
def export_to_nbai(
weights_file: str,
output_file: str,
trained_by: str = "unknown",
train_loss: float = 0.0,
):
if not Path(weights_file).exists():
print(f"Error: Weights file not found at {weights_file}")
print(f"Error: weights file not found at {weights_file}")
sys.exit(1)
# Load weights — handle both raw state dicts and full training checkpoints
loaded = torch.load(weights_file, map_location='cpu')
state_dict = loaded["model_state_dict"] if isinstance(loaded, dict) and "model_state_dict" in loaded else loaded
loaded = torch.load(weights_file, map_location="cpu")
state_dict = (
loaded["model_state_dict"]
if isinstance(loaded, dict) and "model_state_dict" in loaded
else loaded
)
# Debug: print available layers
print(f"Available layers in {weights_file}:")
for key in sorted(state_dict.keys()):
print(f" {key}: {state_dict[key].shape}")
sidecar = _read_sidecar(weights_file)
val_loss = float(loaded.get("best_val_loss", sidecar.get("final_val_loss", 0.0))) if isinstance(loaded, dict) else 0.0
trained_at = sidecar.get("date", datetime.now().isoformat())
training_data_count = int(sidecar.get("num_positions", 0))
# Create output directory if needed
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
metadata = {
"trainedBy": trained_by,
"trainedAt": trained_at,
"trainingDataCount": training_data_count,
"valLoss": val_loss,
"trainLoss": train_loss,
}
with open(output_file, 'wb') as f:
# Write magic number and version
f.write(b'NNUE')
f.write(struct.pack('<I', 1)) # version 1
layers = _infer_layers(state_dict)
layer_names = sorted(
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
key=lambda n: int(n[1:]),
)
# Write each weight tensor in order
for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias', 'l4.weight', 'l4.bias', 'l5.weight', 'l5.bias']:
if layer_name not in state_dict:
print(f"Error: Missing layer {layer_name}")
sys.exit(1)
print(f"Architecture ({len(layers)} layers):")
for i, l in enumerate(layers):
print(f" l{i + 1}: {l['inputSize']} -> {l['outputSize']} [{l['activation']}]")
tensor = state_dict[layer_name]
# Convert to float32 and flatten
data = tensor.float().flatten().cpu().numpy()
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
# Write shape (allows validation on load)
shape = list(tensor.shape)
f.write(struct.pack('<I', len(shape)))
for dim in shape:
f.write(struct.pack('<I', dim))
with open(output_file, "wb") as f:
# Header
f.write(struct.pack("<I", MAGIC))
f.write(struct.pack("<H", VERSION))
# Write flattened data as binary floats
f.write(struct.pack(f'<{len(data)}f', *data))
# Metadata (length-prefixed UTF-8 JSON)
meta_bytes = json.dumps(metadata, indent=2).encode("utf-8")
f.write(struct.pack("<I", len(meta_bytes)))
f.write(meta_bytes)
print(f" {layer_name}: shape {shape}, {len(data)} floats")
# Layer descriptors
f.write(struct.pack("<H", len(layers)))
for layer in layers:
name_bytes = layer["activation"].encode("ascii")
f.write(struct.pack("<B", len(name_bytes)))
f.write(name_bytes)
f.write(struct.pack("<I", layer["inputSize"]))
f.write(struct.pack("<I", layer["outputSize"]))
# Weights: weight tensor then bias tensor per layer
for name in layer_names:
w = state_dict[f"{name}.weight"]
b = state_dict[f"{name}.bias"]
_write_floats(f, w)
_write_floats(f, b)
print(f" Wrote {name}: weight {tuple(w.shape)}, bias {tuple(b.shape)}")
size_mb = Path(output_file).stat().st_size / (1024 ** 2)
print(f"\nExported to {output_file} ({size_mb:.2f} MB)")
print(f"Metadata: {json.dumps(metadata, indent=2)}")
file_size_mb = output_path.stat().st_size / (1024**2)
print(f"Weights exported to {output_file} ({file_size_mb:.2f} MB)")
if __name__ == "__main__":
weights_file = "nnue_weights.pt"
output_file = "../src/main/resources/nnue_weights.bin"
output_file = "../src/main/resources/nnue_weights.nbai"
trained_by = "unknown"
train_loss = 0.0
if len(sys.argv) > 1:
weights_file = sys.argv[1]
if len(sys.argv) > 2:
output_file = sys.argv[2]
if len(sys.argv) > 3:
trained_by = sys.argv[3]
if len(sys.argv) > 4:
train_loss = float(sys.argv[4])
export_weights_to_binary(weights_file, output_file)
export_to_nbai(weights_file, output_file, trained_by, train_loss)
+12 -1
View File
@@ -125,6 +125,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
# Load all FENs that need evaluation
fens_to_evaluate = []
fens_seen_in_batch = set() # Track duplicates within current batch
skipped_invalid = 0
skipped_duplicate = 0
@@ -140,7 +141,12 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
skipped_duplicate += 1
continue
if fen in fens_seen_in_batch:
skipped_duplicate += 1
continue
fens_to_evaluate.append(fen)
fens_seen_in_batch.add(fen)
total_to_evaluate = len(fens_to_evaluate)
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
@@ -178,8 +184,13 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
with open(output_file, 'a') as out:
for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
for fen, eval_normalized, eval_cp in batch_results:
# Skip if already evaluated in output file during this run
if fen in evaluated_fens:
continue
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
out.write(json.dumps(data) + '\n')
evaluated_fens.add(fen) # Track as evaluated
evaluated += 1
raw_evals.append(eval_cp)
normalized_evals.append(eval_normalized)
@@ -287,7 +298,7 @@ if __name__ == "__main__":
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
parser.add_argument("--depth", type=int, default=12,
help="Stockfish depth (default: 12)")
parser.add_argument("--batch-size", type=int, default=20,
parser.add_argument("--batch-size", type=int, default=1000,
help="Batch size for processing (default: 1000)")
parser.add_argument("--no-normalize", action="store_true",
help="Disable evaluation normalization (keep raw centipawns)")
@@ -17,7 +17,7 @@ from generate import play_random_game_and_collect_positions
def download_and_extract_puzzle_db(
url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst',
output_dir: str = 'trainingdata'
output_dir: str = 'tactical_data'
):
"""Download and extract the Lichess puzzle database."""
output_path = Path(output_dir)
@@ -141,6 +141,31 @@ def merge_positions(
print(f"{'='*60}\n")
def extract_tactical_only(
puzzle_csv: str,
output_file: str,
max_puzzles: int = 300_000
) -> int:
"""Extract tactical positions and save to file (no merge prompts).
Args:
puzzle_csv: Path to Lichess puzzle CSV
output_file: Where to save the FEN positions
max_puzzles: Maximum puzzles to extract
Returns:
Number of positions extracted
"""
print("Extracting tactical positions from puzzle database...")
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
with open(output_file, 'w') as f:
for fen in tactical_positions:
f.write(fen + '\n')
return len(tactical_positions)
def interactive_merge_positions(
puzzle_csv: str,
output_file: str = 'position.txt',