feat: Enhance NNUE evaluation with incremental updates and validation checks

This commit is contained in:
2026-04-13 21:30:26 +02:00
parent 8fb872e958
commit c88159ecec
2 changed files with 46 additions and 5 deletions
@@ -23,7 +23,11 @@ object EvaluationNNUE extends Evaluation:
nnue.copyAccumulator(parentPly, childPly)
override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit =
nnue.pushAccumulator(childPly, move, parent.board)
// Use incremental updates, but recompute from scratch every 10 plies to prevent accumulation errors
if (childPly % 10 == 0) then
nnue.recomputeAccumulator(childPly, child.board)
else
nnue.pushAccumulator(childPly, move, parent.board)
override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int =
nnue.evaluateAtPly(ply, context.turn, hash)
nnue.evaluateAtPlyWithValidation(ply, context.turn, hash, context.board)
@@ -8,6 +8,7 @@ class NNUE(model: NbaiModel):
private val featureSize = model.layers(0).inputSize
private val accSize = model.layers(0).outputSize
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
// Column-major L1 weights for cache-friendly sparse & incremental updates.
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
@@ -67,12 +68,40 @@ class NNUE(model: NbaiModel):
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
def recomputeAccumulator(ply: Int, board: Board): Unit =
System.arraycopy(model.weights(0).bias, 0, l1Stack(ply), 0, accSize)
for (sq, piece) <- board.pieces do addColumn(l1Stack(ply), featureIndex(piece, squareNum(sq)))
def validateAccumulator(ply: Int, board: Board): Boolean =
// Compute what L1 should be from scratch
val expectedL1 = new Array[Float](accSize)
System.arraycopy(model.weights(0).bias, 0, expectedL1, 0, accSize)
for (sq, piece) <- board.pieces do addColumn(expectedL1, featureIndex(piece, squareNum(sq)))
// Compare with actual L1
val actual = l1Stack(ply)
var maxError = 0f
for i <- 0 until accSize do
val error = math.abs(actual(i) - expectedL1(i))
if error > maxError then maxError = error
maxError < 0.001f // Allow small floating-point errors
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
// Extract source and destination square indices early
val fromNum = squareNum(move.from)
val toNum = squareNum(move.to)
// Get the moving piece
board.pieceAt(move.from).foreach { mover =>
val fromNum = squareNum(move.from)
val toNum = squareNum(move.to)
subtractColumn(l1, featureIndex(mover, fromNum))
board.pieceAt(move.to).foreach(cap => subtractColumn(l1, featureIndex(cap, toNum)))
// If there's a capture, subtract the captured piece
board.pieceAt(move.to).foreach { cap =>
subtractColumn(l1, featureIndex(cap, toNum))
}
// Add the piece to its new location
addColumn(l1, featureIndex(mover, toNum))
}
@@ -123,6 +152,14 @@ class NNUE(model: NbaiModel):
evalCacheScores(idx) = score
score
def evaluateAtPlyWithValidation(ply: Int, turn: Color, hash: Long, board: Board): Int =
// For debugging: validate that incremental accumulator matches recomputation
if validateAccum && ply > 0 && ply % 10 != 0 then
val isValid = validateAccumulator(ply, board)
if !isValid then
System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
evaluateAtPly(ply, turn, hash)
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
val l1ReLU = evalBuffers(0)
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f