feat: Improved how NNUE Evalutes

This commit is contained in:
2026-04-13 17:37:24 +02:00
parent ed26406185
commit 5df5a1875f
23 changed files with 438 additions and 292 deletions
@@ -27,7 +27,7 @@ object DefaultRules extends RuleSet:
List((2, 1), (2, -1), (-2, 1), (-2, -1), (1, 2), (1, -2), (-1, 2), (-1, -2))
// ── Pawn configuration helpers ─────────────────────────────────────
private def pawnForward(color: Color): Int = if color == Color.White then 1 else -1
private def pawnForward(color: Color): Int = if color == Color.White then 1 else -1
private def pawnStartRank(color: Color): Int = if color == Color.White then 1 else 6
private def pawnPromoRank(color: Color): Int = if color == Color.White then 7 else 0
@@ -36,13 +36,14 @@ object DefaultRules extends RuleSet:
override def candidateMoves(context: GameContext)(square: Square): List[Move] =
context.board.pieceAt(square).fold(List.empty[Move]) { piece =>
if piece.color != context.turn then List.empty[Move]
else piece.pieceType match
case PieceType.Pawn => pawnCandidates(context, square, piece.color)
case PieceType.Knight => knightCandidates(context, square, piece.color)
case PieceType.Bishop => slidingMoves(context, square, piece.color, BishopDirs)
case PieceType.Rook => slidingMoves(context, square, piece.color, RookDirs)
case PieceType.Queen => slidingMoves(context, square, piece.color, QueenDirs)
case PieceType.King => kingCandidates(context, square, piece.color)
else
piece.pieceType match
case PieceType.Pawn => pawnCandidates(context, square, piece.color)
case PieceType.Knight => knightCandidates(context, square, piece.color)
case PieceType.Bishop => slidingMoves(context, square, piece.color, BishopDirs)
case PieceType.Rook => slidingMoves(context, square, piece.color, RookDirs)
case PieceType.Queen => slidingMoves(context, square, piece.color, QueenDirs)
case PieceType.King => kingCandidates(context, square, piece.color)
}
override def legalMoves(context: GameContext)(square: Square): List[Move] =
@@ -51,13 +52,10 @@ object DefaultRules extends RuleSet:
}
override def allLegalMoves(context: GameContext): List[Move] =
context.board.pieces
.collect { case (sq, p) if p.color == context.turn => legalMoves(context)(sq) }
.flatten
.toList
Square.all.flatMap(sq => legalMoves(context)(sq)).toList
override def isCheck(context: GameContext): Boolean =
context.kingSquare(context.turn)
kingSquare(context.board, context.turn)
.fold(false)(sq => isAttackedBy(context.board, sq, context.turn.opposite))
override def isCheckmate(context: GameContext): Boolean =
@@ -115,18 +113,18 @@ object DefaultRules extends RuleSet:
// ── Sliding pieces (Bishop, Rook, Queen) ───────────────────────────
private def slidingMoves(
context: GameContext,
from: Square,
color: Color,
dirs: List[(Int, Int)]
context: GameContext,
from: Square,
color: Color,
dirs: List[(Int, Int)],
): List[Move] =
dirs.flatMap(dir => castRay(context.board, from, color, dir))
private def castRay(
board: Board,
from: Square,
color: Color,
dir: (Int, Int)
board: Board,
from: Square,
color: Color,
dir: (Int, Int),
): List[Move] =
@tailrec
def loop(sq: Square, acc: List[Move]): List[Move] =
@@ -134,40 +132,40 @@ object DefaultRules extends RuleSet:
case None => acc
case Some(next) =>
board.pieceAt(next) match
case None => loop(next, Move(from, next) :: acc)
case None => loop(next, Move(from, next) :: acc)
case Some(p) if p.color != color => Move(from, next, MoveType.Normal(isCapture = true)) :: acc
case Some(_) => acc
case Some(_) => acc
loop(from, Nil).reverse
// ── Knight ─────────────────────────────────────────────────────────
private def knightCandidates(
context: GameContext,
from: Square,
color: Color
context: GameContext,
from: Square,
color: Color,
): List[Move] =
KnightJumps.flatMap { (df, dr) =>
from.offset(df, dr).flatMap { to =>
context.board.pieceAt(to) match
case Some(p) if p.color == color => None
case Some(_) => Some(Move(from, to, MoveType.Normal(isCapture = true)))
case None => Some(Move(from, to))
case Some(_) => Some(Move(from, to, MoveType.Normal(isCapture = true)))
case None => Some(Move(from, to))
}
}
// ── King ───────────────────────────────────────────────────────────
private def kingCandidates(
context: GameContext,
from: Square,
color: Color
context: GameContext,
from: Square,
color: Color,
): List[Move] =
val steps = QueenDirs.flatMap { (df, dr) =>
from.offset(df, dr).flatMap { to =>
context.board.pieceAt(to) match
case Some(p) if p.color == color => None
case Some(_) => Some(Move(from, to, MoveType.Normal(isCapture = true)))
case None => Some(Move(from, to))
case Some(_) => Some(Move(from, to, MoveType.Normal(isCapture = true)))
case None => Some(Move(from, to))
}
}
steps ++ castlingCandidates(context, from, color)
@@ -175,17 +173,17 @@ object DefaultRules extends RuleSet:
// ── Castling ───────────────────────────────────────────────────────
private case class CastlingMove(
kingFromAlg: String,
kingToAlg: String,
middleAlg: String,
rookFromAlg: String,
moveType: MoveType
kingFromAlg: String,
kingToAlg: String,
middleAlg: String,
rookFromAlg: String,
moveType: MoveType,
)
private def castlingCandidates(
context: GameContext,
from: Square,
color: Color
context: GameContext,
from: Square,
color: Color,
): List[Move] =
color match
case Color.White => whiteCastles(context, from)
@@ -196,10 +194,18 @@ object DefaultRules extends RuleSet:
if from != expected then List.empty
else
val moves = scala.collection.mutable.ListBuffer[Move]()
addCastleMove(context, moves, context.castlingRights.whiteKingSide,
CastlingMove("e1", "g1", "f1", "h1", MoveType.CastleKingside))
addCastleMove(context, moves, context.castlingRights.whiteQueenSide,
CastlingMove("e1", "c1", "d1", "a1", MoveType.CastleQueenside))
addCastleMove(
context,
moves,
context.castlingRights.whiteKingSide,
CastlingMove("e1", "g1", "f1", "h1", MoveType.CastleKingside),
)
addCastleMove(
context,
moves,
context.castlingRights.whiteQueenSide,
CastlingMove("e1", "c1", "d1", "a1", MoveType.CastleQueenside),
)
moves.toList
private def blackCastles(context: GameContext, from: Square): List[Move] =
@@ -207,10 +213,18 @@ object DefaultRules extends RuleSet:
if from != expected then List.empty
else
val moves = scala.collection.mutable.ListBuffer[Move]()
addCastleMove(context, moves, context.castlingRights.blackKingSide,
CastlingMove("e8", "g8", "f8", "h8", MoveType.CastleKingside))
addCastleMove(context, moves, context.castlingRights.blackQueenSide,
CastlingMove("e8", "c8", "d8", "a8", MoveType.CastleQueenside))
addCastleMove(
context,
moves,
context.castlingRights.blackKingSide,
CastlingMove("e8", "g8", "f8", "h8", MoveType.CastleKingside),
)
addCastleMove(
context,
moves,
context.castlingRights.blackQueenSide,
CastlingMove("e8", "c8", "d8", "a8", MoveType.CastleQueenside),
)
moves.toList
private def queensideBSquare(kingToAlg: String): List[String] =
@@ -220,10 +234,10 @@ object DefaultRules extends RuleSet:
case _ => List.empty
private def addCastleMove(
context: GameContext,
moves: scala.collection.mutable.ListBuffer[Move],
castlingRight: Boolean,
castlingMove: CastlingMove
context: GameContext,
moves: scala.collection.mutable.ListBuffer[Move],
castlingRight: Boolean,
castlingMove: CastlingMove,
): Unit =
if castlingRight then
val clearSqs = (List(castlingMove.middleAlg, castlingMove.kingToAlg) ++ queensideBSquare(castlingMove.kingToAlg))
@@ -235,16 +249,15 @@ object DefaultRules extends RuleSet:
kt <- Square.fromAlgebraic(castlingMove.kingToAlg)
rf <- Square.fromAlgebraic(castlingMove.rookFromAlg)
do
val color = context.turn
val color = context.turn
val kingPresent = context.board.pieceAt(kf).exists(p => p.color == color && p.pieceType == PieceType.King)
val rookPresent = context.board.pieceAt(rf).exists(p => p.color == color && p.pieceType == PieceType.Rook)
val squaresSafe =
!isAttackedBy(context.board, kf, color.opposite) &&
!isAttackedBy(context.board, km, color.opposite) &&
!isAttackedBy(context.board, kt, color.opposite)
!isAttackedBy(context.board, km, color.opposite) &&
!isAttackedBy(context.board, kt, color.opposite)
if kingPresent && rookPresent && squaresSafe then
moves += Move(kf, kt, castlingMove.moveType)
if kingPresent && rookPresent && squaresSafe then moves += Move(kf, kt, castlingMove.moveType)
private def squaresEmpty(board: Board, squares: List[Square]): Boolean =
squares.forall(sq => board.pieceAt(sq).isEmpty)
@@ -252,22 +265,26 @@ object DefaultRules extends RuleSet:
// ── Pawn ───────────────────────────────────────────────────────────
private def pawnCandidates(
context: GameContext,
from: Square,
color: Color
context: GameContext,
from: Square,
color: Color,
): List[Move] =
val fwd = pawnForward(color)
val fwd = pawnForward(color)
val startRank = pawnStartRank(color)
val promoRank = pawnPromoRank(color)
val single = from.offset(0, fwd).filter(to => context.board.pieceAt(to).isEmpty)
val double = Option.when(from.rank.ordinal == startRank) {
from.offset(0, fwd).flatMap { mid =>
Option.when(context.board.pieceAt(mid).isEmpty) {
from.offset(0, fwd * 2).filter(to => context.board.pieceAt(to).isEmpty)
}.flatten
val double = Option
.when(from.rank.ordinal == startRank) {
from.offset(0, fwd).flatMap { mid =>
Option
.when(context.board.pieceAt(mid).isEmpty) {
from.offset(0, fwd * 2).filter(to => context.board.pieceAt(to).isEmpty)
}
.flatten
}
}
}.flatten
.flatten
val diagonalCaptures = List(-1, 1).flatMap { df =>
from.offset(df, fwd).flatMap { to =>
@@ -286,55 +303,56 @@ object DefaultRules extends RuleSet:
def toMoves(dest: Square, isCapture: Boolean): List[Move] =
if dest.rank.ordinal == promoRank then
List(
PromotionPiece.Queen, PromotionPiece.Rook,
PromotionPiece.Bishop, PromotionPiece.Knight
PromotionPiece.Queen,
PromotionPiece.Rook,
PromotionPiece.Bishop,
PromotionPiece.Knight,
).map(pt => Move(from, dest, MoveType.Promotion(pt)))
else List(Move(from, dest, MoveType.Normal(isCapture = isCapture)))
val stepSquares = single.toList ++ double.toList
val stepMoves = stepSquares.flatMap(dest => toMoves(dest, isCapture = false))
val stepSquares = single.toList ++ double.toList
val stepMoves = stepSquares.flatMap(dest => toMoves(dest, isCapture = false))
val captureMoves = diagonalCaptures.flatMap(dest => toMoves(dest, isCapture = true))
stepMoves ++ captureMoves ++ epCaptures
// ── Check detection ────────────────────────────────────────────────
/** Cast rays outward from `target` to detect attackers — O(rays) instead of O(64×rays). */
private def kingSquare(board: Board, color: Color): Option[Square] =
Square.all.find(sq => board.pieceAt(sq).exists(p => p.color == color && p.pieceType == PieceType.King))
private def isAttackedBy(board: Board, target: Square, attacker: Color): Boolean =
attackedBySlider(board, target, attacker, RookDirs, PieceType.Rook) ||
attackedBySlider(board, target, attacker, BishopDirs, PieceType.Bishop) ||
attackedByKnight(board, target, attacker) ||
attackedByPawn(board, target, attacker) ||
attackedByKing(board, target, attacker)
Square.all.exists { sq =>
board.pieceAt(sq).fold(false) { p =>
p.color == attacker && squareAttacks(board, sq, p, target)
}
}
private def attackedBySlider(board: Board, target: Square, attacker: Color, dirs: List[(Int, Int)], sliderType: PieceType): Boolean =
private def squareAttacks(board: Board, from: Square, piece: Piece, target: Square): Boolean =
val fwd = pawnForward(piece.color)
piece.pieceType match
case PieceType.Pawn =>
from.offset(-1, fwd).contains(target) || from.offset(1, fwd).contains(target)
case PieceType.Knight =>
KnightJumps.exists((df, dr) => from.offset(df, dr).contains(target))
case PieceType.Bishop => rayReaches(board, from, BishopDirs, target)
case PieceType.Rook => rayReaches(board, from, RookDirs, target)
case PieceType.Queen => rayReaches(board, from, QueenDirs, target)
case PieceType.King =>
QueenDirs.exists((df, dr) => from.offset(df, dr).contains(target))
private def rayReaches(board: Board, from: Square, dirs: List[(Int, Int)], target: Square): Boolean =
dirs.exists { dir =>
@tailrec def loop(sq: Square): Boolean = sq.offset(dir._1, dir._2) match
case None => false
case Some(next) => board.pieceAt(next) match
case None => loop(next)
case Some(p) if p.color == attacker && (p.pieceType == sliderType || p.pieceType == PieceType.Queen) => true
case _ => false
loop(target)
}
private def attackedByKnight(board: Board, target: Square, attacker: Color): Boolean =
KnightJumps.exists { (df, dr) =>
target.offset(df, dr).exists(sq => board.pieceAt(sq).exists(p => p.color == attacker && p.pieceType == PieceType.Knight))
}
private def attackedByPawn(board: Board, target: Square, attacker: Color): Boolean =
val dr = if attacker == Color.White then -1 else 1
List(-1, 1).exists { df =>
target.offset(df, dr).exists(sq => board.pieceAt(sq).exists(p => p.color == attacker && p.pieceType == PieceType.Pawn))
}
private def attackedByKing(board: Board, target: Square, attacker: Color): Boolean =
QueenDirs.exists { (df, dr) =>
target.offset(df, dr).exists(sq => board.pieceAt(sq).exists(p => p.color == attacker && p.pieceType == PieceType.King))
@tailrec
def loop(sq: Square): Boolean = sq.offset(dir._1, dir._2) match
case None => false
case Some(next) if next == target => true
case Some(next) if board.pieceAt(next).isEmpty => loop(next)
case Some(_) => false
loop(from)
}
private def leavesKingInCheck(context: GameContext, move: Move): Boolean =
val nextBoard = context.board.applyMove(move)
val nextBoard = context.board.applyMove(move)
val nextContext = context.withBoard(nextBoard)
isCheck(nextContext)
@@ -342,7 +360,7 @@ object DefaultRules extends RuleSet:
override def applyMove(context: GameContext)(move: Move): GameContext =
val color = context.turn
val board = context.board
val board = context.board
val newBoard = move.moveType match
case MoveType.CastleKingside => applyCastle(board, color, kingside = true)
@@ -351,14 +369,14 @@ object DefaultRules extends RuleSet:
case MoveType.Promotion(pp) => applyPromotion(board, move, color, pp)
case MoveType.Normal(_) => board.applyMove(move)
val newCastlingRights = updateCastlingRights(context.castlingRights, board, move, color)
val newCastlingRights = updateCastlingRights(context.castlingRights, board, move, color)
val newEnPassantSquare = computeEnPassantSquare(board, move)
val isCapture = move.moveType match
case MoveType.Normal(capture) => capture
case MoveType.EnPassant => true
case _ => board.pieceAt(move.to).isDefined
val isPawnMove = board.pieceAt(move.from).exists(_.pieceType == PieceType.Pawn)
val newClock = if isPawnMove || isCapture then 0 else context.halfMoveClock + 1
val newClock = if isPawnMove || isCapture then 0 else context.halfMoveClock + 1
context
.withBoard(newBoard)
@@ -371,19 +389,18 @@ object DefaultRules extends RuleSet:
private def applyCastle(board: Board, color: Color, kingside: Boolean): Board =
val rank = if color == Color.White then Rank.R1 else Rank.R8
val (kingFrom, kingTo, rookFrom, rookTo) =
if kingside then
(Square(File.E, rank), Square(File.G, rank), Square(File.H, rank), Square(File.F, rank))
else
(Square(File.E, rank), Square(File.C, rank), Square(File.A, rank), Square(File.D, rank))
if kingside then (Square(File.E, rank), Square(File.G, rank), Square(File.H, rank), Square(File.F, rank))
else (Square(File.E, rank), Square(File.C, rank), Square(File.A, rank), Square(File.D, rank))
val king = board.pieceAt(kingFrom).getOrElse(Piece(color, PieceType.King))
val rook = board.pieceAt(rookFrom).getOrElse(Piece(color, PieceType.Rook))
board
.removed(kingFrom).removed(rookFrom)
.removed(kingFrom)
.removed(rookFrom)
.updated(kingTo, king)
.updated(rookTo, rook)
private def applyEnPassant(board: Board, move: Move): Board =
val capturedRank = move.from.rank // the captured pawn is on the same rank as the moving pawn
val capturedRank = move.from.rank // the captured pawn is on the same rank as the moving pawn
val capturedSquare = Square(move.to.file, capturedRank)
board.applyMove(move).removed(capturedSquare)
@@ -396,7 +413,7 @@ object DefaultRules extends RuleSet:
board.removed(move.from).updated(move.to, Piece(color, promotedType))
private def updateCastlingRights(rights: CastlingRights, board: Board, move: Move, color: Color): CastlingRights =
val piece = board.pieceAt(move.from)
val piece = board.pieceAt(move.from)
val isKingMove = piece.exists(_.pieceType == PieceType.King)
val isRookMove = piece.exists(_.pieceType == PieceType.Rook)
@@ -406,19 +423,25 @@ object DefaultRules extends RuleSet:
val blackKingsideRook = Square(File.H, Rank.R8)
val blackQueensideRook = Square(File.A, Rank.R8)
var r = rights
if isKingMove then r = r.revokeColor(color)
else if isRookMove then
if move.from == whiteKingsideRook then r = r.revokeKingSide(Color.White)
if move.from == whiteQueensideRook then r = r.revokeQueenSide(Color.White)
if move.from == blackKingsideRook then r = r.revokeKingSide(Color.Black)
if move.from == blackQueensideRook then r = r.revokeQueenSide(Color.Black)
val afterKingMove = if isKingMove then rights.revokeColor(color) else rights
val afterRookMove =
if !isRookMove then afterKingMove
else
move.from match
case `whiteKingsideRook` => afterKingMove.revokeKingSide(Color.White)
case `whiteQueensideRook` => afterKingMove.revokeQueenSide(Color.White)
case `blackKingsideRook` => afterKingMove.revokeKingSide(Color.Black)
case `blackQueensideRook` => afterKingMove.revokeQueenSide(Color.Black)
case _ => afterKingMove
// Also revoke if a rook is captured
if move.to == whiteKingsideRook then r = r.revokeKingSide(Color.White)
if move.to == whiteQueensideRook then r = r.revokeQueenSide(Color.White)
if move.to == blackKingsideRook then r = r.revokeKingSide(Color.Black)
if move.to == blackQueensideRook then r = r.revokeQueenSide(Color.Black)
r
move.to match
case `whiteKingsideRook` => afterRookMove.revokeKingSide(Color.White)
case `whiteQueensideRook` => afterRookMove.revokeQueenSide(Color.White)
case `blackKingsideRook` => afterRookMove.revokeKingSide(Color.Black)
case `blackQueensideRook` => afterRookMove.revokeQueenSide(Color.Black)
case _ => afterRookMove
private def computeEnPassantSquare(board: Board, move: Move): Option[Square] =
val piece = board.pieceAt(move.from)
@@ -435,9 +458,10 @@ object DefaultRules extends RuleSet:
private def insufficientMaterial(board: Board): Boolean =
val pieces = board.pieces.values.toList.filter(_.pieceType != PieceType.King)
pieces match
case Nil => true
case Nil => true
case List(p) if p.pieceType == PieceType.Bishop || p.pieceType == PieceType.Knight => true
case List(p1, p2)
if p1.pieceType == PieceType.Bishop && p2.pieceType == PieceType.Bishop
&& p1.color != p2.color => true
if p1.pieceType == PieceType.Bishop && p2.pieceType == PieceType.Bishop
&& p1.color != p2.color =>
true
case _ => false
@@ -29,10 +29,15 @@ class DefaultRulesTest extends AnyFunSuite with Matchers:
test("pawn can capture diagonally"):
// FEN: white pawn e4, black pawn d5
val fen = "8/8/8/3p4/4P3/8/8/8 w - - 0 1"
val context = FenParser.parseFen(fen).fold(_ => fail(), identity)
val moves = rules.allLegalMoves(context)
val captures = moves.filter(m => m.from == Square(File.E, Rank.R4) && (m.moveType match { case _: MoveType.Normal => true; case _ => false }))
val fen = "8/8/8/3p4/4P3/8/8/8 w - - 0 1"
val context = FenParser.parseFen(fen).fold(_ => fail(), identity)
val moves = rules.allLegalMoves(context)
val captures = moves.filter { m =>
m.from == Square(File.E, Rank.R4) && (m.moveType match
case _: MoveType.Normal => true
case _ => false
)
}
captures.exists(m => m.to == Square(File.D, Rank.R5)) shouldBe true
test("pawn cannot move backward"):