feat: enhance training and evaluation processes with new parameters and normalization options

This commit is contained in:
2026-04-08 10:21:38 +02:00
committed by Janis
parent 34945e4fb8
commit a5560285fd
8 changed files with 242 additions and 90 deletions
+56 -6
View File
@@ -12,6 +12,7 @@ from tqdm import tqdm
import chess
from datetime import datetime
import re
import numpy as np
class NNUEDataset(Dataset):
"""Dataset of chess positions with evaluations."""
@@ -19,15 +20,28 @@ class NNUEDataset(Dataset):
def __init__(self, data_file):
self.positions = []
self.evals = []
self.evals_raw = []
self.is_normalized = None
with open(data_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
fen = data['fen']
eval_cp = data['eval']
eval_val = data['eval']
self.positions.append(fen)
self.evals.append(eval_cp)
self.evals.append(eval_val)
# Check if normalized or raw
if self.is_normalized is None:
# If eval is in range [-1, 1], assume normalized
self.is_normalized = abs(eval_val) <= 1.0
# Store raw if available
if 'eval_raw' in data:
self.evals_raw.append(data['eval_raw'])
else:
self.evals_raw.append(eval_val)
except (json.JSONDecodeError, KeyError):
pass
@@ -36,9 +50,15 @@ class NNUEDataset(Dataset):
def __getitem__(self, idx):
fen = self.positions[idx]
eval_cp = self.evals[idx]
eval_val = self.evals[idx]
features = fen_to_features(fen)
target = torch.sigmoid(torch.tensor(eval_cp / 400.0, dtype=torch.float32))
# Use evaluation as-is if normalized, otherwise apply sigmoid scaling
if self.is_normalized:
target = torch.tensor(eval_val, dtype=torch.float32)
else:
target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32))
return features, target
def fen_to_features(fen):
@@ -122,7 +142,7 @@ def save_metadata(weights_file, metadata):
return metadata_file
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True):
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4096, lr=1e-3, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None):
"""Train the NNUE model.
Args:
@@ -134,12 +154,28 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
"""
print("Loading dataset...")
dataset = NNUEDataset(data_file)
num_positions = len(dataset)
print(f"Dataset size: {num_positions}")
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
# Print dataset diagnostics
evals_array = np.array(dataset.evals)
print()
print("=" * 60)
print("TRAINING DATASET DIAGNOSTICS")
print("=" * 60)
print(f"Min evaluation: {evals_array.min():.4f}")
print(f"Max evaluation: {evals_array.max():.4f}")
print(f"Mean evaluation: {evals_array.mean():.4f}")
print(f"Median evaluation: {np.median(evals_array):.4f}")
print(f"Std deviation: {evals_array.std():.4f}")
print("=" * 60)
print()
# Split 90% train, 10% validation
train_size = int(0.9 * len(dataset))
@@ -179,8 +215,11 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
best_val_loss = float('inf')
best_model_state = None
epochs_without_improvement = 0
print(f"Training for {epochs} epochs (starting from epoch {start_epoch + 1})...")
if early_stopping_patience:
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
print()
training_start_time = datetime.now()
@@ -228,6 +267,14 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=20, batch_size=4
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = model.state_dict().copy()
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
# Early stopping check
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
break
# Save best model
if best_model_state is not None:
@@ -286,6 +333,8 @@ if __name__ == "__main__":
help="Batch size (default: 4096)")
parser.add_argument("--lr", type=float, default=1e-3,
help="Learning rate (default: 1e-3)")
parser.add_argument("--early-stopping", type=int, default=None,
help="Stop if val loss doesn't improve for N epochs (optional)")
parser.add_argument("--stockfish-depth", type=int, default=12,
help="Stockfish depth used for evaluations (for metadata, default: 12)")
parser.add_argument("--no-versioning", action="store_true",
@@ -301,5 +350,6 @@ if __name__ == "__main__":
lr=args.lr,
checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning
use_versioning=not args.no_versioning,
early_stopping_patience=args.early_stopping
)