199 lines
7.0 KiB
Python
199 lines
7.0 KiB
Python
#!/usr/bin/env python3
|
|
"""Label positions with Stockfish evaluations."""
|
|
|
|
import json
|
|
import chess.engine
|
|
import sys
|
|
import os
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
|
|
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False):
|
|
"""Read positions and label them with Stockfish evaluations.
|
|
|
|
Args:
|
|
positions_file: Path to positions.txt
|
|
output_file: Path to training_data.jsonl
|
|
stockfish_path: Path to stockfish binary
|
|
batch_size: Batch size (not used, kept for compatibility)
|
|
depth: Stockfish depth
|
|
verbose: Print detailed error messages
|
|
"""
|
|
|
|
# Check if stockfish exists
|
|
if not Path(stockfish_path).exists():
|
|
print(f"Error: Stockfish not found at {stockfish_path}")
|
|
print(f"Tried: {stockfish_path}")
|
|
print(f"Set STOCKFISH_PATH environment variable or pass as argument")
|
|
sys.exit(1)
|
|
|
|
print(f"Using Stockfish: {stockfish_path}")
|
|
|
|
# Check if positions file exists
|
|
if not Path(positions_file).exists():
|
|
print(f"Error: Positions file not found at {positions_file}")
|
|
sys.exit(1)
|
|
|
|
# Load existing evaluations if resuming
|
|
evaluated_fens = set()
|
|
position_count = 0
|
|
|
|
if Path(output_file).exists():
|
|
with open(output_file, 'r') as f:
|
|
for line in f:
|
|
try:
|
|
data = json.loads(line)
|
|
evaluated_fens.add(data['fen'])
|
|
position_count += 1
|
|
except json.JSONDecodeError:
|
|
pass
|
|
print(f"Resuming from {position_count} already evaluated positions")
|
|
|
|
# Count total positions
|
|
with open(positions_file, 'r') as f:
|
|
total_lines = sum(1 for _ in f)
|
|
|
|
if total_lines == 0:
|
|
print(f"Error: Positions file is empty ({positions_file})")
|
|
sys.exit(1)
|
|
|
|
print(f"Total positions to process: {total_lines}")
|
|
print(f"Using depth: {depth}")
|
|
print()
|
|
|
|
# Initialize engine
|
|
try:
|
|
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
|
except Exception as e:
|
|
print(f"Error: Could not start Stockfish engine")
|
|
print(f" Stockfish path: {stockfish_path}")
|
|
print(f" Error: {e}")
|
|
sys.exit(1)
|
|
|
|
# Track statistics
|
|
evaluated = 0
|
|
skipped_invalid = 0
|
|
skipped_duplicate = 0
|
|
errors = 0
|
|
|
|
try:
|
|
with open(positions_file, 'r') as f:
|
|
with open(output_file, 'a') as out:
|
|
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar:
|
|
for fen in f:
|
|
fen = fen.strip()
|
|
|
|
# Skip empty lines
|
|
if not fen:
|
|
skipped_invalid += 1
|
|
pbar.update(1)
|
|
continue
|
|
|
|
# Skip already evaluated
|
|
if fen in evaluated_fens:
|
|
skipped_duplicate += 1
|
|
pbar.update(1)
|
|
continue
|
|
|
|
try:
|
|
# Validate FEN
|
|
board = chess.Board(fen)
|
|
if not board.is_valid():
|
|
skipped_invalid += 1
|
|
pbar.update(1)
|
|
continue
|
|
|
|
# Evaluate at specified depth
|
|
info = engine.analyse(board, chess.engine.Limit(depth=depth))
|
|
|
|
if info.get('score') is None:
|
|
skipped_invalid += 1
|
|
pbar.update(1)
|
|
continue
|
|
|
|
score = info['score'].white()
|
|
|
|
# Convert to centipawns
|
|
if score.is_mate():
|
|
# Use large values for mate scores
|
|
eval_cp = 2000 if score.mate() > 0 else -2000
|
|
else:
|
|
eval_cp = score.cp
|
|
|
|
# Clamp to [-2000, 2000]
|
|
eval_cp = max(-2000, min(2000, eval_cp))
|
|
|
|
# Save evaluation
|
|
data = {"fen": fen, "eval": eval_cp}
|
|
out.write(json.dumps(data) + '\n')
|
|
out.flush() # Force write to disk
|
|
evaluated += 1
|
|
|
|
except Exception as e:
|
|
errors += 1
|
|
if verbose:
|
|
print(f"Error evaluating position: {fen[:50]}...")
|
|
print(f" {type(e).__name__}: {e}")
|
|
pbar.update(1)
|
|
continue
|
|
|
|
pbar.update(1)
|
|
|
|
finally:
|
|
engine.quit()
|
|
|
|
# Print summary
|
|
print()
|
|
print("=" * 60)
|
|
print("LABELING SUMMARY")
|
|
print("=" * 60)
|
|
print(f"Successfully evaluated: {evaluated}")
|
|
print(f"Skipped (duplicates): {skipped_duplicate}")
|
|
print(f"Skipped (invalid): {skipped_invalid}")
|
|
print(f"Errors: {errors}")
|
|
print(f"Total processed: {evaluated + skipped_duplicate + skipped_invalid + errors}")
|
|
print("=" * 60)
|
|
print()
|
|
|
|
if evaluated == 0:
|
|
print("WARNING: No positions were successfully evaluated!")
|
|
print("Check that:")
|
|
print(" 1. positions.txt is not empty")
|
|
print(" 2. positions.txt contains valid FENs")
|
|
print(" 3. Stockfish is installed and working")
|
|
print(" 4. Stockfish path is correct")
|
|
return False
|
|
|
|
print(f"✓ Labeling complete. Output saved to {output_file}")
|
|
return True
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Label chess positions with Stockfish evaluations")
|
|
parser.add_argument("positions_file", nargs="?", default="positions.txt",
|
|
help="Input positions file (default: positions.txt)")
|
|
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
|
|
help="Output file (default: training_data.jsonl)")
|
|
parser.add_argument("stockfish_path", nargs="?", default=None,
|
|
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
|
parser.add_argument("--depth", type=int, default=12,
|
|
help="Stockfish depth (default: 12)")
|
|
parser.add_argument("--verbose", action="store_true",
|
|
help="Print detailed error messages")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Determine Stockfish path
|
|
stockfish_path = args.stockfish_path or os.environ.get("STOCKFISH_PATH", "stockfish")
|
|
|
|
success = label_positions_with_stockfish(
|
|
positions_file=args.positions_file,
|
|
output_file=args.output_file,
|
|
stockfish_path=stockfish_path,
|
|
depth=args.depth,
|
|
verbose=args.verbose
|
|
)
|
|
|
|
sys.exit(0 if success else 1)
|