feat: Implement dataset versioning and management for NNUE training data

This commit is contained in:
2026-04-13 21:19:26 +02:00
parent 4b52199754
commit 8fb872e958
18 changed files with 1399 additions and 335 deletions
@@ -6,7 +6,7 @@ import de.nowchess.bot.ai.Evaluation
object EvaluationNNUE extends Evaluation:
private val nnue = NNUE()
private val nnue = NNUE(NbaiLoader.loadDefault())
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
@@ -3,84 +3,31 @@ package de.nowchess.bot.bots.nnue
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
import java.nio.ByteBuffer
import java.nio.ByteOrder
class NNUE:
class NNUE(model: NbaiModel):
private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) =
loadWeights()
private val featureSize = model.layers(0).inputSize
private val accSize = model.layers(0).outputSize
// Column-major L1 weights for cache-friendly sparse & incremental updates.
// l1WeightsT(featureIdx * 1536 + outputIdx) = l1Weights(outputIdx * 768 + featureIdx)
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
private val l1WeightsT: Array[Float] =
val t = new Array[Float](768 * 1536)
for j <- 0 until 768; i <- 0 until 1536 do t(j * 1536 + i) = l1Weights(i * 768 + j)
val w = model.weights(0).weights
val t = new Array[Float](featureSize * accSize)
for j <- 0 until featureSize; i <- 0 until accSize do t(j * accSize + i) = w(i * featureSize + j)
t
private def loadWeights(): (
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
Array[Float],
) =
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
.getOrElse(sys.error("NNUE weights file not found in resources"))
try
val bytes = stream.readAllBytes()
val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN)
val magic = buffer.getInt()
if magic != 0x4555_4e4e then sys.error(s"Invalid magic number: 0x${magic.toHexString}")
val version = buffer.getInt()
if version != 1 then sys.error(s"Unsupported weight version: $version")
val l1w = readTensor(buffer)
val l1b = readTensor(buffer)
val l2w = readTensor(buffer)
val l2b = readTensor(buffer)
val l3w = readTensor(buffer)
val l3b = readTensor(buffer)
val l4w = readTensor(buffer)
val l4b = readTensor(buffer)
val l5w = readTensor(buffer)
val l5b = readTensor(buffer)
(l1w, l1b, l2w, l2b, l3w, l3b, l4w, l4b, l5w, l5b)
finally stream.close()
private def readTensor(buffer: ByteBuffer): Array[Float] =
val shapeLen = buffer.getInt()
val shape = Array.ofDim[Int](shapeLen)
for i <- 0 until shapeLen do shape(i) = buffer.getInt()
val totalElements = shape.product
val floats = Array.ofDim[Float](totalElements)
for i <- 0 until totalElements do floats(i) = buffer.getFloat()
floats
// ── Accumulator stack ────────────────────────────────────────────────────
// l1Stack(ply) holds the L1 pre-activations (before ReLU) for that ply.
// Initialised once at root; each child ply is derived incrementally.
private val MAX_PLY = 128
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](1536))
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](accSize))
// Shared buffers for the dense L2-L5 layers (single-threaded, non-reentrant).
private val l1ReLU = new Array[Float](1536)
private val l2Output = new Array[Float](1024)
private val l3Output = new Array[Float](512)
private val l4Output = new Array[Float](256)
// Shared evaluation buffers: index i holds the output of layers(i) (all except the scalar output layer).
private val evalBuffers: Array[Array[Float]] = model.layers.init.map(l => new Array[Float](l.outputSize))
// ── Eval cache ───────────────────────────────────────────────────────────
private val EVAL_CACHE_MASK = (1 << 18) - 1L // 256 K slots ≈ 3 MB
private val EVAL_CACHE_MASK = (1 << 18) - 1L
private val evalCacheHashes = new Array[Long](1 << 18)
private val evalCacheScores = new Array[Int](1 << 18)
@@ -93,35 +40,32 @@ class NNUE:
(colorOffset + piece.pieceType.ordinal) * 64 + sqNum
private def addColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
val offset = featureIdx * 1536
for i <- 0 until 1536 do l1Pre(i) += l1WeightsT(offset + i)
val offset = featureIdx * accSize
for i <- 0 until accSize do l1Pre(i) += l1WeightsT(offset + i)
private def subtractColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
val offset = featureIdx * 1536
for i <- 0 until 1536 do l1Pre(i) -= l1WeightsT(offset + i)
val offset = featureIdx * accSize
for i <- 0 until accSize do l1Pre(i) -= l1WeightsT(offset + i)
// ── Accumulator init ─────────────────────────────────────────────────────
/** Initialise l1Stack(0) from scratch using sparse active features. */
def initAccumulator(board: Board): Unit =
System.arraycopy(l1Bias, 0, l1Stack(0), 0, 1536)
System.arraycopy(model.weights(0).bias, 0, l1Stack(0), 0, accSize)
for (sq, piece) <- board.pieces do addColumn(l1Stack(0), featureIndex(piece, squareNum(sq)))
// ── Accumulator push (incremental updates) ───────────────────────────────
/** Copy parent ply's pre-activations to childPly, then apply move deltas. */
def pushAccumulator(childPly: Int, move: Move, board: Board): Unit =
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, 1536)
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize)
val l1 = l1Stack(childPly)
move.moveType match
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
/** Copy pre-activations from parentPly to childPly without any move delta (null-move). */
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, 1536)
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
board.pieceAt(move.from).foreach { mover =>
@@ -170,9 +114,6 @@ class NNUE:
// ── Evaluation from accumulator ──────────────────────────────────────────
/** Evaluate from pre-computed L1 pre-activations at the given ply. Probes eval cache first; stores result after
* computation.
*/
def evaluateAtPly(ply: Int, turn: Color, hash: Long): Int =
val idx = (hash & EVAL_CACHE_MASK).toInt
if evalCacheHashes(idx) == hash then evalCacheScores(idx)
@@ -183,11 +124,19 @@ class NNUE:
score
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
for i <- 0 until 1536 do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
runDenseReLU(l1ReLU, 1536, l2Weights, l2Bias, l2Output, 1024)
runDenseReLU(l2Output, 1024, l3Weights, l3Bias, l3Output, 512)
runDenseReLU(l3Output, 512, l4Weights, l4Bias, l4Output, 256)
val output = runOutputLayer(l4Output, 256)
val l1ReLU = evalBuffers(0)
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
var input = l1ReLU
for i <- 1 until model.layers.length - 1 do
val lw = model.weights(i)
val out = evalBuffers(i)
val ld = model.layers(i)
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
input = out
val lastIdx = model.layers.length - 1
val output = runOutputLayer(input, model.layers(lastIdx).inputSize, model.weights(lastIdx))
scoreFromOutput(output, turn)
private def runDenseReLU(
@@ -202,8 +151,8 @@ class NNUE:
val sum = (0 until inSize).foldLeft(bias(i))((s, j) => s + input(j) * weights(i * inSize + j))
output(i) = if sum > 0f then sum else 0f
private def runOutputLayer(input: Array[Float], inSize: Int): Float =
(0 until inSize).foldLeft(l5Bias(0))((sum, j) => sum + input(j) * l5Weights(j))
private def runOutputLayer(input: Array[Float], inSize: Int, lw: LayerWeights): Float =
(0 until inSize).foldLeft(lw.bias(0))((sum, j) => sum + input(j) * lw.weights(j))
private def scoreFromOutput(output: Float, turn: Color): Int =
val cp =
@@ -214,21 +163,15 @@ class NNUE:
val cpFromTurn = if turn == Color.Black then -cp else cp
math.max(-20000, math.min(20000, cpFromTurn))
// ── Legacy full-board evaluate (kept for Evaluation.evaluate compatibility) ──
// ── Legacy full-board evaluate ────────────────────────────────────────────
// Pre-allocated buffers used only by the legacy evaluate path.
private val features = new Array[Float](768)
private val legacyL1 = new Array[Float](1536)
private val legacyL1 = new Array[Float](accSize)
/** Evaluate using full board scan (sparse over active features). Layout: black pieces at indices 0-5, white at 6-11.
*/
def evaluate(context: GameContext): Int =
val l1Pre = legacyL1
System.arraycopy(l1Bias, 0, l1Pre, 0, 1536)
for (sq, piece) <- context.board.pieces do addColumn(l1Pre, featureIndex(piece, squareNum(sq)))
runL2toOutput(l1Pre, context.turn)
System.arraycopy(model.weights(0).bias, 0, legacyL1, 0, accSize)
for (sq, piece) <- context.board.pieces do addColumn(legacyL1, featureIndex(piece, squareNum(sq)))
runL2toOutput(legacyL1, context.turn)
/** Benchmark: time 1M evaluations and report ns/eval. */
def benchmark(): Unit =
val context = GameContext.initial
val iterations = 1_000_000
@@ -0,0 +1,50 @@
package de.nowchess.bot.bots.nnue
import java.io.InputStream
import java.nio.{ByteBuffer, ByteOrder}
import java.nio.charset.StandardCharsets
object NbaiLoader:
/** Little-endian encoding of ASCII bytes 'N','B','A','I'. */
val MAGIC: Int = 0x4942_414e
def load(stream: InputStream): NbaiModel =
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
checkHeader(buf)
val metadata = readMetadata(buf)
val descs = readLayerDescriptors(buf)
val weights = descs.map(_ => readLayerWeights(buf))
NbaiModel(metadata, descs, weights)
/** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */
def loadDefault(): NbaiModel =
Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match
case Some(s) => try load(s) finally s.close()
case None => NbaiMigrator.migrateFromBin()
private def checkHeader(buf: ByteBuffer): Unit =
val magic = buf.getInt()
if magic != MAGIC then sys.error(s"Invalid NBAI magic: 0x${magic.toHexString}")
val version = buf.getShort() & 0xffff
if version != 1 then sys.error(s"Unsupported NBAI version: $version")
private def readMetadata(buf: ByteBuffer): NbaiMetadata =
val bytes = new Array[Byte](buf.getInt())
buf.get(bytes)
NbaiMetadata.fromJson(new String(bytes, StandardCharsets.UTF_8))
private def readLayerDescriptors(buf: ByteBuffer): Array[LayerDescriptor] =
Array.tabulate(buf.getShort() & 0xffff) { _ =>
val nameBytes = new Array[Byte](buf.get() & 0xff)
buf.get(nameBytes)
LayerDescriptor(new String(nameBytes, StandardCharsets.US_ASCII), buf.getInt(), buf.getInt())
}
private def readLayerWeights(buf: ByteBuffer): LayerWeights =
LayerWeights(readFloats(buf), readFloats(buf))
private def readFloats(buf: ByteBuffer): Array[Float] =
val arr = new Array[Float](buf.getInt())
for i <- arr.indices do arr(i) = buf.getFloat()
arr
@@ -0,0 +1,43 @@
package de.nowchess.bot.bots.nnue
import java.nio.{ByteBuffer, ByteOrder}
/** Converts the legacy nnue_weights.bin resource into an NbaiModel. Used as fallback when no .nbai file exists. */
object NbaiMigrator:
private val BinMagic = 0x4555_4e4e
private val BinVersion = 1
private val DefaultLayers: Array[LayerDescriptor] = Array(
LayerDescriptor("relu", 768, 1536),
LayerDescriptor("relu", 1536, 1024),
LayerDescriptor("relu", 1024, 512),
LayerDescriptor("relu", 512, 256),
LayerDescriptor("linear", 256, 1),
)
private val UnknownMetadata: NbaiMetadata =
NbaiMetadata(trainedBy = "unknown", trainedAt = "unknown", trainingDataCount = 0L, valLoss = 0.0, trainLoss = 0.0)
def migrateFromBin(): NbaiModel =
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
.getOrElse(sys.error("Neither nnue_weights.nbai nor nnue_weights.bin found in resources"))
try
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
checkBinHeader(buf)
val weights = DefaultLayers.map(_ => readBinLayerWeights(buf))
NbaiModel(UnknownMetadata, DefaultLayers, weights)
finally stream.close()
private def checkBinHeader(buf: ByteBuffer): Unit =
val magic = buf.getInt()
if magic != BinMagic then sys.error(s"Invalid bin magic: 0x${magic.toHexString}")
val version = buf.getInt()
if version != BinVersion then sys.error(s"Unsupported bin version: $version")
private def readBinLayerWeights(buf: ByteBuffer): LayerWeights =
LayerWeights(readBinTensor(buf), readBinTensor(buf))
private def readBinTensor(buf: ByteBuffer): Array[Float] =
val shape = Array.tabulate(buf.getInt())(_ => buf.getInt())
Array.tabulate(shape.product)(_ => buf.getFloat())
@@ -0,0 +1,39 @@
package de.nowchess.bot.bots.nnue
/** Descriptor for a single dense layer stored in a .nbai file. */
case class LayerDescriptor(activation: String, inputSize: Int, outputSize: Int)
/** Training metadata embedded in every .nbai file. */
case class NbaiMetadata(
trainedBy: String,
trainedAt: String,
trainingDataCount: Long,
valLoss: Double,
trainLoss: Double,
):
def toJson: String =
s"""{
| "trainedBy": "$trainedBy",
| "trainedAt": "$trainedAt",
| "trainingDataCount": $trainingDataCount,
| "valLoss": $valLoss,
| "trainLoss": $trainLoss
|}""".stripMargin
object NbaiMetadata:
def fromJson(json: String): NbaiMetadata =
def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("")
def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0")
NbaiMetadata(str("trainedBy"), str("trainedAt"), num("trainingDataCount").toLong, num("valLoss").toDouble, num("trainLoss").toDouble)
/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */
case class LayerWeights(weights: Array[Float], bias: Array[Float])
/** A fully deserialized .nbai model ready to initialize NNUE. */
case class NbaiModel(
metadata: NbaiMetadata,
layers: Array[LayerDescriptor],
weights: Array[LayerWeights],
):
require(layers.length == weights.length, "Layer count must match weight count")
require(layers.length >= 2, "Model must have at least 2 layers")
@@ -0,0 +1,51 @@
package de.nowchess.bot.bots.nnue
import java.io.{ByteArrayOutputStream, OutputStream}
import java.nio.{ByteBuffer, ByteOrder}
import java.nio.charset.StandardCharsets
object NbaiWriter:
def write(model: NbaiModel, out: OutputStream): Unit =
val acc = new ByteArrayOutputStream()
writeHeader(acc)
writeMetadata(acc, model.metadata)
writeLayerDescriptors(acc, model.layers)
model.weights.foreach(lw => writeLayerWeights(acc, lw))
out.write(acc.toByteArray)
private def writeHeader(out: ByteArrayOutputStream): Unit =
val buf = ByteBuffer.allocate(6).order(ByteOrder.LITTLE_ENDIAN)
buf.putInt(NbaiLoader.MAGIC)
buf.putShort(1.toShort)
out.write(buf.array())
private def writeMetadata(out: ByteArrayOutputStream, meta: NbaiMetadata): Unit =
val json = meta.toJson.getBytes(StandardCharsets.UTF_8)
val buf = ByteBuffer.allocate(4 + json.length).order(ByteOrder.LITTLE_ENDIAN)
buf.putInt(json.length)
buf.put(json)
out.write(buf.array())
private def writeLayerDescriptors(out: ByteArrayOutputStream, layers: Array[LayerDescriptor]): Unit =
val nameBytes = layers.map(_.activation.getBytes(StandardCharsets.US_ASCII))
val capacity = 2 + layers.indices.map(i => 1 + nameBytes(i).length + 8).sum
val buf = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN)
buf.putShort(layers.length.toShort)
layers.zip(nameBytes).foreach { (l, nb) =>
buf.put(nb.length.toByte)
buf.put(nb)
buf.putInt(l.inputSize)
buf.putInt(l.outputSize)
}
out.write(buf.array())
private def writeLayerWeights(out: ByteArrayOutputStream, lw: LayerWeights): Unit =
writeFloats(out, lw.weights)
writeFloats(out, lw.bias)
private def writeFloats(out: ByteArrayOutputStream, floats: Array[Float]): Unit =
val buf = ByteBuffer.allocate(4 + floats.length * 4).order(ByteOrder.LITTLE_ENDIAN)
buf.putInt(floats.length)
floats.foreach(buf.putFloat)
out.write(buf.array())