Co-authored-by: Janis <janis@nowchess.de> Reviewed-on: #33 Co-authored-by: Janis <janis.e.20@gmx.de> Co-committed-by: Janis <janis.e.20@gmx.de>
This commit was merged in pull request #33.
This commit is contained in:
@@ -0,0 +1,326 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Label positions with Stockfish evaluations and analyze distribution."""
|
||||
|
||||
import json
|
||||
import chess.engine
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
|
||||
def normalize_evaluation(cp_value, method='tanh', scale=300.0):
|
||||
"""Normalize centipawn evaluation to a bounded range.
|
||||
|
||||
Args:
|
||||
cp_value: Centipawn evaluation from Stockfish
|
||||
method: 'tanh' (default) or 'sigmoid'
|
||||
scale: Scale factor (tanh: 300 is typical)
|
||||
|
||||
Returns:
|
||||
Normalized value in approximately [-1, 1] (tanh) or [0, 1] (sigmoid)
|
||||
"""
|
||||
if method == 'tanh':
|
||||
return np.tanh(cp_value / scale)
|
||||
elif method == 'sigmoid':
|
||||
return 1.0 / (1.0 + np.exp(-cp_value / scale))
|
||||
else:
|
||||
return cp_value / 100.0
|
||||
|
||||
def _evaluate_fen_batch(args):
|
||||
"""Worker function to evaluate a batch of FENs with Stockfish threading.
|
||||
|
||||
Args:
|
||||
args: tuple of (fens, stockfish_path, depth, normalize)
|
||||
|
||||
Returns:
|
||||
list of (fen, eval_normalized, eval_raw) tuples
|
||||
"""
|
||||
fens, stockfish_path, depth, normalize = args
|
||||
|
||||
results = []
|
||||
|
||||
try:
|
||||
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
try:
|
||||
for fen in fens:
|
||||
try:
|
||||
board = chess.Board(fen)
|
||||
if not board.is_valid():
|
||||
continue
|
||||
|
||||
info = engine.analyse(board, chess.engine.Limit(depth=depth))
|
||||
|
||||
if info.get('score') is None:
|
||||
continue
|
||||
|
||||
score = info['score'].white()
|
||||
|
||||
if score.is_mate():
|
||||
eval_cp = 2000 if score.mate() > 0 else -2000
|
||||
else:
|
||||
eval_cp = score.cp
|
||||
|
||||
eval_cp = max(-2000, min(2000, eval_cp))
|
||||
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
|
||||
|
||||
results.append((fen, eval_normalized, eval_cp))
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
finally:
|
||||
engine.quit()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=1000, depth=12, verbose=False, normalize=True, num_workers=1):
|
||||
"""Read positions and label them with Stockfish evaluations.
|
||||
|
||||
Args:
|
||||
positions_file: Path to positions.txt
|
||||
output_file: Path to training_data.jsonl
|
||||
stockfish_path: Path to stockfish binary
|
||||
batch_size: Batch size for processing (positions per worker task, default: 1000)
|
||||
depth: Stockfish depth
|
||||
verbose: Print detailed error messages
|
||||
normalize: If True, normalize evals using tanh
|
||||
num_workers: Number of parallel Stockfish processes
|
||||
"""
|
||||
|
||||
# Check if stockfish exists
|
||||
if not Path(stockfish_path).exists():
|
||||
print(f"Error: Stockfish not found at {stockfish_path}")
|
||||
print(f"Tried: {stockfish_path}")
|
||||
print(f"Set STOCKFISH_PATH environment variable or pass as argument")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using Stockfish: {stockfish_path}")
|
||||
print(f"Number of workers: {num_workers}")
|
||||
|
||||
# Check if positions file exists
|
||||
if not Path(positions_file).exists():
|
||||
print(f"Error: Positions file not found at {positions_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load existing evaluations if resuming
|
||||
evaluated_fens = set()
|
||||
position_count = 0
|
||||
|
||||
if Path(output_file).exists():
|
||||
with open(output_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
evaluated_fens.add(data['fen'])
|
||||
position_count += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
print(f"Resuming from {position_count} already evaluated positions")
|
||||
|
||||
# Load all FENs that need evaluation
|
||||
fens_to_evaluate = []
|
||||
fens_seen_in_batch = set() # Track duplicates within current batch
|
||||
skipped_invalid = 0
|
||||
skipped_duplicate = 0
|
||||
|
||||
with open(positions_file, 'r') as f:
|
||||
for fen in f:
|
||||
fen = fen.strip()
|
||||
|
||||
if not fen:
|
||||
skipped_invalid += 1
|
||||
continue
|
||||
|
||||
if fen in evaluated_fens:
|
||||
skipped_duplicate += 1
|
||||
continue
|
||||
|
||||
if fen in fens_seen_in_batch:
|
||||
skipped_duplicate += 1
|
||||
continue
|
||||
|
||||
fens_to_evaluate.append(fen)
|
||||
fens_seen_in_batch.add(fen)
|
||||
|
||||
total_to_evaluate = len(fens_to_evaluate)
|
||||
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
|
||||
|
||||
if total_to_evaluate == 0:
|
||||
if position_count == 0:
|
||||
print(f"Error: No valid positions to evaluate in {positions_file}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"All positions already evaluated. No new positions to process.")
|
||||
return True
|
||||
|
||||
print(f"Total positions to process: {total_lines}")
|
||||
print(f"New positions to evaluate: {total_to_evaluate}")
|
||||
print(f"Using depth: {depth}")
|
||||
print()
|
||||
|
||||
# Split FENs into batches for workers
|
||||
batches = []
|
||||
for i in range(0, total_to_evaluate, batch_size):
|
||||
batch = fens_to_evaluate[i:i+batch_size]
|
||||
batches.append((batch, stockfish_path, depth, normalize))
|
||||
|
||||
# Process batches in parallel
|
||||
evaluated = 0
|
||||
errors = 0
|
||||
raw_evals = []
|
||||
normalized_evals = []
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
with Pool(num_workers) as pool:
|
||||
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar:
|
||||
with open(output_file, 'a') as out:
|
||||
for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
|
||||
for fen, eval_normalized, eval_cp in batch_results:
|
||||
# Skip if already evaluated in output file during this run
|
||||
if fen in evaluated_fens:
|
||||
continue
|
||||
|
||||
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
|
||||
out.write(json.dumps(data) + '\n')
|
||||
evaluated_fens.add(fen) # Track as evaluated
|
||||
evaluated += 1
|
||||
raw_evals.append(eval_cp)
|
||||
normalized_evals.append(eval_normalized)
|
||||
pbar.update(1)
|
||||
|
||||
# Update progress for any failed evaluations in the batch
|
||||
batch_size_actual = len(batches[0][0]) if batches else batch_size
|
||||
failed = batch_size_actual - len(batch_results)
|
||||
if failed > 0:
|
||||
errors += failed
|
||||
pbar.update(failed)
|
||||
|
||||
# Calculate and show throughput and ETA
|
||||
elapsed = time.time() - start_time
|
||||
throughput = evaluated / elapsed if elapsed > 0 else 0
|
||||
remaining_positions = total_to_evaluate - evaluated
|
||||
eta_seconds = remaining_positions / throughput if throughput > 0 else 0
|
||||
eta_str = f"{int(eta_seconds // 60)}:{int(eta_seconds % 60):02d}"
|
||||
|
||||
if (batch_idx + 1) % max(1, len(batches) // 10) == 0:
|
||||
pbar.set_postfix({
|
||||
'rate': f'{throughput:.0f} pos/s',
|
||||
'eta': eta_str
|
||||
})
|
||||
|
||||
# Print summary and analysis
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("LABELING SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Successfully evaluated: {evaluated}")
|
||||
print(f"Skipped (duplicates): {skipped_duplicate}")
|
||||
print(f"Skipped (invalid): {skipped_invalid}")
|
||||
print(f"Errors: {errors}")
|
||||
print(f"Total processed: {evaluated + skipped_duplicate + skipped_invalid + errors}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if evaluated == 0:
|
||||
print("WARNING: No positions were successfully evaluated!")
|
||||
print("Check that:")
|
||||
print(" 1. positions.txt is not empty")
|
||||
print(" 2. positions.txt contains valid FENs")
|
||||
print(" 3. Stockfish is installed and working")
|
||||
print(" 4. Stockfish path is correct")
|
||||
return False
|
||||
|
||||
# Print distribution analysis
|
||||
if raw_evals:
|
||||
raw_evals_arr = np.array(raw_evals)
|
||||
norm_evals_arr = np.array(normalized_evals)
|
||||
|
||||
print("=" * 60)
|
||||
print("EVALUATION DISTRIBUTION ANALYSIS")
|
||||
print("=" * 60)
|
||||
print()
|
||||
print("Raw Evaluations (centipawns):")
|
||||
print(f" Min: {raw_evals_arr.min():.1f}")
|
||||
print(f" Max: {raw_evals_arr.max():.1f}")
|
||||
print(f" Mean: {raw_evals_arr.mean():.1f}")
|
||||
print(f" Median: {np.median(raw_evals_arr):.1f}")
|
||||
print(f" Std: {raw_evals_arr.std():.1f}")
|
||||
print()
|
||||
|
||||
print("Normalized Evaluations (tanh):")
|
||||
print(f" Min: {norm_evals_arr.min():.4f}")
|
||||
print(f" Max: {norm_evals_arr.max():.4f}")
|
||||
print(f" Mean: {norm_evals_arr.mean():.4f}")
|
||||
print(f" Median: {np.median(norm_evals_arr):.4f}")
|
||||
print(f" Std: {norm_evals_arr.std():.4f}")
|
||||
print()
|
||||
|
||||
# Distribution buckets
|
||||
print("Raw Evaluation Buckets (counts):")
|
||||
buckets = [
|
||||
(-float('inf'), -500, "< -5.00"),
|
||||
(-500, -300, "[-5.00, -3.00)"),
|
||||
(-300, -100, "[-3.00, -1.00)"),
|
||||
(-100, 0, "[-1.00, 0.00)"),
|
||||
(0, 100, "[0.00, 1.00)"),
|
||||
(100, 300, "[1.00, 3.00)"),
|
||||
(300, 500, "[3.00, 5.00)"),
|
||||
(500, float('inf'), "> 5.00"),
|
||||
]
|
||||
for low, high, label in buckets:
|
||||
count = np.sum((raw_evals_arr > low) & (raw_evals_arr <= high))
|
||||
pct = 100.0 * count / len(raw_evals_arr)
|
||||
print(f" {label}: {count:6d} ({pct:5.1f}%)")
|
||||
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
print(f"✓ Labeling complete. Output saved to {output_file}")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Label chess positions with Stockfish evaluations")
|
||||
parser.add_argument("positions_file", nargs="?", default="positions.txt",
|
||||
help="Input positions file (default: positions.txt)")
|
||||
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
|
||||
help="Output file (default: training_data.jsonl)")
|
||||
parser.add_argument("stockfish_path", nargs="?", default=None,
|
||||
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
||||
parser.add_argument("--depth", type=int, default=12,
|
||||
help="Stockfish depth (default: 12)")
|
||||
parser.add_argument("--batch-size", type=int, default=1000,
|
||||
help="Batch size for processing (default: 1000)")
|
||||
parser.add_argument("--no-normalize", action="store_true",
|
||||
help="Disable evaluation normalization (keep raw centipawns)")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Print detailed error messages")
|
||||
parser.add_argument("--workers", type=int, default=1,
|
||||
help="Number of parallel Stockfish processes (default: 1)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine Stockfish path
|
||||
stockfish_path = args.stockfish_path or os.environ.get("STOCKFISH_PATH", "stockfish")
|
||||
|
||||
success = label_positions_with_stockfish(
|
||||
positions_file=args.positions_file,
|
||||
output_file=args.output_file,
|
||||
stockfish_path=stockfish_path,
|
||||
batch_size=args.batch_size,
|
||||
depth=args.depth,
|
||||
normalize=not args.no_normalize,
|
||||
verbose=args.verbose,
|
||||
num_workers=args.workers
|
||||
)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user