dceab0875e
Build & Test (NowChessSystems) TeamCity build finished
Co-authored-by: Janis <janis@nowchess.de> Reviewed-on: #33 Co-authored-by: Janis <janis.e.20@gmx.de> Co-committed-by: Janis <janis.e.20@gmx.de>
288 lines
8.5 KiB
Python
288 lines
8.5 KiB
Python
#!/usr/bin/env python3
|
||
"""Dataset versioning and management for NNUE training data."""
|
||
|
||
import json
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from typing import Optional, Dict, List, Tuple
|
||
from rich.console import Console
|
||
from rich.table import Table
|
||
|
||
|
||
def get_datasets_dir() -> Path:
|
||
"""Get/create datasets directory."""
|
||
datasets_dir = Path(__file__).parent.parent / "datasets"
|
||
datasets_dir.mkdir(exist_ok=True)
|
||
return datasets_dir
|
||
|
||
|
||
def next_dataset_version() -> int:
|
||
"""Find the next available dataset version number."""
|
||
datasets_dir = get_datasets_dir()
|
||
versions = []
|
||
|
||
for d in datasets_dir.iterdir():
|
||
if d.is_dir() and d.name.startswith("ds_v"):
|
||
try:
|
||
v = int(d.name.split("_v")[1])
|
||
versions.append(v)
|
||
except (ValueError, IndexError):
|
||
pass
|
||
|
||
return max(versions) + 1 if versions else 1
|
||
|
||
|
||
def list_datasets() -> List[Tuple[int, Dict]]:
|
||
"""List all datasets with their metadata.
|
||
|
||
Returns:
|
||
List of (version, metadata_dict) tuples, sorted by version.
|
||
"""
|
||
datasets_dir = get_datasets_dir()
|
||
datasets = []
|
||
|
||
for d in datasets_dir.iterdir():
|
||
if d.is_dir() and d.name.startswith("ds_v"):
|
||
try:
|
||
v = int(d.name.split("_v")[1])
|
||
metadata_file = d / "metadata.json"
|
||
if metadata_file.exists():
|
||
with open(metadata_file, 'r') as f:
|
||
metadata = json.load(f)
|
||
datasets.append((v, metadata))
|
||
except (ValueError, IndexError, json.JSONDecodeError):
|
||
pass
|
||
|
||
return sorted(datasets, key=lambda x: x[0])
|
||
|
||
|
||
def load_dataset_metadata(version: int) -> Optional[Dict]:
|
||
"""Load metadata for a specific dataset version.
|
||
|
||
Returns:
|
||
Metadata dict or None if not found.
|
||
"""
|
||
datasets_dir = get_datasets_dir()
|
||
metadata_file = datasets_dir / f"ds_v{version}" / "metadata.json"
|
||
|
||
if not metadata_file.exists():
|
||
return None
|
||
|
||
with open(metadata_file, 'r') as f:
|
||
return json.load(f)
|
||
|
||
|
||
def save_dataset_metadata(version: int, metadata: Dict) -> None:
|
||
"""Save metadata for a dataset version."""
|
||
datasets_dir = get_datasets_dir()
|
||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||
dataset_dir.mkdir(exist_ok=True)
|
||
|
||
metadata_file = dataset_dir / "metadata.json"
|
||
with open(metadata_file, 'w') as f:
|
||
json.dump(metadata, f, indent=2, default=str)
|
||
|
||
|
||
def create_dataset(
|
||
version: int,
|
||
labeled_jsonl_path: str,
|
||
sources: List[Dict],
|
||
stockfish_depth: int = 12
|
||
) -> Path:
|
||
"""Create a new versioned dataset.
|
||
|
||
Args:
|
||
version: Dataset version number
|
||
labeled_jsonl_path: Path to labeled.jsonl to copy
|
||
sources: List of source dicts (see plan for schema)
|
||
stockfish_depth: Depth used for labeling
|
||
|
||
Returns:
|
||
Path to the created dataset directory.
|
||
"""
|
||
datasets_dir = get_datasets_dir()
|
||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||
dataset_dir.mkdir(exist_ok=True)
|
||
|
||
# Copy labeled data with deduplication (in case source has duplicates)
|
||
source_path = Path(labeled_jsonl_path)
|
||
if source_path.exists():
|
||
dest_path = dataset_dir / "labeled.jsonl"
|
||
seen_fens = set()
|
||
unique_count = 0
|
||
|
||
with open(source_path, 'r') as src, open(dest_path, 'w') as dst:
|
||
for line in src:
|
||
try:
|
||
data = json.loads(line)
|
||
fen = data.get('fen')
|
||
if fen and fen not in seen_fens:
|
||
dst.write(line)
|
||
seen_fens.add(fen)
|
||
unique_count += 1
|
||
except json.JSONDecodeError:
|
||
# Skip malformed lines
|
||
pass
|
||
|
||
# Count positions
|
||
total_positions = 0
|
||
if (dataset_dir / "labeled.jsonl").exists():
|
||
with open(dataset_dir / "labeled.jsonl", 'r') as f:
|
||
total_positions = sum(1 for _ in f)
|
||
|
||
# Create metadata
|
||
metadata = {
|
||
"version": version,
|
||
"created": datetime.now().isoformat(),
|
||
"total_positions": total_positions,
|
||
"stockfish_depth": stockfish_depth,
|
||
"sources": sources
|
||
}
|
||
|
||
save_dataset_metadata(version, metadata)
|
||
return dataset_dir
|
||
|
||
|
||
def extend_dataset(
|
||
version: int,
|
||
new_labeled_path: str,
|
||
new_source_entry: Dict
|
||
) -> bool:
|
||
"""Extend an existing dataset with new labeled positions (with deduplication).
|
||
|
||
Args:
|
||
version: Dataset version to extend
|
||
new_labeled_path: Path to new labeled.jsonl to merge
|
||
new_source_entry: Source entry to add to metadata
|
||
|
||
Returns:
|
||
True if successful, False otherwise.
|
||
"""
|
||
datasets_dir = get_datasets_dir()
|
||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||
|
||
if not dataset_dir.exists():
|
||
return False
|
||
|
||
labeled_file = dataset_dir / "labeled.jsonl"
|
||
new_labeled_file = Path(new_labeled_path)
|
||
|
||
if not new_labeled_file.exists():
|
||
return False
|
||
|
||
# Load existing FENs (dedup set) — must load entire file to avoid duplicates
|
||
existing_fens = set()
|
||
if labeled_file.exists():
|
||
with open(labeled_file, 'r') as f:
|
||
for line in f:
|
||
try:
|
||
data = json.loads(line)
|
||
fen = data.get('fen')
|
||
if fen:
|
||
existing_fens.add(fen)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# Merge new positions, skipping duplicates
|
||
new_count = 0
|
||
new_lines = []
|
||
with open(new_labeled_file, 'r') as f_new:
|
||
for line in f_new:
|
||
try:
|
||
data = json.loads(line)
|
||
fen = data.get('fen')
|
||
if fen and fen not in existing_fens:
|
||
new_lines.append(line)
|
||
existing_fens.add(fen)
|
||
new_count += 1
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# Append only the new, unique positions
|
||
if new_lines:
|
||
with open(labeled_file, 'a') as f_append:
|
||
for line in new_lines:
|
||
f_append.write(line)
|
||
|
||
# Update metadata
|
||
metadata = load_dataset_metadata(version)
|
||
if metadata:
|
||
# Count total positions
|
||
total_positions = 0
|
||
with open(labeled_file, 'r') as f:
|
||
total_positions = sum(1 for _ in f)
|
||
|
||
metadata['total_positions'] = total_positions
|
||
# Update the source entry with actual count of new positions added
|
||
new_source_entry['actual_count'] = new_count
|
||
metadata['sources'].append(new_source_entry)
|
||
save_dataset_metadata(version, metadata)
|
||
|
||
return True
|
||
|
||
|
||
def get_dataset_labeled_path(version: int) -> Optional[Path]:
|
||
"""Get the path to a dataset's labeled.jsonl file.
|
||
|
||
Returns:
|
||
Path to labeled.jsonl or None if dataset doesn't exist.
|
||
"""
|
||
datasets_dir = get_datasets_dir()
|
||
labeled_file = datasets_dir / f"ds_v{version}" / "labeled.jsonl"
|
||
|
||
if labeled_file.exists():
|
||
return labeled_file
|
||
return None
|
||
|
||
|
||
def delete_dataset(version: int) -> bool:
|
||
"""Delete a dataset (recursively removes directory).
|
||
|
||
Args:
|
||
version: Dataset version to delete
|
||
|
||
Returns:
|
||
True if successful.
|
||
"""
|
||
datasets_dir = get_datasets_dir()
|
||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||
|
||
if not dataset_dir.exists():
|
||
return False
|
||
|
||
import shutil
|
||
shutil.rmtree(dataset_dir)
|
||
return True
|
||
|
||
|
||
def show_datasets_table(console: Console = None) -> None:
|
||
"""Display all datasets in a Rich table."""
|
||
if console is None:
|
||
console = Console()
|
||
|
||
datasets = list_datasets()
|
||
|
||
if not datasets:
|
||
console.print("[yellow]ℹ No datasets found yet[/yellow]")
|
||
return
|
||
|
||
table = Table(title="Available Datasets", show_header=True, header_style="bold cyan")
|
||
table.add_column("Version", style="dim")
|
||
table.add_column("Positions", justify="right")
|
||
table.add_column("Sources", justify="left")
|
||
table.add_column("Depth", justify="center")
|
||
table.add_column("Created", justify="left")
|
||
|
||
for v, metadata in datasets:
|
||
positions = metadata.get('total_positions', 0)
|
||
sources = metadata.get('sources', [])
|
||
source_str = ", ".join([s.get('type', '?') for s in sources])
|
||
depth = metadata.get('stockfish_depth', '?')
|
||
created = metadata.get('created', '?')
|
||
if created != '?':
|
||
created = created.split('T')[0] # Just the date
|
||
|
||
table.add_row(f"v{v}", f"{positions:,}", source_str, str(depth), created)
|
||
|
||
console.print(table)
|