dceab0875e
Build & Test (NowChessSystems) TeamCity build finished
Co-authored-by: Janis <janis@nowchess.de> Reviewed-on: #33 Co-authored-by: Janis <janis.e.20@gmx.de> Co-committed-by: Janis <janis.e.20@gmx.de>
172 lines
5.9 KiB
Python
172 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
|
"""Generate random chess positions for NNUE training with multiprocessing."""
|
|
|
|
import chess
|
|
import random
|
|
import sys
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
from multiprocessing import Pool, Queue
|
|
from datetime import datetime
|
|
import time
|
|
|
|
def _worker_generate_games(worker_id, games_per_worker, samples_per_game, min_move, max_move):
|
|
"""Generate games for one worker.
|
|
|
|
Returns:
|
|
list of FENs generated by this worker
|
|
"""
|
|
positions = []
|
|
|
|
for game_num in range(games_per_worker):
|
|
board = chess.Board()
|
|
move_history = []
|
|
|
|
# Play a complete random game
|
|
while not board.is_game_over() and len(move_history) < 200:
|
|
legal_moves = list(board.legal_moves)
|
|
if not legal_moves:
|
|
break
|
|
move = random.choice(legal_moves)
|
|
board.push(move)
|
|
move_history.append(board.copy())
|
|
|
|
# Determine the range of moves to sample from
|
|
game_length = len(move_history)
|
|
valid_start = max(min_move, 0)
|
|
valid_end = min(max_move, game_length)
|
|
|
|
if valid_start >= valid_end:
|
|
continue
|
|
|
|
# Randomly sample positions from this game
|
|
sample_count = min(samples_per_game, valid_end - valid_start)
|
|
if sample_count > 0:
|
|
sample_indices = random.sample(
|
|
range(valid_start, valid_end),
|
|
k=sample_count
|
|
)
|
|
|
|
for idx in sample_indices:
|
|
sampled_board = move_history[idx]
|
|
|
|
# Only filter truly invalid or terminal positions
|
|
if not sampled_board.is_valid() or sampled_board.is_game_over():
|
|
continue
|
|
|
|
# Save position (include check, captures, all positions)
|
|
fen = sampled_board.fen()
|
|
positions.append(fen)
|
|
|
|
return positions
|
|
|
|
|
|
def play_random_game_and_collect_positions(
|
|
output_file,
|
|
total_positions=3000000,
|
|
samples_per_game=1,
|
|
min_move=1,
|
|
max_move=50,
|
|
num_workers=8
|
|
):
|
|
"""Generate positions using multiprocessing with multiple workers.
|
|
|
|
Args:
|
|
output_file: Output file for positions
|
|
total_positions: Target number of positions to generate
|
|
samples_per_game: Number of positions to sample per game (1-N)
|
|
min_move: Minimum move number to start sampling from
|
|
max_move: Maximum move number for sampling
|
|
num_workers: Number of parallel worker processes
|
|
|
|
Returns:
|
|
Number of valid positions saved
|
|
"""
|
|
# Estimate games needed (roughly 1 position per game on average)
|
|
total_games = max(total_positions // samples_per_game, num_workers)
|
|
games_per_worker = total_games // num_workers
|
|
|
|
print(f"Generating {total_positions:,} positions using {num_workers} workers")
|
|
print(f"Total games: ~{total_games:,} ({games_per_worker:,} per worker)")
|
|
print()
|
|
|
|
start_time = datetime.now()
|
|
|
|
# Generate positions in parallel
|
|
worker_tasks = [
|
|
(i, games_per_worker, samples_per_game, min_move, max_move)
|
|
for i in range(num_workers)
|
|
]
|
|
|
|
positions_count = 0
|
|
all_positions = []
|
|
|
|
with Pool(num_workers) as pool:
|
|
with tqdm(total=num_workers, desc="Workers generating games") as pbar:
|
|
for positions in pool.starmap(_worker_generate_games, worker_tasks):
|
|
all_positions.extend(positions)
|
|
positions_count += len(positions)
|
|
pbar.update(1)
|
|
|
|
# Write all positions to file
|
|
print(f"Writing {positions_count:,} positions to {output_file}...")
|
|
with open(output_file, 'w') as f:
|
|
for fen in all_positions:
|
|
f.write(fen + '\n')
|
|
|
|
elapsed_time = datetime.now() - start_time
|
|
elapsed_seconds = elapsed_time.total_seconds()
|
|
positions_per_second = positions_count / elapsed_seconds if elapsed_seconds > 0 else 0
|
|
|
|
# Print summary
|
|
print()
|
|
print("=" * 60)
|
|
print("POSITION GENERATION SUMMARY")
|
|
print("=" * 60)
|
|
print(f"Target positions: {total_positions:,}")
|
|
print(f"Actual positions saved: {positions_count:,}")
|
|
print(f"Workers: {num_workers}")
|
|
print(f"Games per worker: {games_per_worker:,}")
|
|
print(f"Samples per game: {samples_per_game}")
|
|
print(f"Move range: {min_move}-{max_move}")
|
|
print(f"Elapsed time: {elapsed_time}")
|
|
print(f"Throughput: {positions_per_second:.0f} positions/second")
|
|
print("=" * 60)
|
|
print()
|
|
|
|
if positions_count == 0:
|
|
print("WARNING: No valid positions were generated!")
|
|
return 0
|
|
|
|
return positions_count
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training")
|
|
parser.add_argument("output_file", nargs="?", default="positions.txt",
|
|
help="Output file for positions (default: positions.txt)")
|
|
parser.add_argument("--positions", type=int, default=3000000,
|
|
help="Target number of positions to generate (default: 3000000)")
|
|
parser.add_argument("--samples-per-game", type=int, default=1,
|
|
help="Number of positions to sample per game (default: 1)")
|
|
parser.add_argument("--min-move", type=int, default=1,
|
|
help="Minimum move number to sample from (default: 1)")
|
|
parser.add_argument("--max-move", type=int, default=50,
|
|
help="Maximum move number to sample from (default: 50)")
|
|
parser.add_argument("--workers", type=int, default=8,
|
|
help="Number of parallel worker processes (default: 8)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
count = play_random_game_and_collect_positions(
|
|
output_file=args.output_file,
|
|
total_positions=args.positions,
|
|
samples_per_game=args.samples_per_game,
|
|
min_move=args.min_move,
|
|
max_move=args.max_move,
|
|
num_workers=args.workers
|
|
)
|
|
|
|
sys.exit(0 if count > 0 else 1)
|