feat: add rich console interface for NNUE training pipeline and update requirements

This commit is contained in:
2026-04-07 23:59:17 +02:00
parent b25be99dcf
commit adc2de23bc
5 changed files with 301 additions and 243 deletions
+40 -9
View File
@@ -7,6 +7,36 @@ import sys
from pathlib import Path
from tqdm import tqdm
# Standard piece values for capture filtering
PIECE_VALUES = {
chess.PAWN: 1,
chess.KNIGHT: 3,
chess.BISHOP: 3,
chess.ROOK: 5,
chess.QUEEN: 9,
}
def has_winning_or_equal_capture(board):
"""Check if position has a capture where victim >= attacker (winning or equal trade).
Returns True only if there's at least one favorable capture.
Positions with only losing captures return False (are kept).
"""
for move in board.legal_moves:
if board.is_capture(move):
attacker_piece = board.piece_at(move.from_square)
victim_piece = board.piece_at(move.to_square)
if attacker_piece and victim_piece:
attacker_value = PIECE_VALUES.get(attacker_piece.piece_type, 0)
victim_value = PIECE_VALUES.get(victim_piece.piece_type, 0)
# If victim >= attacker, it's a winning or equal capture
if victim_value >= attacker_value:
return True
return False
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
"""Play random games and save positions after 8-20 random moves.
@@ -49,10 +79,9 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
pbar.update(1)
continue
# Check if any captures are available (if filtering enabled)
# Check if there are winning or equal captures (if filtering enabled)
if filter_captures:
has_captures = any(board.is_capture(move) for move in board.legal_moves)
if has_captures:
if has_winning_or_equal_capture(board):
filtered_captures += 1
pbar.update(1)
continue
@@ -69,21 +98,23 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
print("=" * 60)
print("POSITION GENERATION SUMMARY")
print("=" * 60)
total_filtered = filtered_check + filtered_captures + filtered_game_over
print(f"Total games: {total_games}")
print(f"Saved positions: {positions_count}")
print(f"Filtered (check): {filtered_check}")
print(f"Filtered (captures): {filtered_captures}")
print(f"Filtered (in check): {filtered_check}")
print(f"Filtered (winning+ cap): {filtered_captures}")
print(f"Filtered (game over): {filtered_game_over}")
print(f"Total filtered: {filtered_check + filtered_captures + filtered_game_over}")
print(f"Total filtered: {total_filtered}")
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
print(f"(Keeps positions with only losing/bad captures)")
print("=" * 60)
print()
if positions_count == 0:
print("WARNING: No valid positions were generated!")
print("This might happen if:")
print(" - The filter criteria are too strict (captures, checks)")
print(" - Try using: --no-filter-captures to accept positions with captures")
print(" - Most positions have checks or game-over states")
print(" - Try using: --no-filter-captures to accept positions with winning captures")
return 0
return positions_count
@@ -97,7 +128,7 @@ if __name__ == "__main__":
parser.add_argument("--games", type=int, default=5000,
help="Number of games to play (default: 500000)")
parser.add_argument("--no-filter-captures", action="store_true",
help="Include positions with available captures (increases output)")
help="Include positions with winning/equal captures (increases output)")
args = parser.parse_args()