feat: Refactor alpha-beta search to use SearchState for repetition tracking

This commit is contained in:
2026-04-17 19:33:43 +02:00
parent e2f980df28
commit 4f6cc2c0f8
2 changed files with 31 additions and 33 deletions
@@ -108,13 +108,13 @@ final class AlphaBetaSearch(
rootHash: Long,
excludedRootMoves: Set[Move],
): (Int, Option[Move]) =
val repetitions = Map(rootHash -> 1)
val state = SearchState(rootHash, Map(rootHash -> 1))
@scala.annotation.tailrec
def loop(currentAlpha: Int, currentBeta: Int, window: Int, attempt: Int): (Int, Option[Move]) =
if attempt >= 3 || attempt >= depth then search(context, depth, 0, -INF, INF, rootHash, repetitions, excludedRootMoves)
if attempt >= 3 || attempt >= depth then search(context, depth, 0, -INF, INF, state, excludedRootMoves)
else
val (score, move) = search(context, depth, 0, currentAlpha, currentBeta, rootHash, repetitions, excludedRootMoves)
val (score, move) = search(context, depth, 0, currentAlpha, currentBeta, state, excludedRootMoves)
if score > currentAlpha && score < currentBeta then (score, move)
else if score <= currentAlpha then
loop(score - window, currentBeta, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1)
@@ -137,18 +137,14 @@ final class AlphaBetaSearch(
depth: Int,
ply: Int,
beta: Int,
repetitions: Map[Long, Int],
state: SearchState,
excludedRootMoves: Set[Move],
): Option[Int] =
val nullCtx = nullMoveContext(context)
val nullHash = ZobristHash.hash(nullCtx)
val nullRepetitions = repetitions.updatedWith(nullHash) {
case Some(v) => Some(v + 1)
case None => Some(1)
}
val nullState = state.advance(ZobristHash.hash(nullCtx))
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
weights.copyAccumulator(ply, ply + 1)
val (score, _) = search(nullCtx, reductionDepth, ply + 1, -beta, -beta + 1, nullHash, nullRepetitions, excludedRootMoves)
val (score, _) = search(nullCtx, reductionDepth, ply + 1, -beta, -beta + 1, nullState, excludedRootMoves)
if -score >= beta then Some(beta) else None
/** Negamax alpha-beta search returning (score, best move). */
@@ -158,15 +154,14 @@ final class AlphaBetaSearch(
ply: Int,
alpha: Int,
beta: Int,
hash: Long,
repetitions: Map[Long, Int],
state: SearchState,
excludedRootMoves: Set[Move],
): (Int, Option[Move]) =
val count = nodeCount.incrementAndGet()
if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then (weights.evaluateAccumulator(ply, context, hash), None)
else if repetitions.getOrElse(hash, 0) >= 3 then (weights.DRAW_SCORE, None)
if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then (weights.evaluateAccumulator(ply, context, state.hash), None)
else if state.repetitions.getOrElse(state.hash, 0) >= 3 then (weights.DRAW_SCORE, None)
else
val ttCutoff = tt.probe(hash).filter(_.depth >= depth).flatMap { entry =>
val ttCutoff = tt.probe(state.hash).filter(_.depth >= depth).flatMap { entry =>
entry.flag match
case TTFlag.Exact => Some((entry.score, entry.bestMove))
case TTFlag.Lower =>
@@ -181,17 +176,17 @@ final class AlphaBetaSearch(
if legalMoves.isEmpty then
(if rules.isCheckmate(context) then -(weights.CHECKMATE_SCORE - ply) else weights.DRAW_SCORE, None)
else if rules.isInsufficientMaterial(context) || rules.isFiftyMoveRule(context) then (weights.DRAW_SCORE, None)
else if depth == 0 then (quiescence(context, ply, alpha, beta, hash), None)
else if depth == 0 then (quiescence(context, ply, alpha, beta, state.hash), None)
else
val nullResult = Option
.when(depth >= 3 && !rules.isCheck(context) && hasNonPawnMaterial(context)) {
tryNullMove(context, depth, ply, beta, repetitions, excludedRootMoves)
tryNullMove(context, depth, ply, beta, state, excludedRootMoves)
}
.flatten
nullResult.map((_, None)).getOrElse {
val ttBest = tt.probe(hash).flatMap(_.bestMove)
val ttBest = tt.probe(state.hash).flatMap(_.bestMove)
val ordered = MoveOrdering.sort(context, legalMoves, ttBest, ply, ordering)
searchSequential(context, depth, ply, alpha, beta, ordered, hash, repetitions, excludedRootMoves)
searchSequential(context, depth, ply, alpha, beta, ordered, state, excludedRootMoves)
}
}
@@ -202,8 +197,7 @@ final class AlphaBetaSearch(
alpha: Int,
beta: Int,
ordered: List[Move],
hash: Long,
repetitions: Map[Long, Int],
state: SearchState,
excludedRootMoves: Set[Move],
): (Int, Option[Move]) =
@scala.annotation.tailrec
@@ -222,17 +216,14 @@ final class AlphaBetaSearch(
move.moveType != MoveType.CastleKingside &&
move.moveType != MoveType.CastleQueenside
val pruneByFutility = depth == 1 && isQuiet && moveNumber > 2 &&
weights.evaluateAccumulator(ply, context, hash) + FUTILITY_MARGIN < alpha
weights.evaluateAccumulator(ply, context, state.hash) + FUTILITY_MARGIN < alpha
if skipRootMove || pruneByFutility then loop(idx + 1, bestMove, bestScore, a, moveNumber + 1)
else
val child = rules.applyMove(context)(move)
val childHash = ZobristHash.nextHash(context, hash, move, child)
val childHash = ZobristHash.nextHash(context, state.hash, move, child)
weights.pushAccumulator(ply + 1, move, context, child)
val childRepetitions = repetitions.updatedWith(childHash) {
case Some(v) => Some(v + 1)
case None => Some(1)
}
val childState = state.advance(childHash)
val givesCheck = rules.isCheck(child)
val extension = if givesCheck then CHECK_EXTENSION else 0
val reduction = if moveNumber > 4 && depth >= 3 && isQuiet then 1 else 0
@@ -240,16 +231,16 @@ final class AlphaBetaSearch(
val score =
if reduction > 0 then
val reducedDepth = math.max(0, depth - 1 - reduction + extension)
val (reducedScore, _) = search(child, reducedDepth, ply + 1, -a - 1, -a, childHash, childRepetitions, excludedRootMoves)
val (reducedScore, _) = search(child, reducedDepth, ply + 1, -a - 1, -a, childState, excludedRootMoves)
val s = -reducedScore
if s > a then
val fullDepth = math.max(0, depth - 1 + extension)
val (fullScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childHash, childRepetitions, excludedRootMoves)
val (fullScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childState, excludedRootMoves)
-fullScore
else s
else
val fullDepth = math.max(0, depth - 1 + extension)
val (rawScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childHash, childRepetitions, excludedRootMoves)
val (rawScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childState, excludedRootMoves)
-rawScore
val newBestScore = math.max(bestScore, score)
@@ -272,7 +263,7 @@ final class AlphaBetaSearch(
if cutoff then TTFlag.Lower
else if bestScore <= alpha then TTFlag.Upper
else TTFlag.Exact
tt.store(TTEntry(hash, depth, bestScore, flag, bestMove))
tt.store(TTEntry(state.hash, depth, bestScore, flag, bestMove))
(bestScore, bestMove)
/** Quiescence search: only captures until position is quiet. */
@@ -2,6 +2,13 @@ package de.nowchess.bot.logic
import de.nowchess.api.move.Move
final case class SearchState(hash: Long, repetitions: Map[Long, Int]):
def advance(nextHash: Long): SearchState =
SearchState(nextHash, repetitions.updatedWith(nextHash) {
case Some(v) => Some(v + 1)
case None => Some(1)
})
enum TTFlag:
case Exact // Score is exact
case Lower // Score is a lower bound