feat: Add Lichess eval importer and metadata for NNUE architecture

This commit is contained in:
2026-04-14 22:43:28 +02:00
parent 8db2c8ca7f
commit 95537bc709
4 changed files with 423 additions and 103 deletions
+194 -103
View File
@@ -23,6 +23,7 @@ from tactical_positions_extractor import (
download_and_extract_puzzle_db,
extract_tactical_only
)
from lichess_importer import import_lichess_evals
from dataset import (
get_datasets_dir,
list_datasets,
@@ -189,9 +190,10 @@ def create_dataset_interactive():
console.print("[cyan]a[/cyan] - Generate random positions")
console.print("[cyan]b[/cyan] - Import from file")
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
console.print("[cyan]d[/cyan] - Done adding sources")
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
console.print("[cyan]e[/cyan] - Done adding sources")
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
if choice == "a":
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
@@ -246,6 +248,34 @@ def create_dataset_interactive():
console.print(f"[red]✗ Tactical extraction failed: {e}[/red]")
elif choice == "d":
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
max_pos = int(max_pos) if max_pos.strip() else None
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
temp_file.unlink(missing_ok=True)
try:
count = import_lichess_evals(
input_path=zst_path,
output_file=str(temp_file),
max_positions=max_pos,
min_depth=min_depth,
)
if count > 0:
sources.append({
"type": "lichess",
"count": count,
"params": {"min_depth": min_depth, "max_positions": max_pos},
})
combined_count += count
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
else:
console.print("[red]✗ No positions imported[/red]")
except Exception as e:
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
elif choice == "e":
if not sources:
console.print("[yellow]⚠ No sources added yet[/yellow]")
continue
@@ -255,77 +285,94 @@ def create_dataset_interactive():
console.print("[yellow]Dataset creation cancelled[/yellow]")
return
# Stockfish labeling parameters
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Determine whether any sources still need Stockfish labeling.
# Lichess sources are already labeled; only generated/tactical/file sources need it.
needs_labeling = any(s["type"] != "lichess" for s in sources)
stockfish_depth = 12
if needs_labeling:
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Summary and confirm
console.print("\n[bold]Dataset Summary:[/bold]")
console.print(f" Total positions: {combined_count:,}")
for source in sources:
console.print(f" - {source['type']}: {source['count']:,}")
console.print(f" Stockfish depth: {stockfish_depth}")
if needs_labeling:
console.print(f" Stockfish depth: {stockfish_depth}")
if not Confirm.ask("\nProceed to label and create dataset?", default=True):
if not Confirm.ask("\nProceed to create dataset?", default=True):
console.print("[yellow]Cancelled[/yellow]")
return
try:
# Combine all sources into one FEN file
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
for source in sources:
if source['type'] == 'generated':
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source['type'] == 'file_import':
temp_file = Path(source['path'])
elif source['type'] == 'tactical':
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
if temp_file.exists():
with open(temp_file, 'r') as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, 'w') as f:
for fen in all_fens:
f.write(fen + '\n')
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers
)
labeled_file.unlink(missing_ok=True)
if not success:
console.print("[red]✗ Labeling failed[/red]")
return
# --- Step 1: Collect already-labeled data (Lichess source) ---
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
if lichess_tmp.exists():
import shutil as _shutil
_shutil.copy(lichess_tmp, labeled_file)
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
console.print(f"[green]✓ Lichess positions ready[/green]")
console.print("[green]✓ Positions labeled[/green]")
# --- Step 2: Combine unlabeled sources and run Stockfish (if any) ---
non_lichess = [s for s in sources if s["type"] != "lichess"]
if non_lichess:
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
# Save dataset
for source in non_lichess:
if source["type"] == "generated":
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source["type"] == "file_import":
temp_file = Path(source["path"])
elif source["type"] == "tactical":
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
else:
continue
if temp_file.exists():
with open(temp_file, "r") as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, "w") as f:
for fen in all_fens:
f.write(fen + "\n")
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers,
)
if not success:
console.print("[red]✗ Stockfish labeling failed[/red]")
return
console.print("[green]✓ Positions labeled[/green]")
# --- Step 3: Create dataset ---
console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]")
version = next_dataset_version()
create_dataset(
version=version,
labeled_jsonl_path=str(labeled_file),
sources=sources,
stockfish_depth=stockfish_depth
stockfish_depth=stockfish_depth,
)
console.print(f"[green]✓ Dataset created: ds_v{version}[/green]")
console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]")
@@ -368,9 +415,10 @@ def extend_dataset_interactive():
console.print("[cyan]a[/cyan] - Generate random positions")
console.print("[cyan]b[/cyan] - Import from file")
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
console.print("[cyan]d[/cyan] - Done adding sources")
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
console.print("[cyan]e[/cyan] - Done adding sources")
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
if choice == "a":
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
@@ -423,6 +471,34 @@ def extend_dataset_interactive():
console.print(f"[red]✗ Extraction failed: {e}[/red]")
elif choice == "d":
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
max_pos = int(max_pos) if max_pos.strip() else None
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
temp_file.unlink(missing_ok=True)
try:
count = import_lichess_evals(
input_path=zst_path,
output_file=str(temp_file),
max_positions=max_pos,
min_depth=min_depth,
)
if count > 0:
sources.append({
"type": "lichess",
"count": count,
"params": {"min_depth": min_depth, "max_positions": max_pos},
})
combined_count += count
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
else:
console.print("[red]✗ No positions imported[/red]")
except Exception as e:
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
elif choice == "e":
if not sources:
console.print("[yellow]⚠ No sources added yet[/yellow]")
continue
@@ -432,14 +508,17 @@ def extend_dataset_interactive():
console.print("[yellow]Extension cancelled[/yellow]")
return
# Stockfish labeling parameters
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
needs_labeling = any(s["type"] != "lichess" for s in sources)
stockfish_depth = 12
if needs_labeling:
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Summary and confirm
console.print("\n[bold]Extension Summary:[/bold]")
@@ -447,53 +526,65 @@ def extend_dataset_interactive():
console.print(f" New positions: {combined_count:,}")
for source in sources:
console.print(f" - {source['type']}: {source['count']:,}")
if needs_labeling:
console.print(f" Stockfish depth: {stockfish_depth}")
if not Confirm.ask("\nProceed to label and extend dataset?", default=True):
if not Confirm.ask("\nProceed to extend dataset?", default=True):
console.print("[yellow]Cancelled[/yellow]")
return
try:
# Combine all sources into one FEN file
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
for source in sources:
if source['type'] == 'generated':
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source['type'] == 'file_import':
temp_file = Path(source['path'])
elif source['type'] == 'tactical':
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
if temp_file.exists():
with open(temp_file, 'r') as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, 'w') as f:
for fen in all_fens:
f.write(fen + '\n')
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
# Label positions
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers
)
labeled_file.unlink(missing_ok=True)
if not success:
console.print("[red]✗ Labeling failed[/red]")
return
# Copy pre-labeled Lichess data if present
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
if lichess_tmp.exists():
import shutil as _shutil
_shutil.copy(lichess_tmp, labeled_file)
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
console.print(f"[green]✓ Lichess positions ready[/green]")
console.print("[green]✓ Positions labeled[/green]")
# Combine and label remaining sources with Stockfish
non_lichess = [s for s in sources if s["type"] != "lichess"]
if non_lichess:
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
for source in non_lichess:
if source["type"] == "generated":
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source["type"] == "file_import":
temp_file = Path(source["path"])
elif source["type"] == "tactical":
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
else:
continue
if temp_file.exists():
with open(temp_file, "r") as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, "w") as f:
for fen in all_fens:
f.write(fen + "\n")
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers,
)
if not success:
console.print("[red]✗ Stockfish labeling failed[/red]")
return
console.print("[green]✓ Positions labeled[/green]")
# Extend dataset
console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]")
@@ -502,8 +593,8 @@ def extend_dataset_interactive():
new_labeled_path=str(labeled_file),
new_source_entry={
"type": "merged_sources",
"count": len(all_fens),
"sources": sources
"count": combined_count,
"sources": sources,
}
)