feat: enhance training and evaluation processes with new parameters and normalization options
This commit is contained in:
@@ -129,11 +129,17 @@ def train_interactive():
|
|||||||
use_existing = Confirm.ask("Use existing positions file?", default=False)
|
use_existing = Confirm.ask("Use existing positions file?", default=False)
|
||||||
positions_file = None
|
positions_file = None
|
||||||
num_games = 500000
|
num_games = 500000
|
||||||
|
samples_per_game = 1
|
||||||
|
min_move = 1
|
||||||
|
max_move = 50
|
||||||
|
|
||||||
if use_existing:
|
if use_existing:
|
||||||
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
|
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
|
||||||
else:
|
else:
|
||||||
num_games = int(Prompt.ask("Number of games to generate", default="500000"))
|
num_games = int(Prompt.ask("Number of games to generate", default="5000"))
|
||||||
|
samples_per_game = int(Prompt.ask("Positions to sample per game", default="1"))
|
||||||
|
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||||
|
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||||
|
|
||||||
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
|
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
|
||||||
labels_file = None
|
labels_file = None
|
||||||
@@ -147,6 +153,9 @@ def train_interactive():
|
|||||||
# Training parameters
|
# Training parameters
|
||||||
epochs = int(Prompt.ask("Number of epochs", default="20"))
|
epochs = int(Prompt.ask("Number of epochs", default="20"))
|
||||||
batch_size = int(Prompt.ask("Batch size", default="4096"))
|
batch_size = int(Prompt.ask("Batch size", default="4096"))
|
||||||
|
early_stopping = None
|
||||||
|
if Confirm.ask("Enable early stopping?", default=False):
|
||||||
|
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||||||
|
|
||||||
# Confirm and start
|
# Confirm and start
|
||||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||||
@@ -154,9 +163,18 @@ def train_interactive():
|
|||||||
console.print(f" Checkpoint: v{checkpoint_version}")
|
console.print(f" Checkpoint: v{checkpoint_version}")
|
||||||
else:
|
else:
|
||||||
console.print(" Checkpoint: None (training from scratch)")
|
console.print(" Checkpoint: None (training from scratch)")
|
||||||
console.print(f" Games: {num_games:,}")
|
if not use_existing:
|
||||||
|
console.print(f" Games: {num_games:,}")
|
||||||
|
console.print(f" Samples per game: {samples_per_game}")
|
||||||
|
console.print(f" Move range: {min_move}-{max_move}")
|
||||||
|
else:
|
||||||
|
console.print(f" Positions file: {positions_file}")
|
||||||
console.print(f" Epochs: {epochs}")
|
console.print(f" Epochs: {epochs}")
|
||||||
console.print(f" Batch size: {batch_size}")
|
console.print(f" Batch size: {batch_size}")
|
||||||
|
if early_stopping:
|
||||||
|
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||||||
|
else:
|
||||||
|
console.print(f" Early stopping: No")
|
||||||
console.print(f" Stockfish: {stockfish_path}")
|
console.print(f" Stockfish: {stockfish_path}")
|
||||||
|
|
||||||
if not Confirm.ask("\nStart training?", default=True):
|
if not Confirm.ask("\nStart training?", default=True):
|
||||||
@@ -174,7 +192,9 @@ def train_interactive():
|
|||||||
count = play_random_game_and_collect_positions(
|
count = play_random_game_and_collect_positions(
|
||||||
str(data_dir / "positions.txt"),
|
str(data_dir / "positions.txt"),
|
||||||
total_games=num_games,
|
total_games=num_games,
|
||||||
filter_captures=True
|
samples_per_game=samples_per_game,
|
||||||
|
min_move=min_move,
|
||||||
|
max_move=max_move
|
||||||
)
|
)
|
||||||
if count == 0:
|
if count == 0:
|
||||||
console.print("[red]✗ No valid positions generated[/red]")
|
console.print("[red]✗ No valid positions generated[/red]")
|
||||||
@@ -219,7 +239,8 @@ def train_interactive():
|
|||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
use_versioning=True
|
use_versioning=True,
|
||||||
|
early_stopping_patience=early_stopping
|
||||||
)
|
)
|
||||||
console.print("[green]✓ Training complete[/green]")
|
console.print("[green]✓ Training complete[/green]")
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Generate 500,000 random chess positions for NNUE training."""
|
"""Generate random chess positions for NNUE training with minimal filtering."""
|
||||||
|
|
||||||
import chess
|
import chess
|
||||||
import random
|
import random
|
||||||
@@ -7,89 +7,78 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# Standard piece values for capture filtering
|
def play_random_game_and_collect_positions(
|
||||||
PIECE_VALUES = {
|
output_file,
|
||||||
chess.PAWN: 1,
|
total_games=500000,
|
||||||
chess.KNIGHT: 3,
|
samples_per_game=1,
|
||||||
chess.BISHOP: 3,
|
min_move=1,
|
||||||
chess.ROOK: 5,
|
max_move=50
|
||||||
chess.QUEEN: 9,
|
):
|
||||||
}
|
"""Play random games and sample multiple positions from each.
|
||||||
|
|
||||||
def has_winning_or_equal_capture(board):
|
Args:
|
||||||
"""Check if position has a capture where victim >= attacker (winning or equal trade).
|
output_file: Output file for positions
|
||||||
|
total_games: Number of games to play
|
||||||
Returns True only if there's at least one favorable capture.
|
samples_per_game: Number of positions to sample per game (1-N)
|
||||||
Positions with only losing captures return False (are kept).
|
min_move: Minimum move number to start sampling from
|
||||||
"""
|
max_move: Maximum move number for sampling
|
||||||
for move in board.legal_moves:
|
|
||||||
if board.is_capture(move):
|
|
||||||
attacker_piece = board.piece_at(move.from_square)
|
|
||||||
victim_piece = board.piece_at(move.to_square)
|
|
||||||
|
|
||||||
if attacker_piece and victim_piece:
|
|
||||||
attacker_value = PIECE_VALUES.get(attacker_piece.piece_type, 0)
|
|
||||||
victim_value = PIECE_VALUES.get(victim_piece.piece_type, 0)
|
|
||||||
|
|
||||||
# If victim >= attacker, it's a winning or equal capture
|
|
||||||
if victim_value >= attacker_value:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
|
|
||||||
"""Play random games and save positions after 8-20 random moves.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of valid positions saved
|
Number of valid positions saved
|
||||||
"""
|
"""
|
||||||
positions_count = 0
|
positions_count = 0
|
||||||
filtered_check = 0
|
|
||||||
filtered_captures = 0
|
|
||||||
filtered_game_over = 0
|
filtered_game_over = 0
|
||||||
|
filtered_illegal = 0
|
||||||
|
|
||||||
with open(output_file, 'w') as f:
|
with open(output_file, 'w') as f:
|
||||||
with tqdm(total=total_games, desc="Generating positions") as pbar:
|
with tqdm(total=total_games, desc="Generating positions") as pbar:
|
||||||
for game_num in range(total_games):
|
for game_num in range(total_games):
|
||||||
board = chess.Board()
|
board = chess.Board()
|
||||||
|
move_history = []
|
||||||
|
|
||||||
# Play 8-20 random opening moves
|
# Play a complete random game
|
||||||
num_moves = random.randint(8, 20)
|
while not board.is_game_over() and len(move_history) < 200:
|
||||||
|
|
||||||
for move_num in range(num_moves):
|
|
||||||
if board.is_game_over():
|
|
||||||
break
|
|
||||||
|
|
||||||
legal_moves = list(board.legal_moves)
|
legal_moves = list(board.legal_moves)
|
||||||
if not legal_moves:
|
if not legal_moves:
|
||||||
break
|
break
|
||||||
|
|
||||||
move = random.choice(legal_moves)
|
move = random.choice(legal_moves)
|
||||||
board.push(move)
|
board.push(move)
|
||||||
|
move_history.append(board.copy())
|
||||||
|
|
||||||
# Skip if game over
|
# Determine the range of moves to sample from
|
||||||
if board.is_game_over():
|
game_length = len(move_history)
|
||||||
filtered_game_over += 1
|
valid_start = max(min_move, 0)
|
||||||
|
valid_end = min(max_move, game_length)
|
||||||
|
|
||||||
|
if valid_start >= valid_end:
|
||||||
|
# Game too short
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Skip if in check
|
# Randomly sample positions from this game
|
||||||
if board.is_check():
|
sample_count = min(samples_per_game, valid_end - valid_start)
|
||||||
filtered_check += 1
|
if sample_count > 0:
|
||||||
pbar.update(1)
|
sample_indices = random.sample(
|
||||||
continue
|
range(valid_start, valid_end),
|
||||||
|
k=sample_count
|
||||||
|
)
|
||||||
|
|
||||||
# Check if there are winning or equal captures (if filtering enabled)
|
for idx in sample_indices:
|
||||||
if filter_captures:
|
sampled_board = move_history[idx]
|
||||||
if has_winning_or_equal_capture(board):
|
|
||||||
filtered_captures += 1
|
|
||||||
pbar.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Save valid position
|
# Only filter truly invalid or terminal positions
|
||||||
fen = board.fen()
|
if not sampled_board.is_valid():
|
||||||
f.write(fen + '\n')
|
filtered_illegal += 1
|
||||||
positions_count += 1
|
continue
|
||||||
|
|
||||||
|
if sampled_board.is_game_over():
|
||||||
|
filtered_game_over += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Save position (include check, captures, all positions)
|
||||||
|
fen = sampled_board.fen()
|
||||||
|
f.write(fen + '\n')
|
||||||
|
positions_count += 1
|
||||||
|
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
@@ -98,23 +87,19 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
|
|||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("POSITION GENERATION SUMMARY")
|
print("POSITION GENERATION SUMMARY")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
total_filtered = filtered_check + filtered_captures + filtered_game_over
|
print(f"Total games played: {total_games}")
|
||||||
print(f"Total games: {total_games}")
|
print(f"Samples per game: {samples_per_game}")
|
||||||
|
print(f"Move range: {min_move}-{max_move}")
|
||||||
print(f"Saved positions: {positions_count}")
|
print(f"Saved positions: {positions_count}")
|
||||||
print(f"Filtered (in check): {filtered_check}")
|
|
||||||
print(f"Filtered (winning+ cap): {filtered_captures}")
|
|
||||||
print(f"Filtered (game over): {filtered_game_over}")
|
print(f"Filtered (game over): {filtered_game_over}")
|
||||||
print(f"Total filtered: {total_filtered}")
|
print(f"Filtered (illegal): {filtered_illegal}")
|
||||||
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
|
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
|
||||||
print(f"(Keeps positions with only losing/bad captures)")
|
print(f"(Includes checks, captures, all realistic positions)")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
if positions_count == 0:
|
if positions_count == 0:
|
||||||
print("WARNING: No valid positions were generated!")
|
print("WARNING: No valid positions were generated!")
|
||||||
print("This might happen if:")
|
|
||||||
print(" - Most positions have checks or game-over states")
|
|
||||||
print(" - Try using: --no-filter-captures to accept positions with winning captures")
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
return positions_count
|
return positions_count
|
||||||
@@ -126,16 +111,22 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("output_file", nargs="?", default="positions.txt",
|
parser.add_argument("output_file", nargs="?", default="positions.txt",
|
||||||
help="Output file for positions (default: positions.txt)")
|
help="Output file for positions (default: positions.txt)")
|
||||||
parser.add_argument("--games", type=int, default=5000,
|
parser.add_argument("--games", type=int, default=5000,
|
||||||
help="Number of games to play (default: 500000)")
|
help="Number of games to play (default: 5000)")
|
||||||
parser.add_argument("--no-filter-captures", action="store_true",
|
parser.add_argument("--samples-per-game", type=int, default=1,
|
||||||
help="Include positions with winning/equal captures (increases output)")
|
help="Number of positions to sample per game (default: 1)")
|
||||||
|
parser.add_argument("--min-move", type=int, default=1,
|
||||||
|
help="Minimum move number to sample from (default: 1)")
|
||||||
|
parser.add_argument("--max-move", type=int, default=50,
|
||||||
|
help="Maximum move number to sample from (default: 50)")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
count = play_random_game_and_collect_positions(
|
count = play_random_game_and_collect_positions(
|
||||||
output_file=args.output_file,
|
output_file=args.output_file,
|
||||||
total_games=args.games,
|
total_games=args.games,
|
||||||
filter_captures=not args.no_filter_captures
|
samples_per_game=args.samples_per_game,
|
||||||
|
min_move=args.min_move,
|
||||||
|
max_move=args.max_move
|
||||||
)
|
)
|
||||||
|
|
||||||
sys.exit(0 if count > 0 else 1)
|
sys.exit(0 if count > 0 else 1)
|
||||||
|
|||||||
@@ -1,14 +1,33 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Label positions with Stockfish evaluations."""
|
"""Label positions with Stockfish evaluations and analyze distribution."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import chess.engine
|
import chess.engine
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False):
|
def normalize_evaluation(cp_value, method='tanh', scale=300.0):
|
||||||
|
"""Normalize centipawn evaluation to a bounded range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cp_value: Centipawn evaluation from Stockfish
|
||||||
|
method: 'tanh' (default) or 'sigmoid'
|
||||||
|
scale: Scale factor (tanh: 300 is typical)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized value in approximately [-1, 1] (tanh) or [0, 1] (sigmoid)
|
||||||
|
"""
|
||||||
|
if method == 'tanh':
|
||||||
|
return np.tanh(cp_value / scale)
|
||||||
|
elif method == 'sigmoid':
|
||||||
|
return 1.0 / (1.0 + np.exp(-cp_value / scale))
|
||||||
|
else:
|
||||||
|
return cp_value / 100.0
|
||||||
|
|
||||||
|
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False, normalize=True):
|
||||||
"""Read positions and label them with Stockfish evaluations.
|
"""Read positions and label them with Stockfish evaluations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -18,6 +37,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
|||||||
batch_size: Batch size (not used, kept for compatibility)
|
batch_size: Batch size (not used, kept for compatibility)
|
||||||
depth: Stockfish depth
|
depth: Stockfish depth
|
||||||
verbose: Print detailed error messages
|
verbose: Print detailed error messages
|
||||||
|
normalize: If True, normalize evals using tanh
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Check if stockfish exists
|
# Check if stockfish exists
|
||||||
@@ -75,6 +95,8 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
|||||||
skipped_invalid = 0
|
skipped_invalid = 0
|
||||||
skipped_duplicate = 0
|
skipped_duplicate = 0
|
||||||
errors = 0
|
errors = 0
|
||||||
|
raw_evals = []
|
||||||
|
normalized_evals = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(positions_file, 'r') as f:
|
with open(positions_file, 'r') as f:
|
||||||
@@ -123,8 +145,15 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
|||||||
# Clamp to [-2000, 2000]
|
# Clamp to [-2000, 2000]
|
||||||
eval_cp = max(-2000, min(2000, eval_cp))
|
eval_cp = max(-2000, min(2000, eval_cp))
|
||||||
|
|
||||||
# Save evaluation
|
# Normalize evaluation
|
||||||
data = {"fen": fen, "eval": eval_cp}
|
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
|
||||||
|
|
||||||
|
# Track statistics
|
||||||
|
raw_evals.append(eval_cp)
|
||||||
|
normalized_evals.append(eval_normalized)
|
||||||
|
|
||||||
|
# Save evaluation (normalized if requested)
|
||||||
|
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
|
||||||
out.write(json.dumps(data) + '\n')
|
out.write(json.dumps(data) + '\n')
|
||||||
out.flush() # Force write to disk
|
out.flush() # Force write to disk
|
||||||
evaluated += 1
|
evaluated += 1
|
||||||
@@ -142,7 +171,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
|||||||
finally:
|
finally:
|
||||||
engine.quit()
|
engine.quit()
|
||||||
|
|
||||||
# Print summary
|
# Print summary and analysis
|
||||||
print()
|
print()
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("LABELING SUMMARY")
|
print("LABELING SUMMARY")
|
||||||
@@ -164,6 +193,51 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
|||||||
print(" 4. Stockfish path is correct")
|
print(" 4. Stockfish path is correct")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Print distribution analysis
|
||||||
|
if raw_evals:
|
||||||
|
raw_evals_arr = np.array(raw_evals)
|
||||||
|
norm_evals_arr = np.array(normalized_evals)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("EVALUATION DISTRIBUTION ANALYSIS")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
print("Raw Evaluations (centipawns):")
|
||||||
|
print(f" Min: {raw_evals_arr.min():.1f}")
|
||||||
|
print(f" Max: {raw_evals_arr.max():.1f}")
|
||||||
|
print(f" Mean: {raw_evals_arr.mean():.1f}")
|
||||||
|
print(f" Median: {np.median(raw_evals_arr):.1f}")
|
||||||
|
print(f" Std: {raw_evals_arr.std():.1f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("Normalized Evaluations (tanh):")
|
||||||
|
print(f" Min: {norm_evals_arr.min():.4f}")
|
||||||
|
print(f" Max: {norm_evals_arr.max():.4f}")
|
||||||
|
print(f" Mean: {norm_evals_arr.mean():.4f}")
|
||||||
|
print(f" Median: {np.median(norm_evals_arr):.4f}")
|
||||||
|
print(f" Std: {norm_evals_arr.std():.4f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Distribution buckets
|
||||||
|
print("Raw Evaluation Buckets (counts):")
|
||||||
|
buckets = [
|
||||||
|
(-float('inf'), -500, "< -5.00"),
|
||||||
|
(-500, -300, "[-5.00, -3.00)"),
|
||||||
|
(-300, -100, "[-3.00, -1.00)"),
|
||||||
|
(-100, 0, "[-1.00, 0.00)"),
|
||||||
|
(0, 100, "[0.00, 1.00)"),
|
||||||
|
(100, 300, "[1.00, 3.00)"),
|
||||||
|
(300, 500, "[3.00, 5.00)"),
|
||||||
|
(500, float('inf'), "> 5.00"),
|
||||||
|
]
|
||||||
|
for low, high, label in buckets:
|
||||||
|
count = np.sum((raw_evals_arr > low) & (raw_evals_arr <= high))
|
||||||
|
pct = 100.0 * count / len(raw_evals_arr)
|
||||||
|
print(f" {label}: {count:6d} ({pct:5.1f}%)")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
print(f"✓ Labeling complete. Output saved to {output_file}")
|
print(f"✓ Labeling complete. Output saved to {output_file}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -179,6 +253,8 @@ if __name__ == "__main__":
|
|||||||
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
||||||
parser.add_argument("--depth", type=int, default=12,
|
parser.add_argument("--depth", type=int, default=12,
|
||||||
help="Stockfish depth (default: 12)")
|
help="Stockfish depth (default: 12)")
|
||||||
|
parser.add_argument("--no-normalize", action="store_true",
|
||||||
|
help="Disable evaluation normalization (keep raw centipawns)")
|
||||||
parser.add_argument("--verbose", action="store_true",
|
parser.add_argument("--verbose", action="store_true",
|
||||||
help="Print detailed error messages")
|
help="Print detailed error messages")
|
||||||
|
|
||||||
@@ -192,6 +268,7 @@ if __name__ == "__main__":
|
|||||||
output_file=args.output_file,
|
output_file=args.output_file,
|
||||||
stockfish_path=stockfish_path,
|
stockfish_path=stockfish_path,
|
||||||
depth=args.depth,
|
depth=args.depth,
|
||||||
|
normalize=not args.no_normalize,
|
||||||
verbose=args.verbose
|
verbose=args.verbose
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
|||||||
import chess
|
import chess
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import re
|
import re
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
class NNUEDataset(Dataset):
|
class NNUEDataset(Dataset):
|
||||||
"""Dataset of chess positions with evaluations."""
|
"""Dataset of chess positions with evaluations."""
|
||||||
@@ -19,15 +20,28 @@ class NNUEDataset(Dataset):
|
|||||||
def __init__(self, data_file):
|
def __init__(self, data_file):
|
||||||
self.positions = []
|
self.positions = []
|
||||||
self.evals = []
|
self.evals = []
|
||||||
|
self.evals_raw = []
|
||||||
|
self.is_normalized = None
|
||||||
|
|
||||||
with open(data_file, 'r') as f:
|
with open(data_file, 'r') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
try:
|
try:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
fen = data['fen']
|
fen = data['fen']
|
||||||
eval_cp = data['eval']
|
eval_val = data['eval']
|
||||||
self.positions.append(fen)
|
self.positions.append(fen)
|
||||||
self.evals.append(eval_cp)
|
self.evals.append(eval_val)
|
||||||
|
|
||||||
|
# Check if normalized or raw
|
||||||
|
if self.is_normalized is None:
|
||||||
|
# If eval is in range [-1, 1], assume normalized
|
||||||
|
self.is_normalized = abs(eval_val) <= 1.0
|
||||||
|
|
||||||
|
# Store raw if available
|
||||||
|
if 'eval_raw' in data:
|
||||||
|
self.evals_raw.append(data['eval_raw'])
|
||||||
|
else:
|
||||||
|
self.evals_raw.append(eval_val)
|
||||||
except (json.JSONDecodeError, KeyError):
|
except (json.JSONDecodeError, KeyError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -36,9 +50,15 @@ class NNUEDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
fen = self.positions[idx]
|
fen = self.positions[idx]
|
||||||
eval_cp = self.evals[idx]
|
eval_val = self.evals[idx]
|
||||||
features = fen_to_features(fen)
|
features = fen_to_features(fen)
|
||||||
target = torch.sigmoid(torch.tensor(eval_cp / 400.0, dtype=torch.float32))
|
|
||||||
|
# Use evaluation as-is if normalized, otherwise apply sigmoid scaling
|
||||||
|
if self.is_normalized:
|
||||||
|
target = torch.tensor(eval_val, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32))
|
||||||
|
|
||||||
return features, target
|
return features, target
|
||||||
|
|
||||||
def fen_to_features(fen):
|
def fen_to_features(fen):
|
||||||
@@ -122,7 +142,7 @@ def save_metadata(weights_file, metadata):
|
|||||||
|
|
||||||
return metadata_file
|
return metadata_file
|
||||||
|
|
||||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True):
|
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None):
|
||||||
"""Train the NNUE model.
|
"""Train the NNUE model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -134,12 +154,28 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
|
|||||||
checkpoint: Optional path to checkpoint file to resume from
|
checkpoint: Optional path to checkpoint file to resume from
|
||||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||||
|
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print("Loading dataset...")
|
print("Loading dataset...")
|
||||||
dataset = NNUEDataset(data_file)
|
dataset = NNUEDataset(data_file)
|
||||||
num_positions = len(dataset)
|
num_positions = len(dataset)
|
||||||
print(f"Dataset size: {num_positions}")
|
print(f"Dataset size: {num_positions}")
|
||||||
|
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
|
||||||
|
|
||||||
|
# Print dataset diagnostics
|
||||||
|
evals_array = np.array(dataset.evals)
|
||||||
|
print()
|
||||||
|
print("=" * 60)
|
||||||
|
print("TRAINING DATASET DIAGNOSTICS")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Min evaluation: {evals_array.min():.4f}")
|
||||||
|
print(f"Max evaluation: {evals_array.max():.4f}")
|
||||||
|
print(f"Mean evaluation: {evals_array.mean():.4f}")
|
||||||
|
print(f"Median evaluation: {np.median(evals_array):.4f}")
|
||||||
|
print(f"Std deviation: {evals_array.std():.4f}")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
# Split 90% train, 10% validation
|
# Split 90% train, 10% validation
|
||||||
train_size = int(0.9 * len(dataset))
|
train_size = int(0.9 * len(dataset))
|
||||||
@@ -179,8 +215,11 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
|
|||||||
|
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
best_model_state = None
|
best_model_state = None
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
|
||||||
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...")
|
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...")
|
||||||
|
if early_stopping_patience:
|
||||||
|
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
training_start_time = datetime.now()
|
training_start_time = datetime.now()
|
||||||
@@ -228,6 +267,14 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
|
|||||||
if val_loss < best_val_loss:
|
if val_loss < best_val_loss:
|
||||||
best_val_loss = val_loss
|
best_val_loss = val_loss
|
||||||
best_model_state = model.state_dict().copy()
|
best_model_state = model.state_dict().copy()
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
else:
|
||||||
|
epochs_without_improvement += 1
|
||||||
|
|
||||||
|
# Early stopping check
|
||||||
|
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
|
||||||
|
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
|
||||||
|
break
|
||||||
|
|
||||||
# Save best model
|
# Save best model
|
||||||
if best_model_state is not None:
|
if best_model_state is not None:
|
||||||
@@ -286,6 +333,8 @@ if __name__ == "__main__":
|
|||||||
help="Batch size (default: 4096)")
|
help="Batch size (default: 4096)")
|
||||||
parser.add_argument("--lr", type=float, default=1e-3,
|
parser.add_argument("--lr", type=float, default=1e-3,
|
||||||
help="Learning rate (default: 1e-3)")
|
help="Learning rate (default: 1e-3)")
|
||||||
|
parser.add_argument("--early-stopping", type=int, default=None,
|
||||||
|
help="Stop if val loss doesn't improve for N epochs (optional)")
|
||||||
parser.add_argument("--stockfish-depth", type=int, default=12,
|
parser.add_argument("--stockfish-depth", type=int, default=12,
|
||||||
help="Stockfish depth used for evaluations (for metadata, default: 12)")
|
help="Stockfish depth used for evaluations (for metadata, default: 12)")
|
||||||
parser.add_argument("--no-versioning", action="store_true",
|
parser.add_argument("--no-versioning", action="store_true",
|
||||||
@@ -301,5 +350,6 @@ if __name__ == "__main__":
|
|||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
checkpoint=args.checkpoint,
|
checkpoint=args.checkpoint,
|
||||||
stockfish_depth=args.stockfish_depth,
|
stockfish_depth=args.stockfish_depth,
|
||||||
use_versioning=not args.no_versioning
|
use_versioning=not args.no_versioning,
|
||||||
|
early_stopping_patience=args.early_stopping
|
||||||
)
|
)
|
||||||
|
|||||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"version": 3,
|
||||||
|
"date": "2026-04-08T09:43:28.000579",
|
||||||
|
"num_positions": 71610,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"epochs": 20,
|
||||||
|
"batch_size": 4096,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"final_val_loss": 0.006398905136849695,
|
||||||
|
"device": "cpu",
|
||||||
|
"checkpoint": "/home/janis/Workspaces/IntelliJ/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v2.pt",
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||||
|
}
|
||||||
@@ -52,7 +52,7 @@ object FenParser extends GameContextImport:
|
|||||||
)
|
)
|
||||||
else None
|
else None
|
||||||
|
|
||||||
/** Parse en passant target square ("-" for none, or algebraic like "e3"). */
|
/** Parse en passant target square ("-" for none, or algebraic like "e3"). */
|
||||||
private def parseEnPassant(s: String): Option[Option[Square]] =
|
private def parseEnPassant(s: String): Option[Option[Square]] =
|
||||||
if s == "-" then Some(None)
|
if s == "-" then Some(None)
|
||||||
else Square.fromAlgebraic(s).map(Some(_))
|
else Square.fromAlgebraic(s).map(Some(_))
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ object Main:
|
|||||||
|
|
||||||
val book = PolyglotBook("../../modules/bot/codekiddy.bin")
|
val book = PolyglotBook("../../modules/bot/codekiddy.bin")
|
||||||
|
|
||||||
engine.setOpponentBot(NNUEBot(BotDifficulty.Easy, book = Some(book)), Black);
|
engine.setOpponentBot(ClassicalBot(BotDifficulty.Easy, book = Some(book)), Black);
|
||||||
|
|
||||||
// Launch ScalaFX GUI in separate thread
|
// Launch ScalaFX GUI in separate thread
|
||||||
ChessGUILauncher.launch(engine)
|
ChessGUILauncher.launch(engine)
|
||||||
|
|||||||
Reference in New Issue
Block a user