feat: Improve code formatting and readability in AlphaBetaSearch and related tests
Build & Test (NowChessSystems) TeamCity build failed
Build & Test (NowChessSystems) TeamCity build failed
This commit is contained in:
@@ -17,4 +17,3 @@ object BotMoveRepetition:
|
||||
def filterAllowed(context: GameContext, moves: List[Move]): List[Move] =
|
||||
val blocked = blockedMoves(context)
|
||||
moves.filterNot(blocked.contains)
|
||||
|
||||
|
||||
@@ -24,10 +24,8 @@ object EvaluationNNUE extends Evaluation:
|
||||
|
||||
override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit =
|
||||
// Use incremental updates, but recompute from scratch every 10 plies to prevent accumulation errors
|
||||
if (childPly % 10 == 0) then
|
||||
nnue.recomputeAccumulator(childPly, child.board)
|
||||
else
|
||||
nnue.pushAccumulator(childPly, move, parent.board)
|
||||
if childPly % 10 == 0 then nnue.recomputeAccumulator(childPly, child.board)
|
||||
else nnue.pushAccumulator(childPly, move, parent.board)
|
||||
|
||||
override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int =
|
||||
nnue.evaluateAtPlyWithValidation(ply, context.turn, hash, context.board)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square}
|
||||
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Square}
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||
|
||||
class NNUE(model: NbaiModel):
|
||||
|
||||
private val featureSize = model.layers(0).inputSize
|
||||
private val accSize = model.layers(0).outputSize
|
||||
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
|
||||
private val featureSize = model.layers(0).inputSize
|
||||
private val accSize = model.layers(0).outputSize
|
||||
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
|
||||
|
||||
// Column-major L1 weights for cache-friendly sparse & incremental updates.
|
||||
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
|
||||
@@ -60,10 +60,10 @@ class NNUE(model: NbaiModel):
|
||||
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize)
|
||||
val l1 = l1Stack(childPly)
|
||||
move.moveType match
|
||||
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
|
||||
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
|
||||
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
|
||||
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
|
||||
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
|
||||
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
|
||||
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
|
||||
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
|
||||
|
||||
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
||||
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
|
||||
@@ -80,12 +80,13 @@ class NNUE(model: NbaiModel):
|
||||
|
||||
// Compare with actual L1
|
||||
val actual = l1Stack(ply)
|
||||
var maxError = 0f
|
||||
for i <- 0 until accSize do
|
||||
val error = math.abs(actual(i) - expectedL1(i))
|
||||
if error > maxError then maxError = error
|
||||
val maxError =
|
||||
(0 until accSize).foldLeft(0f) { (currentMax, i) =>
|
||||
val error = math.abs(actual(i) - expectedL1(i))
|
||||
math.max(currentMax, error)
|
||||
}
|
||||
|
||||
maxError < 0.001f // Allow small floating-point errors
|
||||
maxError < 0.001f // Allow small floating-point errors
|
||||
|
||||
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||
// Extract source and destination square indices early
|
||||
@@ -156,24 +157,24 @@ class NNUE(model: NbaiModel):
|
||||
// For debugging: validate that incremental accumulator matches recomputation
|
||||
if validateAccum && ply > 0 && ply % 10 != 0 then
|
||||
val isValid = validateAccumulator(ply, board)
|
||||
if !isValid then
|
||||
System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
|
||||
if !isValid then System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
|
||||
evaluateAtPly(ply, turn, hash)
|
||||
|
||||
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
|
||||
val l1ReLU = evalBuffers(0)
|
||||
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
|
||||
|
||||
var input = l1ReLU
|
||||
for i <- 1 until model.layers.length - 1 do
|
||||
val lw = model.weights(i)
|
||||
val out = evalBuffers(i)
|
||||
val ld = model.layers(i)
|
||||
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
|
||||
input = out
|
||||
val finalInput =
|
||||
(1 until model.layers.length - 1).foldLeft(l1ReLU) { (input, i) =>
|
||||
val lw = model.weights(i)
|
||||
val out = evalBuffers(i)
|
||||
val ld = model.layers(i)
|
||||
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
|
||||
out
|
||||
}
|
||||
|
||||
val lastIdx = model.layers.length - 1
|
||||
val output = runOutputLayer(input, model.layers(lastIdx).inputSize, model.weights(lastIdx))
|
||||
val output = runOutputLayer(finalInput, model.layers(lastIdx).inputSize, model.weights(lastIdx))
|
||||
scoreFromOutput(output, turn)
|
||||
|
||||
private def runDenseReLU(
|
||||
|
||||
@@ -20,8 +20,10 @@ object NbaiLoader:
|
||||
/** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */
|
||||
def loadDefault(): NbaiModel =
|
||||
Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match
|
||||
case Some(s) => try load(s) finally s.close()
|
||||
case None => NbaiMigrator.migrateFromBin()
|
||||
case Some(s) =>
|
||||
try load(s)
|
||||
finally s.close()
|
||||
case None => NbaiMigrator.migrateFromBin()
|
||||
|
||||
private def checkHeader(buf: ByteBuffer): Unit =
|
||||
val magic = buf.getInt()
|
||||
|
||||
@@ -9,11 +9,11 @@ object NbaiMigrator:
|
||||
private val BinVersion = 1
|
||||
|
||||
private val DefaultLayers: Array[LayerDescriptor] = Array(
|
||||
LayerDescriptor("relu", 768, 1536),
|
||||
LayerDescriptor("relu", 1536, 1024),
|
||||
LayerDescriptor("relu", 1024, 512),
|
||||
LayerDescriptor("relu", 512, 256),
|
||||
LayerDescriptor("linear", 256, 1),
|
||||
LayerDescriptor("relu", 768, 1536),
|
||||
LayerDescriptor("relu", 1536, 1024),
|
||||
LayerDescriptor("relu", 1024, 512),
|
||||
LayerDescriptor("relu", 512, 256),
|
||||
LayerDescriptor("linear", 256, 1),
|
||||
)
|
||||
|
||||
private val UnknownMetadata: NbaiMetadata =
|
||||
|
||||
@@ -24,7 +24,13 @@ object NbaiMetadata:
|
||||
def fromJson(json: String): NbaiMetadata =
|
||||
def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("")
|
||||
def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0")
|
||||
NbaiMetadata(str("trainedBy"), str("trainedAt"), num("trainingDataCount").toLong, num("valLoss").toDouble, num("trainLoss").toDouble)
|
||||
NbaiMetadata(
|
||||
str("trainedBy"),
|
||||
str("trainedAt"),
|
||||
num("trainingDataCount").toLong,
|
||||
num("valLoss").toDouble,
|
||||
num("trainLoss").toDouble,
|
||||
)
|
||||
|
||||
/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */
|
||||
case class LayerWeights(weights: Array[Float], bias: Array[Float])
|
||||
|
||||
@@ -147,8 +147,8 @@ final class AlphaBetaSearch(
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
): Option[Int] =
|
||||
val nullCtx = nullMoveContext(context)
|
||||
val nullState = state.advance(ZobristHash.hash(nullCtx))
|
||||
val nullCtx = nullMoveContext(context)
|
||||
val nullState = state.advance(ZobristHash.hash(nullCtx))
|
||||
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
|
||||
weights.copyAccumulator(ply, ply + 1)
|
||||
val (score, _) = search(nullCtx, reductionDepth, ply + 1, Window(-beta, -beta + 1), nullState, excludedRootMoves)
|
||||
@@ -199,8 +199,14 @@ final class AlphaBetaSearch(
|
||||
legalMoves: List[Move],
|
||||
): Option[(Int, Option[Move])] =
|
||||
if legalMoves.isEmpty then
|
||||
Some((if rules.isCheckmate(params.context) then -(weights.CHECKMATE_SCORE - params.ply) else weights.DRAW_SCORE, None))
|
||||
else if rules.isInsufficientMaterial(params.context) || rules.isFiftyMoveRule(params.context) then Some((weights.DRAW_SCORE, None))
|
||||
Some(
|
||||
(
|
||||
if rules.isCheckmate(params.context) then -(weights.CHECKMATE_SCORE - params.ply) else weights.DRAW_SCORE,
|
||||
None,
|
||||
),
|
||||
)
|
||||
else if rules.isInsufficientMaterial(params.context) || rules.isFiftyMoveRule(params.context) then
|
||||
Some((weights.DRAW_SCORE, None))
|
||||
else if params.depth == 0 then
|
||||
Some((quiescence(params.context, params.ply, params.window.alpha, params.window.beta, params.state.hash), None))
|
||||
else None
|
||||
@@ -210,7 +216,18 @@ final class AlphaBetaSearch(
|
||||
legalMoves: List[Move],
|
||||
): (Int, Option[Move]) =
|
||||
val nullResult =
|
||||
Option.when(canTryNullMove(params))(tryNullMove(params.context, params.depth, params.ply, params.window.beta, params.state, params.excludedRootMoves)).flatten
|
||||
Option
|
||||
.when(canTryNullMove(params))(
|
||||
tryNullMove(
|
||||
params.context,
|
||||
params.depth,
|
||||
params.ply,
|
||||
params.window.beta,
|
||||
params.state,
|
||||
params.excludedRootMoves,
|
||||
),
|
||||
)
|
||||
.flatten
|
||||
|
||||
nullResult.map((_, None)).getOrElse {
|
||||
val ttBest = tt.probe(params.state.hash).flatMap(_.bestMove)
|
||||
@@ -228,13 +245,13 @@ final class AlphaBetaSearch(
|
||||
|
||||
private def canTryNullMove(params: SearchParams): Boolean =
|
||||
params.depth >= 3 &&
|
||||
!rules.isCheck(params.context) &&
|
||||
hasNonPawnMaterial(params.context)
|
||||
!rules.isCheck(params.context) &&
|
||||
hasNonPawnMaterial(params.context)
|
||||
|
||||
private def isQuietMove(context: GameContext, move: Move): Boolean =
|
||||
!isCapture(context, move) &&
|
||||
move.moveType != MoveType.CastleKingside &&
|
||||
move.moveType != MoveType.CastleQueenside
|
||||
move.moveType != MoveType.CastleKingside &&
|
||||
move.moveType != MoveType.CastleQueenside
|
||||
|
||||
private def scoreMove(
|
||||
child: GameContext,
|
||||
@@ -246,14 +263,35 @@ final class AlphaBetaSearch(
|
||||
): Int =
|
||||
val betaNeg = -params.window.beta
|
||||
if reduction > 0 then
|
||||
val (rs, _) = search(child, math.max(0, params.depth - 1 - reduction + extension), params.ply + 1, Window(-a - 1, -a), childState, params.excludedRootMoves)
|
||||
val s = -rs
|
||||
val (rs, _) = search(
|
||||
child,
|
||||
math.max(0, params.depth - 1 - reduction + extension),
|
||||
params.ply + 1,
|
||||
Window(-a - 1, -a),
|
||||
childState,
|
||||
params.excludedRootMoves,
|
||||
)
|
||||
val s = -rs
|
||||
if s > a then
|
||||
val (fs, _) = search(child, math.max(0, params.depth - 1 + extension), params.ply + 1, Window(betaNeg, -a), childState, params.excludedRootMoves)
|
||||
val (fs, _) = search(
|
||||
child,
|
||||
math.max(0, params.depth - 1 + extension),
|
||||
params.ply + 1,
|
||||
Window(betaNeg, -a),
|
||||
childState,
|
||||
params.excludedRootMoves,
|
||||
)
|
||||
-fs
|
||||
else s
|
||||
else
|
||||
val (rs, _) = search(child, math.max(0, params.depth - 1 + extension), params.ply + 1, Window(betaNeg, -a), childState, params.excludedRootMoves)
|
||||
val (rs, _) = search(
|
||||
child,
|
||||
math.max(0, params.depth - 1 + extension),
|
||||
params.ply + 1,
|
||||
Window(betaNeg, -a),
|
||||
childState,
|
||||
params.excludedRootMoves,
|
||||
)
|
||||
-rs
|
||||
|
||||
private def evalSingleMove(
|
||||
@@ -268,8 +306,8 @@ final class AlphaBetaSearch(
|
||||
weights.evaluateAccumulator(params.ply, params.context, params.state.hash) + FUTILITY_MARGIN < params.window.alpha
|
||||
if skipRoot || futility then None
|
||||
else
|
||||
val child = rules.applyMove(params.context)(move)
|
||||
val childHash = ZobristHash.nextHash(params.context, params.state.hash, move, child)
|
||||
val child = rules.applyMove(params.context)(move)
|
||||
val childHash = ZobristHash.nextHash(params.context, params.state.hash, move, child)
|
||||
weights.pushAccumulator(params.ply + 1, move, params.context, child)
|
||||
val childState = params.state.advance(childHash)
|
||||
val extension = if rules.isCheck(child) then CHECK_EXTENSION else 0
|
||||
@@ -317,7 +355,7 @@ final class AlphaBetaSearch(
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
): (Int, Option[Move]) =
|
||||
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
|
||||
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
|
||||
val (bestMove, bestScore, cutoff) = searchLoop(0, 0, LoopAcc(None, -INF, window.alpha), params, ordered)
|
||||
val flag =
|
||||
if cutoff then TTFlag.Lower
|
||||
|
||||
@@ -18,10 +18,13 @@ final case class SearchParams(
|
||||
|
||||
final case class SearchState(hash: Long, repetitions: Map[Long, Int]):
|
||||
def advance(nextHash: Long): SearchState =
|
||||
SearchState(nextHash, repetitions.updatedWith(nextHash) {
|
||||
case Some(v) => Some(v + 1)
|
||||
case None => Some(1)
|
||||
})
|
||||
SearchState(
|
||||
nextHash,
|
||||
repetitions.updatedWith(nextHash) {
|
||||
case Some(v) => Some(v + 1)
|
||||
case None => Some(1)
|
||||
},
|
||||
)
|
||||
|
||||
enum TTFlag:
|
||||
case Exact // Score is exact
|
||||
|
||||
@@ -11,11 +11,13 @@ import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
|
||||
|
||||
private object ZeroEval extends Evaluation:
|
||||
val CHECKMATE_SCORE: Int = 1_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
val CHECKMATE_SCORE: Int = 1_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
def evaluate(context: GameContext): Int = 0
|
||||
|
||||
test("bestMove on initial position returns a move"):
|
||||
@@ -43,7 +45,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
|
||||
def isStalemate(context: GameContext): Boolean = false
|
||||
def isInsufficientMaterial(context: GameContext): Boolean = false
|
||||
def isFiftyMoveRule(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def applyMove(context: GameContext)(move: Move): GameContext = context
|
||||
|
||||
val search = AlphaBetaSearch(stubRules, weights = EvaluationClassic)
|
||||
@@ -61,7 +63,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
|
||||
def isStalemate(context: GameContext): Boolean = false
|
||||
def isInsufficientMaterial(context: GameContext): Boolean = false
|
||||
def isFiftyMoveRule(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def applyMove(context: GameContext)(move: Move): GameContext = context
|
||||
|
||||
val search = AlphaBetaSearch(stubRules, weights = EvaluationClassic)
|
||||
@@ -109,7 +111,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
|
||||
def isStalemate(context: GameContext): Boolean = true
|
||||
def isInsufficientMaterial(context: GameContext): Boolean = false
|
||||
def isFiftyMoveRule(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def applyMove(context: GameContext)(move: Move): GameContext = context
|
||||
|
||||
val search = AlphaBetaSearch(stalematRules, weights = EvaluationClassic)
|
||||
@@ -126,7 +128,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
|
||||
def isStalemate(context: GameContext): Boolean = false
|
||||
def isInsufficientMaterial(context: GameContext): Boolean = true
|
||||
def isFiftyMoveRule(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def applyMove(context: GameContext)(move: Move): GameContext = context
|
||||
|
||||
val search = AlphaBetaSearch(insufficientRules, weights = EvaluationClassic)
|
||||
@@ -143,7 +145,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
|
||||
def isStalemate(context: GameContext): Boolean = false
|
||||
def isInsufficientMaterial(context: GameContext): Boolean = false
|
||||
def isFiftyMoveRule(context: GameContext): Boolean = true
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def applyMove(context: GameContext)(move: Move): GameContext = context
|
||||
|
||||
val search = AlphaBetaSearch(fiftyMoveRules, weights = EvaluationClassic)
|
||||
@@ -170,7 +172,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
|
||||
def isStalemate(context: GameContext): Boolean = false
|
||||
def isInsufficientMaterial(context: GameContext): Boolean = false
|
||||
def isFiftyMoveRule(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def applyMove(context: GameContext)(move: Move): GameContext = context
|
||||
|
||||
val search = AlphaBetaSearch(rulesWithCapture, weights = EvaluationClassic)
|
||||
@@ -188,7 +190,7 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
|
||||
def isStalemate(context: GameContext): Boolean = false
|
||||
def isInsufficientMaterial(context: GameContext): Boolean = false
|
||||
def isFiftyMoveRule(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def applyMove(context: GameContext)(move: Move): GameContext = context
|
||||
|
||||
val search = AlphaBetaSearch(rulesQuiet, weights = EvaluationClassic)
|
||||
@@ -311,18 +313,16 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
|
||||
search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove))
|
||||
|
||||
test("quiescence depth-limit in-check branch is exercised"):
|
||||
val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
|
||||
val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true))
|
||||
var firstChildCheckCall = true
|
||||
val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
|
||||
val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true))
|
||||
val firstChildCheckCall = AtomicBoolean(true)
|
||||
val deepQRules = new RuleSet:
|
||||
def candidateMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context)
|
||||
def legalMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context)
|
||||
def allLegalMoves(context: GameContext): List[Move] =
|
||||
if context.moves.isEmpty then List(rootMove) else List(capMove)
|
||||
def isCheck(context: GameContext): Boolean =
|
||||
if context.moves.length == 1 && firstChildCheckCall then
|
||||
firstChildCheckCall = false
|
||||
false
|
||||
if context.moves.length == 1 && firstChildCheckCall.compareAndSet(true, false) then false
|
||||
else context.moves.nonEmpty
|
||||
def isCheckmate(context: GameContext): Boolean = false
|
||||
def isStalemate(context: GameContext): Boolean = false
|
||||
@@ -334,4 +334,3 @@ class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
|
||||
|
||||
val search = AlphaBetaSearch(deepQRules, weights = ZeroEval)
|
||||
search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove))
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class ClassicalBotTest extends AnyFunSuite with Matchers:
|
||||
def isStalemate(context: GameContext): Boolean = false
|
||||
def isInsufficientMaterial(context: GameContext): Boolean = false
|
||||
def isFiftyMoveRule(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def applyMove(context: GameContext)(move: Move): GameContext = context
|
||||
|
||||
val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
|
||||
@@ -66,7 +66,7 @@ class ClassicalBotTest extends AnyFunSuite with Matchers:
|
||||
def isStalemate(context: GameContext): Boolean = false
|
||||
def isInsufficientMaterial(context: GameContext): Boolean = false
|
||||
def isFiftyMoveRule(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def applyMove(context: GameContext)(move: Move): GameContext = context
|
||||
|
||||
val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
|
||||
@@ -89,11 +89,10 @@ class ClassicalBotTest extends AnyFunSuite with Matchers:
|
||||
def isStalemate(context: GameContext): Boolean = false
|
||||
def isInsufficientMaterial(context: GameContext): Boolean = false
|
||||
def isFiftyMoveRule(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def isThreefoldRepetition(context: GameContext): Boolean = false
|
||||
def applyMove(context: GameContext)(move: Move): GameContext = context
|
||||
|
||||
val context = GameContext.initial.copy(moves = List(repeatedMove, repeatedMove, repeatedMove))
|
||||
val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
|
||||
|
||||
bot.nextMove(context) should be(None)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import org.scalatest.matchers.should.Matchers
|
||||
|
||||
import java.io.{DataOutputStream, FileOutputStream}
|
||||
import java.nio.file.Files
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import scala.util.Using
|
||||
|
||||
class HybridBotTest extends AnyFunSuite with Matchers:
|
||||
@@ -68,8 +69,8 @@ class HybridBotTest extends AnyFunSuite with Matchers:
|
||||
test("HybridBot uses book move when available"):
|
||||
val tempFile = Files.createTempFile("hybrid_book", ".bin")
|
||||
try
|
||||
val ctx = GameContext.initial
|
||||
val hash = PolyglotHash.hash(ctx)
|
||||
val ctx = GameContext.initial
|
||||
val hash = PolyglotHash.hash(ctx)
|
||||
val e2e4: Short = (4 | (3 << 3) | (4 << 6) | (1 << 9)).toShort
|
||||
|
||||
Using(DataOutputStream(FileOutputStream(tempFile.toFile))) { dos =>
|
||||
@@ -100,26 +101,26 @@ class HybridBotTest extends AnyFunSuite with Matchers:
|
||||
context.copy(turn = context.turn.opposite, moves = context.moves :+ move)
|
||||
|
||||
object LowNnue extends Evaluation:
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
def evaluate(context: GameContext): Int = 0
|
||||
|
||||
object HighClassic extends Evaluation:
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
def evaluate(context: GameContext): Int = 10_000
|
||||
|
||||
var reported = false
|
||||
val reported = AtomicBoolean(false)
|
||||
val bot = HybridBot(
|
||||
BotDifficulty.Easy,
|
||||
rules = oneMoveRules,
|
||||
nnueEvaluation = LowNnue,
|
||||
classicalEvaluation = HighClassic,
|
||||
vetoReporter = _ => reported = true,
|
||||
vetoReporter = _ => reported.set(true),
|
||||
)
|
||||
|
||||
bot.nextMove(GameContext.initial) should be(Some(forcedMove))
|
||||
reported should be(true)
|
||||
reported.get should be(true)
|
||||
|
||||
test("HybridBot default veto reporter prints when threshold is exceeded"):
|
||||
val forcedMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
|
||||
@@ -137,13 +138,13 @@ class HybridBotTest extends AnyFunSuite with Matchers:
|
||||
context.copy(turn = context.turn.opposite, moves = context.moves :+ move)
|
||||
|
||||
object LowNnue extends Evaluation:
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
def evaluate(context: GameContext): Int = 0
|
||||
|
||||
object HighClassic extends Evaluation:
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
def evaluate(context: GameContext): Int = 10_000
|
||||
|
||||
val bot = HybridBot(
|
||||
@@ -157,4 +158,3 @@ class HybridBotTest extends AnyFunSuite with Matchers:
|
||||
bot.nextMove(GameContext.initial)
|
||||
}
|
||||
printed should be(Some(forcedMove))
|
||||
|
||||
|
||||
@@ -9,16 +9,6 @@ import org.scalatest.matchers.should.Matchers
|
||||
|
||||
class MoveOrderingTest extends AnyFunSuite with Matchers:
|
||||
|
||||
private val moveOrderingClass = Class.forName("de.nowchess.bot.logic.MoveOrdering$")
|
||||
private val moveOrderingObject = moveOrderingClass.getField("MODULE$").get(null)
|
||||
|
||||
private def invokeMoveOrderingPrivate[T](methodPrefix: String, args: Seq[AnyRef]): T =
|
||||
val method = moveOrderingClass.getDeclaredMethods
|
||||
.find(m => (m.getName == methodPrefix || m.getName.startsWith(methodPrefix + "$")) && m.getParameterCount == args.length)
|
||||
.getOrElse(throw RuntimeException(s"Method not found: $methodPrefix/${args.length}"))
|
||||
method.setAccessible(true)
|
||||
method.invoke(moveOrderingObject, args*).asInstanceOf[T]
|
||||
|
||||
test("queen capture ranks higher than rook capture"):
|
||||
val board = Board(
|
||||
Map(
|
||||
@@ -200,13 +190,13 @@ class MoveOrderingTest extends AnyFunSuite with Matchers:
|
||||
Square(File.D, Rank.R8) -> Piece.BlackRook,
|
||||
),
|
||||
)
|
||||
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
|
||||
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
|
||||
val knightPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Knight))
|
||||
val bishopPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Bishop))
|
||||
val rookPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Rook))
|
||||
MoveOrdering.score(context, knightPromo, None) should be > 0
|
||||
MoveOrdering.score(context, bishopPromo, None) should be > 0
|
||||
MoveOrdering.score(context, rookPromo, None) should be > 0
|
||||
MoveOrdering.score(context, rookPromo, None) should be > 0
|
||||
|
||||
test("negative SEE capture path is scored below neutral capture baseline"):
|
||||
val board = Board(
|
||||
@@ -221,23 +211,9 @@ class MoveOrderingTest extends AnyFunSuite with Matchers:
|
||||
|
||||
MoveOrdering.score(context, move, None) should be < 100_000
|
||||
|
||||
test("private fallback branches in MoveOrdering are covered"):
|
||||
test("non-capture move keeps fallback scoring at zero"):
|
||||
val board = Board(Map(Square(File.E, Rank.R1) -> Piece.WhiteKing, Square(File.A, Rank.R8) -> Piece.BlackKing))
|
||||
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
|
||||
val castle = Move(Square(File.E, Rank.R1), Square(File.G, Rank.R1), MoveType.CastleKingside)
|
||||
|
||||
val victim = invokeMoveOrderingPrivate[Int]("victimValue", Seq[AnyRef](context, castle))
|
||||
victim should be(0)
|
||||
|
||||
val capture = invokeMoveOrderingPrivate[Boolean]("isCapture", Seq[AnyRef](context, castle))
|
||||
capture should be(false)
|
||||
|
||||
val see = invokeMoveOrderingPrivate[Int]("staticExchange", Seq[AnyRef](context, castle))
|
||||
see should be(0)
|
||||
|
||||
val clear = invokeMoveOrderingPrivate[Boolean](
|
||||
"pathClear",
|
||||
Seq[AnyRef](board, Square(File.A, Rank.R1), Square(File.H, Rank.R1), Int.box(-1), Int.box(0)),
|
||||
)
|
||||
clear should be(false)
|
||||
|
||||
MoveOrdering.score(context, castle, None) should be(0)
|
||||
|
||||
@@ -98,8 +98,12 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
|
||||
Square(File.E, Rank.R8) -> Piece.BlackKing,
|
||||
),
|
||||
)
|
||||
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White)
|
||||
.withCastlingRights(CastlingRights(whiteKingSide = false, whiteQueenSide = true, blackKingSide = false, blackQueenSide = false))
|
||||
val ctx = GameContext.initial
|
||||
.withBoard(board)
|
||||
.withTurn(Color.White)
|
||||
.withCastlingRights(
|
||||
CastlingRights(whiteKingSide = false, whiteQueenSide = true, blackKingSide = false, blackQueenSide = false),
|
||||
)
|
||||
val move = Move(Square(File.E, Rank.R1), Square(File.C, Rank.R1), MoveType.CastleQueenside)
|
||||
val next = DefaultRules.applyMove(ctx)(move)
|
||||
ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next))
|
||||
@@ -113,7 +117,9 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
|
||||
Square(File.E, Rank.R8) -> Piece.BlackKing,
|
||||
),
|
||||
)
|
||||
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White)
|
||||
val ctx = GameContext.initial
|
||||
.withBoard(board)
|
||||
.withTurn(Color.White)
|
||||
.withEnPassantSquare(Some(Square(File.D, Rank.R6)))
|
||||
val move = Move(Square(File.E, Rank.R5), Square(File.D, Rank.R6), MoveType.EnPassant)
|
||||
val next = DefaultRules.applyMove(ctx)(move)
|
||||
@@ -127,8 +133,12 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
|
||||
Square(File.E, Rank.R1) -> Piece.WhiteKing,
|
||||
),
|
||||
)
|
||||
val ctx = GameContext.initial.withBoard(board).withTurn(Color.Black)
|
||||
.withCastlingRights(CastlingRights(whiteKingSide = false, whiteQueenSide = false, blackKingSide = true, blackQueenSide = false))
|
||||
val ctx = GameContext.initial
|
||||
.withBoard(board)
|
||||
.withTurn(Color.Black)
|
||||
.withCastlingRights(
|
||||
CastlingRights(whiteKingSide = false, whiteQueenSide = false, blackKingSide = true, blackQueenSide = false),
|
||||
)
|
||||
val move = Move(Square(File.E, Rank.R8), Square(File.G, Rank.R8), MoveType.CastleKingside)
|
||||
val next = DefaultRules.applyMove(ctx)(move)
|
||||
ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next))
|
||||
@@ -141,7 +151,9 @@ class ZobristHashTest extends AnyFunSuite with Matchers:
|
||||
Square(File.E, Rank.R8) -> Piece.BlackKing,
|
||||
),
|
||||
)
|
||||
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White)
|
||||
val ctx = GameContext.initial
|
||||
.withBoard(board)
|
||||
.withTurn(Color.White)
|
||||
.withCastlingRights(CastlingRights(false, false, false, false))
|
||||
|
||||
for pp <- List(PromotionPiece.Knight, PromotionPiece.Bishop, PromotionPiece.Rook) do
|
||||
|
||||
Reference in New Issue
Block a user