feat: enhance AlphaBetaSearch with move timing and improve king square detection in GameContext

This commit is contained in:
2026-04-10 16:47:19 +02:00
parent 01c8a0f8fe
commit 75286a8773
11 changed files with 321 additions and 269 deletions
Binary file not shown.
@@ -4,4 +4,7 @@ object Config:
/** Threshold in centipawns: if classical evaluation differs from NNUE by more than this,
* the move is vetoed (not accepted as a suggestion). */
val VETO_THRESHOLD: Int = 100
val VETO_THRESHOLD: Int = 150
/** Time budget per move for iterative deepening (milliseconds). */
val TIME_LIMIT_MS: Long = 2000L
@@ -2,9 +2,11 @@ package de.nowchess.bot.bots
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.bot.logic.HybridSearch
import de.nowchess.bot.bots.classic.EvaluationClassic
import de.nowchess.bot.bots.nnue.EvaluationNNUE
import de.nowchess.bot.logic.{AlphaBetaSearch, TranspositionTable}
import de.nowchess.bot.util.PolyglotBook
import de.nowchess.bot.{Bot, BotDifficulty}
import de.nowchess.bot.{Bot, BotDifficulty, Config}
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
@@ -14,10 +16,20 @@ final class HybridBot(
book: Option[PolyglotBook] = None
) extends Bot:
private val search: HybridSearch = HybridSearch(rules)
private val search = AlphaBetaSearch(rules, TranspositionTable(), EvaluationClassic)
override val name: String = s"HybridBot(${difficulty.toString})"
override def nextMove(context: GameContext): Option[Move] =
book.flatMap(_.probe(context))
.orElse(search.bestMove(context))
book.flatMap(_.probe(context)).orElse(searchWithVeto(context))
private def searchWithVeto(context: GameContext): Option[Move] =
search.bestMoveWithTime(context, Config.TIME_LIMIT_MS).map { move =>
val next = rules.applyMove(context)(move)
val staticNnue = EvaluationNNUE.evaluate(next)
val classical = EvaluationClassic.evaluate(next)
val diff = (classical - staticNnue).abs
if diff > Config.VETO_THRESHOLD then
println(f"[Veto] ${move.from}->${move.to}: nnue=$staticNnue classical=$classical diff=$diff — flagged but trusted (deep search)")
move
}
@@ -67,28 +67,11 @@ class NNUE:
private val l4Output = new Array[Float](256)
/** Convert a position to 768-dimensional binary feature vector.
* 12 piece types (white pawn to black king) × 64 squares from white's perspective. */
private def positionToFeatures(board: Board, sideToMove: Color): Array[Float] =
// Zero out features array
* Layout matches training: black pieces at indices 0-5, white at 6-11.
* feature_idx = piece_idx * 64 + square (square: a1=0 .. h8=63, no mirroring). */
private def positionToFeatures(board: Board): Array[Float] =
java.util.Arrays.fill(features, 0f)
// Piece type to feature index offset: wp=0, wn=64, wb=128, wr=192, wq=256, wk=320, bp=384, bn=448, bb=512, br=576, bq=640, bk=704
val pieceToFeatureOffset = Array(
0, // White Pawn (0)
64, // White Knight (1)
128, // White Bishop (2)
192, // White Rook (3)
256, // White Queen (4)
320, // White King (5)
384, // Black Pawn (6)
448, // Black Knight (7)
512, // Black Bishop (8)
576, // Black Rook (9)
640, // Black Queen (10)
704 // Black King (11)
)
// Build features: always from white's perspective
for
fileIdx <- 0 until 8
rankIdx <- 0 until 8
@@ -99,17 +82,11 @@ class NNUE:
val squareNum = rankIdx * 8 + fileIdx
board.pieceAt(square).foreach { piece =>
val featureIdx = if sideToMove == Color.Black then
// Mirror square for black side-to-move
val mirroredSq = squareNum ^ 56
val offset = pieceToFeatureOffset(piece.color.ordinal * 6 + piece.pieceType.ordinal)
offset + mirroredSq
else
val offset = pieceToFeatureOffset(piece.color.ordinal * 6 + piece.pieceType.ordinal)
offset + squareNum
if featureIdx >= 0 && featureIdx < 768 then
features(featureIdx) = 1f
// black pieces → 0-5, white pieces → 6-11 (matches Python training encoding)
val colorOffset = if piece.color == Color.White then 6 else 0
val pieceIdx = colorOffset + piece.pieceType.ordinal
val featureIdx = pieceIdx * 64 + squareNum
features(featureIdx) = 1f
}
features
@@ -119,7 +96,7 @@ class NNUE:
* No allocations in the hot path (uses pre-allocated buffers).
* Architecture: 768→1536→1024→512→256→1 */
def evaluate(context: GameContext): Int =
val features = positionToFeatures(context.board, context.turn)
val features = positionToFeatures(context.board)
// Layer 1: Dense(768 → 1536) + ReLU
for i <- 0 until 1536 do
@@ -154,18 +131,17 @@ class NNUE:
for j <- 0 until 256 do
output += l4Output(j) * l5Weights(j)
// Convert from tanh-normalized output back to centipawns
// Training uses: eval_normalized = tanh(eval_cp / 300)
// Inverse: eval_cp = 300 * atanh(output)
// atanh(x) = 0.5 * ln((1 + x) / (1 - x))
// Convert from tanh-normalized output back to centipawns.
// Training uses: eval_normalized = tanh(eval_cp / 300) always from White's perspective.
// Inverse: eval_cp = 300 * atanh(output); negate for Black to return from side-to-move perspective.
val cp = if math.abs(output) >= 0.9999f then
// Clamp for numerical stability (avoid ln of very small numbers)
if output > 0f then 20000 else -20000
else
val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat
(300f * atanh).toInt
math.max(-20000, math.min(20000, cp))
val cpFromTurn = if context.turn == Color.Black then -cp else cp
math.max(-20000, math.min(20000, cpFromTurn))
/** Benchmark: time 1M evaluations and report ns/eval.
* This measures the performance of the inference on the starting position. */
@@ -73,8 +73,12 @@ final class AlphaBetaSearch(
if depth == 1 then (-INF, INF)
else (prevScore - aspWindow, prevScore + aspWindow)
val (score, move) = searchWithAspiration(context, depth, alpha, beta, aspWindow)
val elapsed = System.currentTimeMillis() - timeStartMs
prevScore = score
move.foreach(m => bestSoFar = Some(m))
move.foreach { m =>
bestSoFar = Some(m)
println(f"[Depth $depth%2d | ${elapsed}%5dms | ${nodeCount}%7d nodes] best=${m.from}->${m.to} score=$score")
}
aspWindow = ASPIRATION_DELTA
depth += 1
bestSoFar
@@ -1,67 +0,0 @@
package de.nowchess.bot.logic
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.bot.Config
import de.nowchess.bot.bots.classic.EvaluationClassic
import de.nowchess.bot.bots.nnue.EvaluationNNUE
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
import scala.util.boundary
import scala.util.boundary.break
final class HybridSearch(
rules: RuleSet = DefaultRules
):
private var vetoCount = 0
private var approvalCount = 0
private val TOP_MOVES_TO_VALIDATE = 10
/** Find the best move by scoring all legal moves with NNUE, then validating top 5 with classical eval.
* If a move's classical score is within VETO_THRESHOLD of its NNUE score, it's approved.
* If all top 5 are vetoed, fall back to the best classical move overall.
*/
def bestMove(context: GameContext): Option[Move] =
val legalMoves = rules.allLegalMoves(context)
if legalMoves.isEmpty then None else findBestMove(legalMoves, context)
private def findBestMove(legalMoves: List[Move], context: GameContext): Option[Move] =
// Score all moves with NNUE
val moveScores = legalMoves.map { move =>
val nextContext = rules.applyMove(context)(move)
val nnueScore = EvaluationNNUE.evaluate(nextContext)
(move, nnueScore, nextContext)
}
// Sort by NNUE score descending
val sortedByNNUE = moveScores.sortBy(_._2).reverse
// Validate top N moves with classical evaluation
val topMovesToCheck = sortedByNNUE.take(TOP_MOVES_TO_VALIDATE)
boundary:
for (move, nnueScore, nextContext) <- topMovesToCheck do
val classicalScore = EvaluationClassic.evaluate(nextContext)
val difference = (classicalScore - nnueScore).abs
if difference <= Config.VETO_THRESHOLD then
approvalCount += 1
println(s"[HybridSearch] Move approved: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference)")
break(Some(move))
else
vetoCount += 1
println(s"[HybridSearch] Move vetoed: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference > ${Config.VETO_THRESHOLD})")
// All top 10 were vetoed, fall back to best classical move
println(s"[HybridSearch] All top 10 NNUE moves vetoed. Falling back to best classical move.")
val bestByClassical = moveScores
.map { case (move, _, nextContext) =>
(move, EvaluationClassic.evaluate(nextContext))
}
.maxBy(_._2)
println(s"[HybridSearch] Fallback move: ${bestByClassical._1} (Classical score=${bestByClassical._2})")
println(s"[HybridSearch] Stats - Approvals: $approvalCount, Vetoes: $vetoCount")
Some(bestByClassical._1)
def getStats: (Int, Int) = (approvalCount, vetoCount)