feat: Improved how NNUE Evalutes

This commit is contained in:
2026-04-13 17:37:24 +02:00
parent ed26406185
commit 5df5a1875f
23 changed files with 438 additions and 292 deletions
@@ -4,7 +4,7 @@ import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.bot.bots.nnue.EvaluationNNUE
import de.nowchess.bot.logic.AlphaBetaSearch
import de.nowchess.bot.util.PolyglotBook
import de.nowchess.bot.util.{PolyglotBook, ZobristHash}
import de.nowchess.bot.{Bot, BotDifficulty}
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
@@ -22,11 +22,36 @@ final class NNUEBot(
override def nextMove(context: GameContext): Option[Move] =
book
.flatMap(_.probe(context))
.orElse(search.bestMoveWithTime(context, allocateTime(context)))
.orElse {
val moves = rules.allLegalMoves(context)
if moves.isEmpty then None
else
val scored = batchEvaluateRoot(context, moves)
val bestMove = scored.maxBy(_._2)._1
search.bestMoveWithTime(context, allocateTime(scored)).orElse(Some(bestMove))
}
/** Allocate more time for complex or critical positions. */
private def allocateTime(context: GameContext): Long =
val moveCount = rules.allLegalMoves(context).length
/** Evaluate all root moves shallowly via incremental NNUE accumulator updates. Returns (move, score) pairs with score
* from the root player's perspective.
*/
private def batchEvaluateRoot(context: GameContext, moves: List[Move]): List[(Move, Int)] =
EvaluationNNUE.initAccumulator(context)
val rootHash = ZobristHash.hash(context)
moves.map { move =>
val child = rules.applyMove(context)(move)
val childHash = ZobristHash.nextHash(context, rootHash, move, child)
EvaluationNNUE.pushAccumulator(1, move, context, child)
val score = -EvaluationNNUE.evaluateAccumulator(1, child, childHash)
(move, score)
}
/** Allocate more time for complex positions; less when one move clearly dominates. */
private def allocateTime(scored: List[(Move, Int)]): Long =
val moveCount = scored.length
if moveCount > 30 then 1500L
else if moveCount < 5 then 500L
else 1000L
else
val scores = scored.map(_._2)
val best = scores.max
val second = scores.filter(_ < best).maxOption.getOrElse(best)
if best - second > 200 then 600L else 1000L
@@ -249,7 +249,7 @@ object EvaluationClassic extends Evaluation:
@scala.annotation.tailrec
def countRay(current: Option[Square], acc: Int): Int =
current match
case None => acc
case None => acc
case Some(target) =>
board.pieceAt(target) match
case Some(piece) if piece.color == color => acc
@@ -287,10 +287,10 @@ object EvaluationClassic extends Evaluation:
val friendlyHasPair =
friendlyBishops.exists((sq, _) => (sq.file.ordinal + sq.rank.ordinal) % 2 == 0) &&
friendlyBishops.exists((sq, _) => (sq.file.ordinal + sq.rank.ordinal) % 2 == 1)
friendlyBishops.exists((sq, _) => (sq.file.ordinal + sq.rank.ordinal) % 2 == 1)
val enemyHasPair =
enemyBishops.exists((sq, _) => (sq.file.ordinal + sq.rank.ordinal) % 2 == 0) &&
enemyBishops.exists((sq, _) => (sq.file.ordinal + sq.rank.ordinal) % 2 == 1)
enemyBishops.exists((sq, _) => (sq.file.ordinal + sq.rank.ordinal) % 2 == 1)
val baseMg = (if friendlyHasPair then bishopPairMg else 0) - (if enemyHasPair then bishopPairMg else 0)
val baseEg = (if friendlyHasPair then bishopPairEg else 0) - (if enemyHasPair then bishopPairEg else 0)
@@ -312,7 +312,7 @@ object EvaluationClassic extends Evaluation:
val kingCentralBonus =
friendlyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15) -
enemyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15)
enemyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15)
val friendlyMaterial = materialCount(context, context.turn)
val enemyMaterial = materialCount(context, context.turn.opposite)
@@ -15,8 +15,7 @@ class NNUE:
// l1WeightsT(featureIdx * 1536 + outputIdx) = l1Weights(outputIdx * 768 + featureIdx)
private val l1WeightsT: Array[Float] =
val t = new Array[Float](768 * 1536)
for j <- 0 until 768; i <- 0 until 1536 do
t(j * 1536 + i) = l1Weights(i * 768 + j)
for j <- 0 until 768; i <- 0 until 1536 do t(j * 1536 + i) = l1Weights(i * 768 + j)
t
private def loadWeights(): (
@@ -42,11 +42,14 @@ final class AlphaBetaSearch(
timeLimitMs.set(Long.MaxValue / 4)
nodeCount.set(0)
val rootHash = ZobristHash.hash(context)
(1 to maxDepth).foldLeft((None: Option[Move], 0)) { case ((bestSoFar, prevScore), depth) =>
val (alpha, beta) = if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
val (score, move) = searchWithAspiration(context, depth, alpha, beta, ASPIRATION_DELTA, rootHash)
(move.orElse(bestSoFar), score)
}._1
(1 to maxDepth)
.foldLeft((None: Option[Move], 0)) { case ((bestSoFar, prevScore), depth) =>
val (alpha, beta) =
if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
val (score, move) = searchWithAspiration(context, depth, alpha, beta, ASPIRATION_DELTA, rootHash)
(move.orElse(bestSoFar), score)
}
._1
/** Return the best move for the side to move within a time budget (ms). Uses iterative deepening, stopping when time
* runs out.
@@ -64,7 +67,8 @@ final class AlphaBetaSearch(
def loop(bestSoFar: Option[Move], prevScore: Int, depth: Int): Option[Move] =
if isOutOfTime then bestSoFar
else
val (alpha, beta) = if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
val (alpha, beta) =
if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
val (score, move) = searchWithAspiration(context, depth, alpha, beta, ASPIRATION_DELTA, rootHash)
loop(move.orElse(bestSoFar), score, depth + 1)
@@ -85,15 +89,13 @@ final class AlphaBetaSearch(
@scala.annotation.tailrec
def loop(currentAlpha: Int, currentBeta: Int, window: Int, attempt: Int): (Int, Option[Move]) =
if attempt >= 3 || attempt >= depth then
search(context, depth, 0, -INF, INF, rootHash, repetitions)
if attempt >= 3 || attempt >= depth then search(context, depth, 0, -INF, INF, rootHash, repetitions)
else
val (score, move) = search(context, depth, 0, currentAlpha, currentBeta, rootHash, repetitions)
if score > currentAlpha && score < currentBeta then (score, move)
else if score <= currentAlpha then
loop(score - window, currentBeta, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1)
else
loop(currentAlpha, score + window, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1)
else loop(currentAlpha, score + window, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1)
loop(alpha, beta, initialWindow, 0)
@@ -136,10 +138,8 @@ final class AlphaBetaSearch(
repetitions: Map[Long, Int],
): (Int, Option[Move]) =
val count = nodeCount.incrementAndGet()
if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then
(weights.evaluateAccumulator(ply, context, hash), None)
else if repetitions.getOrElse(hash, 0) >= 3 then
(weights.DRAW_SCORE, None)
if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then (weights.evaluateAccumulator(ply, context, hash), None)
else if repetitions.getOrElse(hash, 0) >= 3 then (weights.DRAW_SCORE, None)
else
val ttCutoff = tt.probe(hash).filter(_.depth >= depth).flatMap { entry =>
entry.flag match
@@ -155,10 +155,8 @@ final class AlphaBetaSearch(
val legalMoves = rules.allLegalMoves(context)
if legalMoves.isEmpty then
(if rules.isCheckmate(context) then -(weights.CHECKMATE_SCORE - ply) else weights.DRAW_SCORE, None)
else if rules.isInsufficientMaterial(context) || rules.isFiftyMoveRule(context) then
(weights.DRAW_SCORE, None)
else if depth == 0 then
(quiescence(context, ply, alpha, beta, hash), None)
else if rules.isInsufficientMaterial(context) || rules.isFiftyMoveRule(context) then (weights.DRAW_SCORE, None)
else if depth == 0 then (quiescence(context, ply, alpha, beta, hash), None)
else
val nullResult = Option
.when(depth >= 3 && !rules.isCheck(context) && hasNonPawnMaterial(context)) {
@@ -192,7 +190,7 @@ final class AlphaBetaSearch(
): (Option[Move], Int, Boolean) =
if idx >= ordered.length then (bestMove, bestScore, false)
else
val move = ordered(idx)
val move = ordered(idx)
val isQuiet = !isCapture(context, move) &&
move.moveType != MoveType.CastleKingside &&
move.moveType != MoveType.CastleQueenside
@@ -233,11 +231,14 @@ final class AlphaBetaSearch(
if newA >= beta then
if isQuiet then
ordering.addHistory(move.from.rank.ordinal * 8 + move.from.file.ordinal, move.to.rank.ordinal * 8 + move.to.file.ordinal, depth * depth)
ordering.addHistory(
move.from.rank.ordinal * 8 + move.from.file.ordinal,
move.to.rank.ordinal * 8 + move.to.file.ordinal,
depth * depth,
)
ordering.addKillerMove(ply, move)
(newBestMove, newBestScore, true)
else
loop(idx + 1, newBestMove, newBestScore, newA, moveNumber + 1)
else loop(idx + 1, newBestMove, newBestScore, newA, moveNumber + 1)
val (bestMove, bestScore, cutoff) = loop(0, None, -INF, alpha, 0)
val flag =
@@ -262,8 +263,7 @@ final class AlphaBetaSearch(
else
val a0 = if inCheck then alpha else math.max(alpha, standPat)
if ply >= MAX_QUIESCENCE_PLY then
if inCheck then weights.evaluateAccumulator(ply, context, hash) else standPat
if ply >= MAX_QUIESCENCE_PLY then if inCheck then weights.evaluateAccumulator(ply, context, hash) else standPat
else
val allMoves = rules.allLegalMoves(context)
val tacticalMoves = if inCheck then allMoves else allMoves.filter(m => isCapture(context, m))
@@ -34,5 +34,4 @@ final class TranspositionTable(val sizePow2: Int = 20):
}
def clear(): Unit =
for i <- 0 until size do
locks(i).synchronized { table(i) = None }
for i <- 0 until size do locks(i).synchronized { table(i) = None }
@@ -82,15 +82,15 @@ final class PolyglotBook(path: String):
if isKingMove(context, from) && isRookSquare(to, context) then Some(decodeCastling(from, to))
else
val moveTypeOpt: Option[MoveType] = if promotionBits > 0 then
promotionBits match
case 1 => Some(MoveType.Promotion(PromotionPiece.Knight))
case 2 => Some(MoveType.Promotion(PromotionPiece.Bishop))
case 3 => Some(MoveType.Promotion(PromotionPiece.Rook))
case 4 => Some(MoveType.Promotion(PromotionPiece.Queen))
case _ => None
else
Some(MoveType.Normal(context.board.pieces.contains(to)))
val moveTypeOpt: Option[MoveType] =
if promotionBits > 0 then
promotionBits match
case 1 => Some(MoveType.Promotion(PromotionPiece.Knight))
case 2 => Some(MoveType.Promotion(PromotionPiece.Bishop))
case 3 => Some(MoveType.Promotion(PromotionPiece.Rook))
case 4 => Some(MoveType.Promotion(PromotionPiece.Queen))
case _ => None
else Some(MoveType.Normal(context.board.pieces.contains(to)))
moveTypeOpt.map(moveType => Move(from, to, moveType))
@@ -74,7 +74,7 @@ object ZobristHash:
context.board.pieceAt(move.from).fold(h0) { pawn =>
val capturedSquare = Square(move.to.file, move.from.rank)
val h1 = h0 ^ pieceKey(move.from, pawn)
val h2 = context.board.pieceAt(capturedSquare).fold(h1)(captured => h1 ^ pieceKey(capturedSquare, captured))
val h2 = context.board.pieceAt(capturedSquare).fold(h1)(captured => h1 ^ pieceKey(capturedSquare, captured))
h2 ^ pieceKey(move.to, pawn)
}
@@ -106,9 +106,12 @@ object ZobristHash:
case PromotionPiece.Queen => PieceType.Queen
private def toggleCastling(h0: Long, before: GameContext, after: GameContext): Long =
val h1 = if before.castlingRights.whiteKingSide != after.castlingRights.whiteKingSide then h0 ^ castlingRands(0) else h0
val h2 = if before.castlingRights.whiteQueenSide != after.castlingRights.whiteQueenSide then h1 ^ castlingRands(1) else h1
val h3 = if before.castlingRights.blackKingSide != after.castlingRights.blackKingSide then h2 ^ castlingRands(2) else h2
val h1 =
if before.castlingRights.whiteKingSide != after.castlingRights.whiteKingSide then h0 ^ castlingRands(0) else h0
val h2 =
if before.castlingRights.whiteQueenSide != after.castlingRights.whiteQueenSide then h1 ^ castlingRands(1) else h1
val h3 =
if before.castlingRights.blackKingSide != after.castlingRights.blackKingSide then h2 ^ castlingRands(2) else h2
if before.castlingRights.blackQueenSide != after.castlingRights.blackQueenSide then h3 ^ castlingRands(3) else h3
private def toggleEnPassant(h0: Long, before: GameContext, after: GameContext): Long =