feat: NCS-43 NNUE high-impact performance optimisations

Implements three high-impact improvements from NCS-43:

1. Incremental L1 accumulator — L1 pre-activations are maintained
   per-ply on a stack and updated by column-add/subtract when pieces
   move, reducing L1 cost from O(768×1536) to O(pieces×1536) for root
   init and O(changed_pieces×1536) per node. Column-major weight
   transposition (l1WeightsT) makes each update sequential in memory.

2. Eval cache — a 256 K direct-mapped cache keyed by Zobrist hash
   skips the L2-L5 forward pass for positions seen during quiescence,
   where the same position often recurs across capture sequences.

3. Dynamic time allocation — NNUEBot now grants 1 500 ms in complex
   positions (> 30 legal moves) and 500 ms in nearly-terminal ones
   (< 5 moves), instead of a fixed 1 000 ms budget.

Accumulator hooks (initAccumulator, pushAccumulator, copyAccumulator,
evaluateAccumulator) are added to the Evaluation trait with no-op
defaults, so ClassicalBot and HybridBot are unaffected.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Janis
2026-04-12 22:30:03 +02:00
committed by Janis
parent 1b5759828b
commit 3750931251
5 changed files with 450 additions and 356 deletions
@@ -0,0 +1,29 @@
package de.nowchess.bot.ai
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
trait Evaluation:
def CHECKMATE_SCORE: Int
def DRAW_SCORE: Int
def evaluate(context: GameContext): Int
// ── Accumulator hooks ─────────────────────────────────────────────────────
// Default implementations fall back to full re-evaluation each call.
// Override in NNUE-capable evaluators for incremental L1 speedup.
/** Initialise the accumulator for the root position at ply 0. */
def initAccumulator(context: GameContext): Unit = ()
/** Copy parent ply's accumulator to childPly without move deltas (null-move). */
def copyAccumulator(parentPly: Int, childPly: Int): Unit = ()
/** Derive childPly's accumulator from parentPly by applying move deltas. */
def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit = ()
/** Evaluate from the pre-computed accumulator at ply, using hash for the eval cache. Falls back to full evaluate when
* not overridden.
*/
def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int = evaluate(context)
@@ -12,14 +12,21 @@ import de.nowchess.rules.sets.DefaultRules
final class NNUEBot( final class NNUEBot(
difficulty: BotDifficulty, difficulty: BotDifficulty,
rules: RuleSet = DefaultRules, rules: RuleSet = DefaultRules,
book: Option[PolyglotBook] = None book: Option[PolyglotBook] = None,
) extends Bot: ) extends Bot:
private val search: AlphaBetaSearch = AlphaBetaSearch(rules, weights = EvaluationNNUE) private val search: AlphaBetaSearch = AlphaBetaSearch(rules, weights = EvaluationNNUE)
private val TIME_BUDGET_MS = 1000L
override val name: String = s"NNUEBot(${difficulty.toString})" override val name: String = s"NNUEBot(${difficulty.toString})"
override def nextMove(context: GameContext): Option[Move] = override def nextMove(context: GameContext): Option[Move] =
book.flatMap(_.probe(context)) book
.orElse(search.bestMoveWithTime(context, TIME_BUDGET_MS)) .flatMap(_.probe(context))
.orElse(search.bestMoveWithTime(context, allocateTime(context)))
/** Allocate more time for complex or critical positions. */
private def allocateTime(context: GameContext): Long =
val moveCount = rules.allLegalMoves(context).length
if moveCount > 30 then 1500L
else if moveCount < 5 then 500L
else 1000L
@@ -1,16 +1,29 @@
package de.nowchess.bot.bots.nnue package de.nowchess.bot.bots.nnue
import de.nowchess.api.game.GameContext import de.nowchess.api.game.GameContext
import de.nowchess.bot.ai.Weights import de.nowchess.api.move.Move
import de.nowchess.bot.ai.Evaluation
object EvaluationNNUE extends Weights: object EvaluationNNUE extends Evaluation:
private val nnue = NNUE() private val nnue = NNUE()
val CHECKMATE_SCORE: Int = 10_000_000 val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0 val DRAW_SCORE: Int = 0
/** Evaluate the position using NNUE neural network. /** Full-board evaluate — used as fallback and by non-search callers. */
* Returns score from the perspective of context.turn (positive = good for the side to move). */ def evaluate(context: GameContext): Int = nnue.evaluate(context)
def evaluate(context: GameContext): Int =
nnue.evaluate(context) // ── Accumulator hooks (incremental L1) ───────────────────────────────────
override def initAccumulator(context: GameContext): Unit =
nnue.initAccumulator(context.board)
override def copyAccumulator(parentPly: Int, childPly: Int): Unit =
nnue.copyAccumulator(parentPly, childPly)
override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit =
nnue.pushAccumulator(childPly, move, parent.board)
override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int =
nnue.evaluateAtPly(ply, context.turn, hash)
@@ -1,34 +1,49 @@
package de.nowchess.bot.bots.nnue package de.nowchess.bot.bots.nnue
import de.nowchess.api.board.{Board, Color, File, PieceType, Rank, Square} import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square}
import de.nowchess.api.game.GameContext import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.ByteOrder import java.nio.ByteOrder
class NNUE: class NNUE:
private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) = loadWeights() private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) =
loadWeights()
private def loadWeights(): (Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float]) = // Column-major L1 weights for cache-friendly sparse & incremental updates.
val stream = getClass.getResourceAsStream("/nnue_weights.bin") // l1WeightsT(featureIdx * 1536 + outputIdx) = l1Weights(outputIdx * 768 + featureIdx)
if stream == null then private val l1WeightsT: Array[Float] =
throw RuntimeException("NNUE weights file not found in resources") val t = new Array[Float](768 * 1536)
for j <- 0 until 768; i <- 0 until 1536 do
t(j * 1536 + i) = l1Weights(i * 768 + j)
t
private def loadWeights(): (
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
) =
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
.getOrElse(sys.error("NNUE weights file not found in resources"))
try try
val bytes = stream.readAllBytes() val bytes = stream.readAllBytes()
val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN) val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
// Read and verify magic number
val magic = buffer.getInt() val magic = buffer.getInt()
if magic != 0x4555_4e4e then // "NNUE" in little-endian if magic != 0x4555_4e4e then sys.error(s"Invalid magic number: 0x${magic.toHexString}")
throw RuntimeException(s"Invalid magic number: 0x${magic.toHexString}")
// Read version
val version = buffer.getInt() val version = buffer.getInt()
if version != 1 then if version != 1 then sys.error(s"Unsupported weight version: $version")
throw RuntimeException(s"Unsupported weight version: $version")
// Read all weight tensors in order
val l1w = readTensor(buffer) val l1w = readTensor(buffer)
val l1b = readTensor(buffer) val l1b = readTensor(buffer)
val l2w = readTensor(buffer) val l2w = readTensor(buffer)
@@ -44,124 +59,186 @@ class NNUE:
finally stream.close() finally stream.close()
private def readTensor(buffer: ByteBuffer): Array[Float] = private def readTensor(buffer: ByteBuffer): Array[Float] =
// Read shape
val shapeLen = buffer.getInt() val shapeLen = buffer.getInt()
val shape = Array.ofDim[Int](shapeLen) val shape = Array.ofDim[Int](shapeLen)
for i <- 0 until shapeLen do for i <- 0 until shapeLen do shape(i) = buffer.getInt()
shape(i) = buffer.getInt()
// Calculate total elements
val totalElements = shape.product val totalElements = shape.product
// Read float data
val floats = Array.ofDim[Float](totalElements) val floats = Array.ofDim[Float](totalElements)
for i <- 0 until totalElements do for i <- 0 until totalElements do floats(i) = buffer.getFloat()
floats(i) = buffer.getFloat()
floats floats
// Pre-allocated buffers for inference (architecture: 768→1536→1024→512→256→1) // ── Accumulator stack ────────────────────────────────────────────────────
private val features = new Array[Float](768) // l1Stack(ply) holds the L1 pre-activations (before ReLU) for that ply.
private val l1Output = new Array[Float](1536) // Initialised once at root; each child ply is derived incrementally.
private val MAX_PLY = 128
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](1536))
// Shared buffers for the dense L2-L5 layers (single-threaded, non-reentrant).
private val l1ReLU = new Array[Float](1536)
private val l2Output = new Array[Float](1024) private val l2Output = new Array[Float](1024)
private val l3Output = new Array[Float](512) private val l3Output = new Array[Float](512)
private val l4Output = new Array[Float](256) private val l4Output = new Array[Float](256)
/** Convert a position to 768-dimensional binary feature vector. // ── Eval cache ───────────────────────────────────────────────────────────
* Layout matches training: black pieces at indices 0-5, white at 6-11. private val EVAL_CACHE_MASK = (1 << 18) - 1L // 256 K slots ≈ 3 MB
* feature_idx = piece_idx * 64 + square (square: a1=0 .. h8=63, no mirroring). */ private val evalCacheHashes = new Array[Long](1 << 18)
private def positionToFeatures(board: Board): Array[Float] = private val evalCacheScores = new Array[Int](1 << 18)
java.util.Arrays.fill(features, 0f)
for // ── Feature helpers ──────────────────────────────────────────────────────
fileIdx <- 0 until 8
rankIdx <- 0 until 8
do
val file = File.values(fileIdx)
val rank = Rank.values(rankIdx)
val square = Square(file, rank)
val squareNum = rankIdx * 8 + fileIdx
board.pieceAt(square).foreach { piece => private def squareNum(sq: Square): Int = sq.rank.ordinal * 8 + sq.file.ordinal
// black pieces → 0-5, white pieces → 6-11 (matches Python training encoding)
private def featureIndex(piece: Piece, sqNum: Int): Int =
val colorOffset = if piece.color == Color.White then 6 else 0 val colorOffset = if piece.color == Color.White then 6 else 0
val pieceIdx = colorOffset + piece.pieceType.ordinal (colorOffset + piece.pieceType.ordinal) * 64 + sqNum
val featureIdx = pieceIdx * 64 + squareNum
features(featureIdx) = 1f private def addColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
val offset = featureIdx * 1536
for i <- 0 until 1536 do l1Pre(i) += l1WeightsT(offset + i)
private def subtractColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
val offset = featureIdx * 1536
for i <- 0 until 1536 do l1Pre(i) -= l1WeightsT(offset + i)
// ── Accumulator init ─────────────────────────────────────────────────────
/** Initialise l1Stack(0) from scratch using sparse active features. */
def initAccumulator(board: Board): Unit =
System.arraycopy(l1Bias, 0, l1Stack(0), 0, 1536)
for (sq, piece) <- board.pieces do addColumn(l1Stack(0), featureIndex(piece, squareNum(sq)))
// ── Accumulator push (incremental updates) ───────────────────────────────
/** Copy parent ply's pre-activations to childPly, then apply move deltas. */
def pushAccumulator(childPly: Int, move: Move, board: Board): Unit =
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, 1536)
val l1 = l1Stack(childPly)
move.moveType match
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
/** Copy pre-activations from parentPly to childPly without any move delta (null-move). */
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, 1536)
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
board.pieceAt(move.from).foreach { mover =>
val fromNum = squareNum(move.from)
val toNum = squareNum(move.to)
subtractColumn(l1, featureIndex(mover, fromNum))
board.pieceAt(move.to).foreach(cap => subtractColumn(l1, featureIndex(cap, toNum)))
addColumn(l1, featureIndex(mover, toNum))
} }
features private def applyEnPassantDelta(l1: Array[Float], move: Move, board: Board): Unit =
board.pieceAt(move.from).foreach { pawn =>
val capturedSq = Square(move.to.file, move.from.rank)
subtractColumn(l1, featureIndex(pawn, squareNum(move.from)))
board.pieceAt(capturedSq).foreach(cap => subtractColumn(l1, featureIndex(cap, squareNum(capturedSq))))
addColumn(l1, featureIndex(pawn, squareNum(move.to)))
}
/** Run NNUE inference on the given position. private def applyCastleDelta(l1: Array[Float], move: Move, board: Board): Unit =
* Returns centipawn score from the perspective of the side-to-move. board.pieceAt(move.from).foreach { king =>
* No allocations in the hot path (uses pre-allocated buffers). val rank = move.from.rank
* Architecture: 768→1536→1024→512→256→1 */ val kingside = move.moveType == MoveType.CastleKingside
def evaluate(context: GameContext): Int = val (rookFrom, rookTo) =
val features = positionToFeatures(context.board) if kingside then (Square(File.H, rank), Square(File.F, rank))
else (Square(File.A, rank), Square(File.D, rank))
val rook = Piece(king.color, PieceType.Rook)
subtractColumn(l1, featureIndex(king, squareNum(move.from)))
addColumn(l1, featureIndex(king, squareNum(move.to)))
subtractColumn(l1, featureIndex(rook, squareNum(rookFrom)))
addColumn(l1, featureIndex(rook, squareNum(rookTo)))
}
// Layer 1: Dense(768 → 1536) + ReLU private def applyPromotionDelta(l1: Array[Float], move: Move, promo: PromotionPiece, board: Board): Unit =
for i <- 0 until 1536 do board.pieceAt(move.from).foreach { pawn =>
var sum = l1Bias(i) val toNum = squareNum(move.to)
for j <- 0 until 768 do subtractColumn(l1, featureIndex(pawn, squareNum(move.from)))
sum += features(j) * l1Weights(i * 768 + j) board.pieceAt(move.to).foreach(cap => subtractColumn(l1, featureIndex(cap, toNum)))
l1Output(i) = if sum > 0f then sum else 0f addColumn(l1, featureIndex(Piece(pawn.color, promotedType(promo)), toNum))
}
// Layer 2: Dense(1536 → 1024) + ReLU private def promotedType(promo: PromotionPiece): PieceType = promo match
for i <- 0 until 1024 do case PromotionPiece.Knight => PieceType.Knight
var sum = l2Bias(i) case PromotionPiece.Bishop => PieceType.Bishop
for j <- 0 until 1536 do case PromotionPiece.Rook => PieceType.Rook
sum += l1Output(j) * l2Weights(i * 1536 + j) case PromotionPiece.Queen => PieceType.Queen
l2Output(i) = if sum > 0f then sum else 0f
// Layer 3: Dense(1024 → 512) + ReLU // ── Evaluation from accumulator ──────────────────────────────────────────
for i <- 0 until 512 do
var sum = l3Bias(i)
for j <- 0 until 1024 do
sum += l2Output(j) * l3Weights(i * 1024 + j)
l3Output(i) = if sum > 0f then sum else 0f
// Layer 4: Dense(512 → 256) + ReLU /** Evaluate from pre-computed L1 pre-activations at the given ply. Probes eval cache first; stores result after
for i <- 0 until 256 do * computation.
var sum = l4Bias(i) */
for j <- 0 until 512 do def evaluateAtPly(ply: Int, turn: Color, hash: Long): Int =
sum += l3Output(j) * l4Weights(i * 512 + j) val idx = (hash & EVAL_CACHE_MASK).toInt
l4Output(i) = if sum > 0f then sum else 0f if evalCacheHashes(idx) == hash then evalCacheScores(idx)
else
val score = runL2toOutput(l1Stack(ply), turn)
evalCacheHashes(idx) = hash
evalCacheScores(idx) = score
score
// Layer 5: Dense(256 → 1), no activation private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
var output = l5Bias(0) for i <- 0 until 1536 do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
for j <- 0 until 256 do runDenseReLU(l1ReLU, 1536, l2Weights, l2Bias, l2Output, 1024)
output += l4Output(j) * l5Weights(j) runDenseReLU(l2Output, 1024, l3Weights, l3Bias, l3Output, 512)
runDenseReLU(l3Output, 512, l4Weights, l4Bias, l4Output, 256)
val output = runOutputLayer(l4Output, 256)
scoreFromOutput(output, turn)
// Convert from tanh-normalized output back to centipawns. private def runDenseReLU(
// Training uses: eval_normalized = tanh(eval_cp / 300) always from White's perspective. input: Array[Float],
// Inverse: eval_cp = 300 * atanh(output); negate for Black to return from side-to-move perspective. inSize: Int,
val cp = if math.abs(output) >= 0.9999f then weights: Array[Float],
if output > 0f then 20000 else -20000 bias: Array[Float],
output: Array[Float],
outSize: Int,
): Unit =
for i <- 0 until outSize do
val sum = (0 until inSize).foldLeft(bias(i))((s, j) => s + input(j) * weights(i * inSize + j))
output(i) = if sum > 0f then sum else 0f
private def runOutputLayer(input: Array[Float], inSize: Int): Float =
(0 until inSize).foldLeft(l5Bias(0))((sum, j) => sum + input(j) * l5Weights(j))
private def scoreFromOutput(output: Float, turn: Color): Int =
val cp =
if math.abs(output) >= 0.9999f then if output > 0f then 20000 else -20000
else else
val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat
(300f * atanh).toInt (300f * atanh).toInt
val cpFromTurn = if turn == Color.Black then -cp else cp
val cpFromTurn = if context.turn == Color.Black then -cp else cp
math.max(-20000, math.min(20000, cpFromTurn)) math.max(-20000, math.min(20000, cpFromTurn))
/** Benchmark: time 1M evaluations and report ns/eval. // ── Legacy full-board evaluate (kept for Evaluation.evaluate compatibility) ──
* This measures the performance of the inference on the starting position. */
// Pre-allocated buffers used only by the legacy evaluate path.
private val features = new Array[Float](768)
private val legacyL1 = new Array[Float](1536)
/** Evaluate using full board scan (sparse over active features). Layout: black pieces at indices 0-5, white at 6-11.
*/
def evaluate(context: GameContext): Int =
val l1Pre = legacyL1
System.arraycopy(l1Bias, 0, l1Pre, 0, 1536)
for (sq, piece) <- context.board.pieces do addColumn(l1Pre, featureIndex(piece, squareNum(sq)))
runL2toOutput(l1Pre, context.turn)
/** Benchmark: time 1M evaluations and report ns/eval. */
def benchmark(): Unit = def benchmark(): Unit =
val context = GameContext.initial val context = GameContext.initial
val iterations = 1_000_000 val iterations = 1_000_000
for _ <- 0 until 10000 do evaluate(context)
// Warm up
for _ <- 0 until 10000 do
evaluate(context)
// Actual benchmark
val startNanos = System.nanoTime() val startNanos = System.nanoTime()
for _ <- 0 until iterations do for _ <- 0 until iterations do evaluate(context)
evaluate(context)
val endNanos = System.nanoTime() val endNanos = System.nanoTime()
val totalNanos = endNanos - startNanos val totalNanos = endNanos - startNanos
val nanosPerEval = totalNanos.toDouble / iterations val nanosPerEval = totalNanos.toDouble / iterations
println() println()
println("=" * 60) println("=" * 60)
println("NNUE BENCHMARK RESULTS") println("NNUE BENCHMARK RESULTS")
@@ -3,17 +3,18 @@ package de.nowchess.bot.logic
import de.nowchess.api.board.PieceType import de.nowchess.api.board.PieceType
import de.nowchess.api.game.GameContext import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType} import de.nowchess.api.move.{Move, MoveType}
import de.nowchess.bot.ai.Weights import de.nowchess.bot.ai.Evaluation
import de.nowchess.bot.logic.{MoveOrdering, TTEntry, TTFlag, TranspositionTable} import de.nowchess.bot.logic.{MoveOrdering, TTEntry, TTFlag, TranspositionTable}
import de.nowchess.bot.util.ZobristHash import de.nowchess.bot.util.ZobristHash
import de.nowchess.rules.RuleSet import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules import de.nowchess.rules.sets.DefaultRules
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}
final class AlphaBetaSearch( final class AlphaBetaSearch(
rules: RuleSet = DefaultRules, rules: RuleSet = DefaultRules,
tt: TranspositionTable = TranspositionTable(), tt: TranspositionTable = TranspositionTable(),
weights: Weights, weights: Evaluation,
numThreads: Int = Runtime.getRuntime.availableProcessors numThreads: Int = Runtime.getRuntime.availableProcessors,
): ):
private val INF = Int.MaxValue / 2 private val INF = Int.MaxValue / 2
@@ -25,63 +26,52 @@ final class AlphaBetaSearch(
private val FUTILITY_MARGIN = 100 private val FUTILITY_MARGIN = 100
private val CHECK_EXTENSION = 1 private val CHECK_EXTENSION = 1
@volatile private var timeStartMs = 0L private val timeStartMs = AtomicLong(0L)
@volatile private var timeLimitMs = 0L private val timeLimitMs = AtomicLong(0L)
@volatile private var nodeCount = 0 private val nodeCount = AtomicInteger(0)
private val ordering = MoveOrdering.OrderingContext() private val ordering = MoveOrdering.OrderingContext()
/** Return the best move for the side to move, searching to maxDepth plies. /** Return the best move for the side to move, searching to maxDepth plies. Uses iterative deepening with aspiration
* Uses iterative deepening with aspiration windows. */ * windows.
*/
def bestMove(context: GameContext, maxDepth: Int): Option[Move] = def bestMove(context: GameContext, maxDepth: Int): Option[Move] =
tt.clear() tt.clear()
ordering.clear() ordering.clear()
timeStartMs = System.currentTimeMillis weights.initAccumulator(context)
timeLimitMs = Long.MaxValue / 4 timeStartMs.set(System.currentTimeMillis)
nodeCount = 0 timeLimitMs.set(Long.MaxValue / 4)
var bestSoFar: Option[Move] = None nodeCount.set(0)
var prevScore = 0
var aspWindow = ASPIRATION_DELTA
for depth <- 1 to maxDepth do
val (alpha, beta) =
if depth == 1 then (-INF, INF)
else (prevScore - aspWindow, prevScore + aspWindow)
val rootHash = ZobristHash.hash(context) val rootHash = ZobristHash.hash(context)
val (score, move) = searchWithAspiration(context, depth, alpha, beta, aspWindow, rootHash) (1 to maxDepth).foldLeft((None: Option[Move], 0)) { case ((bestSoFar, prevScore), depth) =>
prevScore = score val (alpha, beta) = if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
move.foreach(m => bestSoFar = Some(m)) val (score, move) = searchWithAspiration(context, depth, alpha, beta, ASPIRATION_DELTA, rootHash)
aspWindow = ASPIRATION_DELTA (move.orElse(bestSoFar), score)
bestSoFar }._1
/** Return the best move for the side to move within a time budget (ms). /** Return the best move for the side to move within a time budget (ms). Uses iterative deepening, stopping when time
* Uses iterative deepening, stopping when time runs out. */ * runs out.
*/
def bestMoveWithTime(context: GameContext, timeBudgetMs: Long): Option[Move] = def bestMoveWithTime(context: GameContext, timeBudgetMs: Long): Option[Move] =
tt.clear() tt.clear()
ordering.clear() ordering.clear()
timeStartMs = System.currentTimeMillis weights.initAccumulator(context)
timeLimitMs = timeBudgetMs timeStartMs.set(System.currentTimeMillis)
nodeCount = 0 timeLimitMs.set(timeBudgetMs)
var bestSoFar: Option[Move] = None nodeCount.set(0)
var prevScore = 0
var depth = 1
var aspWindow = ASPIRATION_DELTA
while !isOutOfTime do
val (alpha, beta) =
if depth == 1 then (-INF, INF)
else (prevScore - aspWindow, prevScore + aspWindow)
val rootHash = ZobristHash.hash(context) val rootHash = ZobristHash.hash(context)
val (score, move) = searchWithAspiration(context, depth, alpha, beta, aspWindow, rootHash)
prevScore = score @scala.annotation.tailrec
move match def loop(bestSoFar: Option[Move], prevScore: Int, depth: Int): Option[Move] =
case Some(m) => if isOutOfTime then bestSoFar
bestSoFar = Some(m) else
case None => val (alpha, beta) = if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
aspWindow = ASPIRATION_DELTA val (score, move) = searchWithAspiration(context, depth, alpha, beta, ASPIRATION_DELTA, rootHash)
depth += 1 loop(move.orElse(bestSoFar), score, depth + 1)
bestSoFar
loop(None, 0, 1)
private def isOutOfTime: Boolean = private def isOutOfTime: Boolean =
System.currentTimeMillis - timeStartMs >= timeLimitMs System.currentTimeMillis - timeStartMs.get >= timeLimitMs.get
private def searchWithAspiration( private def searchWithAspiration(
context: GameContext, context: GameContext,
@@ -89,27 +79,23 @@ final class AlphaBetaSearch(
alpha: Int, alpha: Int,
beta: Int, beta: Int,
initialWindow: Int, initialWindow: Int,
rootHash: Long rootHash: Long,
): (Int, Option[Move]) = ): (Int, Option[Move]) =
var currentAlpha = alpha
var currentBeta = beta
var window = initialWindow
var attempt = 0
val repetitions = Map(rootHash -> 1) val repetitions = Map(rootHash -> 1)
while attempt < 3 && attempt < depth do @scala.annotation.tailrec
val (score, move) = search(context, depth, 0, currentAlpha, currentBeta, rootHash, repetitions) def loop(currentAlpha: Int, currentBeta: Int, window: Int, attempt: Int): (Int, Option[Move]) =
if score > currentAlpha && score < currentBeta then if attempt >= 3 || attempt >= depth then
return (score, move)
if score <= currentAlpha then
currentAlpha = score - window
window = math.min(window * 2, ASPIRATION_DELTA_MAX)
if score >= currentBeta then
currentBeta = score + window
window = math.min(window * 2, ASPIRATION_DELTA_MAX)
attempt += 1
search(context, depth, 0, -INF, INF, rootHash, repetitions) search(context, depth, 0, -INF, INF, rootHash, repetitions)
else
val (score, move) = search(context, depth, 0, currentAlpha, currentBeta, rootHash, repetitions)
if score > currentAlpha && score < currentBeta then (score, move)
else if score <= currentAlpha then
loop(score - window, currentBeta, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1)
else
loop(currentAlpha, score + window, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1)
loop(alpha, beta, initialWindow, 0)
private def hasNonPawnMaterial(context: GameContext): Boolean = private def hasNonPawnMaterial(context: GameContext): Boolean =
context.board.pieces.values.exists { piece => context.board.pieces.values.exists { piece =>
@@ -126,7 +112,7 @@ final class AlphaBetaSearch(
depth: Int, depth: Int,
ply: Int, ply: Int,
beta: Int, beta: Int,
repetitions: Map[Long, Int] repetitions: Map[Long, Int],
): Option[Int] = ): Option[Int] =
val nullCtx = nullMoveContext(context) val nullCtx = nullMoveContext(context)
val nullHash = ZobristHash.hash(nullCtx) val nullHash = ZobristHash.hash(nullCtx)
@@ -135,6 +121,7 @@ final class AlphaBetaSearch(
case None => Some(1) case None => Some(1)
} }
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R) val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
weights.copyAccumulator(ply, ply + 1)
val (score, _) = search(nullCtx, reductionDepth, ply + 1, -beta, -beta + 1, nullHash, nullRepetitions) val (score, _) = search(nullCtx, reductionDepth, ply + 1, -beta, -beta + 1, nullHash, nullRepetitions)
if -score >= beta then Some(beta) else None if -score >= beta then Some(beta) else None
@@ -146,58 +133,44 @@ final class AlphaBetaSearch(
alpha: Int, alpha: Int,
beta: Int, beta: Int,
hash: Long, hash: Long,
repetitions: Map[Long, Int] repetitions: Map[Long, Int],
): (Int, Option[Move]) = ): (Int, Option[Move]) =
// Periodic time check val count = nodeCount.incrementAndGet()
nodeCount += 1 if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then
if nodeCount % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then (weights.evaluateAccumulator(ply, context, hash), None)
return (weights.evaluate(context), None) else if repetitions.getOrElse(hash, 0) >= 3 then
(weights.DRAW_SCORE, None)
if repetitions.getOrElse(hash, 0) >= 3 then else
return (weights.DRAW_SCORE, None) val ttCutoff = tt.probe(hash).filter(_.depth >= depth).flatMap { entry =>
// TT probe
tt.probe(hash) match
case Some(entry) if entry.depth >= depth =>
entry.flag match entry.flag match
case TTFlag.Exact => return (entry.score, entry.bestMove) case TTFlag.Exact => Some((entry.score, entry.bestMove))
case TTFlag.Lower => case TTFlag.Lower =>
val newAlpha = math.max(alpha, entry.score) val newAlpha = math.max(alpha, entry.score)
if newAlpha >= beta then return (entry.score, entry.bestMove) Option.when(newAlpha >= beta)((entry.score, entry.bestMove))
case TTFlag.Upper => case TTFlag.Upper =>
val newBeta = math.min(beta, entry.score) val newBeta = math.min(beta, entry.score)
if alpha >= newBeta then return (entry.score, entry.bestMove) Option.when(alpha >= newBeta)((entry.score, entry.bestMove))
case _ => }
ttCutoff.getOrElse {
// Terminal node check
val legalMoves = rules.allLegalMoves(context) val legalMoves = rules.allLegalMoves(context)
if legalMoves.isEmpty then if legalMoves.isEmpty then
val score = if rules.isCheckmate(context) then (if rules.isCheckmate(context) then -(weights.CHECKMATE_SCORE - ply) else weights.DRAW_SCORE, None)
-(weights.CHECKMATE_SCORE - ply) else if rules.isInsufficientMaterial(context) || rules.isFiftyMoveRule(context) then
(weights.DRAW_SCORE, None)
else if depth == 0 then
(quiescence(context, ply, alpha, beta, hash), None)
else else
weights.DRAW_SCORE val nullResult = Option
return (score, None) .when(depth >= 3 && !rules.isCheck(context) && hasNonPawnMaterial(context)) {
tryNullMove(context, depth, ply, beta, repetitions)
if rules.isInsufficientMaterial(context) || rules.isFiftyMoveRule(context) then }
return (weights.DRAW_SCORE, None) .flatten
nullResult.map((_, None)).getOrElse {
// Leaf node: call quiescence
if depth == 0 then
return (quiescence(context, ply, alpha, beta), None)
// Null move pruning
if depth >= 3 && !rules.isCheck(context) && hasNonPawnMaterial(context) then
tryNullMove(context, depth, ply, beta, repetitions) match
case Some(score) => return (score, None)
case None =>
// Get TT best move for ordering
val ttBest = tt.probe(hash).flatMap(_.bestMove) val ttBest = tt.probe(hash).flatMap(_.bestMove)
// Order moves
val ordered = MoveOrdering.sort(context, legalMoves, ttBest, ply, ordering) val ordered = MoveOrdering.sort(context, legalMoves, ttBest, ply, ordering)
searchSequential(context, depth, ply, alpha, beta, ordered, hash, repetitions) searchSequential(context, depth, ply, alpha, beta, ordered, hash, repetitions)
}
}
private def searchSequential( private def searchSequential(
context: GameContext, context: GameContext,
@@ -207,32 +180,30 @@ final class AlphaBetaSearch(
beta: Int, beta: Int,
ordered: List[Move], ordered: List[Move],
hash: Long, hash: Long,
repetitions: Map[Long, Int] repetitions: Map[Long, Int],
): (Int, Option[Move]) = ): (Int, Option[Move]) =
var bestMove: Option[Move] = None @scala.annotation.tailrec
var bestScore = -INF def loop(
var a = alpha idx: Int,
var moveNumber = 0 bestMove: Option[Move],
var cutoff = false bestScore: Int,
a: Int,
var idx = 0 moveNumber: Int,
while idx < ordered.length && !cutoff do ): (Option[Move], Int, Boolean) =
if idx >= ordered.length then (bestMove, bestScore, false)
else
val move = ordered(idx) val move = ordered(idx)
idx += 1
moveNumber += 1
val isQuiet = !isCapture(context, move) && val isQuiet = !isCapture(context, move) &&
move.moveType != MoveType.CastleKingside && move.moveType != MoveType.CastleKingside &&
move.moveType != MoveType.CastleQueenside move.moveType != MoveType.CastleQueenside
val pruneByFutility = depth == 1 && isQuiet && moveNumber > 2 &&
weights.evaluateAccumulator(ply, context, hash) + FUTILITY_MARGIN < alpha
// Futility pruning at frontier nodes: if static eval + margin is still below alpha, skip quiet moves if pruneByFutility then loop(idx + 1, bestMove, bestScore, a, moveNumber + 1)
val pruneByFutility = if depth == 1 && isQuiet && moveNumber > 2 then else
val staticEval = weights.evaluate(context)
staticEval + FUTILITY_MARGIN < alpha
else false
if !pruneByFutility then
val child = rules.applyMove(context)(move) val child = rules.applyMove(context)(move)
val childHash = ZobristHash.nextHash(context, hash, move, child) val childHash = ZobristHash.nextHash(context, hash, move, child)
weights.pushAccumulator(ply + 1, move, context, child)
val childRepetitions = repetitions.updatedWith(childHash) { val childRepetitions = repetitions.updatedWith(childHash) {
case Some(v) => Some(v + 1) case Some(v) => Some(v + 1)
case None => Some(1) case None => Some(1)
@@ -241,7 +212,8 @@ final class AlphaBetaSearch(
val extension = if givesCheck then CHECK_EXTENSION else 0 val extension = if givesCheck then CHECK_EXTENSION else 0
val reduction = if moveNumber > 4 && depth >= 3 && isQuiet then 1 else 0 val reduction = if moveNumber > 4 && depth >= 3 && isQuiet then 1 else 0
val score = if reduction > 0 then val score =
if reduction > 0 then
val reducedDepth = math.max(0, depth - 1 - reduction + extension) val reducedDepth = math.max(0, depth - 1 - reduction + extension)
val (reducedScore, _) = search(child, reducedDepth, ply + 1, -a - 1, -a, childHash, childRepetitions) val (reducedScore, _) = search(child, reducedDepth, ply + 1, -a - 1, -a, childHash, childRepetitions)
val s = -reducedScore val s = -reducedScore
@@ -255,20 +227,19 @@ final class AlphaBetaSearch(
val (rawScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childHash, childRepetitions) val (rawScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childHash, childRepetitions)
-rawScore -rawScore
if score > bestScore then val newBestScore = math.max(bestScore, score)
bestScore = score val newBestMove = if score > bestScore then Some(move) else bestMove
bestMove = Some(move) val newA = math.max(a, score)
a = math.max(a, score) if newA >= beta then
if a >= beta then
if isQuiet then if isQuiet then
val fromIdx = move.from.rank.ordinal * 8 + move.from.file.ordinal ordering.addHistory(move.from.rank.ordinal * 8 + move.from.file.ordinal, move.to.rank.ordinal * 8 + move.to.file.ordinal, depth * depth)
val toIdx = move.to.rank.ordinal * 8 + move.to.file.ordinal
ordering.addHistory(fromIdx, toIdx, depth * depth)
ordering.addKillerMove(ply, move) ordering.addKillerMove(ply, move)
cutoff = true (newBestMove, newBestScore, true)
else
loop(idx + 1, newBestMove, newBestScore, newA, moveNumber + 1)
val (bestMove, bestScore, cutoff) = loop(0, None, -INF, alpha, 0)
val flag = val flag =
if cutoff then TTFlag.Lower if cutoff then TTFlag.Lower
else if bestScore <= alpha then TTFlag.Upper else if bestScore <= alpha then TTFlag.Upper
@@ -281,43 +252,40 @@ final class AlphaBetaSearch(
context: GameContext, context: GameContext,
ply: Int, ply: Int,
alpha: Int, alpha: Int,
beta: Int beta: Int,
hash: Long,
): Int = ): Int =
val inCheck = rules.isCheck(context) val inCheck = rules.isCheck(context)
val standPat = if inCheck then -INF else weights.evaluate(context) val standPat = if inCheck then -INF else weights.evaluateAccumulator(ply, context, hash)
if !inCheck && standPat >= beta then
return beta
var a = if inCheck then alpha else math.max(alpha, standPat) if !inCheck && standPat >= beta then beta
else
val a0 = if inCheck then alpha else math.max(alpha, standPat)
// Guard against infinite quiescence
if ply >= MAX_QUIESCENCE_PLY then if ply >= MAX_QUIESCENCE_PLY then
return if inCheck then weights.evaluate(context) else standPat if inCheck then weights.evaluateAccumulator(ply, context, hash) else standPat
else
// Generate captures, or all evasions when side-to-move is in check.
val allMoves = rules.allLegalMoves(context) val allMoves = rules.allLegalMoves(context)
val tacticalMoves = if inCheck then allMoves else allMoves.filter(m => isCapture(context, m)) val tacticalMoves = if inCheck then allMoves else allMoves.filter(m => isCapture(context, m))
if inCheck && tacticalMoves.isEmpty then if inCheck && tacticalMoves.isEmpty then -(weights.CHECKMATE_SCORE - ply)
return -(weights.CHECKMATE_SCORE - ply) else
val ordered = MoveOrdering.sort(context, tacticalMoves, None) val ordered = MoveOrdering.sort(context, tacticalMoves, None)
var cutoff = false @scala.annotation.tailrec
var idx = 0 def loop(idx: Int, a: Int): Int =
while idx < ordered.length && !cutoff do if idx >= ordered.length then a
val move = ordered(idx)
idx += 1
val child = rules.applyMove(context)(move)
val score = -quiescence(child, ply + 1, -beta, -a)
if score >= beta then
a = beta
cutoff = true
else else
a = math.max(a, score) val move = ordered(idx)
a val child = rules.applyMove(context)(move)
val childHash = ZobristHash.nextHash(context, hash, move, child)
weights.pushAccumulator(ply + 1, move, context, child)
val score = -quiescence(child, ply + 1, -beta, -a, childHash)
if score >= beta then beta
else loop(idx + 1, math.max(a, score))
loop(0, a0)
/** Predicate: context-aware capture classification. */
private def isCapture(context: GameContext, move: Move): Boolean = move.moveType match private def isCapture(context: GameContext, move: Move): Boolean = move.moveType match
case MoveType.Normal(true) => true case MoveType.Normal(true) => true
case MoveType.EnPassant => true case MoveType.EnPassant => true