feat: add hybrid bot implementation and enhance NNUE training pipeline with tactical data extraction
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
package de.nowchess.bot
|
||||
|
||||
object Config:
|
||||
|
||||
/** Threshold in centipawns: if classical evaluation differs from NNUE by more than this,
|
||||
* the move is vetoed (not accepted as a suggestion). */
|
||||
val VETO_THRESHOLD: Int = 100
|
||||
@@ -0,0 +1,23 @@
|
||||
package de.nowchess.bot.bots
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.logic.HybridSearch
|
||||
import de.nowchess.bot.util.PolyglotBook
|
||||
import de.nowchess.bot.{Bot, BotDifficulty}
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
|
||||
final class HybridBot(
|
||||
difficulty: BotDifficulty,
|
||||
rules: RuleSet = DefaultRules,
|
||||
book: Option[PolyglotBook] = None
|
||||
) extends Bot:
|
||||
|
||||
private val search: HybridSearch = HybridSearch(rules)
|
||||
|
||||
override val name: String = s"HybridBot(${difficulty.toString})"
|
||||
|
||||
override def nextMove(context: GameContext): Option[Move] =
|
||||
book.flatMap(_.probe(context))
|
||||
.orElse(search.bestMove(context))
|
||||
@@ -7,9 +7,9 @@ import java.nio.ByteOrder
|
||||
|
||||
class NNUE:
|
||||
|
||||
private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias) = loadWeights()
|
||||
private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) = loadWeights()
|
||||
|
||||
private def loadWeights(): (Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float]) =
|
||||
private def loadWeights(): (Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float]) =
|
||||
val stream = getClass.getResourceAsStream("/nnue_weights.bin")
|
||||
if stream == null then
|
||||
throw RuntimeException("NNUE weights file not found in resources")
|
||||
@@ -35,8 +35,12 @@ class NNUE:
|
||||
val l2b = readTensor(buffer)
|
||||
val l3w = readTensor(buffer)
|
||||
val l3b = readTensor(buffer)
|
||||
val l4w = readTensor(buffer)
|
||||
val l4b = readTensor(buffer)
|
||||
val l5w = readTensor(buffer)
|
||||
val l5b = readTensor(buffer)
|
||||
|
||||
(l1w, l1b, l2w, l2b, l3w, l3b)
|
||||
(l1w, l1b, l2w, l2b, l3w, l3b, l4w, l4b, l5w, l5b)
|
||||
finally stream.close()
|
||||
|
||||
private def readTensor(buffer: ByteBuffer): Array[Float] =
|
||||
@@ -55,10 +59,12 @@ class NNUE:
|
||||
floats(i) = buffer.getFloat()
|
||||
floats
|
||||
|
||||
// Pre-allocated buffers for inference
|
||||
// Pre-allocated buffers for inference (architecture: 768→1536→1024→512→256→1)
|
||||
private val features = new Array[Float](768)
|
||||
private val l1Output = new Array[Float](256)
|
||||
private val l2Output = new Array[Float](32)
|
||||
private val l1Output = new Array[Float](1536)
|
||||
private val l2Output = new Array[Float](1024)
|
||||
private val l3Output = new Array[Float](512)
|
||||
private val l4Output = new Array[Float](256)
|
||||
|
||||
/** Convert a position to 768-dimensional binary feature vector.
|
||||
* 12 piece types (white pawn to black king) × 64 squares from white's perspective. */
|
||||
@@ -110,28 +116,43 @@ class NNUE:
|
||||
|
||||
/** Run NNUE inference on the given position.
|
||||
* Returns centipawn score from the perspective of the side-to-move.
|
||||
* No allocations in the hot path (uses pre-allocated buffers). */
|
||||
* No allocations in the hot path (uses pre-allocated buffers).
|
||||
* Architecture: 768→1536→1024→512→256→1 */
|
||||
def evaluate(context: GameContext): Int =
|
||||
val features = positionToFeatures(context.board, context.turn)
|
||||
|
||||
// Layer 1: Dense(768 -> 256) + ReLU
|
||||
for i <- 0 until 256 do
|
||||
// Layer 1: Dense(768 → 1536) + ReLU
|
||||
for i <- 0 until 1536 do
|
||||
var sum = l1Bias(i)
|
||||
for j <- 0 until 768 do
|
||||
sum += features(j) * l1Weights(i * 768 + j)
|
||||
l1Output(i) = if sum > 0f then sum else 0f
|
||||
|
||||
// Layer 2: Dense(256 -> 32) + ReLU
|
||||
for i <- 0 until 32 do
|
||||
// Layer 2: Dense(1536 → 1024) + ReLU
|
||||
for i <- 0 until 1024 do
|
||||
var sum = l2Bias(i)
|
||||
for j <- 0 until 256 do
|
||||
sum += l1Output(j) * l2Weights(i * 256 + j)
|
||||
for j <- 0 until 1536 do
|
||||
sum += l1Output(j) * l2Weights(i * 1536 + j)
|
||||
l2Output(i) = if sum > 0f then sum else 0f
|
||||
|
||||
// Layer 3: Dense(32 -> 1), no activation
|
||||
var output = l3Bias(0)
|
||||
for j <- 0 until 32 do
|
||||
output += l2Output(j) * l3Weights(j)
|
||||
// Layer 3: Dense(1024 → 512) + ReLU
|
||||
for i <- 0 until 512 do
|
||||
var sum = l3Bias(i)
|
||||
for j <- 0 until 1024 do
|
||||
sum += l2Output(j) * l3Weights(i * 1024 + j)
|
||||
l3Output(i) = if sum > 0f then sum else 0f
|
||||
|
||||
// Layer 4: Dense(512 → 256) + ReLU
|
||||
for i <- 0 until 256 do
|
||||
var sum = l4Bias(i)
|
||||
for j <- 0 until 512 do
|
||||
sum += l3Output(j) * l4Weights(i * 512 + j)
|
||||
l4Output(i) = if sum > 0f then sum else 0f
|
||||
|
||||
// Layer 5: Dense(256 → 1), no activation
|
||||
var output = l5Bias(0)
|
||||
for j <- 0 until 256 do
|
||||
output += l4Output(j) * l5Weights(j)
|
||||
|
||||
// Convert from tanh-normalized output back to centipawns
|
||||
// Training uses: eval_normalized = tanh(eval_cp / 300)
|
||||
@@ -145,3 +166,33 @@ class NNUE:
|
||||
(300f * atanh).toInt
|
||||
|
||||
math.max(-20000, math.min(20000, cp))
|
||||
|
||||
/** Benchmark: time 1M evaluations and report ns/eval.
|
||||
* This measures the performance of the inference on the starting position. */
|
||||
def benchmark(): Unit =
|
||||
val context = GameContext.initial
|
||||
val iterations = 1_000_000
|
||||
|
||||
// Warm up
|
||||
for _ <- 0 until 10000 do
|
||||
evaluate(context)
|
||||
|
||||
// Actual benchmark
|
||||
val startNanos = System.nanoTime()
|
||||
for _ <- 0 until iterations do
|
||||
evaluate(context)
|
||||
val endNanos = System.nanoTime()
|
||||
|
||||
val totalNanos = endNanos - startNanos
|
||||
val nanosPerEval = totalNanos.toDouble / iterations
|
||||
|
||||
println()
|
||||
println("=" * 60)
|
||||
println("NNUE BENCHMARK RESULTS")
|
||||
println("=" * 60)
|
||||
println(f"Iterations: $iterations%,d")
|
||||
println(f"Total time: ${totalNanos / 1e9}%.2f seconds")
|
||||
println(f"ns/eval: $nanosPerEval%.2f ns")
|
||||
println(f"evals/second: ${1e9 / nanosPerEval}%.0f evals/s")
|
||||
println("=" * 60)
|
||||
println()
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
package de.nowchess.bot.logic
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.Config
|
||||
import de.nowchess.bot.bots.classic.EvaluationClassic
|
||||
import de.nowchess.bot.bots.nnue.EvaluationNNUE
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
import scala.util.boundary
|
||||
import scala.util.boundary.break
|
||||
|
||||
final class HybridSearch(
|
||||
rules: RuleSet = DefaultRules
|
||||
):
|
||||
|
||||
private var vetoCount = 0
|
||||
private var approvalCount = 0
|
||||
private val TOP_MOVES_TO_VALIDATE = 10
|
||||
|
||||
/** Find the best move by scoring all legal moves with NNUE, then validating top 5 with classical eval.
|
||||
* If a move's classical score is within VETO_THRESHOLD of its NNUE score, it's approved.
|
||||
* If all top 5 are vetoed, fall back to the best classical move overall.
|
||||
*/
|
||||
def bestMove(context: GameContext): Option[Move] =
|
||||
val legalMoves = rules.allLegalMoves(context)
|
||||
if legalMoves.isEmpty then None else findBestMove(legalMoves, context)
|
||||
|
||||
private def findBestMove(legalMoves: List[Move], context: GameContext): Option[Move] =
|
||||
// Score all moves with NNUE
|
||||
val moveScores = legalMoves.map { move =>
|
||||
val nextContext = rules.applyMove(context)(move)
|
||||
val nnueScore = EvaluationNNUE.evaluate(nextContext)
|
||||
(move, nnueScore, nextContext)
|
||||
}
|
||||
|
||||
// Sort by NNUE score descending
|
||||
val sortedByNNUE = moveScores.sortBy(_._2).reverse
|
||||
|
||||
// Validate top N moves with classical evaluation
|
||||
val topMovesToCheck = sortedByNNUE.take(TOP_MOVES_TO_VALIDATE)
|
||||
|
||||
boundary:
|
||||
for (move, nnueScore, nextContext) <- topMovesToCheck do
|
||||
val classicalScore = EvaluationClassic.evaluate(nextContext)
|
||||
val difference = (classicalScore - nnueScore).abs
|
||||
if difference <= Config.VETO_THRESHOLD then
|
||||
approvalCount += 1
|
||||
println(s"[HybridSearch] Move approved: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference)")
|
||||
break(Some(move))
|
||||
else
|
||||
vetoCount += 1
|
||||
println(s"[HybridSearch] Move vetoed: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference > ${Config.VETO_THRESHOLD})")
|
||||
|
||||
// All top 10 were vetoed, fall back to best classical move
|
||||
println(s"[HybridSearch] All top 10 NNUE moves vetoed. Falling back to best classical move.")
|
||||
val bestByClassical = moveScores
|
||||
.map { case (move, _, nextContext) =>
|
||||
(move, EvaluationClassic.evaluate(nextContext))
|
||||
}
|
||||
.maxBy(_._2)
|
||||
|
||||
println(s"[HybridSearch] Fallback move: ${bestByClassical._1} (Classical score=${bestByClassical._2})")
|
||||
println(s"[HybridSearch] Stats - Approvals: $approvalCount, Vetoes: $vetoCount")
|
||||
Some(bestByClassical._1)
|
||||
|
||||
def getStats: (Int, Int) = (approvalCount, vetoCount)
|
||||
Reference in New Issue
Block a user