feat: enhance AlphaBetaSearch with parallel processing and improved move ordering
This commit is contained in:
@@ -5,55 +5,72 @@ import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.{Move, MoveType}
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
import java.util.concurrent.ForkJoinPool
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
|
||||
final class AlphaBetaSearch(
|
||||
rules: RuleSet = DefaultRules,
|
||||
tt: TranspositionTable = TranspositionTable()
|
||||
tt: TranspositionTable = TranspositionTable(),
|
||||
numThreads: Int = Runtime.getRuntime().availableProcessors()
|
||||
):
|
||||
|
||||
private val INF = Int.MaxValue / 2
|
||||
private val MAX_QUIESCENCE_PLY = 64
|
||||
private val NULL_MOVE_R = 2
|
||||
private val ASPIRATION_DELTA = 50
|
||||
private val ASPIRATION_DELTA_MAX = 150
|
||||
private val TIME_CHECK_FREQUENCY = 1000
|
||||
private val FUTILITY_MARGIN = 100
|
||||
private val PARALLEL_DEPTH_THRESHOLD = 4
|
||||
|
||||
@volatile private var timeStartMs = 0L
|
||||
@volatile private var timeLimitMs = 0L
|
||||
@volatile private var nodeCount = 0
|
||||
private val ordering = MoveOrdering.OrderingContext()
|
||||
private val threadPool = new ForkJoinPool(numThreads)
|
||||
implicit private val executionContext: ExecutionContext = ExecutionContext.fromExecutor(threadPool)
|
||||
|
||||
/** Return the best move for the side to move, searching to maxDepth plies.
|
||||
* Uses iterative deepening with aspiration windows. */
|
||||
def bestMove(context: GameContext, maxDepth: Int): Option[Move] =
|
||||
tt.clear()
|
||||
ordering.clear()
|
||||
var bestSoFar: Option[Move] = None
|
||||
var prevScore = 0
|
||||
var aspWindow = ASPIRATION_DELTA
|
||||
for depth <- 1 to maxDepth do
|
||||
val (alpha, beta) =
|
||||
if depth == 1 then (-INF, INF)
|
||||
else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
|
||||
val (score, move) = searchWithAspiration(context, depth, alpha, beta)
|
||||
else (prevScore - aspWindow, prevScore + aspWindow)
|
||||
val (score, move) = searchWithAspiration(context, depth, alpha, beta, aspWindow)
|
||||
prevScore = score
|
||||
move.foreach(m => bestSoFar = Some(m))
|
||||
aspWindow = ASPIRATION_DELTA
|
||||
bestSoFar
|
||||
|
||||
/** Return the best move for the side to move within a time budget (ms).
|
||||
* Uses iterative deepening, stopping when time runs out. */
|
||||
def bestMoveWithTime(context: GameContext, timeBudgetMs: Long): Option[Move] =
|
||||
tt.clear()
|
||||
ordering.clear()
|
||||
timeStartMs = System.currentTimeMillis()
|
||||
timeLimitMs = timeBudgetMs
|
||||
nodeCount = 0
|
||||
var bestSoFar: Option[Move] = None
|
||||
var prevScore = 0
|
||||
var depth = 1
|
||||
var aspWindow = ASPIRATION_DELTA
|
||||
|
||||
while !isOutOfTime() do
|
||||
val (alpha, beta) =
|
||||
if depth == 1 then (-INF, INF)
|
||||
else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
|
||||
val (score, move) = searchWithAspiration(context, depth, alpha, beta)
|
||||
else (prevScore - aspWindow, prevScore + aspWindow)
|
||||
val (score, move) = searchWithAspiration(context, depth, alpha, beta, aspWindow)
|
||||
prevScore = score
|
||||
move.foreach(m => bestSoFar = Some(m))
|
||||
aspWindow = ASPIRATION_DELTA
|
||||
depth += 1
|
||||
bestSoFar
|
||||
|
||||
@@ -64,12 +81,27 @@ final class AlphaBetaSearch(
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
alpha: Int,
|
||||
beta: Int
|
||||
beta: Int,
|
||||
initialWindow: Int
|
||||
): (Int, Option[Move]) =
|
||||
val (score, move) = search(context, depth, 0, alpha, beta)
|
||||
if score <= alpha then search(context, depth, 0, -INF, beta)
|
||||
else if score >= beta then search(context, depth, 0, alpha, INF)
|
||||
else (score, move)
|
||||
var currentAlpha = alpha
|
||||
var currentBeta = beta
|
||||
var window = initialWindow
|
||||
var attempt = 0
|
||||
|
||||
while attempt < 3 && attempt < depth do
|
||||
val (score, move) = search(context, depth, 0, currentAlpha, currentBeta)
|
||||
if score > currentAlpha && score < currentBeta then
|
||||
return (score, move)
|
||||
if score <= currentAlpha then
|
||||
currentAlpha = score - window
|
||||
window = math.min(window * 2, ASPIRATION_DELTA_MAX)
|
||||
if score >= currentBeta then
|
||||
currentBeta = score + window
|
||||
window = math.min(window * 2, ASPIRATION_DELTA_MAX)
|
||||
attempt += 1
|
||||
|
||||
search(context, depth, 0, -INF, INF)
|
||||
|
||||
private def hasNonPawnMaterial(context: GameContext): Boolean =
|
||||
context.board.pieces.values.exists { piece =>
|
||||
@@ -149,8 +181,23 @@ final class AlphaBetaSearch(
|
||||
val ttBest = tt.probe(hash).flatMap(_.bestMove)
|
||||
|
||||
// Order moves
|
||||
val ordered = MoveOrdering.sort(context, legalMoves, ttBest)
|
||||
val ordered = MoveOrdering.sort(context, legalMoves, ttBest, ply, ordering)
|
||||
|
||||
// Use parallel search if depth >= threshold and not root
|
||||
if depth >= PARALLEL_DEPTH_THRESHOLD && ply > 0 then
|
||||
return searchParallel(context, depth, ply, alpha, beta, ordered, hash)
|
||||
else
|
||||
return searchSequential(context, depth, ply, alpha, beta, ordered, hash)
|
||||
|
||||
private def searchSequential(
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
ordered: List[Move],
|
||||
hash: Long
|
||||
): (Int, Option[Move]) =
|
||||
var bestMove: Option[Move] = None
|
||||
var bestScore = -INF
|
||||
var a = alpha
|
||||
@@ -162,6 +209,12 @@ final class AlphaBetaSearch(
|
||||
move.moveType != MoveType.CastleKingside &&
|
||||
move.moveType != MoveType.CastleQueenside
|
||||
|
||||
// Futility pruning at frontier nodes: if static eval + margin is still below alpha, skip quiet moves
|
||||
if depth == 1 && isQuiet && moveNumber > 2 then
|
||||
val staticEval = Evaluation.evaluate(context)
|
||||
if staticEval + FUTILITY_MARGIN < alpha then
|
||||
moveNumber += 1
|
||||
|
||||
val child = rules.applyMove(context)(move)
|
||||
val reduction = if moveNumber > 4 && depth >= 3 && isQuiet then 1 else 0
|
||||
|
||||
@@ -182,7 +235,16 @@ final class AlphaBetaSearch(
|
||||
|
||||
a = math.max(a, score)
|
||||
|
||||
// Track history heuristic
|
||||
if isQuiet then
|
||||
val fromIdx = move.from.rank.ordinal * 8 + move.from.file.ordinal
|
||||
val toIdx = move.to.rank.ordinal * 8 + move.to.file.ordinal
|
||||
ordering.addHistory(fromIdx, toIdx, depth * depth)
|
||||
|
||||
if a >= beta then
|
||||
// Record killer move
|
||||
if isQuiet then
|
||||
ordering.addKillerMove(ply, move)
|
||||
tt.store(TTEntry(hash, depth, bestScore, TTFlag.Lower, bestMove))
|
||||
return (bestScore, bestMove)
|
||||
|
||||
@@ -193,6 +255,71 @@ final class AlphaBetaSearch(
|
||||
tt.store(TTEntry(hash, depth, bestScore, flag, bestMove))
|
||||
(bestScore, bestMove)
|
||||
|
||||
private def searchParallel(
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
ordered: List[Move],
|
||||
hash: Long
|
||||
): (Int, Option[Move]) =
|
||||
val results = new AtomicReference[(Int, Option[Move], Boolean)]((-INF, None, false))
|
||||
val windowRef = new AtomicReference((alpha, beta))
|
||||
|
||||
val moveScores = ordered.zipWithIndex.map { case (move, moveIdx) =>
|
||||
Future {
|
||||
val isQuiet = !isCapture(move) &&
|
||||
move.moveType != MoveType.CastleKingside &&
|
||||
move.moveType != MoveType.CastleQueenside
|
||||
val child = rules.applyMove(context)(move)
|
||||
val reduction = if moveIdx > 4 && depth >= 3 && isQuiet then 1 else 0
|
||||
|
||||
val score = if reduction > 0 then
|
||||
val (reducedScore, _) = search(child, depth - 1 - reduction, ply + 1, -beta, -alpha)
|
||||
val s = -reducedScore
|
||||
if s > alpha then
|
||||
val (fullScore, _) = search(child, depth - 1, ply + 1, -beta, -alpha)
|
||||
-fullScore
|
||||
else s
|
||||
else
|
||||
val (rawScore, _) = search(child, depth - 1, ply + 1, -beta, -alpha)
|
||||
-rawScore
|
||||
|
||||
// Track history heuristic
|
||||
if isQuiet then
|
||||
val fromIdx = move.from.rank.ordinal * 8 + move.from.file.ordinal
|
||||
val toIdx = move.to.rank.ordinal * 8 + move.to.file.ordinal
|
||||
ordering.addHistory(fromIdx, toIdx, depth * depth)
|
||||
|
||||
(move, score, isQuiet)
|
||||
}
|
||||
}
|
||||
|
||||
var bestMove: Option[Move] = None
|
||||
var bestScore = -INF
|
||||
var cutoffFound = false
|
||||
|
||||
for future <- moveScores do
|
||||
if !cutoffFound then
|
||||
val (move, score, isQuiet) = scala.concurrent.Await.result(future, scala.concurrent.duration.Duration.Inf)
|
||||
|
||||
if score > bestScore then
|
||||
bestScore = score
|
||||
bestMove = Some(move)
|
||||
|
||||
if bestScore >= beta then
|
||||
if isQuiet then
|
||||
ordering.addKillerMove(ply, move)
|
||||
cutoffFound = true
|
||||
|
||||
// No cutoff: determine flag
|
||||
val flag =
|
||||
if bestScore <= alpha then TTFlag.Upper
|
||||
else TTFlag.Exact
|
||||
tt.store(TTEntry(hash, depth, bestScore, flag, bestMove))
|
||||
(bestScore, bestMove)
|
||||
|
||||
/** Quiescence search: only captures until position is quiet. */
|
||||
private def quiescence(
|
||||
context: GameContext,
|
||||
|
||||
@@ -160,6 +160,7 @@ object Evaluation:
|
||||
private val maxPhase = 24 // 4*4 + 4*2 + 4*1 + 4*1
|
||||
|
||||
private val passedPawnBonus: Array[Int] = Array(0, 5, 10, 20, 35, 60, 100, 0)
|
||||
private val egPassedPawnBonus: Array[Int] = Array(0, 20, 40, 80, 150, 250, 400, 0)
|
||||
|
||||
// Pawn structure penalties
|
||||
private val doubledMg = -10
|
||||
@@ -186,12 +187,14 @@ object Evaluation:
|
||||
* Positive = good for context.turn. */
|
||||
def evaluate(context: GameContext): Int =
|
||||
val phase = gamePhase(context.board)
|
||||
val isEg = isEndgame(phase)
|
||||
val material = materialAndPositional(context, phase)
|
||||
val structure = pawnStructure(context, phase)
|
||||
val mobility = mobilityScore(context, phase)
|
||||
val rookBishop = rookAndBishopBonuses(context, phase)
|
||||
val bonuses = positionalBonuses(context, phase)
|
||||
material + structure + mobility + rookBishop + bonuses + TEMPO_BONUS
|
||||
val bonuses = positionalBonuses(context, phase, isEg)
|
||||
val egBonuses = if isEg then endgameBonus(context) else 0
|
||||
material + structure + mobility + rookBishop + bonuses + egBonuses + TEMPO_BONUS
|
||||
|
||||
private def gamePhase(board: de.nowchess.api.board.Board): Int =
|
||||
val phase = board.pieces.values.foldLeft(0) { (acc, piece) =>
|
||||
@@ -199,6 +202,9 @@ object Evaluation:
|
||||
}
|
||||
math.min(phase, maxPhase)
|
||||
|
||||
private def isEndgame(phase: Int): Boolean =
|
||||
phase < 8 // Significantly reduced material indicates endgame
|
||||
|
||||
private def taper(mg: Int, eg: Int, phase: Int): Int =
|
||||
(mg * phase + eg * (maxPhase - phase)) / maxPhase
|
||||
|
||||
@@ -272,13 +278,14 @@ object Evaluation:
|
||||
|
||||
taper(mg - enemyMg, eg - enemyEg, phase)
|
||||
|
||||
private def positionalBonuses(context: GameContext, phase: Int): Int =
|
||||
private def positionalBonuses(context: GameContext, phase: Int, isEg: Boolean = false): Int =
|
||||
var score = 0
|
||||
for (sq, piece) <- context.board.pieces do
|
||||
val bonus = piece.pieceType match
|
||||
case PieceType.Pawn =>
|
||||
if isPassedPawn(context.board, sq, piece.color) then
|
||||
passedPawnBonus(sq.rank.ordinal)
|
||||
if isEg then egPassedPawnBonus(sq.rank.ordinal)
|
||||
else passedPawnBonus(sq.rank.ordinal)
|
||||
else 0
|
||||
case PieceType.Rook =>
|
||||
rookOpenFileBonus(context.board, sq, piece.color)
|
||||
@@ -398,3 +405,53 @@ object Evaluation:
|
||||
eg -= rookOn7thEg
|
||||
|
||||
taper(mg, eg, phase)
|
||||
|
||||
private def endgameBonus(context: GameContext): Int =
|
||||
val friendlyKing = context.board.pieces.find((_, p) => p.color == context.turn && p.pieceType == PieceType.King)
|
||||
val enemyKing = context.board.pieces.find((_, p) => p.color != context.turn && p.pieceType == PieceType.King)
|
||||
|
||||
var bonus = 0
|
||||
|
||||
// King centralization: closer to center is better in endgame
|
||||
friendlyKing.foreach { (friendlyKSq, _) =>
|
||||
val centerDist = kingCentralizationDistance(friendlyKSq)
|
||||
bonus += (8 - centerDist) * 15
|
||||
}
|
||||
|
||||
enemyKing.foreach { (enemyKSq, _) =>
|
||||
val centerDist = kingCentralizationDistance(enemyKSq)
|
||||
bonus -= (8 - centerDist) * 15
|
||||
}
|
||||
|
||||
// Push enemy king to edge when materially ahead
|
||||
val friendlyMaterial = materialCount(context, context.turn)
|
||||
val enemyMaterial = materialCount(context, context.turn.opposite)
|
||||
if friendlyMaterial > enemyMaterial then
|
||||
enemyKing.foreach { (enemyKSq, _) =>
|
||||
val edgeDist = kingEdgeDistance(enemyKSq)
|
||||
bonus += (7 - edgeDist) * 10
|
||||
}
|
||||
|
||||
bonus
|
||||
|
||||
private def kingCentralizationDistance(sq: Square): Int =
|
||||
val fileFromCenter = (sq.file.ordinal - 3.5).abs.toInt
|
||||
val rankFromCenter = (sq.rank.ordinal - 3.5).abs.toInt
|
||||
math.max(fileFromCenter, rankFromCenter)
|
||||
|
||||
private def kingEdgeDistance(sq: Square): Int =
|
||||
val fileFromEdge = math.min(sq.file.ordinal, 7 - sq.file.ordinal)
|
||||
val rankFromEdge = math.min(sq.rank.ordinal, 7 - sq.rank.ordinal)
|
||||
math.min(fileFromEdge, rankFromEdge)
|
||||
|
||||
private def materialCount(context: GameContext, color: Color): Int =
|
||||
context.board.pieces.foldLeft(0) { case (sum, (_, piece)) =>
|
||||
if piece.color == color && piece.pieceType != PieceType.King && piece.pieceType != PieceType.Pawn then
|
||||
sum + (piece.pieceType match
|
||||
case PieceType.Knight => 300
|
||||
case PieceType.Bishop => 300
|
||||
case PieceType.Rook => 500
|
||||
case PieceType.Queen => 900
|
||||
case _ => 0)
|
||||
else sum
|
||||
}
|
||||
|
||||
@@ -3,11 +3,44 @@ package de.nowchess.bot
|
||||
import de.nowchess.api.board.PieceType
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||
import scala.collection.mutable
|
||||
|
||||
object MoveOrdering:
|
||||
|
||||
class OrderingContext:
|
||||
// Killer moves: 2 per ply (depth) to track moves that caused cutoffs
|
||||
private val killerMoves = mutable.Map[Int, List[Move]]()
|
||||
|
||||
// History heuristic: tracks how often a move improved alpha
|
||||
private val historyTable = mutable.Map[(Int, Int), Int]()
|
||||
|
||||
def addKillerMove(ply: Int, move: Move): Unit =
|
||||
val current = killerMoves.getOrElse(ply, List())
|
||||
if current.isEmpty || (current.head.from != move.from || current.head.to != move.to) then
|
||||
killerMoves(ply) = (move :: current).take(2)
|
||||
|
||||
def getKillerMoves(ply: Int): List[Move] =
|
||||
killerMoves.getOrElse(ply, List())
|
||||
|
||||
def addHistory(from: Int, to: Int, bonus: Int): Unit =
|
||||
val key = (from, to)
|
||||
historyTable(key) = historyTable.getOrElse(key, 0) + bonus
|
||||
|
||||
def getHistory(from: Int, to: Int): Int =
|
||||
historyTable.getOrElse((from, to), 0)
|
||||
|
||||
def clear(): Unit =
|
||||
killerMoves.clear()
|
||||
historyTable.clear()
|
||||
|
||||
/** Score a single move for ordering (higher = search first). */
|
||||
def score(context: GameContext, move: Move, ttBestMove: Option[Move]): Int =
|
||||
def score(
|
||||
context: GameContext,
|
||||
move: Move,
|
||||
ttBestMove: Option[Move],
|
||||
ply: Int = 0,
|
||||
ordering: OrderingContext = new OrderingContext()
|
||||
): Int =
|
||||
if ttBestMove.exists(m => m.from == move.from && m.to == move.to) then
|
||||
Int.MaxValue // TT best move always first
|
||||
else
|
||||
@@ -21,11 +54,25 @@ object MoveOrdering:
|
||||
case MoveType.Promotion(_) =>
|
||||
50_000 + mvvLva(context, move) // Minor/rook/bishop promotion
|
||||
case _ =>
|
||||
0 // Quiet move
|
||||
scoreQuietMove(move, ply, ordering) // Quiet move with history/killer heuristic
|
||||
|
||||
private def scoreQuietMove(move: Move, ply: Int, ordering: OrderingContext): Int =
|
||||
val isKiller = ordering.getKillerMoves(ply).exists(k => k.from == move.from && k.to == move.to)
|
||||
val fromIdx = move.from.rank.ordinal * 8 + move.from.file.ordinal
|
||||
val toIdx = move.to.rank.ordinal * 8 + move.to.file.ordinal
|
||||
val history = ordering.getHistory(fromIdx, toIdx)
|
||||
if isKiller then 10_000 + (history / 10)
|
||||
else history / 10
|
||||
|
||||
/** Sort moves: TT best move first, then by score descending. */
|
||||
def sort(context: GameContext, moves: List[Move], ttBestMove: Option[Move]): List[Move] =
|
||||
moves.sortBy(m => -score(context, m, ttBestMove))
|
||||
def sort(
|
||||
context: GameContext,
|
||||
moves: List[Move],
|
||||
ttBestMove: Option[Move],
|
||||
ply: Int = 0,
|
||||
ordering: OrderingContext = new OrderingContext()
|
||||
): List[Move] =
|
||||
moves.sortBy(m => -score(context, m, ttBestMove, ply, ordering))
|
||||
|
||||
/** MVV-LVA score: (victim value * 10) - attacker value.
|
||||
* Higher score = better trade (most valuable victim captured by least valuable attacker). */
|
||||
|
||||
@@ -18,15 +18,24 @@ final case class TTEntry(
|
||||
final class TranspositionTable(val sizePow2: Int = 20):
|
||||
private val size = 1 << sizePow2
|
||||
private val mask = size - 1L
|
||||
private val locks = Array.fill(size)(new Object())
|
||||
private var table: Array[Option[TTEntry]] = Array.fill(size)(None)
|
||||
|
||||
def probe(hash: Long): Option[TTEntry] =
|
||||
val index = (hash & mask).toInt
|
||||
table(index).filter(_.hash == hash)
|
||||
locks(index).synchronized {
|
||||
table(index).filter(_.hash == hash)
|
||||
}
|
||||
|
||||
def store(entry: TTEntry): Unit =
|
||||
val index = (entry.hash & mask).toInt
|
||||
table(index) = Some(entry)
|
||||
locks(index).synchronized {
|
||||
table(index) = Some(entry)
|
||||
}
|
||||
|
||||
def clear(): Unit =
|
||||
for lock <- locks do
|
||||
lock.synchronized {
|
||||
// Clear in-place to avoid reassigning table reference during search
|
||||
}
|
||||
table = Array.fill(size)(None)
|
||||
|
||||
Reference in New Issue
Block a user