feat: enhance training and evaluation processes with new parameters and normalization options

This commit is contained in:
2026-04-08 10:21:38 +02:00
committed by Janis
parent 34945e4fb8
commit a5560285fd
8 changed files with 242 additions and 90 deletions
+64 -73
View File
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
"""Generate 500,000 random chess positions for NNUE training."""
"""Generate random chess positions for NNUE training with minimal filtering."""
import chess
import random
@@ -7,89 +7,78 @@ import sys
from pathlib import Path
from tqdm import tqdm
# Standard piece values for capture filtering
PIECE_VALUES = {
chess.PAWN: 1,
chess.KNIGHT: 3,
chess.BISHOP: 3,
chess.ROOK: 5,
chess.QUEEN: 9,
}
def play_random_game_and_collect_positions(
output_file,
total_games=500000,
samples_per_game=1,
min_move=1,
max_move=50
):
"""Play random games and sample multiple positions from each.
def has_winning_or_equal_capture(board):
"""Check if position has a capture where victim >= attacker (winning or equal trade).
Returns True only if there's at least one favorable capture.
Positions with only losing captures return False (are kept).
"""
for move in board.legal_moves:
if board.is_capture(move):
attacker_piece = board.piece_at(move.from_square)
victim_piece = board.piece_at(move.to_square)
if attacker_piece and victim_piece:
attacker_value = PIECE_VALUES.get(attacker_piece.piece_type, 0)
victim_value = PIECE_VALUES.get(victim_piece.piece_type, 0)
# If victim >= attacker, it's a winning or equal capture
if victim_value >= attacker_value:
return True
return False
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
"""Play random games and save positions after 8-20 random moves.
Args:
output_file: Output file for positions
total_games: Number of games to play
samples_per_game: Number of positions to sample per game (1-N)
min_move: Minimum move number to start sampling from
max_move: Maximum move number for sampling
Returns:
Number of valid positions saved
"""
positions_count = 0
filtered_check = 0
filtered_captures = 0
filtered_game_over = 0
filtered_illegal = 0
with open(output_file, 'w') as f:
with tqdm(total=total_games, desc="Generating positions") as pbar:
for game_num in range(total_games):
board = chess.Board()
move_history = []
# Play 8-20 random opening moves
num_moves = random.randint(8, 20)
for move_num in range(num_moves):
if board.is_game_over():
break
# Play a complete random game
while not board.is_game_over() and len(move_history) < 200:
legal_moves = list(board.legal_moves)
if not legal_moves:
break
move = random.choice(legal_moves)
board.push(move)
move_history.append(board.copy())
# Skip if game over
if board.is_game_over():
filtered_game_over += 1
# Determine the range of moves to sample from
game_length = len(move_history)
valid_start = max(min_move, 0)
valid_end = min(max_move, game_length)
if valid_start >= valid_end:
# Game too short
pbar.update(1)
continue
# Skip if in check
if board.is_check():
filtered_check += 1
pbar.update(1)
continue
# Randomly sample positions from this game
sample_count = min(samples_per_game, valid_end - valid_start)
if sample_count > 0:
sample_indices = random.sample(
range(valid_start, valid_end),
k=sample_count
)
# Check if there are winning or equal captures (if filtering enabled)
if filter_captures:
if has_winning_or_equal_capture(board):
filtered_captures += 1
pbar.update(1)
continue
for idx in sample_indices:
sampled_board = move_history[idx]
# Save valid position
fen = board.fen()
f.write(fen + '\n')
positions_count += 1
# Only filter truly invalid or terminal positions
if not sampled_board.is_valid():
filtered_illegal += 1
continue
if sampled_board.is_game_over():
filtered_game_over += 1
continue
# Save position (include check, captures, all positions)
fen = sampled_board.fen()
f.write(fen + '\n')
positions_count += 1
pbar.update(1)
@@ -98,23 +87,19 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
print("=" * 60)
print("POSITION GENERATION SUMMARY")
print("=" * 60)
total_filtered = filtered_check + filtered_captures + filtered_game_over
print(f"Total games: {total_games}")
print(f"Total games played: {total_games}")
print(f"Samples per game: {samples_per_game}")
print(f"Move range: {min_move}-{max_move}")
print(f"Saved positions: {positions_count}")
print(f"Filtered (in check): {filtered_check}")
print(f"Filtered (winning+ cap): {filtered_captures}")
print(f"Filtered (game over): {filtered_game_over}")
print(f"Total filtered: {total_filtered}")
print(f"Filtered (illegal): {filtered_illegal}")
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
print(f"(Keeps positions with only losing/bad captures)")
print(f"(Includes checks, captures, all realistic positions)")
print("=" * 60)
print()
if positions_count == 0:
print("WARNING: No valid positions were generated!")
print("This might happen if:")
print(" - Most positions have checks or game-over states")
print(" - Try using: --no-filter-captures to accept positions with winning captures")
return 0
return positions_count
@@ -126,16 +111,22 @@ if __name__ == "__main__":
parser.add_argument("output_file", nargs="?", default="positions.txt",
help="Output file for positions (default: positions.txt)")
parser.add_argument("--games", type=int, default=5000,
help="Number of games to play (default: 500000)")
parser.add_argument("--no-filter-captures", action="store_true",
help="Include positions with winning/equal captures (increases output)")
help="Number of games to play (default: 5000)")
parser.add_argument("--samples-per-game", type=int, default=1,
help="Number of positions to sample per game (default: 1)")
parser.add_argument("--min-move", type=int, default=1,
help="Minimum move number to sample from (default: 1)")
parser.add_argument("--max-move", type=int, default=50,
help="Maximum move number to sample from (default: 50)")
args = parser.parse_args()
count = play_random_game_and_collect_positions(
output_file=args.output_file,
total_games=args.games,
filter_captures=not args.no_filter_captures
samples_per_game=args.samples_per_game,
min_move=args.min_move,
max_move=args.max_move
)
sys.exit(0 if count > 0 else 1)