diff --git a/modules/bot/build.gradle.kts b/modules/bot/build.gradle.kts index 208a860..9dfcd68 100644 --- a/modules/bot/build.gradle.kts +++ b/modules/bot/build.gradle.kts @@ -63,3 +63,7 @@ tasks.test { tasks.reportScoverage { dependsOn(tasks.test) } + +tasks.jar { + duplicatesStrategy = DuplicatesStrategy.EXCLUDE +} diff --git a/modules/bot/python/.gitignore b/modules/bot/python/.gitignore index b5e7531..54e198c 100644 --- a/modules/bot/python/.gitignore +++ b/modules/bot/python/.gitignore @@ -17,3 +17,5 @@ ENV/ .vscode/ *.swp *.swo +tactical_data/ +trainingdata/ diff --git a/modules/bot/python/nnue.py b/modules/bot/python/nnue.py index 2766461..a531a1f 100644 --- a/modules/bot/python/nnue.py +++ b/modules/bot/python/nnue.py @@ -2,6 +2,7 @@ """Central NNUE pipeline TUI for training and exporting models.""" import os +import shutil import sys from pathlib import Path from rich.console import Console @@ -17,6 +18,10 @@ from generate import play_random_game_and_collect_positions from label import label_positions_with_stockfish from train import train_nnue from export import export_weights_to_binary +from tactical_positions_extractor import ( + download_and_extract_puzzle_db, + interactive_merge_positions +) def get_data_dir(): """Get/create data directory.""" @@ -24,6 +29,12 @@ def get_data_dir(): data_dir.mkdir(exist_ok=True) return data_dir +def get_tactical_data_dir(): + """Get/create data directory.""" + data_dir = Path(__file__).parent / "tactical_data" + data_dir.mkdir(exist_ok=True) + return data_dir + def get_weights_dir(): """Get/create weights directory.""" weights_dir = Path(__file__).parent / "weights" @@ -87,20 +98,23 @@ def show_main_menu(): console.print("\n[bold]What would you like to do?[/bold]") console.print("[cyan]1[/cyan] - Train NNUE Model") console.print("[cyan]2[/cyan] - Export Weights to Scala") - console.print("[cyan]3[/cyan] - View Checkpoints") - console.print("[cyan]4[/cyan] - Exit") + console.print("[cyan]3[/cyan] - Extract Tactical Positions") + console.print("[cyan]4[/cyan] - View Checkpoints") + console.print("[cyan]5[/cyan] - Exit") - choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"]) + choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"]) if choice == "1": train_interactive() elif choice == "2": export_interactive() elif choice == "3": + extract_tactical_interactive() + elif choice == "4": show_header() show_checkpoints_table() Prompt.ask("\nPress Enter to continue") - elif choice == "4": + elif choice == "5": console.print("[yellow]👋 Goodbye![/yellow]") return @@ -146,13 +160,18 @@ def train_interactive(): if use_existing_labels: labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl")) - # Stockfish path - default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/bin/stockfish") + # Stockfish path and labeling parameters + default_stockfish = os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish" stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish) + stockfish_depth = 12 + num_workers = 1 + if not use_existing_labels: + stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12")) + num_workers = int(Prompt.ask("Number of parallel workers", default="1")) # Training parameters - epochs = int(Prompt.ask("Number of epochs", default="20")) - batch_size = int(Prompt.ask("Batch size", default="4096")) + epochs = int(Prompt.ask("Number of epochs", default="100")) + batch_size = int(Prompt.ask("Batch size", default="16384")) early_stopping = None if Confirm.ask("Enable early stopping?", default=False): early_stopping = int(Prompt.ask("Patience (epochs)", default="5")) @@ -175,6 +194,9 @@ def train_interactive(): console.print(f" Early stopping: Yes (patience: {early_stopping})") else: console.print(f" Early stopping: No") + if not use_existing_labels: + console.print(f" Stockfish depth: {stockfish_depth}") + console.print(f" Workers: {num_workers}") console.print(f" Stockfish: {stockfish_path}") if not Confirm.ask("\nStart training?", default=True): @@ -214,7 +236,8 @@ def train_interactive(): str(positions_file), str(output_file), stockfish_path, - depth=12 + depth=stockfish_depth, + num_workers=num_workers ) if not success: console.print("[red]✗ Position labeling failed[/red]") @@ -305,6 +328,66 @@ def export_interactive(): traceback.print_exc() Prompt.ask("Press Enter to continue") +def extract_tactical_interactive(): + """Interactive tactical positions extraction and merge menu.""" + console = Console() + show_header() + + console.print("\n[bold cyan]♟️ Tactical Positions Extraction & Merge[/bold cyan]") + + # Download and extract options + console.print("\n[bold]Lichess Puzzle Database:[/bold]") + download_url = Prompt.ask( + "Download URL", + default="https://database.lichess.org/lichess_db_puzzle.csv.zst" + ) + + output_dir = Prompt.ask( + "Extract to directory", + default=str(Path(__file__).parent / "trainingdata") + ) + + max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000")) + + # Confirm and download + console.print("\n[bold]Configuration:[/bold]") + console.print(f" Download URL: {download_url}") + console.print(f" Extract directory: {output_dir}") + console.print(f" Max puzzles: {max_puzzles:,}") + + if not Confirm.ask("\nProceed?", default=True): + console.print("[yellow]Cancelled[/yellow]") + Prompt.ask("Press Enter to continue") + return + + try: + console.print("\n[bold cyan]Step 1: Download & Extract[/bold cyan]") + csv_path = download_and_extract_puzzle_db(download_url, output_dir) + + if not csv_path: + console.print("[red]✗ Failed to download/extract[/red]") + Prompt.ask("Press Enter to continue") + return + + console.print(f"[green]✓ Ready: {csv_path}[/green]") + + # Interactive merge + console.print("\n[bold cyan]Step 2: Extract & Merge[/bold cyan]") + output_file = Prompt.ask( + "Output file path", + default=str(Path(__file__).parent / "data" / "position.txt") + ) + + interactive_merge_positions(csv_path, output_file, max_puzzles) + console.print(f"\n[green]✓ Complete![/green]") + Prompt.ask("Press Enter to continue") + + except Exception as e: + console.print(f"[red]✗ Error: {e}[/red]") + import traceback + traceback.print_exc() + Prompt.ask("Press Enter to continue") + def main(): try: show_main_menu() diff --git a/modules/bot/python/requirements.txt b/modules/bot/python/requirements.txt index b0bf304..feccc69 100644 --- a/modules/bot/python/requirements.txt +++ b/modules/bot/python/requirements.txt @@ -2,4 +2,5 @@ chess==1.11.2 torch==2.11.0 tqdm==4.67.3 numpy==2.4.4 -rich==13.7.0 \ No newline at end of file +rich==13.7.0 +zstandard==0.23.0 \ No newline at end of file diff --git a/modules/bot/python/src/export.py b/modules/bot/python/src/export.py index df5454d..f20370b 100644 --- a/modules/bot/python/src/export.py +++ b/modules/bot/python/src/export.py @@ -13,8 +13,9 @@ def export_weights_to_binary(weights_file, output_file): print(f"Error: Weights file not found at {weights_file}") sys.exit(1) - # Load weights - state_dict = torch.load(weights_file, map_location='cpu') + # Load weights — handle both raw state dicts and full training checkpoints + loaded = torch.load(weights_file, map_location='cpu') + state_dict = loaded["model_state_dict"] if isinstance(loaded, dict) and "model_state_dict" in loaded else loaded # Debug: print available layers print(f"Available layers in {weights_file}:") @@ -31,7 +32,7 @@ def export_weights_to_binary(weights_file, output_file): f.write(struct.pack('= valid_end: + continue + + # Randomly sample positions from this game + sample_count = min(samples_per_game, valid_end - valid_start) + if sample_count > 0: + sample_indices = random.sample( + range(valid_start, valid_end), + k=sample_count + ) + + for idx in sample_indices: + sampled_board = move_history[idx] + + # Only filter truly invalid or terminal positions + if not sampled_board.is_valid() or sampled_board.is_game_over(): + continue + + # Save position (include check, captures, all positions) + fen = sampled_board.fen() + positions.append(fen) + + return positions + def play_random_game_and_collect_positions( output_file, - total_games=500000, + total_positions=3000000, samples_per_game=1, min_move=1, - max_move=50 + max_move=50, + num_workers=8 ): - """Play random games and sample multiple positions from each. + """Generate positions using multiprocessing with multiple workers. Args: output_file: Output file for positions - total_games: Number of games to play + total_positions: Target number of positions to generate samples_per_game: Number of positions to sample per game (1-N) min_move: Minimum move number to start sampling from max_move: Maximum move number for sampling + num_workers: Number of parallel worker processes Returns: Number of valid positions saved """ + # Estimate games needed (roughly 1 position per game on average) + total_games = max(total_positions // samples_per_game, num_workers) + games_per_worker = total_games // num_workers + + print(f"Generating {total_positions:,} positions using {num_workers} workers") + print(f"Total games: ~{total_games:,} ({games_per_worker:,} per worker)") + print() + + start_time = datetime.now() + + # Generate positions in parallel + worker_tasks = [ + (i, games_per_worker, samples_per_game, min_move, max_move) + for i in range(num_workers) + ] + positions_count = 0 - filtered_game_over = 0 - filtered_illegal = 0 - - with open(output_file, 'w') as f: - with tqdm(total=total_games, desc="Generating positions") as pbar: - for game_num in range(total_games): - board = chess.Board() - move_history = [] - - # Play a complete random game - while not board.is_game_over() and len(move_history) < 200: - legal_moves = list(board.legal_moves) - if not legal_moves: - break - move = random.choice(legal_moves) - board.push(move) - move_history.append(board.copy()) - - # Determine the range of moves to sample from - game_length = len(move_history) - valid_start = max(min_move, 0) - valid_end = min(max_move, game_length) - - if valid_start >= valid_end: - # Game too short - pbar.update(1) - continue - - # Randomly sample positions from this game - sample_count = min(samples_per_game, valid_end - valid_start) - if sample_count > 0: - sample_indices = random.sample( - range(valid_start, valid_end), - k=sample_count - ) - - for idx in sample_indices: - sampled_board = move_history[idx] - - # Only filter truly invalid or terminal positions - if not sampled_board.is_valid(): - filtered_illegal += 1 - continue - - if sampled_board.is_game_over(): - filtered_game_over += 1 - continue - - # Save position (include check, captures, all positions) - fen = sampled_board.fen() - f.write(fen + '\n') - positions_count += 1 + all_positions = [] + with Pool(num_workers) as pool: + with tqdm(total=num_workers, desc="Workers generating games") as pbar: + for positions in pool.starmap(_worker_generate_games, worker_tasks): + all_positions.extend(positions) + positions_count += len(positions) pbar.update(1) + # Write all positions to file + print(f"Writing {positions_count:,} positions to {output_file}...") + with open(output_file, 'w') as f: + for fen in all_positions: + f.write(fen + '\n') + + elapsed_time = datetime.now() - start_time + elapsed_seconds = elapsed_time.total_seconds() + positions_per_second = positions_count / elapsed_seconds if elapsed_seconds > 0 else 0 + # Print summary print() print("=" * 60) print("POSITION GENERATION SUMMARY") print("=" * 60) - print(f"Total games played: {total_games}") + print(f"Target positions: {total_positions:,}") + print(f"Actual positions saved: {positions_count:,}") + print(f"Workers: {num_workers}") + print(f"Games per worker: {games_per_worker:,}") print(f"Samples per game: {samples_per_game}") print(f"Move range: {min_move}-{max_move}") - print(f"Saved positions: {positions_count}") - print(f"Filtered (game over): {filtered_game_over}") - print(f"Filtered (illegal): {filtered_illegal}") - print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%") - print(f"(Includes checks, captures, all realistic positions)") + print(f"Elapsed time: {elapsed_time}") + print(f"Throughput: {positions_per_second:.0f} positions/second") print("=" * 60) print() @@ -110,23 +146,26 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training") parser.add_argument("output_file", nargs="?", default="positions.txt", help="Output file for positions (default: positions.txt)") - parser.add_argument("--games", type=int, default=5000, - help="Number of games to play (default: 5000)") + parser.add_argument("--positions", type=int, default=3000000, + help="Target number of positions to generate (default: 3000000)") parser.add_argument("--samples-per-game", type=int, default=1, help="Number of positions to sample per game (default: 1)") parser.add_argument("--min-move", type=int, default=1, help="Minimum move number to sample from (default: 1)") parser.add_argument("--max-move", type=int, default=50, help="Maximum move number to sample from (default: 50)") + parser.add_argument("--workers", type=int, default=8, + help="Number of parallel worker processes (default: 8)") args = parser.parse_args() count = play_random_game_and_collect_positions( output_file=args.output_file, - total_games=args.games, + total_positions=args.positions, samples_per_game=args.samples_per_game, min_move=args.min_move, - max_move=args.max_move + max_move=args.max_move, + num_workers=args.workers ) sys.exit(0 if count > 0 else 1) diff --git a/modules/bot/python/src/label.py b/modules/bot/python/src/label.py index 6a87c0a..67b832b 100644 --- a/modules/bot/python/src/label.py +++ b/modules/bot/python/src/label.py @@ -8,6 +8,8 @@ import os import numpy as np from pathlib import Path from tqdm import tqdm +from multiprocessing import Pool +from functools import partial def normalize_evaluation(cp_value, method='tanh', scale=300.0): """Normalize centipawn evaluation to a bounded range. @@ -27,17 +29,68 @@ def normalize_evaluation(cp_value, method='tanh', scale=300.0): else: return cp_value / 100.0 -def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False, normalize=True): +def _evaluate_fen_batch(args): + """Worker function to evaluate a batch of FENs with Stockfish threading. + + Args: + args: tuple of (fens, stockfish_path, depth, normalize) + + Returns: + list of (fen, eval_normalized, eval_raw) tuples + """ + fens, stockfish_path, depth, normalize = args + + results = [] + + try: + engine = chess.engine.SimpleEngine.popen_uci(stockfish_path) + except Exception: + return [] + + try: + for fen in fens: + try: + board = chess.Board(fen) + if not board.is_valid(): + continue + + info = engine.analyse(board, chess.engine.Limit(depth=depth)) + + if info.get('score') is None: + continue + + score = info['score'].white() + + if score.is_mate(): + eval_cp = 2000 if score.mate() > 0 else -2000 + else: + eval_cp = score.cp + + eval_cp = max(-2000, min(2000, eval_cp)) + eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp + + results.append((fen, eval_normalized, eval_cp)) + + except Exception: + continue + finally: + engine.quit() + + return results + + +def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=1000, depth=12, verbose=False, normalize=True, num_workers=1): """Read positions and label them with Stockfish evaluations. Args: positions_file: Path to positions.txt output_file: Path to training_data.jsonl stockfish_path: Path to stockfish binary - batch_size: Batch size (not used, kept for compatibility) + batch_size: Batch size for processing (positions per worker task, default: 1000) depth: Stockfish depth verbose: Print detailed error messages normalize: If True, normalize evals using tanh + num_workers: Number of parallel Stockfish processes """ # Check if stockfish exists @@ -48,6 +101,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path, sys.exit(1) print(f"Using Stockfish: {stockfish_path}") + print(f"Number of workers: {num_workers}") # Check if positions file exists if not Path(positions_file).exists(): @@ -69,107 +123,87 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path, pass print(f"Resuming from {position_count} already evaluated positions") - # Count total positions - with open(positions_file, 'r') as f: - total_lines = sum(1 for _ in f) + # Load all FENs that need evaluation + fens_to_evaluate = [] + skipped_invalid = 0 + skipped_duplicate = 0 - if total_lines == 0: - print(f"Error: Positions file is empty ({positions_file})") - sys.exit(1) + with open(positions_file, 'r') as f: + for fen in f: + fen = fen.strip() + + if not fen: + skipped_invalid += 1 + continue + + if fen in evaluated_fens: + skipped_duplicate += 1 + continue + + fens_to_evaluate.append(fen) + + total_to_evaluate = len(fens_to_evaluate) + total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate + + if total_to_evaluate == 0: + if position_count == 0: + print(f"Error: No valid positions to evaluate in {positions_file}") + sys.exit(1) + else: + print(f"All positions already evaluated. No new positions to process.") + return True print(f"Total positions to process: {total_lines}") + print(f"New positions to evaluate: {total_to_evaluate}") print(f"Using depth: {depth}") print() - # Initialize engine - try: - engine = chess.engine.SimpleEngine.popen_uci(stockfish_path) - except Exception as e: - print(f"Error: Could not start Stockfish engine") - print(f" Stockfish path: {stockfish_path}") - print(f" Error: {e}") - sys.exit(1) + # Split FENs into batches for workers + batches = [] + for i in range(0, total_to_evaluate, batch_size): + batch = fens_to_evaluate[i:i+batch_size] + batches.append((batch, stockfish_path, depth, normalize)) - # Track statistics + # Process batches in parallel evaluated = 0 - skipped_invalid = 0 - skipped_duplicate = 0 errors = 0 raw_evals = [] normalized_evals = [] - try: - with open(positions_file, 'r') as f: + import time + start_time = time.time() + + with Pool(num_workers) as pool: + with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar: with open(output_file, 'a') as out: - with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar: - for fen in f: - fen = fen.strip() - - # Skip empty lines - if not fen: - skipped_invalid += 1 - pbar.update(1) - continue - - # Skip already evaluated - if fen in evaluated_fens: - skipped_duplicate += 1 - pbar.update(1) - continue - - try: - # Validate FEN - board = chess.Board(fen) - if not board.is_valid(): - skipped_invalid += 1 - pbar.update(1) - continue - - # Evaluate at specified depth - info = engine.analyse(board, chess.engine.Limit(depth=depth)) - - if info.get('score') is None: - skipped_invalid += 1 - pbar.update(1) - continue - - score = info['score'].white() - - # Convert to centipawns - if score.is_mate(): - # Use large values for mate scores - eval_cp = 2000 if score.mate() > 0 else -2000 - else: - eval_cp = score.cp - - # Clamp to [-2000, 2000] - eval_cp = max(-2000, min(2000, eval_cp)) - - # Normalize evaluation - eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp - - # Track statistics - raw_evals.append(eval_cp) - normalized_evals.append(eval_normalized) - - # Save evaluation (normalized if requested) - data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp} - out.write(json.dumps(data) + '\n') - out.flush() # Force write to disk - evaluated += 1 - - except Exception as e: - errors += 1 - if verbose: - print(f"Error evaluating position: {fen[:50]}...") - print(f" {type(e).__name__}: {e}") - pbar.update(1) - continue - + for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)): + for fen, eval_normalized, eval_cp in batch_results: + data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp} + out.write(json.dumps(data) + '\n') + evaluated += 1 + raw_evals.append(eval_cp) + normalized_evals.append(eval_normalized) pbar.update(1) - finally: - engine.quit() + # Update progress for any failed evaluations in the batch + batch_size_actual = len(batches[0][0]) if batches else batch_size + failed = batch_size_actual - len(batch_results) + if failed > 0: + errors += failed + pbar.update(failed) + + # Calculate and show throughput and ETA + elapsed = time.time() - start_time + throughput = evaluated / elapsed if elapsed > 0 else 0 + remaining_positions = total_to_evaluate - evaluated + eta_seconds = remaining_positions / throughput if throughput > 0 else 0 + eta_str = f"{int(eta_seconds // 60)}:{int(eta_seconds % 60):02d}" + + if (batch_idx + 1) % max(1, len(batches) // 10) == 0: + pbar.set_postfix({ + 'rate': f'{throughput:.0f} pos/s', + 'eta': eta_str + }) # Print summary and analysis print() @@ -253,10 +287,14 @@ if __name__ == "__main__": help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')") parser.add_argument("--depth", type=int, default=12, help="Stockfish depth (default: 12)") + parser.add_argument("--batch-size", type=int, default=20, + help="Batch size for processing (default: 1000)") parser.add_argument("--no-normalize", action="store_true", help="Disable evaluation normalization (keep raw centipawns)") parser.add_argument("--verbose", action="store_true", help="Print detailed error messages") + parser.add_argument("--workers", type=int, default=1, + help="Number of parallel Stockfish processes (default: 1)") args = parser.parse_args() @@ -267,9 +305,11 @@ if __name__ == "__main__": positions_file=args.positions_file, output_file=args.output_file, stockfish_path=stockfish_path, + batch_size=args.batch_size, depth=args.depth, normalize=not args.no_normalize, - verbose=args.verbose + verbose=args.verbose, + num_workers=args.workers ) sys.exit(0 if success else 1) diff --git a/modules/bot/python/src/tactical_positions_extractor.py b/modules/bot/python/src/tactical_positions_extractor.py new file mode 100644 index 0000000..fd30efd --- /dev/null +++ b/modules/bot/python/src/tactical_positions_extractor.py @@ -0,0 +1,224 @@ +import chess +import csv +import json +import sys +import urllib.request +from pathlib import Path +from typing import Set, Tuple + +try: + import zstandard as zstd +except ImportError: + print("zstandard library not found. Install with: pip install zstandard") + sys.exit(1) + +from generate import play_random_game_and_collect_positions + + +def download_and_extract_puzzle_db( + url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst', + output_dir: str = 'trainingdata' +): + """Download and extract the Lichess puzzle database.""" + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + csv_file = output_path / 'lichess_db_puzzle.csv' + zst_file = output_path / 'lichess_db_puzzle.csv.zst' + + # Download if not already present + if not zst_file.exists(): + print(f"Downloading puzzle database from {url}...") + try: + urllib.request.urlretrieve(url, zst_file) + print(f"Downloaded to {zst_file}") + except Exception as e: + print(f"Failed to download: {e}") + return None + + # Extract if CSV doesn't exist + if not csv_file.exists(): + print(f"Extracting {zst_file}...") + try: + with open(zst_file, 'rb') as f: + dctx = zstd.ZstdDecompressor() + with dctx.stream_reader(f) as reader: + with open(csv_file, 'wb') as out: + out.write(reader.read()) + print(f"Extracted to {csv_file}") + except Exception as e: + print(f"Failed to extract: {e}") + return None + + return str(csv_file) + + +def extract_puzzle_positions( + puzzle_csv: str, + max_puzzles: int = 300_000 +) -> Set[str]: + """ + Extract the position BEFORE the blunder from each puzzle. + This is exactly the type of position where tactical + recognition matters most. + + Returns a set of unique FENs. + """ + positions = set() + + with open(puzzle_csv) as f: + reader = csv.DictReader(f) + for row in reader: + if len(positions) >= max_puzzles: + break + + try: + board = chess.Board(row['FEN']) + + # The puzzle FEN is AFTER the blunder move + # We want the position BEFORE — so it learns + # to find the tactic, not just play it + moves = row['Moves'].split() + + # Undo one move to get pre-tactic position + board.push_uci(moves[0]) # opponent blunder + fen = board.fen() + + # Filter for useful tactical themes + themes = row.get('Themes', '') + useful = any(t in themes for t in [ + 'fork', 'pin', 'skewer', 'discoveredAttack', + 'mate', 'mateIn2', 'mateIn3', 'hangingPiece', + 'trappedPiece', 'sacrifice' + ]) + + if useful: + positions.add(fen) + + except Exception: + continue + + return positions + + +def load_positions_from_file(file_path: str) -> Set[str]: + """Load positions from a text file (one FEN per line).""" + positions = set() + try: + with open(file_path) as f: + for line in f: + line = line.strip() + if line: + positions.add(line) + print(f"Loaded {len(positions)} positions from {file_path}") + return positions + except Exception as e: + print(f"Failed to load from {file_path}: {e}") + return set() + + +def merge_positions( + tactical: Set[str], + other: Set[str], + output_file: str = 'position.txt' +): + """Merge two position sets and write to file.""" + merged = tactical | other + + with open(output_file, 'w') as f: + for fen in merged: + f.write(fen + '\n') + + overlap = len(tactical & other) + print(f"\n{'='*60}") + print(f"MERGE SUMMARY") + print(f"{'='*60}") + print(f"Tactical positions: {len(tactical):,}") + print(f"Other positions: {len(other):,}") + print(f"Overlap (deduplicated): {overlap:,}") + print(f"Total merged positions: {len(merged):,}") + print(f"Written to: {output_file}") + print(f"{'='*60}\n") + + +def interactive_merge_positions( + puzzle_csv: str, + output_file: str = 'position.txt', + max_puzzles: int = 300_000 +): + """Interactive workflow: extract tactical positions and merge with user selection.""" + print("\n" + "="*60) + print("TACTICAL POSITION EXTRACTOR & MERGER") + print("="*60 + "\n") + + # Extract tactical positions + print("Extracting tactical positions from puzzle database...") + tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles) + print(f"Extracted {len(tactical_positions):,} unique tactical positions\n") + + # Ask what to merge with + print("What would you like to merge with these tactical positions?") + print("1. Load from a position file") + print("2. Generate random positions") + print("3. Skip merging (save tactical only)") + + choice = input("\nEnter choice (1-3): ").strip() + + other_positions = set() + + if choice == '1': + file_path = input("Enter path to position file: ").strip() + other_positions = load_positions_from_file(file_path) + + elif choice == '2': + positions_to_gen = input("How many positions to generate? (default 1000000): ").strip() + try: + positions_to_gen = int(positions_to_gen) if positions_to_gen else 1000000 + except ValueError: + positions_to_gen = 1000000 + + temp_file = 'temp_generated_positions.txt' + print(f"\nGenerating {positions_to_gen:,} random positions...") + play_random_game_and_collect_positions( + output_file=temp_file, + total_positions=positions_to_gen, + samples_per_game=1, + min_move=1, + max_move=50, + num_workers=8 + ) + other_positions = load_positions_from_file(temp_file) + + elif choice == '3': + print("Skipping merge, saving tactical positions only...") + + else: + print("Invalid choice, saving tactical positions only...") + + merge_positions(tactical_positions, other_positions, output_file) + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description="Extract and merge tactical positions") + parser.add_argument("--url", default='https://database.lichess.org/lichess_db_puzzle.csv.zst', + help="URL to download puzzle database from") + parser.add_argument("--output-dir", default='trainingdata', + help="Directory to extract puzzle database to") + parser.add_argument("--max-puzzles", type=int, default=300_000, + help="Maximum puzzles to extract (default: 300000)") + parser.add_argument("--output-file", default='position.txt', + help="Output file for merged positions (default: position.txt)") + + args = parser.parse_args() + + # Download and extract + csv_path = download_and_extract_puzzle_db(args.url, args.output_dir) + + if csv_path: + # Interactive merge + interactive_merge_positions(csv_path, args.output_file, args.max_puzzles) + else: + print("Failed to download/extract puzzle database") + sys.exit(1) \ No newline at end of file diff --git a/modules/bot/python/src/train.py b/modules/bot/python/src/train.py index 868fdb2..4fba99a 100644 --- a/modules/bot/python/src/train.py +++ b/modules/bot/python/src/train.py @@ -87,24 +87,34 @@ def fen_to_features(fen): return features class NNUE(nn.Module): - """NNUE neural network architecture.""" + """NNUE neural network architecture: 768→1536→1024→512→256→1 with dropout regularization.""" - def __init__(self): + def __init__(self, dropout_rate=0.2): super().__init__() - self.l1 = nn.Linear(768, 256) + self.l1 = nn.Linear(768, 1536) self.relu1 = nn.ReLU() - self.l2 = nn.Linear(256, 32) + self.drop1 = nn.Dropout(dropout_rate) + + self.l2 = nn.Linear(1536, 1024) self.relu2 = nn.ReLU() - self.l3 = nn.Linear(32, 1) - self.sigmoid = nn.Sigmoid() + self.drop2 = nn.Dropout(dropout_rate) + + self.l3 = nn.Linear(1024, 512) + self.relu3 = nn.ReLU() + self.drop3 = nn.Dropout(dropout_rate) + + self.l4 = nn.Linear(512, 256) + self.relu4 = nn.ReLU() + self.drop4 = nn.Dropout(dropout_rate) + + self.l5 = nn.Linear(256, 1) def forward(self, x): - x = self.l1(x) - x = self.relu1(x) - x = self.l2(x) - x = self.relu2(x) - x = self.l3(x) - return x + x = self.drop1(self.relu1(self.l1(x))) + x = self.drop2(self.relu2(self.l2(x))) + x = self.drop3(self.relu3(self.l3(x))) + x = self.drop4(self.relu4(self.l4(x))) + return self.l5(x) def find_next_version(base_name="nnue_weights"): """Find the next version number for model versioning. @@ -142,21 +152,32 @@ def save_metadata(weights_file, metadata): return metadata_file -def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None): - """Train the NNUE model. +def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4): + """Train the NNUE model with GPU optimizations and automatic mixed precision. Args: data_file: Path to training_data.jsonl output_file: Where to save best weights (or base name if use_versioning=True) - epochs: Number of training epochs - batch_size: Training batch size - lr: Learning rate + epochs: Number of training epochs (default: 100) + batch_size: Training batch size (default: 16384) + lr: Learning rate (default: 0.001) checkpoint: Optional path to checkpoint file to resume from stockfish_depth: Depth used in Stockfish evaluation (for metadata) use_versioning: If True, save as nnue_weights_v{N}.pt with metadata early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable) + weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting) """ + print("Checking GPU availability...") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + print(f"✓ GPU available: {torch.cuda.get_device_name(0)}") + print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") + else: + print("⚠ GPU not available, using CPU") + print(f"Using device: {device}") + print() + print("Loading dataset...") dataset = NNUEDataset(data_file) num_positions = len(dataset) @@ -182,42 +203,63 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4 val_size = len(dataset) - train_size from torch.utils.data import random_split - train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + generator = torch.Generator().manual_seed(42) + train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator) - train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) - - # Device - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") + # DataLoader with GPU optimizations: num_workers=8, pin_memory, persistent_workers + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=8, + pin_memory=True, + persistent_workers=True + ) + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=8, + pin_memory=True, + persistent_workers=True + ) # Model model = NNUE().to(device) criterion = nn.MSELoss() - optimizer = optim.Adam(model.parameters(), lr=lr) + optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) - # Load checkpoint if provided - checkpoint_to_load = checkpoint - if checkpoint_to_load is None and Path(output_file).exists(): - # Auto-detect checkpoint: if output file already exists, use it as checkpoint - checkpoint_to_load = output_file + # Cosine annealing learning rate scheduler + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + # Mixed precision training + scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu') start_epoch = 0 - if checkpoint_to_load is not None and Path(checkpoint_to_load).exists(): - print(f"Loading checkpoint from {checkpoint_to_load}...") - try: - checkpoint_state = torch.load(checkpoint_to_load, map_location=device) - model.load_state_dict(checkpoint_state) - print(f"✓ Checkpoint loaded successfully") - except Exception as e: - print(f"Warning: Could not load checkpoint: {e}") - print("Training from scratch instead") - best_val_loss = float('inf') + if checkpoint: + print(f"Loading checkpoint: {checkpoint}") + ckpt = torch.load(checkpoint, map_location=device) + if isinstance(ckpt, dict) and "model_state_dict" in ckpt: + model.load_state_dict(ckpt["model_state_dict"]) + optimizer.load_state_dict(ckpt["optimizer_state_dict"]) + scheduler.load_state_dict(ckpt["scheduler_state_dict"]) + scaler.load_state_dict(ckpt["scaler_state_dict"]) + start_epoch = ckpt["epoch"] + 1 + best_val_loss = ckpt.get("best_val_loss", float('inf')) + print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})") + else: + model.load_state_dict(ckpt) + print("Loaded weights-only checkpoint (no optimizer state)") + + checkpoint_val_loss = best_val_loss if checkpoint else float('inf') best_model_state = None epochs_without_improvement = 0 - print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...") + print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...") + print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})") + print(f"Mixed precision training: enabled") + print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})") if early_stopping_patience: print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)") print() @@ -236,10 +278,15 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4 batch_targets = batch_targets.to(device).unsqueeze(1) optimizer.zero_grad() - outputs = model(batch_features) - loss = criterion(outputs, batch_targets) - loss.backward() - optimizer.step() + + # Mixed precision forward and backward + with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'): + outputs = model(batch_features) + loss = criterion(outputs, batch_targets) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() train_loss += loss.item() * batch_features.size(0) pbar.update(1) @@ -255,19 +302,51 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4 batch_features = batch_features.to(device) batch_targets = batch_targets.to(device).unsqueeze(1) - outputs = model(batch_features) - loss = criterion(outputs, batch_targets) + with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'): + outputs = model(batch_features) + loss = criterion(outputs, batch_targets) val_loss += loss.item() * batch_features.size(0) pbar.update(1) val_loss /= len(val_dataset) - print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}") + # Update learning rate + scheduler.step() - if val_loss < best_val_loss: + # Print GPU memory usage + if torch.cuda.is_available(): + gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9 + gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9 + print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved") + + # Calculate and print estimated time remaining + elapsed_time = datetime.now() - training_start_time + time_per_epoch = elapsed_time.total_seconds() / (epoch + 1) + remaining_epochs = total_epochs - epoch_display + eta_seconds = time_per_epoch * remaining_epochs + eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0] + + print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}") + + # Save checkpoint after every epoch + checkpoint_file = output_file.replace(".pt", "_checkpoint.pt") + torch.save({ + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "scaler_state_dict": scaler.state_dict(), + "best_val_loss": best_val_loss, + }, checkpoint_file) + + if val_loss < : best_val_loss = val_loss best_model_state = model.state_dict().copy() epochs_without_improvement = 0 + # Save best model snapshot + snapshot_file = output_file.replace(".pt", "_best_snapshot.pt") + torch.save(best_model_state, snapshot_file) + print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})") else: epochs_without_improvement += 1 @@ -277,6 +356,11 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4 break # Save best model + if best_model_state is None or best_val_loss >= checkpoint_val_loss: + print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})") + print("No new model created.") + return + if best_model_state is not None: # Determine final output file with versioning final_output_file = output_file @@ -302,7 +386,14 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4 "notes": "Win rate vs classical eval: TBD (requires benchmark games)" } - torch.save(best_model_state, final_output_file) + torch.save({ + "model_state_dict": best_model_state, + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "scaler_state_dict": scaler.state_dict(), + "epoch": epoch, + "best_val_loss": best_val_loss, + }, final_output_file) print(f"Best model saved to {final_output_file}") # Save metadata if versioning is enabled @@ -327,18 +418,20 @@ if __name__ == "__main__": help="Output file base name (default: nnue_weights.pt)") parser.add_argument("--checkpoint", type=str, default=None, help="Path to checkpoint file to resume training from (optional)") - parser.add_argument("--epochs", type=int, default=20, - help="Number of epochs to train (default: 20)") - parser.add_argument("--batch-size", type=int, default=4096, - help="Batch size (default: 4096)") - parser.add_argument("--lr", type=float, default=1e-3, - help="Learning rate (default: 1e-3)") + parser.add_argument("--epochs", type=int, default=100, + help="Number of epochs to train (default: 100)") + parser.add_argument("--batch-size", type=int, default=16384, + help="Batch size (default: 16384)") + parser.add_argument("--lr", type=float, default=0.001, + help="Learning rate (default: 0.001)") parser.add_argument("--early-stopping", type=int, default=None, help="Stop if val loss doesn't improve for N epochs (optional)") parser.add_argument("--stockfish-depth", type=int, default=12, help="Stockfish depth used for evaluations (for metadata, default: 12)") parser.add_argument("--no-versioning", action="store_true", help="Disable automatic versioning (save directly to output file)") + parser.add_argument("--weight-decay", type=float, default=5e-5, + help="L2 regularization strength (default: 1e-4, helps prevent overfitting)") args = parser.parse_args() @@ -351,5 +444,6 @@ if __name__ == "__main__": checkpoint=args.checkpoint, stockfish_depth=args.stockfish_depth, use_versioning=not args.no_versioning, - early_stopping_patience=args.early_stopping + early_stopping_patience=args.early_stopping, + weight_decay=args.weight_decay ) diff --git a/modules/bot/python/start.ps1 b/modules/bot/python/start.ps1 new file mode 100644 index 0000000..95122e9 --- /dev/null +++ b/modules/bot/python/start.ps1 @@ -0,0 +1,30 @@ +# Setup and run NNUE training pipeline + +$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path +$VenvDir = Join-Path $ScriptDir ".venv" + +# Check if virtual environment exists +if (-not (Test-Path $VenvDir)) { + Write-Host "Creating virtual environment..." + python -m venv $VenvDir + if ($LASTEXITCODE -ne 0) { + Write-Host "Error: Failed to create virtual environment. Make sure python is installed." + exit 1 + } +} + +# Activate virtual environment +Write-Host "Activating virtual environment..." +$ActivateScript = Join-Path $VenvDir "Scripts\Activate.ps1" +& $ActivateScript + +# Install/update dependencies if requirements.txt exists +$RequirementsFile = Join-Path $ScriptDir "requirements.txt" +if (Test-Path $RequirementsFile) { + Write-Host "Installing dependencies..." + pip install -q -r $RequirementsFile +} + +# Run nnue.py +Write-Host "Starting NNUE Training Pipeline..." +python (Join-Path $ScriptDir "nnue.py") diff --git a/modules/bot/python/start.sh b/modules/bot/python/start.sh new file mode 100644 index 0000000..011a905 --- /dev/null +++ b/modules/bot/python/start.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Setup and run NNUE training pipeline + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +VENV_DIR="$SCRIPT_DIR/.venv" + +# Check if virtual environment exists +if [ ! -d "$VENV_DIR" ]; then + echo "Creating virtual environment..." + python3 -m venv "$VENV_DIR" + if [ $? -ne 0 ]; then + echo "Error: Failed to create virtual environment. Make sure python3 is installed." + exit 1 + fi +fi + +# Activate virtual environment +echo "Activating virtual environment..." +source "$VENV_DIR/bin/activate" + +# Install/update dependencies if requirements.txt exists +if [ -f "$SCRIPT_DIR/requirements.txt" ]; then + echo "Installing dependencies..." + pip install -q -r "$SCRIPT_DIR/requirements.txt" +fi + +# Run nnue.py +echo "Starting NNUE Training Pipeline..." +python "$SCRIPT_DIR/nnue.py" diff --git a/modules/bot/src/main/scala/de/nowchess/bot/Config.scala b/modules/bot/src/main/scala/de/nowchess/bot/Config.scala new file mode 100644 index 0000000..2cde199 --- /dev/null +++ b/modules/bot/src/main/scala/de/nowchess/bot/Config.scala @@ -0,0 +1,7 @@ +package de.nowchess.bot + +object Config: + + /** Threshold in centipawns: if classical evaluation differs from NNUE by more than this, + * the move is vetoed (not accepted as a suggestion). */ + val VETO_THRESHOLD: Int = 100 diff --git a/modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala b/modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala new file mode 100644 index 0000000..65ab115 --- /dev/null +++ b/modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala @@ -0,0 +1,23 @@ +package de.nowchess.bot.bots + +import de.nowchess.api.game.GameContext +import de.nowchess.api.move.Move +import de.nowchess.bot.logic.HybridSearch +import de.nowchess.bot.util.PolyglotBook +import de.nowchess.bot.{Bot, BotDifficulty} +import de.nowchess.rules.RuleSet +import de.nowchess.rules.sets.DefaultRules + +final class HybridBot( + difficulty: BotDifficulty, + rules: RuleSet = DefaultRules, + book: Option[PolyglotBook] = None +) extends Bot: + + private val search: HybridSearch = HybridSearch(rules) + + override val name: String = s"HybridBot(${difficulty.toString})" + + override def nextMove(context: GameContext): Option[Move] = + book.flatMap(_.probe(context)) + .orElse(search.bestMove(context)) diff --git a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala index 5a912d3..fc1746e 100644 --- a/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala +++ b/modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala @@ -7,9 +7,9 @@ import java.nio.ByteOrder class NNUE: - private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias) = loadWeights() + private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) = loadWeights() - private def loadWeights(): (Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float]) = + private def loadWeights(): (Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float]) = val stream = getClass.getResourceAsStream("/nnue_weights.bin") if stream == null then throw RuntimeException("NNUE weights file not found in resources") @@ -35,8 +35,12 @@ class NNUE: val l2b = readTensor(buffer) val l3w = readTensor(buffer) val l3b = readTensor(buffer) + val l4w = readTensor(buffer) + val l4b = readTensor(buffer) + val l5w = readTensor(buffer) + val l5b = readTensor(buffer) - (l1w, l1b, l2w, l2b, l3w, l3b) + (l1w, l1b, l2w, l2b, l3w, l3b, l4w, l4b, l5w, l5b) finally stream.close() private def readTensor(buffer: ByteBuffer): Array[Float] = @@ -55,10 +59,12 @@ class NNUE: floats(i) = buffer.getFloat() floats - // Pre-allocated buffers for inference + // Pre-allocated buffers for inference (architecture: 768→1536→1024→512→256→1) private val features = new Array[Float](768) - private val l1Output = new Array[Float](256) - private val l2Output = new Array[Float](32) + private val l1Output = new Array[Float](1536) + private val l2Output = new Array[Float](1024) + private val l3Output = new Array[Float](512) + private val l4Output = new Array[Float](256) /** Convert a position to 768-dimensional binary feature vector. * 12 piece types (white pawn to black king) × 64 squares from white's perspective. */ @@ -110,28 +116,43 @@ class NNUE: /** Run NNUE inference on the given position. * Returns centipawn score from the perspective of the side-to-move. - * No allocations in the hot path (uses pre-allocated buffers). */ + * No allocations in the hot path (uses pre-allocated buffers). + * Architecture: 768→1536→1024→512→256→1 */ def evaluate(context: GameContext): Int = val features = positionToFeatures(context.board, context.turn) - // Layer 1: Dense(768 -> 256) + ReLU - for i <- 0 until 256 do + // Layer 1: Dense(768 → 1536) + ReLU + for i <- 0 until 1536 do var sum = l1Bias(i) for j <- 0 until 768 do sum += features(j) * l1Weights(i * 768 + j) l1Output(i) = if sum > 0f then sum else 0f - // Layer 2: Dense(256 -> 32) + ReLU - for i <- 0 until 32 do + // Layer 2: Dense(1536 → 1024) + ReLU + for i <- 0 until 1024 do var sum = l2Bias(i) - for j <- 0 until 256 do - sum += l1Output(j) * l2Weights(i * 256 + j) + for j <- 0 until 1536 do + sum += l1Output(j) * l2Weights(i * 1536 + j) l2Output(i) = if sum > 0f then sum else 0f - // Layer 3: Dense(32 -> 1), no activation - var output = l3Bias(0) - for j <- 0 until 32 do - output += l2Output(j) * l3Weights(j) + // Layer 3: Dense(1024 → 512) + ReLU + for i <- 0 until 512 do + var sum = l3Bias(i) + for j <- 0 until 1024 do + sum += l2Output(j) * l3Weights(i * 1024 + j) + l3Output(i) = if sum > 0f then sum else 0f + + // Layer 4: Dense(512 → 256) + ReLU + for i <- 0 until 256 do + var sum = l4Bias(i) + for j <- 0 until 512 do + sum += l3Output(j) * l4Weights(i * 512 + j) + l4Output(i) = if sum > 0f then sum else 0f + + // Layer 5: Dense(256 → 1), no activation + var output = l5Bias(0) + for j <- 0 until 256 do + output += l4Output(j) * l5Weights(j) // Convert from tanh-normalized output back to centipawns // Training uses: eval_normalized = tanh(eval_cp / 300) @@ -145,3 +166,33 @@ class NNUE: (300f * atanh).toInt math.max(-20000, math.min(20000, cp)) + + /** Benchmark: time 1M evaluations and report ns/eval. + * This measures the performance of the inference on the starting position. */ + def benchmark(): Unit = + val context = GameContext.initial + val iterations = 1_000_000 + + // Warm up + for _ <- 0 until 10000 do + evaluate(context) + + // Actual benchmark + val startNanos = System.nanoTime() + for _ <- 0 until iterations do + evaluate(context) + val endNanos = System.nanoTime() + + val totalNanos = endNanos - startNanos + val nanosPerEval = totalNanos.toDouble / iterations + + println() + println("=" * 60) + println("NNUE BENCHMARK RESULTS") + println("=" * 60) + println(f"Iterations: $iterations%,d") + println(f"Total time: ${totalNanos / 1e9}%.2f seconds") + println(f"ns/eval: $nanosPerEval%.2f ns") + println(f"evals/second: ${1e9 / nanosPerEval}%.0f evals/s") + println("=" * 60) + println() diff --git a/modules/bot/src/main/scala/de/nowchess/bot/logic/HybridSearch.scala b/modules/bot/src/main/scala/de/nowchess/bot/logic/HybridSearch.scala new file mode 100644 index 0000000..af8a078 --- /dev/null +++ b/modules/bot/src/main/scala/de/nowchess/bot/logic/HybridSearch.scala @@ -0,0 +1,67 @@ +package de.nowchess.bot.logic + +import de.nowchess.api.game.GameContext +import de.nowchess.api.move.Move +import de.nowchess.bot.Config +import de.nowchess.bot.bots.classic.EvaluationClassic +import de.nowchess.bot.bots.nnue.EvaluationNNUE +import de.nowchess.rules.RuleSet +import de.nowchess.rules.sets.DefaultRules +import scala.util.boundary +import scala.util.boundary.break + +final class HybridSearch( + rules: RuleSet = DefaultRules +): + + private var vetoCount = 0 + private var approvalCount = 0 + private val TOP_MOVES_TO_VALIDATE = 10 + + /** Find the best move by scoring all legal moves with NNUE, then validating top 5 with classical eval. + * If a move's classical score is within VETO_THRESHOLD of its NNUE score, it's approved. + * If all top 5 are vetoed, fall back to the best classical move overall. + */ + def bestMove(context: GameContext): Option[Move] = + val legalMoves = rules.allLegalMoves(context) + if legalMoves.isEmpty then None else findBestMove(legalMoves, context) + + private def findBestMove(legalMoves: List[Move], context: GameContext): Option[Move] = + // Score all moves with NNUE + val moveScores = legalMoves.map { move => + val nextContext = rules.applyMove(context)(move) + val nnueScore = EvaluationNNUE.evaluate(nextContext) + (move, nnueScore, nextContext) + } + + // Sort by NNUE score descending + val sortedByNNUE = moveScores.sortBy(_._2).reverse + + // Validate top N moves with classical evaluation + val topMovesToCheck = sortedByNNUE.take(TOP_MOVES_TO_VALIDATE) + + boundary: + for (move, nnueScore, nextContext) <- topMovesToCheck do + val classicalScore = EvaluationClassic.evaluate(nextContext) + val difference = (classicalScore - nnueScore).abs + if difference <= Config.VETO_THRESHOLD then + approvalCount += 1 + println(s"[HybridSearch] Move approved: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference)") + break(Some(move)) + else + vetoCount += 1 + println(s"[HybridSearch] Move vetoed: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference > ${Config.VETO_THRESHOLD})") + + // All top 10 were vetoed, fall back to best classical move + println(s"[HybridSearch] All top 10 NNUE moves vetoed. Falling back to best classical move.") + val bestByClassical = moveScores + .map { case (move, _, nextContext) => + (move, EvaluationClassic.evaluate(nextContext)) + } + .maxBy(_._2) + + println(s"[HybridSearch] Fallback move: ${bestByClassical._1} (Classical score=${bestByClassical._2})") + println(s"[HybridSearch] Stats - Approvals: $approvalCount, Vetoes: $vetoCount") + Some(bestByClassical._1) + + def getStats: (Int, Int) = (approvalCount, vetoCount) diff --git a/modules/ui/src/main/scala/de/nowchess/ui/Main.scala b/modules/ui/src/main/scala/de/nowchess/ui/Main.scala index 9b4cf02..48ff272 100644 --- a/modules/ui/src/main/scala/de/nowchess/ui/Main.scala +++ b/modules/ui/src/main/scala/de/nowchess/ui/Main.scala @@ -3,7 +3,7 @@ package de.nowchess.ui import de.nowchess.api.board.Color.Black import de.nowchess.bot.util.PolyglotBook import de.nowchess.bot.BotDifficulty -import de.nowchess.bot.bots.{ClassicalBot, NNUEBot} +import de.nowchess.bot.bots.{ClassicalBot, HybridBot, NNUEBot} import de.nowchess.chess.engine.GameEngine import de.nowchess.ui.terminal.TerminalUI import de.nowchess.ui.gui.ChessGUILauncher @@ -19,7 +19,7 @@ object Main: val book = PolyglotBook("../../modules/bot/codekiddy.bin") - engine.setOpponentBot(ClassicalBot(BotDifficulty.Easy, book = Some(book)), Black); + engine.setOpponentBot(HybridBot(BotDifficulty.Easy, book = Some(book)), Black); // Launch ScalaFX GUI in separate thread ChessGUILauncher.launch(engine)