From e2b4342f602215b5e8de6fccafc4105525a1ddd1 Mon Sep 17 00:00:00 2001 From: Janis Eccarius Date: Wed, 24 Jun 2026 22:18:03 +0200 Subject: [PATCH] fix(official-bots): prevent Colab OOM in NNUE training Dense 98304-dim HalfKP features at batch_size=16384 cost ~6.4 GB/batch on the host; with 8 hardcoded DataLoader workers and prefetch this OOM-killed the Colab runtime. - train.py: adaptive DataLoader workers (min(4, cpu_count), Colab free tier = 2), overridable via NNUE_LOADER_WORKERS; persistent_workers only when > 0. - NNUETraining.ipynb: lower BATCH_SIZE 16384 -> 4096 with a memory-cost note. Co-Authored-By: Claude Opus 4.8 --- .../official-bots/python/NNUETraining.ipynb | 23 +------------------ modules/official-bots/python/src/train.py | 13 +++++++---- 2 files changed, 10 insertions(+), 26 deletions(-) diff --git a/modules/official-bots/python/NNUETraining.ipynb b/modules/official-bots/python/NNUETraining.ipynb index 3fc92e1..ac7f758 100644 --- a/modules/official-bots/python/NNUETraining.ipynb +++ b/modules/official-bots/python/NNUETraining.ipynb @@ -92,28 +92,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES\n", - "\n", - "WEIGHTS_DIR = Path(DRIVE_ROOT) / 'weights'\n", - "WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)\n", - "OUTPUT_FILE = str(WEIGHTS_DIR / 'nnue_weights.pt')\n", - "\n", - "# ── Training hyperparameters ──────────────────────────────────────────────────\n", - "HIDDEN_SIZES = DEFAULT_HIDDEN_SIZES # [1536, 1024, 512, 256]\n", - "BATCH_SIZE = 16384\n", - "EPOCHS = 100\n", - "EARLY_STOPPING = 10 # None to disable\n", - "SUBSAMPLE_RATIO = 1.0\n", - "\n", - "# Resume from latest checkpoint if one exists\n", - "checkpoints = sorted(WEIGHTS_DIR.glob('nnue_weights_v*.pt'))\n", - "CHECKPOINT = str(checkpoints[-1]) if checkpoints else None\n", - "if CHECKPOINT:\n", - " print(f'Resuming from checkpoint: {CHECKPOINT}')\n", - "else:\n", - " print('Starting training from scratch.')" - ], + "source": "from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES\n\nWEIGHTS_DIR = Path(DRIVE_ROOT) / 'weights'\nWEIGHTS_DIR.mkdir(parents=True, exist_ok=True)\nOUTPUT_FILE = str(WEIGHTS_DIR / 'nnue_weights.pt')\n\n# ── Training hyperparameters ──────────────────────────────────────────────────\nHIDDEN_SIZES = DEFAULT_HIDDEN_SIZES\n# fen_to_features builds a DENSE 98304-dim input, so a batch costs\n# batch_size * 98304 * 4 bytes on the host (× DataLoader prefetch). On Colab's\n# ~12 GB RAM keep this small; raise it only if you have headroom.\nBATCH_SIZE = 4096\nEPOCHS = 100\nEARLY_STOPPING = 10 # None to disable\nSUBSAMPLE_RATIO = 1.0\n\n# Resume from latest checkpoint if one exists\ncheckpoints = sorted(WEIGHTS_DIR.glob('nnue_weights_v*.pt'))\nCHECKPOINT = str(checkpoints[-1]) if checkpoints else None\nif CHECKPOINT:\n print(f'Resuming from checkpoint: {CHECKPOINT}')\nelse:\n print('Starting training from scratch.')", "id": "train-config" }, { diff --git a/modules/official-bots/python/src/train.py b/modules/official-bots/python/src/train.py index fdf633a..b18165f 100644 --- a/modules/official-bots/python/src/train.py +++ b/modules/official-bots/python/src/train.py @@ -13,6 +13,11 @@ import chess from datetime import datetime, timedelta import re import numpy as np +import os + +# DataLoader workers: cap to the machine's CPUs (Colab free tier = 2). Too many +# workers each fork the dataset and OOM-kill the runtime. +LOADER_WORKERS = int(os.environ.get("NNUE_LOADER_WORKERS", min(4, os.cpu_count() or 2))) def _shard_files(data_file): @@ -256,17 +261,17 @@ def _setup_training(data_file, batch_size, subsample_ratio): train_dataset, batch_size=batch_size, sampler=train_sampler, - num_workers=8, + num_workers=LOADER_WORKERS, pin_memory=True, - persistent_workers=True + persistent_workers=LOADER_WORKERS > 0 ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, - num_workers=8, + num_workers=LOADER_WORKERS, pin_memory=True, - persistent_workers=True + persistent_workers=LOADER_WORKERS > 0 ) return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions