111 lines
3.8 KiB
Python
111 lines
3.8 KiB
Python
#!/usr/bin/env python3
|
|
"""Generate 500,000 random chess positions for NNUE training."""
|
|
|
|
import chess
|
|
import random
|
|
import sys
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
|
|
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
|
|
"""Play random games and save positions after 8-20 random moves.
|
|
|
|
Returns:
|
|
Number of valid positions saved
|
|
"""
|
|
positions_count = 0
|
|
filtered_check = 0
|
|
filtered_captures = 0
|
|
filtered_game_over = 0
|
|
|
|
with open(output_file, 'w') as f:
|
|
with tqdm(total=total_games, desc="Generating positions") as pbar:
|
|
for game_num in range(total_games):
|
|
board = chess.Board()
|
|
|
|
# Play 8-20 random opening moves
|
|
num_moves = random.randint(8, 20)
|
|
|
|
for move_num in range(num_moves):
|
|
if board.is_game_over():
|
|
break
|
|
|
|
legal_moves = list(board.legal_moves)
|
|
if not legal_moves:
|
|
break
|
|
|
|
move = random.choice(legal_moves)
|
|
board.push(move)
|
|
|
|
# Skip if game over
|
|
if board.is_game_over():
|
|
filtered_game_over += 1
|
|
pbar.update(1)
|
|
continue
|
|
|
|
# Skip if in check
|
|
if board.is_check():
|
|
filtered_check += 1
|
|
pbar.update(1)
|
|
continue
|
|
|
|
# Check if any captures are available (if filtering enabled)
|
|
if filter_captures:
|
|
has_captures = any(board.is_capture(move) for move in board.legal_moves)
|
|
if has_captures:
|
|
filtered_captures += 1
|
|
pbar.update(1)
|
|
continue
|
|
|
|
# Save valid position
|
|
fen = board.fen()
|
|
f.write(fen + '\n')
|
|
positions_count += 1
|
|
|
|
pbar.update(1)
|
|
|
|
# Print summary
|
|
print()
|
|
print("=" * 60)
|
|
print("POSITION GENERATION SUMMARY")
|
|
print("=" * 60)
|
|
print(f"Total games: {total_games}")
|
|
print(f"Saved positions: {positions_count}")
|
|
print(f"Filtered (check): {filtered_check}")
|
|
print(f"Filtered (captures): {filtered_captures}")
|
|
print(f"Filtered (game over): {filtered_game_over}")
|
|
print(f"Total filtered: {filtered_check + filtered_captures + filtered_game_over}")
|
|
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
|
|
print("=" * 60)
|
|
print()
|
|
|
|
if positions_count == 0:
|
|
print("WARNING: No valid positions were generated!")
|
|
print("This might happen if:")
|
|
print(" - The filter criteria are too strict (captures, checks)")
|
|
print(" - Try using: --no-filter-captures to accept positions with captures")
|
|
return 0
|
|
|
|
return positions_count
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training")
|
|
parser.add_argument("output_file", nargs="?", default="positions.txt",
|
|
help="Output file for positions (default: positions.txt)")
|
|
parser.add_argument("--games", type=int, default=5000,
|
|
help="Number of games to play (default: 500000)")
|
|
parser.add_argument("--no-filter-captures", action="store_true",
|
|
help="Include positions with available captures (increases output)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
count = play_random_game_and_collect_positions(
|
|
output_file=args.output_file,
|
|
total_games=args.games,
|
|
filter_captures=not args.no_filter_captures
|
|
)
|
|
|
|
sys.exit(0 if count > 0 else 1)
|