67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
#!/usr/bin/env python3
|
|
"""Export NNUE weights to binary format for runtime loading."""
|
|
|
|
import torch
|
|
import struct
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
def export_weights_to_binary(weights_file, output_file):
|
|
"""Load PyTorch weights and export as binary file."""
|
|
|
|
if not Path(weights_file).exists():
|
|
print(f"Error: Weights file not found at {weights_file}")
|
|
sys.exit(1)
|
|
|
|
# Load weights
|
|
state_dict = torch.load(weights_file, map_location='cpu')
|
|
|
|
# Debug: print available layers
|
|
print(f"Available layers in {weights_file}:")
|
|
for key in sorted(state_dict.keys()):
|
|
print(f" {key}: {state_dict[key].shape}")
|
|
|
|
# Create output directory if needed
|
|
output_path = Path(output_file)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_file, 'wb') as f:
|
|
# Write magic number and version
|
|
f.write(b'NNUE')
|
|
f.write(struct.pack('<I', 1)) # version 1
|
|
|
|
# Write each weight tensor in order
|
|
for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias']:
|
|
if layer_name not in state_dict:
|
|
print(f"Error: Missing layer {layer_name}")
|
|
sys.exit(1)
|
|
|
|
tensor = state_dict[layer_name]
|
|
# Convert to float32 and flatten
|
|
data = tensor.float().flatten().cpu().numpy()
|
|
|
|
# Write shape (allows validation on load)
|
|
shape = list(tensor.shape)
|
|
f.write(struct.pack('<I', len(shape)))
|
|
for dim in shape:
|
|
f.write(struct.pack('<I', dim))
|
|
|
|
# Write flattened data as binary floats
|
|
f.write(struct.pack(f'<{len(data)}f', *data))
|
|
|
|
print(f" {layer_name}: shape {shape}, {len(data)} floats")
|
|
|
|
file_size_mb = output_path.stat().st_size / (1024**2)
|
|
print(f"Weights exported to {output_file} ({file_size_mb:.2f} MB)")
|
|
|
|
if __name__ == "__main__":
|
|
weights_file = "nnue_weights.pt"
|
|
output_file = "../src/main/resources/nnue_weights.bin"
|
|
|
|
if len(sys.argv) > 1:
|
|
weights_file = sys.argv[1]
|
|
if len(sys.argv) > 2:
|
|
output_file = sys.argv[2]
|
|
|
|
export_weights_to_binary(weights_file, output_file)
|