feat(ncs-110): feed NNUE root-move scores into search move ordering (#83)
Build & Test (NowChessSystems) TeamCity build finished
Build & Test (NowChessSystems) TeamCity build finished
Pre-evaluated NNUE scores from NNUEBot.batchEvaluateRoot are now passed as root hints into AlphaBetaSearch, improving move ordering at ply 0 before the TT is populated. Hints are threaded immutably through SearchParams to satisfy the no-var constraint. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Janis Eccarius <eccariusjanis@gmail.com> Reviewed-on: #83
This commit was merged in pull request #83.
This commit is contained in:
@@ -28,7 +28,7 @@ object NNUEBot:
|
||||
else
|
||||
val scored = batchEvaluateRoot(rules, context, moves)
|
||||
val bestMove = scored.maxBy(_._2)._1
|
||||
search.bestMoveWithTime(context, allocateTime(scored), blockedMoves).orElse(Some(bestMove))
|
||||
search.bestMoveWithTime(context, allocateTime(scored), blockedMoves, scored.toMap).orElse(Some(bestMove))
|
||||
}
|
||||
|
||||
private def batchEvaluateRoot(rules: RuleSet, context: GameContext, moves: List[Move]): List[(Move, Int)] =
|
||||
|
||||
@@ -32,6 +32,8 @@ final class AlphaBetaSearch(
|
||||
private val nodeCount = AtomicInteger(0)
|
||||
private val ordering = MoveOrdering.OrderingContext()
|
||||
|
||||
def lastNodeCount: Int = nodeCount.get()
|
||||
|
||||
private final case class QuiescenceNode(
|
||||
context: GameContext,
|
||||
ply: Int,
|
||||
@@ -47,6 +49,17 @@ final class AlphaBetaSearch(
|
||||
bestMove(context, maxDepth, Set.empty)
|
||||
|
||||
def bestMove(context: GameContext, maxDepth: Int, excludedRootMoves: Set[Move]): Option[Move] =
|
||||
doDepthSearch(context, maxDepth, excludedRootMoves, Map.empty)
|
||||
|
||||
def bestMove(context: GameContext, maxDepth: Int, excludedRootMoves: Set[Move], hints: Map[Move, Int]): Option[Move] =
|
||||
doDepthSearch(context, maxDepth, excludedRootMoves, hints)
|
||||
|
||||
private def doDepthSearch(
|
||||
context: GameContext,
|
||||
maxDepth: Int,
|
||||
excludedRootMoves: Set[Move],
|
||||
hints: Map[Move, Int],
|
||||
): Option[Move] =
|
||||
tt.clear()
|
||||
ordering.clear()
|
||||
weights.initAccumulator(context)
|
||||
@@ -66,6 +79,7 @@ final class AlphaBetaSearch(
|
||||
ASPIRATION_DELTA,
|
||||
rootHash,
|
||||
excludedRootMoves,
|
||||
hints,
|
||||
)
|
||||
(move.orElse(bestSoFar), score)
|
||||
}
|
||||
@@ -78,6 +92,22 @@ final class AlphaBetaSearch(
|
||||
bestMoveWithTime(context, timeBudgetMs, Set.empty)
|
||||
|
||||
def bestMoveWithTime(context: GameContext, timeBudgetMs: Long, excludedRootMoves: Set[Move]): Option[Move] =
|
||||
doTimedSearch(context, timeBudgetMs, excludedRootMoves, Map.empty)
|
||||
|
||||
def bestMoveWithTime(
|
||||
context: GameContext,
|
||||
timeBudgetMs: Long,
|
||||
excludedRootMoves: Set[Move],
|
||||
hints: Map[Move, Int],
|
||||
): Option[Move] =
|
||||
doTimedSearch(context, timeBudgetMs, excludedRootMoves, hints)
|
||||
|
||||
private def doTimedSearch(
|
||||
context: GameContext,
|
||||
timeBudgetMs: Long,
|
||||
excludedRootMoves: Set[Move],
|
||||
hints: Map[Move, Int],
|
||||
): Option[Move] =
|
||||
tt.clear()
|
||||
ordering.clear()
|
||||
weights.initAccumulator(context)
|
||||
@@ -100,6 +130,7 @@ final class AlphaBetaSearch(
|
||||
ASPIRATION_DELTA,
|
||||
rootHash,
|
||||
excludedRootMoves,
|
||||
hints,
|
||||
)
|
||||
loop(move.orElse(bestSoFar), score, depth + 1, depth)
|
||||
|
||||
@@ -124,14 +155,17 @@ final class AlphaBetaSearch(
|
||||
initialWindow: Int,
|
||||
rootHash: Long,
|
||||
excludedRootMoves: Set[Move],
|
||||
hints: Map[Move, Int],
|
||||
): (Int, Option[Move]) =
|
||||
val state = SearchState(rootHash, Map(rootHash -> 1))
|
||||
|
||||
@scala.annotation.tailrec
|
||||
def loop(currentAlpha: Int, currentBeta: Int, delta: Int, attempt: Int): (Int, Option[Move]) =
|
||||
if attempt >= 3 || attempt >= depth then search(context, depth, 0, Window(-INF, INF), state, excludedRootMoves)
|
||||
if attempt >= 3 || attempt >= depth then
|
||||
search(context, depth, 0, Window(-INF, INF), state, excludedRootMoves, hints)
|
||||
else
|
||||
val (score, move) = search(context, depth, 0, Window(currentAlpha, currentBeta), state, excludedRootMoves)
|
||||
val (score, move) =
|
||||
search(context, depth, 0, Window(currentAlpha, currentBeta), state, excludedRootMoves, hints)
|
||||
if score > currentAlpha && score < currentBeta then (score, move)
|
||||
else if score <= currentAlpha then
|
||||
loop(score - delta, currentBeta, math.min(delta * 2, ASPIRATION_DELTA_MAX), attempt + 1)
|
||||
@@ -156,12 +190,14 @@ final class AlphaBetaSearch(
|
||||
beta: Int,
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
hints: Map[Move, Int],
|
||||
): Option[Int] =
|
||||
val nullCtx = nullMoveContext(context)
|
||||
val nullState = state.advance(ZobristHash.hash(nullCtx))
|
||||
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
|
||||
weights.copyAccumulator(ply, ply + 1)
|
||||
val (score, _) = search(nullCtx, reductionDepth, ply + 1, Window(-beta, -beta + 1), nullState, excludedRootMoves)
|
||||
val (score, _) =
|
||||
search(nullCtx, reductionDepth, ply + 1, Window(-beta, -beta + 1), nullState, excludedRootMoves, hints)
|
||||
if -score >= beta then Some(beta) else None
|
||||
|
||||
/** Negamax alpha-beta search returning (score, best move). */
|
||||
@@ -172,8 +208,9 @@ final class AlphaBetaSearch(
|
||||
window: Window,
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
hints: Map[Move, Int],
|
||||
): (Int, Option[Move]) =
|
||||
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
|
||||
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves, hints)
|
||||
searchNode(params)
|
||||
|
||||
private def searchNode(params: SearchParams): (Int, Option[Move]) =
|
||||
@@ -235,13 +272,14 @@ final class AlphaBetaSearch(
|
||||
params.window.beta,
|
||||
params.state,
|
||||
params.excludedRootMoves,
|
||||
params.rootHints,
|
||||
),
|
||||
)
|
||||
.flatten
|
||||
|
||||
nullResult.map((_, None)).getOrElse {
|
||||
val ttBest = tt.probe(params.state.hash).flatMap(_.bestMove)
|
||||
val ordered = MoveOrdering.sort(params.context, legalMoves, ttBest, params.ply, ordering)
|
||||
val ordered = MoveOrdering.sort(params.context, legalMoves, ttBest, params.ply, ordering, params.rootHints)
|
||||
searchSequential(
|
||||
params.context,
|
||||
params.depth,
|
||||
@@ -250,6 +288,7 @@ final class AlphaBetaSearch(
|
||||
ordered,
|
||||
params.state,
|
||||
params.excludedRootMoves,
|
||||
params.rootHints,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -280,6 +319,7 @@ final class AlphaBetaSearch(
|
||||
Window(-a - 1, -a),
|
||||
childState,
|
||||
params.excludedRootMoves,
|
||||
params.rootHints,
|
||||
)
|
||||
val s = -rs
|
||||
if s > a then
|
||||
@@ -290,6 +330,7 @@ final class AlphaBetaSearch(
|
||||
Window(betaNeg, -a),
|
||||
childState,
|
||||
params.excludedRootMoves,
|
||||
params.rootHints,
|
||||
)
|
||||
-fs
|
||||
else s
|
||||
@@ -301,6 +342,7 @@ final class AlphaBetaSearch(
|
||||
Window(betaNeg, -a),
|
||||
childState,
|
||||
params.excludedRootMoves,
|
||||
params.rootHints,
|
||||
)
|
||||
-rs
|
||||
|
||||
@@ -364,8 +406,9 @@ final class AlphaBetaSearch(
|
||||
ordered: List[Move],
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
rootHints: Map[Move, Int] = Map.empty,
|
||||
): (Int, Option[Move]) =
|
||||
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
|
||||
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves, rootHints)
|
||||
val (bestMove, bestScore, cutoff) = searchLoop(0, 0, LoopAcc(None, -INF, window.alpha), params, ordered)
|
||||
val flag =
|
||||
if cutoff then TTFlag.Lower
|
||||
|
||||
@@ -38,8 +38,10 @@ object MoveOrdering:
|
||||
ttBestMove: Option[Move],
|
||||
ply: Int = 0,
|
||||
ordering: OrderingContext = new OrderingContext(),
|
||||
rootHints: Map[Move, Int] = Map.empty,
|
||||
): Int =
|
||||
if ttBestMove.exists(m => m.from == move.from && m.to == move.to) then Int.MaxValue
|
||||
else if ply == 0 && rootHints.nonEmpty then rootHints.getOrElse(move, Int.MinValue / 2)
|
||||
else
|
||||
move.moveType match
|
||||
case MoveType.Promotion(PromotionPiece.Queen) =>
|
||||
@@ -56,8 +58,9 @@ object MoveOrdering:
|
||||
ttBestMove: Option[Move],
|
||||
ply: Int = 0,
|
||||
ordering: OrderingContext = new OrderingContext(),
|
||||
rootHints: Map[Move, Int] = Map.empty,
|
||||
): List[Move] =
|
||||
moves.sortBy(m => -score(context, m, ttBestMove, ply, ordering))
|
||||
moves.sortBy(m => -score(context, m, ttBestMove, ply, ordering, rootHints))
|
||||
|
||||
private def scoreQuietMove(move: Move, ply: Int, ordering: OrderingContext): Int =
|
||||
val isKiller = ordering.getKillerMoves(ply).exists(k => k.from == move.from && k.to == move.to)
|
||||
|
||||
@@ -14,6 +14,7 @@ final case class SearchParams(
|
||||
window: Window,
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
rootHints: Map[Move, Int] = Map.empty,
|
||||
)
|
||||
|
||||
final case class SearchState(hash: Long, repetitions: Map[Long, Int]):
|
||||
|
||||
@@ -312,6 +312,24 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
|
||||
val search = AlphaBetaSearch(qRules, weights = ZeroEval)
|
||||
search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove))
|
||||
|
||||
test("bestMove with root hints returns a valid move without regression"):
|
||||
val context = GameContext.initial
|
||||
val legalMoves = DefaultRules.allLegalMoves(context)
|
||||
val hints = legalMoves.zipWithIndex.map { case (m, i) => m -> (legalMoves.length - i) }.toMap
|
||||
val withHints = AlphaBetaSearch(DefaultRules, weights = EvaluationClassic)
|
||||
.bestMove(context, maxDepth = 2, Set.empty, hints)
|
||||
withHints should not be None
|
||||
legalMoves should contain(withHints.get)
|
||||
|
||||
test("bestMoveWithTime with root hints returns a valid move without regression"):
|
||||
val context = GameContext.initial
|
||||
val legalMoves = DefaultRules.allLegalMoves(context)
|
||||
val hints = legalMoves.zipWithIndex.map { case (m, i) => m -> (legalMoves.length - i) }.toMap
|
||||
val withHints = AlphaBetaSearch(DefaultRules, weights = EvaluationClassic)
|
||||
.bestMoveWithTime(context, 500L, Set.empty, hints)
|
||||
withHints should not be None
|
||||
legalMoves should contain(withHints.get)
|
||||
|
||||
test("quiescence depth-limit in-check branch is exercised"):
|
||||
val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
|
||||
val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true))
|
||||
|
||||
@@ -217,3 +217,60 @@ class MoveOrderingTest extends AnyFunSuite with Matchers:
|
||||
val castle = Move(Square(File.E, Rank.R1), Square(File.G, Rank.R1), MoveType.CastleKingside)
|
||||
|
||||
MoveOrdering.score(context, castle, None) should be(0)
|
||||
|
||||
test("root hints override capture heuristics at ply 0"):
|
||||
val board = Board(
|
||||
Map(
|
||||
Square(File.E, Rank.R4) -> Piece.WhiteQueen,
|
||||
Square(File.E, Rank.R5) -> Piece.BlackPawn,
|
||||
Square(File.D, Rank.R5) -> Piece.BlackRook,
|
||||
),
|
||||
)
|
||||
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
|
||||
val quietMove = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R6))
|
||||
val rookCapture = Move(Square(File.E, Rank.R4), Square(File.D, Rank.R5), MoveType.Normal(true))
|
||||
val hints = Map(quietMove -> 500, rookCapture -> 100)
|
||||
|
||||
MoveOrdering.score(context, quietMove, None, ply = 0, rootHints = hints) should equal(500)
|
||||
MoveOrdering.score(context, rookCapture, None, ply = 0, rootHints = hints) should equal(100)
|
||||
MoveOrdering.score(context, rookCapture, None, ply = 0, rootHints = hints) should be <
|
||||
MoveOrdering.score(context, quietMove, None, ply = 0, rootHints = hints)
|
||||
|
||||
test("root hints ignored at ply > 0"):
|
||||
val board = Board(Map(Square(File.E, Rank.R4) -> Piece.WhiteQueen, Square(File.E, Rank.R5) -> Piece.BlackPawn))
|
||||
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
|
||||
val capture = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R5), MoveType.Normal(true))
|
||||
val quiet = Move(Square(File.E, Rank.R4), Square(File.D, Rank.R4))
|
||||
val hints = Map(quiet -> 99999, capture -> -99999)
|
||||
|
||||
val captureScore = MoveOrdering.score(context, capture, None, ply = 1, rootHints = hints)
|
||||
val quietScore = MoveOrdering.score(context, quiet, None, ply = 1, rootHints = hints)
|
||||
captureScore should be > quietScore
|
||||
|
||||
test("move absent from root hints gets Int.MinValue / 2 fallback"):
|
||||
val board = Board(Map(Square(File.E, Rank.R4) -> Piece.WhiteQueen))
|
||||
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
|
||||
val move1 = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R6))
|
||||
val move2 = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R5))
|
||||
val hints = Map(move1 -> 0)
|
||||
|
||||
MoveOrdering.score(context, move2, None, ply = 0, rootHints = hints) should equal(Int.MinValue / 2)
|
||||
|
||||
test("sort uses root hints at ply 0 to reorder moves"):
|
||||
val board = Board(
|
||||
Map(
|
||||
Square(File.E, Rank.R4) -> Piece.WhiteQueen,
|
||||
Square(File.E, Rank.R5) -> Piece.BlackPawn,
|
||||
Square(File.D, Rank.R5) -> Piece.BlackRook,
|
||||
),
|
||||
)
|
||||
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
|
||||
val rookCapture = Move(Square(File.E, Rank.R4), Square(File.D, Rank.R5), MoveType.Normal(true))
|
||||
val pawnCapture = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R5), MoveType.Normal(true))
|
||||
val quiet = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R6))
|
||||
val hints = Map(quiet -> 9999, pawnCapture -> 500, rookCapture -> 100)
|
||||
|
||||
val sorted = MoveOrdering.sort(context, List(rookCapture, pawnCapture, quiet), None, ply = 0, rootHints = hints)
|
||||
sorted.head should equal(quiet)
|
||||
sorted(1) should equal(pawnCapture)
|
||||
sorted(2) should equal(rookCapture)
|
||||
|
||||
Reference in New Issue
Block a user