feat(ncs-110): feed NNUE root-move scores into search move ordering #83

Merged
Janis merged 1 commits from feat/NCS-110-nnue-root-ordering into main 2026-06-24 20:09:30 +02:00
6 changed files with 130 additions and 8 deletions
@@ -28,7 +28,7 @@ object NNUEBot:
else else
val scored = batchEvaluateRoot(rules, context, moves) val scored = batchEvaluateRoot(rules, context, moves)
val bestMove = scored.maxBy(_._2)._1 val bestMove = scored.maxBy(_._2)._1
search.bestMoveWithTime(context, allocateTime(scored), blockedMoves).orElse(Some(bestMove)) search.bestMoveWithTime(context, allocateTime(scored), blockedMoves, scored.toMap).orElse(Some(bestMove))
} }
private def batchEvaluateRoot(rules: RuleSet, context: GameContext, moves: List[Move]): List[(Move, Int)] = private def batchEvaluateRoot(rules: RuleSet, context: GameContext, moves: List[Move]): List[(Move, Int)] =
@@ -32,6 +32,8 @@ final class AlphaBetaSearch(
private val nodeCount = AtomicInteger(0) private val nodeCount = AtomicInteger(0)
private val ordering = MoveOrdering.OrderingContext() private val ordering = MoveOrdering.OrderingContext()
def lastNodeCount: Int = nodeCount.get()
private final case class QuiescenceNode( private final case class QuiescenceNode(
context: GameContext, context: GameContext,
ply: Int, ply: Int,
@@ -47,6 +49,17 @@ final class AlphaBetaSearch(
bestMove(context, maxDepth, Set.empty) bestMove(context, maxDepth, Set.empty)
def bestMove(context: GameContext, maxDepth: Int, excludedRootMoves: Set[Move]): Option[Move] = def bestMove(context: GameContext, maxDepth: Int, excludedRootMoves: Set[Move]): Option[Move] =
doDepthSearch(context, maxDepth, excludedRootMoves, Map.empty)
def bestMove(context: GameContext, maxDepth: Int, excludedRootMoves: Set[Move], hints: Map[Move, Int]): Option[Move] =
doDepthSearch(context, maxDepth, excludedRootMoves, hints)
private def doDepthSearch(
context: GameContext,
maxDepth: Int,
excludedRootMoves: Set[Move],
hints: Map[Move, Int],
): Option[Move] =
tt.clear() tt.clear()
ordering.clear() ordering.clear()
weights.initAccumulator(context) weights.initAccumulator(context)
@@ -66,6 +79,7 @@ final class AlphaBetaSearch(
ASPIRATION_DELTA, ASPIRATION_DELTA,
rootHash, rootHash,
excludedRootMoves, excludedRootMoves,
hints,
) )
(move.orElse(bestSoFar), score) (move.orElse(bestSoFar), score)
} }
@@ -78,6 +92,22 @@ final class AlphaBetaSearch(
bestMoveWithTime(context, timeBudgetMs, Set.empty) bestMoveWithTime(context, timeBudgetMs, Set.empty)
def bestMoveWithTime(context: GameContext, timeBudgetMs: Long, excludedRootMoves: Set[Move]): Option[Move] = def bestMoveWithTime(context: GameContext, timeBudgetMs: Long, excludedRootMoves: Set[Move]): Option[Move] =
doTimedSearch(context, timeBudgetMs, excludedRootMoves, Map.empty)
def bestMoveWithTime(
context: GameContext,
timeBudgetMs: Long,
excludedRootMoves: Set[Move],
hints: Map[Move, Int],
): Option[Move] =
doTimedSearch(context, timeBudgetMs, excludedRootMoves, hints)
private def doTimedSearch(
context: GameContext,
timeBudgetMs: Long,
excludedRootMoves: Set[Move],
hints: Map[Move, Int],
): Option[Move] =
tt.clear() tt.clear()
ordering.clear() ordering.clear()
weights.initAccumulator(context) weights.initAccumulator(context)
@@ -100,6 +130,7 @@ final class AlphaBetaSearch(
ASPIRATION_DELTA, ASPIRATION_DELTA,
rootHash, rootHash,
excludedRootMoves, excludedRootMoves,
hints,
) )
loop(move.orElse(bestSoFar), score, depth + 1, depth) loop(move.orElse(bestSoFar), score, depth + 1, depth)
@@ -124,14 +155,17 @@ final class AlphaBetaSearch(
initialWindow: Int, initialWindow: Int,
rootHash: Long, rootHash: Long,
excludedRootMoves: Set[Move], excludedRootMoves: Set[Move],
hints: Map[Move, Int],
): (Int, Option[Move]) = ): (Int, Option[Move]) =
val state = SearchState(rootHash, Map(rootHash -> 1)) val state = SearchState(rootHash, Map(rootHash -> 1))
@scala.annotation.tailrec @scala.annotation.tailrec
def loop(currentAlpha: Int, currentBeta: Int, delta: Int, attempt: Int): (Int, Option[Move]) = def loop(currentAlpha: Int, currentBeta: Int, delta: Int, attempt: Int): (Int, Option[Move]) =
if attempt >= 3 || attempt >= depth then search(context, depth, 0, Window(-INF, INF), state, excludedRootMoves) if attempt >= 3 || attempt >= depth then
search(context, depth, 0, Window(-INF, INF), state, excludedRootMoves, hints)
else else
val (score, move) = search(context, depth, 0, Window(currentAlpha, currentBeta), state, excludedRootMoves) val (score, move) =
search(context, depth, 0, Window(currentAlpha, currentBeta), state, excludedRootMoves, hints)
if score > currentAlpha && score < currentBeta then (score, move) if score > currentAlpha && score < currentBeta then (score, move)
else if score <= currentAlpha then else if score <= currentAlpha then
loop(score - delta, currentBeta, math.min(delta * 2, ASPIRATION_DELTA_MAX), attempt + 1) loop(score - delta, currentBeta, math.min(delta * 2, ASPIRATION_DELTA_MAX), attempt + 1)
@@ -156,12 +190,14 @@ final class AlphaBetaSearch(
beta: Int, beta: Int,
state: SearchState, state: SearchState,
excludedRootMoves: Set[Move], excludedRootMoves: Set[Move],
hints: Map[Move, Int],
): Option[Int] = ): Option[Int] =
val nullCtx = nullMoveContext(context) val nullCtx = nullMoveContext(context)
val nullState = state.advance(ZobristHash.hash(nullCtx)) val nullState = state.advance(ZobristHash.hash(nullCtx))
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R) val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
weights.copyAccumulator(ply, ply + 1) weights.copyAccumulator(ply, ply + 1)
val (score, _) = search(nullCtx, reductionDepth, ply + 1, Window(-beta, -beta + 1), nullState, excludedRootMoves) val (score, _) =
search(nullCtx, reductionDepth, ply + 1, Window(-beta, -beta + 1), nullState, excludedRootMoves, hints)
if -score >= beta then Some(beta) else None if -score >= beta then Some(beta) else None
/** Negamax alpha-beta search returning (score, best move). */ /** Negamax alpha-beta search returning (score, best move). */
@@ -172,8 +208,9 @@ final class AlphaBetaSearch(
window: Window, window: Window,
state: SearchState, state: SearchState,
excludedRootMoves: Set[Move], excludedRootMoves: Set[Move],
hints: Map[Move, Int],
): (Int, Option[Move]) = ): (Int, Option[Move]) =
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves) val params = SearchParams(context, depth, ply, window, state, excludedRootMoves, hints)
searchNode(params) searchNode(params)
private def searchNode(params: SearchParams): (Int, Option[Move]) = private def searchNode(params: SearchParams): (Int, Option[Move]) =
@@ -235,13 +272,14 @@ final class AlphaBetaSearch(
params.window.beta, params.window.beta,
params.state, params.state,
params.excludedRootMoves, params.excludedRootMoves,
params.rootHints,
), ),
) )
.flatten .flatten
nullResult.map((_, None)).getOrElse { nullResult.map((_, None)).getOrElse {
val ttBest = tt.probe(params.state.hash).flatMap(_.bestMove) val ttBest = tt.probe(params.state.hash).flatMap(_.bestMove)
val ordered = MoveOrdering.sort(params.context, legalMoves, ttBest, params.ply, ordering) val ordered = MoveOrdering.sort(params.context, legalMoves, ttBest, params.ply, ordering, params.rootHints)
searchSequential( searchSequential(
params.context, params.context,
params.depth, params.depth,
@@ -250,6 +288,7 @@ final class AlphaBetaSearch(
ordered, ordered,
params.state, params.state,
params.excludedRootMoves, params.excludedRootMoves,
params.rootHints,
) )
} }
@@ -280,6 +319,7 @@ final class AlphaBetaSearch(
Window(-a - 1, -a), Window(-a - 1, -a),
childState, childState,
params.excludedRootMoves, params.excludedRootMoves,
params.rootHints,
) )
val s = -rs val s = -rs
if s > a then if s > a then
@@ -290,6 +330,7 @@ final class AlphaBetaSearch(
Window(betaNeg, -a), Window(betaNeg, -a),
childState, childState,
params.excludedRootMoves, params.excludedRootMoves,
params.rootHints,
) )
-fs -fs
else s else s
@@ -301,6 +342,7 @@ final class AlphaBetaSearch(
Window(betaNeg, -a), Window(betaNeg, -a),
childState, childState,
params.excludedRootMoves, params.excludedRootMoves,
params.rootHints,
) )
-rs -rs
@@ -364,8 +406,9 @@ final class AlphaBetaSearch(
ordered: List[Move], ordered: List[Move],
state: SearchState, state: SearchState,
excludedRootMoves: Set[Move], excludedRootMoves: Set[Move],
rootHints: Map[Move, Int] = Map.empty,
): (Int, Option[Move]) = ): (Int, Option[Move]) =
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves) val params = SearchParams(context, depth, ply, window, state, excludedRootMoves, rootHints)
val (bestMove, bestScore, cutoff) = searchLoop(0, 0, LoopAcc(None, -INF, window.alpha), params, ordered) val (bestMove, bestScore, cutoff) = searchLoop(0, 0, LoopAcc(None, -INF, window.alpha), params, ordered)
val flag = val flag =
if cutoff then TTFlag.Lower if cutoff then TTFlag.Lower
@@ -38,8 +38,10 @@ object MoveOrdering:
ttBestMove: Option[Move], ttBestMove: Option[Move],
ply: Int = 0, ply: Int = 0,
ordering: OrderingContext = new OrderingContext(), ordering: OrderingContext = new OrderingContext(),
rootHints: Map[Move, Int] = Map.empty,
): Int = ): Int =
if ttBestMove.exists(m => m.from == move.from && m.to == move.to) then Int.MaxValue if ttBestMove.exists(m => m.from == move.from && m.to == move.to) then Int.MaxValue
else if ply == 0 && rootHints.nonEmpty then rootHints.getOrElse(move, Int.MinValue / 2)
else else
move.moveType match move.moveType match
case MoveType.Promotion(PromotionPiece.Queen) => case MoveType.Promotion(PromotionPiece.Queen) =>
@@ -56,8 +58,9 @@ object MoveOrdering:
ttBestMove: Option[Move], ttBestMove: Option[Move],
ply: Int = 0, ply: Int = 0,
ordering: OrderingContext = new OrderingContext(), ordering: OrderingContext = new OrderingContext(),
rootHints: Map[Move, Int] = Map.empty,
): List[Move] = ): List[Move] =
moves.sortBy(m => -score(context, m, ttBestMove, ply, ordering)) moves.sortBy(m => -score(context, m, ttBestMove, ply, ordering, rootHints))
private def scoreQuietMove(move: Move, ply: Int, ordering: OrderingContext): Int = private def scoreQuietMove(move: Move, ply: Int, ordering: OrderingContext): Int =
val isKiller = ordering.getKillerMoves(ply).exists(k => k.from == move.from && k.to == move.to) val isKiller = ordering.getKillerMoves(ply).exists(k => k.from == move.from && k.to == move.to)
@@ -14,6 +14,7 @@ final case class SearchParams(
window: Window, window: Window,
state: SearchState, state: SearchState,
excludedRootMoves: Set[Move], excludedRootMoves: Set[Move],
rootHints: Map[Move, Int] = Map.empty,
) )
final case class SearchState(hash: Long, repetitions: Map[Long, Int]): final case class SearchState(hash: Long, repetitions: Map[Long, Int]):
@@ -312,6 +312,24 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
val search = AlphaBetaSearch(qRules, weights = ZeroEval) val search = AlphaBetaSearch(qRules, weights = ZeroEval)
search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove)) search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove))
test("bestMove with root hints returns a valid move without regression"):
val context = GameContext.initial
val legalMoves = DefaultRules.allLegalMoves(context)
val hints = legalMoves.zipWithIndex.map { case (m, i) => m -> (legalMoves.length - i) }.toMap
val withHints = AlphaBetaSearch(DefaultRules, weights = EvaluationClassic)
.bestMove(context, maxDepth = 2, Set.empty, hints)
withHints should not be None
legalMoves should contain(withHints.get)
test("bestMoveWithTime with root hints returns a valid move without regression"):
val context = GameContext.initial
val legalMoves = DefaultRules.allLegalMoves(context)
val hints = legalMoves.zipWithIndex.map { case (m, i) => m -> (legalMoves.length - i) }.toMap
val withHints = AlphaBetaSearch(DefaultRules, weights = EvaluationClassic)
.bestMoveWithTime(context, 500L, Set.empty, hints)
withHints should not be None
legalMoves should contain(withHints.get)
test("quiescence depth-limit in-check branch is exercised"): test("quiescence depth-limit in-check branch is exercised"):
val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal()) val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true)) val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true))
@@ -217,3 +217,60 @@ class MoveOrderingTest extends AnyFunSuite with Matchers:
val castle = Move(Square(File.E, Rank.R1), Square(File.G, Rank.R1), MoveType.CastleKingside) val castle = Move(Square(File.E, Rank.R1), Square(File.G, Rank.R1), MoveType.CastleKingside)
MoveOrdering.score(context, castle, None) should be(0) MoveOrdering.score(context, castle, None) should be(0)
test("root hints override capture heuristics at ply 0"):
val board = Board(
Map(
Square(File.E, Rank.R4) -> Piece.WhiteQueen,
Square(File.E, Rank.R5) -> Piece.BlackPawn,
Square(File.D, Rank.R5) -> Piece.BlackRook,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val quietMove = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R6))
val rookCapture = Move(Square(File.E, Rank.R4), Square(File.D, Rank.R5), MoveType.Normal(true))
val hints = Map(quietMove -> 500, rookCapture -> 100)
MoveOrdering.score(context, quietMove, None, ply = 0, rootHints = hints) should equal(500)
MoveOrdering.score(context, rookCapture, None, ply = 0, rootHints = hints) should equal(100)
MoveOrdering.score(context, rookCapture, None, ply = 0, rootHints = hints) should be <
MoveOrdering.score(context, quietMove, None, ply = 0, rootHints = hints)
test("root hints ignored at ply > 0"):
val board = Board(Map(Square(File.E, Rank.R4) -> Piece.WhiteQueen, Square(File.E, Rank.R5) -> Piece.BlackPawn))
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val capture = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R5), MoveType.Normal(true))
val quiet = Move(Square(File.E, Rank.R4), Square(File.D, Rank.R4))
val hints = Map(quiet -> 99999, capture -> -99999)
val captureScore = MoveOrdering.score(context, capture, None, ply = 1, rootHints = hints)
val quietScore = MoveOrdering.score(context, quiet, None, ply = 1, rootHints = hints)
captureScore should be > quietScore
test("move absent from root hints gets Int.MinValue / 2 fallback"):
val board = Board(Map(Square(File.E, Rank.R4) -> Piece.WhiteQueen))
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val move1 = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R6))
val move2 = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R5))
val hints = Map(move1 -> 0)
MoveOrdering.score(context, move2, None, ply = 0, rootHints = hints) should equal(Int.MinValue / 2)
test("sort uses root hints at ply 0 to reorder moves"):
val board = Board(
Map(
Square(File.E, Rank.R4) -> Piece.WhiteQueen,
Square(File.E, Rank.R5) -> Piece.BlackPawn,
Square(File.D, Rank.R5) -> Piece.BlackRook,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val rookCapture = Move(Square(File.E, Rank.R4), Square(File.D, Rank.R5), MoveType.Normal(true))
val pawnCapture = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R5), MoveType.Normal(true))
val quiet = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R6))
val hints = Map(quiet -> 9999, pawnCapture -> 500, rookCapture -> 100)
val sorted = MoveOrdering.sort(context, List(rookCapture, pawnCapture, quiet), None, ply = 0, rootHints = hints)
sorted.head should equal(quiet)
sorted(1) should equal(pawnCapture)
sorted(2) should equal(rookCapture)