#!/usr/bin/env python3 """Export NNUE weights to Scala code.""" import torch import sys from pathlib import Path def export_weights_to_scala(weights_file, output_file): """Load PyTorch weights and export as Scala code.""" if not Path(weights_file).exists(): print(f"Error: Weights file not found at {weights_file}") sys.exit(1) # Load weights (weights_only=False for compatibility with older PyTorch versions) state_dict = torch.load(weights_file, map_location='cpu') # Create output directory if needed output_path = Path(output_file) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_file, 'w') as f: f.write("package de.nowchess.bot.bots.nnue\n\n") f.write("object NNUEWeights:\n") for layer_name, tensor in sorted(state_dict.items()): # Sanitize name safe_name = layer_name.replace('.', '_').replace(' ', '_') # Convert tensor to flat list values = tensor.flatten().tolist() # Format as Scala array f.write(f"\n val {safe_name} = Array(\n") # Write values in chunks for readability chunk_size = 16 for i in range(0, len(values), chunk_size): chunk = values[i:i + chunk_size] formatted_chunk = ", ".join(f"{v:.10g}f" for v in chunk) f.write(f" {formatted_chunk}") if i + chunk_size < len(values): f.write(",\n") else: f.write("\n") f.write(f" )\n") # Store shape for reference shape = list(tensor.shape) f.write(f" // Shape: {shape}\n") print(f"Weights exported to {output_file}") if __name__ == "__main__": weights_file = "nnue_weights.pt" output_file = "../src/main/scala/de/nowchess/bot/bots/nnue/NNUEWeights.scala" if len(sys.argv) > 1: weights_file = sys.argv[1] if len(sys.argv) > 2: output_file = sys.argv[2] export_weights_to_scala(weights_file, output_file)