feat: Enhance NNUE evaluation with incremental updates and validation checks
This commit is contained in:
@@ -23,7 +23,11 @@ object EvaluationNNUE extends Evaluation:
|
||||
nnue.copyAccumulator(parentPly, childPly)
|
||||
|
||||
override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit =
|
||||
nnue.pushAccumulator(childPly, move, parent.board)
|
||||
// Use incremental updates, but recompute from scratch every 10 plies to prevent accumulation errors
|
||||
if (childPly % 10 == 0) then
|
||||
nnue.recomputeAccumulator(childPly, child.board)
|
||||
else
|
||||
nnue.pushAccumulator(childPly, move, parent.board)
|
||||
|
||||
override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int =
|
||||
nnue.evaluateAtPly(ply, context.turn, hash)
|
||||
nnue.evaluateAtPlyWithValidation(ply, context.turn, hash, context.board)
|
||||
|
||||
@@ -8,6 +8,7 @@ class NNUE(model: NbaiModel):
|
||||
|
||||
private val featureSize = model.layers(0).inputSize
|
||||
private val accSize = model.layers(0).outputSize
|
||||
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
|
||||
|
||||
// Column-major L1 weights for cache-friendly sparse & incremental updates.
|
||||
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
|
||||
@@ -67,12 +68,40 @@ class NNUE(model: NbaiModel):
|
||||
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
||||
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
|
||||
|
||||
def recomputeAccumulator(ply: Int, board: Board): Unit =
|
||||
System.arraycopy(model.weights(0).bias, 0, l1Stack(ply), 0, accSize)
|
||||
for (sq, piece) <- board.pieces do addColumn(l1Stack(ply), featureIndex(piece, squareNum(sq)))
|
||||
|
||||
def validateAccumulator(ply: Int, board: Board): Boolean =
|
||||
// Compute what L1 should be from scratch
|
||||
val expectedL1 = new Array[Float](accSize)
|
||||
System.arraycopy(model.weights(0).bias, 0, expectedL1, 0, accSize)
|
||||
for (sq, piece) <- board.pieces do addColumn(expectedL1, featureIndex(piece, squareNum(sq)))
|
||||
|
||||
// Compare with actual L1
|
||||
val actual = l1Stack(ply)
|
||||
var maxError = 0f
|
||||
for i <- 0 until accSize do
|
||||
val error = math.abs(actual(i) - expectedL1(i))
|
||||
if error > maxError then maxError = error
|
||||
|
||||
maxError < 0.001f // Allow small floating-point errors
|
||||
|
||||
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||
// Extract source and destination square indices early
|
||||
val fromNum = squareNum(move.from)
|
||||
val toNum = squareNum(move.to)
|
||||
|
||||
// Get the moving piece
|
||||
board.pieceAt(move.from).foreach { mover =>
|
||||
val fromNum = squareNum(move.from)
|
||||
val toNum = squareNum(move.to)
|
||||
subtractColumn(l1, featureIndex(mover, fromNum))
|
||||
board.pieceAt(move.to).foreach(cap => subtractColumn(l1, featureIndex(cap, toNum)))
|
||||
|
||||
// If there's a capture, subtract the captured piece
|
||||
board.pieceAt(move.to).foreach { cap =>
|
||||
subtractColumn(l1, featureIndex(cap, toNum))
|
||||
}
|
||||
|
||||
// Add the piece to its new location
|
||||
addColumn(l1, featureIndex(mover, toNum))
|
||||
}
|
||||
|
||||
@@ -123,6 +152,14 @@ class NNUE(model: NbaiModel):
|
||||
evalCacheScores(idx) = score
|
||||
score
|
||||
|
||||
def evaluateAtPlyWithValidation(ply: Int, turn: Color, hash: Long, board: Board): Int =
|
||||
// For debugging: validate that incremental accumulator matches recomputation
|
||||
if validateAccum && ply > 0 && ply % 10 != 0 then
|
||||
val isValid = validateAccumulator(ply, board)
|
||||
if !isValid then
|
||||
System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
|
||||
evaluateAtPly(ply, turn, hash)
|
||||
|
||||
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
|
||||
val l1ReLU = evalBuffers(0)
|
||||
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
|
||||
|
||||
Reference in New Issue
Block a user