feat: refactor AlphaBetaSearch and ClassicalBot for improved evaluation and organization
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Export NNUE weights to Scala code."""
|
||||
|
||||
import torch
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
def export_weights_to_scala(weights_file, output_file):
|
||||
"""Load PyTorch weights and export as Scala code."""
|
||||
|
||||
if not Path(weights_file).exists():
|
||||
print(f"Error: Weights file not found at {weights_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load weights (weights_only=False for compatibility with older PyTorch versions)
|
||||
state_dict = torch.load(weights_file, map_location='cpu')
|
||||
|
||||
# Create output directory if needed
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
f.write("package de.nowchess.bot.bots.nnue\n\n")
|
||||
f.write("object NNUEWeights:\n")
|
||||
|
||||
for layer_name, tensor in sorted(state_dict.items()):
|
||||
# Sanitize name
|
||||
safe_name = layer_name.replace('.', '_').replace(' ', '_')
|
||||
|
||||
# Convert tensor to flat list
|
||||
values = tensor.flatten().tolist()
|
||||
|
||||
# Format as Scala array
|
||||
f.write(f"\n val {safe_name} = Array(\n")
|
||||
|
||||
# Write values in chunks for readability
|
||||
chunk_size = 16
|
||||
for i in range(0, len(values), chunk_size):
|
||||
chunk = values[i:i + chunk_size]
|
||||
formatted_chunk = ", ".join(f"{v:.10g}f" for v in chunk)
|
||||
f.write(f" {formatted_chunk}")
|
||||
if i + chunk_size < len(values):
|
||||
f.write(",\n")
|
||||
else:
|
||||
f.write("\n")
|
||||
|
||||
f.write(f" )\n")
|
||||
|
||||
# Store shape for reference
|
||||
shape = list(tensor.shape)
|
||||
f.write(f" // Shape: {shape}\n")
|
||||
|
||||
print(f"Weights exported to {output_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
weights_file = "nnue_weights.pt"
|
||||
output_file = "../src/main/scala/de/nowchess/bot/bots/nnue/NNUEWeights.scala"
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
weights_file = sys.argv[1]
|
||||
if len(sys.argv) > 2:
|
||||
output_file = sys.argv[2]
|
||||
|
||||
export_weights_to_scala(weights_file, output_file)
|
||||
Reference in New Issue
Block a user