diff --git a/modules/bot/python/nnue.py b/modules/bot/python/nnue.py index 156a89c..2766461 100644 --- a/modules/bot/python/nnue.py +++ b/modules/bot/python/nnue.py @@ -129,11 +129,17 @@ def train_interactive(): use_existing = Confirm.ask("Use existing positions file?", default=False) positions_file = None num_games = 500000 + samples_per_game = 1 + min_move = 1 + max_move = 50 if use_existing: positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt")) else: - num_games = int(Prompt.ask("Number of games to generate", default="500000")) + num_games = int(Prompt.ask("Number of games to generate", default="5000")) + samples_per_game = int(Prompt.ask("Positions to sample per game", default="1")) + min_move = int(Prompt.ask("Minimum move number", default="1")) + max_move = int(Prompt.ask("Maximum move number", default="50")) use_existing_labels = Confirm.ask("Use existing labels file?", default=False) labels_file = None @@ -147,6 +153,9 @@ def train_interactive(): # Training parameters epochs = int(Prompt.ask("Number of epochs", default="20")) batch_size = int(Prompt.ask("Batch size", default="4096")) + early_stopping = None + if Confirm.ask("Enable early stopping?", default=False): + early_stopping = int(Prompt.ask("Patience (epochs)", default="5")) # Confirm and start console.print("\n[bold]Configuration Summary:[/bold]") @@ -154,9 +163,18 @@ def train_interactive(): console.print(f" Checkpoint: v{checkpoint_version}") else: console.print(" Checkpoint: None (training from scratch)") - console.print(f" Games: {num_games:,}") + if not use_existing: + console.print(f" Games: {num_games:,}") + console.print(f" Samples per game: {samples_per_game}") + console.print(f" Move range: {min_move}-{max_move}") + else: + console.print(f" Positions file: {positions_file}") console.print(f" Epochs: {epochs}") console.print(f" Batch size: {batch_size}") + if early_stopping: + console.print(f" Early stopping: Yes (patience: {early_stopping})") + else: + console.print(f" Early stopping: No") console.print(f" Stockfish: {stockfish_path}") if not Confirm.ask("\nStart training?", default=True): @@ -174,7 +192,9 @@ def train_interactive(): count = play_random_game_and_collect_positions( str(data_dir / "positions.txt"), total_games=num_games, - filter_captures=True + samples_per_game=samples_per_game, + min_move=min_move, + max_move=max_move ) if count == 0: console.print("[red]✗ No valid positions generated[/red]") @@ -219,7 +239,8 @@ def train_interactive(): epochs=epochs, batch_size=batch_size, checkpoint=checkpoint, - use_versioning=True + use_versioning=True, + early_stopping_patience=early_stopping ) console.print("[green]✓ Training complete[/green]") diff --git a/modules/bot/python/src/generate.py b/modules/bot/python/src/generate.py index 8188638..017cc85 100644 --- a/modules/bot/python/src/generate.py +++ b/modules/bot/python/src/generate.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Generate 500,000 random chess positions for NNUE training.""" +"""Generate random chess positions for NNUE training with minimal filtering.""" import chess import random @@ -7,89 +7,78 @@ import sys from pathlib import Path from tqdm import tqdm -# Standard piece values for capture filtering -PIECE_VALUES = { - chess.PAWN: 1, - chess.KNIGHT: 3, - chess.BISHOP: 3, - chess.ROOK: 5, - chess.QUEEN: 9, -} +def play_random_game_and_collect_positions( + output_file, + total_games=500000, + samples_per_game=1, + min_move=1, + max_move=50 +): + """Play random games and sample multiple positions from each. -def has_winning_or_equal_capture(board): - """Check if position has a capture where victim >= attacker (winning or equal trade). - - Returns True only if there's at least one favorable capture. - Positions with only losing captures return False (are kept). - """ - for move in board.legal_moves: - if board.is_capture(move): - attacker_piece = board.piece_at(move.from_square) - victim_piece = board.piece_at(move.to_square) - - if attacker_piece and victim_piece: - attacker_value = PIECE_VALUES.get(attacker_piece.piece_type, 0) - victim_value = PIECE_VALUES.get(victim_piece.piece_type, 0) - - # If victim >= attacker, it's a winning or equal capture - if victim_value >= attacker_value: - return True - - return False - -def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True): - """Play random games and save positions after 8-20 random moves. + Args: + output_file: Output file for positions + total_games: Number of games to play + samples_per_game: Number of positions to sample per game (1-N) + min_move: Minimum move number to start sampling from + max_move: Maximum move number for sampling Returns: Number of valid positions saved """ positions_count = 0 - filtered_check = 0 - filtered_captures = 0 filtered_game_over = 0 + filtered_illegal = 0 with open(output_file, 'w') as f: with tqdm(total=total_games, desc="Generating positions") as pbar: for game_num in range(total_games): board = chess.Board() + move_history = [] - # Play 8-20 random opening moves - num_moves = random.randint(8, 20) - - for move_num in range(num_moves): - if board.is_game_over(): - break - + # Play a complete random game + while not board.is_game_over() and len(move_history) < 200: legal_moves = list(board.legal_moves) if not legal_moves: break - move = random.choice(legal_moves) board.push(move) + move_history.append(board.copy()) - # Skip if game over - if board.is_game_over(): - filtered_game_over += 1 + # Determine the range of moves to sample from + game_length = len(move_history) + valid_start = max(min_move, 0) + valid_end = min(max_move, game_length) + + if valid_start >= valid_end: + # Game too short pbar.update(1) continue - # Skip if in check - if board.is_check(): - filtered_check += 1 - pbar.update(1) - continue + # Randomly sample positions from this game + sample_count = min(samples_per_game, valid_end - valid_start) + if sample_count > 0: + sample_indices = random.sample( + range(valid_start, valid_end), + k=sample_count + ) - # Check if there are winning or equal captures (if filtering enabled) - if filter_captures: - if has_winning_or_equal_capture(board): - filtered_captures += 1 - pbar.update(1) - continue + for idx in sample_indices: + sampled_board = move_history[idx] - # Save valid position - fen = board.fen() - f.write(fen + '\n') - positions_count += 1 + # Only filter truly invalid or terminal positions + if not sampled_board.is_valid(): + filtered_illegal += 1 + continue + + if sampled_board.is_game_over(): + filtered_game_over += 1 + continue + + # Save position (include check, captures, all positions) + fen = sampled_board.fen() + f.write(fen + '\n') + positions_count += 1 pbar.update(1) @@ -98,23 +87,19 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt print("=" * 60) print("POSITION GENERATION SUMMARY") print("=" * 60) - total_filtered = filtered_check + filtered_captures + filtered_game_over - print(f"Total games: {total_games}") + print(f"Total games played: {total_games}") + print(f"Samples per game: {samples_per_game}") + print(f"Move range: {min_move}-{max_move}") print(f"Saved positions: {positions_count}") - print(f"Filtered (in check): {filtered_check}") - print(f"Filtered (winning+ cap): {filtered_captures}") print(f"Filtered (game over): {filtered_game_over}") - print(f"Total filtered: {total_filtered}") + print(f"Filtered (illegal): {filtered_illegal}") print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%") - print(f"(Keeps positions with only losing/bad captures)") + print(f"(Includes checks, captures, all realistic positions)") print("=" * 60) print() if positions_count == 0: print("WARNING: No valid positions were generated!") - print("This might happen if:") - print(" - Most positions have checks or game-over states") - print(" - Try using: --no-filter-captures to accept positions with winning captures") return 0 return positions_count @@ -126,16 +111,22 @@ if __name__ == "__main__": parser.add_argument("output_file", nargs="?", default="positions.txt", help="Output file for positions (default: positions.txt)") parser.add_argument("--games", type=int, default=5000, - help="Number of games to play (default: 500000)") - parser.add_argument("--no-filter-captures", action="store_true", - help="Include positions with winning/equal captures (increases output)") + help="Number of games to play (default: 5000)") + parser.add_argument("--samples-per-game", type=int, default=1, + help="Number of positions to sample per game (default: 1)") + parser.add_argument("--min-move", type=int, default=1, + help="Minimum move number to sample from (default: 1)") + parser.add_argument("--max-move", type=int, default=50, + help="Maximum move number to sample from (default: 50)") args = parser.parse_args() count = play_random_game_and_collect_positions( output_file=args.output_file, total_games=args.games, - filter_captures=not args.no_filter_captures + samples_per_game=args.samples_per_game, + min_move=args.min_move, + max_move=args.max_move ) sys.exit(0 if count > 0 else 1) diff --git a/modules/bot/python/src/label.py b/modules/bot/python/src/label.py index e5352b8..6a87c0a 100644 --- a/modules/bot/python/src/label.py +++ b/modules/bot/python/src/label.py @@ -1,14 +1,33 @@ #!/usr/bin/env python3 -"""Label positions with Stockfish evaluations.""" +"""Label positions with Stockfish evaluations and analyze distribution.""" import json import chess.engine import sys import os +import numpy as np from pathlib import Path from tqdm import tqdm -def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False): +def normalize_evaluation(cp_value, method='tanh', scale=300.0): + """Normalize centipawn evaluation to a bounded range. + + Args: + cp_value: Centipawn evaluation from Stockfish + method: 'tanh' (default) or 'sigmoid' + scale: Scale factor (tanh: 300 is typical) + + Returns: + Normalized value in approximately [-1, 1] (tanh) or [0, 1] (sigmoid) + """ + if method == 'tanh': + return np.tanh(cp_value / scale) + elif method == 'sigmoid': + return 1.0 / (1.0 + np.exp(-cp_value / scale)) + else: + return cp_value / 100.0 + +def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False, normalize=True): """Read positions and label them with Stockfish evaluations. Args: @@ -18,6 +37,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size: Batch size (not used, kept for compatibility) depth: Stockfish depth verbose: Print detailed error messages + normalize: If True, normalize evals using tanh """ # Check if stockfish exists @@ -75,6 +95,8 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path, skipped_invalid = 0 skipped_duplicate = 0 errors = 0 + raw_evals = [] + normalized_evals = [] try: with open(positions_file, 'r') as f: @@ -123,8 +145,15 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path, # Clamp to [-2000, 2000] eval_cp = max(-2000, min(2000, eval_cp)) - # Save evaluation - data = {"fen": fen, "eval": eval_cp} + # Normalize evaluation + eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp + + # Track statistics + raw_evals.append(eval_cp) + normalized_evals.append(eval_normalized) + + # Save evaluation (normalized if requested) + data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp} out.write(json.dumps(data) + '\n') out.flush() # Force write to disk evaluated += 1 @@ -142,7 +171,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path, finally: engine.quit() - # Print summary + # Print summary and analysis print() print("=" * 60) print("LABELING SUMMARY") @@ -164,6 +193,51 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path, print(" 4. Stockfish path is correct") return False + # Print distribution analysis + if raw_evals: + raw_evals_arr = np.array(raw_evals) + norm_evals_arr = np.array(normalized_evals) + + print("=" * 60) + print("EVALUATION DISTRIBUTION ANALYSIS") + print("=" * 60) + print() + print("Raw Evaluations (centipawns):") + print(f" Min: {raw_evals_arr.min():.1f}") + print(f" Max: {raw_evals_arr.max():.1f}") + print(f" Mean: {raw_evals_arr.mean():.1f}") + print(f" Median: {np.median(raw_evals_arr):.1f}") + print(f" Std: {raw_evals_arr.std():.1f}") + print() + + print("Normalized Evaluations (tanh):") + print(f" Min: {norm_evals_arr.min():.4f}") + print(f" Max: {norm_evals_arr.max():.4f}") + print(f" Mean: {norm_evals_arr.mean():.4f}") + print(f" Median: {np.median(norm_evals_arr):.4f}") + print(f" Std: {norm_evals_arr.std():.4f}") + print() + + # Distribution buckets + print("Raw Evaluation Buckets (counts):") + buckets = [ + (-float('inf'), -500, "< -5.00"), + (-500, -300, "[-5.00, -3.00)"), + (-300, -100, "[-3.00, -1.00)"), + (-100, 0, "[-1.00, 0.00)"), + (0, 100, "[0.00, 1.00)"), + (100, 300, "[1.00, 3.00)"), + (300, 500, "[3.00, 5.00)"), + (500, float('inf'), "> 5.00"), + ] + for low, high, label in buckets: + count = np.sum((raw_evals_arr > low) & (raw_evals_arr <= high)) + pct = 100.0 * count / len(raw_evals_arr) + print(f" {label}: {count:6d} ({pct:5.1f}%)") + + print("=" * 60) + print() + print(f"✓ Labeling complete. Output saved to {output_file}") return True @@ -179,6 +253,8 @@ if __name__ == "__main__": help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')") parser.add_argument("--depth", type=int, default=12, help="Stockfish depth (default: 12)") + parser.add_argument("--no-normalize", action="store_true", + help="Disable evaluation normalization (keep raw centipawns)") parser.add_argument("--verbose", action="store_true", help="Print detailed error messages") @@ -192,6 +268,7 @@ if __name__ == "__main__": output_file=args.output_file, stockfish_path=stockfish_path, depth=args.depth, + normalize=not args.no_normalize, verbose=args.verbose ) diff --git a/modules/bot/python/src/train.py b/modules/bot/python/src/train.py index 5fe4176..868fdb2 100644 --- a/modules/bot/python/src/train.py +++ b/modules/bot/python/src/train.py @@ -12,6 +12,7 @@ from tqdm import tqdm import chess from datetime import datetime import re +import numpy as np class NNUEDataset(Dataset): """Dataset of chess positions with evaluations.""" @@ -19,15 +20,28 @@ class NNUEDataset(Dataset): def __init__(self, data_file): self.positions = [] self.evals = [] + self.evals_raw = [] + self.is_normalized = None with open(data_file, 'r') as f: for line in f: try: data = json.loads(line) fen = data['fen'] - eval_cp = data['eval'] + eval_val = data['eval'] self.positions.append(fen) - self.evals.append(eval_cp) + self.evals.append(eval_val) + + # Check if normalized or raw + if self.is_normalized is None: + # If eval is in range [-1, 1], assume normalized + self.is_normalized = abs(eval_val) <= 1.0 + + # Store raw if available + if 'eval_raw' in data: + self.evals_raw.append(data['eval_raw']) + else: + self.evals_raw.append(eval_val) except (json.JSONDecodeError, KeyError): pass @@ -36,9 +50,15 @@ class NNUEDataset(Dataset): def __getitem__(self, idx): fen = self.positions[idx] - eval_cp = self.evals[idx] + eval_val = self.evals[idx] features = fen_to_features(fen) - target = torch.sigmoid(torch.tensor(eval_cp / 400.0, dtype=torch.float32)) + + # Use evaluation as-is if normalized, otherwise apply sigmoid scaling + if self.is_normalized: + target = torch.tensor(eval_val, dtype=torch.float32) + else: + target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32)) + return features, target def fen_to_features(fen): @@ -122,7 +142,7 @@ def save_metadata(weights_file, metadata): return metadata_file -def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True): +def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None): """Train the NNUE model. Args: @@ -134,12 +154,28 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4 checkpoint: Optional path to checkpoint file to resume from stockfish_depth: Depth used in Stockfish evaluation (for metadata) use_versioning: If True, save as nnue_weights_v{N}.pt with metadata + early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable) """ print("Loading dataset...") dataset = NNUEDataset(data_file) num_positions = len(dataset) print(f"Dataset size: {num_positions}") + print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})") + + # Print dataset diagnostics + evals_array = np.array(dataset.evals) + print() + print("=" * 60) + print("TRAINING DATASET DIAGNOSTICS") + print("=" * 60) + print(f"Min evaluation: {evals_array.min():.4f}") + print(f"Max evaluation: {evals_array.max():.4f}") + print(f"Mean evaluation: {evals_array.mean():.4f}") + print(f"Median evaluation: {np.median(evals_array):.4f}") + print(f"Std deviation: {evals_array.std():.4f}") + print("=" * 60) + print() # Split 90% train, 10% validation train_size = int(0.9 * len(dataset)) @@ -179,8 +215,11 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4 best_val_loss = float('inf') best_model_state = None + epochs_without_improvement = 0 print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...") + if early_stopping_patience: + print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)") print() training_start_time = datetime.now() @@ -228,6 +267,14 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4 if val_loss < best_val_loss: best_val_loss = val_loss best_model_state = model.state_dict().copy() + epochs_without_improvement = 0 + else: + epochs_without_improvement += 1 + + # Early stopping check + if early_stopping_patience and epochs_without_improvement >= early_stopping_patience: + print(f"Early stopping: no improvement for {early_stopping_patience} epochs") + break # Save best model if best_model_state is not None: @@ -286,6 +333,8 @@ if __name__ == "__main__": help="Batch size (default: 4096)") parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate (default: 1e-3)") + parser.add_argument("--early-stopping", type=int, default=None, + help="Stop if val loss doesn't improve for N epochs (optional)") parser.add_argument("--stockfish-depth", type=int, default=12, help="Stockfish depth used for evaluations (for metadata, default: 12)") parser.add_argument("--no-versioning", action="store_true", @@ -301,5 +350,6 @@ if __name__ == "__main__": lr=args.lr, checkpoint=args.checkpoint, stockfish_depth=args.stockfish_depth, - use_versioning=not args.no_versioning + use_versioning=not args.no_versioning, + early_stopping_patience=args.early_stopping ) diff --git a/modules/bot/python/weights/nnue_weights_v3.pt b/modules/bot/python/weights/nnue_weights_v3.pt new file mode 100644 index 0000000..91dee7c Binary files /dev/null and b/modules/bot/python/weights/nnue_weights_v3.pt differ diff --git a/modules/bot/python/weights/nnue_weights_v3_metadata.json b/modules/bot/python/weights/nnue_weights_v3_metadata.json new file mode 100644 index 0000000..70ce1e3 --- /dev/null +++ b/modules/bot/python/weights/nnue_weights_v3_metadata.json @@ -0,0 +1,13 @@ +{ + "version": 3, + "date": "2026-04-08T09:43:28.000579", + "num_positions": 71610, + "stockfish_depth": 12, + "epochs": 20, + "batch_size": 4096, + "learning_rate": 0.001, + "final_val_loss": 0.006398905136849695, + "device": "cpu", + "checkpoint": "/home/janis/Workspaces/IntelliJ/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v2.pt", + "notes": "Win rate vs classical eval: TBD (requires benchmark games)" +} \ No newline at end of file diff --git a/modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala b/modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala index 1f206ff..63825d6 100644 --- a/modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala +++ b/modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala @@ -52,7 +52,7 @@ object FenParser extends GameContextImport: ) else None - /** Parse en passant target square ("-" for none, or algebraic like "e3"). */ + /** Parse en passant target square ("-" for none, or algebraic like "e3"). */ private def parseEnPassant(s: String): Option[Option[Square]] = if s == "-" then Some(None) else Square.fromAlgebraic(s).map(Some(_)) diff --git a/modules/ui/src/main/scala/de/nowchess/ui/Main.scala b/modules/ui/src/main/scala/de/nowchess/ui/Main.scala index d8d504d..9b4cf02 100644 --- a/modules/ui/src/main/scala/de/nowchess/ui/Main.scala +++ b/modules/ui/src/main/scala/de/nowchess/ui/Main.scala @@ -19,7 +19,7 @@ object Main: val book = PolyglotBook("../../modules/bot/codekiddy.bin") - engine.setOpponentBot(NNUEBot(BotDifficulty.Easy, book = Some(book)), Black); + engine.setOpponentBot(ClassicalBot(BotDifficulty.Easy, book = Some(book)), Black); // Launch ScalaFX GUI in separate thread ChessGUILauncher.launch(engine)