feat: Add Lichess eval importer and metadata for NNUE architecture
This commit is contained in:
+194
-103
@@ -23,6 +23,7 @@ from tactical_positions_extractor import (
|
||||
download_and_extract_puzzle_db,
|
||||
extract_tactical_only
|
||||
)
|
||||
from lichess_importer import import_lichess_evals
|
||||
from dataset import (
|
||||
get_datasets_dir,
|
||||
list_datasets,
|
||||
@@ -189,9 +190,10 @@ def create_dataset_interactive():
|
||||
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||
console.print("[cyan]b[/cyan] - Import from file")
|
||||
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||
console.print("[cyan]d[/cyan] - Done adding sources")
|
||||
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
|
||||
console.print("[cyan]e[/cyan] - Done adding sources")
|
||||
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
|
||||
|
||||
if choice == "a":
|
||||
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||
@@ -246,6 +248,34 @@ def create_dataset_interactive():
|
||||
console.print(f"[red]✗ Tactical extraction failed: {e}[/red]")
|
||||
|
||||
elif choice == "d":
|
||||
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
|
||||
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
|
||||
max_pos = int(max_pos) if max_pos.strip() else None
|
||||
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
|
||||
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
temp_file.unlink(missing_ok=True)
|
||||
try:
|
||||
count = import_lichess_evals(
|
||||
input_path=zst_path,
|
||||
output_file=str(temp_file),
|
||||
max_positions=max_pos,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "lichess",
|
||||
"count": count,
|
||||
"params": {"min_depth": min_depth, "max_positions": max_pos},
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
|
||||
else:
|
||||
console.print("[red]✗ No positions imported[/red]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
|
||||
|
||||
elif choice == "e":
|
||||
if not sources:
|
||||
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||
continue
|
||||
@@ -255,77 +285,94 @@ def create_dataset_interactive():
|
||||
console.print("[yellow]Dataset creation cancelled[/yellow]")
|
||||
return
|
||||
|
||||
# Stockfish labeling parameters
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
# Determine whether any sources still need Stockfish labeling.
|
||||
# Lichess sources are already labeled; only generated/tactical/file sources need it.
|
||||
needs_labeling = any(s["type"] != "lichess" for s in sources)
|
||||
|
||||
stockfish_depth = 12
|
||||
if needs_labeling:
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Summary and confirm
|
||||
console.print("\n[bold]Dataset Summary:[/bold]")
|
||||
console.print(f" Total positions: {combined_count:,}")
|
||||
for source in sources:
|
||||
console.print(f" - {source['type']}: {source['count']:,}")
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
if needs_labeling:
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
|
||||
if not Confirm.ask("\nProceed to label and create dataset?", default=True):
|
||||
if not Confirm.ask("\nProceed to create dataset?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
# Combine all sources into one FEN file
|
||||
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
for source in sources:
|
||||
if source['type'] == 'generated':
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source['type'] == 'file_import':
|
||||
temp_file = Path(source['path'])
|
||||
elif source['type'] == 'tactical':
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
|
||||
if temp_file.exists():
|
||||
with open(temp_file, 'r') as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, 'w') as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + '\n')
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
|
||||
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers
|
||||
)
|
||||
labeled_file.unlink(missing_ok=True)
|
||||
|
||||
if not success:
|
||||
console.print("[red]✗ Labeling failed[/red]")
|
||||
return
|
||||
# --- Step 1: Collect already-labeled data (Lichess source) ---
|
||||
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
if lichess_tmp.exists():
|
||||
import shutil as _shutil
|
||||
_shutil.copy(lichess_tmp, labeled_file)
|
||||
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
|
||||
console.print(f"[green]✓ Lichess positions ready[/green]")
|
||||
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
# --- Step 2: Combine unlabeled sources and run Stockfish (if any) ---
|
||||
non_lichess = [s for s in sources if s["type"] != "lichess"]
|
||||
if non_lichess:
|
||||
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
# Save dataset
|
||||
for source in non_lichess:
|
||||
if source["type"] == "generated":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source["type"] == "file_import":
|
||||
temp_file = Path(source["path"])
|
||||
elif source["type"] == "tactical":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
else:
|
||||
continue
|
||||
|
||||
if temp_file.exists():
|
||||
with open(temp_file, "r") as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, "w") as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + "\n")
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
|
||||
|
||||
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Stockfish labeling failed[/red]")
|
||||
return
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
|
||||
# --- Step 3: Create dataset ---
|
||||
console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]")
|
||||
version = next_dataset_version()
|
||||
create_dataset(
|
||||
version=version,
|
||||
labeled_jsonl_path=str(labeled_file),
|
||||
sources=sources,
|
||||
stockfish_depth=stockfish_depth
|
||||
stockfish_depth=stockfish_depth,
|
||||
)
|
||||
console.print(f"[green]✓ Dataset created: ds_v{version}[/green]")
|
||||
console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]")
|
||||
@@ -368,9 +415,10 @@ def extend_dataset_interactive():
|
||||
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||
console.print("[cyan]b[/cyan] - Import from file")
|
||||
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||
console.print("[cyan]d[/cyan] - Done adding sources")
|
||||
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
|
||||
console.print("[cyan]e[/cyan] - Done adding sources")
|
||||
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
|
||||
|
||||
if choice == "a":
|
||||
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||
@@ -423,6 +471,34 @@ def extend_dataset_interactive():
|
||||
console.print(f"[red]✗ Extraction failed: {e}[/red]")
|
||||
|
||||
elif choice == "d":
|
||||
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
|
||||
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
|
||||
max_pos = int(max_pos) if max_pos.strip() else None
|
||||
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
|
||||
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
temp_file.unlink(missing_ok=True)
|
||||
try:
|
||||
count = import_lichess_evals(
|
||||
input_path=zst_path,
|
||||
output_file=str(temp_file),
|
||||
max_positions=max_pos,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "lichess",
|
||||
"count": count,
|
||||
"params": {"min_depth": min_depth, "max_positions": max_pos},
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
|
||||
else:
|
||||
console.print("[red]✗ No positions imported[/red]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
|
||||
|
||||
elif choice == "e":
|
||||
if not sources:
|
||||
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||
continue
|
||||
@@ -432,14 +508,17 @@ def extend_dataset_interactive():
|
||||
console.print("[yellow]Extension cancelled[/yellow]")
|
||||
return
|
||||
|
||||
# Stockfish labeling parameters
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
needs_labeling = any(s["type"] != "lichess" for s in sources)
|
||||
|
||||
stockfish_depth = 12
|
||||
if needs_labeling:
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Summary and confirm
|
||||
console.print("\n[bold]Extension Summary:[/bold]")
|
||||
@@ -447,53 +526,65 @@ def extend_dataset_interactive():
|
||||
console.print(f" New positions: {combined_count:,}")
|
||||
for source in sources:
|
||||
console.print(f" - {source['type']}: {source['count']:,}")
|
||||
if needs_labeling:
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
|
||||
if not Confirm.ask("\nProceed to label and extend dataset?", default=True):
|
||||
if not Confirm.ask("\nProceed to extend dataset?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
# Combine all sources into one FEN file
|
||||
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
for source in sources:
|
||||
if source['type'] == 'generated':
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source['type'] == 'file_import':
|
||||
temp_file = Path(source['path'])
|
||||
elif source['type'] == 'tactical':
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
|
||||
if temp_file.exists():
|
||||
with open(temp_file, 'r') as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, 'w') as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + '\n')
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
|
||||
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers
|
||||
)
|
||||
labeled_file.unlink(missing_ok=True)
|
||||
|
||||
if not success:
|
||||
console.print("[red]✗ Labeling failed[/red]")
|
||||
return
|
||||
# Copy pre-labeled Lichess data if present
|
||||
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
if lichess_tmp.exists():
|
||||
import shutil as _shutil
|
||||
_shutil.copy(lichess_tmp, labeled_file)
|
||||
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
|
||||
console.print(f"[green]✓ Lichess positions ready[/green]")
|
||||
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
# Combine and label remaining sources with Stockfish
|
||||
non_lichess = [s for s in sources if s["type"] != "lichess"]
|
||||
if non_lichess:
|
||||
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
for source in non_lichess:
|
||||
if source["type"] == "generated":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source["type"] == "file_import":
|
||||
temp_file = Path(source["path"])
|
||||
elif source["type"] == "tactical":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
else:
|
||||
continue
|
||||
if temp_file.exists():
|
||||
with open(temp_file, "r") as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, "w") as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + "\n")
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
|
||||
|
||||
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Stockfish labeling failed[/red]")
|
||||
return
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
|
||||
# Extend dataset
|
||||
console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]")
|
||||
@@ -502,8 +593,8 @@ def extend_dataset_interactive():
|
||||
new_labeled_path=str(labeled_file),
|
||||
new_source_entry={
|
||||
"type": "merged_sources",
|
||||
"count": len(all_fens),
|
||||
"sources": sources
|
||||
"count": combined_count,
|
||||
"sources": sources,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Import pre-labeled positions from the Lichess evaluation database.
|
||||
|
||||
Source: https://database.lichess.org/#evals
|
||||
Format: lichess_db_eval.jsonl.zst — compressed JSONL, one position per line.
|
||||
|
||||
Each line:
|
||||
{
|
||||
"fen": "<pieces> <turn> <castling> <ep>",
|
||||
"evals": [
|
||||
{
|
||||
"knodes": <int>,
|
||||
"depth": <int>,
|
||||
"pvs": [{"cp": <int>, "line": "..."} | {"mate": <int>, "line": "..."}]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
cp and mate are from White's perspective (positive = White winning), matching
|
||||
the sign convention used by label.py (score.white()) and expected by train.py.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
MATE_CP = 20000
|
||||
SCALE = 300.0
|
||||
|
||||
|
||||
def _best_eval(evals: list) -> dict | None:
|
||||
"""Return the highest-depth evaluation entry, using knodes as tiebreaker."""
|
||||
if not evals:
|
||||
return None
|
||||
return max(evals, key=lambda e: (e.get("depth", 0), e.get("knodes", 0)))
|
||||
|
||||
|
||||
def _cp_from_pv(pv: dict) -> int | None:
|
||||
"""Extract centipawn value from a principal variation entry."""
|
||||
if "cp" in pv:
|
||||
return max(-MATE_CP, min(MATE_CP, pv["cp"]))
|
||||
if "mate" in pv:
|
||||
return MATE_CP if pv["mate"] > 0 else -MATE_CP
|
||||
return None
|
||||
|
||||
|
||||
def _normalize(cp: int) -> float:
|
||||
return float(np.tanh(cp / SCALE))
|
||||
|
||||
|
||||
def import_lichess_evals(
|
||||
input_path: str,
|
||||
output_file: str,
|
||||
max_positions: int | None = None,
|
||||
min_depth: int = 0,
|
||||
verbose: bool = False,
|
||||
) -> int:
|
||||
"""Stream the Lichess eval database and write a labeled.jsonl file.
|
||||
|
||||
Args:
|
||||
input_path: Path to lichess_db_eval.jsonl.zst (or uncompressed .jsonl).
|
||||
output_file: Destination labeled.jsonl (appended — supports resuming).
|
||||
max_positions: Stop after this many new positions (None = no limit).
|
||||
min_depth: Skip positions whose best eval has depth < min_depth.
|
||||
verbose: Print warnings for skipped lines.
|
||||
|
||||
Returns:
|
||||
Number of new positions written.
|
||||
"""
|
||||
import zstandard as zstd
|
||||
|
||||
input_path = Path(input_path)
|
||||
if not input_path.exists():
|
||||
print(f"Error: {input_path} not found")
|
||||
sys.exit(1)
|
||||
|
||||
# Resume: collect already-written FENs so we skip duplicates.
|
||||
seen_fens: set[str] = set()
|
||||
if Path(output_file).exists():
|
||||
with open(output_file, "r") as f:
|
||||
for line in f:
|
||||
try:
|
||||
seen_fens.add(json.loads(line)["fen"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
if seen_fens:
|
||||
print(f"Resuming — skipping {len(seen_fens):,} already-imported positions")
|
||||
|
||||
written = 0
|
||||
skipped_depth = 0
|
||||
skipped_no_eval = 0
|
||||
skipped_dup = 0
|
||||
|
||||
def iter_lines():
|
||||
"""Yield decoded text lines from either a .zst or plain .jsonl file."""
|
||||
import io
|
||||
if input_path.suffix == ".zst":
|
||||
dctx = zstd.ZstdDecompressor()
|
||||
with open(input_path, "rb") as fh:
|
||||
with dctx.stream_reader(fh) as reader:
|
||||
text_stream = io.TextIOWrapper(reader, encoding="utf-8")
|
||||
yield from text_stream
|
||||
else:
|
||||
with open(input_path, "r", encoding="utf-8") as fh:
|
||||
yield from fh
|
||||
|
||||
try:
|
||||
with open(output_file, "a") as out:
|
||||
with tqdm(desc="Importing Lichess evals", unit=" pos") as pbar:
|
||||
for raw_line in iter_lines():
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
if verbose:
|
||||
print("Warning: malformed JSON line skipped")
|
||||
continue
|
||||
|
||||
fen = data.get("fen", "")
|
||||
if not fen:
|
||||
skipped_no_eval += 1
|
||||
continue
|
||||
|
||||
if fen in seen_fens:
|
||||
skipped_dup += 1
|
||||
continue
|
||||
|
||||
best = _best_eval(data.get("evals", []))
|
||||
if best is None:
|
||||
skipped_no_eval += 1
|
||||
continue
|
||||
|
||||
if best.get("depth", 0) < min_depth:
|
||||
skipped_depth += 1
|
||||
continue
|
||||
|
||||
pvs = best.get("pvs", [])
|
||||
if not pvs:
|
||||
skipped_no_eval += 1
|
||||
continue
|
||||
|
||||
cp = _cp_from_pv(pvs[0])
|
||||
if cp is None:
|
||||
skipped_no_eval += 1
|
||||
continue
|
||||
|
||||
record = {
|
||||
"fen": fen,
|
||||
"eval": _normalize(cp),
|
||||
"eval_raw": cp,
|
||||
}
|
||||
out.write(json.dumps(record) + "\n")
|
||||
seen_fens.add(fen)
|
||||
written += 1
|
||||
pbar.update(1)
|
||||
|
||||
if max_positions and written >= max_positions:
|
||||
print(f"\nReached max_positions limit ({max_positions:,})")
|
||||
break
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("LICHESS IMPORT SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Positions written: {written:,}")
|
||||
print(f"Skipped (dup): {skipped_dup:,}")
|
||||
print(f"Skipped (no eval): {skipped_no_eval:,}")
|
||||
print(f"Skipped (depth<{min_depth}): {skipped_depth:,}")
|
||||
print("=" * 60)
|
||||
print(f"\n✓ Output: {output_file}")
|
||||
|
||||
return written
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Import Lichess pre-labeled positions into labeled.jsonl"
|
||||
)
|
||||
parser.add_argument("input_path",
|
||||
help="Path to lichess_db_eval.jsonl.zst")
|
||||
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
|
||||
help="Output labeled.jsonl (default: training_data.jsonl)")
|
||||
parser.add_argument("--max-positions", type=int, default=None,
|
||||
help="Stop after N positions (default: no limit)")
|
||||
parser.add_argument("--min-depth", type=int, default=0,
|
||||
help="Minimum eval depth to accept (default: 0)")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Print warnings for skipped lines")
|
||||
|
||||
args = parser.parse_args()
|
||||
count = import_lichess_evals(
|
||||
input_path=args.input_path,
|
||||
output_file=args.output_file,
|
||||
max_positions=args.max_positions,
|
||||
min_depth=args.min_depth,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
sys.exit(0 if count > 0 else 1)
|
||||
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"version": 10,
|
||||
"date": "2026-04-14T22:18:38.824577",
|
||||
"num_positions": 3022562,
|
||||
"stockfish_depth": 12,
|
||||
"final_val_loss": 6.248612448225196e-05,
|
||||
"architecture": [
|
||||
768,
|
||||
1536,
|
||||
1024,
|
||||
512,
|
||||
256,
|
||||
1
|
||||
],
|
||||
"device": "cuda",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"checkpoint": null
|
||||
}
|
||||
Reference in New Issue
Block a user