feat: add hybrid bot implementation and enhance NNUE training pipeline with tactical data extraction

This commit is contained in:
2026-04-09 22:36:09 +02:00
parent 0bf1d52132
commit 5d4cf5f13c
16 changed files with 940 additions and 245 deletions
+4
View File
@@ -63,3 +63,7 @@ tasks.test {
tasks.reportScoverage { tasks.reportScoverage {
dependsOn(tasks.test) dependsOn(tasks.test)
} }
tasks.jar {
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
}
+2
View File
@@ -17,3 +17,5 @@ ENV/
.vscode/ .vscode/
*.swp *.swp
*.swo *.swo
tactical_data/
trainingdata/
+92 -9
View File
@@ -2,6 +2,7 @@
"""Central NNUE pipeline TUI for training and exporting models.""" """Central NNUE pipeline TUI for training and exporting models."""
import os import os
import shutil
import sys import sys
from pathlib import Path from pathlib import Path
from rich.console import Console from rich.console import Console
@@ -17,6 +18,10 @@ from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish from label import label_positions_with_stockfish
from train import train_nnue from train import train_nnue
from export import export_weights_to_binary from export import export_weights_to_binary
from tactical_positions_extractor import (
download_and_extract_puzzle_db,
interactive_merge_positions
)
def get_data_dir(): def get_data_dir():
"""Get/create data directory.""" """Get/create data directory."""
@@ -24,6 +29,12 @@ def get_data_dir():
data_dir.mkdir(exist_ok=True) data_dir.mkdir(exist_ok=True)
return data_dir return data_dir
def get_tactical_data_dir():
"""Get/create data directory."""
data_dir = Path(__file__).parent / "tactical_data"
data_dir.mkdir(exist_ok=True)
return data_dir
def get_weights_dir(): def get_weights_dir():
"""Get/create weights directory.""" """Get/create weights directory."""
weights_dir = Path(__file__).parent / "weights" weights_dir = Path(__file__).parent / "weights"
@@ -87,20 +98,23 @@ def show_main_menu():
console.print("\n[bold]What would you like to do?[/bold]") console.print("\n[bold]What would you like to do?[/bold]")
console.print("[cyan]1[/cyan] - Train NNUE Model") console.print("[cyan]1[/cyan] - Train NNUE Model")
console.print("[cyan]2[/cyan] - Export Weights to Scala") console.print("[cyan]2[/cyan] - Export Weights to Scala")
console.print("[cyan]3[/cyan] - View Checkpoints") console.print("[cyan]3[/cyan] - Extract Tactical Positions")
console.print("[cyan]4[/cyan] - Exit") console.print("[cyan]4[/cyan] - View Checkpoints")
console.print("[cyan]5[/cyan] - Exit")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"]) choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
if choice == "1": if choice == "1":
train_interactive() train_interactive()
elif choice == "2": elif choice == "2":
export_interactive() export_interactive()
elif choice == "3": elif choice == "3":
extract_tactical_interactive()
elif choice == "4":
show_header() show_header()
show_checkpoints_table() show_checkpoints_table()
Prompt.ask("\nPress Enter to continue") Prompt.ask("\nPress Enter to continue")
elif choice == "4": elif choice == "5":
console.print("[yellow]👋 Goodbye![/yellow]") console.print("[yellow]👋 Goodbye![/yellow]")
return return
@@ -146,13 +160,18 @@ def train_interactive():
if use_existing_labels: if use_existing_labels:
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl")) labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
# Stockfish path # Stockfish path and labeling parameters
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/bin/stockfish") default_stockfish = os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish) stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
stockfish_depth = 12
num_workers = 1
if not use_existing_labels:
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Training parameters # Training parameters
epochs = int(Prompt.ask("Number of epochs", default="20")) epochs = int(Prompt.ask("Number of epochs", default="100"))
batch_size = int(Prompt.ask("Batch size", default="4096")) batch_size = int(Prompt.ask("Batch size", default="16384"))
early_stopping = None early_stopping = None
if Confirm.ask("Enable early stopping?", default=False): if Confirm.ask("Enable early stopping?", default=False):
early_stopping = int(Prompt.ask("Patience (epochs)", default="5")) early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
@@ -175,6 +194,9 @@ def train_interactive():
console.print(f" Early stopping: Yes (patience: {early_stopping})") console.print(f" Early stopping: Yes (patience: {early_stopping})")
else: else:
console.print(f" Early stopping: No") console.print(f" Early stopping: No")
if not use_existing_labels:
console.print(f" Stockfish depth: {stockfish_depth}")
console.print(f" Workers: {num_workers}")
console.print(f" Stockfish: {stockfish_path}") console.print(f" Stockfish: {stockfish_path}")
if not Confirm.ask("\nStart training?", default=True): if not Confirm.ask("\nStart training?", default=True):
@@ -214,7 +236,8 @@ def train_interactive():
str(positions_file), str(positions_file),
str(output_file), str(output_file),
stockfish_path, stockfish_path,
depth=12 depth=stockfish_depth,
num_workers=num_workers
) )
if not success: if not success:
console.print("[red]✗ Position labeling failed[/red]") console.print("[red]✗ Position labeling failed[/red]")
@@ -305,6 +328,66 @@ def export_interactive():
traceback.print_exc() traceback.print_exc()
Prompt.ask("Press Enter to continue") Prompt.ask("Press Enter to continue")
def extract_tactical_interactive():
"""Interactive tactical positions extraction and merge menu."""
console = Console()
show_header()
console.print("\n[bold cyan]♟️ Tactical Positions Extraction & Merge[/bold cyan]")
# Download and extract options
console.print("\n[bold]Lichess Puzzle Database:[/bold]")
download_url = Prompt.ask(
"Download URL",
default="https://database.lichess.org/lichess_db_puzzle.csv.zst"
)
output_dir = Prompt.ask(
"Extract to directory",
default=str(Path(__file__).parent / "trainingdata")
)
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
# Confirm and download
console.print("\n[bold]Configuration:[/bold]")
console.print(f" Download URL: {download_url}")
console.print(f" Extract directory: {output_dir}")
console.print(f" Max puzzles: {max_puzzles:,}")
if not Confirm.ask("\nProceed?", default=True):
console.print("[yellow]Cancelled[/yellow]")
Prompt.ask("Press Enter to continue")
return
try:
console.print("\n[bold cyan]Step 1: Download & Extract[/bold cyan]")
csv_path = download_and_extract_puzzle_db(download_url, output_dir)
if not csv_path:
console.print("[red]✗ Failed to download/extract[/red]")
Prompt.ask("Press Enter to continue")
return
console.print(f"[green]✓ Ready: {csv_path}[/green]")
# Interactive merge
console.print("\n[bold cyan]Step 2: Extract & Merge[/bold cyan]")
output_file = Prompt.ask(
"Output file path",
default=str(Path(__file__).parent / "data" / "position.txt")
)
interactive_merge_positions(csv_path, output_file, max_puzzles)
console.print(f"\n[green]✓ Complete![/green]")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def main(): def main():
try: try:
show_main_menu() show_main_menu()
+2 -1
View File
@@ -2,4 +2,5 @@ chess==1.11.2
torch==2.11.0 torch==2.11.0
tqdm==4.67.3 tqdm==4.67.3
numpy==2.4.4 numpy==2.4.4
rich==13.7.0 rich==13.7.0
zstandard==0.23.0
+4 -3
View File
@@ -13,8 +13,9 @@ def export_weights_to_binary(weights_file, output_file):
print(f"Error: Weights file not found at {weights_file}") print(f"Error: Weights file not found at {weights_file}")
sys.exit(1) sys.exit(1)
# Load weights # Load weights — handle both raw state dicts and full training checkpoints
state_dict = torch.load(weights_file, map_location='cpu') loaded = torch.load(weights_file, map_location='cpu')
state_dict = loaded["model_state_dict"] if isinstance(loaded, dict) and "model_state_dict" in loaded else loaded
# Debug: print available layers # Debug: print available layers
print(f"Available layers in {weights_file}:") print(f"Available layers in {weights_file}:")
@@ -31,7 +32,7 @@ def export_weights_to_binary(weights_file, output_file):
f.write(struct.pack('<I', 1)) # version 1 f.write(struct.pack('<I', 1)) # version 1
# Write each weight tensor in order # Write each weight tensor in order
for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias']: for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias', 'l4.weight', 'l4.bias', 'l5.weight', 'l5.bias']:
if layer_name not in state_dict: if layer_name not in state_dict:
print(f"Error: Missing layer {layer_name}") print(f"Error: Missing layer {layer_name}")
sys.exit(1) sys.exit(1)
+106 -67
View File
@@ -1,100 +1,136 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Generate random chess positions for NNUE training with minimal filtering.""" """Generate random chess positions for NNUE training with multiprocessing."""
import chess import chess
import random import random
import sys import sys
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from multiprocessing import Pool, Queue
from datetime import datetime
import time
def _worker_generate_games(worker_id, games_per_worker, samples_per_game, min_move, max_move):
"""Generate games for one worker.
Returns:
list of FENs generated by this worker
"""
positions = []
for game_num in range(games_per_worker):
board = chess.Board()
move_history = []
# Play a complete random game
while not board.is_game_over() and len(move_history) < 200:
legal_moves = list(board.legal_moves)
if not legal_moves:
break
move = random.choice(legal_moves)
board.push(move)
move_history.append(board.copy())
# Determine the range of moves to sample from
game_length = len(move_history)
valid_start = max(min_move, 0)
valid_end = min(max_move, game_length)
if valid_start >= valid_end:
continue
# Randomly sample positions from this game
sample_count = min(samples_per_game, valid_end - valid_start)
if sample_count > 0:
sample_indices = random.sample(
range(valid_start, valid_end),
k=sample_count
)
for idx in sample_indices:
sampled_board = move_history[idx]
# Only filter truly invalid or terminal positions
if not sampled_board.is_valid() or sampled_board.is_game_over():
continue
# Save position (include check, captures, all positions)
fen = sampled_board.fen()
positions.append(fen)
return positions
def play_random_game_and_collect_positions( def play_random_game_and_collect_positions(
output_file, output_file,
total_games=500000, total_positions=3000000,
samples_per_game=1, samples_per_game=1,
min_move=1, min_move=1,
max_move=50 max_move=50,
num_workers=8
): ):
"""Play random games and sample multiple positions from each. """Generate positions using multiprocessing with multiple workers.
Args: Args:
output_file: Output file for positions output_file: Output file for positions
total_games: Number of games to play total_positions: Target number of positions to generate
samples_per_game: Number of positions to sample per game (1-N) samples_per_game: Number of positions to sample per game (1-N)
min_move: Minimum move number to start sampling from min_move: Minimum move number to start sampling from
max_move: Maximum move number for sampling max_move: Maximum move number for sampling
num_workers: Number of parallel worker processes
Returns: Returns:
Number of valid positions saved Number of valid positions saved
""" """
# Estimate games needed (roughly 1 position per game on average)
total_games = max(total_positions // samples_per_game, num_workers)
games_per_worker = total_games // num_workers
print(f"Generating {total_positions:,} positions using {num_workers} workers")
print(f"Total games: ~{total_games:,} ({games_per_worker:,} per worker)")
print()
start_time = datetime.now()
# Generate positions in parallel
worker_tasks = [
(i, games_per_worker, samples_per_game, min_move, max_move)
for i in range(num_workers)
]
positions_count = 0 positions_count = 0
filtered_game_over = 0 all_positions = []
filtered_illegal = 0
with open(output_file, 'w') as f:
with tqdm(total=total_games, desc="Generating positions") as pbar:
for game_num in range(total_games):
board = chess.Board()
move_history = []
# Play a complete random game
while not board.is_game_over() and len(move_history) < 200:
legal_moves = list(board.legal_moves)
if not legal_moves:
break
move = random.choice(legal_moves)
board.push(move)
move_history.append(board.copy())
# Determine the range of moves to sample from
game_length = len(move_history)
valid_start = max(min_move, 0)
valid_end = min(max_move, game_length)
if valid_start >= valid_end:
# Game too short
pbar.update(1)
continue
# Randomly sample positions from this game
sample_count = min(samples_per_game, valid_end - valid_start)
if sample_count > 0:
sample_indices = random.sample(
range(valid_start, valid_end),
k=sample_count
)
for idx in sample_indices:
sampled_board = move_history[idx]
# Only filter truly invalid or terminal positions
if not sampled_board.is_valid():
filtered_illegal += 1
continue
if sampled_board.is_game_over():
filtered_game_over += 1
continue
# Save position (include check, captures, all positions)
fen = sampled_board.fen()
f.write(fen + '\n')
positions_count += 1
with Pool(num_workers) as pool:
with tqdm(total=num_workers, desc="Workers generating games") as pbar:
for positions in pool.starmap(_worker_generate_games, worker_tasks):
all_positions.extend(positions)
positions_count += len(positions)
pbar.update(1) pbar.update(1)
# Write all positions to file
print(f"Writing {positions_count:,} positions to {output_file}...")
with open(output_file, 'w') as f:
for fen in all_positions:
f.write(fen + '\n')
elapsed_time = datetime.now() - start_time
elapsed_seconds = elapsed_time.total_seconds()
positions_per_second = positions_count / elapsed_seconds if elapsed_seconds > 0 else 0
# Print summary # Print summary
print() print()
print("=" * 60) print("=" * 60)
print("POSITION GENERATION SUMMARY") print("POSITION GENERATION SUMMARY")
print("=" * 60) print("=" * 60)
print(f"Total games played: {total_games}") print(f"Target positions: {total_positions:,}")
print(f"Actual positions saved: {positions_count:,}")
print(f"Workers: {num_workers}")
print(f"Games per worker: {games_per_worker:,}")
print(f"Samples per game: {samples_per_game}") print(f"Samples per game: {samples_per_game}")
print(f"Move range: {min_move}-{max_move}") print(f"Move range: {min_move}-{max_move}")
print(f"Saved positions: {positions_count}") print(f"Elapsed time: {elapsed_time}")
print(f"Filtered (game over): {filtered_game_over}") print(f"Throughput: {positions_per_second:.0f} positions/second")
print(f"Filtered (illegal): {filtered_illegal}")
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
print(f"(Includes checks, captures, all realistic positions)")
print("=" * 60) print("=" * 60)
print() print()
@@ -110,23 +146,26 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training") parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training")
parser.add_argument("output_file", nargs="?", default="positions.txt", parser.add_argument("output_file", nargs="?", default="positions.txt",
help="Output file for positions (default: positions.txt)") help="Output file for positions (default: positions.txt)")
parser.add_argument("--games", type=int, default=5000, parser.add_argument("--positions", type=int, default=3000000,
help="Number of games to play (default: 5000)") help="Target number of positions to generate (default: 3000000)")
parser.add_argument("--samples-per-game", type=int, default=1, parser.add_argument("--samples-per-game", type=int, default=1,
help="Number of positions to sample per game (default: 1)") help="Number of positions to sample per game (default: 1)")
parser.add_argument("--min-move", type=int, default=1, parser.add_argument("--min-move", type=int, default=1,
help="Minimum move number to sample from (default: 1)") help="Minimum move number to sample from (default: 1)")
parser.add_argument("--max-move", type=int, default=50, parser.add_argument("--max-move", type=int, default=50,
help="Maximum move number to sample from (default: 50)") help="Maximum move number to sample from (default: 50)")
parser.add_argument("--workers", type=int, default=8,
help="Number of parallel worker processes (default: 8)")
args = parser.parse_args() args = parser.parse_args()
count = play_random_game_and_collect_positions( count = play_random_game_and_collect_positions(
output_file=args.output_file, output_file=args.output_file,
total_games=args.games, total_positions=args.positions,
samples_per_game=args.samples_per_game, samples_per_game=args.samples_per_game,
min_move=args.min_move, min_move=args.min_move,
max_move=args.max_move max_move=args.max_move,
num_workers=args.workers
) )
sys.exit(0 if count > 0 else 1) sys.exit(0 if count > 0 else 1)
+129 -89
View File
@@ -8,6 +8,8 @@ import os
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from multiprocessing import Pool
from functools import partial
def normalize_evaluation(cp_value, method='tanh', scale=300.0): def normalize_evaluation(cp_value, method='tanh', scale=300.0):
"""Normalize centipawn evaluation to a bounded range. """Normalize centipawn evaluation to a bounded range.
@@ -27,17 +29,68 @@ def normalize_evaluation(cp_value, method='tanh', scale=300.0):
else: else:
return cp_value / 100.0 return cp_value / 100.0
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False, normalize=True): def _evaluate_fen_batch(args):
"""Worker function to evaluate a batch of FENs with Stockfish threading.
Args:
args: tuple of (fens, stockfish_path, depth, normalize)
Returns:
list of (fen, eval_normalized, eval_raw) tuples
"""
fens, stockfish_path, depth, normalize = args
results = []
try:
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
except Exception:
return []
try:
for fen in fens:
try:
board = chess.Board(fen)
if not board.is_valid():
continue
info = engine.analyse(board, chess.engine.Limit(depth=depth))
if info.get('score') is None:
continue
score = info['score'].white()
if score.is_mate():
eval_cp = 2000 if score.mate() > 0 else -2000
else:
eval_cp = score.cp
eval_cp = max(-2000, min(2000, eval_cp))
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
results.append((fen, eval_normalized, eval_cp))
except Exception:
continue
finally:
engine.quit()
return results
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=1000, depth=12, verbose=False, normalize=True, num_workers=1):
"""Read positions and label them with Stockfish evaluations. """Read positions and label them with Stockfish evaluations.
Args: Args:
positions_file: Path to positions.txt positions_file: Path to positions.txt
output_file: Path to training_data.jsonl output_file: Path to training_data.jsonl
stockfish_path: Path to stockfish binary stockfish_path: Path to stockfish binary
batch_size: Batch size (not used, kept for compatibility) batch_size: Batch size for processing (positions per worker task, default: 1000)
depth: Stockfish depth depth: Stockfish depth
verbose: Print detailed error messages verbose: Print detailed error messages
normalize: If True, normalize evals using tanh normalize: If True, normalize evals using tanh
num_workers: Number of parallel Stockfish processes
""" """
# Check if stockfish exists # Check if stockfish exists
@@ -48,6 +101,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
sys.exit(1) sys.exit(1)
print(f"Using Stockfish: {stockfish_path}") print(f"Using Stockfish: {stockfish_path}")
print(f"Number of workers: {num_workers}")
# Check if positions file exists # Check if positions file exists
if not Path(positions_file).exists(): if not Path(positions_file).exists():
@@ -69,107 +123,87 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
pass pass
print(f"Resuming from {position_count} already evaluated positions") print(f"Resuming from {position_count} already evaluated positions")
# Count total positions # Load all FENs that need evaluation
with open(positions_file, 'r') as f: fens_to_evaluate = []
total_lines = sum(1 for _ in f) skipped_invalid = 0
skipped_duplicate = 0
if total_lines == 0: with open(positions_file, 'r') as f:
print(f"Error: Positions file is empty ({positions_file})") for fen in f:
sys.exit(1) fen = fen.strip()
if not fen:
skipped_invalid += 1
continue
if fen in evaluated_fens:
skipped_duplicate += 1
continue
fens_to_evaluate.append(fen)
total_to_evaluate = len(fens_to_evaluate)
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
if total_to_evaluate == 0:
if position_count == 0:
print(f"Error: No valid positions to evaluate in {positions_file}")
sys.exit(1)
else:
print(f"All positions already evaluated. No new positions to process.")
return True
print(f"Total positions to process: {total_lines}") print(f"Total positions to process: {total_lines}")
print(f"New positions to evaluate: {total_to_evaluate}")
print(f"Using depth: {depth}") print(f"Using depth: {depth}")
print() print()
# Initialize engine # Split FENs into batches for workers
try: batches = []
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path) for i in range(0, total_to_evaluate, batch_size):
except Exception as e: batch = fens_to_evaluate[i:i+batch_size]
print(f"Error: Could not start Stockfish engine") batches.append((batch, stockfish_path, depth, normalize))
print(f" Stockfish path: {stockfish_path}")
print(f" Error: {e}")
sys.exit(1)
# Track statistics # Process batches in parallel
evaluated = 0 evaluated = 0
skipped_invalid = 0
skipped_duplicate = 0
errors = 0 errors = 0
raw_evals = [] raw_evals = []
normalized_evals = [] normalized_evals = []
try: import time
with open(positions_file, 'r') as f: start_time = time.time()
with Pool(num_workers) as pool:
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar:
with open(output_file, 'a') as out: with open(output_file, 'a') as out:
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar: for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
for fen in f: for fen, eval_normalized, eval_cp in batch_results:
fen = fen.strip() data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
out.write(json.dumps(data) + '\n')
# Skip empty lines evaluated += 1
if not fen: raw_evals.append(eval_cp)
skipped_invalid += 1 normalized_evals.append(eval_normalized)
pbar.update(1)
continue
# Skip already evaluated
if fen in evaluated_fens:
skipped_duplicate += 1
pbar.update(1)
continue
try:
# Validate FEN
board = chess.Board(fen)
if not board.is_valid():
skipped_invalid += 1
pbar.update(1)
continue
# Evaluate at specified depth
info = engine.analyse(board, chess.engine.Limit(depth=depth))
if info.get('score') is None:
skipped_invalid += 1
pbar.update(1)
continue
score = info['score'].white()
# Convert to centipawns
if score.is_mate():
# Use large values for mate scores
eval_cp = 2000 if score.mate() > 0 else -2000
else:
eval_cp = score.cp
# Clamp to [-2000, 2000]
eval_cp = max(-2000, min(2000, eval_cp))
# Normalize evaluation
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
# Track statistics
raw_evals.append(eval_cp)
normalized_evals.append(eval_normalized)
# Save evaluation (normalized if requested)
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
out.write(json.dumps(data) + '\n')
out.flush() # Force write to disk
evaluated += 1
except Exception as e:
errors += 1
if verbose:
print(f"Error evaluating position: {fen[:50]}...")
print(f" {type(e).__name__}: {e}")
pbar.update(1)
continue
pbar.update(1) pbar.update(1)
finally: # Update progress for any failed evaluations in the batch
engine.quit() batch_size_actual = len(batches[0][0]) if batches else batch_size
failed = batch_size_actual - len(batch_results)
if failed > 0:
errors += failed
pbar.update(failed)
# Calculate and show throughput and ETA
elapsed = time.time() - start_time
throughput = evaluated / elapsed if elapsed > 0 else 0
remaining_positions = total_to_evaluate - evaluated
eta_seconds = remaining_positions / throughput if throughput > 0 else 0
eta_str = f"{int(eta_seconds // 60)}:{int(eta_seconds % 60):02d}"
if (batch_idx + 1) % max(1, len(batches) // 10) == 0:
pbar.set_postfix({
'rate': f'{throughput:.0f} pos/s',
'eta': eta_str
})
# Print summary and analysis # Print summary and analysis
print() print()
@@ -253,10 +287,14 @@ if __name__ == "__main__":
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')") help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
parser.add_argument("--depth", type=int, default=12, parser.add_argument("--depth", type=int, default=12,
help="Stockfish depth (default: 12)") help="Stockfish depth (default: 12)")
parser.add_argument("--batch-size", type=int, default=20,
help="Batch size for processing (default: 1000)")
parser.add_argument("--no-normalize", action="store_true", parser.add_argument("--no-normalize", action="store_true",
help="Disable evaluation normalization (keep raw centipawns)") help="Disable evaluation normalization (keep raw centipawns)")
parser.add_argument("--verbose", action="store_true", parser.add_argument("--verbose", action="store_true",
help="Print detailed error messages") help="Print detailed error messages")
parser.add_argument("--workers", type=int, default=1,
help="Number of parallel Stockfish processes (default: 1)")
args = parser.parse_args() args = parser.parse_args()
@@ -267,9 +305,11 @@ if __name__ == "__main__":
positions_file=args.positions_file, positions_file=args.positions_file,
output_file=args.output_file, output_file=args.output_file,
stockfish_path=stockfish_path, stockfish_path=stockfish_path,
batch_size=args.batch_size,
depth=args.depth, depth=args.depth,
normalize=not args.no_normalize, normalize=not args.no_normalize,
verbose=args.verbose verbose=args.verbose,
num_workers=args.workers
) )
sys.exit(0 if success else 1) sys.exit(0 if success else 1)
@@ -0,0 +1,224 @@
import chess
import csv
import json
import sys
import urllib.request
from pathlib import Path
from typing import Set, Tuple
try:
import zstandard as zstd
except ImportError:
print("zstandard library not found. Install with: pip install zstandard")
sys.exit(1)
from generate import play_random_game_and_collect_positions
def download_and_extract_puzzle_db(
url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst',
output_dir: str = 'trainingdata'
):
"""Download and extract the Lichess puzzle database."""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
csv_file = output_path / 'lichess_db_puzzle.csv'
zst_file = output_path / 'lichess_db_puzzle.csv.zst'
# Download if not already present
if not zst_file.exists():
print(f"Downloading puzzle database from {url}...")
try:
urllib.request.urlretrieve(url, zst_file)
print(f"Downloaded to {zst_file}")
except Exception as e:
print(f"Failed to download: {e}")
return None
# Extract if CSV doesn't exist
if not csv_file.exists():
print(f"Extracting {zst_file}...")
try:
with open(zst_file, 'rb') as f:
dctx = zstd.ZstdDecompressor()
with dctx.stream_reader(f) as reader:
with open(csv_file, 'wb') as out:
out.write(reader.read())
print(f"Extracted to {csv_file}")
except Exception as e:
print(f"Failed to extract: {e}")
return None
return str(csv_file)
def extract_puzzle_positions(
puzzle_csv: str,
max_puzzles: int = 300_000
) -> Set[str]:
"""
Extract the position BEFORE the blunder from each puzzle.
This is exactly the type of position where tactical
recognition matters most.
Returns a set of unique FENs.
"""
positions = set()
with open(puzzle_csv) as f:
reader = csv.DictReader(f)
for row in reader:
if len(positions) >= max_puzzles:
break
try:
board = chess.Board(row['FEN'])
# The puzzle FEN is AFTER the blunder move
# We want the position BEFORE — so it learns
# to find the tactic, not just play it
moves = row['Moves'].split()
# Undo one move to get pre-tactic position
board.push_uci(moves[0]) # opponent blunder
fen = board.fen()
# Filter for useful tactical themes
themes = row.get('Themes', '')
useful = any(t in themes for t in [
'fork', 'pin', 'skewer', 'discoveredAttack',
'mate', 'mateIn2', 'mateIn3', 'hangingPiece',
'trappedPiece', 'sacrifice'
])
if useful:
positions.add(fen)
except Exception:
continue
return positions
def load_positions_from_file(file_path: str) -> Set[str]:
"""Load positions from a text file (one FEN per line)."""
positions = set()
try:
with open(file_path) as f:
for line in f:
line = line.strip()
if line:
positions.add(line)
print(f"Loaded {len(positions)} positions from {file_path}")
return positions
except Exception as e:
print(f"Failed to load from {file_path}: {e}")
return set()
def merge_positions(
tactical: Set[str],
other: Set[str],
output_file: str = 'position.txt'
):
"""Merge two position sets and write to file."""
merged = tactical | other
with open(output_file, 'w') as f:
for fen in merged:
f.write(fen + '\n')
overlap = len(tactical & other)
print(f"\n{'='*60}")
print(f"MERGE SUMMARY")
print(f"{'='*60}")
print(f"Tactical positions: {len(tactical):,}")
print(f"Other positions: {len(other):,}")
print(f"Overlap (deduplicated): {overlap:,}")
print(f"Total merged positions: {len(merged):,}")
print(f"Written to: {output_file}")
print(f"{'='*60}\n")
def interactive_merge_positions(
puzzle_csv: str,
output_file: str = 'position.txt',
max_puzzles: int = 300_000
):
"""Interactive workflow: extract tactical positions and merge with user selection."""
print("\n" + "="*60)
print("TACTICAL POSITION EXTRACTOR & MERGER")
print("="*60 + "\n")
# Extract tactical positions
print("Extracting tactical positions from puzzle database...")
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
print(f"Extracted {len(tactical_positions):,} unique tactical positions\n")
# Ask what to merge with
print("What would you like to merge with these tactical positions?")
print("1. Load from a position file")
print("2. Generate random positions")
print("3. Skip merging (save tactical only)")
choice = input("\nEnter choice (1-3): ").strip()
other_positions = set()
if choice == '1':
file_path = input("Enter path to position file: ").strip()
other_positions = load_positions_from_file(file_path)
elif choice == '2':
positions_to_gen = input("How many positions to generate? (default 1000000): ").strip()
try:
positions_to_gen = int(positions_to_gen) if positions_to_gen else 1000000
except ValueError:
positions_to_gen = 1000000
temp_file = 'temp_generated_positions.txt'
print(f"\nGenerating {positions_to_gen:,} random positions...")
play_random_game_and_collect_positions(
output_file=temp_file,
total_positions=positions_to_gen,
samples_per_game=1,
min_move=1,
max_move=50,
num_workers=8
)
other_positions = load_positions_from_file(temp_file)
elif choice == '3':
print("Skipping merge, saving tactical positions only...")
else:
print("Invalid choice, saving tactical positions only...")
merge_positions(tactical_positions, other_positions, output_file)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description="Extract and merge tactical positions")
parser.add_argument("--url", default='https://database.lichess.org/lichess_db_puzzle.csv.zst',
help="URL to download puzzle database from")
parser.add_argument("--output-dir", default='trainingdata',
help="Directory to extract puzzle database to")
parser.add_argument("--max-puzzles", type=int, default=300_000,
help="Maximum puzzles to extract (default: 300000)")
parser.add_argument("--output-file", default='position.txt',
help="Output file for merged positions (default: position.txt)")
args = parser.parse_args()
# Download and extract
csv_path = download_and_extract_puzzle_db(args.url, args.output_dir)
if csv_path:
# Interactive merge
interactive_merge_positions(csv_path, args.output_file, args.max_puzzles)
else:
print("Failed to download/extract puzzle database")
sys.exit(1)
+151 -57
View File
@@ -87,24 +87,34 @@ def fen_to_features(fen):
return features return features
class NNUE(nn.Module): class NNUE(nn.Module):
"""NNUE neural network architecture.""" """NNUE neural network architecture: 768→1536→1024→512→256→1 with dropout regularization."""
def __init__(self): def __init__(self, dropout_rate=0.2):
super().__init__() super().__init__()
self.l1 = nn.Linear(768, 256) self.l1 = nn.Linear(768, 1536)
self.relu1 = nn.ReLU() self.relu1 = nn.ReLU()
self.l2 = nn.Linear(256, 32) self.drop1 = nn.Dropout(dropout_rate)
self.l2 = nn.Linear(1536, 1024)
self.relu2 = nn.ReLU() self.relu2 = nn.ReLU()
self.l3 = nn.Linear(32, 1) self.drop2 = nn.Dropout(dropout_rate)
self.sigmoid = nn.Sigmoid()
self.l3 = nn.Linear(1024, 512)
self.relu3 = nn.ReLU()
self.drop3 = nn.Dropout(dropout_rate)
self.l4 = nn.Linear(512, 256)
self.relu4 = nn.ReLU()
self.drop4 = nn.Dropout(dropout_rate)
self.l5 = nn.Linear(256, 1)
def forward(self, x): def forward(self, x):
x = self.l1(x) x = self.drop1(self.relu1(self.l1(x)))
x = self.relu1(x) x = self.drop2(self.relu2(self.l2(x)))
x = self.l2(x) x = self.drop3(self.relu3(self.l3(x)))
x = self.relu2(x) x = self.drop4(self.relu4(self.l4(x)))
x = self.l3(x) return self.l5(x)
return x
def find_next_version(base_name="nnue_weights"): def find_next_version(base_name="nnue_weights"):
"""Find the next version number for model versioning. """Find the next version number for model versioning.
@@ -142,21 +152,32 @@ def save_metadata(weights_file, metadata):
return metadata_file return metadata_file
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None): def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4):
"""Train the NNUE model. """Train the NNUE model with GPU optimizations and automatic mixed precision.
Args: Args:
data_file: Path to training_data.jsonl data_file: Path to training_data.jsonl
output_file: Where to save best weights (or base name if use_versioning=True) output_file: Where to save best weights (or base name if use_versioning=True)
epochs: Number of training epochs epochs: Number of training epochs (default: 100)
batch_size: Training batch size batch_size: Training batch size (default: 16384)
lr: Learning rate lr: Learning rate (default: 0.001)
checkpoint: Optional path to checkpoint file to resume from checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata) stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable) early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
""" """
print("Checking GPU availability...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
print("⚠ GPU not available, using CPU")
print(f"Using device: {device}")
print()
print("Loading dataset...") print("Loading dataset...")
dataset = NNUEDataset(data_file) dataset = NNUEDataset(data_file)
num_positions = len(dataset) num_positions = len(dataset)
@@ -182,42 +203,63 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
val_size = len(dataset) - train_size val_size = len(dataset) - train_size
from torch.utils.data import random_split from torch.utils.data import random_split
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # DataLoader with GPU optimizations: num_workers=8, pin_memory, persistent_workers
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) train_loader = DataLoader(
train_dataset,
# Device batch_size=batch_size,
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") shuffle=True,
print(f"Using device: {device}") num_workers=8,
pin_memory=True,
persistent_workers=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=8,
pin_memory=True,
persistent_workers=True
)
# Model # Model
model = NNUE().to(device) model = NNUE().to(device)
criterion = nn.MSELoss() criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr) optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
# Load checkpoint if provided # Cosine annealing learning rate scheduler
checkpoint_to_load = checkpoint scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
if checkpoint_to_load is None and Path(output_file).exists():
# Auto-detect checkpoint: if output file already exists, use it as checkpoint # Mixed precision training
checkpoint_to_load = output_file scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
start_epoch = 0 start_epoch = 0
if checkpoint_to_load is not None and Path(checkpoint_to_load).exists():
print(f"Loading checkpoint from {checkpoint_to_load}...")
try:
checkpoint_state = torch.load(checkpoint_to_load, map_location=device)
model.load_state_dict(checkpoint_state)
print(f"✓ Checkpoint loaded successfully")
except Exception as e:
print(f"Warning: Could not load checkpoint: {e}")
print("Training from scratch instead")
best_val_loss = float('inf') best_val_loss = float('inf')
if checkpoint:
print(f"Loading checkpoint: {checkpoint}")
ckpt = torch.load(checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
scaler.load_state_dict(ckpt["scaler_state_dict"])
start_epoch = ckpt["epoch"] + 1
best_val_loss = ckpt.get("best_val_loss", float('inf'))
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no optimizer state)")
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
best_model_state = None best_model_state = None
epochs_without_improvement = 0 epochs_without_improvement = 0
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...") print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
print(f"Mixed precision training: enabled")
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
if early_stopping_patience: if early_stopping_patience:
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)") print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
print() print()
@@ -236,10 +278,15 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
batch_targets = batch_targets.to(device).unsqueeze(1) batch_targets = batch_targets.to(device).unsqueeze(1)
optimizer.zero_grad() optimizer.zero_grad()
outputs = model(batch_features)
loss = criterion(outputs, batch_targets) # Mixed precision forward and backward
loss.backward() with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
optimizer.step() outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item() * batch_features.size(0) train_loss += loss.item() * batch_features.size(0)
pbar.update(1) pbar.update(1)
@@ -255,19 +302,51 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
batch_features = batch_features.to(device) batch_features = batch_features.to(device)
batch_targets = batch_targets.to(device).unsqueeze(1) batch_targets = batch_targets.to(device).unsqueeze(1)
outputs = model(batch_features) with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
loss = criterion(outputs, batch_targets) outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
val_loss += loss.item() * batch_features.size(0) val_loss += loss.item() * batch_features.size(0)
pbar.update(1) pbar.update(1)
val_loss /= len(val_dataset) val_loss /= len(val_dataset)
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}") # Update learning rate
scheduler.step()
if val_loss < best_val_loss: # Print GPU memory usage
if torch.cuda.is_available():
gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9
gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9
print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved")
# Calculate and print estimated time remaining
elapsed_time = datetime.now() - training_start_time
time_per_epoch = elapsed_time.total_seconds() / (epoch + 1)
remaining_epochs = total_epochs - epoch_display
eta_seconds = time_per_epoch * remaining_epochs
eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0]
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}")
# Save checkpoint after every epoch
checkpoint_file = output_file.replace(".pt", "_checkpoint.pt")
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"best_val_loss": best_val_loss,
}, checkpoint_file)
if val_loss < :
best_val_loss = val_loss best_val_loss = val_loss
best_model_state = model.state_dict().copy() best_model_state = model.state_dict().copy()
epochs_without_improvement = 0 epochs_without_improvement = 0
# Save best model snapshot
snapshot_file = output_file.replace(".pt", "_best_snapshot.pt")
torch.save(best_model_state, snapshot_file)
print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})")
else: else:
epochs_without_improvement += 1 epochs_without_improvement += 1
@@ -277,6 +356,11 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
break break
# Save best model # Save best model
if best_model_state is None or best_val_loss >= checkpoint_val_loss:
print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})")
print("No new model created.")
return
if best_model_state is not None: if best_model_state is not None:
# Determine final output file with versioning # Determine final output file with versioning
final_output_file = output_file final_output_file = output_file
@@ -302,7 +386,14 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
"notes": "Win rate vs classical eval: TBD (requires benchmark games)" "notes": "Win rate vs classical eval: TBD (requires benchmark games)"
} }
torch.save(best_model_state, final_output_file) torch.save({
"model_state_dict": best_model_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"epoch": epoch,
"best_val_loss": best_val_loss,
}, final_output_file)
print(f"Best model saved to {final_output_file}") print(f"Best model saved to {final_output_file}")
# Save metadata if versioning is enabled # Save metadata if versioning is enabled
@@ -327,18 +418,20 @@ if __name__ == "__main__":
help="Output file base name (default: nnue_weights.pt)") help="Output file base name (default: nnue_weights.pt)")
parser.add_argument("--checkpoint", type=str, default=None, parser.add_argument("--checkpoint", type=str, default=None,
help="Path to checkpoint file to resume training from (optional)") help="Path to checkpoint file to resume training from (optional)")
parser.add_argument("--epochs", type=int, default=20, parser.add_argument("--epochs", type=int, default=100,
help="Number of epochs to train (default: 20)") help="Number of epochs to train (default: 100)")
parser.add_argument("--batch-size", type=int, default=4096, parser.add_argument("--batch-size", type=int, default=16384,
help="Batch size (default: 4096)") help="Batch size (default: 16384)")
parser.add_argument("--lr", type=float, default=1e-3, parser.add_argument("--lr", type=float, default=0.001,
help="Learning rate (default: 1e-3)") help="Learning rate (default: 0.001)")
parser.add_argument("--early-stopping", type=int, default=None, parser.add_argument("--early-stopping", type=int, default=None,
help="Stop if val loss doesn't improve for N epochs (optional)") help="Stop if val loss doesn't improve for N epochs (optional)")
parser.add_argument("--stockfish-depth", type=int, default=12, parser.add_argument("--stockfish-depth", type=int, default=12,
help="Stockfish depth used for evaluations (for metadata, default: 12)") help="Stockfish depth used for evaluations (for metadata, default: 12)")
parser.add_argument("--no-versioning", action="store_true", parser.add_argument("--no-versioning", action="store_true",
help="Disable automatic versioning (save directly to output file)") help="Disable automatic versioning (save directly to output file)")
parser.add_argument("--weight-decay", type=float, default=5e-5,
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
args = parser.parse_args() args = parser.parse_args()
@@ -351,5 +444,6 @@ if __name__ == "__main__":
checkpoint=args.checkpoint, checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth, stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning, use_versioning=not args.no_versioning,
early_stopping_patience=args.early_stopping early_stopping_patience=args.early_stopping,
weight_decay=args.weight_decay
) )
+30
View File
@@ -0,0 +1,30 @@
# Setup and run NNUE training pipeline
$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path
$VenvDir = Join-Path $ScriptDir ".venv"
# Check if virtual environment exists
if (-not (Test-Path $VenvDir)) {
Write-Host "Creating virtual environment..."
python -m venv $VenvDir
if ($LASTEXITCODE -ne 0) {
Write-Host "Error: Failed to create virtual environment. Make sure python is installed."
exit 1
}
}
# Activate virtual environment
Write-Host "Activating virtual environment..."
$ActivateScript = Join-Path $VenvDir "Scripts\Activate.ps1"
& $ActivateScript
# Install/update dependencies if requirements.txt exists
$RequirementsFile = Join-Path $ScriptDir "requirements.txt"
if (Test-Path $RequirementsFile) {
Write-Host "Installing dependencies..."
pip install -q -r $RequirementsFile
}
# Run nnue.py
Write-Host "Starting NNUE Training Pipeline..."
python (Join-Path $ScriptDir "nnue.py")
+29
View File
@@ -0,0 +1,29 @@
#!/bin/bash
# Setup and run NNUE training pipeline
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
VENV_DIR="$SCRIPT_DIR/.venv"
# Check if virtual environment exists
if [ ! -d "$VENV_DIR" ]; then
echo "Creating virtual environment..."
python3 -m venv "$VENV_DIR"
if [ $? -ne 0 ]; then
echo "Error: Failed to create virtual environment. Make sure python3 is installed."
exit 1
fi
fi
# Activate virtual environment
echo "Activating virtual environment..."
source "$VENV_DIR/bin/activate"
# Install/update dependencies if requirements.txt exists
if [ -f "$SCRIPT_DIR/requirements.txt" ]; then
echo "Installing dependencies..."
pip install -q -r "$SCRIPT_DIR/requirements.txt"
fi
# Run nnue.py
echo "Starting NNUE Training Pipeline..."
python "$SCRIPT_DIR/nnue.py"
@@ -0,0 +1,7 @@
package de.nowchess.bot
object Config:
/** Threshold in centipawns: if classical evaluation differs from NNUE by more than this,
* the move is vetoed (not accepted as a suggestion). */
val VETO_THRESHOLD: Int = 100
@@ -0,0 +1,23 @@
package de.nowchess.bot.bots
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.bot.logic.HybridSearch
import de.nowchess.bot.util.PolyglotBook
import de.nowchess.bot.{Bot, BotDifficulty}
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
final class HybridBot(
difficulty: BotDifficulty,
rules: RuleSet = DefaultRules,
book: Option[PolyglotBook] = None
) extends Bot:
private val search: HybridSearch = HybridSearch(rules)
override val name: String = s"HybridBot(${difficulty.toString})"
override def nextMove(context: GameContext): Option[Move] =
book.flatMap(_.probe(context))
.orElse(search.bestMove(context))
@@ -7,9 +7,9 @@ import java.nio.ByteOrder
class NNUE: class NNUE:
private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias) = loadWeights() private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) = loadWeights()
private def loadWeights(): (Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float]) = private def loadWeights(): (Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float]) =
val stream = getClass.getResourceAsStream("/nnue_weights.bin") val stream = getClass.getResourceAsStream("/nnue_weights.bin")
if stream == null then if stream == null then
throw RuntimeException("NNUE weights file not found in resources") throw RuntimeException("NNUE weights file not found in resources")
@@ -35,8 +35,12 @@ class NNUE:
val l2b = readTensor(buffer) val l2b = readTensor(buffer)
val l3w = readTensor(buffer) val l3w = readTensor(buffer)
val l3b = readTensor(buffer) val l3b = readTensor(buffer)
val l4w = readTensor(buffer)
val l4b = readTensor(buffer)
val l5w = readTensor(buffer)
val l5b = readTensor(buffer)
(l1w, l1b, l2w, l2b, l3w, l3b) (l1w, l1b, l2w, l2b, l3w, l3b, l4w, l4b, l5w, l5b)
finally stream.close() finally stream.close()
private def readTensor(buffer: ByteBuffer): Array[Float] = private def readTensor(buffer: ByteBuffer): Array[Float] =
@@ -55,10 +59,12 @@ class NNUE:
floats(i) = buffer.getFloat() floats(i) = buffer.getFloat()
floats floats
// Pre-allocated buffers for inference // Pre-allocated buffers for inference (architecture: 768→1536→1024→512→256→1)
private val features = new Array[Float](768) private val features = new Array[Float](768)
private val l1Output = new Array[Float](256) private val l1Output = new Array[Float](1536)
private val l2Output = new Array[Float](32) private val l2Output = new Array[Float](1024)
private val l3Output = new Array[Float](512)
private val l4Output = new Array[Float](256)
/** Convert a position to 768-dimensional binary feature vector. /** Convert a position to 768-dimensional binary feature vector.
* 12 piece types (white pawn to black king) × 64 squares from white's perspective. */ * 12 piece types (white pawn to black king) × 64 squares from white's perspective. */
@@ -110,28 +116,43 @@ class NNUE:
/** Run NNUE inference on the given position. /** Run NNUE inference on the given position.
* Returns centipawn score from the perspective of the side-to-move. * Returns centipawn score from the perspective of the side-to-move.
* No allocations in the hot path (uses pre-allocated buffers). */ * No allocations in the hot path (uses pre-allocated buffers).
* Architecture: 768→1536→1024→512→256→1 */
def evaluate(context: GameContext): Int = def evaluate(context: GameContext): Int =
val features = positionToFeatures(context.board, context.turn) val features = positionToFeatures(context.board, context.turn)
// Layer 1: Dense(768 -> 256) + ReLU // Layer 1: Dense(768 → 1536) + ReLU
for i <- 0 until 256 do for i <- 0 until 1536 do
var sum = l1Bias(i) var sum = l1Bias(i)
for j <- 0 until 768 do for j <- 0 until 768 do
sum += features(j) * l1Weights(i * 768 + j) sum += features(j) * l1Weights(i * 768 + j)
l1Output(i) = if sum > 0f then sum else 0f l1Output(i) = if sum > 0f then sum else 0f
// Layer 2: Dense(256 -> 32) + ReLU // Layer 2: Dense(1536 → 1024) + ReLU
for i <- 0 until 32 do for i <- 0 until 1024 do
var sum = l2Bias(i) var sum = l2Bias(i)
for j <- 0 until 256 do for j <- 0 until 1536 do
sum += l1Output(j) * l2Weights(i * 256 + j) sum += l1Output(j) * l2Weights(i * 1536 + j)
l2Output(i) = if sum > 0f then sum else 0f l2Output(i) = if sum > 0f then sum else 0f
// Layer 3: Dense(32 -> 1), no activation // Layer 3: Dense(1024 → 512) + ReLU
var output = l3Bias(0) for i <- 0 until 512 do
for j <- 0 until 32 do var sum = l3Bias(i)
output += l2Output(j) * l3Weights(j) for j <- 0 until 1024 do
sum += l2Output(j) * l3Weights(i * 1024 + j)
l3Output(i) = if sum > 0f then sum else 0f
// Layer 4: Dense(512 → 256) + ReLU
for i <- 0 until 256 do
var sum = l4Bias(i)
for j <- 0 until 512 do
sum += l3Output(j) * l4Weights(i * 512 + j)
l4Output(i) = if sum > 0f then sum else 0f
// Layer 5: Dense(256 → 1), no activation
var output = l5Bias(0)
for j <- 0 until 256 do
output += l4Output(j) * l5Weights(j)
// Convert from tanh-normalized output back to centipawns // Convert from tanh-normalized output back to centipawns
// Training uses: eval_normalized = tanh(eval_cp / 300) // Training uses: eval_normalized = tanh(eval_cp / 300)
@@ -145,3 +166,33 @@ class NNUE:
(300f * atanh).toInt (300f * atanh).toInt
math.max(-20000, math.min(20000, cp)) math.max(-20000, math.min(20000, cp))
/** Benchmark: time 1M evaluations and report ns/eval.
* This measures the performance of the inference on the starting position. */
def benchmark(): Unit =
val context = GameContext.initial
val iterations = 1_000_000
// Warm up
for _ <- 0 until 10000 do
evaluate(context)
// Actual benchmark
val startNanos = System.nanoTime()
for _ <- 0 until iterations do
evaluate(context)
val endNanos = System.nanoTime()
val totalNanos = endNanos - startNanos
val nanosPerEval = totalNanos.toDouble / iterations
println()
println("=" * 60)
println("NNUE BENCHMARK RESULTS")
println("=" * 60)
println(f"Iterations: $iterations%,d")
println(f"Total time: ${totalNanos / 1e9}%.2f seconds")
println(f"ns/eval: $nanosPerEval%.2f ns")
println(f"evals/second: ${1e9 / nanosPerEval}%.0f evals/s")
println("=" * 60)
println()
@@ -0,0 +1,67 @@
package de.nowchess.bot.logic
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.bot.Config
import de.nowchess.bot.bots.classic.EvaluationClassic
import de.nowchess.bot.bots.nnue.EvaluationNNUE
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
import scala.util.boundary
import scala.util.boundary.break
final class HybridSearch(
rules: RuleSet = DefaultRules
):
private var vetoCount = 0
private var approvalCount = 0
private val TOP_MOVES_TO_VALIDATE = 10
/** Find the best move by scoring all legal moves with NNUE, then validating top 5 with classical eval.
* If a move's classical score is within VETO_THRESHOLD of its NNUE score, it's approved.
* If all top 5 are vetoed, fall back to the best classical move overall.
*/
def bestMove(context: GameContext): Option[Move] =
val legalMoves = rules.allLegalMoves(context)
if legalMoves.isEmpty then None else findBestMove(legalMoves, context)
private def findBestMove(legalMoves: List[Move], context: GameContext): Option[Move] =
// Score all moves with NNUE
val moveScores = legalMoves.map { move =>
val nextContext = rules.applyMove(context)(move)
val nnueScore = EvaluationNNUE.evaluate(nextContext)
(move, nnueScore, nextContext)
}
// Sort by NNUE score descending
val sortedByNNUE = moveScores.sortBy(_._2).reverse
// Validate top N moves with classical evaluation
val topMovesToCheck = sortedByNNUE.take(TOP_MOVES_TO_VALIDATE)
boundary:
for (move, nnueScore, nextContext) <- topMovesToCheck do
val classicalScore = EvaluationClassic.evaluate(nextContext)
val difference = (classicalScore - nnueScore).abs
if difference <= Config.VETO_THRESHOLD then
approvalCount += 1
println(s"[HybridSearch] Move approved: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference)")
break(Some(move))
else
vetoCount += 1
println(s"[HybridSearch] Move vetoed: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference > ${Config.VETO_THRESHOLD})")
// All top 10 were vetoed, fall back to best classical move
println(s"[HybridSearch] All top 10 NNUE moves vetoed. Falling back to best classical move.")
val bestByClassical = moveScores
.map { case (move, _, nextContext) =>
(move, EvaluationClassic.evaluate(nextContext))
}
.maxBy(_._2)
println(s"[HybridSearch] Fallback move: ${bestByClassical._1} (Classical score=${bestByClassical._2})")
println(s"[HybridSearch] Stats - Approvals: $approvalCount, Vetoes: $vetoCount")
Some(bestByClassical._1)
def getStats: (Int, Int) = (approvalCount, vetoCount)
@@ -3,7 +3,7 @@ package de.nowchess.ui
import de.nowchess.api.board.Color.Black import de.nowchess.api.board.Color.Black
import de.nowchess.bot.util.PolyglotBook import de.nowchess.bot.util.PolyglotBook
import de.nowchess.bot.BotDifficulty import de.nowchess.bot.BotDifficulty
import de.nowchess.bot.bots.{ClassicalBot, NNUEBot} import de.nowchess.bot.bots.{ClassicalBot, HybridBot, NNUEBot}
import de.nowchess.chess.engine.GameEngine import de.nowchess.chess.engine.GameEngine
import de.nowchess.ui.terminal.TerminalUI import de.nowchess.ui.terminal.TerminalUI
import de.nowchess.ui.gui.ChessGUILauncher import de.nowchess.ui.gui.ChessGUILauncher
@@ -19,7 +19,7 @@ object Main:
val book = PolyglotBook("../../modules/bot/codekiddy.bin") val book = PolyglotBook("../../modules/bot/codekiddy.bin")
engine.setOpponentBot(ClassicalBot(BotDifficulty.Easy, book = Some(book)), Black); engine.setOpponentBot(HybridBot(BotDifficulty.Easy, book = Some(book)), Black);
// Launch ScalaFX GUI in separate thread // Launch ScalaFX GUI in separate thread
ChessGUILauncher.launch(engine) ChessGUILauncher.launch(engine)