feat: add hybrid bot implementation and enhance NNUE training pipeline with tactical data extraction
This commit is contained in:
@@ -63,3 +63,7 @@ tasks.test {
|
||||
tasks.reportScoverage {
|
||||
dependsOn(tasks.test)
|
||||
}
|
||||
|
||||
tasks.jar {
|
||||
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
|
||||
}
|
||||
|
||||
@@ -17,3 +17,5 @@ ENV/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
tactical_data/
|
||||
trainingdata/
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
"""Central NNUE pipeline TUI for training and exporting models."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
@@ -17,6 +18,10 @@ from generate import play_random_game_and_collect_positions
|
||||
from label import label_positions_with_stockfish
|
||||
from train import train_nnue
|
||||
from export import export_weights_to_binary
|
||||
from tactical_positions_extractor import (
|
||||
download_and_extract_puzzle_db,
|
||||
interactive_merge_positions
|
||||
)
|
||||
|
||||
def get_data_dir():
|
||||
"""Get/create data directory."""
|
||||
@@ -24,6 +29,12 @@ def get_data_dir():
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
return data_dir
|
||||
|
||||
def get_tactical_data_dir():
|
||||
"""Get/create data directory."""
|
||||
data_dir = Path(__file__).parent / "tactical_data"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
return data_dir
|
||||
|
||||
def get_weights_dir():
|
||||
"""Get/create weights directory."""
|
||||
weights_dir = Path(__file__).parent / "weights"
|
||||
@@ -87,20 +98,23 @@ def show_main_menu():
|
||||
console.print("\n[bold]What would you like to do?[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Train NNUE Model")
|
||||
console.print("[cyan]2[/cyan] - Export Weights to Scala")
|
||||
console.print("[cyan]3[/cyan] - View Checkpoints")
|
||||
console.print("[cyan]4[/cyan] - Exit")
|
||||
console.print("[cyan]3[/cyan] - Extract Tactical Positions")
|
||||
console.print("[cyan]4[/cyan] - View Checkpoints")
|
||||
console.print("[cyan]5[/cyan] - Exit")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
|
||||
|
||||
if choice == "1":
|
||||
train_interactive()
|
||||
elif choice == "2":
|
||||
export_interactive()
|
||||
elif choice == "3":
|
||||
extract_tactical_interactive()
|
||||
elif choice == "4":
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
elif choice == "4":
|
||||
elif choice == "5":
|
||||
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||
return
|
||||
|
||||
@@ -146,13 +160,18 @@ def train_interactive():
|
||||
if use_existing_labels:
|
||||
labels_file = Prompt.ask("Enter path to labels file", default=str(get_data_dir() / "training_data.jsonl"))
|
||||
|
||||
# Stockfish path
|
||||
default_stockfish = os.environ.get("STOCKFISH_PATH", "/usr/bin/stockfish")
|
||||
# Stockfish path and labeling parameters
|
||||
default_stockfish = os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
stockfish_path = Prompt.ask("Stockfish path", default=default_stockfish)
|
||||
stockfish_depth = 12
|
||||
num_workers = 1
|
||||
if not use_existing_labels:
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Training parameters
|
||||
epochs = int(Prompt.ask("Number of epochs", default="20"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="4096"))
|
||||
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
early_stopping = None
|
||||
if Confirm.ask("Enable early stopping?", default=False):
|
||||
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||||
@@ -175,6 +194,9 @@ def train_interactive():
|
||||
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||||
else:
|
||||
console.print(f" Early stopping: No")
|
||||
if not use_existing_labels:
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
console.print(f" Workers: {num_workers}")
|
||||
console.print(f" Stockfish: {stockfish_path}")
|
||||
|
||||
if not Confirm.ask("\nStart training?", default=True):
|
||||
@@ -214,7 +236,8 @@ def train_interactive():
|
||||
str(positions_file),
|
||||
str(output_file),
|
||||
stockfish_path,
|
||||
depth=12
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Position labeling failed[/red]")
|
||||
@@ -305,6 +328,66 @@ def export_interactive():
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def extract_tactical_interactive():
|
||||
"""Interactive tactical positions extraction and merge menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]♟️ Tactical Positions Extraction & Merge[/bold cyan]")
|
||||
|
||||
# Download and extract options
|
||||
console.print("\n[bold]Lichess Puzzle Database:[/bold]")
|
||||
download_url = Prompt.ask(
|
||||
"Download URL",
|
||||
default="https://database.lichess.org/lichess_db_puzzle.csv.zst"
|
||||
)
|
||||
|
||||
output_dir = Prompt.ask(
|
||||
"Extract to directory",
|
||||
default=str(Path(__file__).parent / "trainingdata")
|
||||
)
|
||||
|
||||
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||
|
||||
# Confirm and download
|
||||
console.print("\n[bold]Configuration:[/bold]")
|
||||
console.print(f" Download URL: {download_url}")
|
||||
console.print(f" Extract directory: {output_dir}")
|
||||
console.print(f" Max puzzles: {max_puzzles:,}")
|
||||
|
||||
if not Confirm.ask("\nProceed?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
try:
|
||||
console.print("\n[bold cyan]Step 1: Download & Extract[/bold cyan]")
|
||||
csv_path = download_and_extract_puzzle_db(download_url, output_dir)
|
||||
|
||||
if not csv_path:
|
||||
console.print("[red]✗ Failed to download/extract[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print(f"[green]✓ Ready: {csv_path}[/green]")
|
||||
|
||||
# Interactive merge
|
||||
console.print("\n[bold cyan]Step 2: Extract & Merge[/bold cyan]")
|
||||
output_file = Prompt.ask(
|
||||
"Output file path",
|
||||
default=str(Path(__file__).parent / "data" / "position.txt")
|
||||
)
|
||||
|
||||
interactive_merge_positions(csv_path, output_file, max_puzzles)
|
||||
console.print(f"\n[green]✓ Complete![/green]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
def main():
|
||||
try:
|
||||
show_main_menu()
|
||||
|
||||
@@ -2,4 +2,5 @@ chess==1.11.2
|
||||
torch==2.11.0
|
||||
tqdm==4.67.3
|
||||
numpy==2.4.4
|
||||
rich==13.7.0
|
||||
rich==13.7.0
|
||||
zstandard==0.23.0
|
||||
@@ -13,8 +13,9 @@ def export_weights_to_binary(weights_file, output_file):
|
||||
print(f"Error: Weights file not found at {weights_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load weights
|
||||
state_dict = torch.load(weights_file, map_location='cpu')
|
||||
# Load weights — handle both raw state dicts and full training checkpoints
|
||||
loaded = torch.load(weights_file, map_location='cpu')
|
||||
state_dict = loaded["model_state_dict"] if isinstance(loaded, dict) and "model_state_dict" in loaded else loaded
|
||||
|
||||
# Debug: print available layers
|
||||
print(f"Available layers in {weights_file}:")
|
||||
@@ -31,7 +32,7 @@ def export_weights_to_binary(weights_file, output_file):
|
||||
f.write(struct.pack('<I', 1)) # version 1
|
||||
|
||||
# Write each weight tensor in order
|
||||
for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias']:
|
||||
for layer_name in ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias', 'l4.weight', 'l4.bias', 'l5.weight', 'l5.bias']:
|
||||
if layer_name not in state_dict:
|
||||
print(f"Error: Missing layer {layer_name}")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -1,100 +1,136 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate random chess positions for NNUE training with minimal filtering."""
|
||||
"""Generate random chess positions for NNUE training with multiprocessing."""
|
||||
|
||||
import chess
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool, Queue
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
def _worker_generate_games(worker_id, games_per_worker, samples_per_game, min_move, max_move):
|
||||
"""Generate games for one worker.
|
||||
|
||||
Returns:
|
||||
list of FENs generated by this worker
|
||||
"""
|
||||
positions = []
|
||||
|
||||
for game_num in range(games_per_worker):
|
||||
board = chess.Board()
|
||||
move_history = []
|
||||
|
||||
# Play a complete random game
|
||||
while not board.is_game_over() and len(move_history) < 200:
|
||||
legal_moves = list(board.legal_moves)
|
||||
if not legal_moves:
|
||||
break
|
||||
move = random.choice(legal_moves)
|
||||
board.push(move)
|
||||
move_history.append(board.copy())
|
||||
|
||||
# Determine the range of moves to sample from
|
||||
game_length = len(move_history)
|
||||
valid_start = max(min_move, 0)
|
||||
valid_end = min(max_move, game_length)
|
||||
|
||||
if valid_start >= valid_end:
|
||||
continue
|
||||
|
||||
# Randomly sample positions from this game
|
||||
sample_count = min(samples_per_game, valid_end - valid_start)
|
||||
if sample_count > 0:
|
||||
sample_indices = random.sample(
|
||||
range(valid_start, valid_end),
|
||||
k=sample_count
|
||||
)
|
||||
|
||||
for idx in sample_indices:
|
||||
sampled_board = move_history[idx]
|
||||
|
||||
# Only filter truly invalid or terminal positions
|
||||
if not sampled_board.is_valid() or sampled_board.is_game_over():
|
||||
continue
|
||||
|
||||
# Save position (include check, captures, all positions)
|
||||
fen = sampled_board.fen()
|
||||
positions.append(fen)
|
||||
|
||||
return positions
|
||||
|
||||
|
||||
def play_random_game_and_collect_positions(
|
||||
output_file,
|
||||
total_games=500000,
|
||||
total_positions=3000000,
|
||||
samples_per_game=1,
|
||||
min_move=1,
|
||||
max_move=50
|
||||
max_move=50,
|
||||
num_workers=8
|
||||
):
|
||||
"""Play random games and sample multiple positions from each.
|
||||
"""Generate positions using multiprocessing with multiple workers.
|
||||
|
||||
Args:
|
||||
output_file: Output file for positions
|
||||
total_games: Number of games to play
|
||||
total_positions: Target number of positions to generate
|
||||
samples_per_game: Number of positions to sample per game (1-N)
|
||||
min_move: Minimum move number to start sampling from
|
||||
max_move: Maximum move number for sampling
|
||||
num_workers: Number of parallel worker processes
|
||||
|
||||
Returns:
|
||||
Number of valid positions saved
|
||||
"""
|
||||
# Estimate games needed (roughly 1 position per game on average)
|
||||
total_games = max(total_positions // samples_per_game, num_workers)
|
||||
games_per_worker = total_games // num_workers
|
||||
|
||||
print(f"Generating {total_positions:,} positions using {num_workers} workers")
|
||||
print(f"Total games: ~{total_games:,} ({games_per_worker:,} per worker)")
|
||||
print()
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
# Generate positions in parallel
|
||||
worker_tasks = [
|
||||
(i, games_per_worker, samples_per_game, min_move, max_move)
|
||||
for i in range(num_workers)
|
||||
]
|
||||
|
||||
positions_count = 0
|
||||
filtered_game_over = 0
|
||||
filtered_illegal = 0
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
with tqdm(total=total_games, desc="Generating positions") as pbar:
|
||||
for game_num in range(total_games):
|
||||
board = chess.Board()
|
||||
move_history = []
|
||||
|
||||
# Play a complete random game
|
||||
while not board.is_game_over() and len(move_history) < 200:
|
||||
legal_moves = list(board.legal_moves)
|
||||
if not legal_moves:
|
||||
break
|
||||
move = random.choice(legal_moves)
|
||||
board.push(move)
|
||||
move_history.append(board.copy())
|
||||
|
||||
# Determine the range of moves to sample from
|
||||
game_length = len(move_history)
|
||||
valid_start = max(min_move, 0)
|
||||
valid_end = min(max_move, game_length)
|
||||
|
||||
if valid_start >= valid_end:
|
||||
# Game too short
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# Randomly sample positions from this game
|
||||
sample_count = min(samples_per_game, valid_end - valid_start)
|
||||
if sample_count > 0:
|
||||
sample_indices = random.sample(
|
||||
range(valid_start, valid_end),
|
||||
k=sample_count
|
||||
)
|
||||
|
||||
for idx in sample_indices:
|
||||
sampled_board = move_history[idx]
|
||||
|
||||
# Only filter truly invalid or terminal positions
|
||||
if not sampled_board.is_valid():
|
||||
filtered_illegal += 1
|
||||
continue
|
||||
|
||||
if sampled_board.is_game_over():
|
||||
filtered_game_over += 1
|
||||
continue
|
||||
|
||||
# Save position (include check, captures, all positions)
|
||||
fen = sampled_board.fen()
|
||||
f.write(fen + '\n')
|
||||
positions_count += 1
|
||||
all_positions = []
|
||||
|
||||
with Pool(num_workers) as pool:
|
||||
with tqdm(total=num_workers, desc="Workers generating games") as pbar:
|
||||
for positions in pool.starmap(_worker_generate_games, worker_tasks):
|
||||
all_positions.extend(positions)
|
||||
positions_count += len(positions)
|
||||
pbar.update(1)
|
||||
|
||||
# Write all positions to file
|
||||
print(f"Writing {positions_count:,} positions to {output_file}...")
|
||||
with open(output_file, 'w') as f:
|
||||
for fen in all_positions:
|
||||
f.write(fen + '\n')
|
||||
|
||||
elapsed_time = datetime.now() - start_time
|
||||
elapsed_seconds = elapsed_time.total_seconds()
|
||||
positions_per_second = positions_count / elapsed_seconds if elapsed_seconds > 0 else 0
|
||||
|
||||
# Print summary
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("POSITION GENERATION SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Total games played: {total_games}")
|
||||
print(f"Target positions: {total_positions:,}")
|
||||
print(f"Actual positions saved: {positions_count:,}")
|
||||
print(f"Workers: {num_workers}")
|
||||
print(f"Games per worker: {games_per_worker:,}")
|
||||
print(f"Samples per game: {samples_per_game}")
|
||||
print(f"Move range: {min_move}-{max_move}")
|
||||
print(f"Saved positions: {positions_count}")
|
||||
print(f"Filtered (game over): {filtered_game_over}")
|
||||
print(f"Filtered (illegal): {filtered_illegal}")
|
||||
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
|
||||
print(f"(Includes checks, captures, all realistic positions)")
|
||||
print(f"Elapsed time: {elapsed_time}")
|
||||
print(f"Throughput: {positions_per_second:.0f} positions/second")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
@@ -110,23 +146,26 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training")
|
||||
parser.add_argument("output_file", nargs="?", default="positions.txt",
|
||||
help="Output file for positions (default: positions.txt)")
|
||||
parser.add_argument("--games", type=int, default=5000,
|
||||
help="Number of games to play (default: 5000)")
|
||||
parser.add_argument("--positions", type=int, default=3000000,
|
||||
help="Target number of positions to generate (default: 3000000)")
|
||||
parser.add_argument("--samples-per-game", type=int, default=1,
|
||||
help="Number of positions to sample per game (default: 1)")
|
||||
parser.add_argument("--min-move", type=int, default=1,
|
||||
help="Minimum move number to sample from (default: 1)")
|
||||
parser.add_argument("--max-move", type=int, default=50,
|
||||
help="Maximum move number to sample from (default: 50)")
|
||||
parser.add_argument("--workers", type=int, default=8,
|
||||
help="Number of parallel worker processes (default: 8)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
count = play_random_game_and_collect_positions(
|
||||
output_file=args.output_file,
|
||||
total_games=args.games,
|
||||
total_positions=args.positions,
|
||||
samples_per_game=args.samples_per_game,
|
||||
min_move=args.min_move,
|
||||
max_move=args.max_move
|
||||
max_move=args.max_move,
|
||||
num_workers=args.workers
|
||||
)
|
||||
|
||||
sys.exit(0 if count > 0 else 1)
|
||||
|
||||
+129
-89
@@ -8,6 +8,8 @@ import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
|
||||
def normalize_evaluation(cp_value, method='tanh', scale=300.0):
|
||||
"""Normalize centipawn evaluation to a bounded range.
|
||||
@@ -27,17 +29,68 @@ def normalize_evaluation(cp_value, method='tanh', scale=300.0):
|
||||
else:
|
||||
return cp_value / 100.0
|
||||
|
||||
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False, normalize=True):
|
||||
def _evaluate_fen_batch(args):
|
||||
"""Worker function to evaluate a batch of FENs with Stockfish threading.
|
||||
|
||||
Args:
|
||||
args: tuple of (fens, stockfish_path, depth, normalize)
|
||||
|
||||
Returns:
|
||||
list of (fen, eval_normalized, eval_raw) tuples
|
||||
"""
|
||||
fens, stockfish_path, depth, normalize = args
|
||||
|
||||
results = []
|
||||
|
||||
try:
|
||||
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
try:
|
||||
for fen in fens:
|
||||
try:
|
||||
board = chess.Board(fen)
|
||||
if not board.is_valid():
|
||||
continue
|
||||
|
||||
info = engine.analyse(board, chess.engine.Limit(depth=depth))
|
||||
|
||||
if info.get('score') is None:
|
||||
continue
|
||||
|
||||
score = info['score'].white()
|
||||
|
||||
if score.is_mate():
|
||||
eval_cp = 2000 if score.mate() > 0 else -2000
|
||||
else:
|
||||
eval_cp = score.cp
|
||||
|
||||
eval_cp = max(-2000, min(2000, eval_cp))
|
||||
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
|
||||
|
||||
results.append((fen, eval_normalized, eval_cp))
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
finally:
|
||||
engine.quit()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=1000, depth=12, verbose=False, normalize=True, num_workers=1):
|
||||
"""Read positions and label them with Stockfish evaluations.
|
||||
|
||||
Args:
|
||||
positions_file: Path to positions.txt
|
||||
output_file: Path to training_data.jsonl
|
||||
stockfish_path: Path to stockfish binary
|
||||
batch_size: Batch size (not used, kept for compatibility)
|
||||
batch_size: Batch size for processing (positions per worker task, default: 1000)
|
||||
depth: Stockfish depth
|
||||
verbose: Print detailed error messages
|
||||
normalize: If True, normalize evals using tanh
|
||||
num_workers: Number of parallel Stockfish processes
|
||||
"""
|
||||
|
||||
# Check if stockfish exists
|
||||
@@ -48,6 +101,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using Stockfish: {stockfish_path}")
|
||||
print(f"Number of workers: {num_workers}")
|
||||
|
||||
# Check if positions file exists
|
||||
if not Path(positions_file).exists():
|
||||
@@ -69,107 +123,87 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
||||
pass
|
||||
print(f"Resuming from {position_count} already evaluated positions")
|
||||
|
||||
# Count total positions
|
||||
with open(positions_file, 'r') as f:
|
||||
total_lines = sum(1 for _ in f)
|
||||
# Load all FENs that need evaluation
|
||||
fens_to_evaluate = []
|
||||
skipped_invalid = 0
|
||||
skipped_duplicate = 0
|
||||
|
||||
if total_lines == 0:
|
||||
print(f"Error: Positions file is empty ({positions_file})")
|
||||
sys.exit(1)
|
||||
with open(positions_file, 'r') as f:
|
||||
for fen in f:
|
||||
fen = fen.strip()
|
||||
|
||||
if not fen:
|
||||
skipped_invalid += 1
|
||||
continue
|
||||
|
||||
if fen in evaluated_fens:
|
||||
skipped_duplicate += 1
|
||||
continue
|
||||
|
||||
fens_to_evaluate.append(fen)
|
||||
|
||||
total_to_evaluate = len(fens_to_evaluate)
|
||||
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
|
||||
|
||||
if total_to_evaluate == 0:
|
||||
if position_count == 0:
|
||||
print(f"Error: No valid positions to evaluate in {positions_file}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"All positions already evaluated. No new positions to process.")
|
||||
return True
|
||||
|
||||
print(f"Total positions to process: {total_lines}")
|
||||
print(f"New positions to evaluate: {total_to_evaluate}")
|
||||
print(f"Using depth: {depth}")
|
||||
print()
|
||||
|
||||
# Initialize engine
|
||||
try:
|
||||
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
||||
except Exception as e:
|
||||
print(f"Error: Could not start Stockfish engine")
|
||||
print(f" Stockfish path: {stockfish_path}")
|
||||
print(f" Error: {e}")
|
||||
sys.exit(1)
|
||||
# Split FENs into batches for workers
|
||||
batches = []
|
||||
for i in range(0, total_to_evaluate, batch_size):
|
||||
batch = fens_to_evaluate[i:i+batch_size]
|
||||
batches.append((batch, stockfish_path, depth, normalize))
|
||||
|
||||
# Track statistics
|
||||
# Process batches in parallel
|
||||
evaluated = 0
|
||||
skipped_invalid = 0
|
||||
skipped_duplicate = 0
|
||||
errors = 0
|
||||
raw_evals = []
|
||||
normalized_evals = []
|
||||
|
||||
try:
|
||||
with open(positions_file, 'r') as f:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
with Pool(num_workers) as pool:
|
||||
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar:
|
||||
with open(output_file, 'a') as out:
|
||||
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar:
|
||||
for fen in f:
|
||||
fen = fen.strip()
|
||||
|
||||
# Skip empty lines
|
||||
if not fen:
|
||||
skipped_invalid += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# Skip already evaluated
|
||||
if fen in evaluated_fens:
|
||||
skipped_duplicate += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Validate FEN
|
||||
board = chess.Board(fen)
|
||||
if not board.is_valid():
|
||||
skipped_invalid += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# Evaluate at specified depth
|
||||
info = engine.analyse(board, chess.engine.Limit(depth=depth))
|
||||
|
||||
if info.get('score') is None:
|
||||
skipped_invalid += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
score = info['score'].white()
|
||||
|
||||
# Convert to centipawns
|
||||
if score.is_mate():
|
||||
# Use large values for mate scores
|
||||
eval_cp = 2000 if score.mate() > 0 else -2000
|
||||
else:
|
||||
eval_cp = score.cp
|
||||
|
||||
# Clamp to [-2000, 2000]
|
||||
eval_cp = max(-2000, min(2000, eval_cp))
|
||||
|
||||
# Normalize evaluation
|
||||
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
|
||||
|
||||
# Track statistics
|
||||
raw_evals.append(eval_cp)
|
||||
normalized_evals.append(eval_normalized)
|
||||
|
||||
# Save evaluation (normalized if requested)
|
||||
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
|
||||
out.write(json.dumps(data) + '\n')
|
||||
out.flush() # Force write to disk
|
||||
evaluated += 1
|
||||
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
if verbose:
|
||||
print(f"Error evaluating position: {fen[:50]}...")
|
||||
print(f" {type(e).__name__}: {e}")
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
|
||||
for fen, eval_normalized, eval_cp in batch_results:
|
||||
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
|
||||
out.write(json.dumps(data) + '\n')
|
||||
evaluated += 1
|
||||
raw_evals.append(eval_cp)
|
||||
normalized_evals.append(eval_normalized)
|
||||
pbar.update(1)
|
||||
|
||||
finally:
|
||||
engine.quit()
|
||||
# Update progress for any failed evaluations in the batch
|
||||
batch_size_actual = len(batches[0][0]) if batches else batch_size
|
||||
failed = batch_size_actual - len(batch_results)
|
||||
if failed > 0:
|
||||
errors += failed
|
||||
pbar.update(failed)
|
||||
|
||||
# Calculate and show throughput and ETA
|
||||
elapsed = time.time() - start_time
|
||||
throughput = evaluated / elapsed if elapsed > 0 else 0
|
||||
remaining_positions = total_to_evaluate - evaluated
|
||||
eta_seconds = remaining_positions / throughput if throughput > 0 else 0
|
||||
eta_str = f"{int(eta_seconds // 60)}:{int(eta_seconds % 60):02d}"
|
||||
|
||||
if (batch_idx + 1) % max(1, len(batches) // 10) == 0:
|
||||
pbar.set_postfix({
|
||||
'rate': f'{throughput:.0f} pos/s',
|
||||
'eta': eta_str
|
||||
})
|
||||
|
||||
# Print summary and analysis
|
||||
print()
|
||||
@@ -253,10 +287,14 @@ if __name__ == "__main__":
|
||||
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
||||
parser.add_argument("--depth", type=int, default=12,
|
||||
help="Stockfish depth (default: 12)")
|
||||
parser.add_argument("--batch-size", type=int, default=20,
|
||||
help="Batch size for processing (default: 1000)")
|
||||
parser.add_argument("--no-normalize", action="store_true",
|
||||
help="Disable evaluation normalization (keep raw centipawns)")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Print detailed error messages")
|
||||
parser.add_argument("--workers", type=int, default=1,
|
||||
help="Number of parallel Stockfish processes (default: 1)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -267,9 +305,11 @@ if __name__ == "__main__":
|
||||
positions_file=args.positions_file,
|
||||
output_file=args.output_file,
|
||||
stockfish_path=stockfish_path,
|
||||
batch_size=args.batch_size,
|
||||
depth=args.depth,
|
||||
normalize=not args.no_normalize,
|
||||
verbose=args.verbose
|
||||
verbose=args.verbose,
|
||||
num_workers=args.workers
|
||||
)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
import chess
|
||||
import csv
|
||||
import json
|
||||
import sys
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from typing import Set, Tuple
|
||||
|
||||
try:
|
||||
import zstandard as zstd
|
||||
except ImportError:
|
||||
print("zstandard library not found. Install with: pip install zstandard")
|
||||
sys.exit(1)
|
||||
|
||||
from generate import play_random_game_and_collect_positions
|
||||
|
||||
|
||||
def download_and_extract_puzzle_db(
|
||||
url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst',
|
||||
output_dir: str = 'trainingdata'
|
||||
):
|
||||
"""Download and extract the Lichess puzzle database."""
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
csv_file = output_path / 'lichess_db_puzzle.csv'
|
||||
zst_file = output_path / 'lichess_db_puzzle.csv.zst'
|
||||
|
||||
# Download if not already present
|
||||
if not zst_file.exists():
|
||||
print(f"Downloading puzzle database from {url}...")
|
||||
try:
|
||||
urllib.request.urlretrieve(url, zst_file)
|
||||
print(f"Downloaded to {zst_file}")
|
||||
except Exception as e:
|
||||
print(f"Failed to download: {e}")
|
||||
return None
|
||||
|
||||
# Extract if CSV doesn't exist
|
||||
if not csv_file.exists():
|
||||
print(f"Extracting {zst_file}...")
|
||||
try:
|
||||
with open(zst_file, 'rb') as f:
|
||||
dctx = zstd.ZstdDecompressor()
|
||||
with dctx.stream_reader(f) as reader:
|
||||
with open(csv_file, 'wb') as out:
|
||||
out.write(reader.read())
|
||||
print(f"Extracted to {csv_file}")
|
||||
except Exception as e:
|
||||
print(f"Failed to extract: {e}")
|
||||
return None
|
||||
|
||||
return str(csv_file)
|
||||
|
||||
|
||||
def extract_puzzle_positions(
|
||||
puzzle_csv: str,
|
||||
max_puzzles: int = 300_000
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Extract the position BEFORE the blunder from each puzzle.
|
||||
This is exactly the type of position where tactical
|
||||
recognition matters most.
|
||||
|
||||
Returns a set of unique FENs.
|
||||
"""
|
||||
positions = set()
|
||||
|
||||
with open(puzzle_csv) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if len(positions) >= max_puzzles:
|
||||
break
|
||||
|
||||
try:
|
||||
board = chess.Board(row['FEN'])
|
||||
|
||||
# The puzzle FEN is AFTER the blunder move
|
||||
# We want the position BEFORE — so it learns
|
||||
# to find the tactic, not just play it
|
||||
moves = row['Moves'].split()
|
||||
|
||||
# Undo one move to get pre-tactic position
|
||||
board.push_uci(moves[0]) # opponent blunder
|
||||
fen = board.fen()
|
||||
|
||||
# Filter for useful tactical themes
|
||||
themes = row.get('Themes', '')
|
||||
useful = any(t in themes for t in [
|
||||
'fork', 'pin', 'skewer', 'discoveredAttack',
|
||||
'mate', 'mateIn2', 'mateIn3', 'hangingPiece',
|
||||
'trappedPiece', 'sacrifice'
|
||||
])
|
||||
|
||||
if useful:
|
||||
positions.add(fen)
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return positions
|
||||
|
||||
|
||||
def load_positions_from_file(file_path: str) -> Set[str]:
|
||||
"""Load positions from a text file (one FEN per line)."""
|
||||
positions = set()
|
||||
try:
|
||||
with open(file_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
positions.add(line)
|
||||
print(f"Loaded {len(positions)} positions from {file_path}")
|
||||
return positions
|
||||
except Exception as e:
|
||||
print(f"Failed to load from {file_path}: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
def merge_positions(
|
||||
tactical: Set[str],
|
||||
other: Set[str],
|
||||
output_file: str = 'position.txt'
|
||||
):
|
||||
"""Merge two position sets and write to file."""
|
||||
merged = tactical | other
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
for fen in merged:
|
||||
f.write(fen + '\n')
|
||||
|
||||
overlap = len(tactical & other)
|
||||
print(f"\n{'='*60}")
|
||||
print(f"MERGE SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
print(f"Tactical positions: {len(tactical):,}")
|
||||
print(f"Other positions: {len(other):,}")
|
||||
print(f"Overlap (deduplicated): {overlap:,}")
|
||||
print(f"Total merged positions: {len(merged):,}")
|
||||
print(f"Written to: {output_file}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def interactive_merge_positions(
|
||||
puzzle_csv: str,
|
||||
output_file: str = 'position.txt',
|
||||
max_puzzles: int = 300_000
|
||||
):
|
||||
"""Interactive workflow: extract tactical positions and merge with user selection."""
|
||||
print("\n" + "="*60)
|
||||
print("TACTICAL POSITION EXTRACTOR & MERGER")
|
||||
print("="*60 + "\n")
|
||||
|
||||
# Extract tactical positions
|
||||
print("Extracting tactical positions from puzzle database...")
|
||||
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
|
||||
print(f"Extracted {len(tactical_positions):,} unique tactical positions\n")
|
||||
|
||||
# Ask what to merge with
|
||||
print("What would you like to merge with these tactical positions?")
|
||||
print("1. Load from a position file")
|
||||
print("2. Generate random positions")
|
||||
print("3. Skip merging (save tactical only)")
|
||||
|
||||
choice = input("\nEnter choice (1-3): ").strip()
|
||||
|
||||
other_positions = set()
|
||||
|
||||
if choice == '1':
|
||||
file_path = input("Enter path to position file: ").strip()
|
||||
other_positions = load_positions_from_file(file_path)
|
||||
|
||||
elif choice == '2':
|
||||
positions_to_gen = input("How many positions to generate? (default 1000000): ").strip()
|
||||
try:
|
||||
positions_to_gen = int(positions_to_gen) if positions_to_gen else 1000000
|
||||
except ValueError:
|
||||
positions_to_gen = 1000000
|
||||
|
||||
temp_file = 'temp_generated_positions.txt'
|
||||
print(f"\nGenerating {positions_to_gen:,} random positions...")
|
||||
play_random_game_and_collect_positions(
|
||||
output_file=temp_file,
|
||||
total_positions=positions_to_gen,
|
||||
samples_per_game=1,
|
||||
min_move=1,
|
||||
max_move=50,
|
||||
num_workers=8
|
||||
)
|
||||
other_positions = load_positions_from_file(temp_file)
|
||||
|
||||
elif choice == '3':
|
||||
print("Skipping merge, saving tactical positions only...")
|
||||
|
||||
else:
|
||||
print("Invalid choice, saving tactical positions only...")
|
||||
|
||||
merge_positions(tactical_positions, other_positions, output_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Extract and merge tactical positions")
|
||||
parser.add_argument("--url", default='https://database.lichess.org/lichess_db_puzzle.csv.zst',
|
||||
help="URL to download puzzle database from")
|
||||
parser.add_argument("--output-dir", default='trainingdata',
|
||||
help="Directory to extract puzzle database to")
|
||||
parser.add_argument("--max-puzzles", type=int, default=300_000,
|
||||
help="Maximum puzzles to extract (default: 300000)")
|
||||
parser.add_argument("--output-file", default='position.txt',
|
||||
help="Output file for merged positions (default: position.txt)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Download and extract
|
||||
csv_path = download_and_extract_puzzle_db(args.url, args.output_dir)
|
||||
|
||||
if csv_path:
|
||||
# Interactive merge
|
||||
interactive_merge_positions(csv_path, args.output_file, args.max_puzzles)
|
||||
else:
|
||||
print("Failed to download/extract puzzle database")
|
||||
sys.exit(1)
|
||||
+151
-57
@@ -87,24 +87,34 @@ def fen_to_features(fen):
|
||||
return features
|
||||
|
||||
class NNUE(nn.Module):
|
||||
"""NNUE neural network architecture."""
|
||||
"""NNUE neural network architecture: 768→1536→1024→512→256→1 with dropout regularization."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, dropout_rate=0.2):
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(768, 256)
|
||||
self.l1 = nn.Linear(768, 1536)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.l2 = nn.Linear(256, 32)
|
||||
self.drop1 = nn.Dropout(dropout_rate)
|
||||
|
||||
self.l2 = nn.Linear(1536, 1024)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.l3 = nn.Linear(32, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.drop2 = nn.Dropout(dropout_rate)
|
||||
|
||||
self.l3 = nn.Linear(1024, 512)
|
||||
self.relu3 = nn.ReLU()
|
||||
self.drop3 = nn.Dropout(dropout_rate)
|
||||
|
||||
self.l4 = nn.Linear(512, 256)
|
||||
self.relu4 = nn.ReLU()
|
||||
self.drop4 = nn.Dropout(dropout_rate)
|
||||
|
||||
self.l5 = nn.Linear(256, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.l1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.l2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.l3(x)
|
||||
return x
|
||||
x = self.drop1(self.relu1(self.l1(x)))
|
||||
x = self.drop2(self.relu2(self.l2(x)))
|
||||
x = self.drop3(self.relu3(self.l3(x)))
|
||||
x = self.drop4(self.relu4(self.l4(x)))
|
||||
return self.l5(x)
|
||||
|
||||
def find_next_version(base_name="nnue_weights"):
|
||||
"""Find the next version number for model versioning.
|
||||
@@ -142,21 +152,32 @@ def save_metadata(weights_file, metadata):
|
||||
|
||||
return metadata_file
|
||||
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None):
|
||||
"""Train the NNUE model.
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4):
|
||||
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
|
||||
|
||||
Args:
|
||||
data_file: Path to training_data.jsonl
|
||||
output_file: Where to save best weights (or base name if use_versioning=True)
|
||||
epochs: Number of training epochs
|
||||
batch_size: Training batch size
|
||||
lr: Learning rate
|
||||
epochs: Number of training epochs (default: 100)
|
||||
batch_size: Training batch size (default: 16384)
|
||||
lr: Learning rate (default: 0.001)
|
||||
checkpoint: Optional path to checkpoint file to resume from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
|
||||
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
|
||||
"""
|
||||
|
||||
print("Checking GPU availability...")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.is_available():
|
||||
print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
|
||||
print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
||||
else:
|
||||
print("⚠ GPU not available, using CPU")
|
||||
print(f"Using device: {device}")
|
||||
print()
|
||||
|
||||
print("Loading dataset...")
|
||||
dataset = NNUEDataset(data_file)
|
||||
num_positions = len(dataset)
|
||||
@@ -182,42 +203,63 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
|
||||
val_size = len(dataset) - train_size
|
||||
|
||||
from torch.utils.data import random_split
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
generator = torch.Generator().manual_seed(42)
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
# Device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
# DataLoader with GPU optimizations: num_workers=8, pin_memory, persistent_workers
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
)
|
||||
|
||||
# Model
|
||||
model = NNUE().to(device)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
|
||||
# Load checkpoint if provided
|
||||
checkpoint_to_load = checkpoint
|
||||
if checkpoint_to_load is None and Path(output_file).exists():
|
||||
# Auto-detect checkpoint: if output file already exists, use it as checkpoint
|
||||
checkpoint_to_load = output_file
|
||||
# Cosine annealing learning rate scheduler
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||
|
||||
# Mixed precision training
|
||||
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||
|
||||
start_epoch = 0
|
||||
if checkpoint_to_load is not None and Path(checkpoint_to_load).exists():
|
||||
print(f"Loading checkpoint from {checkpoint_to_load}...")
|
||||
try:
|
||||
checkpoint_state = torch.load(checkpoint_to_load, map_location=device)
|
||||
model.load_state_dict(checkpoint_state)
|
||||
print(f"✓ Checkpoint loaded successfully")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load checkpoint: {e}")
|
||||
print("Training from scratch instead")
|
||||
|
||||
best_val_loss = float('inf')
|
||||
if checkpoint:
|
||||
print(f"Loading checkpoint: {checkpoint}")
|
||||
ckpt = torch.load(checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
||||
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
||||
scaler.load_state_dict(ckpt["scaler_state_dict"])
|
||||
start_epoch = ckpt["epoch"] + 1
|
||||
best_val_loss = ckpt.get("best_val_loss", float('inf'))
|
||||
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
print("Loaded weights-only checkpoint (no optimizer state)")
|
||||
|
||||
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
|
||||
best_model_state = None
|
||||
epochs_without_improvement = 0
|
||||
|
||||
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...")
|
||||
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
|
||||
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
|
||||
print(f"Mixed precision training: enabled")
|
||||
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
|
||||
if early_stopping_patience:
|
||||
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
|
||||
print()
|
||||
@@ -236,10 +278,15 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
|
||||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Mixed precision forward and backward
|
||||
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
train_loss += loss.item() * batch_features.size(0)
|
||||
pbar.update(1)
|
||||
@@ -255,19 +302,51 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
|
||||
batch_features = batch_features.to(device)
|
||||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
val_loss += loss.item() * batch_features.size(0)
|
||||
pbar.update(1)
|
||||
|
||||
val_loss /= len(val_dataset)
|
||||
|
||||
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
|
||||
# Update learning rate
|
||||
scheduler.step()
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
# Print GPU memory usage
|
||||
if torch.cuda.is_available():
|
||||
gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9
|
||||
gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9
|
||||
print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved")
|
||||
|
||||
# Calculate and print estimated time remaining
|
||||
elapsed_time = datetime.now() - training_start_time
|
||||
time_per_epoch = elapsed_time.total_seconds() / (epoch + 1)
|
||||
remaining_epochs = total_epochs - epoch_display
|
||||
eta_seconds = time_per_epoch * remaining_epochs
|
||||
eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0]
|
||||
|
||||
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}")
|
||||
|
||||
# Save checkpoint after every epoch
|
||||
checkpoint_file = output_file.replace(".pt", "_checkpoint.pt")
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"scaler_state_dict": scaler.state_dict(),
|
||||
"best_val_loss": best_val_loss,
|
||||
}, checkpoint_file)
|
||||
|
||||
if val_loss < :
|
||||
best_val_loss = val_loss
|
||||
best_model_state = model.state_dict().copy()
|
||||
epochs_without_improvement = 0
|
||||
# Save best model snapshot
|
||||
snapshot_file = output_file.replace(".pt", "_best_snapshot.pt")
|
||||
torch.save(best_model_state, snapshot_file)
|
||||
print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})")
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
|
||||
@@ -277,6 +356,11 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
|
||||
break
|
||||
|
||||
# Save best model
|
||||
if best_model_state is None or best_val_loss >= checkpoint_val_loss:
|
||||
print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})")
|
||||
print("No new model created.")
|
||||
return
|
||||
|
||||
if best_model_state is not None:
|
||||
# Determine final output file with versioning
|
||||
final_output_file = output_file
|
||||
@@ -302,7 +386,14 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
|
||||
torch.save(best_model_state, final_output_file)
|
||||
torch.save({
|
||||
"model_state_dict": best_model_state,
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"scaler_state_dict": scaler.state_dict(),
|
||||
"epoch": epoch,
|
||||
"best_val_loss": best_val_loss,
|
||||
}, final_output_file)
|
||||
print(f"Best model saved to {final_output_file}")
|
||||
|
||||
# Save metadata if versioning is enabled
|
||||
@@ -327,18 +418,20 @@ if __name__ == "__main__":
|
||||
help="Output file base name (default: nnue_weights.pt)")
|
||||
parser.add_argument("--checkpoint", type=str, default=None,
|
||||
help="Path to checkpoint file to resume training from (optional)")
|
||||
parser.add_argument("--epochs", type=int, default=20,
|
||||
help="Number of epochs to train (default: 20)")
|
||||
parser.add_argument("--batch-size", type=int, default=4096,
|
||||
help="Batch size (default: 4096)")
|
||||
parser.add_argument("--lr", type=float, default=1e-3,
|
||||
help="Learning rate (default: 1e-3)")
|
||||
parser.add_argument("--epochs", type=int, default=100,
|
||||
help="Number of epochs to train (default: 100)")
|
||||
parser.add_argument("--batch-size", type=int, default=16384,
|
||||
help="Batch size (default: 16384)")
|
||||
parser.add_argument("--lr", type=float, default=0.001,
|
||||
help="Learning rate (default: 0.001)")
|
||||
parser.add_argument("--early-stopping", type=int, default=None,
|
||||
help="Stop if val loss doesn't improve for N epochs (optional)")
|
||||
parser.add_argument("--stockfish-depth", type=int, default=12,
|
||||
help="Stockfish depth used for evaluations (for metadata, default: 12)")
|
||||
parser.add_argument("--no-versioning", action="store_true",
|
||||
help="Disable automatic versioning (save directly to output file)")
|
||||
parser.add_argument("--weight-decay", type=float, default=5e-5,
|
||||
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -351,5 +444,6 @@ if __name__ == "__main__":
|
||||
checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning,
|
||||
early_stopping_patience=args.early_stopping
|
||||
early_stopping_patience=args.early_stopping,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
# Setup and run NNUE training pipeline
|
||||
|
||||
$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path
|
||||
$VenvDir = Join-Path $ScriptDir ".venv"
|
||||
|
||||
# Check if virtual environment exists
|
||||
if (-not (Test-Path $VenvDir)) {
|
||||
Write-Host "Creating virtual environment..."
|
||||
python -m venv $VenvDir
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Host "Error: Failed to create virtual environment. Make sure python is installed."
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
# Activate virtual environment
|
||||
Write-Host "Activating virtual environment..."
|
||||
$ActivateScript = Join-Path $VenvDir "Scripts\Activate.ps1"
|
||||
& $ActivateScript
|
||||
|
||||
# Install/update dependencies if requirements.txt exists
|
||||
$RequirementsFile = Join-Path $ScriptDir "requirements.txt"
|
||||
if (Test-Path $RequirementsFile) {
|
||||
Write-Host "Installing dependencies..."
|
||||
pip install -q -r $RequirementsFile
|
||||
}
|
||||
|
||||
# Run nnue.py
|
||||
Write-Host "Starting NNUE Training Pipeline..."
|
||||
python (Join-Path $ScriptDir "nnue.py")
|
||||
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
# Setup and run NNUE training pipeline
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
VENV_DIR="$SCRIPT_DIR/.venv"
|
||||
|
||||
# Check if virtual environment exists
|
||||
if [ ! -d "$VENV_DIR" ]; then
|
||||
echo "Creating virtual environment..."
|
||||
python3 -m venv "$VENV_DIR"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to create virtual environment. Make sure python3 is installed."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Activate virtual environment
|
||||
echo "Activating virtual environment..."
|
||||
source "$VENV_DIR/bin/activate"
|
||||
|
||||
# Install/update dependencies if requirements.txt exists
|
||||
if [ -f "$SCRIPT_DIR/requirements.txt" ]; then
|
||||
echo "Installing dependencies..."
|
||||
pip install -q -r "$SCRIPT_DIR/requirements.txt"
|
||||
fi
|
||||
|
||||
# Run nnue.py
|
||||
echo "Starting NNUE Training Pipeline..."
|
||||
python "$SCRIPT_DIR/nnue.py"
|
||||
@@ -0,0 +1,7 @@
|
||||
package de.nowchess.bot
|
||||
|
||||
object Config:
|
||||
|
||||
/** Threshold in centipawns: if classical evaluation differs from NNUE by more than this,
|
||||
* the move is vetoed (not accepted as a suggestion). */
|
||||
val VETO_THRESHOLD: Int = 100
|
||||
@@ -0,0 +1,23 @@
|
||||
package de.nowchess.bot.bots
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.logic.HybridSearch
|
||||
import de.nowchess.bot.util.PolyglotBook
|
||||
import de.nowchess.bot.{Bot, BotDifficulty}
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
|
||||
final class HybridBot(
|
||||
difficulty: BotDifficulty,
|
||||
rules: RuleSet = DefaultRules,
|
||||
book: Option[PolyglotBook] = None
|
||||
) extends Bot:
|
||||
|
||||
private val search: HybridSearch = HybridSearch(rules)
|
||||
|
||||
override val name: String = s"HybridBot(${difficulty.toString})"
|
||||
|
||||
override def nextMove(context: GameContext): Option[Move] =
|
||||
book.flatMap(_.probe(context))
|
||||
.orElse(search.bestMove(context))
|
||||
@@ -7,9 +7,9 @@ import java.nio.ByteOrder
|
||||
|
||||
class NNUE:
|
||||
|
||||
private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias) = loadWeights()
|
||||
private val (l1Weights, l1Bias, l2Weights, l2Bias, l3Weights, l3Bias, l4Weights, l4Bias, l5Weights, l5Bias) = loadWeights()
|
||||
|
||||
private def loadWeights(): (Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float]) =
|
||||
private def loadWeights(): (Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float], Array[Float]) =
|
||||
val stream = getClass.getResourceAsStream("/nnue_weights.bin")
|
||||
if stream == null then
|
||||
throw RuntimeException("NNUE weights file not found in resources")
|
||||
@@ -35,8 +35,12 @@ class NNUE:
|
||||
val l2b = readTensor(buffer)
|
||||
val l3w = readTensor(buffer)
|
||||
val l3b = readTensor(buffer)
|
||||
val l4w = readTensor(buffer)
|
||||
val l4b = readTensor(buffer)
|
||||
val l5w = readTensor(buffer)
|
||||
val l5b = readTensor(buffer)
|
||||
|
||||
(l1w, l1b, l2w, l2b, l3w, l3b)
|
||||
(l1w, l1b, l2w, l2b, l3w, l3b, l4w, l4b, l5w, l5b)
|
||||
finally stream.close()
|
||||
|
||||
private def readTensor(buffer: ByteBuffer): Array[Float] =
|
||||
@@ -55,10 +59,12 @@ class NNUE:
|
||||
floats(i) = buffer.getFloat()
|
||||
floats
|
||||
|
||||
// Pre-allocated buffers for inference
|
||||
// Pre-allocated buffers for inference (architecture: 768→1536→1024→512→256→1)
|
||||
private val features = new Array[Float](768)
|
||||
private val l1Output = new Array[Float](256)
|
||||
private val l2Output = new Array[Float](32)
|
||||
private val l1Output = new Array[Float](1536)
|
||||
private val l2Output = new Array[Float](1024)
|
||||
private val l3Output = new Array[Float](512)
|
||||
private val l4Output = new Array[Float](256)
|
||||
|
||||
/** Convert a position to 768-dimensional binary feature vector.
|
||||
* 12 piece types (white pawn to black king) × 64 squares from white's perspective. */
|
||||
@@ -110,28 +116,43 @@ class NNUE:
|
||||
|
||||
/** Run NNUE inference on the given position.
|
||||
* Returns centipawn score from the perspective of the side-to-move.
|
||||
* No allocations in the hot path (uses pre-allocated buffers). */
|
||||
* No allocations in the hot path (uses pre-allocated buffers).
|
||||
* Architecture: 768→1536→1024→512→256→1 */
|
||||
def evaluate(context: GameContext): Int =
|
||||
val features = positionToFeatures(context.board, context.turn)
|
||||
|
||||
// Layer 1: Dense(768 -> 256) + ReLU
|
||||
for i <- 0 until 256 do
|
||||
// Layer 1: Dense(768 → 1536) + ReLU
|
||||
for i <- 0 until 1536 do
|
||||
var sum = l1Bias(i)
|
||||
for j <- 0 until 768 do
|
||||
sum += features(j) * l1Weights(i * 768 + j)
|
||||
l1Output(i) = if sum > 0f then sum else 0f
|
||||
|
||||
// Layer 2: Dense(256 -> 32) + ReLU
|
||||
for i <- 0 until 32 do
|
||||
// Layer 2: Dense(1536 → 1024) + ReLU
|
||||
for i <- 0 until 1024 do
|
||||
var sum = l2Bias(i)
|
||||
for j <- 0 until 256 do
|
||||
sum += l1Output(j) * l2Weights(i * 256 + j)
|
||||
for j <- 0 until 1536 do
|
||||
sum += l1Output(j) * l2Weights(i * 1536 + j)
|
||||
l2Output(i) = if sum > 0f then sum else 0f
|
||||
|
||||
// Layer 3: Dense(32 -> 1), no activation
|
||||
var output = l3Bias(0)
|
||||
for j <- 0 until 32 do
|
||||
output += l2Output(j) * l3Weights(j)
|
||||
// Layer 3: Dense(1024 → 512) + ReLU
|
||||
for i <- 0 until 512 do
|
||||
var sum = l3Bias(i)
|
||||
for j <- 0 until 1024 do
|
||||
sum += l2Output(j) * l3Weights(i * 1024 + j)
|
||||
l3Output(i) = if sum > 0f then sum else 0f
|
||||
|
||||
// Layer 4: Dense(512 → 256) + ReLU
|
||||
for i <- 0 until 256 do
|
||||
var sum = l4Bias(i)
|
||||
for j <- 0 until 512 do
|
||||
sum += l3Output(j) * l4Weights(i * 512 + j)
|
||||
l4Output(i) = if sum > 0f then sum else 0f
|
||||
|
||||
// Layer 5: Dense(256 → 1), no activation
|
||||
var output = l5Bias(0)
|
||||
for j <- 0 until 256 do
|
||||
output += l4Output(j) * l5Weights(j)
|
||||
|
||||
// Convert from tanh-normalized output back to centipawns
|
||||
// Training uses: eval_normalized = tanh(eval_cp / 300)
|
||||
@@ -145,3 +166,33 @@ class NNUE:
|
||||
(300f * atanh).toInt
|
||||
|
||||
math.max(-20000, math.min(20000, cp))
|
||||
|
||||
/** Benchmark: time 1M evaluations and report ns/eval.
|
||||
* This measures the performance of the inference on the starting position. */
|
||||
def benchmark(): Unit =
|
||||
val context = GameContext.initial
|
||||
val iterations = 1_000_000
|
||||
|
||||
// Warm up
|
||||
for _ <- 0 until 10000 do
|
||||
evaluate(context)
|
||||
|
||||
// Actual benchmark
|
||||
val startNanos = System.nanoTime()
|
||||
for _ <- 0 until iterations do
|
||||
evaluate(context)
|
||||
val endNanos = System.nanoTime()
|
||||
|
||||
val totalNanos = endNanos - startNanos
|
||||
val nanosPerEval = totalNanos.toDouble / iterations
|
||||
|
||||
println()
|
||||
println("=" * 60)
|
||||
println("NNUE BENCHMARK RESULTS")
|
||||
println("=" * 60)
|
||||
println(f"Iterations: $iterations%,d")
|
||||
println(f"Total time: ${totalNanos / 1e9}%.2f seconds")
|
||||
println(f"ns/eval: $nanosPerEval%.2f ns")
|
||||
println(f"evals/second: ${1e9 / nanosPerEval}%.0f evals/s")
|
||||
println("=" * 60)
|
||||
println()
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
package de.nowchess.bot.logic
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.Config
|
||||
import de.nowchess.bot.bots.classic.EvaluationClassic
|
||||
import de.nowchess.bot.bots.nnue.EvaluationNNUE
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
import scala.util.boundary
|
||||
import scala.util.boundary.break
|
||||
|
||||
final class HybridSearch(
|
||||
rules: RuleSet = DefaultRules
|
||||
):
|
||||
|
||||
private var vetoCount = 0
|
||||
private var approvalCount = 0
|
||||
private val TOP_MOVES_TO_VALIDATE = 10
|
||||
|
||||
/** Find the best move by scoring all legal moves with NNUE, then validating top 5 with classical eval.
|
||||
* If a move's classical score is within VETO_THRESHOLD of its NNUE score, it's approved.
|
||||
* If all top 5 are vetoed, fall back to the best classical move overall.
|
||||
*/
|
||||
def bestMove(context: GameContext): Option[Move] =
|
||||
val legalMoves = rules.allLegalMoves(context)
|
||||
if legalMoves.isEmpty then None else findBestMove(legalMoves, context)
|
||||
|
||||
private def findBestMove(legalMoves: List[Move], context: GameContext): Option[Move] =
|
||||
// Score all moves with NNUE
|
||||
val moveScores = legalMoves.map { move =>
|
||||
val nextContext = rules.applyMove(context)(move)
|
||||
val nnueScore = EvaluationNNUE.evaluate(nextContext)
|
||||
(move, nnueScore, nextContext)
|
||||
}
|
||||
|
||||
// Sort by NNUE score descending
|
||||
val sortedByNNUE = moveScores.sortBy(_._2).reverse
|
||||
|
||||
// Validate top N moves with classical evaluation
|
||||
val topMovesToCheck = sortedByNNUE.take(TOP_MOVES_TO_VALIDATE)
|
||||
|
||||
boundary:
|
||||
for (move, nnueScore, nextContext) <- topMovesToCheck do
|
||||
val classicalScore = EvaluationClassic.evaluate(nextContext)
|
||||
val difference = (classicalScore - nnueScore).abs
|
||||
if difference <= Config.VETO_THRESHOLD then
|
||||
approvalCount += 1
|
||||
println(s"[HybridSearch] Move approved: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference)")
|
||||
break(Some(move))
|
||||
else
|
||||
vetoCount += 1
|
||||
println(s"[HybridSearch] Move vetoed: $move (NNUE=$nnueScore, Classical=$classicalScore, diff=$difference > ${Config.VETO_THRESHOLD})")
|
||||
|
||||
// All top 10 were vetoed, fall back to best classical move
|
||||
println(s"[HybridSearch] All top 10 NNUE moves vetoed. Falling back to best classical move.")
|
||||
val bestByClassical = moveScores
|
||||
.map { case (move, _, nextContext) =>
|
||||
(move, EvaluationClassic.evaluate(nextContext))
|
||||
}
|
||||
.maxBy(_._2)
|
||||
|
||||
println(s"[HybridSearch] Fallback move: ${bestByClassical._1} (Classical score=${bestByClassical._2})")
|
||||
println(s"[HybridSearch] Stats - Approvals: $approvalCount, Vetoes: $vetoCount")
|
||||
Some(bestByClassical._1)
|
||||
|
||||
def getStats: (Int, Int) = (approvalCount, vetoCount)
|
||||
@@ -3,7 +3,7 @@ package de.nowchess.ui
|
||||
import de.nowchess.api.board.Color.Black
|
||||
import de.nowchess.bot.util.PolyglotBook
|
||||
import de.nowchess.bot.BotDifficulty
|
||||
import de.nowchess.bot.bots.{ClassicalBot, NNUEBot}
|
||||
import de.nowchess.bot.bots.{ClassicalBot, HybridBot, NNUEBot}
|
||||
import de.nowchess.chess.engine.GameEngine
|
||||
import de.nowchess.ui.terminal.TerminalUI
|
||||
import de.nowchess.ui.gui.ChessGUILauncher
|
||||
@@ -19,7 +19,7 @@ object Main:
|
||||
|
||||
val book = PolyglotBook("../../modules/bot/codekiddy.bin")
|
||||
|
||||
engine.setOpponentBot(ClassicalBot(BotDifficulty.Easy, book = Some(book)), Black);
|
||||
engine.setOpponentBot(HybridBot(BotDifficulty.Easy, book = Some(book)), Black);
|
||||
|
||||
// Launch ScalaFX GUI in separate thread
|
||||
ChessGUILauncher.launch(engine)
|
||||
|
||||
Reference in New Issue
Block a user