From c88159ececc6276bc4f749753c4f903d7aac0b62 Mon Sep 17 00:00:00 2001 From: Janis Date: Mon, 13 Apr 2026 21:30:26 +0200 Subject: [PATCH] feat: Enhance NNUE evaluation with incremental updates and validation checks --- .../bot/bots/nnue/EvaluationNNUE.scala | 8 +++- .../de/nowchess/bot/bots/nnue/NNUE.scala | 43 +++++++++++++++++-- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala index 478a021..e1957b6 100644 --- a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala +++ b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala @@ -23,7 +23,11 @@ object EvaluationNNUE extends Evaluation: nnue.copyAccumulator(parentPly, childPly) override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit = - nnue.pushAccumulator(childPly, move, parent.board) + // Use incremental updates, but recompute from scratch every 10 plies to prevent accumulation errors + if (childPly % 10 == 0) then + nnue.recomputeAccumulator(childPly, child.board) + else + nnue.pushAccumulator(childPly, move, parent.board) override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int = - nnue.evaluateAtPly(ply, context.turn, hash) + nnue.evaluateAtPlyWithValidation(ply, context.turn, hash, context.board) diff --git a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala index 3a23d62..efc1c8f 100644 --- a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala +++ b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala @@ -8,6 +8,7 @@ class NNUE(model: NbaiModel): private val featureSize = model.layers(0).inputSize private val accSize = model.layers(0).outputSize + private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1 // Column-major L1 weights for cache-friendly sparse & incremental updates. // l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx) @@ -67,12 +68,40 @@ class NNUE(model: NbaiModel): def copyAccumulator(parentPly: Int, childPly: Int): Unit = System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize) + def recomputeAccumulator(ply: Int, board: Board): Unit = + System.arraycopy(model.weights(0).bias, 0, l1Stack(ply), 0, accSize) + for (sq, piece) <- board.pieces do addColumn(l1Stack(ply), featureIndex(piece, squareNum(sq))) + + def validateAccumulator(ply: Int, board: Board): Boolean = + // Compute what L1 should be from scratch + val expectedL1 = new Array[Float](accSize) + System.arraycopy(model.weights(0).bias, 0, expectedL1, 0, accSize) + for (sq, piece) <- board.pieces do addColumn(expectedL1, featureIndex(piece, squareNum(sq))) + + // Compare with actual L1 + val actual = l1Stack(ply) + var maxError = 0f + for i <- 0 until accSize do + val error = math.abs(actual(i) - expectedL1(i)) + if error > maxError then maxError = error + + maxError < 0.001f // Allow small floating-point errors + private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit = + // Extract source and destination square indices early + val fromNum = squareNum(move.from) + val toNum = squareNum(move.to) + + // Get the moving piece board.pieceAt(move.from).foreach { mover => - val fromNum = squareNum(move.from) - val toNum = squareNum(move.to) subtractColumn(l1, featureIndex(mover, fromNum)) - board.pieceAt(move.to).foreach(cap => subtractColumn(l1, featureIndex(cap, toNum))) + + // If there's a capture, subtract the captured piece + board.pieceAt(move.to).foreach { cap => + subtractColumn(l1, featureIndex(cap, toNum)) + } + + // Add the piece to its new location addColumn(l1, featureIndex(mover, toNum)) } @@ -123,6 +152,14 @@ class NNUE(model: NbaiModel): evalCacheScores(idx) = score score + def evaluateAtPlyWithValidation(ply: Int, turn: Color, hash: Long, board: Board): Int = + // For debugging: validate that incremental accumulator matches recomputation + if validateAccum && ply > 0 && ply % 10 != 0 then + val isValid = validateAccumulator(ply, board) + if !isValid then + System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply") + evaluateAtPly(ply, turn, hash) + private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int = val l1ReLU = evalBuffers(0) for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f