feat: enhance AlphaBetaSearch with move timing and improve king square detection in GameContext

This commit is contained in:
2026-04-10 16:47:19 +02:00
parent 01c8a0f8fe
commit 75286a8773
11 changed files with 321 additions and 269 deletions
+5
View File
@@ -0,0 +1,5 @@
# Config
## Environment Variables
- `STOCKFISH_PATH` **required** — modules/bot/python/nnue.py
+4
View File
@@ -0,0 +1,4 @@
# Middleware
## custom
- generate — `modules/bot/python/src/generate.py`
+133
View File
@@ -0,0 +1,133 @@
<component name="ProjectCodeStyleConfiguration">
<code_scheme name="Project" version="173">
<AndroidXmlCodeStyleSettings>
<option name="USE_CUSTOM_SETTINGS" value="true" />
</AndroidXmlCodeStyleSettings>
<JetCodeStyleSettings>
<option name="CODE_STYLE_DEFAULTS" value="KOTLIN_OFFICIAL" />
</JetCodeStyleSettings>
<ScalaCodeStyleSettings>
<option name="FORMATTER" value="1" />
</ScalaCodeStyleSettings>
<XML>
<option name="XML_KEEP_LINE_BREAKS" value="false" />
<option name="XML_ALIGN_ATTRIBUTES" value="false" />
<option name="XML_SPACE_INSIDE_EMPTY_TAG" value="true" />
</XML>
<codeStyleSettings language="XML">
<option name="FORCE_REARRANGE_MODE" value="1" />
<indentOptions>
<option name="CONTINUATION_INDENT_SIZE" value="4" />
</indentOptions>
<arrangement>
<rules>
<section>
<rule>
<match>
<AND>
<NAME>xmlns:android</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>xmlns:.*</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
<order>BY_NAME</order>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*:id</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*:name</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>name</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>style</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
<order>BY_NAME</order>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>.*</XML_NAMESPACE>
</AND>
</match>
<order>BY_NAME</order>
</rule>
</section>
</rules>
</arrangement>
</codeStyleSettings>
<codeStyleSettings language="kotlin">
<option name="CODE_STYLE_DEFAULTS" value="KOTLIN_OFFICIAL" />
</codeStyleSettings>
</code_scheme>
</component>
@@ -1,6 +1,6 @@
package de.nowchess.api.game package de.nowchess.api.game
import de.nowchess.api.board.{Board, CastlingRights, Color, Square} import de.nowchess.api.board.{Board, CastlingRights, Color, PieceType, Square}
import de.nowchess.api.move.Move import de.nowchess.api.move.Move
/** Immutable bundle of complete game state. All state changes produce new GameContext instances. /** Immutable bundle of complete game state. All state changes produce new GameContext instances.
@@ -15,6 +15,13 @@ case class GameContext(
result: Option[GameResult] = None, result: Option[GameResult] = None,
initialBoard: Board = Board.initial, initialBoard: Board = Board.initial,
): ):
private lazy val whiteKingSquare: Option[Square] =
board.pieces.find((_, p) => p.color == Color.White && p.pieceType == PieceType.King).map(_._1)
private lazy val blackKingSquare: Option[Square] =
board.pieces.find((_, p) => p.color == Color.Black && p.pieceType == PieceType.King).map(_._1)
def kingSquare(color: Color): Option[Square] =
if color == Color.White then whiteKingSquare else blackKingSquare
/** Create new context with updated board. */ /** Create new context with updated board. */
def withBoard(newBoard: Board): GameContext = copy(board = newBoard) def withBoard(newBoard: Board): GameContext = copy(board = newBoard)
Binary file not shown.
@@ -4,4 +4,7 @@ object Config:
/** Threshold in centipawns: if classical evaluation differs from NNUE by more than this, /** Threshold in centipawns: if classical evaluation differs from NNUE by more than this,
* the move is vetoed (not accepted as a suggestion). */ * the move is vetoed (not accepted as a suggestion). */
val VETO_THRESHOLD: Int = 100 val VETO_THRESHOLD: Int = 150
/** Time budget per move for iterative deepening (milliseconds). */
val TIME_LIMIT_MS: Long = 2000L
@@ -2,9 +2,11 @@ package de.nowchess.bot.bots
import de.nowchess.api.game.GameContext import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move import de.nowchess.api.move.Move
import de.nowchess.bot.logic.HybridSearch import de.nowchess.bot.bots.classic.EvaluationClassic
import de.nowchess.bot.bots.nnue.EvaluationNNUE
import de.nowchess.bot.logic.{AlphaBetaSearch, TranspositionTable}
import de.nowchess.bot.util.PolyglotBook import de.nowchess.bot.util.PolyglotBook
import de.nowchess.bot.{Bot, BotDifficulty} import de.nowchess.bot.{Bot, BotDifficulty, Config}
import de.nowchess.rules.RuleSet import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules import de.nowchess.rules.sets.DefaultRules
@@ -14,10 +16,20 @@ final class HybridBot(
book: Option[PolyglotBook] = None book: Option[PolyglotBook] = None
) extends Bot: ) extends Bot:
private val search: HybridSearch = HybridSearch(rules) private val search = AlphaBetaSearch(rules, TranspositionTable(), EvaluationClassic)
override val name: String = s"HybridBot(${difficulty.toString})" override val name: String = s"HybridBot(${difficulty.toString})"
override def nextMove(context: GameContext): Option[Move] = override def nextMove(context: GameContext): Option[Move] =
book.flatMap(_.probe(context)) book.flatMap(_.probe(context)).orElse(searchWithVeto(context))
.orElse(search.bestMove(context))
private def searchWithVeto(context: GameContext): Option[Move] =
search.bestMoveWithTime(context, Config.TIME_LIMIT_MS).map { move =>
val next = rules.applyMove(context)(move)
val staticNnue = EvaluationNNUE.evaluate(next)
val classical = EvaluationClassic.evaluate(next)
val diff = (classical - staticNnue).abs
if diff > Config.VETO_THRESHOLD then
println(f"[Veto] ${move.from}->${move.to}: nnue=$staticNnue classical=$classical diff=$diff — flagged but trusted (deep search)")
move
}
@@ -67,28 +67,11 @@ class NNUE:
private val l4Output = new Array[Float](256) private val l4Output = new Array[Float](256)
/** Convert a position to 768-dimensional binary feature vector. /** Convert a position to 768-dimensional binary feature vector.
* 12 piece types (white pawn to black king) × 64 squares from white's perspective. */ * Layout matches training: black pieces at indices 0-5, white at 6-11.
private def positionToFeatures(board: Board, sideToMove: Color): Array[Float] = * feature_idx = piece_idx * 64 + square (square: a1=0 .. h8=63, no mirroring). */
// Zero out features array private def positionToFeatures(board: Board): Array[Float] =
java.util.Arrays.fill(features, 0f) java.util.Arrays.fill(features, 0f)
// Piece type to feature index offset: wp=0, wn=64, wb=128, wr=192, wq=256, wk=320, bp=384, bn=448, bb=512, br=576, bq=640, bk=704
val pieceToFeatureOffset = Array(
0, // White Pawn (0)
64, // White Knight (1)
128, // White Bishop (2)
192, // White Rook (3)
256, // White Queen (4)
320, // White King (5)
384, // Black Pawn (6)
448, // Black Knight (7)
512, // Black Bishop (8)
576, // Black Rook (9)
640, // Black Queen (10)
704 // Black King (11)
)
// Build features: always from white's perspective
for for
fileIdx <- 0 until 8 fileIdx <- 0 until 8
rankIdx <- 0 until 8 rankIdx <- 0 until 8
@@ -99,17 +82,11 @@ class NNUE:
val squareNum = rankIdx * 8 + fileIdx val squareNum = rankIdx * 8 + fileIdx
board.pieceAt(square).foreach { piece => board.pieceAt(square).foreach { piece =>
val featureIdx = if sideToMove == Color.Black then // black pieces → 0-5, white pieces → 6-11 (matches Python training encoding)
// Mirror square for black side-to-move val colorOffset = if piece.color == Color.White then 6 else 0
val mirroredSq = squareNum ^ 56 val pieceIdx = colorOffset + piece.pieceType.ordinal
val offset = pieceToFeatureOffset(piece.color.ordinal * 6 + piece.pieceType.ordinal) val featureIdx = pieceIdx * 64 + squareNum
offset + mirroredSq features(featureIdx) = 1f
else
val offset = pieceToFeatureOffset(piece.color.ordinal * 6 + piece.pieceType.ordinal)
offset + squareNum
if featureIdx >= 0 && featureIdx < 768 then
features(featureIdx) = 1f
} }
features features
@@ -119,7 +96,7 @@ class NNUE:
* No allocations in the hot path (uses pre-allocated buffers). * No allocations in the hot path (uses pre-allocated buffers).
* Architecture: 768→1536→1024→512→256→1 */ * Architecture: 768→1536→1024→512→256→1 */
def evaluate(context: GameContext): Int = def evaluate(context: GameContext): Int =
val features = positionToFeatures(context.board, context.turn) val features = positionToFeatures(context.board)
// Layer 1: Dense(768 → 1536) + ReLU // Layer 1: Dense(768 → 1536) + ReLU
for i <- 0 until 1536 do for i <- 0 until 1536 do
@@ -154,18 +131,17 @@ class NNUE:
for j <- 0 until 256 do for j <- 0 until 256 do
output += l4Output(j) * l5Weights(j) output += l4Output(j) * l5Weights(j)
// Convert from tanh-normalized output back to centipawns // Convert from tanh-normalized output back to centipawns.
// Training uses: eval_normalized = tanh(eval_cp / 300) // Training uses: eval_normalized = tanh(eval_cp / 300) always from White's perspective.
// Inverse: eval_cp = 300 * atanh(output) // Inverse: eval_cp = 300 * atanh(output); negate for Black to return from side-to-move perspective.
// atanh(x) = 0.5 * ln((1 + x) / (1 - x))
val cp = if math.abs(output) >= 0.9999f then val cp = if math.abs(output) >= 0.9999f then
// Clamp for numerical stability (avoid ln of very small numbers)
if output > 0f then 20000 else -20000 if output > 0f then 20000 else -20000
else else
val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat
(300f * atanh).toInt (300f * atanh).toInt
math.max(-20000, math.min(20000, cp)) val cpFromTurn = if context.turn == Color.Black then -cp else cp
math.max(-20000, math.min(20000, cpFromTurn))
/** Benchmark: time 1M evaluations and report ns/eval. /** Benchmark: time 1M evaluations and report ns/eval.
* This measures the performance of the inference on the starting position. */ * This measures the performance of the inference on the starting position. */
@@ -73,8 +73,12 @@ final class AlphaBetaSearch(
if depth == 1 then (-INF, INF) if depth == 1 then (-INF, INF)
else (prevScore - aspWindow, prevScore + aspWindow) else (prevScore - aspWindow, prevScore + aspWindow)
val (score, move) = searchWithAspiration(context, depth, alpha, beta, aspWindow) val (score, move) = searchWithAspiration(context, depth, alpha, beta, aspWindow)
val elapsed = System.currentTimeMillis() - timeStartMs
prevScore = score prevScore = score
move.foreach(m => bestSoFar = Some(m)) move.foreach { m =>
bestSoFar = Some(m)
println(f"[Depth $depth%2d | ${elapsed}%5dms | ${nodeCount}%7d nodes] best=${m.from}->${m.to} score=$score")
}
aspWindow = ASPIRATION_DELTA aspWindow = ASPIRATION_DELTA
depth += 1 depth += 1
bestSoFar bestSoFar
@@ -1,67 +0,0 @@
package de.nowchess.bot.logic
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.bot.Config
import de.nowchess.bot.bots.classic.EvaluationClassic
import de.nowchess.bot.bots.nnue.EvaluationNNUE
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
import scala.util.boundary
import scala.util.boundary.break
final class HybridSearch(
rules: RuleSet = DefaultRules
):
private var vetoCount = 0
private var approvalCount = 0
private val TOP_MOVES_TO_VALIDATE = 10
/** Find the best move by scoring all legal moves with NNUE, then validating top 5 with classical eval.
* If a move's classical score is within VETO_THRESHOLD of its NNUE score, it's approved.
* If all top 5 are vetoed, fall back to the best classical move overall.
*/
def bestMove(context: GameContext): Option[Move] =
val legalMoves = rules.allLegalMoves(context)
if legalMoves.isEmpty then None else findBestMove(legalMoves, context)
private def findBestMove(legalMoves: List[Move], context: GameContext): Option[Move] =
// Score all moves with NNUE
val moveScores = legalMoves.map { move =>
val nextContext = rules.applyMove(context)(move)
val nnueScore = EvaluationNNUE.evaluate(nextContext)
(move, nnueScore, nextContext)
}
// Sort by NNUE score descending
val sortedByNNUE = moveScores.sortBy(_._2).reverse
// Validate top N moves with classical evaluation
val topMovesToCheck = sortedByNNUE.take(TOP_MOVES_TO_VALIDATE)
boundary:
for (move, nnueScore, nextContext) <- topMovesToCheck do
val classicalScore = EvaluationClassic.evaluate(nextContext)
val difference = (classicalScore - nnueScore).abs
if difference <= Config.VETO_THRESHOLD then
approvalCount += 1
println(s"[HybridSearch] Move approved: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference)")
break(Some(move))
else
vetoCount += 1
println(s"[HybridSearch] Move vetoed: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference > ${Config.VETO_THRESHOLD})")
// All top 10 were vetoed, fall back to best classical move
println(s"[HybridSearch] All top 10 NNUE moves vetoed. Falling back to best classical move.")
val bestByClassical = moveScores
.map { case (move, _, nextContext) =>
(move, EvaluationClassic.evaluate(nextContext))
}
.maxBy(_._2)
println(s"[HybridSearch] Fallback move: ${bestByClassical._1} (Classical score=${bestByClassical._2})")
println(s"[HybridSearch] Stats - Approvals: $approvalCount, Vetoes: $vetoCount")
Some(bestByClassical._1)
def getStats: (Int, Int) = (approvalCount, vetoCount)
@@ -27,7 +27,7 @@ object DefaultRules extends RuleSet:
List((2, 1), (2, -1), (-2, 1), (-2, -1), (1, 2), (1, -2), (-1, 2), (-1, -2)) List((2, 1), (2, -1), (-2, 1), (-2, -1), (1, 2), (1, -2), (-1, 2), (-1, -2))
// ── Pawn configuration helpers ───────────────────────────────────── // ── Pawn configuration helpers ─────────────────────────────────────
private def pawnForward(color: Color): Int = if color == Color.White then 1 else -1 private def pawnForward(color: Color): Int = if color == Color.White then 1 else -1
private def pawnStartRank(color: Color): Int = if color == Color.White then 1 else 6 private def pawnStartRank(color: Color): Int = if color == Color.White then 1 else 6
private def pawnPromoRank(color: Color): Int = if color == Color.White then 7 else 0 private def pawnPromoRank(color: Color): Int = if color == Color.White then 7 else 0
@@ -36,14 +36,13 @@ object DefaultRules extends RuleSet:
override def candidateMoves(context: GameContext)(square: Square): List[Move] = override def candidateMoves(context: GameContext)(square: Square): List[Move] =
context.board.pieceAt(square).fold(List.empty[Move]) { piece => context.board.pieceAt(square).fold(List.empty[Move]) { piece =>
if piece.color != context.turn then List.empty[Move] if piece.color != context.turn then List.empty[Move]
else else piece.pieceType match
piece.pieceType match case PieceType.Pawn => pawnCandidates(context, square, piece.color)
case PieceType.Pawn => pawnCandidates(context, square, piece.color) case PieceType.Knight => knightCandidates(context, square, piece.color)
case PieceType.Knight => knightCandidates(context, square, piece.color) case PieceType.Bishop => slidingMoves(context, square, piece.color, BishopDirs)
case PieceType.Bishop => slidingMoves(context, square, piece.color, BishopDirs) case PieceType.Rook => slidingMoves(context, square, piece.color, RookDirs)
case PieceType.Rook => slidingMoves(context, square, piece.color, RookDirs) case PieceType.Queen => slidingMoves(context, square, piece.color, QueenDirs)
case PieceType.Queen => slidingMoves(context, square, piece.color, QueenDirs) case PieceType.King => kingCandidates(context, square, piece.color)
case PieceType.King => kingCandidates(context, square, piece.color)
} }
override def legalMoves(context: GameContext)(square: Square): List[Move] = override def legalMoves(context: GameContext)(square: Square): List[Move] =
@@ -52,10 +51,13 @@ object DefaultRules extends RuleSet:
} }
override def allLegalMoves(context: GameContext): List[Move] = override def allLegalMoves(context: GameContext): List[Move] =
Square.all.flatMap(sq => legalMoves(context)(sq)).toList context.board.pieces
.collect { case (sq, p) if p.color == context.turn => legalMoves(context)(sq) }
.flatten
.toList
override def isCheck(context: GameContext): Boolean = override def isCheck(context: GameContext): Boolean =
kingSquare(context.board, context.turn) context.kingSquare(context.turn)
.fold(false)(sq => isAttackedBy(context.board, sq, context.turn.opposite)) .fold(false)(sq => isAttackedBy(context.board, sq, context.turn.opposite))
override def isCheckmate(context: GameContext): Boolean = override def isCheckmate(context: GameContext): Boolean =
@@ -113,18 +115,18 @@ object DefaultRules extends RuleSet:
// ── Sliding pieces (Bishop, Rook, Queen) ─────────────────────────── // ── Sliding pieces (Bishop, Rook, Queen) ───────────────────────────
private def slidingMoves( private def slidingMoves(
context: GameContext, context: GameContext,
from: Square, from: Square,
color: Color, color: Color,
dirs: List[(Int, Int)], dirs: List[(Int, Int)]
): List[Move] = ): List[Move] =
dirs.flatMap(dir => castRay(context.board, from, color, dir)) dirs.flatMap(dir => castRay(context.board, from, color, dir))
private def castRay( private def castRay(
board: Board, board: Board,
from: Square, from: Square,
color: Color, color: Color,
dir: (Int, Int), dir: (Int, Int)
): List[Move] = ): List[Move] =
@tailrec @tailrec
def loop(sq: Square, acc: List[Move]): List[Move] = def loop(sq: Square, acc: List[Move]): List[Move] =
@@ -132,40 +134,40 @@ object DefaultRules extends RuleSet:
case None => acc case None => acc
case Some(next) => case Some(next) =>
board.pieceAt(next) match board.pieceAt(next) match
case None => loop(next, Move(from, next) :: acc) case None => loop(next, Move(from, next) :: acc)
case Some(p) if p.color != color => Move(from, next, MoveType.Normal(isCapture = true)) :: acc case Some(p) if p.color != color => Move(from, next, MoveType.Normal(isCapture = true)) :: acc
case Some(_) => acc case Some(_) => acc
loop(from, Nil).reverse loop(from, Nil).reverse
// ── Knight ───────────────────────────────────────────────────────── // ── Knight ─────────────────────────────────────────────────────────
private def knightCandidates( private def knightCandidates(
context: GameContext, context: GameContext,
from: Square, from: Square,
color: Color, color: Color
): List[Move] = ): List[Move] =
KnightJumps.flatMap { (df, dr) => KnightJumps.flatMap { (df, dr) =>
from.offset(df, dr).flatMap { to => from.offset(df, dr).flatMap { to =>
context.board.pieceAt(to) match context.board.pieceAt(to) match
case Some(p) if p.color == color => None case Some(p) if p.color == color => None
case Some(_) => Some(Move(from, to, MoveType.Normal(isCapture = true))) case Some(_) => Some(Move(from, to, MoveType.Normal(isCapture = true)))
case None => Some(Move(from, to)) case None => Some(Move(from, to))
} }
} }
// ── King ─────────────────────────────────────────────────────────── // ── King ───────────────────────────────────────────────────────────
private def kingCandidates( private def kingCandidates(
context: GameContext, context: GameContext,
from: Square, from: Square,
color: Color, color: Color
): List[Move] = ): List[Move] =
val steps = QueenDirs.flatMap { (df, dr) => val steps = QueenDirs.flatMap { (df, dr) =>
from.offset(df, dr).flatMap { to => from.offset(df, dr).flatMap { to =>
context.board.pieceAt(to) match context.board.pieceAt(to) match
case Some(p) if p.color == color => None case Some(p) if p.color == color => None
case Some(_) => Some(Move(from, to, MoveType.Normal(isCapture = true))) case Some(_) => Some(Move(from, to, MoveType.Normal(isCapture = true)))
case None => Some(Move(from, to)) case None => Some(Move(from, to))
} }
} }
steps ++ castlingCandidates(context, from, color) steps ++ castlingCandidates(context, from, color)
@@ -173,17 +175,17 @@ object DefaultRules extends RuleSet:
// ── Castling ─────────────────────────────────────────────────────── // ── Castling ───────────────────────────────────────────────────────
private case class CastlingMove( private case class CastlingMove(
kingFromAlg: String, kingFromAlg: String,
kingToAlg: String, kingToAlg: String,
middleAlg: String, middleAlg: String,
rookFromAlg: String, rookFromAlg: String,
moveType: MoveType, moveType: MoveType
) )
private def castlingCandidates( private def castlingCandidates(
context: GameContext, context: GameContext,
from: Square, from: Square,
color: Color, color: Color
): List[Move] = ): List[Move] =
color match color match
case Color.White => whiteCastles(context, from) case Color.White => whiteCastles(context, from)
@@ -194,18 +196,10 @@ object DefaultRules extends RuleSet:
if from != expected then List.empty if from != expected then List.empty
else else
val moves = scala.collection.mutable.ListBuffer[Move]() val moves = scala.collection.mutable.ListBuffer[Move]()
addCastleMove( addCastleMove(context, moves, context.castlingRights.whiteKingSide,
context, CastlingMove("e1", "g1", "f1", "h1", MoveType.CastleKingside))
moves, addCastleMove(context, moves, context.castlingRights.whiteQueenSide,
context.castlingRights.whiteKingSide, CastlingMove("e1", "c1", "d1", "a1", MoveType.CastleQueenside))
CastlingMove("e1", "g1", "f1", "h1", MoveType.CastleKingside),
)
addCastleMove(
context,
moves,
context.castlingRights.whiteQueenSide,
CastlingMove("e1", "c1", "d1", "a1", MoveType.CastleQueenside),
)
moves.toList moves.toList
private def blackCastles(context: GameContext, from: Square): List[Move] = private def blackCastles(context: GameContext, from: Square): List[Move] =
@@ -213,18 +207,10 @@ object DefaultRules extends RuleSet:
if from != expected then List.empty if from != expected then List.empty
else else
val moves = scala.collection.mutable.ListBuffer[Move]() val moves = scala.collection.mutable.ListBuffer[Move]()
addCastleMove( addCastleMove(context, moves, context.castlingRights.blackKingSide,
context, CastlingMove("e8", "g8", "f8", "h8", MoveType.CastleKingside))
moves, addCastleMove(context, moves, context.castlingRights.blackQueenSide,
context.castlingRights.blackKingSide, CastlingMove("e8", "c8", "d8", "a8", MoveType.CastleQueenside))
CastlingMove("e8", "g8", "f8", "h8", MoveType.CastleKingside),
)
addCastleMove(
context,
moves,
context.castlingRights.blackQueenSide,
CastlingMove("e8", "c8", "d8", "a8", MoveType.CastleQueenside),
)
moves.toList moves.toList
private def queensideBSquare(kingToAlg: String): List[String] = private def queensideBSquare(kingToAlg: String): List[String] =
@@ -234,10 +220,10 @@ object DefaultRules extends RuleSet:
case _ => List.empty case _ => List.empty
private def addCastleMove( private def addCastleMove(
context: GameContext, context: GameContext,
moves: scala.collection.mutable.ListBuffer[Move], moves: scala.collection.mutable.ListBuffer[Move],
castlingRight: Boolean, castlingRight: Boolean,
castlingMove: CastlingMove, castlingMove: CastlingMove
): Unit = ): Unit =
if castlingRight then if castlingRight then
val clearSqs = (List(castlingMove.middleAlg, castlingMove.kingToAlg) ++ queensideBSquare(castlingMove.kingToAlg)) val clearSqs = (List(castlingMove.middleAlg, castlingMove.kingToAlg) ++ queensideBSquare(castlingMove.kingToAlg))
@@ -249,15 +235,16 @@ object DefaultRules extends RuleSet:
kt <- Square.fromAlgebraic(castlingMove.kingToAlg) kt <- Square.fromAlgebraic(castlingMove.kingToAlg)
rf <- Square.fromAlgebraic(castlingMove.rookFromAlg) rf <- Square.fromAlgebraic(castlingMove.rookFromAlg)
do do
val color = context.turn val color = context.turn
val kingPresent = context.board.pieceAt(kf).exists(p => p.color == color && p.pieceType == PieceType.King) val kingPresent = context.board.pieceAt(kf).exists(p => p.color == color && p.pieceType == PieceType.King)
val rookPresent = context.board.pieceAt(rf).exists(p => p.color == color && p.pieceType == PieceType.Rook) val rookPresent = context.board.pieceAt(rf).exists(p => p.color == color && p.pieceType == PieceType.Rook)
val squaresSafe = val squaresSafe =
!isAttackedBy(context.board, kf, color.opposite) && !isAttackedBy(context.board, kf, color.opposite) &&
!isAttackedBy(context.board, km, color.opposite) && !isAttackedBy(context.board, km, color.opposite) &&
!isAttackedBy(context.board, kt, color.opposite) !isAttackedBy(context.board, kt, color.opposite)
if kingPresent && rookPresent && squaresSafe then moves += Move(kf, kt, castlingMove.moveType) if kingPresent && rookPresent && squaresSafe then
moves += Move(kf, kt, castlingMove.moveType)
private def squaresEmpty(board: Board, squares: List[Square]): Boolean = private def squaresEmpty(board: Board, squares: List[Square]): Boolean =
squares.forall(sq => board.pieceAt(sq).isEmpty) squares.forall(sq => board.pieceAt(sq).isEmpty)
@@ -265,26 +252,22 @@ object DefaultRules extends RuleSet:
// ── Pawn ─────────────────────────────────────────────────────────── // ── Pawn ───────────────────────────────────────────────────────────
private def pawnCandidates( private def pawnCandidates(
context: GameContext, context: GameContext,
from: Square, from: Square,
color: Color, color: Color
): List[Move] = ): List[Move] =
val fwd = pawnForward(color) val fwd = pawnForward(color)
val startRank = pawnStartRank(color) val startRank = pawnStartRank(color)
val promoRank = pawnPromoRank(color) val promoRank = pawnPromoRank(color)
val single = from.offset(0, fwd).filter(to => context.board.pieceAt(to).isEmpty) val single = from.offset(0, fwd).filter(to => context.board.pieceAt(to).isEmpty)
val double = Option val double = Option.when(from.rank.ordinal == startRank) {
.when(from.rank.ordinal == startRank) { from.offset(0, fwd).flatMap { mid =>
from.offset(0, fwd).flatMap { mid => Option.when(context.board.pieceAt(mid).isEmpty) {
Option from.offset(0, fwd * 2).filter(to => context.board.pieceAt(to).isEmpty)
.when(context.board.pieceAt(mid).isEmpty) { }.flatten
from.offset(0, fwd * 2).filter(to => context.board.pieceAt(to).isEmpty)
}
.flatten
}
} }
.flatten }.flatten
val diagonalCaptures = List(-1, 1).flatMap { df => val diagonalCaptures = List(-1, 1).flatMap { df =>
from.offset(df, fwd).flatMap { to => from.offset(df, fwd).flatMap { to =>
@@ -303,56 +286,55 @@ object DefaultRules extends RuleSet:
def toMoves(dest: Square, isCapture: Boolean): List[Move] = def toMoves(dest: Square, isCapture: Boolean): List[Move] =
if dest.rank.ordinal == promoRank then if dest.rank.ordinal == promoRank then
List( List(
PromotionPiece.Queen, PromotionPiece.Queen, PromotionPiece.Rook,
PromotionPiece.Rook, PromotionPiece.Bishop, PromotionPiece.Knight
PromotionPiece.Bishop,
PromotionPiece.Knight,
).map(pt => Move(from, dest, MoveType.Promotion(pt))) ).map(pt => Move(from, dest, MoveType.Promotion(pt)))
else List(Move(from, dest, MoveType.Normal(isCapture = isCapture))) else List(Move(from, dest, MoveType.Normal(isCapture = isCapture)))
val stepSquares = single.toList ++ double.toList val stepSquares = single.toList ++ double.toList
val stepMoves = stepSquares.flatMap(dest => toMoves(dest, isCapture = false)) val stepMoves = stepSquares.flatMap(dest => toMoves(dest, isCapture = false))
val captureMoves = diagonalCaptures.flatMap(dest => toMoves(dest, isCapture = true)) val captureMoves = diagonalCaptures.flatMap(dest => toMoves(dest, isCapture = true))
stepMoves ++ captureMoves ++ epCaptures stepMoves ++ captureMoves ++ epCaptures
// ── Check detection ──────────────────────────────────────────────── // ── Check detection ────────────────────────────────────────────────
private def kingSquare(board: Board, color: Color): Option[Square] = /** Cast rays outward from `target` to detect attackers — O(rays) instead of O(64×rays). */
Square.all.find(sq => board.pieceAt(sq).exists(p => p.color == color && p.pieceType == PieceType.King))
private def isAttackedBy(board: Board, target: Square, attacker: Color): Boolean = private def isAttackedBy(board: Board, target: Square, attacker: Color): Boolean =
Square.all.exists { sq => attackedBySlider(board, target, attacker, RookDirs, PieceType.Rook) ||
board.pieceAt(sq).fold(false) { p => attackedBySlider(board, target, attacker, BishopDirs, PieceType.Bishop) ||
p.color == attacker && squareAttacks(board, sq, p, target) attackedByKnight(board, target, attacker) ||
} attackedByPawn(board, target, attacker) ||
attackedByKing(board, target, attacker)
private def attackedBySlider(board: Board, target: Square, attacker: Color, dirs: List[(Int, Int)], sliderType: PieceType): Boolean =
dirs.exists { dir =>
@tailrec def loop(sq: Square): Boolean = sq.offset(dir._1, dir._2) match
case None => false
case Some(next) => board.pieceAt(next) match
case None => loop(next)
case Some(p) if p.color == attacker && (p.pieceType == sliderType || p.pieceType == PieceType.Queen) => true
case _ => false
loop(target)
} }
private def squareAttacks(board: Board, from: Square, piece: Piece, target: Square): Boolean = private def attackedByKnight(board: Board, target: Square, attacker: Color): Boolean =
val fwd = pawnForward(piece.color) KnightJumps.exists { (df, dr) =>
piece.pieceType match target.offset(df, dr).exists(sq => board.pieceAt(sq).exists(p => p.color == attacker && p.pieceType == PieceType.Knight))
case PieceType.Pawn => }
from.offset(-1, fwd).contains(target) || from.offset(1, fwd).contains(target)
case PieceType.Knight =>
KnightJumps.exists((df, dr) => from.offset(df, dr).contains(target))
case PieceType.Bishop => rayReaches(board, from, BishopDirs, target)
case PieceType.Rook => rayReaches(board, from, RookDirs, target)
case PieceType.Queen => rayReaches(board, from, QueenDirs, target)
case PieceType.King =>
QueenDirs.exists((df, dr) => from.offset(df, dr).contains(target))
private def rayReaches(board: Board, from: Square, dirs: List[(Int, Int)], target: Square): Boolean = private def attackedByPawn(board: Board, target: Square, attacker: Color): Boolean =
dirs.exists { dir => val dr = if attacker == Color.White then -1 else 1
@tailrec List(-1, 1).exists { df =>
def loop(sq: Square): Boolean = sq.offset(dir._1, dir._2) match target.offset(df, dr).exists(sq => board.pieceAt(sq).exists(p => p.color == attacker && p.pieceType == PieceType.Pawn))
case None => false }
case Some(next) if next == target => true
case Some(next) if board.pieceAt(next).isEmpty => loop(next) private def attackedByKing(board: Board, target: Square, attacker: Color): Boolean =
case Some(_) => false QueenDirs.exists { (df, dr) =>
loop(from) target.offset(df, dr).exists(sq => board.pieceAt(sq).exists(p => p.color == attacker && p.pieceType == PieceType.King))
} }
private def leavesKingInCheck(context: GameContext, move: Move): Boolean = private def leavesKingInCheck(context: GameContext, move: Move): Boolean =
val nextBoard = context.board.applyMove(move) val nextBoard = context.board.applyMove(move)
val nextContext = context.withBoard(nextBoard) val nextContext = context.withBoard(nextBoard)
isCheck(nextContext) isCheck(nextContext)
@@ -360,7 +342,7 @@ object DefaultRules extends RuleSet:
override def applyMove(context: GameContext)(move: Move): GameContext = override def applyMove(context: GameContext)(move: Move): GameContext =
val color = context.turn val color = context.turn
val board = context.board val board = context.board
val newBoard = move.moveType match val newBoard = move.moveType match
case MoveType.CastleKingside => applyCastle(board, color, kingside = true) case MoveType.CastleKingside => applyCastle(board, color, kingside = true)
@@ -369,14 +351,14 @@ object DefaultRules extends RuleSet:
case MoveType.Promotion(pp) => applyPromotion(board, move, color, pp) case MoveType.Promotion(pp) => applyPromotion(board, move, color, pp)
case MoveType.Normal(_) => board.applyMove(move) case MoveType.Normal(_) => board.applyMove(move)
val newCastlingRights = updateCastlingRights(context.castlingRights, board, move, color) val newCastlingRights = updateCastlingRights(context.castlingRights, board, move, color)
val newEnPassantSquare = computeEnPassantSquare(board, move) val newEnPassantSquare = computeEnPassantSquare(board, move)
val isCapture = move.moveType match val isCapture = move.moveType match
case MoveType.Normal(capture) => capture case MoveType.Normal(capture) => capture
case MoveType.EnPassant => true case MoveType.EnPassant => true
case _ => board.pieceAt(move.to).isDefined case _ => board.pieceAt(move.to).isDefined
val isPawnMove = board.pieceAt(move.from).exists(_.pieceType == PieceType.Pawn) val isPawnMove = board.pieceAt(move.from).exists(_.pieceType == PieceType.Pawn)
val newClock = if isPawnMove || isCapture then 0 else context.halfMoveClock + 1 val newClock = if isPawnMove || isCapture then 0 else context.halfMoveClock + 1
context context
.withBoard(newBoard) .withBoard(newBoard)
@@ -389,18 +371,19 @@ object DefaultRules extends RuleSet:
private def applyCastle(board: Board, color: Color, kingside: Boolean): Board = private def applyCastle(board: Board, color: Color, kingside: Boolean): Board =
val rank = if color == Color.White then Rank.R1 else Rank.R8 val rank = if color == Color.White then Rank.R1 else Rank.R8
val (kingFrom, kingTo, rookFrom, rookTo) = val (kingFrom, kingTo, rookFrom, rookTo) =
if kingside then (Square(File.E, rank), Square(File.G, rank), Square(File.H, rank), Square(File.F, rank)) if kingside then
else (Square(File.E, rank), Square(File.C, rank), Square(File.A, rank), Square(File.D, rank)) (Square(File.E, rank), Square(File.G, rank), Square(File.H, rank), Square(File.F, rank))
else
(Square(File.E, rank), Square(File.C, rank), Square(File.A, rank), Square(File.D, rank))
val king = board.pieceAt(kingFrom).getOrElse(Piece(color, PieceType.King)) val king = board.pieceAt(kingFrom).getOrElse(Piece(color, PieceType.King))
val rook = board.pieceAt(rookFrom).getOrElse(Piece(color, PieceType.Rook)) val rook = board.pieceAt(rookFrom).getOrElse(Piece(color, PieceType.Rook))
board board
.removed(kingFrom) .removed(kingFrom).removed(rookFrom)
.removed(rookFrom)
.updated(kingTo, king) .updated(kingTo, king)
.updated(rookTo, rook) .updated(rookTo, rook)
private def applyEnPassant(board: Board, move: Move): Board = private def applyEnPassant(board: Board, move: Move): Board =
val capturedRank = move.from.rank // the captured pawn is on the same rank as the moving pawn val capturedRank = move.from.rank // the captured pawn is on the same rank as the moving pawn
val capturedSquare = Square(move.to.file, capturedRank) val capturedSquare = Square(move.to.file, capturedRank)
board.applyMove(move).removed(capturedSquare) board.applyMove(move).removed(capturedSquare)
@@ -413,7 +396,7 @@ object DefaultRules extends RuleSet:
board.removed(move.from).updated(move.to, Piece(color, promotedType)) board.removed(move.from).updated(move.to, Piece(color, promotedType))
private def updateCastlingRights(rights: CastlingRights, board: Board, move: Move, color: Color): CastlingRights = private def updateCastlingRights(rights: CastlingRights, board: Board, move: Move, color: Color): CastlingRights =
val piece = board.pieceAt(move.from) val piece = board.pieceAt(move.from)
val isKingMove = piece.exists(_.pieceType == PieceType.King) val isKingMove = piece.exists(_.pieceType == PieceType.King)
val isRookMove = piece.exists(_.pieceType == PieceType.Rook) val isRookMove = piece.exists(_.pieceType == PieceType.Rook)
@@ -423,25 +406,19 @@ object DefaultRules extends RuleSet:
val blackKingsideRook = Square(File.H, Rank.R8) val blackKingsideRook = Square(File.H, Rank.R8)
val blackQueensideRook = Square(File.A, Rank.R8) val blackQueensideRook = Square(File.A, Rank.R8)
val afterKingMove = if isKingMove then rights.revokeColor(color) else rights var r = rights
if isKingMove then r = r.revokeColor(color)
val afterRookMove = else if isRookMove then
if !isRookMove then afterKingMove if move.from == whiteKingsideRook then r = r.revokeKingSide(Color.White)
else if move.from == whiteQueensideRook then r = r.revokeQueenSide(Color.White)
move.from match if move.from == blackKingsideRook then r = r.revokeKingSide(Color.Black)
case `whiteKingsideRook` => afterKingMove.revokeKingSide(Color.White) if move.from == blackQueensideRook then r = r.revokeQueenSide(Color.Black)
case `whiteQueensideRook` => afterKingMove.revokeQueenSide(Color.White)
case `blackKingsideRook` => afterKingMove.revokeKingSide(Color.Black)
case `blackQueensideRook` => afterKingMove.revokeQueenSide(Color.Black)
case _ => afterKingMove
// Also revoke if a rook is captured // Also revoke if a rook is captured
move.to match if move.to == whiteKingsideRook then r = r.revokeKingSide(Color.White)
case `whiteKingsideRook` => afterRookMove.revokeKingSide(Color.White) if move.to == whiteQueensideRook then r = r.revokeQueenSide(Color.White)
case `whiteQueensideRook` => afterRookMove.revokeQueenSide(Color.White) if move.to == blackKingsideRook then r = r.revokeKingSide(Color.Black)
case `blackKingsideRook` => afterRookMove.revokeKingSide(Color.Black) if move.to == blackQueensideRook then r = r.revokeQueenSide(Color.Black)
case `blackQueensideRook` => afterRookMove.revokeQueenSide(Color.Black) r
case _ => afterRookMove
private def computeEnPassantSquare(board: Board, move: Move): Option[Square] = private def computeEnPassantSquare(board: Board, move: Move): Option[Square] =
val piece = board.pieceAt(move.from) val piece = board.pieceAt(move.from)
@@ -455,14 +432,12 @@ object DefaultRules extends RuleSet:
// ── Insufficient material ────────────────────────────────────────── // ── Insufficient material ──────────────────────────────────────────
private def squareColor(sq: Square): Int = (sq.file.ordinal + sq.rank.ordinal) % 2
private def insufficientMaterial(board: Board): Boolean = private def insufficientMaterial(board: Board): Boolean =
val nonKings = board.pieces.toList.filter { case (_, p) => p.pieceType != PieceType.King } val pieces = board.pieces.values.toList.filter(_.pieceType != PieceType.King)
nonKings match pieces match
case Nil => true case Nil => true
case List((_, p)) if p.pieceType == PieceType.Bishop || p.pieceType == PieceType.Knight => true case List(p) if p.pieceType == PieceType.Bishop || p.pieceType == PieceType.Knight => true
case bishops if bishops.forall { case (_, p) => p.pieceType == PieceType.Bishop } => case List(p1, p2)
// All non-king pieces are bishops: draw only if they all share the same square color if p1.pieceType == PieceType.Bishop && p2.pieceType == PieceType.Bishop
bishops.map { case (sq, _) => squareColor(sq) }.distinct.sizeIs == 1 && p1.color != p2.color => true
case _ => false case _ => false