feat: Improve code formatting and readability in AlphaBetaSearch and related tests
Build & Test (NowChessSystems) TeamCity build failed

This commit is contained in:
2026-04-17 21:57:25 +02:00
parent 473c62666a
commit 0f6e32cf08
24 changed files with 352 additions and 268 deletions
@@ -17,4 +17,3 @@ object BotMoveRepetition:
def filterAllowed(context: GameContext, moves: List[Move]): List[Move] =
val blocked = blockedMoves(context)
moves.filterNot(blocked.contains)
@@ -24,10 +24,8 @@ object EvaluationNNUE extends Evaluation:
override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit =
// Use incremental updates, but recompute from scratch every 10 plies to prevent accumulation errors
if (childPly % 10 == 0) then
nnue.recomputeAccumulator(childPly, child.board)
else
nnue.pushAccumulator(childPly, move, parent.board)
if childPly % 10 == 0 then nnue.recomputeAccumulator(childPly, child.board)
else nnue.pushAccumulator(childPly, move, parent.board)
override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int =
nnue.evaluateAtPlyWithValidation(ply, context.turn, hash, context.board)
@@ -1,14 +1,14 @@
package de.nowchess.bot.bots.nnue
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square}
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
class NNUE(model: NbaiModel):
private val featureSize = model.layers(0).inputSize
private val accSize = model.layers(0).outputSize
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
private val featureSize = model.layers(0).inputSize
private val accSize = model.layers(0).outputSize
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
// Column-major L1 weights for cache-friendly sparse & incremental updates.
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
@@ -60,10 +60,10 @@ class NNUE(model: NbaiModel):
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize)
val l1 = l1Stack(childPly)
move.moveType match
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
@@ -80,12 +80,13 @@ class NNUE(model: NbaiModel):
// Compare with actual L1
val actual = l1Stack(ply)
var maxError = 0f
for i <- 0 until accSize do
val error = math.abs(actual(i) - expectedL1(i))
if error > maxError then maxError = error
val maxError =
(0 until accSize).foldLeft(0f) { (currentMax, i) =>
val error = math.abs(actual(i) - expectedL1(i))
math.max(currentMax, error)
}
maxError < 0.001f // Allow small floating-point errors
maxError < 0.001f // Allow small floating-point errors
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
// Extract source and destination square indices early
@@ -156,24 +157,24 @@ class NNUE(model: NbaiModel):
// For debugging: validate that incremental accumulator matches recomputation
if validateAccum && ply > 0 && ply % 10 != 0 then
val isValid = validateAccumulator(ply, board)
if !isValid then
System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
if !isValid then System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
evaluateAtPly(ply, turn, hash)
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
val l1ReLU = evalBuffers(0)
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
var input = l1ReLU
for i <- 1 until model.layers.length - 1 do
val lw = model.weights(i)
val out = evalBuffers(i)
val ld = model.layers(i)
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
input = out
val finalInput =
(1 until model.layers.length - 1).foldLeft(l1ReLU) { (input, i) =>
val lw = model.weights(i)
val out = evalBuffers(i)
val ld = model.layers(i)
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
out
}
val lastIdx = model.layers.length - 1
val output = runOutputLayer(input, model.layers(lastIdx).inputSize, model.weights(lastIdx))
val output = runOutputLayer(finalInput, model.layers(lastIdx).inputSize, model.weights(lastIdx))
scoreFromOutput(output, turn)
private def runDenseReLU(
@@ -20,8 +20,10 @@ object NbaiLoader:
/** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */
def loadDefault(): NbaiModel =
Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match
case Some(s) => try load(s) finally s.close()
case None => NbaiMigrator.migrateFromBin()
case Some(s) =>
try load(s)
finally s.close()
case None => NbaiMigrator.migrateFromBin()
private def checkHeader(buf: ByteBuffer): Unit =
val magic = buf.getInt()
@@ -9,11 +9,11 @@ object NbaiMigrator:
private val BinVersion = 1
private val DefaultLayers: Array[LayerDescriptor] = Array(
LayerDescriptor("relu", 768, 1536),
LayerDescriptor("relu", 1536, 1024),
LayerDescriptor("relu", 1024, 512),
LayerDescriptor("relu", 512, 256),
LayerDescriptor("linear", 256, 1),
LayerDescriptor("relu", 768, 1536),
LayerDescriptor("relu", 1536, 1024),
LayerDescriptor("relu", 1024, 512),
LayerDescriptor("relu", 512, 256),
LayerDescriptor("linear", 256, 1),
)
private val UnknownMetadata: NbaiMetadata =
@@ -24,7 +24,13 @@ object NbaiMetadata:
def fromJson(json: String): NbaiMetadata =
def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("")
def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0")
NbaiMetadata(str("trainedBy"), str("trainedAt"), num("trainingDataCount").toLong, num("valLoss").toDouble, num("trainLoss").toDouble)
NbaiMetadata(
str("trainedBy"),
str("trainedAt"),
num("trainingDataCount").toLong,
num("valLoss").toDouble,
num("trainLoss").toDouble,
)
/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */
case class LayerWeights(weights: Array[Float], bias: Array[Float])
@@ -147,8 +147,8 @@ final class AlphaBetaSearch(
state: SearchState,
excludedRootMoves: Set[Move],
): Option[Int] =
val nullCtx = nullMoveContext(context)
val nullState = state.advance(ZobristHash.hash(nullCtx))
val nullCtx = nullMoveContext(context)
val nullState = state.advance(ZobristHash.hash(nullCtx))
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
weights.copyAccumulator(ply, ply + 1)
val (score, _) = search(nullCtx, reductionDepth, ply + 1, Window(-beta, -beta + 1), nullState, excludedRootMoves)
@@ -199,8 +199,14 @@ final class AlphaBetaSearch(
legalMoves: List[Move],
): Option[(Int, Option[Move])] =
if legalMoves.isEmpty then
Some((if rules.isCheckmate(params.context) then -(weights.CHECKMATE_SCORE - params.ply) else weights.DRAW_SCORE, None))
else if rules.isInsufficientMaterial(params.context) || rules.isFiftyMoveRule(params.context) then Some((weights.DRAW_SCORE, None))
Some(
(
if rules.isCheckmate(params.context) then -(weights.CHECKMATE_SCORE - params.ply) else weights.DRAW_SCORE,
None,
),
)
else if rules.isInsufficientMaterial(params.context) || rules.isFiftyMoveRule(params.context) then
Some((weights.DRAW_SCORE, None))
else if params.depth == 0 then
Some((quiescence(params.context, params.ply, params.window.alpha, params.window.beta, params.state.hash), None))
else None
@@ -210,7 +216,18 @@ final class AlphaBetaSearch(
legalMoves: List[Move],
): (Int, Option[Move]) =
val nullResult =
Option.when(canTryNullMove(params))(tryNullMove(params.context, params.depth, params.ply, params.window.beta, params.state, params.excludedRootMoves)).flatten
Option
.when(canTryNullMove(params))(
tryNullMove(
params.context,
params.depth,
params.ply,
params.window.beta,
params.state,
params.excludedRootMoves,
),
)
.flatten
nullResult.map((_, None)).getOrElse {
val ttBest = tt.probe(params.state.hash).flatMap(_.bestMove)
@@ -228,13 +245,13 @@ final class AlphaBetaSearch(
private def canTryNullMove(params: SearchParams): Boolean =
params.depth >= 3 &&
!rules.isCheck(params.context) &&
hasNonPawnMaterial(params.context)
!rules.isCheck(params.context) &&
hasNonPawnMaterial(params.context)
private def isQuietMove(context: GameContext, move: Move): Boolean =
!isCapture(context, move) &&
move.moveType != MoveType.CastleKingside &&
move.moveType != MoveType.CastleQueenside
move.moveType != MoveType.CastleKingside &&
move.moveType != MoveType.CastleQueenside
private def scoreMove(
child: GameContext,
@@ -246,14 +263,35 @@ final class AlphaBetaSearch(
): Int =
val betaNeg = -params.window.beta
if reduction > 0 then
val (rs, _) = search(child, math.max(0, params.depth - 1 - reduction + extension), params.ply + 1, Window(-a - 1, -a), childState, params.excludedRootMoves)
val s = -rs
val (rs, _) = search(
child,
math.max(0, params.depth - 1 - reduction + extension),
params.ply + 1,
Window(-a - 1, -a),
childState,
params.excludedRootMoves,
)
val s = -rs
if s > a then
val (fs, _) = search(child, math.max(0, params.depth - 1 + extension), params.ply + 1, Window(betaNeg, -a), childState, params.excludedRootMoves)
val (fs, _) = search(
child,
math.max(0, params.depth - 1 + extension),
params.ply + 1,
Window(betaNeg, -a),
childState,
params.excludedRootMoves,
)
-fs
else s
else
val (rs, _) = search(child, math.max(0, params.depth - 1 + extension), params.ply + 1, Window(betaNeg, -a), childState, params.excludedRootMoves)
val (rs, _) = search(
child,
math.max(0, params.depth - 1 + extension),
params.ply + 1,
Window(betaNeg, -a),
childState,
params.excludedRootMoves,
)
-rs
private def evalSingleMove(
@@ -268,8 +306,8 @@ final class AlphaBetaSearch(
weights.evaluateAccumulator(params.ply, params.context, params.state.hash) + FUTILITY_MARGIN < params.window.alpha
if skipRoot || futility then None
else
val child = rules.applyMove(params.context)(move)
val childHash = ZobristHash.nextHash(params.context, params.state.hash, move, child)
val child = rules.applyMove(params.context)(move)
val childHash = ZobristHash.nextHash(params.context, params.state.hash, move, child)
weights.pushAccumulator(params.ply + 1, move, params.context, child)
val childState = params.state.advance(childHash)
val extension = if rules.isCheck(child) then CHECK_EXTENSION else 0
@@ -317,7 +355,7 @@ final class AlphaBetaSearch(
state: SearchState,
excludedRootMoves: Set[Move],
): (Int, Option[Move]) =
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
val (bestMove, bestScore, cutoff) = searchLoop(0, 0, LoopAcc(None, -INF, window.alpha), params, ordered)
val flag =
if cutoff then TTFlag.Lower
@@ -18,10 +18,13 @@ final case class SearchParams(
final case class SearchState(hash: Long, repetitions: Map[Long, Int]):
def advance(nextHash: Long): SearchState =
SearchState(nextHash, repetitions.updatedWith(nextHash) {
case Some(v) => Some(v + 1)
case None => Some(1)
})
SearchState(
nextHash,
repetitions.updatedWith(nextHash) {
case Some(v) => Some(v + 1)
case None => Some(1)
},
)
enum TTFlag:
case Exact // Score is exact
@@ -11,11 +11,13 @@ import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import de.nowchess.rules.sets.DefaultRules
import java.util.concurrent.atomic.AtomicBoolean
class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
private object ZeroEval extends Evaluation:
val CHECKMATE_SCORE: Int = 1_000_000
val DRAW_SCORE: Int = 0
val CHECKMATE_SCORE: Int = 1_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 0
test("bestMove on initial position returns a move"):
@@ -43,7 +45,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(stubRules, weights = EvaluationClassic)
@@ -61,7 +63,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(stubRules, weights = EvaluationClassic)
@@ -109,7 +111,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = true
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(stalematRules, weights = EvaluationClassic)
@@ -126,7 +128,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = true
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(insufficientRules, weights = EvaluationClassic)
@@ -143,7 +145,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = true
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(fiftyMoveRules, weights = EvaluationClassic)
@@ -170,7 +172,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(rulesWithCapture, weights = EvaluationClassic)
@@ -188,7 +190,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(rulesQuiet, weights = EvaluationClassic)
@@ -311,18 +313,16 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove))
test("quiescence depth-limit in-check branch is exercised"):
val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true))
var firstChildCheckCall = true
val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true))
val firstChildCheckCall = AtomicBoolean(true)
val deepQRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context)
def legalMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context)
def allLegalMoves(context: GameContext): List[Move] =
if context.moves.isEmpty then List(rootMove) else List(capMove)
def isCheck(context: GameContext): Boolean =
if context.moves.length == 1 && firstChildCheckCall then
firstChildCheckCall = false
false
if context.moves.length == 1 && firstChildCheckCall.compareAndSet(true, false) then false
else context.moves.nonEmpty
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
@@ -334,4 +334,3 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
val search = AlphaBetaSearch(deepQRules, weights = ZeroEval)
search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove))
@@ -35,7 +35,7 @@ class ClassicalBotTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
@@ -66,7 +66,7 @@ class ClassicalBotTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
@@ -89,11 +89,10 @@ class ClassicalBotTest extends AnyFunSuite with Matchers:
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val context = GameContext.initial.copy(moves = List(repeatedMove, repeatedMove, repeatedMove))
val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
bot.nextMove(context) should be(None)
@@ -12,6 +12,7 @@ import org.scalatest.matchers.should.Matchers
import java.io.{DataOutputStream, FileOutputStream}
import java.nio.file.Files
import java.util.concurrent.atomic.AtomicBoolean
import scala.util.Using
class HybridBotTest extends AnyFunSuite with Matchers:
@@ -68,8 +69,8 @@ class HybridBotTest extends AnyFunSuite with Matchers:
test("HybridBot uses book move when available"):
val tempFile = Files.createTempFile("hybrid_book", ".bin")
try
val ctx = GameContext.initial
val hash = PolyglotHash.hash(ctx)
val ctx = GameContext.initial
val hash = PolyglotHash.hash(ctx)
val e2e4: Short = (4 | (3 << 3) | (4 << 6) | (1 << 9)).toShort
Using(DataOutputStream(FileOutputStream(tempFile.toFile))) { dos =>
@@ -100,26 +101,26 @@ class HybridBotTest extends AnyFunSuite with Matchers:
context.copy(turn = context.turn.opposite, moves = context.moves :+ move)
object LowNnue extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 0
object HighClassic extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 10_000
var reported = false
val reported = AtomicBoolean(false)
val bot = HybridBot(
BotDifficulty.Easy,
rules = oneMoveRules,
nnueEvaluation = LowNnue,
classicalEvaluation = HighClassic,
vetoReporter = _ => reported = true,
vetoReporter = _ => reported.set(true),
)
bot.nextMove(GameContext.initial) should be(Some(forcedMove))
reported should be(true)
reported.get should be(true)
test("HybridBot default veto reporter prints when threshold is exceeded"):
val forcedMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
@@ -137,13 +138,13 @@ class HybridBotTest extends AnyFunSuite with Matchers:
context.copy(turn = context.turn.opposite, moves = context.moves :+ move)
object LowNnue extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 0
object HighClassic extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 10_000
val bot = HybridBot(
@@ -157,4 +158,3 @@ class HybridBotTest extends AnyFunSuite with Matchers:
bot.nextMove(GameContext.initial)
}
printed should be(Some(forcedMove))
@@ -9,16 +9,6 @@ import org.scalatest.matchers.should.Matchers
class MoveOrderingTest extends AnyFunSuite with Matchers:
private val moveOrderingClass = Class.forName("de.nowchess.bot.logic.MoveOrdering$")
private val moveOrderingObject = moveOrderingClass.getField("MODULE$").get(null)
private def invokeMoveOrderingPrivate[T](methodPrefix: String, args: Seq[AnyRef]): T =
val method = moveOrderingClass.getDeclaredMethods
.find(m => (m.getName == methodPrefix || m.getName.startsWith(methodPrefix + "$")) && m.getParameterCount == args.length)
.getOrElse(throw RuntimeException(s"Method not found: $methodPrefix/${args.length}"))
method.setAccessible(true)
method.invoke(moveOrderingObject, args*).asInstanceOf[T]
test("queen capture ranks higher than rook capture"):
val board = Board(
Map(
@@ -200,13 +190,13 @@ class MoveOrderingTest extends AnyFunSuite with Matchers:
Square(File.D, Rank.R8) -> Piece.BlackRook,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val knightPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Knight))
val bishopPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Bishop))
val rookPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Rook))
MoveOrdering.score(context, knightPromo, None) should be > 0
MoveOrdering.score(context, bishopPromo, None) should be > 0
MoveOrdering.score(context, rookPromo, None) should be > 0
MoveOrdering.score(context, rookPromo, None) should be > 0
test("negative SEE capture path is scored below neutral capture baseline"):
val board = Board(
@@ -221,23 +211,9 @@ class MoveOrderingTest extends AnyFunSuite with Matchers:
MoveOrdering.score(context, move, None) should be < 100_000
test("private fallback branches in MoveOrdering are covered"):
test("non-capture move keeps fallback scoring at zero"):
val board = Board(Map(Square(File.E, Rank.R1) -> Piece.WhiteKing, Square(File.A, Rank.R8) -> Piece.BlackKing))
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val castle = Move(Square(File.E, Rank.R1), Square(File.G, Rank.R1), MoveType.CastleKingside)
val victim = invokeMoveOrderingPrivate[Int]("victimValue", Seq[AnyRef](context, castle))
victim should be(0)
val capture = invokeMoveOrderingPrivate[Boolean]("isCapture", Seq[AnyRef](context, castle))
capture should be(false)
val see = invokeMoveOrderingPrivate[Int]("staticExchange", Seq[AnyRef](context, castle))
see should be(0)
val clear = invokeMoveOrderingPrivate[Boolean](
"pathClear",
Seq[AnyRef](board, Square(File.A, Rank.R1), Square(File.H, Rank.R1), Int.box(-1), Int.box(0)),
)
clear should be(false)
MoveOrdering.score(context, castle, None) should be(0)
@@ -98,8 +98,12 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
Square(File.E, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White)
.withCastlingRights(CastlingRights(whiteKingSide = false, whiteQueenSide = true, blackKingSide = false, blackQueenSide = false))
val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.White)
.withCastlingRights(
CastlingRights(whiteKingSide = false, whiteQueenSide = true, blackKingSide = false, blackQueenSide = false),
)
val move = Move(Square(File.E, Rank.R1), Square(File.C, Rank.R1), MoveType.CastleQueenside)
val next = DefaultRules.applyMove(ctx)(move)
ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next))
@@ -113,7 +117,9 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
Square(File.E, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White)
val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.White)
.withEnPassantSquare(Some(Square(File.D, Rank.R6)))
val move = Move(Square(File.E, Rank.R5), Square(File.D, Rank.R6), MoveType.EnPassant)
val next = DefaultRules.applyMove(ctx)(move)
@@ -127,8 +133,12 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
Square(File.E, Rank.R1) -> Piece.WhiteKing,
),
)
val ctx = GameContext.initial.withBoard(board).withTurn(Color.Black)
.withCastlingRights(CastlingRights(whiteKingSide = false, whiteQueenSide = false, blackKingSide = true, blackQueenSide = false))
val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.Black)
.withCastlingRights(
CastlingRights(whiteKingSide = false, whiteQueenSide = false, blackKingSide = true, blackQueenSide = false),
)
val move = Move(Square(File.E, Rank.R8), Square(File.G, Rank.R8), MoveType.CastleKingside)
val next = DefaultRules.applyMove(ctx)(move)
ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next))
@@ -141,7 +151,9 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
Square(File.E, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White)
val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.White)
.withCastlingRights(CastlingRights(false, false, false, false))
for pp <- List(PromotionPiece.Knight, PromotionPiece.Bishop, PromotionPiece.Rook) do