feat: Improve code formatting and readability in AlphaBetaSearch and related tests
Build & Test (NowChessSystems) TeamCity build failed

This commit is contained in:
2026-04-17 21:57:25 +02:00
parent 473c62666a
commit 0f6e32cf08
24 changed files with 352 additions and 268 deletions
@@ -5,5 +5,4 @@ import de.nowchess.api.player.PlayerInfo
sealed trait Participant
final case class Human(playerInfo: PlayerInfo) extends Participant
final case class BotParticipant(bot: Bot) extends Participant
final case class BotParticipant(bot: Bot) extends Participant
@@ -17,4 +17,3 @@ object BotMoveRepetition:
def filterAllowed(context: GameContext, moves: List[Move]): List[Move] =
val blocked = blockedMoves(context)
moves.filterNot(blocked.contains)
@@ -24,10 +24,8 @@ object EvaluationNNUE extends Evaluation:
override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit =
// Use incremental updates, but recompute from scratch every 10 plies to prevent accumulation errors
if (childPly % 10 == 0) then
nnue.recomputeAccumulator(childPly, child.board)
else
nnue.pushAccumulator(childPly, move, parent.board)
if childPly % 10 == 0 then nnue.recomputeAccumulator(childPly, child.board)
else nnue.pushAccumulator(childPly, move, parent.board)
override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int =
nnue.evaluateAtPlyWithValidation(ply, context.turn, hash, context.board)
@@ -1,14 +1,14 @@
package de.nowchess.bot.bots.nnue
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square}
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
class NNUE(model: NbaiModel):
private val featureSize = model.layers(0).inputSize
private val accSize = model.layers(0).outputSize
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
private val featureSize = model.layers(0).inputSize
private val accSize = model.layers(0).outputSize
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
// Column-major L1 weights for cache-friendly sparse & incremental updates.
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
@@ -60,10 +60,10 @@ class NNUE(model: NbaiModel):
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize)
val l1 = l1Stack(childPly)
move.moveType match
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
@@ -80,12 +80,13 @@ class NNUE(model: NbaiModel):
// Compare with actual L1
val actual = l1Stack(ply)
var maxError = 0f
for i <- 0 until accSize do
val error = math.abs(actual(i) - expectedL1(i))
if error > maxError then maxError = error
val maxError =
(0 until accSize).foldLeft(0f) { (currentMax, i) =>
val error = math.abs(actual(i) - expectedL1(i))
math.max(currentMax, error)
}
maxError < 0.001f // Allow small floating-point errors
maxError < 0.001f // Allow small floating-point errors
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
// Extract source and destination square indices early
@@ -156,24 +157,24 @@ class NNUE(model: NbaiModel):
// For debugging: validate that incremental accumulator matches recomputation
if validateAccum && ply > 0 && ply % 10 != 0 then
val isValid = validateAccumulator(ply, board)
if !isValid then
System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
if !isValid then System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
evaluateAtPly(ply, turn, hash)
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
val l1ReLU = evalBuffers(0)
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
var input = l1ReLU
for i <- 1 until model.layers.length - 1 do
val lw = model.weights(i)
val out = evalBuffers(i)
val ld = model.layers(i)
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
input = out
val finalInput =
(1 until model.layers.length - 1).foldLeft(l1ReLU) { (input, i) =>
val lw = model.weights(i)
val out = evalBuffers(i)
val ld = model.layers(i)
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
out
}
val lastIdx = model.layers.length - 1
val output = runOutputLayer(input, model.layers(lastIdx).inputSize, model.weights(lastIdx))
val output = runOutputLayer(finalInput, model.layers(lastIdx).inputSize, model.weights(lastIdx))
scoreFromOutput(output, turn)
private def runDenseReLU(
@@ -20,8 +20,10 @@ object NbaiLoader:
/** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */
def loadDefault(): NbaiModel =
Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match
case Some(s) => try load(s) finally s.close()
case None => NbaiMigrator.migrateFromBin()
case Some(s) =>
try load(s)
finally s.close()
case None => NbaiMigrator.migrateFromBin()
private def checkHeader(buf: ByteBuffer): Unit =
val magic = buf.getInt()
@@ -9,11 +9,11 @@ object NbaiMigrator:
private val BinVersion = 1
private val DefaultLayers: Array[LayerDescriptor] = Array(
LayerDescriptor("relu", 768, 1536),
LayerDescriptor("relu", 1536, 1024),
LayerDescriptor("relu", 1024, 512),
LayerDescriptor("relu", 512, 256),
LayerDescriptor("linear", 256, 1),
LayerDescriptor("relu", 768, 1536),
LayerDescriptor("relu", 1536, 1024),
LayerDescriptor("relu", 1024, 512),
LayerDescriptor("relu", 512, 256),
LayerDescriptor("linear", 256, 1),
)
private val UnknownMetadata: NbaiMetadata =
@@ -24,7 +24,13 @@ object NbaiMetadata:
def fromJson(json: String): NbaiMetadata =
def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("")
def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0")
NbaiMetadata(str("trainedBy"), str("trainedAt"), num("trainingDataCount").toLong, num("valLoss").toDouble, num("trainLoss").toDouble)
NbaiMetadata(
str("trainedBy"),
str("trainedAt"),
num("trainingDataCount").toLong,
num("valLoss").toDouble,
num("trainLoss").toDouble,
)
/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */
case class LayerWeights(weights: Array[Float], bias: Array[Float])
@@ -147,8 +147,8 @@ final class AlphaBetaSearch(
state: SearchState,
excludedRootMoves: Set[Move],
): Option[Int] =
val nullCtx = nullMoveContext(context)
val nullState = state.advance(ZobristHash.hash(nullCtx))
val nullCtx = nullMoveContext(context)
val nullState = state.advance(ZobristHash.hash(nullCtx))
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
weights.copyAccumulator(ply, ply + 1)
val (score, _) = search(nullCtx, reductionDepth, ply + 1, Window(-beta, -beta + 1), nullState, excludedRootMoves)
@@ -199,8 +199,14 @@ final class AlphaBetaSearch(
legalMoves: List[Move],
): Option[(Int, Option[Move])] =
if legalMoves.isEmpty then
Some((if rules.isCheckmate(params.context) then -(weights.CHECKMATE_SCORE - params.ply) else weights.DRAW_SCORE, None))
else if rules.isInsufficientMaterial(params.context) || rules.isFiftyMoveRule(params.context) then Some((weights.DRAW_SCORE, None))
Some(
(
if rules.isCheckmate(params.context) then -(weights.CHECKMATE_SCORE - params.ply) else weights.DRAW_SCORE,
None,
),
)
else if rules.isInsufficientMaterial(params.context) || rules.isFiftyMoveRule(params.context) then
Some((weights.DRAW_SCORE, None))
else if params.depth == 0 then
Some((quiescence(params.context, params.ply, params.window.alpha, params.window.beta, params.state.hash), None))
else None
@@ -210,7 +216,18 @@ final class AlphaBetaSearch(
legalMoves: List[Move],
): (Int, Option[Move]) =
val nullResult =
Option.when(canTryNullMove(params))(tryNullMove(params.context, params.depth, params.ply, params.window.beta, params.state, params.excludedRootMoves)).flatten
Option
.when(canTryNullMove(params))(
tryNullMove(
params.context,
params.depth,
params.ply,
params.window.beta,
params.state,
params.excludedRootMoves,
),
)
.flatten
nullResult.map((_, None)).getOrElse {
val ttBest = tt.probe(params.state.hash).flatMap(_.bestMove)
@@ -228,13 +245,13 @@ final class AlphaBetaSearch(
private def canTryNullMove(params: SearchParams): Boolean =
params.depth >= 3 &&
!rules.isCheck(params.context) &&
hasNonPawnMaterial(params.context)
!rules.isCheck(params.context) &&
hasNonPawnMaterial(params.context)
private def isQuietMove(context: GameContext, move: Move): Boolean =
!isCapture(context, move) &&
move.moveType != MoveType.CastleKingside &&
move.moveType != MoveType.CastleQueenside
move.moveType != MoveType.CastleKingside &&
move.moveType != MoveType.CastleQueenside
private def scoreMove(
child: GameContext,
@@ -246,14 +263,35 @@ final class AlphaBetaSearch(
): Int =
val betaNeg = -params.window.beta
if reduction > 0 then
val (rs, _) = search(child, math.max(0, params.depth - 1 - reduction + extension), params.ply + 1, Window(-a - 1, -a), childState, params.excludedRootMoves)
val s = -rs
val (rs, _) = search(
child,
math.max(0, params.depth - 1 - reduction + extension),
params.ply + 1,
Window(-a - 1, -a),
childState,
params.excludedRootMoves,
)
val s = -rs
if s > a then
val (fs, _) = search(child, math.max(0, params.depth - 1 + extension), params.ply + 1, Window(betaNeg, -a), childState, params.excludedRootMoves)
val (fs, _) = search(
child,
math.max(0, params.depth - 1 + extension),
params.ply + 1,
Window(betaNeg, -a),
childState,
params.excludedRootMoves,
)
-fs
else s
else
val (rs, _) = search(child, math.max(0, params.depth - 1 + extension), params.ply + 1, Window(betaNeg, -a), childState, params.excludedRootMoves)
val (rs, _) = search(
child,
math.max(0, params.depth - 1 + extension),
params.ply + 1,
Window(betaNeg, -a),
childState,
params.excludedRootMoves,
)
-rs
private def evalSingleMove(
@@ -268,8 +306,8 @@ final class AlphaBetaSearch(
weights.evaluateAccumulator(params.ply, params.context, params.state.hash) + FUTILITY_MARGIN < params.window.alpha
if skipRoot || futility then None
else
val child = rules.applyMove(params.context)(move)
val childHash = ZobristHash.nextHash(params.context, params.state.hash, move, child)
val child = rules.applyMove(params.context)(move)
val childHash = ZobristHash.nextHash(params.context, params.state.hash, move, child)
weights.pushAccumulator(params.ply + 1, move, params.context, child)
val childState = params.state.advance(childHash)
val extension = if rules.isCheck(child) then CHECK_EXTENSION else 0
@@ -317,7 +355,7 @@ final class AlphaBetaSearch(
state: SearchState,
excludedRootMoves: Set[Move],
): (Int, Option[Move]) =
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
val (bestMove, bestScore, cutoff) = searchLoop(0, 0, LoopAcc(None, -INF, window.alpha), params, ordered)
val flag =
if cutoff then TTFlag.Lower
@@ -18,10 +18,13 @@ final case class SearchParams(
final case class SearchState(hash: Long, repetitions: Map[Long, Int]):
def advance(nextHash: Long): SearchState =
SearchState(nextHash, repetitions.updatedWith(nextHash) {
case Some(v) => Some(v + 1)
case None => Some(1)
})
SearchState(
nextHash,
repetitions.updatedWith(nextHash) {
case Some(v) => Some(v + 1)
case None => Some(1)
},
)
enum TTFlag:
case Exact // Score is exact
@@ -11,11 +11,13 @@ import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import de.nowchess.rules.sets.DefaultRules
import java.util.concurrent.atomic.AtomicBoolean
class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
private object ZeroEval extends Evaluation:
val CHECKMATE_SCORE: Int = 1_000_000
val DRAW_SCORE: Int = 0
val CHECKMATE_SCORE: Int = 1_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 0
test("bestMove on initial position returns a move"):
@@ -43,7 +45,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(stubRules, weights = EvaluationClassic)
@@ -61,7 +63,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(stubRules, weights = EvaluationClassic)
@@ -109,7 +111,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = true
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(stalematRules, weights = EvaluationClassic)
@@ -126,7 +128,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = true
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(insufficientRules, weights = EvaluationClassic)
@@ -143,7 +145,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = true
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(fiftyMoveRules, weights = EvaluationClassic)
@@ -170,7 +172,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(rulesWithCapture, weights = EvaluationClassic)
@@ -188,7 +190,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(rulesQuiet, weights = EvaluationClassic)
@@ -311,18 +313,16 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove))
test("quiescence depth-limit in-check branch is exercised"):
val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true))
var firstChildCheckCall = true
val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true))
val firstChildCheckCall = AtomicBoolean(true)
val deepQRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context)
def legalMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context)
def allLegalMoves(context: GameContext): List[Move] =
if context.moves.isEmpty then List(rootMove) else List(capMove)
def isCheck(context: GameContext): Boolean =
if context.moves.length == 1 && firstChildCheckCall then
firstChildCheckCall = false
false
if context.moves.length == 1 && firstChildCheckCall.compareAndSet(true, false) then false
else context.moves.nonEmpty
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
@@ -334,4 +334,3 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
val search = AlphaBetaSearch(deepQRules, weights = ZeroEval)
search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove))
@@ -35,7 +35,7 @@ class ClassicalBotTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
@@ -66,7 +66,7 @@ class ClassicalBotTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
@@ -89,11 +89,10 @@ class ClassicalBotTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val context = GameContext.initial.copy(moves = List(repeatedMove, repeatedMove, repeatedMove))
val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
bot.nextMove(context) should be(None)
@@ -12,6 +12,7 @@ import org.scalatest.matchers.should.Matchers
import java.io.{DataOutputStream, FileOutputStream}
import java.nio.file.Files
import java.util.concurrent.atomic.AtomicBoolean
import scala.util.Using
class HybridBotTest extends AnyFunSuite with Matchers:
@@ -68,8 +69,8 @@ class HybridBotTest extends AnyFunSuite with Matchers:
test("HybridBot uses book move when available"):
val tempFile = Files.createTempFile("hybrid_book", ".bin")
try
val ctx = GameContext.initial
val hash = PolyglotHash.hash(ctx)
val ctx = GameContext.initial
val hash = PolyglotHash.hash(ctx)
val e2e4: Short = (4 | (3 << 3) | (4 << 6) | (1 << 9)).toShort
Using(DataOutputStream(FileOutputStream(tempFile.toFile))) { dos =>
@@ -100,26 +101,26 @@ class HybridBotTest extends AnyFunSuite with Matchers:
context.copy(turn = context.turn.opposite, moves = context.moves :+ move)
object LowNnue extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 0
object HighClassic extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 10_000
var reported = false
val reported = AtomicBoolean(false)
val bot = HybridBot(
BotDifficulty.Easy,
rules = oneMoveRules,
nnueEvaluation = LowNnue,
classicalEvaluation = HighClassic,
vetoReporter = _ => reported = true,
vetoReporter = _ => reported.set(true),
)
bot.nextMove(GameContext.initial) should be(Some(forcedMove))
reported should be(true)
reported.get should be(true)
test("HybridBot default veto reporter prints when threshold is exceeded"):
val forcedMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
@@ -137,13 +138,13 @@ class HybridBotTest extends AnyFunSuite with Matchers:
context.copy(turn = context.turn.opposite, moves = context.moves :+ move)
object LowNnue extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 0
object HighClassic extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 10_000
val bot = HybridBot(
@@ -157,4 +158,3 @@ class HybridBotTest extends AnyFunSuite with Matchers:
bot.nextMove(GameContext.initial)
}
printed should be(Some(forcedMove))
@@ -9,16 +9,6 @@ import org.scalatest.matchers.should.Matchers
class MoveOrderingTest extends AnyFunSuite with Matchers:
private val moveOrderingClass = Class.forName("de.nowchess.bot.logic.MoveOrdering$")
private val moveOrderingObject = moveOrderingClass.getField("MODULE$").get(null)
private def invokeMoveOrderingPrivate[T](methodPrefix: String, args: Seq[AnyRef]): T =
val method = moveOrderingClass.getDeclaredMethods
.find(m => (m.getName == methodPrefix || m.getName.startsWith(methodPrefix + "$")) && m.getParameterCount == args.length)
.getOrElse(throw RuntimeException(s"Method not found: $methodPrefix/${args.length}"))
method.setAccessible(true)
method.invoke(moveOrderingObject, args*).asInstanceOf[T]
test("queen capture ranks higher than rook capture"):
val board = Board(
Map(
@@ -200,13 +190,13 @@ class MoveOrderingTest extends AnyFunSuite with Matchers:
Square(File.D, Rank.R8) -> Piece.BlackRook,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val knightPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Knight))
val bishopPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Bishop))
val rookPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Rook))
MoveOrdering.score(context, knightPromo, None) should be > 0
MoveOrdering.score(context, bishopPromo, None) should be > 0
MoveOrdering.score(context, rookPromo, None) should be > 0
MoveOrdering.score(context, rookPromo, None) should be > 0
test("negative SEE capture path is scored below neutral capture baseline"):
val board = Board(
@@ -221,23 +211,9 @@ class MoveOrderingTest extends AnyFunSuite with Matchers:
MoveOrdering.score(context, move, None) should be < 100_000
test("private fallback branches in MoveOrdering are covered"):
test("non-capture move keeps fallback scoring at zero"):
val board = Board(Map(Square(File.E, Rank.R1) -> Piece.WhiteKing, Square(File.A, Rank.R8) -> Piece.BlackKing))
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val castle = Move(Square(File.E, Rank.R1), Square(File.G, Rank.R1), MoveType.CastleKingside)
val victim = invokeMoveOrderingPrivate[Int]("victimValue", Seq[AnyRef](context, castle))
victim should be(0)
val capture = invokeMoveOrderingPrivate[Boolean]("isCapture", Seq[AnyRef](context, castle))
capture should be(false)
val see = invokeMoveOrderingPrivate[Int]("staticExchange", Seq[AnyRef](context, castle))
see should be(0)
val clear = invokeMoveOrderingPrivate[Boolean](
"pathClear",
Seq[AnyRef](board, Square(File.A, Rank.R1), Square(File.H, Rank.R1), Int.box(-1), Int.box(0)),
)
clear should be(false)
MoveOrdering.score(context, castle, None) should be(0)
@@ -98,8 +98,12 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
Square(File.E, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White)
.withCastlingRights(CastlingRights(whiteKingSide = false, whiteQueenSide = true, blackKingSide = false, blackQueenSide = false))
val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.White)
.withCastlingRights(
CastlingRights(whiteKingSide = false, whiteQueenSide = true, blackKingSide = false, blackQueenSide = false),
)
val move = Move(Square(File.E, Rank.R1), Square(File.C, Rank.R1), MoveType.CastleQueenside)
val next = DefaultRules.applyMove(ctx)(move)
ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next))
@@ -113,7 +117,9 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
Square(File.E, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White)
val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.White)
.withEnPassantSquare(Some(Square(File.D, Rank.R6)))
val move = Move(Square(File.E, Rank.R5), Square(File.D, Rank.R6), MoveType.EnPassant)
val next = DefaultRules.applyMove(ctx)(move)
@@ -127,8 +133,12 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
Square(File.E, Rank.R1) -> Piece.WhiteKing,
),
)
val ctx = GameContext.initial.withBoard(board).withTurn(Color.Black)
.withCastlingRights(CastlingRights(whiteKingSide = false, whiteQueenSide = false, blackKingSide = true, blackQueenSide = false))
val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.Black)
.withCastlingRights(
CastlingRights(whiteKingSide = false, whiteQueenSide = false, blackKingSide = true, blackQueenSide = false),
)
val move = Move(Square(File.E, Rank.R8), Square(File.G, Rank.R8), MoveType.CastleKingside)
val next = DefaultRules.applyMove(ctx)(move)
ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next))
@@ -141,7 +151,9 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
Square(File.E, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White)
val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.White)
.withCastlingRights(CastlingRights(false, false, false, false))
for pp <- List(PromotionPiece.Knight, PromotionPiece.Bishop, PromotionPiece.Rook) do
@@ -5,8 +5,8 @@ import de.nowchess.api.move.PromotionPiece
object Parser:
/** Parses UCI move notation: "e2e4" (4 chars) or "e7e8q" (5 chars with promotion piece suffix).
* The promotion suffix is q=Queen, r=Rook, b=Bishop, n=Knight. Returns None for invalid input.
/** Parses UCI move notation: "e2e4" (4 chars) or "e7e8q" (5 chars with promotion piece suffix). The promotion suffix
* is q=Queen, r=Rook, b=Bishop, n=Knight. Returns None for invalid input.
*/
def parseMove(input: String): Option[(Square, Square, Option[PromotionPiece])] =
val trimmed = input.trim.toLowerCase
@@ -19,13 +19,16 @@ import scala.concurrent.{ExecutionContext, Future}
class GameEngine(
val initialContext: GameContext = GameContext.initial,
val ruleSet: RuleSet = DefaultRules,
val participants: Map[Color, Participant] = Map(Color.White -> Human(PlayerInfo(PlayerId("p1"), "Player 1")), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))),
val participants: Map[Color, Participant] = Map(
Color.White -> Human(PlayerInfo(PlayerId("p1"), "Player 1")),
Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2")),
),
) extends Observable:
// Ensure that initialBoard is set correctly for threefold repetition detection
private val contextWithInitialBoard = if initialContext.moves.isEmpty && initialContext.board != initialContext.initialBoard then
initialContext.copy(initialBoard = initialContext.board)
else
initialContext
private val contextWithInitialBoard =
if initialContext.moves.isEmpty && initialContext.board != initialContext.initialBoard then
initialContext.copy(initialBoard = initialContext.board)
else initialContext
@SuppressWarnings(Array("DisableSyntax.var"))
private var currentContext: GameContext = contextWithInitialBoard
private val invoker = new CommandInvoker()
@@ -109,10 +112,15 @@ class GameEngine(
notifyObservers(InvalidMoveEvent(currentContext, "Illegal move."))
case _ if isPromotionMove(piece, to) =>
if promotionPiece.isEmpty then
notifyObservers(InvalidMoveEvent(currentContext, "Promotion piece required: append q, r, b, or n to the move."))
notifyObservers(
InvalidMoveEvent(currentContext, "Promotion piece required: append q, r, b, or n to the move."),
)
else
candidates.find(_.moveType == MoveType.Promotion(promotionPiece.get)) match
case None => notifyObservers(InvalidMoveEvent(currentContext, "Error completing promotion: no matching legal move."))
case None =>
notifyObservers(
InvalidMoveEvent(currentContext, "Error completing promotion: no matching legal move."),
)
case Some(move) => executeMove(move)
case move :: _ =>
executeMove(move)
@@ -159,7 +167,7 @@ class GameEngine(
result
private def applyReplayMove(move: Move): Either[String, Unit] =
val legal = ruleSet.legalMoves(currentContext)(move.from)
val legal = ruleSet.legalMoves(currentContext)(move.from)
val candidate = move.moveType match
case MoveType.Promotion(pp) => legal.find(m => m.to == move.to && m.moveType == MoveType.Promotion(pp))
case _ => legal.find(_.to == move.to)
@@ -174,10 +182,9 @@ class GameEngine(
/** Load an arbitrary board position, clearing all history and undo/redo state. */
def loadPosition(newContext: GameContext): Unit = synchronized {
val contextWithInitialBoard = if newContext.moves.isEmpty then
newContext.copy(initialBoard = newContext.board)
else
newContext
val contextWithInitialBoard =
if newContext.moves.isEmpty then newContext.copy(initialBoard = newContext.board)
else newContext
currentContext = contextWithInitialBoard
invoker.clear()
notifyObservers(BoardResetEvent(currentContext))
@@ -235,7 +242,8 @@ class GameEngine(
else if ruleSet.isCheck(currentContext) then notifyObservers(CheckDetectedEvent(currentContext))
if currentContext.halfMoveClock >= 100 then notifyObservers(FiftyMoveRuleAvailableEvent(currentContext))
if ruleSet.isThreefoldRepetition(currentContext) then notifyObservers(ThreefoldRepetitionAvailableEvent(currentContext))
if ruleSet.isThreefoldRepetition(currentContext) then
notifyObservers(ThreefoldRepetitionAvailableEvent(currentContext))
else requestBotMoveIfNeeded()
private def translateMoveToNotation(move: Move, boardBefore: Board): String =
@@ -18,8 +18,8 @@ class CommandInvokerBranchTest extends AnyFunSuite with Matchers:
initialShouldFailOnUndo: Boolean = false,
initialShouldFailOnExecute: Boolean = false,
) extends Command:
val shouldFailOnUndo = new java.util.concurrent.atomic.AtomicBoolean(initialShouldFailOnUndo)
val shouldFailOnExecute = new java.util.concurrent.atomic.AtomicBoolean(initialShouldFailOnExecute)
val shouldFailOnUndo = new java.util.concurrent.atomic.AtomicBoolean(initialShouldFailOnUndo)
val shouldFailOnExecute = new java.util.concurrent.atomic.AtomicBoolean(initialShouldFailOnExecute)
override def execute(): Boolean = !shouldFailOnExecute.get()
override def undo(): Boolean = !shouldFailOnUndo.get()
override def description: String = "Conditional fail"
@@ -251,10 +251,10 @@ class GameEngineOutcomesTest extends AnyFunSuite with Matchers:
engine.processUserInput("f3g1")
observer.clear()
engine.processUserInput("f6g8") // 3rd occurrence of initial position
engine.processUserInput("f6g8") // 3rd occurrence of initial position
observer.hasEvent[ThreefoldRepetitionAvailableEvent] shouldBe true
engine.context.result shouldBe None // claimable, not automatic
engine.context.result shouldBe None // claimable, not automatic
test("draw claim via threefold repetition ends game with DrawEvent"):
val engine = EngineTestHelpers.makeEngine()
@@ -268,7 +268,7 @@ class GameEngineOutcomesTest extends AnyFunSuite with Matchers:
engine.processUserInput("g1f3")
engine.processUserInput("g8f6")
engine.processUserInput("f3g1")
engine.processUserInput("f6g8") // threefold now available
engine.processUserInput("f6g8") // threefold now available
observer.clear()
engine.processUserInput("draw")
@@ -15,17 +15,17 @@ import org.scalatest.matchers.should.Matchers
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
private class NoMoveBot extends Bot:
def name: String = "nomove"
def nextMove(context: GameContext): Option[Move] = None
def name: String = "nomove"
def nextMove(context: GameContext): Option[Move] = None
private class FixedMoveBot(move: Move) extends Bot:
def name: String = "fixed"
def nextMove(context: GameContext): Option[Move] = Some(move)
def name: String = "fixed"
def nextMove(context: GameContext): Option[Move] = Some(move)
class GameEngineWithBotTest extends AnyFunSuite with Matchers:
test("GameEngine can play against a ClassicalBot"):
val bot = ClassicalBot(BotDifficulty.Easy)
val bot = ClassicalBot(BotDifficulty.Easy)
val engine = GameEngine(
GameContext.initial,
DefaultRules,
@@ -99,7 +99,7 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
movesMade.get() should be >= 1
test("GameEngine plays valid bot moves"):
val bot = ClassicalBot(BotDifficulty.Easy)
val bot = ClassicalBot(BotDifficulty.Easy)
val engine = GameEngine(
GameContext.initial,
DefaultRules,
@@ -125,17 +125,18 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
engine.context.moves.nonEmpty should be(true)
test("startGame triggers bot when the starting player is a bot"):
val bot = new FixedMoveBot(Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4)))
val bot = new FixedMoveBot(Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4)))
val engine = GameEngine(
GameContext.initial,
DefaultRules,
Map(Color.White -> BotParticipant(bot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))),
)
val movesMade = new AtomicInteger(0)
engine.subscribe(new Observer:
def onGameEvent(event: GameEvent): Unit = event match
case _: MoveExecutedEvent => movesMade.incrementAndGet()
case _ => ()
engine.subscribe(
new Observer:
def onGameEvent(event: GameEvent): Unit = event match
case _: MoveExecutedEvent => movesMade.incrementAndGet()
case _ => (),
)
engine.startGame()
Thread.sleep(500)
@@ -143,17 +144,18 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
test("applyBotMove fires InvalidMoveEvent when bot move destination is illegal"):
val illegalMove = Move(Square(File.E, Rank.R7), Square(File.E, Rank.R3), MoveType.Normal())
val bot = new FixedMoveBot(illegalMove)
val bot = new FixedMoveBot(illegalMove)
val engine = GameEngine(
GameContext.initial,
DefaultRules,
Map(Color.White -> Human(PlayerInfo(PlayerId("p1"), "Player 1")), Color.Black -> BotParticipant(bot)),
)
val invalidCount = new AtomicInteger(0)
engine.subscribe(new Observer:
def onGameEvent(event: GameEvent): Unit = event match
case _: InvalidMoveEvent => invalidCount.incrementAndGet()
case _ => ()
engine.subscribe(
new Observer:
def onGameEvent(event: GameEvent): Unit = event match
case _: InvalidMoveEvent => invalidCount.incrementAndGet()
case _ => (),
)
engine.processUserInput("e2e4")
Thread.sleep(1000)
@@ -161,17 +163,18 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
test("applyBotMove fires InvalidMoveEvent when bot move source square is invalid"):
val invalidMove = Move(Square(File.E, Rank.R5), Square(File.E, Rank.R6), MoveType.Normal())
val bot = new FixedMoveBot(invalidMove)
val bot = new FixedMoveBot(invalidMove)
val engine = GameEngine(
GameContext.initial,
DefaultRules,
Map(Color.White -> Human(PlayerInfo(PlayerId("p1"), "Player 1")), Color.Black -> BotParticipant(bot)),
)
val invalidCount = new AtomicInteger(0)
engine.subscribe(new Observer:
def onGameEvent(event: GameEvent): Unit = event match
case _: InvalidMoveEvent => invalidCount.incrementAndGet()
case _ => ()
engine.subscribe(
new Observer:
def onGameEvent(event: GameEvent): Unit = event match
case _: InvalidMoveEvent => invalidCount.incrementAndGet()
case _ => (),
)
engine.processUserInput("e2e4")
Thread.sleep(1000)
@@ -179,12 +182,14 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
test("handleBotNoMove fires CheckmateEvent when position is checkmate"):
// White king at A1 in check from Qb2; Rb8 protects queen so king can't capture it
val board = Board(Map(
Square(File.A, Rank.R1) -> Piece.WhiteKing,
Square(File.B, Rank.R2) -> Piece.BlackQueen,
Square(File.B, Rank.R8) -> Piece.BlackRook,
Square(File.H, Rank.R8) -> Piece.BlackKing,
))
val board = Board(
Map(
Square(File.A, Rank.R1) -> Piece.WhiteKing,
Square(File.B, Rank.R2) -> Piece.BlackQueen,
Square(File.B, Rank.R8) -> Piece.BlackRook,
Square(File.H, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial.copy(
board = board,
turn = Color.White,
@@ -193,12 +198,17 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
halfMoveClock = 0,
moves = List.empty,
)
val engine = GameEngine(ctx, DefaultRules, Map(Color.White -> BotParticipant(new NoMoveBot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))))
val engine = GameEngine(
ctx,
DefaultRules,
Map(Color.White -> BotParticipant(new NoMoveBot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))),
)
val checkmateCount = new AtomicInteger(0)
engine.subscribe(new Observer:
def onGameEvent(event: GameEvent): Unit = event match
case _: CheckmateEvent => checkmateCount.incrementAndGet()
case _ => ()
engine.subscribe(
new Observer:
def onGameEvent(event: GameEvent): Unit = event match
case _: CheckmateEvent => checkmateCount.incrementAndGet()
case _ => (),
)
engine.startGame()
Thread.sleep(1000)
@@ -206,11 +216,13 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
test("handleBotNoMove fires DrawEvent when position is stalemate"):
// White king at A1 not in check but has no legal moves (queen at B3 covers A2, B1, B2)
val board = Board(Map(
Square(File.A, Rank.R1) -> Piece.WhiteKing,
Square(File.B, Rank.R3) -> Piece.BlackQueen,
Square(File.H, Rank.R8) -> Piece.BlackKing,
))
val board = Board(
Map(
Square(File.A, Rank.R1) -> Piece.WhiteKing,
Square(File.B, Rank.R3) -> Piece.BlackQueen,
Square(File.H, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial.copy(
board = board,
turn = Color.White,
@@ -219,12 +231,17 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
halfMoveClock = 0,
moves = List.empty,
)
val engine = GameEngine(ctx, DefaultRules, Map(Color.White -> BotParticipant(new NoMoveBot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))))
val engine = GameEngine(
ctx,
DefaultRules,
Map(Color.White -> BotParticipant(new NoMoveBot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))),
)
val drawCount = new AtomicInteger(0)
engine.subscribe(new Observer:
def onGameEvent(event: GameEvent): Unit = event match
case _: DrawEvent => drawCount.incrementAndGet()
case _ => ()
engine.subscribe(
new Observer:
def onGameEvent(event: GameEvent): Unit = event match
case _: DrawEvent => drawCount.incrementAndGet()
case _ => (),
)
engine.startGame()
Thread.sleep(1000)
@@ -237,13 +254,13 @@ class GameEngineWithBotTest extends AnyFunSuite with Matchers:
Map(Color.White -> BotParticipant(new NoMoveBot), Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2"))),
)
val unexpectedEvents = new AtomicInteger(0)
engine.subscribe(new Observer:
def onGameEvent(event: GameEvent): Unit = event match
case _: CheckmateEvent => unexpectedEvents.incrementAndGet()
case _: DrawEvent => unexpectedEvents.incrementAndGet()
case _ => ()
engine.subscribe(
new Observer:
def onGameEvent(event: GameEvent): Unit = event match
case _: CheckmateEvent => unexpectedEvents.incrementAndGet()
case _: DrawEvent => unexpectedEvents.incrementAndGet()
case _ => (),
)
engine.startGame()
Thread.sleep(500)
unexpectedEvents.get() shouldBe 0
@@ -70,58 +70,66 @@ class FenParserFastParseTest extends AnyFunSuite with Matchers:
FenParserFastParse.parseBoard("8pp/8/8/8/8/8/8/8") shouldBe None
test("parseFen handles all individual castling rights"):
FenParserFastParse.parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w K - 0 1").fold(_ => fail(), ctx =>
ctx.castlingRights.whiteKingSide shouldBe true
ctx.castlingRights.whiteQueenSide shouldBe false
ctx.castlingRights.blackKingSide shouldBe false
ctx.castlingRights.blackQueenSide shouldBe false
)
FenParserFastParse
.parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w K - 0 1")
.fold(
_ => fail(),
ctx =>
ctx.castlingRights.whiteKingSide shouldBe true
ctx.castlingRights.whiteQueenSide shouldBe false
ctx.castlingRights.blackKingSide shouldBe false
ctx.castlingRights.blackQueenSide shouldBe false,
)
FenParserFastParse.parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w Q - 0 1").fold(_ => fail(), ctx =>
ctx.castlingRights.whiteQueenSide shouldBe true
ctx.castlingRights.whiteKingSide shouldBe false
)
FenParserFastParse
.parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w Q - 0 1")
.fold(
_ => fail(),
ctx =>
ctx.castlingRights.whiteQueenSide shouldBe true
ctx.castlingRights.whiteKingSide shouldBe false,
)
FenParserFastParse.parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w k - 0 1").fold(_ => fail(), ctx =>
ctx.castlingRights.blackKingSide shouldBe true
ctx.castlingRights.whiteKingSide shouldBe false
)
FenParserFastParse
.parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w k - 0 1")
.fold(
_ => fail(),
ctx =>
ctx.castlingRights.blackKingSide shouldBe true
ctx.castlingRights.whiteKingSide shouldBe false,
)
FenParserFastParse.parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w q - 0 1").fold(_ => fail(), ctx =>
ctx.castlingRights.blackQueenSide shouldBe true
ctx.castlingRights.whiteKingSide shouldBe false
)
FenParserFastParse
.parseFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w q - 0 1")
.fold(
_ => fail(),
ctx =>
ctx.castlingRights.blackQueenSide shouldBe true
ctx.castlingRights.whiteKingSide shouldBe false,
)
test("parseFen parses all en passant squares"):
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - a3 0 1").fold(_ => fail(), ctx =>
ctx.enPassantSquare shouldBe Some(Square(File.A, Rank.R3))
)
FenParserFastParse
.parseFen("8/8/8/8/8/8/8/8 w - a3 0 1")
.fold(_ => fail(), ctx => ctx.enPassantSquare shouldBe Some(Square(File.A, Rank.R3)))
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - h6 0 1").fold(_ => fail(), ctx =>
ctx.enPassantSquare shouldBe Some(Square(File.H, Rank.R6))
)
FenParserFastParse
.parseFen("8/8/8/8/8/8/8/8 w - h6 0 1")
.fold(_ => fail(), ctx => ctx.enPassantSquare shouldBe Some(Square(File.H, Rank.R6)))
test("parseFen parses different halfMove and fullMove clocks"):
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - - 5 10").fold(_ => fail(), ctx =>
ctx.halfMoveClock shouldBe 5
)
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - - 5 10").fold(_ => fail(), ctx => ctx.halfMoveClock shouldBe 5)
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - - 0 100").fold(_ => fail(), ctx =>
ctx.halfMoveClock shouldBe 0
)
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - - 0 100").fold(_ => fail(), ctx => ctx.halfMoveClock shouldBe 0)
test("parseBoard parses boards with mixed empty and piece tokens"):
val mixed = "8/1p1p1p1p/8/1P1P1P1P/8/8/8/8"
FenParserFastParse.parseBoard(mixed) should not be empty
test("parseFen handles turn transitions"):
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - - 0 1").fold(_ => fail(), ctx =>
ctx.turn shouldBe Color.White
)
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w - - 0 1").fold(_ => fail(), ctx => ctx.turn shouldBe Color.White)
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 b - - 0 1").fold(_ => fail(), ctx =>
ctx.turn shouldBe Color.Black
)
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 b - - 0 1").fold(_ => fail(), ctx => ctx.turn shouldBe Color.Black)
test("parseFen rejects invalid piece characters"):
FenParserFastParse.parseFen("8x/8/8/8/8/8/8/8 w - - 0 1").isLeft shouldBe true
@@ -133,7 +141,7 @@ class FenParserFastParseTest extends AnyFunSuite with Matchers:
test("parseBoard tests all piece types in various positions"):
// Test each piece type: pawn, rook, knight, bishop, queen, king (both colors)
val allPieces = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR"
val parsed = FenParserFastParse.parseBoard(allPieces)
val parsed = FenParserFastParse.parseBoard(allPieces)
parsed.map(_.pieces.size) shouldBe Some(32)
parsed.map(_.pieceAt(Square(File.A, Rank.R8))) shouldBe Some(Some(Piece.BlackRook))
parsed.map(_.pieceAt(Square(File.B, Rank.R8))) shouldBe Some(Some(Piece.BlackKnight))
@@ -150,25 +158,33 @@ class FenParserFastParseTest extends AnyFunSuite with Matchers:
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 b - - 0 1").fold(_ => fail(), _.turn shouldBe Color.Black)
test("parseFen tests all castling combinations"):
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w KQkq - 0 1").fold(_ => fail(), ctx =>
ctx.castlingRights.whiteKingSide shouldBe true
ctx.castlingRights.whiteQueenSide shouldBe true
ctx.castlingRights.blackKingSide shouldBe true
ctx.castlingRights.blackQueenSide shouldBe true
)
FenParserFastParse
.parseFen("8/8/8/8/8/8/8/8 w KQkq - 0 1")
.fold(
_ => fail(),
ctx =>
ctx.castlingRights.whiteKingSide shouldBe true
ctx.castlingRights.whiteQueenSide shouldBe true
ctx.castlingRights.blackKingSide shouldBe true
ctx.castlingRights.blackQueenSide shouldBe true,
)
FenParserFastParse.parseFen("8/8/8/8/8/8/8/8 w Kq - 0 1").fold(_ => fail(), ctx =>
ctx.castlingRights.whiteKingSide shouldBe true
ctx.castlingRights.whiteQueenSide shouldBe false
ctx.castlingRights.blackKingSide shouldBe false
ctx.castlingRights.blackQueenSide shouldBe true
)
FenParserFastParse
.parseFen("8/8/8/8/8/8/8/8 w Kq - 0 1")
.fold(
_ => fail(),
ctx =>
ctx.castlingRights.whiteKingSide shouldBe true
ctx.castlingRights.whiteQueenSide shouldBe false
ctx.castlingRights.blackKingSide shouldBe false
ctx.castlingRights.blackQueenSide shouldBe true,
)
test("parseFen tests all en passant files"):
for file <- Seq("a", "b", "c", "d", "e", "f", "g", "h") do
FenParserFastParse.parseFen(s"8/8/8/8/8/8/8/8 w - ${file}3 0 1").fold(_ => fail(), ctx =>
ctx.enPassantSquare should not be empty
)
FenParserFastParse
.parseFen(s"8/8/8/8/8/8/8/8 w - ${file}3 0 1")
.fold(_ => fail(), ctx => ctx.enPassantSquare should not be empty)
test("parseBoard with mixed pieces and empty squares"):
FenParserFastParse.parseBoard("r1bqkb1r/pppppppp/2n2n2/8/8/2N2N2/PPPPPPPP/R1BQKB1R") should not be empty
@@ -81,8 +81,7 @@ object DefaultRules extends RuleSet:
private def countPositionOccurrences(context: GameContext, targetPosition: Position): Int =
try
var count = 0
var tempCtx = GameContext(
val initialCtx = GameContext(
board = context.initialBoard,
turn = Color.White,
castlingRights = CastlingRights.Initial,
@@ -91,20 +90,24 @@ object DefaultRules extends RuleSet:
moves = List.empty,
initialBoard = context.initialBoard,
)
var tempPos = Position(tempCtx.board, tempCtx.turn, tempCtx.castlingRights, tempCtx.enPassantSquare)
if tempPos == targetPosition then count += 1
for move <- context.moves do
tempCtx = applyMove(tempCtx)(move)
tempPos = Position(
board = tempCtx.board,
turn = tempCtx.turn,
castlingRights = tempCtx.castlingRights,
enPassantSquare = tempCtx.enPassantSquare,
def positionOf(ctx: GameContext): Position =
Position(
board = ctx.board,
turn = ctx.turn,
castlingRights = ctx.castlingRights,
enPassantSquare = ctx.enPassantSquare,
)
if tempPos == targetPosition then count += 1
count
val initialCount = if positionOf(initialCtx) == targetPosition then 1 else 0
context.moves
.foldLeft((initialCtx, initialCount)) { case ((tempCtx, count), move) =>
val nextCtx = applyMove(tempCtx)(move)
val nextCount = if positionOf(nextCtx) == targetPosition then count + 1 else count
(nextCtx, nextCount)
}
._2
catch
case _: Exception =>
// If replay fails, conservatively count only the current position (never triggers a draw)
@@ -29,10 +29,12 @@ class DefaultRulesTest extends AnyFunSuite with Matchers:
test("pawn can capture diagonally"):
// FEN: white pawn e4, black pawn d5
val fen = "8/8/8/3p4/4P3/8/8/8 w - - 0 1"
val context = FenParser.parseFen(fen).fold(_ => fail(), identity)
val moves = rules.allLegalMoves(context)
val captures = moves.filter(m => m.from == Square(File.E, Rank.R4) && (m.moveType match { case _: MoveType.Normal => true; case _ => false }))
val fen = "8/8/8/3p4/4P3/8/8/8 w - - 0 1"
val context = FenParser.parseFen(fen).fold(_ => fail(), identity)
val moves = rules.allLegalMoves(context)
val captures = moves.filter(m =>
m.from == Square(File.E, Rank.R4) && (m.moveType match { case _: MoveType.Normal => true; case _ => false }),
)
captures.exists(m => m.to == Square(File.D, Rank.R5)) shouldBe true
test("pawn cannot move backward"):
@@ -208,7 +210,7 @@ class DefaultRulesTest extends AnyFunSuite with Matchers:
test("threefold repetition catch block returns false for inconsistent context"):
// A context whose moves cannot be replayed from initialBoard (forces the catch path)
val m = Move(Square(File.E, Rank.R5), Square(File.E, Rank.R6)) // e5→e6, no pawn there in initial board
val m = Move(Square(File.E, Rank.R5), Square(File.E, Rank.R6)) // e5→e6, no pawn there in initial board
val brokenCtx = GameContext(
board = Board.initial,
turn = Color.White,
@@ -205,7 +205,12 @@ class ChessBoardView(val stage: Stage, private val engine: GameEngine) extends B
else
val isPromo = engine.ruleSet
.legalMoves(engine.context)(fromSquare)
.exists(m => m.to == clickedSquare && m.moveType.isInstanceOf[MoveType.Promotion])
.exists(m =>
m.to == clickedSquare && (m.moveType match
case MoveType.Promotion(_) => true
case _ => false
),
)
if isPromo then showPromotionDialog(fromSquare, clickedSquare)
else engine.processUserInput(s"${fromSquare}$clickedSquare")
selectedSquare.set(None)
@@ -267,7 +272,8 @@ class ChessBoardView(val stage: Stage, private val engine: GameEngine) extends B
case Some(piece) =>
Seq(bgRect) ++ PieceSprites.loadPieceImage(piece, squareSize * 0.8).toSeq
case None =>
Seq(bgRect)): Seq[scalafx.scene.Node]
Seq(bgRect)
): Seq[scalafx.scene.Node]
}
def showMessage(msg: String): Unit =
@@ -30,9 +30,9 @@ class ChessGUIApp extends JFXApplication:
stage.scene = new Scene {
root = boardView
// Load CSS if available
try {
try
Option(getClass.getResource("/styles.css")).foreach(url => stylesheets.add(url.toExternalForm))
} catch {
catch {
case _: Exception => // CSS is optional
}
}