Files
NowChessSystems/modules/official-bots/python/build_dataset.py
T
Janis Eccarius 1c80abdb8a
Build & Test (NowChessSystems) TeamCity build finished
feat(official-bots): standalone self-play + one-shot dataset builder for NNUE training
Add an easy local data pipeline feeding GPU training on Colab.

- SelfPlayMain: standalone NNUEBot self-play (no microservices) writing FENs
  for labeling; randomised openings for game diversity, sequential due to the
  shared EvaluationNNUE accumulator. Exposed via the `selfPlay` Gradle task and
  selfplay.sh.
- NNUEBot: optional fixedMoveTimeMs so self-play runs fast (default unchanged).
- NbaiLoader: honor `-Dnnue.weights=<path>` to load weights from a file before
  falling back to the bundled resource.
- build_dataset.py / dataset.sh: one command builds the entire dataset
  (Lichess eval-DB backbone + self-play + tactical + random filler), dedups,
  balances the eval histogram, writes append-only zstd shards + manifest, and
  rclone-pushes to Drive.
- train.py: NNUEDataset reads a directory of .jsonl.zst shards (streaming) in
  addition to a single file.
- NNUETraining.ipynb: clone to ephemeral /content, sync shards from Drive
  (cache-aware), train on the shards dir; removed Colab generation/upload steps.
- Concept + implementation plan docs.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-24 22:04:22 +02:00

282 lines
10 KiB
Python

#!/usr/bin/env python3
"""Build the ENTIRE NNUE training dataset with one command.
Orchestrates the existing source modules (Lichess eval DB, self-play, tactical puzzles,
random filler), labels what needs labeling with local Stockfish, deduplicates, balances
the eval distribution, writes append-only compressed shards + a manifest, and pushes to
Google Drive with rclone.
./dataset.sh # build everything + push
./dataset.sh --no-push # build only
./dataset.sh --no-lichess # skip the (large) Lichess backbone
Tune the CONFIG block below — that is the only thing you normally touch.
"""
import argparse
import hashlib
import json
import os
import random
import subprocess
import sys
import urllib.request
from datetime import datetime, timezone
from pathlib import Path
import zstandard as zstd
HERE = Path(__file__).resolve().parent
sys.path.insert(0, str(HERE / "src"))
from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish
from lichess_importer import import_lichess_evals
from tactical_positions_extractor import download_and_extract_puzzle_db, extract_tactical_only
# ── CONFIG — the only knobs you normally touch ───────────────────────────────
LICHESS_POSITIONS = 2_000_000 # backbone positions from the Lichess eval DB
USE_SELFPLAY = True # label data/selfplay.txt if present
TACTICAL_PUZZLES = 200_000 # tactical positions from the Lichess puzzle DB
RANDOM_FILLER = 100_000 # cheap random-play positions
STOCKFISH_DEPTH = 14 # local labeling depth (selfplay/tactical/random)
RCLONE_REMOTE = "gdrive:NowChess/datasets"
# ─────────────────────────────────────────────────────────────────────────────
LABEL_BATCH = 64 # positions per Stockfish task (small = smooth progress + load balance)
SHARD_SIZE = 100_000 # positions per shard
BALANCE_BINS = 64 # eval histogram bins over [-1, 1]
BALANCE_FACTOR = 2.0 # cap each bin at FACTOR x the uniform bin size
LICHESS_EVAL_URL = "https://database.lichess.org/lichess_db_eval.jsonl.zst"
STOCKFISH_PATH = os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
WORKERS = os.cpu_count() or 4
DATA_DIR = HERE / "data"
WORK_DIR = HERE / "data" / "_build"
DATASETS_DIR = HERE / "datasets"
SHARDS_DIR = DATASETS_DIR / "shards"
MANIFEST = DATASETS_DIR / "manifest.json"
LICHESS_DB = HERE / "trainingdata" / "lichess_db_eval.jsonl.zst"
def label(fens_file: Path, out: Path) -> int:
"""Label a FEN file with local Stockfish. Returns positions written."""
if not fens_file.exists():
return 0
label_positions_with_stockfish(
str(fens_file), str(out), STOCKFISH_PATH,
batch_size=LABEL_BATCH, depth=STOCKFISH_DEPTH, num_workers=WORKERS,
)
return count_lines(out)
def count_lines(path: Path) -> int:
if not path.exists():
return 0
with open(path) as f:
return sum(1 for _ in f)
def source_lichess(out: Path) -> int:
if not LICHESS_DB.exists():
print(f"Downloading Lichess eval DB → {LICHESS_DB} (large, one-time)...")
LICHESS_DB.parent.mkdir(parents=True, exist_ok=True)
urllib.request.urlretrieve(LICHESS_EVAL_URL, LICHESS_DB)
return import_lichess_evals(str(LICHESS_DB), str(out), max_positions=LICHESS_POSITIONS)
def source_selfplay(out: Path) -> int:
return label(DATA_DIR / "selfplay.txt", out)
def source_tactical(out: Path) -> int:
puzzle_csv = download_and_extract_puzzle_db(output_dir=str(HERE / "tactical_data"))
if puzzle_csv is None:
return 0
fens = WORK_DIR / "tactical_fens.txt"
extract_tactical_only(str(puzzle_csv), str(fens), max_puzzles=TACTICAL_PUZZLES)
return label(fens, out)
def source_random(out: Path) -> int:
fens = WORK_DIR / "random_fens.txt"
play_random_game_and_collect_positions(
str(fens), total_positions=RANDOM_FILLER, num_workers=WORKERS,
)
return label(fens, out)
def build_sources(args) -> dict[str, Path]:
"""Run each enabled source into its own labeled jsonl. Returns {name: path}."""
WORK_DIR.mkdir(parents=True, exist_ok=True)
plan = [
("lichess", args.lichess, source_lichess),
("selfplay", args.selfplay, source_selfplay),
("tactical", args.tactical, source_tactical),
("random", args.random, source_random),
]
outputs: dict[str, Path] = {}
for name, enabled, fn in plan:
if not enabled:
continue
out = WORK_DIR / f"{name}_labeled.jsonl"
out.unlink(missing_ok=True)
print(f"\n=== Source: {name} ===")
written = fn(out)
print(f"{name}: {written:,} labeled positions")
if written:
outputs[name] = out
return outputs
def existing_fens() -> set[str]:
"""FENs already present in the dataset, so growth stays deduplicated."""
seen: set[str] = set()
if not MANIFEST.exists():
return seen
manifest = json.loads(MANIFEST.read_text())
for shard in manifest.get("shards", []):
for rec in read_shard(SHARDS_DIR / shard["file"]):
seen.add(rec["fen"])
return seen
def read_shard(path: Path):
dctx = zstd.ZstdDecompressor()
with open(path, "rb") as fh, dctx.stream_reader(fh) as reader:
for line in iter_text(reader):
yield json.loads(line)
def iter_text(reader):
import io
yield from io.TextIOWrapper(reader, encoding="utf-8")
def merge_dedup(outputs: dict[str, Path], skip: set[str]):
"""Merge all source jsonl, drop dupes (within batch + vs existing dataset)."""
seen = set(skip)
records, per_source = [], {}
for name, path in outputs.items():
kept = 0
with open(path) as f:
for line in f:
rec = json.loads(line)
fen = rec["fen"]
if fen in seen:
continue
seen.add(fen)
rec["source"] = name
records.append(rec)
kept += 1
per_source[name] = kept
return records, per_source
def balance(records: list) -> list:
"""Flatten the eval histogram: cap each bin at FACTOR x the uniform bin size."""
if not records:
return records
cap = max(1, int(BALANCE_FACTOR * len(records) / BALANCE_BINS))
bins: dict[int, int] = {}
kept = []
random.shuffle(records)
for rec in records:
b = min(BALANCE_BINS - 1, int((rec["eval"] + 1.0) / 2.0 * BALANCE_BINS))
if bins.get(b, 0) < cap:
bins[b] = bins.get(b, 0) + 1
kept.append(rec)
return kept
def sha256(path: Path) -> str:
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
return h.hexdigest()
def write_shards(records: list, build_id: str) -> list[dict]:
SHARDS_DIR.mkdir(parents=True, exist_ok=True)
cctx = zstd.ZstdCompressor(level=10)
entries = []
for i in range(0, len(records), SHARD_SIZE):
chunk = records[i : i + SHARD_SIZE]
name = f"{build_id}_{i // SHARD_SIZE:05d}.jsonl.zst"
path = SHARDS_DIR / name
with open(path, "wb") as fh, cctx.stream_writer(fh) as w:
for rec in chunk:
w.write((json.dumps(rec) + "\n").encode("utf-8"))
entries.append({"file": name, "positions": len(chunk),
"sha256": sha256(path), "build_id": build_id})
print(f" wrote {name} ({len(chunk):,} positions)")
return entries
def update_manifest(new_shards: list[dict], build: dict) -> None:
manifest = json.loads(MANIFEST.read_text()) if MANIFEST.exists() else {
"dataset_version": 0, "scale": 300.0, "builds": [], "shards": [],
}
manifest["dataset_version"] += 1
manifest["created"] = build["created"]
manifest["builds"].append(build)
manifest["shards"].extend(new_shards)
manifest["total_positions"] = sum(s["positions"] for s in manifest["shards"])
MANIFEST.write_text(json.dumps(manifest, indent=2))
print(f"\nDataset version {manifest['dataset_version']}: "
f"{manifest['total_positions']:,} total positions across {len(manifest['shards'])} shards")
def push() -> None:
if not subprocess.run(["which", "rclone"], capture_output=True).stdout:
print("rclone not found — skipping push.")
return
print(f"\nPushing {DATASETS_DIR}{RCLONE_REMOTE} ...")
subprocess.run(["rclone", "copy", str(DATASETS_DIR), RCLONE_REMOTE, "--progress"], check=True)
def parse_args():
p = argparse.ArgumentParser(description="Build the entire NNUE dataset.")
for name in ("lichess", "selfplay", "tactical", "random", "push"):
p.add_argument(f"--no-{name}", dest=name, action="store_false")
p.add_argument("--push-only", action="store_true", help="Push the existing dataset, build nothing.")
return p.parse_args()
def main() -> None:
args = parse_args()
if args.push_only:
push()
return
build_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
outputs = build_sources(args)
if not outputs:
print("No sources produced data — nothing to build.")
return
print("\n=== Merge / dedup / balance ===")
records, per_source = merge_dedup(outputs, existing_fens())
print(f"merged unique (new): {len(records):,}")
records = balance(records)
print(f"after balancing: {len(records):,}")
new_shards = write_shards(records, build_id)
update_manifest(new_shards, {
"build_id": build_id,
"created": datetime.now(timezone.utc).isoformat(),
"stockfish_depth": STOCKFISH_DEPTH,
"sources": per_source,
"kept_after_balance": len(records),
})
if args.push:
push()
print("\nDone.")
if __name__ == "__main__":
main()