feat: enhance training and evaluation processes with new parameters and normalization options
This commit is contained in:
@@ -129,11 +129,17 @@ def train_interactive():
|
||||
use_existing = Confirm.ask("Use existing positions file?", default=False)
|
||||
positions_file = None
|
||||
num_games = 500000
|
||||
samples_per_game = 1
|
||||
min_move = 1
|
||||
max_move = 50
|
||||
|
||||
if use_existing:
|
||||
positions_file = Prompt.ask("Enter path to positions file", default=str(get_data_dir() / "positions.txt"))
|
||||
else:
|
||||
num_games = int(Prompt.ask("Number of games to generate", default="500000"))
|
||||
num_games = int(Prompt.ask("Number of games to generate", default="5000"))
|
||||
samples_per_game = int(Prompt.ask("Positions to sample per game", default="1"))
|
||||
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||
|
||||
use_existing_labels = Confirm.ask("Use existing labels file?", default=False)
|
||||
labels_file = None
|
||||
@@ -147,6 +153,9 @@ def train_interactive():
|
||||
# Training parameters
|
||||
epochs = int(Prompt.ask("Number of epochs", default="20"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="4096"))
|
||||
early_stopping = None
|
||||
if Confirm.ask("Enable early stopping?", default=False):
|
||||
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||||
|
||||
# Confirm and start
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
@@ -154,9 +163,18 @@ def train_interactive():
|
||||
console.print(f" Checkpoint: v{checkpoint_version}")
|
||||
else:
|
||||
console.print(" Checkpoint: None (training from scratch)")
|
||||
console.print(f" Games: {num_games:,}")
|
||||
if not use_existing:
|
||||
console.print(f" Games: {num_games:,}")
|
||||
console.print(f" Samples per game: {samples_per_game}")
|
||||
console.print(f" Move range: {min_move}-{max_move}")
|
||||
else:
|
||||
console.print(f" Positions file: {positions_file}")
|
||||
console.print(f" Epochs: {epochs}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
if early_stopping:
|
||||
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||||
else:
|
||||
console.print(f" Early stopping: No")
|
||||
console.print(f" Stockfish: {stockfish_path}")
|
||||
|
||||
if not Confirm.ask("\nStart training?", default=True):
|
||||
@@ -174,7 +192,9 @@ def train_interactive():
|
||||
count = play_random_game_and_collect_positions(
|
||||
str(data_dir / "positions.txt"),
|
||||
total_games=num_games,
|
||||
filter_captures=True
|
||||
samples_per_game=samples_per_game,
|
||||
min_move=min_move,
|
||||
max_move=max_move
|
||||
)
|
||||
if count == 0:
|
||||
console.print("[red]✗ No valid positions generated[/red]")
|
||||
@@ -219,7 +239,8 @@ def train_interactive():
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
checkpoint=checkpoint,
|
||||
use_versioning=True
|
||||
use_versioning=True,
|
||||
early_stopping_patience=early_stopping
|
||||
)
|
||||
console.print("[green]✓ Training complete[/green]")
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate 500,000 random chess positions for NNUE training."""
|
||||
"""Generate random chess positions for NNUE training with minimal filtering."""
|
||||
|
||||
import chess
|
||||
import random
|
||||
@@ -7,89 +7,78 @@ import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
# Standard piece values for capture filtering
|
||||
PIECE_VALUES = {
|
||||
chess.PAWN: 1,
|
||||
chess.KNIGHT: 3,
|
||||
chess.BISHOP: 3,
|
||||
chess.ROOK: 5,
|
||||
chess.QUEEN: 9,
|
||||
}
|
||||
def play_random_game_and_collect_positions(
|
||||
output_file,
|
||||
total_games=500000,
|
||||
samples_per_game=1,
|
||||
min_move=1,
|
||||
max_move=50
|
||||
):
|
||||
"""Play random games and sample multiple positions from each.
|
||||
|
||||
def has_winning_or_equal_capture(board):
|
||||
"""Check if position has a capture where victim >= attacker (winning or equal trade).
|
||||
|
||||
Returns True only if there's at least one favorable capture.
|
||||
Positions with only losing captures return False (are kept).
|
||||
"""
|
||||
for move in board.legal_moves:
|
||||
if board.is_capture(move):
|
||||
attacker_piece = board.piece_at(move.from_square)
|
||||
victim_piece = board.piece_at(move.to_square)
|
||||
|
||||
if attacker_piece and victim_piece:
|
||||
attacker_value = PIECE_VALUES.get(attacker_piece.piece_type, 0)
|
||||
victim_value = PIECE_VALUES.get(victim_piece.piece_type, 0)
|
||||
|
||||
# If victim >= attacker, it's a winning or equal capture
|
||||
if victim_value >= attacker_value:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def play_random_game_and_collect_positions(output_file, total_games=500000, filter_captures=True):
|
||||
"""Play random games and save positions after 8-20 random moves.
|
||||
Args:
|
||||
output_file: Output file for positions
|
||||
total_games: Number of games to play
|
||||
samples_per_game: Number of positions to sample per game (1-N)
|
||||
min_move: Minimum move number to start sampling from
|
||||
max_move: Maximum move number for sampling
|
||||
|
||||
Returns:
|
||||
Number of valid positions saved
|
||||
"""
|
||||
positions_count = 0
|
||||
filtered_check = 0
|
||||
filtered_captures = 0
|
||||
filtered_game_over = 0
|
||||
filtered_illegal = 0
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
with tqdm(total=total_games, desc="Generating positions") as pbar:
|
||||
for game_num in range(total_games):
|
||||
board = chess.Board()
|
||||
move_history = []
|
||||
|
||||
# Play 8-20 random opening moves
|
||||
num_moves = random.randint(8, 20)
|
||||
|
||||
for move_num in range(num_moves):
|
||||
if board.is_game_over():
|
||||
break
|
||||
|
||||
# Play a complete random game
|
||||
while not board.is_game_over() and len(move_history) < 200:
|
||||
legal_moves = list(board.legal_moves)
|
||||
if not legal_moves:
|
||||
break
|
||||
|
||||
move = random.choice(legal_moves)
|
||||
board.push(move)
|
||||
move_history.append(board.copy())
|
||||
|
||||
# Skip if game over
|
||||
if board.is_game_over():
|
||||
filtered_game_over += 1
|
||||
# Determine the range of moves to sample from
|
||||
game_length = len(move_history)
|
||||
valid_start = max(min_move, 0)
|
||||
valid_end = min(max_move, game_length)
|
||||
|
||||
if valid_start >= valid_end:
|
||||
# Game too short
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# Skip if in check
|
||||
if board.is_check():
|
||||
filtered_check += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
# Randomly sample positions from this game
|
||||
sample_count = min(samples_per_game, valid_end - valid_start)
|
||||
if sample_count > 0:
|
||||
sample_indices = random.sample(
|
||||
range(valid_start, valid_end),
|
||||
k=sample_count
|
||||
)
|
||||
|
||||
# Check if there are winning or equal captures (if filtering enabled)
|
||||
if filter_captures:
|
||||
if has_winning_or_equal_capture(board):
|
||||
filtered_captures += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
for idx in sample_indices:
|
||||
sampled_board = move_history[idx]
|
||||
|
||||
# Save valid position
|
||||
fen = board.fen()
|
||||
f.write(fen + '\n')
|
||||
positions_count += 1
|
||||
# Only filter truly invalid or terminal positions
|
||||
if not sampled_board.is_valid():
|
||||
filtered_illegal += 1
|
||||
continue
|
||||
|
||||
if sampled_board.is_game_over():
|
||||
filtered_game_over += 1
|
||||
continue
|
||||
|
||||
# Save position (include check, captures, all positions)
|
||||
fen = sampled_board.fen()
|
||||
f.write(fen + '\n')
|
||||
positions_count += 1
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
@@ -98,23 +87,19 @@ def play_random_game_and_collect_positions(output_file, total_games=500000, filt
|
||||
print("=" * 60)
|
||||
print("POSITION GENERATION SUMMARY")
|
||||
print("=" * 60)
|
||||
total_filtered = filtered_check + filtered_captures + filtered_game_over
|
||||
print(f"Total games: {total_games}")
|
||||
print(f"Total games played: {total_games}")
|
||||
print(f"Samples per game: {samples_per_game}")
|
||||
print(f"Move range: {min_move}-{max_move}")
|
||||
print(f"Saved positions: {positions_count}")
|
||||
print(f"Filtered (in check): {filtered_check}")
|
||||
print(f"Filtered (winning+ cap): {filtered_captures}")
|
||||
print(f"Filtered (game over): {filtered_game_over}")
|
||||
print(f"Total filtered: {total_filtered}")
|
||||
print(f"Filtered (illegal): {filtered_illegal}")
|
||||
print(f"Acceptance rate: {positions_count / total_games * 100:.2f}%")
|
||||
print(f"(Keeps positions with only losing/bad captures)")
|
||||
print(f"(Includes checks, captures, all realistic positions)")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if positions_count == 0:
|
||||
print("WARNING: No valid positions were generated!")
|
||||
print("This might happen if:")
|
||||
print(" - Most positions have checks or game-over states")
|
||||
print(" - Try using: --no-filter-captures to accept positions with winning captures")
|
||||
return 0
|
||||
|
||||
return positions_count
|
||||
@@ -126,16 +111,22 @@ if __name__ == "__main__":
|
||||
parser.add_argument("output_file", nargs="?", default="positions.txt",
|
||||
help="Output file for positions (default: positions.txt)")
|
||||
parser.add_argument("--games", type=int, default=5000,
|
||||
help="Number of games to play (default: 500000)")
|
||||
parser.add_argument("--no-filter-captures", action="store_true",
|
||||
help="Include positions with winning/equal captures (increases output)")
|
||||
help="Number of games to play (default: 5000)")
|
||||
parser.add_argument("--samples-per-game", type=int, default=1,
|
||||
help="Number of positions to sample per game (default: 1)")
|
||||
parser.add_argument("--min-move", type=int, default=1,
|
||||
help="Minimum move number to sample from (default: 1)")
|
||||
parser.add_argument("--max-move", type=int, default=50,
|
||||
help="Maximum move number to sample from (default: 50)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
count = play_random_game_and_collect_positions(
|
||||
output_file=args.output_file,
|
||||
total_games=args.games,
|
||||
filter_captures=not args.no_filter_captures
|
||||
samples_per_game=args.samples_per_game,
|
||||
min_move=args.min_move,
|
||||
max_move=args.max_move
|
||||
)
|
||||
|
||||
sys.exit(0 if count > 0 else 1)
|
||||
|
||||
@@ -1,14 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Label positions with Stockfish evaluations."""
|
||||
"""Label positions with Stockfish evaluations and analyze distribution."""
|
||||
|
||||
import json
|
||||
import chess.engine
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False):
|
||||
def normalize_evaluation(cp_value, method='tanh', scale=300.0):
|
||||
"""Normalize centipawn evaluation to a bounded range.
|
||||
|
||||
Args:
|
||||
cp_value: Centipawn evaluation from Stockfish
|
||||
method: 'tanh' (default) or 'sigmoid'
|
||||
scale: Scale factor (tanh: 300 is typical)
|
||||
|
||||
Returns:
|
||||
Normalized value in approximately [-1, 1] (tanh) or [0, 1] (sigmoid)
|
||||
"""
|
||||
if method == 'tanh':
|
||||
return np.tanh(cp_value / scale)
|
||||
elif method == 'sigmoid':
|
||||
return 1.0 / (1.0 + np.exp(-cp_value / scale))
|
||||
else:
|
||||
return cp_value / 100.0
|
||||
|
||||
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=100, depth=12, verbose=False, normalize=True):
|
||||
"""Read positions and label them with Stockfish evaluations.
|
||||
|
||||
Args:
|
||||
@@ -18,6 +37,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
||||
batch_size: Batch size (not used, kept for compatibility)
|
||||
depth: Stockfish depth
|
||||
verbose: Print detailed error messages
|
||||
normalize: If True, normalize evals using tanh
|
||||
"""
|
||||
|
||||
# Check if stockfish exists
|
||||
@@ -75,6 +95,8 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
||||
skipped_invalid = 0
|
||||
skipped_duplicate = 0
|
||||
errors = 0
|
||||
raw_evals = []
|
||||
normalized_evals = []
|
||||
|
||||
try:
|
||||
with open(positions_file, 'r') as f:
|
||||
@@ -123,8 +145,15 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
||||
# Clamp to [-2000, 2000]
|
||||
eval_cp = max(-2000, min(2000, eval_cp))
|
||||
|
||||
# Save evaluation
|
||||
data = {"fen": fen, "eval": eval_cp}
|
||||
# Normalize evaluation
|
||||
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
|
||||
|
||||
# Track statistics
|
||||
raw_evals.append(eval_cp)
|
||||
normalized_evals.append(eval_normalized)
|
||||
|
||||
# Save evaluation (normalized if requested)
|
||||
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
|
||||
out.write(json.dumps(data) + '\n')
|
||||
out.flush() # Force write to disk
|
||||
evaluated += 1
|
||||
@@ -142,7 +171,7 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
||||
finally:
|
||||
engine.quit()
|
||||
|
||||
# Print summary
|
||||
# Print summary and analysis
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("LABELING SUMMARY")
|
||||
@@ -164,6 +193,51 @@ def label_positions_with_stockfish(positions_file, output_file, stockfish_path,
|
||||
print(" 4. Stockfish path is correct")
|
||||
return False
|
||||
|
||||
# Print distribution analysis
|
||||
if raw_evals:
|
||||
raw_evals_arr = np.array(raw_evals)
|
||||
norm_evals_arr = np.array(normalized_evals)
|
||||
|
||||
print("=" * 60)
|
||||
print("EVALUATION DISTRIBUTION ANALYSIS")
|
||||
print("=" * 60)
|
||||
print()
|
||||
print("Raw Evaluations (centipawns):")
|
||||
print(f" Min: {raw_evals_arr.min():.1f}")
|
||||
print(f" Max: {raw_evals_arr.max():.1f}")
|
||||
print(f" Mean: {raw_evals_arr.mean():.1f}")
|
||||
print(f" Median: {np.median(raw_evals_arr):.1f}")
|
||||
print(f" Std: {raw_evals_arr.std():.1f}")
|
||||
print()
|
||||
|
||||
print("Normalized Evaluations (tanh):")
|
||||
print(f" Min: {norm_evals_arr.min():.4f}")
|
||||
print(f" Max: {norm_evals_arr.max():.4f}")
|
||||
print(f" Mean: {norm_evals_arr.mean():.4f}")
|
||||
print(f" Median: {np.median(norm_evals_arr):.4f}")
|
||||
print(f" Std: {norm_evals_arr.std():.4f}")
|
||||
print()
|
||||
|
||||
# Distribution buckets
|
||||
print("Raw Evaluation Buckets (counts):")
|
||||
buckets = [
|
||||
(-float('inf'), -500, "< -5.00"),
|
||||
(-500, -300, "[-5.00, -3.00)"),
|
||||
(-300, -100, "[-3.00, -1.00)"),
|
||||
(-100, 0, "[-1.00, 0.00)"),
|
||||
(0, 100, "[0.00, 1.00)"),
|
||||
(100, 300, "[1.00, 3.00)"),
|
||||
(300, 500, "[3.00, 5.00)"),
|
||||
(500, float('inf'), "> 5.00"),
|
||||
]
|
||||
for low, high, label in buckets:
|
||||
count = np.sum((raw_evals_arr > low) & (raw_evals_arr <= high))
|
||||
pct = 100.0 * count / len(raw_evals_arr)
|
||||
print(f" {label}: {count:6d} ({pct:5.1f}%)")
|
||||
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
print(f"✓ Labeling complete. Output saved to {output_file}")
|
||||
return True
|
||||
|
||||
@@ -179,6 +253,8 @@ if __name__ == "__main__":
|
||||
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
||||
parser.add_argument("--depth", type=int, default=12,
|
||||
help="Stockfish depth (default: 12)")
|
||||
parser.add_argument("--no-normalize", action="store_true",
|
||||
help="Disable evaluation normalization (keep raw centipawns)")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Print detailed error messages")
|
||||
|
||||
@@ -192,6 +268,7 @@ if __name__ == "__main__":
|
||||
output_file=args.output_file,
|
||||
stockfish_path=stockfish_path,
|
||||
depth=args.depth,
|
||||
normalize=not args.no_normalize,
|
||||
verbose=args.verbose
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
||||
import chess
|
||||
from datetime import datetime
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
class NNUEDataset(Dataset):
|
||||
"""Dataset of chess positions with evaluations."""
|
||||
@@ -19,15 +20,28 @@ class NNUEDataset(Dataset):
|
||||
def __init__(self, data_file):
|
||||
self.positions = []
|
||||
self.evals = []
|
||||
self.evals_raw = []
|
||||
self.is_normalized = None
|
||||
|
||||
with open(data_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data['fen']
|
||||
eval_cp = data['eval']
|
||||
eval_val = data['eval']
|
||||
self.positions.append(fen)
|
||||
self.evals.append(eval_cp)
|
||||
self.evals.append(eval_val)
|
||||
|
||||
# Check if normalized or raw
|
||||
if self.is_normalized is None:
|
||||
# If eval is in range [-1, 1], assume normalized
|
||||
self.is_normalized = abs(eval_val) <= 1.0
|
||||
|
||||
# Store raw if available
|
||||
if 'eval_raw' in data:
|
||||
self.evals_raw.append(data['eval_raw'])
|
||||
else:
|
||||
self.evals_raw.append(eval_val)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
@@ -36,9 +50,15 @@ class NNUEDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
fen = self.positions[idx]
|
||||
eval_cp = self.evals[idx]
|
||||
eval_val = self.evals[idx]
|
||||
features = fen_to_features(fen)
|
||||
target = torch.sigmoid(torch.tensor(eval_cp / 400.0, dtype=torch.float32))
|
||||
|
||||
# Use evaluation as-is if normalized, otherwise apply sigmoid scaling
|
||||
if self.is_normalized:
|
||||
target = torch.tensor(eval_val, dtype=torch.float32)
|
||||
else:
|
||||
target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32))
|
||||
|
||||
return features, target
|
||||
|
||||
def fen_to_features(fen):
|
||||
@@ -122,7 +142,7 @@ def save_metadata(weights_file, metadata):
|
||||
|
||||
return metadata_file
|
||||
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True):
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None):
|
||||
"""Train the NNUE model.
|
||||
|
||||
Args:
|
||||
@@ -134,12 +154,28 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
|
||||
checkpoint: Optional path to checkpoint file to resume from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
|
||||
"""
|
||||
|
||||
print("Loading dataset...")
|
||||
dataset = NNUEDataset(data_file)
|
||||
num_positions = len(dataset)
|
||||
print(f"Dataset size: {num_positions}")
|
||||
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
|
||||
|
||||
# Print dataset diagnostics
|
||||
evals_array = np.array(dataset.evals)
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("TRAINING DATASET DIAGNOSTICS")
|
||||
print("=" * 60)
|
||||
print(f"Min evaluation: {evals_array.min():.4f}")
|
||||
print(f"Max evaluation: {evals_array.max():.4f}")
|
||||
print(f"Mean evaluation: {evals_array.mean():.4f}")
|
||||
print(f"Median evaluation: {np.median(evals_array):.4f}")
|
||||
print(f"Std deviation: {evals_array.std():.4f}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Split 90% train, 10% validation
|
||||
train_size = int(0.9 * len(dataset))
|
||||
@@ -179,8 +215,11 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
|
||||
|
||||
best_val_loss = float('inf')
|
||||
best_model_state = None
|
||||
epochs_without_improvement = 0
|
||||
|
||||
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...")
|
||||
if early_stopping_patience:
|
||||
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
|
||||
print()
|
||||
|
||||
training_start_time = datetime.now()
|
||||
@@ -228,6 +267,14 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_model_state = model.state_dict().copy()
|
||||
epochs_without_improvement = 0
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
|
||||
# Early stopping check
|
||||
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
|
||||
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
|
||||
break
|
||||
|
||||
# Save best model
|
||||
if best_model_state is not None:
|
||||
@@ -286,6 +333,8 @@ if __name__ == "__main__":
|
||||
help="Batch size (default: 4096)")
|
||||
parser.add_argument("--lr", type=float, default=1e-3,
|
||||
help="Learning rate (default: 1e-3)")
|
||||
parser.add_argument("--early-stopping", type=int, default=None,
|
||||
help="Stop if val loss doesn't improve for N epochs (optional)")
|
||||
parser.add_argument("--stockfish-depth", type=int, default=12,
|
||||
help="Stockfish depth used for evaluations (for metadata, default: 12)")
|
||||
parser.add_argument("--no-versioning", action="store_true",
|
||||
@@ -301,5 +350,6 @@ if __name__ == "__main__":
|
||||
lr=args.lr,
|
||||
checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning
|
||||
use_versioning=not args.no_versioning,
|
||||
early_stopping_patience=args.early_stopping
|
||||
)
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 3,
|
||||
"date": "2026-04-08T09:43:28.000579",
|
||||
"num_positions": 71610,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 20,
|
||||
"batch_size": 4096,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.006398905136849695,
|
||||
"device": "cpu",
|
||||
"checkpoint": "/home/janis/Workspaces/IntelliJ/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v2.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
@@ -52,7 +52,7 @@ object FenParser extends GameContextImport:
|
||||
)
|
||||
else None
|
||||
|
||||
/** Parse en passant target square ("-" for none, or algebraic like "e3"). */
|
||||
/** Parse en passant target square ("-" for none, or algebraic like "e3"). */
|
||||
private def parseEnPassant(s: String): Option[Option[Square]] =
|
||||
if s == "-" then Some(None)
|
||||
else Square.fromAlgebraic(s).map(Some(_))
|
||||
|
||||
@@ -19,7 +19,7 @@ object Main:
|
||||
|
||||
val book = PolyglotBook("../../modules/bot/codekiddy.bin")
|
||||
|
||||
engine.setOpponentBot(NNUEBot(BotDifficulty.Easy, book = Some(book)), Black);
|
||||
engine.setOpponentBot(ClassicalBot(BotDifficulty.Easy, book = Some(book)), Black);
|
||||
|
||||
// Launch ScalaFX GUI in separate thread
|
||||
ChessGUILauncher.launch(engine)
|
||||
|
||||
Reference in New Issue
Block a user