feat: add rich console interface for NNUE training pipeline and update requirements
This commit is contained in:
@@ -7,6 +7,36 @@ import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
# Standard piece values for capture filtering
|
||||
PIECE_VALUES = {
|
||||
chess.PAWN: 1,
|
||||
chess.KNIGHT: 3,
|
||||
chess.BISHOP: 3,
|
||||
chess.ROOK: 5,
|
||||
chess.QUEEN: 9,
|
||||
}
|
||||
|
||||
def has_winning_or_equal_capture(board):
|
||||
"""Check if position has a capture where victim >= attacker (winning or equal trade).
|
||||
|
||||
Returns True only if there's at least one favorable capture.
|
||||
Positions with only losing captures return False (are kept).
|
||||
"""
|
||||
for move in board.legal_moves:
|
||||
if board.is_capture(move):
|
||||
attacker_piece = board.piece_at(move.from_square)
|
||||
victim_piece = board.piece_at(move.to_square)
|
||||
|
||||
if attacker_piece and victim_piece:
|
||||
attacker_value = PIECE_VALUES.get(attacker_piece.piece_type, 0)
|
||||
victim_value = PIECE_VALUES.get(victim_piece.piece_type, 0)
|
||||
|
||||
# If victim >= attacker, it's a winning or equal capture
|
||||
if victim_value >= attacker_value:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
|
||||
"""Play random games and save positions after 8-20 random moves.
|
||||
|
||||
@@ -49,10 +79,9 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# Check if any captures are available (if filtering enabled)
|
||||
# Check if there are winning or equal captures (if filtering enabled)
|
||||
if filter_captures:
|
||||
has_captures = any(board.is_capture(move) for move in board.legal_moves)
|
||||
if has_captures:
|
||||
if has_winning_or_equal_capture(board):
|
||||
filtered_captures += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
@@ -69,21 +98,23 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
|
||||
print("=" * 60)
|
||||
print("POSITION GENERATION SUMMARY")
|
||||
print("=" * 60)
|
||||
total_filtered = filtered_check + filtered_captures + filtered_game_over
|
||||
print(f"Total games: {total_games}")
|
||||
print(f"Saved positions: {positions_count}")
|
||||
print(f"Filtered (check): {filtered_check}")
|
||||
print(f"Filtered (captures): {filtered_captures}")
|
||||
print(f"Filtered (in check): {filtered_check}")
|
||||
print(f"Filtered (winning+ cap): {filtered_captures}")
|
||||
print(f"Filtered (game over): {filtered_game_over}")
|
||||
print(f"Total filtered: {filtered_check + filtered_captures + filtered_game_over}")
|
||||
print(f"Total filtered: {total_filtered}")
|
||||
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
|
||||
print(f"(Keeps positions with only losing/bad captures)")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if positions_count == 0:
|
||||
print("WARNING: No valid positions were generated!")
|
||||
print("This might happen if:")
|
||||
print(" - The filter criteria are too strict (captures, checks)")
|
||||
print(" - Try using: --no-filter-captures to accept positions with captures")
|
||||
print(" - Most positions have checks or game-over states")
|
||||
print(" - Try using: --no-filter-captures to accept positions with winning captures")
|
||||
return 0
|
||||
|
||||
return positions_count
|
||||
@@ -97,7 +128,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--games", type=int, default=5000,
|
||||
help="Number of games to play (default: 500000)")
|
||||
parser.add_argument("--no-filter-captures", action="store_true",
|
||||
help="Include positions with available captures (increases output)")
|
||||
help="Include positions with winning/equal captures (increases output)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user