65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Export NNUE weights to Scala code."""
|
|
|
|
import torch
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
def export_weights_to_scala(weights_file, output_file):
|
|
"""Load PyTorch weights and export as Scala code."""
|
|
|
|
if not Path(weights_file).exists():
|
|
print(f"Error: Weights file not found at {weights_file}")
|
|
sys.exit(1)
|
|
|
|
# Load weights (weights_only=False for compatibility with older PyTorch versions)
|
|
state_dict = torch.load(weights_file, map_location='cpu')
|
|
|
|
# Create output directory if needed
|
|
output_path = Path(output_file)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_file, 'w') as f:
|
|
f.write("package de.nowchess.bot.bots.nnue\n\n")
|
|
f.write("object NNUEWeights:\n")
|
|
|
|
for layer_name, tensor in sorted(state_dict.items()):
|
|
# Sanitize name
|
|
safe_name = layer_name.replace('.', '_').replace(' ', '_')
|
|
|
|
# Convert tensor to flat list
|
|
values = tensor.flatten().tolist()
|
|
|
|
# Format as Scala array
|
|
f.write(f"\n val {safe_name} = Array(\n")
|
|
|
|
# Write values in chunks for readability
|
|
chunk_size = 16
|
|
for i in range(0, len(values), chunk_size):
|
|
chunk = values[i:i + chunk_size]
|
|
formatted_chunk = ", ".join(f"{v:.10g}f" for v in chunk)
|
|
f.write(f" {formatted_chunk}")
|
|
if i + chunk_size < len(values):
|
|
f.write(",\n")
|
|
else:
|
|
f.write("\n")
|
|
|
|
f.write(f" )\n")
|
|
|
|
# Store shape for reference
|
|
shape = list(tensor.shape)
|
|
f.write(f" // Shape: {shape}\n")
|
|
|
|
print(f"Weights exported to {output_file}")
|
|
|
|
if __name__ == "__main__":
|
|
weights_file = "nnue_weights.pt"
|
|
output_file = "../src/main/scala/de/nowchess/bot/bots/nnue/NNUEWeights.scala"
|
|
|
|
if len(sys.argv) > 1:
|
|
weights_file = sys.argv[1]
|
|
if len(sys.argv) > 2:
|
|
output_file = sys.argv[2]
|
|
|
|
export_weights_to_scala(weights_file, output_file)
|