378 lines
12 KiB
Plaintext
378 lines
12 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5,
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"name": "python",
|
|
"version": "3.10.0"
|
|
},
|
|
"colab": {
|
|
"provenance": [],
|
|
"gpuType": "T4"
|
|
},
|
|
"accelerator": "GPU"
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# NNUE Training Pipeline\n",
|
|
"\n",
|
|
"End-to-end notebook: data generation → Stockfish labeling → training → `.nbai` export.\n",
|
|
"\n",
|
|
"**Runtime:** GPU (T4 or better). Runtime → Change runtime type → T4 GPU.\n",
|
|
"\n",
|
|
"**Persistence:** Checkpoints and datasets are saved to Google Drive so training can resume after session timeout."
|
|
],
|
|
"id": "intro-md"
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"## ⚙️ 1 — Setup"
|
|
],
|
|
"id": "setup-md"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Mount Google Drive for checkpoint persistence\n",
|
|
"from google.colab import drive\n",
|
|
"drive.mount('/content/drive')"
|
|
],
|
|
"id": "mount-drive"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"# ── Configure these paths once ───────────────────────────────────────────────\n",
|
|
"REPO_URL = 'https://git.janis-eccarius.de/NowChess/NowChessSystems.git'\n",
|
|
"DRIVE_ROOT = '/content/drive/MyDrive/NowChess'\n",
|
|
"REPO_DIR = f'{DRIVE_ROOT}/NowChessSystems'\n",
|
|
"PYTHON_DIR = f'{REPO_DIR}/modules/official-bots/python'\n",
|
|
"# ─────────────────────────────────────────────────────────────────────────────\n",
|
|
"\n",
|
|
"os.makedirs(DRIVE_ROOT, exist_ok=True)\n",
|
|
"\n",
|
|
"if not os.path.isdir(REPO_DIR):\n",
|
|
" !git clone --depth=1 \"{REPO_URL}\" \"{REPO_DIR}\"\n",
|
|
" print('Repo cloned to Drive.')\n",
|
|
"else:\n",
|
|
" !git -C \"{REPO_DIR}\" pull --ff-only\n",
|
|
" print('Repo updated.')"
|
|
],
|
|
"id": "clone-repo"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Install Python dependencies\n",
|
|
"!pip install -q chess tqdm rich zstandard\n",
|
|
"\n",
|
|
"# Stockfish for position labeling\n",
|
|
"!apt-get install -q -y stockfish\n",
|
|
"import shutil\n",
|
|
"STOCKFISH_PATH = shutil.which('stockfish') or '/usr/games/stockfish'\n",
|
|
"print(f'Stockfish: {STOCKFISH_PATH}')\n",
|
|
"\n",
|
|
"# Add pipeline source to path\n",
|
|
"import sys\n",
|
|
"sys.path.insert(0, f'{PYTHON_DIR}/src')\n",
|
|
"sys.path.insert(0, PYTHON_DIR)\n",
|
|
"print('Python path configured.')"
|
|
],
|
|
"id": "install-deps"
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"## 🗄️ 2 — Data\n",
|
|
"\n",
|
|
"Choose **one** of the two options below:\n",
|
|
"- **Option A** — generate FEN positions with random play, then label them with Stockfish.\n",
|
|
"- **Option B** — upload an existing `labeled.jsonl` from your machine or Drive."
|
|
],
|
|
"id": "data-md"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from pathlib import Path\n",
|
|
"\n",
|
|
"# Paths (all on Drive so they survive session restarts)\n",
|
|
"DATA_DIR = Path(DRIVE_ROOT) / 'training_data'\n",
|
|
"DATA_DIR.mkdir(parents=True, exist_ok=True)\n",
|
|
"POSITIONS_FILE = DATA_DIR / 'positions.txt' # raw FENs\n",
|
|
"LABELED_FILE = DATA_DIR / 'labeled.jsonl' # FEN + eval pairs\n",
|
|
"\n",
|
|
"print(f'Data directory: {DATA_DIR}')"
|
|
],
|
|
"id": "data-paths"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ── Option A: Generate + label ────────────────────────────────────────────────\n",
|
|
"# Adjust NUM_POSITIONS to taste. 50 000 trains in ~10 min on T4;\n",
|
|
"# 200 000+ gives better generalisation.\n",
|
|
"NUM_POSITIONS = 50_000\n",
|
|
"STOCKFISH_DEPTH = 12\n",
|
|
"LABEL_WORKERS = 4 # parallel Stockfish processes\n",
|
|
"MIN_MOVE = 5 # skip opening book moves\n",
|
|
"MAX_MOVE = 60\n",
|
|
"\n",
|
|
"from generate import play_random_game_and_collect_positions\n",
|
|
"from label import label_positions_with_stockfish\n",
|
|
"\n",
|
|
"print(f'Generating {NUM_POSITIONS:,} positions...')\n",
|
|
"count = play_random_game_and_collect_positions(\n",
|
|
" str(POSITIONS_FILE),\n",
|
|
" total_positions=NUM_POSITIONS,\n",
|
|
" samples_per_game=1,\n",
|
|
" min_move=MIN_MOVE,\n",
|
|
" max_move=MAX_MOVE,\n",
|
|
" num_workers=4,\n",
|
|
")\n",
|
|
"print(f'{count:,} positions written to {POSITIONS_FILE}')\n",
|
|
"\n",
|
|
"print('Labeling with Stockfish (this is the slow step)...')\n",
|
|
"ok = label_positions_with_stockfish(\n",
|
|
" str(POSITIONS_FILE),\n",
|
|
" str(LABELED_FILE),\n",
|
|
" STOCKFISH_PATH,\n",
|
|
" depth=STOCKFISH_DEPTH,\n",
|
|
" num_workers=LABEL_WORKERS,\n",
|
|
")\n",
|
|
"if ok:\n",
|
|
" print(f'Labeled dataset saved: {LABELED_FILE}')\n",
|
|
"else:\n",
|
|
" print('ERROR: labeling failed')"
|
|
],
|
|
"id": "option-a-generate"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ── Option B: Upload existing labeled.jsonl ───────────────────────────────────\n",
|
|
"# Run this cell instead of Option A if you already have a labeled dataset.\n",
|
|
"#\n",
|
|
"# To upload from local machine:\n",
|
|
"# from google.colab import files\n",
|
|
"# uploaded = files.upload() # pick your labeled.jsonl\n",
|
|
"# import shutil, os\n",
|
|
"# shutil.move(next(iter(uploaded)), str(LABELED_FILE))\n",
|
|
"#\n",
|
|
"# Or copy from Drive:\n",
|
|
"# import shutil\n",
|
|
"# shutil.copy('/content/drive/MyDrive/path/to/labeled.jsonl', str(LABELED_FILE))\n",
|
|
"\n",
|
|
"import os\n",
|
|
"if LABELED_FILE.exists():\n",
|
|
" lines = sum(1 for _ in open(LABELED_FILE))\n",
|
|
" print(f'Ready: {lines:,} labeled positions at {LABELED_FILE}')\n",
|
|
"else:\n",
|
|
" print('No labeled.jsonl found — run Option A first or upload one.')"
|
|
],
|
|
"id": "option-b-upload"
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"## 🏋️ 3 — Train\n",
|
|
"\n",
|
|
"Standard training runs a fixed number of epochs. \n",
|
|
"**Burst mode** is better for Colab: it repeatedly restarts from the best checkpoint within a time budget, surviving session disconnects gracefully."
|
|
],
|
|
"id": "train-md"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES\n",
|
|
"\n",
|
|
"WEIGHTS_DIR = Path(DRIVE_ROOT) / 'weights'\n",
|
|
"WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)\n",
|
|
"OUTPUT_FILE = str(WEIGHTS_DIR / 'nnue_weights.pt')\n",
|
|
"\n",
|
|
"# ── Training hyperparameters ──────────────────────────────────────────────────\n",
|
|
"HIDDEN_SIZES = DEFAULT_HIDDEN_SIZES # [1536, 1024, 512, 256]\n",
|
|
"BATCH_SIZE = 16384\n",
|
|
"EPOCHS = 100\n",
|
|
"EARLY_STOPPING = 10 # None to disable\n",
|
|
"SUBSAMPLE_RATIO = 1.0\n",
|
|
"\n",
|
|
"# Resume from latest checkpoint if one exists\n",
|
|
"checkpoints = sorted(WEIGHTS_DIR.glob('nnue_weights_v*.pt'))\n",
|
|
"CHECKPOINT = str(checkpoints[-1]) if checkpoints else None\n",
|
|
"if CHECKPOINT:\n",
|
|
" print(f'Resuming from checkpoint: {CHECKPOINT}')\n",
|
|
"else:\n",
|
|
" print('Starting training from scratch.')"
|
|
],
|
|
"id": "train-config"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ── Standard training ─────────────────────────────────────────────────────────\n",
|
|
"# Use this when you have a reliable long-running session.\n",
|
|
"\n",
|
|
"train_nnue(\n",
|
|
" data_file=str(LABELED_FILE),\n",
|
|
" output_file=OUTPUT_FILE,\n",
|
|
" epochs=EPOCHS,\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" checkpoint=CHECKPOINT,\n",
|
|
" use_versioning=True,\n",
|
|
" early_stopping_patience=EARLY_STOPPING,\n",
|
|
" subsample_ratio=SUBSAMPLE_RATIO,\n",
|
|
" hidden_sizes=HIDDEN_SIZES,\n",
|
|
")"
|
|
],
|
|
"id": "standard-train"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ── Burst training (recommended for Colab free tier) ─────────────────────────\n",
|
|
"# Restarts from the global best each time early stopping fires.\n",
|
|
"# Set BURST_MINUTES to slightly less than the Colab session limit (~70 min).\n",
|
|
"\n",
|
|
"BURST_MINUTES = 70\n",
|
|
"EPOCHS_PER_SEASON = 30\n",
|
|
"BURST_PATIENCE = 8\n",
|
|
"\n",
|
|
"burst_train(\n",
|
|
" data_file=str(LABELED_FILE),\n",
|
|
" output_file=OUTPUT_FILE,\n",
|
|
" duration_minutes=BURST_MINUTES,\n",
|
|
" epochs_per_season=EPOCHS_PER_SEASON,\n",
|
|
" early_stopping_patience=BURST_PATIENCE,\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" initial_checkpoint=CHECKPOINT,\n",
|
|
" use_versioning=True,\n",
|
|
" subsample_ratio=SUBSAMPLE_RATIO,\n",
|
|
" hidden_sizes=HIDDEN_SIZES,\n",
|
|
")"
|
|
],
|
|
"id": "burst-train"
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"## 📦 4 — Export\n",
|
|
"\n",
|
|
"Convert the best `.pt` checkpoint to the `.nbai` binary format read by `NbaiLoader` in Scala."
|
|
],
|
|
"id": "export-md"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from export import export_to_nbai\n",
|
|
"\n",
|
|
"NBAI_FILE = Path(DRIVE_ROOT) / 'nnue_weights.nbai'\n",
|
|
"\n",
|
|
"# Pick the latest versioned checkpoint\n",
|
|
"checkpoints = sorted(WEIGHTS_DIR.glob('nnue_weights_v*.pt'))\n",
|
|
"if not checkpoints:\n",
|
|
" raise FileNotFoundError('No checkpoints found in ' + str(WEIGHTS_DIR))\n",
|
|
"\n",
|
|
"latest = checkpoints[-1]\n",
|
|
"print(f'Exporting {latest.name} → {NBAI_FILE.name}')\n",
|
|
"\n",
|
|
"export_to_nbai(\n",
|
|
" weights_file=str(latest),\n",
|
|
" output_file=str(NBAI_FILE),\n",
|
|
" trained_by='colab',\n",
|
|
")\n",
|
|
"print('Export complete.')"
|
|
],
|
|
"id": "export-cell"
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"## ⬇️ 5 — Download\n",
|
|
"\n",
|
|
"Download the `.nbai` weights file and the latest `.pt` checkpoint to your local machine.\n",
|
|
"\n",
|
|
"Place `nnue_weights.nbai` in `modules/official-bots/src/main/resources/` and rebuild the native image."
|
|
],
|
|
"id": "download-md"
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from google.colab import files\n",
|
|
"\n",
|
|
"if NBAI_FILE.exists():\n",
|
|
" files.download(str(NBAI_FILE))\n",
|
|
" print(f'Downloading {NBAI_FILE.name}')\n",
|
|
"else:\n",
|
|
" print('No .nbai file found — run the Export cell first.')\n",
|
|
"\n",
|
|
"checkpoints = sorted(WEIGHTS_DIR.glob('nnue_weights_v*.pt'))\n",
|
|
"if checkpoints:\n",
|
|
" latest = checkpoints[-1]\n",
|
|
" files.download(str(latest))\n",
|
|
" print(f'Downloading checkpoint {latest.name}')\n",
|
|
"else:\n",
|
|
" print('No .pt checkpoint found.')"
|
|
],
|
|
"id": "download-cell"
|
|
}
|
|
]
|
|
}
|