1c80abdb8a
Build & Test (NowChessSystems) TeamCity build finished
Add an easy local data pipeline feeding GPU training on Colab. - SelfPlayMain: standalone NNUEBot self-play (no microservices) writing FENs for labeling; randomised openings for game diversity, sequential due to the shared EvaluationNNUE accumulator. Exposed via the `selfPlay` Gradle task and selfplay.sh. - NNUEBot: optional fixedMoveTimeMs so self-play runs fast (default unchanged). - NbaiLoader: honor `-Dnnue.weights=<path>` to load weights from a file before falling back to the bundled resource. - build_dataset.py / dataset.sh: one command builds the entire dataset (Lichess eval-DB backbone + self-play + tactical + random filler), dedups, balances the eval histogram, writes append-only zstd shards + manifest, and rclone-pushes to Drive. - train.py: NNUEDataset reads a directory of .jsonl.zst shards (streaming) in addition to a single file. - NNUETraining.ipynb: clone to ephemeral /content, sync shards from Drive (cache-aware), train on the shards dir; removed Colab generation/upload steps. - Concept + implementation plan docs. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
282 lines
10 KiB
Python
282 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""Build the ENTIRE NNUE training dataset with one command.
|
|
|
|
Orchestrates the existing source modules (Lichess eval DB, self-play, tactical puzzles,
|
|
random filler), labels what needs labeling with local Stockfish, deduplicates, balances
|
|
the eval distribution, writes append-only compressed shards + a manifest, and pushes to
|
|
Google Drive with rclone.
|
|
|
|
./dataset.sh # build everything + push
|
|
./dataset.sh --no-push # build only
|
|
./dataset.sh --no-lichess # skip the (large) Lichess backbone
|
|
|
|
Tune the CONFIG block below — that is the only thing you normally touch.
|
|
"""
|
|
|
|
import argparse
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import random
|
|
import subprocess
|
|
import sys
|
|
import urllib.request
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
import zstandard as zstd
|
|
|
|
HERE = Path(__file__).resolve().parent
|
|
sys.path.insert(0, str(HERE / "src"))
|
|
|
|
from generate import play_random_game_and_collect_positions
|
|
from label import label_positions_with_stockfish
|
|
from lichess_importer import import_lichess_evals
|
|
from tactical_positions_extractor import download_and_extract_puzzle_db, extract_tactical_only
|
|
|
|
# ── CONFIG — the only knobs you normally touch ───────────────────────────────
|
|
LICHESS_POSITIONS = 2_000_000 # backbone positions from the Lichess eval DB
|
|
USE_SELFPLAY = True # label data/selfplay.txt if present
|
|
TACTICAL_PUZZLES = 200_000 # tactical positions from the Lichess puzzle DB
|
|
RANDOM_FILLER = 100_000 # cheap random-play positions
|
|
STOCKFISH_DEPTH = 14 # local labeling depth (selfplay/tactical/random)
|
|
RCLONE_REMOTE = "gdrive:NowChess/datasets"
|
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
|
|
LABEL_BATCH = 64 # positions per Stockfish task (small = smooth progress + load balance)
|
|
SHARD_SIZE = 100_000 # positions per shard
|
|
BALANCE_BINS = 64 # eval histogram bins over [-1, 1]
|
|
BALANCE_FACTOR = 2.0 # cap each bin at FACTOR x the uniform bin size
|
|
LICHESS_EVAL_URL = "https://database.lichess.org/lichess_db_eval.jsonl.zst"
|
|
|
|
STOCKFISH_PATH = os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
|
|
WORKERS = os.cpu_count() or 4
|
|
|
|
DATA_DIR = HERE / "data"
|
|
WORK_DIR = HERE / "data" / "_build"
|
|
DATASETS_DIR = HERE / "datasets"
|
|
SHARDS_DIR = DATASETS_DIR / "shards"
|
|
MANIFEST = DATASETS_DIR / "manifest.json"
|
|
LICHESS_DB = HERE / "trainingdata" / "lichess_db_eval.jsonl.zst"
|
|
|
|
|
|
def label(fens_file: Path, out: Path) -> int:
|
|
"""Label a FEN file with local Stockfish. Returns positions written."""
|
|
if not fens_file.exists():
|
|
return 0
|
|
label_positions_with_stockfish(
|
|
str(fens_file), str(out), STOCKFISH_PATH,
|
|
batch_size=LABEL_BATCH, depth=STOCKFISH_DEPTH, num_workers=WORKERS,
|
|
)
|
|
return count_lines(out)
|
|
|
|
|
|
def count_lines(path: Path) -> int:
|
|
if not path.exists():
|
|
return 0
|
|
with open(path) as f:
|
|
return sum(1 for _ in f)
|
|
|
|
|
|
def source_lichess(out: Path) -> int:
|
|
if not LICHESS_DB.exists():
|
|
print(f"Downloading Lichess eval DB → {LICHESS_DB} (large, one-time)...")
|
|
LICHESS_DB.parent.mkdir(parents=True, exist_ok=True)
|
|
urllib.request.urlretrieve(LICHESS_EVAL_URL, LICHESS_DB)
|
|
return import_lichess_evals(str(LICHESS_DB), str(out), max_positions=LICHESS_POSITIONS)
|
|
|
|
|
|
def source_selfplay(out: Path) -> int:
|
|
return label(DATA_DIR / "selfplay.txt", out)
|
|
|
|
|
|
def source_tactical(out: Path) -> int:
|
|
puzzle_csv = download_and_extract_puzzle_db(output_dir=str(HERE / "tactical_data"))
|
|
if puzzle_csv is None:
|
|
return 0
|
|
fens = WORK_DIR / "tactical_fens.txt"
|
|
extract_tactical_only(str(puzzle_csv), str(fens), max_puzzles=TACTICAL_PUZZLES)
|
|
return label(fens, out)
|
|
|
|
|
|
def source_random(out: Path) -> int:
|
|
fens = WORK_DIR / "random_fens.txt"
|
|
play_random_game_and_collect_positions(
|
|
str(fens), total_positions=RANDOM_FILLER, num_workers=WORKERS,
|
|
)
|
|
return label(fens, out)
|
|
|
|
|
|
def build_sources(args) -> dict[str, Path]:
|
|
"""Run each enabled source into its own labeled jsonl. Returns {name: path}."""
|
|
WORK_DIR.mkdir(parents=True, exist_ok=True)
|
|
plan = [
|
|
("lichess", args.lichess, source_lichess),
|
|
("selfplay", args.selfplay, source_selfplay),
|
|
("tactical", args.tactical, source_tactical),
|
|
("random", args.random, source_random),
|
|
]
|
|
outputs: dict[str, Path] = {}
|
|
for name, enabled, fn in plan:
|
|
if not enabled:
|
|
continue
|
|
out = WORK_DIR / f"{name}_labeled.jsonl"
|
|
out.unlink(missing_ok=True)
|
|
print(f"\n=== Source: {name} ===")
|
|
written = fn(out)
|
|
print(f"{name}: {written:,} labeled positions")
|
|
if written:
|
|
outputs[name] = out
|
|
return outputs
|
|
|
|
|
|
def existing_fens() -> set[str]:
|
|
"""FENs already present in the dataset, so growth stays deduplicated."""
|
|
seen: set[str] = set()
|
|
if not MANIFEST.exists():
|
|
return seen
|
|
manifest = json.loads(MANIFEST.read_text())
|
|
for shard in manifest.get("shards", []):
|
|
for rec in read_shard(SHARDS_DIR / shard["file"]):
|
|
seen.add(rec["fen"])
|
|
return seen
|
|
|
|
|
|
def read_shard(path: Path):
|
|
dctx = zstd.ZstdDecompressor()
|
|
with open(path, "rb") as fh, dctx.stream_reader(fh) as reader:
|
|
for line in iter_text(reader):
|
|
yield json.loads(line)
|
|
|
|
|
|
def iter_text(reader):
|
|
import io
|
|
yield from io.TextIOWrapper(reader, encoding="utf-8")
|
|
|
|
|
|
def merge_dedup(outputs: dict[str, Path], skip: set[str]):
|
|
"""Merge all source jsonl, drop dupes (within batch + vs existing dataset)."""
|
|
seen = set(skip)
|
|
records, per_source = [], {}
|
|
for name, path in outputs.items():
|
|
kept = 0
|
|
with open(path) as f:
|
|
for line in f:
|
|
rec = json.loads(line)
|
|
fen = rec["fen"]
|
|
if fen in seen:
|
|
continue
|
|
seen.add(fen)
|
|
rec["source"] = name
|
|
records.append(rec)
|
|
kept += 1
|
|
per_source[name] = kept
|
|
return records, per_source
|
|
|
|
|
|
def balance(records: list) -> list:
|
|
"""Flatten the eval histogram: cap each bin at FACTOR x the uniform bin size."""
|
|
if not records:
|
|
return records
|
|
cap = max(1, int(BALANCE_FACTOR * len(records) / BALANCE_BINS))
|
|
bins: dict[int, int] = {}
|
|
kept = []
|
|
random.shuffle(records)
|
|
for rec in records:
|
|
b = min(BALANCE_BINS - 1, int((rec["eval"] + 1.0) / 2.0 * BALANCE_BINS))
|
|
if bins.get(b, 0) < cap:
|
|
bins[b] = bins.get(b, 0) + 1
|
|
kept.append(rec)
|
|
return kept
|
|
|
|
|
|
def sha256(path: Path) -> str:
|
|
h = hashlib.sha256()
|
|
with open(path, "rb") as f:
|
|
for chunk in iter(lambda: f.read(1 << 20), b""):
|
|
h.update(chunk)
|
|
return h.hexdigest()
|
|
|
|
|
|
def write_shards(records: list, build_id: str) -> list[dict]:
|
|
SHARDS_DIR.mkdir(parents=True, exist_ok=True)
|
|
cctx = zstd.ZstdCompressor(level=10)
|
|
entries = []
|
|
for i in range(0, len(records), SHARD_SIZE):
|
|
chunk = records[i : i + SHARD_SIZE]
|
|
name = f"{build_id}_{i // SHARD_SIZE:05d}.jsonl.zst"
|
|
path = SHARDS_DIR / name
|
|
with open(path, "wb") as fh, cctx.stream_writer(fh) as w:
|
|
for rec in chunk:
|
|
w.write((json.dumps(rec) + "\n").encode("utf-8"))
|
|
entries.append({"file": name, "positions": len(chunk),
|
|
"sha256": sha256(path), "build_id": build_id})
|
|
print(f" wrote {name} ({len(chunk):,} positions)")
|
|
return entries
|
|
|
|
|
|
def update_manifest(new_shards: list[dict], build: dict) -> None:
|
|
manifest = json.loads(MANIFEST.read_text()) if MANIFEST.exists() else {
|
|
"dataset_version": 0, "scale": 300.0, "builds": [], "shards": [],
|
|
}
|
|
manifest["dataset_version"] += 1
|
|
manifest["created"] = build["created"]
|
|
manifest["builds"].append(build)
|
|
manifest["shards"].extend(new_shards)
|
|
manifest["total_positions"] = sum(s["positions"] for s in manifest["shards"])
|
|
MANIFEST.write_text(json.dumps(manifest, indent=2))
|
|
print(f"\nDataset version {manifest['dataset_version']}: "
|
|
f"{manifest['total_positions']:,} total positions across {len(manifest['shards'])} shards")
|
|
|
|
|
|
def push() -> None:
|
|
if not subprocess.run(["which", "rclone"], capture_output=True).stdout:
|
|
print("rclone not found — skipping push.")
|
|
return
|
|
print(f"\nPushing {DATASETS_DIR} → {RCLONE_REMOTE} ...")
|
|
subprocess.run(["rclone", "copy", str(DATASETS_DIR), RCLONE_REMOTE, "--progress"], check=True)
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser(description="Build the entire NNUE dataset.")
|
|
for name in ("lichess", "selfplay", "tactical", "random", "push"):
|
|
p.add_argument(f"--no-{name}", dest=name, action="store_false")
|
|
p.add_argument("--push-only", action="store_true", help="Push the existing dataset, build nothing.")
|
|
return p.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
if args.push_only:
|
|
push()
|
|
return
|
|
build_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
|
|
outputs = build_sources(args)
|
|
if not outputs:
|
|
print("No sources produced data — nothing to build.")
|
|
return
|
|
|
|
print("\n=== Merge / dedup / balance ===")
|
|
records, per_source = merge_dedup(outputs, existing_fens())
|
|
print(f"merged unique (new): {len(records):,}")
|
|
records = balance(records)
|
|
print(f"after balancing: {len(records):,}")
|
|
|
|
new_shards = write_shards(records, build_id)
|
|
update_manifest(new_shards, {
|
|
"build_id": build_id,
|
|
"created": datetime.now(timezone.utc).isoformat(),
|
|
"stockfish_depth": STOCKFISH_DEPTH,
|
|
"sources": per_source,
|
|
"kept_after_balance": len(records),
|
|
})
|
|
|
|
if args.push:
|
|
push()
|
|
print("\nDone.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|