From 4f6cc2c0f848b9bd6c94db0ca378ea527ac234fc Mon Sep 17 00:00:00 2001 From: Janis Date: Fri, 17 Apr 2026 19:33:43 +0200 Subject: [PATCH] feat: Refactor alpha-beta search to use SearchState for repetition tracking --- .../nowchess/bot/logic/AlphaBetaSearch.scala | 57 ++++++++----------- .../bot/logic/TranspositionTable.scala | 7 +++ 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala b/modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala index bc0360c..49a8a2e 100644 --- a/modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala +++ b/modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala @@ -108,13 +108,13 @@ final class AlphaBetaSearch( rootHash: Long, excludedRootMoves: Set[Move], ): (Int, Option[Move]) = - val repetitions = Map(rootHash -> 1) + val state = SearchState(rootHash, Map(rootHash -> 1)) @scala.annotation.tailrec def loop(currentAlpha: Int, currentBeta: Int, window: Int, attempt: Int): (Int, Option[Move]) = - if attempt >= 3 || attempt >= depth then search(context, depth, 0, -INF, INF, rootHash, repetitions, excludedRootMoves) + if attempt >= 3 || attempt >= depth then search(context, depth, 0, -INF, INF, state, excludedRootMoves) else - val (score, move) = search(context, depth, 0, currentAlpha, currentBeta, rootHash, repetitions, excludedRootMoves) + val (score, move) = search(context, depth, 0, currentAlpha, currentBeta, state, excludedRootMoves) if score > currentAlpha && score < currentBeta then (score, move) else if score <= currentAlpha then loop(score - window, currentBeta, math.min(window * 2, ASPIRATION_DELTA_MAX), attempt + 1) @@ -137,18 +137,14 @@ final class AlphaBetaSearch( depth: Int, ply: Int, beta: Int, - repetitions: Map[Long, Int], + state: SearchState, excludedRootMoves: Set[Move], ): Option[Int] = - val nullCtx = nullMoveContext(context) - val nullHash = ZobristHash.hash(nullCtx) - val nullRepetitions = repetitions.updatedWith(nullHash) { - case Some(v) => Some(v + 1) - case None => Some(1) - } + val nullCtx = nullMoveContext(context) + val nullState = state.advance(ZobristHash.hash(nullCtx)) val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R) weights.copyAccumulator(ply, ply + 1) - val (score, _) = search(nullCtx, reductionDepth, ply + 1, -beta, -beta + 1, nullHash, nullRepetitions, excludedRootMoves) + val (score, _) = search(nullCtx, reductionDepth, ply + 1, -beta, -beta + 1, nullState, excludedRootMoves) if -score >= beta then Some(beta) else None /** Negamax alpha-beta search returning (score, best move). */ @@ -158,15 +154,14 @@ final class AlphaBetaSearch( ply: Int, alpha: Int, beta: Int, - hash: Long, - repetitions: Map[Long, Int], + state: SearchState, excludedRootMoves: Set[Move], ): (Int, Option[Move]) = val count = nodeCount.incrementAndGet() - if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then (weights.evaluateAccumulator(ply, context, hash), None) - else if repetitions.getOrElse(hash, 0) >= 3 then (weights.DRAW_SCORE, None) + if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then (weights.evaluateAccumulator(ply, context, state.hash), None) + else if state.repetitions.getOrElse(state.hash, 0) >= 3 then (weights.DRAW_SCORE, None) else - val ttCutoff = tt.probe(hash).filter(_.depth >= depth).flatMap { entry => + val ttCutoff = tt.probe(state.hash).filter(_.depth >= depth).flatMap { entry => entry.flag match case TTFlag.Exact => Some((entry.score, entry.bestMove)) case TTFlag.Lower => @@ -181,17 +176,17 @@ final class AlphaBetaSearch( if legalMoves.isEmpty then (if rules.isCheckmate(context) then -(weights.CHECKMATE_SCORE - ply) else weights.DRAW_SCORE, None) else if rules.isInsufficientMaterial(context) || rules.isFiftyMoveRule(context) then (weights.DRAW_SCORE, None) - else if depth == 0 then (quiescence(context, ply, alpha, beta, hash), None) + else if depth == 0 then (quiescence(context, ply, alpha, beta, state.hash), None) else val nullResult = Option .when(depth >= 3 && !rules.isCheck(context) && hasNonPawnMaterial(context)) { - tryNullMove(context, depth, ply, beta, repetitions, excludedRootMoves) + tryNullMove(context, depth, ply, beta, state, excludedRootMoves) } .flatten nullResult.map((_, None)).getOrElse { - val ttBest = tt.probe(hash).flatMap(_.bestMove) + val ttBest = tt.probe(state.hash).flatMap(_.bestMove) val ordered = MoveOrdering.sort(context, legalMoves, ttBest, ply, ordering) - searchSequential(context, depth, ply, alpha, beta, ordered, hash, repetitions, excludedRootMoves) + searchSequential(context, depth, ply, alpha, beta, ordered, state, excludedRootMoves) } } @@ -202,8 +197,7 @@ final class AlphaBetaSearch( alpha: Int, beta: Int, ordered: List[Move], - hash: Long, - repetitions: Map[Long, Int], + state: SearchState, excludedRootMoves: Set[Move], ): (Int, Option[Move]) = @scala.annotation.tailrec @@ -222,17 +216,14 @@ final class AlphaBetaSearch( move.moveType != MoveType.CastleKingside && move.moveType != MoveType.CastleQueenside val pruneByFutility = depth == 1 && isQuiet && moveNumber > 2 && - weights.evaluateAccumulator(ply, context, hash) + FUTILITY_MARGIN < alpha + weights.evaluateAccumulator(ply, context, state.hash) + FUTILITY_MARGIN < alpha if skipRootMove || pruneByFutility then loop(idx + 1, bestMove, bestScore, a, moveNumber + 1) else - val child = rules.applyMove(context)(move) - val childHash = ZobristHash.nextHash(context, hash, move, child) + val child = rules.applyMove(context)(move) + val childHash = ZobristHash.nextHash(context, state.hash, move, child) weights.pushAccumulator(ply + 1, move, context, child) - val childRepetitions = repetitions.updatedWith(childHash) { - case Some(v) => Some(v + 1) - case None => Some(1) - } + val childState = state.advance(childHash) val givesCheck = rules.isCheck(child) val extension = if givesCheck then CHECK_EXTENSION else 0 val reduction = if moveNumber > 4 && depth >= 3 && isQuiet then 1 else 0 @@ -240,16 +231,16 @@ final class AlphaBetaSearch( val score = if reduction > 0 then val reducedDepth = math.max(0, depth - 1 - reduction + extension) - val (reducedScore, _) = search(child, reducedDepth, ply + 1, -a - 1, -a, childHash, childRepetitions, excludedRootMoves) + val (reducedScore, _) = search(child, reducedDepth, ply + 1, -a - 1, -a, childState, excludedRootMoves) val s = -reducedScore if s > a then val fullDepth = math.max(0, depth - 1 + extension) - val (fullScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childHash, childRepetitions, excludedRootMoves) + val (fullScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childState, excludedRootMoves) -fullScore else s else val fullDepth = math.max(0, depth - 1 + extension) - val (rawScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childHash, childRepetitions, excludedRootMoves) + val (rawScore, _) = search(child, fullDepth, ply + 1, -beta, -a, childState, excludedRootMoves) -rawScore val newBestScore = math.max(bestScore, score) @@ -272,7 +263,7 @@ final class AlphaBetaSearch( if cutoff then TTFlag.Lower else if bestScore <= alpha then TTFlag.Upper else TTFlag.Exact - tt.store(TTEntry(hash, depth, bestScore, flag, bestMove)) + tt.store(TTEntry(state.hash, depth, bestScore, flag, bestMove)) (bestScore, bestMove) /** Quiescence search: only captures until position is quiet. */ diff --git a/modules/bot/src/main/scala/de/nowchess/bot/logic/TranspositionTable.scala b/modules/bot/src/main/scala/de/nowchess/bot/logic/TranspositionTable.scala index 27e7a13..b4221e1 100644 --- a/modules/bot/src/main/scala/de/nowchess/bot/logic/TranspositionTable.scala +++ b/modules/bot/src/main/scala/de/nowchess/bot/logic/TranspositionTable.scala @@ -2,6 +2,13 @@ package de.nowchess.bot.logic import de.nowchess.api.move.Move +final case class SearchState(hash: Long, repetitions: Map[Long, Int]): + def advance(nextHash: Long): SearchState = + SearchState(nextHash, repetitions.updatedWith(nextHash) { + case Some(v) => Some(v + 1) + case None => Some(1) + }) + enum TTFlag: case Exact // Score is exact case Lower // Score is a lower bound