@@ -0,0 +1,25 @@
|
||||
package de.nowchess.bot.bots
|
||||
|
||||
import de.nowchess.bot.Bot
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.rules.RuleSet
|
||||
import de.nowchess.bot.bots.classic.EvaluationClassic
|
||||
import de.nowchess.bot.logic.AlphaBetaSearch
|
||||
import de.nowchess.bot.util.PolyglotBook
|
||||
import de.nowchess.bot.{BotDifficulty, BotMoveRepetition}
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
|
||||
object ClassicalBot:
|
||||
def apply(
|
||||
difficulty: BotDifficulty,
|
||||
rules: RuleSet = DefaultRules,
|
||||
book: Option[PolyglotBook] = None,
|
||||
): Bot =
|
||||
val search = AlphaBetaSearch(rules, weights = EvaluationClassic)
|
||||
val timeBudgetMs = 1000L
|
||||
context =>
|
||||
val blockedMoves = BotMoveRepetition.blockedMoves(context)
|
||||
book
|
||||
.flatMap(_.probe(context))
|
||||
.filterNot(blockedMoves.contains)
|
||||
.orElse(search.bestMoveWithTime(context, timeBudgetMs, blockedMoves))
|
||||
@@ -0,0 +1,39 @@
|
||||
package de.nowchess.bot.bots
|
||||
|
||||
import de.nowchess.bot.Bot
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.api.rules.RuleSet
|
||||
import de.nowchess.bot.ai.Evaluation
|
||||
import de.nowchess.bot.bots.classic.EvaluationClassic
|
||||
import de.nowchess.bot.bots.nnue.EvaluationNNUE
|
||||
import de.nowchess.bot.logic.{AlphaBetaSearch, TranspositionTable}
|
||||
import de.nowchess.bot.util.PolyglotBook
|
||||
import de.nowchess.bot.{BotDifficulty, BotMoveRepetition, Config}
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
|
||||
object HybridBot:
|
||||
def apply(
|
||||
difficulty: BotDifficulty,
|
||||
rules: RuleSet = DefaultRules,
|
||||
book: Option[PolyglotBook] = None,
|
||||
nnueEvaluation: Evaluation = EvaluationNNUE,
|
||||
classicalEvaluation: Evaluation = EvaluationClassic,
|
||||
vetoReporter: String => Unit = println(_),
|
||||
): Bot =
|
||||
val search = AlphaBetaSearch(rules, TranspositionTable(), classicalEvaluation)
|
||||
context =>
|
||||
val blockedMoves = BotMoveRepetition.blockedMoves(context)
|
||||
book.flatMap(_.probe(context)).filterNot(blockedMoves.contains).orElse {
|
||||
search.bestMoveWithTime(context, Config.TIME_LIMIT_MS, blockedMoves).map { move =>
|
||||
val next = rules.applyMove(context)(move)
|
||||
val staticNnue = nnueEvaluation.evaluate(next)
|
||||
val classical = classicalEvaluation.evaluate(next)
|
||||
val diff = (classical - staticNnue).abs
|
||||
if diff > Config.VETO_THRESHOLD then
|
||||
vetoReporter(
|
||||
f"[Veto] ${move.from}->${move.to}: nnue=$staticNnue classical=$classical diff=$diff — flagged but trusted (deep search)",
|
||||
)
|
||||
move
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package de.nowchess.bot.bots
|
||||
|
||||
import de.nowchess.bot.Bot
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.api.rules.RuleSet
|
||||
import de.nowchess.bot.bots.nnue.EvaluationNNUE
|
||||
import de.nowchess.bot.logic.AlphaBetaSearch
|
||||
import de.nowchess.bot.util.{PolyglotBook, ZobristHash}
|
||||
import de.nowchess.bot.{BotDifficulty, BotMoveRepetition}
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
|
||||
object NNUEBot:
|
||||
def apply(
|
||||
difficulty: BotDifficulty,
|
||||
rules: RuleSet = DefaultRules,
|
||||
book: Option[PolyglotBook] = None,
|
||||
): Bot =
|
||||
val search = AlphaBetaSearch(rules, weights = EvaluationNNUE)
|
||||
context =>
|
||||
val blockedMoves = BotMoveRepetition.blockedMoves(context)
|
||||
book
|
||||
.flatMap(_.probe(context))
|
||||
.filterNot(blockedMoves.contains)
|
||||
.orElse {
|
||||
val moves = BotMoveRepetition.filterAllowed(context, rules.allLegalMoves(context))
|
||||
if moves.isEmpty then None
|
||||
else
|
||||
val scored = batchEvaluateRoot(rules, context, moves)
|
||||
val bestMove = scored.maxBy(_._2)._1
|
||||
search.bestMoveWithTime(context, allocateTime(scored), blockedMoves).orElse(Some(bestMove))
|
||||
}
|
||||
|
||||
private def batchEvaluateRoot(rules: RuleSet, context: GameContext, moves: List[Move]): List[(Move, Int)] =
|
||||
EvaluationNNUE.initAccumulator(context)
|
||||
val rootHash = ZobristHash.hash(context)
|
||||
moves.map { move =>
|
||||
val child = rules.applyMove(context)(move)
|
||||
val childHash = ZobristHash.nextHash(context, rootHash, move, child)
|
||||
EvaluationNNUE.pushAccumulator(1, move, context, child)
|
||||
val score = -EvaluationNNUE.evaluateAccumulator(1, child, childHash)
|
||||
(move, score)
|
||||
}
|
||||
|
||||
private def allocateTime(scored: List[(Move, Int)]): Long =
|
||||
val moveCount = scored.length
|
||||
if moveCount > 30 then 1500L
|
||||
else if moveCount < 5 then 500L
|
||||
else
|
||||
val scores = scored.map(_._2)
|
||||
val best = scores.max
|
||||
val second = scores.filter(_ < best).maxOption.getOrElse(best)
|
||||
if best - second > 200 then 600L else 1000L
|
||||
+361
@@ -0,0 +1,361 @@
|
||||
package de.nowchess.bot.bots.classic
|
||||
|
||||
import de.nowchess.api.board.{Color, PieceType, Square}
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.bot.ai.Evaluation
|
||||
|
||||
object EvaluationClassic extends Evaluation:
|
||||
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
|
||||
// Material values in centipawns (indexed by PieceType.ordinal: Pawn=0, Knight=1, Bishop=2, Rook=3, Queen=4, King=5)
|
||||
private val mgMaterial = Array(100, 325, 335, 500, 900, 20_000)
|
||||
private val egMaterial = Array(110, 310, 310, 530, 1_000, 20_000)
|
||||
|
||||
private val TEMPO_BONUS: Int = 10
|
||||
|
||||
// Piece-square tables (Simplified Evaluation Function, Michniewski)
|
||||
// Indexed by squareIndex = rank.ordinal * 8 + file.ordinal
|
||||
// White's perspective: rank 0 = home (r1), rank 7 = back rank (r8)
|
||||
// Black is vertically mirrored
|
||||
|
||||
private val mgPawnTable: Array[Int] = Array(
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 10, 10, 20, 30, 30, 20, 10, 10, 5, 5, 10, 25, 25, 10, 5, 5,
|
||||
0, 0, 0, 20, 20, 0, 0, 0, 5, -5, -10, 0, 0, -10, -5, 5, 5, 10, 10, -20, -20, 10, 10, 5, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
)
|
||||
|
||||
private val egPawnTable: Array[Int] = Array(
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 70, 70, 70, 70, 70, 70, 70, 70, 40, 40, 40, 40, 40, 40, 40, 40, 30, 30, 30, 30, 30, 30, 30,
|
||||
30, 20, 20, 20, 20, 20, 20, 20, 20, 10, 10, 10, 10, 10, 10, 10, 10, 5, 5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
)
|
||||
|
||||
private val mgKnightTable: Array[Int] = Array(
|
||||
-50, -40, -30, -30, -30, -30, -40, -50, -40, -20, 0, 0, 0, 0, -20, -40, -30, 0, 10, 15, 15, 10, 0, -30, -30, 5, 15,
|
||||
20, 20, 15, 5, -30, -30, 0, 15, 20, 20, 15, 0, -30, -30, 5, 10, 15, 15, 10, 5, -30, -40, -20, 0, 5, 5, 0, -20, -40,
|
||||
-50, -40, -30, -30, -30, -30, -40, -50,
|
||||
)
|
||||
|
||||
private val egKnightTable: Array[Int] = Array(
|
||||
-30, -20, -10, -10, -10, -10, -20, -30, -20, 0, 5, 5, 5, 5, 0, -20, -10, 5, 15, 20, 20, 15, 5, -10, -10, 5, 20, 25,
|
||||
25, 20, 5, -10, -10, 5, 20, 25, 25, 20, 5, -10, -10, 5, 15, 20, 20, 15, 5, -10, -20, 0, 5, 5, 5, 5, 0, -20, -30,
|
||||
-20, -10, -10, -10, -10, -20, -30,
|
||||
)
|
||||
|
||||
private val mgBishopTable: Array[Int] = Array(
|
||||
-20, -10, -10, -10, -10, -10, -10, -20, -10, 0, 0, 0, 0, 0, 0, -10, -10, 0, 5, 10, 10, 5, 0, -10, -10, 5, 5, 10, 10,
|
||||
5, 5, -10, -10, 0, 10, 10, 10, 10, 0, -10, -10, 10, 10, 10, 10, 10, 10, -10, -10, 5, 0, 0, 0, 0, 5, -10, -20, -10,
|
||||
-10, -10, -10, -10, -10, -20,
|
||||
)
|
||||
|
||||
private val egBishopTable: Array[Int] = Array(
|
||||
-20, -10, -5, -5, -5, -5, -10, -20, -10, 0, 5, 5, 5, 5, 0, -10, -5, 5, 10, 10, 10, 10, 5, -5, -5, 5, 10, 15, 15, 10,
|
||||
5, -5, -5, 5, 10, 15, 15, 10, 5, -5, -5, 5, 10, 10, 10, 10, 5, -5, -10, 0, 5, 5, 5, 5, 0, -10, -20, -10, -5, -5, -5,
|
||||
-5, -10, -20,
|
||||
)
|
||||
|
||||
private val mgRookTable: Array[Int] = Array(
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 5, 10, 10, 10, 10, 10, 10, 5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0,
|
||||
0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, 0, 0, 0, 5, 5, 0, 0, 0,
|
||||
)
|
||||
|
||||
private val egRookTable: Array[Int] = Array(
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 5, 10, 10, 10, 10, 10, 10, 5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0,
|
||||
0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, 0, 0, 0, 5, 5, 0, 0, 0,
|
||||
)
|
||||
|
||||
private val mgQueenTable: Array[Int] = Array(
|
||||
-20, -10, -10, -5, -5, -10, -10, -20, -10, 0, 0, 0, 0, 0, 0, -10, -10, 0, 5, 5, 5, 5, 0, -10, -5, 0, 5, 5, 5, 5, 0,
|
||||
-5, 0, 0, 5, 5, 5, 5, 0, -5, -10, 5, 5, 5, 5, 5, 0, -10, -10, 0, 5, 0, 0, 0, 0, -10, -20, -10, -10, -5, -5, -10,
|
||||
-10, -20,
|
||||
)
|
||||
|
||||
private val egQueenTable: Array[Int] = Array(
|
||||
-15, -10, -8, -5, -5, -8, -10, -15, -10, 0, 3, 5, 5, 3, 0, -10, -8, 3, 10, 10, 10, 10, 3, -8, -5, 5, 10, 15, 15, 10,
|
||||
5, -5, -5, 5, 10, 15, 15, 10, 5, -5, -8, 3, 10, 10, 10, 10, 3, -8, -10, 0, 3, 5, 5, 3, 0, -10, -15, -10, -8, -5, -5,
|
||||
-8, -10, -15,
|
||||
)
|
||||
|
||||
private val mgKingTable: Array[Int] = Array(
|
||||
-30, -40, -40, -50, -50, -40, -40, -30, -30, -40, -40, -50, -50, -40, -40, -30, -30, -40, -40, -50, -50, -40, -40,
|
||||
-30, -30, -40, -40, -50, -50, -40, -40, -30, -20, -30, -30, -40, -40, -30, -30, -20, -10, -20, -20, -20, -20, -20,
|
||||
-20, -10, 20, 20, 0, 0, 0, 0, 20, 20, 20, 30, 10, 0, 0, 10, 30, 20,
|
||||
)
|
||||
|
||||
private val egKingTable: Array[Int] = Array(
|
||||
-50, -40, -30, -20, -20, -30, -40, -50, -30, -20, -10, 0, 0, -10, -20, -30, -30, -10, 20, 30, 30, 20, -10, -30, -30,
|
||||
-10, 30, 40, 40, 30, -10, -30, -30, -10, 30, 40, 40, 30, -10, -30, -30, -10, 20, 30, 30, 20, -10, -30, -30, -30, 0,
|
||||
0, 0, 0, -30, -30, -50, -30, -30, -30, -30, -30, -30, -50,
|
||||
)
|
||||
|
||||
private val phaseWeight: Map[PieceType, Int] = Map(
|
||||
PieceType.Knight -> 1,
|
||||
PieceType.Bishop -> 1,
|
||||
PieceType.Rook -> 2,
|
||||
PieceType.Queen -> 4,
|
||||
)
|
||||
private val maxPhase = 24 // 4*4 + 4*2 + 4*1 + 4*1
|
||||
|
||||
private val passedPawnBonus: Array[Int] = Array(0, 5, 10, 20, 35, 60, 100, 0)
|
||||
private val egPassedPawnBonus: Array[Int] = Array(0, 20, 40, 80, 150, 250, 400, 0)
|
||||
|
||||
// Pawn structure penalties
|
||||
private val doubledMg = -10
|
||||
private val doubledEg = -25
|
||||
private val isolatedMg = -15
|
||||
private val isolatedEg = -20
|
||||
|
||||
// Mobility weights: centipawns per reachable square (indexed by PieceType.ordinal)
|
||||
private val mobilityMg = Array(0, 4, 3, 2, 1, 0, 0)
|
||||
private val mobilityEg = Array(0, 4, 3, 4, 2, 0, 0)
|
||||
|
||||
// Direction offsets for sliding pieces
|
||||
private val diagonals = List((-1, -1), (-1, 1), (1, -1), (1, 1))
|
||||
private val orthogonals = List((-1, 0), (1, 0), (0, -1), (0, 1))
|
||||
private val knightOffsets = List((-2, -1), (-2, 1), (-1, -2), (-1, 2), (1, -2), (1, 2), (2, -1), (2, 1))
|
||||
|
||||
// Rook and bishop bonuses
|
||||
private val bishopPairMg = 50
|
||||
private val bishopPairEg = 70
|
||||
private val rookOn7thMg = 20
|
||||
private val rookOn7thEg = 10
|
||||
|
||||
/** Evaluate the position from the perspective of context.turn. Positive = good for context.turn.
|
||||
*/
|
||||
def evaluate(context: GameContext): Int =
|
||||
val phase = gamePhase(context.board)
|
||||
val isEg = isEndgame(phase)
|
||||
val material = materialAndPositional(context, phase)
|
||||
val structure = pawnStructure(context, phase)
|
||||
val mobility = mobilityScore(context, phase)
|
||||
val rookBishop = rookAndBishopBonuses(context, phase)
|
||||
val bonuses = positionalBonuses(context, phase, isEg)
|
||||
val egBonuses = if isEg then endgameBonus(context) else 0
|
||||
material + structure + mobility + rookBishop + bonuses + egBonuses + TEMPO_BONUS
|
||||
|
||||
private def gamePhase(board: de.nowchess.api.board.Board): Int =
|
||||
val phase = board.pieces.values.foldLeft(0) { (acc, piece) =>
|
||||
acc + phaseWeight.getOrElse(piece.pieceType, 0)
|
||||
}
|
||||
math.min(phase, maxPhase)
|
||||
|
||||
private def isEndgame(phase: Int): Boolean =
|
||||
phase < 8 // Significantly reduced material indicates endgame
|
||||
|
||||
private def taper(mg: Int, eg: Int, phase: Int): Int =
|
||||
(mg * phase + eg * (maxPhase - phase)) / maxPhase
|
||||
|
||||
private def materialAndPositional(context: GameContext, phase: Int): Int =
|
||||
val (mg, eg) = context.board.pieces.foldLeft((0, 0)) { case ((mg, eg), (square, piece)) =>
|
||||
val (psqMg, psqEg) = squareBonus(piece.pieceType, piece.color, square)
|
||||
val pieceMg = mgMaterial(piece.pieceType.ordinal) + psqMg
|
||||
val pieceEg = egMaterial(piece.pieceType.ordinal) + psqEg
|
||||
val sign = if piece.color == context.turn then 1 else -1
|
||||
(mg + sign * pieceMg, eg + sign * pieceEg)
|
||||
}
|
||||
taper(mg, eg, phase)
|
||||
|
||||
private def squareBonus(pieceType: PieceType, color: Color, sq: Square): (Int, Int) =
|
||||
val rankIdx = if color == Color.White then sq.rank.ordinal else 7 - sq.rank.ordinal
|
||||
val fileIdx = sq.file.ordinal
|
||||
val squareIdx = rankIdx * 8 + fileIdx
|
||||
|
||||
pieceType match
|
||||
case PieceType.Pawn => (mgPawnTable(squareIdx), egPawnTable(squareIdx))
|
||||
case PieceType.Knight => (mgKnightTable(squareIdx), egKnightTable(squareIdx))
|
||||
case PieceType.Bishop => (mgBishopTable(squareIdx), egBishopTable(squareIdx))
|
||||
case PieceType.Rook => (mgRookTable(squareIdx), egRookTable(squareIdx))
|
||||
case PieceType.Queen => (mgQueenTable(squareIdx), egQueenTable(squareIdx))
|
||||
case PieceType.King => (mgKingTable(squareIdx), egKingTable(squareIdx))
|
||||
|
||||
private def pawnStructure(context: GameContext, phase: Int): Int =
|
||||
val friendlyPawns = context.board.pieces.filter((_, p) => p.color == context.turn && p.pieceType == PieceType.Pawn)
|
||||
val enemyPawns = context.board.pieces.filter((_, p) => p.color != context.turn && p.pieceType == PieceType.Pawn)
|
||||
|
||||
val friendlyByFile = friendlyPawns.groupMap(s => s._1.file.ordinal)(s => s._1.rank.ordinal)
|
||||
val enemyByFile = enemyPawns.groupMap(s => s._1.file.ordinal)(s => s._1.rank.ordinal)
|
||||
|
||||
val (fMg, fEg) = structureScore(friendlyByFile)
|
||||
val (eMg, eEg) = structureScore(enemyByFile)
|
||||
taper(fMg - eMg, fEg - eEg, phase)
|
||||
|
||||
private def structureScore(byFile: Map[Int, Iterable[Int]]): (Int, Int) =
|
||||
byFile.foldLeft((0, 0)) { case ((mg, eg), (file, ranks)) =>
|
||||
val doubled = (ranks.size - 1).max(0)
|
||||
val hasAdjacent = (file - 1 to file + 1).filter(f => f >= 0 && f < 8 && f != file).exists(byFile.contains)
|
||||
val isolated = if !hasAdjacent then ranks.size else 0
|
||||
(mg + doubled * doubledMg + isolated * isolatedMg, eg + doubled * doubledEg + isolated * isolatedEg)
|
||||
}
|
||||
|
||||
private def positionalBonuses(context: GameContext, phase: Int, isEg: Boolean): Int =
|
||||
context.board.pieces.foldLeft(0) { case (score, (sq, piece)) =>
|
||||
val bonus = piece.pieceType match
|
||||
case PieceType.Pawn =>
|
||||
if isPassedPawn(context.board, sq, piece.color) then
|
||||
if isEg then egPassedPawnBonus(sq.rank.ordinal) else passedPawnBonus(sq.rank.ordinal)
|
||||
else 0
|
||||
case PieceType.Rook => rookOpenFileBonus(context.board, sq, piece.color)
|
||||
case PieceType.King => kingShieldBonus(context.board, sq, piece.color, phase)
|
||||
case _ => 0
|
||||
if piece.color == context.turn then score + bonus else score - bonus
|
||||
}
|
||||
|
||||
private def isPassedPawn(board: de.nowchess.api.board.Board, sq: Square, color: Color): Boolean =
|
||||
val enemyColor = color.opposite
|
||||
val pawnRank = sq.rank.ordinal
|
||||
val fileRange = (sq.file.ordinal - 1 to sq.file.ordinal + 1).filter(f => f >= 0 && f < 8)
|
||||
val rankCheck = if color == Color.White then (r: Int) => r > pawnRank else (r: Int) => r < pawnRank
|
||||
|
||||
board.pieces.forall { (enemySq, enemyPiece) =>
|
||||
!(enemyPiece.color == enemyColor &&
|
||||
enemyPiece.pieceType == PieceType.Pawn &&
|
||||
fileRange.contains(enemySq.file.ordinal) &&
|
||||
rankCheck(enemySq.rank.ordinal))
|
||||
}
|
||||
|
||||
private def rookOpenFileBonus(board: de.nowchess.api.board.Board, rookSq: Square, color: Color): Int =
|
||||
val hasFriendlyPawn = board.pieces.exists { (sq, piece) =>
|
||||
piece.color == color && piece.pieceType == PieceType.Pawn && sq.file == rookSq.file
|
||||
}
|
||||
val hasEnemyPawn = board.pieces.exists { (sq, piece) =>
|
||||
piece.color != color && piece.pieceType == PieceType.Pawn && sq.file == rookSq.file
|
||||
}
|
||||
if !hasFriendlyPawn && !hasEnemyPawn then 20 // open file
|
||||
else if !hasFriendlyPawn then 10 // semi-open file
|
||||
else 0
|
||||
|
||||
private def kingShieldBonus(board: de.nowchess.api.board.Board, kingSq: Square, color: Color, phase: Int): Int =
|
||||
val shieldRankDelta = if color == Color.White then 1 else -1
|
||||
val shieldFiles = (kingSq.file.ordinal - 1 to kingSq.file.ordinal + 1).filter(f => f >= 0 && f < 8)
|
||||
val shieldRank = kingSq.rank.ordinal + shieldRankDelta
|
||||
|
||||
if shieldRank < 0 || shieldRank > 7 then 0
|
||||
else
|
||||
val rawBonus = board.pieces.count { (sq, piece) =>
|
||||
piece.color == color &&
|
||||
piece.pieceType == PieceType.Pawn &&
|
||||
shieldFiles.contains(sq.file.ordinal) &&
|
||||
sq.rank.ordinal == shieldRank
|
||||
} * 10
|
||||
(rawBonus * phase) / maxPhase
|
||||
|
||||
private def slidingCount(
|
||||
sq: Square,
|
||||
board: de.nowchess.api.board.Board,
|
||||
color: Color,
|
||||
directions: List[(Int, Int)],
|
||||
): Int =
|
||||
directions.foldLeft(0) { case (total, (fileDelta, rankDelta)) =>
|
||||
@scala.annotation.tailrec
|
||||
def countRay(current: Option[Square], acc: Int): Int =
|
||||
current match
|
||||
case None => acc
|
||||
case Some(target) =>
|
||||
board.pieceAt(target) match
|
||||
case Some(piece) if piece.color == color => acc
|
||||
case Some(_) => acc + 1
|
||||
case None => countRay(target.offset(fileDelta, rankDelta), acc + 1)
|
||||
total + countRay(sq.offset(fileDelta, rankDelta), 0)
|
||||
}
|
||||
|
||||
private def knightCount(sq: Square, board: de.nowchess.api.board.Board, color: Color): Int =
|
||||
knightOffsets.count { case (fileDelta, rankDelta) =>
|
||||
sq.offset(fileDelta, rankDelta).forall { target =>
|
||||
board.pieceAt(target).forall(_.color != color)
|
||||
}
|
||||
}
|
||||
|
||||
private def mobilityScore(context: GameContext, phase: Int): Int =
|
||||
val (mg, eg) = context.board.pieces.foldLeft((0, 0)) { case ((mg, eg), (sq, piece)) =>
|
||||
val count = piece.pieceType match
|
||||
case PieceType.Knight => knightCount(sq, context.board, piece.color)
|
||||
case PieceType.Bishop => slidingCount(sq, context.board, piece.color, diagonals)
|
||||
case PieceType.Rook => slidingCount(sq, context.board, piece.color, orthogonals)
|
||||
case PieceType.Queen => slidingCount(sq, context.board, piece.color, diagonals ++ orthogonals)
|
||||
case _ => 0
|
||||
val pieceMg = count * mobilityMg(piece.pieceType.ordinal)
|
||||
val pieceEg = count * mobilityEg(piece.pieceType.ordinal)
|
||||
val sign = if piece.color == context.turn then 1 else -1
|
||||
(mg + sign * pieceMg, eg + sign * pieceEg)
|
||||
}
|
||||
taper(mg, eg, phase)
|
||||
|
||||
private def rookAndBishopBonuses(context: GameContext, phase: Int): Int =
|
||||
val (baseMg, baseEg) = bishopPairBase(context)
|
||||
val (rookMg, rookEg) = rookOn7thDelta(context)
|
||||
taper(baseMg + rookMg, baseEg + rookEg, phase)
|
||||
|
||||
private def bishopPairBase(context: GameContext): (Int, Int) =
|
||||
val friendlyHasPair = hasBishopPair(context, context.turn)
|
||||
val enemyHasPair = hasBishopPair(context, context.turn.opposite)
|
||||
val mg = pairDelta(friendlyHasPair, enemyHasPair, bishopPairMg)
|
||||
val eg = pairDelta(friendlyHasPair, enemyHasPair, bishopPairEg)
|
||||
(mg, eg)
|
||||
|
||||
private def hasBishopPair(context: GameContext, color: Color): Boolean =
|
||||
val bishopSquares = context.board.pieces.collect {
|
||||
case (sq, piece) if piece.color == color && piece.pieceType == PieceType.Bishop => sq
|
||||
}
|
||||
bishopSquares.exists(isEvenSquare) && bishopSquares.exists(sq => !isEvenSquare(sq))
|
||||
|
||||
private def isEvenSquare(square: Square): Boolean =
|
||||
(square.file.ordinal + square.rank.ordinal) % 2 == 0
|
||||
|
||||
private def pairDelta(friendlyHasPair: Boolean, enemyHasPair: Boolean, bonus: Int): Int =
|
||||
(if friendlyHasPair then bonus else 0) - (if enemyHasPair then bonus else 0)
|
||||
|
||||
private def rookOn7thDelta(context: GameContext): (Int, Int) =
|
||||
context.board.pieces.foldLeft((0, 0)) { case ((mg, eg), (sq, piece)) =>
|
||||
rookOn7thContribution(piece, sq, context.turn).fold((mg, eg)) { case (dMg, dEg) =>
|
||||
(mg + dMg, eg + dEg)
|
||||
}
|
||||
}
|
||||
|
||||
private def rookOn7thContribution(piece: de.nowchess.api.board.Piece, sq: Square, turn: Color): Option[(Int, Int)] =
|
||||
Option.when(piece.pieceType == PieceType.Rook && isRookOn7th(piece.color, sq)) {
|
||||
val sign = if piece.color == turn then 1 else -1
|
||||
(sign * rookOn7thMg, sign * rookOn7thEg)
|
||||
}
|
||||
|
||||
private def isRookOn7th(color: Color, sq: Square): Boolean =
|
||||
if color == Color.White then sq.rank.ordinal == 6 else sq.rank.ordinal == 1
|
||||
|
||||
private def endgameBonus(context: GameContext): Int =
|
||||
val friendlyKing = context.board.pieces.find((_, p) => p.color == context.turn && p.pieceType == PieceType.King)
|
||||
val enemyKing = context.board.pieces.find((_, p) => p.color != context.turn && p.pieceType == PieceType.King)
|
||||
|
||||
val kingCentralBonus =
|
||||
friendlyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15) -
|
||||
enemyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15)
|
||||
|
||||
val friendlyMaterial = materialCount(context, context.turn)
|
||||
val enemyMaterial = materialCount(context, context.turn.opposite)
|
||||
val edgeBonus =
|
||||
if friendlyMaterial > enemyMaterial then enemyKing.fold(0)((kSq, _) => (7 - kingEdgeDistance(kSq)) * 10)
|
||||
else 0
|
||||
|
||||
kingCentralBonus + edgeBonus
|
||||
|
||||
private def kingCentralizationDistance(sq: Square): Int =
|
||||
val fileFromCenter = (sq.file.ordinal - 3.5).abs.toInt
|
||||
val rankFromCenter = (sq.rank.ordinal - 3.5).abs.toInt
|
||||
math.max(fileFromCenter, rankFromCenter)
|
||||
|
||||
private def kingEdgeDistance(sq: Square): Int =
|
||||
val fileFromEdge = math.min(sq.file.ordinal, 7 - sq.file.ordinal)
|
||||
val rankFromEdge = math.min(sq.rank.ordinal, 7 - sq.rank.ordinal)
|
||||
math.min(fileFromEdge, rankFromEdge)
|
||||
|
||||
private def materialCount(context: GameContext, color: Color): Int =
|
||||
context.board.pieces.foldLeft(0) { case (sum, (_, piece)) =>
|
||||
if piece.color == color then
|
||||
sum + (piece.pieceType match
|
||||
case PieceType.Knight => 300
|
||||
case PieceType.Bishop => 300
|
||||
case PieceType.Rook => 500
|
||||
case PieceType.Queen => 900
|
||||
case PieceType.Pawn => 0
|
||||
case PieceType.King => 0
|
||||
)
|
||||
else sum
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.ai.Evaluation
|
||||
|
||||
object EvaluationNNUE extends Evaluation:
|
||||
|
||||
private val nnue = NNUE(NbaiLoader.loadDefault())
|
||||
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
|
||||
/** Full-board evaluate — used as fallback and by non-search callers. */
|
||||
def evaluate(context: GameContext): Int = nnue.evaluate(context)
|
||||
|
||||
// ── Accumulator hooks (incremental L1) ───────────────────────────────────
|
||||
|
||||
override def initAccumulator(context: GameContext): Unit =
|
||||
nnue.initAccumulator(context.board)
|
||||
|
||||
override def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
||||
nnue.copyAccumulator(parentPly, childPly)
|
||||
|
||||
override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit =
|
||||
// Use incremental updates, but recompute from scratch every 10 plies to prevent accumulation errors
|
||||
if childPly % 10 == 0 then nnue.recomputeAccumulator(childPly, child.board)
|
||||
else nnue.pushAccumulator(childPly, move, parent.board)
|
||||
|
||||
override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int =
|
||||
nnue.evaluateAtPlyWithValidation(ply, context.turn, hash, context.board)
|
||||
@@ -0,0 +1,231 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Square}
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||
|
||||
class NNUE(model: NbaiModel):
|
||||
|
||||
private val featureSize = model.layers(0).inputSize
|
||||
private val accSize = model.layers(0).outputSize
|
||||
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
|
||||
|
||||
// Column-major L1 weights for cache-friendly sparse & incremental updates.
|
||||
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
|
||||
private val l1WeightsT: Array[Float] =
|
||||
val w = model.weights(0).weights
|
||||
val t = new Array[Float](featureSize * accSize)
|
||||
for j <- 0 until featureSize; i <- 0 until accSize do t(j * accSize + i) = w(i * featureSize + j)
|
||||
t
|
||||
|
||||
// ── Accumulator stack ────────────────────────────────────────────────────
|
||||
|
||||
private val MAX_PLY = 128
|
||||
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](accSize))
|
||||
|
||||
// Shared evaluation buffers: index i holds the output of layers(i) (all except the scalar output layer).
|
||||
private val evalBuffers: Array[Array[Float]] = model.layers.init.map(l => new Array[Float](l.outputSize))
|
||||
|
||||
// ── Eval cache ───────────────────────────────────────────────────────────
|
||||
|
||||
private val EVAL_CACHE_MASK = (1 << 18) - 1L
|
||||
private val evalCacheHashes = new Array[Long](1 << 18)
|
||||
private val evalCacheScores = new Array[Int](1 << 18)
|
||||
|
||||
// ── Feature helpers ──────────────────────────────────────────────────────
|
||||
|
||||
private def squareNum(sq: Square): Int = sq.rank.ordinal * 8 + sq.file.ordinal
|
||||
|
||||
private def featureIndex(piece: Piece, sqNum: Int): Int =
|
||||
val colorOffset = if piece.color == Color.White then 6 else 0
|
||||
(colorOffset + piece.pieceType.ordinal) * 64 + sqNum
|
||||
|
||||
private def addColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
||||
val offset = featureIdx * accSize
|
||||
for i <- 0 until accSize do l1Pre(i) += l1WeightsT(offset + i)
|
||||
|
||||
private def subtractColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
||||
val offset = featureIdx * accSize
|
||||
for i <- 0 until accSize do l1Pre(i) -= l1WeightsT(offset + i)
|
||||
|
||||
// ── Accumulator init ─────────────────────────────────────────────────────
|
||||
|
||||
def initAccumulator(board: Board): Unit =
|
||||
System.arraycopy(model.weights(0).bias, 0, l1Stack(0), 0, accSize)
|
||||
for (sq, piece) <- board.pieces do addColumn(l1Stack(0), featureIndex(piece, squareNum(sq)))
|
||||
|
||||
// ── Accumulator push (incremental updates) ───────────────────────────────
|
||||
|
||||
def pushAccumulator(childPly: Int, move: Move, board: Board): Unit =
|
||||
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize)
|
||||
val l1 = l1Stack(childPly)
|
||||
move.moveType match
|
||||
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
|
||||
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
|
||||
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
|
||||
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
|
||||
|
||||
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
||||
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
|
||||
|
||||
def recomputeAccumulator(ply: Int, board: Board): Unit =
|
||||
System.arraycopy(model.weights(0).bias, 0, l1Stack(ply), 0, accSize)
|
||||
for (sq, piece) <- board.pieces do addColumn(l1Stack(ply), featureIndex(piece, squareNum(sq)))
|
||||
|
||||
def validateAccumulator(ply: Int, board: Board): Boolean =
|
||||
// Compute what L1 should be from scratch
|
||||
val expectedL1 = new Array[Float](accSize)
|
||||
System.arraycopy(model.weights(0).bias, 0, expectedL1, 0, accSize)
|
||||
for (sq, piece) <- board.pieces do addColumn(expectedL1, featureIndex(piece, squareNum(sq)))
|
||||
|
||||
// Compare with actual L1
|
||||
val actual = l1Stack(ply)
|
||||
val maxError =
|
||||
(0 until accSize).foldLeft(0f) { (currentMax, i) =>
|
||||
val error = math.abs(actual(i) - expectedL1(i))
|
||||
math.max(currentMax, error)
|
||||
}
|
||||
|
||||
maxError < 0.001f // Allow small floating-point errors
|
||||
|
||||
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||
// Extract source and destination square indices early
|
||||
val fromNum = squareNum(move.from)
|
||||
val toNum = squareNum(move.to)
|
||||
|
||||
// Get the moving piece
|
||||
board.pieceAt(move.from).foreach { mover =>
|
||||
subtractColumn(l1, featureIndex(mover, fromNum))
|
||||
|
||||
// If there's a capture, subtract the captured piece
|
||||
board.pieceAt(move.to).foreach { cap =>
|
||||
subtractColumn(l1, featureIndex(cap, toNum))
|
||||
}
|
||||
|
||||
// Add the piece to its new location
|
||||
addColumn(l1, featureIndex(mover, toNum))
|
||||
}
|
||||
|
||||
private def applyEnPassantDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||
board.pieceAt(move.from).foreach { pawn =>
|
||||
val capturedSq = Square(move.to.file, move.from.rank)
|
||||
subtractColumn(l1, featureIndex(pawn, squareNum(move.from)))
|
||||
board.pieceAt(capturedSq).foreach(cap => subtractColumn(l1, featureIndex(cap, squareNum(capturedSq))))
|
||||
addColumn(l1, featureIndex(pawn, squareNum(move.to)))
|
||||
}
|
||||
|
||||
private def applyCastleDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||
board.pieceAt(move.from).foreach { king =>
|
||||
val rank = move.from.rank
|
||||
val kingside = move.moveType == MoveType.CastleKingside
|
||||
val (rookFrom, rookTo) =
|
||||
if kingside then (Square(File.H, rank), Square(File.F, rank))
|
||||
else (Square(File.A, rank), Square(File.D, rank))
|
||||
val rook = Piece(king.color, PieceType.Rook)
|
||||
subtractColumn(l1, featureIndex(king, squareNum(move.from)))
|
||||
addColumn(l1, featureIndex(king, squareNum(move.to)))
|
||||
subtractColumn(l1, featureIndex(rook, squareNum(rookFrom)))
|
||||
addColumn(l1, featureIndex(rook, squareNum(rookTo)))
|
||||
}
|
||||
|
||||
private def applyPromotionDelta(l1: Array[Float], move: Move, promo: PromotionPiece, board: Board): Unit =
|
||||
board.pieceAt(move.from).foreach { pawn =>
|
||||
val toNum = squareNum(move.to)
|
||||
subtractColumn(l1, featureIndex(pawn, squareNum(move.from)))
|
||||
board.pieceAt(move.to).foreach(cap => subtractColumn(l1, featureIndex(cap, toNum)))
|
||||
addColumn(l1, featureIndex(Piece(pawn.color, promotedType(promo)), toNum))
|
||||
}
|
||||
|
||||
private def promotedType(promo: PromotionPiece): PieceType = promo match
|
||||
case PromotionPiece.Knight => PieceType.Knight
|
||||
case PromotionPiece.Bishop => PieceType.Bishop
|
||||
case PromotionPiece.Rook => PieceType.Rook
|
||||
case PromotionPiece.Queen => PieceType.Queen
|
||||
|
||||
// ── Evaluation from accumulator ──────────────────────────────────────────
|
||||
|
||||
def evaluateAtPly(ply: Int, turn: Color, hash: Long): Int =
|
||||
val idx = (hash & EVAL_CACHE_MASK).toInt
|
||||
if evalCacheHashes(idx) == hash then evalCacheScores(idx)
|
||||
else
|
||||
val score = runL2toOutput(l1Stack(ply), turn)
|
||||
evalCacheHashes(idx) = hash
|
||||
evalCacheScores(idx) = score
|
||||
score
|
||||
|
||||
def evaluateAtPlyWithValidation(ply: Int, turn: Color, hash: Long, board: Board): Int =
|
||||
// For debugging: validate that incremental accumulator matches recomputation
|
||||
if validateAccum && ply > 0 && ply % 10 != 0 then
|
||||
val isValid = validateAccumulator(ply, board)
|
||||
if !isValid then System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
|
||||
evaluateAtPly(ply, turn, hash)
|
||||
|
||||
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
|
||||
val l1ReLU = evalBuffers(0)
|
||||
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
|
||||
|
||||
val finalInput =
|
||||
(1 until model.layers.length - 1).foldLeft(l1ReLU) { (input, i) =>
|
||||
val lw = model.weights(i)
|
||||
val out = evalBuffers(i)
|
||||
val ld = model.layers(i)
|
||||
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
|
||||
out
|
||||
}
|
||||
|
||||
val lastIdx = model.layers.length - 1
|
||||
val output = runOutputLayer(finalInput, model.layers(lastIdx).inputSize, model.weights(lastIdx))
|
||||
scoreFromOutput(output, turn)
|
||||
|
||||
private def runDenseReLU(
|
||||
input: Array[Float],
|
||||
inSize: Int,
|
||||
weights: Array[Float],
|
||||
bias: Array[Float],
|
||||
output: Array[Float],
|
||||
outSize: Int,
|
||||
): Unit =
|
||||
for i <- 0 until outSize do
|
||||
val sum = (0 until inSize).foldLeft(bias(i))((s, j) => s + input(j) * weights(i * inSize + j))
|
||||
output(i) = if sum > 0f then sum else 0f
|
||||
|
||||
private def runOutputLayer(input: Array[Float], inSize: Int, lw: LayerWeights): Float =
|
||||
(0 until inSize).foldLeft(lw.bias(0))((sum, j) => sum + input(j) * lw.weights(j))
|
||||
|
||||
private def scoreFromOutput(output: Float, turn: Color): Int =
|
||||
val cp =
|
||||
if math.abs(output) >= 0.9999f then if output > 0f then 20000 else -20000
|
||||
else
|
||||
val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat
|
||||
(300f * atanh).toInt
|
||||
val cpFromTurn = if turn == Color.Black then -cp else cp
|
||||
math.max(-20000, math.min(20000, cpFromTurn))
|
||||
|
||||
// ── Legacy full-board evaluate ────────────────────────────────────────────
|
||||
|
||||
private val legacyL1 = new Array[Float](accSize)
|
||||
|
||||
def evaluate(context: GameContext): Int =
|
||||
System.arraycopy(model.weights(0).bias, 0, legacyL1, 0, accSize)
|
||||
for (sq, piece) <- context.board.pieces do addColumn(legacyL1, featureIndex(piece, squareNum(sq)))
|
||||
runL2toOutput(legacyL1, context.turn)
|
||||
|
||||
def benchmark(): Unit =
|
||||
val context = GameContext.initial
|
||||
val iterations = 1_000_000
|
||||
for _ <- 0 until 10000 do evaluate(context)
|
||||
val startNanos = System.nanoTime()
|
||||
for _ <- 0 until iterations do evaluate(context)
|
||||
val endNanos = System.nanoTime()
|
||||
val totalNanos = endNanos - startNanos
|
||||
val nanosPerEval = totalNanos.toDouble / iterations
|
||||
println()
|
||||
println("=" * 60)
|
||||
println("NNUE BENCHMARK RESULTS")
|
||||
println("=" * 60)
|
||||
println(f"Iterations: $iterations%,d")
|
||||
println(f"Total time: ${totalNanos / 1e9}%.2f seconds")
|
||||
println(f"ns/eval: $nanosPerEval%.2f ns")
|
||||
println(f"evals/second: ${1e9 / nanosPerEval}%.0f evals/s")
|
||||
println("=" * 60)
|
||||
println()
|
||||
@@ -0,0 +1,52 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import java.io.InputStream
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
object NbaiLoader:
|
||||
|
||||
/** Little-endian encoding of ASCII bytes 'N','B','A','I'. */
|
||||
val MAGIC: Int = 0x4942_414e
|
||||
|
||||
def load(stream: InputStream): NbaiModel =
|
||||
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||||
checkHeader(buf)
|
||||
val metadata = readMetadata(buf)
|
||||
val descs = readLayerDescriptors(buf)
|
||||
val weights = descs.map(_ => readLayerWeights(buf))
|
||||
NbaiModel(metadata, descs, weights)
|
||||
|
||||
/** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */
|
||||
def loadDefault(): NbaiModel =
|
||||
Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match
|
||||
case Some(s) =>
|
||||
try load(s)
|
||||
finally s.close()
|
||||
case None => NbaiMigrator.migrateFromBin()
|
||||
|
||||
private def checkHeader(buf: ByteBuffer): Unit =
|
||||
val magic = buf.getInt()
|
||||
if magic != MAGIC then sys.error(s"Invalid NBAI magic: 0x${magic.toHexString}")
|
||||
val version = buf.getShort() & 0xffff
|
||||
if version != 1 then sys.error(s"Unsupported NBAI version: $version")
|
||||
|
||||
private def readMetadata(buf: ByteBuffer): NbaiMetadata =
|
||||
val bytes = new Array[Byte](buf.getInt())
|
||||
buf.get(bytes)
|
||||
NbaiMetadata.fromJson(new String(bytes, StandardCharsets.UTF_8))
|
||||
|
||||
private def readLayerDescriptors(buf: ByteBuffer): Array[LayerDescriptor] =
|
||||
Array.tabulate(buf.getShort() & 0xffff) { _ =>
|
||||
val nameBytes = new Array[Byte](buf.get() & 0xff)
|
||||
buf.get(nameBytes)
|
||||
LayerDescriptor(new String(nameBytes, StandardCharsets.US_ASCII), buf.getInt(), buf.getInt())
|
||||
}
|
||||
|
||||
private def readLayerWeights(buf: ByteBuffer): LayerWeights =
|
||||
LayerWeights(readFloats(buf), readFloats(buf))
|
||||
|
||||
private def readFloats(buf: ByteBuffer): Array[Float] =
|
||||
val arr = new Array[Float](buf.getInt())
|
||||
for i <- arr.indices do arr(i) = buf.getFloat()
|
||||
arr
|
||||
@@ -0,0 +1,43 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
|
||||
/** Converts the legacy nnue_weights.bin resource into an NbaiModel. Used as fallback when no .nbai file exists. */
|
||||
object NbaiMigrator:
|
||||
|
||||
private val BinMagic = 0x4555_4e4e
|
||||
private val BinVersion = 1
|
||||
|
||||
private val DefaultLayers: Array[LayerDescriptor] = Array(
|
||||
LayerDescriptor("relu", 768, 1536),
|
||||
LayerDescriptor("relu", 1536, 1024),
|
||||
LayerDescriptor("relu", 1024, 512),
|
||||
LayerDescriptor("relu", 512, 256),
|
||||
LayerDescriptor("linear", 256, 1),
|
||||
)
|
||||
|
||||
private val UnknownMetadata: NbaiMetadata =
|
||||
NbaiMetadata(trainedBy = "unknown", trainedAt = "unknown", trainingDataCount = 0L, valLoss = 0.0, trainLoss = 0.0)
|
||||
|
||||
def migrateFromBin(): NbaiModel =
|
||||
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
|
||||
.getOrElse(sys.error("Neither nnue_weights.nbai nor nnue_weights.bin found in resources"))
|
||||
try
|
||||
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||||
checkBinHeader(buf)
|
||||
val weights = DefaultLayers.map(_ => readBinLayerWeights(buf))
|
||||
NbaiModel(UnknownMetadata, DefaultLayers, weights)
|
||||
finally stream.close()
|
||||
|
||||
private def checkBinHeader(buf: ByteBuffer): Unit =
|
||||
val magic = buf.getInt()
|
||||
if magic != BinMagic then sys.error(s"Invalid bin magic: 0x${magic.toHexString}")
|
||||
val version = buf.getInt()
|
||||
if version != BinVersion then sys.error(s"Unsupported bin version: $version")
|
||||
|
||||
private def readBinLayerWeights(buf: ByteBuffer): LayerWeights =
|
||||
LayerWeights(readBinTensor(buf), readBinTensor(buf))
|
||||
|
||||
private def readBinTensor(buf: ByteBuffer): Array[Float] =
|
||||
val shape = Array.tabulate(buf.getInt())(_ => buf.getInt())
|
||||
Array.tabulate(shape.product)(_ => buf.getFloat())
|
||||
@@ -0,0 +1,45 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
/** Descriptor for a single dense layer stored in a .nbai file. */
|
||||
case class LayerDescriptor(activation: String, inputSize: Int, outputSize: Int)
|
||||
|
||||
/** Training metadata embedded in every .nbai file. */
|
||||
case class NbaiMetadata(
|
||||
trainedBy: String,
|
||||
trainedAt: String,
|
||||
trainingDataCount: Long,
|
||||
valLoss: Double,
|
||||
trainLoss: Double,
|
||||
):
|
||||
def toJson: String =
|
||||
s"""{
|
||||
| "trainedBy": "$trainedBy",
|
||||
| "trainedAt": "$trainedAt",
|
||||
| "trainingDataCount": $trainingDataCount,
|
||||
| "valLoss": $valLoss,
|
||||
| "trainLoss": $trainLoss
|
||||
|}""".stripMargin
|
||||
|
||||
object NbaiMetadata:
|
||||
def fromJson(json: String): NbaiMetadata =
|
||||
def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("")
|
||||
def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0")
|
||||
NbaiMetadata(
|
||||
str("trainedBy"),
|
||||
str("trainedAt"),
|
||||
num("trainingDataCount").toLong,
|
||||
num("valLoss").toDouble,
|
||||
num("trainLoss").toDouble,
|
||||
)
|
||||
|
||||
/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */
|
||||
case class LayerWeights(weights: Array[Float], bias: Array[Float])
|
||||
|
||||
/** A fully deserialized .nbai model ready to initialize NNUE. */
|
||||
case class NbaiModel(
|
||||
metadata: NbaiMetadata,
|
||||
layers: Array[LayerDescriptor],
|
||||
weights: Array[LayerWeights],
|
||||
):
|
||||
require(layers.length == weights.length, "Layer count must match weight count")
|
||||
require(layers.length >= 2, "Model must have at least 2 layers")
|
||||
@@ -0,0 +1,51 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import java.io.{ByteArrayOutputStream, OutputStream}
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
object NbaiWriter:
|
||||
|
||||
def write(model: NbaiModel, out: OutputStream): Unit =
|
||||
val acc = new ByteArrayOutputStream()
|
||||
writeHeader(acc)
|
||||
writeMetadata(acc, model.metadata)
|
||||
writeLayerDescriptors(acc, model.layers)
|
||||
model.weights.foreach(lw => writeLayerWeights(acc, lw))
|
||||
out.write(acc.toByteArray)
|
||||
|
||||
private def writeHeader(out: ByteArrayOutputStream): Unit =
|
||||
val buf = ByteBuffer.allocate(6).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putInt(NbaiLoader.MAGIC)
|
||||
buf.putShort(1.toShort)
|
||||
out.write(buf.array())
|
||||
|
||||
private def writeMetadata(out: ByteArrayOutputStream, meta: NbaiMetadata): Unit =
|
||||
val json = meta.toJson.getBytes(StandardCharsets.UTF_8)
|
||||
val buf = ByteBuffer.allocate(4 + json.length).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putInt(json.length)
|
||||
buf.put(json)
|
||||
out.write(buf.array())
|
||||
|
||||
private def writeLayerDescriptors(out: ByteArrayOutputStream, layers: Array[LayerDescriptor]): Unit =
|
||||
val nameBytes = layers.map(_.activation.getBytes(StandardCharsets.US_ASCII))
|
||||
val capacity = 2 + layers.indices.map(i => 1 + nameBytes(i).length + 8).sum
|
||||
val buf = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putShort(layers.length.toShort)
|
||||
layers.zip(nameBytes).foreach { (l, nb) =>
|
||||
buf.put(nb.length.toByte)
|
||||
buf.put(nb)
|
||||
buf.putInt(l.inputSize)
|
||||
buf.putInt(l.outputSize)
|
||||
}
|
||||
out.write(buf.array())
|
||||
|
||||
private def writeLayerWeights(out: ByteArrayOutputStream, lw: LayerWeights): Unit =
|
||||
writeFloats(out, lw.weights)
|
||||
writeFloats(out, lw.bias)
|
||||
|
||||
private def writeFloats(out: ByteArrayOutputStream, floats: Array[Float]): Unit =
|
||||
val buf = ByteBuffer.allocate(4 + floats.length * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putInt(floats.length)
|
||||
floats.foreach(buf.putFloat)
|
||||
out.write(buf.array())
|
||||
Reference in New Issue
Block a user