feat: Add Lichess eval importer and metadata for NNUE architecture

This commit is contained in:
2026-04-14 22:43:28 +02:00
parent 8db2c8ca7f
commit 95537bc709
4 changed files with 423 additions and 103 deletions
+194 -103
View File
@@ -23,6 +23,7 @@ from tactical_positions_extractor import (
download_and_extract_puzzle_db,
extract_tactical_only
)
from lichess_importer import import_lichess_evals
from dataset import (
get_datasets_dir,
list_datasets,
@@ -189,9 +190,10 @@ def create_dataset_interactive():
console.print("[cyan]a[/cyan] - Generate random positions")
console.print("[cyan]b[/cyan] - Import from file")
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
console.print("[cyan]d[/cyan] - Done adding sources")
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
console.print("[cyan]e[/cyan] - Done adding sources")
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
if choice == "a":
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
@@ -246,6 +248,34 @@ def create_dataset_interactive():
console.print(f"[red]✗ Tactical extraction failed: {e}[/red]")
elif choice == "d":
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
max_pos = int(max_pos) if max_pos.strip() else None
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
temp_file.unlink(missing_ok=True)
try:
count = import_lichess_evals(
input_path=zst_path,
output_file=str(temp_file),
max_positions=max_pos,
min_depth=min_depth,
)
if count > 0:
sources.append({
"type": "lichess",
"count": count,
"params": {"min_depth": min_depth, "max_positions": max_pos},
})
combined_count += count
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
else:
console.print("[red]✗ No positions imported[/red]")
except Exception as e:
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
elif choice == "e":
if not sources:
console.print("[yellow]⚠ No sources added yet[/yellow]")
continue
@@ -255,77 +285,94 @@ def create_dataset_interactive():
console.print("[yellow]Dataset creation cancelled[/yellow]")
return
# Stockfish labeling parameters
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Determine whether any sources still need Stockfish labeling.
# Lichess sources are already labeled; only generated/tactical/file sources need it.
needs_labeling = any(s["type"] != "lichess" for s in sources)
stockfish_depth = 12
if needs_labeling:
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Summary and confirm
console.print("\n[bold]Dataset Summary:[/bold]")
console.print(f" Total positions: {combined_count:,}")
for source in sources:
console.print(f" - {source['type']}: {source['count']:,}")
console.print(f" Stockfish depth: {stockfish_depth}")
if needs_labeling:
console.print(f" Stockfish depth: {stockfish_depth}")
if not Confirm.ask("\nProceed to label and create dataset?", default=True):
if not Confirm.ask("\nProceed to create dataset?", default=True):
console.print("[yellow]Cancelled[/yellow]")
return
try:
# Combine all sources into one FEN file
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
for source in sources:
if source['type'] == 'generated':
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source['type'] == 'file_import':
temp_file = Path(source['path'])
elif source['type'] == 'tactical':
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
if temp_file.exists():
with open(temp_file, 'r') as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, 'w') as f:
for fen in all_fens:
f.write(fen + '\n')
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers
)
labeled_file.unlink(missing_ok=True)
if not success:
console.print("[red]✗ Labeling failed[/red]")
return
# --- Step 1: Collect already-labeled data (Lichess source) ---
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
if lichess_tmp.exists():
import shutil as _shutil
_shutil.copy(lichess_tmp, labeled_file)
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
console.print(f"[green]✓ Lichess positions ready[/green]")
console.print("[green]✓ Positions labeled[/green]")
# --- Step 2: Combine unlabeled sources and run Stockfish (if any) ---
non_lichess = [s for s in sources if s["type"] != "lichess"]
if non_lichess:
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
# Save dataset
for source in non_lichess:
if source["type"] == "generated":
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source["type"] == "file_import":
temp_file = Path(source["path"])
elif source["type"] == "tactical":
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
else:
continue
if temp_file.exists():
with open(temp_file, "r") as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, "w") as f:
for fen in all_fens:
f.write(fen + "\n")
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers,
)
if not success:
console.print("[red]✗ Stockfish labeling failed[/red]")
return
console.print("[green]✓ Positions labeled[/green]")
# --- Step 3: Create dataset ---
console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]")
version = next_dataset_version()
create_dataset(
version=version,
labeled_jsonl_path=str(labeled_file),
sources=sources,
stockfish_depth=stockfish_depth
stockfish_depth=stockfish_depth,
)
console.print(f"[green]✓ Dataset created: ds_v{version}[/green]")
console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]")
@@ -368,9 +415,10 @@ def extend_dataset_interactive():
console.print("[cyan]a[/cyan] - Generate random positions")
console.print("[cyan]b[/cyan] - Import from file")
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
console.print("[cyan]d[/cyan] - Done adding sources")
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
console.print("[cyan]e[/cyan] - Done adding sources")
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
if choice == "a":
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
@@ -423,6 +471,34 @@ def extend_dataset_interactive():
console.print(f"[red]✗ Extraction failed: {e}[/red]")
elif choice == "d":
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
max_pos = int(max_pos) if max_pos.strip() else None
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
temp_file.unlink(missing_ok=True)
try:
count = import_lichess_evals(
input_path=zst_path,
output_file=str(temp_file),
max_positions=max_pos,
min_depth=min_depth,
)
if count > 0:
sources.append({
"type": "lichess",
"count": count,
"params": {"min_depth": min_depth, "max_positions": max_pos},
})
combined_count += count
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
else:
console.print("[red]✗ No positions imported[/red]")
except Exception as e:
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
elif choice == "e":
if not sources:
console.print("[yellow]⚠ No sources added yet[/yellow]")
continue
@@ -432,14 +508,17 @@ def extend_dataset_interactive():
console.print("[yellow]Extension cancelled[/yellow]")
return
# Stockfish labeling parameters
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
needs_labeling = any(s["type"] != "lichess" for s in sources)
stockfish_depth = 12
if needs_labeling:
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Summary and confirm
console.print("\n[bold]Extension Summary:[/bold]")
@@ -447,53 +526,65 @@ def extend_dataset_interactive():
console.print(f" New positions: {combined_count:,}")
for source in sources:
console.print(f" - {source['type']}: {source['count']:,}")
if needs_labeling:
console.print(f" Stockfish depth: {stockfish_depth}")
if not Confirm.ask("\nProceed to label and extend dataset?", default=True):
if not Confirm.ask("\nProceed to extend dataset?", default=True):
console.print("[yellow]Cancelled[/yellow]")
return
try:
# Combine all sources into one FEN file
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
for source in sources:
if source['type'] == 'generated':
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source['type'] == 'file_import':
temp_file = Path(source['path'])
elif source['type'] == 'tactical':
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
if temp_file.exists():
with open(temp_file, 'r') as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, 'w') as f:
for fen in all_fens:
f.write(fen + '\n')
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers
)
labeled_file.unlink(missing_ok=True)
if not success:
console.print("[red]✗ Labeling failed[/red]")
return
# Copy pre-labeled Lichess data if present
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
if lichess_tmp.exists():
import shutil as _shutil
_shutil.copy(lichess_tmp, labeled_file)
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
console.print(f"[green]✓ Lichess positions ready[/green]")
console.print("[green]✓ Positions labeled[/green]")
# Combine and label remaining sources with Stockfish
non_lichess = [s for s in sources if s["type"] != "lichess"]
if non_lichess:
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
for source in non_lichess:
if source["type"] == "generated":
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source["type"] == "file_import":
temp_file = Path(source["path"])
elif source["type"] == "tactical":
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
else:
continue
if temp_file.exists():
with open(temp_file, "r") as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, "w") as f:
for fen in all_fens:
f.write(fen + "\n")
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers,
)
if not success:
console.print("[red]✗ Stockfish labeling failed[/red]")
return
console.print("[green]✓ Positions labeled[/green]")
# Extend dataset
console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]")
@@ -502,8 +593,8 @@ def extend_dataset_interactive():
new_labeled_path=str(labeled_file),
new_source_entry={
"type": "merged_sources",
"count": len(all_fens),
"sources": sources
"count": combined_count,
"sources": sources,
}
)
+208
View File
@@ -0,0 +1,208 @@
#!/usr/bin/env python3
"""Import pre-labeled positions from the Lichess evaluation database.
Source: https://database.lichess.org/#evals
Format: lichess_db_eval.jsonl.zst — compressed JSONL, one position per line.
Each line:
{
"fen": "<pieces> <turn> <castling> <ep>",
"evals": [
{
"knodes": <int>,
"depth": <int>,
"pvs": [{"cp": <int>, "line": "..."} | {"mate": <int>, "line": "..."}]
}
]
}
cp and mate are from White's perspective (positive = White winning), matching
the sign convention used by label.py (score.white()) and expected by train.py.
"""
import json
import sys
import numpy as np
from pathlib import Path
from tqdm import tqdm
MATE_CP = 20000
SCALE = 300.0
def _best_eval(evals: list) -> dict | None:
"""Return the highest-depth evaluation entry, using knodes as tiebreaker."""
if not evals:
return None
return max(evals, key=lambda e: (e.get("depth", 0), e.get("knodes", 0)))
def _cp_from_pv(pv: dict) -> int | None:
"""Extract centipawn value from a principal variation entry."""
if "cp" in pv:
return max(-MATE_CP, min(MATE_CP, pv["cp"]))
if "mate" in pv:
return MATE_CP if pv["mate"] > 0 else -MATE_CP
return None
def _normalize(cp: int) -> float:
return float(np.tanh(cp / SCALE))
def import_lichess_evals(
input_path: str,
output_file: str,
max_positions: int | None = None,
min_depth: int = 0,
verbose: bool = False,
) -> int:
"""Stream the Lichess eval database and write a labeled.jsonl file.
Args:
input_path: Path to lichess_db_eval.jsonl.zst (or uncompressed .jsonl).
output_file: Destination labeled.jsonl (appended — supports resuming).
max_positions: Stop after this many new positions (None = no limit).
min_depth: Skip positions whose best eval has depth < min_depth.
verbose: Print warnings for skipped lines.
Returns:
Number of new positions written.
"""
import zstandard as zstd
input_path = Path(input_path)
if not input_path.exists():
print(f"Error: {input_path} not found")
sys.exit(1)
# Resume: collect already-written FENs so we skip duplicates.
seen_fens: set[str] = set()
if Path(output_file).exists():
with open(output_file, "r") as f:
for line in f:
try:
seen_fens.add(json.loads(line)["fen"])
except (json.JSONDecodeError, KeyError):
pass
if seen_fens:
print(f"Resuming — skipping {len(seen_fens):,} already-imported positions")
written = 0
skipped_depth = 0
skipped_no_eval = 0
skipped_dup = 0
def iter_lines():
"""Yield decoded text lines from either a .zst or plain .jsonl file."""
import io
if input_path.suffix == ".zst":
dctx = zstd.ZstdDecompressor()
with open(input_path, "rb") as fh:
with dctx.stream_reader(fh) as reader:
text_stream = io.TextIOWrapper(reader, encoding="utf-8")
yield from text_stream
else:
with open(input_path, "r", encoding="utf-8") as fh:
yield from fh
try:
with open(output_file, "a") as out:
with tqdm(desc="Importing Lichess evals", unit=" pos") as pbar:
for raw_line in iter_lines():
line = raw_line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
if verbose:
print("Warning: malformed JSON line skipped")
continue
fen = data.get("fen", "")
if not fen:
skipped_no_eval += 1
continue
if fen in seen_fens:
skipped_dup += 1
continue
best = _best_eval(data.get("evals", []))
if best is None:
skipped_no_eval += 1
continue
if best.get("depth", 0) < min_depth:
skipped_depth += 1
continue
pvs = best.get("pvs", [])
if not pvs:
skipped_no_eval += 1
continue
cp = _cp_from_pv(pvs[0])
if cp is None:
skipped_no_eval += 1
continue
record = {
"fen": fen,
"eval": _normalize(cp),
"eval_raw": cp,
}
out.write(json.dumps(record) + "\n")
seen_fens.add(fen)
written += 1
pbar.update(1)
if max_positions and written >= max_positions:
print(f"\nReached max_positions limit ({max_positions:,})")
break
except Exception:
raise
print()
print("=" * 60)
print("LICHESS IMPORT SUMMARY")
print("=" * 60)
print(f"Positions written: {written:,}")
print(f"Skipped (dup): {skipped_dup:,}")
print(f"Skipped (no eval): {skipped_no_eval:,}")
print(f"Skipped (depth<{min_depth}): {skipped_depth:,}")
print("=" * 60)
print(f"\n✓ Output: {output_file}")
return written
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Import Lichess pre-labeled positions into labeled.jsonl"
)
parser.add_argument("input_path",
help="Path to lichess_db_eval.jsonl.zst")
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
help="Output labeled.jsonl (default: training_data.jsonl)")
parser.add_argument("--max-positions", type=int, default=None,
help="Stop after N positions (default: no limit)")
parser.add_argument("--min-depth", type=int, default=0,
help="Minimum eval depth to accept (default: 0)")
parser.add_argument("--verbose", action="store_true",
help="Print warnings for skipped lines")
args = parser.parse_args()
count = import_lichess_evals(
input_path=args.input_path,
output_file=args.output_file,
max_positions=args.max_positions,
min_depth=args.min_depth,
verbose=args.verbose,
)
sys.exit(0 if count > 0 else 1)
Binary file not shown.
@@ -0,0 +1,21 @@
{
"version": 10,
"date": "2026-04-14T22:18:38.824577",
"num_positions": 3022562,
"stockfish_depth": 12,
"final_val_loss": 6.248612448225196e-05,
"architecture": [
768,
1536,
1024,
512,
256,
1
],
"device": "cuda",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"checkpoint": null
}