feat: Refactor alpha-beta search to use SearchState for repetition tracking
This commit is contained in:
@@ -108,13 +108,13 @@ final class AlphaBetaSearch(
|
||||
rootHash: Long,
|
||||
excludedRootMoves: Set[Move],
|
||||
): (Int, Option[Move]) =
|
||||
val repetitions = Map(rootHash -> 1)
|
||||
val state = SearchState(rootHash, Map(rootHash -> 1))
|
||||
|
||||
@scala.annotation.tailrec
|
||||
def loop(currentAlpha: Int, currentBeta: Int, window: Int, attempt: Int): (Int, Option[Move]) =
|
||||
if attempt >= 3 || attempt >= depth then search(context, depth, 0, -INF, INF, rootHash, repetitions, excludedRootMoves)
|
||||
if attempt >= 3 || attempt >= depth then search(context, depth, 0, -INF, INF, state, excludedRootMoves)
|
||||
else
|
||||
val (score, move) = search(context, depth, 0, currentAlpha, currentBeta, rootHash, repetitions, excludedRootMoves)
|
||||
val (score, move) = search(context, depth, 0, currentAlpha, currentBeta, state, excludedRootMoves)
|
||||
if score > currentAlpha && score < currentBeta then (score, move)
|
||||
else if score <= currentAlpha then
|
||||
loop(score - window, currentBeta, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1)
|
||||
@@ -137,18 +137,14 @@ final class AlphaBetaSearch(
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
beta: Int,
|
||||
repetitions: Map[Long, Int],
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
): Option[Int] =
|
||||
val nullCtx = nullMoveContext(context)
|
||||
val nullHash = ZobristHash.hash(nullCtx)
|
||||
val nullRepetitions = repetitions.updatedWith(nullHash) {
|
||||
case Some(v) => Some(v + 1)
|
||||
case None => Some(1)
|
||||
}
|
||||
val nullState = state.advance(ZobristHash.hash(nullCtx))
|
||||
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
|
||||
weights.copyAccumulator(ply, ply + 1)
|
||||
val (score, _) = search(nullCtx, reductionDepth, ply + 1, -beta, -beta + 1, nullHash, nullRepetitions, excludedRootMoves)
|
||||
val (score, _) = search(nullCtx, reductionDepth, ply + 1, -beta, -beta + 1, nullState, excludedRootMoves)
|
||||
if -score >= beta then Some(beta) else None
|
||||
|
||||
/** Negamax alpha-beta search returning (score, best move). */
|
||||
@@ -158,15 +154,14 @@ final class AlphaBetaSearch(
|
||||
ply: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
hash: Long,
|
||||
repetitions: Map[Long, Int],
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
): (Int, Option[Move]) =
|
||||
val count = nodeCount.incrementAndGet()
|
||||
if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then (weights.evaluateAccumulator(ply, context, hash), None)
|
||||
else if repetitions.getOrElse(hash, 0) >= 3 then (weights.DRAW_SCORE, None)
|
||||
if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then (weights.evaluateAccumulator(ply, context, state.hash), None)
|
||||
else if state.repetitions.getOrElse(state.hash, 0) >= 3 then (weights.DRAW_SCORE, None)
|
||||
else
|
||||
val ttCutoff = tt.probe(hash).filter(_.depth >= depth).flatMap { entry =>
|
||||
val ttCutoff = tt.probe(state.hash).filter(_.depth >= depth).flatMap { entry =>
|
||||
entry.flag match
|
||||
case TTFlag.Exact => Some((entry.score, entry.bestMove))
|
||||
case TTFlag.Lower =>
|
||||
@@ -181,17 +176,17 @@ final class AlphaBetaSearch(
|
||||
if legalMoves.isEmpty then
|
||||
(if rules.isCheckmate(context) then -(weights.CHECKMATE_SCORE - ply) else weights.DRAW_SCORE, None)
|
||||
else if rules.isInsufficientMaterial(context) || rules.isFiftyMoveRule(context) then (weights.DRAW_SCORE, None)
|
||||
else if depth == 0 then (quiescence(context, ply, alpha, beta, hash), None)
|
||||
else if depth == 0 then (quiescence(context, ply, alpha, beta, state.hash), None)
|
||||
else
|
||||
val nullResult = Option
|
||||
.when(depth >= 3 && !rules.isCheck(context) && hasNonPawnMaterial(context)) {
|
||||
tryNullMove(context, depth, ply, beta, repetitions, excludedRootMoves)
|
||||
tryNullMove(context, depth, ply, beta, state, excludedRootMoves)
|
||||
}
|
||||
.flatten
|
||||
nullResult.map((_, None)).getOrElse {
|
||||
val ttBest = tt.probe(hash).flatMap(_.bestMove)
|
||||
val ttBest = tt.probe(state.hash).flatMap(_.bestMove)
|
||||
val ordered = MoveOrdering.sort(context, legalMoves, ttBest, ply, ordering)
|
||||
searchSequential(context, depth, ply, alpha, beta, ordered, hash, repetitions, excludedRootMoves)
|
||||
searchSequential(context, depth, ply, alpha, beta, ordered, state, excludedRootMoves)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,8 +197,7 @@ final class AlphaBetaSearch(
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
ordered: List[Move],
|
||||
hash: Long,
|
||||
repetitions: Map[Long, Int],
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
): (Int, Option[Move]) =
|
||||
@scala.annotation.tailrec
|
||||
@@ -222,17 +216,14 @@ final class AlphaBetaSearch(
|
||||
move.moveType != MoveType.CastleKingside &&
|
||||
move.moveType != MoveType.CastleQueenside
|
||||
val pruneByFutility = depth == 1 && isQuiet && moveNumber > 2 &&
|
||||
weights.evaluateAccumulator(ply, context, hash) + FUTILITY_MARGIN < alpha
|
||||
weights.evaluateAccumulator(ply, context, state.hash) + FUTILITY_MARGIN < alpha
|
||||
|
||||
if skipRootMove || pruneByFutility then loop(idx + 1, bestMove, bestScore, a, moveNumber + 1)
|
||||
else
|
||||
val child = rules.applyMove(context)(move)
|
||||
val childHash = ZobristHash.nextHash(context, hash, move, child)
|
||||
val childHash = ZobristHash.nextHash(context, state.hash, move, child)
|
||||
weights.pushAccumulator(ply + 1, move, context, child)
|
||||
val childRepetitions = repetitions.updatedWith(childHash) {
|
||||
case Some(v) => Some(v + 1)
|
||||
case None => Some(1)
|
||||
}
|
||||
val childState = state.advance(childHash)
|
||||
val givesCheck = rules.isCheck(child)
|
||||
val extension = if givesCheck then CHECK_EXTENSION else 0
|
||||
val reduction = if moveNumber > 4 && depth >= 3 && isQuiet then 1 else 0
|
||||
@@ -240,16 +231,16 @@ final class AlphaBetaSearch(
|
||||
val score =
|
||||
if reduction > 0 then
|
||||
val reducedDepth = math.max(0, depth - 1 - reduction + extension)
|
||||
val (reducedScore, _) = search(child, reducedDepth, ply + 1, -a - 1, -a, childHash, childRepetitions, excludedRootMoves)
|
||||
val (reducedScore, _) = search(child, reducedDepth, ply + 1, -a - 1, -a, childState, excludedRootMoves)
|
||||
val s = -reducedScore
|
||||
if s > a then
|
||||
val fullDepth = math.max(0, depth - 1 + extension)
|
||||
val (fullScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childHash, childRepetitions, excludedRootMoves)
|
||||
val (fullScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childState, excludedRootMoves)
|
||||
-fullScore
|
||||
else s
|
||||
else
|
||||
val fullDepth = math.max(0, depth - 1 + extension)
|
||||
val (rawScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childHash, childRepetitions, excludedRootMoves)
|
||||
val (rawScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childState, excludedRootMoves)
|
||||
-rawScore
|
||||
|
||||
val newBestScore = math.max(bestScore, score)
|
||||
@@ -272,7 +263,7 @@ final class AlphaBetaSearch(
|
||||
if cutoff then TTFlag.Lower
|
||||
else if bestScore <= alpha then TTFlag.Upper
|
||||
else TTFlag.Exact
|
||||
tt.store(TTEntry(hash, depth, bestScore, flag, bestMove))
|
||||
tt.store(TTEntry(state.hash, depth, bestScore, flag, bestMove))
|
||||
(bestScore, bestMove)
|
||||
|
||||
/** Quiescence search: only captures until position is quiet. */
|
||||
|
||||
@@ -2,6 +2,13 @@ package de.nowchess.bot.logic
|
||||
|
||||
import de.nowchess.api.move.Move
|
||||
|
||||
final case class SearchState(hash: Long, repetitions: Map[Long, Int]):
|
||||
def advance(nextHash: Long): SearchState =
|
||||
SearchState(nextHash, repetitions.updatedWith(nextHash) {
|
||||
case Some(v) => Some(v + 1)
|
||||
case None => Some(1)
|
||||
})
|
||||
|
||||
enum TTFlag:
|
||||
case Exact // Score is exact
|
||||
case Lower // Score is a lower bound
|
||||
|
||||
Reference in New Issue
Block a user