{ "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" }, "colab": { "provenance": [], "gpuType": "T4" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": "# NNUE Training Pipeline\n\nGPU training on Colab. Data is built **locally** (`./dataset.sh` → sharded, pushed to\nDrive via rclone); this notebook only **syncs shards → trains → exports `.nbai`**.\nNo generation, no Stockfish labeling, no browser uploads here.\n\n**Runtime:** GPU (T4 or better). Runtime → Change runtime type → T4 GPU.\n\n**Persistence:** Datasets and checkpoints live on Google Drive, so training resumes\nafter a session timeout. The repo is cloned to ephemeral `/content` for speed.", "id": "intro-md" }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## ⚙️ 1 — Setup" ], "id": "setup-md" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Mount Google Drive for checkpoint persistence\n", "from google.colab import drive\n", "drive.mount('/content/drive')" ], "id": "mount-drive" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "import os\n\n# ── Configure these paths once ───────────────────────────────────────────────\nREPO_URL = 'https://git.janis-eccarius.de/NowChess/NowChessSystems.git'\nDRIVE_ROOT = '/content/drive/MyDrive/NowChess' # datasets + weights persist here\nREPO_DIR = '/content/NowChessSystems' # ephemeral, fast local clone\nPYTHON_DIR = f'{REPO_DIR}/modules/official-bots/python'\n# ─────────────────────────────────────────────────────────────────────────────\n\nos.makedirs(DRIVE_ROOT, exist_ok=True)\n\n# Clone to ephemeral /content (NOT Drive) — fast checkout, no Drive bloat.\nif not os.path.isdir(REPO_DIR):\n !git clone --depth=1 \"{REPO_URL}\" \"{REPO_DIR}\"\n print('Repo cloned to /content.')\nelse:\n !git -C \"{REPO_DIR}\" pull --ff-only\n print('Repo updated.')", "id": "clone-repo" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "# Install Python dependencies. No Stockfish — labeling happens on the local box,\n# this notebook only trains on already-labeled shards.\n!pip install -q chess tqdm rich zstandard\n\nimport sys\nsys.path.insert(0, f'{PYTHON_DIR}/src')\nsys.path.insert(0, PYTHON_DIR)\nprint('Python path configured.')", "id": "install-deps" }, { "cell_type": "markdown", "metadata": {}, "source": "---\n## 🗄️ 2 — Data\n\nDatasets are built **locally** (`./dataset.sh`) and pushed to Drive with rclone as\ncompressed shards under `MyDrive/NowChess/datasets/`. Here we just sync those shards\nto the fast local disk — no generation, no labeling, no browser uploads.\n\nThe cell reads `manifest.json` and copies only shards not already cached on `/content`.", "id": "data-md" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "import json, shutil\nfrom pathlib import Path\n\n# Source: shards synced from the local box via `rclone copy datasets/ gdrive:NowChess/datasets`\nDRIVE_DATASETS = Path(DRIVE_ROOT) / 'datasets'\nLOCAL_DATASETS = Path('/content/datasets')\n(LOCAL_DATASETS / 'shards').mkdir(parents=True, exist_ok=True)\n\nmanifest = json.load(open(DRIVE_DATASETS / 'manifest.json'))\nprint(f\"Dataset v{manifest['dataset_version']}: \"\n f\"{manifest['total_positions']:,} positions across {len(manifest['shards'])} shards\")\n\ncopied = 0\nfor sh in manifest['shards']:\n dst = LOCAL_DATASETS / 'shards' / sh['file']\n if not dst.exists(): # cache: only copy shards we don't already have\n shutil.copy(DRIVE_DATASETS / 'shards' / sh['file'], dst)\n copied += 1\nshutil.copy(DRIVE_DATASETS / 'manifest.json', LOCAL_DATASETS / 'manifest.json')\n\nDATA_PATH = str(LOCAL_DATASETS) # train_nnue / burst_train read this dir of shards directly\nprint(f\"Synced {copied} new shard(s). Dataset ready at {DATA_PATH}\")", "id": "data-paths" }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 🏋️ 3 — Train\n", "\n", "Standard training runs a fixed number of epochs. \n", "**Burst mode** is better for Colab: it repeatedly restarts from the best checkpoint within a time budget, surviving session disconnects gracefully." ], "id": "train-md" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES\n", "\n", "WEIGHTS_DIR = Path(DRIVE_ROOT) / 'weights'\n", "WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)\n", "OUTPUT_FILE = str(WEIGHTS_DIR / 'nnue_weights.pt')\n", "\n", "# ── Training hyperparameters ──────────────────────────────────────────────────\n", "HIDDEN_SIZES = DEFAULT_HIDDEN_SIZES # [1536, 1024, 512, 256]\n", "BATCH_SIZE = 16384\n", "EPOCHS = 100\n", "EARLY_STOPPING = 10 # None to disable\n", "SUBSAMPLE_RATIO = 1.0\n", "\n", "# Resume from latest checkpoint if one exists\n", "checkpoints = sorted(WEIGHTS_DIR.glob('nnue_weights_v*.pt'))\n", "CHECKPOINT = str(checkpoints[-1]) if checkpoints else None\n", "if CHECKPOINT:\n", " print(f'Resuming from checkpoint: {CHECKPOINT}')\n", "else:\n", " print('Starting training from scratch.')" ], "id": "train-config" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "# ── Standard training ─────────────────────────────────────────────────────────\n# Use this when you have a reliable long-running session.\n\ntrain_nnue(\n data_file=DATA_PATH,\n output_file=OUTPUT_FILE,\n epochs=EPOCHS,\n batch_size=BATCH_SIZE,\n checkpoint=CHECKPOINT,\n use_versioning=True,\n early_stopping_patience=EARLY_STOPPING,\n subsample_ratio=SUBSAMPLE_RATIO,\n hidden_sizes=HIDDEN_SIZES,\n)", "id": "standard-train" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "# ── Burst training (recommended for Colab free tier) ─────────────────────────\n# Restarts from the global best each time early stopping fires.\n# Set BURST_MINUTES to slightly less than the Colab session limit (~70 min).\n\nBURST_MINUTES = 70\nEPOCHS_PER_SEASON = 30\nBURST_PATIENCE = 8\n\nburst_train(\n data_file=DATA_PATH,\n output_file=OUTPUT_FILE,\n duration_minutes=BURST_MINUTES,\n epochs_per_season=EPOCHS_PER_SEASON,\n early_stopping_patience=BURST_PATIENCE,\n batch_size=BATCH_SIZE,\n initial_checkpoint=CHECKPOINT,\n use_versioning=True,\n subsample_ratio=SUBSAMPLE_RATIO,\n hidden_sizes=HIDDEN_SIZES,\n)", "id": "burst-train" }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 📦 4 — Export\n", "\n", "Convert the best `.pt` checkpoint to the `.nbai` binary format read by `NbaiLoader` in Scala." ], "id": "export-md" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from export import export_to_nbai\n", "\n", "NBAI_FILE = Path(DRIVE_ROOT) / 'nnue_weights.nbai'\n", "\n", "# Pick the latest versioned checkpoint\n", "checkpoints = sorted(WEIGHTS_DIR.glob('nnue_weights_v*.pt'))\n", "if not checkpoints:\n", " raise FileNotFoundError('No checkpoints found in ' + str(WEIGHTS_DIR))\n", "\n", "latest = checkpoints[-1]\n", "print(f'Exporting {latest.name} → {NBAI_FILE.name}')\n", "\n", "export_to_nbai(\n", " weights_file=str(latest),\n", " output_file=str(NBAI_FILE),\n", " trained_by='colab',\n", ")\n", "print('Export complete.')" ], "id": "export-cell" }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## ⬇️ 5 — Download\n", "\n", "Download the `.nbai` weights file and the latest `.pt` checkpoint to your local machine.\n", "\n", "Place `nnue_weights.nbai` in `modules/official-bots/src/main/resources/` and rebuild the native image." ], "id": "download-md" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from google.colab import files\n", "\n", "if NBAI_FILE.exists():\n", " files.download(str(NBAI_FILE))\n", " print(f'Downloading {NBAI_FILE.name}')\n", "else:\n", " print('No .nbai file found — run the Export cell first.')\n", "\n", "checkpoints = sorted(WEIGHTS_DIR.glob('nnue_weights_v*.pt'))\n", "if checkpoints:\n", " latest = checkpoints[-1]\n", " files.download(str(latest))\n", " print(f'Downloading checkpoint {latest.name}')\n", "else:\n", " print('No .pt checkpoint found.')" ], "id": "download-cell" } ] }