feat: integrate NNUE bot and add Python training pipeline with weight export functionality
This commit is contained in:
+59
-32
@@ -13,44 +13,63 @@ def get_python_cmd():
|
||||
return "python"
|
||||
return "python3" if os.popen("which python3 2>/dev/null").read() else "python"
|
||||
|
||||
def get_src_module(module_name):
|
||||
"""Get path to module in src/ directory."""
|
||||
return Path(__file__).parent / "src" / f"{module_name}.py"
|
||||
|
||||
def get_data_dir():
|
||||
"""Get/create data directory."""
|
||||
data_dir = Path(__file__).parent / "data"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
return data_dir
|
||||
|
||||
def get_weights_dir():
|
||||
"""Get/create weights directory."""
|
||||
weights_dir = Path(__file__).parent / "weights"
|
||||
weights_dir.mkdir(exist_ok=True)
|
||||
return weights_dir
|
||||
|
||||
def list_checkpoints():
|
||||
"""List available checkpoint versions."""
|
||||
checkpoints = sorted(Path(".").glob("nnue_weights_v*.pt"))
|
||||
weights_dir = get_weights_dir()
|
||||
checkpoints = sorted(weights_dir.glob("nnue_weights_v*.pt"))
|
||||
if not checkpoints:
|
||||
return []
|
||||
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
||||
|
||||
def run_generate_positions(num_games):
|
||||
"""Generate random positions."""
|
||||
positions_file = "positions.txt"
|
||||
data_dir = get_data_dir()
|
||||
positions_file = data_dir / "positions.txt"
|
||||
print(f"Generating {num_games} positions...")
|
||||
result = subprocess.run(
|
||||
[get_python_cmd(), "generate_positions.py", positions_file, "--games", str(num_games)],
|
||||
[get_python_cmd(), str(get_src_module("generate")), str(positions_file), "--games", str(num_games)],
|
||||
capture_output=False
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print("ERROR: Position generation failed")
|
||||
return False
|
||||
return Path(positions_file).exists()
|
||||
return positions_file.exists()
|
||||
|
||||
def run_label_positions(stockfish_path):
|
||||
"""Label positions with Stockfish."""
|
||||
positions_file = "positions.txt"
|
||||
output_file = "training_data.jsonl"
|
||||
data_dir = get_data_dir()
|
||||
positions_file = data_dir / "positions.txt"
|
||||
output_file = data_dir / "training_data.jsonl"
|
||||
|
||||
if not Path(positions_file).exists():
|
||||
if not positions_file.exists():
|
||||
print("ERROR: positions.txt not found")
|
||||
return False
|
||||
|
||||
print("Labeling positions with Stockfish...")
|
||||
result = subprocess.run(
|
||||
[get_python_cmd(), "label_positions.py", positions_file, output_file, stockfish_path],
|
||||
[get_python_cmd(), str(get_src_module("label")), str(positions_file), str(output_file), stockfish_path],
|
||||
capture_output=False
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print("ERROR: Position labeling failed")
|
||||
return False
|
||||
return Path(output_file).exists()
|
||||
return output_file.exists()
|
||||
|
||||
def run_train(positions_file, output_weights, from_checkpoint=None):
|
||||
"""Train NNUE model."""
|
||||
@@ -58,29 +77,34 @@ def run_train(positions_file, output_weights, from_checkpoint=None):
|
||||
print(f"ERROR: {positions_file} not found")
|
||||
return False
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
print(f"Training model (output: {output_weights})...")
|
||||
if from_checkpoint:
|
||||
print(f" Starting from checkpoint: {from_checkpoint}")
|
||||
|
||||
cmd = [get_python_cmd(), "train_nnue.py", positions_file, output_weights]
|
||||
cmd = [get_python_cmd(), str(get_src_module("train")), str(positions_file), str(output_weights)]
|
||||
if from_checkpoint:
|
||||
cmd.extend(["--checkpoint", from_checkpoint])
|
||||
cmd.extend(["--checkpoint", str(from_checkpoint)])
|
||||
|
||||
result = subprocess.run(cmd, capture_output=False)
|
||||
# Run from weights directory so outputs save there
|
||||
result = subprocess.run(cmd, cwd=str(weights_dir), capture_output=False)
|
||||
if result.returncode != 0:
|
||||
print("ERROR: Training failed")
|
||||
return False
|
||||
return True # train_nnue creates versioned file, not the base name
|
||||
return True
|
||||
|
||||
def run_export(weights_file, output_file):
|
||||
"""Export weights to Scala."""
|
||||
if not Path(weights_file).exists():
|
||||
print(f"ERROR: {weights_file} not found")
|
||||
weights_dir = get_weights_dir()
|
||||
weights_path = weights_dir / Path(weights_file).name
|
||||
|
||||
if not weights_path.exists():
|
||||
print(f"ERROR: {weights_file} not found in {weights_dir}")
|
||||
return False
|
||||
|
||||
print(f"Exporting {weights_file} to Scala...")
|
||||
result = subprocess.run(
|
||||
[get_python_cmd(), "export_weights.py", weights_file, output_file],
|
||||
[get_python_cmd(), str(get_src_module("export")), str(weights_path), output_file],
|
||||
capture_output=False
|
||||
)
|
||||
if result.returncode != 0:
|
||||
@@ -91,13 +115,16 @@ def run_export(weights_file, output_file):
|
||||
def cmd_train(args):
|
||||
"""Handle train command."""
|
||||
stockfish_path = args.stockfish or os.environ.get("STOCKFISH_PATH", "/usr/games/stockfish")
|
||||
data_dir = get_data_dir()
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
# Determine checkpoint
|
||||
checkpoint = None
|
||||
if args.from_checkpoint:
|
||||
checkpoint_version = args.from_checkpoint
|
||||
checkpoint = f"nnue_weights_v{checkpoint_version}.pt"
|
||||
if not Path(checkpoint).exists():
|
||||
checkpoint_path = weights_dir / checkpoint
|
||||
if not checkpoint_path.exists():
|
||||
print(f"ERROR: Checkpoint {checkpoint} not found")
|
||||
return False
|
||||
else:
|
||||
@@ -109,12 +136,12 @@ def cmd_train(args):
|
||||
|
||||
# Generate or use existing positions
|
||||
if args.positions_file:
|
||||
if not Path(args.positions_file).exists():
|
||||
positions_file = Path(args.positions_file)
|
||||
if not positions_file.exists():
|
||||
print(f"ERROR: {args.positions_file} not found")
|
||||
return False
|
||||
positions_file = args.positions_file
|
||||
else:
|
||||
positions_file = "positions.txt"
|
||||
positions_file = data_dir / "positions.txt"
|
||||
num_games = args.games or 500000
|
||||
if not run_generate_positions(num_games):
|
||||
return False
|
||||
@@ -125,8 +152,9 @@ def cmd_train(args):
|
||||
|
||||
print("\nStarting training...")
|
||||
|
||||
# Train (train_nnue.py handles versioning internally)
|
||||
if not run_train("training_data.jsonl", "nnue_weights.pt", checkpoint):
|
||||
# Train with absolute path to data, checkpoint is relative to weights dir
|
||||
training_data = str(data_dir / "training_data.jsonl")
|
||||
if not run_train(training_data, "nnue_weights.pt", checkpoint):
|
||||
return False
|
||||
|
||||
# Show created version
|
||||
@@ -143,13 +171,8 @@ def cmd_export(args):
|
||||
if not weights_file.endswith(".pt"):
|
||||
weights_file = f"nnue_weights_v{weights_file}.pt"
|
||||
|
||||
if not Path(weights_file).exists():
|
||||
print(f"ERROR: {weights_file} not found")
|
||||
return False
|
||||
|
||||
# Determine version from filename
|
||||
version = Path(weights_file).stem.split("_v")[1] if "_v" in weights_file else "1"
|
||||
output_file = f"../src/main/scala/de/nowchess/bot/bots/nnue/NNUEWeights_v{version}.scala"
|
||||
# Output to resources directory as binary format
|
||||
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.bin")
|
||||
|
||||
if not run_export(weights_file, output_file):
|
||||
return False
|
||||
@@ -164,11 +187,15 @@ def cmd_list(args):
|
||||
print("No checkpoints found")
|
||||
return True
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
print("Available checkpoints:")
|
||||
for v in available:
|
||||
weights_file = f"nnue_weights_v{v}.pt"
|
||||
size = Path(weights_file).stat().st_size / (1024**2) # MB
|
||||
print(f" v{v} ({size:.1f} MB)")
|
||||
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
|
||||
if weights_file.exists():
|
||||
size = weights_file.stat().st_size / (1024**2) # MB
|
||||
print(f" v{v} ({size:.1f} MB)")
|
||||
else:
|
||||
print(f" v{v} (file not found)")
|
||||
return True
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user