feat: Allow configurable hidden layer sizes for NNUE architecture
This commit is contained in:
@@ -17,7 +17,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src"))
|
|||||||
|
|
||||||
from generate import play_random_game_and_collect_positions
|
from generate import play_random_game_and_collect_positions
|
||||||
from label import label_positions_with_stockfish
|
from label import label_positions_with_stockfish
|
||||||
from train import train_nnue, burst_train
|
from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES
|
||||||
from export import export_to_nbai
|
from export import export_to_nbai
|
||||||
from tactical_positions_extractor import (
|
from tactical_positions_extractor import (
|
||||||
download_and_extract_puzzle_db,
|
download_and_extract_puzzle_db,
|
||||||
@@ -624,13 +624,22 @@ def train_interactive():
|
|||||||
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
||||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||||
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||||
|
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
|
||||||
|
hidden_layers_str = Prompt.ask(
|
||||||
|
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
|
||||||
|
default=default_layers
|
||||||
|
)
|
||||||
|
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
|
||||||
early_stopping = None
|
early_stopping = None
|
||||||
if Confirm.ask("Enable early stopping?", default=False):
|
if Confirm.ask("Enable early stopping?", default=False):
|
||||||
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||||||
|
|
||||||
|
arch_str = " → ".join(str(s) for s in [768] + hidden_sizes + [1])
|
||||||
|
|
||||||
# Confirm and start
|
# Confirm and start
|
||||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||||
console.print(f" Dataset: ds_v{dataset_version}")
|
console.print(f" Dataset: ds_v{dataset_version}")
|
||||||
|
console.print(f" Architecture: {arch_str}")
|
||||||
console.print(f" Epochs: {epochs}")
|
console.print(f" Epochs: {epochs}")
|
||||||
console.print(f" Batch size: {batch_size}")
|
console.print(f" Batch size: {batch_size}")
|
||||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||||
@@ -665,7 +674,8 @@ def train_interactive():
|
|||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
use_versioning=True,
|
use_versioning=True,
|
||||||
early_stopping_patience=early_stopping,
|
early_stopping_patience=early_stopping,
|
||||||
subsample_ratio=subsample_ratio
|
subsample_ratio=subsample_ratio,
|
||||||
|
hidden_sizes=hidden_sizes,
|
||||||
)
|
)
|
||||||
console.print("[green]✓ Training complete[/green]")
|
console.print("[green]✓ Training complete[/green]")
|
||||||
|
|
||||||
@@ -727,10 +737,18 @@ def burst_train_interactive():
|
|||||||
# Training hyperparameters
|
# Training hyperparameters
|
||||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||||
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||||
|
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
|
||||||
|
hidden_layers_str = Prompt.ask(
|
||||||
|
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
|
||||||
|
default=default_layers
|
||||||
|
)
|
||||||
|
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
|
||||||
|
arch_str = " → ".join(str(s) for s in [768] + hidden_sizes + [1])
|
||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||||
console.print(f" Dataset: ds_v{dataset_version}")
|
console.print(f" Dataset: ds_v{dataset_version}")
|
||||||
|
console.print(f" Architecture: {arch_str}")
|
||||||
console.print(f" Duration: {duration_minutes:.0f} minutes")
|
console.print(f" Duration: {duration_minutes:.0f} minutes")
|
||||||
console.print(f" Epochs per season: {epochs_per_season}")
|
console.print(f" Epochs per season: {epochs_per_season}")
|
||||||
console.print(f" Patience: {early_stopping_patience}")
|
console.print(f" Patience: {early_stopping_patience}")
|
||||||
@@ -757,6 +775,7 @@ def burst_train_interactive():
|
|||||||
initial_checkpoint=checkpoint,
|
initial_checkpoint=checkpoint,
|
||||||
use_versioning=True,
|
use_versioning=True,
|
||||||
subsample_ratio=subsample_ratio,
|
subsample_ratio=subsample_ratio,
|
||||||
|
hidden_sizes=hidden_sizes,
|
||||||
)
|
)
|
||||||
console.print("[green]✓ Burst training complete[/green]")
|
console.print("[green]✓ Burst training complete[/green]")
|
||||||
|
|
||||||
|
|||||||
@@ -86,35 +86,39 @@ def fen_to_features(fen):
|
|||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
DEFAULT_HIDDEN_SIZES = [1536, 1024, 512, 256]
|
||||||
|
|
||||||
|
|
||||||
class NNUE(nn.Module):
|
class NNUE(nn.Module):
|
||||||
"""NNUE neural network architecture: 768→1536→1024→512→256→1 with dropout regularization."""
|
"""NNUE neural network with configurable hidden layers.
|
||||||
|
|
||||||
def __init__(self, dropout_rate=0.2):
|
Architecture: 768 → hidden_sizes[0] → ... → hidden_sizes[-1] → 1
|
||||||
|
Layer attributes follow the naming l1, l2, ..., lN so export.py can
|
||||||
|
infer the architecture directly from the state_dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_sizes=None, dropout_rate=0.2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.l1 = nn.Linear(768, 1536)
|
if hidden_sizes is None:
|
||||||
self.relu1 = nn.ReLU()
|
hidden_sizes = DEFAULT_HIDDEN_SIZES
|
||||||
self.drop1 = nn.Dropout(dropout_rate)
|
self.hidden_sizes = list(hidden_sizes)
|
||||||
|
sizes = [768] + self.hidden_sizes + [1]
|
||||||
|
num_hidden = len(self.hidden_sizes)
|
||||||
|
|
||||||
self.l2 = nn.Linear(1536, 1024)
|
for i in range(num_hidden):
|
||||||
self.relu2 = nn.ReLU()
|
setattr(self, f"l{i + 1}", nn.Linear(sizes[i], sizes[i + 1]))
|
||||||
self.drop2 = nn.Dropout(dropout_rate)
|
setattr(self, f"relu{i + 1}", nn.ReLU())
|
||||||
|
setattr(self, f"drop{i + 1}", nn.Dropout(dropout_rate))
|
||||||
self.l3 = nn.Linear(1024, 512)
|
setattr(self, f"l{num_hidden + 1}", nn.Linear(sizes[-2], sizes[-1]))
|
||||||
self.relu3 = nn.ReLU()
|
self._num_hidden = num_hidden
|
||||||
self.drop3 = nn.Dropout(dropout_rate)
|
|
||||||
|
|
||||||
self.l4 = nn.Linear(512, 256)
|
|
||||||
self.relu4 = nn.ReLU()
|
|
||||||
self.drop4 = nn.Dropout(dropout_rate)
|
|
||||||
|
|
||||||
self.l5 = nn.Linear(256, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.drop1(self.relu1(self.l1(x)))
|
for i in range(1, self._num_hidden + 1):
|
||||||
x = self.drop2(self.relu2(self.l2(x)))
|
layer = getattr(self, f"l{i}")
|
||||||
x = self.drop3(self.relu3(self.l3(x)))
|
relu = getattr(self, f"relu{i}")
|
||||||
x = self.drop4(self.relu4(self.l4(x)))
|
drop = getattr(self, f"drop{i}")
|
||||||
return self.l5(x)
|
x = drop(relu(layer(x)))
|
||||||
|
return getattr(self, f"l{self._num_hidden + 1}")(x)
|
||||||
|
|
||||||
def find_next_version(base_name="nnue_weights"):
|
def find_next_version(base_name="nnue_weights"):
|
||||||
"""Find the next version number for model versioning.
|
"""Find the next version number for model versioning.
|
||||||
@@ -306,6 +310,7 @@ def _run_training_season(
|
|||||||
"scheduler_state_dict": scheduler.state_dict(),
|
"scheduler_state_dict": scheduler.state_dict(),
|
||||||
"scaler_state_dict": scaler.state_dict(),
|
"scaler_state_dict": scaler.state_dict(),
|
||||||
"best_val_loss": best_val_loss,
|
"best_val_loss": best_val_loss,
|
||||||
|
"hidden_sizes": model.hidden_sizes,
|
||||||
}, checkpoint_file)
|
}, checkpoint_file)
|
||||||
|
|
||||||
if val_loss < best_val_loss:
|
if val_loss < best_val_loss:
|
||||||
@@ -328,10 +333,12 @@ def _run_training_season(
|
|||||||
|
|
||||||
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
|
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||||
best_val_loss, output_file, use_versioning, num_positions,
|
best_val_loss, output_file, use_versioning, num_positions,
|
||||||
stockfish_depth, training_start_time, extra_metadata=None):
|
stockfish_depth, training_start_time, hidden_sizes=None,
|
||||||
|
extra_metadata=None):
|
||||||
"""Save the best model with optional versioning and metadata."""
|
"""Save the best model with optional versioning and metadata."""
|
||||||
final_output_file = output_file
|
final_output_file = output_file
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
architecture = [768] + list(hidden_sizes or DEFAULT_HIDDEN_SIZES) + [1]
|
||||||
|
|
||||||
if use_versioning:
|
if use_versioning:
|
||||||
base_name = output_file.replace(".pt", "")
|
base_name = output_file.replace(".pt", "")
|
||||||
@@ -344,6 +351,7 @@ def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_e
|
|||||||
"num_positions": num_positions,
|
"num_positions": num_positions,
|
||||||
"stockfish_depth": stockfish_depth,
|
"stockfish_depth": stockfish_depth,
|
||||||
"final_val_loss": float(best_val_loss),
|
"final_val_loss": float(best_val_loss),
|
||||||
|
"architecture": architecture,
|
||||||
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
|
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
|
||||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||||
}
|
}
|
||||||
@@ -357,6 +365,7 @@ def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_e
|
|||||||
"scaler_state_dict": scaler.state_dict(),
|
"scaler_state_dict": scaler.state_dict(),
|
||||||
"epoch": last_epoch,
|
"epoch": last_epoch,
|
||||||
"best_val_loss": best_val_loss,
|
"best_val_loss": best_val_loss,
|
||||||
|
"hidden_sizes": list(hidden_sizes or DEFAULT_HIDDEN_SIZES),
|
||||||
}, final_output_file)
|
}, final_output_file)
|
||||||
print(f"Best model saved to {final_output_file}")
|
print(f"Best model saved to {final_output_file}")
|
||||||
|
|
||||||
@@ -367,7 +376,7 @@ def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_e
|
|||||||
for key, val in metadata.items():
|
for key, val in metadata.items():
|
||||||
print(f" {key}: {val}")
|
print(f" {key}: {val}")
|
||||||
|
|
||||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0):
|
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
|
||||||
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
|
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -382,20 +391,31 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
|
|||||||
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
|
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
|
||||||
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
|
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
|
||||||
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
|
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
|
||||||
|
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
|
||||||
"""
|
"""
|
||||||
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||||
_setup_training(data_file, batch_size, subsample_ratio)
|
_setup_training(data_file, batch_size, subsample_ratio)
|
||||||
|
|
||||||
model = NNUE().to(device)
|
start_epoch = 0
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
|
||||||
|
|
||||||
|
if checkpoint:
|
||||||
|
print(f"Loading checkpoint: {checkpoint}")
|
||||||
|
ckpt = torch.load(checkpoint, map_location=device)
|
||||||
|
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||||
|
ckpt_hidden = ckpt.get("hidden_sizes")
|
||||||
|
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
|
||||||
|
print(f" Using architecture from checkpoint: {ckpt_hidden}")
|
||||||
|
resolved_hidden_sizes = ckpt_hidden
|
||||||
|
|
||||||
|
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||||
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||||
|
|
||||||
start_epoch = 0
|
|
||||||
best_val_loss = float('inf')
|
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
print(f"Loading checkpoint: {checkpoint}")
|
|
||||||
ckpt = torch.load(checkpoint, map_location=device)
|
ckpt = torch.load(checkpoint, map_location=device)
|
||||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||||
model.load_state_dict(ckpt["model_state_dict"])
|
model.load_state_dict(ckpt["model_state_dict"])
|
||||||
@@ -412,6 +432,8 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
|
|||||||
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
|
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
|
||||||
|
|
||||||
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
||||||
|
arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
|
||||||
|
print(f"Architecture: {arch_str}")
|
||||||
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
|
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
|
||||||
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
|
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
|
||||||
print(f"Mixed precision training: enabled")
|
print(f"Mixed precision training: enabled")
|
||||||
@@ -441,6 +463,7 @@ def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=
|
|||||||
best_model_state, optimizer, scheduler, scaler, last_epoch,
|
best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||||
best_val_loss, output_file, use_versioning, num_positions,
|
best_val_loss, output_file, use_versioning, num_positions,
|
||||||
stockfish_depth, training_start_time,
|
stockfish_depth, training_start_time,
|
||||||
|
hidden_sizes=resolved_hidden_sizes,
|
||||||
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
|
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
|
||||||
"checkpoint": str(checkpoint) if checkpoint else None}
|
"checkpoint": str(checkpoint) if checkpoint else None}
|
||||||
)
|
)
|
||||||
@@ -449,7 +472,7 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
|||||||
epochs_per_season=50, early_stopping_patience=10,
|
epochs_per_season=50, early_stopping_patience=10,
|
||||||
batch_size=16384, lr=0.001, initial_checkpoint=None,
|
batch_size=16384, lr=0.001, initial_checkpoint=None,
|
||||||
stockfish_depth=12, use_versioning=True,
|
stockfish_depth=12, use_versioning=True,
|
||||||
weight_decay=1e-4, subsample_ratio=1.0):
|
weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
|
||||||
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
|
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
|
||||||
|
|
||||||
Each season trains with early stopping. When early stopping fires, the model reloads the
|
Each season trains with early stopping. When early stopping fires, the model reloads the
|
||||||
@@ -469,18 +492,29 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
|||||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||||
weight_decay: L2 regularization strength (default: 1e-4)
|
weight_decay: L2 regularization strength (default: 1e-4)
|
||||||
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
|
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
|
||||||
|
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
|
||||||
"""
|
"""
|
||||||
deadline = datetime.now() + timedelta(minutes=duration_minutes)
|
deadline = datetime.now() + timedelta(minutes=duration_minutes)
|
||||||
|
|
||||||
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||||
_setup_training(data_file, batch_size, subsample_ratio)
|
_setup_training(data_file, batch_size, subsample_ratio)
|
||||||
|
|
||||||
model = NNUE().to(device)
|
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
|
||||||
|
|
||||||
|
if initial_checkpoint:
|
||||||
|
print(f"Loading initial weights: {initial_checkpoint}")
|
||||||
|
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||||
|
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||||
|
ckpt_hidden = ckpt.get("hidden_sizes")
|
||||||
|
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
|
||||||
|
print(f" Using architecture from checkpoint: {ckpt_hidden}")
|
||||||
|
resolved_hidden_sizes = ckpt_hidden
|
||||||
|
|
||||||
|
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
best_global_val_loss = float('inf')
|
best_global_val_loss = float('inf')
|
||||||
|
|
||||||
if initial_checkpoint:
|
if initial_checkpoint:
|
||||||
print(f"Loading initial weights: {initial_checkpoint}")
|
|
||||||
ckpt = torch.load(initial_checkpoint, map_location=device)
|
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||||
model.load_state_dict(ckpt["model_state_dict"])
|
model.load_state_dict(ckpt["model_state_dict"])
|
||||||
@@ -493,6 +527,8 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
|||||||
model.load_state_dict(ckpt)
|
model.load_state_dict(ckpt)
|
||||||
print("Loaded weights-only checkpoint (no val loss info).")
|
print("Loaded weights-only checkpoint (no val loss info).")
|
||||||
|
|
||||||
|
arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
|
||||||
|
print(f"Architecture: {arch_str}")
|
||||||
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
|
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
|
||||||
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
|
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
|
||||||
print()
|
print()
|
||||||
@@ -555,6 +591,7 @@ def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
|||||||
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
|
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
|
||||||
best_global_val_loss, output_file, use_versioning, num_positions,
|
best_global_val_loss, output_file, use_versioning, num_positions,
|
||||||
stockfish_depth, burst_start_time,
|
stockfish_depth, burst_start_time,
|
||||||
|
hidden_sizes=resolved_hidden_sizes,
|
||||||
extra_metadata={
|
extra_metadata={
|
||||||
"mode": "burst",
|
"mode": "burst",
|
||||||
"duration_minutes": duration_minutes,
|
"duration_minutes": duration_minutes,
|
||||||
@@ -593,6 +630,8 @@ if __name__ == "__main__":
|
|||||||
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
|
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
|
||||||
parser.add_argument("--subsample-ratio", type=float, default=1.0,
|
parser.add_argument("--subsample-ratio", type=float, default=1.0,
|
||||||
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
|
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
|
||||||
|
parser.add_argument("--hidden-layers", type=str, default=None,
|
||||||
|
help="Comma-separated hidden layer sizes (default: 1536,1024,512,256)")
|
||||||
|
|
||||||
# Burst mode
|
# Burst mode
|
||||||
parser.add_argument("--burst-duration", type=float, default=None,
|
parser.add_argument("--burst-duration", type=float, default=None,
|
||||||
@@ -602,6 +641,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
hidden_sizes = [int(x) for x in args.hidden_layers.split(",")] if args.hidden_layers else None
|
||||||
|
|
||||||
if args.burst_duration is not None:
|
if args.burst_duration is not None:
|
||||||
burst_train(
|
burst_train(
|
||||||
data_file=args.data_file,
|
data_file=args.data_file,
|
||||||
@@ -616,6 +657,7 @@ if __name__ == "__main__":
|
|||||||
use_versioning=not args.no_versioning,
|
use_versioning=not args.no_versioning,
|
||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
subsample_ratio=args.subsample_ratio,
|
subsample_ratio=args.subsample_ratio,
|
||||||
|
hidden_sizes=hidden_sizes,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_nnue(
|
train_nnue(
|
||||||
@@ -630,4 +672,5 @@ if __name__ == "__main__":
|
|||||||
early_stopping_patience=args.early_stopping,
|
early_stopping_patience=args.early_stopping,
|
||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
subsample_ratio=args.subsample_ratio,
|
subsample_ratio=args.subsample_ratio,
|
||||||
|
hidden_sizes=hidden_sizes,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user