feat: enhance AlphaBetaSearch with move timing and improve king square detection in GameContext
This commit is contained in:
Binary file not shown.
@@ -4,4 +4,7 @@ object Config:
|
||||
|
||||
/** Threshold in centipawns: if classical evaluation differs from NNUE by more than this,
|
||||
* the move is vetoed (not accepted as a suggestion). */
|
||||
val VETO_THRESHOLD: Int = 100
|
||||
val VETO_THRESHOLD: Int = 150
|
||||
|
||||
/** Time budget per move for iterative deepening (milliseconds). */
|
||||
val TIME_LIMIT_MS: Long = 2000L
|
||||
|
||||
@@ -2,9 +2,11 @@ package de.nowchess.bot.bots
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.logic.HybridSearch
|
||||
import de.nowchess.bot.bots.classic.EvaluationClassic
|
||||
import de.nowchess.bot.bots.nnue.EvaluationNNUE
|
||||
import de.nowchess.bot.logic.{AlphaBetaSearch, TranspositionTable}
|
||||
import de.nowchess.bot.util.PolyglotBook
|
||||
import de.nowchess.bot.{Bot, BotDifficulty}
|
||||
import de.nowchess.bot.{Bot, BotDifficulty, Config}
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
|
||||
@@ -14,10 +16,20 @@ final class HybridBot(
|
||||
book: Option[PolyglotBook] = None
|
||||
) extends Bot:
|
||||
|
||||
private val search: HybridSearch = HybridSearch(rules)
|
||||
private val search = AlphaBetaSearch(rules, TranspositionTable(), EvaluationClassic)
|
||||
|
||||
override val name: String = s"HybridBot(${difficulty.toString})"
|
||||
|
||||
override def nextMove(context: GameContext): Option[Move] =
|
||||
book.flatMap(_.probe(context))
|
||||
.orElse(search.bestMove(context))
|
||||
book.flatMap(_.probe(context)).orElse(searchWithVeto(context))
|
||||
|
||||
private def searchWithVeto(context: GameContext): Option[Move] =
|
||||
search.bestMoveWithTime(context, Config.TIME_LIMIT_MS).map { move =>
|
||||
val next = rules.applyMove(context)(move)
|
||||
val staticNnue = EvaluationNNUE.evaluate(next)
|
||||
val classical = EvaluationClassic.evaluate(next)
|
||||
val diff = (classical - staticNnue).abs
|
||||
if diff > Config.VETO_THRESHOLD then
|
||||
println(f"[Veto] ${move.from}->${move.to}: nnue=$staticNnue classical=$classical diff=$diff — flagged but trusted (deep search)")
|
||||
move
|
||||
}
|
||||
|
||||
@@ -67,28 +67,11 @@ class NNUE:
|
||||
private val l4Output = new Array[Float](256)
|
||||
|
||||
/** Convert a position to 768-dimensional binary feature vector.
|
||||
* 12 piece types (white pawn to black king) × 64 squares from white's perspective. */
|
||||
private def positionToFeatures(board: Board, sideToMove: Color): Array[Float] =
|
||||
// Zero out features array
|
||||
* Layout matches training: black pieces at indices 0-5, white at 6-11.
|
||||
* feature_idx = piece_idx * 64 + square (square: a1=0 .. h8=63, no mirroring). */
|
||||
private def positionToFeatures(board: Board): Array[Float] =
|
||||
java.util.Arrays.fill(features, 0f)
|
||||
|
||||
// Piece type to feature index offset: wp=0, wn=64, wb=128, wr=192, wq=256, wk=320, bp=384, bn=448, bb=512, br=576, bq=640, bk=704
|
||||
val pieceToFeatureOffset = Array(
|
||||
0, // White Pawn (0)
|
||||
64, // White Knight (1)
|
||||
128, // White Bishop (2)
|
||||
192, // White Rook (3)
|
||||
256, // White Queen (4)
|
||||
320, // White King (5)
|
||||
384, // Black Pawn (6)
|
||||
448, // Black Knight (7)
|
||||
512, // Black Bishop (8)
|
||||
576, // Black Rook (9)
|
||||
640, // Black Queen (10)
|
||||
704 // Black King (11)
|
||||
)
|
||||
|
||||
// Build features: always from white's perspective
|
||||
for
|
||||
fileIdx <- 0 until 8
|
||||
rankIdx <- 0 until 8
|
||||
@@ -99,17 +82,11 @@ class NNUE:
|
||||
val squareNum = rankIdx * 8 + fileIdx
|
||||
|
||||
board.pieceAt(square).foreach { piece =>
|
||||
val featureIdx = if sideToMove == Color.Black then
|
||||
// Mirror square for black side-to-move
|
||||
val mirroredSq = squareNum ^ 56
|
||||
val offset = pieceToFeatureOffset(piece.color.ordinal * 6 + piece.pieceType.ordinal)
|
||||
offset + mirroredSq
|
||||
else
|
||||
val offset = pieceToFeatureOffset(piece.color.ordinal * 6 + piece.pieceType.ordinal)
|
||||
offset + squareNum
|
||||
|
||||
if featureIdx >= 0 && featureIdx < 768 then
|
||||
features(featureIdx) = 1f
|
||||
// black pieces → 0-5, white pieces → 6-11 (matches Python training encoding)
|
||||
val colorOffset = if piece.color == Color.White then 6 else 0
|
||||
val pieceIdx = colorOffset + piece.pieceType.ordinal
|
||||
val featureIdx = pieceIdx * 64 + squareNum
|
||||
features(featureIdx) = 1f
|
||||
}
|
||||
|
||||
features
|
||||
@@ -119,7 +96,7 @@ class NNUE:
|
||||
* No allocations in the hot path (uses pre-allocated buffers).
|
||||
* Architecture: 768→1536→1024→512→256→1 */
|
||||
def evaluate(context: GameContext): Int =
|
||||
val features = positionToFeatures(context.board, context.turn)
|
||||
val features = positionToFeatures(context.board)
|
||||
|
||||
// Layer 1: Dense(768 → 1536) + ReLU
|
||||
for i <- 0 until 1536 do
|
||||
@@ -154,18 +131,17 @@ class NNUE:
|
||||
for j <- 0 until 256 do
|
||||
output += l4Output(j) * l5Weights(j)
|
||||
|
||||
// Convert from tanh-normalized output back to centipawns
|
||||
// Training uses: eval_normalized = tanh(eval_cp / 300)
|
||||
// Inverse: eval_cp = 300 * atanh(output)
|
||||
// atanh(x) = 0.5 * ln((1 + x) / (1 - x))
|
||||
// Convert from tanh-normalized output back to centipawns.
|
||||
// Training uses: eval_normalized = tanh(eval_cp / 300) always from White's perspective.
|
||||
// Inverse: eval_cp = 300 * atanh(output); negate for Black to return from side-to-move perspective.
|
||||
val cp = if math.abs(output) >= 0.9999f then
|
||||
// Clamp for numerical stability (avoid ln of very small numbers)
|
||||
if output > 0f then 20000 else -20000
|
||||
else
|
||||
val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat
|
||||
(300f * atanh).toInt
|
||||
|
||||
math.max(-20000, math.min(20000, cp))
|
||||
val cpFromTurn = if context.turn == Color.Black then -cp else cp
|
||||
math.max(-20000, math.min(20000, cpFromTurn))
|
||||
|
||||
/** Benchmark: time 1M evaluations and report ns/eval.
|
||||
* This measures the performance of the inference on the starting position. */
|
||||
|
||||
@@ -73,8 +73,12 @@ final class AlphaBetaSearch(
|
||||
if depth == 1 then (-INF, INF)
|
||||
else (prevScore - aspWindow, prevScore + aspWindow)
|
||||
val (score, move) = searchWithAspiration(context, depth, alpha, beta, aspWindow)
|
||||
val elapsed = System.currentTimeMillis() - timeStartMs
|
||||
prevScore = score
|
||||
move.foreach(m => bestSoFar = Some(m))
|
||||
move.foreach { m =>
|
||||
bestSoFar = Some(m)
|
||||
println(f"[Depth $depth%2d | ${elapsed}%5dms | ${nodeCount}%7d nodes] best=${m.from}->${m.to} score=$score")
|
||||
}
|
||||
aspWindow = ASPIRATION_DELTA
|
||||
depth += 1
|
||||
bestSoFar
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
package de.nowchess.bot.logic
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.Config
|
||||
import de.nowchess.bot.bots.classic.EvaluationClassic
|
||||
import de.nowchess.bot.bots.nnue.EvaluationNNUE
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
import scala.util.boundary
|
||||
import scala.util.boundary.break
|
||||
|
||||
final class HybridSearch(
|
||||
rules: RuleSet = DefaultRules
|
||||
):
|
||||
|
||||
private var vetoCount = 0
|
||||
private var approvalCount = 0
|
||||
private val TOP_MOVES_TO_VALIDATE = 10
|
||||
|
||||
/** Find the best move by scoring all legal moves with NNUE, then validating top 5 with classical eval.
|
||||
* If a move's classical score is within VETO_THRESHOLD of its NNUE score, it's approved.
|
||||
* If all top 5 are vetoed, fall back to the best classical move overall.
|
||||
*/
|
||||
def bestMove(context: GameContext): Option[Move] =
|
||||
val legalMoves = rules.allLegalMoves(context)
|
||||
if legalMoves.isEmpty then None else findBestMove(legalMoves, context)
|
||||
|
||||
private def findBestMove(legalMoves: List[Move], context: GameContext): Option[Move] =
|
||||
// Score all moves with NNUE
|
||||
val moveScores = legalMoves.map { move =>
|
||||
val nextContext = rules.applyMove(context)(move)
|
||||
val nnueScore = EvaluationNNUE.evaluate(nextContext)
|
||||
(move, nnueScore, nextContext)
|
||||
}
|
||||
|
||||
// Sort by NNUE score descending
|
||||
val sortedByNNUE = moveScores.sortBy(_._2).reverse
|
||||
|
||||
// Validate top N moves with classical evaluation
|
||||
val topMovesToCheck = sortedByNNUE.take(TOP_MOVES_TO_VALIDATE)
|
||||
|
||||
boundary:
|
||||
for (move, nnueScore, nextContext) <- topMovesToCheck do
|
||||
val classicalScore = EvaluationClassic.evaluate(nextContext)
|
||||
val difference = (classicalScore - nnueScore).abs
|
||||
if difference <= Config.VETO_THRESHOLD then
|
||||
approvalCount += 1
|
||||
println(s"[HybridSearch] Move approved: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference)")
|
||||
break(Some(move))
|
||||
else
|
||||
vetoCount += 1
|
||||
println(s"[HybridSearch] Move vetoed: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference > ${Config.VETO_THRESHOLD})")
|
||||
|
||||
// All top 10 were vetoed, fall back to best classical move
|
||||
println(s"[HybridSearch] All top 10 NNUE moves vetoed. Falling back to best classical move.")
|
||||
val bestByClassical = moveScores
|
||||
.map { case (move, _, nextContext) =>
|
||||
(move, EvaluationClassic.evaluate(nextContext))
|
||||
}
|
||||
.maxBy(_._2)
|
||||
|
||||
println(s"[HybridSearch] Fallback move: ${bestByClassical._1} (Classical score=${bestByClassical._2})")
|
||||
println(s"[HybridSearch] Stats - Approvals: $approvalCount, Vetoes: $vetoCount")
|
||||
Some(bestByClassical._1)
|
||||
|
||||
def getStats: (Int, Int) = (approvalCount, vetoCount)
|
||||
Reference in New Issue
Block a user