Files
NowChessSystems/modules/bot/python/export_weights.py
T

65 lines
2.1 KiB
Python

#!/usr/bin/env python3
"""Export NNUE weights to Scala code."""
import torch
import sys
from pathlib import Path
def export_weights_to_scala(weights_file, output_file):
"""Load PyTorch weights and export as Scala code."""
if not Path(weights_file).exists():
print(f"Error: Weights file not found at {weights_file}")
sys.exit(1)
# Load weights (weights_only=False for compatibility with older PyTorch versions)
state_dict = torch.load(weights_file, map_location='cpu')
# Create output directory if needed
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'w') as f:
f.write("package de.nowchess.bot.bots.nnue\n\n")
f.write("object NNUEWeights:\n")
for layer_name, tensor in sorted(state_dict.items()):
# Sanitize name
safe_name = layer_name.replace('.', '_').replace(' ', '_')
# Convert tensor to flat list
values = tensor.flatten().tolist()
# Format as Scala array
f.write(f"\n val {safe_name} = Array(\n")
# Write values in chunks for readability
chunk_size = 16
for i in range(0, len(values), chunk_size):
chunk = values[i:i + chunk_size]
formatted_chunk = ", ".join(f"{v:.10g}f" for v in chunk)
f.write(f" {formatted_chunk}")
if i + chunk_size < len(values):
f.write(",\n")
else:
f.write("\n")
f.write(f" )\n")
# Store shape for reference
shape = list(tensor.shape)
f.write(f" // Shape: {shape}\n")
print(f"Weights exported to {output_file}")
if __name__ == "__main__":
weights_file = "nnue_weights.pt"
output_file = "../src/main/scala/de/nowchess/bot/bots/nnue/NNUEWeights.scala"
if len(sys.argv) > 1:
weights_file = sys.argv[1]
if len(sys.argv) > 2:
output_file = sys.argv[2]
export_weights_to_scala(weights_file, output_file)