feat: integrate NNUE bot and add Python training pipeline with weight export functionality

This commit is contained in:
2026-04-07 23:33:20 +02:00
parent 6a9ac55b31
commit b25be99dcf
29 changed files with 338 additions and 2538 deletions
+66
View File
@@ -0,0 +1,66 @@
#!/usr/bin/env python3
"""Export NNUE weights to binary format for runtime loading."""
import torch
import struct
import sys
from pathlib import Path
def export_weights_to_binary(weights_file, output_file):
"""Load PyTorch weights and export as binary file."""
if not Path(weights_file).exists():
print(f"Error: Weights file not found at {weights_file}")
sys.exit(1)
# Load weights
state_dict = torch.load(weights_file, map_location='cpu')
# Debug: print available layers
print(f"Available layers in {weights_file}:")
for key in sorted(state_dict.keys()):
print(f" {key}: {state_dict[key].shape}")
# Create output directory if needed
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'wb') as f:
# Write magic number and version
f.write(b'NNUE')
f.write(struct.pack('<I', 1)) # version 1
# Write each weight tensor in order
for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias']:
if layer_name not in state_dict:
print(f"Error: Missing layer {layer_name}")
sys.exit(1)
tensor = state_dict[layer_name]
# Convert to float32 and flatten
data = tensor.float().flatten().cpu().numpy()
# Write shape (allows validation on load)
shape = list(tensor.shape)
f.write(struct.pack('<I', len(shape)))
for dim in shape:
f.write(struct.pack('<I', dim))
# Write flattened data as binary floats
f.write(struct.pack(f'<{len(data)}f', *data))
print(f" {layer_name}: shape {shape}, {len(data)} floats")
file_size_mb = output_path.stat().st_size / (1024**2)
print(f"Weights exported to {output_file} ({file_size_mb:.2f} MB)")
if __name__ == "__main__":
weights_file = "nnue_weights.pt"
output_file = "../src/main/resources/nnue_weights.bin"
if len(sys.argv) > 1:
weights_file = sys.argv[1]
if len(sys.argv) > 2:
output_file = sys.argv[2]
export_weights_to_binary(weights_file, output_file)