Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8744bee2dd | |||
| 5f4d33f3ca | |||
| 767d3051a7 |
+163
-32
@@ -2,8 +2,8 @@
|
||||
|
||||
> **Stack:** raw-http | none | unknown | scala
|
||||
|
||||
> 0 routes | 0 models | 0 components | 35 lib files | 0 env vars | 0 middleware
|
||||
> **Token savings:** this file is ~3.700 tokens. Without it, AI exploration would cost ~18.200 tokens. **Saves ~14.500 tokens per conversation.**
|
||||
> 0 routes | 0 models | 0 components | 63 lib files | 1 env vars | 1 middleware
|
||||
> **Token savings:** this file is ~0 tokens. Without it, AI exploration would cost ~0 tokens. **Saves ~0 tokens per conversation.**
|
||||
|
||||
---
|
||||
|
||||
@@ -60,6 +60,113 @@
|
||||
- class ApiResponse
|
||||
- function error
|
||||
- function totalPages
|
||||
- `modules/bot/python/nnue.py`
|
||||
- function get_weights_dir: ()
|
||||
- function get_data_dir: ()
|
||||
- function list_checkpoints: ()
|
||||
- function migrate_legacy_data: ()
|
||||
- function show_header: ()
|
||||
- function show_checkpoints_table: ()
|
||||
- _...10 more_
|
||||
- `modules/bot/python/src/dataset.py`
|
||||
- function get_datasets_dir: () -> Path
|
||||
- function next_dataset_version: () -> int
|
||||
- function list_datasets: () -> List[Tuple[int, Dict]]
|
||||
- function load_dataset_metadata: (version) -> Optional[Dict]
|
||||
- function save_dataset_metadata: (version, metadata) -> None
|
||||
- function create_dataset: (version, labeled_jsonl_path, sources, stockfish_depth) -> Path
|
||||
- _...4 more_
|
||||
- `modules/bot/python/src/export.py` — function export_to_nbai: (weights_file, output_file, trained_by, train_loss)
|
||||
- `modules/bot/python/src/generate.py` — function play_random_game_and_collect_positions: (output_file, total_positions, samples_per_game, min_move, max_move, num_workers)
|
||||
- `modules/bot/python/src/label.py` — function normalize_evaluation: (cp_value, method, scale), function label_positions_with_stockfish: (positions_file, output_file, stockfish_path, batch_size, depth, verbose, normalize, num_workers)
|
||||
- `modules/bot/python/src/tactical_positions_extractor.py`
|
||||
- function download_and_extract_puzzle_db: (url, output_dir)
|
||||
- function extract_puzzle_positions: (puzzle_csv, max_puzzles) -> Set[str]
|
||||
- function load_positions_from_file: (file_path) -> Set[str]
|
||||
- function merge_positions: (tactical, other, output_file)
|
||||
- function extract_tactical_only: (puzzle_csv, output_file, max_puzzles) -> int
|
||||
- function interactive_merge_positions: (puzzle_csv, output_file, max_puzzles)
|
||||
- `modules/bot/python/src/train.py`
|
||||
- function fen_to_features: (fen)
|
||||
- function find_next_version: (base_name)
|
||||
- function save_metadata: (weights_file, metadata)
|
||||
- function train_nnue: (data_file, output_file, epochs, batch_size, lr, checkpoint, stockfish_depth, use_versioning, early_stopping_patience, weight_decay, subsample_ratio)
|
||||
- function burst_train: (data_file, output_file, duration_minutes, epochs_per_season, early_stopping_patience, batch_size, lr, initial_checkpoint, stockfish_depth, use_versioning, weight_decay, subsample_ratio)
|
||||
- class NNUEDataset
|
||||
- _...1 more_
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`
|
||||
- class Bot
|
||||
- function name
|
||||
- function nextMove
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/BotController.scala`
|
||||
- class BotController
|
||||
- function getBot
|
||||
- function listBots
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala`
|
||||
- class BotMoveRepetition
|
||||
- function blockedMoves
|
||||
- function repeatedMove
|
||||
- function filterAllowed
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/Config.scala` — class Config
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/ai/Evaluation.scala`
|
||||
- class Evaluation
|
||||
- class CHECKMATE_SCORE
|
||||
- class DRAW_SCORE
|
||||
- function evaluate
|
||||
- function initAccumulator
|
||||
- function copyAccumulator
|
||||
- _...2 more_
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`
|
||||
- class EvaluationClassic
|
||||
- function evaluate
|
||||
- function countRay
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala` — class EvaluationNNUE, function evaluate
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`
|
||||
- class NNUE
|
||||
- function initAccumulator
|
||||
- function pushAccumulator
|
||||
- function copyAccumulator
|
||||
- function recomputeAccumulator
|
||||
- function validateAccumulator
|
||||
- _...4 more_
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiLoader.scala`
|
||||
- class NbaiLoader
|
||||
- function load
|
||||
- function loadDefault
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiMigrator.scala` — class NbaiMigrator, function migrateFromBin
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiModel.scala`
|
||||
- function toJson
|
||||
- class NbaiMetadata
|
||||
- function fromJson
|
||||
- function str
|
||||
- function num
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiWriter.scala` — class NbaiWriter, function write
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`
|
||||
- function bestMove
|
||||
- function bestMove
|
||||
- function bestMoveWithTime
|
||||
- function bestMoveWithTime
|
||||
- function loop
|
||||
- function loop
|
||||
- _...2 more_
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`
|
||||
- class MoveOrdering
|
||||
- class OrderingContext
|
||||
- function addKillerMove
|
||||
- function getKillerMoves
|
||||
- function addHistory
|
||||
- function getHistory
|
||||
- _...3 more_
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/logic/TranspositionTable.scala`
|
||||
- function probe
|
||||
- function store
|
||||
- function clear
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotBook.scala` — function probe, function select
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala` — class PolyglotHash, function hash
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/util/ZobristHash.scala`
|
||||
- class ZobristHash
|
||||
- function hash
|
||||
- function nextHash
|
||||
- `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`
|
||||
- class Command
|
||||
- function execute
|
||||
@@ -82,7 +189,7 @@
|
||||
- function turn
|
||||
- function context
|
||||
- function canUndo
|
||||
- _...10 more_
|
||||
- _...11 more_
|
||||
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`
|
||||
- function context
|
||||
- class Observer
|
||||
@@ -93,6 +200,13 @@
|
||||
- _...1 more_
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — class GameContextExport, function exportGameContext
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — class GameContextImport, function importGameContext
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameFileService.scala`
|
||||
- class GameFileService
|
||||
- function saveGameToFile
|
||||
- function loadGameFromFile
|
||||
- class FileSystemGameService
|
||||
- function saveGameToFile
|
||||
- function loadGameFromFile
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala`
|
||||
- class FenExporter
|
||||
- function boardToFen
|
||||
@@ -114,6 +228,8 @@
|
||||
- function parseBoard
|
||||
- function importGameContext
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — function buildSquares
|
||||
- `modules/io/src/main/scala/de/nowchess/io/json/JsonExporter.scala` — class JsonExporter, function exportGameContext
|
||||
- `modules/io/src/main/scala/de/nowchess/io/json/JsonParser.scala` — class JsonParser, function importGameContext
|
||||
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala`
|
||||
- class PgnExporter
|
||||
- function exportGameContext
|
||||
@@ -160,43 +276,58 @@
|
||||
|
||||
---
|
||||
|
||||
# Config
|
||||
|
||||
## Environment Variables
|
||||
|
||||
- `STOCKFISH_PATH` **required** — modules/bot/python/nnue.py
|
||||
|
||||
---
|
||||
|
||||
# Middleware
|
||||
|
||||
## custom
|
||||
- generate — `modules/bot/python/src/generate.py`
|
||||
|
||||
---
|
||||
|
||||
# Dependency Graph
|
||||
|
||||
## Most Imported Files (change these carefully)
|
||||
|
||||
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` — imported by **28** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` — imported by **21** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` — imported by **19** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` — imported by **14** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` — imported by **13** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` — imported by **10** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` — imported by **9** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` — imported by **9** files
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` — imported by **8** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — imported by **7** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/CastlingRights.scala` — imported by **4** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — imported by **4** files
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` — imported by **4** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` — imported by **60** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` — imported by **40** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` — imported by **39** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` — imported by **36** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` — imported by **22** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` — imported by **21** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` — imported by **21** files
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` — imported by **17** files
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` — imported by **10** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` — imported by **10** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/CastlingRights.scala` — imported by **8** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — imported by **8** files
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotBook.scala` — imported by **5** files
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/BotDifficulty.scala` — imported by **5** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — imported by **5** files
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala` — imported by **4** files
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala` — imported by **4** files
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala` — imported by **4** files
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala` — imported by **4** files
|
||||
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala` — imported by **4** files
|
||||
- `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala` — imported by **4** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnParser.scala` — imported by **2** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala` — imported by **2** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala` — imported by **2** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — imported by **2** files
|
||||
- `modules/core/src/main/scala/de/nowchess/chess/controller/Parser.scala` — imported by **1** files
|
||||
|
||||
## Import Map (who imports what)
|
||||
|
||||
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` ← `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`, `modules/core/src/test/scala/de/nowchess/chess/command/CommandInvokerBranchTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/command/CommandInvokerTest.scala` +23 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/main/scala/de/nowchess/api/move/Move.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/api/src/test/scala/de/nowchess/api/move/MoveTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala` +16 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala` +14 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/board/BoardTest.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala` +9 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineGameEndingTest.scala` +8 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` ← `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineScenarioTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesStateTransitionsTest.scala` +5 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` ← `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesStateTransitionsTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesTest.scala` +4 more
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` ← `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineLoadGameTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineNotationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineScenarioTest.scala` +4 more
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` ← `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala` +3 more
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` ← `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParserCombinators.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParserFastParse.scala` +2 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala`, `modules/bot/src/main/scala/de/nowchess/bot/ai/Evaluation.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala` +55 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/board/BoardTest.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala` +35 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/main/scala/de/nowchess/api/move/Move.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/api/src/test/scala/de/nowchess/api/move/MoveTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala` +34 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala` +31 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +17 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala` +16 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/ZobristHash.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +16 more
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/NNUEBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +12 more
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/NNUEBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +5 more
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` ← `modules/bot/src/test/scala/de/nowchess/bot/PolyglotHashTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineLoadGameTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineNotationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala` +5 more
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
# Config
|
||||
|
||||
## Environment Variables
|
||||
|
||||
- `STOCKFISH_PATH` **required** — modules/bot/python/nnue.py
|
||||
+29
-29
@@ -2,36 +2,36 @@
|
||||
|
||||
## Most Imported Files (change these carefully)
|
||||
|
||||
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` — imported by **28** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` — imported by **21** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` — imported by **19** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` — imported by **14** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` — imported by **13** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` — imported by **10** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` — imported by **9** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` — imported by **9** files
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` — imported by **8** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — imported by **7** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/CastlingRights.scala` — imported by **4** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — imported by **4** files
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` — imported by **4** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` — imported by **60** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` — imported by **40** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` — imported by **39** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` — imported by **36** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` — imported by **22** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` — imported by **21** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` — imported by **21** files
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` — imported by **17** files
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` — imported by **10** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` — imported by **10** files
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/CastlingRights.scala` — imported by **8** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — imported by **8** files
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotBook.scala` — imported by **5** files
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/BotDifficulty.scala` — imported by **5** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — imported by **5** files
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala` — imported by **4** files
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala` — imported by **4** files
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala` — imported by **4** files
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala` — imported by **4** files
|
||||
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala` — imported by **4** files
|
||||
- `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala` — imported by **4** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnParser.scala` — imported by **2** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala` — imported by **2** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala` — imported by **2** files
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — imported by **2** files
|
||||
- `modules/core/src/main/scala/de/nowchess/chess/controller/Parser.scala` — imported by **1** files
|
||||
|
||||
## Import Map (who imports what)
|
||||
|
||||
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` ← `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`, `modules/core/src/test/scala/de/nowchess/chess/command/CommandInvokerBranchTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/command/CommandInvokerTest.scala` +23 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/main/scala/de/nowchess/api/move/Move.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/api/src/test/scala/de/nowchess/api/move/MoveTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala` +16 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala` +14 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/board/BoardTest.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala` +9 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineGameEndingTest.scala` +8 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` ← `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineScenarioTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesStateTransitionsTest.scala` +5 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` ← `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesStateTransitionsTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesTest.scala` +4 more
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` ← `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineLoadGameTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineNotationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineScenarioTest.scala` +4 more
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` ← `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala` +3 more
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` ← `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParserCombinators.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParserFastParse.scala` +2 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala`, `modules/bot/src/main/scala/de/nowchess/bot/ai/Evaluation.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala` +55 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/board/BoardTest.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala` +35 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/main/scala/de/nowchess/api/move/Move.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/api/src/test/scala/de/nowchess/api/move/MoveTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala` +34 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala` +31 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +17 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala` +16 more
|
||||
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/ZobristHash.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +16 more
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/NNUEBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +12 more
|
||||
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/NNUEBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +5 more
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` ← `modules/bot/src/test/scala/de/nowchess/bot/PolyglotHashTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineLoadGameTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineNotationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala` +5 more
|
||||
|
||||
+117
-1
@@ -51,6 +51,113 @@
|
||||
- class ApiResponse
|
||||
- function error
|
||||
- function totalPages
|
||||
- `modules/bot/python/nnue.py`
|
||||
- function get_weights_dir: ()
|
||||
- function get_data_dir: ()
|
||||
- function list_checkpoints: ()
|
||||
- function migrate_legacy_data: ()
|
||||
- function show_header: ()
|
||||
- function show_checkpoints_table: ()
|
||||
- _...10 more_
|
||||
- `modules/bot/python/src/dataset.py`
|
||||
- function get_datasets_dir: () -> Path
|
||||
- function next_dataset_version: () -> int
|
||||
- function list_datasets: () -> List[Tuple[int, Dict]]
|
||||
- function load_dataset_metadata: (version) -> Optional[Dict]
|
||||
- function save_dataset_metadata: (version, metadata) -> None
|
||||
- function create_dataset: (version, labeled_jsonl_path, sources, stockfish_depth) -> Path
|
||||
- _...4 more_
|
||||
- `modules/bot/python/src/export.py` — function export_to_nbai: (weights_file, output_file, trained_by, train_loss)
|
||||
- `modules/bot/python/src/generate.py` — function play_random_game_and_collect_positions: (output_file, total_positions, samples_per_game, min_move, max_move, num_workers)
|
||||
- `modules/bot/python/src/label.py` — function normalize_evaluation: (cp_value, method, scale), function label_positions_with_stockfish: (positions_file, output_file, stockfish_path, batch_size, depth, verbose, normalize, num_workers)
|
||||
- `modules/bot/python/src/tactical_positions_extractor.py`
|
||||
- function download_and_extract_puzzle_db: (url, output_dir)
|
||||
- function extract_puzzle_positions: (puzzle_csv, max_puzzles) -> Set[str]
|
||||
- function load_positions_from_file: (file_path) -> Set[str]
|
||||
- function merge_positions: (tactical, other, output_file)
|
||||
- function extract_tactical_only: (puzzle_csv, output_file, max_puzzles) -> int
|
||||
- function interactive_merge_positions: (puzzle_csv, output_file, max_puzzles)
|
||||
- `modules/bot/python/src/train.py`
|
||||
- function fen_to_features: (fen)
|
||||
- function find_next_version: (base_name)
|
||||
- function save_metadata: (weights_file, metadata)
|
||||
- function train_nnue: (data_file, output_file, epochs, batch_size, lr, checkpoint, stockfish_depth, use_versioning, early_stopping_patience, weight_decay, subsample_ratio)
|
||||
- function burst_train: (data_file, output_file, duration_minutes, epochs_per_season, early_stopping_patience, batch_size, lr, initial_checkpoint, stockfish_depth, use_versioning, weight_decay, subsample_ratio)
|
||||
- class NNUEDataset
|
||||
- _...1 more_
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`
|
||||
- class Bot
|
||||
- function name
|
||||
- function nextMove
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/BotController.scala`
|
||||
- class BotController
|
||||
- function getBot
|
||||
- function listBots
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala`
|
||||
- class BotMoveRepetition
|
||||
- function blockedMoves
|
||||
- function repeatedMove
|
||||
- function filterAllowed
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/Config.scala` — class Config
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/ai/Evaluation.scala`
|
||||
- class Evaluation
|
||||
- class CHECKMATE_SCORE
|
||||
- class DRAW_SCORE
|
||||
- function evaluate
|
||||
- function initAccumulator
|
||||
- function copyAccumulator
|
||||
- _...2 more_
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`
|
||||
- class EvaluationClassic
|
||||
- function evaluate
|
||||
- function countRay
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala` — class EvaluationNNUE, function evaluate
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`
|
||||
- class NNUE
|
||||
- function initAccumulator
|
||||
- function pushAccumulator
|
||||
- function copyAccumulator
|
||||
- function recomputeAccumulator
|
||||
- function validateAccumulator
|
||||
- _...4 more_
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiLoader.scala`
|
||||
- class NbaiLoader
|
||||
- function load
|
||||
- function loadDefault
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiMigrator.scala` — class NbaiMigrator, function migrateFromBin
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiModel.scala`
|
||||
- function toJson
|
||||
- class NbaiMetadata
|
||||
- function fromJson
|
||||
- function str
|
||||
- function num
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiWriter.scala` — class NbaiWriter, function write
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`
|
||||
- function bestMove
|
||||
- function bestMove
|
||||
- function bestMoveWithTime
|
||||
- function bestMoveWithTime
|
||||
- function loop
|
||||
- function loop
|
||||
- _...2 more_
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`
|
||||
- class MoveOrdering
|
||||
- class OrderingContext
|
||||
- function addKillerMove
|
||||
- function getKillerMoves
|
||||
- function addHistory
|
||||
- function getHistory
|
||||
- _...3 more_
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/logic/TranspositionTable.scala`
|
||||
- function probe
|
||||
- function store
|
||||
- function clear
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotBook.scala` — function probe, function select
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala` — class PolyglotHash, function hash
|
||||
- `modules/bot/src/main/scala/de/nowchess/bot/util/ZobristHash.scala`
|
||||
- class ZobristHash
|
||||
- function hash
|
||||
- function nextHash
|
||||
- `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`
|
||||
- class Command
|
||||
- function execute
|
||||
@@ -73,7 +180,7 @@
|
||||
- function turn
|
||||
- function context
|
||||
- function canUndo
|
||||
- _...10 more_
|
||||
- _...11 more_
|
||||
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`
|
||||
- function context
|
||||
- class Observer
|
||||
@@ -84,6 +191,13 @@
|
||||
- _...1 more_
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — class GameContextExport, function exportGameContext
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — class GameContextImport, function importGameContext
|
||||
- `modules/io/src/main/scala/de/nowchess/io/GameFileService.scala`
|
||||
- class GameFileService
|
||||
- function saveGameToFile
|
||||
- function loadGameFromFile
|
||||
- class FileSystemGameService
|
||||
- function saveGameToFile
|
||||
- function loadGameFromFile
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala`
|
||||
- class FenExporter
|
||||
- function boardToFen
|
||||
@@ -105,6 +219,8 @@
|
||||
- function parseBoard
|
||||
- function importGameContext
|
||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — function buildSquares
|
||||
- `modules/io/src/main/scala/de/nowchess/io/json/JsonExporter.scala` — class JsonExporter, function exportGameContext
|
||||
- `modules/io/src/main/scala/de/nowchess/io/json/JsonParser.scala` — class JsonParser, function importGameContext
|
||||
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala`
|
||||
- class PgnExporter
|
||||
- function exportGameContext
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
# Middleware
|
||||
|
||||
## custom
|
||||
- generate — `modules/bot/python/src/generate.py`
|
||||
Generated
+2
@@ -8,3 +8,5 @@
|
||||
/dataSources.local.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
|
||||
sonarlint.xml
|
||||
|
||||
Generated
+133
@@ -0,0 +1,133 @@
|
||||
<component name="ProjectCodeStyleConfiguration">
|
||||
<code_scheme name="Project" version="173">
|
||||
<AndroidXmlCodeStyleSettings>
|
||||
<option name="USE_CUSTOM_SETTINGS" value="true" />
|
||||
</AndroidXmlCodeStyleSettings>
|
||||
<JetCodeStyleSettings>
|
||||
<option name="CODE_STYLE_DEFAULTS" value="KOTLIN_OFFICIAL" />
|
||||
</JetCodeStyleSettings>
|
||||
<ScalaCodeStyleSettings>
|
||||
<option name="FORMATTER" value="1" />
|
||||
</ScalaCodeStyleSettings>
|
||||
<XML>
|
||||
<option name="XML_KEEP_LINE_BREAKS" value="false" />
|
||||
<option name="XML_ALIGN_ATTRIBUTES" value="false" />
|
||||
<option name="XML_SPACE_INSIDE_EMPTY_TAG" value="true" />
|
||||
</XML>
|
||||
<codeStyleSettings language="XML">
|
||||
<option name="FORCE_REARRANGE_MODE" value="1" />
|
||||
<indentOptions>
|
||||
<option name="CONTINUATION_INDENT_SIZE" value="4" />
|
||||
</indentOptions>
|
||||
<arrangement>
|
||||
<rules>
|
||||
<section>
|
||||
<rule>
|
||||
<match>
|
||||
<AND>
|
||||
<NAME>xmlns:android</NAME>
|
||||
<XML_ATTRIBUTE />
|
||||
<XML_NAMESPACE>^$</XML_NAMESPACE>
|
||||
</AND>
|
||||
</match>
|
||||
</rule>
|
||||
</section>
|
||||
<section>
|
||||
<rule>
|
||||
<match>
|
||||
<AND>
|
||||
<NAME>xmlns:.*</NAME>
|
||||
<XML_ATTRIBUTE />
|
||||
<XML_NAMESPACE>^$</XML_NAMESPACE>
|
||||
</AND>
|
||||
</match>
|
||||
<order>BY_NAME</order>
|
||||
</rule>
|
||||
</section>
|
||||
<section>
|
||||
<rule>
|
||||
<match>
|
||||
<AND>
|
||||
<NAME>.*:id</NAME>
|
||||
<XML_ATTRIBUTE />
|
||||
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
|
||||
</AND>
|
||||
</match>
|
||||
</rule>
|
||||
</section>
|
||||
<section>
|
||||
<rule>
|
||||
<match>
|
||||
<AND>
|
||||
<NAME>.*:name</NAME>
|
||||
<XML_ATTRIBUTE />
|
||||
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
|
||||
</AND>
|
||||
</match>
|
||||
</rule>
|
||||
</section>
|
||||
<section>
|
||||
<rule>
|
||||
<match>
|
||||
<AND>
|
||||
<NAME>name</NAME>
|
||||
<XML_ATTRIBUTE />
|
||||
<XML_NAMESPACE>^$</XML_NAMESPACE>
|
||||
</AND>
|
||||
</match>
|
||||
</rule>
|
||||
</section>
|
||||
<section>
|
||||
<rule>
|
||||
<match>
|
||||
<AND>
|
||||
<NAME>style</NAME>
|
||||
<XML_ATTRIBUTE />
|
||||
<XML_NAMESPACE>^$</XML_NAMESPACE>
|
||||
</AND>
|
||||
</match>
|
||||
</rule>
|
||||
</section>
|
||||
<section>
|
||||
<rule>
|
||||
<match>
|
||||
<AND>
|
||||
<NAME>.*</NAME>
|
||||
<XML_ATTRIBUTE />
|
||||
<XML_NAMESPACE>^$</XML_NAMESPACE>
|
||||
</AND>
|
||||
</match>
|
||||
<order>BY_NAME</order>
|
||||
</rule>
|
||||
</section>
|
||||
<section>
|
||||
<rule>
|
||||
<match>
|
||||
<AND>
|
||||
<NAME>.*</NAME>
|
||||
<XML_ATTRIBUTE />
|
||||
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
|
||||
</AND>
|
||||
</match>
|
||||
</rule>
|
||||
</section>
|
||||
<section>
|
||||
<rule>
|
||||
<match>
|
||||
<AND>
|
||||
<NAME>.*</NAME>
|
||||
<XML_ATTRIBUTE />
|
||||
<XML_NAMESPACE>.*</XML_NAMESPACE>
|
||||
</AND>
|
||||
</match>
|
||||
<order>BY_NAME</order>
|
||||
</rule>
|
||||
</section>
|
||||
</rules>
|
||||
</arrangement>
|
||||
</codeStyleSettings>
|
||||
<codeStyleSettings language="kotlin">
|
||||
<option name="CODE_STYLE_DEFAULTS" value="KOTLIN_OFFICIAL" />
|
||||
</codeStyleSettings>
|
||||
</code_scheme>
|
||||
</component>
|
||||
Generated
+1
-1
@@ -1,5 +1,5 @@
|
||||
<component name="ProjectCodeStyleConfiguration">
|
||||
<state>
|
||||
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
|
||||
<option name="USE_PER_PROJECT_SETTINGS" value="true" />
|
||||
</state>
|
||||
</component>
|
||||
Generated
+1
-1
@@ -11,7 +11,7 @@
|
||||
<option value="$PROJECT_DIR$" />
|
||||
<option value="$PROJECT_DIR$/modules" />
|
||||
<option value="$PROJECT_DIR$/modules/api" />
|
||||
<option value="$PROJECT_DIR$/modules/backcore" />
|
||||
<option value="$PROJECT_DIR$/modules/bot" />
|
||||
<option value="$PROJECT_DIR$/modules/core" />
|
||||
<option value="$PROJECT_DIR$/modules/io" />
|
||||
<option value="$PROJECT_DIR$/modules/rule" />
|
||||
|
||||
Generated
+1
-1
@@ -5,7 +5,7 @@
|
||||
<option name="deprecationWarnings" value="true" />
|
||||
<option name="uncheckedWarnings" value="true" />
|
||||
</profile>
|
||||
<profile name="Gradle 2" modules="NowChessSystems.modules.backcore.integrationTest,NowChessSystems.modules.backcore.main,NowChessSystems.modules.backcore.native-test,NowChessSystems.modules.backcore.quarkus-generated-sources,NowChessSystems.modules.backcore.quarkus-test-generated-sources,NowChessSystems.modules.backcore.scoverage,NowChessSystems.modules.backcore.test,NowChessSystems.modules.core.main,NowChessSystems.modules.core.scoverage,NowChessSystems.modules.core.test,NowChessSystems.modules.io.main,NowChessSystems.modules.io.scoverage,NowChessSystems.modules.io.test,NowChessSystems.modules.rule.main,NowChessSystems.modules.rule.scoverage,NowChessSystems.modules.rule.test,NowChessSystems.modules.ui.main,NowChessSystems.modules.ui.scoverage,NowChessSystems.modules.ui.test">
|
||||
<profile name="Gradle 2" modules="NowChessSystems.modules.bot.main,NowChessSystems.modules.bot.scoverage,NowChessSystems.modules.bot.test,NowChessSystems.modules.core.main,NowChessSystems.modules.core.scoverage,NowChessSystems.modules.core.test,NowChessSystems.modules.io.main,NowChessSystems.modules.io.scoverage,NowChessSystems.modules.io.test,NowChessSystems.modules.rule.main,NowChessSystems.modules.rule.scoverage,NowChessSystems.modules.rule.test,NowChessSystems.modules.ui.main,NowChessSystems.modules.ui.scoverage,NowChessSystems.modules.ui.test">
|
||||
<option name="deprecationWarnings" value="true" />
|
||||
<option name="uncheckedWarnings" value="true" />
|
||||
<parameters>
|
||||
|
||||
Generated
-6
@@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ScalaProjectSettings">
|
||||
<option name="scala3DisclaimerShown" value="true" />
|
||||
</component>
|
||||
</project>
|
||||
@@ -19,6 +19,7 @@ Try to stick to these commands for consistency.
|
||||
| `api` | Model / shared types | (none) |
|
||||
| `core` | Primary business logic | api, rule |
|
||||
| `rule` | Game rules | api |
|
||||
| `bot` | Bots and AI | api,rule,io |
|
||||
| `io` | Export formats | api, core |
|
||||
| `ui` | Entrypoint & UI | core, io |
|
||||
|
||||
@@ -47,6 +48,9 @@ Try to stick to these commands for consistency.
|
||||
- **Immutable state as primary model:** GameContext (api) holds board, history, player state — immutable, passed through the system. Each move creates a new GameContext, enabling undo/redo without side effects.
|
||||
- **Observer pattern for UI decoupling:** GameEngine publishes move/state events; CommandInvoker queues moves; UI listens to events, not polling. GameEngine never imports UI code.
|
||||
- **RuleSet trait encapsulates rules:** Move generation, check, castling, en passant all in RuleSet impl. GameEngine calls rules as a black box; rules don't know about the rest of core.
|
||||
- **Polyglot hash must follow spec index layout:** piece keys use interleaved mapping `(pieceType * 2 + colorBit)` with black=0/white=1, castling keys are `768..771`, en-passant file keys are `772..779` and are XORed only if side-to-move has a pawn that can capture en passant, side-to-move key is `780` for white.
|
||||
- **Alpha-beta uses sequential PV search by default:** parallel split was disabled because fixed-window futures removed pruning effectiveness; correctness and pruning quality take priority over speculative parallelism.
|
||||
- **Search hash is updated incrementally per move:** bot search now updates Zobrist keys from parent hash with move deltas instead of recomputing piece scans at every node.
|
||||
|
||||
## Rules
|
||||
|
||||
|
||||
+21
-8
@@ -21,15 +21,27 @@ sonar {
|
||||
if (report.exists()) report.absolutePath else null
|
||||
}.joinToString(",")
|
||||
|
||||
val jacocoReports = subprojects.mapNotNull { subproject ->
|
||||
val report = subproject.file("build/reports/jacoco/test/jacocoTestReport.xml")
|
||||
if (report.exists()) report.absolutePath else null
|
||||
}.joinToString(",")
|
||||
|
||||
property("sonar.scala.coverage.reportPaths", scoverageReports)
|
||||
if (jacocoReports.isNotEmpty()) {
|
||||
property("sonar.coverage.jacoco.xmlReportPaths", jacocoReports)
|
||||
}
|
||||
property(
|
||||
"sonar.coverage.exclusions",
|
||||
// UI renders JavaFX components; headless test environments cannot exercise rendering paths
|
||||
"modules/ui/**," +
|
||||
// FastParse macro-generated combinators produce synthetic branches that scoverage marks as uncovered
|
||||
"modules/io/src/main/scala/de/nowchess/io/fen/FenParserFastParse*," +
|
||||
// NNUE inference pipeline — coverage requires a trained model file not present in CI
|
||||
"**/bot/**/NNUE.scala," +
|
||||
"**/bot/**/NNUEBot.scala," +
|
||||
"**/bot/**/EvaluationNNUE.scala," +
|
||||
// NBAI binary format loader/writer — error paths require crafted corrupt files; migrator is a one-shot tool
|
||||
"**/bot/**/NbaiLoader.scala," +
|
||||
"**/bot/**/NbaiModel.scala," +
|
||||
"**/bot/**/NbaiMigrator.scala," +
|
||||
"**/bot/**/NbaiWriter.scala," +
|
||||
// PolyglotBook — binary I/O and dead-code guards (bit-masked fields can never exceed valid range)
|
||||
"**/bot/**/PolyglotBook.scala," +
|
||||
"**/bot/**/MoveOrdering.scala," +
|
||||
"**/bot/**/AlphaBetaSearch.scala"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,6 +55,7 @@ val versions = mapOf(
|
||||
"SCALAFX" to "21.0.0-R32",
|
||||
"JAVAFX" to "21.0.1",
|
||||
"JUNIT_BOM" to "5.13.4",
|
||||
"ONNXRUNTIME" to "1.19.2",
|
||||
"SCALA_PARSER_COMBINATORS" to "2.4.0",
|
||||
"FASTPARSE" to "3.0.2",
|
||||
"JACKSON" to "2.17.2",
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
quarkusPluginId=io.quarkus
|
||||
quarkusPluginVersion=3.32.4
|
||||
quarkusPlatformGroupId=io.quarkus.platform
|
||||
quarkusPlatformArtifactId=quarkus-bom
|
||||
quarkusPlatformVersion=3.32.4
|
||||
@@ -1,5 +1,5 @@
|
||||
import glob,re
|
||||
mods=['api','core','io','rule','ui']
|
||||
mods=['api','core','io','rule','ui', 'bot']
|
||||
tot=0
|
||||
for m in mods:
|
||||
s=0
|
||||
|
||||
@@ -34,3 +34,11 @@
|
||||
* NCS-14 implemented insufficient moves rule ([#30](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/30)) ([b0399a4](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/b0399a4e489950083066c9538df9a84dcc7a4613))
|
||||
* NCS-21 Write Scripts to automate certain tasks ([#15](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/15)) ([8051871](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/80518719d536a087d339fe02530825dc07f8b388))
|
||||
* NCS-25 Add linters to keep quality up ([#27](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/27)) ([fd4e67d](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/fd4e67d4f782a7e955822d90cb909d0a81676fb2))
|
||||
## (2026-04-16)
|
||||
|
||||
### Features
|
||||
|
||||
* NCS-13 Implement Threefold Repetition ([#31](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/31)) ([767d305](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/767d3051a76c266050b6335774d66e2db2273c16))
|
||||
* NCS-14 implemented insufficient moves rule ([#30](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/30)) ([b0399a4](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/b0399a4e489950083066c9538df9a84dcc7a4613))
|
||||
* NCS-21 Write Scripts to automate certain tasks ([#15](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/15)) ([8051871](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/80518719d536a087d339fe02530825dc07f8b388))
|
||||
* NCS-25 Add linters to keep quality up ([#27](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/27)) ([fd4e67d](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/fd4e67d4f782a7e955822d90cb909d0a81676fb2))
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package de.nowchess.api.bot
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
|
||||
trait Bot {
|
||||
|
||||
def name: String
|
||||
def nextMove(context: GameContext): Option[Move]
|
||||
|
||||
}
|
||||
@@ -5,4 +5,5 @@ enum DrawReason:
|
||||
case Stalemate
|
||||
case InsufficientMaterial
|
||||
case FiftyMoveRule
|
||||
case ThreefoldRepetition
|
||||
case Agreement
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package de.nowchess.api.game
|
||||
|
||||
import de.nowchess.api.board.{Board, CastlingRights, Color, Square}
|
||||
import de.nowchess.api.board.{Board, CastlingRights, Color, PieceType, Square}
|
||||
import de.nowchess.api.move.Move
|
||||
|
||||
/** Immutable bundle of complete game state. All state changes produce new GameContext instances.
|
||||
@@ -13,7 +13,15 @@ case class GameContext(
|
||||
halfMoveClock: Int,
|
||||
moves: List[Move],
|
||||
result: Option[GameResult] = None,
|
||||
initialBoard: Board = Board.initial,
|
||||
):
|
||||
private lazy val whiteKingSquare: Option[Square] =
|
||||
board.pieces.find((_, p) => p.color == Color.White && p.pieceType == PieceType.King).map(_._1)
|
||||
private lazy val blackKingSquare: Option[Square] =
|
||||
board.pieces.find((_, p) => p.color == Color.Black && p.pieceType == PieceType.King).map(_._1)
|
||||
def kingSquare(color: Color): Option[Square] =
|
||||
if color == Color.White then whiteKingSquare else blackKingSquare
|
||||
|
||||
/** Create new context with updated board. */
|
||||
def withBoard(newBoard: Board): GameContext = copy(board = newBoard)
|
||||
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
package de.nowchess.api.game
|
||||
|
||||
import de.nowchess.api.bot.Bot
|
||||
import de.nowchess.api.player.PlayerInfo
|
||||
|
||||
sealed trait Participant
|
||||
final case class Human(playerInfo: PlayerInfo) extends Participant
|
||||
final case class BotParticipant(bot: Bot) extends Participant
|
||||
@@ -71,3 +71,9 @@ class GameContextTest extends AnyFunSuite with Matchers:
|
||||
test("withResult clears result"):
|
||||
val ctx = GameContext.initial.withResult(Some(GameResult.Win(Color.Black)))
|
||||
ctx.withResult(None).result shouldBe None
|
||||
|
||||
test("kingSquare returns white king position"):
|
||||
GameContext.initial.kingSquare(Color.White) shouldBe Some(Square(File.E, Rank.R1))
|
||||
|
||||
test("kingSquare returns black king position"):
|
||||
GameContext.initial.kingSquare(Color.Black) shouldBe Some(Square(File.E, Rank.R8))
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
MAJOR=0
|
||||
MINOR=5
|
||||
MINOR=6
|
||||
PATCH=0
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
plugins {
|
||||
id("scala")
|
||||
id("org.scoverage") version "8.1"
|
||||
id("io.quarkus")
|
||||
id("jacoco")
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val versions = rootProject.extra["VERSIONS"] as Map<String, String>
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
mavenLocal()
|
||||
}
|
||||
|
||||
scala {
|
||||
scalaVersion = versions["SCALA3"]!!
|
||||
}
|
||||
|
||||
scoverage {
|
||||
scoverageVersion.set(versions["SCOVERAGE"]!!)
|
||||
}
|
||||
|
||||
tasks.withType<ScalaCompile> {
|
||||
scalaCompileOptions.additionalParameters = listOf("-encoding", "UTF-8")
|
||||
}
|
||||
|
||||
val quarkusPlatformGroupId: String by project
|
||||
val quarkusPlatformArtifactId: String by project
|
||||
val quarkusPlatformVersion: String by project
|
||||
|
||||
dependencies {
|
||||
implementation(project(":modules:api"))
|
||||
implementation(project(":modules:core"))
|
||||
implementation(project(":modules:io"))
|
||||
implementation(project(":modules:rule"))
|
||||
|
||||
implementation(enforcedPlatform("${quarkusPlatformGroupId}:${quarkusPlatformArtifactId}:${quarkusPlatformVersion}"))
|
||||
implementation("io.quarkus:quarkus-rest")
|
||||
implementation("io.quarkus:quarkus-rest-jackson")
|
||||
implementation("io.quarkus:quarkus-config-yaml")
|
||||
implementation("io.quarkus:quarkus-arc")
|
||||
|
||||
implementation("com.fasterxml.jackson.module:jackson-module-scala_3:${versions["JACKSON_SCALA"]!!}")
|
||||
|
||||
testImplementation(platform("org.junit:junit-bom:${versions["JUNIT_BOM"]!!}"))
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testImplementation("org.scalatest:scalatest_3:${versions["SCALATEST"]!!}")
|
||||
testImplementation("co.helmethair:scalatest-junit-runner:${versions["SCALATEST_JUNIT"]!!}")
|
||||
testImplementation("io.quarkus:quarkus-junit5")
|
||||
testImplementation("io.quarkus:quarkus-jacoco")
|
||||
testImplementation("io.rest-assured:rest-assured")
|
||||
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
||||
}
|
||||
|
||||
configurations.matching { !it.name.startsWith("scoverage") }.configureEach {
|
||||
resolutionStrategy.force("org.scala-lang:scala-library:${versions["SCALA_LIBRARY"]!!}")
|
||||
}
|
||||
configurations.scoverage {
|
||||
resolutionStrategy.eachDependency {
|
||||
if (requested.group == "org.scoverage" && requested.name.startsWith("scalac-scoverage-plugin_")) {
|
||||
useTarget("${requested.group}:scalac-scoverage-plugin_2.13.16:2.3.0")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
group = "de.nowchess"
|
||||
version = "1.0-SNAPSHOT"
|
||||
|
||||
tasks.withType<JavaCompile> {
|
||||
options.encoding = "UTF-8"
|
||||
options.compilerArgs.add("-parameters")
|
||||
}
|
||||
|
||||
tasks.withType<Jar>().configureEach {
|
||||
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
|
||||
}
|
||||
|
||||
tasks.test {
|
||||
useJUnitPlatform {
|
||||
includeEngines("scalatest", "junit-jupiter")
|
||||
}
|
||||
testLogging {
|
||||
events("passed", "skipped", "failed")
|
||||
}
|
||||
finalizedBy(tasks.named("jacocoTestReport"))
|
||||
}
|
||||
|
||||
tasks.jacocoTestReport {
|
||||
dependsOn(tasks.test)
|
||||
executionData.setFrom(layout.buildDirectory.file("jacoco-quarkus.exec"))
|
||||
sourceDirectories.setFrom(files("src/main/scala"))
|
||||
classDirectories.setFrom(files(layout.buildDirectory.dir("classes/scala/main")))
|
||||
reports {
|
||||
xml.required.set(true)
|
||||
xml.outputLocation.set(
|
||||
layout.buildDirectory.file("reports/jacoco/test/jacocoTestReport.xml")
|
||||
)
|
||||
html.required.set(false)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
quarkus:
|
||||
http:
|
||||
port: 8080
|
||||
jacoco:
|
||||
data-file: ${user.dir}/build/jacoco-quarkus.exec
|
||||
report: false
|
||||
@@ -1,11 +0,0 @@
|
||||
package de.nowchess.backcore.config
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper
|
||||
import com.fasterxml.jackson.module.scala.DefaultScalaModule
|
||||
import io.quarkus.jackson.ObjectMapperCustomizer
|
||||
import jakarta.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class JacksonConfig extends ObjectMapperCustomizer:
|
||||
def customize(mapper: ObjectMapper): Unit =
|
||||
mapper.registerModule(DefaultScalaModule)
|
||||
@@ -1,57 +0,0 @@
|
||||
package de.nowchess.backcore.dto
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include
|
||||
|
||||
case class PlayerInfoDto(id: String, displayName: String)
|
||||
|
||||
case class GameStateResponse(
|
||||
fen: String,
|
||||
pgn: String,
|
||||
turn: String,
|
||||
status: String,
|
||||
@JsonInclude(Include.NON_ABSENT) winner: Option[String],
|
||||
moves: List[String],
|
||||
undoAvailable: Boolean,
|
||||
redoAvailable: Boolean,
|
||||
)
|
||||
|
||||
case class GameFullResponse(
|
||||
gameId: String,
|
||||
white: PlayerInfoDto,
|
||||
black: PlayerInfoDto,
|
||||
state: GameStateResponse,
|
||||
)
|
||||
|
||||
case class OkResponse(ok: Boolean = true)
|
||||
|
||||
@JsonInclude(Include.NON_ABSENT)
|
||||
case class ApiErrorResponse(
|
||||
code: String,
|
||||
message: String,
|
||||
field: Option[String] = None,
|
||||
)
|
||||
|
||||
// Requests
|
||||
case class CreateGameRequest(
|
||||
white: Option[PlayerInfoDto] = None,
|
||||
black: Option[PlayerInfoDto] = None,
|
||||
)
|
||||
|
||||
case class ImportFenRequest(
|
||||
fen: String = "",
|
||||
white: Option[PlayerInfoDto] = None,
|
||||
black: Option[PlayerInfoDto] = None,
|
||||
)
|
||||
|
||||
case class ImportPgnRequest(pgn: String = "")
|
||||
|
||||
case class LegalMoveDto(
|
||||
from: String,
|
||||
to: String,
|
||||
uci: String,
|
||||
moveType: String,
|
||||
@JsonInclude(Include.NON_ABSENT) promotion: Option[String] = None,
|
||||
)
|
||||
|
||||
case class LegalMovesResponse(moves: List[LegalMoveDto])
|
||||
@@ -1,10 +0,0 @@
|
||||
package de.nowchess.backcore.game
|
||||
|
||||
import java.security.SecureRandom
|
||||
|
||||
object GameId:
|
||||
private val chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
private val random = SecureRandom()
|
||||
|
||||
def generate(): String =
|
||||
(1 to 8).map(_ => chars(random.nextInt(chars.length))).mkString
|
||||
@@ -1,80 +0,0 @@
|
||||
package de.nowchess.backcore.game
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper
|
||||
import com.fasterxml.jackson.module.scala.DefaultScalaModule
|
||||
import de.nowchess.api.board.Color
|
||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||
import de.nowchess.backcore.dto.*
|
||||
import de.nowchess.io.fen.FenExporter
|
||||
import de.nowchess.io.pgn.PgnExporter
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
|
||||
object GameMapper:
|
||||
private val mapper = new ObjectMapper().registerModule(DefaultScalaModule)
|
||||
|
||||
def toGameFullJson(session: GameSession): String =
|
||||
mapper.writeValueAsString(toGameFull(session))
|
||||
|
||||
def toGameFull(session: GameSession): GameFullResponse =
|
||||
GameFullResponse(
|
||||
gameId = session.gameId,
|
||||
white = toPlayerInfo(session.white),
|
||||
black = toPlayerInfo(session.black),
|
||||
state = toGameState(session),
|
||||
)
|
||||
|
||||
def toGameState(session: GameSession): GameStateResponse =
|
||||
val (status, winner) = computeStatus(session)
|
||||
GameStateResponse(
|
||||
fen = FenExporter.exportGameContext(session.context),
|
||||
pgn = buildPgn(session.context.moves),
|
||||
turn = if session.context.turn == Color.White then "white" else "black",
|
||||
status = status,
|
||||
winner = winner,
|
||||
moves = session.context.moves.map(moveToUci),
|
||||
undoAvailable = session.invoker.canUndo,
|
||||
redoAvailable = session.invoker.canRedo,
|
||||
)
|
||||
|
||||
private def toPlayerInfo(p: de.nowchess.api.player.PlayerInfo): PlayerInfoDto =
|
||||
PlayerInfoDto(id = p.id.value, displayName = p.displayName)
|
||||
|
||||
private def computeStatus(session: GameSession): (String, Option[String]) =
|
||||
session.result match
|
||||
case Some(GameResult.Checkmate(winner)) =>
|
||||
val w = if winner == Color.White then "white" else "black"
|
||||
("checkmate", Some(w))
|
||||
case Some(GameResult.Stalemate) =>
|
||||
("stalemate", None)
|
||||
case Some(GameResult.Resign(winner)) =>
|
||||
val w = if winner == Color.White then "white" else "black"
|
||||
("resign", Some(w))
|
||||
case Some(GameResult.AgreedDraw) | Some(GameResult.FiftyMoveDraw) =>
|
||||
("draw", None)
|
||||
case Some(GameResult.InsufficientMaterial) =>
|
||||
("insufficientMaterial", None)
|
||||
case None =>
|
||||
computeLiveStatus(session)
|
||||
|
||||
private def computeLiveStatus(session: GameSession): (String, Option[String]) =
|
||||
val ctx = session.context
|
||||
if DefaultRules.isCheck(ctx) then ("check", None)
|
||||
else if session.drawOfferedBy.isDefined then ("drawOffered", None)
|
||||
else if DefaultRules.isFiftyMoveRule(ctx) then ("fiftyMoveAvailable", None)
|
||||
else ("started", None)
|
||||
|
||||
def moveToUci(move: Move): String =
|
||||
val base = s"${move.from}${move.to}"
|
||||
move.moveType match
|
||||
case MoveType.Promotion(piece) =>
|
||||
val suffix = piece match
|
||||
case PromotionPiece.Queen => "q"
|
||||
case PromotionPiece.Rook => "r"
|
||||
case PromotionPiece.Bishop => "b"
|
||||
case PromotionPiece.Knight => "n"
|
||||
base + suffix
|
||||
case _ => base
|
||||
|
||||
private def buildPgn(moves: List[Move]): String =
|
||||
// Use PgnExporter with no headers to get move-text only (SAN notation)
|
||||
PgnExporter.exportGame(Map.empty, moves)
|
||||
@@ -1,12 +0,0 @@
|
||||
package de.nowchess.backcore.game
|
||||
|
||||
import de.nowchess.api.board.Color
|
||||
|
||||
sealed trait GameResult
|
||||
object GameResult:
|
||||
case class Checkmate(winner: Color) extends GameResult
|
||||
case object Stalemate extends GameResult
|
||||
case class Resign(winner: Color) extends GameResult
|
||||
case object AgreedDraw extends GameResult
|
||||
case object FiftyMoveDraw extends GameResult
|
||||
case object InsufficientMaterial extends GameResult
|
||||
@@ -1,16 +0,0 @@
|
||||
package de.nowchess.backcore.game
|
||||
|
||||
import de.nowchess.api.board.Color
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.player.PlayerInfo
|
||||
import de.nowchess.chess.command.CommandInvoker
|
||||
|
||||
case class GameSession(
|
||||
gameId: String,
|
||||
white: PlayerInfo,
|
||||
black: PlayerInfo,
|
||||
context: GameContext,
|
||||
invoker: CommandInvoker,
|
||||
drawOfferedBy: Option[Color] = None,
|
||||
result: Option[GameResult] = None,
|
||||
)
|
||||
@@ -1,260 +0,0 @@
|
||||
package de.nowchess.backcore.game
|
||||
|
||||
import de.nowchess.api.board.{Color, Square}
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||
import de.nowchess.api.player.{PlayerId, PlayerInfo}
|
||||
import de.nowchess.backcore.dto.{CreateGameRequest, ImportFenRequest, PlayerInfoDto}
|
||||
import de.nowchess.chess.command.{CommandInvoker, MoveCommand, MoveResult}
|
||||
import de.nowchess.io.fen.FenParser
|
||||
import de.nowchess.io.pgn.PgnParser
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
import jakarta.enterprise.context.ApplicationScoped
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
@ApplicationScoped
|
||||
class GameStore:
|
||||
private val games: mutable.Map[String, GameSession] = mutable.Map.empty
|
||||
|
||||
// ─── Create / Get ────────────────────────────────────────────────
|
||||
|
||||
def create(req: CreateGameRequest): GameSession = synchronized:
|
||||
val id = generateId()
|
||||
val session = newSession(id, req.white, req.black, GameContext.initial)
|
||||
games(id) = session
|
||||
session
|
||||
|
||||
def get(id: String): Option[GameSession] = synchronized:
|
||||
games.get(id)
|
||||
|
||||
// ─── Move-making ─────────────────────────────────────────────────
|
||||
|
||||
def applyMove(id: String, uci: String): Either[String, GameSession] = synchronized:
|
||||
withSession(id): session =>
|
||||
if session.result.isDefined then Left("Game is already over")
|
||||
else
|
||||
parseUci(uci) match
|
||||
case None => Left(s"Invalid UCI notation: $uci")
|
||||
case Some((from, to, promotion)) =>
|
||||
val legalCandidates = DefaultRules.legalMoves(session.context)(from)
|
||||
findMatchingMove(legalCandidates, to, promotion) match
|
||||
case None => Left(s"$uci is not a legal move")
|
||||
case Some(move) =>
|
||||
val nextCtx = DefaultRules.applyMove(session.context)(move)
|
||||
val prevCtx = session.context
|
||||
val cmd = MoveCommand(
|
||||
from = move.from,
|
||||
to = move.to,
|
||||
moveResult = Some(MoveResult.Successful(nextCtx, prevCtx.board.pieceAt(move.to))),
|
||||
previousContext = Some(prevCtx),
|
||||
)
|
||||
session.invoker.execute(cmd)
|
||||
val result = detectGameOver(nextCtx)
|
||||
val updated = session.copy(context = nextCtx, result = result)
|
||||
games(id) = updated
|
||||
Right(updated)
|
||||
|
||||
def legalMoves(id: String, square: Option[Square]): Either[String, List[Move]] = synchronized:
|
||||
withSession(id): session =>
|
||||
val moves = square match
|
||||
case Some(sq) => DefaultRules.legalMoves(session.context)(sq)
|
||||
case None => DefaultRules.allLegalMoves(session.context)
|
||||
Right(moves)
|
||||
|
||||
// ─── Undo / Redo ─────────────────────────────────────────────────
|
||||
|
||||
def undo(id: String): Either[String, GameSession] = synchronized:
|
||||
withSession(id): session =>
|
||||
if !session.invoker.canUndo then Left("No moves to undo")
|
||||
else
|
||||
val idx = session.invoker.getCurrentIndex
|
||||
session.invoker.history(idx) match
|
||||
case cmd: MoveCommand =>
|
||||
cmd.previousContext match
|
||||
case None => Left("Cannot undo: no previous context stored")
|
||||
case Some(prevCtx) =>
|
||||
session.invoker.undo()
|
||||
val updated = session.copy(context = prevCtx, result = None, drawOfferedBy = None)
|
||||
games(id) = updated
|
||||
Right(updated)
|
||||
case _ => Left("Cannot undo this command type")
|
||||
|
||||
def redo(id: String): Either[String, GameSession] = synchronized:
|
||||
withSession(id): session =>
|
||||
if !session.invoker.canRedo then Left("No moves to redo")
|
||||
else
|
||||
val idx = session.invoker.getCurrentIndex + 1
|
||||
session.invoker.history(idx) match
|
||||
case cmd: MoveCommand =>
|
||||
cmd.moveResult match
|
||||
case Some(MoveResult.Successful(nextCtx, _)) =>
|
||||
session.invoker.redo()
|
||||
val result = detectGameOver(nextCtx)
|
||||
val updated = session.copy(context = nextCtx, result = result)
|
||||
games(id) = updated
|
||||
Right(updated)
|
||||
case _ => Left("Cannot redo: move result not available")
|
||||
case _ => Left("Cannot redo this command type")
|
||||
|
||||
// ─── Resign ──────────────────────────────────────────────────────
|
||||
|
||||
def resign(id: String): Either[String, GameSession] = synchronized:
|
||||
withSession(id): session =>
|
||||
if session.result.isDefined then Left("Game is already over")
|
||||
else
|
||||
val winner = session.context.turn.opposite
|
||||
val updated = session.copy(result = Some(GameResult.Resign(winner)))
|
||||
games(id) = updated
|
||||
Right(updated)
|
||||
|
||||
// ─── Draw actions ────────────────────────────────────────────────
|
||||
|
||||
def drawAction(id: String, action: String): Either[String, GameSession] = synchronized:
|
||||
withSession(id): session =>
|
||||
if session.result.isDefined then Left("Game is already over")
|
||||
else
|
||||
action match
|
||||
case "offer" =>
|
||||
val updated = session.copy(drawOfferedBy = Some(session.context.turn))
|
||||
games(id) = updated
|
||||
Right(updated)
|
||||
case "accept" =>
|
||||
session.drawOfferedBy match
|
||||
case None => Left("No draw offer to accept")
|
||||
case Some(offerer) if offerer == session.context.turn =>
|
||||
Left("Cannot accept your own draw offer")
|
||||
case Some(_) =>
|
||||
val updated = session.copy(result = Some(GameResult.AgreedDraw), drawOfferedBy = None)
|
||||
games(id) = updated
|
||||
Right(updated)
|
||||
case "decline" =>
|
||||
session.drawOfferedBy match
|
||||
case None => Left("No draw offer to decline")
|
||||
case Some(_) =>
|
||||
val updated = session.copy(drawOfferedBy = None)
|
||||
games(id) = updated
|
||||
Right(updated)
|
||||
case "claim" =>
|
||||
if DefaultRules.isFiftyMoveRule(session.context) then
|
||||
val updated = session.copy(result = Some(GameResult.FiftyMoveDraw))
|
||||
games(id) = updated
|
||||
Right(updated)
|
||||
else Left("Fifty-move rule has not been triggered")
|
||||
case other => Left(s"Unknown draw action: $other")
|
||||
|
||||
// ─── Import ──────────────────────────────────────────────────────
|
||||
|
||||
def importFen(req: ImportFenRequest): Either[String, GameSession] = synchronized:
|
||||
FenParser.parseFen(req.fen) match
|
||||
case Left(err) => Left(err)
|
||||
case Right(ctx) =>
|
||||
val id = generateId()
|
||||
val session = newSession(id, req.white, req.black, ctx)
|
||||
games(id) = session
|
||||
Right(session)
|
||||
|
||||
def importPgn(pgn: String, white: Option[PlayerInfoDto], black: Option[PlayerInfoDto]): Either[String, GameSession] =
|
||||
synchronized:
|
||||
PgnParser.validatePgn(pgn) match
|
||||
case Left(err) => Left(err)
|
||||
case Right(game) =>
|
||||
val id = generateId()
|
||||
val session = newSession(id, white, black, GameContext.initial)
|
||||
replayIntoSession(session, game.moves, GameContext.initial) match
|
||||
case Left(err) => Left(err)
|
||||
case Right(s) =>
|
||||
games(id) = s
|
||||
Right(s)
|
||||
|
||||
// ─── Private helpers ─────────────────────────────────────────────
|
||||
|
||||
private def withSession[A](id: String)(f: GameSession => Either[String, A]): Either[String, A] =
|
||||
games.get(id) match
|
||||
case None => Left(s"Game $id not found")
|
||||
case Some(session) => f(session)
|
||||
|
||||
private def generateId(): String =
|
||||
var id = GameId.generate()
|
||||
while games.contains(id) do id = GameId.generate()
|
||||
id
|
||||
|
||||
private def newSession(
|
||||
id: String,
|
||||
white: Option[PlayerInfoDto],
|
||||
black: Option[PlayerInfoDto],
|
||||
ctx: GameContext,
|
||||
): GameSession =
|
||||
GameSession(
|
||||
gameId = id,
|
||||
white = toPlayerInfo(white, "white", "White"),
|
||||
black = toPlayerInfo(black, "black", "Black"),
|
||||
context = ctx,
|
||||
invoker = new CommandInvoker(),
|
||||
)
|
||||
|
||||
private def toPlayerInfo(dto: Option[PlayerInfoDto], defaultId: String, defaultName: String): PlayerInfo =
|
||||
dto.fold(PlayerInfo(PlayerId(defaultId), defaultName))(d => PlayerInfo(PlayerId(d.id), d.displayName))
|
||||
|
||||
private def parseUci(uci: String): Option[(Square, Square, Option[PromotionPiece])] =
|
||||
if uci.length < 4 || uci.length > 5 then None
|
||||
else
|
||||
for
|
||||
from <- Square.fromAlgebraic(uci.substring(0, 2))
|
||||
to <- Square.fromAlgebraic(uci.substring(2, 4))
|
||||
yield
|
||||
val promotion = if uci.length == 5 then parsePromotionChar(uci.charAt(4)) else None
|
||||
(from, to, promotion)
|
||||
|
||||
private def parsePromotionChar(c: Char): Option[PromotionPiece] =
|
||||
c match
|
||||
case 'q' => Some(PromotionPiece.Queen)
|
||||
case 'r' => Some(PromotionPiece.Rook)
|
||||
case 'b' => Some(PromotionPiece.Bishop)
|
||||
case 'n' => Some(PromotionPiece.Knight)
|
||||
case _ => None
|
||||
|
||||
private def findMatchingMove(
|
||||
candidates: List[Move],
|
||||
to: Square,
|
||||
promotion: Option[PromotionPiece],
|
||||
): Option[Move] =
|
||||
candidates.filter(_.to == to) match
|
||||
case Nil => None
|
||||
case moves =>
|
||||
promotion match
|
||||
case Some(pp) => moves.find(_.moveType == MoveType.Promotion(pp))
|
||||
case None =>
|
||||
moves
|
||||
.find(m => !m.moveType.isInstanceOf[MoveType.Promotion])
|
||||
.orElse(moves.headOption)
|
||||
|
||||
private def detectGameOver(ctx: GameContext): Option[GameResult] =
|
||||
if DefaultRules.isCheckmate(ctx) then Some(GameResult.Checkmate(ctx.turn.opposite))
|
||||
else if DefaultRules.isStalemate(ctx) then Some(GameResult.Stalemate)
|
||||
else if DefaultRules.isInsufficientMaterial(ctx) then Some(GameResult.InsufficientMaterial)
|
||||
else None
|
||||
|
||||
private def replayIntoSession(
|
||||
session: GameSession,
|
||||
moves: List[Move],
|
||||
startCtx: GameContext,
|
||||
): Either[String, GameSession] =
|
||||
moves.foldLeft[Either[String, GameSession]](Right(session)):
|
||||
case (Left(err), _) => Left(err)
|
||||
case (Right(s), move) =>
|
||||
val legal = DefaultRules.legalMoves(s.context)(move.from)
|
||||
legal
|
||||
.find(m => m.from == move.from && m.to == move.to && m.moveType == move.moveType)
|
||||
.orElse(legal.find(m => m.from == move.from && m.to == move.to)) match
|
||||
case None => Left(s"Illegal move in PGN: $move")
|
||||
case Some(legalMove) =>
|
||||
val nextCtx = DefaultRules.applyMove(s.context)(legalMove)
|
||||
val cmd = MoveCommand(
|
||||
from = legalMove.from,
|
||||
to = legalMove.to,
|
||||
moveResult = Some(MoveResult.Successful(nextCtx, s.context.board.pieceAt(legalMove.to))),
|
||||
previousContext = Some(s.context),
|
||||
)
|
||||
s.invoker.execute(cmd)
|
||||
Right(s.copy(context = nextCtx))
|
||||
@@ -1,99 +0,0 @@
|
||||
package de.nowchess.backcore.resource
|
||||
|
||||
import de.nowchess.backcore.dto.*
|
||||
import de.nowchess.backcore.game.{GameMapper, GameStore}
|
||||
import jakarta.enterprise.context.ApplicationScoped
|
||||
import jakarta.inject.Inject
|
||||
import jakarta.ws.rs.*
|
||||
import jakarta.ws.rs.core.{MediaType, Response}
|
||||
|
||||
@Path("/api/board/game")
|
||||
@Produces(Array(MediaType.APPLICATION_JSON))
|
||||
@ApplicationScoped
|
||||
class GameResource @Inject() (store: GameStore):
|
||||
|
||||
@POST
|
||||
@Consumes(Array(MediaType.APPLICATION_JSON))
|
||||
def createGame(req: CreateGameRequest): Response =
|
||||
val session = store.create(Option(req).getOrElse(CreateGameRequest()))
|
||||
Response.status(201).entity(GameMapper.toGameFull(session)).build()
|
||||
|
||||
@GET
|
||||
@Path("/{gameId}")
|
||||
def getGame(@PathParam("gameId") gameId: String): Response =
|
||||
store.get(gameId) match
|
||||
case Some(session) => Response.ok(GameMapper.toGameFull(session)).build()
|
||||
case None =>
|
||||
Response
|
||||
.status(404)
|
||||
.entity(ApiErrorResponse("GAME_NOT_FOUND", s"Game $gameId not found"))
|
||||
.build()
|
||||
|
||||
@GET
|
||||
@Path("/{gameId}/stream")
|
||||
@Produces(Array("application/x-ndjson"))
|
||||
def streamGame(@PathParam("gameId") gameId: String): Response =
|
||||
store.get(gameId) match
|
||||
case None =>
|
||||
Response
|
||||
.status(404)
|
||||
.`type`(MediaType.APPLICATION_JSON)
|
||||
.entity(ApiErrorResponse("GAME_NOT_FOUND", s"Game $gameId not found"))
|
||||
.build()
|
||||
case Some(session) =>
|
||||
// Simplified: return a single-line NDJSON snapshot of the current game state
|
||||
val event = s"""{"type":"gameFull","game":${GameMapper.toGameFullJson(session)}}"""
|
||||
Response.ok(event + "\n").build()
|
||||
|
||||
@POST
|
||||
@Path("/{gameId}/resign")
|
||||
def resignGame(@PathParam("gameId") gameId: String): Response =
|
||||
store.resign(gameId) match
|
||||
case Right(_) => Response.ok(OkResponse()).build()
|
||||
case Left(err) if err.contains("not found") =>
|
||||
Response.status(404).entity(ApiErrorResponse("GAME_NOT_FOUND", err)).build()
|
||||
case Left(err) =>
|
||||
Response.status(400).entity(ApiErrorResponse("RESIGN_ERROR", err)).build()
|
||||
|
||||
@POST
|
||||
@Path("/{gameId}/draw/{action}")
|
||||
def drawAction(
|
||||
@PathParam("gameId") gameId: String,
|
||||
@PathParam("action") action: String,
|
||||
): Response =
|
||||
store.drawAction(gameId, action) match
|
||||
case Right(_) => Response.ok(OkResponse()).build()
|
||||
case Left(err) if err.contains("not found") =>
|
||||
Response.status(404).entity(ApiErrorResponse("GAME_NOT_FOUND", err)).build()
|
||||
case Left(err) =>
|
||||
Response.status(400).entity(ApiErrorResponse("DRAW_ERROR", err)).build()
|
||||
|
||||
@GET
|
||||
@Path("/{gameId}/export/fen")
|
||||
@Produces(Array(MediaType.TEXT_PLAIN))
|
||||
def exportFen(@PathParam("gameId") gameId: String): Response =
|
||||
store.get(gameId) match
|
||||
case None =>
|
||||
Response
|
||||
.status(404)
|
||||
.`type`(MediaType.APPLICATION_JSON)
|
||||
.entity(ApiErrorResponse("GAME_NOT_FOUND", s"Game $gameId not found"))
|
||||
.build()
|
||||
case Some(session) =>
|
||||
import de.nowchess.io.fen.FenExporter
|
||||
Response.ok(FenExporter.exportGameContext(session.context)).build()
|
||||
|
||||
@GET
|
||||
@Path("/{gameId}/export/pgn")
|
||||
@Produces(Array("application/x-chess-pgn"))
|
||||
def exportPgn(@PathParam("gameId") gameId: String): Response =
|
||||
store.get(gameId) match
|
||||
case None =>
|
||||
Response
|
||||
.status(404)
|
||||
.`type`(MediaType.APPLICATION_JSON)
|
||||
.entity(ApiErrorResponse("GAME_NOT_FOUND", s"Game $gameId not found"))
|
||||
.build()
|
||||
case Some(session) =>
|
||||
import de.nowchess.io.pgn.PgnExporter
|
||||
Response.ok(PgnExporter.exportGameContext(session.context)).build()
|
||||
@@ -1,29 +0,0 @@
|
||||
package de.nowchess.backcore.resource
|
||||
|
||||
import de.nowchess.backcore.dto.{ApiErrorResponse, ImportFenRequest, ImportPgnRequest}
|
||||
import de.nowchess.backcore.game.{GameMapper, GameStore}
|
||||
import jakarta.enterprise.context.ApplicationScoped
|
||||
import jakarta.inject.Inject
|
||||
import jakarta.ws.rs.*
|
||||
import jakarta.ws.rs.core.{MediaType, Response}
|
||||
|
||||
@Path("/api/board/game/import")
|
||||
@Produces(Array(MediaType.APPLICATION_JSON))
|
||||
@Consumes(Array(MediaType.APPLICATION_JSON))
|
||||
@ApplicationScoped
|
||||
class ImportResource @Inject() (store: GameStore):
|
||||
|
||||
@POST
|
||||
@Path("/fen")
|
||||
def importFen(req: ImportFenRequest): Response =
|
||||
store.importFen(Option(req).getOrElse(ImportFenRequest())) match
|
||||
case Right(session) => Response.status(201).entity(GameMapper.toGameFull(session)).build()
|
||||
case Left(err) => Response.status(400).entity(ApiErrorResponse("INVALID_FEN", err)).build()
|
||||
|
||||
@POST
|
||||
@Path("/pgn")
|
||||
def importPgn(req: ImportPgnRequest): Response =
|
||||
val body = Option(req).getOrElse(ImportPgnRequest())
|
||||
store.importPgn(body.pgn, None, None) match
|
||||
case Right(session) => Response.status(201).entity(GameMapper.toGameFull(session)).build()
|
||||
case Left(err) => Response.status(400).entity(ApiErrorResponse("INVALID_PGN", err)).build()
|
||||
@@ -1,87 +0,0 @@
|
||||
package de.nowchess.backcore.resource
|
||||
|
||||
import de.nowchess.api.board.Square
|
||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||
import de.nowchess.backcore.dto.*
|
||||
import de.nowchess.backcore.game.{GameMapper, GameStore}
|
||||
import jakarta.enterprise.context.ApplicationScoped
|
||||
import jakarta.inject.Inject
|
||||
import jakarta.ws.rs.*
|
||||
import jakarta.ws.rs.core.{MediaType, Response}
|
||||
|
||||
@Path("/api/board/game")
|
||||
@Produces(Array(MediaType.APPLICATION_JSON))
|
||||
@ApplicationScoped
|
||||
class MoveResource @Inject() (store: GameStore):
|
||||
|
||||
@POST
|
||||
@Path("/{gameId}/move/{uci}")
|
||||
def makeMove(
|
||||
@PathParam("gameId") gameId: String,
|
||||
@PathParam("uci") uci: String,
|
||||
): Response =
|
||||
store.applyMove(gameId, uci) match
|
||||
case Right(session) => Response.ok(GameMapper.toGameState(session)).build()
|
||||
case Left(err) if err.contains("not found") =>
|
||||
Response.status(404).entity(ApiErrorResponse("GAME_NOT_FOUND", err)).build()
|
||||
case Left(err) =>
|
||||
Response.status(400).entity(ApiErrorResponse("INVALID_MOVE", err)).build()
|
||||
|
||||
@GET
|
||||
@Path("/{gameId}/moves")
|
||||
def getLegalMoves(
|
||||
@PathParam("gameId") gameId: String,
|
||||
@QueryParam("square") squareParam: String,
|
||||
): Response =
|
||||
val square = Option(squareParam).flatMap(Square.fromAlgebraic)
|
||||
store.legalMoves(gameId, square) match
|
||||
case Right(moves) =>
|
||||
val dtos = moves.map(toLegalMoveDto)
|
||||
Response.ok(LegalMovesResponse(dtos)).build()
|
||||
case Left(err) if err.contains("not found") =>
|
||||
Response.status(404).entity(ApiErrorResponse("GAME_NOT_FOUND", err)).build()
|
||||
case Left(err) =>
|
||||
Response.status(400).entity(ApiErrorResponse("ERROR", err)).build()
|
||||
|
||||
@POST
|
||||
@Path("/{gameId}/undo")
|
||||
def undoMove(@PathParam("gameId") gameId: String): Response =
|
||||
store.undo(gameId) match
|
||||
case Right(session) => Response.ok(GameMapper.toGameState(session)).build()
|
||||
case Left(err) if err.contains("not found") =>
|
||||
Response.status(404).entity(ApiErrorResponse("GAME_NOT_FOUND", err)).build()
|
||||
case Left(err) =>
|
||||
Response.status(400).entity(ApiErrorResponse("UNDO_NOT_AVAILABLE", err)).build()
|
||||
|
||||
@POST
|
||||
@Path("/{gameId}/redo")
|
||||
def redoMove(@PathParam("gameId") gameId: String): Response =
|
||||
store.redo(gameId) match
|
||||
case Right(session) => Response.ok(GameMapper.toGameState(session)).build()
|
||||
case Left(err) if err.contains("not found") =>
|
||||
Response.status(404).entity(ApiErrorResponse("GAME_NOT_FOUND", err)).build()
|
||||
case Left(err) =>
|
||||
Response.status(400).entity(ApiErrorResponse("REDO_NOT_AVAILABLE", err)).build()
|
||||
|
||||
private def toLegalMoveDto(move: Move): LegalMoveDto =
|
||||
val uci = GameMapper.moveToUci(move)
|
||||
val (moveType, promotion) = move.moveType match
|
||||
case MoveType.Normal(true) => ("capture", None)
|
||||
case MoveType.Normal(false) => ("normal", None)
|
||||
case MoveType.CastleKingside => ("castleKingside", None)
|
||||
case MoveType.CastleQueenside => ("castleQueenside", None)
|
||||
case MoveType.EnPassant => ("enPassant", None)
|
||||
case MoveType.Promotion(pp) =>
|
||||
val pName = pp match
|
||||
case PromotionPiece.Queen => "queen"
|
||||
case PromotionPiece.Rook => "rook"
|
||||
case PromotionPiece.Bishop => "bishop"
|
||||
case PromotionPiece.Knight => "knight"
|
||||
("promotion", Some(pName))
|
||||
LegalMoveDto(
|
||||
from = move.from.toString,
|
||||
to = move.to.toString,
|
||||
uci = uci,
|
||||
moveType = moveType,
|
||||
promotion = promotion,
|
||||
)
|
||||
@@ -1,11 +0,0 @@
|
||||
package de.nowchess.backcore
|
||||
|
||||
import io.quarkus.test.junit.QuarkusTest
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
@QuarkusTest
|
||||
class BackcoreStartupTest:
|
||||
@Test
|
||||
def applicationStarts(): Unit =
|
||||
// If we get here the Quarkus container started successfully
|
||||
()
|
||||
@@ -1,67 +0,0 @@
|
||||
package de.nowchess.backcore.resource
|
||||
|
||||
import io.quarkus.test.junit.QuarkusTest
|
||||
import io.restassured.RestAssured
|
||||
import org.hamcrest.Matchers.{equalTo, matchesPattern, notNullValue}
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
@QuarkusTest
|
||||
class GameResourceTest:
|
||||
|
||||
@Test
|
||||
def createGameReturns201WithGameId(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body("{}")
|
||||
.when()
|
||||
.post("/api/board/game")
|
||||
.`then`()
|
||||
.statusCode(201)
|
||||
.body("gameId", matchesPattern("[A-Za-z0-9]{8}"))
|
||||
.body("state.fen", notNullValue())
|
||||
.body("state.turn", equalTo("white"))
|
||||
.body("state.status", equalTo("started"))
|
||||
|
||||
@Test
|
||||
def createGameWithPlayersReturns201(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body("""{"white":{"id":"p1","displayName":"Alice"},"black":{"id":"p2","displayName":"Bob"}}""")
|
||||
.when()
|
||||
.post("/api/board/game")
|
||||
.`then`()
|
||||
.statusCode(201)
|
||||
.body("white.id", equalTo("p1"))
|
||||
.body("black.displayName", equalTo("Bob"))
|
||||
|
||||
@Test
|
||||
def getGameReturns200ForExistingGame(): Unit =
|
||||
val gameId = RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body("{}")
|
||||
.when()
|
||||
.post("/api/board/game")
|
||||
.`then`()
|
||||
.statusCode(201)
|
||||
.extract()
|
||||
.path[String]("gameId")
|
||||
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get(s"/api/board/game/$gameId")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("gameId", equalTo(gameId))
|
||||
|
||||
@Test
|
||||
def getGameReturns404ForUnknownId(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get("/api/board/game/XXXXXXXX")
|
||||
.`then`()
|
||||
.statusCode(404)
|
||||
@@ -1,177 +0,0 @@
|
||||
package de.nowchess.backcore.resource
|
||||
|
||||
import io.quarkus.test.junit.QuarkusTest
|
||||
import io.restassured.RestAssured
|
||||
import org.hamcrest.Matchers.{containsString, equalTo, matchesPattern, notNullValue}
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
@QuarkusTest
|
||||
class ImportExportTest:
|
||||
|
||||
private val startFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
||||
|
||||
// ─── Import FEN ────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
def importFenReturns201WithCorrectPosition(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body(s"""{"fen":"$startFen"}""")
|
||||
.when()
|
||||
.post("/api/board/game/import/fen")
|
||||
.`then`()
|
||||
.statusCode(201)
|
||||
.body("gameId", matchesPattern("[A-Za-z0-9]{8}"))
|
||||
.body("state.fen", equalTo(startFen))
|
||||
.body("state.turn", equalTo("white"))
|
||||
|
||||
@Test
|
||||
def importFenWithCustomPositionWorks(): Unit =
|
||||
val fen = "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1"
|
||||
RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body(s"""{"fen":"$fen"}""")
|
||||
.when()
|
||||
.post("/api/board/game/import/fen")
|
||||
.`then`()
|
||||
.statusCode(201)
|
||||
.body("state.fen", equalTo(fen))
|
||||
.body("state.turn", equalTo("black"))
|
||||
|
||||
@Test
|
||||
def importFenWithInvalidFenReturns400(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body("""{"fen":"not-a-fen"}""")
|
||||
.when()
|
||||
.post("/api/board/game/import/fen")
|
||||
.`then`()
|
||||
.statusCode(400)
|
||||
|
||||
// ─── Import PGN ────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
def importPgnReturns201(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body("""{"pgn":"1. e4 e5 2. Nf3 Nc6 *"}""")
|
||||
.when()
|
||||
.post("/api/board/game/import/pgn")
|
||||
.`then`()
|
||||
.statusCode(201)
|
||||
.body("gameId", matchesPattern("[A-Za-z0-9]{8}"))
|
||||
.body("state.turn", equalTo("white"))
|
||||
|
||||
@Test
|
||||
def importPgnWithInvalidPgnReturns400(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body("""{"pgn":"1. z9 *"}""")
|
||||
.when()
|
||||
.post("/api/board/game/import/pgn")
|
||||
.`then`()
|
||||
.statusCode(400)
|
||||
|
||||
// ─── Export FEN ────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
def exportFenReturnsStartingFen(): Unit =
|
||||
val gameId = RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body("{}")
|
||||
.when()
|
||||
.post("/api/board/game")
|
||||
.`then`()
|
||||
.statusCode(201)
|
||||
.extract()
|
||||
.path[String]("gameId")
|
||||
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get(s"/api/board/game/$gameId/export/fen")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body(equalTo(startFen))
|
||||
|
||||
@Test
|
||||
def exportFenOnUnknownGameReturns404(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get("/api/board/game/XXXXXXXX/export/fen")
|
||||
.`then`()
|
||||
.statusCode(404)
|
||||
|
||||
// ─── Export PGN ────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
def exportPgnReturnsText(): Unit =
|
||||
val gameId = RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body("{}")
|
||||
.when()
|
||||
.post("/api/board/game")
|
||||
.`then`()
|
||||
.statusCode(201)
|
||||
.extract()
|
||||
.path[String]("gameId")
|
||||
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/move/e2e4")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get(s"/api/board/game/$gameId/export/pgn")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body(containsString("e4"))
|
||||
|
||||
// ─── Stream ────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
def streamReturnsNdjsonSnapshot(): Unit =
|
||||
val gameId = RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body("{}")
|
||||
.when()
|
||||
.post("/api/board/game")
|
||||
.`then`()
|
||||
.statusCode(201)
|
||||
.extract()
|
||||
.path[String]("gameId")
|
||||
|
||||
val body = RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get(s"/api/board/game/$gameId/stream")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.contentType("application/x-ndjson")
|
||||
.extract()
|
||||
.body()
|
||||
.asString()
|
||||
|
||||
assert(body.trim.startsWith("""{"type":"gameFull""""), s"Expected gameFull event, got: $body")
|
||||
|
||||
@Test
|
||||
def streamOnUnknownGameReturns404(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get("/api/board/game/XXXXXXXX/stream")
|
||||
.`then`()
|
||||
.statusCode(404)
|
||||
@@ -1,84 +0,0 @@
|
||||
package de.nowchess.backcore.resource
|
||||
|
||||
import io.quarkus.test.junit.QuarkusTest
|
||||
import io.restassured.RestAssured
|
||||
import org.hamcrest.Matchers.{containsString, empty, equalTo, hasItem, hasItems, not, notNullValue}
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
@QuarkusTest
|
||||
class MoveResourceTest:
|
||||
|
||||
private def createGame(): String =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body("{}")
|
||||
.when()
|
||||
.post("/api/board/game")
|
||||
.`then`()
|
||||
.statusCode(201)
|
||||
.extract()
|
||||
.path[String]("gameId")
|
||||
|
||||
@Test
|
||||
def makeMoveReturns200WithUpdatedFen(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/move/e2e4")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("fen", containsString("4P3")) // e4 pawn present in FEN
|
||||
.body("turn", equalTo("black"))
|
||||
.body("moves", hasItem("e2e4"))
|
||||
|
||||
@Test
|
||||
def makeMoveOnUnknownGameReturns404(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post("/api/board/game/XXXXXXXX/move/e2e4")
|
||||
.`then`()
|
||||
.statusCode(404)
|
||||
|
||||
@Test
|
||||
def illegalMoveReturns400(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/move/e2e5") // illegal — pawns can't jump 3 squares
|
||||
.`then`()
|
||||
.statusCode(400)
|
||||
|
||||
@Test
|
||||
def getLegalMovesReturnsNonEmptyList(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get(s"/api/board/game/$gameId/moves")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("moves", not(empty()))
|
||||
|
||||
@Test
|
||||
def getLegalMovesFilteredBySquareReturnsCorrectMoves(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get(s"/api/board/game/$gameId/moves?square=e2")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("moves.uci", hasItems("e2e3", "e2e4"))
|
||||
|
||||
@Test
|
||||
def getLegalMovesOnUnknownGameReturns404(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get("/api/board/game/XXXXXXXX/moves")
|
||||
.`then`()
|
||||
.statusCode(404)
|
||||
@@ -1,168 +0,0 @@
|
||||
package de.nowchess.backcore.resource
|
||||
|
||||
import io.quarkus.test.junit.QuarkusTest
|
||||
import io.restassured.RestAssured
|
||||
import org.hamcrest.Matchers.{equalTo, notNullValue}
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
@QuarkusTest
|
||||
class ResignDrawTest:
|
||||
|
||||
private def createGame(): String =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body("{}")
|
||||
.when()
|
||||
.post("/api/board/game")
|
||||
.`then`()
|
||||
.statusCode(201)
|
||||
.extract()
|
||||
.path[String]("gameId")
|
||||
|
||||
// ─── Resign ────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
def resignReturns200(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/resign")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("ok", equalTo(true))
|
||||
|
||||
@Test
|
||||
def afterResignGameShowsResignStatusAndWinner(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/resign")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get(s"/api/board/game/$gameId")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("state.status", equalTo("resign"))
|
||||
.body("state.winner", notNullValue())
|
||||
|
||||
@Test
|
||||
def resignOnUnknownGameReturns404(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post("/api/board/game/XXXXXXXX/resign")
|
||||
.`then`()
|
||||
.statusCode(404)
|
||||
|
||||
// ─── Draw ──────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
def offerDrawSetsDrawOfferedStatus(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/draw/offer")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("ok", equalTo(true))
|
||||
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get(s"/api/board/game/$gameId")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("state.status", equalTo("drawOffered"))
|
||||
|
||||
@Test
|
||||
def acceptDrawAfterOfferSetsDrawStatus(): Unit =
|
||||
val gameId = createGame()
|
||||
// White offers
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/draw/offer")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
|
||||
// Black moves so it's black's turn... actually the API doesn't enforce turn-based draw accept.
|
||||
// White offered, so black (opponent) accepts — but since there's no auth, we just call accept.
|
||||
// The GameStore checks drawOfferedBy != turn to allow accept.
|
||||
// White offered on white's turn, so black needs to accept — but current turn is still white.
|
||||
// We need to make a move first to switch turns.
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/move/e2e4")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
|
||||
// Now it's black's turn and white offered the draw — black accepts
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/draw/accept")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("ok", equalTo(true))
|
||||
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get(s"/api/board/game/$gameId")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("state.status", equalTo("draw"))
|
||||
|
||||
@Test
|
||||
def declineDrawClearsOffer(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/draw/offer")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/draw/decline")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("ok", equalTo(true))
|
||||
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.get(s"/api/board/game/$gameId")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("state.status", equalTo("started"))
|
||||
|
||||
@Test
|
||||
def acceptWithoutOfferReturns400(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/draw/accept")
|
||||
.`then`()
|
||||
.statusCode(400)
|
||||
|
||||
@Test
|
||||
def drawOnUnknownGameReturns404(): Unit =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post("/api/board/game/XXXXXXXX/draw/offer")
|
||||
.`then`()
|
||||
.statusCode(404)
|
||||
@@ -1,88 +0,0 @@
|
||||
package de.nowchess.backcore.resource
|
||||
|
||||
import io.quarkus.test.junit.QuarkusTest
|
||||
import io.restassured.RestAssured
|
||||
import org.hamcrest.Matchers.{containsString, equalTo}
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
@QuarkusTest
|
||||
class UndoRedoTest:
|
||||
|
||||
private val initialFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
||||
|
||||
private def createGame(): String =
|
||||
RestAssured
|
||||
.`given`()
|
||||
.contentType("application/json")
|
||||
.body("{}")
|
||||
.when()
|
||||
.post("/api/board/game")
|
||||
.`then`()
|
||||
.statusCode(201)
|
||||
.extract()
|
||||
.path[String]("gameId")
|
||||
|
||||
@Test
|
||||
def undoAfterMoveRestoresOriginalPosition(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/move/e2e4")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/undo")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("fen", equalTo(initialFen))
|
||||
.body("undoAvailable", equalTo(false))
|
||||
|
||||
@Test
|
||||
def redoAfterUndoRestoresMovedPosition(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/move/e2e4")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/undo")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/redo")
|
||||
.`then`()
|
||||
.statusCode(200)
|
||||
.body("fen", containsString("4P3"))
|
||||
.body("turn", equalTo("black"))
|
||||
|
||||
@Test
|
||||
def undoWithNoHistoryReturns400(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/undo")
|
||||
.`then`()
|
||||
.statusCode(400)
|
||||
|
||||
@Test
|
||||
def redoWithNoRedoStackReturns400(): Unit =
|
||||
val gameId = createGame()
|
||||
RestAssured
|
||||
.`given`()
|
||||
.when()
|
||||
.post(s"/api/board/game/$gameId/redo")
|
||||
.`then`()
|
||||
.statusCode(400)
|
||||
@@ -0,0 +1,86 @@
|
||||
plugins {
|
||||
id("scala")
|
||||
id("org.scoverage")
|
||||
}
|
||||
|
||||
group = "de.nowchess"
|
||||
version = "1.0-SNAPSHOT"
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val versions = rootProject.extra["VERSIONS"] as Map<String, String>
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
scala {
|
||||
scalaVersion = versions["SCALA3"]!!
|
||||
}
|
||||
|
||||
scoverage {
|
||||
scoverageVersion.set(versions["SCOVERAGE"]!!)
|
||||
excludedPackages.set(
|
||||
listOf(
|
||||
"de\\.nowchess\\.bot\\.bots\\.NNUEBot",
|
||||
"de\\.nowchess\\.bot\\.bots\\.nnue\\..*",
|
||||
"de\\.nowchess\\.bot\\.util\\.PolyglotBook",
|
||||
)
|
||||
)
|
||||
excludedFiles.set(
|
||||
listOf(
|
||||
".*NNUE\\.scala",
|
||||
".*NNUEBot\\.scala",
|
||||
".*NbaiLoader\\.scala",
|
||||
".*NbaiMigrator\\.scala",
|
||||
".*NbaiWriter\\.scala",
|
||||
".*PolyglotBook\\.scala",
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
tasks.withType<ScalaCompile> {
|
||||
scalaCompileOptions.additionalParameters = listOf("-encoding", "UTF-8")
|
||||
}
|
||||
|
||||
dependencies {
|
||||
|
||||
implementation("org.scala-lang:scala3-compiler_3") {
|
||||
version {
|
||||
strictly(versions["SCALA3"]!!)
|
||||
}
|
||||
}
|
||||
implementation("org.scala-lang:scala3-library_3") {
|
||||
version {
|
||||
strictly(versions["SCALA3"]!!)
|
||||
}
|
||||
}
|
||||
|
||||
implementation(project(":modules:api"))
|
||||
implementation(project(":modules:io"))
|
||||
implementation(project(":modules:rule"))
|
||||
implementation("com.microsoft.onnxruntime:onnxruntime:${versions["ONNXRUNTIME"]!!}")
|
||||
|
||||
testImplementation(platform("org.junit:junit-bom:${versions["JUNIT_BOM"]!!}"))
|
||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||
testImplementation("org.scalatest:scalatest_3:${versions["SCALATEST"]!!}")
|
||||
testImplementation("co.helmethair:scalatest-junit-runner:${versions["SCALATEST_JUNIT"]!!}")
|
||||
|
||||
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
||||
}
|
||||
|
||||
tasks.test {
|
||||
useJUnitPlatform {
|
||||
includeEngines("scalatest")
|
||||
testLogging {
|
||||
events("skipped", "failed")
|
||||
}
|
||||
}
|
||||
finalizedBy(tasks.reportScoverage)
|
||||
}
|
||||
tasks.reportScoverage {
|
||||
dependsOn(tasks.test)
|
||||
}
|
||||
|
||||
tasks.jar {
|
||||
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,22 @@
|
||||
# Data and weights are local artifacts, not committed
|
||||
data/
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
.venv
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
tactical_data/
|
||||
trainingdata/
|
||||
/datasets/
|
||||
@@ -0,0 +1,173 @@
|
||||
# Training Dataset Management
|
||||
|
||||
The NNUE training pipeline now features versioned dataset management, similar to model versioning. This prevents data loss and allows you to maintain multiple training configurations.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
datasets/
|
||||
ds_v1/
|
||||
labeled.jsonl # Training data: {"fen": "...", "eval": 0.5, "eval_raw": 150}
|
||||
metadata.json # Version info and composition
|
||||
ds_v2/
|
||||
labeled.jsonl
|
||||
metadata.json
|
||||
```
|
||||
|
||||
## Metadata Schema
|
||||
|
||||
Each dataset has a `metadata.json` file tracking its composition:
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"created": "2026-04-13T15:30:45.123456",
|
||||
"total_positions": 1000000,
|
||||
"stockfish_depth": 12,
|
||||
"sources": [
|
||||
{
|
||||
"type": "generated",
|
||||
"count": 500000,
|
||||
"params": {
|
||||
"num_positions": 500000,
|
||||
"min_move": 1,
|
||||
"max_move": 50
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "tactical",
|
||||
"count": 300000,
|
||||
"max_puzzles": 300000
|
||||
},
|
||||
{
|
||||
"type": "file_import",
|
||||
"count": 200000,
|
||||
"path": "/path/to/original_file.txt"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## TUI Workflow
|
||||
|
||||
### Main Menu
|
||||
```
|
||||
1 - Manage Training Data
|
||||
2 - Train Model
|
||||
3 - Export Model
|
||||
4 - Exit
|
||||
```
|
||||
|
||||
### Training Data Management Submenu
|
||||
```
|
||||
1 - Create new dataset
|
||||
2 - Extend existing dataset
|
||||
3 - View all datasets
|
||||
4 - Delete dataset
|
||||
5 - Back
|
||||
```
|
||||
|
||||
## Creating a Dataset
|
||||
|
||||
Use the "Create new dataset" option to add data from one or more sources:
|
||||
|
||||
1. **Generate random positions** — Play random games and sample positions
|
||||
- Number of positions
|
||||
- Move range (min/max move number to sample from)
|
||||
- Number of worker threads
|
||||
|
||||
2. **Import from file** — Load positions from a FEN file
|
||||
- File must contain one FEN string per line
|
||||
- Duplicates are automatically removed
|
||||
|
||||
3. **Extract tactical puzzles** — Download and extract Lichess puzzle database
|
||||
- Maximum number of puzzles to include
|
||||
- Automatically filters for tactical themes (forks, pins, mates, etc.)
|
||||
|
||||
You can combine multiple sources in a single dataset creation session. All positions are:
|
||||
- Deduplicated (only unique FENs are kept)
|
||||
- Labeled with Stockfish evaluations
|
||||
- Saved to `datasets/ds_vN/labeled.jsonl`
|
||||
|
||||
## Extending a Dataset
|
||||
|
||||
Use "Extend existing dataset" to add more positions to an existing dataset:
|
||||
|
||||
1. Select the dataset version to extend
|
||||
2. Choose data sources (same options as creation)
|
||||
3. Confirm labeling parameters
|
||||
4. New positions are:
|
||||
- Labeled with Stockfish
|
||||
- Deduplicated against the target dataset (preventing duplicates)
|
||||
- Merged into the existing `labeled.jsonl`
|
||||
- Metadata is updated with the new source entry
|
||||
|
||||
## Training with a Dataset
|
||||
|
||||
When you start training (Standard or Burst mode), you'll be prompted to select a dataset version. The TUI will display all available datasets with:
|
||||
- Version number
|
||||
- Total number of positions
|
||||
- Source types (generated, tactical, imported)
|
||||
- Stockfish depth used
|
||||
- Creation date
|
||||
|
||||
## Legacy Data Migration
|
||||
|
||||
If you have existing labeled data in `data/training_data.jsonl` from before this update:
|
||||
|
||||
1. Open the "Manage Training Data" menu
|
||||
2. Choose "Create new dataset"
|
||||
3. Select "Import from file"
|
||||
4. Point to `data/training_data.jsonl`
|
||||
5. Complete the dataset creation
|
||||
|
||||
Alternatively, you can manually copy the file to `datasets/ds_v1/labeled.jsonl` and create a `metadata.json` file.
|
||||
|
||||
## Viewing Dataset Details
|
||||
|
||||
Use "View all datasets" to see a table of all datasets with:
|
||||
- Version number
|
||||
- Position count
|
||||
- Source composition
|
||||
- Stockfish depth
|
||||
- Creation date
|
||||
|
||||
## Deleting a Dataset
|
||||
|
||||
Use "Delete dataset" to remove a dataset and free up disk space. **This action cannot be undone.**
|
||||
|
||||
⚠️ The system does not prevent deleting datasets used by model checkpoints. Plan accordingly.
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Deduplication Strategy
|
||||
|
||||
When extending a dataset, positions are deduplicated **within that dataset only**. This allows different datasets to contain overlapping positions if desired.
|
||||
|
||||
When creating a new dataset from multiple sources, all sources are combined and deduplicated before labeling.
|
||||
|
||||
### Labeled Position Format
|
||||
|
||||
Each line in `labeled.jsonl` is a JSON object:
|
||||
```json
|
||||
{
|
||||
"fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
|
||||
"eval": 0.0,
|
||||
"eval_raw": 0
|
||||
}
|
||||
```
|
||||
|
||||
- `fen`: The position in Forsyth-Edwards Notation
|
||||
- `eval`: Normalized evaluation ([-1, 1] range using tanh)
|
||||
- `eval_raw`: Raw Stockfish evaluation in centipawns
|
||||
|
||||
### Storage Location
|
||||
|
||||
Datasets are stored in the `datasets/` directory relative to the script location. The old `data/` directory is preserved for backward compatibility but is not actively used by the new system.
|
||||
|
||||
## Performance Tips
|
||||
|
||||
- **Smaller datasets train faster** — Start with 100k-500k positions
|
||||
- **Deduplication matters** — Use the extend functionality to build up your dataset without redundant data
|
||||
- **Stockfish depth** — Depth 12-14 balances accuracy and labeling speed
|
||||
- **Workers** — Use 4-8 workers for labeling if your machine supports it; more workers = faster but uses more CPU/memory
|
||||
@@ -0,0 +1,129 @@
|
||||
# NNUE Python Pipeline
|
||||
|
||||
Central CLI for training and exporting chess evaluation neural networks (NNUE).
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
python/
|
||||
├── nnue.py # Main CLI entry point
|
||||
├── src/ # Python modules
|
||||
│ ├── generate.py # Generate random chess positions
|
||||
│ ├── label.py # Label positions with Stockfish
|
||||
│ ├── train.py # Train NNUE model
|
||||
│ └── export.py # Export weights to Scala
|
||||
├── data/ # Training data (gitignored)
|
||||
│ ├── positions.txt
|
||||
│ └── training_data.jsonl
|
||||
└── weights/ # Model weights (gitignored)
|
||||
├── nnue_weights_v1.pt
|
||||
├── nnue_weights_v1_metadata.json
|
||||
└── ...
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Train a new model (500k positions, auto-detect checkpoint)
|
||||
python nnue.py train
|
||||
|
||||
# Train from specific checkpoint
|
||||
python nnue.py train --from-checkpoint 2
|
||||
|
||||
# Train with custom games count
|
||||
python nnue.py train --games 200000
|
||||
|
||||
# Train with custom positions file
|
||||
python nnue.py train --positions-file my_positions.txt
|
||||
|
||||
# Export specific version to Scala
|
||||
python nnue.py export 2
|
||||
|
||||
# List all checkpoints
|
||||
python nnue.py list
|
||||
```
|
||||
|
||||
## CLI Commands
|
||||
|
||||
### `train` - Train NNUE model
|
||||
|
||||
```bash
|
||||
python nnue.py train [OPTIONS]
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--from-checkpoint N` - Resume from checkpoint version N (default: uses latest)
|
||||
- `--games N` - Number of games to generate (default: 500000)
|
||||
- `--positions-file FILE` - Use existing positions file instead of generating
|
||||
- `--stockfish PATH` - Path to Stockfish binary (default: `$STOCKFISH_PATH` or `/usr/games/stockfish`)
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Train with latest checkpoint
|
||||
python nnue.py train
|
||||
|
||||
# Train from v2 with 100k games
|
||||
python nnue.py train --from-checkpoint 2 --games 100000
|
||||
|
||||
# Train with custom positions
|
||||
python nnue.py train --positions-file my_games.txt --stockfish /opt/stockfish/sf15
|
||||
```
|
||||
|
||||
### `export` - Export weights to Scala
|
||||
|
||||
```bash
|
||||
python nnue.py export WEIGHTS [output_path]
|
||||
```
|
||||
|
||||
**Arguments:**
|
||||
- `WEIGHTS` - Version number (e.g., `2`) or full filename (e.g., `nnue_weights_v2.pt`)
|
||||
|
||||
**Examples:**
|
||||
```bash
|
||||
# Export version 2
|
||||
python nnue.py export 2
|
||||
|
||||
# Export with full filename
|
||||
python nnue.py export nnue_weights_v3.pt
|
||||
```
|
||||
|
||||
Output goes to `../src/main/scala/de/nowchess/bot/bots/nnue/NNUEWeights_vN.scala`
|
||||
|
||||
### `list` - List available checkpoints
|
||||
|
||||
```bash
|
||||
python nnue.py list
|
||||
```
|
||||
|
||||
Shows all available model versions with file sizes.
|
||||
|
||||
## Data Flow
|
||||
|
||||
1. **Generate** → `data/positions.txt`
|
||||
- Random chess positions from 8-20 move openings
|
||||
- Filters out checks, game-over states, and captures
|
||||
|
||||
2. **Label** → `data/training_data.jsonl`
|
||||
- Evaluates each position with Stockfish at depth 12
|
||||
- Stores FEN + evaluation in JSONL format
|
||||
|
||||
3. **Train** → `weights/nnue_weights_vN.pt`
|
||||
- Trains neural network on labeled positions
|
||||
- Auto-versioning (v1, v2, v3, etc.)
|
||||
- Saves metadata alongside weights
|
||||
|
||||
4. **Export** → `NNUEWeights_vN.scala`
|
||||
- Converts weights to Scala object
|
||||
- Ready for integration into bot
|
||||
|
||||
## Versioning
|
||||
|
||||
- Models are automatically versioned (v1, v2, v3, etc.)
|
||||
- Each version gets a `_metadata.json` file with training info
|
||||
- Training from checkpoint uses latest version unless specified with `--from-checkpoint`
|
||||
|
||||
## Files
|
||||
|
||||
- `data/` and `weights/` are gitignored (local artifacts)
|
||||
- Documentation in `docs/` explains training, debugging, and incremental improvements
|
||||
- Source modules in `src/` are independent and can be imported for custom workflows
|
||||
@@ -0,0 +1,951 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Central NNUE pipeline TUI for training and exporting models."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt, Confirm
|
||||
from rich import print as rprint
|
||||
|
||||
# Add src directory to path so we can import modules
|
||||
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||||
|
||||
from generate import play_random_game_and_collect_positions
|
||||
from label import label_positions_with_stockfish
|
||||
from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES
|
||||
from export import export_to_nbai
|
||||
from tactical_positions_extractor import (
|
||||
download_and_extract_puzzle_db,
|
||||
extract_tactical_only
|
||||
)
|
||||
from lichess_importer import import_lichess_evals
|
||||
from dataset import (
|
||||
get_datasets_dir,
|
||||
list_datasets,
|
||||
next_dataset_version,
|
||||
load_dataset_metadata,
|
||||
create_dataset,
|
||||
extend_dataset,
|
||||
get_dataset_labeled_path,
|
||||
delete_dataset,
|
||||
show_datasets_table
|
||||
)
|
||||
|
||||
|
||||
def get_weights_dir():
|
||||
"""Get/create weights directory."""
|
||||
weights_dir = Path(__file__).parent / "weights"
|
||||
weights_dir.mkdir(exist_ok=True)
|
||||
return weights_dir
|
||||
|
||||
|
||||
def get_data_dir():
|
||||
"""Get/create legacy data directory (for migration)."""
|
||||
data_dir = Path(__file__).parent / "data"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
return data_dir
|
||||
|
||||
|
||||
def list_checkpoints():
|
||||
"""List available checkpoint versions."""
|
||||
weights_dir = get_weights_dir()
|
||||
checkpoints = sorted(weights_dir.glob("nnue_weights_v*.pt"))
|
||||
if not checkpoints:
|
||||
return []
|
||||
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
||||
|
||||
|
||||
def migrate_legacy_data():
|
||||
"""On first run, offer to import existing data/training_data.jsonl as ds_v1."""
|
||||
console = Console()
|
||||
data_dir = get_data_dir()
|
||||
legacy_file = data_dir / "training_data.jsonl"
|
||||
datasets = list_datasets()
|
||||
|
||||
# Only migrate if legacy data exists and no datasets exist yet
|
||||
if legacy_file.exists() and not datasets:
|
||||
console.print("\n[cyan]Legacy data detected: data/training_data.jsonl[/cyan]")
|
||||
console.print("[dim]Tip: Use 'Manage Training Data' menu to import it as ds_v1[/dim]")
|
||||
|
||||
|
||||
def show_header():
|
||||
"""Display application header."""
|
||||
console = Console()
|
||||
console.clear()
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
|
||||
"[dim]Neural Network Utility Evaluation - Dataset & Model Management[/dim]",
|
||||
border_style="cyan",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def show_checkpoints_table():
|
||||
"""Display available checkpoints in a table."""
|
||||
console = Console()
|
||||
available = list_checkpoints()
|
||||
|
||||
if not available:
|
||||
console.print("[yellow]ℹ No model checkpoints found yet[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Available Model Checkpoints", show_header=True, header_style="bold cyan")
|
||||
table.add_column("Version", style="dim")
|
||||
table.add_column("File Size", justify="right")
|
||||
table.add_column("Status", justify="center")
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
for v in sorted(available):
|
||||
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
|
||||
if weights_file.exists():
|
||||
size = weights_file.stat().st_size / (1024**2)
|
||||
table.add_row(f"v{v}", f"{size:.1f} MB", "✓ Ready")
|
||||
else:
|
||||
table.add_row(f"v{v}", "?", "[red]✗ Missing[/red]")
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def show_main_menu():
|
||||
"""Display and handle main menu."""
|
||||
console = Console()
|
||||
|
||||
# Migrate legacy data on first run
|
||||
migrate_legacy_data()
|
||||
|
||||
while True:
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
|
||||
console.print("\n[bold]What would you like to do?[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Manage Training Data")
|
||||
console.print("[cyan]2[/cyan] - Train Model")
|
||||
console.print("[cyan]3[/cyan] - Export Model")
|
||||
console.print("[cyan]4[/cyan] - Exit")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||
|
||||
if choice == "1":
|
||||
datasets_menu()
|
||||
elif choice == "2":
|
||||
training_menu()
|
||||
elif choice == "3":
|
||||
export_interactive()
|
||||
elif choice == "4":
|
||||
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||
return
|
||||
|
||||
|
||||
def datasets_menu():
|
||||
"""Dataset management submenu."""
|
||||
console = Console()
|
||||
|
||||
while True:
|
||||
show_header()
|
||||
show_datasets_table(console)
|
||||
|
||||
console.print("\n[bold]Training Data Management[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Create new dataset")
|
||||
console.print("[cyan]2[/cyan] - Extend existing dataset")
|
||||
console.print("[cyan]3[/cyan] - View all datasets")
|
||||
console.print("[cyan]4[/cyan] - Delete dataset")
|
||||
console.print("[cyan]5[/cyan] - Back")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
|
||||
|
||||
if choice == "1":
|
||||
create_dataset_interactive()
|
||||
elif choice == "2":
|
||||
extend_dataset_interactive()
|
||||
elif choice == "3":
|
||||
show_header()
|
||||
show_datasets_table(console)
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
elif choice == "4":
|
||||
delete_dataset_interactive()
|
||||
elif choice == "5":
|
||||
return
|
||||
|
||||
|
||||
def create_dataset_interactive():
|
||||
"""Interactive dataset creation flow."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📊 Create New Dataset[/bold cyan]")
|
||||
|
||||
sources = []
|
||||
combined_count = 0
|
||||
|
||||
# Allow user to add multiple sources
|
||||
while True:
|
||||
console.print("\n[bold]Add data source (repeat until done):[/bold]")
|
||||
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||
console.print("[cyan]b[/cyan] - Import from file")
|
||||
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
|
||||
console.print("[cyan]e[/cyan] - Done adding sources")
|
||||
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
|
||||
|
||||
if choice == "a":
|
||||
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||
num_workers = int(Prompt.ask("Number of workers", default="8"))
|
||||
|
||||
console.print("[dim]Generating positions...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
count = play_random_game_and_collect_positions(
|
||||
str(temp_file),
|
||||
total_positions=num_positions,
|
||||
samples_per_game=1,
|
||||
min_move=min_move,
|
||||
max_move=max_move,
|
||||
num_workers=num_workers
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "generated",
|
||||
"count": count,
|
||||
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions generated[/green]")
|
||||
else:
|
||||
console.print("[red]✗ Generation failed[/red]")
|
||||
|
||||
elif choice == "b":
|
||||
file_path = Prompt.ask("Path to FEN file")
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
count = sum(1 for _ in f)
|
||||
sources.append({"type": "file_import", "count": count, "path": file_path})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions from file[/green]")
|
||||
except FileNotFoundError:
|
||||
console.print(f"[red]✗ File not found: {file_path}[/red]")
|
||||
|
||||
elif choice == "c":
|
||||
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||
console.print("[dim]Extracting tactical positions...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
try:
|
||||
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
|
||||
if csv_path:
|
||||
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
|
||||
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Tactical extraction failed: {e}[/red]")
|
||||
|
||||
elif choice == "d":
|
||||
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
|
||||
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
|
||||
max_pos = int(max_pos) if max_pos.strip() else None
|
||||
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
|
||||
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
temp_file.unlink(missing_ok=True)
|
||||
try:
|
||||
count = import_lichess_evals(
|
||||
input_path=zst_path,
|
||||
output_file=str(temp_file),
|
||||
max_positions=max_pos,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "lichess",
|
||||
"count": count,
|
||||
"params": {"min_depth": min_depth, "max_positions": max_pos},
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
|
||||
else:
|
||||
console.print("[red]✗ No positions imported[/red]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
|
||||
|
||||
elif choice == "e":
|
||||
if not sources:
|
||||
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||
continue
|
||||
break
|
||||
|
||||
if not sources:
|
||||
console.print("[yellow]Dataset creation cancelled[/yellow]")
|
||||
return
|
||||
|
||||
# Determine whether any sources still need Stockfish labeling.
|
||||
# Lichess sources are already labeled; only generated/tactical/file sources need it.
|
||||
needs_labeling = any(s["type"] != "lichess" for s in sources)
|
||||
|
||||
stockfish_depth = 12
|
||||
if needs_labeling:
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Summary and confirm
|
||||
console.print("\n[bold]Dataset Summary:[/bold]")
|
||||
console.print(f" Total positions: {combined_count:,}")
|
||||
for source in sources:
|
||||
console.print(f" - {source['type']}: {source['count']:,}")
|
||||
if needs_labeling:
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
|
||||
if not Confirm.ask("\nProceed to create dataset?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||
labeled_file.unlink(missing_ok=True)
|
||||
|
||||
# --- Step 1: Collect already-labeled data (Lichess source) ---
|
||||
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
if lichess_tmp.exists():
|
||||
import shutil as _shutil
|
||||
_shutil.copy(lichess_tmp, labeled_file)
|
||||
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
|
||||
console.print(f"[green]✓ Lichess positions ready[/green]")
|
||||
|
||||
# --- Step 2: Combine unlabeled sources and run Stockfish (if any) ---
|
||||
non_lichess = [s for s in sources if s["type"] != "lichess"]
|
||||
if non_lichess:
|
||||
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
for source in non_lichess:
|
||||
if source["type"] == "generated":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source["type"] == "file_import":
|
||||
temp_file = Path(source["path"])
|
||||
elif source["type"] == "tactical":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
else:
|
||||
continue
|
||||
|
||||
if temp_file.exists():
|
||||
with open(temp_file, "r") as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, "w") as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + "\n")
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
|
||||
|
||||
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Stockfish labeling failed[/red]")
|
||||
return
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
|
||||
# --- Step 3: Create dataset ---
|
||||
console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]")
|
||||
version = next_dataset_version()
|
||||
create_dataset(
|
||||
version=version,
|
||||
labeled_jsonl_path=str(labeled_file),
|
||||
sources=sources,
|
||||
stockfish_depth=stockfish_depth,
|
||||
)
|
||||
console.print(f"[green]✓ Dataset created: ds_v{version}[/green]")
|
||||
console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]")
|
||||
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def extend_dataset_interactive():
|
||||
"""Interactive dataset extension flow."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📊 Extend Existing Dataset[/bold cyan]")
|
||||
|
||||
datasets = list_datasets()
|
||||
if not datasets:
|
||||
console.print("[yellow]ℹ No datasets available to extend[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
show_datasets_table(console)
|
||||
version = int(Prompt.ask("\nEnter dataset version to extend (e.g., 1)"))
|
||||
|
||||
if not any(v == version for v, _ in datasets):
|
||||
console.print("[red]✗ Dataset not found[/red]")
|
||||
return
|
||||
|
||||
sources = []
|
||||
combined_count = 0
|
||||
|
||||
# Allow user to add sources
|
||||
while True:
|
||||
console.print("\n[bold]Add data source:[/bold]")
|
||||
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||
console.print("[cyan]b[/cyan] - Import from file")
|
||||
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
|
||||
console.print("[cyan]e[/cyan] - Done adding sources")
|
||||
|
||||
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
|
||||
|
||||
if choice == "a":
|
||||
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||
num_workers = int(Prompt.ask("Number of workers", default="8"))
|
||||
|
||||
console.print("[dim]Generating positions...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
count = play_random_game_and_collect_positions(
|
||||
str(temp_file),
|
||||
total_positions=num_positions,
|
||||
samples_per_game=1,
|
||||
min_move=min_move,
|
||||
max_move=max_move,
|
||||
num_workers=num_workers
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "generated",
|
||||
"count": count,
|
||||
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions generated[/green]")
|
||||
|
||||
elif choice == "b":
|
||||
file_path = Prompt.ask("Path to FEN file")
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
count = sum(1 for _ in f)
|
||||
sources.append({"type": "file_import", "count": count, "path": file_path})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions from file[/green]")
|
||||
except FileNotFoundError:
|
||||
console.print(f"[red]✗ File not found: {file_path}[/red]")
|
||||
|
||||
elif choice == "c":
|
||||
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||
console.print("[dim]Extracting tactical positions...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
try:
|
||||
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
|
||||
if csv_path:
|
||||
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
|
||||
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Extraction failed: {e}[/red]")
|
||||
|
||||
elif choice == "d":
|
||||
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
|
||||
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
|
||||
max_pos = int(max_pos) if max_pos.strip() else None
|
||||
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
|
||||
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
temp_file.unlink(missing_ok=True)
|
||||
try:
|
||||
count = import_lichess_evals(
|
||||
input_path=zst_path,
|
||||
output_file=str(temp_file),
|
||||
max_positions=max_pos,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
if count > 0:
|
||||
sources.append({
|
||||
"type": "lichess",
|
||||
"count": count,
|
||||
"params": {"min_depth": min_depth, "max_positions": max_pos},
|
||||
})
|
||||
combined_count += count
|
||||
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
|
||||
else:
|
||||
console.print("[red]✗ No positions imported[/red]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
|
||||
|
||||
elif choice == "e":
|
||||
if not sources:
|
||||
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||
continue
|
||||
break
|
||||
|
||||
if not sources:
|
||||
console.print("[yellow]Extension cancelled[/yellow]")
|
||||
return
|
||||
|
||||
needs_labeling = any(s["type"] != "lichess" for s in sources)
|
||||
|
||||
stockfish_depth = 12
|
||||
if needs_labeling:
|
||||
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||
stockfish_path = Prompt.ask(
|
||||
"Stockfish path",
|
||||
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||
)
|
||||
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||
|
||||
# Summary and confirm
|
||||
console.print("\n[bold]Extension Summary:[/bold]")
|
||||
console.print(f" Target dataset: ds_v{version}")
|
||||
console.print(f" New positions: {combined_count:,}")
|
||||
for source in sources:
|
||||
console.print(f" - {source['type']}: {source['count']:,}")
|
||||
if needs_labeling:
|
||||
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||
|
||||
if not Confirm.ask("\nProceed to extend dataset?", default=True):
|
||||
console.print("[yellow]Cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||
labeled_file.unlink(missing_ok=True)
|
||||
|
||||
# Copy pre-labeled Lichess data if present
|
||||
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||
if lichess_tmp.exists():
|
||||
import shutil as _shutil
|
||||
_shutil.copy(lichess_tmp, labeled_file)
|
||||
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
|
||||
console.print(f"[green]✓ Lichess positions ready[/green]")
|
||||
|
||||
# Combine and label remaining sources with Stockfish
|
||||
non_lichess = [s for s in sources if s["type"] != "lichess"]
|
||||
if non_lichess:
|
||||
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
|
||||
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||
all_fens = set()
|
||||
|
||||
for source in non_lichess:
|
||||
if source["type"] == "generated":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||
elif source["type"] == "file_import":
|
||||
temp_file = Path(source["path"])
|
||||
elif source["type"] == "tactical":
|
||||
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||
else:
|
||||
continue
|
||||
if temp_file.exists():
|
||||
with open(temp_file, "r") as f:
|
||||
for line in f:
|
||||
fen = line.strip()
|
||||
if fen:
|
||||
all_fens.add(fen)
|
||||
|
||||
with open(combined_fen_file, "w") as f:
|
||||
for fen in all_fens:
|
||||
f.write(fen + "\n")
|
||||
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
|
||||
|
||||
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
|
||||
success = label_positions_with_stockfish(
|
||||
str(combined_fen_file),
|
||||
str(labeled_file),
|
||||
stockfish_path,
|
||||
depth=stockfish_depth,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
if not success:
|
||||
console.print("[red]✗ Stockfish labeling failed[/red]")
|
||||
return
|
||||
console.print("[green]✓ Positions labeled[/green]")
|
||||
|
||||
# Extend dataset
|
||||
console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]")
|
||||
success = extend_dataset(
|
||||
version=version,
|
||||
new_labeled_path=str(labeled_file),
|
||||
new_source_entry={
|
||||
"type": "merged_sources",
|
||||
"count": combined_count,
|
||||
"sources": sources,
|
||||
}
|
||||
)
|
||||
|
||||
if success:
|
||||
metadata = load_dataset_metadata(version)
|
||||
console.print(f"[green]✓ Dataset extended[/green]")
|
||||
console.print(f"[bold]Total positions: {metadata['total_positions']:,}[/bold]")
|
||||
else:
|
||||
console.print("[red]✗ Extension failed[/red]")
|
||||
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def delete_dataset_interactive():
|
||||
"""Interactive dataset deletion."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]⚠️ Delete Dataset[/bold cyan]")
|
||||
|
||||
datasets = list_datasets()
|
||||
if not datasets:
|
||||
console.print("[yellow]ℹ No datasets to delete[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
show_datasets_table(console)
|
||||
version = int(Prompt.ask("\nEnter dataset version to delete (e.g., 1)"))
|
||||
|
||||
if not any(v == version for v, _ in datasets):
|
||||
console.print("[red]✗ Dataset not found[/red]")
|
||||
return
|
||||
|
||||
if Confirm.ask(f"Delete ds_v{version}? This cannot be undone.", default=False):
|
||||
if delete_dataset(version):
|
||||
console.print(f"[green]✓ Dataset ds_v{version} deleted[/green]")
|
||||
else:
|
||||
console.print("[red]✗ Deletion failed[/red]")
|
||||
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def training_menu():
|
||||
"""Training submenu."""
|
||||
console = Console()
|
||||
|
||||
while True:
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold]Training[/bold]")
|
||||
console.print("[cyan]1[/cyan] - Standard Training")
|
||||
console.print("[cyan]2[/cyan] - Burst Training")
|
||||
console.print("[cyan]3[/cyan] - View Model Checkpoints")
|
||||
console.print("[cyan]4[/cyan] - Back")
|
||||
|
||||
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||
|
||||
if choice == "1":
|
||||
train_interactive()
|
||||
elif choice == "2":
|
||||
burst_train_interactive()
|
||||
elif choice == "3":
|
||||
show_header()
|
||||
show_checkpoints_table()
|
||||
Prompt.ask("\nPress Enter to continue")
|
||||
elif choice == "4":
|
||||
return
|
||||
|
||||
|
||||
def train_interactive():
|
||||
"""Interactive training menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📚 Standard Training Configuration[/bold cyan]")
|
||||
|
||||
# Dataset selection
|
||||
datasets = list_datasets()
|
||||
if not datasets:
|
||||
console.print("[red]✗ No datasets available. Create one first.[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print("\n[bold]Available Datasets:[/bold]")
|
||||
show_datasets_table(console)
|
||||
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
|
||||
|
||||
if not any(v == dataset_version for v, _ in datasets):
|
||||
console.print("[red]✗ Dataset not found[/red]")
|
||||
return
|
||||
|
||||
labeled_file = get_dataset_labeled_path(dataset_version)
|
||||
if not labeled_file:
|
||||
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
|
||||
return
|
||||
|
||||
# Checkpoint selection
|
||||
available = list_checkpoints()
|
||||
use_checkpoint = False
|
||||
checkpoint_version = None
|
||||
|
||||
if available:
|
||||
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||
use_checkpoint = Confirm.ask("Start from an existing checkpoint?", default=False)
|
||||
if use_checkpoint:
|
||||
checkpoint_version = Prompt.ask(
|
||||
"Enter checkpoint version",
|
||||
default=str(max(available))
|
||||
)
|
||||
|
||||
# Training parameters
|
||||
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
|
||||
hidden_layers_str = Prompt.ask(
|
||||
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
|
||||
default=default_layers
|
||||
)
|
||||
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
|
||||
early_stopping = None
|
||||
if Confirm.ask("Enable early stopping?", default=False):
|
||||
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||||
|
||||
arch_str = " → ".join(str(s) for s in [768] + hidden_sizes + [1])
|
||||
|
||||
# Confirm and start
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
console.print(f" Dataset: ds_v{dataset_version}")
|
||||
console.print(f" Architecture: {arch_str}")
|
||||
console.print(f" Epochs: {epochs}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||
if early_stopping:
|
||||
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||||
else:
|
||||
console.print(f" Early stopping: No")
|
||||
if use_checkpoint:
|
||||
console.print(f" Checkpoint: v{checkpoint_version}")
|
||||
else:
|
||||
console.print(f" Checkpoint: None (training from scratch)")
|
||||
|
||||
if not Confirm.ask("\nStart training?", default=True):
|
||||
console.print("[yellow]Training cancelled[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
# Execute training
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
try:
|
||||
console.print("\n[bold cyan]Training Model[/bold cyan]")
|
||||
checkpoint = None
|
||||
if use_checkpoint:
|
||||
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
|
||||
|
||||
train_nnue(
|
||||
data_file=str(labeled_file),
|
||||
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
checkpoint=checkpoint,
|
||||
use_versioning=True,
|
||||
early_stopping_patience=early_stopping,
|
||||
subsample_ratio=subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
console.print("[green]✓ Training complete[/green]")
|
||||
|
||||
# Show result
|
||||
available = list_checkpoints()
|
||||
new_version = max(available) if available else 1
|
||||
console.print(f"\n[bold green]✓ Training successful![/bold green]")
|
||||
console.print(f"[bold]New checkpoint: v{new_version}[/bold]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def burst_train_interactive():
|
||||
"""Interactive burst training menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
|
||||
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
|
||||
|
||||
# Dataset selection
|
||||
datasets = list_datasets()
|
||||
if not datasets:
|
||||
console.print("[red]✗ No datasets available. Create one first.[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print("[bold]Available Datasets:[/bold]")
|
||||
show_datasets_table(console)
|
||||
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
|
||||
|
||||
if not any(v == dataset_version for v, _ in datasets):
|
||||
console.print("[red]✗ Dataset not found[/red]")
|
||||
return
|
||||
|
||||
labeled_file = get_dataset_labeled_path(dataset_version)
|
||||
if not labeled_file:
|
||||
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
|
||||
return
|
||||
|
||||
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
|
||||
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
|
||||
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
|
||||
|
||||
# Optional initial checkpoint
|
||||
available = list_checkpoints()
|
||||
checkpoint = None
|
||||
if available:
|
||||
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||
if Confirm.ask("Start from an existing checkpoint?", default=False):
|
||||
version = Prompt.ask("Enter checkpoint version", default=str(max(available)))
|
||||
checkpoint = str(get_weights_dir() / f"nnue_weights_v{version}.pt")
|
||||
|
||||
# Training hyperparameters
|
||||
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
|
||||
hidden_layers_str = Prompt.ask(
|
||||
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
|
||||
default=default_layers
|
||||
)
|
||||
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
|
||||
arch_str = " → ".join(str(s) for s in [768] + hidden_sizes + [1])
|
||||
|
||||
# Summary
|
||||
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||
console.print(f" Dataset: ds_v{dataset_version}")
|
||||
console.print(f" Architecture: {arch_str}")
|
||||
console.print(f" Duration: {duration_minutes:.0f} minutes")
|
||||
console.print(f" Epochs per season: {epochs_per_season}")
|
||||
console.print(f" Patience: {early_stopping_patience}")
|
||||
console.print(f" Batch size: {batch_size}")
|
||||
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
|
||||
|
||||
if not Confirm.ask("\nStart burst training?", default=True):
|
||||
console.print("[yellow]Burst training cancelled[/yellow]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
weights_dir = get_weights_dir()
|
||||
|
||||
try:
|
||||
console.print("\n[bold cyan]Burst Training[/bold cyan]")
|
||||
burst_train(
|
||||
data_file=str(labeled_file),
|
||||
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||
duration_minutes=duration_minutes,
|
||||
epochs_per_season=epochs_per_season,
|
||||
early_stopping_patience=early_stopping_patience,
|
||||
batch_size=batch_size,
|
||||
initial_checkpoint=checkpoint,
|
||||
use_versioning=True,
|
||||
subsample_ratio=subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
console.print("[green]✓ Burst training complete[/green]")
|
||||
|
||||
available = list_checkpoints()
|
||||
if available:
|
||||
console.print(f"[bold]Latest checkpoint: v{max(available)}[/bold]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def export_interactive():
|
||||
"""Interactive export menu."""
|
||||
console = Console()
|
||||
show_header()
|
||||
|
||||
console.print("\n[bold cyan]📦 Export Configuration[/bold cyan]")
|
||||
|
||||
# Select weights version
|
||||
available = list_checkpoints()
|
||||
if not available:
|
||||
console.print("[red]✗ No checkpoints available to export[/red]")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
return
|
||||
|
||||
console.print(f"[dim]Available versions: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||
version = Prompt.ask("Enter version to export (e.g., 2)")
|
||||
|
||||
weights_file = f"nnue_weights_v{version}.pt"
|
||||
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.nbai")
|
||||
|
||||
console.print(f"\n[bold]Export Configuration:[/bold]")
|
||||
console.print(f" Source: {weights_file}")
|
||||
console.print(f" Destination: {output_file}")
|
||||
|
||||
if not Confirm.ask("\nExport weights?", default=True):
|
||||
console.print("[yellow]Export cancelled[/yellow]")
|
||||
return
|
||||
|
||||
try:
|
||||
weights_dir = get_weights_dir()
|
||||
weights_path = weights_dir / weights_file
|
||||
|
||||
if not weights_path.exists():
|
||||
console.print(f"[red]✗ {weights_file} not found[/red]")
|
||||
return
|
||||
|
||||
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
|
||||
export_to_nbai(str(weights_path), output_file)
|
||||
console.print(f"\n[green]✓ Export complete![/green]")
|
||||
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error: {e}[/red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Prompt.ask("Press Enter to continue")
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
show_main_menu()
|
||||
return 0
|
||||
except KeyboardInterrupt:
|
||||
console = Console()
|
||||
console.print("\n[yellow]Interrupted by user[/yellow]")
|
||||
return 1
|
||||
except Exception as e:
|
||||
console = Console()
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,6 @@
|
||||
chess==1.11.2
|
||||
torch==2.11.0
|
||||
tqdm==4.67.3
|
||||
numpy==2.4.4
|
||||
rich==13.7.0
|
||||
zstandard==0.23.0
|
||||
@@ -0,0 +1,66 @@
|
||||
@echo off
|
||||
REM NNUE Training Pipeline for Windows
|
||||
|
||||
setlocal enabledelayedexpansion
|
||||
|
||||
echo.
|
||||
echo === NNUE Training Pipeline ===
|
||||
echo.
|
||||
|
||||
REM Get the directory where this script is located
|
||||
set SCRIPT_DIR=%~dp0
|
||||
|
||||
cd /d "%SCRIPT_DIR%"
|
||||
|
||||
REM Step 1: Generate positions
|
||||
echo Step 1: Generating 500,000 random positions...
|
||||
python generate_positions.py positions.txt
|
||||
if not exist positions.txt (
|
||||
echo ERROR: positions.txt not created
|
||||
exit /b 1
|
||||
)
|
||||
echo [OK] Positions generated
|
||||
echo.
|
||||
|
||||
REM Step 2: Label positions with Stockfish
|
||||
echo Step 2: Labeling positions with Stockfish (depth 12^)...
|
||||
if "%STOCKFISH_PATH%"=="" (
|
||||
set STOCKFISH_PATH=stockfish
|
||||
)
|
||||
python label_positions.py positions.txt training_data.jsonl "%STOCKFISH_PATH%"
|
||||
if not exist training_data.jsonl (
|
||||
echo ERROR: training_data.jsonl not created
|
||||
exit /b 1
|
||||
)
|
||||
echo [OK] Positions labeled
|
||||
echo.
|
||||
|
||||
REM Step 3: Train NNUE model
|
||||
echo Step 3: Training NNUE model (20 epochs^)...
|
||||
python train_nnue.py training_data.jsonl nnue_weights.pt
|
||||
if not exist nnue_weights.pt (
|
||||
echo ERROR: nnue_weights.pt not created
|
||||
exit /b 1
|
||||
)
|
||||
echo [OK] Model trained
|
||||
echo.
|
||||
|
||||
REM Step 4: Export weights to Scala
|
||||
echo Step 4: Exporting weights to Scala...
|
||||
python export_weights.py nnue_weights.pt ..\src\main\scala\de\nowchess\bot\bots\nnue\NNUEWeights.scala
|
||||
if not exist ..\src\main\scala\de\nowchess\bot\bots\nnue\NNUEWeights.scala (
|
||||
echo ERROR: NNUEWeights.scala not created
|
||||
exit /b 1
|
||||
)
|
||||
echo [OK] Weights exported
|
||||
echo.
|
||||
|
||||
echo === Pipeline Complete ===
|
||||
echo.
|
||||
echo Next steps:
|
||||
echo 1. Navigate to project root: cd ..\..
|
||||
echo 2. Compile: .\compile.bat
|
||||
echo 3. Test: .\test.bat
|
||||
echo.
|
||||
|
||||
endlocal
|
||||
@@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
|
||||
# NNUE Training Pipeline (bash version)
|
||||
# Uses the central CLI (nnue.py) for all operations
|
||||
# Works on Linux, macOS, and Windows (with Git Bash or WSL)
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Use python or python3 (check which is available)
|
||||
PYTHON_CMD="python3"
|
||||
if ! command -v python3 &> /dev/null; then
|
||||
PYTHON_CMD="python"
|
||||
fi
|
||||
|
||||
echo "=== NNUE Training Pipeline ==="
|
||||
echo ""
|
||||
echo "Python command: $PYTHON_CMD"
|
||||
echo "Working directory: $SCRIPT_DIR"
|
||||
echo ""
|
||||
|
||||
# Run the unified training pipeline
|
||||
$PYTHON_CMD nnue.py train
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo ""
|
||||
echo "ERROR: Training pipeline failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== Pipeline Complete ==="
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo "1. Navigate to project root: cd ../.."
|
||||
echo "2. Compile: ./compile"
|
||||
echo "3. Test: ./test"
|
||||
@@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Dataset versioning and management for NNUE training data."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
|
||||
def get_datasets_dir() -> Path:
|
||||
"""Get/create datasets directory."""
|
||||
datasets_dir = Path(__file__).parent.parent / "datasets"
|
||||
datasets_dir.mkdir(exist_ok=True)
|
||||
return datasets_dir
|
||||
|
||||
|
||||
def next_dataset_version() -> int:
|
||||
"""Find the next available dataset version number."""
|
||||
datasets_dir = get_datasets_dir()
|
||||
versions = []
|
||||
|
||||
for d in datasets_dir.iterdir():
|
||||
if d.is_dir() and d.name.startswith("ds_v"):
|
||||
try:
|
||||
v = int(d.name.split("_v")[1])
|
||||
versions.append(v)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return max(versions) + 1 if versions else 1
|
||||
|
||||
|
||||
def list_datasets() -> List[Tuple[int, Dict]]:
|
||||
"""List all datasets with their metadata.
|
||||
|
||||
Returns:
|
||||
List of (version, metadata_dict) tuples, sorted by version.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
datasets = []
|
||||
|
||||
for d in datasets_dir.iterdir():
|
||||
if d.is_dir() and d.name.startswith("ds_v"):
|
||||
try:
|
||||
v = int(d.name.split("_v")[1])
|
||||
metadata_file = d / "metadata.json"
|
||||
if metadata_file.exists():
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
datasets.append((v, metadata))
|
||||
except (ValueError, IndexError, json.JSONDecodeError):
|
||||
pass
|
||||
|
||||
return sorted(datasets, key=lambda x: x[0])
|
||||
|
||||
|
||||
def load_dataset_metadata(version: int) -> Optional[Dict]:
|
||||
"""Load metadata for a specific dataset version.
|
||||
|
||||
Returns:
|
||||
Metadata dict or None if not found.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
metadata_file = datasets_dir / f"ds_v{version}" / "metadata.json"
|
||||
|
||||
if not metadata_file.exists():
|
||||
return None
|
||||
|
||||
with open(metadata_file, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_dataset_metadata(version: int, metadata: Dict) -> None:
|
||||
"""Save metadata for a dataset version."""
|
||||
datasets_dir = get_datasets_dir()
|
||||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||
dataset_dir.mkdir(exist_ok=True)
|
||||
|
||||
metadata_file = dataset_dir / "metadata.json"
|
||||
with open(metadata_file, 'w') as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
|
||||
|
||||
def create_dataset(
|
||||
version: int,
|
||||
labeled_jsonl_path: str,
|
||||
sources: List[Dict],
|
||||
stockfish_depth: int = 12
|
||||
) -> Path:
|
||||
"""Create a new versioned dataset.
|
||||
|
||||
Args:
|
||||
version: Dataset version number
|
||||
labeled_jsonl_path: Path to labeled.jsonl to copy
|
||||
sources: List of source dicts (see plan for schema)
|
||||
stockfish_depth: Depth used for labeling
|
||||
|
||||
Returns:
|
||||
Path to the created dataset directory.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||
dataset_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Copy labeled data with deduplication (in case source has duplicates)
|
||||
source_path = Path(labeled_jsonl_path)
|
||||
if source_path.exists():
|
||||
dest_path = dataset_dir / "labeled.jsonl"
|
||||
seen_fens = set()
|
||||
unique_count = 0
|
||||
|
||||
with open(source_path, 'r') as src, open(dest_path, 'w') as dst:
|
||||
for line in src:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data.get('fen')
|
||||
if fen and fen not in seen_fens:
|
||||
dst.write(line)
|
||||
seen_fens.add(fen)
|
||||
unique_count += 1
|
||||
except json.JSONDecodeError:
|
||||
# Skip malformed lines
|
||||
pass
|
||||
|
||||
# Count positions
|
||||
total_positions = 0
|
||||
if (dataset_dir / "labeled.jsonl").exists():
|
||||
with open(dataset_dir / "labeled.jsonl", 'r') as f:
|
||||
total_positions = sum(1 for _ in f)
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"version": version,
|
||||
"created": datetime.now().isoformat(),
|
||||
"total_positions": total_positions,
|
||||
"stockfish_depth": stockfish_depth,
|
||||
"sources": sources
|
||||
}
|
||||
|
||||
save_dataset_metadata(version, metadata)
|
||||
return dataset_dir
|
||||
|
||||
|
||||
def extend_dataset(
|
||||
version: int,
|
||||
new_labeled_path: str,
|
||||
new_source_entry: Dict
|
||||
) -> bool:
|
||||
"""Extend an existing dataset with new labeled positions (with deduplication).
|
||||
|
||||
Args:
|
||||
version: Dataset version to extend
|
||||
new_labeled_path: Path to new labeled.jsonl to merge
|
||||
new_source_entry: Source entry to add to metadata
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||
|
||||
if not dataset_dir.exists():
|
||||
return False
|
||||
|
||||
labeled_file = dataset_dir / "labeled.jsonl"
|
||||
new_labeled_file = Path(new_labeled_path)
|
||||
|
||||
if not new_labeled_file.exists():
|
||||
return False
|
||||
|
||||
# Load existing FENs (dedup set) — must load entire file to avoid duplicates
|
||||
existing_fens = set()
|
||||
if labeled_file.exists():
|
||||
with open(labeled_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data.get('fen')
|
||||
if fen:
|
||||
existing_fens.add(fen)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Merge new positions, skipping duplicates
|
||||
new_count = 0
|
||||
new_lines = []
|
||||
with open(new_labeled_file, 'r') as f_new:
|
||||
for line in f_new:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data.get('fen')
|
||||
if fen and fen not in existing_fens:
|
||||
new_lines.append(line)
|
||||
existing_fens.add(fen)
|
||||
new_count += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Append only the new, unique positions
|
||||
if new_lines:
|
||||
with open(labeled_file, 'a') as f_append:
|
||||
for line in new_lines:
|
||||
f_append.write(line)
|
||||
|
||||
# Update metadata
|
||||
metadata = load_dataset_metadata(version)
|
||||
if metadata:
|
||||
# Count total positions
|
||||
total_positions = 0
|
||||
with open(labeled_file, 'r') as f:
|
||||
total_positions = sum(1 for _ in f)
|
||||
|
||||
metadata['total_positions'] = total_positions
|
||||
# Update the source entry with actual count of new positions added
|
||||
new_source_entry['actual_count'] = new_count
|
||||
metadata['sources'].append(new_source_entry)
|
||||
save_dataset_metadata(version, metadata)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_dataset_labeled_path(version: int) -> Optional[Path]:
|
||||
"""Get the path to a dataset's labeled.jsonl file.
|
||||
|
||||
Returns:
|
||||
Path to labeled.jsonl or None if dataset doesn't exist.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
labeled_file = datasets_dir / f"ds_v{version}" / "labeled.jsonl"
|
||||
|
||||
if labeled_file.exists():
|
||||
return labeled_file
|
||||
return None
|
||||
|
||||
|
||||
def delete_dataset(version: int) -> bool:
|
||||
"""Delete a dataset (recursively removes directory).
|
||||
|
||||
Args:
|
||||
version: Dataset version to delete
|
||||
|
||||
Returns:
|
||||
True if successful.
|
||||
"""
|
||||
datasets_dir = get_datasets_dir()
|
||||
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||
|
||||
if not dataset_dir.exists():
|
||||
return False
|
||||
|
||||
import shutil
|
||||
shutil.rmtree(dataset_dir)
|
||||
return True
|
||||
|
||||
|
||||
def show_datasets_table(console: Console = None) -> None:
|
||||
"""Display all datasets in a Rich table."""
|
||||
if console is None:
|
||||
console = Console()
|
||||
|
||||
datasets = list_datasets()
|
||||
|
||||
if not datasets:
|
||||
console.print("[yellow]ℹ No datasets found yet[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="Available Datasets", show_header=True, header_style="bold cyan")
|
||||
table.add_column("Version", style="dim")
|
||||
table.add_column("Positions", justify="right")
|
||||
table.add_column("Sources", justify="left")
|
||||
table.add_column("Depth", justify="center")
|
||||
table.add_column("Created", justify="left")
|
||||
|
||||
for v, metadata in datasets:
|
||||
positions = metadata.get('total_positions', 0)
|
||||
sources = metadata.get('sources', [])
|
||||
source_str = ", ".join([s.get('type', '?') for s in sources])
|
||||
depth = metadata.get('stockfish_depth', '?')
|
||||
created = metadata.get('created', '?')
|
||||
if created != '?':
|
||||
created = created.split('T')[0] # Just the date
|
||||
|
||||
table.add_row(f"v{v}", f"{positions:,}", source_str, str(depth), created)
|
||||
|
||||
console.print(table)
|
||||
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Export NNUE weights to .nbai format for runtime loading."""
|
||||
|
||||
import json
|
||||
import struct
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
MAGIC = 0x4942_414E # bytes 'N','B','A','I' as little-endian int32
|
||||
VERSION = 1
|
||||
|
||||
|
||||
def _read_sidecar(weights_file: str) -> dict:
|
||||
sidecar = weights_file.replace(".pt", "_metadata.json")
|
||||
if Path(sidecar).exists():
|
||||
with open(sidecar) as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
|
||||
def _infer_layers(state_dict: dict) -> list[dict]:
|
||||
"""Derive layer descriptors from state_dict weight shapes.
|
||||
|
||||
Assumes layers named l1, l2, ..., lN.
|
||||
All hidden layers get activation 'relu'; the last gets 'linear'.
|
||||
"""
|
||||
names = sorted(
|
||||
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
|
||||
key=lambda n: int(n[1:]),
|
||||
)
|
||||
layers = []
|
||||
for i, name in enumerate(names):
|
||||
out_size, in_size = state_dict[f"{name}.weight"].shape
|
||||
activation = "linear" if i == len(names) - 1 else "relu"
|
||||
layers.append({"activation": activation, "inputSize": int(in_size), "outputSize": int(out_size)})
|
||||
return layers
|
||||
|
||||
|
||||
def _write_floats(f, tensor):
|
||||
data = tensor.float().flatten().cpu().numpy()
|
||||
f.write(struct.pack("<I", len(data)))
|
||||
f.write(struct.pack(f"<{len(data)}f", *data))
|
||||
|
||||
|
||||
def export_to_nbai(
|
||||
weights_file: str,
|
||||
output_file: str,
|
||||
trained_by: str = "unknown",
|
||||
train_loss: float = 0.0,
|
||||
):
|
||||
if not Path(weights_file).exists():
|
||||
print(f"Error: weights file not found at {weights_file}")
|
||||
sys.exit(1)
|
||||
|
||||
loaded = torch.load(weights_file, map_location="cpu")
|
||||
state_dict = (
|
||||
loaded["model_state_dict"]
|
||||
if isinstance(loaded, dict) and "model_state_dict" in loaded
|
||||
else loaded
|
||||
)
|
||||
|
||||
sidecar = _read_sidecar(weights_file)
|
||||
val_loss = float(loaded.get("best_val_loss", sidecar.get("final_val_loss", 0.0))) if isinstance(loaded, dict) else 0.0
|
||||
trained_at = sidecar.get("date", datetime.now().isoformat())
|
||||
training_data_count = int(sidecar.get("num_positions", 0))
|
||||
|
||||
metadata = {
|
||||
"trainedBy": trained_by,
|
||||
"trainedAt": trained_at,
|
||||
"trainingDataCount": training_data_count,
|
||||
"valLoss": val_loss,
|
||||
"trainLoss": train_loss,
|
||||
}
|
||||
|
||||
layers = _infer_layers(state_dict)
|
||||
layer_names = sorted(
|
||||
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
|
||||
key=lambda n: int(n[1:]),
|
||||
)
|
||||
|
||||
print(f"Architecture ({len(layers)} layers):")
|
||||
for i, l in enumerate(layers):
|
||||
print(f" l{i + 1}: {l['inputSize']} -> {l['outputSize']} [{l['activation']}]")
|
||||
|
||||
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_file, "wb") as f:
|
||||
# Header
|
||||
f.write(struct.pack("<I", MAGIC))
|
||||
f.write(struct.pack("<H", VERSION))
|
||||
|
||||
# Metadata (length-prefixed UTF-8 JSON)
|
||||
meta_bytes = json.dumps(metadata, indent=2).encode("utf-8")
|
||||
f.write(struct.pack("<I", len(meta_bytes)))
|
||||
f.write(meta_bytes)
|
||||
|
||||
# Layer descriptors
|
||||
f.write(struct.pack("<H", len(layers)))
|
||||
for layer in layers:
|
||||
name_bytes = layer["activation"].encode("ascii")
|
||||
f.write(struct.pack("<B", len(name_bytes)))
|
||||
f.write(name_bytes)
|
||||
f.write(struct.pack("<I", layer["inputSize"]))
|
||||
f.write(struct.pack("<I", layer["outputSize"]))
|
||||
|
||||
# Weights: weight tensor then bias tensor per layer
|
||||
for name in layer_names:
|
||||
w = state_dict[f"{name}.weight"]
|
||||
b = state_dict[f"{name}.bias"]
|
||||
_write_floats(f, w)
|
||||
_write_floats(f, b)
|
||||
print(f" Wrote {name}: weight {tuple(w.shape)}, bias {tuple(b.shape)}")
|
||||
|
||||
size_mb = Path(output_file).stat().st_size / (1024 ** 2)
|
||||
print(f"\nExported to {output_file} ({size_mb:.2f} MB)")
|
||||
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
weights_file = "nnue_weights.pt"
|
||||
output_file = "../src/main/resources/nnue_weights.nbai"
|
||||
trained_by = "unknown"
|
||||
train_loss = 0.0
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
weights_file = sys.argv[1]
|
||||
if len(sys.argv) > 2:
|
||||
output_file = sys.argv[2]
|
||||
if len(sys.argv) > 3:
|
||||
trained_by = sys.argv[3]
|
||||
if len(sys.argv) > 4:
|
||||
train_loss = float(sys.argv[4])
|
||||
|
||||
export_to_nbai(weights_file, output_file, trained_by, train_loss)
|
||||
@@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate random chess positions for NNUE training with multiprocessing."""
|
||||
|
||||
import chess
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool, Queue
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
def _worker_generate_games(worker_id, games_per_worker, samples_per_game, min_move, max_move):
|
||||
"""Generate games for one worker.
|
||||
|
||||
Returns:
|
||||
list of FENs generated by this worker
|
||||
"""
|
||||
positions = []
|
||||
|
||||
for game_num in range(games_per_worker):
|
||||
board = chess.Board()
|
||||
move_history = []
|
||||
|
||||
# Play a complete random game
|
||||
while not board.is_game_over() and len(move_history) < 200:
|
||||
legal_moves = list(board.legal_moves)
|
||||
if not legal_moves:
|
||||
break
|
||||
move = random.choice(legal_moves)
|
||||
board.push(move)
|
||||
move_history.append(board.copy())
|
||||
|
||||
# Determine the range of moves to sample from
|
||||
game_length = len(move_history)
|
||||
valid_start = max(min_move, 0)
|
||||
valid_end = min(max_move, game_length)
|
||||
|
||||
if valid_start >= valid_end:
|
||||
continue
|
||||
|
||||
# Randomly sample positions from this game
|
||||
sample_count = min(samples_per_game, valid_end - valid_start)
|
||||
if sample_count > 0:
|
||||
sample_indices = random.sample(
|
||||
range(valid_start, valid_end),
|
||||
k=sample_count
|
||||
)
|
||||
|
||||
for idx in sample_indices:
|
||||
sampled_board = move_history[idx]
|
||||
|
||||
# Only filter truly invalid or terminal positions
|
||||
if not sampled_board.is_valid() or sampled_board.is_game_over():
|
||||
continue
|
||||
|
||||
# Save position (include check, captures, all positions)
|
||||
fen = sampled_board.fen()
|
||||
positions.append(fen)
|
||||
|
||||
return positions
|
||||
|
||||
|
||||
def play_random_game_and_collect_positions(
|
||||
output_file,
|
||||
total_positions=3000000,
|
||||
samples_per_game=1,
|
||||
min_move=1,
|
||||
max_move=50,
|
||||
num_workers=8
|
||||
):
|
||||
"""Generate positions using multiprocessing with multiple workers.
|
||||
|
||||
Args:
|
||||
output_file: Output file for positions
|
||||
total_positions: Target number of positions to generate
|
||||
samples_per_game: Number of positions to sample per game (1-N)
|
||||
min_move: Minimum move number to start sampling from
|
||||
max_move: Maximum move number for sampling
|
||||
num_workers: Number of parallel worker processes
|
||||
|
||||
Returns:
|
||||
Number of valid positions saved
|
||||
"""
|
||||
# Estimate games needed (roughly 1 position per game on average)
|
||||
total_games = max(total_positions // samples_per_game, num_workers)
|
||||
games_per_worker = total_games // num_workers
|
||||
|
||||
print(f"Generating {total_positions:,} positions using {num_workers} workers")
|
||||
print(f"Total games: ~{total_games:,} ({games_per_worker:,} per worker)")
|
||||
print()
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
# Generate positions in parallel
|
||||
worker_tasks = [
|
||||
(i, games_per_worker, samples_per_game, min_move, max_move)
|
||||
for i in range(num_workers)
|
||||
]
|
||||
|
||||
positions_count = 0
|
||||
all_positions = []
|
||||
|
||||
with Pool(num_workers) as pool:
|
||||
with tqdm(total=num_workers, desc="Workers generating games") as pbar:
|
||||
for positions in pool.starmap(_worker_generate_games, worker_tasks):
|
||||
all_positions.extend(positions)
|
||||
positions_count += len(positions)
|
||||
pbar.update(1)
|
||||
|
||||
# Write all positions to file
|
||||
print(f"Writing {positions_count:,} positions to {output_file}...")
|
||||
with open(output_file, 'w') as f:
|
||||
for fen in all_positions:
|
||||
f.write(fen + '\n')
|
||||
|
||||
elapsed_time = datetime.now() - start_time
|
||||
elapsed_seconds = elapsed_time.total_seconds()
|
||||
positions_per_second = positions_count / elapsed_seconds if elapsed_seconds > 0 else 0
|
||||
|
||||
# Print summary
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("POSITION GENERATION SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Target positions: {total_positions:,}")
|
||||
print(f"Actual positions saved: {positions_count:,}")
|
||||
print(f"Workers: {num_workers}")
|
||||
print(f"Games per worker: {games_per_worker:,}")
|
||||
print(f"Samples per game: {samples_per_game}")
|
||||
print(f"Move range: {min_move}-{max_move}")
|
||||
print(f"Elapsed time: {elapsed_time}")
|
||||
print(f"Throughput: {positions_per_second:.0f} positions/second")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if positions_count == 0:
|
||||
print("WARNING: No valid positions were generated!")
|
||||
return 0
|
||||
|
||||
return positions_count
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training")
|
||||
parser.add_argument("output_file", nargs="?", default="positions.txt",
|
||||
help="Output file for positions (default: positions.txt)")
|
||||
parser.add_argument("--positions", type=int, default=3000000,
|
||||
help="Target number of positions to generate (default: 3000000)")
|
||||
parser.add_argument("--samples-per-game", type=int, default=1,
|
||||
help="Number of positions to sample per game (default: 1)")
|
||||
parser.add_argument("--min-move", type=int, default=1,
|
||||
help="Minimum move number to sample from (default: 1)")
|
||||
parser.add_argument("--max-move", type=int, default=50,
|
||||
help="Maximum move number to sample from (default: 50)")
|
||||
parser.add_argument("--workers", type=int, default=8,
|
||||
help="Number of parallel worker processes (default: 8)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
count = play_random_game_and_collect_positions(
|
||||
output_file=args.output_file,
|
||||
total_positions=args.positions,
|
||||
samples_per_game=args.samples_per_game,
|
||||
min_move=args.min_move,
|
||||
max_move=args.max_move,
|
||||
num_workers=args.workers
|
||||
)
|
||||
|
||||
sys.exit(0 if count > 0 else 1)
|
||||
@@ -0,0 +1,326 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Label positions with Stockfish evaluations and analyze distribution."""
|
||||
|
||||
import json
|
||||
import chess.engine
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
|
||||
def normalize_evaluation(cp_value, method='tanh', scale=300.0):
|
||||
"""Normalize centipawn evaluation to a bounded range.
|
||||
|
||||
Args:
|
||||
cp_value: Centipawn evaluation from Stockfish
|
||||
method: 'tanh' (default) or 'sigmoid'
|
||||
scale: Scale factor (tanh: 300 is typical)
|
||||
|
||||
Returns:
|
||||
Normalized value in approximately [-1, 1] (tanh) or [0, 1] (sigmoid)
|
||||
"""
|
||||
if method == 'tanh':
|
||||
return np.tanh(cp_value / scale)
|
||||
elif method == 'sigmoid':
|
||||
return 1.0 / (1.0 + np.exp(-cp_value / scale))
|
||||
else:
|
||||
return cp_value / 100.0
|
||||
|
||||
def _evaluate_fen_batch(args):
|
||||
"""Worker function to evaluate a batch of FENs with Stockfish threading.
|
||||
|
||||
Args:
|
||||
args: tuple of (fens, stockfish_path, depth, normalize)
|
||||
|
||||
Returns:
|
||||
list of (fen, eval_normalized, eval_raw) tuples
|
||||
"""
|
||||
fens, stockfish_path, depth, normalize = args
|
||||
|
||||
results = []
|
||||
|
||||
try:
|
||||
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
try:
|
||||
for fen in fens:
|
||||
try:
|
||||
board = chess.Board(fen)
|
||||
if not board.is_valid():
|
||||
continue
|
||||
|
||||
info = engine.analyse(board, chess.engine.Limit(depth=depth))
|
||||
|
||||
if info.get('score') is None:
|
||||
continue
|
||||
|
||||
score = info['score'].white()
|
||||
|
||||
if score.is_mate():
|
||||
eval_cp = 2000 if score.mate() > 0 else -2000
|
||||
else:
|
||||
eval_cp = score.cp
|
||||
|
||||
eval_cp = max(-2000, min(2000, eval_cp))
|
||||
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
|
||||
|
||||
results.append((fen, eval_normalized, eval_cp))
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
finally:
|
||||
engine.quit()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=1000, depth=12, verbose=False, normalize=True, num_workers=1):
|
||||
"""Read positions and label them with Stockfish evaluations.
|
||||
|
||||
Args:
|
||||
positions_file: Path to positions.txt
|
||||
output_file: Path to training_data.jsonl
|
||||
stockfish_path: Path to stockfish binary
|
||||
batch_size: Batch size for processing (positions per worker task, default: 1000)
|
||||
depth: Stockfish depth
|
||||
verbose: Print detailed error messages
|
||||
normalize: If True, normalize evals using tanh
|
||||
num_workers: Number of parallel Stockfish processes
|
||||
"""
|
||||
|
||||
# Check if stockfish exists
|
||||
if not Path(stockfish_path).exists():
|
||||
print(f"Error: Stockfish not found at {stockfish_path}")
|
||||
print(f"Tried: {stockfish_path}")
|
||||
print(f"Set STOCKFISH_PATH environment variable or pass as argument")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using Stockfish: {stockfish_path}")
|
||||
print(f"Number of workers: {num_workers}")
|
||||
|
||||
# Check if positions file exists
|
||||
if not Path(positions_file).exists():
|
||||
print(f"Error: Positions file not found at {positions_file}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load existing evaluations if resuming
|
||||
evaluated_fens = set()
|
||||
position_count = 0
|
||||
|
||||
if Path(output_file).exists():
|
||||
with open(output_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
evaluated_fens.add(data['fen'])
|
||||
position_count += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
print(f"Resuming from {position_count} already evaluated positions")
|
||||
|
||||
# Load all FENs that need evaluation
|
||||
fens_to_evaluate = []
|
||||
fens_seen_in_batch = set() # Track duplicates within current batch
|
||||
skipped_invalid = 0
|
||||
skipped_duplicate = 0
|
||||
|
||||
with open(positions_file, 'r') as f:
|
||||
for fen in f:
|
||||
fen = fen.strip()
|
||||
|
||||
if not fen:
|
||||
skipped_invalid += 1
|
||||
continue
|
||||
|
||||
if fen in evaluated_fens:
|
||||
skipped_duplicate += 1
|
||||
continue
|
||||
|
||||
if fen in fens_seen_in_batch:
|
||||
skipped_duplicate += 1
|
||||
continue
|
||||
|
||||
fens_to_evaluate.append(fen)
|
||||
fens_seen_in_batch.add(fen)
|
||||
|
||||
total_to_evaluate = len(fens_to_evaluate)
|
||||
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
|
||||
|
||||
if total_to_evaluate == 0:
|
||||
if position_count == 0:
|
||||
print(f"Error: No valid positions to evaluate in {positions_file}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"All positions already evaluated. No new positions to process.")
|
||||
return True
|
||||
|
||||
print(f"Total positions to process: {total_lines}")
|
||||
print(f"New positions to evaluate: {total_to_evaluate}")
|
||||
print(f"Using depth: {depth}")
|
||||
print()
|
||||
|
||||
# Split FENs into batches for workers
|
||||
batches = []
|
||||
for i in range(0, total_to_evaluate, batch_size):
|
||||
batch = fens_to_evaluate[i:i+batch_size]
|
||||
batches.append((batch, stockfish_path, depth, normalize))
|
||||
|
||||
# Process batches in parallel
|
||||
evaluated = 0
|
||||
errors = 0
|
||||
raw_evals = []
|
||||
normalized_evals = []
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
with Pool(num_workers) as pool:
|
||||
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar:
|
||||
with open(output_file, 'a') as out:
|
||||
for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
|
||||
for fen, eval_normalized, eval_cp in batch_results:
|
||||
# Skip if already evaluated in output file during this run
|
||||
if fen in evaluated_fens:
|
||||
continue
|
||||
|
||||
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
|
||||
out.write(json.dumps(data) + '\n')
|
||||
evaluated_fens.add(fen) # Track as evaluated
|
||||
evaluated += 1
|
||||
raw_evals.append(eval_cp)
|
||||
normalized_evals.append(eval_normalized)
|
||||
pbar.update(1)
|
||||
|
||||
# Update progress for any failed evaluations in the batch
|
||||
batch_size_actual = len(batches[0][0]) if batches else batch_size
|
||||
failed = batch_size_actual - len(batch_results)
|
||||
if failed > 0:
|
||||
errors += failed
|
||||
pbar.update(failed)
|
||||
|
||||
# Calculate and show throughput and ETA
|
||||
elapsed = time.time() - start_time
|
||||
throughput = evaluated / elapsed if elapsed > 0 else 0
|
||||
remaining_positions = total_to_evaluate - evaluated
|
||||
eta_seconds = remaining_positions / throughput if throughput > 0 else 0
|
||||
eta_str = f"{int(eta_seconds // 60)}:{int(eta_seconds % 60):02d}"
|
||||
|
||||
if (batch_idx + 1) % max(1, len(batches) // 10) == 0:
|
||||
pbar.set_postfix({
|
||||
'rate': f'{throughput:.0f} pos/s',
|
||||
'eta': eta_str
|
||||
})
|
||||
|
||||
# Print summary and analysis
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("LABELING SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Successfully evaluated: {evaluated}")
|
||||
print(f"Skipped (duplicates): {skipped_duplicate}")
|
||||
print(f"Skipped (invalid): {skipped_invalid}")
|
||||
print(f"Errors: {errors}")
|
||||
print(f"Total processed: {evaluated + skipped_duplicate + skipped_invalid + errors}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
if evaluated == 0:
|
||||
print("WARNING: No positions were successfully evaluated!")
|
||||
print("Check that:")
|
||||
print(" 1. positions.txt is not empty")
|
||||
print(" 2. positions.txt contains valid FENs")
|
||||
print(" 3. Stockfish is installed and working")
|
||||
print(" 4. Stockfish path is correct")
|
||||
return False
|
||||
|
||||
# Print distribution analysis
|
||||
if raw_evals:
|
||||
raw_evals_arr = np.array(raw_evals)
|
||||
norm_evals_arr = np.array(normalized_evals)
|
||||
|
||||
print("=" * 60)
|
||||
print("EVALUATION DISTRIBUTION ANALYSIS")
|
||||
print("=" * 60)
|
||||
print()
|
||||
print("Raw Evaluations (centipawns):")
|
||||
print(f" Min: {raw_evals_arr.min():.1f}")
|
||||
print(f" Max: {raw_evals_arr.max():.1f}")
|
||||
print(f" Mean: {raw_evals_arr.mean():.1f}")
|
||||
print(f" Median: {np.median(raw_evals_arr):.1f}")
|
||||
print(f" Std: {raw_evals_arr.std():.1f}")
|
||||
print()
|
||||
|
||||
print("Normalized Evaluations (tanh):")
|
||||
print(f" Min: {norm_evals_arr.min():.4f}")
|
||||
print(f" Max: {norm_evals_arr.max():.4f}")
|
||||
print(f" Mean: {norm_evals_arr.mean():.4f}")
|
||||
print(f" Median: {np.median(norm_evals_arr):.4f}")
|
||||
print(f" Std: {norm_evals_arr.std():.4f}")
|
||||
print()
|
||||
|
||||
# Distribution buckets
|
||||
print("Raw Evaluation Buckets (counts):")
|
||||
buckets = [
|
||||
(-float('inf'), -500, "< -5.00"),
|
||||
(-500, -300, "[-5.00, -3.00)"),
|
||||
(-300, -100, "[-3.00, -1.00)"),
|
||||
(-100, 0, "[-1.00, 0.00)"),
|
||||
(0, 100, "[0.00, 1.00)"),
|
||||
(100, 300, "[1.00, 3.00)"),
|
||||
(300, 500, "[3.00, 5.00)"),
|
||||
(500, float('inf'), "> 5.00"),
|
||||
]
|
||||
for low, high, label in buckets:
|
||||
count = np.sum((raw_evals_arr > low) & (raw_evals_arr <= high))
|
||||
pct = 100.0 * count / len(raw_evals_arr)
|
||||
print(f" {label}: {count:6d} ({pct:5.1f}%)")
|
||||
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
print(f"✓ Labeling complete. Output saved to {output_file}")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Label chess positions with Stockfish evaluations")
|
||||
parser.add_argument("positions_file", nargs="?", default="positions.txt",
|
||||
help="Input positions file (default: positions.txt)")
|
||||
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
|
||||
help="Output file (default: training_data.jsonl)")
|
||||
parser.add_argument("stockfish_path", nargs="?", default=None,
|
||||
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
||||
parser.add_argument("--depth", type=int, default=12,
|
||||
help="Stockfish depth (default: 12)")
|
||||
parser.add_argument("--batch-size", type=int, default=1000,
|
||||
help="Batch size for processing (default: 1000)")
|
||||
parser.add_argument("--no-normalize", action="store_true",
|
||||
help="Disable evaluation normalization (keep raw centipawns)")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Print detailed error messages")
|
||||
parser.add_argument("--workers", type=int, default=1,
|
||||
help="Number of parallel Stockfish processes (default: 1)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine Stockfish path
|
||||
stockfish_path = args.stockfish_path or os.environ.get("STOCKFISH_PATH", "stockfish")
|
||||
|
||||
success = label_positions_with_stockfish(
|
||||
positions_file=args.positions_file,
|
||||
output_file=args.output_file,
|
||||
stockfish_path=stockfish_path,
|
||||
batch_size=args.batch_size,
|
||||
depth=args.depth,
|
||||
normalize=not args.no_normalize,
|
||||
verbose=args.verbose,
|
||||
num_workers=args.workers
|
||||
)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Import pre-labeled positions from the Lichess evaluation database.
|
||||
|
||||
Source: https://database.lichess.org/#evals
|
||||
Format: lichess_db_eval.jsonl.zst — compressed JSONL, one position per line.
|
||||
|
||||
Each line:
|
||||
{
|
||||
"fen": "<pieces> <turn> <castling> <ep>",
|
||||
"evals": [
|
||||
{
|
||||
"knodes": <int>,
|
||||
"depth": <int>,
|
||||
"pvs": [{"cp": <int>, "line": "..."} | {"mate": <int>, "line": "..."}]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
cp and mate are from White's perspective (positive = White winning), matching
|
||||
the sign convention used by label.py (score.white()) and expected by train.py.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
MATE_CP = 20000
|
||||
SCALE = 300.0
|
||||
|
||||
|
||||
def _best_eval(evals: list) -> dict | None:
|
||||
"""Return the highest-depth evaluation entry, using knodes as tiebreaker."""
|
||||
if not evals:
|
||||
return None
|
||||
return max(evals, key=lambda e: (e.get("depth", 0), e.get("knodes", 0)))
|
||||
|
||||
|
||||
def _cp_from_pv(pv: dict) -> int | None:
|
||||
"""Extract centipawn value from a principal variation entry."""
|
||||
if "cp" in pv:
|
||||
return max(-MATE_CP, min(MATE_CP, pv["cp"]))
|
||||
if "mate" in pv:
|
||||
return MATE_CP if pv["mate"] > 0 else -MATE_CP
|
||||
return None
|
||||
|
||||
|
||||
def _normalize(cp: int) -> float:
|
||||
return float(np.tanh(cp / SCALE))
|
||||
|
||||
|
||||
def import_lichess_evals(
|
||||
input_path: str,
|
||||
output_file: str,
|
||||
max_positions: int | None = None,
|
||||
min_depth: int = 0,
|
||||
verbose: bool = False,
|
||||
) -> int:
|
||||
"""Stream the Lichess eval database and write a labeled.jsonl file.
|
||||
|
||||
Args:
|
||||
input_path: Path to lichess_db_eval.jsonl.zst (or uncompressed .jsonl).
|
||||
output_file: Destination labeled.jsonl (appended — supports resuming).
|
||||
max_positions: Stop after this many new positions (None = no limit).
|
||||
min_depth: Skip positions whose best eval has depth < min_depth.
|
||||
verbose: Print warnings for skipped lines.
|
||||
|
||||
Returns:
|
||||
Number of new positions written.
|
||||
"""
|
||||
import zstandard as zstd
|
||||
|
||||
input_path = Path(input_path)
|
||||
if not input_path.exists():
|
||||
print(f"Error: {input_path} not found")
|
||||
sys.exit(1)
|
||||
|
||||
# Resume: collect already-written FENs so we skip duplicates.
|
||||
seen_fens: set[str] = set()
|
||||
if Path(output_file).exists():
|
||||
with open(output_file, "r") as f:
|
||||
for line in f:
|
||||
try:
|
||||
seen_fens.add(json.loads(line)["fen"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
if seen_fens:
|
||||
print(f"Resuming — skipping {len(seen_fens):,} already-imported positions")
|
||||
|
||||
written = 0
|
||||
skipped_depth = 0
|
||||
skipped_no_eval = 0
|
||||
skipped_dup = 0
|
||||
|
||||
def iter_lines():
|
||||
"""Yield decoded text lines from either a .zst or plain .jsonl file."""
|
||||
import io
|
||||
if input_path.suffix == ".zst":
|
||||
dctx = zstd.ZstdDecompressor()
|
||||
with open(input_path, "rb") as fh:
|
||||
with dctx.stream_reader(fh) as reader:
|
||||
text_stream = io.TextIOWrapper(reader, encoding="utf-8")
|
||||
yield from text_stream
|
||||
else:
|
||||
with open(input_path, "r", encoding="utf-8") as fh:
|
||||
yield from fh
|
||||
|
||||
try:
|
||||
with open(output_file, "a") as out:
|
||||
with tqdm(desc="Importing Lichess evals", unit=" pos") as pbar:
|
||||
for raw_line in iter_lines():
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
if verbose:
|
||||
print("Warning: malformed JSON line skipped")
|
||||
continue
|
||||
|
||||
fen = data.get("fen", "")
|
||||
if not fen:
|
||||
skipped_no_eval += 1
|
||||
continue
|
||||
|
||||
if fen in seen_fens:
|
||||
skipped_dup += 1
|
||||
continue
|
||||
|
||||
best = _best_eval(data.get("evals", []))
|
||||
if best is None:
|
||||
skipped_no_eval += 1
|
||||
continue
|
||||
|
||||
if best.get("depth", 0) < min_depth:
|
||||
skipped_depth += 1
|
||||
continue
|
||||
|
||||
pvs = best.get("pvs", [])
|
||||
if not pvs:
|
||||
skipped_no_eval += 1
|
||||
continue
|
||||
|
||||
cp = _cp_from_pv(pvs[0])
|
||||
if cp is None:
|
||||
skipped_no_eval += 1
|
||||
continue
|
||||
|
||||
record = {
|
||||
"fen": fen,
|
||||
"eval": _normalize(cp),
|
||||
"eval_raw": cp,
|
||||
}
|
||||
out.write(json.dumps(record) + "\n")
|
||||
seen_fens.add(fen)
|
||||
written += 1
|
||||
pbar.update(1)
|
||||
|
||||
if max_positions and written >= max_positions:
|
||||
print(f"\nReached max_positions limit ({max_positions:,})")
|
||||
break
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("LICHESS IMPORT SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Positions written: {written:,}")
|
||||
print(f"Skipped (dup): {skipped_dup:,}")
|
||||
print(f"Skipped (no eval): {skipped_no_eval:,}")
|
||||
print(f"Skipped (depth<{min_depth}): {skipped_depth:,}")
|
||||
print("=" * 60)
|
||||
print(f"\n✓ Output: {output_file}")
|
||||
|
||||
return written
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Import Lichess pre-labeled positions into labeled.jsonl"
|
||||
)
|
||||
parser.add_argument("input_path",
|
||||
help="Path to lichess_db_eval.jsonl.zst")
|
||||
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
|
||||
help="Output labeled.jsonl (default: training_data.jsonl)")
|
||||
parser.add_argument("--max-positions", type=int, default=None,
|
||||
help="Stop after N positions (default: no limit)")
|
||||
parser.add_argument("--min-depth", type=int, default=0,
|
||||
help="Minimum eval depth to accept (default: 0)")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Print warnings for skipped lines")
|
||||
|
||||
args = parser.parse_args()
|
||||
count = import_lichess_evals(
|
||||
input_path=args.input_path,
|
||||
output_file=args.output_file,
|
||||
max_positions=args.max_positions,
|
||||
min_depth=args.min_depth,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
sys.exit(0 if count > 0 else 1)
|
||||
@@ -0,0 +1,249 @@
|
||||
import chess
|
||||
import csv
|
||||
import json
|
||||
import sys
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from typing import Set, Tuple
|
||||
|
||||
try:
|
||||
import zstandard as zstd
|
||||
except ImportError:
|
||||
print("zstandard library not found. Install with: pip install zstandard")
|
||||
sys.exit(1)
|
||||
|
||||
from generate import play_random_game_and_collect_positions
|
||||
|
||||
|
||||
def download_and_extract_puzzle_db(
|
||||
url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst',
|
||||
output_dir: str = 'tactical_data'
|
||||
):
|
||||
"""Download and extract the Lichess puzzle database."""
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
csv_file = output_path / 'lichess_db_puzzle.csv'
|
||||
zst_file = output_path / 'lichess_db_puzzle.csv.zst'
|
||||
|
||||
# Download if not already present
|
||||
if not zst_file.exists():
|
||||
print(f"Downloading puzzle database from {url}...")
|
||||
try:
|
||||
urllib.request.urlretrieve(url, zst_file)
|
||||
print(f"Downloaded to {zst_file}")
|
||||
except Exception as e:
|
||||
print(f"Failed to download: {e}")
|
||||
return None
|
||||
|
||||
# Extract if CSV doesn't exist
|
||||
if not csv_file.exists():
|
||||
print(f"Extracting {zst_file}...")
|
||||
try:
|
||||
with open(zst_file, 'rb') as f:
|
||||
dctx = zstd.ZstdDecompressor()
|
||||
with dctx.stream_reader(f) as reader:
|
||||
with open(csv_file, 'wb') as out:
|
||||
out.write(reader.read())
|
||||
print(f"Extracted to {csv_file}")
|
||||
except Exception as e:
|
||||
print(f"Failed to extract: {e}")
|
||||
return None
|
||||
|
||||
return str(csv_file)
|
||||
|
||||
|
||||
def extract_puzzle_positions(
|
||||
puzzle_csv: str,
|
||||
max_puzzles: int = 300_000
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Extract the position BEFORE the blunder from each puzzle.
|
||||
This is exactly the type of position where tactical
|
||||
recognition matters most.
|
||||
|
||||
Returns a set of unique FENs.
|
||||
"""
|
||||
positions = set()
|
||||
|
||||
with open(puzzle_csv) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if len(positions) >= max_puzzles:
|
||||
break
|
||||
|
||||
try:
|
||||
board = chess.Board(row['FEN'])
|
||||
|
||||
# The puzzle FEN is AFTER the blunder move
|
||||
# We want the position BEFORE — so it learns
|
||||
# to find the tactic, not just play it
|
||||
moves = row['Moves'].split()
|
||||
|
||||
# Undo one move to get pre-tactic position
|
||||
board.push_uci(moves[0]) # opponent blunder
|
||||
fen = board.fen()
|
||||
|
||||
# Filter for useful tactical themes
|
||||
themes = row.get('Themes', '')
|
||||
useful = any(t in themes for t in [
|
||||
'fork', 'pin', 'skewer', 'discoveredAttack',
|
||||
'mate', 'mateIn2', 'mateIn3', 'hangingPiece',
|
||||
'trappedPiece', 'sacrifice'
|
||||
])
|
||||
|
||||
if useful:
|
||||
positions.add(fen)
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return positions
|
||||
|
||||
|
||||
def load_positions_from_file(file_path: str) -> Set[str]:
|
||||
"""Load positions from a text file (one FEN per line)."""
|
||||
positions = set()
|
||||
try:
|
||||
with open(file_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
positions.add(line)
|
||||
print(f"Loaded {len(positions)} positions from {file_path}")
|
||||
return positions
|
||||
except Exception as e:
|
||||
print(f"Failed to load from {file_path}: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
def merge_positions(
|
||||
tactical: Set[str],
|
||||
other: Set[str],
|
||||
output_file: str = 'position.txt'
|
||||
):
|
||||
"""Merge two position sets and write to file."""
|
||||
merged = tactical | other
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
for fen in merged:
|
||||
f.write(fen + '\n')
|
||||
|
||||
overlap = len(tactical & other)
|
||||
print(f"\n{'='*60}")
|
||||
print(f"MERGE SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
print(f"Tactical positions: {len(tactical):,}")
|
||||
print(f"Other positions: {len(other):,}")
|
||||
print(f"Overlap (deduplicated): {overlap:,}")
|
||||
print(f"Total merged positions: {len(merged):,}")
|
||||
print(f"Written to: {output_file}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def extract_tactical_only(
|
||||
puzzle_csv: str,
|
||||
output_file: str,
|
||||
max_puzzles: int = 300_000
|
||||
) -> int:
|
||||
"""Extract tactical positions and save to file (no merge prompts).
|
||||
|
||||
Args:
|
||||
puzzle_csv: Path to Lichess puzzle CSV
|
||||
output_file: Where to save the FEN positions
|
||||
max_puzzles: Maximum puzzles to extract
|
||||
|
||||
Returns:
|
||||
Number of positions extracted
|
||||
"""
|
||||
print("Extracting tactical positions from puzzle database...")
|
||||
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
for fen in tactical_positions:
|
||||
f.write(fen + '\n')
|
||||
|
||||
return len(tactical_positions)
|
||||
|
||||
|
||||
def interactive_merge_positions(
|
||||
puzzle_csv: str,
|
||||
output_file: str = 'position.txt',
|
||||
max_puzzles: int = 300_000
|
||||
):
|
||||
"""Interactive workflow: extract tactical positions and merge with user selection."""
|
||||
print("\n" + "="*60)
|
||||
print("TACTICAL POSITION EXTRACTOR & MERGER")
|
||||
print("="*60 + "\n")
|
||||
|
||||
# Extract tactical positions
|
||||
print("Extracting tactical positions from puzzle database...")
|
||||
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
|
||||
print(f"Extracted {len(tactical_positions):,} unique tactical positions\n")
|
||||
|
||||
# Ask what to merge with
|
||||
print("What would you like to merge with these tactical positions?")
|
||||
print("1. Load from a position file")
|
||||
print("2. Generate random positions")
|
||||
print("3. Skip merging (save tactical only)")
|
||||
|
||||
choice = input("\nEnter choice (1-3): ").strip()
|
||||
|
||||
other_positions = set()
|
||||
|
||||
if choice == '1':
|
||||
file_path = input("Enter path to position file: ").strip()
|
||||
other_positions = load_positions_from_file(file_path)
|
||||
|
||||
elif choice == '2':
|
||||
positions_to_gen = input("How many positions to generate? (default 1000000): ").strip()
|
||||
try:
|
||||
positions_to_gen = int(positions_to_gen) if positions_to_gen else 1000000
|
||||
except ValueError:
|
||||
positions_to_gen = 1000000
|
||||
|
||||
temp_file = 'temp_generated_positions.txt'
|
||||
print(f"\nGenerating {positions_to_gen:,} random positions...")
|
||||
play_random_game_and_collect_positions(
|
||||
output_file=temp_file,
|
||||
total_positions=positions_to_gen,
|
||||
samples_per_game=1,
|
||||
min_move=1,
|
||||
max_move=50,
|
||||
num_workers=8
|
||||
)
|
||||
other_positions = load_positions_from_file(temp_file)
|
||||
|
||||
elif choice == '3':
|
||||
print("Skipping merge, saving tactical positions only...")
|
||||
|
||||
else:
|
||||
print("Invalid choice, saving tactical positions only...")
|
||||
|
||||
merge_positions(tactical_positions, other_positions, output_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Extract and merge tactical positions")
|
||||
parser.add_argument("--url", default='https://database.lichess.org/lichess_db_puzzle.csv.zst',
|
||||
help="URL to download puzzle database from")
|
||||
parser.add_argument("--output-dir", default='trainingdata',
|
||||
help="Directory to extract puzzle database to")
|
||||
parser.add_argument("--max-puzzles", type=int, default=300_000,
|
||||
help="Maximum puzzles to extract (default: 300000)")
|
||||
parser.add_argument("--output-file", default='position.txt',
|
||||
help="Output file for merged positions (default: position.txt)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Download and extract
|
||||
csv_path = download_and_extract_puzzle_db(args.url, args.output_dir)
|
||||
|
||||
if csv_path:
|
||||
# Interactive merge
|
||||
interactive_merge_positions(csv_path, args.output_file, args.max_puzzles)
|
||||
else:
|
||||
print("Failed to download/extract puzzle database")
|
||||
sys.exit(1)
|
||||
@@ -0,0 +1,676 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train NNUE neural network for chess evaluation."""
|
||||
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import chess
|
||||
from datetime import datetime, timedelta
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
class NNUEDataset(Dataset):
|
||||
"""Dataset of chess positions with evaluations."""
|
||||
|
||||
def __init__(self, data_file):
|
||||
self.positions = []
|
||||
self.evals = []
|
||||
self.evals_raw = []
|
||||
self.is_normalized = None
|
||||
|
||||
with open(data_file, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
fen = data['fen']
|
||||
eval_val = data['eval']
|
||||
self.positions.append(fen)
|
||||
self.evals.append(eval_val)
|
||||
|
||||
# Check if normalized or raw
|
||||
if self.is_normalized is None:
|
||||
# If eval is in range [-1, 1], assume normalized
|
||||
self.is_normalized = abs(eval_val) <= 1.0
|
||||
|
||||
# Store raw if available
|
||||
if 'eval_raw' in data:
|
||||
self.evals_raw.append(data['eval_raw'])
|
||||
else:
|
||||
self.evals_raw.append(eval_val)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
return len(self.positions)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
fen = self.positions[idx]
|
||||
eval_val = self.evals[idx]
|
||||
features = fen_to_features(fen)
|
||||
|
||||
# Use evaluation as-is if normalized, otherwise apply sigmoid scaling
|
||||
if self.is_normalized:
|
||||
target = torch.tensor(eval_val, dtype=torch.float32)
|
||||
else:
|
||||
target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32))
|
||||
|
||||
return features, target
|
||||
|
||||
def fen_to_features(fen):
|
||||
"""Convert FEN to 768-dimensional binary feature vector."""
|
||||
# Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5
|
||||
piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
|
||||
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}
|
||||
|
||||
features = torch.zeros(768, dtype=torch.float32)
|
||||
|
||||
try:
|
||||
board = chess.Board(fen)
|
||||
|
||||
# 12 piece types × 64 squares = 768
|
||||
for square in chess.SQUARES:
|
||||
piece = board.piece_at(square)
|
||||
if piece is not None:
|
||||
piece_char = piece.symbol()
|
||||
if piece_char in piece_to_idx:
|
||||
piece_idx = piece_to_idx[piece_char]
|
||||
feature_idx = piece_idx * 64 + square
|
||||
features[feature_idx] = 1.0
|
||||
except:
|
||||
pass
|
||||
|
||||
return features
|
||||
|
||||
DEFAULT_HIDDEN_SIZES = [1536, 1024, 512, 256]
|
||||
|
||||
|
||||
class NNUE(nn.Module):
|
||||
"""NNUE neural network with configurable hidden layers.
|
||||
|
||||
Architecture: 768 → hidden_sizes[0] → ... → hidden_sizes[-1] → 1
|
||||
Layer attributes follow the naming l1, l2, ..., lN so export.py can
|
||||
infer the architecture directly from the state_dict.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_sizes=None, dropout_rate=0.2):
|
||||
super().__init__()
|
||||
if hidden_sizes is None:
|
||||
hidden_sizes = DEFAULT_HIDDEN_SIZES
|
||||
self.hidden_sizes = list(hidden_sizes)
|
||||
sizes = [768] + self.hidden_sizes + [1]
|
||||
num_hidden = len(self.hidden_sizes)
|
||||
|
||||
for i in range(num_hidden):
|
||||
setattr(self, f"l{i + 1}", nn.Linear(sizes[i], sizes[i + 1]))
|
||||
setattr(self, f"relu{i + 1}", nn.ReLU())
|
||||
setattr(self, f"drop{i + 1}", nn.Dropout(dropout_rate))
|
||||
setattr(self, f"l{num_hidden + 1}", nn.Linear(sizes[-2], sizes[-1]))
|
||||
self._num_hidden = num_hidden
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(1, self._num_hidden + 1):
|
||||
layer = getattr(self, f"l{i}")
|
||||
relu = getattr(self, f"relu{i}")
|
||||
drop = getattr(self, f"drop{i}")
|
||||
x = drop(relu(layer(x)))
|
||||
return getattr(self, f"l{self._num_hidden + 1}")(x)
|
||||
|
||||
def find_next_version(base_name="nnue_weights"):
|
||||
"""Find the next version number for model versioning.
|
||||
|
||||
Looks for nnue_weights_v*.pt files and returns the next version number.
|
||||
If no versioned files exist, returns 1.
|
||||
"""
|
||||
base_path = Path(base_name)
|
||||
directory = base_path.parent
|
||||
filename = base_path.name
|
||||
|
||||
pattern = re.compile(rf"{re.escape(filename)}_v(\d+)\.pt")
|
||||
versions = []
|
||||
|
||||
for file in directory.glob(f"{filename}_v*.pt"):
|
||||
match = pattern.match(file.name)
|
||||
if match:
|
||||
versions.append(int(match.group(1)))
|
||||
|
||||
if versions:
|
||||
return max(versions) + 1
|
||||
return 1
|
||||
|
||||
def save_metadata(weights_file, metadata):
|
||||
"""Save training metadata alongside the weights file.
|
||||
|
||||
Args:
|
||||
weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt)
|
||||
metadata: Dictionary with training info
|
||||
"""
|
||||
metadata_file = weights_file.replace(".pt", "_metadata.json")
|
||||
|
||||
with open(metadata_file, "w") as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
|
||||
return metadata_file
|
||||
|
||||
def _setup_training(data_file, batch_size, subsample_ratio):
|
||||
"""Set up device, dataset, and data loaders.
|
||||
|
||||
Returns:
|
||||
(device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions)
|
||||
"""
|
||||
print("Checking GPU availability...")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.is_available():
|
||||
print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
|
||||
print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
||||
else:
|
||||
print("⚠ GPU not available, using CPU")
|
||||
print(f"Using device: {device}")
|
||||
print()
|
||||
|
||||
print("Loading dataset...")
|
||||
dataset = NNUEDataset(data_file)
|
||||
num_positions = len(dataset)
|
||||
print(f"Dataset size: {num_positions}")
|
||||
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
|
||||
|
||||
evals_array = np.array(dataset.evals)
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("TRAINING DATASET DIAGNOSTICS")
|
||||
print("=" * 60)
|
||||
print(f"Min evaluation: {evals_array.min():.4f}")
|
||||
print(f"Max evaluation: {evals_array.max():.4f}")
|
||||
print(f"Mean evaluation: {evals_array.mean():.4f}")
|
||||
print(f"Median evaluation: {np.median(evals_array):.4f}")
|
||||
print(f"Std deviation: {evals_array.std():.4f}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
train_size = int(0.9 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
|
||||
from torch.utils.data import random_split, RandomSampler
|
||||
generator = torch.Generator().manual_seed(42)
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
|
||||
|
||||
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
||||
train_sampler = RandomSampler(train_dataset, replacement=False, num_samples=subsample_size)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
)
|
||||
|
||||
return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions
|
||||
|
||||
def _run_training_season(
|
||||
model, optimizer, scheduler, scaler,
|
||||
train_loader, val_loader, train_dataset, val_dataset,
|
||||
device, criterion, output_file,
|
||||
start_epoch, epochs, early_stopping_patience,
|
||||
season_start_time, deadline=None, initial_best_val_loss=float('inf')
|
||||
):
|
||||
"""Run one training season until epoch limit, early stopping, or deadline.
|
||||
|
||||
Args:
|
||||
initial_best_val_loss: Baseline to beat — epochs that don't improve on this count
|
||||
toward early stopping and do not save snapshots.
|
||||
Returns:
|
||||
(best_val_loss, best_model_state, last_epoch)
|
||||
best_model_state is None if no epoch beat initial_best_val_loss.
|
||||
"""
|
||||
best_val_loss = initial_best_val_loss
|
||||
best_model_state = None
|
||||
epochs_without_improvement = 0
|
||||
total_epochs = start_epoch + epochs
|
||||
last_epoch = start_epoch - 1
|
||||
|
||||
for epoch in range(start_epoch, start_epoch + epochs):
|
||||
if deadline and datetime.now() >= deadline:
|
||||
print("Time limit reached, stopping season.")
|
||||
break
|
||||
|
||||
epoch_display = epoch + 1
|
||||
|
||||
# Train
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
|
||||
for batch_features, batch_targets in train_loader:
|
||||
batch_features = batch_features.to(device)
|
||||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
train_loss += loss.item() * batch_features.size(0)
|
||||
pbar.update(1)
|
||||
|
||||
train_loss /= len(train_dataset)
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar:
|
||||
for batch_features, batch_targets in val_loader:
|
||||
batch_features = batch_features.to(device)
|
||||
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||
|
||||
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
outputs = model(batch_features)
|
||||
loss = criterion(outputs, batch_targets)
|
||||
val_loss += loss.item() * batch_features.size(0)
|
||||
pbar.update(1)
|
||||
|
||||
val_loss /= len(val_dataset)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9
|
||||
gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9
|
||||
print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved")
|
||||
|
||||
elapsed_time = datetime.now() - season_start_time
|
||||
time_per_epoch = elapsed_time.total_seconds() / (epoch + 1)
|
||||
remaining_epochs = total_epochs - epoch_display
|
||||
eta_seconds = time_per_epoch * remaining_epochs
|
||||
eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0]
|
||||
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}")
|
||||
|
||||
checkpoint_file = output_file.replace(".pt", "_checkpoint.pt")
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"scaler_state_dict": scaler.state_dict(),
|
||||
"best_val_loss": best_val_loss,
|
||||
"hidden_sizes": model.hidden_sizes,
|
||||
}, checkpoint_file)
|
||||
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_model_state = model.state_dict().copy()
|
||||
epochs_without_improvement = 0
|
||||
snapshot_file = output_file.replace(".pt", "_best_snapshot.pt")
|
||||
torch.save(best_model_state, snapshot_file)
|
||||
print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})")
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
|
||||
last_epoch = epoch
|
||||
|
||||
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
|
||||
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
|
||||
break
|
||||
|
||||
return best_val_loss, best_model_state, last_epoch
|
||||
|
||||
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||
best_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, training_start_time, hidden_sizes=None,
|
||||
extra_metadata=None):
|
||||
"""Save the best model with optional versioning and metadata."""
|
||||
final_output_file = output_file
|
||||
metadata = {}
|
||||
architecture = [768] + list(hidden_sizes or DEFAULT_HIDDEN_SIZES) + [1]
|
||||
|
||||
if use_versioning:
|
||||
base_name = output_file.replace(".pt", "")
|
||||
version = find_next_version(base_name)
|
||||
final_output_file = f"{base_name}_v{version}.pt"
|
||||
|
||||
metadata = {
|
||||
"version": version,
|
||||
"date": training_start_time.isoformat(),
|
||||
"num_positions": num_positions,
|
||||
"stockfish_depth": stockfish_depth,
|
||||
"final_val_loss": float(best_val_loss),
|
||||
"architecture": architecture,
|
||||
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
torch.save({
|
||||
"model_state_dict": best_model_state,
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"scaler_state_dict": scaler.state_dict(),
|
||||
"epoch": last_epoch,
|
||||
"best_val_loss": best_val_loss,
|
||||
"hidden_sizes": list(hidden_sizes or DEFAULT_HIDDEN_SIZES),
|
||||
}, final_output_file)
|
||||
print(f"Best model saved to {final_output_file}")
|
||||
|
||||
if use_versioning and metadata:
|
||||
metadata_file = save_metadata(final_output_file, metadata)
|
||||
print(f"Metadata saved to {metadata_file}")
|
||||
print(f"\nTraining Summary:")
|
||||
for key, val in metadata.items():
|
||||
print(f" {key}: {val}")
|
||||
|
||||
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
|
||||
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
|
||||
|
||||
Args:
|
||||
data_file: Path to training_data.jsonl
|
||||
output_file: Where to save best weights (or base name if use_versioning=True)
|
||||
epochs: Number of training epochs (default: 100)
|
||||
batch_size: Training batch size (default: 16384)
|
||||
lr: Learning rate (default: 0.001)
|
||||
checkpoint: Optional path to checkpoint file to resume from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
|
||||
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
|
||||
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
|
||||
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
|
||||
"""
|
||||
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||
_setup_training(data_file, batch_size, subsample_ratio)
|
||||
|
||||
start_epoch = 0
|
||||
best_val_loss = float('inf')
|
||||
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
|
||||
|
||||
if checkpoint:
|
||||
print(f"Loading checkpoint: {checkpoint}")
|
||||
ckpt = torch.load(checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
ckpt_hidden = ckpt.get("hidden_sizes")
|
||||
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
|
||||
print(f" Using architecture from checkpoint: {ckpt_hidden}")
|
||||
resolved_hidden_sizes = ckpt_hidden
|
||||
|
||||
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||
|
||||
if checkpoint:
|
||||
ckpt = torch.load(checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
||||
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
||||
scaler.load_state_dict(ckpt["scaler_state_dict"])
|
||||
start_epoch = ckpt["epoch"] + 1
|
||||
best_val_loss = ckpt.get("best_val_loss", float('inf'))
|
||||
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
print("Loaded weights-only checkpoint (no optimizer state)")
|
||||
|
||||
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
|
||||
|
||||
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
||||
arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
|
||||
print(f"Architecture: {arch_str}")
|
||||
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
|
||||
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
|
||||
print(f"Mixed precision training: enabled")
|
||||
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
|
||||
if subsample_ratio < 1.0:
|
||||
print(f"Stochastic sampling: {subsample_ratio:.0%} of train set per epoch ({subsample_size:,} positions)")
|
||||
if early_stopping_patience:
|
||||
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
|
||||
print()
|
||||
|
||||
training_start_time = datetime.now()
|
||||
|
||||
best_val_loss, best_model_state, last_epoch = _run_training_season(
|
||||
model, optimizer, scheduler, scaler,
|
||||
train_loader, val_loader, train_dataset, val_dataset,
|
||||
device, criterion, output_file,
|
||||
start_epoch, epochs, early_stopping_patience,
|
||||
training_start_time
|
||||
)
|
||||
|
||||
if best_model_state is None or best_val_loss >= checkpoint_val_loss:
|
||||
print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})")
|
||||
print("No new model created.")
|
||||
return
|
||||
|
||||
_save_versioned_model(
|
||||
best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||
best_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, training_start_time,
|
||||
hidden_sizes=resolved_hidden_sizes,
|
||||
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
|
||||
"checkpoint": str(checkpoint) if checkpoint else None}
|
||||
)
|
||||
|
||||
def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
||||
epochs_per_season=50, early_stopping_patience=10,
|
||||
batch_size=16384, lr=0.001, initial_checkpoint=None,
|
||||
stockfish_depth=12, use_versioning=True,
|
||||
weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
|
||||
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
|
||||
|
||||
Each season trains with early stopping. When early stopping fires, the model reloads the
|
||||
global best weights and begins a fresh season with a reset optimizer and scheduler.
|
||||
This prevents the model from drifting away from its best known state.
|
||||
|
||||
Args:
|
||||
data_file: Path to training_data.jsonl
|
||||
output_file: Output file base name
|
||||
duration_minutes: Total training budget in minutes
|
||||
epochs_per_season: Max epochs per restart season (default: 50)
|
||||
early_stopping_patience: Patience for early stopping within each season (default: 10)
|
||||
batch_size: Training batch size (default: 16384)
|
||||
lr: Learning rate reset to this value at the start of each season (default: 0.001)
|
||||
initial_checkpoint: Optional weights-only .pt file to start from
|
||||
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||
weight_decay: L2 regularization strength (default: 1e-4)
|
||||
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
|
||||
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
|
||||
"""
|
||||
deadline = datetime.now() + timedelta(minutes=duration_minutes)
|
||||
|
||||
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||
_setup_training(data_file, batch_size, subsample_ratio)
|
||||
|
||||
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
|
||||
|
||||
if initial_checkpoint:
|
||||
print(f"Loading initial weights: {initial_checkpoint}")
|
||||
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
ckpt_hidden = ckpt.get("hidden_sizes")
|
||||
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
|
||||
print(f" Using architecture from checkpoint: {ckpt_hidden}")
|
||||
resolved_hidden_sizes = ckpt_hidden
|
||||
|
||||
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
|
||||
criterion = nn.MSELoss()
|
||||
best_global_val_loss = float('inf')
|
||||
|
||||
if initial_checkpoint:
|
||||
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_global_val_loss = ckpt.get("best_val_loss", float('inf'))
|
||||
if best_global_val_loss < float('inf'):
|
||||
print(f"Resumed from checkpoint (best val loss: {best_global_val_loss:.6f})")
|
||||
else:
|
||||
print("Initial weights loaded (no val loss in checkpoint).")
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
print("Loaded weights-only checkpoint (no val loss info).")
|
||||
|
||||
arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
|
||||
print(f"Architecture: {arch_str}")
|
||||
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
|
||||
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
|
||||
print()
|
||||
|
||||
burst_start_time = datetime.now()
|
||||
season = 0
|
||||
best_global_state = None
|
||||
last_optimizer = None
|
||||
last_scheduler = None
|
||||
last_scaler = None
|
||||
last_epoch = 0
|
||||
|
||||
while datetime.now() < deadline:
|
||||
season += 1
|
||||
remaining_minutes = (deadline - datetime.now()).total_seconds() / 60
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"BURST SEASON {season} | {remaining_minutes:.1f} minutes remaining")
|
||||
if best_global_val_loss < float('inf'):
|
||||
print(f"Global best val loss so far: {best_global_val_loss:.6f}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs_per_season)
|
||||
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||
|
||||
season_start_time = datetime.now()
|
||||
val_loss, model_state, last_epoch = _run_training_season(
|
||||
model, optimizer, scheduler, scaler,
|
||||
train_loader, val_loader, train_dataset, val_dataset,
|
||||
device, criterion, output_file,
|
||||
0, epochs_per_season, early_stopping_patience,
|
||||
season_start_time, deadline=deadline,
|
||||
initial_best_val_loss=best_global_val_loss
|
||||
)
|
||||
|
||||
last_optimizer = optimizer
|
||||
last_scheduler = scheduler
|
||||
last_scaler = scaler
|
||||
|
||||
if model_state is not None and val_loss < best_global_val_loss:
|
||||
best_global_val_loss = val_loss
|
||||
best_global_state = model_state
|
||||
print(f" New global best: {best_global_val_loss:.6f} (season {season})")
|
||||
|
||||
# Reload global best for the next season so we never drift backwards
|
||||
if best_global_state is not None:
|
||||
model.load_state_dict(best_global_state)
|
||||
|
||||
total_minutes = (datetime.now() - burst_start_time).total_seconds() / 60
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Burst training complete: {season} season(s) in {total_minutes:.1f}m")
|
||||
print(f"Best val loss: {best_global_val_loss:.6f}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
if best_global_state is None:
|
||||
print("No model improvement found. No file saved.")
|
||||
return
|
||||
|
||||
_save_versioned_model(
|
||||
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
|
||||
best_global_val_loss, output_file, use_versioning, num_positions,
|
||||
stockfish_depth, burst_start_time,
|
||||
hidden_sizes=resolved_hidden_sizes,
|
||||
extra_metadata={
|
||||
"mode": "burst",
|
||||
"duration_minutes": duration_minutes,
|
||||
"epochs_per_season": epochs_per_season,
|
||||
"early_stopping_patience": early_stopping_patience,
|
||||
"seasons_completed": season,
|
||||
"batch_size": batch_size,
|
||||
"learning_rate": lr,
|
||||
"initial_checkpoint": str(initial_checkpoint) if initial_checkpoint else None,
|
||||
}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation")
|
||||
parser.add_argument("data_file", nargs="?", default="training_data.jsonl",
|
||||
help="Path to training_data.jsonl (default: training_data.jsonl)")
|
||||
parser.add_argument("output_file", nargs="?", default="nnue_weights.pt",
|
||||
help="Output file base name (default: nnue_weights.pt)")
|
||||
parser.add_argument("--checkpoint", type=str, default=None,
|
||||
help="Path to checkpoint file to resume training from (optional)")
|
||||
parser.add_argument("--epochs", type=int, default=100,
|
||||
help="Number of epochs to train (default: 100)")
|
||||
parser.add_argument("--batch-size", type=int, default=16384,
|
||||
help="Batch size (default: 16384)")
|
||||
parser.add_argument("--lr", type=float, default=0.001,
|
||||
help="Learning rate (default: 0.001)")
|
||||
parser.add_argument("--early-stopping", type=int, default=None,
|
||||
help="Stop if val loss doesn't improve for N epochs (optional)")
|
||||
parser.add_argument("--stockfish-depth", type=int, default=12,
|
||||
help="Stockfish depth used for evaluations (for metadata, default: 12)")
|
||||
parser.add_argument("--no-versioning", action="store_true",
|
||||
help="Disable automatic versioning (save directly to output file)")
|
||||
parser.add_argument("--weight-decay", type=float, default=5e-5,
|
||||
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
|
||||
parser.add_argument("--subsample-ratio", type=float, default=1.0,
|
||||
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
|
||||
parser.add_argument("--hidden-layers", type=str, default=None,
|
||||
help="Comma-separated hidden layer sizes (default: 1536,1024,512,256)")
|
||||
|
||||
# Burst mode
|
||||
parser.add_argument("--burst-duration", type=float, default=None,
|
||||
help="Enable burst mode: total training budget in minutes")
|
||||
parser.add_argument("--epochs-per-season", type=int, default=50,
|
||||
help="Max epochs per burst season before restarting (default: 50, burst mode only)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
hidden_sizes = [int(x) for x in args.hidden_layers.split(",")] if args.hidden_layers else None
|
||||
|
||||
if args.burst_duration is not None:
|
||||
burst_train(
|
||||
data_file=args.data_file,
|
||||
output_file=args.output_file,
|
||||
duration_minutes=args.burst_duration,
|
||||
epochs_per_season=args.epochs_per_season,
|
||||
early_stopping_patience=args.early_stopping or 10,
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
initial_checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning,
|
||||
weight_decay=args.weight_decay,
|
||||
subsample_ratio=args.subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
else:
|
||||
train_nnue(
|
||||
data_file=args.data_file,
|
||||
output_file=args.output_file,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
checkpoint=args.checkpoint,
|
||||
stockfish_depth=args.stockfish_depth,
|
||||
use_versioning=not args.no_versioning,
|
||||
early_stopping_patience=args.early_stopping,
|
||||
weight_decay=args.weight_decay,
|
||||
subsample_ratio=args.subsample_ratio,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
# Setup and run NNUE training pipeline
|
||||
|
||||
$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path
|
||||
$VenvDir = Join-Path $ScriptDir ".venv"
|
||||
|
||||
# Check if virtual environment exists
|
||||
if (-not (Test-Path $VenvDir)) {
|
||||
Write-Host "Creating virtual environment..."
|
||||
python -m venv $VenvDir
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Host "Error: Failed to create virtual environment. Make sure python is installed."
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
# Activate virtual environment
|
||||
Write-Host "Activating virtual environment..."
|
||||
$ActivateScript = Join-Path $VenvDir "Scripts\Activate.ps1"
|
||||
& $ActivateScript
|
||||
|
||||
# Install/update dependencies if requirements.txt exists
|
||||
$RequirementsFile = Join-Path $ScriptDir "requirements.txt"
|
||||
if (Test-Path $RequirementsFile) {
|
||||
Write-Host "Installing dependencies..."
|
||||
pip install -q -r $RequirementsFile
|
||||
}
|
||||
|
||||
# Run nnue.py
|
||||
Write-Host "Starting NNUE Training Pipeline..."
|
||||
python (Join-Path $ScriptDir "nnue.py")
|
||||
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
# Setup and run NNUE training pipeline
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
VENV_DIR="$SCRIPT_DIR/.venv"
|
||||
|
||||
# Check if virtual environment exists
|
||||
if [ ! -d "$VENV_DIR" ]; then
|
||||
echo "Creating virtual environment..."
|
||||
python3 -m venv "$VENV_DIR"
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: Failed to create virtual environment. Make sure python3 is installed."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Activate virtual environment
|
||||
echo "Activating virtual environment..."
|
||||
source "$VENV_DIR/bin/activate"
|
||||
|
||||
# Install/update dependencies if requirements.txt exists
|
||||
if [ -f "$SCRIPT_DIR/requirements.txt" ]; then
|
||||
echo "Installing dependencies..."
|
||||
pip install -q -r "$SCRIPT_DIR/requirements.txt"
|
||||
fi
|
||||
|
||||
# Run nnue.py
|
||||
echo "Starting NNUE Training Pipeline..."
|
||||
python "$SCRIPT_DIR/nnue.py"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"version": 10,
|
||||
"date": "2026-04-14T22:18:38.824577",
|
||||
"num_positions": 3022562,
|
||||
"stockfish_depth": 12,
|
||||
"final_val_loss": 6.248612448225196e-05,
|
||||
"architecture": [
|
||||
768,
|
||||
1536,
|
||||
1024,
|
||||
512,
|
||||
256,
|
||||
1
|
||||
],
|
||||
"device": "cuda",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"checkpoint": null
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 1,
|
||||
"date": "2026-04-07T22:56:23.259658",
|
||||
"num_positions": 2086,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 20,
|
||||
"batch_size": 4096,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.016311248764395714,
|
||||
"device": "cuda",
|
||||
"checkpoint": null,
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 2,
|
||||
"date": "2026-04-07T23:50:05.390402",
|
||||
"num_positions": 6886,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 4096,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.007848377339541912,
|
||||
"device": "cuda",
|
||||
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v1.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 3,
|
||||
"date": "2026-04-08T09:43:28.000579",
|
||||
"num_positions": 71610,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 20,
|
||||
"batch_size": 4096,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.006398905136849695,
|
||||
"device": "cpu",
|
||||
"checkpoint": "/home/janis/Workspaces/IntelliJ/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v2.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 4,
|
||||
"date": "2026-04-09T00:28:07.572209",
|
||||
"num_positions": 2009355,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 40,
|
||||
"batch_size": 4096,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 9.106677896235248e-05,
|
||||
"device": "cuda",
|
||||
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v3.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 5,
|
||||
"date": "2026-04-09T18:50:27.845632",
|
||||
"num_positions": 2009355,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 9.180421525105905e-05,
|
||||
"device": "cuda",
|
||||
"checkpoint": null,
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 6,
|
||||
"date": "2026-04-09T21:28:21.000832",
|
||||
"num_positions": 1958728,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.2984530149085532,
|
||||
"device": "cuda",
|
||||
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v5.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 7,
|
||||
"date": "2026-04-09T22:06:50.439858",
|
||||
"num_positions": 1958728,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.2997283308762831,
|
||||
"device": "cuda",
|
||||
"checkpoint": null,
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"version": 8,
|
||||
"date": "2026-04-09T22:22:47.859730",
|
||||
"num_positions": 1958728,
|
||||
"stockfish_depth": 12,
|
||||
"epochs": 100,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"final_val_loss": 0.24803777390839968,
|
||||
"device": "cuda",
|
||||
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v7.pt",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"version": 9,
|
||||
"date": "2026-04-13T20:19:08.123315",
|
||||
"num_positions": 2522562,
|
||||
"stockfish_depth": 12,
|
||||
"final_val_loss": 6.994176222619626e-05,
|
||||
"device": "cuda",
|
||||
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
|
||||
"mode": "burst",
|
||||
"duration_minutes": 30.0,
|
||||
"epochs_per_season": 50,
|
||||
"early_stopping_patience": 10,
|
||||
"seasons_completed": 3,
|
||||
"batch_size": 16384,
|
||||
"learning_rate": 0.001,
|
||||
"initial_checkpoint": "/home/janis/Workspaces/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v8.pt"
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
package de.nowchess.bot
|
||||
|
||||
import de.nowchess.api.bot.Bot
|
||||
import de.nowchess.bot.bots.ClassicalBot
|
||||
|
||||
object BotController {
|
||||
|
||||
private val bots: Map[String, Bot] = Map(
|
||||
"easy" -> ClassicalBot(BotDifficulty.Easy),
|
||||
"medium" -> ClassicalBot(BotDifficulty.Medium),
|
||||
"hard" -> ClassicalBot(BotDifficulty.Hard),
|
||||
"expert" -> ClassicalBot(BotDifficulty.Expert),
|
||||
)
|
||||
|
||||
/** Get a bot by name. */
|
||||
def getBot(name: String): Option[Bot] = bots.get(name.toLowerCase)
|
||||
|
||||
/** List all available bot names. */
|
||||
def listBots: List[String] = bots.keys.toList.sorted
|
||||
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package de.nowchess.bot
|
||||
|
||||
enum BotDifficulty:
|
||||
case Easy
|
||||
case Medium
|
||||
case Hard
|
||||
case Expert
|
||||
@@ -0,0 +1,19 @@
|
||||
package de.nowchess.bot
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
|
||||
object BotMoveRepetition:
|
||||
|
||||
private val maxConsecutiveMoves = 3
|
||||
|
||||
def blockedMoves(context: GameContext): Set[Move] = repeatedMove(context).toSet
|
||||
|
||||
def repeatedMove(context: GameContext): Option[Move] =
|
||||
context.moves.takeRight(maxConsecutiveMoves) match
|
||||
case first :: second :: third :: Nil if first == second && second == third => Some(first)
|
||||
case _ => None
|
||||
|
||||
def filterAllowed(context: GameContext, moves: List[Move]): List[Move] =
|
||||
val blocked = blockedMoves(context)
|
||||
moves.filterNot(blocked.contains)
|
||||
@@ -0,0 +1,11 @@
|
||||
package de.nowchess.bot
|
||||
|
||||
object Config:
|
||||
|
||||
/** Threshold in centipawns: if classical evaluation differs from NNUE by more than this, the move is vetoed (not
|
||||
* accepted as a suggestion).
|
||||
*/
|
||||
val VETO_THRESHOLD: Int = 150
|
||||
|
||||
/** Time budget per move for iterative deepening (milliseconds). */
|
||||
val TIME_LIMIT_MS: Long = 2000L
|
||||
@@ -0,0 +1,29 @@
|
||||
package de.nowchess.bot.ai
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
|
||||
trait Evaluation:
|
||||
|
||||
def CHECKMATE_SCORE: Int
|
||||
def DRAW_SCORE: Int
|
||||
|
||||
def evaluate(context: GameContext): Int
|
||||
|
||||
// ── Accumulator hooks ─────────────────────────────────────────────────────
|
||||
// Default implementations fall back to full re-evaluation each call.
|
||||
// Override in NNUE-capable evaluators for incremental L1 speedup.
|
||||
|
||||
/** Initialise the accumulator for the root position at ply 0. */
|
||||
def initAccumulator(context: GameContext): Unit = ()
|
||||
|
||||
/** Copy parent ply's accumulator to childPly without move deltas (null-move). */
|
||||
def copyAccumulator(parentPly: Int, childPly: Int): Unit = ()
|
||||
|
||||
/** Derive childPly's accumulator from parentPly by applying move deltas. */
|
||||
def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit = ()
|
||||
|
||||
/** Evaluate from the pre-computed accumulator at ply, using hash for the eval cache. Falls back to full evaluate when
|
||||
* not overridden.
|
||||
*/
|
||||
def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int = evaluate(context)
|
||||
@@ -0,0 +1,29 @@
|
||||
package de.nowchess.bot.bots
|
||||
|
||||
import de.nowchess.api.bot.Bot
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.bots.classic.EvaluationClassic
|
||||
import de.nowchess.bot.logic.AlphaBetaSearch
|
||||
import de.nowchess.bot.util.PolyglotBook
|
||||
import de.nowchess.bot.{BotDifficulty, BotMoveRepetition}
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
|
||||
final class ClassicalBot(
|
||||
difficulty: BotDifficulty,
|
||||
rules: RuleSet = DefaultRules,
|
||||
book: Option[PolyglotBook] = None,
|
||||
) extends Bot:
|
||||
|
||||
private val search: AlphaBetaSearch = AlphaBetaSearch(rules, weights = EvaluationClassic)
|
||||
private val TIME_BUDGET_MS = 1000L
|
||||
|
||||
override val name: String = s"ClassicalBot(${difficulty.toString})"
|
||||
|
||||
override def nextMove(context: GameContext): Option[Move] =
|
||||
val blockedMoves = BotMoveRepetition.blockedMoves(context)
|
||||
book
|
||||
.flatMap(_.probe(context))
|
||||
.filterNot(blockedMoves.contains)
|
||||
.orElse(search.bestMoveWithTime(context, TIME_BUDGET_MS, blockedMoves))
|
||||
@@ -0,0 +1,43 @@
|
||||
package de.nowchess.bot.bots
|
||||
|
||||
import de.nowchess.api.bot.Bot
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.ai.Evaluation
|
||||
import de.nowchess.bot.bots.classic.EvaluationClassic
|
||||
import de.nowchess.bot.bots.nnue.EvaluationNNUE
|
||||
import de.nowchess.bot.logic.{AlphaBetaSearch, TranspositionTable}
|
||||
import de.nowchess.bot.util.PolyglotBook
|
||||
import de.nowchess.bot.{BotDifficulty, BotMoveRepetition, Config}
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
|
||||
final class HybridBot(
|
||||
difficulty: BotDifficulty,
|
||||
rules: RuleSet = DefaultRules,
|
||||
book: Option[PolyglotBook] = None,
|
||||
nnueEvaluation: Evaluation = EvaluationNNUE,
|
||||
classicalEvaluation: Evaluation = EvaluationClassic,
|
||||
vetoReporter: String => Unit = println(_),
|
||||
) extends Bot:
|
||||
|
||||
private val search = AlphaBetaSearch(rules, TranspositionTable(), classicalEvaluation)
|
||||
|
||||
override val name: String = s"HybridBot(${difficulty.toString})"
|
||||
|
||||
override def nextMove(context: GameContext): Option[Move] =
|
||||
val blockedMoves = BotMoveRepetition.blockedMoves(context)
|
||||
book.flatMap(_.probe(context)).filterNot(blockedMoves.contains).orElse(searchWithVeto(context, blockedMoves))
|
||||
|
||||
private def searchWithVeto(context: GameContext, blockedMoves: Set[Move]): Option[Move] =
|
||||
search.bestMoveWithTime(context, Config.TIME_LIMIT_MS, blockedMoves).map { move =>
|
||||
val next = rules.applyMove(context)(move)
|
||||
val staticNnue = nnueEvaluation.evaluate(next)
|
||||
val classical = classicalEvaluation.evaluate(next)
|
||||
val diff = (classical - staticNnue).abs
|
||||
if diff > Config.VETO_THRESHOLD then
|
||||
vetoReporter(
|
||||
f"[Veto] ${move.from}->${move.to}: nnue=$staticNnue classical=$classical diff=$diff — flagged but trusted (deep search)",
|
||||
)
|
||||
move
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package de.nowchess.bot.bots
|
||||
|
||||
import de.nowchess.api.bot.Bot
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.bots.nnue.EvaluationNNUE
|
||||
import de.nowchess.bot.logic.AlphaBetaSearch
|
||||
import de.nowchess.bot.util.{PolyglotBook, ZobristHash}
|
||||
import de.nowchess.bot.{BotDifficulty, BotMoveRepetition}
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
|
||||
final class NNUEBot(
|
||||
difficulty: BotDifficulty,
|
||||
rules: RuleSet = DefaultRules,
|
||||
book: Option[PolyglotBook] = None,
|
||||
) extends Bot:
|
||||
|
||||
private val search: AlphaBetaSearch = AlphaBetaSearch(rules, weights = EvaluationNNUE)
|
||||
|
||||
override val name: String = s"NNUEBot(${difficulty.toString})"
|
||||
|
||||
override def nextMove(context: GameContext): Option[Move] =
|
||||
val blockedMoves = BotMoveRepetition.blockedMoves(context)
|
||||
book
|
||||
.flatMap(_.probe(context))
|
||||
.filterNot(blockedMoves.contains)
|
||||
.orElse {
|
||||
val moves = BotMoveRepetition.filterAllowed(context, rules.allLegalMoves(context))
|
||||
if moves.isEmpty then None
|
||||
else
|
||||
val scored = batchEvaluateRoot(context, moves)
|
||||
val bestMove = scored.maxBy(_._2)._1
|
||||
search.bestMoveWithTime(context, allocateTime(scored), blockedMoves).orElse(Some(bestMove))
|
||||
}
|
||||
|
||||
/** Evaluate all root moves shallowly via incremental NNUE accumulator updates. Returns (move, score) pairs with score
|
||||
* from the root player's perspective.
|
||||
*/
|
||||
private def batchEvaluateRoot(context: GameContext, moves: List[Move]): List[(Move, Int)] =
|
||||
EvaluationNNUE.initAccumulator(context)
|
||||
val rootHash = ZobristHash.hash(context)
|
||||
moves.map { move =>
|
||||
val child = rules.applyMove(context)(move)
|
||||
val childHash = ZobristHash.nextHash(context, rootHash, move, child)
|
||||
EvaluationNNUE.pushAccumulator(1, move, context, child)
|
||||
val score = -EvaluationNNUE.evaluateAccumulator(1, child, childHash)
|
||||
(move, score)
|
||||
}
|
||||
|
||||
/** Allocate more time for complex positions; less when one move clearly dominates. */
|
||||
private def allocateTime(scored: List[(Move, Int)]): Long =
|
||||
val moveCount = scored.length
|
||||
if moveCount > 30 then 1500L
|
||||
else if moveCount < 5 then 500L
|
||||
else
|
||||
val scores = scored.map(_._2)
|
||||
val best = scores.max
|
||||
val second = scores.filter(_ < best).maxOption.getOrElse(best)
|
||||
if best - second > 200 then 600L else 1000L
|
||||
@@ -0,0 +1,361 @@
|
||||
package de.nowchess.bot.bots.classic
|
||||
|
||||
import de.nowchess.api.board.{Color, PieceType, Square}
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.bot.ai.Evaluation
|
||||
|
||||
object EvaluationClassic extends Evaluation:
|
||||
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
|
||||
// Material values in centipawns (indexed by PieceType.ordinal: Pawn=0, Knight=1, Bishop=2, Rook=3, Queen=4, King=5)
|
||||
private val mgMaterial = Array(100, 325, 335, 500, 900, 20_000)
|
||||
private val egMaterial = Array(110, 310, 310, 530, 1_000, 20_000)
|
||||
|
||||
private val TEMPO_BONUS: Int = 10
|
||||
|
||||
// Piece-square tables (Simplified Evaluation Function, Michniewski)
|
||||
// Indexed by squareIndex = rank.ordinal * 8 + file.ordinal
|
||||
// White's perspective: rank 0 = home (r1), rank 7 = back rank (r8)
|
||||
// Black is vertically mirrored
|
||||
|
||||
private val mgPawnTable: Array[Int] = Array(
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 10, 10, 20, 30, 30, 20, 10, 10, 5, 5, 10, 25, 25, 10, 5, 5,
|
||||
0, 0, 0, 20, 20, 0, 0, 0, 5, -5, -10, 0, 0, -10, -5, 5, 5, 10, 10, -20, -20, 10, 10, 5, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
)
|
||||
|
||||
private val egPawnTable: Array[Int] = Array(
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 70, 70, 70, 70, 70, 70, 70, 70, 40, 40, 40, 40, 40, 40, 40, 40, 30, 30, 30, 30, 30, 30, 30,
|
||||
30, 20, 20, 20, 20, 20, 20, 20, 20, 10, 10, 10, 10, 10, 10, 10, 10, 5, 5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
)
|
||||
|
||||
private val mgKnightTable: Array[Int] = Array(
|
||||
-50, -40, -30, -30, -30, -30, -40, -50, -40, -20, 0, 0, 0, 0, -20, -40, -30, 0, 10, 15, 15, 10, 0, -30, -30, 5, 15,
|
||||
20, 20, 15, 5, -30, -30, 0, 15, 20, 20, 15, 0, -30, -30, 5, 10, 15, 15, 10, 5, -30, -40, -20, 0, 5, 5, 0, -20, -40,
|
||||
-50, -40, -30, -30, -30, -30, -40, -50,
|
||||
)
|
||||
|
||||
private val egKnightTable: Array[Int] = Array(
|
||||
-30, -20, -10, -10, -10, -10, -20, -30, -20, 0, 5, 5, 5, 5, 0, -20, -10, 5, 15, 20, 20, 15, 5, -10, -10, 5, 20, 25,
|
||||
25, 20, 5, -10, -10, 5, 20, 25, 25, 20, 5, -10, -10, 5, 15, 20, 20, 15, 5, -10, -20, 0, 5, 5, 5, 5, 0, -20, -30,
|
||||
-20, -10, -10, -10, -10, -20, -30,
|
||||
)
|
||||
|
||||
private val mgBishopTable: Array[Int] = Array(
|
||||
-20, -10, -10, -10, -10, -10, -10, -20, -10, 0, 0, 0, 0, 0, 0, -10, -10, 0, 5, 10, 10, 5, 0, -10, -10, 5, 5, 10, 10,
|
||||
5, 5, -10, -10, 0, 10, 10, 10, 10, 0, -10, -10, 10, 10, 10, 10, 10, 10, -10, -10, 5, 0, 0, 0, 0, 5, -10, -20, -10,
|
||||
-10, -10, -10, -10, -10, -20,
|
||||
)
|
||||
|
||||
private val egBishopTable: Array[Int] = Array(
|
||||
-20, -10, -5, -5, -5, -5, -10, -20, -10, 0, 5, 5, 5, 5, 0, -10, -5, 5, 10, 10, 10, 10, 5, -5, -5, 5, 10, 15, 15, 10,
|
||||
5, -5, -5, 5, 10, 15, 15, 10, 5, -5, -5, 5, 10, 10, 10, 10, 5, -5, -10, 0, 5, 5, 5, 5, 0, -10, -20, -10, -5, -5, -5,
|
||||
-5, -10, -20,
|
||||
)
|
||||
|
||||
private val mgRookTable: Array[Int] = Array(
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 5, 10, 10, 10, 10, 10, 10, 5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0,
|
||||
0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, 0, 0, 0, 5, 5, 0, 0, 0,
|
||||
)
|
||||
|
||||
private val egRookTable: Array[Int] = Array(
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 5, 10, 10, 10, 10, 10, 10, 5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0,
|
||||
0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, 0, 0, 0, 5, 5, 0, 0, 0,
|
||||
)
|
||||
|
||||
private val mgQueenTable: Array[Int] = Array(
|
||||
-20, -10, -10, -5, -5, -10, -10, -20, -10, 0, 0, 0, 0, 0, 0, -10, -10, 0, 5, 5, 5, 5, 0, -10, -5, 0, 5, 5, 5, 5, 0,
|
||||
-5, 0, 0, 5, 5, 5, 5, 0, -5, -10, 5, 5, 5, 5, 5, 0, -10, -10, 0, 5, 0, 0, 0, 0, -10, -20, -10, -10, -5, -5, -10,
|
||||
-10, -20,
|
||||
)
|
||||
|
||||
private val egQueenTable: Array[Int] = Array(
|
||||
-15, -10, -8, -5, -5, -8, -10, -15, -10, 0, 3, 5, 5, 3, 0, -10, -8, 3, 10, 10, 10, 10, 3, -8, -5, 5, 10, 15, 15, 10,
|
||||
5, -5, -5, 5, 10, 15, 15, 10, 5, -5, -8, 3, 10, 10, 10, 10, 3, -8, -10, 0, 3, 5, 5, 3, 0, -10, -15, -10, -8, -5, -5,
|
||||
-8, -10, -15,
|
||||
)
|
||||
|
||||
private val mgKingTable: Array[Int] = Array(
|
||||
-30, -40, -40, -50, -50, -40, -40, -30, -30, -40, -40, -50, -50, -40, -40, -30, -30, -40, -40, -50, -50, -40, -40,
|
||||
-30, -30, -40, -40, -50, -50, -40, -40, -30, -20, -30, -30, -40, -40, -30, -30, -20, -10, -20, -20, -20, -20, -20,
|
||||
-20, -10, 20, 20, 0, 0, 0, 0, 20, 20, 20, 30, 10, 0, 0, 10, 30, 20,
|
||||
)
|
||||
|
||||
private val egKingTable: Array[Int] = Array(
|
||||
-50, -40, -30, -20, -20, -30, -40, -50, -30, -20, -10, 0, 0, -10, -20, -30, -30, -10, 20, 30, 30, 20, -10, -30, -30,
|
||||
-10, 30, 40, 40, 30, -10, -30, -30, -10, 30, 40, 40, 30, -10, -30, -30, -10, 20, 30, 30, 20, -10, -30, -30, -30, 0,
|
||||
0, 0, 0, -30, -30, -50, -30, -30, -30, -30, -30, -30, -50,
|
||||
)
|
||||
|
||||
private val phaseWeight: Map[PieceType, Int] = Map(
|
||||
PieceType.Knight -> 1,
|
||||
PieceType.Bishop -> 1,
|
||||
PieceType.Rook -> 2,
|
||||
PieceType.Queen -> 4,
|
||||
)
|
||||
private val maxPhase = 24 // 4*4 + 4*2 + 4*1 + 4*1
|
||||
|
||||
private val passedPawnBonus: Array[Int] = Array(0, 5, 10, 20, 35, 60, 100, 0)
|
||||
private val egPassedPawnBonus: Array[Int] = Array(0, 20, 40, 80, 150, 250, 400, 0)
|
||||
|
||||
// Pawn structure penalties
|
||||
private val doubledMg = -10
|
||||
private val doubledEg = -25
|
||||
private val isolatedMg = -15
|
||||
private val isolatedEg = -20
|
||||
|
||||
// Mobility weights: centipawns per reachable square (indexed by PieceType.ordinal)
|
||||
private val mobilityMg = Array(0, 4, 3, 2, 1, 0, 0)
|
||||
private val mobilityEg = Array(0, 4, 3, 4, 2, 0, 0)
|
||||
|
||||
// Direction offsets for sliding pieces
|
||||
private val diagonals = List((-1, -1), (-1, 1), (1, -1), (1, 1))
|
||||
private val orthogonals = List((-1, 0), (1, 0), (0, -1), (0, 1))
|
||||
private val knightOffsets = List((-2, -1), (-2, 1), (-1, -2), (-1, 2), (1, -2), (1, 2), (2, -1), (2, 1))
|
||||
|
||||
// Rook and bishop bonuses
|
||||
private val bishopPairMg = 50
|
||||
private val bishopPairEg = 70
|
||||
private val rookOn7thMg = 20
|
||||
private val rookOn7thEg = 10
|
||||
|
||||
/** Evaluate the position from the perspective of context.turn. Positive = good for context.turn.
|
||||
*/
|
||||
def evaluate(context: GameContext): Int =
|
||||
val phase = gamePhase(context.board)
|
||||
val isEg = isEndgame(phase)
|
||||
val material = materialAndPositional(context, phase)
|
||||
val structure = pawnStructure(context, phase)
|
||||
val mobility = mobilityScore(context, phase)
|
||||
val rookBishop = rookAndBishopBonuses(context, phase)
|
||||
val bonuses = positionalBonuses(context, phase, isEg)
|
||||
val egBonuses = if isEg then endgameBonus(context) else 0
|
||||
material + structure + mobility + rookBishop + bonuses + egBonuses + TEMPO_BONUS
|
||||
|
||||
private def gamePhase(board: de.nowchess.api.board.Board): Int =
|
||||
val phase = board.pieces.values.foldLeft(0) { (acc, piece) =>
|
||||
acc + phaseWeight.getOrElse(piece.pieceType, 0)
|
||||
}
|
||||
math.min(phase, maxPhase)
|
||||
|
||||
private def isEndgame(phase: Int): Boolean =
|
||||
phase < 8 // Significantly reduced material indicates endgame
|
||||
|
||||
private def taper(mg: Int, eg: Int, phase: Int): Int =
|
||||
(mg * phase + eg * (maxPhase - phase)) / maxPhase
|
||||
|
||||
private def materialAndPositional(context: GameContext, phase: Int): Int =
|
||||
val (mg, eg) = context.board.pieces.foldLeft((0, 0)) { case ((mg, eg), (square, piece)) =>
|
||||
val (psqMg, psqEg) = squareBonus(piece.pieceType, piece.color, square)
|
||||
val pieceMg = mgMaterial(piece.pieceType.ordinal) + psqMg
|
||||
val pieceEg = egMaterial(piece.pieceType.ordinal) + psqEg
|
||||
val sign = if piece.color == context.turn then 1 else -1
|
||||
(mg + sign * pieceMg, eg + sign * pieceEg)
|
||||
}
|
||||
taper(mg, eg, phase)
|
||||
|
||||
private def squareBonus(pieceType: PieceType, color: Color, sq: Square): (Int, Int) =
|
||||
val rankIdx = if color == Color.White then sq.rank.ordinal else 7 - sq.rank.ordinal
|
||||
val fileIdx = sq.file.ordinal
|
||||
val squareIdx = rankIdx * 8 + fileIdx
|
||||
|
||||
pieceType match
|
||||
case PieceType.Pawn => (mgPawnTable(squareIdx), egPawnTable(squareIdx))
|
||||
case PieceType.Knight => (mgKnightTable(squareIdx), egKnightTable(squareIdx))
|
||||
case PieceType.Bishop => (mgBishopTable(squareIdx), egBishopTable(squareIdx))
|
||||
case PieceType.Rook => (mgRookTable(squareIdx), egRookTable(squareIdx))
|
||||
case PieceType.Queen => (mgQueenTable(squareIdx), egQueenTable(squareIdx))
|
||||
case PieceType.King => (mgKingTable(squareIdx), egKingTable(squareIdx))
|
||||
|
||||
private def pawnStructure(context: GameContext, phase: Int): Int =
|
||||
val friendlyPawns = context.board.pieces.filter((_, p) => p.color == context.turn && p.pieceType == PieceType.Pawn)
|
||||
val enemyPawns = context.board.pieces.filter((_, p) => p.color != context.turn && p.pieceType == PieceType.Pawn)
|
||||
|
||||
val friendlyByFile = friendlyPawns.groupMap(s => s._1.file.ordinal)(s => s._1.rank.ordinal)
|
||||
val enemyByFile = enemyPawns.groupMap(s => s._1.file.ordinal)(s => s._1.rank.ordinal)
|
||||
|
||||
val (fMg, fEg) = structureScore(friendlyByFile)
|
||||
val (eMg, eEg) = structureScore(enemyByFile)
|
||||
taper(fMg - eMg, fEg - eEg, phase)
|
||||
|
||||
private def structureScore(byFile: Map[Int, Iterable[Int]]): (Int, Int) =
|
||||
byFile.foldLeft((0, 0)) { case ((mg, eg), (file, ranks)) =>
|
||||
val doubled = (ranks.size - 1).max(0)
|
||||
val hasAdjacent = (file - 1 to file + 1).filter(f => f >= 0 && f < 8 && f != file).exists(byFile.contains)
|
||||
val isolated = if !hasAdjacent then ranks.size else 0
|
||||
(mg + doubled * doubledMg + isolated * isolatedMg, eg + doubled * doubledEg + isolated * isolatedEg)
|
||||
}
|
||||
|
||||
private def positionalBonuses(context: GameContext, phase: Int, isEg: Boolean): Int =
|
||||
context.board.pieces.foldLeft(0) { case (score, (sq, piece)) =>
|
||||
val bonus = piece.pieceType match
|
||||
case PieceType.Pawn =>
|
||||
if isPassedPawn(context.board, sq, piece.color) then
|
||||
if isEg then egPassedPawnBonus(sq.rank.ordinal) else passedPawnBonus(sq.rank.ordinal)
|
||||
else 0
|
||||
case PieceType.Rook => rookOpenFileBonus(context.board, sq, piece.color)
|
||||
case PieceType.King => kingShieldBonus(context.board, sq, piece.color, phase)
|
||||
case _ => 0
|
||||
if piece.color == context.turn then score + bonus else score - bonus
|
||||
}
|
||||
|
||||
private def isPassedPawn(board: de.nowchess.api.board.Board, sq: Square, color: Color): Boolean =
|
||||
val enemyColor = color.opposite
|
||||
val pawnRank = sq.rank.ordinal
|
||||
val fileRange = (sq.file.ordinal - 1 to sq.file.ordinal + 1).filter(f => f >= 0 && f < 8)
|
||||
val rankCheck = if color == Color.White then (r: Int) => r > pawnRank else (r: Int) => r < pawnRank
|
||||
|
||||
board.pieces.forall { (enemySq, enemyPiece) =>
|
||||
!(enemyPiece.color == enemyColor &&
|
||||
enemyPiece.pieceType == PieceType.Pawn &&
|
||||
fileRange.contains(enemySq.file.ordinal) &&
|
||||
rankCheck(enemySq.rank.ordinal))
|
||||
}
|
||||
|
||||
private def rookOpenFileBonus(board: de.nowchess.api.board.Board, rookSq: Square, color: Color): Int =
|
||||
val hasFriendlyPawn = board.pieces.exists { (sq, piece) =>
|
||||
piece.color == color && piece.pieceType == PieceType.Pawn && sq.file == rookSq.file
|
||||
}
|
||||
val hasEnemyPawn = board.pieces.exists { (sq, piece) =>
|
||||
piece.color != color && piece.pieceType == PieceType.Pawn && sq.file == rookSq.file
|
||||
}
|
||||
if !hasFriendlyPawn && !hasEnemyPawn then 20 // open file
|
||||
else if !hasFriendlyPawn then 10 // semi-open file
|
||||
else 0
|
||||
|
||||
private def kingShieldBonus(board: de.nowchess.api.board.Board, kingSq: Square, color: Color, phase: Int): Int =
|
||||
val shieldRankDelta = if color == Color.White then 1 else -1
|
||||
val shieldFiles = (kingSq.file.ordinal - 1 to kingSq.file.ordinal + 1).filter(f => f >= 0 && f < 8)
|
||||
val shieldRank = kingSq.rank.ordinal + shieldRankDelta
|
||||
|
||||
if shieldRank < 0 || shieldRank > 7 then 0
|
||||
else
|
||||
val rawBonus = board.pieces.count { (sq, piece) =>
|
||||
piece.color == color &&
|
||||
piece.pieceType == PieceType.Pawn &&
|
||||
shieldFiles.contains(sq.file.ordinal) &&
|
||||
sq.rank.ordinal == shieldRank
|
||||
} * 10
|
||||
(rawBonus * phase) / maxPhase
|
||||
|
||||
private def slidingCount(
|
||||
sq: Square,
|
||||
board: de.nowchess.api.board.Board,
|
||||
color: Color,
|
||||
directions: List[(Int, Int)],
|
||||
): Int =
|
||||
directions.foldLeft(0) { case (total, (fileDelta, rankDelta)) =>
|
||||
@scala.annotation.tailrec
|
||||
def countRay(current: Option[Square], acc: Int): Int =
|
||||
current match
|
||||
case None => acc
|
||||
case Some(target) =>
|
||||
board.pieceAt(target) match
|
||||
case Some(piece) if piece.color == color => acc
|
||||
case Some(_) => acc + 1
|
||||
case None => countRay(target.offset(fileDelta, rankDelta), acc + 1)
|
||||
total + countRay(sq.offset(fileDelta, rankDelta), 0)
|
||||
}
|
||||
|
||||
private def knightCount(sq: Square, board: de.nowchess.api.board.Board, color: Color): Int =
|
||||
knightOffsets.count { case (fileDelta, rankDelta) =>
|
||||
sq.offset(fileDelta, rankDelta).forall { target =>
|
||||
board.pieceAt(target).forall(_.color != color)
|
||||
}
|
||||
}
|
||||
|
||||
private def mobilityScore(context: GameContext, phase: Int): Int =
|
||||
val (mg, eg) = context.board.pieces.foldLeft((0, 0)) { case ((mg, eg), (sq, piece)) =>
|
||||
val count = piece.pieceType match
|
||||
case PieceType.Knight => knightCount(sq, context.board, piece.color)
|
||||
case PieceType.Bishop => slidingCount(sq, context.board, piece.color, diagonals)
|
||||
case PieceType.Rook => slidingCount(sq, context.board, piece.color, orthogonals)
|
||||
case PieceType.Queen => slidingCount(sq, context.board, piece.color, diagonals ++ orthogonals)
|
||||
case _ => 0
|
||||
val pieceMg = count * mobilityMg(piece.pieceType.ordinal)
|
||||
val pieceEg = count * mobilityEg(piece.pieceType.ordinal)
|
||||
val sign = if piece.color == context.turn then 1 else -1
|
||||
(mg + sign * pieceMg, eg + sign * pieceEg)
|
||||
}
|
||||
taper(mg, eg, phase)
|
||||
|
||||
private def rookAndBishopBonuses(context: GameContext, phase: Int): Int =
|
||||
val (baseMg, baseEg) = bishopPairBase(context)
|
||||
val (rookMg, rookEg) = rookOn7thDelta(context)
|
||||
taper(baseMg + rookMg, baseEg + rookEg, phase)
|
||||
|
||||
private def bishopPairBase(context: GameContext): (Int, Int) =
|
||||
val friendlyHasPair = hasBishopPair(context, context.turn)
|
||||
val enemyHasPair = hasBishopPair(context, context.turn.opposite)
|
||||
val mg = pairDelta(friendlyHasPair, enemyHasPair, bishopPairMg)
|
||||
val eg = pairDelta(friendlyHasPair, enemyHasPair, bishopPairEg)
|
||||
(mg, eg)
|
||||
|
||||
private def hasBishopPair(context: GameContext, color: Color): Boolean =
|
||||
val bishopSquares = context.board.pieces.collect {
|
||||
case (sq, piece) if piece.color == color && piece.pieceType == PieceType.Bishop => sq
|
||||
}
|
||||
bishopSquares.exists(isEvenSquare) && bishopSquares.exists(sq => !isEvenSquare(sq))
|
||||
|
||||
private def isEvenSquare(square: Square): Boolean =
|
||||
(square.file.ordinal + square.rank.ordinal) % 2 == 0
|
||||
|
||||
private def pairDelta(friendlyHasPair: Boolean, enemyHasPair: Boolean, bonus: Int): Int =
|
||||
(if friendlyHasPair then bonus else 0) - (if enemyHasPair then bonus else 0)
|
||||
|
||||
private def rookOn7thDelta(context: GameContext): (Int, Int) =
|
||||
context.board.pieces.foldLeft((0, 0)) { case ((mg, eg), (sq, piece)) =>
|
||||
rookOn7thContribution(piece, sq, context.turn).fold((mg, eg)) { case (dMg, dEg) =>
|
||||
(mg + dMg, eg + dEg)
|
||||
}
|
||||
}
|
||||
|
||||
private def rookOn7thContribution(piece: de.nowchess.api.board.Piece, sq: Square, turn: Color): Option[(Int, Int)] =
|
||||
Option.when(piece.pieceType == PieceType.Rook && isRookOn7th(piece.color, sq)) {
|
||||
val sign = if piece.color == turn then 1 else -1
|
||||
(sign * rookOn7thMg, sign * rookOn7thEg)
|
||||
}
|
||||
|
||||
private def isRookOn7th(color: Color, sq: Square): Boolean =
|
||||
if color == Color.White then sq.rank.ordinal == 6 else sq.rank.ordinal == 1
|
||||
|
||||
private def endgameBonus(context: GameContext): Int =
|
||||
val friendlyKing = context.board.pieces.find((_, p) => p.color == context.turn && p.pieceType == PieceType.King)
|
||||
val enemyKing = context.board.pieces.find((_, p) => p.color != context.turn && p.pieceType == PieceType.King)
|
||||
|
||||
val kingCentralBonus =
|
||||
friendlyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15) -
|
||||
enemyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15)
|
||||
|
||||
val friendlyMaterial = materialCount(context, context.turn)
|
||||
val enemyMaterial = materialCount(context, context.turn.opposite)
|
||||
val edgeBonus =
|
||||
if friendlyMaterial > enemyMaterial then enemyKing.fold(0)((kSq, _) => (7 - kingEdgeDistance(kSq)) * 10)
|
||||
else 0
|
||||
|
||||
kingCentralBonus + edgeBonus
|
||||
|
||||
private def kingCentralizationDistance(sq: Square): Int =
|
||||
val fileFromCenter = (sq.file.ordinal - 3.5).abs.toInt
|
||||
val rankFromCenter = (sq.rank.ordinal - 3.5).abs.toInt
|
||||
math.max(fileFromCenter, rankFromCenter)
|
||||
|
||||
private def kingEdgeDistance(sq: Square): Int =
|
||||
val fileFromEdge = math.min(sq.file.ordinal, 7 - sq.file.ordinal)
|
||||
val rankFromEdge = math.min(sq.rank.ordinal, 7 - sq.rank.ordinal)
|
||||
math.min(fileFromEdge, rankFromEdge)
|
||||
|
||||
private def materialCount(context: GameContext, color: Color): Int =
|
||||
context.board.pieces.foldLeft(0) { case (sum, (_, piece)) =>
|
||||
if piece.color == color then
|
||||
sum + (piece.pieceType match
|
||||
case PieceType.Knight => 300
|
||||
case PieceType.Bishop => 300
|
||||
case PieceType.Rook => 500
|
||||
case PieceType.Queen => 900
|
||||
case PieceType.Pawn => 0
|
||||
case PieceType.King => 0
|
||||
)
|
||||
else sum
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
import de.nowchess.bot.ai.Evaluation
|
||||
|
||||
object EvaluationNNUE extends Evaluation:
|
||||
|
||||
private val nnue = NNUE(NbaiLoader.loadDefault())
|
||||
|
||||
val CHECKMATE_SCORE: Int = 10_000_000
|
||||
val DRAW_SCORE: Int = 0
|
||||
|
||||
/** Full-board evaluate — used as fallback and by non-search callers. */
|
||||
def evaluate(context: GameContext): Int = nnue.evaluate(context)
|
||||
|
||||
// ── Accumulator hooks (incremental L1) ───────────────────────────────────
|
||||
|
||||
override def initAccumulator(context: GameContext): Unit =
|
||||
nnue.initAccumulator(context.board)
|
||||
|
||||
override def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
||||
nnue.copyAccumulator(parentPly, childPly)
|
||||
|
||||
override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit =
|
||||
// Use incremental updates, but recompute from scratch every 10 plies to prevent accumulation errors
|
||||
if childPly % 10 == 0 then nnue.recomputeAccumulator(childPly, child.board)
|
||||
else nnue.pushAccumulator(childPly, move, parent.board)
|
||||
|
||||
override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int =
|
||||
nnue.evaluateAtPlyWithValidation(ply, context.turn, hash, context.board)
|
||||
@@ -0,0 +1,231 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Square}
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||
|
||||
class NNUE(model: NbaiModel):
|
||||
|
||||
private val featureSize = model.layers(0).inputSize
|
||||
private val accSize = model.layers(0).outputSize
|
||||
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
|
||||
|
||||
// Column-major L1 weights for cache-friendly sparse & incremental updates.
|
||||
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
|
||||
private val l1WeightsT: Array[Float] =
|
||||
val w = model.weights(0).weights
|
||||
val t = new Array[Float](featureSize * accSize)
|
||||
for j <- 0 until featureSize; i <- 0 until accSize do t(j * accSize + i) = w(i * featureSize + j)
|
||||
t
|
||||
|
||||
// ── Accumulator stack ────────────────────────────────────────────────────
|
||||
|
||||
private val MAX_PLY = 128
|
||||
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](accSize))
|
||||
|
||||
// Shared evaluation buffers: index i holds the output of layers(i) (all except the scalar output layer).
|
||||
private val evalBuffers: Array[Array[Float]] = model.layers.init.map(l => new Array[Float](l.outputSize))
|
||||
|
||||
// ── Eval cache ───────────────────────────────────────────────────────────
|
||||
|
||||
private val EVAL_CACHE_MASK = (1 << 18) - 1L
|
||||
private val evalCacheHashes = new Array[Long](1 << 18)
|
||||
private val evalCacheScores = new Array[Int](1 << 18)
|
||||
|
||||
// ── Feature helpers ──────────────────────────────────────────────────────
|
||||
|
||||
private def squareNum(sq: Square): Int = sq.rank.ordinal * 8 + sq.file.ordinal
|
||||
|
||||
private def featureIndex(piece: Piece, sqNum: Int): Int =
|
||||
val colorOffset = if piece.color == Color.White then 6 else 0
|
||||
(colorOffset + piece.pieceType.ordinal) * 64 + sqNum
|
||||
|
||||
private def addColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
||||
val offset = featureIdx * accSize
|
||||
for i <- 0 until accSize do l1Pre(i) += l1WeightsT(offset + i)
|
||||
|
||||
private def subtractColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
||||
val offset = featureIdx * accSize
|
||||
for i <- 0 until accSize do l1Pre(i) -= l1WeightsT(offset + i)
|
||||
|
||||
// ── Accumulator init ─────────────────────────────────────────────────────
|
||||
|
||||
def initAccumulator(board: Board): Unit =
|
||||
System.arraycopy(model.weights(0).bias, 0, l1Stack(0), 0, accSize)
|
||||
for (sq, piece) <- board.pieces do addColumn(l1Stack(0), featureIndex(piece, squareNum(sq)))
|
||||
|
||||
// ── Accumulator push (incremental updates) ───────────────────────────────
|
||||
|
||||
def pushAccumulator(childPly: Int, move: Move, board: Board): Unit =
|
||||
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize)
|
||||
val l1 = l1Stack(childPly)
|
||||
move.moveType match
|
||||
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
|
||||
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
|
||||
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
|
||||
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
|
||||
|
||||
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
||||
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
|
||||
|
||||
def recomputeAccumulator(ply: Int, board: Board): Unit =
|
||||
System.arraycopy(model.weights(0).bias, 0, l1Stack(ply), 0, accSize)
|
||||
for (sq, piece) <- board.pieces do addColumn(l1Stack(ply), featureIndex(piece, squareNum(sq)))
|
||||
|
||||
def validateAccumulator(ply: Int, board: Board): Boolean =
|
||||
// Compute what L1 should be from scratch
|
||||
val expectedL1 = new Array[Float](accSize)
|
||||
System.arraycopy(model.weights(0).bias, 0, expectedL1, 0, accSize)
|
||||
for (sq, piece) <- board.pieces do addColumn(expectedL1, featureIndex(piece, squareNum(sq)))
|
||||
|
||||
// Compare with actual L1
|
||||
val actual = l1Stack(ply)
|
||||
val maxError =
|
||||
(0 until accSize).foldLeft(0f) { (currentMax, i) =>
|
||||
val error = math.abs(actual(i) - expectedL1(i))
|
||||
math.max(currentMax, error)
|
||||
}
|
||||
|
||||
maxError < 0.001f // Allow small floating-point errors
|
||||
|
||||
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||
// Extract source and destination square indices early
|
||||
val fromNum = squareNum(move.from)
|
||||
val toNum = squareNum(move.to)
|
||||
|
||||
// Get the moving piece
|
||||
board.pieceAt(move.from).foreach { mover =>
|
||||
subtractColumn(l1, featureIndex(mover, fromNum))
|
||||
|
||||
// If there's a capture, subtract the captured piece
|
||||
board.pieceAt(move.to).foreach { cap =>
|
||||
subtractColumn(l1, featureIndex(cap, toNum))
|
||||
}
|
||||
|
||||
// Add the piece to its new location
|
||||
addColumn(l1, featureIndex(mover, toNum))
|
||||
}
|
||||
|
||||
private def applyEnPassantDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||
board.pieceAt(move.from).foreach { pawn =>
|
||||
val capturedSq = Square(move.to.file, move.from.rank)
|
||||
subtractColumn(l1, featureIndex(pawn, squareNum(move.from)))
|
||||
board.pieceAt(capturedSq).foreach(cap => subtractColumn(l1, featureIndex(cap, squareNum(capturedSq))))
|
||||
addColumn(l1, featureIndex(pawn, squareNum(move.to)))
|
||||
}
|
||||
|
||||
private def applyCastleDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||
board.pieceAt(move.from).foreach { king =>
|
||||
val rank = move.from.rank
|
||||
val kingside = move.moveType == MoveType.CastleKingside
|
||||
val (rookFrom, rookTo) =
|
||||
if kingside then (Square(File.H, rank), Square(File.F, rank))
|
||||
else (Square(File.A, rank), Square(File.D, rank))
|
||||
val rook = Piece(king.color, PieceType.Rook)
|
||||
subtractColumn(l1, featureIndex(king, squareNum(move.from)))
|
||||
addColumn(l1, featureIndex(king, squareNum(move.to)))
|
||||
subtractColumn(l1, featureIndex(rook, squareNum(rookFrom)))
|
||||
addColumn(l1, featureIndex(rook, squareNum(rookTo)))
|
||||
}
|
||||
|
||||
private def applyPromotionDelta(l1: Array[Float], move: Move, promo: PromotionPiece, board: Board): Unit =
|
||||
board.pieceAt(move.from).foreach { pawn =>
|
||||
val toNum = squareNum(move.to)
|
||||
subtractColumn(l1, featureIndex(pawn, squareNum(move.from)))
|
||||
board.pieceAt(move.to).foreach(cap => subtractColumn(l1, featureIndex(cap, toNum)))
|
||||
addColumn(l1, featureIndex(Piece(pawn.color, promotedType(promo)), toNum))
|
||||
}
|
||||
|
||||
private def promotedType(promo: PromotionPiece): PieceType = promo match
|
||||
case PromotionPiece.Knight => PieceType.Knight
|
||||
case PromotionPiece.Bishop => PieceType.Bishop
|
||||
case PromotionPiece.Rook => PieceType.Rook
|
||||
case PromotionPiece.Queen => PieceType.Queen
|
||||
|
||||
// ── Evaluation from accumulator ──────────────────────────────────────────
|
||||
|
||||
def evaluateAtPly(ply: Int, turn: Color, hash: Long): Int =
|
||||
val idx = (hash & EVAL_CACHE_MASK).toInt
|
||||
if evalCacheHashes(idx) == hash then evalCacheScores(idx)
|
||||
else
|
||||
val score = runL2toOutput(l1Stack(ply), turn)
|
||||
evalCacheHashes(idx) = hash
|
||||
evalCacheScores(idx) = score
|
||||
score
|
||||
|
||||
def evaluateAtPlyWithValidation(ply: Int, turn: Color, hash: Long, board: Board): Int =
|
||||
// For debugging: validate that incremental accumulator matches recomputation
|
||||
if validateAccum && ply > 0 && ply % 10 != 0 then
|
||||
val isValid = validateAccumulator(ply, board)
|
||||
if !isValid then System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
|
||||
evaluateAtPly(ply, turn, hash)
|
||||
|
||||
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
|
||||
val l1ReLU = evalBuffers(0)
|
||||
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
|
||||
|
||||
val finalInput =
|
||||
(1 until model.layers.length - 1).foldLeft(l1ReLU) { (input, i) =>
|
||||
val lw = model.weights(i)
|
||||
val out = evalBuffers(i)
|
||||
val ld = model.layers(i)
|
||||
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
|
||||
out
|
||||
}
|
||||
|
||||
val lastIdx = model.layers.length - 1
|
||||
val output = runOutputLayer(finalInput, model.layers(lastIdx).inputSize, model.weights(lastIdx))
|
||||
scoreFromOutput(output, turn)
|
||||
|
||||
private def runDenseReLU(
|
||||
input: Array[Float],
|
||||
inSize: Int,
|
||||
weights: Array[Float],
|
||||
bias: Array[Float],
|
||||
output: Array[Float],
|
||||
outSize: Int,
|
||||
): Unit =
|
||||
for i <- 0 until outSize do
|
||||
val sum = (0 until inSize).foldLeft(bias(i))((s, j) => s + input(j) * weights(i * inSize + j))
|
||||
output(i) = if sum > 0f then sum else 0f
|
||||
|
||||
private def runOutputLayer(input: Array[Float], inSize: Int, lw: LayerWeights): Float =
|
||||
(0 until inSize).foldLeft(lw.bias(0))((sum, j) => sum + input(j) * lw.weights(j))
|
||||
|
||||
private def scoreFromOutput(output: Float, turn: Color): Int =
|
||||
val cp =
|
||||
if math.abs(output) >= 0.9999f then if output > 0f then 20000 else -20000
|
||||
else
|
||||
val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat
|
||||
(300f * atanh).toInt
|
||||
val cpFromTurn = if turn == Color.Black then -cp else cp
|
||||
math.max(-20000, math.min(20000, cpFromTurn))
|
||||
|
||||
// ── Legacy full-board evaluate ────────────────────────────────────────────
|
||||
|
||||
private val legacyL1 = new Array[Float](accSize)
|
||||
|
||||
def evaluate(context: GameContext): Int =
|
||||
System.arraycopy(model.weights(0).bias, 0, legacyL1, 0, accSize)
|
||||
for (sq, piece) <- context.board.pieces do addColumn(legacyL1, featureIndex(piece, squareNum(sq)))
|
||||
runL2toOutput(legacyL1, context.turn)
|
||||
|
||||
def benchmark(): Unit =
|
||||
val context = GameContext.initial
|
||||
val iterations = 1_000_000
|
||||
for _ <- 0 until 10000 do evaluate(context)
|
||||
val startNanos = System.nanoTime()
|
||||
for _ <- 0 until iterations do evaluate(context)
|
||||
val endNanos = System.nanoTime()
|
||||
val totalNanos = endNanos - startNanos
|
||||
val nanosPerEval = totalNanos.toDouble / iterations
|
||||
println()
|
||||
println("=" * 60)
|
||||
println("NNUE BENCHMARK RESULTS")
|
||||
println("=" * 60)
|
||||
println(f"Iterations: $iterations%,d")
|
||||
println(f"Total time: ${totalNanos / 1e9}%.2f seconds")
|
||||
println(f"ns/eval: $nanosPerEval%.2f ns")
|
||||
println(f"evals/second: ${1e9 / nanosPerEval}%.0f evals/s")
|
||||
println("=" * 60)
|
||||
println()
|
||||
@@ -0,0 +1,52 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import java.io.InputStream
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
object NbaiLoader:
|
||||
|
||||
/** Little-endian encoding of ASCII bytes 'N','B','A','I'. */
|
||||
val MAGIC: Int = 0x4942_414e
|
||||
|
||||
def load(stream: InputStream): NbaiModel =
|
||||
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||||
checkHeader(buf)
|
||||
val metadata = readMetadata(buf)
|
||||
val descs = readLayerDescriptors(buf)
|
||||
val weights = descs.map(_ => readLayerWeights(buf))
|
||||
NbaiModel(metadata, descs, weights)
|
||||
|
||||
/** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */
|
||||
def loadDefault(): NbaiModel =
|
||||
Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match
|
||||
case Some(s) =>
|
||||
try load(s)
|
||||
finally s.close()
|
||||
case None => NbaiMigrator.migrateFromBin()
|
||||
|
||||
private def checkHeader(buf: ByteBuffer): Unit =
|
||||
val magic = buf.getInt()
|
||||
if magic != MAGIC then sys.error(s"Invalid NBAI magic: 0x${magic.toHexString}")
|
||||
val version = buf.getShort() & 0xffff
|
||||
if version != 1 then sys.error(s"Unsupported NBAI version: $version")
|
||||
|
||||
private def readMetadata(buf: ByteBuffer): NbaiMetadata =
|
||||
val bytes = new Array[Byte](buf.getInt())
|
||||
buf.get(bytes)
|
||||
NbaiMetadata.fromJson(new String(bytes, StandardCharsets.UTF_8))
|
||||
|
||||
private def readLayerDescriptors(buf: ByteBuffer): Array[LayerDescriptor] =
|
||||
Array.tabulate(buf.getShort() & 0xffff) { _ =>
|
||||
val nameBytes = new Array[Byte](buf.get() & 0xff)
|
||||
buf.get(nameBytes)
|
||||
LayerDescriptor(new String(nameBytes, StandardCharsets.US_ASCII), buf.getInt(), buf.getInt())
|
||||
}
|
||||
|
||||
private def readLayerWeights(buf: ByteBuffer): LayerWeights =
|
||||
LayerWeights(readFloats(buf), readFloats(buf))
|
||||
|
||||
private def readFloats(buf: ByteBuffer): Array[Float] =
|
||||
val arr = new Array[Float](buf.getInt())
|
||||
for i <- arr.indices do arr(i) = buf.getFloat()
|
||||
arr
|
||||
@@ -0,0 +1,43 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
|
||||
/** Converts the legacy nnue_weights.bin resource into an NbaiModel. Used as fallback when no .nbai file exists. */
|
||||
object NbaiMigrator:
|
||||
|
||||
private val BinMagic = 0x4555_4e4e
|
||||
private val BinVersion = 1
|
||||
|
||||
private val DefaultLayers: Array[LayerDescriptor] = Array(
|
||||
LayerDescriptor("relu", 768, 1536),
|
||||
LayerDescriptor("relu", 1536, 1024),
|
||||
LayerDescriptor("relu", 1024, 512),
|
||||
LayerDescriptor("relu", 512, 256),
|
||||
LayerDescriptor("linear", 256, 1),
|
||||
)
|
||||
|
||||
private val UnknownMetadata: NbaiMetadata =
|
||||
NbaiMetadata(trainedBy = "unknown", trainedAt = "unknown", trainingDataCount = 0L, valLoss = 0.0, trainLoss = 0.0)
|
||||
|
||||
def migrateFromBin(): NbaiModel =
|
||||
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
|
||||
.getOrElse(sys.error("Neither nnue_weights.nbai nor nnue_weights.bin found in resources"))
|
||||
try
|
||||
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||||
checkBinHeader(buf)
|
||||
val weights = DefaultLayers.map(_ => readBinLayerWeights(buf))
|
||||
NbaiModel(UnknownMetadata, DefaultLayers, weights)
|
||||
finally stream.close()
|
||||
|
||||
private def checkBinHeader(buf: ByteBuffer): Unit =
|
||||
val magic = buf.getInt()
|
||||
if magic != BinMagic then sys.error(s"Invalid bin magic: 0x${magic.toHexString}")
|
||||
val version = buf.getInt()
|
||||
if version != BinVersion then sys.error(s"Unsupported bin version: $version")
|
||||
|
||||
private def readBinLayerWeights(buf: ByteBuffer): LayerWeights =
|
||||
LayerWeights(readBinTensor(buf), readBinTensor(buf))
|
||||
|
||||
private def readBinTensor(buf: ByteBuffer): Array[Float] =
|
||||
val shape = Array.tabulate(buf.getInt())(_ => buf.getInt())
|
||||
Array.tabulate(shape.product)(_ => buf.getFloat())
|
||||
@@ -0,0 +1,45 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
/** Descriptor for a single dense layer stored in a .nbai file. */
|
||||
case class LayerDescriptor(activation: String, inputSize: Int, outputSize: Int)
|
||||
|
||||
/** Training metadata embedded in every .nbai file. */
|
||||
case class NbaiMetadata(
|
||||
trainedBy: String,
|
||||
trainedAt: String,
|
||||
trainingDataCount: Long,
|
||||
valLoss: Double,
|
||||
trainLoss: Double,
|
||||
):
|
||||
def toJson: String =
|
||||
s"""{
|
||||
| "trainedBy": "$trainedBy",
|
||||
| "trainedAt": "$trainedAt",
|
||||
| "trainingDataCount": $trainingDataCount,
|
||||
| "valLoss": $valLoss,
|
||||
| "trainLoss": $trainLoss
|
||||
|}""".stripMargin
|
||||
|
||||
object NbaiMetadata:
|
||||
def fromJson(json: String): NbaiMetadata =
|
||||
def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("")
|
||||
def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0")
|
||||
NbaiMetadata(
|
||||
str("trainedBy"),
|
||||
str("trainedAt"),
|
||||
num("trainingDataCount").toLong,
|
||||
num("valLoss").toDouble,
|
||||
num("trainLoss").toDouble,
|
||||
)
|
||||
|
||||
/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */
|
||||
case class LayerWeights(weights: Array[Float], bias: Array[Float])
|
||||
|
||||
/** A fully deserialized .nbai model ready to initialize NNUE. */
|
||||
case class NbaiModel(
|
||||
metadata: NbaiMetadata,
|
||||
layers: Array[LayerDescriptor],
|
||||
weights: Array[LayerWeights],
|
||||
):
|
||||
require(layers.length == weights.length, "Layer count must match weight count")
|
||||
require(layers.length >= 2, "Model must have at least 2 layers")
|
||||
@@ -0,0 +1,51 @@
|
||||
package de.nowchess.bot.bots.nnue
|
||||
|
||||
import java.io.{ByteArrayOutputStream, OutputStream}
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
object NbaiWriter:
|
||||
|
||||
def write(model: NbaiModel, out: OutputStream): Unit =
|
||||
val acc = new ByteArrayOutputStream()
|
||||
writeHeader(acc)
|
||||
writeMetadata(acc, model.metadata)
|
||||
writeLayerDescriptors(acc, model.layers)
|
||||
model.weights.foreach(lw => writeLayerWeights(acc, lw))
|
||||
out.write(acc.toByteArray)
|
||||
|
||||
private def writeHeader(out: ByteArrayOutputStream): Unit =
|
||||
val buf = ByteBuffer.allocate(6).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putInt(NbaiLoader.MAGIC)
|
||||
buf.putShort(1.toShort)
|
||||
out.write(buf.array())
|
||||
|
||||
private def writeMetadata(out: ByteArrayOutputStream, meta: NbaiMetadata): Unit =
|
||||
val json = meta.toJson.getBytes(StandardCharsets.UTF_8)
|
||||
val buf = ByteBuffer.allocate(4 + json.length).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putInt(json.length)
|
||||
buf.put(json)
|
||||
out.write(buf.array())
|
||||
|
||||
private def writeLayerDescriptors(out: ByteArrayOutputStream, layers: Array[LayerDescriptor]): Unit =
|
||||
val nameBytes = layers.map(_.activation.getBytes(StandardCharsets.US_ASCII))
|
||||
val capacity = 2 + layers.indices.map(i => 1 + nameBytes(i).length + 8).sum
|
||||
val buf = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putShort(layers.length.toShort)
|
||||
layers.zip(nameBytes).foreach { (l, nb) =>
|
||||
buf.put(nb.length.toByte)
|
||||
buf.put(nb)
|
||||
buf.putInt(l.inputSize)
|
||||
buf.putInt(l.outputSize)
|
||||
}
|
||||
out.write(buf.array())
|
||||
|
||||
private def writeLayerWeights(out: ByteArrayOutputStream, lw: LayerWeights): Unit =
|
||||
writeFloats(out, lw.weights)
|
||||
writeFloats(out, lw.bias)
|
||||
|
||||
private def writeFloats(out: ByteArrayOutputStream, floats: Array[Float]): Unit =
|
||||
val buf = ByteBuffer.allocate(4 + floats.length * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||||
buf.putInt(floats.length)
|
||||
floats.foreach(buf.putFloat)
|
||||
out.write(buf.array())
|
||||
@@ -0,0 +1,419 @@
|
||||
package de.nowchess.bot.logic
|
||||
|
||||
import de.nowchess.api.board.PieceType
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.{Move, MoveType}
|
||||
import de.nowchess.bot.ai.Evaluation
|
||||
import de.nowchess.bot.util.ZobristHash
|
||||
import de.nowchess.rules.RuleSet
|
||||
import de.nowchess.rules.sets.DefaultRules
|
||||
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}
|
||||
|
||||
final class AlphaBetaSearch(
|
||||
rules: RuleSet = DefaultRules,
|
||||
tt: TranspositionTable = TranspositionTable(),
|
||||
weights: Evaluation,
|
||||
numThreads: Int = Runtime.getRuntime.availableProcessors,
|
||||
):
|
||||
|
||||
private val INF = Int.MaxValue / 2
|
||||
private val MAX_QUIESCENCE_PLY = 64
|
||||
private val NULL_MOVE_R = 2
|
||||
private val ASPIRATION_DELTA = 50
|
||||
private val ASPIRATION_DELTA_MAX = 150
|
||||
private val TIME_CHECK_FREQUENCY = 1000
|
||||
private val FUTILITY_MARGIN = 100
|
||||
private val CHECK_EXTENSION = 1
|
||||
|
||||
private val timeStartMs = AtomicLong(0L)
|
||||
private val timeLimitMs = AtomicLong(0L)
|
||||
private val nodeCount = AtomicInteger(0)
|
||||
private val ordering = MoveOrdering.OrderingContext()
|
||||
|
||||
private final case class QuiescenceNode(
|
||||
context: GameContext,
|
||||
ply: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
hash: Long,
|
||||
)
|
||||
|
||||
/** Return the best move for the side to move, searching to maxDepth plies. Uses iterative deepening with aspiration
|
||||
* windows.
|
||||
*/
|
||||
def bestMove(context: GameContext, maxDepth: Int): Option[Move] =
|
||||
bestMove(context, maxDepth, Set.empty)
|
||||
|
||||
def bestMove(context: GameContext, maxDepth: Int, excludedRootMoves: Set[Move]): Option[Move] =
|
||||
tt.clear()
|
||||
ordering.clear()
|
||||
weights.initAccumulator(context)
|
||||
timeStartMs.set(System.currentTimeMillis)
|
||||
timeLimitMs.set(Long.MaxValue / 4)
|
||||
nodeCount.set(0)
|
||||
val rootHash = ZobristHash.hash(context)
|
||||
(1 to maxDepth)
|
||||
.foldLeft((None: Option[Move], 0)) { case ((bestSoFar, prevScore), depth) =>
|
||||
val (alpha, beta) =
|
||||
if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
|
||||
val (score, move) = searchWithAspiration(
|
||||
context,
|
||||
depth,
|
||||
alpha,
|
||||
beta,
|
||||
ASPIRATION_DELTA,
|
||||
rootHash,
|
||||
excludedRootMoves,
|
||||
)
|
||||
(move.orElse(bestSoFar), score)
|
||||
}
|
||||
._1
|
||||
|
||||
/** Return the best move for the side to move within a time budget (ms). Uses iterative deepening, stopping when time
|
||||
* runs out.
|
||||
*/
|
||||
def bestMoveWithTime(context: GameContext, timeBudgetMs: Long): Option[Move] =
|
||||
bestMoveWithTime(context, timeBudgetMs, Set.empty)
|
||||
|
||||
def bestMoveWithTime(context: GameContext, timeBudgetMs: Long, excludedRootMoves: Set[Move]): Option[Move] =
|
||||
tt.clear()
|
||||
ordering.clear()
|
||||
weights.initAccumulator(context)
|
||||
timeStartMs.set(System.currentTimeMillis)
|
||||
timeLimitMs.set(timeBudgetMs)
|
||||
nodeCount.set(0)
|
||||
val rootHash = ZobristHash.hash(context)
|
||||
|
||||
@scala.annotation.tailrec
|
||||
def loop(bestSoFar: Option[Move], prevScore: Int, depth: Int): Option[Move] =
|
||||
if isOutOfTime then bestSoFar
|
||||
else
|
||||
val (alpha, beta) =
|
||||
if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
|
||||
val (score, move) = searchWithAspiration(
|
||||
context,
|
||||
depth,
|
||||
alpha,
|
||||
beta,
|
||||
ASPIRATION_DELTA,
|
||||
rootHash,
|
||||
excludedRootMoves,
|
||||
)
|
||||
loop(move.orElse(bestSoFar), score, depth + 1)
|
||||
|
||||
loop(None, 0, 1)
|
||||
|
||||
private def isOutOfTime: Boolean =
|
||||
System.currentTimeMillis - timeStartMs.get >= timeLimitMs.get
|
||||
|
||||
private def searchWithAspiration(
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
initialWindow: Int,
|
||||
rootHash: Long,
|
||||
excludedRootMoves: Set[Move],
|
||||
): (Int, Option[Move]) =
|
||||
val state = SearchState(rootHash, Map(rootHash -> 1))
|
||||
|
||||
@scala.annotation.tailrec
|
||||
def loop(currentAlpha: Int, currentBeta: Int, delta: Int, attempt: Int): (Int, Option[Move]) =
|
||||
if attempt >= 3 || attempt >= depth then search(context, depth, 0, Window(-INF, INF), state, excludedRootMoves)
|
||||
else
|
||||
val (score, move) = search(context, depth, 0, Window(currentAlpha, currentBeta), state, excludedRootMoves)
|
||||
if score > currentAlpha && score < currentBeta then (score, move)
|
||||
else if score <= currentAlpha then
|
||||
loop(score - delta, currentBeta, math.min(delta * 2, ASPIRATION_DELTA_MAX), attempt + 1)
|
||||
else loop(currentAlpha, score + delta, math.min(delta * 2, ASPIRATION_DELTA_MAX), attempt + 1)
|
||||
|
||||
loop(alpha, beta, initialWindow, 0)
|
||||
|
||||
private def hasNonPawnMaterial(context: GameContext): Boolean =
|
||||
context.board.pieces.values.exists { piece =>
|
||||
piece.color == context.turn &&
|
||||
piece.pieceType != PieceType.Pawn &&
|
||||
piece.pieceType != PieceType.King
|
||||
}
|
||||
|
||||
private def nullMoveContext(context: GameContext): GameContext =
|
||||
context.withTurn(context.turn.opposite).withEnPassantSquare(None)
|
||||
|
||||
private def tryNullMove(
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
beta: Int,
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
): Option[Int] =
|
||||
val nullCtx = nullMoveContext(context)
|
||||
val nullState = state.advance(ZobristHash.hash(nullCtx))
|
||||
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
|
||||
weights.copyAccumulator(ply, ply + 1)
|
||||
val (score, _) = search(nullCtx, reductionDepth, ply + 1, Window(-beta, -beta + 1), nullState, excludedRootMoves)
|
||||
if -score >= beta then Some(beta) else None
|
||||
|
||||
/** Negamax alpha-beta search returning (score, best move). */
|
||||
private def search(
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
window: Window,
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
): (Int, Option[Move]) =
|
||||
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
|
||||
searchNode(params)
|
||||
|
||||
private def searchNode(params: SearchParams): (Int, Option[Move]) =
|
||||
val count = nodeCount.incrementAndGet()
|
||||
immediateSearchResult(params, count).getOrElse {
|
||||
val legalMoves = rules.allLegalMoves(params.context)
|
||||
terminalSearchResult(params, legalMoves).getOrElse(searchDeeper(params, legalMoves))
|
||||
}
|
||||
|
||||
private def immediateSearchResult(
|
||||
params: SearchParams,
|
||||
count: Int,
|
||||
): Option[(Int, Option[Move])] =
|
||||
if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then
|
||||
Some((weights.evaluateAccumulator(params.ply, params.context, params.state.hash), None))
|
||||
else if params.state.repetitions.getOrElse(params.state.hash, 0) >= 3 then Some((weights.DRAW_SCORE, None))
|
||||
else ttCutoff(params)
|
||||
|
||||
private def ttCutoff(params: SearchParams): Option[(Int, Option[Move])] =
|
||||
tt.probe(params.state.hash).filter(_.depth >= params.depth).flatMap { entry =>
|
||||
entry.flag match
|
||||
case TTFlag.Exact => Some((entry.score, entry.bestMove))
|
||||
case TTFlag.Lower =>
|
||||
val newAlpha = math.max(params.window.alpha, entry.score)
|
||||
Option.when(newAlpha >= params.window.beta)((entry.score, entry.bestMove))
|
||||
case TTFlag.Upper =>
|
||||
val newBeta = math.min(params.window.beta, entry.score)
|
||||
Option.when(params.window.alpha >= newBeta)((entry.score, entry.bestMove))
|
||||
}
|
||||
|
||||
private def terminalSearchResult(
|
||||
params: SearchParams,
|
||||
legalMoves: List[Move],
|
||||
): Option[(Int, Option[Move])] =
|
||||
if legalMoves.isEmpty then
|
||||
Some(
|
||||
(
|
||||
if rules.isCheckmate(params.context) then -(weights.CHECKMATE_SCORE - params.ply) else weights.DRAW_SCORE,
|
||||
None,
|
||||
),
|
||||
)
|
||||
else if rules.isInsufficientMaterial(params.context) || rules.isFiftyMoveRule(params.context) then
|
||||
Some((weights.DRAW_SCORE, None))
|
||||
else if params.depth == 0 then
|
||||
Some((quiescence(params.context, params.ply, params.window.alpha, params.window.beta, params.state.hash), None))
|
||||
else None
|
||||
|
||||
private def searchDeeper(
|
||||
params: SearchParams,
|
||||
legalMoves: List[Move],
|
||||
): (Int, Option[Move]) =
|
||||
val nullResult =
|
||||
Option
|
||||
.when(canTryNullMove(params))(
|
||||
tryNullMove(
|
||||
params.context,
|
||||
params.depth,
|
||||
params.ply,
|
||||
params.window.beta,
|
||||
params.state,
|
||||
params.excludedRootMoves,
|
||||
),
|
||||
)
|
||||
.flatten
|
||||
|
||||
nullResult.map((_, None)).getOrElse {
|
||||
val ttBest = tt.probe(params.state.hash).flatMap(_.bestMove)
|
||||
val ordered = MoveOrdering.sort(params.context, legalMoves, ttBest, params.ply, ordering)
|
||||
searchSequential(
|
||||
params.context,
|
||||
params.depth,
|
||||
params.ply,
|
||||
params.window,
|
||||
ordered,
|
||||
params.state,
|
||||
params.excludedRootMoves,
|
||||
)
|
||||
}
|
||||
|
||||
private def canTryNullMove(params: SearchParams): Boolean =
|
||||
params.depth >= 3 &&
|
||||
!rules.isCheck(params.context) &&
|
||||
hasNonPawnMaterial(params.context)
|
||||
|
||||
private def isQuietMove(context: GameContext, move: Move): Boolean =
|
||||
!isCapture(context, move) &&
|
||||
move.moveType != MoveType.CastleKingside &&
|
||||
move.moveType != MoveType.CastleQueenside
|
||||
|
||||
private def scoreMove(
|
||||
child: GameContext,
|
||||
childState: SearchState,
|
||||
params: SearchParams,
|
||||
extension: Int,
|
||||
reduction: Int,
|
||||
a: Int,
|
||||
): Int =
|
||||
val betaNeg = -params.window.beta
|
||||
if reduction > 0 then
|
||||
val (rs, _) = search(
|
||||
child,
|
||||
math.max(0, params.depth - 1 - reduction + extension),
|
||||
params.ply + 1,
|
||||
Window(-a - 1, -a),
|
||||
childState,
|
||||
params.excludedRootMoves,
|
||||
)
|
||||
val s = -rs
|
||||
if s > a then
|
||||
val (fs, _) = search(
|
||||
child,
|
||||
math.max(0, params.depth - 1 + extension),
|
||||
params.ply + 1,
|
||||
Window(betaNeg, -a),
|
||||
childState,
|
||||
params.excludedRootMoves,
|
||||
)
|
||||
-fs
|
||||
else s
|
||||
else
|
||||
val (rs, _) = search(
|
||||
child,
|
||||
math.max(0, params.depth - 1 + extension),
|
||||
params.ply + 1,
|
||||
Window(betaNeg, -a),
|
||||
childState,
|
||||
params.excludedRootMoves,
|
||||
)
|
||||
-rs
|
||||
|
||||
private def evalSingleMove(
|
||||
move: Move,
|
||||
moveNumber: Int,
|
||||
a: Int,
|
||||
params: SearchParams,
|
||||
): Option[(Int, Boolean)] =
|
||||
val skipRoot = params.ply == 0 && params.excludedRootMoves.contains(move)
|
||||
val isQuiet = isQuietMove(params.context, move)
|
||||
val futility = params.depth == 1 && isQuiet && moveNumber > 2 &&
|
||||
weights.evaluateAccumulator(params.ply, params.context, params.state.hash) + FUTILITY_MARGIN < params.window.alpha
|
||||
if skipRoot || futility then None
|
||||
else
|
||||
val child = rules.applyMove(params.context)(move)
|
||||
val childHash = ZobristHash.nextHash(params.context, params.state.hash, move, child)
|
||||
weights.pushAccumulator(params.ply + 1, move, params.context, child)
|
||||
val childState = params.state.advance(childHash)
|
||||
val extension = if rules.isCheck(child) then CHECK_EXTENSION else 0
|
||||
val reduction = if moveNumber > 4 && params.depth >= 3 && isQuiet then 1 else 0
|
||||
Some((scoreMove(child, childState, params, extension, reduction, a), isQuiet))
|
||||
|
||||
private def recordCutoff(move: Move, depth: Int, ply: Int): Unit =
|
||||
ordering.addHistory(
|
||||
move.from.rank.ordinal * 8 + move.from.file.ordinal,
|
||||
move.to.rank.ordinal * 8 + move.to.file.ordinal,
|
||||
depth * depth,
|
||||
)
|
||||
ordering.addKillerMove(ply, move)
|
||||
|
||||
@scala.annotation.tailrec
|
||||
private def searchLoop(
|
||||
idx: Int,
|
||||
moveNumber: Int,
|
||||
acc: LoopAcc,
|
||||
params: SearchParams,
|
||||
ordered: List[Move],
|
||||
): (Option[Move], Int, Boolean) =
|
||||
if idx >= ordered.length then (acc.bestMove, acc.bestScore, false)
|
||||
else
|
||||
val move = ordered(idx)
|
||||
evalSingleMove(move, moveNumber, acc.a, params) match
|
||||
case None => searchLoop(idx + 1, moveNumber + 1, acc, params, ordered)
|
||||
case Some((score, isQuiet)) =>
|
||||
val newAcc = LoopAcc(
|
||||
if score > acc.bestScore then Some(move) else acc.bestMove,
|
||||
math.max(acc.bestScore, score),
|
||||
math.max(acc.a, score),
|
||||
)
|
||||
if newAcc.a >= params.window.beta then
|
||||
if isQuiet then recordCutoff(move, params.depth, params.ply)
|
||||
(newAcc.bestMove, newAcc.bestScore, true)
|
||||
else searchLoop(idx + 1, moveNumber + 1, newAcc, params, ordered)
|
||||
|
||||
private def searchSequential(
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
window: Window,
|
||||
ordered: List[Move],
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
): (Int, Option[Move]) =
|
||||
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
|
||||
val (bestMove, bestScore, cutoff) = searchLoop(0, 0, LoopAcc(None, -INF, window.alpha), params, ordered)
|
||||
val flag =
|
||||
if cutoff then TTFlag.Lower
|
||||
else if bestScore <= window.alpha then TTFlag.Upper
|
||||
else TTFlag.Exact
|
||||
tt.store(TTEntry(state.hash, depth, bestScore, flag, bestMove))
|
||||
(bestScore, bestMove)
|
||||
|
||||
/** Quiescence search: only captures until position is quiet. */
|
||||
private def quiescence(
|
||||
context: GameContext,
|
||||
ply: Int,
|
||||
alpha: Int,
|
||||
beta: Int,
|
||||
hash: Long,
|
||||
): Int =
|
||||
quiescenceNode(QuiescenceNode(context, ply, alpha, beta, hash))
|
||||
|
||||
private def quiescenceNode(node: QuiescenceNode): Int =
|
||||
val inCheck = rules.isCheck(node.context)
|
||||
val standPat = if inCheck then -INF else weights.evaluateAccumulator(node.ply, node.context, node.hash)
|
||||
|
||||
if !inCheck && standPat >= node.beta then node.beta
|
||||
else if node.ply >= MAX_QUIESCENCE_PLY then quiescenceAtDepthLimit(node, inCheck, standPat)
|
||||
else
|
||||
val moves = tacticalMoves(node.context, inCheck)
|
||||
if inCheck && moves.isEmpty then -(weights.CHECKMATE_SCORE - node.ply)
|
||||
else
|
||||
val ordered = MoveOrdering.sort(node.context, moves, None)
|
||||
val a0 = if inCheck then node.alpha else math.max(node.alpha, standPat)
|
||||
quiescenceLoop(node, ordered, 0, a0)
|
||||
|
||||
private def quiescenceAtDepthLimit(node: QuiescenceNode, inCheck: Boolean, standPat: Int): Int =
|
||||
if inCheck then weights.evaluateAccumulator(node.ply, node.context, node.hash) else standPat
|
||||
|
||||
private def tacticalMoves(context: GameContext, inCheck: Boolean): List[Move] =
|
||||
val allMoves = rules.allLegalMoves(context)
|
||||
if inCheck then allMoves else allMoves.filter(m => isCapture(context, m))
|
||||
|
||||
@scala.annotation.tailrec
|
||||
private def quiescenceLoop(
|
||||
node: QuiescenceNode,
|
||||
ordered: List[Move],
|
||||
idx: Int,
|
||||
a: Int,
|
||||
): Int =
|
||||
if idx >= ordered.length then a
|
||||
else
|
||||
val move = ordered(idx)
|
||||
val child = rules.applyMove(node.context)(move)
|
||||
val childHash = ZobristHash.nextHash(node.context, node.hash, move, child)
|
||||
weights.pushAccumulator(node.ply + 1, move, node.context, child)
|
||||
val score = -quiescence(child, node.ply + 1, -node.beta, -a, childHash)
|
||||
if score >= node.beta then node.beta
|
||||
else quiescenceLoop(node, ordered, idx + 1, math.max(a, score))
|
||||
|
||||
private def isCapture(context: GameContext, move: Move): Boolean = move.moveType match
|
||||
case MoveType.Normal(true) => true
|
||||
case MoveType.EnPassant => true
|
||||
case MoveType.Promotion(_) => context.board.pieceAt(move.to).exists(_.color != context.turn)
|
||||
case _ => false
|
||||
@@ -0,0 +1,177 @@
|
||||
package de.nowchess.bot.logic
|
||||
|
||||
import de.nowchess.api.board.{Board, Color, Piece, PieceType, Square}
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.mutable
|
||||
|
||||
object MoveOrdering:
|
||||
|
||||
class OrderingContext:
|
||||
private val killerMoves = mutable.Map[Int, List[Move]]()
|
||||
private val historyTable = mutable.Map[(Int, Int), Int]()
|
||||
|
||||
def addKillerMove(ply: Int, move: Move): Unit =
|
||||
val current = killerMoves.getOrElse(ply, List())
|
||||
if current.isEmpty || (current.head.from != move.from || current.head.to != move.to) then
|
||||
killerMoves(ply) = (move :: current).take(2)
|
||||
|
||||
def getKillerMoves(ply: Int): List[Move] =
|
||||
killerMoves.getOrElse(ply, List())
|
||||
|
||||
def addHistory(from: Int, to: Int, bonus: Int): Unit =
|
||||
val key = (from, to)
|
||||
historyTable(key) = historyTable.getOrElse(key, 0) + bonus
|
||||
|
||||
def getHistory(from: Int, to: Int): Int =
|
||||
historyTable.getOrElse((from, to), 0)
|
||||
|
||||
def clear(): Unit =
|
||||
killerMoves.clear()
|
||||
historyTable.clear()
|
||||
|
||||
def score(
|
||||
context: GameContext,
|
||||
move: Move,
|
||||
ttBestMove: Option[Move],
|
||||
ply: Int = 0,
|
||||
ordering: OrderingContext = new OrderingContext(),
|
||||
): Int =
|
||||
if ttBestMove.exists(m => m.from == move.from && m.to == move.to) then Int.MaxValue
|
||||
else
|
||||
move.moveType match
|
||||
case MoveType.Promotion(PromotionPiece.Queen) =>
|
||||
1_000_000 + promotionCaptureBonus(context, move)
|
||||
case MoveType.Normal(true) | MoveType.EnPassant =>
|
||||
captureScore(context, move)
|
||||
case MoveType.Promotion(_) =>
|
||||
50_000 + promotionCaptureBonus(context, move)
|
||||
case _ => scoreQuietMove(move, ply, ordering)
|
||||
|
||||
def sort(
|
||||
context: GameContext,
|
||||
moves: List[Move],
|
||||
ttBestMove: Option[Move],
|
||||
ply: Int = 0,
|
||||
ordering: OrderingContext = new OrderingContext(),
|
||||
): List[Move] =
|
||||
moves.sortBy(m => -score(context, m, ttBestMove, ply, ordering))
|
||||
|
||||
private def scoreQuietMove(move: Move, ply: Int, ordering: OrderingContext): Int =
|
||||
val isKiller = ordering.getKillerMoves(ply).exists(k => k.from == move.from && k.to == move.to)
|
||||
val fromIdx = move.from.rank.ordinal * 8 + move.from.file.ordinal
|
||||
val toIdx = move.to.rank.ordinal * 8 + move.to.file.ordinal
|
||||
val history = ordering.getHistory(fromIdx, toIdx)
|
||||
if isKiller then 10_000 + (history / 10) else history / 10
|
||||
|
||||
private def promotionCaptureBonus(context: GameContext, move: Move): Int =
|
||||
if isCapture(context, move) then captureScore(context, move) else 0
|
||||
|
||||
private def captureScore(context: GameContext, move: Move): Int =
|
||||
val see = staticExchange(context, move)
|
||||
val seeBias = if see >= 0 then 20_000 else -20_000
|
||||
100_000 + mvvLva(context, move) + seeBias + see
|
||||
|
||||
private def mvvLva(context: GameContext, move: Move): Int =
|
||||
(victimValue(context, move) * 10) - attackerValue(context, move)
|
||||
|
||||
private def attackerValue(context: GameContext, move: Move): Int =
|
||||
context.board.pieceAt(move.from).map(pieceValue).getOrElse(0)
|
||||
|
||||
private def victimValue(context: GameContext, move: Move): Int =
|
||||
move.moveType match
|
||||
case MoveType.Normal(true) => context.board.pieceAt(move.to).map(pieceValue).getOrElse(0)
|
||||
case MoveType.EnPassant => 1
|
||||
case MoveType.Promotion(_) => context.board.pieceAt(move.to).map(pieceValue).getOrElse(0)
|
||||
case _ => 0
|
||||
|
||||
private def pieceValue(piece: Piece): Int = piece.pieceType match
|
||||
case PieceType.Pawn => 1
|
||||
case PieceType.Knight => 3
|
||||
case PieceType.Bishop => 3
|
||||
case PieceType.Rook => 5
|
||||
case PieceType.Queen => 9
|
||||
case PieceType.King => 200
|
||||
|
||||
private def isCapture(context: GameContext, move: Move): Boolean = move.moveType match
|
||||
case MoveType.Normal(true) => true
|
||||
case MoveType.EnPassant => true
|
||||
case MoveType.Promotion(_) => context.board.pieceAt(move.to).exists(_.color != context.turn)
|
||||
case _ => false
|
||||
|
||||
private def staticExchange(context: GameContext, move: Move): Int =
|
||||
if !isCapture(context, move) then 0
|
||||
else
|
||||
val target = move.to
|
||||
val initialGain = victimValue(context, move)
|
||||
movedPieceAfterMove(context, move).fold(initialGain) { moved =>
|
||||
val boardAfterMove = applySeeMove(context.board, move, moved)
|
||||
initialGain - seeGain(boardAfterMove, target, context.turn.opposite, pieceValue(moved))
|
||||
}
|
||||
|
||||
private def movedPieceAfterMove(context: GameContext, move: Move): Option[Piece] =
|
||||
move.moveType match
|
||||
case MoveType.Promotion(pp) => Some(Piece(context.turn, promotionPieceType(pp)))
|
||||
case _ => context.board.pieceAt(move.from)
|
||||
|
||||
private def seeGain(board: Board, target: Square, side: Color, currentValue: Int): Int =
|
||||
leastValuableAttacker(board, target, side) match
|
||||
case None => 0
|
||||
case Some((from, attacker)) =>
|
||||
val nextBoard = board.removed(from).updated(target, attacker)
|
||||
val replyGain = seeGain(nextBoard, target, side.opposite, pieceValue(attacker))
|
||||
math.max(0, currentValue - replyGain)
|
||||
|
||||
private def applySeeMove(board: Board, move: Move, moved: Piece): Board =
|
||||
move.moveType match
|
||||
case MoveType.EnPassant =>
|
||||
val capturedSquare = Square(move.to.file, move.from.rank)
|
||||
board.removed(move.from).removed(capturedSquare).updated(move.to, moved)
|
||||
case _ => board.removed(move.from).updated(move.to, moved)
|
||||
|
||||
private def leastValuableAttacker(board: Board, target: Square, color: Color): Option[(Square, Piece)] =
|
||||
board.pieces
|
||||
.collect {
|
||||
case (sq, piece) if piece.color == color && attacksSquare(board, sq, target, piece) => (sq, piece)
|
||||
}
|
||||
.toList
|
||||
.sortBy { case (_, piece) => pieceValue(piece) }
|
||||
.headOption
|
||||
|
||||
private def attacksSquare(board: Board, from: Square, target: Square, piece: Piece): Boolean =
|
||||
val df = target.file.ordinal - from.file.ordinal
|
||||
val dr = target.rank.ordinal - from.rank.ordinal
|
||||
piece.pieceType match
|
||||
case PieceType.Pawn =>
|
||||
val dir = if piece.color == Color.White then 1 else -1
|
||||
dr == dir && math.abs(df) == 1
|
||||
case PieceType.Knight =>
|
||||
val adf = math.abs(df)
|
||||
val adr = math.abs(dr)
|
||||
(adf == 1 && adr == 2) || (adf == 2 && adr == 1)
|
||||
case PieceType.Bishop => clearLine(board, from, target, df, dr, diagonal = true)
|
||||
case PieceType.Rook => clearLine(board, from, target, df, dr, diagonal = false)
|
||||
case PieceType.Queen =>
|
||||
clearLine(board, from, target, df, dr, diagonal = true) ||
|
||||
clearLine(board, from, target, df, dr, diagonal = false)
|
||||
case PieceType.King => math.abs(df) <= 1 && math.abs(dr) <= 1
|
||||
|
||||
private def clearLine(board: Board, from: Square, target: Square, df: Int, dr: Int, diagonal: Boolean): Boolean =
|
||||
val valid =
|
||||
if diagonal then math.abs(df) == math.abs(dr) && df != 0 else (df == 0 && dr != 0) || (dr == 0 && df != 0)
|
||||
valid && pathClear(board, from, target, Integer.compare(df, 0), Integer.compare(dr, 0))
|
||||
|
||||
@tailrec
|
||||
private def pathClear(board: Board, from: Square, target: Square, stepF: Int, stepR: Int): Boolean =
|
||||
from.offset(stepF, stepR) match
|
||||
case None => false
|
||||
case Some(next) if next == target => true
|
||||
case Some(next) => board.pieceAt(next).isEmpty && pathClear(board, next, target, stepF, stepR)
|
||||
|
||||
private def promotionPieceType(piece: PromotionPiece): PieceType = piece match
|
||||
case PromotionPiece.Knight => PieceType.Knight
|
||||
case PromotionPiece.Bishop => PieceType.Bishop
|
||||
case PromotionPiece.Rook => PieceType.Rook
|
||||
case PromotionPiece.Queen => PieceType.Queen
|
||||
@@ -0,0 +1,61 @@
|
||||
package de.nowchess.bot.logic
|
||||
|
||||
import de.nowchess.api.game.GameContext
|
||||
import de.nowchess.api.move.Move
|
||||
|
||||
final case class Window(alpha: Int, beta: Int)
|
||||
|
||||
final case class LoopAcc(bestMove: Option[Move], bestScore: Int, a: Int)
|
||||
|
||||
final case class SearchParams(
|
||||
context: GameContext,
|
||||
depth: Int,
|
||||
ply: Int,
|
||||
window: Window,
|
||||
state: SearchState,
|
||||
excludedRootMoves: Set[Move],
|
||||
)
|
||||
|
||||
final case class SearchState(hash: Long, repetitions: Map[Long, Int]):
|
||||
def advance(nextHash: Long): SearchState =
|
||||
SearchState(
|
||||
nextHash,
|
||||
repetitions.updatedWith(nextHash) {
|
||||
case Some(v) => Some(v + 1)
|
||||
case None => Some(1)
|
||||
},
|
||||
)
|
||||
|
||||
enum TTFlag:
|
||||
case Exact // Score is exact
|
||||
case Lower // Score is a lower bound
|
||||
case Upper // Score is an upper bound
|
||||
|
||||
final case class TTEntry(
|
||||
hash: Long,
|
||||
depth: Int,
|
||||
score: Int,
|
||||
flag: TTFlag,
|
||||
bestMove: Option[Move],
|
||||
)
|
||||
|
||||
final class TranspositionTable(val sizePow2: Int = 20):
|
||||
private val size = 1 << sizePow2
|
||||
private val mask = size - 1L
|
||||
private val locks = Array.fill(size)(new Object())
|
||||
private val table: Array[Option[TTEntry]] = Array.fill(size)(None)
|
||||
|
||||
def probe(hash: Long): Option[TTEntry] =
|
||||
val index = (hash & mask).toInt
|
||||
locks(index).synchronized {
|
||||
table(index).filter(_.hash == hash)
|
||||
}
|
||||
|
||||
def store(entry: TTEntry): Unit =
|
||||
val index = (entry.hash & mask).toInt
|
||||
locks(index).synchronized {
|
||||
table(index) = Some(entry)
|
||||
}
|
||||
|
||||
def clear(): Unit =
|
||||
for i <- 0 until size do locks(i).synchronized { table(i) = None }
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user