#!/usr/bin/env python3 """Generate random chess positions for NNUE training with multiprocessing.""" import chess import random import sys from pathlib import Path from tqdm import tqdm from multiprocessing import Pool, Queue from datetime import datetime import time def _worker_generate_games(worker_id, games_per_worker, samples_per_game, min_move, max_move): """Generate games for one worker. Returns: list of FENs generated by this worker """ positions = [] for game_num in range(games_per_worker): board = chess.Board() move_history = [] # Play a complete random game while not board.is_game_over() and len(move_history) < 200: legal_moves = list(board.legal_moves) if not legal_moves: break move = random.choice(legal_moves) board.push(move) move_history.append(board.copy()) # Determine the range of moves to sample from game_length = len(move_history) valid_start = max(min_move, 0) valid_end = min(max_move, game_length) if valid_start >= valid_end: continue # Randomly sample positions from this game sample_count = min(samples_per_game, valid_end - valid_start) if sample_count > 0: sample_indices = random.sample( range(valid_start, valid_end), k=sample_count ) for idx in sample_indices: sampled_board = move_history[idx] # Only filter truly invalid or terminal positions if not sampled_board.is_valid() or sampled_board.is_game_over(): continue # Save position (include check, captures, all positions) fen = sampled_board.fen() positions.append(fen) return positions def play_random_game_and_collect_positions( output_file, total_positions=3000000, samples_per_game=1, min_move=1, max_move=50, num_workers=8 ): """Generate positions using multiprocessing with multiple workers. Args: output_file: Output file for positions total_positions: Target number of positions to generate samples_per_game: Number of positions to sample per game (1-N) min_move: Minimum move number to start sampling from max_move: Maximum move number for sampling num_workers: Number of parallel worker processes Returns: Number of valid positions saved """ # Estimate games needed (roughly 1 position per game on average) total_games = max(total_positions // samples_per_game, num_workers) games_per_worker = total_games // num_workers print(f"Generating {total_positions:,} positions using {num_workers} workers") print(f"Total games: ~{total_games:,} ({games_per_worker:,} per worker)") print() start_time = datetime.now() # Generate positions in parallel worker_tasks = [ (i, games_per_worker, samples_per_game, min_move, max_move) for i in range(num_workers) ] positions_count = 0 all_positions = [] with Pool(num_workers) as pool: with tqdm(total=num_workers, desc="Workers generating games") as pbar: for positions in pool.starmap(_worker_generate_games, worker_tasks): all_positions.extend(positions) positions_count += len(positions) pbar.update(1) # Write all positions to file print(f"Writing {positions_count:,} positions to {output_file}...") with open(output_file, 'w') as f: for fen in all_positions: f.write(fen + '\n') elapsed_time = datetime.now() - start_time elapsed_seconds = elapsed_time.total_seconds() positions_per_second = positions_count / elapsed_seconds if elapsed_seconds > 0 else 0 # Print summary print() print("=" * 60) print("POSITION GENERATION SUMMARY") print("=" * 60) print(f"Target positions: {total_positions:,}") print(f"Actual positions saved: {positions_count:,}") print(f"Workers: {num_workers}") print(f"Games per worker: {games_per_worker:,}") print(f"Samples per game: {samples_per_game}") print(f"Move range: {min_move}-{max_move}") print(f"Elapsed time: {elapsed_time}") print(f"Throughput: {positions_per_second:.0f} positions/second") print("=" * 60) print() if positions_count == 0: print("WARNING: No valid positions were generated!") return 0 return positions_count if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training") parser.add_argument("output_file", nargs="?", default="positions.txt", help="Output file for positions (default: positions.txt)") parser.add_argument("--positions", type=int, default=3000000, help="Target number of positions to generate (default: 3000000)") parser.add_argument("--samples-per-game", type=int, default=1, help="Number of positions to sample per game (default: 1)") parser.add_argument("--min-move", type=int, default=1, help="Minimum move number to sample from (default: 1)") parser.add_argument("--max-move", type=int, default=50, help="Maximum move number to sample from (default: 50)") parser.add_argument("--workers", type=int, default=8, help="Number of parallel worker processes (default: 8)") args = parser.parse_args() count = play_random_game_and_collect_positions( output_file=args.output_file, total_positions=args.positions, samples_per_game=args.samples_per_game, min_move=args.min_move, max_move=args.max_move, num_workers=args.workers ) sys.exit(0 if count > 0 else 1)