feat: Improve code formatting and readability in AlphaBetaSearch and related tests
Build & Test (NowChessSystems) TeamCity build failed

This commit is contained in:
2026-04-17 21:57:25 +02:00
parent 473c62666a
commit 0f6e32cf08
24 changed files with 352 additions and 268 deletions
@@ -5,5 +5,4 @@ import de.nowchess.api.player.PlayerInfo
sealed trait Participant sealed trait Participant
final case class Human(playerInfo: PlayerInfo) extends Participant final case class Human(playerInfo: PlayerInfo) extends Participant
final case class BotParticipant(bot: Bot) extends Participant final case class BotParticipant(bot: Bot) extends Participant
@@ -17,4 +17,3 @@ object BotMoveRepetition:
def filterAllowed(context: GameContext, moves: List[Move]): List[Move] = def filterAllowed(context: GameContext, moves: List[Move]): List[Move] =
val blocked = blockedMoves(context) val blocked = blockedMoves(context)
moves.filterNot(blocked.contains) moves.filterNot(blocked.contains)
@@ -24,10 +24,8 @@ object EvaluationNNUE extends Evaluation:
override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit = override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit =
// Use incremental updates, but recompute from scratch every 10 plies to prevent accumulation errors // Use incremental updates, but recompute from scratch every 10 plies to prevent accumulation errors
if (childPly % 10 == 0) then if childPly % 10 == 0 then nnue.recomputeAccumulator(childPly, child.board)
nnue.recomputeAccumulator(childPly, child.board) else nnue.pushAccumulator(childPly, move, parent.board)
else
nnue.pushAccumulator(childPly, move, parent.board)
override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int = override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int =
nnue.evaluateAtPlyWithValidation(ply, context.turn, hash, context.board) nnue.evaluateAtPlyWithValidation(ply, context.turn, hash, context.board)
@@ -1,14 +1,14 @@
package de.nowchess.bot.bots.nnue package de.nowchess.bot.bots.nnue
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square} import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Square}
import de.nowchess.api.game.GameContext import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece} import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
class NNUE(model: NbaiModel): class NNUE(model: NbaiModel):
private val featureSize = model.layers(0).inputSize private val featureSize = model.layers(0).inputSize
private val accSize = model.layers(0).outputSize private val accSize = model.layers(0).outputSize
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1 private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
// Column-major L1 weights for cache-friendly sparse & incremental updates. // Column-major L1 weights for cache-friendly sparse & incremental updates.
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx) // l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
@@ -60,10 +60,10 @@ class NNUE(model: NbaiModel):
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize) System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize)
val l1 = l1Stack(childPly) val l1 = l1Stack(childPly)
move.moveType match move.moveType match
case MoveType.Normal(_) => applyNormalDelta(l1, move, board) case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board) case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board) case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board) case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
def copyAccumulator(parentPly: Int, childPly: Int): Unit = def copyAccumulator(parentPly: Int, childPly: Int): Unit =
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize) System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
@@ -80,12 +80,13 @@ class NNUE(model: NbaiModel):
// Compare with actual L1 // Compare with actual L1
val actual = l1Stack(ply) val actual = l1Stack(ply)
var maxError = 0f val maxError =
for i <- 0 until accSize do (0 until accSize).foldLeft(0f) { (currentMax, i) =>
val error = math.abs(actual(i) - expectedL1(i)) val error = math.abs(actual(i) - expectedL1(i))
if error > maxError then maxError = error math.max(currentMax, error)
}
maxError < 0.001f // Allow small floating-point errors maxError < 0.001f // Allow small floating-point errors
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit = private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
// Extract source and destination square indices early // Extract source and destination square indices early
@@ -156,24 +157,24 @@ class NNUE(model: NbaiModel):
// For debugging: validate that incremental accumulator matches recomputation // For debugging: validate that incremental accumulator matches recomputation
if validateAccum && ply > 0 && ply % 10 != 0 then if validateAccum && ply > 0 && ply % 10 != 0 then
val isValid = validateAccumulator(ply, board) val isValid = validateAccumulator(ply, board)
if !isValid then if !isValid then System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
evaluateAtPly(ply, turn, hash) evaluateAtPly(ply, turn, hash)
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int = private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
val l1ReLU = evalBuffers(0) val l1ReLU = evalBuffers(0)
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
var input = l1ReLU val finalInput =
for i <- 1 until model.layers.length - 1 do (1 until model.layers.length - 1).foldLeft(l1ReLU) { (input, i) =>
val lw = model.weights(i) val lw = model.weights(i)
val out = evalBuffers(i) val out = evalBuffers(i)
val ld = model.layers(i) val ld = model.layers(i)
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize) runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
input = out out
}
val lastIdx = model.layers.length - 1 val lastIdx = model.layers.length - 1
val output = runOutputLayer(input, model.layers(lastIdx).inputSize, model.weights(lastIdx)) val output = runOutputLayer(finalInput, model.layers(lastIdx).inputSize, model.weights(lastIdx))
scoreFromOutput(output, turn) scoreFromOutput(output, turn)
private def runDenseReLU( private def runDenseReLU(
@@ -20,8 +20,10 @@ object NbaiLoader:
/** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */ /** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */
def loadDefault(): NbaiModel = def loadDefault(): NbaiModel =
Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match
case Some(s) => try load(s) finally s.close() case Some(s) =>
case None => NbaiMigrator.migrateFromBin() try load(s)
finally s.close()
case None => NbaiMigrator.migrateFromBin()
private def checkHeader(buf: ByteBuffer): Unit = private def checkHeader(buf: ByteBuffer): Unit =
val magic = buf.getInt() val magic = buf.getInt()
@@ -9,11 +9,11 @@ object NbaiMigrator:
private val BinVersion = 1 private val BinVersion = 1
private val DefaultLayers: Array[LayerDescriptor] = Array( private val DefaultLayers: Array[LayerDescriptor] = Array(
LayerDescriptor("relu", 768, 1536), LayerDescriptor("relu", 768, 1536),
LayerDescriptor("relu", 1536, 1024), LayerDescriptor("relu", 1536, 1024),
LayerDescriptor("relu", 1024, 512), LayerDescriptor("relu", 1024, 512),
LayerDescriptor("relu", 512, 256), LayerDescriptor("relu", 512, 256),
LayerDescriptor("linear", 256, 1), LayerDescriptor("linear", 256, 1),
) )
private val UnknownMetadata: NbaiMetadata = private val UnknownMetadata: NbaiMetadata =
@@ -24,7 +24,13 @@ object NbaiMetadata:
def fromJson(json: String): NbaiMetadata = def fromJson(json: String): NbaiMetadata =
def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("") def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("")
def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0") def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0")
NbaiMetadata(str("trainedBy"), str("trainedAt"), num("trainingDataCount").toLong, num("valLoss").toDouble, num("trainLoss").toDouble) NbaiMetadata(
str("trainedBy"),
str("trainedAt"),
num("trainingDataCount").toLong,
num("valLoss").toDouble,
num("trainLoss").toDouble,
)
/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */ /** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */
case class LayerWeights(weights: Array[Float], bias: Array[Float]) case class LayerWeights(weights: Array[Float], bias: Array[Float])
@@ -147,8 +147,8 @@ final class AlphaBetaSearch(
state: SearchState, state: SearchState,
excludedRootMoves: Set[Move], excludedRootMoves: Set[Move],
): Option[Int] = ): Option[Int] =
val nullCtx = nullMoveContext(context) val nullCtx = nullMoveContext(context)
val nullState = state.advance(ZobristHash.hash(nullCtx)) val nullState = state.advance(ZobristHash.hash(nullCtx))
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R) val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
weights.copyAccumulator(ply, ply + 1) weights.copyAccumulator(ply, ply + 1)
val (score, _) = search(nullCtx, reductionDepth, ply + 1, Window(-beta, -beta + 1), nullState, excludedRootMoves) val (score, _) = search(nullCtx, reductionDepth, ply + 1, Window(-beta, -beta + 1), nullState, excludedRootMoves)
@@ -199,8 +199,14 @@ final class AlphaBetaSearch(
legalMoves: List[Move], legalMoves: List[Move],
): Option[(Int, Option[Move])] = ): Option[(Int, Option[Move])] =
if legalMoves.isEmpty then if legalMoves.isEmpty then
Some((if rules.isCheckmate(params.context) then -(weights.CHECKMATE_SCORE - params.ply) else weights.DRAW_SCORE, None)) Some(
else if rules.isInsufficientMaterial(params.context) || rules.isFiftyMoveRule(params.context) then Some((weights.DRAW_SCORE, None)) (
if rules.isCheckmate(params.context) then -(weights.CHECKMATE_SCORE - params.ply) else weights.DRAW_SCORE,
None,
),
)
else if rules.isInsufficientMaterial(params.context) || rules.isFiftyMoveRule(params.context) then
Some((weights.DRAW_SCORE, None))
else if params.depth == 0 then else if params.depth == 0 then
Some((quiescence(params.context, params.ply, params.window.alpha, params.window.beta, params.state.hash), None)) Some((quiescence(params.context, params.ply, params.window.alpha, params.window.beta, params.state.hash), None))
else None else None
@@ -210,7 +216,18 @@ final class AlphaBetaSearch(
legalMoves: List[Move], legalMoves: List[Move],
): (Int, Option[Move]) = ): (Int, Option[Move]) =
val nullResult = val nullResult =
Option.when(canTryNullMove(params))(tryNullMove(params.context, params.depth, params.ply, params.window.beta, params.state, params.excludedRootMoves)).flatten Option
.when(canTryNullMove(params))(
tryNullMove(
params.context,
params.depth,
params.ply,
params.window.beta,
params.state,
params.excludedRootMoves,
),
)
.flatten
nullResult.map((_, None)).getOrElse { nullResult.map((_, None)).getOrElse {
val ttBest = tt.probe(params.state.hash).flatMap(_.bestMove) val ttBest = tt.probe(params.state.hash).flatMap(_.bestMove)
@@ -228,13 +245,13 @@ final class AlphaBetaSearch(
private def canTryNullMove(params: SearchParams): Boolean = private def canTryNullMove(params: SearchParams): Boolean =
params.depth >= 3 && params.depth >= 3 &&
!rules.isCheck(params.context) && !rules.isCheck(params.context) &&
hasNonPawnMaterial(params.context) hasNonPawnMaterial(params.context)
private def isQuietMove(context: GameContext, move: Move): Boolean = private def isQuietMove(context: GameContext, move: Move): Boolean =
!isCapture(context, move) && !isCapture(context, move) &&
move.moveType != MoveType.CastleKingside && move.moveType != MoveType.CastleKingside &&
move.moveType != MoveType.CastleQueenside move.moveType != MoveType.CastleQueenside
private def scoreMove( private def scoreMove(
child: GameContext, child: GameContext,
@@ -246,14 +263,35 @@ final class AlphaBetaSearch(
): Int = ): Int =
val betaNeg = -params.window.beta val betaNeg = -params.window.beta
if reduction > 0 then if reduction > 0 then
val (rs, _) = search(child, math.max(0, params.depth - 1 - reduction + extension), params.ply + 1, Window(-a - 1, -a), childState, params.excludedRootMoves) val (rs, _) = search(
val s = -rs child,
math.max(0, params.depth - 1 - reduction + extension),
params.ply + 1,
Window(-a - 1, -a),
childState,
params.excludedRootMoves,
)
val s = -rs
if s > a then if s > a then
val (fs, _) = search(child, math.max(0, params.depth - 1 + extension), params.ply + 1, Window(betaNeg, -a), childState, params.excludedRootMoves) val (fs, _) = search(
child,
math.max(0, params.depth - 1 + extension),
params.ply + 1,
Window(betaNeg, -a),
childState,
params.excludedRootMoves,
)
-fs -fs
else s else s
else else
val (rs, _) = search(child, math.max(0, params.depth - 1 + extension), params.ply + 1, Window(betaNeg, -a), childState, params.excludedRootMoves) val (rs, _) = search(
child,
math.max(0, params.depth - 1 + extension),
params.ply + 1,
Window(betaNeg, -a),
childState,
params.excludedRootMoves,
)
-rs -rs
private def evalSingleMove( private def evalSingleMove(
@@ -268,8 +306,8 @@ final class AlphaBetaSearch(
weights.evaluateAccumulator(params.ply, params.context, params.state.hash) + FUTILITY_MARGIN < params.window.alpha weights.evaluateAccumulator(params.ply, params.context, params.state.hash) + FUTILITY_MARGIN < params.window.alpha
if skipRoot || futility then None if skipRoot || futility then None
else else
val child = rules.applyMove(params.context)(move) val child = rules.applyMove(params.context)(move)
val childHash = ZobristHash.nextHash(params.context, params.state.hash, move, child) val childHash = ZobristHash.nextHash(params.context, params.state.hash, move, child)
weights.pushAccumulator(params.ply + 1, move, params.context, child) weights.pushAccumulator(params.ply + 1, move, params.context, child)
val childState = params.state.advance(childHash) val childState = params.state.advance(childHash)
val extension = if rules.isCheck(child) then CHECK_EXTENSION else 0 val extension = if rules.isCheck(child) then CHECK_EXTENSION else 0
@@ -317,7 +355,7 @@ final class AlphaBetaSearch(
state: SearchState, state: SearchState,
excludedRootMoves: Set[Move], excludedRootMoves: Set[Move],
): (Int, Option[Move]) = ): (Int, Option[Move]) =
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves) val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
val (bestMove, bestScore, cutoff) = searchLoop(0, 0, LoopAcc(None, -INF, window.alpha), params, ordered) val (bestMove, bestScore, cutoff) = searchLoop(0, 0, LoopAcc(None, -INF, window.alpha), params, ordered)
val flag = val flag =
if cutoff then TTFlag.Lower if cutoff then TTFlag.Lower
@@ -18,10 +18,13 @@ final case class SearchParams(
final case class SearchState(hash: Long, repetitions: Map[Long, Int]): final case class SearchState(hash: Long, repetitions: Map[Long, Int]):
def advance(nextHash: Long): SearchState = def advance(nextHash: Long): SearchState =
SearchState(nextHash, repetitions.updatedWith(nextHash) { SearchState(
case Some(v) => Some(v + 1) nextHash,
case None => Some(1) repetitions.updatedWith(nextHash) {
}) case Some(v) => Some(v + 1)
case None => Some(1)
},
)
enum TTFlag: enum TTFlag:
case Exact // Score is exact case Exact // Score is exact
@@ -11,11 +11,13 @@ import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import de.nowchess.rules.sets.DefaultRules import de.nowchess.rules.sets.DefaultRules
import java.util.concurrent.atomic.AtomicBoolean
class AlphaBetaSearchTest extends AnyFunSuite with Matchers: class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
private object ZeroEval extends Evaluation: private object ZeroEval extends Evaluation:
val CHECKMATE_SCORE: Int = 1_000_000 val CHECKMATE_SCORE: Int = 1_000_000
val DRAW_SCORE: Int = 0 val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 0 def evaluate(context: GameContext): Int = 0
test("bestMove on initial position returns a move"): test("bestMove on initial position returns a move"):
@@ -43,7 +45,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(stubRules, weights = EvaluationClassic) val search = AlphaBetaSearch(stubRules, weights = EvaluationClassic)
@@ -61,7 +63,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(stubRules, weights = EvaluationClassic) val search = AlphaBetaSearch(stubRules, weights = EvaluationClassic)
@@ -109,7 +111,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = true def isStalemate(context: GameContext): Boolean = true
def isInsufficientMaterial(context: GameContext): Boolean = false def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(stalematRules, weights = EvaluationClassic) val search = AlphaBetaSearch(stalematRules, weights = EvaluationClassic)
@@ -126,7 +128,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = true def isInsufficientMaterial(context: GameContext): Boolean = true
def isFiftyMoveRule(context: GameContext): Boolean = false def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(insufficientRules, weights = EvaluationClassic) val search = AlphaBetaSearch(insufficientRules, weights = EvaluationClassic)
@@ -143,7 +145,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = true def isFiftyMoveRule(context: GameContext): Boolean = true
def isThreefoldRepetition(context: GameContext): Boolean = false def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(fiftyMoveRules, weights = EvaluationClassic) val search = AlphaBetaSearch(fiftyMoveRules, weights = EvaluationClassic)
@@ -170,7 +172,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(rulesWithCapture, weights = EvaluationClassic) val search = AlphaBetaSearch(rulesWithCapture, weights = EvaluationClassic)
@@ -188,7 +190,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(rulesQuiet, weights = EvaluationClassic) val search = AlphaBetaSearch(rulesQuiet, weights = EvaluationClassic)
@@ -311,18 +313,16 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove)) search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove))
test("quiescence depth-limit in-check branch is exercised"): test("quiescence depth-limit in-check branch is exercised"):
val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal()) val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true)) val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true))
var firstChildCheckCall = true val firstChildCheckCall = AtomicBoolean(true)
val deepQRules = new RuleSet: val deepQRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context) def candidateMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context)
def legalMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context) def legalMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context)
def allLegalMoves(context: GameContext): List[Move] = def allLegalMoves(context: GameContext): List[Move] =
if context.moves.isEmpty then List(rootMove) else List(capMove) if context.moves.isEmpty then List(rootMove) else List(capMove)
def isCheck(context: GameContext): Boolean = def isCheck(context: GameContext): Boolean =
if context.moves.length == 1 && firstChildCheckCall then if context.moves.length == 1 && firstChildCheckCall.compareAndSet(true, false) then false
firstChildCheckCall = false
false
else context.moves.nonEmpty else context.moves.nonEmpty
def isCheckmate(context: GameContext): Boolean = false def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false def isStalemate(context: GameContext): Boolean = false
@@ -334,4 +334,3 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
val search = AlphaBetaSearch(deepQRules, weights = ZeroEval) val search = AlphaBetaSearch(deepQRules, weights = ZeroEval)
search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove)) search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove))
@@ -35,7 +35,7 @@ class ClassicalBotTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context def applyMove(context: GameContext)(move: Move): GameContext = context
val bot = ClassicalBot(BotDifficulty.Easy, stubRules) val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
@@ -66,7 +66,7 @@ class ClassicalBotTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context def applyMove(context: GameContext)(move: Move): GameContext = context
val bot = ClassicalBot(BotDifficulty.Easy, stubRules) val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
@@ -89,11 +89,10 @@ class ClassicalBotTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context def applyMove(context: GameContext)(move: Move): GameContext = context
val context = GameContext.initial.copy(moves = List(repeatedMove, repeatedMove, repeatedMove)) val context = GameContext.initial.copy(moves = List(repeatedMove, repeatedMove, repeatedMove))
val bot = ClassicalBot(BotDifficulty.Easy, stubRules) val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
bot.nextMove(context) should be(None) bot.nextMove(context) should be(None)
@@ -12,6 +12,7 @@ import org.scalatest.matchers.should.Matchers
import java.io.{DataOutputStream, FileOutputStream} import java.io.{DataOutputStream, FileOutputStream}
import java.nio.file.Files import java.nio.file.Files
import java.util.concurrent.atomic.AtomicBoolean
import scala.util.Using import scala.util.Using
class HybridBotTest extends AnyFunSuite with Matchers: class HybridBotTest extends AnyFunSuite with Matchers:
@@ -68,8 +69,8 @@ class HybridBotTest extends AnyFunSuite with Matchers:
test("HybridBot uses book move when available"): test("HybridBot uses book move when available"):
val tempFile = Files.createTempFile("hybrid_book", ".bin") val tempFile = Files.createTempFile("hybrid_book", ".bin")
try try
val ctx = GameContext.initial val ctx = GameContext.initial
val hash = PolyglotHash.hash(ctx) val hash = PolyglotHash.hash(ctx)
val e2e4: Short = (4 | (3 << 3) | (4 << 6) | (1 << 9)).toShort val e2e4: Short = (4 | (3 << 3) | (4 << 6) | (1 << 9)).toShort
Using(DataOutputStream(FileOutputStream(tempFile.toFile))) { dos => Using(DataOutputStream(FileOutputStream(tempFile.toFile))) { dos =>
@@ -100,26 +101,26 @@ class HybridBotTest extends AnyFunSuite with Matchers:
context.copy(turn = context.turn.opposite, moves = context.moves :+ move) context.copy(turn = context.turn.opposite, moves = context.moves :+ move)
object LowNnue extends Evaluation: object LowNnue extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000 val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0 val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 0 def evaluate(context: GameContext): Int = 0
object HighClassic extends Evaluation: object HighClassic extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000 val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0 val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 10_000 def evaluate(context: GameContext): Int = 10_000
var reported = false val reported = AtomicBoolean(false)
val bot = HybridBot( val bot = HybridBot(
BotDifficulty.Easy, BotDifficulty.Easy,
rules = oneMoveRules, rules = oneMoveRules,
nnueEvaluation = LowNnue, nnueEvaluation = LowNnue,
classicalEvaluation = HighClassic, classicalEvaluation = HighClassic,
vetoReporter = _ => reported = true, vetoReporter = _ => reported.set(true),
) )
bot.nextMove(GameContext.initial) should be(Some(forcedMove)) bot.nextMove(GameContext.initial) should be(Some(forcedMove))
reported should be(true) reported.get should be(true)
test("HybridBot default veto reporter prints when threshold is exceeded"): test("HybridBot default veto reporter prints when threshold is exceeded"):
val forcedMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal()) val forcedMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
@@ -137,13 +138,13 @@ class HybridBotTest extends AnyFunSuite with Matchers:
context.copy(turn = context.turn.opposite, moves = context.moves :+ move) context.copy(turn = context.turn.opposite, moves = context.moves :+ move)
object LowNnue extends Evaluation: object LowNnue extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000 val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0 val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 0 def evaluate(context: GameContext): Int = 0
object HighClassic extends Evaluation: object HighClassic extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000 val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0 val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 10_000 def evaluate(context: GameContext): Int = 10_000
val bot = HybridBot( val bot = HybridBot(
@@ -157,4 +158,3 @@ class HybridBotTest extends AnyFunSuite with Matchers:
bot.nextMove(GameContext.initial) bot.nextMove(GameContext.initial)
} }
printed should be(Some(forcedMove)) printed should be(Some(forcedMove))
@@ -9,16 +9,6 @@ import org.scalatest.matchers.should.Matchers
class MoveOrderingTest extends AnyFunSuite with Matchers: class MoveOrderingTest extends AnyFunSuite with Matchers:
private val moveOrderingClass = Class.forName("de.nowchess.bot.logic.MoveOrdering$")
private val moveOrderingObject = moveOrderingClass.getField("MODULE$").get(null)
private def invokeMoveOrderingPrivate[T](methodPrefix: String, args: Seq[AnyRef]): T =
val method = moveOrderingClass.getDeclaredMethods
.find(m => (m.getName == methodPrefix || m.getName.startsWith(methodPrefix + "$")) && m.getParameterCount == args.length)
.getOrElse(throw RuntimeException(s"Method not found: $methodPrefix/${args.length}"))
method.setAccessible(true)
method.invoke(moveOrderingObject, args*).asInstanceOf[T]
test("queen capture ranks higher than rook capture"): test("queen capture ranks higher than rook capture"):
val board = Board( val board = Board(
Map( Map(
@@ -200,13 +190,13 @@ class MoveOrderingTest extends AnyFunSuite with Matchers:
Square(File.D, Rank.R8) -> Piece.BlackRook, Square(File.D, Rank.R8) -> Piece.BlackRook,
), ),
) )
val context = GameContext.initial.withBoard(board).withTurn(Color.White) val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val knightPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Knight)) val knightPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Knight))
val bishopPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Bishop)) val bishopPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Bishop))
val rookPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Rook)) val rookPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Rook))
MoveOrdering.score(context, knightPromo, None) should be > 0 MoveOrdering.score(context, knightPromo, None) should be > 0
MoveOrdering.score(context, bishopPromo, None) should be > 0 MoveOrdering.score(context, bishopPromo, None) should be > 0
MoveOrdering.score(context, rookPromo, None) should be > 0 MoveOrdering.score(context, rookPromo, None) should be > 0
test("negative SEE capture path is scored below neutral capture baseline"): test("negative SEE capture path is scored below neutral capture baseline"):
val board = Board( val board = Board(
@@ -221,23 +211,9 @@ class MoveOrderingTest extends AnyFunSuite with Matchers:
MoveOrdering.score(context, move, None) should be < 100_000 MoveOrdering.score(context, move, None) should be < 100_000
test("private fallback branches in MoveOrdering are covered"): test("non-capture move keeps fallback scoring at zero"):
val board = Board(Map(Square(File.E, Rank.R1) -> Piece.WhiteKing, Square(File.A, Rank.R8) -> Piece.BlackKing)) val board = Board(Map(Square(File.E, Rank.R1) -> Piece.WhiteKing, Square(File.A, Rank.R8) -> Piece.BlackKing))
val context = GameContext.initial.withBoard(board).withTurn(Color.White) val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val castle = Move(Square(File.E, Rank.R1), Square(File.G, Rank.R1), MoveType.CastleKingside) val castle = Move(Square(File.E, Rank.R1), Square(File.G, Rank.R1), MoveType.CastleKingside)
val victim = invokeMoveOrderingPrivate[Int]("victimValue", Seq[AnyRef](context, castle)) MoveOrdering.score(context, castle, None) should be(0)
victim should be(0)
val capture = invokeMoveOrderingPrivate[Boolean]("isCapture", Seq[AnyRef](context, castle))
capture should be(false)
val see = invokeMoveOrderingPrivate[Int]("staticExchange", Seq[AnyRef](context, castle))
see should be(0)
val clear = invokeMoveOrderingPrivate[Boolean](
"pathClear",
Seq[AnyRef](board, Square(File.A, Rank.R1), Square(File.H, Rank.R1), Int.box(-1), Int.box(0)),
)
clear should be(false)
@@ -98,8 +98,12 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
Square(File.E, Rank.R8) -> Piece.BlackKing, Square(File.E, Rank.R8) -> Piece.BlackKing,
), ),
) )
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White) val ctx = GameContext.initial
.withCastlingRights(CastlingRights(whiteKingSide = false, whiteQueenSide = true, blackKingSide = false, blackQueenSide = false)) .withBoard(board)
.withTurn(Color.White)
.withCastlingRights(
CastlingRights(whiteKingSide = false, whiteQueenSide = true, blackKingSide = false, blackQueenSide = false),
)
val move = Move(Square(File.E, Rank.R1), Square(File.C, Rank.R1), MoveType.CastleQueenside) val move = Move(Square(File.E, Rank.R1), Square(File.C, Rank.R1), MoveType.CastleQueenside)
val next = DefaultRules.applyMove(ctx)(move) val next = DefaultRules.applyMove(ctx)(move)
ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next)) ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next))
@@ -113,7 +117,9 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
Square(File.E, Rank.R8) -> Piece.BlackKing, Square(File.E, Rank.R8) -> Piece.BlackKing,
), ),
) )
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White) val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.White)
.withEnPassantSquare(Some(Square(File.D, Rank.R6))) .withEnPassantSquare(Some(Square(File.D, Rank.R6)))
val move = Move(Square(File.E, Rank.R5), Square(File.D, Rank.R6), MoveType.EnPassant) val move = Move(Square(File.E, Rank.R5), Square(File.D, Rank.R6), MoveType.EnPassant)
val next = DefaultRules.applyMove(ctx)(move) val next = DefaultRules.applyMove(ctx)(move)
@@ -127,8 +133,12 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
Square(File.E, Rank.R1) -> Piece.WhiteKing, Square(File.E, Rank.R1) -> Piece.WhiteKing,
), ),
) )
val ctx = GameContext.initial.withBoard(board).withTurn(Color.Black) val ctx = GameContext.initial
.withCastlingRights(CastlingRights(whiteKingSide = false, whiteQueenSide = false, blackKingSide = true, blackQueenSide = false)) .withBoard(board)
.withTurn(Color.Black)
.withCastlingRights(
CastlingRights(whiteKingSide = false, whiteQueenSide = false, blackKingSide = true, blackQueenSide = false),
)
val move = Move(Square(File.E, Rank.R8), Square(File.G, Rank.R8), MoveType.CastleKingside) val move = Move(Square(File.E, Rank.R8), Square(File.G, Rank.R8), MoveType.CastleKingside)
val next = DefaultRules.applyMove(ctx)(move) val next = DefaultRules.applyMove(ctx)(move)
ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next)) ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next))
@@ -141,7 +151,9 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
Square(File.E, Rank.R8) -> Piece.BlackKing, Square(File.E, Rank.R8) -> Piece.BlackKing,
), ),
) )
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White) val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.White)
.withCastlingRights(CastlingRights(false, false, false, false)) .withCastlingRights(CastlingRights(false, false, false, false))
for pp <- List(PromotionPiece.Knight, PromotionPiece.Bishop, PromotionPiece.Rook) do for pp <- List(PromotionPiece.Knight, PromotionPiece.Bishop, PromotionPiece.Rook) do
@@ -5,8 +5,8 @@ import de.nowchess.api.move.PromotionPiece
object Parser: object Parser:
/** Parses UCI move notation: "e2e4" (4 chars) or "e7e8q" (5 chars with promotion piece suffix). /** Parses UCI move notation: "e2e4" (4 chars) or "e7e8q" (5 chars with promotion piece suffix). The promotion suffix
* The promotion suffix is q=Queen, r=Rook, b=Bishop, n=Knight. Returns None for invalid input. * is q=Queen, r=Rook, b=Bishop, n=Knight. Returns None for invalid input.
*/ */
def parseMove(input: String): Option[(Square, Square, Option[PromotionPiece])] = def parseMove(input: String): Option[(Square, Square, Option[PromotionPiece])] =
val trimmed = input.trim.toLowerCase val trimmed = input.trim.toLowerCase
@@ -19,13 +19,16 @@ import scala.concurrent.{ExecutionContext, Future}
class GameEngine( class GameEngine(
val initialContext: GameContext = GameContext.initial, val initialContext: GameContext = GameContext.initial,
val ruleSet: RuleSet = DefaultRules, val ruleSet: RuleSet = DefaultRules,
val participants: Map[Color, Participant] = Map(Color.White -> Human(PlayerInfo(PlayerId("p1"), "Player 1")), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))), val participants: Map[Color, Participant] = Map(
Color.White -> Human(PlayerInfo(PlayerId("p1"), "Player 1")),
Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2")),
),
) extends Observable: ) extends Observable:
// Ensure that initialBoard is set correctly for threefold repetition detection // Ensure that initialBoard is set correctly for threefold repetition detection
private val contextWithInitialBoard = if initialContext.moves.isEmpty && initialContext.board != initialContext.initialBoard then private val contextWithInitialBoard =
initialContext.copy(initialBoard = initialContext.board) if initialContext.moves.isEmpty && initialContext.board != initialContext.initialBoard then
else initialContext.copy(initialBoard = initialContext.board)
initialContext else initialContext
@SuppressWarnings(Array("DisableSyntax.var")) @SuppressWarnings(Array("DisableSyntax.var"))
private var currentContext: GameContext = contextWithInitialBoard private var currentContext: GameContext = contextWithInitialBoard
private val invoker = new CommandInvoker() private val invoker = new CommandInvoker()
@@ -109,10 +112,15 @@ class GameEngine(
notifyObservers(InvalidMoveEvent(currentContext, "Illegal move.")) notifyObservers(InvalidMoveEvent(currentContext, "Illegal move."))
case _ if isPromotionMove(piece, to) => case _ if isPromotionMove(piece, to) =>
if promotionPiece.isEmpty then if promotionPiece.isEmpty then
notifyObservers(InvalidMoveEvent(currentContext, "Promotion piece required: append q, r, b, or n to the move.")) notifyObservers(
InvalidMoveEvent(currentContext, "Promotion piece required: append q, r, b, or n to the move."),
)
else else
candidates.find(_.moveType == MoveType.Promotion(promotionPiece.get)) match candidates.find(_.moveType == MoveType.Promotion(promotionPiece.get)) match
case None => notifyObservers(InvalidMoveEvent(currentContext, "Error completing promotion: no matching legal move.")) case None =>
notifyObservers(
InvalidMoveEvent(currentContext, "Error completing promotion: no matching legal move."),
)
case Some(move) => executeMove(move) case Some(move) => executeMove(move)
case move :: _ => case move :: _ =>
executeMove(move) executeMove(move)
@@ -159,7 +167,7 @@ class GameEngine(
result result
private def applyReplayMove(move: Move): Either[String, Unit] = private def applyReplayMove(move: Move): Either[String, Unit] =
val legal = ruleSet.legalMoves(currentContext)(move.from) val legal = ruleSet.legalMoves(currentContext)(move.from)
val candidate = move.moveType match val candidate = move.moveType match
case MoveType.Promotion(pp) => legal.find(m => m.to == move.to && m.moveType == MoveType.Promotion(pp)) case MoveType.Promotion(pp) => legal.find(m => m.to == move.to && m.moveType == MoveType.Promotion(pp))
case _ => legal.find(_.to == move.to) case _ => legal.find(_.to == move.to)
@@ -174,10 +182,9 @@ class GameEngine(
/** Load an arbitrary board position, clearing all history and undo/redo state. */ /** Load an arbitrary board position, clearing all history and undo/redo state. */
def loadPosition(newContext: GameContext): Unit = synchronized { def loadPosition(newContext: GameContext): Unit = synchronized {
val contextWithInitialBoard = if newContext.moves.isEmpty then val contextWithInitialBoard =
newContext.copy(initialBoard = newContext.board) if newContext.moves.isEmpty then newContext.copy(initialBoard = newContext.board)
else else newContext
newContext
currentContext = contextWithInitialBoard currentContext = contextWithInitialBoard
invoker.clear() invoker.clear()
notifyObservers(BoardResetEvent(currentContext)) notifyObservers(BoardResetEvent(currentContext))
@@ -235,7 +242,8 @@ class GameEngine(
else if ruleSet.isCheck(currentContext) then notifyObservers(CheckDetectedEvent(currentContext)) else if ruleSet.isCheck(currentContext) then notifyObservers(CheckDetectedEvent(currentContext))
if currentContext.halfMoveClock >= 100 then notifyObservers(FiftyMoveRuleAvailableEvent(currentContext)) if currentContext.halfMoveClock >= 100 then notifyObservers(FiftyMoveRuleAvailableEvent(currentContext))
if ruleSet.isThreefoldRepetition(currentContext) then notifyObservers(ThreefoldRepetitionAvailableEvent(currentContext)) if ruleSet.isThreefoldRepetition(currentContext) then
notifyObservers(ThreefoldRepetitionAvailableEvent(currentContext))
else requestBotMoveIfNeeded() else requestBotMoveIfNeeded()
private def translateMoveToNotation(move: Move, boardBefore: Board): String = private def translateMoveToNotation(move: Move, boardBefore: Board): String =
@@ -18,8 +18,8 @@ class CommandInvokerBranchTest extends AnyFunSuite with Matchers:
initialShouldFailOnUndo: Boolean = false, initialShouldFailOnUndo: Boolean = false,
initialShouldFailOnExecute: Boolean = false, initialShouldFailOnExecute: Boolean = false,
) extends Command: ) extends Command:
val shouldFailOnUndo = new java.util.concurrent.atomic.AtomicBoolean(initialShouldFailOnUndo) val shouldFailOnUndo = new java.util.concurrent.atomic.AtomicBoolean(initialShouldFailOnUndo)
val shouldFailOnExecute = new java.util.concurrent.atomic.AtomicBoolean(initialShouldFailOnExecute) val shouldFailOnExecute = new java.util.concurrent.atomic.AtomicBoolean(initialShouldFailOnExecute)
override def execute(): Boolean = !shouldFailOnExecute.get() override def execute(): Boolean = !shouldFailOnExecute.get()
override def undo(): Boolean = !shouldFailOnUndo.get() override def undo(): Boolean = !shouldFailOnUndo.get()
override def description: String = "Conditional fail" override def description: String = "Conditional fail"
@@ -251,10 +251,10 @@ class GameEngineOutcomesTest extends AnyFunSuite with Matchers:
engine.processUserInput("f3g1") engine.processUserInput("f3g1")
observer.clear() observer.clear()
engine.processUserInput("f6g8") // 3rd occurrence of initial position engine.processUserInput("f6g8") // 3rd occurrence of initial position
observer.hasEvent[ThreefoldRepetitionAvailableEvent] shouldBe true observer.hasEvent[ThreefoldRepetitionAvailableEvent] shouldBe true
engine.context.result shouldBe None // claimable, not automatic engine.context.result shouldBe None // claimable, not automatic
test("draw claim via threefold repetition ends game with DrawEvent"): test("draw claim via threefold repetition ends game with DrawEvent"):
val engine = EngineTestHelpers.makeEngine() val engine = EngineTestHelpers.makeEngine()
@@ -268,7 +268,7 @@ class GameEngineOutcomesTest extends AnyFunSuite with Matchers:
engine.processUserInput("g1f3") engine.processUserInput("g1f3")
engine.processUserInput("g8f6") engine.processUserInput("g8f6")
engine.processUserInput("f3g1") engine.processUserInput("f3g1")
engine.processUserInput("f6g8") // threefold now available engine.processUserInput("f6g8") // threefold now available
observer.clear() observer.clear()
engine.processUserInput("draw") engine.processUserInput("draw")
@@ -15,17 +15,17 @@ import org.scalatest.matchers.should.Matchers
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
private class NoMoveBot extends Bot: private class NoMoveBot extends Bot:
def name: String = "nomove" def name: String = "nomove"
def nextMove(context: GameContext): Option[Move] = None def nextMove(context: GameContext): Option[Move] = None
private class FixedMoveBot(move: Move) extends Bot: private class FixedMoveBot(move: Move) extends Bot:
def name: String = "fixed" def name: String = "fixed"
def nextMove(context: GameContext): Option[Move] = Some(move) def nextMove(context: GameContext): Option[Move] = Some(move)
class GameEngineWithBotTest extends AnyFunSuite with Matchers: class GameEngineWithBotTest extends AnyFunSuite with Matchers:
test("GameEngine can play against a ClassicalBot"): test("GameEngine can play against a ClassicalBot"):
val bot = ClassicalBot(BotDifficulty.Easy) val bot = ClassicalBot(BotDifficulty.Easy)
val engine = GameEngine( val engine = GameEngine(
GameContext.initial, GameContext.initial,
DefaultRules, DefaultRules,
@@ -99,7 +99,7 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
movesMade.get() should be >= 1 movesMade.get() should be >= 1
test("GameEngine plays valid bot moves"): test("GameEngine plays valid bot moves"):
val bot = ClassicalBot(BotDifficulty.Easy) val bot = ClassicalBot(BotDifficulty.Easy)
val engine = GameEngine( val engine = GameEngine(
GameContext.initial, GameContext.initial,
DefaultRules, DefaultRules,
@@ -125,17 +125,18 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
engine.context.moves.nonEmpty should be(true) engine.context.moves.nonEmpty should be(true)
test("startGame triggers bot when the starting player is a bot"): test("startGame triggers bot when the starting player is a bot"):
val bot = new FixedMoveBot(Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4))) val bot = new FixedMoveBot(Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4)))
val engine = GameEngine( val engine = GameEngine(
GameContext.initial, GameContext.initial,
DefaultRules, DefaultRules,
Map(Color.White -> BotParticipant(bot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))), Map(Color.White -> BotParticipant(bot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))),
) )
val movesMade = new AtomicInteger(0) val movesMade = new AtomicInteger(0)
engine.subscribe(new Observer: engine.subscribe(
def onGameEvent(event: GameEvent): Unit = event match new Observer:
case _: MoveExecutedEvent => movesMade.incrementAndGet() def onGameEvent(event: GameEvent): Unit = event match
case _ => () case _: MoveExecutedEvent => movesMade.incrementAndGet()
case _ => (),
) )
engine.startGame() engine.startGame()
Thread.sleep(500) Thread.sleep(500)
@@ -143,17 +144,18 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
test("applyBotMove fires InvalidMoveEvent when bot move destination is illegal"): test("applyBotMove fires InvalidMoveEvent when bot move destination is illegal"):
val illegalMove = Move(Square(File.E, Rank.R7), Square(File.E, Rank.R3), MoveType.Normal()) val illegalMove = Move(Square(File.E, Rank.R7), Square(File.E, Rank.R3), MoveType.Normal())
val bot = new FixedMoveBot(illegalMove) val bot = new FixedMoveBot(illegalMove)
val engine = GameEngine( val engine = GameEngine(
GameContext.initial, GameContext.initial,
DefaultRules, DefaultRules,
Map(Color.White -> Human(PlayerInfo(PlayerId("p1"), "Player 1")), Color.Black -> BotParticipant(bot)), Map(Color.White -> Human(PlayerInfo(PlayerId("p1"), "Player 1")), Color.Black -> BotParticipant(bot)),
) )
val invalidCount = new AtomicInteger(0) val invalidCount = new AtomicInteger(0)
engine.subscribe(new Observer: engine.subscribe(
def onGameEvent(event: GameEvent): Unit = event match new Observer:
case _: InvalidMoveEvent => invalidCount.incrementAndGet() def onGameEvent(event: GameEvent): Unit = event match
case _ => () case _: InvalidMoveEvent => invalidCount.incrementAndGet()
case _ => (),
) )
engine.processUserInput("e2e4") engine.processUserInput("e2e4")
Thread.sleep(1000) Thread.sleep(1000)
@@ -161,17 +163,18 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
test("applyBotMove fires InvalidMoveEvent when bot move source square is invalid"): test("applyBotMove fires InvalidMoveEvent when bot move source square is invalid"):
val invalidMove = Move(Square(File.E, Rank.R5), Square(File.E, Rank.R6), MoveType.Normal()) val invalidMove = Move(Square(File.E, Rank.R5), Square(File.E, Rank.R6), MoveType.Normal())
val bot = new FixedMoveBot(invalidMove) val bot = new FixedMoveBot(invalidMove)
val engine = GameEngine( val engine = GameEngine(
GameContext.initial, GameContext.initial,
DefaultRules, DefaultRules,
Map(Color.White -> Human(PlayerInfo(PlayerId("p1"), "Player 1")), Color.Black -> BotParticipant(bot)), Map(Color.White -> Human(PlayerInfo(PlayerId("p1"), "Player 1")), Color.Black -> BotParticipant(bot)),
) )
val invalidCount = new AtomicInteger(0) val invalidCount = new AtomicInteger(0)
engine.subscribe(new Observer: engine.subscribe(
def onGameEvent(event: GameEvent): Unit = event match new Observer:
case _: InvalidMoveEvent => invalidCount.incrementAndGet() def onGameEvent(event: GameEvent): Unit = event match
case _ => () case _: InvalidMoveEvent => invalidCount.incrementAndGet()
case _ => (),
) )
engine.processUserInput("e2e4") engine.processUserInput("e2e4")
Thread.sleep(1000) Thread.sleep(1000)
@@ -179,12 +182,14 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
test("handleBotNoMove fires CheckmateEvent when position is checkmate"): test("handleBotNoMove fires CheckmateEvent when position is checkmate"):
// White king at A1 in check from Qb2; Rb8 protects queen so king can't capture it // White king at A1 in check from Qb2; Rb8 protects queen so king can't capture it
val board = Board(Map( val board = Board(
Square(File.A, Rank.R1) -> Piece.WhiteKing, Map(
Square(File.B, Rank.R2) -> Piece.BlackQueen, Square(File.A, Rank.R1) -> Piece.WhiteKing,
Square(File.B, Rank.R8) -> Piece.BlackRook, Square(File.B, Rank.R2) -> Piece.BlackQueen,
Square(File.H, Rank.R8) -> Piece.BlackKing, Square(File.B, Rank.R8) -> Piece.BlackRook,
)) Square(File.H, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial.copy( val ctx = GameContext.initial.copy(
board = board, board = board,
turn = Color.White, turn = Color.White,
@@ -193,12 +198,17 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
halfMoveClock = 0, halfMoveClock = 0,
moves = List.empty, moves = List.empty,
) )
val engine = GameEngine(ctx, DefaultRules, Map(Color.White -> BotParticipant(new NoMoveBot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2")))) val engine = GameEngine(
ctx,
DefaultRules,
Map(Color.White -> BotParticipant(new NoMoveBot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))),
)
val checkmateCount = new AtomicInteger(0) val checkmateCount = new AtomicInteger(0)
engine.subscribe(new Observer: engine.subscribe(
def onGameEvent(event: GameEvent): Unit = event match new Observer:
case _: CheckmateEvent => checkmateCount.incrementAndGet() def onGameEvent(event: GameEvent): Unit = event match
case _ => () case _: CheckmateEvent => checkmateCount.incrementAndGet()
case _ => (),
) )
engine.startGame() engine.startGame()
Thread.sleep(1000) Thread.sleep(1000)
@@ -206,11 +216,13 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
test("handleBotNoMove fires DrawEvent when position is stalemate"): test("handleBotNoMove fires DrawEvent when position is stalemate"):
// White king at A1 not in check but has no legal moves (queen at B3 covers A2, B1, B2) // White king at A1 not in check but has no legal moves (queen at B3 covers A2, B1, B2)
val board = Board(Map( val board = Board(
Square(File.A, Rank.R1) -> Piece.WhiteKing, Map(
Square(File.B, Rank.R3) -> Piece.BlackQueen, Square(File.A, Rank.R1) -> Piece.WhiteKing,
Square(File.H, Rank.R8) -> Piece.BlackKing, Square(File.B, Rank.R3) -> Piece.BlackQueen,
)) Square(File.H, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial.copy( val ctx = GameContext.initial.copy(
board = board, board = board,
turn = Color.White, turn = Color.White,
@@ -219,12 +231,17 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
halfMoveClock = 0, halfMoveClock = 0,
moves = List.empty, moves = List.empty,
) )
val engine = GameEngine(ctx, DefaultRules, Map(Color.White -> BotParticipant(new NoMoveBot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2")))) val engine = GameEngine(
ctx,
DefaultRules,
Map(Color.White -> BotParticipant(new NoMoveBot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))),
)
val drawCount = new AtomicInteger(0) val drawCount = new AtomicInteger(0)
engine.subscribe(new Observer: engine.subscribe(
def onGameEvent(event: GameEvent): Unit = event match new Observer:
case _: DrawEvent => drawCount.incrementAndGet() def onGameEvent(event: GameEvent): Unit = event match
case _ => () case _: DrawEvent => drawCount.incrementAndGet()
case _ => (),
) )
engine.startGame() engine.startGame()
Thread.sleep(1000) Thread.sleep(1000)
@@ -237,13 +254,13 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
Map(Color.White -> BotParticipant(new NoMoveBot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))), Map(Color.White -> BotParticipant(new NoMoveBot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))),
) )
val unexpectedEvents = new AtomicInteger(0) val unexpectedEvents = new AtomicInteger(0)
engine.subscribe(new Observer: engine.subscribe(
def onGameEvent(event: GameEvent): Unit = event match new Observer:
case _: CheckmateEvent => unexpectedEvents.incrementAndGet() def onGameEvent(event: GameEvent): Unit = event match
case _: DrawEvent => unexpectedEvents.incrementAndGet() case _: CheckmateEvent => unexpectedEvents.incrementAndGet()
case _ => () case _: DrawEvent => unexpectedEvents.incrementAndGet()
case _ => (),
) )
engine.startGame() engine.startGame()
Thread.sleep(500) Thread.sleep(500)
unexpectedEvents.get() shouldBe 0 unexpectedEvents.get() shouldBe 0
@@ -70,58 +70,66 @@ class FenParserFastParseTest extends AnyFunSuite with Matchers:
FenParserFastParse.parseBoard("8pp/8/8/8/8/8/8/8") shouldBe None FenParserFastParse.parseBoard("8pp/8/8/8/8/8/8/8") shouldBe None
test("parseFen handles all individual castling rights"): test("parseFen handles all individual castling rights"):
FenParserFastParse.parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w K - 0 1").fold(_ => fail(), ctx => FenParserFastParse
ctx.castlingRights.whiteKingSide shouldBe true .parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w K - 0 1")
ctx.castlingRights.whiteQueenSide shouldBe false .fold(
ctx.castlingRights.blackKingSide shouldBe false _ => fail(),
ctx.castlingRights.blackQueenSide shouldBe false ctx =>
) ctx.castlingRights.whiteKingSide shouldBe true
ctx.castlingRights.whiteQueenSide shouldBe false
ctx.castlingRights.blackKingSide shouldBe false
ctx.castlingRights.blackQueenSide shouldBe false,
)
FenParserFastParse.parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w Q - 0 1").fold(_ => fail(), ctx => FenParserFastParse
ctx.castlingRights.whiteQueenSide shouldBe true .parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w Q - 0 1")
ctx.castlingRights.whiteKingSide shouldBe false .fold(
) _ => fail(),
ctx =>
ctx.castlingRights.whiteQueenSide shouldBe true
ctx.castlingRights.whiteKingSide shouldBe false,
)
FenParserFastParse.parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w k - 0 1").fold(_ => fail(), ctx => FenParserFastParse
ctx.castlingRights.blackKingSide shouldBe true .parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w k - 0 1")
ctx.castlingRights.whiteKingSide shouldBe false .fold(
) _ => fail(),
ctx =>
ctx.castlingRights.blackKingSide shouldBe true
ctx.castlingRights.whiteKingSide shouldBe false,
)
FenParserFastParse.parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w q - 0 1").fold(_ => fail(), ctx => FenParserFastParse
ctx.castlingRights.blackQueenSide shouldBe true .parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w q - 0 1")
ctx.castlingRights.whiteKingSide shouldBe false .fold(
) _ => fail(),
ctx =>
ctx.castlingRights.blackQueenSide shouldBe true
ctx.castlingRights.whiteKingSide shouldBe false,
)
test("parseFen parses all en passant squares"): test("parseFen parses all en passant squares"):
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - a3 0 1").fold(_ => fail(), ctx => FenParserFastParse
ctx.enPassantSquare shouldBe Some(Square(File.A, Rank.R3)) .parseFen("8/8/8/8/8/8/8/8 w - a3 0 1")
) .fold(_ => fail(), ctx => ctx.enPassantSquare shouldBe Some(Square(File.A, Rank.R3)))
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - h6 0 1").fold(_ => fail(), ctx => FenParserFastParse
ctx.enPassantSquare shouldBe Some(Square(File.H, Rank.R6)) .parseFen("8/8/8/8/8/8/8/8 w - h6 0 1")
) .fold(_ => fail(), ctx => ctx.enPassantSquare shouldBe Some(Square(File.H, Rank.R6)))
test("parseFen parses different halfMove and fullMove clocks"): test("parseFen parses different halfMove and fullMove clocks"):
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - - 5 10").fold(_ => fail(), ctx => FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - - 5 10").fold(_ => fail(), ctx => ctx.halfMoveClock shouldBe 5)
ctx.halfMoveClock shouldBe 5
)
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - - 0 100").fold(_ => fail(), ctx => FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - - 0 100").fold(_ => fail(), ctx => ctx.halfMoveClock shouldBe 0)
ctx.halfMoveClock shouldBe 0
)
test("parseBoard parses boards with mixed empty and piece tokens"): test("parseBoard parses boards with mixed empty and piece tokens"):
val mixed = "8/1p1p1p1p/8/1P1P1P1P/8/8/8/8" val mixed = "8/1p1p1p1p/8/1P1P1P1P/8/8/8/8"
FenParserFastParse.parseBoard(mixed) should not be empty FenParserFastParse.parseBoard(mixed) should not be empty
test("parseFen handles turn transitions"): test("parseFen handles turn transitions"):
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - - 0 1").fold(_ => fail(), ctx => FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - - 0 1").fold(_ => fail(), ctx => ctx.turn shouldBe Color.White)
ctx.turn shouldBe Color.White
)
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 b - - 0 1").fold(_ => fail(), ctx => FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 b - - 0 1").fold(_ => fail(), ctx => ctx.turn shouldBe Color.Black)
ctx.turn shouldBe Color.Black
)
test("parseFen rejects invalid piece characters"): test("parseFen rejects invalid piece characters"):
FenParserFastParse.parseFen("8x/8/8/8/8/8/8/8 w - - 0 1").isLeft shouldBe true FenParserFastParse.parseFen("8x/8/8/8/8/8/8/8 w - - 0 1").isLeft shouldBe true
@@ -133,7 +141,7 @@ class FenParserFastParseTest extends AnyFunSuite with Matchers:
test("parseBoard tests all piece types in various positions"): test("parseBoard tests all piece types in various positions"):
// Test each piece type: pawn, rook, knight, bishop, queen, king (both colors) // Test each piece type: pawn, rook, knight, bishop, queen, king (both colors)
val allPieces = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR" val allPieces = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR"
val parsed = FenParserFastParse.parseBoard(allPieces) val parsed = FenParserFastParse.parseBoard(allPieces)
parsed.map(_.pieces.size) shouldBe Some(32) parsed.map(_.pieces.size) shouldBe Some(32)
parsed.map(_.pieceAt(Square(File.A, Rank.R8))) shouldBe Some(Some(Piece.BlackRook)) parsed.map(_.pieceAt(Square(File.A, Rank.R8))) shouldBe Some(Some(Piece.BlackRook))
parsed.map(_.pieceAt(Square(File.B, Rank.R8))) shouldBe Some(Some(Piece.BlackKnight)) parsed.map(_.pieceAt(Square(File.B, Rank.R8))) shouldBe Some(Some(Piece.BlackKnight))
@@ -150,25 +158,33 @@ class FenParserFastParseTest extends AnyFunSuite with Matchers:
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 b - - 0 1").fold(_ => fail(), _.turn shouldBe Color.Black) FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 b - - 0 1").fold(_ => fail(), _.turn shouldBe Color.Black)
test("parseFen tests all castling combinations"): test("parseFen tests all castling combinations"):
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w KQkq - 0 1").fold(_ => fail(), ctx => FenParserFastParse
ctx.castlingRights.whiteKingSide shouldBe true .parseFen("8/8/8/8/8/8/8/8 w KQkq - 0 1")
ctx.castlingRights.whiteQueenSide shouldBe true .fold(
ctx.castlingRights.blackKingSide shouldBe true _ => fail(),
ctx.castlingRights.blackQueenSide shouldBe true ctx =>
) ctx.castlingRights.whiteKingSide shouldBe true
ctx.castlingRights.whiteQueenSide shouldBe true
ctx.castlingRights.blackKingSide shouldBe true
ctx.castlingRights.blackQueenSide shouldBe true,
)
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w Kq - 0 1").fold(_ => fail(), ctx => FenParserFastParse
ctx.castlingRights.whiteKingSide shouldBe true .parseFen("8/8/8/8/8/8/8/8 w Kq - 0 1")
ctx.castlingRights.whiteQueenSide shouldBe false .fold(
ctx.castlingRights.blackKingSide shouldBe false _ => fail(),
ctx.castlingRights.blackQueenSide shouldBe true ctx =>
) ctx.castlingRights.whiteKingSide shouldBe true
ctx.castlingRights.whiteQueenSide shouldBe false
ctx.castlingRights.blackKingSide shouldBe false
ctx.castlingRights.blackQueenSide shouldBe true,
)
test("parseFen tests all en passant files"): test("parseFen tests all en passant files"):
for file <- Seq("a", "b", "c", "d", "e", "f", "g", "h") do for file <- Seq("a", "b", "c", "d", "e", "f", "g", "h") do
FenParserFastParse.parseFen(s"8/8/8/8/8/8/8/8 w - ${file}3 0 1").fold(_ => fail(), ctx => FenParserFastParse
ctx.enPassantSquare should not be empty .parseFen(s"8/8/8/8/8/8/8/8 w - ${file}3 0 1")
) .fold(_ => fail(), ctx => ctx.enPassantSquare should not be empty)
test("parseBoard with mixed pieces and empty squares"): test("parseBoard with mixed pieces and empty squares"):
FenParserFastParse.parseBoard("r1bqkb1r/pppppppp/2n2n2/8/8/2N2N2/PPPPPPPP/R1BQKB1R") should not be empty FenParserFastParse.parseBoard("r1bqkb1r/pppppppp/2n2n2/8/8/2N2N2/PPPPPPPP/R1BQKB1R") should not be empty
@@ -81,8 +81,7 @@ object DefaultRules extends RuleSet:
private def countPositionOccurrences(context: GameContext, targetPosition: Position): Int = private def countPositionOccurrences(context: GameContext, targetPosition: Position): Int =
try try
var count = 0 val initialCtx = GameContext(
var tempCtx = GameContext(
board = context.initialBoard, board = context.initialBoard,
turn = Color.White, turn = Color.White,
castlingRights = CastlingRights.Initial, castlingRights = CastlingRights.Initial,
@@ -91,20 +90,24 @@ object DefaultRules extends RuleSet:
moves = List.empty, moves = List.empty,
initialBoard = context.initialBoard, initialBoard = context.initialBoard,
) )
var tempPos = Position(tempCtx.board, tempCtx.turn, tempCtx.castlingRights, tempCtx.enPassantSquare)
if tempPos == targetPosition then count += 1
for move <- context.moves do def positionOf(ctx: GameContext): Position =
tempCtx = applyMove(tempCtx)(move) Position(
tempPos = Position( board = ctx.board,
board = tempCtx.board, turn = ctx.turn,
turn = tempCtx.turn, castlingRights = ctx.castlingRights,
castlingRights = tempCtx.castlingRights, enPassantSquare = ctx.enPassantSquare,
enPassantSquare = tempCtx.enPassantSquare,
) )
if tempPos == targetPosition then count += 1
count val initialCount = if positionOf(initialCtx) == targetPosition then 1 else 0
context.moves
.foldLeft((initialCtx, initialCount)) { case ((tempCtx, count), move) =>
val nextCtx = applyMove(tempCtx)(move)
val nextCount = if positionOf(nextCtx) == targetPosition then count + 1 else count
(nextCtx, nextCount)
}
._2
catch catch
case _: Exception => case _: Exception =>
// If replay fails, conservatively count only the current position (never triggers a draw) // If replay fails, conservatively count only the current position (never triggers a draw)
@@ -29,10 +29,12 @@ class DefaultRulesTest extends AnyFunSuite with Matchers:
test("pawn can capture diagonally"): test("pawn can capture diagonally"):
// FEN: white pawn e4, black pawn d5 // FEN: white pawn e4, black pawn d5
val fen = "8/8/8/3p4/4P3/8/8/8 w - - 0 1" val fen = "8/8/8/3p4/4P3/8/8/8 w - - 0 1"
val context = FenParser.parseFen(fen).fold(_ => fail(), identity) val context = FenParser.parseFen(fen).fold(_ => fail(), identity)
val moves = rules.allLegalMoves(context) val moves = rules.allLegalMoves(context)
val captures = moves.filter(m => m.from == Square(File.E, Rank.R4) && (m.moveType match { case _: MoveType.Normal => true; case _ => false })) val captures = moves.filter(m =>
m.from == Square(File.E, Rank.R4) && (m.moveType match { case _: MoveType.Normal => true; case _ => false }),
)
captures.exists(m => m.to == Square(File.D, Rank.R5)) shouldBe true captures.exists(m => m.to == Square(File.D, Rank.R5)) shouldBe true
test("pawn cannot move backward"): test("pawn cannot move backward"):
@@ -208,7 +210,7 @@ class DefaultRulesTest extends AnyFunSuite with Matchers:
test("threefold repetition catch block returns false for inconsistent context"): test("threefold repetition catch block returns false for inconsistent context"):
// A context whose moves cannot be replayed from initialBoard (forces the catch path) // A context whose moves cannot be replayed from initialBoard (forces the catch path)
val m = Move(Square(File.E, Rank.R5), Square(File.E, Rank.R6)) // e5→e6, no pawn there in initial board val m = Move(Square(File.E, Rank.R5), Square(File.E, Rank.R6)) // e5→e6, no pawn there in initial board
val brokenCtx = GameContext( val brokenCtx = GameContext(
board = Board.initial, board = Board.initial,
turn = Color.White, turn = Color.White,
@@ -205,7 +205,12 @@ class ChessBoardView(val stage: Stage, private val engine: GameEngine) extends B
else else
val isPromo = engine.ruleSet val isPromo = engine.ruleSet
.legalMoves(engine.context)(fromSquare) .legalMoves(engine.context)(fromSquare)
.exists(m => m.to == clickedSquare && m.moveType.isInstanceOf[MoveType.Promotion]) .exists(m =>
m.to == clickedSquare && (m.moveType match
case MoveType.Promotion(_) => true
case _ => false
),
)
if isPromo then showPromotionDialog(fromSquare, clickedSquare) if isPromo then showPromotionDialog(fromSquare, clickedSquare)
else engine.processUserInput(s"${fromSquare}$clickedSquare") else engine.processUserInput(s"${fromSquare}$clickedSquare")
selectedSquare.set(None) selectedSquare.set(None)
@@ -267,7 +272,8 @@ class ChessBoardView(val stage: Stage, private val engine: GameEngine) extends B
case Some(piece) => case Some(piece) =>
Seq(bgRect) ++ PieceSprites.loadPieceImage(piece, squareSize * 0.8).toSeq Seq(bgRect) ++ PieceSprites.loadPieceImage(piece, squareSize * 0.8).toSeq
case None => case None =>
Seq(bgRect)): Seq[scalafx.scene.Node] Seq(bgRect)
): Seq[scalafx.scene.Node]
} }
def showMessage(msg: String): Unit = def showMessage(msg: String): Unit =
@@ -30,9 +30,9 @@ class ChessGUIApp extends JFXApplication:
stage.scene = new Scene { stage.scene = new Scene {
root = boardView root = boardView
// Load CSS if available // Load CSS if available
try { try
Option(getClass.getResource("/styles.css")).foreach(url => stylesheets.add(url.toExternalForm)) Option(getClass.getResource("/styles.css")).foreach(url => stylesheets.add(url.toExternalForm))
} catch { catch {
case _: Exception => // CSS is optional case _: Exception => // CSS is optional
} }
} }