feat: add hybrid bot implementation and enhance NNUE training pipeline with tactical data extraction

This commit is contained in:
2026-04-09 22:36:09 +02:00
parent 0bf1d52132
commit 5d4cf5f13c
16 changed files with 940 additions and 245 deletions
+106 -67
View File
@@ -1,100 +1,136 @@
#!/usr/bin/env python3
"""Generate random chess positions for NNUE training with minimal filtering."""
"""Generate random chess positions for NNUE training with multiprocessing."""
import chess
import random
import sys
from pathlib import Path
from tqdm import tqdm
from multiprocessing import Pool, Queue
from datetime import datetime
import time
def _worker_generate_games(worker_id, games_per_worker, samples_per_game, min_move, max_move):
"""Generate games for one worker.
Returns:
list of FENs generated by this worker
"""
positions = []
for game_num in range(games_per_worker):
board = chess.Board()
move_history = []
# Play a complete random game
while not board.is_game_over() and len(move_history) < 200:
legal_moves = list(board.legal_moves)
if not legal_moves:
break
move = random.choice(legal_moves)
board.push(move)
move_history.append(board.copy())
# Determine the range of moves to sample from
game_length = len(move_history)
valid_start = max(min_move, 0)
valid_end = min(max_move, game_length)
if valid_start >= valid_end:
continue
# Randomly sample positions from this game
sample_count = min(samples_per_game, valid_end - valid_start)
if sample_count > 0:
sample_indices = random.sample(
range(valid_start, valid_end),
k=sample_count
)
for idx in sample_indices:
sampled_board = move_history[idx]
# Only filter truly invalid or terminal positions
if not sampled_board.is_valid() or sampled_board.is_game_over():
continue
# Save position (include check, captures, all positions)
fen = sampled_board.fen()
positions.append(fen)
return positions
def play_random_game_and_collect_positions(
output_file,
total_games=500000,
total_positions=3000000,
samples_per_game=1,
min_move=1,
max_move=50
max_move=50,
num_workers=8
):
"""Play random games and sample multiple positions from each.
"""Generate positions using multiprocessing with multiple workers.
Args:
output_file: Output file for positions
total_games: Number of games to play
total_positions: Target number of positions to generate
samples_per_game: Number of positions to sample per game (1-N)
min_move: Minimum move number to start sampling from
max_move: Maximum move number for sampling
num_workers: Number of parallel worker processes
Returns:
Number of valid positions saved
"""
# Estimate games needed (roughly 1 position per game on average)
total_games = max(total_positions // samples_per_game, num_workers)
games_per_worker = total_games // num_workers
print(f"Generating {total_positions:,} positions using {num_workers} workers")
print(f"Total games: ~{total_games:,} ({games_per_worker:,} per worker)")
print()
start_time = datetime.now()
# Generate positions in parallel
worker_tasks = [
(i, games_per_worker, samples_per_game, min_move, max_move)
for i in range(num_workers)
]
positions_count = 0
filtered_game_over = 0
filtered_illegal = 0
with open(output_file, 'w') as f:
with tqdm(total=total_games, desc="Generating positions") as pbar:
for game_num in range(total_games):
board = chess.Board()
move_history = []
# Play a complete random game
while not board.is_game_over() and len(move_history) < 200:
legal_moves = list(board.legal_moves)
if not legal_moves:
break
move = random.choice(legal_moves)
board.push(move)
move_history.append(board.copy())
# Determine the range of moves to sample from
game_length = len(move_history)
valid_start = max(min_move, 0)
valid_end = min(max_move, game_length)
if valid_start >= valid_end:
# Game too short
pbar.update(1)
continue
# Randomly sample positions from this game
sample_count = min(samples_per_game, valid_end - valid_start)
if sample_count > 0:
sample_indices = random.sample(
range(valid_start, valid_end),
k=sample_count
)
for idx in sample_indices:
sampled_board = move_history[idx]
# Only filter truly invalid or terminal positions
if not sampled_board.is_valid():
filtered_illegal += 1
continue
if sampled_board.is_game_over():
filtered_game_over += 1
continue
# Save position (include check, captures, all positions)
fen = sampled_board.fen()
f.write(fen + '\n')
positions_count += 1
all_positions = []
with Pool(num_workers) as pool:
with tqdm(total=num_workers, desc="Workers generating games") as pbar:
for positions in pool.starmap(_worker_generate_games, worker_tasks):
all_positions.extend(positions)
positions_count += len(positions)
pbar.update(1)
# Write all positions to file
print(f"Writing {positions_count:,} positions to {output_file}...")
with open(output_file, 'w') as f:
for fen in all_positions:
f.write(fen + '\n')
elapsed_time = datetime.now() - start_time
elapsed_seconds = elapsed_time.total_seconds()
positions_per_second = positions_count / elapsed_seconds if elapsed_seconds > 0 else 0
# Print summary
print()
print("=" * 60)
print("POSITION GENERATION SUMMARY")
print("=" * 60)
print(f"Total games played: {total_games}")
print(f"Target positions: {total_positions:,}")
print(f"Actual positions saved: {positions_count:,}")
print(f"Workers: {num_workers}")
print(f"Games per worker: {games_per_worker:,}")
print(f"Samples per game: {samples_per_game}")
print(f"Move range: {min_move}-{max_move}")
print(f"Saved positions: {positions_count}")
print(f"Filtered (game over): {filtered_game_over}")
print(f"Filtered (illegal): {filtered_illegal}")
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
print(f"(Includes checks, captures, all realistic positions)")
print(f"Elapsed time: {elapsed_time}")
print(f"Throughput: {positions_per_second:.0f} positions/second")
print("=" * 60)
print()
@@ -110,23 +146,26 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training")
parser.add_argument("output_file", nargs="?", default="positions.txt",
help="Output file for positions (default: positions.txt)")
parser.add_argument("--games", type=int, default=5000,
help="Number of games to play (default: 5000)")
parser.add_argument("--positions", type=int, default=3000000,
help="Target number of positions to generate (default: 3000000)")
parser.add_argument("--samples-per-game", type=int, default=1,
help="Number of positions to sample per game (default: 1)")
parser.add_argument("--min-move", type=int, default=1,
help="Minimum move number to sample from (default: 1)")
parser.add_argument("--max-move", type=int, default=50,
help="Maximum move number to sample from (default: 50)")
parser.add_argument("--workers", type=int, default=8,
help="Number of parallel worker processes (default: 8)")
args = parser.parse_args()
count = play_random_game_and_collect_positions(
output_file=args.output_file,
total_games=args.games,
total_positions=args.positions,
samples_per_game=args.samples_per_game,
min_move=args.min_move,
max_move=args.max_move
max_move=args.max_move,
num_workers=args.workers
)
sys.exit(0 if count > 0 else 1)