feat: NCS-43 NNUE high-impact performance optimisations
Implements three high-impact improvements from NCS-43: 1. Incremental L1 accumulator — L1 pre-activations are maintained per-ply on a stack and updated by column-add/subtract when pieces move, reducing L1 cost from O(768×1536) to O(pieces×1536) for root init and O(changed_pieces×1536) per node. Column-major weight transposition (l1WeightsT) makes each update sequential in memory. 2. Eval cache — a 256 K direct-mapped cache keyed by Zobrist hash skips the L2-L5 forward pass for positions seen during quiescence, where the same position often recurs across capture sequences. 3. Dynamic time allocation — NNUEBot now grants 1 500 ms in complex positions (> 30 legal moves) and 500 ms in nearly-terminal ones (< 5 moves), instead of a fixed 1 000 ms budget. Accumulator hooks (initAccumulator, pushAccumulator, copyAccumulator, evaluateAccumulator) are added to the Evaluation trait with no-op defaults, so ClassicalBot and HybridBot are unaffected. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
package de.nowchess.bot.ai
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
|
||||
trait Evaluation:
|
||||
|
||||
def CHECKMATE_SCORE: Int
|
||||
def DRAW_SCORE: Int
|
||||
|
||||
def evaluate(context: GameContext): Int
|
||||
|
||||
// ── Accumulator hooks ─────────────────────────────────────────────────────
|
||||
// Default implementations fall back to full re-evaluation each call.
|
||||
// Override in NNUE-capable evaluators for incremental L1 speedup.
|
||||
|
||||
/** Initialise the accumulator for the root position at ply 0. */
|
||||
def initAccumulator(context: GameContext): Unit = ()
|
||||
|
||||
/** Copy parent ply's accumulator to childPly without move deltas (null-move). */
|
||||
def copyAccumulator(parentPly: Int, childPly: Int): Unit = ()
|
||||
|
||||
/** Derive childPly's accumulator from parentPly by applying move deltas. */
|
||||
def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit = ()
|
||||
|
||||
/** Evaluate from the pre-computed accumulator at ply, using hash for the eval cache. Falls back to full evaluate when
|
||||
* not overridden.
|
||||
*/
|
||||
def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int = evaluate(context)
|
||||
@@ -10,16 +10,23 @@ import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
|
||||
final class NNUEBot(
|
||||
difficulty: BotDifficulty,
|
||||
rules: RuleSet = DefaultRules,
|
||||
book: Option[PolyglotBook] = None
|
||||
difficulty: BotDifficulty,
|
||||
rules: RuleSet = DefaultRules,
|
||||
book: Option[PolyglotBook] = None,
|
||||
) extends Bot:
|
||||
|
||||
private val search: AlphaBetaSearch = AlphaBetaSearch(rules, weights = EvaluationNNUE)
|
||||
private val TIME_BUDGET_MS = 1000L
|
||||
|
||||
override val name: String = s"NNUEBot(${difficulty.toString})"
|
||||
|
||||
override def nextMove(context: GameContext): Option[Move] =
|
||||
book.flatMap(_.probe(context))
|
||||
.orElse(search.bestMoveWithTime(context, TIME_BUDGET_MS))
|
||||
book
|
||||
.flatMap(_.probe(context))
|
||||
.orElse(search.bestMoveWithTime(context, allocateTime(context)))
|
||||
|
||||
/** Allocate more time for complex or critical positions. */
|
||||
private def allocateTime(context: GameContext): Long =
|
||||
val moveCount = rules.allLegalMoves(context).length
|
||||
if moveCount > 30 then 1500L
|
||||
else if moveCount < 5 then 500L
|
||||
else 1000L
|
||||
|
||||
@@ -1,16 +1,29 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.bot.ai.Weights
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.ai.Evaluation
|
||||
|
||||
object EvaluationNNUE extends Weights:
|
||||
object EvaluationNNUE extends Evaluation:
|
||||
|
||||
private val nnue = NNUE()
|
||||
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
val DRAW_SCORE: Int = 0
|
||||
|
||||
/** Evaluate the position using NNUE neural network.
|
||||
* Returns score from the perspective of context.turn (positive = good for the side to move). */
|
||||
def evaluate(context: GameContext): Int =
|
||||
nnue.evaluate(context)
|
||||
/** Full-board evaluate — used as fallback and by non-search callers. */
|
||||
def evaluate(context: GameContext): Int = nnue.evaluate(context)
|
||||
|
||||
// ── Accumulator hooks (incremental L1) ───────────────────────────────────
|
||||
|
||||
override def initAccumulator(context: GameContext): Unit =
|
||||
nnue.initAccumulator(context.board)
|
||||
|
||||
override def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
||||
nnue.copyAccumulator(parentPly, childPly)
|
||||
|
||||
override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit =
|
||||
nnue.pushAccumulator(childPly, move, parent.board)
|
||||
|
||||
override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int =
|
||||
nnue.evaluateAtPly(ply, context.turn, hash)
|
||||
|
||||
@@ -1,34 +1,49 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import de.nowchess.api.board.{Board, Color, File, PieceType, Rank, Square}
|
||||
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square}
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
|
||||
class NNUE:
|
||||
|
||||
private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) = loadWeights()
|
||||
private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) =
|
||||
loadWeights()
|
||||
|
||||
private def loadWeights(): (Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float]) =
|
||||
val stream = getClass.getResourceAsStream("/nnue_weights.bin")
|
||||
if stream == null then
|
||||
throw RuntimeException("NNUE weights file not found in resources")
|
||||
// Column-major L1 weights for cache-friendly sparse & incremental updates.
|
||||
// l1WeightsT(featureIdx * 1536 + outputIdx) = l1Weights(outputIdx * 768 + featureIdx)
|
||||
private val l1WeightsT: Array[Float] =
|
||||
val t = new Array[Float](768 * 1536)
|
||||
for j <- 0 until 768; i <- 0 until 1536 do
|
||||
t(j * 1536 + i) = l1Weights(i * 768 + j)
|
||||
t
|
||||
|
||||
private def loadWeights(): (
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
Array[Float],
|
||||
) =
|
||||
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
|
||||
.getOrElse(sys.error("NNUE weights file not found in resources"))
|
||||
|
||||
try
|
||||
val bytes = stream.readAllBytes()
|
||||
val bytes = stream.readAllBytes()
|
||||
val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
|
||||
|
||||
// Read and verify magic number
|
||||
val magic = buffer.getInt()
|
||||
if magic != 0x4555_4e4e then // "NNUE" in little-endian
|
||||
throw RuntimeException(s"Invalid magic number: 0x${magic.toHexString}")
|
||||
if magic != 0x4555_4e4e then sys.error(s"Invalid magic number: 0x${magic.toHexString}")
|
||||
|
||||
// Read version
|
||||
val version = buffer.getInt()
|
||||
if version != 1 then
|
||||
throw RuntimeException(s"Unsupported weight version: $version")
|
||||
if version != 1 then sys.error(s"Unsupported weight version: $version")
|
||||
|
||||
// Read all weight tensors in order
|
||||
val l1w = readTensor(buffer)
|
||||
val l1b = readTensor(buffer)
|
||||
val l2w = readTensor(buffer)
|
||||
@@ -44,124 +59,186 @@ class NNUE:
|
||||
finally stream.close()
|
||||
|
||||
private def readTensor(buffer: ByteBuffer): Array[Float] =
|
||||
// Read shape
|
||||
val shapeLen = buffer.getInt()
|
||||
val shape = Array.ofDim[Int](shapeLen)
|
||||
for i <- 0 until shapeLen do
|
||||
shape(i) = buffer.getInt()
|
||||
|
||||
// Calculate total elements
|
||||
val shape = Array.ofDim[Int](shapeLen)
|
||||
for i <- 0 until shapeLen do shape(i) = buffer.getInt()
|
||||
val totalElements = shape.product
|
||||
|
||||
// Read float data
|
||||
val floats = Array.ofDim[Float](totalElements)
|
||||
for i <- 0 until totalElements do
|
||||
floats(i) = buffer.getFloat()
|
||||
val floats = Array.ofDim[Float](totalElements)
|
||||
for i <- 0 until totalElements do floats(i) = buffer.getFloat()
|
||||
floats
|
||||
|
||||
// Pre-allocated buffers for inference (architecture: 768→1536→1024→512→256→1)
|
||||
private val features = new Array[Float](768)
|
||||
private val l1Output = new Array[Float](1536)
|
||||
// ── Accumulator stack ────────────────────────────────────────────────────
|
||||
// l1Stack(ply) holds the L1 pre-activations (before ReLU) for that ply.
|
||||
// Initialised once at root; each child ply is derived incrementally.
|
||||
|
||||
private val MAX_PLY = 128
|
||||
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](1536))
|
||||
|
||||
// Shared buffers for the dense L2-L5 layers (single-threaded, non-reentrant).
|
||||
private val l1ReLU = new Array[Float](1536)
|
||||
private val l2Output = new Array[Float](1024)
|
||||
private val l3Output = new Array[Float](512)
|
||||
private val l4Output = new Array[Float](256)
|
||||
|
||||
/** Convert a position to 768-dimensional binary feature vector.
|
||||
* Layout matches training: black pieces at indices 0-5, white at 6-11.
|
||||
* feature_idx = piece_idx * 64 + square (square: a1=0 .. h8=63, no mirroring). */
|
||||
private def positionToFeatures(board: Board): Array[Float] =
|
||||
java.util.Arrays.fill(features, 0f)
|
||||
// ── Eval cache ───────────────────────────────────────────────────────────
|
||||
private val EVAL_CACHE_MASK = (1 << 18) - 1L // 256 K slots ≈ 3 MB
|
||||
private val evalCacheHashes = new Array[Long](1 << 18)
|
||||
private val evalCacheScores = new Array[Int](1 << 18)
|
||||
|
||||
for
|
||||
fileIdx <- 0 until 8
|
||||
rankIdx <- 0 until 8
|
||||
do
|
||||
val file = File.values(fileIdx)
|
||||
val rank = Rank.values(rankIdx)
|
||||
val square = Square(file, rank)
|
||||
val squareNum = rankIdx * 8 + fileIdx
|
||||
// ── Feature helpers ──────────────────────────────────────────────────────
|
||||
|
||||
board.pieceAt(square).foreach { piece =>
|
||||
// black pieces → 0-5, white pieces → 6-11 (matches Python training encoding)
|
||||
val colorOffset = if piece.color == Color.White then 6 else 0
|
||||
val pieceIdx = colorOffset + piece.pieceType.ordinal
|
||||
val featureIdx = pieceIdx * 64 + squareNum
|
||||
features(featureIdx) = 1f
|
||||
}
|
||||
private def squareNum(sq: Square): Int = sq.rank.ordinal * 8 + sq.file.ordinal
|
||||
|
||||
features
|
||||
private def featureIndex(piece: Piece, sqNum: Int): Int =
|
||||
val colorOffset = if piece.color == Color.White then 6 else 0
|
||||
(colorOffset + piece.pieceType.ordinal) * 64 + sqNum
|
||||
|
||||
/** Run NNUE inference on the given position.
|
||||
* Returns centipawn score from the perspective of the side-to-move.
|
||||
* No allocations in the hot path (uses pre-allocated buffers).
|
||||
* Architecture: 768→1536→1024→512→256→1 */
|
||||
def evaluate(context: GameContext): Int =
|
||||
val features = positionToFeatures(context.board)
|
||||
private def addColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
||||
val offset = featureIdx * 1536
|
||||
for i <- 0 until 1536 do l1Pre(i) += l1WeightsT(offset + i)
|
||||
|
||||
// Layer 1: Dense(768 → 1536) + ReLU
|
||||
for i <- 0 until 1536 do
|
||||
var sum = l1Bias(i)
|
||||
for j <- 0 until 768 do
|
||||
sum += features(j) * l1Weights(i * 768 + j)
|
||||
l1Output(i) = if sum > 0f then sum else 0f
|
||||
private def subtractColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
||||
val offset = featureIdx * 1536
|
||||
for i <- 0 until 1536 do l1Pre(i) -= l1WeightsT(offset + i)
|
||||
|
||||
// Layer 2: Dense(1536 → 1024) + ReLU
|
||||
for i <- 0 until 1024 do
|
||||
var sum = l2Bias(i)
|
||||
for j <- 0 until 1536 do
|
||||
sum += l1Output(j) * l2Weights(i * 1536 + j)
|
||||
l2Output(i) = if sum > 0f then sum else 0f
|
||||
// ── Accumulator init ─────────────────────────────────────────────────────
|
||||
|
||||
// Layer 3: Dense(1024 → 512) + ReLU
|
||||
for i <- 0 until 512 do
|
||||
var sum = l3Bias(i)
|
||||
for j <- 0 until 1024 do
|
||||
sum += l2Output(j) * l3Weights(i * 1024 + j)
|
||||
l3Output(i) = if sum > 0f then sum else 0f
|
||||
/** Initialise l1Stack(0) from scratch using sparse active features. */
|
||||
def initAccumulator(board: Board): Unit =
|
||||
System.arraycopy(l1Bias, 0, l1Stack(0), 0, 1536)
|
||||
for (sq, piece) <- board.pieces do addColumn(l1Stack(0), featureIndex(piece, squareNum(sq)))
|
||||
|
||||
// Layer 4: Dense(512 → 256) + ReLU
|
||||
for i <- 0 until 256 do
|
||||
var sum = l4Bias(i)
|
||||
for j <- 0 until 512 do
|
||||
sum += l3Output(j) * l4Weights(i * 512 + j)
|
||||
l4Output(i) = if sum > 0f then sum else 0f
|
||||
// ── Accumulator push (incremental updates) ───────────────────────────────
|
||||
|
||||
// Layer 5: Dense(256 → 1), no activation
|
||||
var output = l5Bias(0)
|
||||
for j <- 0 until 256 do
|
||||
output += l4Output(j) * l5Weights(j)
|
||||
/** Copy parent ply's pre-activations to childPly, then apply move deltas. */
|
||||
def pushAccumulator(childPly: Int, move: Move, board: Board): Unit =
|
||||
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, 1536)
|
||||
val l1 = l1Stack(childPly)
|
||||
move.moveType match
|
||||
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
|
||||
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
|
||||
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
|
||||
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
|
||||
|
||||
// Convert from tanh-normalized output back to centipawns.
|
||||
// Training uses: eval_normalized = tanh(eval_cp / 300) always from White's perspective.
|
||||
// Inverse: eval_cp = 300 * atanh(output); negate for Black to return from side-to-move perspective.
|
||||
val cp = if math.abs(output) >= 0.9999f then
|
||||
if output > 0f then 20000 else -20000
|
||||
/** Copy pre-activations from parentPly to childPly without any move delta (null-move). */
|
||||
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
||||
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, 1536)
|
||||
|
||||
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||
board.pieceAt(move.from).foreach { mover =>
|
||||
val fromNum = squareNum(move.from)
|
||||
val toNum = squareNum(move.to)
|
||||
subtractColumn(l1, featureIndex(mover, fromNum))
|
||||
board.pieceAt(move.to).foreach(cap => subtractColumn(l1, featureIndex(cap, toNum)))
|
||||
addColumn(l1, featureIndex(mover, toNum))
|
||||
}
|
||||
|
||||
private def applyEnPassantDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||
board.pieceAt(move.from).foreach { pawn =>
|
||||
val capturedSq = Square(move.to.file, move.from.rank)
|
||||
subtractColumn(l1, featureIndex(pawn, squareNum(move.from)))
|
||||
board.pieceAt(capturedSq).foreach(cap => subtractColumn(l1, featureIndex(cap, squareNum(capturedSq))))
|
||||
addColumn(l1, featureIndex(pawn, squareNum(move.to)))
|
||||
}
|
||||
|
||||
private def applyCastleDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||
board.pieceAt(move.from).foreach { king =>
|
||||
val rank = move.from.rank
|
||||
val kingside = move.moveType == MoveType.CastleKingside
|
||||
val (rookFrom, rookTo) =
|
||||
if kingside then (Square(File.H, rank), Square(File.F, rank))
|
||||
else (Square(File.A, rank), Square(File.D, rank))
|
||||
val rook = Piece(king.color, PieceType.Rook)
|
||||
subtractColumn(l1, featureIndex(king, squareNum(move.from)))
|
||||
addColumn(l1, featureIndex(king, squareNum(move.to)))
|
||||
subtractColumn(l1, featureIndex(rook, squareNum(rookFrom)))
|
||||
addColumn(l1, featureIndex(rook, squareNum(rookTo)))
|
||||
}
|
||||
|
||||
private def applyPromotionDelta(l1: Array[Float], move: Move, promo: PromotionPiece, board: Board): Unit =
|
||||
board.pieceAt(move.from).foreach { pawn =>
|
||||
val toNum = squareNum(move.to)
|
||||
subtractColumn(l1, featureIndex(pawn, squareNum(move.from)))
|
||||
board.pieceAt(move.to).foreach(cap => subtractColumn(l1, featureIndex(cap, toNum)))
|
||||
addColumn(l1, featureIndex(Piece(pawn.color, promotedType(promo)), toNum))
|
||||
}
|
||||
|
||||
private def promotedType(promo: PromotionPiece): PieceType = promo match
|
||||
case PromotionPiece.Knight => PieceType.Knight
|
||||
case PromotionPiece.Bishop => PieceType.Bishop
|
||||
case PromotionPiece.Rook => PieceType.Rook
|
||||
case PromotionPiece.Queen => PieceType.Queen
|
||||
|
||||
// ── Evaluation from accumulator ──────────────────────────────────────────
|
||||
|
||||
/** Evaluate from pre-computed L1 pre-activations at the given ply. Probes eval cache first; stores result after
|
||||
* computation.
|
||||
*/
|
||||
def evaluateAtPly(ply: Int, turn: Color, hash: Long): Int =
|
||||
val idx = (hash & EVAL_CACHE_MASK).toInt
|
||||
if evalCacheHashes(idx) == hash then evalCacheScores(idx)
|
||||
else
|
||||
val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat
|
||||
(300f * atanh).toInt
|
||||
val score = runL2toOutput(l1Stack(ply), turn)
|
||||
evalCacheHashes(idx) = hash
|
||||
evalCacheScores(idx) = score
|
||||
score
|
||||
|
||||
val cpFromTurn = if context.turn == Color.Black then -cp else cp
|
||||
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
|
||||
for i <- 0 until 1536 do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
|
||||
runDenseReLU(l1ReLU, 1536, l2Weights, l2Bias, l2Output, 1024)
|
||||
runDenseReLU(l2Output, 1024, l3Weights, l3Bias, l3Output, 512)
|
||||
runDenseReLU(l3Output, 512, l4Weights, l4Bias, l4Output, 256)
|
||||
val output = runOutputLayer(l4Output, 256)
|
||||
scoreFromOutput(output, turn)
|
||||
|
||||
private def runDenseReLU(
|
||||
input: Array[Float],
|
||||
inSize: Int,
|
||||
weights: Array[Float],
|
||||
bias: Array[Float],
|
||||
output: Array[Float],
|
||||
outSize: Int,
|
||||
): Unit =
|
||||
for i <- 0 until outSize do
|
||||
val sum = (0 until inSize).foldLeft(bias(i))((s, j) => s + input(j) * weights(i * inSize + j))
|
||||
output(i) = if sum > 0f then sum else 0f
|
||||
|
||||
private def runOutputLayer(input: Array[Float], inSize: Int): Float =
|
||||
(0 until inSize).foldLeft(l5Bias(0))((sum, j) => sum + input(j) * l5Weights(j))
|
||||
|
||||
private def scoreFromOutput(output: Float, turn: Color): Int =
|
||||
val cp =
|
||||
if math.abs(output) >= 0.9999f then if output > 0f then 20000 else -20000
|
||||
else
|
||||
val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat
|
||||
(300f * atanh).toInt
|
||||
val cpFromTurn = if turn == Color.Black then -cp else cp
|
||||
math.max(-20000, math.min(20000, cpFromTurn))
|
||||
|
||||
/** Benchmark: time 1M evaluations and report ns/eval.
|
||||
* This measures the performance of the inference on the starting position. */
|
||||
// ── Legacy full-board evaluate (kept for Evaluation.evaluate compatibility) ──
|
||||
|
||||
// Pre-allocated buffers used only by the legacy evaluate path.
|
||||
private val features = new Array[Float](768)
|
||||
private val legacyL1 = new Array[Float](1536)
|
||||
|
||||
/** Evaluate using full board scan (sparse over active features). Layout: black pieces at indices 0-5, white at 6-11.
|
||||
*/
|
||||
def evaluate(context: GameContext): Int =
|
||||
val l1Pre = legacyL1
|
||||
System.arraycopy(l1Bias, 0, l1Pre, 0, 1536)
|
||||
for (sq, piece) <- context.board.pieces do addColumn(l1Pre, featureIndex(piece, squareNum(sq)))
|
||||
runL2toOutput(l1Pre, context.turn)
|
||||
|
||||
/** Benchmark: time 1M evaluations and report ns/eval. */
|
||||
def benchmark(): Unit =
|
||||
val context = GameContext.initial
|
||||
val context = GameContext.initial
|
||||
val iterations = 1_000_000
|
||||
|
||||
// Warm up
|
||||
for _ <- 0 until 10000 do
|
||||
evaluate(context)
|
||||
|
||||
// Actual benchmark
|
||||
for _ <- 0 until 10000 do evaluate(context)
|
||||
val startNanos = System.nanoTime()
|
||||
for _ <- 0 until iterations do
|
||||
evaluate(context)
|
||||
val endNanos = System.nanoTime()
|
||||
|
||||
val totalNanos = endNanos - startNanos
|
||||
for _ <- 0 until iterations do evaluate(context)
|
||||
val endNanos = System.nanoTime()
|
||||
val totalNanos = endNanos - startNanos
|
||||
val nanosPerEval = totalNanos.toDouble / iterations
|
||||
|
||||
println()
|
||||
println("=" * 60)
|
||||
println("NNUE BENCHMARK RESULTS")
|
||||
|
||||
@@ -3,113 +3,99 @@ package de.nowchess.bot.logic
|
||||
import de.nowchess.api.board.PieceType
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.{Move, MoveType}
|
||||
import de.nowchess.bot.ai.Weights
|
||||
import de.nowchess.bot.ai.Evaluation
|
||||
import de.nowchess.bot.logic.{MoveOrdering, TTEntry, TTFlag, TranspositionTable}
|
||||
import de.nowchess.bot.util.ZobristHash
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}
|
||||
|
||||
final class AlphaBetaSearch(
|
||||
rules: RuleSet = DefaultRules,
|
||||
tt: TranspositionTable = TranspositionTable(),
|
||||
weights: Weights,
|
||||
numThreads: Int = Runtime.getRuntime.availableProcessors
|
||||
rules: RuleSet = DefaultRules,
|
||||
tt: TranspositionTable = TranspositionTable(),
|
||||
weights: Evaluation,
|
||||
numThreads: Int = Runtime.getRuntime.availableProcessors,
|
||||
):
|
||||
|
||||
private val INF = Int.MaxValue / 2
|
||||
private val MAX_QUIESCENCE_PLY = 64
|
||||
private val NULL_MOVE_R = 2
|
||||
private val ASPIRATION_DELTA = 50
|
||||
private val INF = Int.MaxValue / 2
|
||||
private val MAX_QUIESCENCE_PLY = 64
|
||||
private val NULL_MOVE_R = 2
|
||||
private val ASPIRATION_DELTA = 50
|
||||
private val ASPIRATION_DELTA_MAX = 150
|
||||
private val TIME_CHECK_FREQUENCY = 1000
|
||||
private val FUTILITY_MARGIN = 100
|
||||
private val CHECK_EXTENSION = 1
|
||||
private val FUTILITY_MARGIN = 100
|
||||
private val CHECK_EXTENSION = 1
|
||||
|
||||
@volatile private var timeStartMs = 0L
|
||||
@volatile private var timeLimitMs = 0L
|
||||
@volatile private var nodeCount = 0
|
||||
private val ordering = MoveOrdering.OrderingContext()
|
||||
private val timeStartMs = AtomicLong(0L)
|
||||
private val timeLimitMs = AtomicLong(0L)
|
||||
private val nodeCount = AtomicInteger(0)
|
||||
private val ordering = MoveOrdering.OrderingContext()
|
||||
|
||||
/** Return the best move for the side to move, searching to maxDepth plies.
|
||||
* Uses iterative deepening with aspiration windows. */
|
||||
/** Return the best move for the side to move, searching to maxDepth plies. Uses iterative deepening with aspiration
|
||||
* windows.
|
||||
*/
|
||||
def bestMove(context: GameContext, maxDepth: Int): Option[Move] =
|
||||
tt.clear()
|
||||
ordering.clear()
|
||||
timeStartMs = System.currentTimeMillis
|
||||
timeLimitMs = Long.MaxValue / 4
|
||||
nodeCount = 0
|
||||
var bestSoFar: Option[Move] = None
|
||||
var prevScore = 0
|
||||
var aspWindow = ASPIRATION_DELTA
|
||||
for depth <- 1 to maxDepth do
|
||||
val (alpha, beta) =
|
||||
if depth == 1 then (-INF, INF)
|
||||
else (prevScore - aspWindow, prevScore + aspWindow)
|
||||
val rootHash = ZobristHash.hash(context)
|
||||
val (score, move) = searchWithAspiration(context, depth, alpha, beta, aspWindow, rootHash)
|
||||
prevScore = score
|
||||
move.foreach(m => bestSoFar = Some(m))
|
||||
aspWindow = ASPIRATION_DELTA
|
||||
bestSoFar
|
||||
weights.initAccumulator(context)
|
||||
timeStartMs.set(System.currentTimeMillis)
|
||||
timeLimitMs.set(Long.MaxValue / 4)
|
||||
nodeCount.set(0)
|
||||
val rootHash = ZobristHash.hash(context)
|
||||
(1 to maxDepth).foldLeft((None: Option[Move], 0)) { case ((bestSoFar, prevScore), depth) =>
|
||||
val (alpha, beta) = if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
|
||||
val (score, move) = searchWithAspiration(context, depth, alpha, beta, ASPIRATION_DELTA, rootHash)
|
||||
(move.orElse(bestSoFar), score)
|
||||
}._1
|
||||
|
||||
/** Return the best move for the side to move within a time budget (ms).
|
||||
* Uses iterative deepening, stopping when time runs out. */
|
||||
/** Return the best move for the side to move within a time budget (ms). Uses iterative deepening, stopping when time
|
||||
* runs out.
|
||||
*/
|
||||
def bestMoveWithTime(context: GameContext, timeBudgetMs: Long): Option[Move] =
|
||||
tt.clear()
|
||||
ordering.clear()
|
||||
timeStartMs = System.currentTimeMillis
|
||||
timeLimitMs = timeBudgetMs
|
||||
nodeCount = 0
|
||||
var bestSoFar: Option[Move] = None
|
||||
var prevScore = 0
|
||||
var depth = 1
|
||||
var aspWindow = ASPIRATION_DELTA
|
||||
weights.initAccumulator(context)
|
||||
timeStartMs.set(System.currentTimeMillis)
|
||||
timeLimitMs.set(timeBudgetMs)
|
||||
nodeCount.set(0)
|
||||
val rootHash = ZobristHash.hash(context)
|
||||
|
||||
while !isOutOfTime do
|
||||
val (alpha, beta) =
|
||||
if depth == 1 then (-INF, INF)
|
||||
else (prevScore - aspWindow, prevScore + aspWindow)
|
||||
val rootHash = ZobristHash.hash(context)
|
||||
val (score, move) = searchWithAspiration(context, depth, alpha, beta, aspWindow, rootHash)
|
||||
prevScore = score
|
||||
move match
|
||||
case Some(m) =>
|
||||
bestSoFar = Some(m)
|
||||
case None =>
|
||||
aspWindow = ASPIRATION_DELTA
|
||||
depth += 1
|
||||
bestSoFar
|
||||
@scala.annotation.tailrec
|
||||
def loop(bestSoFar: Option[Move], prevScore: Int, depth: Int): Option[Move] =
|
||||
if isOutOfTime then bestSoFar
|
||||
else
|
||||
val (alpha, beta) = if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
|
||||
val (score, move) = searchWithAspiration(context, depth, alpha, beta, ASPIRATION_DELTA, rootHash)
|
||||
loop(move.orElse(bestSoFar), score, depth + 1)
|
||||
|
||||
loop(None, 0, 1)
|
||||
|
||||
private def isOutOfTime: Boolean =
|
||||
System.currentTimeMillis - timeStartMs >= timeLimitMs
|
||||
System.currentTimeMillis - timeStartMs.get >= timeLimitMs.get
|
||||
|
||||
private def searchWithAspiration(
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
initialWindow: Int,
|
||||
rootHash: Long
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
initialWindow: Int,
|
||||
rootHash: Long,
|
||||
): (Int, Option[Move]) =
|
||||
var currentAlpha = alpha
|
||||
var currentBeta = beta
|
||||
var window = initialWindow
|
||||
var attempt = 0
|
||||
val repetitions = Map(rootHash -> 1)
|
||||
|
||||
while attempt < 3 && attempt < depth do
|
||||
val (score, move) = search(context, depth, 0, currentAlpha, currentBeta, rootHash, repetitions)
|
||||
if score > currentAlpha && score < currentBeta then
|
||||
return (score, move)
|
||||
if score <= currentAlpha then
|
||||
currentAlpha = score - window
|
||||
window = math.min(window * 2, ASPIRATION_DELTA_MAX)
|
||||
if score >= currentBeta then
|
||||
currentBeta = score + window
|
||||
window = math.min(window * 2, ASPIRATION_DELTA_MAX)
|
||||
attempt += 1
|
||||
@scala.annotation.tailrec
|
||||
def loop(currentAlpha: Int, currentBeta: Int, window: Int, attempt: Int): (Int, Option[Move]) =
|
||||
if attempt >= 3 || attempt >= depth then
|
||||
search(context, depth, 0, -INF, INF, rootHash, repetitions)
|
||||
else
|
||||
val (score, move) = search(context, depth, 0, currentAlpha, currentBeta, rootHash, repetitions)
|
||||
if score > currentAlpha && score < currentBeta then (score, move)
|
||||
else if score <= currentAlpha then
|
||||
loop(score - window, currentBeta, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1)
|
||||
else
|
||||
loop(currentAlpha, score + window, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1)
|
||||
|
||||
search(context, depth, 0, -INF, INF, rootHash, repetitions)
|
||||
loop(alpha, beta, initialWindow, 0)
|
||||
|
||||
private def hasNonPawnMaterial(context: GameContext): Boolean =
|
||||
context.board.pieces.values.exists { piece =>
|
||||
@@ -122,153 +108,138 @@ final class AlphaBetaSearch(
|
||||
context.withTurn(context.turn.opposite).withEnPassantSquare(None)
|
||||
|
||||
private def tryNullMove(
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
beta: Int,
|
||||
repetitions: Map[Long, Int]
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
beta: Int,
|
||||
repetitions: Map[Long, Int],
|
||||
): Option[Int] =
|
||||
val nullCtx = nullMoveContext(context)
|
||||
val nullCtx = nullMoveContext(context)
|
||||
val nullHash = ZobristHash.hash(nullCtx)
|
||||
val nullRepetitions = repetitions.updatedWith(nullHash) {
|
||||
case Some(v) => Some(v + 1)
|
||||
case None => Some(1)
|
||||
case None => Some(1)
|
||||
}
|
||||
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
|
||||
weights.copyAccumulator(ply, ply + 1)
|
||||
val (score, _) = search(nullCtx, reductionDepth, ply + 1, -beta, -beta + 1, nullHash, nullRepetitions)
|
||||
if -score >= beta then Some(beta) else None
|
||||
|
||||
/** Negamax alpha-beta search returning (score, best move). */
|
||||
private def search(
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
hash: Long,
|
||||
repetitions: Map[Long, Int]
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
hash: Long,
|
||||
repetitions: Map[Long, Int],
|
||||
): (Int, Option[Move]) =
|
||||
// Periodic time check
|
||||
nodeCount += 1
|
||||
if nodeCount % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then
|
||||
return (weights.evaluate(context), None)
|
||||
|
||||
if repetitions.getOrElse(hash, 0) >= 3 then
|
||||
return (weights.DRAW_SCORE, None)
|
||||
|
||||
// TT probe
|
||||
tt.probe(hash) match
|
||||
case Some(entry) if entry.depth >= depth =>
|
||||
val count = nodeCount.incrementAndGet()
|
||||
if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then
|
||||
(weights.evaluateAccumulator(ply, context, hash), None)
|
||||
else if repetitions.getOrElse(hash, 0) >= 3 then
|
||||
(weights.DRAW_SCORE, None)
|
||||
else
|
||||
val ttCutoff = tt.probe(hash).filter(_.depth >= depth).flatMap { entry =>
|
||||
entry.flag match
|
||||
case TTFlag.Exact => return (entry.score, entry.bestMove)
|
||||
case TTFlag.Exact => Some((entry.score, entry.bestMove))
|
||||
case TTFlag.Lower =>
|
||||
val newAlpha = math.max(alpha, entry.score)
|
||||
if newAlpha >= beta then return (entry.score, entry.bestMove)
|
||||
Option.when(newAlpha >= beta)((entry.score, entry.bestMove))
|
||||
case TTFlag.Upper =>
|
||||
val newBeta = math.min(beta, entry.score)
|
||||
if alpha >= newBeta then return (entry.score, entry.bestMove)
|
||||
case _ =>
|
||||
|
||||
// Terminal node check
|
||||
val legalMoves = rules.allLegalMoves(context)
|
||||
if legalMoves.isEmpty then
|
||||
val score = if rules.isCheckmate(context) then
|
||||
-(weights.CHECKMATE_SCORE - ply)
|
||||
else
|
||||
weights.DRAW_SCORE
|
||||
return (score, None)
|
||||
|
||||
if rules.isInsufficientMaterial(context) || rules.isFiftyMoveRule(context) then
|
||||
return (weights.DRAW_SCORE, None)
|
||||
|
||||
// Leaf node: call quiescence
|
||||
if depth == 0 then
|
||||
return (quiescence(context, ply, alpha, beta), None)
|
||||
|
||||
// Null move pruning
|
||||
if depth >= 3 && !rules.isCheck(context) && hasNonPawnMaterial(context) then
|
||||
tryNullMove(context, depth, ply, beta, repetitions) match
|
||||
case Some(score) => return (score, None)
|
||||
case None =>
|
||||
|
||||
// Get TT best move for ordering
|
||||
val ttBest = tt.probe(hash).flatMap(_.bestMove)
|
||||
|
||||
// Order moves
|
||||
val ordered = MoveOrdering.sort(context, legalMoves, ttBest, ply, ordering)
|
||||
|
||||
searchSequential(context, depth, ply, alpha, beta, ordered, hash, repetitions)
|
||||
Option.when(alpha >= newBeta)((entry.score, entry.bestMove))
|
||||
}
|
||||
ttCutoff.getOrElse {
|
||||
val legalMoves = rules.allLegalMoves(context)
|
||||
if legalMoves.isEmpty then
|
||||
(if rules.isCheckmate(context) then -(weights.CHECKMATE_SCORE - ply) else weights.DRAW_SCORE, None)
|
||||
else if rules.isInsufficientMaterial(context) || rules.isFiftyMoveRule(context) then
|
||||
(weights.DRAW_SCORE, None)
|
||||
else if depth == 0 then
|
||||
(quiescence(context, ply, alpha, beta, hash), None)
|
||||
else
|
||||
val nullResult = Option
|
||||
.when(depth >= 3 && !rules.isCheck(context) && hasNonPawnMaterial(context)) {
|
||||
tryNullMove(context, depth, ply, beta, repetitions)
|
||||
}
|
||||
.flatten
|
||||
nullResult.map((_, None)).getOrElse {
|
||||
val ttBest = tt.probe(hash).flatMap(_.bestMove)
|
||||
val ordered = MoveOrdering.sort(context, legalMoves, ttBest, ply, ordering)
|
||||
searchSequential(context, depth, ply, alpha, beta, ordered, hash, repetitions)
|
||||
}
|
||||
}
|
||||
|
||||
private def searchSequential(
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
ordered: List[Move],
|
||||
hash: Long,
|
||||
repetitions: Map[Long, Int]
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
ordered: List[Move],
|
||||
hash: Long,
|
||||
repetitions: Map[Long, Int],
|
||||
): (Int, Option[Move]) =
|
||||
var bestMove: Option[Move] = None
|
||||
var bestScore = -INF
|
||||
var a = alpha
|
||||
var moveNumber = 0
|
||||
var cutoff = false
|
||||
@scala.annotation.tailrec
|
||||
def loop(
|
||||
idx: Int,
|
||||
bestMove: Option[Move],
|
||||
bestScore: Int,
|
||||
a: Int,
|
||||
moveNumber: Int,
|
||||
): (Option[Move], Int, Boolean) =
|
||||
if idx >= ordered.length then (bestMove, bestScore, false)
|
||||
else
|
||||
val move = ordered(idx)
|
||||
val isQuiet = !isCapture(context, move) &&
|
||||
move.moveType != MoveType.CastleKingside &&
|
||||
move.moveType != MoveType.CastleQueenside
|
||||
val pruneByFutility = depth == 1 && isQuiet && moveNumber > 2 &&
|
||||
weights.evaluateAccumulator(ply, context, hash) + FUTILITY_MARGIN < alpha
|
||||
|
||||
var idx = 0
|
||||
while idx < ordered.length && !cutoff do
|
||||
val move = ordered(idx)
|
||||
idx += 1
|
||||
moveNumber += 1
|
||||
val isQuiet = !isCapture(context, move) &&
|
||||
move.moveType != MoveType.CastleKingside &&
|
||||
move.moveType != MoveType.CastleQueenside
|
||||
|
||||
// Futility pruning at frontier nodes: if static eval + margin is still below alpha, skip quiet moves
|
||||
val pruneByFutility = if depth == 1 && isQuiet && moveNumber > 2 then
|
||||
val staticEval = weights.evaluate(context)
|
||||
staticEval + FUTILITY_MARGIN < alpha
|
||||
else false
|
||||
|
||||
if !pruneByFutility then
|
||||
val child = rules.applyMove(context)(move)
|
||||
val childHash = ZobristHash.nextHash(context, hash, move, child)
|
||||
val childRepetitions = repetitions.updatedWith(childHash) {
|
||||
case Some(v) => Some(v + 1)
|
||||
case None => Some(1)
|
||||
}
|
||||
val givesCheck = rules.isCheck(child)
|
||||
val extension = if givesCheck then CHECK_EXTENSION else 0
|
||||
val reduction = if moveNumber > 4 && depth >= 3 && isQuiet then 1 else 0
|
||||
|
||||
val score = if reduction > 0 then
|
||||
val reducedDepth = math.max(0, depth - 1 - reduction + extension)
|
||||
val (reducedScore, _) = search(child, reducedDepth, ply + 1, -a - 1, -a, childHash, childRepetitions)
|
||||
val s = -reducedScore
|
||||
if s > a then
|
||||
val fullDepth = math.max(0, depth - 1 + extension)
|
||||
val (fullScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childHash, childRepetitions)
|
||||
-fullScore
|
||||
else s
|
||||
if pruneByFutility then loop(idx + 1, bestMove, bestScore, a, moveNumber + 1)
|
||||
else
|
||||
val fullDepth = math.max(0, depth - 1 + extension)
|
||||
val (rawScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childHash, childRepetitions)
|
||||
-rawScore
|
||||
val child = rules.applyMove(context)(move)
|
||||
val childHash = ZobristHash.nextHash(context, hash, move, child)
|
||||
weights.pushAccumulator(ply + 1, move, context, child)
|
||||
val childRepetitions = repetitions.updatedWith(childHash) {
|
||||
case Some(v) => Some(v + 1)
|
||||
case None => Some(1)
|
||||
}
|
||||
val givesCheck = rules.isCheck(child)
|
||||
val extension = if givesCheck then CHECK_EXTENSION else 0
|
||||
val reduction = if moveNumber > 4 && depth >= 3 && isQuiet then 1 else 0
|
||||
|
||||
if score > bestScore then
|
||||
bestScore = score
|
||||
bestMove = Some(move)
|
||||
val score =
|
||||
if reduction > 0 then
|
||||
val reducedDepth = math.max(0, depth - 1 - reduction + extension)
|
||||
val (reducedScore, _) = search(child, reducedDepth, ply + 1, -a - 1, -a, childHash, childRepetitions)
|
||||
val s = -reducedScore
|
||||
if s > a then
|
||||
val fullDepth = math.max(0, depth - 1 + extension)
|
||||
val (fullScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childHash, childRepetitions)
|
||||
-fullScore
|
||||
else s
|
||||
else
|
||||
val fullDepth = math.max(0, depth - 1 + extension)
|
||||
val (rawScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childHash, childRepetitions)
|
||||
-rawScore
|
||||
|
||||
a = math.max(a, score)
|
||||
val newBestScore = math.max(bestScore, score)
|
||||
val newBestMove = if score > bestScore then Some(move) else bestMove
|
||||
val newA = math.max(a, score)
|
||||
|
||||
if a >= beta then
|
||||
if isQuiet then
|
||||
val fromIdx = move.from.rank.ordinal * 8 + move.from.file.ordinal
|
||||
val toIdx = move.to.rank.ordinal * 8 + move.to.file.ordinal
|
||||
ordering.addHistory(fromIdx, toIdx, depth * depth)
|
||||
ordering.addKillerMove(ply, move)
|
||||
cutoff = true
|
||||
if newA >= beta then
|
||||
if isQuiet then
|
||||
ordering.addHistory(move.from.rank.ordinal * 8 + move.from.file.ordinal, move.to.rank.ordinal * 8 + move.to.file.ordinal, depth * depth)
|
||||
ordering.addKillerMove(ply, move)
|
||||
(newBestMove, newBestScore, true)
|
||||
else
|
||||
loop(idx + 1, newBestMove, newBestScore, newA, moveNumber + 1)
|
||||
|
||||
val (bestMove, bestScore, cutoff) = loop(0, None, -INF, alpha, 0)
|
||||
val flag =
|
||||
if cutoff then TTFlag.Lower
|
||||
else if bestScore <= alpha then TTFlag.Upper
|
||||
@@ -278,48 +249,45 @@ final class AlphaBetaSearch(
|
||||
|
||||
/** Quiescence search: only captures until position is quiet. */
|
||||
private def quiescence(
|
||||
context: GameContext,
|
||||
ply: Int,
|
||||
alpha: Int,
|
||||
beta: Int
|
||||
context: GameContext,
|
||||
ply: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
hash: Long,
|
||||
): Int =
|
||||
val inCheck = rules.isCheck(context)
|
||||
val standPat = if inCheck then -INF else weights.evaluate(context)
|
||||
if !inCheck && standPat >= beta then
|
||||
return beta
|
||||
val inCheck = rules.isCheck(context)
|
||||
val standPat = if inCheck then -INF else weights.evaluateAccumulator(ply, context, hash)
|
||||
|
||||
var a = if inCheck then alpha else math.max(alpha, standPat)
|
||||
if !inCheck && standPat >= beta then beta
|
||||
else
|
||||
val a0 = if inCheck then alpha else math.max(alpha, standPat)
|
||||
|
||||
// Guard against infinite quiescence
|
||||
if ply >= MAX_QUIESCENCE_PLY then
|
||||
return if inCheck then weights.evaluate(context) else standPat
|
||||
|
||||
// Generate captures, or all evasions when side-to-move is in check.
|
||||
val allMoves = rules.allLegalMoves(context)
|
||||
val tacticalMoves = if inCheck then allMoves else allMoves.filter(m => isCapture(context, m))
|
||||
|
||||
if inCheck && tacticalMoves.isEmpty then
|
||||
return -(weights.CHECKMATE_SCORE - ply)
|
||||
|
||||
val ordered = MoveOrdering.sort(context, tacticalMoves, None)
|
||||
|
||||
var cutoff = false
|
||||
var idx = 0
|
||||
while idx < ordered.length && !cutoff do
|
||||
val move = ordered(idx)
|
||||
idx += 1
|
||||
val child = rules.applyMove(context)(move)
|
||||
val score = -quiescence(child, ply + 1, -beta, -a)
|
||||
if score >= beta then
|
||||
a = beta
|
||||
cutoff = true
|
||||
if ply >= MAX_QUIESCENCE_PLY then
|
||||
if inCheck then weights.evaluateAccumulator(ply, context, hash) else standPat
|
||||
else
|
||||
a = math.max(a, score)
|
||||
a
|
||||
val allMoves = rules.allLegalMoves(context)
|
||||
val tacticalMoves = if inCheck then allMoves else allMoves.filter(m => isCapture(context, m))
|
||||
|
||||
if inCheck && tacticalMoves.isEmpty then -(weights.CHECKMATE_SCORE - ply)
|
||||
else
|
||||
val ordered = MoveOrdering.sort(context, tacticalMoves, None)
|
||||
|
||||
@scala.annotation.tailrec
|
||||
def loop(idx: Int, a: Int): Int =
|
||||
if idx >= ordered.length then a
|
||||
else
|
||||
val move = ordered(idx)
|
||||
val child = rules.applyMove(context)(move)
|
||||
val childHash = ZobristHash.nextHash(context, hash, move, child)
|
||||
weights.pushAccumulator(ply + 1, move, context, child)
|
||||
val score = -quiescence(child, ply + 1, -beta, -a, childHash)
|
||||
if score >= beta then beta
|
||||
else loop(idx + 1, math.max(a, score))
|
||||
|
||||
loop(0, a0)
|
||||
|
||||
/** Predicate: context-aware capture classification. */
|
||||
private def isCapture(context: GameContext, move: Move): Boolean = move.moveType match
|
||||
case MoveType.Normal(true) => true
|
||||
case MoveType.EnPassant => true
|
||||
case MoveType.EnPassant => true
|
||||
case MoveType.Promotion(_) => context.board.pieceAt(move.to).exists(_.color != context.turn)
|
||||
case _ => false
|
||||
case _ => false
|
||||
|
||||
Reference in New Issue
Block a user