feat: Add Lichess eval importer and metadata for NNUE architecture
This commit is contained in:
+194
-103
@@ -23,6 +23,7 @@ from tactical_positions_extractor import (
|
||||
download_and_extract_puzzle_db,
|
||||
extract_tactical_only
|
||||
)
|
||||
from lichess_importer import import_lichess_evals
|
||||
from dataset import (
|
||||
get_datasets_dir,
|
||||
list_datasets,
|
||||
@@ -189,9 +190,10 @@ def create_dataset_interactive():
|
||||
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||
console.print("[cyan]b[/cyan] - Import from file")
|
||||
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||
console.print("[cyan]d[/cyan] - Done adding sources")
|
||||
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
|
||||
console.print("[cyan]e[/cyan] - Done adding sources")
|
||||
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
|
||||
|
||||
if choice == "a":
|
||||
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||
@@ -246,6 +248,34 @@ def create_dataset_interactive():
|
||||
console.print(f"[red]✗ Tactical extraction failed: {e}[/red]")
|
||||
|
||||
elif choice == "d":
|
||||
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
|
||||
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
|
||||
max_pos = int(max_pos) if max_pos.strip() else None
|
||||
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
|
||||
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
temp_file.unlink(missing_ok=True)
|
||||
try:
|
||||
count = import_lichess_evals(
|
||||
input_path=zst_path,
|
||||
output_file=str(temp_file),
|
||||
max_positions=max_pos,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "lichess",
|
||||
"count": count,
|
||||
"params": {"min_depth": min_depth, "max_positions": max_pos},
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
|
||||
else:
|
||||
console.print("[red]✗ No positions imported[/red]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
|
||||
|
||||
elif choice == "e":
|
||||
if not sources:
|
||||
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||
continue
|
||||
@@ -255,77 +285,94 @@ def create_dataset_interactive():
|
||||
console.print("[yellow]Dataset creation cancelled[/yellow]")
|
||||
return
|
||||
|
||||
# Stockfish labeling parameters
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
# Determine whether any sources still need Stockfish labeling.
|
||||
# Lichess sources are already labeled; only generated/tactical/file sources need it.
|
||||
needs_labeling = any(s["type"] != "lichess" for s in sources)
|
||||
|
||||
stockfish_depth = 12
|
||||
if needs_labeling:
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Summary and confirm
|
||||
console.print("\n[bold]Dataset Summary:[/bold]")
|
||||
console.print(f" Total positions: {combined_count:,}")
|
||||
for source in sources:
|
||||
console.print(f" - {source['type']}: {source['count']:,}")
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
if needs_labeling:
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
|
||||
if not Confirm.ask("\nProceed to label and create dataset?", default=True):
|
||||
if not Confirm.ask("\nProceed to create dataset?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
# Combine all sources into one FEN file
|
||||
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
for source in sources:
|
||||
if source['type'] == 'generated':
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source['type'] == 'file_import':
|
||||
temp_file = Path(source['path'])
|
||||
elif source['type'] == 'tactical':
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
|
||||
if temp_file.exists():
|
||||
with open(temp_file, 'r') as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, 'w') as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + '\n')
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
|
||||
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers
|
||||
)
|
||||
labeled_file.unlink(missing_ok=True)
|
||||
|
||||
if not success:
|
||||
console.print("[red]✗ Labeling failed[/red]")
|
||||
return
|
||||
# --- Step 1: Collect already-labeled data (Lichess source) ---
|
||||
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
if lichess_tmp.exists():
|
||||
import shutil as _shutil
|
||||
_shutil.copy(lichess_tmp, labeled_file)
|
||||
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
|
||||
console.print(f"[green]✓ Lichess positions ready[/green]")
|
||||
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
# --- Step 2: Combine unlabeled sources and run Stockfish (if any) ---
|
||||
non_lichess = [s for s in sources if s["type"] != "lichess"]
|
||||
if non_lichess:
|
||||
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
# Save dataset
|
||||
for source in non_lichess:
|
||||
if source["type"] == "generated":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source["type"] == "file_import":
|
||||
temp_file = Path(source["path"])
|
||||
elif source["type"] == "tactical":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
else:
|
||||
continue
|
||||
|
||||
if temp_file.exists():
|
||||
with open(temp_file, "r") as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, "w") as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + "\n")
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
|
||||
|
||||
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Stockfish labeling failed[/red]")
|
||||
return
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
|
||||
# --- Step 3: Create dataset ---
|
||||
console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]")
|
||||
version = next_dataset_version()
|
||||
create_dataset(
|
||||
version=version,
|
||||
labeled_jsonl_path=str(labeled_file),
|
||||
sources=sources,
|
||||
stockfish_depth=stockfish_depth
|
||||
stockfish_depth=stockfish_depth,
|
||||
)
|
||||
console.print(f"[green]✓ Dataset created: ds_v{version}[/green]")
|
||||
console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]")
|
||||
@@ -368,9 +415,10 @@ def extend_dataset_interactive():
|
||||
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||
console.print("[cyan]b[/cyan] - Import from file")
|
||||
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||
console.print("[cyan]d[/cyan] - Done adding sources")
|
||||
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
|
||||
console.print("[cyan]e[/cyan] - Done adding sources")
|
||||
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d"])
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
|
||||
|
||||
if choice == "a":
|
||||
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||
@@ -423,6 +471,34 @@ def extend_dataset_interactive():
|
||||
console.print(f"[red]✗ Extraction failed: {e}[/red]")
|
||||
|
||||
elif choice == "d":
|
||||
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
|
||||
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
|
||||
max_pos = int(max_pos) if max_pos.strip() else None
|
||||
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
|
||||
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
temp_file.unlink(missing_ok=True)
|
||||
try:
|
||||
count = import_lichess_evals(
|
||||
input_path=zst_path,
|
||||
output_file=str(temp_file),
|
||||
max_positions=max_pos,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "lichess",
|
||||
"count": count,
|
||||
"params": {"min_depth": min_depth, "max_positions": max_pos},
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
|
||||
else:
|
||||
console.print("[red]✗ No positions imported[/red]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
|
||||
|
||||
elif choice == "e":
|
||||
if not sources:
|
||||
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||
continue
|
||||
@@ -432,14 +508,17 @@ def extend_dataset_interactive():
|
||||
console.print("[yellow]Extension cancelled[/yellow]")
|
||||
return
|
||||
|
||||
# Stockfish labeling parameters
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
needs_labeling = any(s["type"] != "lichess" for s in sources)
|
||||
|
||||
stockfish_depth = 12
|
||||
if needs_labeling:
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Summary and confirm
|
||||
console.print("\n[bold]Extension Summary:[/bold]")
|
||||
@@ -447,53 +526,65 @@ def extend_dataset_interactive():
|
||||
console.print(f" New positions: {combined_count:,}")
|
||||
for source in sources:
|
||||
console.print(f" - {source['type']}: {source['count']:,}")
|
||||
if needs_labeling:
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
|
||||
if not Confirm.ask("\nProceed to label and extend dataset?", default=True):
|
||||
if not Confirm.ask("\nProceed to extend dataset?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
# Combine all sources into one FEN file
|
||||
console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
for source in sources:
|
||||
if source['type'] == 'generated':
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source['type'] == 'file_import':
|
||||
temp_file = Path(source['path'])
|
||||
elif source['type'] == 'tactical':
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
|
||||
if temp_file.exists():
|
||||
with open(temp_file, 'r') as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, 'w') as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + '\n')
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]")
|
||||
|
||||
# Label positions
|
||||
console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]")
|
||||
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers
|
||||
)
|
||||
labeled_file.unlink(missing_ok=True)
|
||||
|
||||
if not success:
|
||||
console.print("[red]✗ Labeling failed[/red]")
|
||||
return
|
||||
# Copy pre-labeled Lichess data if present
|
||||
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
if lichess_tmp.exists():
|
||||
import shutil as _shutil
|
||||
_shutil.copy(lichess_tmp, labeled_file)
|
||||
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
|
||||
console.print(f"[green]✓ Lichess positions ready[/green]")
|
||||
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
# Combine and label remaining sources with Stockfish
|
||||
non_lichess = [s for s in sources if s["type"] != "lichess"]
|
||||
if non_lichess:
|
||||
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
for source in non_lichess:
|
||||
if source["type"] == "generated":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source["type"] == "file_import":
|
||||
temp_file = Path(source["path"])
|
||||
elif source["type"] == "tactical":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
else:
|
||||
continue
|
||||
if temp_file.exists():
|
||||
with open(temp_file, "r") as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, "w") as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + "\n")
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
|
||||
|
||||
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Stockfish labeling failed[/red]")
|
||||
return
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
|
||||
# Extend dataset
|
||||
console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]")
|
||||
@@ -502,8 +593,8 @@ def extend_dataset_interactive():
|
||||
new_labeled_path=str(labeled_file),
|
||||
new_source_entry={
|
||||
"type": "merged_sources",
|
||||
"count": len(all_fens),
|
||||
"sources": sources
|
||||
"count": combined_count,
|
||||
"sources": sources,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user