diff --git a/modules/bot/python/nnue.py b/modules/bot/python/nnue.py index 58bf155..305e94b 100644 --- a/modules/bot/python/nnue.py +++ b/modules/bot/python/nnue.py @@ -23,6 +23,7 @@ from tactical_positions_extractor import ( download_and_extract_puzzle_db, extract_tactical_only ) +from lichess_importer import import_lichess_evals from dataset import ( get_datasets_dir, list_datasets, @@ -189,9 +190,10 @@ def create_dataset_interactive(): console.print("[cyan]a[/cyan] - Generate random positions") console.print("[cyan]b[/cyan] - Import from file") console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles") - console.print("[cyan]d[/cyan] - Done adding sources") + console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)") + console.print("[cyan]e[/cyan] - Done adding sources") - choice = Prompt.ask("Select", choices=["a", "b", "c", "d"]) + choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"]) if choice == "a": num_positions = int(Prompt.ask("Number of positions to generate", default="100000")) @@ -246,6 +248,34 @@ def create_dataset_interactive(): console.print(f"[red]✗ Tactical extraction failed: {e}[/red]") elif choice == "d": + zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst") + max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="") + max_pos = int(max_pos) if max_pos.strip() else None + min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20")) + console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]") + temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl" + temp_file.unlink(missing_ok=True) + try: + count = import_lichess_evals( + input_path=zst_path, + output_file=str(temp_file), + max_positions=max_pos, + min_depth=min_depth, + ) + if count > 0: + sources.append({ + "type": "lichess", + "count": count, + "params": {"min_depth": min_depth, "max_positions": max_pos}, + }) + combined_count += count + console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]") + else: + console.print("[red]✗ No positions imported[/red]") + except Exception as e: + console.print(f"[red]✗ Lichess import failed: {e}[/red]") + + elif choice == "e": if not sources: console.print("[yellow]⚠ No sources added yet[/yellow]") continue @@ -255,77 +285,94 @@ def create_dataset_interactive(): console.print("[yellow]Dataset creation cancelled[/yellow]") return - # Stockfish labeling parameters - console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]") - stockfish_path = Prompt.ask( - "Stockfish path", - default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish" - ) - stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12")) - num_workers = int(Prompt.ask("Number of parallel workers", default="1")) + # Determine whether any sources still need Stockfish labeling. + # Lichess sources are already labeled; only generated/tactical/file sources need it. + needs_labeling = any(s["type"] != "lichess" for s in sources) + + stockfish_depth = 12 + if needs_labeling: + console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]") + stockfish_path = Prompt.ask( + "Stockfish path", + default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish" + ) + stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12")) + num_workers = int(Prompt.ask("Number of parallel workers", default="1")) # Summary and confirm console.print("\n[bold]Dataset Summary:[/bold]") console.print(f" Total positions: {combined_count:,}") for source in sources: console.print(f" - {source['type']}: {source['count']:,}") - console.print(f" Stockfish depth: {stockfish_depth}") + if needs_labeling: + console.print(f" Stockfish depth: {stockfish_depth}") - if not Confirm.ask("\nProceed to label and create dataset?", default=True): + if not Confirm.ask("\nProceed to create dataset?", default=True): console.print("[yellow]Cancelled[/yellow]") return try: - # Combine all sources into one FEN file - console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]") - combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt" - all_fens = set() - - for source in sources: - if source['type'] == 'generated': - temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt" - elif source['type'] == 'file_import': - temp_file = Path(source['path']) - elif source['type'] == 'tactical': - temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt" - - if temp_file.exists(): - with open(temp_file, 'r') as f: - for line in f: - fen = line.strip() - if fen: - all_fens.add(fen) - - with open(combined_fen_file, 'w') as f: - for fen in all_fens: - f.write(fen + '\n') - console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]") - - # Label positions - console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]") labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl" - success = label_positions_with_stockfish( - str(combined_fen_file), - str(labeled_file), - stockfish_path, - depth=stockfish_depth, - num_workers=num_workers - ) + labeled_file.unlink(missing_ok=True) - if not success: - console.print("[red]✗ Labeling failed[/red]") - return + # --- Step 1: Collect already-labeled data (Lichess source) --- + lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl" + if lichess_tmp.exists(): + import shutil as _shutil + _shutil.copy(lichess_tmp, labeled_file) + console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]") + console.print(f"[green]✓ Lichess positions ready[/green]") - console.print("[green]✓ Positions labeled[/green]") + # --- Step 2: Combine unlabeled sources and run Stockfish (if any) --- + non_lichess = [s for s in sources if s["type"] != "lichess"] + if non_lichess: + console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]") + combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt" + all_fens = set() - # Save dataset + for source in non_lichess: + if source["type"] == "generated": + temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt" + elif source["type"] == "file_import": + temp_file = Path(source["path"]) + elif source["type"] == "tactical": + temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt" + else: + continue + + if temp_file.exists(): + with open(temp_file, "r") as f: + for line in f: + fen = line.strip() + if fen: + all_fens.add(fen) + + with open(combined_fen_file, "w") as f: + for fen in all_fens: + f.write(fen + "\n") + console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]") + + console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]") + success = label_positions_with_stockfish( + str(combined_fen_file), + str(labeled_file), + stockfish_path, + depth=stockfish_depth, + num_workers=num_workers, + ) + if not success: + console.print("[red]✗ Stockfish labeling failed[/red]") + return + console.print("[green]✓ Positions labeled[/green]") + + # --- Step 3: Create dataset --- console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]") version = next_dataset_version() create_dataset( version=version, labeled_jsonl_path=str(labeled_file), sources=sources, - stockfish_depth=stockfish_depth + stockfish_depth=stockfish_depth, ) console.print(f"[green]✓ Dataset created: ds_v{version}[/green]") console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]") @@ -368,9 +415,10 @@ def extend_dataset_interactive(): console.print("[cyan]a[/cyan] - Generate random positions") console.print("[cyan]b[/cyan] - Import from file") console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles") - console.print("[cyan]d[/cyan] - Done adding sources") + console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)") + console.print("[cyan]e[/cyan] - Done adding sources") - choice = Prompt.ask("Select", choices=["a", "b", "c", "d"]) + choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"]) if choice == "a": num_positions = int(Prompt.ask("Number of positions to generate", default="100000")) @@ -423,6 +471,34 @@ def extend_dataset_interactive(): console.print(f"[red]✗ Extraction failed: {e}[/red]") elif choice == "d": + zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst") + max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="") + max_pos = int(max_pos) if max_pos.strip() else None + min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20")) + console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]") + temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl" + temp_file.unlink(missing_ok=True) + try: + count = import_lichess_evals( + input_path=zst_path, + output_file=str(temp_file), + max_positions=max_pos, + min_depth=min_depth, + ) + if count > 0: + sources.append({ + "type": "lichess", + "count": count, + "params": {"min_depth": min_depth, "max_positions": max_pos}, + }) + combined_count += count + console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]") + else: + console.print("[red]✗ No positions imported[/red]") + except Exception as e: + console.print(f"[red]✗ Lichess import failed: {e}[/red]") + + elif choice == "e": if not sources: console.print("[yellow]⚠ No sources added yet[/yellow]") continue @@ -432,14 +508,17 @@ def extend_dataset_interactive(): console.print("[yellow]Extension cancelled[/yellow]") return - # Stockfish labeling parameters - console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]") - stockfish_path = Prompt.ask( - "Stockfish path", - default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish" - ) - stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12")) - num_workers = int(Prompt.ask("Number of parallel workers", default="1")) + needs_labeling = any(s["type"] != "lichess" for s in sources) + + stockfish_depth = 12 + if needs_labeling: + console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]") + stockfish_path = Prompt.ask( + "Stockfish path", + default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish" + ) + stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12")) + num_workers = int(Prompt.ask("Number of parallel workers", default="1")) # Summary and confirm console.print("\n[bold]Extension Summary:[/bold]") @@ -447,53 +526,65 @@ def extend_dataset_interactive(): console.print(f" New positions: {combined_count:,}") for source in sources: console.print(f" - {source['type']}: {source['count']:,}") + if needs_labeling: + console.print(f" Stockfish depth: {stockfish_depth}") - if not Confirm.ask("\nProceed to label and extend dataset?", default=True): + if not Confirm.ask("\nProceed to extend dataset?", default=True): console.print("[yellow]Cancelled[/yellow]") return try: - # Combine all sources into one FEN file - console.print("\n[bold cyan]Step 1: Combining sources[/bold cyan]") - combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt" - all_fens = set() - - for source in sources: - if source['type'] == 'generated': - temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt" - elif source['type'] == 'file_import': - temp_file = Path(source['path']) - elif source['type'] == 'tactical': - temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt" - - if temp_file.exists(): - with open(temp_file, 'r') as f: - for line in f: - fen = line.strip() - if fen: - all_fens.add(fen) - - with open(combined_fen_file, 'w') as f: - for fen in all_fens: - f.write(fen + '\n') - console.print(f"[green]✓ Combined {len(all_fens):,} unique positions[/green]") - - # Label positions - console.print("\n[bold cyan]Step 2: Labeling Positions[/bold cyan]") labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl" - success = label_positions_with_stockfish( - str(combined_fen_file), - str(labeled_file), - stockfish_path, - depth=stockfish_depth, - num_workers=num_workers - ) + labeled_file.unlink(missing_ok=True) - if not success: - console.print("[red]✗ Labeling failed[/red]") - return + # Copy pre-labeled Lichess data if present + lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl" + if lichess_tmp.exists(): + import shutil as _shutil + _shutil.copy(lichess_tmp, labeled_file) + console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]") + console.print(f"[green]✓ Lichess positions ready[/green]") - console.print("[green]✓ Positions labeled[/green]") + # Combine and label remaining sources with Stockfish + non_lichess = [s for s in sources if s["type"] != "lichess"] + if non_lichess: + console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]") + combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt" + all_fens = set() + + for source in non_lichess: + if source["type"] == "generated": + temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt" + elif source["type"] == "file_import": + temp_file = Path(source["path"]) + elif source["type"] == "tactical": + temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt" + else: + continue + if temp_file.exists(): + with open(temp_file, "r") as f: + for line in f: + fen = line.strip() + if fen: + all_fens.add(fen) + + with open(combined_fen_file, "w") as f: + for fen in all_fens: + f.write(fen + "\n") + console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]") + + console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]") + success = label_positions_with_stockfish( + str(combined_fen_file), + str(labeled_file), + stockfish_path, + depth=stockfish_depth, + num_workers=num_workers, + ) + if not success: + console.print("[red]✗ Stockfish labeling failed[/red]") + return + console.print("[green]✓ Positions labeled[/green]") # Extend dataset console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]") @@ -502,8 +593,8 @@ def extend_dataset_interactive(): new_labeled_path=str(labeled_file), new_source_entry={ "type": "merged_sources", - "count": len(all_fens), - "sources": sources + "count": combined_count, + "sources": sources, } ) diff --git a/modules/bot/python/src/lichess_importer.py b/modules/bot/python/src/lichess_importer.py new file mode 100644 index 0000000..9fb2c56 --- /dev/null +++ b/modules/bot/python/src/lichess_importer.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +"""Import pre-labeled positions from the Lichess evaluation database. + +Source: https://database.lichess.org/#evals +Format: lichess_db_eval.jsonl.zst — compressed JSONL, one position per line. + +Each line: + { + "fen": " ", + "evals": [ + { + "knodes": , + "depth": , + "pvs": [{"cp": , "line": "..."} | {"mate": , "line": "..."}] + } + ] + } + +cp and mate are from White's perspective (positive = White winning), matching +the sign convention used by label.py (score.white()) and expected by train.py. +""" + +import json +import sys +import numpy as np +from pathlib import Path +from tqdm import tqdm + +MATE_CP = 20000 +SCALE = 300.0 + + +def _best_eval(evals: list) -> dict | None: + """Return the highest-depth evaluation entry, using knodes as tiebreaker.""" + if not evals: + return None + return max(evals, key=lambda e: (e.get("depth", 0), e.get("knodes", 0))) + + +def _cp_from_pv(pv: dict) -> int | None: + """Extract centipawn value from a principal variation entry.""" + if "cp" in pv: + return max(-MATE_CP, min(MATE_CP, pv["cp"])) + if "mate" in pv: + return MATE_CP if pv["mate"] > 0 else -MATE_CP + return None + + +def _normalize(cp: int) -> float: + return float(np.tanh(cp / SCALE)) + + +def import_lichess_evals( + input_path: str, + output_file: str, + max_positions: int | None = None, + min_depth: int = 0, + verbose: bool = False, +) -> int: + """Stream the Lichess eval database and write a labeled.jsonl file. + + Args: + input_path: Path to lichess_db_eval.jsonl.zst (or uncompressed .jsonl). + output_file: Destination labeled.jsonl (appended — supports resuming). + max_positions: Stop after this many new positions (None = no limit). + min_depth: Skip positions whose best eval has depth < min_depth. + verbose: Print warnings for skipped lines. + + Returns: + Number of new positions written. + """ + import zstandard as zstd + + input_path = Path(input_path) + if not input_path.exists(): + print(f"Error: {input_path} not found") + sys.exit(1) + + # Resume: collect already-written FENs so we skip duplicates. + seen_fens: set[str] = set() + if Path(output_file).exists(): + with open(output_file, "r") as f: + for line in f: + try: + seen_fens.add(json.loads(line)["fen"]) + except (json.JSONDecodeError, KeyError): + pass + if seen_fens: + print(f"Resuming — skipping {len(seen_fens):,} already-imported positions") + + written = 0 + skipped_depth = 0 + skipped_no_eval = 0 + skipped_dup = 0 + + def iter_lines(): + """Yield decoded text lines from either a .zst or plain .jsonl file.""" + import io + if input_path.suffix == ".zst": + dctx = zstd.ZstdDecompressor() + with open(input_path, "rb") as fh: + with dctx.stream_reader(fh) as reader: + text_stream = io.TextIOWrapper(reader, encoding="utf-8") + yield from text_stream + else: + with open(input_path, "r", encoding="utf-8") as fh: + yield from fh + + try: + with open(output_file, "a") as out: + with tqdm(desc="Importing Lichess evals", unit=" pos") as pbar: + for raw_line in iter_lines(): + line = raw_line.strip() + if not line: + continue + + try: + data = json.loads(line) + except json.JSONDecodeError: + if verbose: + print("Warning: malformed JSON line skipped") + continue + + fen = data.get("fen", "") + if not fen: + skipped_no_eval += 1 + continue + + if fen in seen_fens: + skipped_dup += 1 + continue + + best = _best_eval(data.get("evals", [])) + if best is None: + skipped_no_eval += 1 + continue + + if best.get("depth", 0) < min_depth: + skipped_depth += 1 + continue + + pvs = best.get("pvs", []) + if not pvs: + skipped_no_eval += 1 + continue + + cp = _cp_from_pv(pvs[0]) + if cp is None: + skipped_no_eval += 1 + continue + + record = { + "fen": fen, + "eval": _normalize(cp), + "eval_raw": cp, + } + out.write(json.dumps(record) + "\n") + seen_fens.add(fen) + written += 1 + pbar.update(1) + + if max_positions and written >= max_positions: + print(f"\nReached max_positions limit ({max_positions:,})") + break + + except Exception: + raise + + print() + print("=" * 60) + print("LICHESS IMPORT SUMMARY") + print("=" * 60) + print(f"Positions written: {written:,}") + print(f"Skipped (dup): {skipped_dup:,}") + print(f"Skipped (no eval): {skipped_no_eval:,}") + print(f"Skipped (depth<{min_depth}): {skipped_depth:,}") + print("=" * 60) + print(f"\n✓ Output: {output_file}") + + return written + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Import Lichess pre-labeled positions into labeled.jsonl" + ) + parser.add_argument("input_path", + help="Path to lichess_db_eval.jsonl.zst") + parser.add_argument("output_file", nargs="?", default="training_data.jsonl", + help="Output labeled.jsonl (default: training_data.jsonl)") + parser.add_argument("--max-positions", type=int, default=None, + help="Stop after N positions (default: no limit)") + parser.add_argument("--min-depth", type=int, default=0, + help="Minimum eval depth to accept (default: 0)") + parser.add_argument("--verbose", action="store_true", + help="Print warnings for skipped lines") + + args = parser.parse_args() + count = import_lichess_evals( + input_path=args.input_path, + output_file=args.output_file, + max_positions=args.max_positions, + min_depth=args.min_depth, + verbose=args.verbose, + ) + sys.exit(0 if count > 0 else 1) diff --git a/modules/bot/python/weights/nnue_weights_v10.pt b/modules/bot/python/weights/nnue_weights_v10.pt new file mode 100644 index 0000000..03f2e64 Binary files /dev/null and b/modules/bot/python/weights/nnue_weights_v10.pt differ diff --git a/modules/bot/python/weights/nnue_weights_v10_metadata.json b/modules/bot/python/weights/nnue_weights_v10_metadata.json new file mode 100644 index 0000000..7310257 --- /dev/null +++ b/modules/bot/python/weights/nnue_weights_v10_metadata.json @@ -0,0 +1,21 @@ +{ + "version": 10, + "date": "2026-04-14T22:18:38.824577", + "num_positions": 3022562, + "stockfish_depth": 12, + "final_val_loss": 6.248612448225196e-05, + "architecture": [ + 768, + 1536, + 1024, + 512, + 256, + 1 + ], + "device": "cuda", + "notes": "Win rate vs classical eval: TBD (requires benchmark games)", + "epochs": 100, + "batch_size": 16384, + "learning_rate": 0.001, + "checkpoint": null +} \ No newline at end of file