From 9d656624d85889f55746faa5704578e248f9b088 Mon Sep 17 00:00:00 2001 From: Janis Eccarius Date: Wed, 24 Jun 2026 22:28:53 +0200 Subject: [PATCH] fix(official-bots): stream NNUE features as sparse indices to stop host OOM Densifying the 98304-dim HalfKP vector per item filled host RAM and crashed the Colab runtime even at small batch sizes. The dataset now yields only the ~64 active feature indices; a custom collate carries (row, col) pairs and the training loop scatters them into a dense [B, INPUT_SIZE] tensor on the GPU. Host RAM stays tiny; GPU holds one dense batch transiently. - NNUEDataset.__getitem__ returns indices via new fen_to_indices. - fen_to_features now derives from fen_to_indices (kept for external callers). - _collate_sparse builds row/col index batches; loaders use it. - train/val loops scatter to a GPU dense batch; loss weighting uses batch size. - Notebook: BATCH_SIZE 4096 -> 8192 (host no longer the limit; GPU is). Co-Authored-By: Claude Opus 4.8 --- .../official-bots/python/NNUETraining.ipynb | 2 +- modules/official-bots/python/src/train.py | 66 +++++++++++++------ 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/modules/official-bots/python/NNUETraining.ipynb b/modules/official-bots/python/NNUETraining.ipynb index ac7f758..690e82a 100644 --- a/modules/official-bots/python/NNUETraining.ipynb +++ b/modules/official-bots/python/NNUETraining.ipynb @@ -92,7 +92,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES\n\nWEIGHTS_DIR = Path(DRIVE_ROOT) / 'weights'\nWEIGHTS_DIR.mkdir(parents=True, exist_ok=True)\nOUTPUT_FILE = str(WEIGHTS_DIR / 'nnue_weights.pt')\n\n# ── Training hyperparameters ──────────────────────────────────────────────────\nHIDDEN_SIZES = DEFAULT_HIDDEN_SIZES\n# fen_to_features builds a DENSE 98304-dim input, so a batch costs\n# batch_size * 98304 * 4 bytes on the host (× DataLoader prefetch). On Colab's\n# ~12 GB RAM keep this small; raise it only if you have headroom.\nBATCH_SIZE = 4096\nEPOCHS = 100\nEARLY_STOPPING = 10 # None to disable\nSUBSAMPLE_RATIO = 1.0\n\n# Resume from latest checkpoint if one exists\ncheckpoints = sorted(WEIGHTS_DIR.glob('nnue_weights_v*.pt'))\nCHECKPOINT = str(checkpoints[-1]) if checkpoints else None\nif CHECKPOINT:\n print(f'Resuming from checkpoint: {CHECKPOINT}')\nelse:\n print('Starting training from scratch.')", + "source": "from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES\n\nWEIGHTS_DIR = Path(DRIVE_ROOT) / 'weights'\nWEIGHTS_DIR.mkdir(parents=True, exist_ok=True)\nOUTPUT_FILE = str(WEIGHTS_DIR / 'nnue_weights.pt')\n\n# ── Training hyperparameters ──────────────────────────────────────────────────\nHIDDEN_SIZES = DEFAULT_HIDDEN_SIZES\n# Features are streamed as sparse indices and densified on the GPU per batch, so\n# host RAM is no longer the limit — GPU memory is. A dense batch is\n# batch_size * 98304 * 4 bytes on the GPU (~3.2 GB at 8192 on a 16 GB T4).\nBATCH_SIZE = 8192\nEPOCHS = 100\nEARLY_STOPPING = 10 # None to disable\nSUBSAMPLE_RATIO = 1.0\n\n# Resume from latest checkpoint if one exists\ncheckpoints = sorted(WEIGHTS_DIR.glob('nnue_weights_v*.pt'))\nCHECKPOINT = str(checkpoints[-1]) if checkpoints else None\nif CHECKPOINT:\n print(f'Resuming from checkpoint: {CHECKPOINT}')\nelse:\n print('Starting training from scratch.')", "id": "train-config" }, { diff --git a/modules/official-bots/python/src/train.py b/modules/official-bots/python/src/train.py index b18165f..7e1fe17 100644 --- a/modules/official-bots/python/src/train.py +++ b/modules/official-bots/python/src/train.py @@ -82,9 +82,12 @@ class NNUEDataset(Dataset): def __getitem__(self, idx): fen = self.positions[idx] eval_val = self.evals[idx] - features = fen_to_features(fen) + # Return only the active feature indices (~64), not a dense 98304-dim vector. + # The training loop scatters these into a dense batch on the GPU, keeping host + # RAM tiny. Densifying per-item here OOM-kills the runtime. + indices = fen_to_indices(fen) - # Board is flipped for Black-to-move in fen_to_features; negate eval + # Board is flipped for Black-to-move in fen_to_indices; negate eval # so the label still means "good for the side shown as White after flip" if ' b ' in fen: eval_val = -eval_val @@ -95,7 +98,7 @@ class NNUEDataset(Dataset): else: target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32)) - return features, target + return indices, target # King-relative (HalfKP) encoding: two perspectives, one per side's king. # Each piece is encoded as: kingSq * 768 + pieceIdx * 64 + sq @@ -110,14 +113,14 @@ _PIECE_TO_IDX = { } -def fen_to_features(fen): - """Convert FEN to 98304-dim king-relative (HalfKP) feature vector. +def fen_to_indices(fen): + """Active king-relative (HalfKP) feature indices for a FEN (~64 entries). For Black-to-move positions the board is mirrored (ranks flipped, colours swapped) so the network always sees the position from the side-to-move's - perspective. The caller is responsible for negating the eval label to match. + perspective. The caller is responsible for negating the eval label to match. """ - features = torch.zeros(INPUT_SIZE, dtype=torch.float32) + indices = [] try: board = chess.Board(fen) # Perspective flip: present all positions as if White is to move @@ -126,20 +129,41 @@ def fen_to_features(fen): wk = board.king(chess.WHITE) bk = board.king(chess.BLACK) if wk is None or bk is None: - return features + return torch.zeros(0, dtype=torch.long) for sq in chess.SQUARES: piece = board.piece_at(sq) if piece is None: continue pidx = _PIECE_TO_IDX[piece.symbol()] # White-king perspective (indices 0 .. _HALF_SIZE-1) - features[wk * 768 + pidx * 64 + sq] = 1.0 + indices.append(wk * 768 + pidx * 64 + sq) # Black-king perspective (indices _HALF_SIZE .. INPUT_SIZE-1) - features[_HALF_SIZE + bk * 768 + pidx * 64 + sq] = 1.0 + indices.append(_HALF_SIZE + bk * 768 + pidx * 64 + sq) except Exception: - pass + return torch.zeros(0, dtype=torch.long) + return torch.tensor(indices, dtype=torch.long) + + +def fen_to_features(fen): + """Dense 98304-dim HalfKP vector. Kept for external callers; training uses the + sparse indices + GPU scatter path instead (see _collate_sparse).""" + features = torch.zeros(INPUT_SIZE, dtype=torch.float32) + features[fen_to_indices(fen)] = 1.0 return features + +def _collate_sparse(batch): + """Collate (indices, target) items into (row_idx, col_idx, batch_size), targets. + + Row/col index pairs address the active features of a dense [B, INPUT_SIZE] tensor + that the training loop allocates on the GPU — so the host only ever holds the + sparse indices, never a dense batch.""" + idx_list, targets = zip(*batch) + rows = torch.cat([torch.full((idx.numel(),), i, dtype=torch.long) + for i, idx in enumerate(idx_list)]) + cols = torch.cat(idx_list) + return (rows, cols, len(idx_list)), torch.stack(targets) + # Smaller hidden layers are appropriate: the L1 input is very sparse (~64 active # features out of 98304) so the L1 itself is cheap to update incrementally; the # larger capacity comes from the wider perspective encoding, not deeper layers. @@ -263,7 +287,8 @@ def _setup_training(data_file, batch_size, subsample_ratio): sampler=train_sampler, num_workers=LOADER_WORKERS, pin_memory=True, - persistent_workers=LOADER_WORKERS > 0 + persistent_workers=LOADER_WORKERS > 0, + collate_fn=_collate_sparse, ) val_loader = DataLoader( val_dataset, @@ -271,7 +296,8 @@ def _setup_training(data_file, batch_size, subsample_ratio): shuffle=False, num_workers=LOADER_WORKERS, pin_memory=True, - persistent_workers=LOADER_WORKERS > 0 + persistent_workers=LOADER_WORKERS > 0, + collate_fn=_collate_sparse, ) return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions @@ -309,8 +335,9 @@ def _run_training_season( model.train() train_loss = 0.0 with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar: - for batch_features, batch_targets in train_loader: - batch_features = batch_features.to(device) + for (rows, cols, bsz), batch_targets in train_loader: + batch_features = torch.zeros(bsz, INPUT_SIZE, device=device) + batch_features[rows.to(device), cols.to(device)] = 1.0 batch_targets = batch_targets.to(device).unsqueeze(1) optimizer.zero_grad() @@ -323,7 +350,7 @@ def _run_training_season( scaler.step(optimizer) scaler.update() - train_loss += loss.item() * batch_features.size(0) + train_loss += loss.item() * bsz pbar.update(1) train_loss /= len(train_dataset) @@ -333,14 +360,15 @@ def _run_training_season( val_loss = 0.0 with torch.no_grad(): with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar: - for batch_features, batch_targets in val_loader: - batch_features = batch_features.to(device) + for (rows, cols, bsz), batch_targets in val_loader: + batch_features = torch.zeros(bsz, INPUT_SIZE, device=device) + batch_features[rows.to(device), cols.to(device)] = 1.0 batch_targets = batch_targets.to(device).unsqueeze(1) with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'): outputs = model(batch_features) loss = criterion(outputs, batch_targets) - val_loss += loss.item() * batch_features.size(0) + val_loss += loss.item() * bsz pbar.update(1) val_loss /= len(val_dataset)