feat: Improved how NNUE Evalutes
This commit is contained in:
@@ -4,7 +4,7 @@ import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.bots.nnue.EvaluationNNUE
|
||||
import de.nowchess.bot.logic.AlphaBetaSearch
|
||||
import de.nowchess.bot.util.PolyglotBook
|
||||
import de.nowchess.bot.util.{PolyglotBook, ZobristHash}
|
||||
import de.nowchess.bot.{Bot, BotDifficulty}
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
@@ -22,11 +22,36 @@ final class NNUEBot(
|
||||
override def nextMove(context: GameContext): Option[Move] =
|
||||
book
|
||||
.flatMap(_.probe(context))
|
||||
.orElse(search.bestMoveWithTime(context, allocateTime(context)))
|
||||
.orElse {
|
||||
val moves = rules.allLegalMoves(context)
|
||||
if moves.isEmpty then None
|
||||
else
|
||||
val scored = batchEvaluateRoot(context, moves)
|
||||
val bestMove = scored.maxBy(_._2)._1
|
||||
search.bestMoveWithTime(context, allocateTime(scored)).orElse(Some(bestMove))
|
||||
}
|
||||
|
||||
/** Allocate more time for complex or critical positions. */
|
||||
private def allocateTime(context: GameContext): Long =
|
||||
val moveCount = rules.allLegalMoves(context).length
|
||||
/** Evaluate all root moves shallowly via incremental NNUE accumulator updates. Returns (move, score) pairs with score
|
||||
* from the root player's perspective.
|
||||
*/
|
||||
private def batchEvaluateRoot(context: GameContext, moves: List[Move]): List[(Move, Int)] =
|
||||
EvaluationNNUE.initAccumulator(context)
|
||||
val rootHash = ZobristHash.hash(context)
|
||||
moves.map { move =>
|
||||
val child = rules.applyMove(context)(move)
|
||||
val childHash = ZobristHash.nextHash(context, rootHash, move, child)
|
||||
EvaluationNNUE.pushAccumulator(1, move, context, child)
|
||||
val score = -EvaluationNNUE.evaluateAccumulator(1, child, childHash)
|
||||
(move, score)
|
||||
}
|
||||
|
||||
/** Allocate more time for complex positions; less when one move clearly dominates. */
|
||||
private def allocateTime(scored: List[(Move, Int)]): Long =
|
||||
val moveCount = scored.length
|
||||
if moveCount > 30 then 1500L
|
||||
else if moveCount < 5 then 500L
|
||||
else 1000L
|
||||
else
|
||||
val scores = scored.map(_._2)
|
||||
val best = scores.max
|
||||
val second = scores.filter(_ < best).maxOption.getOrElse(best)
|
||||
if best - second > 200 then 600L else 1000L
|
||||
|
||||
@@ -249,7 +249,7 @@ object EvaluationClassic extends Evaluation:
|
||||
@scala.annotation.tailrec
|
||||
def countRay(current: Option[Square], acc: Int): Int =
|
||||
current match
|
||||
case None => acc
|
||||
case None => acc
|
||||
case Some(target) =>
|
||||
board.pieceAt(target) match
|
||||
case Some(piece) if piece.color == color => acc
|
||||
@@ -287,10 +287,10 @@ object EvaluationClassic extends Evaluation:
|
||||
|
||||
val friendlyHasPair =
|
||||
friendlyBishops.exists((sq, _) => (sq.file.ordinal + sq.rank.ordinal) % 2 == 0) &&
|
||||
friendlyBishops.exists((sq, _) => (sq.file.ordinal + sq.rank.ordinal) % 2 == 1)
|
||||
friendlyBishops.exists((sq, _) => (sq.file.ordinal + sq.rank.ordinal) % 2 == 1)
|
||||
val enemyHasPair =
|
||||
enemyBishops.exists((sq, _) => (sq.file.ordinal + sq.rank.ordinal) % 2 == 0) &&
|
||||
enemyBishops.exists((sq, _) => (sq.file.ordinal + sq.rank.ordinal) % 2 == 1)
|
||||
enemyBishops.exists((sq, _) => (sq.file.ordinal + sq.rank.ordinal) % 2 == 1)
|
||||
|
||||
val baseMg = (if friendlyHasPair then bishopPairMg else 0) - (if enemyHasPair then bishopPairMg else 0)
|
||||
val baseEg = (if friendlyHasPair then bishopPairEg else 0) - (if enemyHasPair then bishopPairEg else 0)
|
||||
@@ -312,7 +312,7 @@ object EvaluationClassic extends Evaluation:
|
||||
|
||||
val kingCentralBonus =
|
||||
friendlyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15) -
|
||||
enemyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15)
|
||||
enemyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15)
|
||||
|
||||
val friendlyMaterial = materialCount(context, context.turn)
|
||||
val enemyMaterial = materialCount(context, context.turn.opposite)
|
||||
|
||||
@@ -15,8 +15,7 @@ class NNUE:
|
||||
// l1WeightsT(featureIdx * 1536 + outputIdx) = l1Weights(outputIdx * 768 + featureIdx)
|
||||
private val l1WeightsT: Array[Float] =
|
||||
val t = new Array[Float](768 * 1536)
|
||||
for j <- 0 until 768; i <- 0 until 1536 do
|
||||
t(j * 1536 + i) = l1Weights(i * 768 + j)
|
||||
for j <- 0 until 768; i <- 0 until 1536 do t(j * 1536 + i) = l1Weights(i * 768 + j)
|
||||
t
|
||||
|
||||
private def loadWeights(): (
|
||||
|
||||
@@ -42,11 +42,14 @@ final class AlphaBetaSearch(
|
||||
timeLimitMs.set(Long.MaxValue / 4)
|
||||
nodeCount.set(0)
|
||||
val rootHash = ZobristHash.hash(context)
|
||||
(1 to maxDepth).foldLeft((None: Option[Move], 0)) { case ((bestSoFar, prevScore), depth) =>
|
||||
val (alpha, beta) = if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
|
||||
val (score, move) = searchWithAspiration(context, depth, alpha, beta, ASPIRATION_DELTA, rootHash)
|
||||
(move.orElse(bestSoFar), score)
|
||||
}._1
|
||||
(1 to maxDepth)
|
||||
.foldLeft((None: Option[Move], 0)) { case ((bestSoFar, prevScore), depth) =>
|
||||
val (alpha, beta) =
|
||||
if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
|
||||
val (score, move) = searchWithAspiration(context, depth, alpha, beta, ASPIRATION_DELTA, rootHash)
|
||||
(move.orElse(bestSoFar), score)
|
||||
}
|
||||
._1
|
||||
|
||||
/** Return the best move for the side to move within a time budget (ms). Uses iterative deepening, stopping when time
|
||||
* runs out.
|
||||
@@ -64,7 +67,8 @@ final class AlphaBetaSearch(
|
||||
def loop(bestSoFar: Option[Move], prevScore: Int, depth: Int): Option[Move] =
|
||||
if isOutOfTime then bestSoFar
|
||||
else
|
||||
val (alpha, beta) = if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
|
||||
val (alpha, beta) =
|
||||
if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
|
||||
val (score, move) = searchWithAspiration(context, depth, alpha, beta, ASPIRATION_DELTA, rootHash)
|
||||
loop(move.orElse(bestSoFar), score, depth + 1)
|
||||
|
||||
@@ -85,15 +89,13 @@ final class AlphaBetaSearch(
|
||||
|
||||
@scala.annotation.tailrec
|
||||
def loop(currentAlpha: Int, currentBeta: Int, window: Int, attempt: Int): (Int, Option[Move]) =
|
||||
if attempt >= 3 || attempt >= depth then
|
||||
search(context, depth, 0, -INF, INF, rootHash, repetitions)
|
||||
if attempt >= 3 || attempt >= depth then search(context, depth, 0, -INF, INF, rootHash, repetitions)
|
||||
else
|
||||
val (score, move) = search(context, depth, 0, currentAlpha, currentBeta, rootHash, repetitions)
|
||||
if score > currentAlpha && score < currentBeta then (score, move)
|
||||
else if score <= currentAlpha then
|
||||
loop(score - window, currentBeta, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1)
|
||||
else
|
||||
loop(currentAlpha, score + window, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1)
|
||||
else loop(currentAlpha, score + window, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1)
|
||||
|
||||
loop(alpha, beta, initialWindow, 0)
|
||||
|
||||
@@ -136,10 +138,8 @@ final class AlphaBetaSearch(
|
||||
repetitions: Map[Long, Int],
|
||||
): (Int, Option[Move]) =
|
||||
val count = nodeCount.incrementAndGet()
|
||||
if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then
|
||||
(weights.evaluateAccumulator(ply, context, hash), None)
|
||||
else if repetitions.getOrElse(hash, 0) >= 3 then
|
||||
(weights.DRAW_SCORE, None)
|
||||
if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then (weights.evaluateAccumulator(ply, context, hash), None)
|
||||
else if repetitions.getOrElse(hash, 0) >= 3 then (weights.DRAW_SCORE, None)
|
||||
else
|
||||
val ttCutoff = tt.probe(hash).filter(_.depth >= depth).flatMap { entry =>
|
||||
entry.flag match
|
||||
@@ -155,10 +155,8 @@ final class AlphaBetaSearch(
|
||||
val legalMoves = rules.allLegalMoves(context)
|
||||
if legalMoves.isEmpty then
|
||||
(if rules.isCheckmate(context) then -(weights.CHECKMATE_SCORE - ply) else weights.DRAW_SCORE, None)
|
||||
else if rules.isInsufficientMaterial(context) || rules.isFiftyMoveRule(context) then
|
||||
(weights.DRAW_SCORE, None)
|
||||
else if depth == 0 then
|
||||
(quiescence(context, ply, alpha, beta, hash), None)
|
||||
else if rules.isInsufficientMaterial(context) || rules.isFiftyMoveRule(context) then (weights.DRAW_SCORE, None)
|
||||
else if depth == 0 then (quiescence(context, ply, alpha, beta, hash), None)
|
||||
else
|
||||
val nullResult = Option
|
||||
.when(depth >= 3 && !rules.isCheck(context) && hasNonPawnMaterial(context)) {
|
||||
@@ -192,7 +190,7 @@ final class AlphaBetaSearch(
|
||||
): (Option[Move], Int, Boolean) =
|
||||
if idx >= ordered.length then (bestMove, bestScore, false)
|
||||
else
|
||||
val move = ordered(idx)
|
||||
val move = ordered(idx)
|
||||
val isQuiet = !isCapture(context, move) &&
|
||||
move.moveType != MoveType.CastleKingside &&
|
||||
move.moveType != MoveType.CastleQueenside
|
||||
@@ -233,11 +231,14 @@ final class AlphaBetaSearch(
|
||||
|
||||
if newA >= beta then
|
||||
if isQuiet then
|
||||
ordering.addHistory(move.from.rank.ordinal * 8 + move.from.file.ordinal, move.to.rank.ordinal * 8 + move.to.file.ordinal, depth * depth)
|
||||
ordering.addHistory(
|
||||
move.from.rank.ordinal * 8 + move.from.file.ordinal,
|
||||
move.to.rank.ordinal * 8 + move.to.file.ordinal,
|
||||
depth * depth,
|
||||
)
|
||||
ordering.addKillerMove(ply, move)
|
||||
(newBestMove, newBestScore, true)
|
||||
else
|
||||
loop(idx + 1, newBestMove, newBestScore, newA, moveNumber + 1)
|
||||
else loop(idx + 1, newBestMove, newBestScore, newA, moveNumber + 1)
|
||||
|
||||
val (bestMove, bestScore, cutoff) = loop(0, None, -INF, alpha, 0)
|
||||
val flag =
|
||||
@@ -262,8 +263,7 @@ final class AlphaBetaSearch(
|
||||
else
|
||||
val a0 = if inCheck then alpha else math.max(alpha, standPat)
|
||||
|
||||
if ply >= MAX_QUIESCENCE_PLY then
|
||||
if inCheck then weights.evaluateAccumulator(ply, context, hash) else standPat
|
||||
if ply >= MAX_QUIESCENCE_PLY then if inCheck then weights.evaluateAccumulator(ply, context, hash) else standPat
|
||||
else
|
||||
val allMoves = rules.allLegalMoves(context)
|
||||
val tacticalMoves = if inCheck then allMoves else allMoves.filter(m => isCapture(context, m))
|
||||
|
||||
@@ -34,5 +34,4 @@ final class TranspositionTable(val sizePow2: Int = 20):
|
||||
}
|
||||
|
||||
def clear(): Unit =
|
||||
for i <- 0 until size do
|
||||
locks(i).synchronized { table(i) = None }
|
||||
for i <- 0 until size do locks(i).synchronized { table(i) = None }
|
||||
|
||||
@@ -82,15 +82,15 @@ final class PolyglotBook(path: String):
|
||||
|
||||
if isKingMove(context, from) && isRookSquare(to, context) then Some(decodeCastling(from, to))
|
||||
else
|
||||
val moveTypeOpt: Option[MoveType] = if promotionBits > 0 then
|
||||
promotionBits match
|
||||
case 1 => Some(MoveType.Promotion(PromotionPiece.Knight))
|
||||
case 2 => Some(MoveType.Promotion(PromotionPiece.Bishop))
|
||||
case 3 => Some(MoveType.Promotion(PromotionPiece.Rook))
|
||||
case 4 => Some(MoveType.Promotion(PromotionPiece.Queen))
|
||||
case _ => None
|
||||
else
|
||||
Some(MoveType.Normal(context.board.pieces.contains(to)))
|
||||
val moveTypeOpt: Option[MoveType] =
|
||||
if promotionBits > 0 then
|
||||
promotionBits match
|
||||
case 1 => Some(MoveType.Promotion(PromotionPiece.Knight))
|
||||
case 2 => Some(MoveType.Promotion(PromotionPiece.Bishop))
|
||||
case 3 => Some(MoveType.Promotion(PromotionPiece.Rook))
|
||||
case 4 => Some(MoveType.Promotion(PromotionPiece.Queen))
|
||||
case _ => None
|
||||
else Some(MoveType.Normal(context.board.pieces.contains(to)))
|
||||
|
||||
moveTypeOpt.map(moveType => Move(from, to, moveType))
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ object ZobristHash:
|
||||
context.board.pieceAt(move.from).fold(h0) { pawn =>
|
||||
val capturedSquare = Square(move.to.file, move.from.rank)
|
||||
val h1 = h0 ^ pieceKey(move.from, pawn)
|
||||
val h2 = context.board.pieceAt(capturedSquare).fold(h1)(captured => h1 ^ pieceKey(capturedSquare, captured))
|
||||
val h2 = context.board.pieceAt(capturedSquare).fold(h1)(captured => h1 ^ pieceKey(capturedSquare, captured))
|
||||
h2 ^ pieceKey(move.to, pawn)
|
||||
}
|
||||
|
||||
@@ -106,9 +106,12 @@ object ZobristHash:
|
||||
case PromotionPiece.Queen => PieceType.Queen
|
||||
|
||||
private def toggleCastling(h0: Long, before: GameContext, after: GameContext): Long =
|
||||
val h1 = if before.castlingRights.whiteKingSide != after.castlingRights.whiteKingSide then h0 ^ castlingRands(0) else h0
|
||||
val h2 = if before.castlingRights.whiteQueenSide != after.castlingRights.whiteQueenSide then h1 ^ castlingRands(1) else h1
|
||||
val h3 = if before.castlingRights.blackKingSide != after.castlingRights.blackKingSide then h2 ^ castlingRands(2) else h2
|
||||
val h1 =
|
||||
if before.castlingRights.whiteKingSide != after.castlingRights.whiteKingSide then h0 ^ castlingRands(0) else h0
|
||||
val h2 =
|
||||
if before.castlingRights.whiteQueenSide != after.castlingRights.whiteQueenSide then h1 ^ castlingRands(1) else h1
|
||||
val h3 =
|
||||
if before.castlingRights.blackKingSide != after.castlingRights.blackKingSide then h2 ^ castlingRands(2) else h2
|
||||
if before.castlingRights.blackQueenSide != after.castlingRights.blackQueenSide then h3 ^ castlingRands(3) else h3
|
||||
|
||||
private def toggleEnPassant(h0: Long, before: GameContext, after: GameContext): Long =
|
||||
|
||||
Reference in New Issue
Block a user