feat: enhance training and evaluation processes with new parameters and normalization options

This commit is contained in:
2026-04-08 10:21:38 +02:00
committed by Janis
parent 34945e4fb8
commit a5560285fd
8 changed files with 242 additions and 90 deletions
+25 -4
View File
@@ -129,11 +129,17 @@ def train_interactive():
use_existing = Confirm.ask("Use existing positions file?", default=False)
positions_file = None
num_games = 500000
samples_per_game = 1
min_move = 1
max_move = 50
if use_existing:
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
else:
num_games = int(Prompt.ask("Number of games to generate", default="500000"))
num_games = int(Prompt.ask("Number of games to generate", default="5000"))
samples_per_game = int(Prompt.ask("Positions to sample per game", default="1"))
min_move = int(Prompt.ask("Minimum move number", default="1"))
max_move = int(Prompt.ask("Maximum move number", default="50"))
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
labels_file = None
@@ -147,6 +153,9 @@ def train_interactive():
# Training parameters
epochs = int(Prompt.ask("Number of epochs", default="20"))
batch_size = int(Prompt.ask("Batch size", default="4096"))
early_stopping = None
if Confirm.ask("Enable early stopping?", default=False):
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
# Confirm and start
console.print("\n[bold]Configuration Summary:[/bold]")
@@ -154,9 +163,18 @@ def train_interactive():
console.print(f" Checkpoint: v{checkpoint_version}")
else:
console.print(" Checkpoint: None (training from scratch)")
console.print(f" Games: {num_games:,}")
if not use_existing:
console.print(f" Games: {num_games:,}")
console.print(f" Samples per game: {samples_per_game}")
console.print(f" Move range: {min_move}-{max_move}")
else:
console.print(f" Positions file: {positions_file}")
console.print(f" Epochs: {epochs}")
console.print(f" Batch size: {batch_size}")
if early_stopping:
console.print(f" Early stopping: Yes (patience: {early_stopping})")
else:
console.print(f" Early stopping: No")
console.print(f" Stockfish: {stockfish_path}")
if not Confirm.ask("\nStart training?", default=True):
@@ -174,7 +192,9 @@ def train_interactive():
count = play_random_game_and_collect_positions(
str(data_dir / "positions.txt"),
total_games=num_games,
filter_captures=True
samples_per_game=samples_per_game,
min_move=min_move,
max_move=max_move
)
if count == 0:
console.print("[red]✗ No valid positions generated[/red]")
@@ -219,7 +239,8 @@ def train_interactive():
epochs=epochs,
batch_size=batch_size,
checkpoint=checkpoint,
use_versioning=True
use_versioning=True,
early_stopping_patience=early_stopping
)
console.print("[green]✓ Training complete[/green]")
+64 -73
View File
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
"""Generate 500,000 random chess positions for NNUE training."""
"""Generate random chess positions for NNUE training with minimal filtering."""
import chess
import random
@@ -7,89 +7,78 @@ import sys
from pathlib import Path
from tqdm import tqdm
# Standard piece values for capture filtering
PIECE_VALUES = {
chess.PAWN: 1,
chess.KNIGHT: 3,
chess.BISHOP: 3,
chess.ROOK: 5,
chess.QUEEN: 9,
}
def play_random_game_and_collect_positions(
output_file,
total_games=500000,
samples_per_game=1,
min_move=1,
max_move=50
):
"""Play random games and sample multiple positions from each.
def has_winning_or_equal_capture(board):
"""Check if position has a capture where victim >= attacker (winning or equal trade).
Returns True only if there's at least one favorable capture.
Positions with only losing captures return False (are kept).
"""
for move in board.legal_moves:
if board.is_capture(move):
attacker_piece = board.piece_at(move.from_square)
victim_piece = board.piece_at(move.to_square)
if attacker_piece and victim_piece:
attacker_value = PIECE_VALUES.get(attacker_piece.piece_type, 0)
victim_value = PIECE_VALUES.get(victim_piece.piece_type, 0)
# If victim >= attacker, it's a winning or equal capture
if victim_value >= attacker_value:
return True
return False
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
"""Play random games and save positions after 8-20 random moves.
Args:
output_file: Output file for positions
total_games: Number of games to play
samples_per_game: Number of positions to sample per game (1-N)
min_move: Minimum move number to start sampling from
max_move: Maximum move number for sampling
Returns:
Number of valid positions saved
"""
positions_count = 0
filtered_check = 0
filtered_captures = 0
filtered_game_over = 0
filtered_illegal = 0
with open(output_file, 'w') as f:
with tqdm(total=total_games, desc="Generating positions") as pbar:
for game_num in range(total_games):
board = chess.Board()
move_history = []
# Play 8-20 random opening moves
num_moves = random.randint(8, 20)
for move_num in range(num_moves):
if board.is_game_over():
break
# Play a complete random game
while not board.is_game_over() and len(move_history) < 200:
legal_moves = list(board.legal_moves)
if not legal_moves:
break
move = random.choice(legal_moves)
board.push(move)
move_history.append(board.copy())
# Skip if game over
if board.is_game_over():
filtered_game_over += 1
# Determine the range of moves to sample from
game_length = len(move_history)
valid_start = max(min_move, 0)
valid_end = min(max_move, game_length)
if valid_start >= valid_end:
# Game too short
pbar.update(1)
continue
# Skip if in check
if board.is_check():
filtered_check += 1
pbar.update(1)
continue
# Randomly sample positions from this game
sample_count = min(samples_per_game, valid_end - valid_start)
if sample_count > 0:
sample_indices = random.sample(
range(valid_start, valid_end),
k=sample_count
)
# Check if there are winning or equal captures (if filtering enabled)
if filter_captures:
if has_winning_or_equal_capture(board):
filtered_captures += 1
pbar.update(1)
continue
for idx in sample_indices:
sampled_board = move_history[idx]
# Save valid position
fen = board.fen()
f.write(fen + '\n')
positions_count += 1
# Only filter truly invalid or terminal positions
if not sampled_board.is_valid():
filtered_illegal += 1
continue
if sampled_board.is_game_over():
filtered_game_over += 1
continue
# Save position (include check, captures, all positions)
fen = sampled_board.fen()
f.write(fen + '\n')
positions_count += 1
pbar.update(1)
@@ -98,23 +87,19 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
print("=" * 60)
print("POSITION GENERATION SUMMARY")
print("=" * 60)
total_filtered = filtered_check + filtered_captures + filtered_game_over
print(f"Total games: {total_games}")
print(f"Total games played: {total_games}")
print(f"Samples per game: {samples_per_game}")
print(f"Move range: {min_move}-{max_move}")
print(f"Saved positions: {positions_count}")
print(f"Filtered (in check): {filtered_check}")
print(f"Filtered (winning+ cap): {filtered_captures}")
print(f"Filtered (game over): {filtered_game_over}")
print(f"Total filtered: {total_filtered}")
print(f"Filtered (illegal): {filtered_illegal}")
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
print(f"(Keeps positions with only losing/bad captures)")
print(f"(Includes checks, captures, all realistic positions)")
print("=" * 60)
print()
if positions_count == 0:
print("WARNING: No valid positions were generated!")
print("This might happen if:")
print(" - Most positions have checks or game-over states")
print(" - Try using: --no-filter-captures to accept positions with winning captures")
return 0
return positions_count
@@ -126,16 +111,22 @@ if __name__ == "__main__":
parser.add_argument("output_file", nargs="?", default="positions.txt",
help="Output file for positions (default: positions.txt)")
parser.add_argument("--games", type=int, default=5000,
help="Number of games to play (default: 500000)")
parser.add_argument("--no-filter-captures", action="store_true",
help="Include positions with winning/equal captures (increases output)")
help="Number of games to play (default: 5000)")
parser.add_argument("--samples-per-game", type=int, default=1,
help="Number of positions to sample per game (default: 1)")
parser.add_argument("--min-move", type=int, default=1,
help="Minimum move number to sample from (default: 1)")
parser.add_argument("--max-move", type=int, default=50,
help="Maximum move number to sample from (default: 50)")
args = parser.parse_args()
count = play_random_game_and_collect_positions(
output_file=args.output_file,
total_games=args.games,
filter_captures=not args.no_filter_captures
samples_per_game=args.samples_per_game,
min_move=args.min_move,
max_move=args.max_move
)
sys.exit(0 if count > 0 else 1)
+82 -5
View File
@@ -1,14 +1,33 @@
#!/usr/bin/env python3
"""Label positions with Stockfish evaluations."""
"""Label positions with Stockfish evaluations and analyze distribution."""
import json
import chess.engine
import sys
import os
import numpy as np
from pathlib import Path
from tqdm import tqdm
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False):
def normalize_evaluation(cp_value, method='tanh', scale=300.0):
"""Normalize centipawn evaluation to a bounded range.
Args:
cp_value: Centipawn evaluation from Stockfish
method: 'tanh' (default) or 'sigmoid'
scale: Scale factor (tanh: 300 is typical)
Returns:
Normalized value in approximately [-1, 1] (tanh) or [0, 1] (sigmoid)
"""
if method == 'tanh':
return np.tanh(cp_value / scale)
elif method == 'sigmoid':
return 1.0 / (1.0 + np.exp(-cp_value / scale))
else:
return cp_value / 100.0
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False, normalize=True):
"""Read positions and label them with Stockfish evaluations.
Args:
@@ -18,6 +37,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
batch_size: Batch size (not used, kept for compatibility)
depth: Stockfish depth
verbose: Print detailed error messages
normalize: If True, normalize evals using tanh
"""
# Check if stockfish exists
@@ -75,6 +95,8 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
skipped_invalid = 0
skipped_duplicate = 0
errors = 0
raw_evals = []
normalized_evals = []
try:
with open(positions_file, 'r') as f:
@@ -123,8 +145,15 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
# Clamp to [-2000, 2000]
eval_cp = max(-2000, min(2000, eval_cp))
# Save evaluation
data = {"fen": fen, "eval": eval_cp}
# Normalize evaluation
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
# Track statistics
raw_evals.append(eval_cp)
normalized_evals.append(eval_normalized)
# Save evaluation (normalized if requested)
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
out.write(json.dumps(data) + '\n')
out.flush() # Force write to disk
evaluated += 1
@@ -142,7 +171,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
finally:
engine.quit()
# Print summary
# Print summary and analysis
print()
print("=" * 60)
print("LABELING SUMMARY")
@@ -164,6 +193,51 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
print(" 4. Stockfish path is correct")
return False
# Print distribution analysis
if raw_evals:
raw_evals_arr = np.array(raw_evals)
norm_evals_arr = np.array(normalized_evals)
print("=" * 60)
print("EVALUATION DISTRIBUTION ANALYSIS")
print("=" * 60)
print()
print("Raw Evaluations (centipawns):")
print(f" Min: {raw_evals_arr.min():.1f}")
print(f" Max: {raw_evals_arr.max():.1f}")
print(f" Mean: {raw_evals_arr.mean():.1f}")
print(f" Median: {np.median(raw_evals_arr):.1f}")
print(f" Std: {raw_evals_arr.std():.1f}")
print()
print("Normalized Evaluations (tanh):")
print(f" Min: {norm_evals_arr.min():.4f}")
print(f" Max: {norm_evals_arr.max():.4f}")
print(f" Mean: {norm_evals_arr.mean():.4f}")
print(f" Median: {np.median(norm_evals_arr):.4f}")
print(f" Std: {norm_evals_arr.std():.4f}")
print()
# Distribution buckets
print("Raw Evaluation Buckets (counts):")
buckets = [
(-float('inf'), -500, "< -5.00"),
(-500, -300, "[-5.00, -3.00)"),
(-300, -100, "[-3.00, -1.00)"),
(-100, 0, "[-1.00, 0.00)"),
(0, 100, "[0.00, 1.00)"),
(100, 300, "[1.00, 3.00)"),
(300, 500, "[3.00, 5.00)"),
(500, float('inf'), "> 5.00"),
]
for low, high, label in buckets:
count = np.sum((raw_evals_arr > low) & (raw_evals_arr <= high))
pct = 100.0 * count / len(raw_evals_arr)
print(f" {label}: {count:6d} ({pct:5.1f}%)")
print("=" * 60)
print()
print(f"✓ Labeling complete. Output saved to {output_file}")
return True
@@ -179,6 +253,8 @@ if __name__ == "__main__":
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
parser.add_argument("--depth", type=int, default=12,
help="Stockfish depth (default: 12)")
parser.add_argument("--no-normalize", action="store_true",
help="Disable evaluation normalization (keep raw centipawns)")
parser.add_argument("--verbose", action="store_true",
help="Print detailed error messages")
@@ -192,6 +268,7 @@ if __name__ == "__main__":
output_file=args.output_file,
stockfish_path=stockfish_path,
depth=args.depth,
normalize=not args.no_normalize,
verbose=args.verbose
)
+56 -6
View File
@@ -12,6 +12,7 @@ from tqdm import tqdm
import chess
from datetime import datetime
import re
import numpy as np
class NNUEDataset(Dataset):
"""Dataset of chess positions with evaluations."""
@@ -19,15 +20,28 @@ class NNUEDataset(Dataset):
def __init__(self, data_file):
self.positions = []
self.evals = []
self.evals_raw = []
self.is_normalized = None
with open(data_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
fen = data['fen']
eval_cp = data['eval']
eval_val = data['eval']
self.positions.append(fen)
self.evals.append(eval_cp)
self.evals.append(eval_val)
# Check if normalized or raw
if self.is_normalized is None:
# If eval is in range [-1, 1], assume normalized
self.is_normalized = abs(eval_val) <= 1.0
# Store raw if available
if 'eval_raw' in data:
self.evals_raw.append(data['eval_raw'])
else:
self.evals_raw.append(eval_val)
except (json.JSONDecodeError, KeyError):
pass
@@ -36,9 +50,15 @@ class NNUEDataset(Dataset):
def __getitem__(self, idx):
fen = self.positions[idx]
eval_cp = self.evals[idx]
eval_val = self.evals[idx]
features = fen_to_features(fen)
target = torch.sigmoid(torch.tensor(eval_cp / 400.0, dtype=torch.float32))
# Use evaluation as-is if normalized, otherwise apply sigmoid scaling
if self.is_normalized:
target = torch.tensor(eval_val, dtype=torch.float32)
else:
target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32))
return features, target
def fen_to_features(fen):
@@ -122,7 +142,7 @@ def save_metadata(weights_file, metadata):
return metadata_file
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True):
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None):
"""Train the NNUE model.
Args:
@@ -134,12 +154,28 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
"""
print("Loading dataset...")
dataset = NNUEDataset(data_file)
num_positions = len(dataset)
print(f"Dataset size: {num_positions}")
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
# Print dataset diagnostics
evals_array = np.array(dataset.evals)
print()
print("=" * 60)
print("TRAINING DATASET DIAGNOSTICS")
print("=" * 60)
print(f"Min evaluation: {evals_array.min():.4f}")
print(f"Max evaluation: {evals_array.max():.4f}")
print(f"Mean evaluation: {evals_array.mean():.4f}")
print(f"Median evaluation: {np.median(evals_array):.4f}")
print(f"Std deviation: {evals_array.std():.4f}")
print("=" * 60)
print()
# Split 90% train, 10% validation
train_size = int(0.9 * len(dataset))
@@ -179,8 +215,11 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
best_val_loss = float('inf')
best_model_state = None
epochs_without_improvement = 0
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...")
if early_stopping_patience:
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
print()
training_start_time = datetime.now()
@@ -228,6 +267,14 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = model.state_dict().copy()
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
# Early stopping check
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
break
# Save best model
if best_model_state is not None:
@@ -286,6 +333,8 @@ if __name__ == "__main__":
help="Batch size (default: 4096)")
parser.add_argument("--lr", type=float, default=1e-3,
help="Learning rate (default: 1e-3)")
parser.add_argument("--early-stopping", type=int, default=None,
help="Stop if val loss doesn't improve for N epochs (optional)")
parser.add_argument("--stockfish-depth", type=int, default=12,
help="Stockfish depth used for evaluations (for metadata, default: 12)")
parser.add_argument("--no-versioning", action="store_true",
@@ -301,5 +350,6 @@ if __name__ == "__main__":
lr=args.lr,
checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning
use_versioning=not args.no_versioning,
early_stopping_patience=args.early_stopping
)
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 3,
"date": "2026-04-08T09:43:28.000579",
"num_positions": 71610,
"stockfish_depth": 12,
"epochs": 20,
"batch_size": 4096,
"learning_rate": 0.001,
"final_val_loss": 0.006398905136849695,
"device": "cpu",
"checkpoint": "/home/janis/Workspaces/IntelliJ/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v2.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
@@ -52,7 +52,7 @@ object FenParser extends GameContextImport:
)
else None
/** Parse en passant target square ("-" for none, or algebraic like "e3"). */
/** Parse en passant target square ("-" for none, or algebraic like "e3"). */
private def parseEnPassant(s: String): Option[Option[Square]] =
if s == "-" then Some(None)
else Square.fromAlgebraic(s).map(Some(_))
@@ -19,7 +19,7 @@ object Main:
val book = PolyglotBook("../../modules/bot/codekiddy.bin")
engine.setOpponentBot(NNUEBot(BotDifficulty.Easy, book = Some(book)), Black);
engine.setOpponentBot(ClassicalBot(BotDifficulty.Easy, book = Some(book)), Black);
// Launch ScalaFX GUI in separate thread
ChessGUILauncher.launch(engine)