feat: Implement dataset versioning and management for NNUE training data
This commit is contained in:
@@ -1,67 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Export NNUE weights to binary format for runtime loading."""
|
||||
"""Export NNUE weights to .nbai format for runtime loading."""
|
||||
|
||||
import torch
|
||||
import json
|
||||
import struct
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
def export_weights_to_binary(weights_file, output_file):
|
||||
"""Load PyTorch weights and export as binary file."""
|
||||
import torch
|
||||
|
||||
MAGIC = 0x4942_414E # bytes 'N','B','A','I' as little-endian int32
|
||||
VERSION = 1
|
||||
|
||||
|
||||
def _read_sidecar(weights_file: str) -> dict:
|
||||
sidecar = weights_file.replace(".pt", "_metadata.json")
|
||||
if Path(sidecar).exists():
|
||||
with open(sidecar) as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
|
||||
def _infer_layers(state_dict: dict) -> list[dict]:
|
||||
"""Derive layer descriptors from state_dict weight shapes.
|
||||
|
||||
Assumes layers named l1, l2, ..., lN.
|
||||
All hidden layers get activation 'relu'; the last gets 'linear'.
|
||||
"""
|
||||
names = sorted(
|
||||
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
|
||||
key=lambda n: int(n[1:]),
|
||||
)
|
||||
layers = []
|
||||
for i, name in enumerate(names):
|
||||
out_size, in_size = state_dict[f"{name}.weight"].shape
|
||||
activation = "linear" if i == len(names) - 1 else "relu"
|
||||
layers.append({"activation": activation, "inputSize": int(in_size), "outputSize": int(out_size)})
|
||||
return layers
|
||||
|
||||
|
||||
def _write_floats(f, tensor):
|
||||
data = tensor.float().flatten().cpu().numpy()
|
||||
f.write(struct.pack("<I", len(data)))
|
||||
f.write(struct.pack(f"<{len(data)}f", *data))
|
||||
|
||||
|
||||
def export_to_nbai(
|
||||
weights_file: str,
|
||||
output_file: str,
|
||||
trained_by: str = "unknown",
|
||||
train_loss: float = 0.0,
|
||||
):
|
||||
if not Path(weights_file).exists():
|
||||
print(f"Error: Weights file not found at {weights_file}")
|
||||
print(f"Error: weights file not found at {weights_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load weights — handle both raw state dicts and full training checkpoints
|
||||
loaded = torch.load(weights_file, map_location='cpu')
|
||||
state_dict = loaded["model_state_dict"] if isinstance(loaded, dict) and "model_state_dict" in loaded else loaded
|
||||
loaded = torch.load(weights_file, map_location="cpu")
|
||||
state_dict = (
|
||||
loaded["model_state_dict"]
|
||||
if isinstance(loaded, dict) and "model_state_dict" in loaded
|
||||
else loaded
|
||||
)
|
||||
|
||||
# Debug: print available layers
|
||||
print(f"Available layers in {weights_file}:")
|
||||
for key in sorted(state_dict.keys()):
|
||||
print(f" {key}: {state_dict[key].shape}")
|
||||
sidecar = _read_sidecar(weights_file)
|
||||
val_loss = float(loaded.get("best_val_loss", sidecar.get("final_val_loss", 0.0))) if isinstance(loaded, dict) else 0.0
|
||||
trained_at = sidecar.get("date", datetime.now().isoformat())
|
||||
training_data_count = int(sidecar.get("num_positions", 0))
|
||||
|
||||
# Create output directory if needed
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
metadata = {
|
||||
"trainedBy": trained_by,
|
||||
"trainedAt": trained_at,
|
||||
"trainingDataCount": training_data_count,
|
||||
"valLoss": val_loss,
|
||||
"trainLoss": train_loss,
|
||||
}
|
||||
|
||||
with open(output_file, 'wb') as f:
|
||||
# Write magic number and version
|
||||
f.write(b'NNUE')
|
||||
f.write(struct.pack('<I', 1)) # version 1
|
||||
layers = _infer_layers(state_dict)
|
||||
layer_names = sorted(
|
||||
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
|
||||
key=lambda n: int(n[1:]),
|
||||
)
|
||||
|
||||
# Write each weight tensor in order
|
||||
for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias', 'l4.weight', 'l4.bias', 'l5.weight', 'l5.bias']:
|
||||
if layer_name not in state_dict:
|
||||
print(f"Error: Missing layer {layer_name}")
|
||||
sys.exit(1)
|
||||
print(f"Architecture ({len(layers)} layers):")
|
||||
for i, l in enumerate(layers):
|
||||
print(f" l{i + 1}: {l['inputSize']} -> {l['outputSize']} [{l['activation']}]")
|
||||
|
||||
tensor = state_dict[layer_name]
|
||||
# Convert to float32 and flatten
|
||||
data = tensor.float().flatten().cpu().numpy()
|
||||
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write shape (allows validation on load)
|
||||
shape = list(tensor.shape)
|
||||
f.write(struct.pack('<I', len(shape)))
|
||||
for dim in shape:
|
||||
f.write(struct.pack('<I', dim))
|
||||
with open(output_file, "wb") as f:
|
||||
# Header
|
||||
f.write(struct.pack("<I", MAGIC))
|
||||
f.write(struct.pack("<H", VERSION))
|
||||
|
||||
# Write flattened data as binary floats
|
||||
f.write(struct.pack(f'<{len(data)}f', *data))
|
||||
# Metadata (length-prefixed UTF-8 JSON)
|
||||
meta_bytes = json.dumps(metadata, indent=2).encode("utf-8")
|
||||
f.write(struct.pack("<I", len(meta_bytes)))
|
||||
f.write(meta_bytes)
|
||||
|
||||
print(f" {layer_name}: shape {shape}, {len(data)} floats")
|
||||
# Layer descriptors
|
||||
f.write(struct.pack("<H", len(layers)))
|
||||
for layer in layers:
|
||||
name_bytes = layer["activation"].encode("ascii")
|
||||
f.write(struct.pack("<B", len(name_bytes)))
|
||||
f.write(name_bytes)
|
||||
f.write(struct.pack("<I", layer["inputSize"]))
|
||||
f.write(struct.pack("<I", layer["outputSize"]))
|
||||
|
||||
# Weights: weight tensor then bias tensor per layer
|
||||
for name in layer_names:
|
||||
w = state_dict[f"{name}.weight"]
|
||||
b = state_dict[f"{name}.bias"]
|
||||
_write_floats(f, w)
|
||||
_write_floats(f, b)
|
||||
print(f" Wrote {name}: weight {tuple(w.shape)}, bias {tuple(b.shape)}")
|
||||
|
||||
size_mb = Path(output_file).stat().st_size / (1024 ** 2)
|
||||
print(f"\nExported to {output_file} ({size_mb:.2f} MB)")
|
||||
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
|
||||
file_size_mb = output_path.stat().st_size / (1024**2)
|
||||
print(f"Weights exported to {output_file} ({file_size_mb:.2f} MB)")
|
||||
|
||||
if __name__ == "__main__":
|
||||
weights_file = "nnue_weights.pt"
|
||||
output_file = "../src/main/resources/nnue_weights.bin"
|
||||
output_file = "../src/main/resources/nnue_weights.nbai"
|
||||
trained_by = "unknown"
|
||||
train_loss = 0.0
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
weights_file = sys.argv[1]
|
||||
if len(sys.argv) > 2:
|
||||
output_file = sys.argv[2]
|
||||
if len(sys.argv) > 3:
|
||||
trained_by = sys.argv[3]
|
||||
if len(sys.argv) > 4:
|
||||
train_loss = float(sys.argv[4])
|
||||
|
||||
export_weights_to_binary(weights_file, output_file)
|
||||
export_to_nbai(weights_file, output_file, trained_by, train_loss)
|
||||
|
||||
Reference in New Issue
Block a user