feat: enhance AlphaBetaSearch with move timing and improve king square detection in GameContext

This commit is contained in:
2026-04-10 16:47:19 +02:00
parent 01c8a0f8fe
commit 75286a8773
11 changed files with 321 additions and 269 deletions
@@ -1,6 +1,6 @@
package de.nowchess.api.game
import de.nowchess.api.board.{Board, CastlingRights, Color, Square}
import de.nowchess.api.board.{Board, CastlingRights, Color, PieceType, Square}
import de.nowchess.api.move.Move
/** Immutable bundle of complete game state. All state changes produce new GameContext instances.
@@ -15,6 +15,13 @@ case class GameContext(
result: Option[GameResult] = None,
initialBoard: Board = Board.initial,
):
private lazy val whiteKingSquare: Option[Square] =
board.pieces.find((_, p) => p.color == Color.White && p.pieceType == PieceType.King).map(_._1)
private lazy val blackKingSquare: Option[Square] =
board.pieces.find((_, p) => p.color == Color.Black && p.pieceType == PieceType.King).map(_._1)
def kingSquare(color: Color): Option[Square] =
if color == Color.White then whiteKingSquare else blackKingSquare
/** Create new context with updated board. */
def withBoard(newBoard: Board): GameContext = copy(board = newBoard)
Binary file not shown.
@@ -4,4 +4,7 @@ object Config:
/** Threshold in centipawns: if classical evaluation differs from NNUE by more than this,
* the move is vetoed (not accepted as a suggestion). */
val VETO_THRESHOLD: Int = 100
val VETO_THRESHOLD: Int = 150
/** Time budget per move for iterative deepening (milliseconds). */
val TIME_LIMIT_MS: Long = 2000L
@@ -2,9 +2,11 @@ package de.nowchess.bot.bots
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.bot.logic.HybridSearch
import de.nowchess.bot.bots.classic.EvaluationClassic
import de.nowchess.bot.bots.nnue.EvaluationNNUE
import de.nowchess.bot.logic.{AlphaBetaSearch, TranspositionTable}
import de.nowchess.bot.util.PolyglotBook
import de.nowchess.bot.{Bot, BotDifficulty}
import de.nowchess.bot.{Bot, BotDifficulty, Config}
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
@@ -14,10 +16,20 @@ final class HybridBot(
book: Option[PolyglotBook] = None
) extends Bot:
private val search: HybridSearch = HybridSearch(rules)
private val search = AlphaBetaSearch(rules, TranspositionTable(), EvaluationClassic)
override val name: String = s"HybridBot(${difficulty.toString})"
override def nextMove(context: GameContext): Option[Move] =
book.flatMap(_.probe(context))
.orElse(search.bestMove(context))
book.flatMap(_.probe(context)).orElse(searchWithVeto(context))
private def searchWithVeto(context: GameContext): Option[Move] =
search.bestMoveWithTime(context, Config.TIME_LIMIT_MS).map { move =>
val next = rules.applyMove(context)(move)
val staticNnue = EvaluationNNUE.evaluate(next)
val classical = EvaluationClassic.evaluate(next)
val diff = (classical - staticNnue).abs
if diff > Config.VETO_THRESHOLD then
println(f"[Veto] ${move.from}->${move.to}: nnue=$staticNnue classical=$classical diff=$diff — flagged but trusted (deep search)")
move
}
@@ -67,28 +67,11 @@ class NNUE:
private val l4Output = new Array[Float](256)
/** Convert a position to 768-dimensional binary feature vector.
* 12 piece types (white pawn to black king) × 64 squares from white's perspective. */
private def positionToFeatures(board: Board, sideToMove: Color): Array[Float] =
// Zero out features array
* Layout matches training: black pieces at indices 0-5, white at 6-11.
* feature_idx = piece_idx * 64 + square (square: a1=0 .. h8=63, no mirroring). */
private def positionToFeatures(board: Board): Array[Float] =
java.util.Arrays.fill(features, 0f)
// Piece type to feature index offset: wp=0, wn=64, wb=128, wr=192, wq=256, wk=320, bp=384, bn=448, bb=512, br=576, bq=640, bk=704
val pieceToFeatureOffset = Array(
0, // White Pawn (0)
64, // White Knight (1)
128, // White Bishop (2)
192, // White Rook (3)
256, // White Queen (4)
320, // White King (5)
384, // Black Pawn (6)
448, // Black Knight (7)
512, // Black Bishop (8)
576, // Black Rook (9)
640, // Black Queen (10)
704 // Black King (11)
)
// Build features: always from white's perspective
for
fileIdx <- 0 until 8
rankIdx <- 0 until 8
@@ -99,17 +82,11 @@ class NNUE:
val squareNum = rankIdx * 8 + fileIdx
board.pieceAt(square).foreach { piece =>
val featureIdx = if sideToMove == Color.Black then
// Mirror square for black side-to-move
val mirroredSq = squareNum ^ 56
val offset = pieceToFeatureOffset(piece.color.ordinal * 6 + piece.pieceType.ordinal)
offset + mirroredSq
else
val offset = pieceToFeatureOffset(piece.color.ordinal * 6 + piece.pieceType.ordinal)
offset + squareNum
if featureIdx >= 0 && featureIdx < 768 then
features(featureIdx) = 1f
// black pieces → 0-5, white pieces → 6-11 (matches Python training encoding)
val colorOffset = if piece.color == Color.White then 6 else 0
val pieceIdx = colorOffset + piece.pieceType.ordinal
val featureIdx = pieceIdx * 64 + squareNum
features(featureIdx) = 1f
}
features
@@ -119,7 +96,7 @@ class NNUE:
* No allocations in the hot path (uses pre-allocated buffers).
* Architecture: 768→1536→1024→512→256→1 */
def evaluate(context: GameContext): Int =
val features = positionToFeatures(context.board, context.turn)
val features = positionToFeatures(context.board)
// Layer 1: Dense(768 → 1536) + ReLU
for i <- 0 until 1536 do
@@ -154,18 +131,17 @@ class NNUE:
for j <- 0 until 256 do
output += l4Output(j) * l5Weights(j)
// Convert from tanh-normalized output back to centipawns
// Training uses: eval_normalized = tanh(eval_cp / 300)
// Inverse: eval_cp = 300 * atanh(output)
// atanh(x) = 0.5 * ln((1 + x) / (1 - x))
// Convert from tanh-normalized output back to centipawns.
// Training uses: eval_normalized = tanh(eval_cp / 300) always from White's perspective.
// Inverse: eval_cp = 300 * atanh(output); negate for Black to return from side-to-move perspective.
val cp = if math.abs(output) >= 0.9999f then
// Clamp for numerical stability (avoid ln of very small numbers)
if output > 0f then 20000 else -20000
else
val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat
(300f * atanh).toInt
math.max(-20000, math.min(20000, cp))
val cpFromTurn = if context.turn == Color.Black then -cp else cp
math.max(-20000, math.min(20000, cpFromTurn))
/** Benchmark: time 1M evaluations and report ns/eval.
* This measures the performance of the inference on the starting position. */
@@ -73,8 +73,12 @@ final class AlphaBetaSearch(
if depth == 1 then (-INF, INF)
else (prevScore - aspWindow, prevScore + aspWindow)
val (score, move) = searchWithAspiration(context, depth, alpha, beta, aspWindow)
val elapsed = System.currentTimeMillis() - timeStartMs
prevScore = score
move.foreach(m => bestSoFar = Some(m))
move.foreach { m =>
bestSoFar = Some(m)
println(f"[Depth $depth%2d | ${elapsed}%5dms | ${nodeCount}%7d nodes] best=${m.from}->${m.to} score=$score")
}
aspWindow = ASPIRATION_DELTA
depth += 1
bestSoFar
@@ -1,67 +0,0 @@
package de.nowchess.bot.logic
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.bot.Config
import de.nowchess.bot.bots.classic.EvaluationClassic
import de.nowchess.bot.bots.nnue.EvaluationNNUE
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
import scala.util.boundary
import scala.util.boundary.break
final class HybridSearch(
rules: RuleSet = DefaultRules
):
private var vetoCount = 0
private var approvalCount = 0
private val TOP_MOVES_TO_VALIDATE = 10
/** Find the best move by scoring all legal moves with NNUE, then validating top 5 with classical eval.
* If a move's classical score is within VETO_THRESHOLD of its NNUE score, it's approved.
* If all top 5 are vetoed, fall back to the best classical move overall.
*/
def bestMove(context: GameContext): Option[Move] =
val legalMoves = rules.allLegalMoves(context)
if legalMoves.isEmpty then None else findBestMove(legalMoves, context)
private def findBestMove(legalMoves: List[Move], context: GameContext): Option[Move] =
// Score all moves with NNUE
val moveScores = legalMoves.map { move =>
val nextContext = rules.applyMove(context)(move)
val nnueScore = EvaluationNNUE.evaluate(nextContext)
(move, nnueScore, nextContext)
}
// Sort by NNUE score descending
val sortedByNNUE = moveScores.sortBy(_._2).reverse
// Validate top N moves with classical evaluation
val topMovesToCheck = sortedByNNUE.take(TOP_MOVES_TO_VALIDATE)
boundary:
for (move, nnueScore, nextContext) <- topMovesToCheck do
val classicalScore = EvaluationClassic.evaluate(nextContext)
val difference = (classicalScore - nnueScore).abs
if difference <= Config.VETO_THRESHOLD then
approvalCount += 1
println(s"[HybridSearch] Move approved: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference)")
break(Some(move))
else
vetoCount += 1
println(s"[HybridSearch] Move vetoed: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference > ${Config.VETO_THRESHOLD})")
// All top 10 were vetoed, fall back to best classical move
println(s"[HybridSearch] All top 10 NNUE moves vetoed. Falling back to best classical move.")
val bestByClassical = moveScores
.map { case (move, _, nextContext) =>
(move, EvaluationClassic.evaluate(nextContext))
}
.maxBy(_._2)
println(s"[HybridSearch] Fallback move: ${bestByClassical._1} (Classical score=${bestByClassical._2})")
println(s"[HybridSearch] Stats - Approvals: $approvalCount, Vetoes: $vetoCount")
Some(bestByClassical._1)
def getStats: (Int, Int) = (approvalCount, vetoCount)
@@ -27,7 +27,7 @@ object DefaultRules extends RuleSet:
List((2, 1), (2, -1), (-2, 1), (-2, -1), (1, 2), (1, -2), (-1, 2), (-1, -2))
// ── Pawn configuration helpers ─────────────────────────────────────
private def pawnForward(color: Color): Int = if color == Color.White then 1 else -1
private def pawnForward(color: Color): Int = if color == Color.White then 1 else -1
private def pawnStartRank(color: Color): Int = if color == Color.White then 1 else 6
private def pawnPromoRank(color: Color): Int = if color == Color.White then 7 else 0
@@ -36,14 +36,13 @@ object DefaultRules extends RuleSet:
override def candidateMoves(context: GameContext)(square: Square): List[Move] =
context.board.pieceAt(square).fold(List.empty[Move]) { piece =>
if piece.color != context.turn then List.empty[Move]
else
piece.pieceType match
case PieceType.Pawn => pawnCandidates(context, square, piece.color)
case PieceType.Knight => knightCandidates(context, square, piece.color)
case PieceType.Bishop => slidingMoves(context, square, piece.color, BishopDirs)
case PieceType.Rook => slidingMoves(context, square, piece.color, RookDirs)
case PieceType.Queen => slidingMoves(context, square, piece.color, QueenDirs)
case PieceType.King => kingCandidates(context, square, piece.color)
else piece.pieceType match
case PieceType.Pawn => pawnCandidates(context, square, piece.color)
case PieceType.Knight => knightCandidates(context, square, piece.color)
case PieceType.Bishop => slidingMoves(context, square, piece.color, BishopDirs)
case PieceType.Rook => slidingMoves(context, square, piece.color, RookDirs)
case PieceType.Queen => slidingMoves(context, square, piece.color, QueenDirs)
case PieceType.King => kingCandidates(context, square, piece.color)
}
override def legalMoves(context: GameContext)(square: Square): List[Move] =
@@ -52,10 +51,13 @@ object DefaultRules extends RuleSet:
}
override def allLegalMoves(context: GameContext): List[Move] =
Square.all.flatMap(sq => legalMoves(context)(sq)).toList
context.board.pieces
.collect { case (sq, p) if p.color == context.turn => legalMoves(context)(sq) }
.flatten
.toList
override def isCheck(context: GameContext): Boolean =
kingSquare(context.board, context.turn)
context.kingSquare(context.turn)
.fold(false)(sq => isAttackedBy(context.board, sq, context.turn.opposite))
override def isCheckmate(context: GameContext): Boolean =
@@ -113,18 +115,18 @@ object DefaultRules extends RuleSet:
// ── Sliding pieces (Bishop, Rook, Queen) ───────────────────────────
private def slidingMoves(
context: GameContext,
from: Square,
color: Color,
dirs: List[(Int, Int)],
context: GameContext,
from: Square,
color: Color,
dirs: List[(Int, Int)]
): List[Move] =
dirs.flatMap(dir => castRay(context.board, from, color, dir))
private def castRay(
board: Board,
from: Square,
color: Color,
dir: (Int, Int),
board: Board,
from: Square,
color: Color,
dir: (Int, Int)
): List[Move] =
@tailrec
def loop(sq: Square, acc: List[Move]): List[Move] =
@@ -132,40 +134,40 @@ object DefaultRules extends RuleSet:
case None => acc
case Some(next) =>
board.pieceAt(next) match
case None => loop(next, Move(from, next) :: acc)
case None => loop(next, Move(from, next) :: acc)
case Some(p) if p.color != color => Move(from, next, MoveType.Normal(isCapture = true)) :: acc
case Some(_) => acc
case Some(_) => acc
loop(from, Nil).reverse
// ── Knight ─────────────────────────────────────────────────────────
private def knightCandidates(
context: GameContext,
from: Square,
color: Color,
context: GameContext,
from: Square,
color: Color
): List[Move] =
KnightJumps.flatMap { (df, dr) =>
from.offset(df, dr).flatMap { to =>
context.board.pieceAt(to) match
case Some(p) if p.color == color => None
case Some(_) => Some(Move(from, to, MoveType.Normal(isCapture = true)))
case None => Some(Move(from, to))
case Some(_) => Some(Move(from, to, MoveType.Normal(isCapture = true)))
case None => Some(Move(from, to))
}
}
// ── King ───────────────────────────────────────────────────────────
private def kingCandidates(
context: GameContext,
from: Square,
color: Color,
context: GameContext,
from: Square,
color: Color
): List[Move] =
val steps = QueenDirs.flatMap { (df, dr) =>
from.offset(df, dr).flatMap { to =>
context.board.pieceAt(to) match
case Some(p) if p.color == color => None
case Some(_) => Some(Move(from, to, MoveType.Normal(isCapture = true)))
case None => Some(Move(from, to))
case Some(_) => Some(Move(from, to, MoveType.Normal(isCapture = true)))
case None => Some(Move(from, to))
}
}
steps ++ castlingCandidates(context, from, color)
@@ -173,17 +175,17 @@ object DefaultRules extends RuleSet:
// ── Castling ───────────────────────────────────────────────────────
private case class CastlingMove(
kingFromAlg: String,
kingToAlg: String,
middleAlg: String,
rookFromAlg: String,
moveType: MoveType,
kingFromAlg: String,
kingToAlg: String,
middleAlg: String,
rookFromAlg: String,
moveType: MoveType
)
private def castlingCandidates(
context: GameContext,
from: Square,
color: Color,
context: GameContext,
from: Square,
color: Color
): List[Move] =
color match
case Color.White => whiteCastles(context, from)
@@ -194,18 +196,10 @@ object DefaultRules extends RuleSet:
if from != expected then List.empty
else
val moves = scala.collection.mutable.ListBuffer[Move]()
addCastleMove(
context,
moves,
context.castlingRights.whiteKingSide,
CastlingMove("e1", "g1", "f1", "h1", MoveType.CastleKingside),
)
addCastleMove(
context,
moves,
context.castlingRights.whiteQueenSide,
CastlingMove("e1", "c1", "d1", "a1", MoveType.CastleQueenside),
)
addCastleMove(context, moves, context.castlingRights.whiteKingSide,
CastlingMove("e1", "g1", "f1", "h1", MoveType.CastleKingside))
addCastleMove(context, moves, context.castlingRights.whiteQueenSide,
CastlingMove("e1", "c1", "d1", "a1", MoveType.CastleQueenside))
moves.toList
private def blackCastles(context: GameContext, from: Square): List[Move] =
@@ -213,18 +207,10 @@ object DefaultRules extends RuleSet:
if from != expected then List.empty
else
val moves = scala.collection.mutable.ListBuffer[Move]()
addCastleMove(
context,
moves,
context.castlingRights.blackKingSide,
CastlingMove("e8", "g8", "f8", "h8", MoveType.CastleKingside),
)
addCastleMove(
context,
moves,
context.castlingRights.blackQueenSide,
CastlingMove("e8", "c8", "d8", "a8", MoveType.CastleQueenside),
)
addCastleMove(context, moves, context.castlingRights.blackKingSide,
CastlingMove("e8", "g8", "f8", "h8", MoveType.CastleKingside))
addCastleMove(context, moves, context.castlingRights.blackQueenSide,
CastlingMove("e8", "c8", "d8", "a8", MoveType.CastleQueenside))
moves.toList
private def queensideBSquare(kingToAlg: String): List[String] =
@@ -234,10 +220,10 @@ object DefaultRules extends RuleSet:
case _ => List.empty
private def addCastleMove(
context: GameContext,
moves: scala.collection.mutable.ListBuffer[Move],
castlingRight: Boolean,
castlingMove: CastlingMove,
context: GameContext,
moves: scala.collection.mutable.ListBuffer[Move],
castlingRight: Boolean,
castlingMove: CastlingMove
): Unit =
if castlingRight then
val clearSqs = (List(castlingMove.middleAlg, castlingMove.kingToAlg) ++ queensideBSquare(castlingMove.kingToAlg))
@@ -249,15 +235,16 @@ object DefaultRules extends RuleSet:
kt <- Square.fromAlgebraic(castlingMove.kingToAlg)
rf <- Square.fromAlgebraic(castlingMove.rookFromAlg)
do
val color = context.turn
val color = context.turn
val kingPresent = context.board.pieceAt(kf).exists(p => p.color == color && p.pieceType == PieceType.King)
val rookPresent = context.board.pieceAt(rf).exists(p => p.color == color && p.pieceType == PieceType.Rook)
val squaresSafe =
!isAttackedBy(context.board, kf, color.opposite) &&
!isAttackedBy(context.board, km, color.opposite) &&
!isAttackedBy(context.board, kt, color.opposite)
!isAttackedBy(context.board, km, color.opposite) &&
!isAttackedBy(context.board, kt, color.opposite)
if kingPresent && rookPresent && squaresSafe then moves += Move(kf, kt, castlingMove.moveType)
if kingPresent && rookPresent && squaresSafe then
moves += Move(kf, kt, castlingMove.moveType)
private def squaresEmpty(board: Board, squares: List[Square]): Boolean =
squares.forall(sq => board.pieceAt(sq).isEmpty)
@@ -265,26 +252,22 @@ object DefaultRules extends RuleSet:
// ── Pawn ───────────────────────────────────────────────────────────
private def pawnCandidates(
context: GameContext,
from: Square,
color: Color,
context: GameContext,
from: Square,
color: Color
): List[Move] =
val fwd = pawnForward(color)
val fwd = pawnForward(color)
val startRank = pawnStartRank(color)
val promoRank = pawnPromoRank(color)
val single = from.offset(0, fwd).filter(to => context.board.pieceAt(to).isEmpty)
val double = Option
.when(from.rank.ordinal == startRank) {
from.offset(0, fwd).flatMap { mid =>
Option
.when(context.board.pieceAt(mid).isEmpty) {
from.offset(0, fwd * 2).filter(to => context.board.pieceAt(to).isEmpty)
}
.flatten
}
val double = Option.when(from.rank.ordinal == startRank) {
from.offset(0, fwd).flatMap { mid =>
Option.when(context.board.pieceAt(mid).isEmpty) {
from.offset(0, fwd * 2).filter(to => context.board.pieceAt(to).isEmpty)
}.flatten
}
.flatten
}.flatten
val diagonalCaptures = List(-1, 1).flatMap { df =>
from.offset(df, fwd).flatMap { to =>
@@ -303,56 +286,55 @@ object DefaultRules extends RuleSet:
def toMoves(dest: Square, isCapture: Boolean): List[Move] =
if dest.rank.ordinal == promoRank then
List(
PromotionPiece.Queen,
PromotionPiece.Rook,
PromotionPiece.Bishop,
PromotionPiece.Knight,
PromotionPiece.Queen, PromotionPiece.Rook,
PromotionPiece.Bishop, PromotionPiece.Knight
).map(pt => Move(from, dest, MoveType.Promotion(pt)))
else List(Move(from, dest, MoveType.Normal(isCapture = isCapture)))
val stepSquares = single.toList ++ double.toList
val stepMoves = stepSquares.flatMap(dest => toMoves(dest, isCapture = false))
val stepSquares = single.toList ++ double.toList
val stepMoves = stepSquares.flatMap(dest => toMoves(dest, isCapture = false))
val captureMoves = diagonalCaptures.flatMap(dest => toMoves(dest, isCapture = true))
stepMoves ++ captureMoves ++ epCaptures
// ── Check detection ────────────────────────────────────────────────
private def kingSquare(board: Board, color: Color): Option[Square] =
Square.all.find(sq => board.pieceAt(sq).exists(p => p.color == color && p.pieceType == PieceType.King))
/** Cast rays outward from `target` to detect attackers — O(rays) instead of O(64×rays). */
private def isAttackedBy(board: Board, target: Square, attacker: Color): Boolean =
Square.all.exists { sq =>
board.pieceAt(sq).fold(false) { p =>
p.color == attacker && squareAttacks(board, sq, p, target)
}
attackedBySlider(board, target, attacker, RookDirs, PieceType.Rook) ||
attackedBySlider(board, target, attacker, BishopDirs, PieceType.Bishop) ||
attackedByKnight(board, target, attacker) ||
attackedByPawn(board, target, attacker) ||
attackedByKing(board, target, attacker)
private def attackedBySlider(board: Board, target: Square, attacker: Color, dirs: List[(Int, Int)], sliderType: PieceType): Boolean =
dirs.exists { dir =>
@tailrec def loop(sq: Square): Boolean = sq.offset(dir._1, dir._2) match
case None => false
case Some(next) => board.pieceAt(next) match
case None => loop(next)
case Some(p) if p.color == attacker && (p.pieceType == sliderType || p.pieceType == PieceType.Queen) => true
case _ => false
loop(target)
}
private def squareAttacks(board: Board, from: Square, piece: Piece, target: Square): Boolean =
val fwd = pawnForward(piece.color)
piece.pieceType match
case PieceType.Pawn =>
from.offset(-1, fwd).contains(target) || from.offset(1, fwd).contains(target)
case PieceType.Knight =>
KnightJumps.exists((df, dr) => from.offset(df, dr).contains(target))
case PieceType.Bishop => rayReaches(board, from, BishopDirs, target)
case PieceType.Rook => rayReaches(board, from, RookDirs, target)
case PieceType.Queen => rayReaches(board, from, QueenDirs, target)
case PieceType.King =>
QueenDirs.exists((df, dr) => from.offset(df, dr).contains(target))
private def attackedByKnight(board: Board, target: Square, attacker: Color): Boolean =
KnightJumps.exists { (df, dr) =>
target.offset(df, dr).exists(sq => board.pieceAt(sq).exists(p => p.color == attacker && p.pieceType == PieceType.Knight))
}
private def rayReaches(board: Board, from: Square, dirs: List[(Int, Int)], target: Square): Boolean =
dirs.exists { dir =>
@tailrec
def loop(sq: Square): Boolean = sq.offset(dir._1, dir._2) match
case None => false
case Some(next) if next == target => true
case Some(next) if board.pieceAt(next).isEmpty => loop(next)
case Some(_) => false
loop(from)
private def attackedByPawn(board: Board, target: Square, attacker: Color): Boolean =
val dr = if attacker == Color.White then -1 else 1
List(-1, 1).exists { df =>
target.offset(df, dr).exists(sq => board.pieceAt(sq).exists(p => p.color == attacker && p.pieceType == PieceType.Pawn))
}
private def attackedByKing(board: Board, target: Square, attacker: Color): Boolean =
QueenDirs.exists { (df, dr) =>
target.offset(df, dr).exists(sq => board.pieceAt(sq).exists(p => p.color == attacker && p.pieceType == PieceType.King))
}
private def leavesKingInCheck(context: GameContext, move: Move): Boolean =
val nextBoard = context.board.applyMove(move)
val nextBoard = context.board.applyMove(move)
val nextContext = context.withBoard(nextBoard)
isCheck(nextContext)
@@ -360,7 +342,7 @@ object DefaultRules extends RuleSet:
override def applyMove(context: GameContext)(move: Move): GameContext =
val color = context.turn
val board = context.board
val board = context.board
val newBoard = move.moveType match
case MoveType.CastleKingside => applyCastle(board, color, kingside = true)
@@ -369,14 +351,14 @@ object DefaultRules extends RuleSet:
case MoveType.Promotion(pp) => applyPromotion(board, move, color, pp)
case MoveType.Normal(_) => board.applyMove(move)
val newCastlingRights = updateCastlingRights(context.castlingRights, board, move, color)
val newCastlingRights = updateCastlingRights(context.castlingRights, board, move, color)
val newEnPassantSquare = computeEnPassantSquare(board, move)
val isCapture = move.moveType match
case MoveType.Normal(capture) => capture
case MoveType.EnPassant => true
case _ => board.pieceAt(move.to).isDefined
val isPawnMove = board.pieceAt(move.from).exists(_.pieceType == PieceType.Pawn)
val newClock = if isPawnMove || isCapture then 0 else context.halfMoveClock + 1
val newClock = if isPawnMove || isCapture then 0 else context.halfMoveClock + 1
context
.withBoard(newBoard)
@@ -389,18 +371,19 @@ object DefaultRules extends RuleSet:
private def applyCastle(board: Board, color: Color, kingside: Boolean): Board =
val rank = if color == Color.White then Rank.R1 else Rank.R8
val (kingFrom, kingTo, rookFrom, rookTo) =
if kingside then (Square(File.E, rank), Square(File.G, rank), Square(File.H, rank), Square(File.F, rank))
else (Square(File.E, rank), Square(File.C, rank), Square(File.A, rank), Square(File.D, rank))
if kingside then
(Square(File.E, rank), Square(File.G, rank), Square(File.H, rank), Square(File.F, rank))
else
(Square(File.E, rank), Square(File.C, rank), Square(File.A, rank), Square(File.D, rank))
val king = board.pieceAt(kingFrom).getOrElse(Piece(color, PieceType.King))
val rook = board.pieceAt(rookFrom).getOrElse(Piece(color, PieceType.Rook))
board
.removed(kingFrom)
.removed(rookFrom)
.removed(kingFrom).removed(rookFrom)
.updated(kingTo, king)
.updated(rookTo, rook)
private def applyEnPassant(board: Board, move: Move): Board =
val capturedRank = move.from.rank // the captured pawn is on the same rank as the moving pawn
val capturedRank = move.from.rank // the captured pawn is on the same rank as the moving pawn
val capturedSquare = Square(move.to.file, capturedRank)
board.applyMove(move).removed(capturedSquare)
@@ -413,7 +396,7 @@ object DefaultRules extends RuleSet:
board.removed(move.from).updated(move.to, Piece(color, promotedType))
private def updateCastlingRights(rights: CastlingRights, board: Board, move: Move, color: Color): CastlingRights =
val piece = board.pieceAt(move.from)
val piece = board.pieceAt(move.from)
val isKingMove = piece.exists(_.pieceType == PieceType.King)
val isRookMove = piece.exists(_.pieceType == PieceType.Rook)
@@ -423,25 +406,19 @@ object DefaultRules extends RuleSet:
val blackKingsideRook = Square(File.H, Rank.R8)
val blackQueensideRook = Square(File.A, Rank.R8)
val afterKingMove = if isKingMove then rights.revokeColor(color) else rights
val afterRookMove =
if !isRookMove then afterKingMove
else
move.from match
case `whiteKingsideRook` => afterKingMove.revokeKingSide(Color.White)
case `whiteQueensideRook` => afterKingMove.revokeQueenSide(Color.White)
case `blackKingsideRook` => afterKingMove.revokeKingSide(Color.Black)
case `blackQueensideRook` => afterKingMove.revokeQueenSide(Color.Black)
case _ => afterKingMove
var r = rights
if isKingMove then r = r.revokeColor(color)
else if isRookMove then
if move.from == whiteKingsideRook then r = r.revokeKingSide(Color.White)
if move.from == whiteQueensideRook then r = r.revokeQueenSide(Color.White)
if move.from == blackKingsideRook then r = r.revokeKingSide(Color.Black)
if move.from == blackQueensideRook then r = r.revokeQueenSide(Color.Black)
// Also revoke if a rook is captured
move.to match
case `whiteKingsideRook` => afterRookMove.revokeKingSide(Color.White)
case `whiteQueensideRook` => afterRookMove.revokeQueenSide(Color.White)
case `blackKingsideRook` => afterRookMove.revokeKingSide(Color.Black)
case `blackQueensideRook` => afterRookMove.revokeQueenSide(Color.Black)
case _ => afterRookMove
if move.to == whiteKingsideRook then r = r.revokeKingSide(Color.White)
if move.to == whiteQueensideRook then r = r.revokeQueenSide(Color.White)
if move.to == blackKingsideRook then r = r.revokeKingSide(Color.Black)
if move.to == blackQueensideRook then r = r.revokeQueenSide(Color.Black)
r
private def computeEnPassantSquare(board: Board, move: Move): Option[Square] =
val piece = board.pieceAt(move.from)
@@ -455,14 +432,12 @@ object DefaultRules extends RuleSet:
// ── Insufficient material ──────────────────────────────────────────
private def squareColor(sq: Square): Int = (sq.file.ordinal + sq.rank.ordinal) % 2
private def insufficientMaterial(board: Board): Boolean =
val nonKings = board.pieces.toList.filter { case (_, p) => p.pieceType != PieceType.King }
nonKings match
case Nil => true
case List((_, p)) if p.pieceType == PieceType.Bishop || p.pieceType == PieceType.Knight => true
case bishops if bishops.forall { case (_, p) => p.pieceType == PieceType.Bishop } =>
// All non-king pieces are bishops: draw only if they all share the same square color
bishops.map { case (sq, _) => squareColor(sq) }.distinct.sizeIs == 1
val pieces = board.pieces.values.toList.filter(_.pieceType != PieceType.King)
pieces match
case Nil => true
case List(p) if p.pieceType == PieceType.Bishop || p.pieceType == PieceType.Knight => true
case List(p1, p2)
if p1.pieceType == PieceType.Bishop && p2.pieceType == PieceType.Bishop
&& p1.color != p2.color => true
case _ => false