Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8744bee2dd | |||
| 5f4d33f3ca | |||
| 767d3051a7 |
+163
-32
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
> **Stack:** raw-http | none | unknown | scala
|
> **Stack:** raw-http | none | unknown | scala
|
||||||
|
|
||||||
> 0 routes | 0 models | 0 components | 35 lib files | 0 env vars | 0 middleware
|
> 0 routes | 0 models | 0 components | 63 lib files | 1 env vars | 1 middleware
|
||||||
> **Token savings:** this file is ~3.700 tokens. Without it, AI exploration would cost ~18.200 tokens. **Saves ~14.500 tokens per conversation.**
|
> **Token savings:** this file is ~0 tokens. Without it, AI exploration would cost ~0 tokens. **Saves ~0 tokens per conversation.**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -60,6 +60,113 @@
|
|||||||
- class ApiResponse
|
- class ApiResponse
|
||||||
- function error
|
- function error
|
||||||
- function totalPages
|
- function totalPages
|
||||||
|
- `modules/bot/python/nnue.py`
|
||||||
|
- function get_weights_dir: ()
|
||||||
|
- function get_data_dir: ()
|
||||||
|
- function list_checkpoints: ()
|
||||||
|
- function migrate_legacy_data: ()
|
||||||
|
- function show_header: ()
|
||||||
|
- function show_checkpoints_table: ()
|
||||||
|
- _...10 more_
|
||||||
|
- `modules/bot/python/src/dataset.py`
|
||||||
|
- function get_datasets_dir: () -> Path
|
||||||
|
- function next_dataset_version: () -> int
|
||||||
|
- function list_datasets: () -> List[Tuple[int, Dict]]
|
||||||
|
- function load_dataset_metadata: (version) -> Optional[Dict]
|
||||||
|
- function save_dataset_metadata: (version, metadata) -> None
|
||||||
|
- function create_dataset: (version, labeled_jsonl_path, sources, stockfish_depth) -> Path
|
||||||
|
- _...4 more_
|
||||||
|
- `modules/bot/python/src/export.py` — function export_to_nbai: (weights_file, output_file, trained_by, train_loss)
|
||||||
|
- `modules/bot/python/src/generate.py` — function play_random_game_and_collect_positions: (output_file, total_positions, samples_per_game, min_move, max_move, num_workers)
|
||||||
|
- `modules/bot/python/src/label.py` — function normalize_evaluation: (cp_value, method, scale), function label_positions_with_stockfish: (positions_file, output_file, stockfish_path, batch_size, depth, verbose, normalize, num_workers)
|
||||||
|
- `modules/bot/python/src/tactical_positions_extractor.py`
|
||||||
|
- function download_and_extract_puzzle_db: (url, output_dir)
|
||||||
|
- function extract_puzzle_positions: (puzzle_csv, max_puzzles) -> Set[str]
|
||||||
|
- function load_positions_from_file: (file_path) -> Set[str]
|
||||||
|
- function merge_positions: (tactical, other, output_file)
|
||||||
|
- function extract_tactical_only: (puzzle_csv, output_file, max_puzzles) -> int
|
||||||
|
- function interactive_merge_positions: (puzzle_csv, output_file, max_puzzles)
|
||||||
|
- `modules/bot/python/src/train.py`
|
||||||
|
- function fen_to_features: (fen)
|
||||||
|
- function find_next_version: (base_name)
|
||||||
|
- function save_metadata: (weights_file, metadata)
|
||||||
|
- function train_nnue: (data_file, output_file, epochs, batch_size, lr, checkpoint, stockfish_depth, use_versioning, early_stopping_patience, weight_decay, subsample_ratio)
|
||||||
|
- function burst_train: (data_file, output_file, duration_minutes, epochs_per_season, early_stopping_patience, batch_size, lr, initial_checkpoint, stockfish_depth, use_versioning, weight_decay, subsample_ratio)
|
||||||
|
- class NNUEDataset
|
||||||
|
- _...1 more_
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`
|
||||||
|
- class Bot
|
||||||
|
- function name
|
||||||
|
- function nextMove
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/BotController.scala`
|
||||||
|
- class BotController
|
||||||
|
- function getBot
|
||||||
|
- function listBots
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala`
|
||||||
|
- class BotMoveRepetition
|
||||||
|
- function blockedMoves
|
||||||
|
- function repeatedMove
|
||||||
|
- function filterAllowed
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/Config.scala` — class Config
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/ai/Evaluation.scala`
|
||||||
|
- class Evaluation
|
||||||
|
- class CHECKMATE_SCORE
|
||||||
|
- class DRAW_SCORE
|
||||||
|
- function evaluate
|
||||||
|
- function initAccumulator
|
||||||
|
- function copyAccumulator
|
||||||
|
- _...2 more_
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`
|
||||||
|
- class EvaluationClassic
|
||||||
|
- function evaluate
|
||||||
|
- function countRay
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala` — class EvaluationNNUE, function evaluate
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`
|
||||||
|
- class NNUE
|
||||||
|
- function initAccumulator
|
||||||
|
- function pushAccumulator
|
||||||
|
- function copyAccumulator
|
||||||
|
- function recomputeAccumulator
|
||||||
|
- function validateAccumulator
|
||||||
|
- _...4 more_
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiLoader.scala`
|
||||||
|
- class NbaiLoader
|
||||||
|
- function load
|
||||||
|
- function loadDefault
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiMigrator.scala` — class NbaiMigrator, function migrateFromBin
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiModel.scala`
|
||||||
|
- function toJson
|
||||||
|
- class NbaiMetadata
|
||||||
|
- function fromJson
|
||||||
|
- function str
|
||||||
|
- function num
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiWriter.scala` — class NbaiWriter, function write
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`
|
||||||
|
- function bestMove
|
||||||
|
- function bestMove
|
||||||
|
- function bestMoveWithTime
|
||||||
|
- function bestMoveWithTime
|
||||||
|
- function loop
|
||||||
|
- function loop
|
||||||
|
- _...2 more_
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`
|
||||||
|
- class MoveOrdering
|
||||||
|
- class OrderingContext
|
||||||
|
- function addKillerMove
|
||||||
|
- function getKillerMoves
|
||||||
|
- function addHistory
|
||||||
|
- function getHistory
|
||||||
|
- _...3 more_
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/logic/TranspositionTable.scala`
|
||||||
|
- function probe
|
||||||
|
- function store
|
||||||
|
- function clear
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotBook.scala` — function probe, function select
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala` — class PolyglotHash, function hash
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/util/ZobristHash.scala`
|
||||||
|
- class ZobristHash
|
||||||
|
- function hash
|
||||||
|
- function nextHash
|
||||||
- `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`
|
- `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`
|
||||||
- class Command
|
- class Command
|
||||||
- function execute
|
- function execute
|
||||||
@@ -82,7 +189,7 @@
|
|||||||
- function turn
|
- function turn
|
||||||
- function context
|
- function context
|
||||||
- function canUndo
|
- function canUndo
|
||||||
- _...10 more_
|
- _...11 more_
|
||||||
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`
|
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`
|
||||||
- function context
|
- function context
|
||||||
- class Observer
|
- class Observer
|
||||||
@@ -93,6 +200,13 @@
|
|||||||
- _...1 more_
|
- _...1 more_
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — class GameContextExport, function exportGameContext
|
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — class GameContextExport, function exportGameContext
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — class GameContextImport, function importGameContext
|
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — class GameContextImport, function importGameContext
|
||||||
|
- `modules/io/src/main/scala/de/nowchess/io/GameFileService.scala`
|
||||||
|
- class GameFileService
|
||||||
|
- function saveGameToFile
|
||||||
|
- function loadGameFromFile
|
||||||
|
- class FileSystemGameService
|
||||||
|
- function saveGameToFile
|
||||||
|
- function loadGameFromFile
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala`
|
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala`
|
||||||
- class FenExporter
|
- class FenExporter
|
||||||
- function boardToFen
|
- function boardToFen
|
||||||
@@ -114,6 +228,8 @@
|
|||||||
- function parseBoard
|
- function parseBoard
|
||||||
- function importGameContext
|
- function importGameContext
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — function buildSquares
|
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — function buildSquares
|
||||||
|
- `modules/io/src/main/scala/de/nowchess/io/json/JsonExporter.scala` — class JsonExporter, function exportGameContext
|
||||||
|
- `modules/io/src/main/scala/de/nowchess/io/json/JsonParser.scala` — class JsonParser, function importGameContext
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala`
|
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala`
|
||||||
- class PgnExporter
|
- class PgnExporter
|
||||||
- function exportGameContext
|
- function exportGameContext
|
||||||
@@ -160,43 +276,58 @@
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
# Config
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
- `STOCKFISH_PATH` **required** — modules/bot/python/nnue.py
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# Middleware
|
||||||
|
|
||||||
|
## custom
|
||||||
|
- generate — `modules/bot/python/src/generate.py`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
# Dependency Graph
|
# Dependency Graph
|
||||||
|
|
||||||
## Most Imported Files (change these carefully)
|
## Most Imported Files (change these carefully)
|
||||||
|
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` — imported by **28** files
|
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` — imported by **60** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` — imported by **21** files
|
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` — imported by **40** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` — imported by **19** files
|
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` — imported by **39** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` — imported by **14** files
|
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` — imported by **36** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` — imported by **13** files
|
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` — imported by **22** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` — imported by **10** files
|
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` — imported by **21** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` — imported by **9** files
|
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` — imported by **21** files
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` — imported by **9** files
|
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` — imported by **17** files
|
||||||
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` — imported by **8** files
|
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` — imported by **10** files
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — imported by **7** files
|
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` — imported by **10** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/CastlingRights.scala` — imported by **4** files
|
- `modules/api/src/main/scala/de/nowchess/api/board/CastlingRights.scala` — imported by **8** files
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — imported by **4** files
|
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — imported by **8** files
|
||||||
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` — imported by **4** files
|
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotBook.scala` — imported by **5** files
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/BotDifficulty.scala` — imported by **5** files
|
||||||
|
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — imported by **5** files
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala` — imported by **4** files
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala` — imported by **4** files
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala` — imported by **4** files
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala` — imported by **4** files
|
||||||
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala` — imported by **4** files
|
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala` — imported by **4** files
|
||||||
- `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala` — imported by **4** files
|
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnParser.scala` — imported by **2** files
|
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala` — imported by **2** files
|
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala` — imported by **2** files
|
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — imported by **2** files
|
|
||||||
- `modules/core/src/main/scala/de/nowchess/chess/controller/Parser.scala` — imported by **1** files
|
|
||||||
|
|
||||||
## Import Map (who imports what)
|
## Import Map (who imports what)
|
||||||
|
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` ← `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`, `modules/core/src/test/scala/de/nowchess/chess/command/CommandInvokerBranchTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/command/CommandInvokerTest.scala` +23 more
|
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala`, `modules/bot/src/main/scala/de/nowchess/bot/ai/Evaluation.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala` +55 more
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/main/scala/de/nowchess/api/move/Move.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/api/src/test/scala/de/nowchess/api/move/MoveTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala` +16 more
|
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/board/BoardTest.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala` +35 more
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala` +14 more
|
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/main/scala/de/nowchess/api/move/Move.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/api/src/test/scala/de/nowchess/api/move/MoveTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala` +34 more
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/board/BoardTest.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala` +9 more
|
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala` +31 more
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineGameEndingTest.scala` +8 more
|
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +17 more
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` ← `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineScenarioTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesStateTransitionsTest.scala` +5 more
|
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala` +16 more
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` ← `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesStateTransitionsTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesTest.scala` +4 more
|
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/ZobristHash.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +16 more
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` ← `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineLoadGameTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineNotationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineScenarioTest.scala` +4 more
|
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/NNUEBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +12 more
|
||||||
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` ← `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala` +3 more
|
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/NNUEBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +5 more
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` ← `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParserCombinators.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParserFastParse.scala` +2 more
|
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` ← `modules/bot/src/test/scala/de/nowchess/bot/PolyglotHashTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineLoadGameTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineNotationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala` +5 more
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
# Config
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
- `STOCKFISH_PATH` **required** — modules/bot/python/nnue.py
|
||||||
+29
-29
@@ -2,36 +2,36 @@
|
|||||||
|
|
||||||
## Most Imported Files (change these carefully)
|
## Most Imported Files (change these carefully)
|
||||||
|
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` — imported by **28** files
|
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` — imported by **60** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` — imported by **21** files
|
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` — imported by **40** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` — imported by **19** files
|
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` — imported by **39** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` — imported by **14** files
|
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` — imported by **36** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` — imported by **13** files
|
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` — imported by **22** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` — imported by **10** files
|
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` — imported by **21** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` — imported by **9** files
|
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` — imported by **21** files
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` — imported by **9** files
|
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` — imported by **17** files
|
||||||
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` — imported by **8** files
|
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` — imported by **10** files
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — imported by **7** files
|
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` — imported by **10** files
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/CastlingRights.scala` — imported by **4** files
|
- `modules/api/src/main/scala/de/nowchess/api/board/CastlingRights.scala` — imported by **8** files
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — imported by **4** files
|
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — imported by **8** files
|
||||||
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` — imported by **4** files
|
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotBook.scala` — imported by **5** files
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/BotDifficulty.scala` — imported by **5** files
|
||||||
|
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — imported by **5** files
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala` — imported by **4** files
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala` — imported by **4** files
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala` — imported by **4** files
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala` — imported by **4** files
|
||||||
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala` — imported by **4** files
|
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala` — imported by **4** files
|
||||||
- `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala` — imported by **4** files
|
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnParser.scala` — imported by **2** files
|
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala` — imported by **2** files
|
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala` — imported by **2** files
|
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — imported by **2** files
|
|
||||||
- `modules/core/src/main/scala/de/nowchess/chess/controller/Parser.scala` — imported by **1** files
|
|
||||||
|
|
||||||
## Import Map (who imports what)
|
## Import Map (who imports what)
|
||||||
|
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` ← `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`, `modules/core/src/test/scala/de/nowchess/chess/command/CommandInvokerBranchTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/command/CommandInvokerTest.scala` +23 more
|
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala`, `modules/bot/src/main/scala/de/nowchess/bot/ai/Evaluation.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala` +55 more
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/main/scala/de/nowchess/api/move/Move.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/api/src/test/scala/de/nowchess/api/move/MoveTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala` +16 more
|
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/board/BoardTest.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala` +35 more
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala` +14 more
|
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/main/scala/de/nowchess/api/move/Move.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/api/src/test/scala/de/nowchess/api/move/MoveTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala` +34 more
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/board/BoardTest.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala` +9 more
|
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala` +31 more
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineGameEndingTest.scala` +8 more
|
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` ← `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +17 more
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` ← `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineScenarioTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesStateTransitionsTest.scala` +5 more
|
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala` +16 more
|
||||||
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` ← `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesStateTransitionsTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesTest.scala` +4 more
|
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/ZobristHash.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +16 more
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` ← `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineLoadGameTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineNotationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineScenarioTest.scala` +4 more
|
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/NNUEBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +12 more
|
||||||
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` ← `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala` +3 more
|
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` ← `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/NNUEBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +5 more
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` ← `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParserCombinators.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParserFastParse.scala` +2 more
|
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` ← `modules/bot/src/test/scala/de/nowchess/bot/PolyglotHashTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineLoadGameTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineNotationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala` +5 more
|
||||||
|
|||||||
+117
-1
@@ -51,6 +51,113 @@
|
|||||||
- class ApiResponse
|
- class ApiResponse
|
||||||
- function error
|
- function error
|
||||||
- function totalPages
|
- function totalPages
|
||||||
|
- `modules/bot/python/nnue.py`
|
||||||
|
- function get_weights_dir: ()
|
||||||
|
- function get_data_dir: ()
|
||||||
|
- function list_checkpoints: ()
|
||||||
|
- function migrate_legacy_data: ()
|
||||||
|
- function show_header: ()
|
||||||
|
- function show_checkpoints_table: ()
|
||||||
|
- _...10 more_
|
||||||
|
- `modules/bot/python/src/dataset.py`
|
||||||
|
- function get_datasets_dir: () -> Path
|
||||||
|
- function next_dataset_version: () -> int
|
||||||
|
- function list_datasets: () -> List[Tuple[int, Dict]]
|
||||||
|
- function load_dataset_metadata: (version) -> Optional[Dict]
|
||||||
|
- function save_dataset_metadata: (version, metadata) -> None
|
||||||
|
- function create_dataset: (version, labeled_jsonl_path, sources, stockfish_depth) -> Path
|
||||||
|
- _...4 more_
|
||||||
|
- `modules/bot/python/src/export.py` — function export_to_nbai: (weights_file, output_file, trained_by, train_loss)
|
||||||
|
- `modules/bot/python/src/generate.py` — function play_random_game_and_collect_positions: (output_file, total_positions, samples_per_game, min_move, max_move, num_workers)
|
||||||
|
- `modules/bot/python/src/label.py` — function normalize_evaluation: (cp_value, method, scale), function label_positions_with_stockfish: (positions_file, output_file, stockfish_path, batch_size, depth, verbose, normalize, num_workers)
|
||||||
|
- `modules/bot/python/src/tactical_positions_extractor.py`
|
||||||
|
- function download_and_extract_puzzle_db: (url, output_dir)
|
||||||
|
- function extract_puzzle_positions: (puzzle_csv, max_puzzles) -> Set[str]
|
||||||
|
- function load_positions_from_file: (file_path) -> Set[str]
|
||||||
|
- function merge_positions: (tactical, other, output_file)
|
||||||
|
- function extract_tactical_only: (puzzle_csv, output_file, max_puzzles) -> int
|
||||||
|
- function interactive_merge_positions: (puzzle_csv, output_file, max_puzzles)
|
||||||
|
- `modules/bot/python/src/train.py`
|
||||||
|
- function fen_to_features: (fen)
|
||||||
|
- function find_next_version: (base_name)
|
||||||
|
- function save_metadata: (weights_file, metadata)
|
||||||
|
- function train_nnue: (data_file, output_file, epochs, batch_size, lr, checkpoint, stockfish_depth, use_versioning, early_stopping_patience, weight_decay, subsample_ratio)
|
||||||
|
- function burst_train: (data_file, output_file, duration_minutes, epochs_per_season, early_stopping_patience, batch_size, lr, initial_checkpoint, stockfish_depth, use_versioning, weight_decay, subsample_ratio)
|
||||||
|
- class NNUEDataset
|
||||||
|
- _...1 more_
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`
|
||||||
|
- class Bot
|
||||||
|
- function name
|
||||||
|
- function nextMove
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/BotController.scala`
|
||||||
|
- class BotController
|
||||||
|
- function getBot
|
||||||
|
- function listBots
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala`
|
||||||
|
- class BotMoveRepetition
|
||||||
|
- function blockedMoves
|
||||||
|
- function repeatedMove
|
||||||
|
- function filterAllowed
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/Config.scala` — class Config
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/ai/Evaluation.scala`
|
||||||
|
- class Evaluation
|
||||||
|
- class CHECKMATE_SCORE
|
||||||
|
- class DRAW_SCORE
|
||||||
|
- function evaluate
|
||||||
|
- function initAccumulator
|
||||||
|
- function copyAccumulator
|
||||||
|
- _...2 more_
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`
|
||||||
|
- class EvaluationClassic
|
||||||
|
- function evaluate
|
||||||
|
- function countRay
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala` — class EvaluationNNUE, function evaluate
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`
|
||||||
|
- class NNUE
|
||||||
|
- function initAccumulator
|
||||||
|
- function pushAccumulator
|
||||||
|
- function copyAccumulator
|
||||||
|
- function recomputeAccumulator
|
||||||
|
- function validateAccumulator
|
||||||
|
- _...4 more_
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiLoader.scala`
|
||||||
|
- class NbaiLoader
|
||||||
|
- function load
|
||||||
|
- function loadDefault
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiMigrator.scala` — class NbaiMigrator, function migrateFromBin
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiModel.scala`
|
||||||
|
- function toJson
|
||||||
|
- class NbaiMetadata
|
||||||
|
- function fromJson
|
||||||
|
- function str
|
||||||
|
- function num
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiWriter.scala` — class NbaiWriter, function write
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`
|
||||||
|
- function bestMove
|
||||||
|
- function bestMove
|
||||||
|
- function bestMoveWithTime
|
||||||
|
- function bestMoveWithTime
|
||||||
|
- function loop
|
||||||
|
- function loop
|
||||||
|
- _...2 more_
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`
|
||||||
|
- class MoveOrdering
|
||||||
|
- class OrderingContext
|
||||||
|
- function addKillerMove
|
||||||
|
- function getKillerMoves
|
||||||
|
- function addHistory
|
||||||
|
- function getHistory
|
||||||
|
- _...3 more_
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/logic/TranspositionTable.scala`
|
||||||
|
- function probe
|
||||||
|
- function store
|
||||||
|
- function clear
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotBook.scala` — function probe, function select
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala` — class PolyglotHash, function hash
|
||||||
|
- `modules/bot/src/main/scala/de/nowchess/bot/util/ZobristHash.scala`
|
||||||
|
- class ZobristHash
|
||||||
|
- function hash
|
||||||
|
- function nextHash
|
||||||
- `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`
|
- `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`
|
||||||
- class Command
|
- class Command
|
||||||
- function execute
|
- function execute
|
||||||
@@ -73,7 +180,7 @@
|
|||||||
- function turn
|
- function turn
|
||||||
- function context
|
- function context
|
||||||
- function canUndo
|
- function canUndo
|
||||||
- _...10 more_
|
- _...11 more_
|
||||||
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`
|
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`
|
||||||
- function context
|
- function context
|
||||||
- class Observer
|
- class Observer
|
||||||
@@ -84,6 +191,13 @@
|
|||||||
- _...1 more_
|
- _...1 more_
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — class GameContextExport, function exportGameContext
|
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — class GameContextExport, function exportGameContext
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — class GameContextImport, function importGameContext
|
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — class GameContextImport, function importGameContext
|
||||||
|
- `modules/io/src/main/scala/de/nowchess/io/GameFileService.scala`
|
||||||
|
- class GameFileService
|
||||||
|
- function saveGameToFile
|
||||||
|
- function loadGameFromFile
|
||||||
|
- class FileSystemGameService
|
||||||
|
- function saveGameToFile
|
||||||
|
- function loadGameFromFile
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala`
|
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala`
|
||||||
- class FenExporter
|
- class FenExporter
|
||||||
- function boardToFen
|
- function boardToFen
|
||||||
@@ -105,6 +219,8 @@
|
|||||||
- function parseBoard
|
- function parseBoard
|
||||||
- function importGameContext
|
- function importGameContext
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — function buildSquares
|
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — function buildSquares
|
||||||
|
- `modules/io/src/main/scala/de/nowchess/io/json/JsonExporter.scala` — class JsonExporter, function exportGameContext
|
||||||
|
- `modules/io/src/main/scala/de/nowchess/io/json/JsonParser.scala` — class JsonParser, function importGameContext
|
||||||
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala`
|
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala`
|
||||||
- class PgnExporter
|
- class PgnExporter
|
||||||
- function exportGameContext
|
- function exportGameContext
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
# Middleware
|
||||||
|
|
||||||
|
## custom
|
||||||
|
- generate — `modules/bot/python/src/generate.py`
|
||||||
Generated
+2
@@ -8,3 +8,5 @@
|
|||||||
/dataSources.local.xml
|
/dataSources.local.xml
|
||||||
# Editor-based HTTP Client requests
|
# Editor-based HTTP Client requests
|
||||||
/httpRequests/
|
/httpRequests/
|
||||||
|
|
||||||
|
sonarlint.xml
|
||||||
|
|||||||
Generated
+133
@@ -0,0 +1,133 @@
|
|||||||
|
<component name="ProjectCodeStyleConfiguration">
|
||||||
|
<code_scheme name="Project" version="173">
|
||||||
|
<AndroidXmlCodeStyleSettings>
|
||||||
|
<option name="USE_CUSTOM_SETTINGS" value="true" />
|
||||||
|
</AndroidXmlCodeStyleSettings>
|
||||||
|
<JetCodeStyleSettings>
|
||||||
|
<option name="CODE_STYLE_DEFAULTS" value="KOTLIN_OFFICIAL" />
|
||||||
|
</JetCodeStyleSettings>
|
||||||
|
<ScalaCodeStyleSettings>
|
||||||
|
<option name="FORMATTER" value="1" />
|
||||||
|
</ScalaCodeStyleSettings>
|
||||||
|
<XML>
|
||||||
|
<option name="XML_KEEP_LINE_BREAKS" value="false" />
|
||||||
|
<option name="XML_ALIGN_ATTRIBUTES" value="false" />
|
||||||
|
<option name="XML_SPACE_INSIDE_EMPTY_TAG" value="true" />
|
||||||
|
</XML>
|
||||||
|
<codeStyleSettings language="XML">
|
||||||
|
<option name="FORCE_REARRANGE_MODE" value="1" />
|
||||||
|
<indentOptions>
|
||||||
|
<option name="CONTINUATION_INDENT_SIZE" value="4" />
|
||||||
|
</indentOptions>
|
||||||
|
<arrangement>
|
||||||
|
<rules>
|
||||||
|
<section>
|
||||||
|
<rule>
|
||||||
|
<match>
|
||||||
|
<AND>
|
||||||
|
<NAME>xmlns:android</NAME>
|
||||||
|
<XML_ATTRIBUTE />
|
||||||
|
<XML_NAMESPACE>^$</XML_NAMESPACE>
|
||||||
|
</AND>
|
||||||
|
</match>
|
||||||
|
</rule>
|
||||||
|
</section>
|
||||||
|
<section>
|
||||||
|
<rule>
|
||||||
|
<match>
|
||||||
|
<AND>
|
||||||
|
<NAME>xmlns:.*</NAME>
|
||||||
|
<XML_ATTRIBUTE />
|
||||||
|
<XML_NAMESPACE>^$</XML_NAMESPACE>
|
||||||
|
</AND>
|
||||||
|
</match>
|
||||||
|
<order>BY_NAME</order>
|
||||||
|
</rule>
|
||||||
|
</section>
|
||||||
|
<section>
|
||||||
|
<rule>
|
||||||
|
<match>
|
||||||
|
<AND>
|
||||||
|
<NAME>.*:id</NAME>
|
||||||
|
<XML_ATTRIBUTE />
|
||||||
|
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
|
||||||
|
</AND>
|
||||||
|
</match>
|
||||||
|
</rule>
|
||||||
|
</section>
|
||||||
|
<section>
|
||||||
|
<rule>
|
||||||
|
<match>
|
||||||
|
<AND>
|
||||||
|
<NAME>.*:name</NAME>
|
||||||
|
<XML_ATTRIBUTE />
|
||||||
|
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
|
||||||
|
</AND>
|
||||||
|
</match>
|
||||||
|
</rule>
|
||||||
|
</section>
|
||||||
|
<section>
|
||||||
|
<rule>
|
||||||
|
<match>
|
||||||
|
<AND>
|
||||||
|
<NAME>name</NAME>
|
||||||
|
<XML_ATTRIBUTE />
|
||||||
|
<XML_NAMESPACE>^$</XML_NAMESPACE>
|
||||||
|
</AND>
|
||||||
|
</match>
|
||||||
|
</rule>
|
||||||
|
</section>
|
||||||
|
<section>
|
||||||
|
<rule>
|
||||||
|
<match>
|
||||||
|
<AND>
|
||||||
|
<NAME>style</NAME>
|
||||||
|
<XML_ATTRIBUTE />
|
||||||
|
<XML_NAMESPACE>^$</XML_NAMESPACE>
|
||||||
|
</AND>
|
||||||
|
</match>
|
||||||
|
</rule>
|
||||||
|
</section>
|
||||||
|
<section>
|
||||||
|
<rule>
|
||||||
|
<match>
|
||||||
|
<AND>
|
||||||
|
<NAME>.*</NAME>
|
||||||
|
<XML_ATTRIBUTE />
|
||||||
|
<XML_NAMESPACE>^$</XML_NAMESPACE>
|
||||||
|
</AND>
|
||||||
|
</match>
|
||||||
|
<order>BY_NAME</order>
|
||||||
|
</rule>
|
||||||
|
</section>
|
||||||
|
<section>
|
||||||
|
<rule>
|
||||||
|
<match>
|
||||||
|
<AND>
|
||||||
|
<NAME>.*</NAME>
|
||||||
|
<XML_ATTRIBUTE />
|
||||||
|
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
|
||||||
|
</AND>
|
||||||
|
</match>
|
||||||
|
</rule>
|
||||||
|
</section>
|
||||||
|
<section>
|
||||||
|
<rule>
|
||||||
|
<match>
|
||||||
|
<AND>
|
||||||
|
<NAME>.*</NAME>
|
||||||
|
<XML_ATTRIBUTE />
|
||||||
|
<XML_NAMESPACE>.*</XML_NAMESPACE>
|
||||||
|
</AND>
|
||||||
|
</match>
|
||||||
|
<order>BY_NAME</order>
|
||||||
|
</rule>
|
||||||
|
</section>
|
||||||
|
</rules>
|
||||||
|
</arrangement>
|
||||||
|
</codeStyleSettings>
|
||||||
|
<codeStyleSettings language="kotlin">
|
||||||
|
<option name="CODE_STYLE_DEFAULTS" value="KOTLIN_OFFICIAL" />
|
||||||
|
</codeStyleSettings>
|
||||||
|
</code_scheme>
|
||||||
|
</component>
|
||||||
Generated
+1
-1
@@ -1,5 +1,5 @@
|
|||||||
<component name="ProjectCodeStyleConfiguration">
|
<component name="ProjectCodeStyleConfiguration">
|
||||||
<state>
|
<state>
|
||||||
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
|
<option name="USE_PER_PROJECT_SETTINGS" value="true" />
|
||||||
</state>
|
</state>
|
||||||
</component>
|
</component>
|
||||||
Generated
+1
-1
@@ -11,7 +11,7 @@
|
|||||||
<option value="$PROJECT_DIR$" />
|
<option value="$PROJECT_DIR$" />
|
||||||
<option value="$PROJECT_DIR$/modules" />
|
<option value="$PROJECT_DIR$/modules" />
|
||||||
<option value="$PROJECT_DIR$/modules/api" />
|
<option value="$PROJECT_DIR$/modules/api" />
|
||||||
<option value="$PROJECT_DIR$/modules/backcore" />
|
<option value="$PROJECT_DIR$/modules/bot" />
|
||||||
<option value="$PROJECT_DIR$/modules/core" />
|
<option value="$PROJECT_DIR$/modules/core" />
|
||||||
<option value="$PROJECT_DIR$/modules/io" />
|
<option value="$PROJECT_DIR$/modules/io" />
|
||||||
<option value="$PROJECT_DIR$/modules/rule" />
|
<option value="$PROJECT_DIR$/modules/rule" />
|
||||||
|
|||||||
Generated
+1
-1
@@ -5,7 +5,7 @@
|
|||||||
<option name="deprecationWarnings" value="true" />
|
<option name="deprecationWarnings" value="true" />
|
||||||
<option name="uncheckedWarnings" value="true" />
|
<option name="uncheckedWarnings" value="true" />
|
||||||
</profile>
|
</profile>
|
||||||
<profile name="Gradle 2" modules="NowChessSystems.modules.backcore.integrationTest,NowChessSystems.modules.backcore.main,NowChessSystems.modules.backcore.native-test,NowChessSystems.modules.backcore.quarkus-generated-sources,NowChessSystems.modules.backcore.quarkus-test-generated-sources,NowChessSystems.modules.backcore.scoverage,NowChessSystems.modules.backcore.test,NowChessSystems.modules.core.main,NowChessSystems.modules.core.scoverage,NowChessSystems.modules.core.test,NowChessSystems.modules.io.main,NowChessSystems.modules.io.scoverage,NowChessSystems.modules.io.test,NowChessSystems.modules.rule.main,NowChessSystems.modules.rule.scoverage,NowChessSystems.modules.rule.test,NowChessSystems.modules.ui.main,NowChessSystems.modules.ui.scoverage,NowChessSystems.modules.ui.test">
|
<profile name="Gradle 2" modules="NowChessSystems.modules.bot.main,NowChessSystems.modules.bot.scoverage,NowChessSystems.modules.bot.test,NowChessSystems.modules.core.main,NowChessSystems.modules.core.scoverage,NowChessSystems.modules.core.test,NowChessSystems.modules.io.main,NowChessSystems.modules.io.scoverage,NowChessSystems.modules.io.test,NowChessSystems.modules.rule.main,NowChessSystems.modules.rule.scoverage,NowChessSystems.modules.rule.test,NowChessSystems.modules.ui.main,NowChessSystems.modules.ui.scoverage,NowChessSystems.modules.ui.test">
|
||||||
<option name="deprecationWarnings" value="true" />
|
<option name="deprecationWarnings" value="true" />
|
||||||
<option name="uncheckedWarnings" value="true" />
|
<option name="uncheckedWarnings" value="true" />
|
||||||
<parameters>
|
<parameters>
|
||||||
|
|||||||
Generated
-6
@@ -1,6 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="ScalaProjectSettings">
|
|
||||||
<option name="scala3DisclaimerShown" value="true" />
|
|
||||||
</component>
|
|
||||||
</project>
|
|
||||||
@@ -19,6 +19,7 @@ Try to stick to these commands for consistency.
|
|||||||
| `api` | Model / shared types | (none) |
|
| `api` | Model / shared types | (none) |
|
||||||
| `core` | Primary business logic | api, rule |
|
| `core` | Primary business logic | api, rule |
|
||||||
| `rule` | Game rules | api |
|
| `rule` | Game rules | api |
|
||||||
|
| `bot` | Bots and AI | api,rule,io |
|
||||||
| `io` | Export formats | api, core |
|
| `io` | Export formats | api, core |
|
||||||
| `ui` | Entrypoint & UI | core, io |
|
| `ui` | Entrypoint & UI | core, io |
|
||||||
|
|
||||||
@@ -47,6 +48,9 @@ Try to stick to these commands for consistency.
|
|||||||
- **Immutable state as primary model:** GameContext (api) holds board, history, player state — immutable, passed through the system. Each move creates a new GameContext, enabling undo/redo without side effects.
|
- **Immutable state as primary model:** GameContext (api) holds board, history, player state — immutable, passed through the system. Each move creates a new GameContext, enabling undo/redo without side effects.
|
||||||
- **Observer pattern for UI decoupling:** GameEngine publishes move/state events; CommandInvoker queues moves; UI listens to events, not polling. GameEngine never imports UI code.
|
- **Observer pattern for UI decoupling:** GameEngine publishes move/state events; CommandInvoker queues moves; UI listens to events, not polling. GameEngine never imports UI code.
|
||||||
- **RuleSet trait encapsulates rules:** Move generation, check, castling, en passant all in RuleSet impl. GameEngine calls rules as a black box; rules don't know about the rest of core.
|
- **RuleSet trait encapsulates rules:** Move generation, check, castling, en passant all in RuleSet impl. GameEngine calls rules as a black box; rules don't know about the rest of core.
|
||||||
|
- **Polyglot hash must follow spec index layout:** piece keys use interleaved mapping `(pieceType * 2 + colorBit)` with black=0/white=1, castling keys are `768..771`, en-passant file keys are `772..779` and are XORed only if side-to-move has a pawn that can capture en passant, side-to-move key is `780` for white.
|
||||||
|
- **Alpha-beta uses sequential PV search by default:** parallel split was disabled because fixed-window futures removed pruning effectiveness; correctness and pruning quality take priority over speculative parallelism.
|
||||||
|
- **Search hash is updated incrementally per move:** bot search now updates Zobrist keys from parent hash with move deltas instead of recomputing piece scans at every node.
|
||||||
|
|
||||||
## Rules
|
## Rules
|
||||||
|
|
||||||
|
|||||||
+21
-8
@@ -21,15 +21,27 @@ sonar {
|
|||||||
if (report.exists()) report.absolutePath else null
|
if (report.exists()) report.absolutePath else null
|
||||||
}.joinToString(",")
|
}.joinToString(",")
|
||||||
|
|
||||||
val jacocoReports = subprojects.mapNotNull { subproject ->
|
|
||||||
val report = subproject.file("build/reports/jacoco/test/jacocoTestReport.xml")
|
|
||||||
if (report.exists()) report.absolutePath else null
|
|
||||||
}.joinToString(",")
|
|
||||||
|
|
||||||
property("sonar.scala.coverage.reportPaths", scoverageReports)
|
property("sonar.scala.coverage.reportPaths", scoverageReports)
|
||||||
if (jacocoReports.isNotEmpty()) {
|
property(
|
||||||
property("sonar.coverage.jacoco.xmlReportPaths", jacocoReports)
|
"sonar.coverage.exclusions",
|
||||||
}
|
// UI renders JavaFX components; headless test environments cannot exercise rendering paths
|
||||||
|
"modules/ui/**," +
|
||||||
|
// FastParse macro-generated combinators produce synthetic branches that scoverage marks as uncovered
|
||||||
|
"modules/io/src/main/scala/de/nowchess/io/fen/FenParserFastParse*," +
|
||||||
|
// NNUE inference pipeline — coverage requires a trained model file not present in CI
|
||||||
|
"**/bot/**/NNUE.scala," +
|
||||||
|
"**/bot/**/NNUEBot.scala," +
|
||||||
|
"**/bot/**/EvaluationNNUE.scala," +
|
||||||
|
// NBAI binary format loader/writer — error paths require crafted corrupt files; migrator is a one-shot tool
|
||||||
|
"**/bot/**/NbaiLoader.scala," +
|
||||||
|
"**/bot/**/NbaiModel.scala," +
|
||||||
|
"**/bot/**/NbaiMigrator.scala," +
|
||||||
|
"**/bot/**/NbaiWriter.scala," +
|
||||||
|
// PolyglotBook — binary I/O and dead-code guards (bit-masked fields can never exceed valid range)
|
||||||
|
"**/bot/**/PolyglotBook.scala," +
|
||||||
|
"**/bot/**/MoveOrdering.scala," +
|
||||||
|
"**/bot/**/AlphaBetaSearch.scala"
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,6 +55,7 @@ val versions = mapOf(
|
|||||||
"SCALAFX" to "21.0.0-R32",
|
"SCALAFX" to "21.0.0-R32",
|
||||||
"JAVAFX" to "21.0.1",
|
"JAVAFX" to "21.0.1",
|
||||||
"JUNIT_BOM" to "5.13.4",
|
"JUNIT_BOM" to "5.13.4",
|
||||||
|
"ONNXRUNTIME" to "1.19.2",
|
||||||
"SCALA_PARSER_COMBINATORS" to "2.4.0",
|
"SCALA_PARSER_COMBINATORS" to "2.4.0",
|
||||||
"FASTPARSE" to "3.0.2",
|
"FASTPARSE" to "3.0.2",
|
||||||
"JACKSON" to "2.17.2",
|
"JACKSON" to "2.17.2",
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
quarkusPluginId=io.quarkus
|
|
||||||
quarkusPluginVersion=3.32.4
|
|
||||||
quarkusPlatformGroupId=io.quarkus.platform
|
|
||||||
quarkusPlatformArtifactId=quarkus-bom
|
|
||||||
quarkusPlatformVersion=3.32.4
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import glob,re
|
import glob,re
|
||||||
mods=['api','core','io','rule','ui']
|
mods=['api','core','io','rule','ui', 'bot']
|
||||||
tot=0
|
tot=0
|
||||||
for m in mods:
|
for m in mods:
|
||||||
s=0
|
s=0
|
||||||
|
|||||||
@@ -34,3 +34,11 @@
|
|||||||
* NCS-14 implemented insufficient moves rule ([#30](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/30)) ([b0399a4](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/b0399a4e489950083066c9538df9a84dcc7a4613))
|
* NCS-14 implemented insufficient moves rule ([#30](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/30)) ([b0399a4](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/b0399a4e489950083066c9538df9a84dcc7a4613))
|
||||||
* NCS-21 Write Scripts to automate certain tasks ([#15](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/15)) ([8051871](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/80518719d536a087d339fe02530825dc07f8b388))
|
* NCS-21 Write Scripts to automate certain tasks ([#15](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/15)) ([8051871](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/80518719d536a087d339fe02530825dc07f8b388))
|
||||||
* NCS-25 Add linters to keep quality up ([#27](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/27)) ([fd4e67d](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/fd4e67d4f782a7e955822d90cb909d0a81676fb2))
|
* NCS-25 Add linters to keep quality up ([#27](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/27)) ([fd4e67d](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/fd4e67d4f782a7e955822d90cb909d0a81676fb2))
|
||||||
|
## (2026-04-16)
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* NCS-13 Implement Threefold Repetition ([#31](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/31)) ([767d305](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/767d3051a76c266050b6335774d66e2db2273c16))
|
||||||
|
* NCS-14 implemented insufficient moves rule ([#30](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/30)) ([b0399a4](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/b0399a4e489950083066c9538df9a84dcc7a4613))
|
||||||
|
* NCS-21 Write Scripts to automate certain tasks ([#15](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/15)) ([8051871](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/80518719d536a087d339fe02530825dc07f8b388))
|
||||||
|
* NCS-25 Add linters to keep quality up ([#27](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/27)) ([fd4e67d](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/fd4e67d4f782a7e955822d90cb909d0a81676fb2))
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package de.nowchess.api.bot
|
||||||
|
|
||||||
|
import de.nowchess.api.game.GameContext
|
||||||
|
import de.nowchess.api.move.Move
|
||||||
|
|
||||||
|
trait Bot {
|
||||||
|
|
||||||
|
def name: String
|
||||||
|
def nextMove(context: GameContext): Option[Move]
|
||||||
|
|
||||||
|
}
|
||||||
@@ -5,4 +5,5 @@ enum DrawReason:
|
|||||||
case Stalemate
|
case Stalemate
|
||||||
case InsufficientMaterial
|
case InsufficientMaterial
|
||||||
case FiftyMoveRule
|
case FiftyMoveRule
|
||||||
|
case ThreefoldRepetition
|
||||||
case Agreement
|
case Agreement
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package de.nowchess.api.game
|
package de.nowchess.api.game
|
||||||
|
|
||||||
import de.nowchess.api.board.{Board, CastlingRights, Color, Square}
|
import de.nowchess.api.board.{Board, CastlingRights, Color, PieceType, Square}
|
||||||
import de.nowchess.api.move.Move
|
import de.nowchess.api.move.Move
|
||||||
|
|
||||||
/** Immutable bundle of complete game state. All state changes produce new GameContext instances.
|
/** Immutable bundle of complete game state. All state changes produce new GameContext instances.
|
||||||
@@ -13,7 +13,15 @@ case class GameContext(
|
|||||||
halfMoveClock: Int,
|
halfMoveClock: Int,
|
||||||
moves: List[Move],
|
moves: List[Move],
|
||||||
result: Option[GameResult] = None,
|
result: Option[GameResult] = None,
|
||||||
|
initialBoard: Board = Board.initial,
|
||||||
):
|
):
|
||||||
|
private lazy val whiteKingSquare: Option[Square] =
|
||||||
|
board.pieces.find((_, p) => p.color == Color.White && p.pieceType == PieceType.King).map(_._1)
|
||||||
|
private lazy val blackKingSquare: Option[Square] =
|
||||||
|
board.pieces.find((_, p) => p.color == Color.Black && p.pieceType == PieceType.King).map(_._1)
|
||||||
|
def kingSquare(color: Color): Option[Square] =
|
||||||
|
if color == Color.White then whiteKingSquare else blackKingSquare
|
||||||
|
|
||||||
/** Create new context with updated board. */
|
/** Create new context with updated board. */
|
||||||
def withBoard(newBoard: Board): GameContext = copy(board = newBoard)
|
def withBoard(newBoard: Board): GameContext = copy(board = newBoard)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package de.nowchess.api.game
|
||||||
|
|
||||||
|
import de.nowchess.api.bot.Bot
|
||||||
|
import de.nowchess.api.player.PlayerInfo
|
||||||
|
|
||||||
|
sealed trait Participant
|
||||||
|
final case class Human(playerInfo: PlayerInfo) extends Participant
|
||||||
|
final case class BotParticipant(bot: Bot) extends Participant
|
||||||
@@ -71,3 +71,9 @@ class GameContextTest extends AnyFunSuite with Matchers:
|
|||||||
test("withResult clears result"):
|
test("withResult clears result"):
|
||||||
val ctx = GameContext.initial.withResult(Some(GameResult.Win(Color.Black)))
|
val ctx = GameContext.initial.withResult(Some(GameResult.Win(Color.Black)))
|
||||||
ctx.withResult(None).result shouldBe None
|
ctx.withResult(None).result shouldBe None
|
||||||
|
|
||||||
|
test("kingSquare returns white king position"):
|
||||||
|
GameContext.initial.kingSquare(Color.White) shouldBe Some(Square(File.E, Rank.R1))
|
||||||
|
|
||||||
|
test("kingSquare returns black king position"):
|
||||||
|
GameContext.initial.kingSquare(Color.Black) shouldBe Some(Square(File.E, Rank.R8))
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
MAJOR=0
|
MAJOR=0
|
||||||
MINOR=5
|
MINOR=6
|
||||||
PATCH=0
|
PATCH=0
|
||||||
|
|||||||
@@ -1,101 +0,0 @@
|
|||||||
plugins {
|
|
||||||
id("scala")
|
|
||||||
id("org.scoverage") version "8.1"
|
|
||||||
id("io.quarkus")
|
|
||||||
id("jacoco")
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
val versions = rootProject.extra["VERSIONS"] as Map<String, String>
|
|
||||||
|
|
||||||
repositories {
|
|
||||||
mavenCentral()
|
|
||||||
mavenLocal()
|
|
||||||
}
|
|
||||||
|
|
||||||
scala {
|
|
||||||
scalaVersion = versions["SCALA3"]!!
|
|
||||||
}
|
|
||||||
|
|
||||||
scoverage {
|
|
||||||
scoverageVersion.set(versions["SCOVERAGE"]!!)
|
|
||||||
}
|
|
||||||
|
|
||||||
tasks.withType<ScalaCompile> {
|
|
||||||
scalaCompileOptions.additionalParameters = listOf("-encoding", "UTF-8")
|
|
||||||
}
|
|
||||||
|
|
||||||
val quarkusPlatformGroupId: String by project
|
|
||||||
val quarkusPlatformArtifactId: String by project
|
|
||||||
val quarkusPlatformVersion: String by project
|
|
||||||
|
|
||||||
dependencies {
|
|
||||||
implementation(project(":modules:api"))
|
|
||||||
implementation(project(":modules:core"))
|
|
||||||
implementation(project(":modules:io"))
|
|
||||||
implementation(project(":modules:rule"))
|
|
||||||
|
|
||||||
implementation(enforcedPlatform("${quarkusPlatformGroupId}:${quarkusPlatformArtifactId}:${quarkusPlatformVersion}"))
|
|
||||||
implementation("io.quarkus:quarkus-rest")
|
|
||||||
implementation("io.quarkus:quarkus-rest-jackson")
|
|
||||||
implementation("io.quarkus:quarkus-config-yaml")
|
|
||||||
implementation("io.quarkus:quarkus-arc")
|
|
||||||
|
|
||||||
implementation("com.fasterxml.jackson.module:jackson-module-scala_3:${versions["JACKSON_SCALA"]!!}")
|
|
||||||
|
|
||||||
testImplementation(platform("org.junit:junit-bom:${versions["JUNIT_BOM"]!!}"))
|
|
||||||
testImplementation("org.junit.jupiter:junit-jupiter")
|
|
||||||
testImplementation("org.scalatest:scalatest_3:${versions["SCALATEST"]!!}")
|
|
||||||
testImplementation("co.helmethair:scalatest-junit-runner:${versions["SCALATEST_JUNIT"]!!}")
|
|
||||||
testImplementation("io.quarkus:quarkus-junit5")
|
|
||||||
testImplementation("io.quarkus:quarkus-jacoco")
|
|
||||||
testImplementation("io.rest-assured:rest-assured")
|
|
||||||
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
|
||||||
}
|
|
||||||
|
|
||||||
configurations.matching { !it.name.startsWith("scoverage") }.configureEach {
|
|
||||||
resolutionStrategy.force("org.scala-lang:scala-library:${versions["SCALA_LIBRARY"]!!}")
|
|
||||||
}
|
|
||||||
configurations.scoverage {
|
|
||||||
resolutionStrategy.eachDependency {
|
|
||||||
if (requested.group == "org.scoverage" && requested.name.startsWith("scalac-scoverage-plugin_")) {
|
|
||||||
useTarget("${requested.group}:scalac-scoverage-plugin_2.13.16:2.3.0")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
group = "de.nowchess"
|
|
||||||
version = "1.0-SNAPSHOT"
|
|
||||||
|
|
||||||
tasks.withType<JavaCompile> {
|
|
||||||
options.encoding = "UTF-8"
|
|
||||||
options.compilerArgs.add("-parameters")
|
|
||||||
}
|
|
||||||
|
|
||||||
tasks.withType<Jar>().configureEach {
|
|
||||||
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
|
|
||||||
}
|
|
||||||
|
|
||||||
tasks.test {
|
|
||||||
useJUnitPlatform {
|
|
||||||
includeEngines("scalatest", "junit-jupiter")
|
|
||||||
}
|
|
||||||
testLogging {
|
|
||||||
events("passed", "skipped", "failed")
|
|
||||||
}
|
|
||||||
finalizedBy(tasks.named("jacocoTestReport"))
|
|
||||||
}
|
|
||||||
|
|
||||||
tasks.jacocoTestReport {
|
|
||||||
dependsOn(tasks.test)
|
|
||||||
executionData.setFrom(layout.buildDirectory.file("jacoco-quarkus.exec"))
|
|
||||||
sourceDirectories.setFrom(files("src/main/scala"))
|
|
||||||
classDirectories.setFrom(files(layout.buildDirectory.dir("classes/scala/main")))
|
|
||||||
reports {
|
|
||||||
xml.required.set(true)
|
|
||||||
xml.outputLocation.set(
|
|
||||||
layout.buildDirectory.file("reports/jacoco/test/jacocoTestReport.xml")
|
|
||||||
)
|
|
||||||
html.required.set(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
quarkus:
|
|
||||||
http:
|
|
||||||
port: 8080
|
|
||||||
jacoco:
|
|
||||||
data-file: ${user.dir}/build/jacoco-quarkus.exec
|
|
||||||
report: false
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
package de.nowchess.backcore.config
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper
|
|
||||||
import com.fasterxml.jackson.module.scala.DefaultScalaModule
|
|
||||||
import io.quarkus.jackson.ObjectMapperCustomizer
|
|
||||||
import jakarta.inject.Singleton
|
|
||||||
|
|
||||||
@Singleton
|
|
||||||
class JacksonConfig extends ObjectMapperCustomizer:
|
|
||||||
def customize(mapper: ObjectMapper): Unit =
|
|
||||||
mapper.registerModule(DefaultScalaModule)
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
package de.nowchess.backcore.dto
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude
|
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude.Include
|
|
||||||
|
|
||||||
case class PlayerInfoDto(id: String, displayName: String)
|
|
||||||
|
|
||||||
case class GameStateResponse(
|
|
||||||
fen: String,
|
|
||||||
pgn: String,
|
|
||||||
turn: String,
|
|
||||||
status: String,
|
|
||||||
@JsonInclude(Include.NON_ABSENT) winner: Option[String],
|
|
||||||
moves: List[String],
|
|
||||||
undoAvailable: Boolean,
|
|
||||||
redoAvailable: Boolean,
|
|
||||||
)
|
|
||||||
|
|
||||||
case class GameFullResponse(
|
|
||||||
gameId: String,
|
|
||||||
white: PlayerInfoDto,
|
|
||||||
black: PlayerInfoDto,
|
|
||||||
state: GameStateResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
case class OkResponse(ok: Boolean = true)
|
|
||||||
|
|
||||||
@JsonInclude(Include.NON_ABSENT)
|
|
||||||
case class ApiErrorResponse(
|
|
||||||
code: String,
|
|
||||||
message: String,
|
|
||||||
field: Option[String] = None,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Requests
|
|
||||||
case class CreateGameRequest(
|
|
||||||
white: Option[PlayerInfoDto] = None,
|
|
||||||
black: Option[PlayerInfoDto] = None,
|
|
||||||
)
|
|
||||||
|
|
||||||
case class ImportFenRequest(
|
|
||||||
fen: String = "",
|
|
||||||
white: Option[PlayerInfoDto] = None,
|
|
||||||
black: Option[PlayerInfoDto] = None,
|
|
||||||
)
|
|
||||||
|
|
||||||
case class ImportPgnRequest(pgn: String = "")
|
|
||||||
|
|
||||||
case class LegalMoveDto(
|
|
||||||
from: String,
|
|
||||||
to: String,
|
|
||||||
uci: String,
|
|
||||||
moveType: String,
|
|
||||||
@JsonInclude(Include.NON_ABSENT) promotion: Option[String] = None,
|
|
||||||
)
|
|
||||||
|
|
||||||
case class LegalMovesResponse(moves: List[LegalMoveDto])
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
package de.nowchess.backcore.game
|
|
||||||
|
|
||||||
import java.security.SecureRandom
|
|
||||||
|
|
||||||
object GameId:
|
|
||||||
private val chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
|
||||||
private val random = SecureRandom()
|
|
||||||
|
|
||||||
def generate(): String =
|
|
||||||
(1 to 8).map(_ => chars(random.nextInt(chars.length))).mkString
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
package de.nowchess.backcore.game
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper
|
|
||||||
import com.fasterxml.jackson.module.scala.DefaultScalaModule
|
|
||||||
import de.nowchess.api.board.Color
|
|
||||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
|
||||||
import de.nowchess.backcore.dto.*
|
|
||||||
import de.nowchess.io.fen.FenExporter
|
|
||||||
import de.nowchess.io.pgn.PgnExporter
|
|
||||||
import de.nowchess.rules.sets.DefaultRules
|
|
||||||
|
|
||||||
object GameMapper:
|
|
||||||
private val mapper = new ObjectMapper().registerModule(DefaultScalaModule)
|
|
||||||
|
|
||||||
def toGameFullJson(session: GameSession): String =
|
|
||||||
mapper.writeValueAsString(toGameFull(session))
|
|
||||||
|
|
||||||
def toGameFull(session: GameSession): GameFullResponse =
|
|
||||||
GameFullResponse(
|
|
||||||
gameId = session.gameId,
|
|
||||||
white = toPlayerInfo(session.white),
|
|
||||||
black = toPlayerInfo(session.black),
|
|
||||||
state = toGameState(session),
|
|
||||||
)
|
|
||||||
|
|
||||||
def toGameState(session: GameSession): GameStateResponse =
|
|
||||||
val (status, winner) = computeStatus(session)
|
|
||||||
GameStateResponse(
|
|
||||||
fen = FenExporter.exportGameContext(session.context),
|
|
||||||
pgn = buildPgn(session.context.moves),
|
|
||||||
turn = if session.context.turn == Color.White then "white" else "black",
|
|
||||||
status = status,
|
|
||||||
winner = winner,
|
|
||||||
moves = session.context.moves.map(moveToUci),
|
|
||||||
undoAvailable = session.invoker.canUndo,
|
|
||||||
redoAvailable = session.invoker.canRedo,
|
|
||||||
)
|
|
||||||
|
|
||||||
private def toPlayerInfo(p: de.nowchess.api.player.PlayerInfo): PlayerInfoDto =
|
|
||||||
PlayerInfoDto(id = p.id.value, displayName = p.displayName)
|
|
||||||
|
|
||||||
private def computeStatus(session: GameSession): (String, Option[String]) =
|
|
||||||
session.result match
|
|
||||||
case Some(GameResult.Checkmate(winner)) =>
|
|
||||||
val w = if winner == Color.White then "white" else "black"
|
|
||||||
("checkmate", Some(w))
|
|
||||||
case Some(GameResult.Stalemate) =>
|
|
||||||
("stalemate", None)
|
|
||||||
case Some(GameResult.Resign(winner)) =>
|
|
||||||
val w = if winner == Color.White then "white" else "black"
|
|
||||||
("resign", Some(w))
|
|
||||||
case Some(GameResult.AgreedDraw) | Some(GameResult.FiftyMoveDraw) =>
|
|
||||||
("draw", None)
|
|
||||||
case Some(GameResult.InsufficientMaterial) =>
|
|
||||||
("insufficientMaterial", None)
|
|
||||||
case None =>
|
|
||||||
computeLiveStatus(session)
|
|
||||||
|
|
||||||
private def computeLiveStatus(session: GameSession): (String, Option[String]) =
|
|
||||||
val ctx = session.context
|
|
||||||
if DefaultRules.isCheck(ctx) then ("check", None)
|
|
||||||
else if session.drawOfferedBy.isDefined then ("drawOffered", None)
|
|
||||||
else if DefaultRules.isFiftyMoveRule(ctx) then ("fiftyMoveAvailable", None)
|
|
||||||
else ("started", None)
|
|
||||||
|
|
||||||
def moveToUci(move: Move): String =
|
|
||||||
val base = s"${move.from}${move.to}"
|
|
||||||
move.moveType match
|
|
||||||
case MoveType.Promotion(piece) =>
|
|
||||||
val suffix = piece match
|
|
||||||
case PromotionPiece.Queen => "q"
|
|
||||||
case PromotionPiece.Rook => "r"
|
|
||||||
case PromotionPiece.Bishop => "b"
|
|
||||||
case PromotionPiece.Knight => "n"
|
|
||||||
base + suffix
|
|
||||||
case _ => base
|
|
||||||
|
|
||||||
private def buildPgn(moves: List[Move]): String =
|
|
||||||
// Use PgnExporter with no headers to get move-text only (SAN notation)
|
|
||||||
PgnExporter.exportGame(Map.empty, moves)
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
package de.nowchess.backcore.game
|
|
||||||
|
|
||||||
import de.nowchess.api.board.Color
|
|
||||||
|
|
||||||
sealed trait GameResult
|
|
||||||
object GameResult:
|
|
||||||
case class Checkmate(winner: Color) extends GameResult
|
|
||||||
case object Stalemate extends GameResult
|
|
||||||
case class Resign(winner: Color) extends GameResult
|
|
||||||
case object AgreedDraw extends GameResult
|
|
||||||
case object FiftyMoveDraw extends GameResult
|
|
||||||
case object InsufficientMaterial extends GameResult
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
package de.nowchess.backcore.game
|
|
||||||
|
|
||||||
import de.nowchess.api.board.Color
|
|
||||||
import de.nowchess.api.game.GameContext
|
|
||||||
import de.nowchess.api.player.PlayerInfo
|
|
||||||
import de.nowchess.chess.command.CommandInvoker
|
|
||||||
|
|
||||||
case class GameSession(
|
|
||||||
gameId: String,
|
|
||||||
white: PlayerInfo,
|
|
||||||
black: PlayerInfo,
|
|
||||||
context: GameContext,
|
|
||||||
invoker: CommandInvoker,
|
|
||||||
drawOfferedBy: Option[Color] = None,
|
|
||||||
result: Option[GameResult] = None,
|
|
||||||
)
|
|
||||||
@@ -1,260 +0,0 @@
|
|||||||
package de.nowchess.backcore.game
|
|
||||||
|
|
||||||
import de.nowchess.api.board.{Color, Square}
|
|
||||||
import de.nowchess.api.game.GameContext
|
|
||||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
|
||||||
import de.nowchess.api.player.{PlayerId, PlayerInfo}
|
|
||||||
import de.nowchess.backcore.dto.{CreateGameRequest, ImportFenRequest, PlayerInfoDto}
|
|
||||||
import de.nowchess.chess.command.{CommandInvoker, MoveCommand, MoveResult}
|
|
||||||
import de.nowchess.io.fen.FenParser
|
|
||||||
import de.nowchess.io.pgn.PgnParser
|
|
||||||
import de.nowchess.rules.sets.DefaultRules
|
|
||||||
import jakarta.enterprise.context.ApplicationScoped
|
|
||||||
|
|
||||||
import scala.collection.mutable
|
|
||||||
|
|
||||||
@ApplicationScoped
|
|
||||||
class GameStore:
|
|
||||||
private val games: mutable.Map[String, GameSession] = mutable.Map.empty
|
|
||||||
|
|
||||||
// ─── Create / Get ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def create(req: CreateGameRequest): GameSession = synchronized:
|
|
||||||
val id = generateId()
|
|
||||||
val session = newSession(id, req.white, req.black, GameContext.initial)
|
|
||||||
games(id) = session
|
|
||||||
session
|
|
||||||
|
|
||||||
def get(id: String): Option[GameSession] = synchronized:
|
|
||||||
games.get(id)
|
|
||||||
|
|
||||||
// ─── Move-making ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def applyMove(id: String, uci: String): Either[String, GameSession] = synchronized:
|
|
||||||
withSession(id): session =>
|
|
||||||
if session.result.isDefined then Left("Game is already over")
|
|
||||||
else
|
|
||||||
parseUci(uci) match
|
|
||||||
case None => Left(s"Invalid UCI notation: $uci")
|
|
||||||
case Some((from, to, promotion)) =>
|
|
||||||
val legalCandidates = DefaultRules.legalMoves(session.context)(from)
|
|
||||||
findMatchingMove(legalCandidates, to, promotion) match
|
|
||||||
case None => Left(s"$uci is not a legal move")
|
|
||||||
case Some(move) =>
|
|
||||||
val nextCtx = DefaultRules.applyMove(session.context)(move)
|
|
||||||
val prevCtx = session.context
|
|
||||||
val cmd = MoveCommand(
|
|
||||||
from = move.from,
|
|
||||||
to = move.to,
|
|
||||||
moveResult = Some(MoveResult.Successful(nextCtx, prevCtx.board.pieceAt(move.to))),
|
|
||||||
previousContext = Some(prevCtx),
|
|
||||||
)
|
|
||||||
session.invoker.execute(cmd)
|
|
||||||
val result = detectGameOver(nextCtx)
|
|
||||||
val updated = session.copy(context = nextCtx, result = result)
|
|
||||||
games(id) = updated
|
|
||||||
Right(updated)
|
|
||||||
|
|
||||||
def legalMoves(id: String, square: Option[Square]): Either[String, List[Move]] = synchronized:
|
|
||||||
withSession(id): session =>
|
|
||||||
val moves = square match
|
|
||||||
case Some(sq) => DefaultRules.legalMoves(session.context)(sq)
|
|
||||||
case None => DefaultRules.allLegalMoves(session.context)
|
|
||||||
Right(moves)
|
|
||||||
|
|
||||||
// ─── Undo / Redo ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def undo(id: String): Either[String, GameSession] = synchronized:
|
|
||||||
withSession(id): session =>
|
|
||||||
if !session.invoker.canUndo then Left("No moves to undo")
|
|
||||||
else
|
|
||||||
val idx = session.invoker.getCurrentIndex
|
|
||||||
session.invoker.history(idx) match
|
|
||||||
case cmd: MoveCommand =>
|
|
||||||
cmd.previousContext match
|
|
||||||
case None => Left("Cannot undo: no previous context stored")
|
|
||||||
case Some(prevCtx) =>
|
|
||||||
session.invoker.undo()
|
|
||||||
val updated = session.copy(context = prevCtx, result = None, drawOfferedBy = None)
|
|
||||||
games(id) = updated
|
|
||||||
Right(updated)
|
|
||||||
case _ => Left("Cannot undo this command type")
|
|
||||||
|
|
||||||
def redo(id: String): Either[String, GameSession] = synchronized:
|
|
||||||
withSession(id): session =>
|
|
||||||
if !session.invoker.canRedo then Left("No moves to redo")
|
|
||||||
else
|
|
||||||
val idx = session.invoker.getCurrentIndex + 1
|
|
||||||
session.invoker.history(idx) match
|
|
||||||
case cmd: MoveCommand =>
|
|
||||||
cmd.moveResult match
|
|
||||||
case Some(MoveResult.Successful(nextCtx, _)) =>
|
|
||||||
session.invoker.redo()
|
|
||||||
val result = detectGameOver(nextCtx)
|
|
||||||
val updated = session.copy(context = nextCtx, result = result)
|
|
||||||
games(id) = updated
|
|
||||||
Right(updated)
|
|
||||||
case _ => Left("Cannot redo: move result not available")
|
|
||||||
case _ => Left("Cannot redo this command type")
|
|
||||||
|
|
||||||
// ─── Resign ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def resign(id: String): Either[String, GameSession] = synchronized:
|
|
||||||
withSession(id): session =>
|
|
||||||
if session.result.isDefined then Left("Game is already over")
|
|
||||||
else
|
|
||||||
val winner = session.context.turn.opposite
|
|
||||||
val updated = session.copy(result = Some(GameResult.Resign(winner)))
|
|
||||||
games(id) = updated
|
|
||||||
Right(updated)
|
|
||||||
|
|
||||||
// ─── Draw actions ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def drawAction(id: String, action: String): Either[String, GameSession] = synchronized:
|
|
||||||
withSession(id): session =>
|
|
||||||
if session.result.isDefined then Left("Game is already over")
|
|
||||||
else
|
|
||||||
action match
|
|
||||||
case "offer" =>
|
|
||||||
val updated = session.copy(drawOfferedBy = Some(session.context.turn))
|
|
||||||
games(id) = updated
|
|
||||||
Right(updated)
|
|
||||||
case "accept" =>
|
|
||||||
session.drawOfferedBy match
|
|
||||||
case None => Left("No draw offer to accept")
|
|
||||||
case Some(offerer) if offerer == session.context.turn =>
|
|
||||||
Left("Cannot accept your own draw offer")
|
|
||||||
case Some(_) =>
|
|
||||||
val updated = session.copy(result = Some(GameResult.AgreedDraw), drawOfferedBy = None)
|
|
||||||
games(id) = updated
|
|
||||||
Right(updated)
|
|
||||||
case "decline" =>
|
|
||||||
session.drawOfferedBy match
|
|
||||||
case None => Left("No draw offer to decline")
|
|
||||||
case Some(_) =>
|
|
||||||
val updated = session.copy(drawOfferedBy = None)
|
|
||||||
games(id) = updated
|
|
||||||
Right(updated)
|
|
||||||
case "claim" =>
|
|
||||||
if DefaultRules.isFiftyMoveRule(session.context) then
|
|
||||||
val updated = session.copy(result = Some(GameResult.FiftyMoveDraw))
|
|
||||||
games(id) = updated
|
|
||||||
Right(updated)
|
|
||||||
else Left("Fifty-move rule has not been triggered")
|
|
||||||
case other => Left(s"Unknown draw action: $other")
|
|
||||||
|
|
||||||
// ─── Import ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def importFen(req: ImportFenRequest): Either[String, GameSession] = synchronized:
|
|
||||||
FenParser.parseFen(req.fen) match
|
|
||||||
case Left(err) => Left(err)
|
|
||||||
case Right(ctx) =>
|
|
||||||
val id = generateId()
|
|
||||||
val session = newSession(id, req.white, req.black, ctx)
|
|
||||||
games(id) = session
|
|
||||||
Right(session)
|
|
||||||
|
|
||||||
def importPgn(pgn: String, white: Option[PlayerInfoDto], black: Option[PlayerInfoDto]): Either[String, GameSession] =
|
|
||||||
synchronized:
|
|
||||||
PgnParser.validatePgn(pgn) match
|
|
||||||
case Left(err) => Left(err)
|
|
||||||
case Right(game) =>
|
|
||||||
val id = generateId()
|
|
||||||
val session = newSession(id, white, black, GameContext.initial)
|
|
||||||
replayIntoSession(session, game.moves, GameContext.initial) match
|
|
||||||
case Left(err) => Left(err)
|
|
||||||
case Right(s) =>
|
|
||||||
games(id) = s
|
|
||||||
Right(s)
|
|
||||||
|
|
||||||
// ─── Private helpers ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
private def withSession[A](id: String)(f: GameSession => Either[String, A]): Either[String, A] =
|
|
||||||
games.get(id) match
|
|
||||||
case None => Left(s"Game $id not found")
|
|
||||||
case Some(session) => f(session)
|
|
||||||
|
|
||||||
private def generateId(): String =
|
|
||||||
var id = GameId.generate()
|
|
||||||
while games.contains(id) do id = GameId.generate()
|
|
||||||
id
|
|
||||||
|
|
||||||
private def newSession(
|
|
||||||
id: String,
|
|
||||||
white: Option[PlayerInfoDto],
|
|
||||||
black: Option[PlayerInfoDto],
|
|
||||||
ctx: GameContext,
|
|
||||||
): GameSession =
|
|
||||||
GameSession(
|
|
||||||
gameId = id,
|
|
||||||
white = toPlayerInfo(white, "white", "White"),
|
|
||||||
black = toPlayerInfo(black, "black", "Black"),
|
|
||||||
context = ctx,
|
|
||||||
invoker = new CommandInvoker(),
|
|
||||||
)
|
|
||||||
|
|
||||||
private def toPlayerInfo(dto: Option[PlayerInfoDto], defaultId: String, defaultName: String): PlayerInfo =
|
|
||||||
dto.fold(PlayerInfo(PlayerId(defaultId), defaultName))(d => PlayerInfo(PlayerId(d.id), d.displayName))
|
|
||||||
|
|
||||||
private def parseUci(uci: String): Option[(Square, Square, Option[PromotionPiece])] =
|
|
||||||
if uci.length < 4 || uci.length > 5 then None
|
|
||||||
else
|
|
||||||
for
|
|
||||||
from <- Square.fromAlgebraic(uci.substring(0, 2))
|
|
||||||
to <- Square.fromAlgebraic(uci.substring(2, 4))
|
|
||||||
yield
|
|
||||||
val promotion = if uci.length == 5 then parsePromotionChar(uci.charAt(4)) else None
|
|
||||||
(from, to, promotion)
|
|
||||||
|
|
||||||
private def parsePromotionChar(c: Char): Option[PromotionPiece] =
|
|
||||||
c match
|
|
||||||
case 'q' => Some(PromotionPiece.Queen)
|
|
||||||
case 'r' => Some(PromotionPiece.Rook)
|
|
||||||
case 'b' => Some(PromotionPiece.Bishop)
|
|
||||||
case 'n' => Some(PromotionPiece.Knight)
|
|
||||||
case _ => None
|
|
||||||
|
|
||||||
private def findMatchingMove(
|
|
||||||
candidates: List[Move],
|
|
||||||
to: Square,
|
|
||||||
promotion: Option[PromotionPiece],
|
|
||||||
): Option[Move] =
|
|
||||||
candidates.filter(_.to == to) match
|
|
||||||
case Nil => None
|
|
||||||
case moves =>
|
|
||||||
promotion match
|
|
||||||
case Some(pp) => moves.find(_.moveType == MoveType.Promotion(pp))
|
|
||||||
case None =>
|
|
||||||
moves
|
|
||||||
.find(m => !m.moveType.isInstanceOf[MoveType.Promotion])
|
|
||||||
.orElse(moves.headOption)
|
|
||||||
|
|
||||||
private def detectGameOver(ctx: GameContext): Option[GameResult] =
|
|
||||||
if DefaultRules.isCheckmate(ctx) then Some(GameResult.Checkmate(ctx.turn.opposite))
|
|
||||||
else if DefaultRules.isStalemate(ctx) then Some(GameResult.Stalemate)
|
|
||||||
else if DefaultRules.isInsufficientMaterial(ctx) then Some(GameResult.InsufficientMaterial)
|
|
||||||
else None
|
|
||||||
|
|
||||||
private def replayIntoSession(
|
|
||||||
session: GameSession,
|
|
||||||
moves: List[Move],
|
|
||||||
startCtx: GameContext,
|
|
||||||
): Either[String, GameSession] =
|
|
||||||
moves.foldLeft[Either[String, GameSession]](Right(session)):
|
|
||||||
case (Left(err), _) => Left(err)
|
|
||||||
case (Right(s), move) =>
|
|
||||||
val legal = DefaultRules.legalMoves(s.context)(move.from)
|
|
||||||
legal
|
|
||||||
.find(m => m.from == move.from && m.to == move.to && m.moveType == move.moveType)
|
|
||||||
.orElse(legal.find(m => m.from == move.from && m.to == move.to)) match
|
|
||||||
case None => Left(s"Illegal move in PGN: $move")
|
|
||||||
case Some(legalMove) =>
|
|
||||||
val nextCtx = DefaultRules.applyMove(s.context)(legalMove)
|
|
||||||
val cmd = MoveCommand(
|
|
||||||
from = legalMove.from,
|
|
||||||
to = legalMove.to,
|
|
||||||
moveResult = Some(MoveResult.Successful(nextCtx, s.context.board.pieceAt(legalMove.to))),
|
|
||||||
previousContext = Some(s.context),
|
|
||||||
)
|
|
||||||
s.invoker.execute(cmd)
|
|
||||||
Right(s.copy(context = nextCtx))
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
package de.nowchess.backcore.resource
|
|
||||||
|
|
||||||
import de.nowchess.backcore.dto.*
|
|
||||||
import de.nowchess.backcore.game.{GameMapper, GameStore}
|
|
||||||
import jakarta.enterprise.context.ApplicationScoped
|
|
||||||
import jakarta.inject.Inject
|
|
||||||
import jakarta.ws.rs.*
|
|
||||||
import jakarta.ws.rs.core.{MediaType, Response}
|
|
||||||
|
|
||||||
@Path("/api/board/game")
|
|
||||||
@Produces(Array(MediaType.APPLICATION_JSON))
|
|
||||||
@ApplicationScoped
|
|
||||||
class GameResource @Inject() (store: GameStore):
|
|
||||||
|
|
||||||
@POST
|
|
||||||
@Consumes(Array(MediaType.APPLICATION_JSON))
|
|
||||||
def createGame(req: CreateGameRequest): Response =
|
|
||||||
val session = store.create(Option(req).getOrElse(CreateGameRequest()))
|
|
||||||
Response.status(201).entity(GameMapper.toGameFull(session)).build()
|
|
||||||
|
|
||||||
@GET
|
|
||||||
@Path("/{gameId}")
|
|
||||||
def getGame(@PathParam("gameId") gameId: String): Response =
|
|
||||||
store.get(gameId) match
|
|
||||||
case Some(session) => Response.ok(GameMapper.toGameFull(session)).build()
|
|
||||||
case None =>
|
|
||||||
Response
|
|
||||||
.status(404)
|
|
||||||
.entity(ApiErrorResponse("GAME_NOT_FOUND", s"Game $gameId not found"))
|
|
||||||
.build()
|
|
||||||
|
|
||||||
@GET
|
|
||||||
@Path("/{gameId}/stream")
|
|
||||||
@Produces(Array("application/x-ndjson"))
|
|
||||||
def streamGame(@PathParam("gameId") gameId: String): Response =
|
|
||||||
store.get(gameId) match
|
|
||||||
case None =>
|
|
||||||
Response
|
|
||||||
.status(404)
|
|
||||||
.`type`(MediaType.APPLICATION_JSON)
|
|
||||||
.entity(ApiErrorResponse("GAME_NOT_FOUND", s"Game $gameId not found"))
|
|
||||||
.build()
|
|
||||||
case Some(session) =>
|
|
||||||
// Simplified: return a single-line NDJSON snapshot of the current game state
|
|
||||||
val event = s"""{"type":"gameFull","game":${GameMapper.toGameFullJson(session)}}"""
|
|
||||||
Response.ok(event + "\n").build()
|
|
||||||
|
|
||||||
@POST
|
|
||||||
@Path("/{gameId}/resign")
|
|
||||||
def resignGame(@PathParam("gameId") gameId: String): Response =
|
|
||||||
store.resign(gameId) match
|
|
||||||
case Right(_) => Response.ok(OkResponse()).build()
|
|
||||||
case Left(err) if err.contains("not found") =>
|
|
||||||
Response.status(404).entity(ApiErrorResponse("GAME_NOT_FOUND", err)).build()
|
|
||||||
case Left(err) =>
|
|
||||||
Response.status(400).entity(ApiErrorResponse("RESIGN_ERROR", err)).build()
|
|
||||||
|
|
||||||
@POST
|
|
||||||
@Path("/{gameId}/draw/{action}")
|
|
||||||
def drawAction(
|
|
||||||
@PathParam("gameId") gameId: String,
|
|
||||||
@PathParam("action") action: String,
|
|
||||||
): Response =
|
|
||||||
store.drawAction(gameId, action) match
|
|
||||||
case Right(_) => Response.ok(OkResponse()).build()
|
|
||||||
case Left(err) if err.contains("not found") =>
|
|
||||||
Response.status(404).entity(ApiErrorResponse("GAME_NOT_FOUND", err)).build()
|
|
||||||
case Left(err) =>
|
|
||||||
Response.status(400).entity(ApiErrorResponse("DRAW_ERROR", err)).build()
|
|
||||||
|
|
||||||
@GET
|
|
||||||
@Path("/{gameId}/export/fen")
|
|
||||||
@Produces(Array(MediaType.TEXT_PLAIN))
|
|
||||||
def exportFen(@PathParam("gameId") gameId: String): Response =
|
|
||||||
store.get(gameId) match
|
|
||||||
case None =>
|
|
||||||
Response
|
|
||||||
.status(404)
|
|
||||||
.`type`(MediaType.APPLICATION_JSON)
|
|
||||||
.entity(ApiErrorResponse("GAME_NOT_FOUND", s"Game $gameId not found"))
|
|
||||||
.build()
|
|
||||||
case Some(session) =>
|
|
||||||
import de.nowchess.io.fen.FenExporter
|
|
||||||
Response.ok(FenExporter.exportGameContext(session.context)).build()
|
|
||||||
|
|
||||||
@GET
|
|
||||||
@Path("/{gameId}/export/pgn")
|
|
||||||
@Produces(Array("application/x-chess-pgn"))
|
|
||||||
def exportPgn(@PathParam("gameId") gameId: String): Response =
|
|
||||||
store.get(gameId) match
|
|
||||||
case None =>
|
|
||||||
Response
|
|
||||||
.status(404)
|
|
||||||
.`type`(MediaType.APPLICATION_JSON)
|
|
||||||
.entity(ApiErrorResponse("GAME_NOT_FOUND", s"Game $gameId not found"))
|
|
||||||
.build()
|
|
||||||
case Some(session) =>
|
|
||||||
import de.nowchess.io.pgn.PgnExporter
|
|
||||||
Response.ok(PgnExporter.exportGameContext(session.context)).build()
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package de.nowchess.backcore.resource
|
|
||||||
|
|
||||||
import de.nowchess.backcore.dto.{ApiErrorResponse, ImportFenRequest, ImportPgnRequest}
|
|
||||||
import de.nowchess.backcore.game.{GameMapper, GameStore}
|
|
||||||
import jakarta.enterprise.context.ApplicationScoped
|
|
||||||
import jakarta.inject.Inject
|
|
||||||
import jakarta.ws.rs.*
|
|
||||||
import jakarta.ws.rs.core.{MediaType, Response}
|
|
||||||
|
|
||||||
@Path("/api/board/game/import")
|
|
||||||
@Produces(Array(MediaType.APPLICATION_JSON))
|
|
||||||
@Consumes(Array(MediaType.APPLICATION_JSON))
|
|
||||||
@ApplicationScoped
|
|
||||||
class ImportResource @Inject() (store: GameStore):
|
|
||||||
|
|
||||||
@POST
|
|
||||||
@Path("/fen")
|
|
||||||
def importFen(req: ImportFenRequest): Response =
|
|
||||||
store.importFen(Option(req).getOrElse(ImportFenRequest())) match
|
|
||||||
case Right(session) => Response.status(201).entity(GameMapper.toGameFull(session)).build()
|
|
||||||
case Left(err) => Response.status(400).entity(ApiErrorResponse("INVALID_FEN", err)).build()
|
|
||||||
|
|
||||||
@POST
|
|
||||||
@Path("/pgn")
|
|
||||||
def importPgn(req: ImportPgnRequest): Response =
|
|
||||||
val body = Option(req).getOrElse(ImportPgnRequest())
|
|
||||||
store.importPgn(body.pgn, None, None) match
|
|
||||||
case Right(session) => Response.status(201).entity(GameMapper.toGameFull(session)).build()
|
|
||||||
case Left(err) => Response.status(400).entity(ApiErrorResponse("INVALID_PGN", err)).build()
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
package de.nowchess.backcore.resource
|
|
||||||
|
|
||||||
import de.nowchess.api.board.Square
|
|
||||||
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
|
||||||
import de.nowchess.backcore.dto.*
|
|
||||||
import de.nowchess.backcore.game.{GameMapper, GameStore}
|
|
||||||
import jakarta.enterprise.context.ApplicationScoped
|
|
||||||
import jakarta.inject.Inject
|
|
||||||
import jakarta.ws.rs.*
|
|
||||||
import jakarta.ws.rs.core.{MediaType, Response}
|
|
||||||
|
|
||||||
@Path("/api/board/game")
|
|
||||||
@Produces(Array(MediaType.APPLICATION_JSON))
|
|
||||||
@ApplicationScoped
|
|
||||||
class MoveResource @Inject() (store: GameStore):
|
|
||||||
|
|
||||||
@POST
|
|
||||||
@Path("/{gameId}/move/{uci}")
|
|
||||||
def makeMove(
|
|
||||||
@PathParam("gameId") gameId: String,
|
|
||||||
@PathParam("uci") uci: String,
|
|
||||||
): Response =
|
|
||||||
store.applyMove(gameId, uci) match
|
|
||||||
case Right(session) => Response.ok(GameMapper.toGameState(session)).build()
|
|
||||||
case Left(err) if err.contains("not found") =>
|
|
||||||
Response.status(404).entity(ApiErrorResponse("GAME_NOT_FOUND", err)).build()
|
|
||||||
case Left(err) =>
|
|
||||||
Response.status(400).entity(ApiErrorResponse("INVALID_MOVE", err)).build()
|
|
||||||
|
|
||||||
@GET
|
|
||||||
@Path("/{gameId}/moves")
|
|
||||||
def getLegalMoves(
|
|
||||||
@PathParam("gameId") gameId: String,
|
|
||||||
@QueryParam("square") squareParam: String,
|
|
||||||
): Response =
|
|
||||||
val square = Option(squareParam).flatMap(Square.fromAlgebraic)
|
|
||||||
store.legalMoves(gameId, square) match
|
|
||||||
case Right(moves) =>
|
|
||||||
val dtos = moves.map(toLegalMoveDto)
|
|
||||||
Response.ok(LegalMovesResponse(dtos)).build()
|
|
||||||
case Left(err) if err.contains("not found") =>
|
|
||||||
Response.status(404).entity(ApiErrorResponse("GAME_NOT_FOUND", err)).build()
|
|
||||||
case Left(err) =>
|
|
||||||
Response.status(400).entity(ApiErrorResponse("ERROR", err)).build()
|
|
||||||
|
|
||||||
@POST
|
|
||||||
@Path("/{gameId}/undo")
|
|
||||||
def undoMove(@PathParam("gameId") gameId: String): Response =
|
|
||||||
store.undo(gameId) match
|
|
||||||
case Right(session) => Response.ok(GameMapper.toGameState(session)).build()
|
|
||||||
case Left(err) if err.contains("not found") =>
|
|
||||||
Response.status(404).entity(ApiErrorResponse("GAME_NOT_FOUND", err)).build()
|
|
||||||
case Left(err) =>
|
|
||||||
Response.status(400).entity(ApiErrorResponse("UNDO_NOT_AVAILABLE", err)).build()
|
|
||||||
|
|
||||||
@POST
|
|
||||||
@Path("/{gameId}/redo")
|
|
||||||
def redoMove(@PathParam("gameId") gameId: String): Response =
|
|
||||||
store.redo(gameId) match
|
|
||||||
case Right(session) => Response.ok(GameMapper.toGameState(session)).build()
|
|
||||||
case Left(err) if err.contains("not found") =>
|
|
||||||
Response.status(404).entity(ApiErrorResponse("GAME_NOT_FOUND", err)).build()
|
|
||||||
case Left(err) =>
|
|
||||||
Response.status(400).entity(ApiErrorResponse("REDO_NOT_AVAILABLE", err)).build()
|
|
||||||
|
|
||||||
private def toLegalMoveDto(move: Move): LegalMoveDto =
|
|
||||||
val uci = GameMapper.moveToUci(move)
|
|
||||||
val (moveType, promotion) = move.moveType match
|
|
||||||
case MoveType.Normal(true) => ("capture", None)
|
|
||||||
case MoveType.Normal(false) => ("normal", None)
|
|
||||||
case MoveType.CastleKingside => ("castleKingside", None)
|
|
||||||
case MoveType.CastleQueenside => ("castleQueenside", None)
|
|
||||||
case MoveType.EnPassant => ("enPassant", None)
|
|
||||||
case MoveType.Promotion(pp) =>
|
|
||||||
val pName = pp match
|
|
||||||
case PromotionPiece.Queen => "queen"
|
|
||||||
case PromotionPiece.Rook => "rook"
|
|
||||||
case PromotionPiece.Bishop => "bishop"
|
|
||||||
case PromotionPiece.Knight => "knight"
|
|
||||||
("promotion", Some(pName))
|
|
||||||
LegalMoveDto(
|
|
||||||
from = move.from.toString,
|
|
||||||
to = move.to.toString,
|
|
||||||
uci = uci,
|
|
||||||
moveType = moveType,
|
|
||||||
promotion = promotion,
|
|
||||||
)
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
package de.nowchess.backcore
|
|
||||||
|
|
||||||
import io.quarkus.test.junit.QuarkusTest
|
|
||||||
import org.junit.jupiter.api.Test
|
|
||||||
|
|
||||||
@QuarkusTest
|
|
||||||
class BackcoreStartupTest:
|
|
||||||
@Test
|
|
||||||
def applicationStarts(): Unit =
|
|
||||||
// If we get here the Quarkus container started successfully
|
|
||||||
()
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
package de.nowchess.backcore.resource
|
|
||||||
|
|
||||||
import io.quarkus.test.junit.QuarkusTest
|
|
||||||
import io.restassured.RestAssured
|
|
||||||
import org.hamcrest.Matchers.{equalTo, matchesPattern, notNullValue}
|
|
||||||
import org.junit.jupiter.api.Test
|
|
||||||
|
|
||||||
@QuarkusTest
|
|
||||||
class GameResourceTest:
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def createGameReturns201WithGameId(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body("{}")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(201)
|
|
||||||
.body("gameId", matchesPattern("[A-Za-z0-9]{8}"))
|
|
||||||
.body("state.fen", notNullValue())
|
|
||||||
.body("state.turn", equalTo("white"))
|
|
||||||
.body("state.status", equalTo("started"))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def createGameWithPlayersReturns201(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body("""{"white":{"id":"p1","displayName":"Alice"},"black":{"id":"p2","displayName":"Bob"}}""")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(201)
|
|
||||||
.body("white.id", equalTo("p1"))
|
|
||||||
.body("black.displayName", equalTo("Bob"))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def getGameReturns200ForExistingGame(): Unit =
|
|
||||||
val gameId = RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body("{}")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(201)
|
|
||||||
.extract()
|
|
||||||
.path[String]("gameId")
|
|
||||||
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get(s"/api/board/game/$gameId")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("gameId", equalTo(gameId))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def getGameReturns404ForUnknownId(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get("/api/board/game/XXXXXXXX")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(404)
|
|
||||||
@@ -1,177 +0,0 @@
|
|||||||
package de.nowchess.backcore.resource
|
|
||||||
|
|
||||||
import io.quarkus.test.junit.QuarkusTest
|
|
||||||
import io.restassured.RestAssured
|
|
||||||
import org.hamcrest.Matchers.{containsString, equalTo, matchesPattern, notNullValue}
|
|
||||||
import org.junit.jupiter.api.Test
|
|
||||||
|
|
||||||
@QuarkusTest
|
|
||||||
class ImportExportTest:
|
|
||||||
|
|
||||||
private val startFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
||||||
|
|
||||||
// ─── Import FEN ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def importFenReturns201WithCorrectPosition(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body(s"""{"fen":"$startFen"}""")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game/import/fen")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(201)
|
|
||||||
.body("gameId", matchesPattern("[A-Za-z0-9]{8}"))
|
|
||||||
.body("state.fen", equalTo(startFen))
|
|
||||||
.body("state.turn", equalTo("white"))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def importFenWithCustomPositionWorks(): Unit =
|
|
||||||
val fen = "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1"
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body(s"""{"fen":"$fen"}""")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game/import/fen")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(201)
|
|
||||||
.body("state.fen", equalTo(fen))
|
|
||||||
.body("state.turn", equalTo("black"))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def importFenWithInvalidFenReturns400(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body("""{"fen":"not-a-fen"}""")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game/import/fen")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(400)
|
|
||||||
|
|
||||||
// ─── Import PGN ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def importPgnReturns201(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body("""{"pgn":"1. e4 e5 2. Nf3 Nc6 *"}""")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game/import/pgn")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(201)
|
|
||||||
.body("gameId", matchesPattern("[A-Za-z0-9]{8}"))
|
|
||||||
.body("state.turn", equalTo("white"))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def importPgnWithInvalidPgnReturns400(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body("""{"pgn":"1. z9 *"}""")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game/import/pgn")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(400)
|
|
||||||
|
|
||||||
// ─── Export FEN ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def exportFenReturnsStartingFen(): Unit =
|
|
||||||
val gameId = RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body("{}")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(201)
|
|
||||||
.extract()
|
|
||||||
.path[String]("gameId")
|
|
||||||
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get(s"/api/board/game/$gameId/export/fen")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body(equalTo(startFen))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def exportFenOnUnknownGameReturns404(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get("/api/board/game/XXXXXXXX/export/fen")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(404)
|
|
||||||
|
|
||||||
// ─── Export PGN ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def exportPgnReturnsText(): Unit =
|
|
||||||
val gameId = RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body("{}")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(201)
|
|
||||||
.extract()
|
|
||||||
.path[String]("gameId")
|
|
||||||
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/move/e2e4")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get(s"/api/board/game/$gameId/export/pgn")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body(containsString("e4"))
|
|
||||||
|
|
||||||
// ─── Stream ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def streamReturnsNdjsonSnapshot(): Unit =
|
|
||||||
val gameId = RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body("{}")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(201)
|
|
||||||
.extract()
|
|
||||||
.path[String]("gameId")
|
|
||||||
|
|
||||||
val body = RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get(s"/api/board/game/$gameId/stream")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.contentType("application/x-ndjson")
|
|
||||||
.extract()
|
|
||||||
.body()
|
|
||||||
.asString()
|
|
||||||
|
|
||||||
assert(body.trim.startsWith("""{"type":"gameFull""""), s"Expected gameFull event, got: $body")
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def streamOnUnknownGameReturns404(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get("/api/board/game/XXXXXXXX/stream")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(404)
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
package de.nowchess.backcore.resource
|
|
||||||
|
|
||||||
import io.quarkus.test.junit.QuarkusTest
|
|
||||||
import io.restassured.RestAssured
|
|
||||||
import org.hamcrest.Matchers.{containsString, empty, equalTo, hasItem, hasItems, not, notNullValue}
|
|
||||||
import org.junit.jupiter.api.Test
|
|
||||||
|
|
||||||
@QuarkusTest
|
|
||||||
class MoveResourceTest:
|
|
||||||
|
|
||||||
private def createGame(): String =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body("{}")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(201)
|
|
||||||
.extract()
|
|
||||||
.path[String]("gameId")
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def makeMoveReturns200WithUpdatedFen(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/move/e2e4")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("fen", containsString("4P3")) // e4 pawn present in FEN
|
|
||||||
.body("turn", equalTo("black"))
|
|
||||||
.body("moves", hasItem("e2e4"))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def makeMoveOnUnknownGameReturns404(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game/XXXXXXXX/move/e2e4")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(404)
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def illegalMoveReturns400(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/move/e2e5") // illegal — pawns can't jump 3 squares
|
|
||||||
.`then`()
|
|
||||||
.statusCode(400)
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def getLegalMovesReturnsNonEmptyList(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get(s"/api/board/game/$gameId/moves")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("moves", not(empty()))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def getLegalMovesFilteredBySquareReturnsCorrectMoves(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get(s"/api/board/game/$gameId/moves?square=e2")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("moves.uci", hasItems("e2e3", "e2e4"))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def getLegalMovesOnUnknownGameReturns404(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get("/api/board/game/XXXXXXXX/moves")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(404)
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
package de.nowchess.backcore.resource
|
|
||||||
|
|
||||||
import io.quarkus.test.junit.QuarkusTest
|
|
||||||
import io.restassured.RestAssured
|
|
||||||
import org.hamcrest.Matchers.{equalTo, notNullValue}
|
|
||||||
import org.junit.jupiter.api.Test
|
|
||||||
|
|
||||||
@QuarkusTest
|
|
||||||
class ResignDrawTest:
|
|
||||||
|
|
||||||
private def createGame(): String =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body("{}")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(201)
|
|
||||||
.extract()
|
|
||||||
.path[String]("gameId")
|
|
||||||
|
|
||||||
// ─── Resign ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def resignReturns200(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/resign")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("ok", equalTo(true))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def afterResignGameShowsResignStatusAndWinner(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/resign")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get(s"/api/board/game/$gameId")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("state.status", equalTo("resign"))
|
|
||||||
.body("state.winner", notNullValue())
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def resignOnUnknownGameReturns404(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game/XXXXXXXX/resign")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(404)
|
|
||||||
|
|
||||||
// ─── Draw ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def offerDrawSetsDrawOfferedStatus(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/draw/offer")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("ok", equalTo(true))
|
|
||||||
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get(s"/api/board/game/$gameId")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("state.status", equalTo("drawOffered"))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def acceptDrawAfterOfferSetsDrawStatus(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
// White offers
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/draw/offer")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
|
|
||||||
// Black moves so it's black's turn... actually the API doesn't enforce turn-based draw accept.
|
|
||||||
// White offered, so black (opponent) accepts — but since there's no auth, we just call accept.
|
|
||||||
// The GameStore checks drawOfferedBy != turn to allow accept.
|
|
||||||
// White offered on white's turn, so black needs to accept — but current turn is still white.
|
|
||||||
// We need to make a move first to switch turns.
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/move/e2e4")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
|
|
||||||
// Now it's black's turn and white offered the draw — black accepts
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/draw/accept")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("ok", equalTo(true))
|
|
||||||
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get(s"/api/board/game/$gameId")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("state.status", equalTo("draw"))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def declineDrawClearsOffer(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/draw/offer")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/draw/decline")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("ok", equalTo(true))
|
|
||||||
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.get(s"/api/board/game/$gameId")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("state.status", equalTo("started"))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def acceptWithoutOfferReturns400(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/draw/accept")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(400)
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def drawOnUnknownGameReturns404(): Unit =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game/XXXXXXXX/draw/offer")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(404)
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
package de.nowchess.backcore.resource
|
|
||||||
|
|
||||||
import io.quarkus.test.junit.QuarkusTest
|
|
||||||
import io.restassured.RestAssured
|
|
||||||
import org.hamcrest.Matchers.{containsString, equalTo}
|
|
||||||
import org.junit.jupiter.api.Test
|
|
||||||
|
|
||||||
@QuarkusTest
|
|
||||||
class UndoRedoTest:
|
|
||||||
|
|
||||||
private val initialFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
||||||
|
|
||||||
private def createGame(): String =
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.contentType("application/json")
|
|
||||||
.body("{}")
|
|
||||||
.when()
|
|
||||||
.post("/api/board/game")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(201)
|
|
||||||
.extract()
|
|
||||||
.path[String]("gameId")
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def undoAfterMoveRestoresOriginalPosition(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/move/e2e4")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/undo")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("fen", equalTo(initialFen))
|
|
||||||
.body("undoAvailable", equalTo(false))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def redoAfterUndoRestoresMovedPosition(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/move/e2e4")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/undo")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/redo")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(200)
|
|
||||||
.body("fen", containsString("4P3"))
|
|
||||||
.body("turn", equalTo("black"))
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def undoWithNoHistoryReturns400(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/undo")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(400)
|
|
||||||
|
|
||||||
@Test
|
|
||||||
def redoWithNoRedoStackReturns400(): Unit =
|
|
||||||
val gameId = createGame()
|
|
||||||
RestAssured
|
|
||||||
.`given`()
|
|
||||||
.when()
|
|
||||||
.post(s"/api/board/game/$gameId/redo")
|
|
||||||
.`then`()
|
|
||||||
.statusCode(400)
|
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
plugins {
|
||||||
|
id("scala")
|
||||||
|
id("org.scoverage")
|
||||||
|
}
|
||||||
|
|
||||||
|
group = "de.nowchess"
|
||||||
|
version = "1.0-SNAPSHOT"
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
val versions = rootProject.extra["VERSIONS"] as Map<String, String>
|
||||||
|
|
||||||
|
repositories {
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
|
||||||
|
scala {
|
||||||
|
scalaVersion = versions["SCALA3"]!!
|
||||||
|
}
|
||||||
|
|
||||||
|
scoverage {
|
||||||
|
scoverageVersion.set(versions["SCOVERAGE"]!!)
|
||||||
|
excludedPackages.set(
|
||||||
|
listOf(
|
||||||
|
"de\\.nowchess\\.bot\\.bots\\.NNUEBot",
|
||||||
|
"de\\.nowchess\\.bot\\.bots\\.nnue\\..*",
|
||||||
|
"de\\.nowchess\\.bot\\.util\\.PolyglotBook",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
excludedFiles.set(
|
||||||
|
listOf(
|
||||||
|
".*NNUE\\.scala",
|
||||||
|
".*NNUEBot\\.scala",
|
||||||
|
".*NbaiLoader\\.scala",
|
||||||
|
".*NbaiMigrator\\.scala",
|
||||||
|
".*NbaiWriter\\.scala",
|
||||||
|
".*PolyglotBook\\.scala",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks.withType<ScalaCompile> {
|
||||||
|
scalaCompileOptions.additionalParameters = listOf("-encoding", "UTF-8")
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
|
||||||
|
implementation("org.scala-lang:scala3-compiler_3") {
|
||||||
|
version {
|
||||||
|
strictly(versions["SCALA3"]!!)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
implementation("org.scala-lang:scala3-library_3") {
|
||||||
|
version {
|
||||||
|
strictly(versions["SCALA3"]!!)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
implementation(project(":modules:api"))
|
||||||
|
implementation(project(":modules:io"))
|
||||||
|
implementation(project(":modules:rule"))
|
||||||
|
implementation("com.microsoft.onnxruntime:onnxruntime:${versions["ONNXRUNTIME"]!!}")
|
||||||
|
|
||||||
|
testImplementation(platform("org.junit:junit-bom:${versions["JUNIT_BOM"]!!}"))
|
||||||
|
testImplementation("org.junit.jupiter:junit-jupiter")
|
||||||
|
testImplementation("org.scalatest:scalatest_3:${versions["SCALATEST"]!!}")
|
||||||
|
testImplementation("co.helmethair:scalatest-junit-runner:${versions["SCALATEST_JUNIT"]!!}")
|
||||||
|
|
||||||
|
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks.test {
|
||||||
|
useJUnitPlatform {
|
||||||
|
includeEngines("scalatest")
|
||||||
|
testLogging {
|
||||||
|
events("skipped", "failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finalizedBy(tasks.reportScoverage)
|
||||||
|
}
|
||||||
|
tasks.reportScoverage {
|
||||||
|
dependsOn(tasks.test)
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks.jar {
|
||||||
|
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
|
||||||
|
}
|
||||||
Binary file not shown.
@@ -0,0 +1,22 @@
|
|||||||
|
# Data and weights are local artifacts, not committed
|
||||||
|
data/
|
||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
.venv
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
tactical_data/
|
||||||
|
trainingdata/
|
||||||
|
/datasets/
|
||||||
@@ -0,0 +1,173 @@
|
|||||||
|
# Training Dataset Management
|
||||||
|
|
||||||
|
The NNUE training pipeline now features versioned dataset management, similar to model versioning. This prevents data loss and allows you to maintain multiple training configurations.
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
datasets/
|
||||||
|
ds_v1/
|
||||||
|
labeled.jsonl # Training data: {"fen": "...", "eval": 0.5, "eval_raw": 150}
|
||||||
|
metadata.json # Version info and composition
|
||||||
|
ds_v2/
|
||||||
|
labeled.jsonl
|
||||||
|
metadata.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Metadata Schema
|
||||||
|
|
||||||
|
Each dataset has a `metadata.json` file tracking its composition:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"created": "2026-04-13T15:30:45.123456",
|
||||||
|
"total_positions": 1000000,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"sources": [
|
||||||
|
{
|
||||||
|
"type": "generated",
|
||||||
|
"count": 500000,
|
||||||
|
"params": {
|
||||||
|
"num_positions": 500000,
|
||||||
|
"min_move": 1,
|
||||||
|
"max_move": 50
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tactical",
|
||||||
|
"count": 300000,
|
||||||
|
"max_puzzles": 300000
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "file_import",
|
||||||
|
"count": 200000,
|
||||||
|
"path": "/path/to/original_file.txt"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## TUI Workflow
|
||||||
|
|
||||||
|
### Main Menu
|
||||||
|
```
|
||||||
|
1 - Manage Training Data
|
||||||
|
2 - Train Model
|
||||||
|
3 - Export Model
|
||||||
|
4 - Exit
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training Data Management Submenu
|
||||||
|
```
|
||||||
|
1 - Create new dataset
|
||||||
|
2 - Extend existing dataset
|
||||||
|
3 - View all datasets
|
||||||
|
4 - Delete dataset
|
||||||
|
5 - Back
|
||||||
|
```
|
||||||
|
|
||||||
|
## Creating a Dataset
|
||||||
|
|
||||||
|
Use the "Create new dataset" option to add data from one or more sources:
|
||||||
|
|
||||||
|
1. **Generate random positions** — Play random games and sample positions
|
||||||
|
- Number of positions
|
||||||
|
- Move range (min/max move number to sample from)
|
||||||
|
- Number of worker threads
|
||||||
|
|
||||||
|
2. **Import from file** — Load positions from a FEN file
|
||||||
|
- File must contain one FEN string per line
|
||||||
|
- Duplicates are automatically removed
|
||||||
|
|
||||||
|
3. **Extract tactical puzzles** — Download and extract Lichess puzzle database
|
||||||
|
- Maximum number of puzzles to include
|
||||||
|
- Automatically filters for tactical themes (forks, pins, mates, etc.)
|
||||||
|
|
||||||
|
You can combine multiple sources in a single dataset creation session. All positions are:
|
||||||
|
- Deduplicated (only unique FENs are kept)
|
||||||
|
- Labeled with Stockfish evaluations
|
||||||
|
- Saved to `datasets/ds_vN/labeled.jsonl`
|
||||||
|
|
||||||
|
## Extending a Dataset
|
||||||
|
|
||||||
|
Use "Extend existing dataset" to add more positions to an existing dataset:
|
||||||
|
|
||||||
|
1. Select the dataset version to extend
|
||||||
|
2. Choose data sources (same options as creation)
|
||||||
|
3. Confirm labeling parameters
|
||||||
|
4. New positions are:
|
||||||
|
- Labeled with Stockfish
|
||||||
|
- Deduplicated against the target dataset (preventing duplicates)
|
||||||
|
- Merged into the existing `labeled.jsonl`
|
||||||
|
- Metadata is updated with the new source entry
|
||||||
|
|
||||||
|
## Training with a Dataset
|
||||||
|
|
||||||
|
When you start training (Standard or Burst mode), you'll be prompted to select a dataset version. The TUI will display all available datasets with:
|
||||||
|
- Version number
|
||||||
|
- Total number of positions
|
||||||
|
- Source types (generated, tactical, imported)
|
||||||
|
- Stockfish depth used
|
||||||
|
- Creation date
|
||||||
|
|
||||||
|
## Legacy Data Migration
|
||||||
|
|
||||||
|
If you have existing labeled data in `data/training_data.jsonl` from before this update:
|
||||||
|
|
||||||
|
1. Open the "Manage Training Data" menu
|
||||||
|
2. Choose "Create new dataset"
|
||||||
|
3. Select "Import from file"
|
||||||
|
4. Point to `data/training_data.jsonl`
|
||||||
|
5. Complete the dataset creation
|
||||||
|
|
||||||
|
Alternatively, you can manually copy the file to `datasets/ds_v1/labeled.jsonl` and create a `metadata.json` file.
|
||||||
|
|
||||||
|
## Viewing Dataset Details
|
||||||
|
|
||||||
|
Use "View all datasets" to see a table of all datasets with:
|
||||||
|
- Version number
|
||||||
|
- Position count
|
||||||
|
- Source composition
|
||||||
|
- Stockfish depth
|
||||||
|
- Creation date
|
||||||
|
|
||||||
|
## Deleting a Dataset
|
||||||
|
|
||||||
|
Use "Delete dataset" to remove a dataset and free up disk space. **This action cannot be undone.**
|
||||||
|
|
||||||
|
⚠️ The system does not prevent deleting datasets used by model checkpoints. Plan accordingly.
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
|
||||||
|
### Deduplication Strategy
|
||||||
|
|
||||||
|
When extending a dataset, positions are deduplicated **within that dataset only**. This allows different datasets to contain overlapping positions if desired.
|
||||||
|
|
||||||
|
When creating a new dataset from multiple sources, all sources are combined and deduplicated before labeling.
|
||||||
|
|
||||||
|
### Labeled Position Format
|
||||||
|
|
||||||
|
Each line in `labeled.jsonl` is a JSON object:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
|
||||||
|
"eval": 0.0,
|
||||||
|
"eval_raw": 0
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `fen`: The position in Forsyth-Edwards Notation
|
||||||
|
- `eval`: Normalized evaluation ([-1, 1] range using tanh)
|
||||||
|
- `eval_raw`: Raw Stockfish evaluation in centipawns
|
||||||
|
|
||||||
|
### Storage Location
|
||||||
|
|
||||||
|
Datasets are stored in the `datasets/` directory relative to the script location. The old `data/` directory is preserved for backward compatibility but is not actively used by the new system.
|
||||||
|
|
||||||
|
## Performance Tips
|
||||||
|
|
||||||
|
- **Smaller datasets train faster** — Start with 100k-500k positions
|
||||||
|
- **Deduplication matters** — Use the extend functionality to build up your dataset without redundant data
|
||||||
|
- **Stockfish depth** — Depth 12-14 balances accuracy and labeling speed
|
||||||
|
- **Workers** — Use 4-8 workers for labeling if your machine supports it; more workers = faster but uses more CPU/memory
|
||||||
@@ -0,0 +1,129 @@
|
|||||||
|
# NNUE Python Pipeline
|
||||||
|
|
||||||
|
Central CLI for training and exporting chess evaluation neural networks (NNUE).
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
python/
|
||||||
|
├── nnue.py # Main CLI entry point
|
||||||
|
├── src/ # Python modules
|
||||||
|
│ ├── generate.py # Generate random chess positions
|
||||||
|
│ ├── label.py # Label positions with Stockfish
|
||||||
|
│ ├── train.py # Train NNUE model
|
||||||
|
│ └── export.py # Export weights to Scala
|
||||||
|
├── data/ # Training data (gitignored)
|
||||||
|
│ ├── positions.txt
|
||||||
|
│ └── training_data.jsonl
|
||||||
|
└── weights/ # Model weights (gitignored)
|
||||||
|
├── nnue_weights_v1.pt
|
||||||
|
├── nnue_weights_v1_metadata.json
|
||||||
|
└── ...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Train a new model (500k positions, auto-detect checkpoint)
|
||||||
|
python nnue.py train
|
||||||
|
|
||||||
|
# Train from specific checkpoint
|
||||||
|
python nnue.py train --from-checkpoint 2
|
||||||
|
|
||||||
|
# Train with custom games count
|
||||||
|
python nnue.py train --games 200000
|
||||||
|
|
||||||
|
# Train with custom positions file
|
||||||
|
python nnue.py train --positions-file my_positions.txt
|
||||||
|
|
||||||
|
# Export specific version to Scala
|
||||||
|
python nnue.py export 2
|
||||||
|
|
||||||
|
# List all checkpoints
|
||||||
|
python nnue.py list
|
||||||
|
```
|
||||||
|
|
||||||
|
## CLI Commands
|
||||||
|
|
||||||
|
### `train` - Train NNUE model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python nnue.py train [OPTIONS]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--from-checkpoint N` - Resume from checkpoint version N (default: uses latest)
|
||||||
|
- `--games N` - Number of games to generate (default: 500000)
|
||||||
|
- `--positions-file FILE` - Use existing positions file instead of generating
|
||||||
|
- `--stockfish PATH` - Path to Stockfish binary (default: `$STOCKFISH_PATH` or `/usr/games/stockfish`)
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
```bash
|
||||||
|
# Train with latest checkpoint
|
||||||
|
python nnue.py train
|
||||||
|
|
||||||
|
# Train from v2 with 100k games
|
||||||
|
python nnue.py train --from-checkpoint 2 --games 100000
|
||||||
|
|
||||||
|
# Train with custom positions
|
||||||
|
python nnue.py train --positions-file my_games.txt --stockfish /opt/stockfish/sf15
|
||||||
|
```
|
||||||
|
|
||||||
|
### `export` - Export weights to Scala
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python nnue.py export WEIGHTS [output_path]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Arguments:**
|
||||||
|
- `WEIGHTS` - Version number (e.g., `2`) or full filename (e.g., `nnue_weights_v2.pt`)
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
```bash
|
||||||
|
# Export version 2
|
||||||
|
python nnue.py export 2
|
||||||
|
|
||||||
|
# Export with full filename
|
||||||
|
python nnue.py export nnue_weights_v3.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
Output goes to `../src/main/scala/de/nowchess/bot/bots/nnue/NNUEWeights_vN.scala`
|
||||||
|
|
||||||
|
### `list` - List available checkpoints
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python nnue.py list
|
||||||
|
```
|
||||||
|
|
||||||
|
Shows all available model versions with file sizes.
|
||||||
|
|
||||||
|
## Data Flow
|
||||||
|
|
||||||
|
1. **Generate** → `data/positions.txt`
|
||||||
|
- Random chess positions from 8-20 move openings
|
||||||
|
- Filters out checks, game-over states, and captures
|
||||||
|
|
||||||
|
2. **Label** → `data/training_data.jsonl`
|
||||||
|
- Evaluates each position with Stockfish at depth 12
|
||||||
|
- Stores FEN + evaluation in JSONL format
|
||||||
|
|
||||||
|
3. **Train** → `weights/nnue_weights_vN.pt`
|
||||||
|
- Trains neural network on labeled positions
|
||||||
|
- Auto-versioning (v1, v2, v3, etc.)
|
||||||
|
- Saves metadata alongside weights
|
||||||
|
|
||||||
|
4. **Export** → `NNUEWeights_vN.scala`
|
||||||
|
- Converts weights to Scala object
|
||||||
|
- Ready for integration into bot
|
||||||
|
|
||||||
|
## Versioning
|
||||||
|
|
||||||
|
- Models are automatically versioned (v1, v2, v3, etc.)
|
||||||
|
- Each version gets a `_metadata.json` file with training info
|
||||||
|
- Training from checkpoint uses latest version unless specified with `--from-checkpoint`
|
||||||
|
|
||||||
|
## Files
|
||||||
|
|
||||||
|
- `data/` and `weights/` are gitignored (local artifacts)
|
||||||
|
- Documentation in `docs/` explains training, debugging, and incremental improvements
|
||||||
|
- Source modules in `src/` are independent and can be imported for custom workflows
|
||||||
@@ -0,0 +1,951 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Central NNUE pipeline TUI for training and exporting models."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.prompt import Prompt, Confirm
|
||||||
|
from rich import print as rprint
|
||||||
|
|
||||||
|
# Add src directory to path so we can import modules
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||||||
|
|
||||||
|
from generate import play_random_game_and_collect_positions
|
||||||
|
from label import label_positions_with_stockfish
|
||||||
|
from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES
|
||||||
|
from export import export_to_nbai
|
||||||
|
from tactical_positions_extractor import (
|
||||||
|
download_and_extract_puzzle_db,
|
||||||
|
extract_tactical_only
|
||||||
|
)
|
||||||
|
from lichess_importer import import_lichess_evals
|
||||||
|
from dataset import (
|
||||||
|
get_datasets_dir,
|
||||||
|
list_datasets,
|
||||||
|
next_dataset_version,
|
||||||
|
load_dataset_metadata,
|
||||||
|
create_dataset,
|
||||||
|
extend_dataset,
|
||||||
|
get_dataset_labeled_path,
|
||||||
|
delete_dataset,
|
||||||
|
show_datasets_table
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_weights_dir():
|
||||||
|
"""Get/create weights directory."""
|
||||||
|
weights_dir = Path(__file__).parent / "weights"
|
||||||
|
weights_dir.mkdir(exist_ok=True)
|
||||||
|
return weights_dir
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_dir():
|
||||||
|
"""Get/create legacy data directory (for migration)."""
|
||||||
|
data_dir = Path(__file__).parent / "data"
|
||||||
|
data_dir.mkdir(exist_ok=True)
|
||||||
|
return data_dir
|
||||||
|
|
||||||
|
|
||||||
|
def list_checkpoints():
|
||||||
|
"""List available checkpoint versions."""
|
||||||
|
weights_dir = get_weights_dir()
|
||||||
|
checkpoints = sorted(weights_dir.glob("nnue_weights_v*.pt"))
|
||||||
|
if not checkpoints:
|
||||||
|
return []
|
||||||
|
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_legacy_data():
|
||||||
|
"""On first run, offer to import existing data/training_data.jsonl as ds_v1."""
|
||||||
|
console = Console()
|
||||||
|
data_dir = get_data_dir()
|
||||||
|
legacy_file = data_dir / "training_data.jsonl"
|
||||||
|
datasets = list_datasets()
|
||||||
|
|
||||||
|
# Only migrate if legacy data exists and no datasets exist yet
|
||||||
|
if legacy_file.exists() and not datasets:
|
||||||
|
console.print("\n[cyan]Legacy data detected: data/training_data.jsonl[/cyan]")
|
||||||
|
console.print("[dim]Tip: Use 'Manage Training Data' menu to import it as ds_v1[/dim]")
|
||||||
|
|
||||||
|
|
||||||
|
def show_header():
|
||||||
|
"""Display application header."""
|
||||||
|
console = Console()
|
||||||
|
console.clear()
|
||||||
|
console.print(
|
||||||
|
Panel(
|
||||||
|
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
|
||||||
|
"[dim]Neural Network Utility Evaluation - Dataset & Model Management[/dim]",
|
||||||
|
border_style="cyan",
|
||||||
|
padding=(1, 2),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def show_checkpoints_table():
|
||||||
|
"""Display available checkpoints in a table."""
|
||||||
|
console = Console()
|
||||||
|
available = list_checkpoints()
|
||||||
|
|
||||||
|
if not available:
|
||||||
|
console.print("[yellow]ℹ No model checkpoints found yet[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title="Available Model Checkpoints", show_header=True, header_style="bold cyan")
|
||||||
|
table.add_column("Version", style="dim")
|
||||||
|
table.add_column("File Size", justify="right")
|
||||||
|
table.add_column("Status", justify="center")
|
||||||
|
|
||||||
|
weights_dir = get_weights_dir()
|
||||||
|
for v in sorted(available):
|
||||||
|
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
|
||||||
|
if weights_file.exists():
|
||||||
|
size = weights_file.stat().st_size / (1024**2)
|
||||||
|
table.add_row(f"v{v}", f"{size:.1f} MB", "✓ Ready")
|
||||||
|
else:
|
||||||
|
table.add_row(f"v{v}", "?", "[red]✗ Missing[/red]")
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
def show_main_menu():
|
||||||
|
"""Display and handle main menu."""
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
# Migrate legacy data on first run
|
||||||
|
migrate_legacy_data()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
show_header()
|
||||||
|
show_checkpoints_table()
|
||||||
|
|
||||||
|
console.print("\n[bold]What would you like to do?[/bold]")
|
||||||
|
console.print("[cyan]1[/cyan] - Manage Training Data")
|
||||||
|
console.print("[cyan]2[/cyan] - Train Model")
|
||||||
|
console.print("[cyan]3[/cyan] - Export Model")
|
||||||
|
console.print("[cyan]4[/cyan] - Exit")
|
||||||
|
|
||||||
|
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||||
|
|
||||||
|
if choice == "1":
|
||||||
|
datasets_menu()
|
||||||
|
elif choice == "2":
|
||||||
|
training_menu()
|
||||||
|
elif choice == "3":
|
||||||
|
export_interactive()
|
||||||
|
elif choice == "4":
|
||||||
|
console.print("[yellow]👋 Goodbye![/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def datasets_menu():
|
||||||
|
"""Dataset management submenu."""
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
show_header()
|
||||||
|
show_datasets_table(console)
|
||||||
|
|
||||||
|
console.print("\n[bold]Training Data Management[/bold]")
|
||||||
|
console.print("[cyan]1[/cyan] - Create new dataset")
|
||||||
|
console.print("[cyan]2[/cyan] - Extend existing dataset")
|
||||||
|
console.print("[cyan]3[/cyan] - View all datasets")
|
||||||
|
console.print("[cyan]4[/cyan] - Delete dataset")
|
||||||
|
console.print("[cyan]5[/cyan] - Back")
|
||||||
|
|
||||||
|
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
|
||||||
|
|
||||||
|
if choice == "1":
|
||||||
|
create_dataset_interactive()
|
||||||
|
elif choice == "2":
|
||||||
|
extend_dataset_interactive()
|
||||||
|
elif choice == "3":
|
||||||
|
show_header()
|
||||||
|
show_datasets_table(console)
|
||||||
|
Prompt.ask("\nPress Enter to continue")
|
||||||
|
elif choice == "4":
|
||||||
|
delete_dataset_interactive()
|
||||||
|
elif choice == "5":
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset_interactive():
|
||||||
|
"""Interactive dataset creation flow."""
|
||||||
|
console = Console()
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]📊 Create New Dataset[/bold cyan]")
|
||||||
|
|
||||||
|
sources = []
|
||||||
|
combined_count = 0
|
||||||
|
|
||||||
|
# Allow user to add multiple sources
|
||||||
|
while True:
|
||||||
|
console.print("\n[bold]Add data source (repeat until done):[/bold]")
|
||||||
|
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||||
|
console.print("[cyan]b[/cyan] - Import from file")
|
||||||
|
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||||
|
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
|
||||||
|
console.print("[cyan]e[/cyan] - Done adding sources")
|
||||||
|
|
||||||
|
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
|
||||||
|
|
||||||
|
if choice == "a":
|
||||||
|
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||||
|
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||||
|
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||||
|
num_workers = int(Prompt.ask("Number of workers", default="8"))
|
||||||
|
|
||||||
|
console.print("[dim]Generating positions...[/dim]")
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||||
|
count = play_random_game_and_collect_positions(
|
||||||
|
str(temp_file),
|
||||||
|
total_positions=num_positions,
|
||||||
|
samples_per_game=1,
|
||||||
|
min_move=min_move,
|
||||||
|
max_move=max_move,
|
||||||
|
num_workers=num_workers
|
||||||
|
)
|
||||||
|
if count > 0:
|
||||||
|
sources.append({
|
||||||
|
"type": "generated",
|
||||||
|
"count": count,
|
||||||
|
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
|
||||||
|
})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} positions generated[/green]")
|
||||||
|
else:
|
||||||
|
console.print("[red]✗ Generation failed[/red]")
|
||||||
|
|
||||||
|
elif choice == "b":
|
||||||
|
file_path = Prompt.ask("Path to FEN file")
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
count = sum(1 for _ in f)
|
||||||
|
sources.append({"type": "file_import", "count": count, "path": file_path})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} positions from file[/green]")
|
||||||
|
except FileNotFoundError:
|
||||||
|
console.print(f"[red]✗ File not found: {file_path}[/red]")
|
||||||
|
|
||||||
|
elif choice == "c":
|
||||||
|
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||||
|
console.print("[dim]Extracting tactical positions...[/dim]")
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||||
|
try:
|
||||||
|
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
|
||||||
|
if csv_path:
|
||||||
|
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
|
||||||
|
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Tactical extraction failed: {e}[/red]")
|
||||||
|
|
||||||
|
elif choice == "d":
|
||||||
|
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
|
||||||
|
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
|
||||||
|
max_pos = int(max_pos) if max_pos.strip() else None
|
||||||
|
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
|
||||||
|
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||||
|
temp_file.unlink(missing_ok=True)
|
||||||
|
try:
|
||||||
|
count = import_lichess_evals(
|
||||||
|
input_path=zst_path,
|
||||||
|
output_file=str(temp_file),
|
||||||
|
max_positions=max_pos,
|
||||||
|
min_depth=min_depth,
|
||||||
|
)
|
||||||
|
if count > 0:
|
||||||
|
sources.append({
|
||||||
|
"type": "lichess",
|
||||||
|
"count": count,
|
||||||
|
"params": {"min_depth": min_depth, "max_positions": max_pos},
|
||||||
|
})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
|
||||||
|
else:
|
||||||
|
console.print("[red]✗ No positions imported[/red]")
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
|
||||||
|
|
||||||
|
elif choice == "e":
|
||||||
|
if not sources:
|
||||||
|
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
if not sources:
|
||||||
|
console.print("[yellow]Dataset creation cancelled[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine whether any sources still need Stockfish labeling.
|
||||||
|
# Lichess sources are already labeled; only generated/tactical/file sources need it.
|
||||||
|
needs_labeling = any(s["type"] != "lichess" for s in sources)
|
||||||
|
|
||||||
|
stockfish_depth = 12
|
||||||
|
if needs_labeling:
|
||||||
|
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||||
|
stockfish_path = Prompt.ask(
|
||||||
|
"Stockfish path",
|
||||||
|
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||||
|
)
|
||||||
|
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||||
|
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||||
|
|
||||||
|
# Summary and confirm
|
||||||
|
console.print("\n[bold]Dataset Summary:[/bold]")
|
||||||
|
console.print(f" Total positions: {combined_count:,}")
|
||||||
|
for source in sources:
|
||||||
|
console.print(f" - {source['type']}: {source['count']:,}")
|
||||||
|
if needs_labeling:
|
||||||
|
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||||
|
|
||||||
|
if not Confirm.ask("\nProceed to create dataset?", default=True):
|
||||||
|
console.print("[yellow]Cancelled[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||||
|
labeled_file.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
# --- Step 1: Collect already-labeled data (Lichess source) ---
|
||||||
|
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||||
|
if lichess_tmp.exists():
|
||||||
|
import shutil as _shutil
|
||||||
|
_shutil.copy(lichess_tmp, labeled_file)
|
||||||
|
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
|
||||||
|
console.print(f"[green]✓ Lichess positions ready[/green]")
|
||||||
|
|
||||||
|
# --- Step 2: Combine unlabeled sources and run Stockfish (if any) ---
|
||||||
|
non_lichess = [s for s in sources if s["type"] != "lichess"]
|
||||||
|
if non_lichess:
|
||||||
|
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
|
||||||
|
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||||
|
all_fens = set()
|
||||||
|
|
||||||
|
for source in non_lichess:
|
||||||
|
if source["type"] == "generated":
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||||
|
elif source["type"] == "file_import":
|
||||||
|
temp_file = Path(source["path"])
|
||||||
|
elif source["type"] == "tactical":
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if temp_file.exists():
|
||||||
|
with open(temp_file, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
fen = line.strip()
|
||||||
|
if fen:
|
||||||
|
all_fens.add(fen)
|
||||||
|
|
||||||
|
with open(combined_fen_file, "w") as f:
|
||||||
|
for fen in all_fens:
|
||||||
|
f.write(fen + "\n")
|
||||||
|
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
|
||||||
|
success = label_positions_with_stockfish(
|
||||||
|
str(combined_fen_file),
|
||||||
|
str(labeled_file),
|
||||||
|
stockfish_path,
|
||||||
|
depth=stockfish_depth,
|
||||||
|
num_workers=num_workers,
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
console.print("[red]✗ Stockfish labeling failed[/red]")
|
||||||
|
return
|
||||||
|
console.print("[green]✓ Positions labeled[/green]")
|
||||||
|
|
||||||
|
# --- Step 3: Create dataset ---
|
||||||
|
console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]")
|
||||||
|
version = next_dataset_version()
|
||||||
|
create_dataset(
|
||||||
|
version=version,
|
||||||
|
labeled_jsonl_path=str(labeled_file),
|
||||||
|
sources=sources,
|
||||||
|
stockfish_depth=stockfish_depth,
|
||||||
|
)
|
||||||
|
console.print(f"[green]✓ Dataset created: ds_v{version}[/green]")
|
||||||
|
console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]")
|
||||||
|
|
||||||
|
Prompt.ask("\nPress Enter to continue")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Error: {e}[/red]")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
|
||||||
|
def extend_dataset_interactive():
|
||||||
|
"""Interactive dataset extension flow."""
|
||||||
|
console = Console()
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]📊 Extend Existing Dataset[/bold cyan]")
|
||||||
|
|
||||||
|
datasets = list_datasets()
|
||||||
|
if not datasets:
|
||||||
|
console.print("[yellow]ℹ No datasets available to extend[/yellow]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
return
|
||||||
|
|
||||||
|
show_datasets_table(console)
|
||||||
|
version = int(Prompt.ask("\nEnter dataset version to extend (e.g., 1)"))
|
||||||
|
|
||||||
|
if not any(v == version for v, _ in datasets):
|
||||||
|
console.print("[red]✗ Dataset not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
sources = []
|
||||||
|
combined_count = 0
|
||||||
|
|
||||||
|
# Allow user to add sources
|
||||||
|
while True:
|
||||||
|
console.print("\n[bold]Add data source:[/bold]")
|
||||||
|
console.print("[cyan]a[/cyan] - Generate random positions")
|
||||||
|
console.print("[cyan]b[/cyan] - Import from file")
|
||||||
|
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
|
||||||
|
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
|
||||||
|
console.print("[cyan]e[/cyan] - Done adding sources")
|
||||||
|
|
||||||
|
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
|
||||||
|
|
||||||
|
if choice == "a":
|
||||||
|
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
|
||||||
|
min_move = int(Prompt.ask("Minimum move number", default="1"))
|
||||||
|
max_move = int(Prompt.ask("Maximum move number", default="50"))
|
||||||
|
num_workers = int(Prompt.ask("Number of workers", default="8"))
|
||||||
|
|
||||||
|
console.print("[dim]Generating positions...[/dim]")
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||||
|
count = play_random_game_and_collect_positions(
|
||||||
|
str(temp_file),
|
||||||
|
total_positions=num_positions,
|
||||||
|
samples_per_game=1,
|
||||||
|
min_move=min_move,
|
||||||
|
max_move=max_move,
|
||||||
|
num_workers=num_workers
|
||||||
|
)
|
||||||
|
if count > 0:
|
||||||
|
sources.append({
|
||||||
|
"type": "generated",
|
||||||
|
"count": count,
|
||||||
|
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
|
||||||
|
})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} positions generated[/green]")
|
||||||
|
|
||||||
|
elif choice == "b":
|
||||||
|
file_path = Prompt.ask("Path to FEN file")
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
count = sum(1 for _ in f)
|
||||||
|
sources.append({"type": "file_import", "count": count, "path": file_path})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} positions from file[/green]")
|
||||||
|
except FileNotFoundError:
|
||||||
|
console.print(f"[red]✗ File not found: {file_path}[/red]")
|
||||||
|
|
||||||
|
elif choice == "c":
|
||||||
|
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
|
||||||
|
console.print("[dim]Extracting tactical positions...[/dim]")
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||||
|
try:
|
||||||
|
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
|
||||||
|
if csv_path:
|
||||||
|
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
|
||||||
|
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Extraction failed: {e}[/red]")
|
||||||
|
|
||||||
|
elif choice == "d":
|
||||||
|
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
|
||||||
|
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
|
||||||
|
max_pos = int(max_pos) if max_pos.strip() else None
|
||||||
|
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
|
||||||
|
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||||
|
temp_file.unlink(missing_ok=True)
|
||||||
|
try:
|
||||||
|
count = import_lichess_evals(
|
||||||
|
input_path=zst_path,
|
||||||
|
output_file=str(temp_file),
|
||||||
|
max_positions=max_pos,
|
||||||
|
min_depth=min_depth,
|
||||||
|
)
|
||||||
|
if count > 0:
|
||||||
|
sources.append({
|
||||||
|
"type": "lichess",
|
||||||
|
"count": count,
|
||||||
|
"params": {"min_depth": min_depth, "max_positions": max_pos},
|
||||||
|
})
|
||||||
|
combined_count += count
|
||||||
|
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
|
||||||
|
else:
|
||||||
|
console.print("[red]✗ No positions imported[/red]")
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
|
||||||
|
|
||||||
|
elif choice == "e":
|
||||||
|
if not sources:
|
||||||
|
console.print("[yellow]⚠ No sources added yet[/yellow]")
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
if not sources:
|
||||||
|
console.print("[yellow]Extension cancelled[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
needs_labeling = any(s["type"] != "lichess" for s in sources)
|
||||||
|
|
||||||
|
stockfish_depth = 12
|
||||||
|
if needs_labeling:
|
||||||
|
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
|
||||||
|
stockfish_path = Prompt.ask(
|
||||||
|
"Stockfish path",
|
||||||
|
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
|
||||||
|
)
|
||||||
|
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
|
||||||
|
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
|
||||||
|
|
||||||
|
# Summary and confirm
|
||||||
|
console.print("\n[bold]Extension Summary:[/bold]")
|
||||||
|
console.print(f" Target dataset: ds_v{version}")
|
||||||
|
console.print(f" New positions: {combined_count:,}")
|
||||||
|
for source in sources:
|
||||||
|
console.print(f" - {source['type']}: {source['count']:,}")
|
||||||
|
if needs_labeling:
|
||||||
|
console.print(f" Stockfish depth: {stockfish_depth}")
|
||||||
|
|
||||||
|
if not Confirm.ask("\nProceed to extend dataset?", default=True):
|
||||||
|
console.print("[yellow]Cancelled[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
|
||||||
|
labeled_file.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
# Copy pre-labeled Lichess data if present
|
||||||
|
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
|
||||||
|
if lichess_tmp.exists():
|
||||||
|
import shutil as _shutil
|
||||||
|
_shutil.copy(lichess_tmp, labeled_file)
|
||||||
|
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
|
||||||
|
console.print(f"[green]✓ Lichess positions ready[/green]")
|
||||||
|
|
||||||
|
# Combine and label remaining sources with Stockfish
|
||||||
|
non_lichess = [s for s in sources if s["type"] != "lichess"]
|
||||||
|
if non_lichess:
|
||||||
|
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
|
||||||
|
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
|
||||||
|
all_fens = set()
|
||||||
|
|
||||||
|
for source in non_lichess:
|
||||||
|
if source["type"] == "generated":
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
|
||||||
|
elif source["type"] == "file_import":
|
||||||
|
temp_file = Path(source["path"])
|
||||||
|
elif source["type"] == "tactical":
|
||||||
|
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if temp_file.exists():
|
||||||
|
with open(temp_file, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
fen = line.strip()
|
||||||
|
if fen:
|
||||||
|
all_fens.add(fen)
|
||||||
|
|
||||||
|
with open(combined_fen_file, "w") as f:
|
||||||
|
for fen in all_fens:
|
||||||
|
f.write(fen + "\n")
|
||||||
|
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
|
||||||
|
success = label_positions_with_stockfish(
|
||||||
|
str(combined_fen_file),
|
||||||
|
str(labeled_file),
|
||||||
|
stockfish_path,
|
||||||
|
depth=stockfish_depth,
|
||||||
|
num_workers=num_workers,
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
console.print("[red]✗ Stockfish labeling failed[/red]")
|
||||||
|
return
|
||||||
|
console.print("[green]✓ Positions labeled[/green]")
|
||||||
|
|
||||||
|
# Extend dataset
|
||||||
|
console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]")
|
||||||
|
success = extend_dataset(
|
||||||
|
version=version,
|
||||||
|
new_labeled_path=str(labeled_file),
|
||||||
|
new_source_entry={
|
||||||
|
"type": "merged_sources",
|
||||||
|
"count": combined_count,
|
||||||
|
"sources": sources,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
metadata = load_dataset_metadata(version)
|
||||||
|
console.print(f"[green]✓ Dataset extended[/green]")
|
||||||
|
console.print(f"[bold]Total positions: {metadata['total_positions']:,}[/bold]")
|
||||||
|
else:
|
||||||
|
console.print("[red]✗ Extension failed[/red]")
|
||||||
|
|
||||||
|
Prompt.ask("\nPress Enter to continue")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Error: {e}[/red]")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
|
||||||
|
def delete_dataset_interactive():
|
||||||
|
"""Interactive dataset deletion."""
|
||||||
|
console = Console()
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]⚠️ Delete Dataset[/bold cyan]")
|
||||||
|
|
||||||
|
datasets = list_datasets()
|
||||||
|
if not datasets:
|
||||||
|
console.print("[yellow]ℹ No datasets to delete[/yellow]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
return
|
||||||
|
|
||||||
|
show_datasets_table(console)
|
||||||
|
version = int(Prompt.ask("\nEnter dataset version to delete (e.g., 1)"))
|
||||||
|
|
||||||
|
if not any(v == version for v, _ in datasets):
|
||||||
|
console.print("[red]✗ Dataset not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
if Confirm.ask(f"Delete ds_v{version}? This cannot be undone.", default=False):
|
||||||
|
if delete_dataset(version):
|
||||||
|
console.print(f"[green]✓ Dataset ds_v{version} deleted[/green]")
|
||||||
|
else:
|
||||||
|
console.print("[red]✗ Deletion failed[/red]")
|
||||||
|
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
|
||||||
|
def training_menu():
|
||||||
|
"""Training submenu."""
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold]Training[/bold]")
|
||||||
|
console.print("[cyan]1[/cyan] - Standard Training")
|
||||||
|
console.print("[cyan]2[/cyan] - Burst Training")
|
||||||
|
console.print("[cyan]3[/cyan] - View Model Checkpoints")
|
||||||
|
console.print("[cyan]4[/cyan] - Back")
|
||||||
|
|
||||||
|
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
|
||||||
|
|
||||||
|
if choice == "1":
|
||||||
|
train_interactive()
|
||||||
|
elif choice == "2":
|
||||||
|
burst_train_interactive()
|
||||||
|
elif choice == "3":
|
||||||
|
show_header()
|
||||||
|
show_checkpoints_table()
|
||||||
|
Prompt.ask("\nPress Enter to continue")
|
||||||
|
elif choice == "4":
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def train_interactive():
|
||||||
|
"""Interactive training menu."""
|
||||||
|
console = Console()
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]📚 Standard Training Configuration[/bold cyan]")
|
||||||
|
|
||||||
|
# Dataset selection
|
||||||
|
datasets = list_datasets()
|
||||||
|
if not datasets:
|
||||||
|
console.print("[red]✗ No datasets available. Create one first.[/red]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print("\n[bold]Available Datasets:[/bold]")
|
||||||
|
show_datasets_table(console)
|
||||||
|
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
|
||||||
|
|
||||||
|
if not any(v == dataset_version for v, _ in datasets):
|
||||||
|
console.print("[red]✗ Dataset not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
labeled_file = get_dataset_labeled_path(dataset_version)
|
||||||
|
if not labeled_file:
|
||||||
|
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Checkpoint selection
|
||||||
|
available = list_checkpoints()
|
||||||
|
use_checkpoint = False
|
||||||
|
checkpoint_version = None
|
||||||
|
|
||||||
|
if available:
|
||||||
|
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||||
|
use_checkpoint = Confirm.ask("Start from an existing checkpoint?", default=False)
|
||||||
|
if use_checkpoint:
|
||||||
|
checkpoint_version = Prompt.ask(
|
||||||
|
"Enter checkpoint version",
|
||||||
|
default=str(max(available))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
epochs = int(Prompt.ask("Number of epochs", default="100"))
|
||||||
|
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||||
|
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||||
|
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
|
||||||
|
hidden_layers_str = Prompt.ask(
|
||||||
|
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
|
||||||
|
default=default_layers
|
||||||
|
)
|
||||||
|
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
|
||||||
|
early_stopping = None
|
||||||
|
if Confirm.ask("Enable early stopping?", default=False):
|
||||||
|
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
|
||||||
|
|
||||||
|
arch_str = " → ".join(str(s) for s in [768] + hidden_sizes + [1])
|
||||||
|
|
||||||
|
# Confirm and start
|
||||||
|
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||||
|
console.print(f" Dataset: ds_v{dataset_version}")
|
||||||
|
console.print(f" Architecture: {arch_str}")
|
||||||
|
console.print(f" Epochs: {epochs}")
|
||||||
|
console.print(f" Batch size: {batch_size}")
|
||||||
|
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||||
|
if early_stopping:
|
||||||
|
console.print(f" Early stopping: Yes (patience: {early_stopping})")
|
||||||
|
else:
|
||||||
|
console.print(f" Early stopping: No")
|
||||||
|
if use_checkpoint:
|
||||||
|
console.print(f" Checkpoint: v{checkpoint_version}")
|
||||||
|
else:
|
||||||
|
console.print(f" Checkpoint: None (training from scratch)")
|
||||||
|
|
||||||
|
if not Confirm.ask("\nStart training?", default=True):
|
||||||
|
console.print("[yellow]Training cancelled[/yellow]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Execute training
|
||||||
|
weights_dir = get_weights_dir()
|
||||||
|
|
||||||
|
try:
|
||||||
|
console.print("\n[bold cyan]Training Model[/bold cyan]")
|
||||||
|
checkpoint = None
|
||||||
|
if use_checkpoint:
|
||||||
|
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
|
||||||
|
|
||||||
|
train_nnue(
|
||||||
|
data_file=str(labeled_file),
|
||||||
|
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||||
|
epochs=epochs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
use_versioning=True,
|
||||||
|
early_stopping_patience=early_stopping,
|
||||||
|
subsample_ratio=subsample_ratio,
|
||||||
|
hidden_sizes=hidden_sizes,
|
||||||
|
)
|
||||||
|
console.print("[green]✓ Training complete[/green]")
|
||||||
|
|
||||||
|
# Show result
|
||||||
|
available = list_checkpoints()
|
||||||
|
new_version = max(available) if available else 1
|
||||||
|
console.print(f"\n[bold green]✓ Training successful![/bold green]")
|
||||||
|
console.print(f"[bold]New checkpoint: v{new_version}[/bold]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Error: {e}[/red]")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
|
||||||
|
def burst_train_interactive():
|
||||||
|
"""Interactive burst training menu."""
|
||||||
|
console = Console()
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
|
||||||
|
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
|
||||||
|
|
||||||
|
# Dataset selection
|
||||||
|
datasets = list_datasets()
|
||||||
|
if not datasets:
|
||||||
|
console.print("[red]✗ No datasets available. Create one first.[/red]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print("[bold]Available Datasets:[/bold]")
|
||||||
|
show_datasets_table(console)
|
||||||
|
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
|
||||||
|
|
||||||
|
if not any(v == dataset_version for v, _ in datasets):
|
||||||
|
console.print("[red]✗ Dataset not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
labeled_file = get_dataset_labeled_path(dataset_version)
|
||||||
|
if not labeled_file:
|
||||||
|
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
|
||||||
|
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
|
||||||
|
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
|
||||||
|
|
||||||
|
# Optional initial checkpoint
|
||||||
|
available = list_checkpoints()
|
||||||
|
checkpoint = None
|
||||||
|
if available:
|
||||||
|
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||||
|
if Confirm.ask("Start from an existing checkpoint?", default=False):
|
||||||
|
version = Prompt.ask("Enter checkpoint version", default=str(max(available)))
|
||||||
|
checkpoint = str(get_weights_dir() / f"nnue_weights_v{version}.pt")
|
||||||
|
|
||||||
|
# Training hyperparameters
|
||||||
|
batch_size = int(Prompt.ask("Batch size", default="16384"))
|
||||||
|
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
|
||||||
|
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
|
||||||
|
hidden_layers_str = Prompt.ask(
|
||||||
|
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
|
||||||
|
default=default_layers
|
||||||
|
)
|
||||||
|
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
|
||||||
|
arch_str = " → ".join(str(s) for s in [768] + hidden_sizes + [1])
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
console.print("\n[bold]Configuration Summary:[/bold]")
|
||||||
|
console.print(f" Dataset: ds_v{dataset_version}")
|
||||||
|
console.print(f" Architecture: {arch_str}")
|
||||||
|
console.print(f" Duration: {duration_minutes:.0f} minutes")
|
||||||
|
console.print(f" Epochs per season: {epochs_per_season}")
|
||||||
|
console.print(f" Patience: {early_stopping_patience}")
|
||||||
|
console.print(f" Batch size: {batch_size}")
|
||||||
|
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
|
||||||
|
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
|
||||||
|
|
||||||
|
if not Confirm.ask("\nStart burst training?", default=True):
|
||||||
|
console.print("[yellow]Burst training cancelled[/yellow]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
return
|
||||||
|
|
||||||
|
weights_dir = get_weights_dir()
|
||||||
|
|
||||||
|
try:
|
||||||
|
console.print("\n[bold cyan]Burst Training[/bold cyan]")
|
||||||
|
burst_train(
|
||||||
|
data_file=str(labeled_file),
|
||||||
|
output_file=str(weights_dir / "nnue_weights.pt"),
|
||||||
|
duration_minutes=duration_minutes,
|
||||||
|
epochs_per_season=epochs_per_season,
|
||||||
|
early_stopping_patience=early_stopping_patience,
|
||||||
|
batch_size=batch_size,
|
||||||
|
initial_checkpoint=checkpoint,
|
||||||
|
use_versioning=True,
|
||||||
|
subsample_ratio=subsample_ratio,
|
||||||
|
hidden_sizes=hidden_sizes,
|
||||||
|
)
|
||||||
|
console.print("[green]✓ Burst training complete[/green]")
|
||||||
|
|
||||||
|
available = list_checkpoints()
|
||||||
|
if available:
|
||||||
|
console.print(f"[bold]Latest checkpoint: v{max(available)}[/bold]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Error: {e}[/red]")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
|
||||||
|
def export_interactive():
|
||||||
|
"""Interactive export menu."""
|
||||||
|
console = Console()
|
||||||
|
show_header()
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]📦 Export Configuration[/bold cyan]")
|
||||||
|
|
||||||
|
# Select weights version
|
||||||
|
available = list_checkpoints()
|
||||||
|
if not available:
|
||||||
|
console.print("[red]✗ No checkpoints available to export[/red]")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print(f"[dim]Available versions: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
|
||||||
|
version = Prompt.ask("Enter version to export (e.g., 2)")
|
||||||
|
|
||||||
|
weights_file = f"nnue_weights_v{version}.pt"
|
||||||
|
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.nbai")
|
||||||
|
|
||||||
|
console.print(f"\n[bold]Export Configuration:[/bold]")
|
||||||
|
console.print(f" Source: {weights_file}")
|
||||||
|
console.print(f" Destination: {output_file}")
|
||||||
|
|
||||||
|
if not Confirm.ask("\nExport weights?", default=True):
|
||||||
|
console.print("[yellow]Export cancelled[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
weights_dir = get_weights_dir()
|
||||||
|
weights_path = weights_dir / weights_file
|
||||||
|
|
||||||
|
if not weights_path.exists():
|
||||||
|
console.print(f"[red]✗ {weights_file} not found[/red]")
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
|
||||||
|
export_to_nbai(str(weights_path), output_file)
|
||||||
|
console.print(f"\n[green]✓ Export complete![/green]")
|
||||||
|
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Error: {e}[/red]")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
Prompt.ask("Press Enter to continue")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
try:
|
||||||
|
show_main_menu()
|
||||||
|
return 0
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
console = Console()
|
||||||
|
console.print("\n[yellow]Interrupted by user[/yellow]")
|
||||||
|
return 1
|
||||||
|
except Exception as e:
|
||||||
|
console = Console()
|
||||||
|
console.print(f"[red]Error:[/red] {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
chess==1.11.2
|
||||||
|
torch==2.11.0
|
||||||
|
tqdm==4.67.3
|
||||||
|
numpy==2.4.4
|
||||||
|
rich==13.7.0
|
||||||
|
zstandard==0.23.0
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
@echo off
|
||||||
|
REM NNUE Training Pipeline for Windows
|
||||||
|
|
||||||
|
setlocal enabledelayedexpansion
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo === NNUE Training Pipeline ===
|
||||||
|
echo.
|
||||||
|
|
||||||
|
REM Get the directory where this script is located
|
||||||
|
set SCRIPT_DIR=%~dp0
|
||||||
|
|
||||||
|
cd /d "%SCRIPT_DIR%"
|
||||||
|
|
||||||
|
REM Step 1: Generate positions
|
||||||
|
echo Step 1: Generating 500,000 random positions...
|
||||||
|
python generate_positions.py positions.txt
|
||||||
|
if not exist positions.txt (
|
||||||
|
echo ERROR: positions.txt not created
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
echo [OK] Positions generated
|
||||||
|
echo.
|
||||||
|
|
||||||
|
REM Step 2: Label positions with Stockfish
|
||||||
|
echo Step 2: Labeling positions with Stockfish (depth 12^)...
|
||||||
|
if "%STOCKFISH_PATH%"=="" (
|
||||||
|
set STOCKFISH_PATH=stockfish
|
||||||
|
)
|
||||||
|
python label_positions.py positions.txt training_data.jsonl "%STOCKFISH_PATH%"
|
||||||
|
if not exist training_data.jsonl (
|
||||||
|
echo ERROR: training_data.jsonl not created
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
echo [OK] Positions labeled
|
||||||
|
echo.
|
||||||
|
|
||||||
|
REM Step 3: Train NNUE model
|
||||||
|
echo Step 3: Training NNUE model (20 epochs^)...
|
||||||
|
python train_nnue.py training_data.jsonl nnue_weights.pt
|
||||||
|
if not exist nnue_weights.pt (
|
||||||
|
echo ERROR: nnue_weights.pt not created
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
echo [OK] Model trained
|
||||||
|
echo.
|
||||||
|
|
||||||
|
REM Step 4: Export weights to Scala
|
||||||
|
echo Step 4: Exporting weights to Scala...
|
||||||
|
python export_weights.py nnue_weights.pt ..\src\main\scala\de\nowchess\bot\bots\nnue\NNUEWeights.scala
|
||||||
|
if not exist ..\src\main\scala\de\nowchess\bot\bots\nnue\NNUEWeights.scala (
|
||||||
|
echo ERROR: NNUEWeights.scala not created
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
echo [OK] Weights exported
|
||||||
|
echo.
|
||||||
|
|
||||||
|
echo === Pipeline Complete ===
|
||||||
|
echo.
|
||||||
|
echo Next steps:
|
||||||
|
echo 1. Navigate to project root: cd ..\..
|
||||||
|
echo 2. Compile: .\compile.bat
|
||||||
|
echo 3. Test: .\test.bat
|
||||||
|
echo.
|
||||||
|
|
||||||
|
endlocal
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# NNUE Training Pipeline (bash version)
|
||||||
|
# Uses the central CLI (nnue.py) for all operations
|
||||||
|
# Works on Linux, macOS, and Windows (with Git Bash or WSL)
|
||||||
|
|
||||||
|
set -e # Exit on error
|
||||||
|
|
||||||
|
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||||
|
cd "$SCRIPT_DIR"
|
||||||
|
|
||||||
|
# Use python or python3 (check which is available)
|
||||||
|
PYTHON_CMD="python3"
|
||||||
|
if ! command -v python3 &> /dev/null; then
|
||||||
|
PYTHON_CMD="python"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "=== NNUE Training Pipeline ==="
|
||||||
|
echo ""
|
||||||
|
echo "Python command: $PYTHON_CMD"
|
||||||
|
echo "Working directory: $SCRIPT_DIR"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Run the unified training pipeline
|
||||||
|
$PYTHON_CMD nnue.py train
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo ""
|
||||||
|
echo "ERROR: Training pipeline failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== Pipeline Complete ==="
|
||||||
|
echo ""
|
||||||
|
echo "Next steps:"
|
||||||
|
echo "1. Navigate to project root: cd ../.."
|
||||||
|
echo "2. Compile: ./compile"
|
||||||
|
echo "3. Test: ./test"
|
||||||
@@ -0,0 +1,287 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Dataset versioning and management for NNUE training data."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, List, Tuple
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
|
||||||
|
def get_datasets_dir() -> Path:
|
||||||
|
"""Get/create datasets directory."""
|
||||||
|
datasets_dir = Path(__file__).parent.parent / "datasets"
|
||||||
|
datasets_dir.mkdir(exist_ok=True)
|
||||||
|
return datasets_dir
|
||||||
|
|
||||||
|
|
||||||
|
def next_dataset_version() -> int:
|
||||||
|
"""Find the next available dataset version number."""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
versions = []
|
||||||
|
|
||||||
|
for d in datasets_dir.iterdir():
|
||||||
|
if d.is_dir() and d.name.startswith("ds_v"):
|
||||||
|
try:
|
||||||
|
v = int(d.name.split("_v")[1])
|
||||||
|
versions.append(v)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return max(versions) + 1 if versions else 1
|
||||||
|
|
||||||
|
|
||||||
|
def list_datasets() -> List[Tuple[int, Dict]]:
|
||||||
|
"""List all datasets with their metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (version, metadata_dict) tuples, sorted by version.
|
||||||
|
"""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
datasets = []
|
||||||
|
|
||||||
|
for d in datasets_dir.iterdir():
|
||||||
|
if d.is_dir() and d.name.startswith("ds_v"):
|
||||||
|
try:
|
||||||
|
v = int(d.name.split("_v")[1])
|
||||||
|
metadata_file = d / "metadata.json"
|
||||||
|
if metadata_file.exists():
|
||||||
|
with open(metadata_file, 'r') as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
datasets.append((v, metadata))
|
||||||
|
except (ValueError, IndexError, json.JSONDecodeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return sorted(datasets, key=lambda x: x[0])
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_metadata(version: int) -> Optional[Dict]:
|
||||||
|
"""Load metadata for a specific dataset version.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Metadata dict or None if not found.
|
||||||
|
"""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
metadata_file = datasets_dir / f"ds_v{version}" / "metadata.json"
|
||||||
|
|
||||||
|
if not metadata_file.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(metadata_file, 'r') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def save_dataset_metadata(version: int, metadata: Dict) -> None:
|
||||||
|
"""Save metadata for a dataset version."""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||||
|
dataset_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
metadata_file = dataset_dir / "metadata.json"
|
||||||
|
with open(metadata_file, 'w') as f:
|
||||||
|
json.dump(metadata, f, indent=2, default=str)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(
|
||||||
|
version: int,
|
||||||
|
labeled_jsonl_path: str,
|
||||||
|
sources: List[Dict],
|
||||||
|
stockfish_depth: int = 12
|
||||||
|
) -> Path:
|
||||||
|
"""Create a new versioned dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: Dataset version number
|
||||||
|
labeled_jsonl_path: Path to labeled.jsonl to copy
|
||||||
|
sources: List of source dicts (see plan for schema)
|
||||||
|
stockfish_depth: Depth used for labeling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the created dataset directory.
|
||||||
|
"""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||||
|
dataset_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Copy labeled data with deduplication (in case source has duplicates)
|
||||||
|
source_path = Path(labeled_jsonl_path)
|
||||||
|
if source_path.exists():
|
||||||
|
dest_path = dataset_dir / "labeled.jsonl"
|
||||||
|
seen_fens = set()
|
||||||
|
unique_count = 0
|
||||||
|
|
||||||
|
with open(source_path, 'r') as src, open(dest_path, 'w') as dst:
|
||||||
|
for line in src:
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
fen = data.get('fen')
|
||||||
|
if fen and fen not in seen_fens:
|
||||||
|
dst.write(line)
|
||||||
|
seen_fens.add(fen)
|
||||||
|
unique_count += 1
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Skip malformed lines
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Count positions
|
||||||
|
total_positions = 0
|
||||||
|
if (dataset_dir / "labeled.jsonl").exists():
|
||||||
|
with open(dataset_dir / "labeled.jsonl", 'r') as f:
|
||||||
|
total_positions = sum(1 for _ in f)
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"version": version,
|
||||||
|
"created": datetime.now().isoformat(),
|
||||||
|
"total_positions": total_positions,
|
||||||
|
"stockfish_depth": stockfish_depth,
|
||||||
|
"sources": sources
|
||||||
|
}
|
||||||
|
|
||||||
|
save_dataset_metadata(version, metadata)
|
||||||
|
return dataset_dir
|
||||||
|
|
||||||
|
|
||||||
|
def extend_dataset(
|
||||||
|
version: int,
|
||||||
|
new_labeled_path: str,
|
||||||
|
new_source_entry: Dict
|
||||||
|
) -> bool:
|
||||||
|
"""Extend an existing dataset with new labeled positions (with deduplication).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: Dataset version to extend
|
||||||
|
new_labeled_path: Path to new labeled.jsonl to merge
|
||||||
|
new_source_entry: Source entry to add to metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise.
|
||||||
|
"""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||||
|
|
||||||
|
if not dataset_dir.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
labeled_file = dataset_dir / "labeled.jsonl"
|
||||||
|
new_labeled_file = Path(new_labeled_path)
|
||||||
|
|
||||||
|
if not new_labeled_file.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Load existing FENs (dedup set) — must load entire file to avoid duplicates
|
||||||
|
existing_fens = set()
|
||||||
|
if labeled_file.exists():
|
||||||
|
with open(labeled_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
fen = data.get('fen')
|
||||||
|
if fen:
|
||||||
|
existing_fens.add(fen)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Merge new positions, skipping duplicates
|
||||||
|
new_count = 0
|
||||||
|
new_lines = []
|
||||||
|
with open(new_labeled_file, 'r') as f_new:
|
||||||
|
for line in f_new:
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
fen = data.get('fen')
|
||||||
|
if fen and fen not in existing_fens:
|
||||||
|
new_lines.append(line)
|
||||||
|
existing_fens.add(fen)
|
||||||
|
new_count += 1
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Append only the new, unique positions
|
||||||
|
if new_lines:
|
||||||
|
with open(labeled_file, 'a') as f_append:
|
||||||
|
for line in new_lines:
|
||||||
|
f_append.write(line)
|
||||||
|
|
||||||
|
# Update metadata
|
||||||
|
metadata = load_dataset_metadata(version)
|
||||||
|
if metadata:
|
||||||
|
# Count total positions
|
||||||
|
total_positions = 0
|
||||||
|
with open(labeled_file, 'r') as f:
|
||||||
|
total_positions = sum(1 for _ in f)
|
||||||
|
|
||||||
|
metadata['total_positions'] = total_positions
|
||||||
|
# Update the source entry with actual count of new positions added
|
||||||
|
new_source_entry['actual_count'] = new_count
|
||||||
|
metadata['sources'].append(new_source_entry)
|
||||||
|
save_dataset_metadata(version, metadata)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_labeled_path(version: int) -> Optional[Path]:
|
||||||
|
"""Get the path to a dataset's labeled.jsonl file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to labeled.jsonl or None if dataset doesn't exist.
|
||||||
|
"""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
labeled_file = datasets_dir / f"ds_v{version}" / "labeled.jsonl"
|
||||||
|
|
||||||
|
if labeled_file.exists():
|
||||||
|
return labeled_file
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def delete_dataset(version: int) -> bool:
|
||||||
|
"""Delete a dataset (recursively removes directory).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: Dataset version to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful.
|
||||||
|
"""
|
||||||
|
datasets_dir = get_datasets_dir()
|
||||||
|
dataset_dir = datasets_dir / f"ds_v{version}"
|
||||||
|
|
||||||
|
if not dataset_dir.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(dataset_dir)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def show_datasets_table(console: Console = None) -> None:
|
||||||
|
"""Display all datasets in a Rich table."""
|
||||||
|
if console is None:
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
datasets = list_datasets()
|
||||||
|
|
||||||
|
if not datasets:
|
||||||
|
console.print("[yellow]ℹ No datasets found yet[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
table = Table(title="Available Datasets", show_header=True, header_style="bold cyan")
|
||||||
|
table.add_column("Version", style="dim")
|
||||||
|
table.add_column("Positions", justify="right")
|
||||||
|
table.add_column("Sources", justify="left")
|
||||||
|
table.add_column("Depth", justify="center")
|
||||||
|
table.add_column("Created", justify="left")
|
||||||
|
|
||||||
|
for v, metadata in datasets:
|
||||||
|
positions = metadata.get('total_positions', 0)
|
||||||
|
sources = metadata.get('sources', [])
|
||||||
|
source_str = ", ".join([s.get('type', '?') for s in sources])
|
||||||
|
depth = metadata.get('stockfish_depth', '?')
|
||||||
|
created = metadata.get('created', '?')
|
||||||
|
if created != '?':
|
||||||
|
created = created.split('T')[0] # Just the date
|
||||||
|
|
||||||
|
table.add_row(f"v{v}", f"{positions:,}", source_str, str(depth), created)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Export NNUE weights to .nbai format for runtime loading."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
MAGIC = 0x4942_414E # bytes 'N','B','A','I' as little-endian int32
|
||||||
|
VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
|
def _read_sidecar(weights_file: str) -> dict:
|
||||||
|
sidecar = weights_file.replace(".pt", "_metadata.json")
|
||||||
|
if Path(sidecar).exists():
|
||||||
|
with open(sidecar) as f:
|
||||||
|
return json.load(f)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_layers(state_dict: dict) -> list[dict]:
|
||||||
|
"""Derive layer descriptors from state_dict weight shapes.
|
||||||
|
|
||||||
|
Assumes layers named l1, l2, ..., lN.
|
||||||
|
All hidden layers get activation 'relu'; the last gets 'linear'.
|
||||||
|
"""
|
||||||
|
names = sorted(
|
||||||
|
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
|
||||||
|
key=lambda n: int(n[1:]),
|
||||||
|
)
|
||||||
|
layers = []
|
||||||
|
for i, name in enumerate(names):
|
||||||
|
out_size, in_size = state_dict[f"{name}.weight"].shape
|
||||||
|
activation = "linear" if i == len(names) - 1 else "relu"
|
||||||
|
layers.append({"activation": activation, "inputSize": int(in_size), "outputSize": int(out_size)})
|
||||||
|
return layers
|
||||||
|
|
||||||
|
|
||||||
|
def _write_floats(f, tensor):
|
||||||
|
data = tensor.float().flatten().cpu().numpy()
|
||||||
|
f.write(struct.pack("<I", len(data)))
|
||||||
|
f.write(struct.pack(f"<{len(data)}f", *data))
|
||||||
|
|
||||||
|
|
||||||
|
def export_to_nbai(
|
||||||
|
weights_file: str,
|
||||||
|
output_file: str,
|
||||||
|
trained_by: str = "unknown",
|
||||||
|
train_loss: float = 0.0,
|
||||||
|
):
|
||||||
|
if not Path(weights_file).exists():
|
||||||
|
print(f"Error: weights file not found at {weights_file}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
loaded = torch.load(weights_file, map_location="cpu")
|
||||||
|
state_dict = (
|
||||||
|
loaded["model_state_dict"]
|
||||||
|
if isinstance(loaded, dict) and "model_state_dict" in loaded
|
||||||
|
else loaded
|
||||||
|
)
|
||||||
|
|
||||||
|
sidecar = _read_sidecar(weights_file)
|
||||||
|
val_loss = float(loaded.get("best_val_loss", sidecar.get("final_val_loss", 0.0))) if isinstance(loaded, dict) else 0.0
|
||||||
|
trained_at = sidecar.get("date", datetime.now().isoformat())
|
||||||
|
training_data_count = int(sidecar.get("num_positions", 0))
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"trainedBy": trained_by,
|
||||||
|
"trainedAt": trained_at,
|
||||||
|
"trainingDataCount": training_data_count,
|
||||||
|
"valLoss": val_loss,
|
||||||
|
"trainLoss": train_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
layers = _infer_layers(state_dict)
|
||||||
|
layer_names = sorted(
|
||||||
|
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
|
||||||
|
key=lambda n: int(n[1:]),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Architecture ({len(layers)} layers):")
|
||||||
|
for i, l in enumerate(layers):
|
||||||
|
print(f" l{i + 1}: {l['inputSize']} -> {l['outputSize']} [{l['activation']}]")
|
||||||
|
|
||||||
|
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(output_file, "wb") as f:
|
||||||
|
# Header
|
||||||
|
f.write(struct.pack("<I", MAGIC))
|
||||||
|
f.write(struct.pack("<H", VERSION))
|
||||||
|
|
||||||
|
# Metadata (length-prefixed UTF-8 JSON)
|
||||||
|
meta_bytes = json.dumps(metadata, indent=2).encode("utf-8")
|
||||||
|
f.write(struct.pack("<I", len(meta_bytes)))
|
||||||
|
f.write(meta_bytes)
|
||||||
|
|
||||||
|
# Layer descriptors
|
||||||
|
f.write(struct.pack("<H", len(layers)))
|
||||||
|
for layer in layers:
|
||||||
|
name_bytes = layer["activation"].encode("ascii")
|
||||||
|
f.write(struct.pack("<B", len(name_bytes)))
|
||||||
|
f.write(name_bytes)
|
||||||
|
f.write(struct.pack("<I", layer["inputSize"]))
|
||||||
|
f.write(struct.pack("<I", layer["outputSize"]))
|
||||||
|
|
||||||
|
# Weights: weight tensor then bias tensor per layer
|
||||||
|
for name in layer_names:
|
||||||
|
w = state_dict[f"{name}.weight"]
|
||||||
|
b = state_dict[f"{name}.bias"]
|
||||||
|
_write_floats(f, w)
|
||||||
|
_write_floats(f, b)
|
||||||
|
print(f" Wrote {name}: weight {tuple(w.shape)}, bias {tuple(b.shape)}")
|
||||||
|
|
||||||
|
size_mb = Path(output_file).stat().st_size / (1024 ** 2)
|
||||||
|
print(f"\nExported to {output_file} ({size_mb:.2f} MB)")
|
||||||
|
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
weights_file = "nnue_weights.pt"
|
||||||
|
output_file = "../src/main/resources/nnue_weights.nbai"
|
||||||
|
trained_by = "unknown"
|
||||||
|
train_loss = 0.0
|
||||||
|
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
weights_file = sys.argv[1]
|
||||||
|
if len(sys.argv) > 2:
|
||||||
|
output_file = sys.argv[2]
|
||||||
|
if len(sys.argv) > 3:
|
||||||
|
trained_by = sys.argv[3]
|
||||||
|
if len(sys.argv) > 4:
|
||||||
|
train_loss = float(sys.argv[4])
|
||||||
|
|
||||||
|
export_to_nbai(weights_file, output_file, trained_by, train_loss)
|
||||||
@@ -0,0 +1,171 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Generate random chess positions for NNUE training with multiprocessing."""
|
||||||
|
|
||||||
|
import chess
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
from multiprocessing import Pool, Queue
|
||||||
|
from datetime import datetime
|
||||||
|
import time
|
||||||
|
|
||||||
|
def _worker_generate_games(worker_id, games_per_worker, samples_per_game, min_move, max_move):
|
||||||
|
"""Generate games for one worker.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of FENs generated by this worker
|
||||||
|
"""
|
||||||
|
positions = []
|
||||||
|
|
||||||
|
for game_num in range(games_per_worker):
|
||||||
|
board = chess.Board()
|
||||||
|
move_history = []
|
||||||
|
|
||||||
|
# Play a complete random game
|
||||||
|
while not board.is_game_over() and len(move_history) < 200:
|
||||||
|
legal_moves = list(board.legal_moves)
|
||||||
|
if not legal_moves:
|
||||||
|
break
|
||||||
|
move = random.choice(legal_moves)
|
||||||
|
board.push(move)
|
||||||
|
move_history.append(board.copy())
|
||||||
|
|
||||||
|
# Determine the range of moves to sample from
|
||||||
|
game_length = len(move_history)
|
||||||
|
valid_start = max(min_move, 0)
|
||||||
|
valid_end = min(max_move, game_length)
|
||||||
|
|
||||||
|
if valid_start >= valid_end:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Randomly sample positions from this game
|
||||||
|
sample_count = min(samples_per_game, valid_end - valid_start)
|
||||||
|
if sample_count > 0:
|
||||||
|
sample_indices = random.sample(
|
||||||
|
range(valid_start, valid_end),
|
||||||
|
k=sample_count
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx in sample_indices:
|
||||||
|
sampled_board = move_history[idx]
|
||||||
|
|
||||||
|
# Only filter truly invalid or terminal positions
|
||||||
|
if not sampled_board.is_valid() or sampled_board.is_game_over():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Save position (include check, captures, all positions)
|
||||||
|
fen = sampled_board.fen()
|
||||||
|
positions.append(fen)
|
||||||
|
|
||||||
|
return positions
|
||||||
|
|
||||||
|
|
||||||
|
def play_random_game_and_collect_positions(
|
||||||
|
output_file,
|
||||||
|
total_positions=3000000,
|
||||||
|
samples_per_game=1,
|
||||||
|
min_move=1,
|
||||||
|
max_move=50,
|
||||||
|
num_workers=8
|
||||||
|
):
|
||||||
|
"""Generate positions using multiprocessing with multiple workers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Output file for positions
|
||||||
|
total_positions: Target number of positions to generate
|
||||||
|
samples_per_game: Number of positions to sample per game (1-N)
|
||||||
|
min_move: Minimum move number to start sampling from
|
||||||
|
max_move: Maximum move number for sampling
|
||||||
|
num_workers: Number of parallel worker processes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of valid positions saved
|
||||||
|
"""
|
||||||
|
# Estimate games needed (roughly 1 position per game on average)
|
||||||
|
total_games = max(total_positions // samples_per_game, num_workers)
|
||||||
|
games_per_worker = total_games // num_workers
|
||||||
|
|
||||||
|
print(f"Generating {total_positions:,} positions using {num_workers} workers")
|
||||||
|
print(f"Total games: ~{total_games:,} ({games_per_worker:,} per worker)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
# Generate positions in parallel
|
||||||
|
worker_tasks = [
|
||||||
|
(i, games_per_worker, samples_per_game, min_move, max_move)
|
||||||
|
for i in range(num_workers)
|
||||||
|
]
|
||||||
|
|
||||||
|
positions_count = 0
|
||||||
|
all_positions = []
|
||||||
|
|
||||||
|
with Pool(num_workers) as pool:
|
||||||
|
with tqdm(total=num_workers, desc="Workers generating games") as pbar:
|
||||||
|
for positions in pool.starmap(_worker_generate_games, worker_tasks):
|
||||||
|
all_positions.extend(positions)
|
||||||
|
positions_count += len(positions)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Write all positions to file
|
||||||
|
print(f"Writing {positions_count:,} positions to {output_file}...")
|
||||||
|
with open(output_file, 'w') as f:
|
||||||
|
for fen in all_positions:
|
||||||
|
f.write(fen + '\n')
|
||||||
|
|
||||||
|
elapsed_time = datetime.now() - start_time
|
||||||
|
elapsed_seconds = elapsed_time.total_seconds()
|
||||||
|
positions_per_second = positions_count / elapsed_seconds if elapsed_seconds > 0 else 0
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
print()
|
||||||
|
print("=" * 60)
|
||||||
|
print("POSITION GENERATION SUMMARY")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Target positions: {total_positions:,}")
|
||||||
|
print(f"Actual positions saved: {positions_count:,}")
|
||||||
|
print(f"Workers: {num_workers}")
|
||||||
|
print(f"Games per worker: {games_per_worker:,}")
|
||||||
|
print(f"Samples per game: {samples_per_game}")
|
||||||
|
print(f"Move range: {min_move}-{max_move}")
|
||||||
|
print(f"Elapsed time: {elapsed_time}")
|
||||||
|
print(f"Throughput: {positions_per_second:.0f} positions/second")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if positions_count == 0:
|
||||||
|
print("WARNING: No valid positions were generated!")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return positions_count
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training")
|
||||||
|
parser.add_argument("output_file", nargs="?", default="positions.txt",
|
||||||
|
help="Output file for positions (default: positions.txt)")
|
||||||
|
parser.add_argument("--positions", type=int, default=3000000,
|
||||||
|
help="Target number of positions to generate (default: 3000000)")
|
||||||
|
parser.add_argument("--samples-per-game", type=int, default=1,
|
||||||
|
help="Number of positions to sample per game (default: 1)")
|
||||||
|
parser.add_argument("--min-move", type=int, default=1,
|
||||||
|
help="Minimum move number to sample from (default: 1)")
|
||||||
|
parser.add_argument("--max-move", type=int, default=50,
|
||||||
|
help="Maximum move number to sample from (default: 50)")
|
||||||
|
parser.add_argument("--workers", type=int, default=8,
|
||||||
|
help="Number of parallel worker processes (default: 8)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
count = play_random_game_and_collect_positions(
|
||||||
|
output_file=args.output_file,
|
||||||
|
total_positions=args.positions,
|
||||||
|
samples_per_game=args.samples_per_game,
|
||||||
|
min_move=args.min_move,
|
||||||
|
max_move=args.max_move,
|
||||||
|
num_workers=args.workers
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.exit(0 if count > 0 else 1)
|
||||||
@@ -0,0 +1,326 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Label positions with Stockfish evaluations and analyze distribution."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import chess.engine
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
def normalize_evaluation(cp_value, method='tanh', scale=300.0):
|
||||||
|
"""Normalize centipawn evaluation to a bounded range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cp_value: Centipawn evaluation from Stockfish
|
||||||
|
method: 'tanh' (default) or 'sigmoid'
|
||||||
|
scale: Scale factor (tanh: 300 is typical)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized value in approximately [-1, 1] (tanh) or [0, 1] (sigmoid)
|
||||||
|
"""
|
||||||
|
if method == 'tanh':
|
||||||
|
return np.tanh(cp_value / scale)
|
||||||
|
elif method == 'sigmoid':
|
||||||
|
return 1.0 / (1.0 + np.exp(-cp_value / scale))
|
||||||
|
else:
|
||||||
|
return cp_value / 100.0
|
||||||
|
|
||||||
|
def _evaluate_fen_batch(args):
|
||||||
|
"""Worker function to evaluate a batch of FENs with Stockfish threading.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: tuple of (fens, stockfish_path, depth, normalize)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of (fen, eval_normalized, eval_raw) tuples
|
||||||
|
"""
|
||||||
|
fens, stockfish_path, depth, normalize = args
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
for fen in fens:
|
||||||
|
try:
|
||||||
|
board = chess.Board(fen)
|
||||||
|
if not board.is_valid():
|
||||||
|
continue
|
||||||
|
|
||||||
|
info = engine.analyse(board, chess.engine.Limit(depth=depth))
|
||||||
|
|
||||||
|
if info.get('score') is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
score = info['score'].white()
|
||||||
|
|
||||||
|
if score.is_mate():
|
||||||
|
eval_cp = 2000 if score.mate() > 0 else -2000
|
||||||
|
else:
|
||||||
|
eval_cp = score.cp
|
||||||
|
|
||||||
|
eval_cp = max(-2000, min(2000, eval_cp))
|
||||||
|
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
|
||||||
|
|
||||||
|
results.append((fen, eval_normalized, eval_cp))
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
finally:
|
||||||
|
engine.quit()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=1000, depth=12, verbose=False, normalize=True, num_workers=1):
|
||||||
|
"""Read positions and label them with Stockfish evaluations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positions_file: Path to positions.txt
|
||||||
|
output_file: Path to training_data.jsonl
|
||||||
|
stockfish_path: Path to stockfish binary
|
||||||
|
batch_size: Batch size for processing (positions per worker task, default: 1000)
|
||||||
|
depth: Stockfish depth
|
||||||
|
verbose: Print detailed error messages
|
||||||
|
normalize: If True, normalize evals using tanh
|
||||||
|
num_workers: Number of parallel Stockfish processes
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if stockfish exists
|
||||||
|
if not Path(stockfish_path).exists():
|
||||||
|
print(f"Error: Stockfish not found at {stockfish_path}")
|
||||||
|
print(f"Tried: {stockfish_path}")
|
||||||
|
print(f"Set STOCKFISH_PATH environment variable or pass as argument")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Using Stockfish: {stockfish_path}")
|
||||||
|
print(f"Number of workers: {num_workers}")
|
||||||
|
|
||||||
|
# Check if positions file exists
|
||||||
|
if not Path(positions_file).exists():
|
||||||
|
print(f"Error: Positions file not found at {positions_file}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Load existing evaluations if resuming
|
||||||
|
evaluated_fens = set()
|
||||||
|
position_count = 0
|
||||||
|
|
||||||
|
if Path(output_file).exists():
|
||||||
|
with open(output_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
evaluated_fens.add(data['fen'])
|
||||||
|
position_count += 1
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
print(f"Resuming from {position_count} already evaluated positions")
|
||||||
|
|
||||||
|
# Load all FENs that need evaluation
|
||||||
|
fens_to_evaluate = []
|
||||||
|
fens_seen_in_batch = set() # Track duplicates within current batch
|
||||||
|
skipped_invalid = 0
|
||||||
|
skipped_duplicate = 0
|
||||||
|
|
||||||
|
with open(positions_file, 'r') as f:
|
||||||
|
for fen in f:
|
||||||
|
fen = fen.strip()
|
||||||
|
|
||||||
|
if not fen:
|
||||||
|
skipped_invalid += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if fen in evaluated_fens:
|
||||||
|
skipped_duplicate += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if fen in fens_seen_in_batch:
|
||||||
|
skipped_duplicate += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
fens_to_evaluate.append(fen)
|
||||||
|
fens_seen_in_batch.add(fen)
|
||||||
|
|
||||||
|
total_to_evaluate = len(fens_to_evaluate)
|
||||||
|
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
|
||||||
|
|
||||||
|
if total_to_evaluate == 0:
|
||||||
|
if position_count == 0:
|
||||||
|
print(f"Error: No valid positions to evaluate in {positions_file}")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
print(f"All positions already evaluated. No new positions to process.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
print(f"Total positions to process: {total_lines}")
|
||||||
|
print(f"New positions to evaluate: {total_to_evaluate}")
|
||||||
|
print(f"Using depth: {depth}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Split FENs into batches for workers
|
||||||
|
batches = []
|
||||||
|
for i in range(0, total_to_evaluate, batch_size):
|
||||||
|
batch = fens_to_evaluate[i:i+batch_size]
|
||||||
|
batches.append((batch, stockfish_path, depth, normalize))
|
||||||
|
|
||||||
|
# Process batches in parallel
|
||||||
|
evaluated = 0
|
||||||
|
errors = 0
|
||||||
|
raw_evals = []
|
||||||
|
normalized_evals = []
|
||||||
|
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
with Pool(num_workers) as pool:
|
||||||
|
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar:
|
||||||
|
with open(output_file, 'a') as out:
|
||||||
|
for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
|
||||||
|
for fen, eval_normalized, eval_cp in batch_results:
|
||||||
|
# Skip if already evaluated in output file during this run
|
||||||
|
if fen in evaluated_fens:
|
||||||
|
continue
|
||||||
|
|
||||||
|
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
|
||||||
|
out.write(json.dumps(data) + '\n')
|
||||||
|
evaluated_fens.add(fen) # Track as evaluated
|
||||||
|
evaluated += 1
|
||||||
|
raw_evals.append(eval_cp)
|
||||||
|
normalized_evals.append(eval_normalized)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Update progress for any failed evaluations in the batch
|
||||||
|
batch_size_actual = len(batches[0][0]) if batches else batch_size
|
||||||
|
failed = batch_size_actual - len(batch_results)
|
||||||
|
if failed > 0:
|
||||||
|
errors += failed
|
||||||
|
pbar.update(failed)
|
||||||
|
|
||||||
|
# Calculate and show throughput and ETA
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
throughput = evaluated / elapsed if elapsed > 0 else 0
|
||||||
|
remaining_positions = total_to_evaluate - evaluated
|
||||||
|
eta_seconds = remaining_positions / throughput if throughput > 0 else 0
|
||||||
|
eta_str = f"{int(eta_seconds // 60)}:{int(eta_seconds % 60):02d}"
|
||||||
|
|
||||||
|
if (batch_idx + 1) % max(1, len(batches) // 10) == 0:
|
||||||
|
pbar.set_postfix({
|
||||||
|
'rate': f'{throughput:.0f} pos/s',
|
||||||
|
'eta': eta_str
|
||||||
|
})
|
||||||
|
|
||||||
|
# Print summary and analysis
|
||||||
|
print()
|
||||||
|
print("=" * 60)
|
||||||
|
print("LABELING SUMMARY")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Successfully evaluated: {evaluated}")
|
||||||
|
print(f"Skipped (duplicates): {skipped_duplicate}")
|
||||||
|
print(f"Skipped (invalid): {skipped_invalid}")
|
||||||
|
print(f"Errors: {errors}")
|
||||||
|
print(f"Total processed: {evaluated + skipped_duplicate + skipped_invalid + errors}")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if evaluated == 0:
|
||||||
|
print("WARNING: No positions were successfully evaluated!")
|
||||||
|
print("Check that:")
|
||||||
|
print(" 1. positions.txt is not empty")
|
||||||
|
print(" 2. positions.txt contains valid FENs")
|
||||||
|
print(" 3. Stockfish is installed and working")
|
||||||
|
print(" 4. Stockfish path is correct")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Print distribution analysis
|
||||||
|
if raw_evals:
|
||||||
|
raw_evals_arr = np.array(raw_evals)
|
||||||
|
norm_evals_arr = np.array(normalized_evals)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("EVALUATION DISTRIBUTION ANALYSIS")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
print("Raw Evaluations (centipawns):")
|
||||||
|
print(f" Min: {raw_evals_arr.min():.1f}")
|
||||||
|
print(f" Max: {raw_evals_arr.max():.1f}")
|
||||||
|
print(f" Mean: {raw_evals_arr.mean():.1f}")
|
||||||
|
print(f" Median: {np.median(raw_evals_arr):.1f}")
|
||||||
|
print(f" Std: {raw_evals_arr.std():.1f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("Normalized Evaluations (tanh):")
|
||||||
|
print(f" Min: {norm_evals_arr.min():.4f}")
|
||||||
|
print(f" Max: {norm_evals_arr.max():.4f}")
|
||||||
|
print(f" Mean: {norm_evals_arr.mean():.4f}")
|
||||||
|
print(f" Median: {np.median(norm_evals_arr):.4f}")
|
||||||
|
print(f" Std: {norm_evals_arr.std():.4f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Distribution buckets
|
||||||
|
print("Raw Evaluation Buckets (counts):")
|
||||||
|
buckets = [
|
||||||
|
(-float('inf'), -500, "< -5.00"),
|
||||||
|
(-500, -300, "[-5.00, -3.00)"),
|
||||||
|
(-300, -100, "[-3.00, -1.00)"),
|
||||||
|
(-100, 0, "[-1.00, 0.00)"),
|
||||||
|
(0, 100, "[0.00, 1.00)"),
|
||||||
|
(100, 300, "[1.00, 3.00)"),
|
||||||
|
(300, 500, "[3.00, 5.00)"),
|
||||||
|
(500, float('inf'), "> 5.00"),
|
||||||
|
]
|
||||||
|
for low, high, label in buckets:
|
||||||
|
count = np.sum((raw_evals_arr > low) & (raw_evals_arr <= high))
|
||||||
|
pct = 100.0 * count / len(raw_evals_arr)
|
||||||
|
print(f" {label}: {count:6d} ({pct:5.1f}%)")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
print(f"✓ Labeling complete. Output saved to {output_file}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Label chess positions with Stockfish evaluations")
|
||||||
|
parser.add_argument("positions_file", nargs="?", default="positions.txt",
|
||||||
|
help="Input positions file (default: positions.txt)")
|
||||||
|
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
|
||||||
|
help="Output file (default: training_data.jsonl)")
|
||||||
|
parser.add_argument("stockfish_path", nargs="?", default=None,
|
||||||
|
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
|
||||||
|
parser.add_argument("--depth", type=int, default=12,
|
||||||
|
help="Stockfish depth (default: 12)")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=1000,
|
||||||
|
help="Batch size for processing (default: 1000)")
|
||||||
|
parser.add_argument("--no-normalize", action="store_true",
|
||||||
|
help="Disable evaluation normalization (keep raw centipawns)")
|
||||||
|
parser.add_argument("--verbose", action="store_true",
|
||||||
|
help="Print detailed error messages")
|
||||||
|
parser.add_argument("--workers", type=int, default=1,
|
||||||
|
help="Number of parallel Stockfish processes (default: 1)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Determine Stockfish path
|
||||||
|
stockfish_path = args.stockfish_path or os.environ.get("STOCKFISH_PATH", "stockfish")
|
||||||
|
|
||||||
|
success = label_positions_with_stockfish(
|
||||||
|
positions_file=args.positions_file,
|
||||||
|
output_file=args.output_file,
|
||||||
|
stockfish_path=stockfish_path,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
depth=args.depth,
|
||||||
|
normalize=not args.no_normalize,
|
||||||
|
verbose=args.verbose,
|
||||||
|
num_workers=args.workers
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
@@ -0,0 +1,208 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Import pre-labeled positions from the Lichess evaluation database.
|
||||||
|
|
||||||
|
Source: https://database.lichess.org/#evals
|
||||||
|
Format: lichess_db_eval.jsonl.zst — compressed JSONL, one position per line.
|
||||||
|
|
||||||
|
Each line:
|
||||||
|
{
|
||||||
|
"fen": "<pieces> <turn> <castling> <ep>",
|
||||||
|
"evals": [
|
||||||
|
{
|
||||||
|
"knodes": <int>,
|
||||||
|
"depth": <int>,
|
||||||
|
"pvs": [{"cp": <int>, "line": "..."} | {"mate": <int>, "line": "..."}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
cp and mate are from White's perspective (positive = White winning), matching
|
||||||
|
the sign convention used by label.py (score.white()) and expected by train.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
MATE_CP = 20000
|
||||||
|
SCALE = 300.0
|
||||||
|
|
||||||
|
|
||||||
|
def _best_eval(evals: list) -> dict | None:
|
||||||
|
"""Return the highest-depth evaluation entry, using knodes as tiebreaker."""
|
||||||
|
if not evals:
|
||||||
|
return None
|
||||||
|
return max(evals, key=lambda e: (e.get("depth", 0), e.get("knodes", 0)))
|
||||||
|
|
||||||
|
|
||||||
|
def _cp_from_pv(pv: dict) -> int | None:
|
||||||
|
"""Extract centipawn value from a principal variation entry."""
|
||||||
|
if "cp" in pv:
|
||||||
|
return max(-MATE_CP, min(MATE_CP, pv["cp"]))
|
||||||
|
if "mate" in pv:
|
||||||
|
return MATE_CP if pv["mate"] > 0 else -MATE_CP
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize(cp: int) -> float:
|
||||||
|
return float(np.tanh(cp / SCALE))
|
||||||
|
|
||||||
|
|
||||||
|
def import_lichess_evals(
|
||||||
|
input_path: str,
|
||||||
|
output_file: str,
|
||||||
|
max_positions: int | None = None,
|
||||||
|
min_depth: int = 0,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""Stream the Lichess eval database and write a labeled.jsonl file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Path to lichess_db_eval.jsonl.zst (or uncompressed .jsonl).
|
||||||
|
output_file: Destination labeled.jsonl (appended — supports resuming).
|
||||||
|
max_positions: Stop after this many new positions (None = no limit).
|
||||||
|
min_depth: Skip positions whose best eval has depth < min_depth.
|
||||||
|
verbose: Print warnings for skipped lines.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of new positions written.
|
||||||
|
"""
|
||||||
|
import zstandard as zstd
|
||||||
|
|
||||||
|
input_path = Path(input_path)
|
||||||
|
if not input_path.exists():
|
||||||
|
print(f"Error: {input_path} not found")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Resume: collect already-written FENs so we skip duplicates.
|
||||||
|
seen_fens: set[str] = set()
|
||||||
|
if Path(output_file).exists():
|
||||||
|
with open(output_file, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
seen_fens.add(json.loads(line)["fen"])
|
||||||
|
except (json.JSONDecodeError, KeyError):
|
||||||
|
pass
|
||||||
|
if seen_fens:
|
||||||
|
print(f"Resuming — skipping {len(seen_fens):,} already-imported positions")
|
||||||
|
|
||||||
|
written = 0
|
||||||
|
skipped_depth = 0
|
||||||
|
skipped_no_eval = 0
|
||||||
|
skipped_dup = 0
|
||||||
|
|
||||||
|
def iter_lines():
|
||||||
|
"""Yield decoded text lines from either a .zst or plain .jsonl file."""
|
||||||
|
import io
|
||||||
|
if input_path.suffix == ".zst":
|
||||||
|
dctx = zstd.ZstdDecompressor()
|
||||||
|
with open(input_path, "rb") as fh:
|
||||||
|
with dctx.stream_reader(fh) as reader:
|
||||||
|
text_stream = io.TextIOWrapper(reader, encoding="utf-8")
|
||||||
|
yield from text_stream
|
||||||
|
else:
|
||||||
|
with open(input_path, "r", encoding="utf-8") as fh:
|
||||||
|
yield from fh
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(output_file, "a") as out:
|
||||||
|
with tqdm(desc="Importing Lichess evals", unit=" pos") as pbar:
|
||||||
|
for raw_line in iter_lines():
|
||||||
|
line = raw_line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
if verbose:
|
||||||
|
print("Warning: malformed JSON line skipped")
|
||||||
|
continue
|
||||||
|
|
||||||
|
fen = data.get("fen", "")
|
||||||
|
if not fen:
|
||||||
|
skipped_no_eval += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if fen in seen_fens:
|
||||||
|
skipped_dup += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
best = _best_eval(data.get("evals", []))
|
||||||
|
if best is None:
|
||||||
|
skipped_no_eval += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if best.get("depth", 0) < min_depth:
|
||||||
|
skipped_depth += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
pvs = best.get("pvs", [])
|
||||||
|
if not pvs:
|
||||||
|
skipped_no_eval += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
cp = _cp_from_pv(pvs[0])
|
||||||
|
if cp is None:
|
||||||
|
skipped_no_eval += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"fen": fen,
|
||||||
|
"eval": _normalize(cp),
|
||||||
|
"eval_raw": cp,
|
||||||
|
}
|
||||||
|
out.write(json.dumps(record) + "\n")
|
||||||
|
seen_fens.add(fen)
|
||||||
|
written += 1
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
if max_positions and written >= max_positions:
|
||||||
|
print(f"\nReached max_positions limit ({max_positions:,})")
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 60)
|
||||||
|
print("LICHESS IMPORT SUMMARY")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Positions written: {written:,}")
|
||||||
|
print(f"Skipped (dup): {skipped_dup:,}")
|
||||||
|
print(f"Skipped (no eval): {skipped_no_eval:,}")
|
||||||
|
print(f"Skipped (depth<{min_depth}): {skipped_depth:,}")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"\n✓ Output: {output_file}")
|
||||||
|
|
||||||
|
return written
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Import Lichess pre-labeled positions into labeled.jsonl"
|
||||||
|
)
|
||||||
|
parser.add_argument("input_path",
|
||||||
|
help="Path to lichess_db_eval.jsonl.zst")
|
||||||
|
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
|
||||||
|
help="Output labeled.jsonl (default: training_data.jsonl)")
|
||||||
|
parser.add_argument("--max-positions", type=int, default=None,
|
||||||
|
help="Stop after N positions (default: no limit)")
|
||||||
|
parser.add_argument("--min-depth", type=int, default=0,
|
||||||
|
help="Minimum eval depth to accept (default: 0)")
|
||||||
|
parser.add_argument("--verbose", action="store_true",
|
||||||
|
help="Print warnings for skipped lines")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
count = import_lichess_evals(
|
||||||
|
input_path=args.input_path,
|
||||||
|
output_file=args.output_file,
|
||||||
|
max_positions=args.max_positions,
|
||||||
|
min_depth=args.min_depth,
|
||||||
|
verbose=args.verbose,
|
||||||
|
)
|
||||||
|
sys.exit(0 if count > 0 else 1)
|
||||||
@@ -0,0 +1,249 @@
|
|||||||
|
import chess
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Set, Tuple
|
||||||
|
|
||||||
|
try:
|
||||||
|
import zstandard as zstd
|
||||||
|
except ImportError:
|
||||||
|
print("zstandard library not found. Install with: pip install zstandard")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
from generate import play_random_game_and_collect_positions
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_extract_puzzle_db(
|
||||||
|
url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst',
|
||||||
|
output_dir: str = 'tactical_data'
|
||||||
|
):
|
||||||
|
"""Download and extract the Lichess puzzle database."""
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
csv_file = output_path / 'lichess_db_puzzle.csv'
|
||||||
|
zst_file = output_path / 'lichess_db_puzzle.csv.zst'
|
||||||
|
|
||||||
|
# Download if not already present
|
||||||
|
if not zst_file.exists():
|
||||||
|
print(f"Downloading puzzle database from {url}...")
|
||||||
|
try:
|
||||||
|
urllib.request.urlretrieve(url, zst_file)
|
||||||
|
print(f"Downloaded to {zst_file}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to download: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract if CSV doesn't exist
|
||||||
|
if not csv_file.exists():
|
||||||
|
print(f"Extracting {zst_file}...")
|
||||||
|
try:
|
||||||
|
with open(zst_file, 'rb') as f:
|
||||||
|
dctx = zstd.ZstdDecompressor()
|
||||||
|
with dctx.stream_reader(f) as reader:
|
||||||
|
with open(csv_file, 'wb') as out:
|
||||||
|
out.write(reader.read())
|
||||||
|
print(f"Extracted to {csv_file}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to extract: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return str(csv_file)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_puzzle_positions(
|
||||||
|
puzzle_csv: str,
|
||||||
|
max_puzzles: int = 300_000
|
||||||
|
) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Extract the position BEFORE the blunder from each puzzle.
|
||||||
|
This is exactly the type of position where tactical
|
||||||
|
recognition matters most.
|
||||||
|
|
||||||
|
Returns a set of unique FENs.
|
||||||
|
"""
|
||||||
|
positions = set()
|
||||||
|
|
||||||
|
with open(puzzle_csv) as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
if len(positions) >= max_puzzles:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
board = chess.Board(row['FEN'])
|
||||||
|
|
||||||
|
# The puzzle FEN is AFTER the blunder move
|
||||||
|
# We want the position BEFORE — so it learns
|
||||||
|
# to find the tactic, not just play it
|
||||||
|
moves = row['Moves'].split()
|
||||||
|
|
||||||
|
# Undo one move to get pre-tactic position
|
||||||
|
board.push_uci(moves[0]) # opponent blunder
|
||||||
|
fen = board.fen()
|
||||||
|
|
||||||
|
# Filter for useful tactical themes
|
||||||
|
themes = row.get('Themes', '')
|
||||||
|
useful = any(t in themes for t in [
|
||||||
|
'fork', 'pin', 'skewer', 'discoveredAttack',
|
||||||
|
'mate', 'mateIn2', 'mateIn3', 'hangingPiece',
|
||||||
|
'trappedPiece', 'sacrifice'
|
||||||
|
])
|
||||||
|
|
||||||
|
if useful:
|
||||||
|
positions.add(fen)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return positions
|
||||||
|
|
||||||
|
|
||||||
|
def load_positions_from_file(file_path: str) -> Set[str]:
|
||||||
|
"""Load positions from a text file (one FEN per line)."""
|
||||||
|
positions = set()
|
||||||
|
try:
|
||||||
|
with open(file_path) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
positions.add(line)
|
||||||
|
print(f"Loaded {len(positions)} positions from {file_path}")
|
||||||
|
return positions
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load from {file_path}: {e}")
|
||||||
|
return set()
|
||||||
|
|
||||||
|
|
||||||
|
def merge_positions(
|
||||||
|
tactical: Set[str],
|
||||||
|
other: Set[str],
|
||||||
|
output_file: str = 'position.txt'
|
||||||
|
):
|
||||||
|
"""Merge two position sets and write to file."""
|
||||||
|
merged = tactical | other
|
||||||
|
|
||||||
|
with open(output_file, 'w') as f:
|
||||||
|
for fen in merged:
|
||||||
|
f.write(fen + '\n')
|
||||||
|
|
||||||
|
overlap = len(tactical & other)
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"MERGE SUMMARY")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"Tactical positions: {len(tactical):,}")
|
||||||
|
print(f"Other positions: {len(other):,}")
|
||||||
|
print(f"Overlap (deduplicated): {overlap:,}")
|
||||||
|
print(f"Total merged positions: {len(merged):,}")
|
||||||
|
print(f"Written to: {output_file}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tactical_only(
|
||||||
|
puzzle_csv: str,
|
||||||
|
output_file: str,
|
||||||
|
max_puzzles: int = 300_000
|
||||||
|
) -> int:
|
||||||
|
"""Extract tactical positions and save to file (no merge prompts).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
puzzle_csv: Path to Lichess puzzle CSV
|
||||||
|
output_file: Where to save the FEN positions
|
||||||
|
max_puzzles: Maximum puzzles to extract
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of positions extracted
|
||||||
|
"""
|
||||||
|
print("Extracting tactical positions from puzzle database...")
|
||||||
|
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
|
||||||
|
|
||||||
|
with open(output_file, 'w') as f:
|
||||||
|
for fen in tactical_positions:
|
||||||
|
f.write(fen + '\n')
|
||||||
|
|
||||||
|
return len(tactical_positions)
|
||||||
|
|
||||||
|
|
||||||
|
def interactive_merge_positions(
|
||||||
|
puzzle_csv: str,
|
||||||
|
output_file: str = 'position.txt',
|
||||||
|
max_puzzles: int = 300_000
|
||||||
|
):
|
||||||
|
"""Interactive workflow: extract tactical positions and merge with user selection."""
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("TACTICAL POSITION EXTRACTOR & MERGER")
|
||||||
|
print("="*60 + "\n")
|
||||||
|
|
||||||
|
# Extract tactical positions
|
||||||
|
print("Extracting tactical positions from puzzle database...")
|
||||||
|
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
|
||||||
|
print(f"Extracted {len(tactical_positions):,} unique tactical positions\n")
|
||||||
|
|
||||||
|
# Ask what to merge with
|
||||||
|
print("What would you like to merge with these tactical positions?")
|
||||||
|
print("1. Load from a position file")
|
||||||
|
print("2. Generate random positions")
|
||||||
|
print("3. Skip merging (save tactical only)")
|
||||||
|
|
||||||
|
choice = input("\nEnter choice (1-3): ").strip()
|
||||||
|
|
||||||
|
other_positions = set()
|
||||||
|
|
||||||
|
if choice == '1':
|
||||||
|
file_path = input("Enter path to position file: ").strip()
|
||||||
|
other_positions = load_positions_from_file(file_path)
|
||||||
|
|
||||||
|
elif choice == '2':
|
||||||
|
positions_to_gen = input("How many positions to generate? (default 1000000): ").strip()
|
||||||
|
try:
|
||||||
|
positions_to_gen = int(positions_to_gen) if positions_to_gen else 1000000
|
||||||
|
except ValueError:
|
||||||
|
positions_to_gen = 1000000
|
||||||
|
|
||||||
|
temp_file = 'temp_generated_positions.txt'
|
||||||
|
print(f"\nGenerating {positions_to_gen:,} random positions...")
|
||||||
|
play_random_game_and_collect_positions(
|
||||||
|
output_file=temp_file,
|
||||||
|
total_positions=positions_to_gen,
|
||||||
|
samples_per_game=1,
|
||||||
|
min_move=1,
|
||||||
|
max_move=50,
|
||||||
|
num_workers=8
|
||||||
|
)
|
||||||
|
other_positions = load_positions_from_file(temp_file)
|
||||||
|
|
||||||
|
elif choice == '3':
|
||||||
|
print("Skipping merge, saving tactical positions only...")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Invalid choice, saving tactical positions only...")
|
||||||
|
|
||||||
|
merge_positions(tactical_positions, other_positions, output_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Extract and merge tactical positions")
|
||||||
|
parser.add_argument("--url", default='https://database.lichess.org/lichess_db_puzzle.csv.zst',
|
||||||
|
help="URL to download puzzle database from")
|
||||||
|
parser.add_argument("--output-dir", default='trainingdata',
|
||||||
|
help="Directory to extract puzzle database to")
|
||||||
|
parser.add_argument("--max-puzzles", type=int, default=300_000,
|
||||||
|
help="Maximum puzzles to extract (default: 300000)")
|
||||||
|
parser.add_argument("--output-file", default='position.txt',
|
||||||
|
help="Output file for merged positions (default: position.txt)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Download and extract
|
||||||
|
csv_path = download_and_extract_puzzle_db(args.url, args.output_dir)
|
||||||
|
|
||||||
|
if csv_path:
|
||||||
|
# Interactive merge
|
||||||
|
interactive_merge_positions(csv_path, args.output_file, args.max_puzzles)
|
||||||
|
else:
|
||||||
|
print("Failed to download/extract puzzle database")
|
||||||
|
sys.exit(1)
|
||||||
@@ -0,0 +1,676 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Train NNUE neural network for chess evaluation."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
import chess
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class NNUEDataset(Dataset):
|
||||||
|
"""Dataset of chess positions with evaluations."""
|
||||||
|
|
||||||
|
def __init__(self, data_file):
|
||||||
|
self.positions = []
|
||||||
|
self.evals = []
|
||||||
|
self.evals_raw = []
|
||||||
|
self.is_normalized = None
|
||||||
|
|
||||||
|
with open(data_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
fen = data['fen']
|
||||||
|
eval_val = data['eval']
|
||||||
|
self.positions.append(fen)
|
||||||
|
self.evals.append(eval_val)
|
||||||
|
|
||||||
|
# Check if normalized or raw
|
||||||
|
if self.is_normalized is None:
|
||||||
|
# If eval is in range [-1, 1], assume normalized
|
||||||
|
self.is_normalized = abs(eval_val) <= 1.0
|
||||||
|
|
||||||
|
# Store raw if available
|
||||||
|
if 'eval_raw' in data:
|
||||||
|
self.evals_raw.append(data['eval_raw'])
|
||||||
|
else:
|
||||||
|
self.evals_raw.append(eval_val)
|
||||||
|
except (json.JSONDecodeError, KeyError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.positions)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
fen = self.positions[idx]
|
||||||
|
eval_val = self.evals[idx]
|
||||||
|
features = fen_to_features(fen)
|
||||||
|
|
||||||
|
# Use evaluation as-is if normalized, otherwise apply sigmoid scaling
|
||||||
|
if self.is_normalized:
|
||||||
|
target = torch.tensor(eval_val, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32))
|
||||||
|
|
||||||
|
return features, target
|
||||||
|
|
||||||
|
def fen_to_features(fen):
|
||||||
|
"""Convert FEN to 768-dimensional binary feature vector."""
|
||||||
|
# Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5
|
||||||
|
piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
|
||||||
|
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}
|
||||||
|
|
||||||
|
features = torch.zeros(768, dtype=torch.float32)
|
||||||
|
|
||||||
|
try:
|
||||||
|
board = chess.Board(fen)
|
||||||
|
|
||||||
|
# 12 piece types × 64 squares = 768
|
||||||
|
for square in chess.SQUARES:
|
||||||
|
piece = board.piece_at(square)
|
||||||
|
if piece is not None:
|
||||||
|
piece_char = piece.symbol()
|
||||||
|
if piece_char in piece_to_idx:
|
||||||
|
piece_idx = piece_to_idx[piece_char]
|
||||||
|
feature_idx = piece_idx * 64 + square
|
||||||
|
features[feature_idx] = 1.0
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
DEFAULT_HIDDEN_SIZES = [1536, 1024, 512, 256]
|
||||||
|
|
||||||
|
|
||||||
|
class NNUE(nn.Module):
|
||||||
|
"""NNUE neural network with configurable hidden layers.
|
||||||
|
|
||||||
|
Architecture: 768 → hidden_sizes[0] → ... → hidden_sizes[-1] → 1
|
||||||
|
Layer attributes follow the naming l1, l2, ..., lN so export.py can
|
||||||
|
infer the architecture directly from the state_dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_sizes=None, dropout_rate=0.2):
|
||||||
|
super().__init__()
|
||||||
|
if hidden_sizes is None:
|
||||||
|
hidden_sizes = DEFAULT_HIDDEN_SIZES
|
||||||
|
self.hidden_sizes = list(hidden_sizes)
|
||||||
|
sizes = [768] + self.hidden_sizes + [1]
|
||||||
|
num_hidden = len(self.hidden_sizes)
|
||||||
|
|
||||||
|
for i in range(num_hidden):
|
||||||
|
setattr(self, f"l{i + 1}", nn.Linear(sizes[i], sizes[i + 1]))
|
||||||
|
setattr(self, f"relu{i + 1}", nn.ReLU())
|
||||||
|
setattr(self, f"drop{i + 1}", nn.Dropout(dropout_rate))
|
||||||
|
setattr(self, f"l{num_hidden + 1}", nn.Linear(sizes[-2], sizes[-1]))
|
||||||
|
self._num_hidden = num_hidden
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for i in range(1, self._num_hidden + 1):
|
||||||
|
layer = getattr(self, f"l{i}")
|
||||||
|
relu = getattr(self, f"relu{i}")
|
||||||
|
drop = getattr(self, f"drop{i}")
|
||||||
|
x = drop(relu(layer(x)))
|
||||||
|
return getattr(self, f"l{self._num_hidden + 1}")(x)
|
||||||
|
|
||||||
|
def find_next_version(base_name="nnue_weights"):
|
||||||
|
"""Find the next version number for model versioning.
|
||||||
|
|
||||||
|
Looks for nnue_weights_v*.pt files and returns the next version number.
|
||||||
|
If no versioned files exist, returns 1.
|
||||||
|
"""
|
||||||
|
base_path = Path(base_name)
|
||||||
|
directory = base_path.parent
|
||||||
|
filename = base_path.name
|
||||||
|
|
||||||
|
pattern = re.compile(rf"{re.escape(filename)}_v(\d+)\.pt")
|
||||||
|
versions = []
|
||||||
|
|
||||||
|
for file in directory.glob(f"{filename}_v*.pt"):
|
||||||
|
match = pattern.match(file.name)
|
||||||
|
if match:
|
||||||
|
versions.append(int(match.group(1)))
|
||||||
|
|
||||||
|
if versions:
|
||||||
|
return max(versions) + 1
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def save_metadata(weights_file, metadata):
|
||||||
|
"""Save training metadata alongside the weights file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt)
|
||||||
|
metadata: Dictionary with training info
|
||||||
|
"""
|
||||||
|
metadata_file = weights_file.replace(".pt", "_metadata.json")
|
||||||
|
|
||||||
|
with open(metadata_file, "w") as f:
|
||||||
|
json.dump(metadata, f, indent=2, default=str)
|
||||||
|
|
||||||
|
return metadata_file
|
||||||
|
|
||||||
|
def _setup_training(data_file, batch_size, subsample_ratio):
|
||||||
|
"""Set up device, dataset, and data loaders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions)
|
||||||
|
"""
|
||||||
|
print("Checking GPU availability...")
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
|
||||||
|
print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
||||||
|
else:
|
||||||
|
print("⚠ GPU not available, using CPU")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("Loading dataset...")
|
||||||
|
dataset = NNUEDataset(data_file)
|
||||||
|
num_positions = len(dataset)
|
||||||
|
print(f"Dataset size: {num_positions}")
|
||||||
|
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
|
||||||
|
|
||||||
|
evals_array = np.array(dataset.evals)
|
||||||
|
print()
|
||||||
|
print("=" * 60)
|
||||||
|
print("TRAINING DATASET DIAGNOSTICS")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Min evaluation: {evals_array.min():.4f}")
|
||||||
|
print(f"Max evaluation: {evals_array.max():.4f}")
|
||||||
|
print(f"Mean evaluation: {evals_array.mean():.4f}")
|
||||||
|
print(f"Median evaluation: {np.median(evals_array):.4f}")
|
||||||
|
print(f"Std deviation: {evals_array.std():.4f}")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
train_size = int(0.9 * len(dataset))
|
||||||
|
val_size = len(dataset) - train_size
|
||||||
|
|
||||||
|
from torch.utils.data import random_split, RandomSampler
|
||||||
|
generator = torch.Generator().manual_seed(42)
|
||||||
|
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
|
||||||
|
|
||||||
|
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
||||||
|
train_sampler = RandomSampler(train_dataset, replacement=False, num_samples=subsample_size)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sampler=train_sampler,
|
||||||
|
num_workers=8,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=True
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=8,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions
|
||||||
|
|
||||||
|
def _run_training_season(
|
||||||
|
model, optimizer, scheduler, scaler,
|
||||||
|
train_loader, val_loader, train_dataset, val_dataset,
|
||||||
|
device, criterion, output_file,
|
||||||
|
start_epoch, epochs, early_stopping_patience,
|
||||||
|
season_start_time, deadline=None, initial_best_val_loss=float('inf')
|
||||||
|
):
|
||||||
|
"""Run one training season until epoch limit, early stopping, or deadline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_best_val_loss: Baseline to beat — epochs that don't improve on this count
|
||||||
|
toward early stopping and do not save snapshots.
|
||||||
|
Returns:
|
||||||
|
(best_val_loss, best_model_state, last_epoch)
|
||||||
|
best_model_state is None if no epoch beat initial_best_val_loss.
|
||||||
|
"""
|
||||||
|
best_val_loss = initial_best_val_loss
|
||||||
|
best_model_state = None
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
total_epochs = start_epoch + epochs
|
||||||
|
last_epoch = start_epoch - 1
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, start_epoch + epochs):
|
||||||
|
if deadline and datetime.now() >= deadline:
|
||||||
|
print("Time limit reached, stopping season.")
|
||||||
|
break
|
||||||
|
|
||||||
|
epoch_display = epoch + 1
|
||||||
|
|
||||||
|
# Train
|
||||||
|
model.train()
|
||||||
|
train_loss = 0.0
|
||||||
|
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
|
||||||
|
for batch_features, batch_targets in train_loader:
|
||||||
|
batch_features = batch_features.to(device)
|
||||||
|
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
|
||||||
|
outputs = model(batch_features)
|
||||||
|
loss = criterion(outputs, batch_targets)
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
train_loss += loss.item() * batch_features.size(0)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
train_loss /= len(train_dataset)
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar:
|
||||||
|
for batch_features, batch_targets in val_loader:
|
||||||
|
batch_features = batch_features.to(device)
|
||||||
|
batch_targets = batch_targets.to(device).unsqueeze(1)
|
||||||
|
|
||||||
|
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
|
||||||
|
outputs = model(batch_features)
|
||||||
|
loss = criterion(outputs, batch_targets)
|
||||||
|
val_loss += loss.item() * batch_features.size(0)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
val_loss /= len(val_dataset)
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9
|
||||||
|
gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9
|
||||||
|
print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved")
|
||||||
|
|
||||||
|
elapsed_time = datetime.now() - season_start_time
|
||||||
|
time_per_epoch = elapsed_time.total_seconds() / (epoch + 1)
|
||||||
|
remaining_epochs = total_epochs - epoch_display
|
||||||
|
eta_seconds = time_per_epoch * remaining_epochs
|
||||||
|
eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0]
|
||||||
|
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}")
|
||||||
|
|
||||||
|
checkpoint_file = output_file.replace(".pt", "_checkpoint.pt")
|
||||||
|
torch.save({
|
||||||
|
"epoch": epoch,
|
||||||
|
"model_state_dict": model.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"scheduler_state_dict": scheduler.state_dict(),
|
||||||
|
"scaler_state_dict": scaler.state_dict(),
|
||||||
|
"best_val_loss": best_val_loss,
|
||||||
|
"hidden_sizes": model.hidden_sizes,
|
||||||
|
}, checkpoint_file)
|
||||||
|
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
best_model_state = model.state_dict().copy()
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
snapshot_file = output_file.replace(".pt", "_best_snapshot.pt")
|
||||||
|
torch.save(best_model_state, snapshot_file)
|
||||||
|
print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})")
|
||||||
|
else:
|
||||||
|
epochs_without_improvement += 1
|
||||||
|
|
||||||
|
last_epoch = epoch
|
||||||
|
|
||||||
|
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
|
||||||
|
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
|
||||||
|
break
|
||||||
|
|
||||||
|
return best_val_loss, best_model_state, last_epoch
|
||||||
|
|
||||||
|
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||||
|
best_val_loss, output_file, use_versioning, num_positions,
|
||||||
|
stockfish_depth, training_start_time, hidden_sizes=None,
|
||||||
|
extra_metadata=None):
|
||||||
|
"""Save the best model with optional versioning and metadata."""
|
||||||
|
final_output_file = output_file
|
||||||
|
metadata = {}
|
||||||
|
architecture = [768] + list(hidden_sizes or DEFAULT_HIDDEN_SIZES) + [1]
|
||||||
|
|
||||||
|
if use_versioning:
|
||||||
|
base_name = output_file.replace(".pt", "")
|
||||||
|
version = find_next_version(base_name)
|
||||||
|
final_output_file = f"{base_name}_v{version}.pt"
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"version": version,
|
||||||
|
"date": training_start_time.isoformat(),
|
||||||
|
"num_positions": num_positions,
|
||||||
|
"stockfish_depth": stockfish_depth,
|
||||||
|
"final_val_loss": float(best_val_loss),
|
||||||
|
"architecture": architecture,
|
||||||
|
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||||
|
}
|
||||||
|
if extra_metadata:
|
||||||
|
metadata.update(extra_metadata)
|
||||||
|
|
||||||
|
torch.save({
|
||||||
|
"model_state_dict": best_model_state,
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"scheduler_state_dict": scheduler.state_dict(),
|
||||||
|
"scaler_state_dict": scaler.state_dict(),
|
||||||
|
"epoch": last_epoch,
|
||||||
|
"best_val_loss": best_val_loss,
|
||||||
|
"hidden_sizes": list(hidden_sizes or DEFAULT_HIDDEN_SIZES),
|
||||||
|
}, final_output_file)
|
||||||
|
print(f"Best model saved to {final_output_file}")
|
||||||
|
|
||||||
|
if use_versioning and metadata:
|
||||||
|
metadata_file = save_metadata(final_output_file, metadata)
|
||||||
|
print(f"Metadata saved to {metadata_file}")
|
||||||
|
print(f"\nTraining Summary:")
|
||||||
|
for key, val in metadata.items():
|
||||||
|
print(f" {key}: {val}")
|
||||||
|
|
||||||
|
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
|
||||||
|
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_file: Path to training_data.jsonl
|
||||||
|
output_file: Where to save best weights (or base name if use_versioning=True)
|
||||||
|
epochs: Number of training epochs (default: 100)
|
||||||
|
batch_size: Training batch size (default: 16384)
|
||||||
|
lr: Learning rate (default: 0.001)
|
||||||
|
checkpoint: Optional path to checkpoint file to resume from
|
||||||
|
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||||
|
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||||
|
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
|
||||||
|
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
|
||||||
|
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
|
||||||
|
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
|
||||||
|
"""
|
||||||
|
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||||
|
_setup_training(data_file, batch_size, subsample_ratio)
|
||||||
|
|
||||||
|
start_epoch = 0
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
|
||||||
|
|
||||||
|
if checkpoint:
|
||||||
|
print(f"Loading checkpoint: {checkpoint}")
|
||||||
|
ckpt = torch.load(checkpoint, map_location=device)
|
||||||
|
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||||
|
ckpt_hidden = ckpt.get("hidden_sizes")
|
||||||
|
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
|
||||||
|
print(f" Using architecture from checkpoint: {ckpt_hidden}")
|
||||||
|
resolved_hidden_sizes = ckpt_hidden
|
||||||
|
|
||||||
|
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||||
|
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||||
|
|
||||||
|
if checkpoint:
|
||||||
|
ckpt = torch.load(checkpoint, map_location=device)
|
||||||
|
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||||
|
model.load_state_dict(ckpt["model_state_dict"])
|
||||||
|
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
||||||
|
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
||||||
|
scaler.load_state_dict(ckpt["scaler_state_dict"])
|
||||||
|
start_epoch = ckpt["epoch"] + 1
|
||||||
|
best_val_loss = ckpt.get("best_val_loss", float('inf'))
|
||||||
|
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
|
||||||
|
else:
|
||||||
|
model.load_state_dict(ckpt)
|
||||||
|
print("Loaded weights-only checkpoint (no optimizer state)")
|
||||||
|
|
||||||
|
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
|
||||||
|
|
||||||
|
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
|
||||||
|
arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
|
||||||
|
print(f"Architecture: {arch_str}")
|
||||||
|
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
|
||||||
|
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
|
||||||
|
print(f"Mixed precision training: enabled")
|
||||||
|
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
|
||||||
|
if subsample_ratio < 1.0:
|
||||||
|
print(f"Stochastic sampling: {subsample_ratio:.0%} of train set per epoch ({subsample_size:,} positions)")
|
||||||
|
if early_stopping_patience:
|
||||||
|
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
training_start_time = datetime.now()
|
||||||
|
|
||||||
|
best_val_loss, best_model_state, last_epoch = _run_training_season(
|
||||||
|
model, optimizer, scheduler, scaler,
|
||||||
|
train_loader, val_loader, train_dataset, val_dataset,
|
||||||
|
device, criterion, output_file,
|
||||||
|
start_epoch, epochs, early_stopping_patience,
|
||||||
|
training_start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
if best_model_state is None or best_val_loss >= checkpoint_val_loss:
|
||||||
|
print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})")
|
||||||
|
print("No new model created.")
|
||||||
|
return
|
||||||
|
|
||||||
|
_save_versioned_model(
|
||||||
|
best_model_state, optimizer, scheduler, scaler, last_epoch,
|
||||||
|
best_val_loss, output_file, use_versioning, num_positions,
|
||||||
|
stockfish_depth, training_start_time,
|
||||||
|
hidden_sizes=resolved_hidden_sizes,
|
||||||
|
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
|
||||||
|
"checkpoint": str(checkpoint) if checkpoint else None}
|
||||||
|
)
|
||||||
|
|
||||||
|
def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
|
||||||
|
epochs_per_season=50, early_stopping_patience=10,
|
||||||
|
batch_size=16384, lr=0.001, initial_checkpoint=None,
|
||||||
|
stockfish_depth=12, use_versioning=True,
|
||||||
|
weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
|
||||||
|
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
|
||||||
|
|
||||||
|
Each season trains with early stopping. When early stopping fires, the model reloads the
|
||||||
|
global best weights and begins a fresh season with a reset optimizer and scheduler.
|
||||||
|
This prevents the model from drifting away from its best known state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_file: Path to training_data.jsonl
|
||||||
|
output_file: Output file base name
|
||||||
|
duration_minutes: Total training budget in minutes
|
||||||
|
epochs_per_season: Max epochs per restart season (default: 50)
|
||||||
|
early_stopping_patience: Patience for early stopping within each season (default: 10)
|
||||||
|
batch_size: Training batch size (default: 16384)
|
||||||
|
lr: Learning rate reset to this value at the start of each season (default: 0.001)
|
||||||
|
initial_checkpoint: Optional weights-only .pt file to start from
|
||||||
|
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
|
||||||
|
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
|
||||||
|
weight_decay: L2 regularization strength (default: 1e-4)
|
||||||
|
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
|
||||||
|
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
|
||||||
|
"""
|
||||||
|
deadline = datetime.now() + timedelta(minutes=duration_minutes)
|
||||||
|
|
||||||
|
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
|
||||||
|
_setup_training(data_file, batch_size, subsample_ratio)
|
||||||
|
|
||||||
|
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
|
||||||
|
|
||||||
|
if initial_checkpoint:
|
||||||
|
print(f"Loading initial weights: {initial_checkpoint}")
|
||||||
|
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||||
|
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||||
|
ckpt_hidden = ckpt.get("hidden_sizes")
|
||||||
|
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
|
||||||
|
print(f" Using architecture from checkpoint: {ckpt_hidden}")
|
||||||
|
resolved_hidden_sizes = ckpt_hidden
|
||||||
|
|
||||||
|
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
best_global_val_loss = float('inf')
|
||||||
|
|
||||||
|
if initial_checkpoint:
|
||||||
|
ckpt = torch.load(initial_checkpoint, map_location=device)
|
||||||
|
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||||||
|
model.load_state_dict(ckpt["model_state_dict"])
|
||||||
|
best_global_val_loss = ckpt.get("best_val_loss", float('inf'))
|
||||||
|
if best_global_val_loss < float('inf'):
|
||||||
|
print(f"Resumed from checkpoint (best val loss: {best_global_val_loss:.6f})")
|
||||||
|
else:
|
||||||
|
print("Initial weights loaded (no val loss in checkpoint).")
|
||||||
|
else:
|
||||||
|
model.load_state_dict(ckpt)
|
||||||
|
print("Loaded weights-only checkpoint (no val loss info).")
|
||||||
|
|
||||||
|
arch_str = " → ".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
|
||||||
|
print(f"Architecture: {arch_str}")
|
||||||
|
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
|
||||||
|
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
burst_start_time = datetime.now()
|
||||||
|
season = 0
|
||||||
|
best_global_state = None
|
||||||
|
last_optimizer = None
|
||||||
|
last_scheduler = None
|
||||||
|
last_scaler = None
|
||||||
|
last_epoch = 0
|
||||||
|
|
||||||
|
while datetime.now() < deadline:
|
||||||
|
season += 1
|
||||||
|
remaining_minutes = (deadline - datetime.now()).total_seconds() / 60
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print(f"BURST SEASON {season} | {remaining_minutes:.1f} minutes remaining")
|
||||||
|
if best_global_val_loss < float('inf'):
|
||||||
|
print(f"Global best val loss so far: {best_global_val_loss:.6f}")
|
||||||
|
print(f"{'=' * 60}\n")
|
||||||
|
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs_per_season)
|
||||||
|
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
|
||||||
|
|
||||||
|
season_start_time = datetime.now()
|
||||||
|
val_loss, model_state, last_epoch = _run_training_season(
|
||||||
|
model, optimizer, scheduler, scaler,
|
||||||
|
train_loader, val_loader, train_dataset, val_dataset,
|
||||||
|
device, criterion, output_file,
|
||||||
|
0, epochs_per_season, early_stopping_patience,
|
||||||
|
season_start_time, deadline=deadline,
|
||||||
|
initial_best_val_loss=best_global_val_loss
|
||||||
|
)
|
||||||
|
|
||||||
|
last_optimizer = optimizer
|
||||||
|
last_scheduler = scheduler
|
||||||
|
last_scaler = scaler
|
||||||
|
|
||||||
|
if model_state is not None and val_loss < best_global_val_loss:
|
||||||
|
best_global_val_loss = val_loss
|
||||||
|
best_global_state = model_state
|
||||||
|
print(f" New global best: {best_global_val_loss:.6f} (season {season})")
|
||||||
|
|
||||||
|
# Reload global best for the next season so we never drift backwards
|
||||||
|
if best_global_state is not None:
|
||||||
|
model.load_state_dict(best_global_state)
|
||||||
|
|
||||||
|
total_minutes = (datetime.now() - burst_start_time).total_seconds() / 60
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print(f"Burst training complete: {season} season(s) in {total_minutes:.1f}m")
|
||||||
|
print(f"Best val loss: {best_global_val_loss:.6f}")
|
||||||
|
print(f"{'=' * 60}\n")
|
||||||
|
|
||||||
|
if best_global_state is None:
|
||||||
|
print("No model improvement found. No file saved.")
|
||||||
|
return
|
||||||
|
|
||||||
|
_save_versioned_model(
|
||||||
|
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
|
||||||
|
best_global_val_loss, output_file, use_versioning, num_positions,
|
||||||
|
stockfish_depth, burst_start_time,
|
||||||
|
hidden_sizes=resolved_hidden_sizes,
|
||||||
|
extra_metadata={
|
||||||
|
"mode": "burst",
|
||||||
|
"duration_minutes": duration_minutes,
|
||||||
|
"epochs_per_season": epochs_per_season,
|
||||||
|
"early_stopping_patience": early_stopping_patience,
|
||||||
|
"seasons_completed": season,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"learning_rate": lr,
|
||||||
|
"initial_checkpoint": str(initial_checkpoint) if initial_checkpoint else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation")
|
||||||
|
parser.add_argument("data_file", nargs="?", default="training_data.jsonl",
|
||||||
|
help="Path to training_data.jsonl (default: training_data.jsonl)")
|
||||||
|
parser.add_argument("output_file", nargs="?", default="nnue_weights.pt",
|
||||||
|
help="Output file base name (default: nnue_weights.pt)")
|
||||||
|
parser.add_argument("--checkpoint", type=str, default=None,
|
||||||
|
help="Path to checkpoint file to resume training from (optional)")
|
||||||
|
parser.add_argument("--epochs", type=int, default=100,
|
||||||
|
help="Number of epochs to train (default: 100)")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=16384,
|
||||||
|
help="Batch size (default: 16384)")
|
||||||
|
parser.add_argument("--lr", type=float, default=0.001,
|
||||||
|
help="Learning rate (default: 0.001)")
|
||||||
|
parser.add_argument("--early-stopping", type=int, default=None,
|
||||||
|
help="Stop if val loss doesn't improve for N epochs (optional)")
|
||||||
|
parser.add_argument("--stockfish-depth", type=int, default=12,
|
||||||
|
help="Stockfish depth used for evaluations (for metadata, default: 12)")
|
||||||
|
parser.add_argument("--no-versioning", action="store_true",
|
||||||
|
help="Disable automatic versioning (save directly to output file)")
|
||||||
|
parser.add_argument("--weight-decay", type=float, default=5e-5,
|
||||||
|
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
|
||||||
|
parser.add_argument("--subsample-ratio", type=float, default=1.0,
|
||||||
|
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
|
||||||
|
parser.add_argument("--hidden-layers", type=str, default=None,
|
||||||
|
help="Comma-separated hidden layer sizes (default: 1536,1024,512,256)")
|
||||||
|
|
||||||
|
# Burst mode
|
||||||
|
parser.add_argument("--burst-duration", type=float, default=None,
|
||||||
|
help="Enable burst mode: total training budget in minutes")
|
||||||
|
parser.add_argument("--epochs-per-season", type=int, default=50,
|
||||||
|
help="Max epochs per burst season before restarting (default: 50, burst mode only)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
hidden_sizes = [int(x) for x in args.hidden_layers.split(",")] if args.hidden_layers else None
|
||||||
|
|
||||||
|
if args.burst_duration is not None:
|
||||||
|
burst_train(
|
||||||
|
data_file=args.data_file,
|
||||||
|
output_file=args.output_file,
|
||||||
|
duration_minutes=args.burst_duration,
|
||||||
|
epochs_per_season=args.epochs_per_season,
|
||||||
|
early_stopping_patience=args.early_stopping or 10,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
lr=args.lr,
|
||||||
|
initial_checkpoint=args.checkpoint,
|
||||||
|
stockfish_depth=args.stockfish_depth,
|
||||||
|
use_versioning=not args.no_versioning,
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
subsample_ratio=args.subsample_ratio,
|
||||||
|
hidden_sizes=hidden_sizes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_nnue(
|
||||||
|
data_file=args.data_file,
|
||||||
|
output_file=args.output_file,
|
||||||
|
epochs=args.epochs,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
lr=args.lr,
|
||||||
|
checkpoint=args.checkpoint,
|
||||||
|
stockfish_depth=args.stockfish_depth,
|
||||||
|
use_versioning=not args.no_versioning,
|
||||||
|
early_stopping_patience=args.early_stopping,
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
subsample_ratio=args.subsample_ratio,
|
||||||
|
hidden_sizes=hidden_sizes,
|
||||||
|
)
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
# Setup and run NNUE training pipeline
|
||||||
|
|
||||||
|
$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path
|
||||||
|
$VenvDir = Join-Path $ScriptDir ".venv"
|
||||||
|
|
||||||
|
# Check if virtual environment exists
|
||||||
|
if (-not (Test-Path $VenvDir)) {
|
||||||
|
Write-Host "Creating virtual environment..."
|
||||||
|
python -m venv $VenvDir
|
||||||
|
if ($LASTEXITCODE -ne 0) {
|
||||||
|
Write-Host "Error: Failed to create virtual environment. Make sure python is installed."
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Activate virtual environment
|
||||||
|
Write-Host "Activating virtual environment..."
|
||||||
|
$ActivateScript = Join-Path $VenvDir "Scripts\Activate.ps1"
|
||||||
|
& $ActivateScript
|
||||||
|
|
||||||
|
# Install/update dependencies if requirements.txt exists
|
||||||
|
$RequirementsFile = Join-Path $ScriptDir "requirements.txt"
|
||||||
|
if (Test-Path $RequirementsFile) {
|
||||||
|
Write-Host "Installing dependencies..."
|
||||||
|
pip install -q -r $RequirementsFile
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run nnue.py
|
||||||
|
Write-Host "Starting NNUE Training Pipeline..."
|
||||||
|
python (Join-Path $ScriptDir "nnue.py")
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Setup and run NNUE training pipeline
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
VENV_DIR="$SCRIPT_DIR/.venv"
|
||||||
|
|
||||||
|
# Check if virtual environment exists
|
||||||
|
if [ ! -d "$VENV_DIR" ]; then
|
||||||
|
echo "Creating virtual environment..."
|
||||||
|
python3 -m venv "$VENV_DIR"
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Error: Failed to create virtual environment. Make sure python3 is installed."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Activate virtual environment
|
||||||
|
echo "Activating virtual environment..."
|
||||||
|
source "$VENV_DIR/bin/activate"
|
||||||
|
|
||||||
|
# Install/update dependencies if requirements.txt exists
|
||||||
|
if [ -f "$SCRIPT_DIR/requirements.txt" ]; then
|
||||||
|
echo "Installing dependencies..."
|
||||||
|
pip install -q -r "$SCRIPT_DIR/requirements.txt"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run nnue.py
|
||||||
|
echo "Starting NNUE Training Pipeline..."
|
||||||
|
python "$SCRIPT_DIR/nnue.py"
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"version": 10,
|
||||||
|
"date": "2026-04-14T22:18:38.824577",
|
||||||
|
"num_positions": 3022562,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"final_val_loss": 6.248612448225196e-05,
|
||||||
|
"architecture": [
|
||||||
|
768,
|
||||||
|
1536,
|
||||||
|
1024,
|
||||||
|
512,
|
||||||
|
256,
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"device": "cuda",
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
|
||||||
|
"epochs": 100,
|
||||||
|
"batch_size": 16384,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"checkpoint": null
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"date": "2026-04-07T22:56:23.259658",
|
||||||
|
"num_positions": 2086,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"epochs": 20,
|
||||||
|
"batch_size": 4096,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"final_val_loss": 0.016311248764395714,
|
||||||
|
"device": "cuda",
|
||||||
|
"checkpoint": null,
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||||
|
}
|
||||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"version": 2,
|
||||||
|
"date": "2026-04-07T23:50:05.390402",
|
||||||
|
"num_positions": 6886,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"epochs": 100,
|
||||||
|
"batch_size": 4096,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"final_val_loss": 0.007848377339541912,
|
||||||
|
"device": "cuda",
|
||||||
|
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v1.pt",
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||||
|
}
|
||||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"version": 3,
|
||||||
|
"date": "2026-04-08T09:43:28.000579",
|
||||||
|
"num_positions": 71610,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"epochs": 20,
|
||||||
|
"batch_size": 4096,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"final_val_loss": 0.006398905136849695,
|
||||||
|
"device": "cpu",
|
||||||
|
"checkpoint": "/home/janis/Workspaces/IntelliJ/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v2.pt",
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||||
|
}
|
||||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"version": 4,
|
||||||
|
"date": "2026-04-09T00:28:07.572209",
|
||||||
|
"num_positions": 2009355,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"epochs": 40,
|
||||||
|
"batch_size": 4096,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"final_val_loss": 9.106677896235248e-05,
|
||||||
|
"device": "cuda",
|
||||||
|
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v3.pt",
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||||
|
}
|
||||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"version": 5,
|
||||||
|
"date": "2026-04-09T18:50:27.845632",
|
||||||
|
"num_positions": 2009355,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"epochs": 100,
|
||||||
|
"batch_size": 16384,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"final_val_loss": 9.180421525105905e-05,
|
||||||
|
"device": "cuda",
|
||||||
|
"checkpoint": null,
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||||
|
}
|
||||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"version": 6,
|
||||||
|
"date": "2026-04-09T21:28:21.000832",
|
||||||
|
"num_positions": 1958728,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"epochs": 100,
|
||||||
|
"batch_size": 16384,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"final_val_loss": 0.2984530149085532,
|
||||||
|
"device": "cuda",
|
||||||
|
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v5.pt",
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||||
|
}
|
||||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"version": 7,
|
||||||
|
"date": "2026-04-09T22:06:50.439858",
|
||||||
|
"num_positions": 1958728,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"epochs": 100,
|
||||||
|
"batch_size": 16384,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"final_val_loss": 0.2997283308762831,
|
||||||
|
"device": "cuda",
|
||||||
|
"checkpoint": null,
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||||
|
}
|
||||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"version": 8,
|
||||||
|
"date": "2026-04-09T22:22:47.859730",
|
||||||
|
"num_positions": 1958728,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"epochs": 100,
|
||||||
|
"batch_size": 16384,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"final_val_loss": 0.24803777390839968,
|
||||||
|
"device": "cuda",
|
||||||
|
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v7.pt",
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
|
||||||
|
}
|
||||||
Binary file not shown.
@@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"version": 9,
|
||||||
|
"date": "2026-04-13T20:19:08.123315",
|
||||||
|
"num_positions": 2522562,
|
||||||
|
"stockfish_depth": 12,
|
||||||
|
"final_val_loss": 6.994176222619626e-05,
|
||||||
|
"device": "cuda",
|
||||||
|
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
|
||||||
|
"mode": "burst",
|
||||||
|
"duration_minutes": 30.0,
|
||||||
|
"epochs_per_season": 50,
|
||||||
|
"early_stopping_patience": 10,
|
||||||
|
"seasons_completed": 3,
|
||||||
|
"batch_size": 16384,
|
||||||
|
"learning_rate": 0.001,
|
||||||
|
"initial_checkpoint": "/home/janis/Workspaces/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v8.pt"
|
||||||
|
}
|
||||||
Binary file not shown.
@@ -0,0 +1,21 @@
|
|||||||
|
package de.nowchess.bot
|
||||||
|
|
||||||
|
import de.nowchess.api.bot.Bot
|
||||||
|
import de.nowchess.bot.bots.ClassicalBot
|
||||||
|
|
||||||
|
object BotController {
|
||||||
|
|
||||||
|
private val bots: Map[String, Bot] = Map(
|
||||||
|
"easy" -> ClassicalBot(BotDifficulty.Easy),
|
||||||
|
"medium" -> ClassicalBot(BotDifficulty.Medium),
|
||||||
|
"hard" -> ClassicalBot(BotDifficulty.Hard),
|
||||||
|
"expert" -> ClassicalBot(BotDifficulty.Expert),
|
||||||
|
)
|
||||||
|
|
||||||
|
/** Get a bot by name. */
|
||||||
|
def getBot(name: String): Option[Bot] = bots.get(name.toLowerCase)
|
||||||
|
|
||||||
|
/** List all available bot names. */
|
||||||
|
def listBots: List[String] = bots.keys.toList.sorted
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package de.nowchess.bot
|
||||||
|
|
||||||
|
enum BotDifficulty:
|
||||||
|
case Easy
|
||||||
|
case Medium
|
||||||
|
case Hard
|
||||||
|
case Expert
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package de.nowchess.bot
|
||||||
|
|
||||||
|
import de.nowchess.api.game.GameContext
|
||||||
|
import de.nowchess.api.move.Move
|
||||||
|
|
||||||
|
object BotMoveRepetition:
|
||||||
|
|
||||||
|
private val maxConsecutiveMoves = 3
|
||||||
|
|
||||||
|
def blockedMoves(context: GameContext): Set[Move] = repeatedMove(context).toSet
|
||||||
|
|
||||||
|
def repeatedMove(context: GameContext): Option[Move] =
|
||||||
|
context.moves.takeRight(maxConsecutiveMoves) match
|
||||||
|
case first :: second :: third :: Nil if first == second && second == third => Some(first)
|
||||||
|
case _ => None
|
||||||
|
|
||||||
|
def filterAllowed(context: GameContext, moves: List[Move]): List[Move] =
|
||||||
|
val blocked = blockedMoves(context)
|
||||||
|
moves.filterNot(blocked.contains)
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package de.nowchess.bot
|
||||||
|
|
||||||
|
object Config:
|
||||||
|
|
||||||
|
/** Threshold in centipawns: if classical evaluation differs from NNUE by more than this, the move is vetoed (not
|
||||||
|
* accepted as a suggestion).
|
||||||
|
*/
|
||||||
|
val VETO_THRESHOLD: Int = 150
|
||||||
|
|
||||||
|
/** Time budget per move for iterative deepening (milliseconds). */
|
||||||
|
val TIME_LIMIT_MS: Long = 2000L
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
package de.nowchess.bot.ai
|
||||||
|
|
||||||
|
import de.nowchess.api.game.GameContext
|
||||||
|
import de.nowchess.api.move.Move
|
||||||
|
|
||||||
|
trait Evaluation:
|
||||||
|
|
||||||
|
def CHECKMATE_SCORE: Int
|
||||||
|
def DRAW_SCORE: Int
|
||||||
|
|
||||||
|
def evaluate(context: GameContext): Int
|
||||||
|
|
||||||
|
// ── Accumulator hooks ─────────────────────────────────────────────────────
|
||||||
|
// Default implementations fall back to full re-evaluation each call.
|
||||||
|
// Override in NNUE-capable evaluators for incremental L1 speedup.
|
||||||
|
|
||||||
|
/** Initialise the accumulator for the root position at ply 0. */
|
||||||
|
def initAccumulator(context: GameContext): Unit = ()
|
||||||
|
|
||||||
|
/** Copy parent ply's accumulator to childPly without move deltas (null-move). */
|
||||||
|
def copyAccumulator(parentPly: Int, childPly: Int): Unit = ()
|
||||||
|
|
||||||
|
/** Derive childPly's accumulator from parentPly by applying move deltas. */
|
||||||
|
def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit = ()
|
||||||
|
|
||||||
|
/** Evaluate from the pre-computed accumulator at ply, using hash for the eval cache. Falls back to full evaluate when
|
||||||
|
* not overridden.
|
||||||
|
*/
|
||||||
|
def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int = evaluate(context)
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
package de.nowchess.bot.bots
|
||||||
|
|
||||||
|
import de.nowchess.api.bot.Bot
|
||||||
|
import de.nowchess.api.game.GameContext
|
||||||
|
import de.nowchess.api.move.Move
|
||||||
|
import de.nowchess.bot.bots.classic.EvaluationClassic
|
||||||
|
import de.nowchess.bot.logic.AlphaBetaSearch
|
||||||
|
import de.nowchess.bot.util.PolyglotBook
|
||||||
|
import de.nowchess.bot.{BotDifficulty, BotMoveRepetition}
|
||||||
|
import de.nowchess.rules.RuleSet
|
||||||
|
import de.nowchess.rules.sets.DefaultRules
|
||||||
|
|
||||||
|
final class ClassicalBot(
|
||||||
|
difficulty: BotDifficulty,
|
||||||
|
rules: RuleSet = DefaultRules,
|
||||||
|
book: Option[PolyglotBook] = None,
|
||||||
|
) extends Bot:
|
||||||
|
|
||||||
|
private val search: AlphaBetaSearch = AlphaBetaSearch(rules, weights = EvaluationClassic)
|
||||||
|
private val TIME_BUDGET_MS = 1000L
|
||||||
|
|
||||||
|
override val name: String = s"ClassicalBot(${difficulty.toString})"
|
||||||
|
|
||||||
|
override def nextMove(context: GameContext): Option[Move] =
|
||||||
|
val blockedMoves = BotMoveRepetition.blockedMoves(context)
|
||||||
|
book
|
||||||
|
.flatMap(_.probe(context))
|
||||||
|
.filterNot(blockedMoves.contains)
|
||||||
|
.orElse(search.bestMoveWithTime(context, TIME_BUDGET_MS, blockedMoves))
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package de.nowchess.bot.bots
|
||||||
|
|
||||||
|
import de.nowchess.api.bot.Bot
|
||||||
|
import de.nowchess.api.game.GameContext
|
||||||
|
import de.nowchess.api.move.Move
|
||||||
|
import de.nowchess.bot.ai.Evaluation
|
||||||
|
import de.nowchess.bot.bots.classic.EvaluationClassic
|
||||||
|
import de.nowchess.bot.bots.nnue.EvaluationNNUE
|
||||||
|
import de.nowchess.bot.logic.{AlphaBetaSearch, TranspositionTable}
|
||||||
|
import de.nowchess.bot.util.PolyglotBook
|
||||||
|
import de.nowchess.bot.{BotDifficulty, BotMoveRepetition, Config}
|
||||||
|
import de.nowchess.rules.RuleSet
|
||||||
|
import de.nowchess.rules.sets.DefaultRules
|
||||||
|
|
||||||
|
final class HybridBot(
|
||||||
|
difficulty: BotDifficulty,
|
||||||
|
rules: RuleSet = DefaultRules,
|
||||||
|
book: Option[PolyglotBook] = None,
|
||||||
|
nnueEvaluation: Evaluation = EvaluationNNUE,
|
||||||
|
classicalEvaluation: Evaluation = EvaluationClassic,
|
||||||
|
vetoReporter: String => Unit = println(_),
|
||||||
|
) extends Bot:
|
||||||
|
|
||||||
|
private val search = AlphaBetaSearch(rules, TranspositionTable(), classicalEvaluation)
|
||||||
|
|
||||||
|
override val name: String = s"HybridBot(${difficulty.toString})"
|
||||||
|
|
||||||
|
override def nextMove(context: GameContext): Option[Move] =
|
||||||
|
val blockedMoves = BotMoveRepetition.blockedMoves(context)
|
||||||
|
book.flatMap(_.probe(context)).filterNot(blockedMoves.contains).orElse(searchWithVeto(context, blockedMoves))
|
||||||
|
|
||||||
|
private def searchWithVeto(context: GameContext, blockedMoves: Set[Move]): Option[Move] =
|
||||||
|
search.bestMoveWithTime(context, Config.TIME_LIMIT_MS, blockedMoves).map { move =>
|
||||||
|
val next = rules.applyMove(context)(move)
|
||||||
|
val staticNnue = nnueEvaluation.evaluate(next)
|
||||||
|
val classical = classicalEvaluation.evaluate(next)
|
||||||
|
val diff = (classical - staticNnue).abs
|
||||||
|
if diff > Config.VETO_THRESHOLD then
|
||||||
|
vetoReporter(
|
||||||
|
f"[Veto] ${move.from}->${move.to}: nnue=$staticNnue classical=$classical diff=$diff — flagged but trusted (deep search)",
|
||||||
|
)
|
||||||
|
move
|
||||||
|
}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
package de.nowchess.bot.bots
|
||||||
|
|
||||||
|
import de.nowchess.api.bot.Bot
|
||||||
|
import de.nowchess.api.game.GameContext
|
||||||
|
import de.nowchess.api.move.Move
|
||||||
|
import de.nowchess.bot.bots.nnue.EvaluationNNUE
|
||||||
|
import de.nowchess.bot.logic.AlphaBetaSearch
|
||||||
|
import de.nowchess.bot.util.{PolyglotBook, ZobristHash}
|
||||||
|
import de.nowchess.bot.{BotDifficulty, BotMoveRepetition}
|
||||||
|
import de.nowchess.rules.RuleSet
|
||||||
|
import de.nowchess.rules.sets.DefaultRules
|
||||||
|
|
||||||
|
final class NNUEBot(
|
||||||
|
difficulty: BotDifficulty,
|
||||||
|
rules: RuleSet = DefaultRules,
|
||||||
|
book: Option[PolyglotBook] = None,
|
||||||
|
) extends Bot:
|
||||||
|
|
||||||
|
private val search: AlphaBetaSearch = AlphaBetaSearch(rules, weights = EvaluationNNUE)
|
||||||
|
|
||||||
|
override val name: String = s"NNUEBot(${difficulty.toString})"
|
||||||
|
|
||||||
|
override def nextMove(context: GameContext): Option[Move] =
|
||||||
|
val blockedMoves = BotMoveRepetition.blockedMoves(context)
|
||||||
|
book
|
||||||
|
.flatMap(_.probe(context))
|
||||||
|
.filterNot(blockedMoves.contains)
|
||||||
|
.orElse {
|
||||||
|
val moves = BotMoveRepetition.filterAllowed(context, rules.allLegalMoves(context))
|
||||||
|
if moves.isEmpty then None
|
||||||
|
else
|
||||||
|
val scored = batchEvaluateRoot(context, moves)
|
||||||
|
val bestMove = scored.maxBy(_._2)._1
|
||||||
|
search.bestMoveWithTime(context, allocateTime(scored), blockedMoves).orElse(Some(bestMove))
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Evaluate all root moves shallowly via incremental NNUE accumulator updates. Returns (move, score) pairs with score
|
||||||
|
* from the root player's perspective.
|
||||||
|
*/
|
||||||
|
private def batchEvaluateRoot(context: GameContext, moves: List[Move]): List[(Move, Int)] =
|
||||||
|
EvaluationNNUE.initAccumulator(context)
|
||||||
|
val rootHash = ZobristHash.hash(context)
|
||||||
|
moves.map { move =>
|
||||||
|
val child = rules.applyMove(context)(move)
|
||||||
|
val childHash = ZobristHash.nextHash(context, rootHash, move, child)
|
||||||
|
EvaluationNNUE.pushAccumulator(1, move, context, child)
|
||||||
|
val score = -EvaluationNNUE.evaluateAccumulator(1, child, childHash)
|
||||||
|
(move, score)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Allocate more time for complex positions; less when one move clearly dominates. */
|
||||||
|
private def allocateTime(scored: List[(Move, Int)]): Long =
|
||||||
|
val moveCount = scored.length
|
||||||
|
if moveCount > 30 then 1500L
|
||||||
|
else if moveCount < 5 then 500L
|
||||||
|
else
|
||||||
|
val scores = scored.map(_._2)
|
||||||
|
val best = scores.max
|
||||||
|
val second = scores.filter(_ < best).maxOption.getOrElse(best)
|
||||||
|
if best - second > 200 then 600L else 1000L
|
||||||
@@ -0,0 +1,361 @@
|
|||||||
|
package de.nowchess.bot.bots.classic
|
||||||
|
|
||||||
|
import de.nowchess.api.board.{Color, PieceType, Square}
|
||||||
|
import de.nowchess.api.game.GameContext
|
||||||
|
import de.nowchess.bot.ai.Evaluation
|
||||||
|
|
||||||
|
object EvaluationClassic extends Evaluation:
|
||||||
|
|
||||||
|
val CHECKMATE_SCORE: Int = 10_000_000
|
||||||
|
val DRAW_SCORE: Int = 0
|
||||||
|
|
||||||
|
// Material values in centipawns (indexed by PieceType.ordinal: Pawn=0, Knight=1, Bishop=2, Rook=3, Queen=4, King=5)
|
||||||
|
private val mgMaterial = Array(100, 325, 335, 500, 900, 20_000)
|
||||||
|
private val egMaterial = Array(110, 310, 310, 530, 1_000, 20_000)
|
||||||
|
|
||||||
|
private val TEMPO_BONUS: Int = 10
|
||||||
|
|
||||||
|
// Piece-square tables (Simplified Evaluation Function, Michniewski)
|
||||||
|
// Indexed by squareIndex = rank.ordinal * 8 + file.ordinal
|
||||||
|
// White's perspective: rank 0 = home (r1), rank 7 = back rank (r8)
|
||||||
|
// Black is vertically mirrored
|
||||||
|
|
||||||
|
private val mgPawnTable: Array[Int] = Array(
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 10, 10, 20, 30, 30, 20, 10, 10, 5, 5, 10, 25, 25, 10, 5, 5,
|
||||||
|
0, 0, 0, 20, 20, 0, 0, 0, 5, -5, -10, 0, 0, -10, -5, 5, 5, 10, 10, -20, -20, 10, 10, 5, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
private val egPawnTable: Array[Int] = Array(
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 70, 70, 70, 70, 70, 70, 70, 70, 40, 40, 40, 40, 40, 40, 40, 40, 30, 30, 30, 30, 30, 30, 30,
|
||||||
|
30, 20, 20, 20, 20, 20, 20, 20, 20, 10, 10, 10, 10, 10, 10, 10, 10, 5, 5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
private val mgKnightTable: Array[Int] = Array(
|
||||||
|
-50, -40, -30, -30, -30, -30, -40, -50, -40, -20, 0, 0, 0, 0, -20, -40, -30, 0, 10, 15, 15, 10, 0, -30, -30, 5, 15,
|
||||||
|
20, 20, 15, 5, -30, -30, 0, 15, 20, 20, 15, 0, -30, -30, 5, 10, 15, 15, 10, 5, -30, -40, -20, 0, 5, 5, 0, -20, -40,
|
||||||
|
-50, -40, -30, -30, -30, -30, -40, -50,
|
||||||
|
)
|
||||||
|
|
||||||
|
private val egKnightTable: Array[Int] = Array(
|
||||||
|
-30, -20, -10, -10, -10, -10, -20, -30, -20, 0, 5, 5, 5, 5, 0, -20, -10, 5, 15, 20, 20, 15, 5, -10, -10, 5, 20, 25,
|
||||||
|
25, 20, 5, -10, -10, 5, 20, 25, 25, 20, 5, -10, -10, 5, 15, 20, 20, 15, 5, -10, -20, 0, 5, 5, 5, 5, 0, -20, -30,
|
||||||
|
-20, -10, -10, -10, -10, -20, -30,
|
||||||
|
)
|
||||||
|
|
||||||
|
private val mgBishopTable: Array[Int] = Array(
|
||||||
|
-20, -10, -10, -10, -10, -10, -10, -20, -10, 0, 0, 0, 0, 0, 0, -10, -10, 0, 5, 10, 10, 5, 0, -10, -10, 5, 5, 10, 10,
|
||||||
|
5, 5, -10, -10, 0, 10, 10, 10, 10, 0, -10, -10, 10, 10, 10, 10, 10, 10, -10, -10, 5, 0, 0, 0, 0, 5, -10, -20, -10,
|
||||||
|
-10, -10, -10, -10, -10, -20,
|
||||||
|
)
|
||||||
|
|
||||||
|
private val egBishopTable: Array[Int] = Array(
|
||||||
|
-20, -10, -5, -5, -5, -5, -10, -20, -10, 0, 5, 5, 5, 5, 0, -10, -5, 5, 10, 10, 10, 10, 5, -5, -5, 5, 10, 15, 15, 10,
|
||||||
|
5, -5, -5, 5, 10, 15, 15, 10, 5, -5, -5, 5, 10, 10, 10, 10, 5, -5, -10, 0, 5, 5, 5, 5, 0, -10, -20, -10, -5, -5, -5,
|
||||||
|
-5, -10, -20,
|
||||||
|
)
|
||||||
|
|
||||||
|
private val mgRookTable: Array[Int] = Array(
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 5, 10, 10, 10, 10, 10, 10, 5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0,
|
||||||
|
0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, 0, 0, 0, 5, 5, 0, 0, 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
private val egRookTable: Array[Int] = Array(
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 5, 10, 10, 10, 10, 10, 10, 5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0,
|
||||||
|
0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, 0, 0, 0, 5, 5, 0, 0, 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
private val mgQueenTable: Array[Int] = Array(
|
||||||
|
-20, -10, -10, -5, -5, -10, -10, -20, -10, 0, 0, 0, 0, 0, 0, -10, -10, 0, 5, 5, 5, 5, 0, -10, -5, 0, 5, 5, 5, 5, 0,
|
||||||
|
-5, 0, 0, 5, 5, 5, 5, 0, -5, -10, 5, 5, 5, 5, 5, 0, -10, -10, 0, 5, 0, 0, 0, 0, -10, -20, -10, -10, -5, -5, -10,
|
||||||
|
-10, -20,
|
||||||
|
)
|
||||||
|
|
||||||
|
private val egQueenTable: Array[Int] = Array(
|
||||||
|
-15, -10, -8, -5, -5, -8, -10, -15, -10, 0, 3, 5, 5, 3, 0, -10, -8, 3, 10, 10, 10, 10, 3, -8, -5, 5, 10, 15, 15, 10,
|
||||||
|
5, -5, -5, 5, 10, 15, 15, 10, 5, -5, -8, 3, 10, 10, 10, 10, 3, -8, -10, 0, 3, 5, 5, 3, 0, -10, -15, -10, -8, -5, -5,
|
||||||
|
-8, -10, -15,
|
||||||
|
)
|
||||||
|
|
||||||
|
private val mgKingTable: Array[Int] = Array(
|
||||||
|
-30, -40, -40, -50, -50, -40, -40, -30, -30, -40, -40, -50, -50, -40, -40, -30, -30, -40, -40, -50, -50, -40, -40,
|
||||||
|
-30, -30, -40, -40, -50, -50, -40, -40, -30, -20, -30, -30, -40, -40, -30, -30, -20, -10, -20, -20, -20, -20, -20,
|
||||||
|
-20, -10, 20, 20, 0, 0, 0, 0, 20, 20, 20, 30, 10, 0, 0, 10, 30, 20,
|
||||||
|
)
|
||||||
|
|
||||||
|
private val egKingTable: Array[Int] = Array(
|
||||||
|
-50, -40, -30, -20, -20, -30, -40, -50, -30, -20, -10, 0, 0, -10, -20, -30, -30, -10, 20, 30, 30, 20, -10, -30, -30,
|
||||||
|
-10, 30, 40, 40, 30, -10, -30, -30, -10, 30, 40, 40, 30, -10, -30, -30, -10, 20, 30, 30, 20, -10, -30, -30, -30, 0,
|
||||||
|
0, 0, 0, -30, -30, -50, -30, -30, -30, -30, -30, -30, -50,
|
||||||
|
)
|
||||||
|
|
||||||
|
private val phaseWeight: Map[PieceType, Int] = Map(
|
||||||
|
PieceType.Knight -> 1,
|
||||||
|
PieceType.Bishop -> 1,
|
||||||
|
PieceType.Rook -> 2,
|
||||||
|
PieceType.Queen -> 4,
|
||||||
|
)
|
||||||
|
private val maxPhase = 24 // 4*4 + 4*2 + 4*1 + 4*1
|
||||||
|
|
||||||
|
private val passedPawnBonus: Array[Int] = Array(0, 5, 10, 20, 35, 60, 100, 0)
|
||||||
|
private val egPassedPawnBonus: Array[Int] = Array(0, 20, 40, 80, 150, 250, 400, 0)
|
||||||
|
|
||||||
|
// Pawn structure penalties
|
||||||
|
private val doubledMg = -10
|
||||||
|
private val doubledEg = -25
|
||||||
|
private val isolatedMg = -15
|
||||||
|
private val isolatedEg = -20
|
||||||
|
|
||||||
|
// Mobility weights: centipawns per reachable square (indexed by PieceType.ordinal)
|
||||||
|
private val mobilityMg = Array(0, 4, 3, 2, 1, 0, 0)
|
||||||
|
private val mobilityEg = Array(0, 4, 3, 4, 2, 0, 0)
|
||||||
|
|
||||||
|
// Direction offsets for sliding pieces
|
||||||
|
private val diagonals = List((-1, -1), (-1, 1), (1, -1), (1, 1))
|
||||||
|
private val orthogonals = List((-1, 0), (1, 0), (0, -1), (0, 1))
|
||||||
|
private val knightOffsets = List((-2, -1), (-2, 1), (-1, -2), (-1, 2), (1, -2), (1, 2), (2, -1), (2, 1))
|
||||||
|
|
||||||
|
// Rook and bishop bonuses
|
||||||
|
private val bishopPairMg = 50
|
||||||
|
private val bishopPairEg = 70
|
||||||
|
private val rookOn7thMg = 20
|
||||||
|
private val rookOn7thEg = 10
|
||||||
|
|
||||||
|
/** Evaluate the position from the perspective of context.turn. Positive = good for context.turn.
|
||||||
|
*/
|
||||||
|
def evaluate(context: GameContext): Int =
|
||||||
|
val phase = gamePhase(context.board)
|
||||||
|
val isEg = isEndgame(phase)
|
||||||
|
val material = materialAndPositional(context, phase)
|
||||||
|
val structure = pawnStructure(context, phase)
|
||||||
|
val mobility = mobilityScore(context, phase)
|
||||||
|
val rookBishop = rookAndBishopBonuses(context, phase)
|
||||||
|
val bonuses = positionalBonuses(context, phase, isEg)
|
||||||
|
val egBonuses = if isEg then endgameBonus(context) else 0
|
||||||
|
material + structure + mobility + rookBishop + bonuses + egBonuses + TEMPO_BONUS
|
||||||
|
|
||||||
|
private def gamePhase(board: de.nowchess.api.board.Board): Int =
|
||||||
|
val phase = board.pieces.values.foldLeft(0) { (acc, piece) =>
|
||||||
|
acc + phaseWeight.getOrElse(piece.pieceType, 0)
|
||||||
|
}
|
||||||
|
math.min(phase, maxPhase)
|
||||||
|
|
||||||
|
private def isEndgame(phase: Int): Boolean =
|
||||||
|
phase < 8 // Significantly reduced material indicates endgame
|
||||||
|
|
||||||
|
private def taper(mg: Int, eg: Int, phase: Int): Int =
|
||||||
|
(mg * phase + eg * (maxPhase - phase)) / maxPhase
|
||||||
|
|
||||||
|
private def materialAndPositional(context: GameContext, phase: Int): Int =
|
||||||
|
val (mg, eg) = context.board.pieces.foldLeft((0, 0)) { case ((mg, eg), (square, piece)) =>
|
||||||
|
val (psqMg, psqEg) = squareBonus(piece.pieceType, piece.color, square)
|
||||||
|
val pieceMg = mgMaterial(piece.pieceType.ordinal) + psqMg
|
||||||
|
val pieceEg = egMaterial(piece.pieceType.ordinal) + psqEg
|
||||||
|
val sign = if piece.color == context.turn then 1 else -1
|
||||||
|
(mg + sign * pieceMg, eg + sign * pieceEg)
|
||||||
|
}
|
||||||
|
taper(mg, eg, phase)
|
||||||
|
|
||||||
|
private def squareBonus(pieceType: PieceType, color: Color, sq: Square): (Int, Int) =
|
||||||
|
val rankIdx = if color == Color.White then sq.rank.ordinal else 7 - sq.rank.ordinal
|
||||||
|
val fileIdx = sq.file.ordinal
|
||||||
|
val squareIdx = rankIdx * 8 + fileIdx
|
||||||
|
|
||||||
|
pieceType match
|
||||||
|
case PieceType.Pawn => (mgPawnTable(squareIdx), egPawnTable(squareIdx))
|
||||||
|
case PieceType.Knight => (mgKnightTable(squareIdx), egKnightTable(squareIdx))
|
||||||
|
case PieceType.Bishop => (mgBishopTable(squareIdx), egBishopTable(squareIdx))
|
||||||
|
case PieceType.Rook => (mgRookTable(squareIdx), egRookTable(squareIdx))
|
||||||
|
case PieceType.Queen => (mgQueenTable(squareIdx), egQueenTable(squareIdx))
|
||||||
|
case PieceType.King => (mgKingTable(squareIdx), egKingTable(squareIdx))
|
||||||
|
|
||||||
|
private def pawnStructure(context: GameContext, phase: Int): Int =
|
||||||
|
val friendlyPawns = context.board.pieces.filter((_, p) => p.color == context.turn && p.pieceType == PieceType.Pawn)
|
||||||
|
val enemyPawns = context.board.pieces.filter((_, p) => p.color != context.turn && p.pieceType == PieceType.Pawn)
|
||||||
|
|
||||||
|
val friendlyByFile = friendlyPawns.groupMap(s => s._1.file.ordinal)(s => s._1.rank.ordinal)
|
||||||
|
val enemyByFile = enemyPawns.groupMap(s => s._1.file.ordinal)(s => s._1.rank.ordinal)
|
||||||
|
|
||||||
|
val (fMg, fEg) = structureScore(friendlyByFile)
|
||||||
|
val (eMg, eEg) = structureScore(enemyByFile)
|
||||||
|
taper(fMg - eMg, fEg - eEg, phase)
|
||||||
|
|
||||||
|
private def structureScore(byFile: Map[Int, Iterable[Int]]): (Int, Int) =
|
||||||
|
byFile.foldLeft((0, 0)) { case ((mg, eg), (file, ranks)) =>
|
||||||
|
val doubled = (ranks.size - 1).max(0)
|
||||||
|
val hasAdjacent = (file - 1 to file + 1).filter(f => f >= 0 && f < 8 && f != file).exists(byFile.contains)
|
||||||
|
val isolated = if !hasAdjacent then ranks.size else 0
|
||||||
|
(mg + doubled * doubledMg + isolated * isolatedMg, eg + doubled * doubledEg + isolated * isolatedEg)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def positionalBonuses(context: GameContext, phase: Int, isEg: Boolean): Int =
|
||||||
|
context.board.pieces.foldLeft(0) { case (score, (sq, piece)) =>
|
||||||
|
val bonus = piece.pieceType match
|
||||||
|
case PieceType.Pawn =>
|
||||||
|
if isPassedPawn(context.board, sq, piece.color) then
|
||||||
|
if isEg then egPassedPawnBonus(sq.rank.ordinal) else passedPawnBonus(sq.rank.ordinal)
|
||||||
|
else 0
|
||||||
|
case PieceType.Rook => rookOpenFileBonus(context.board, sq, piece.color)
|
||||||
|
case PieceType.King => kingShieldBonus(context.board, sq, piece.color, phase)
|
||||||
|
case _ => 0
|
||||||
|
if piece.color == context.turn then score + bonus else score - bonus
|
||||||
|
}
|
||||||
|
|
||||||
|
private def isPassedPawn(board: de.nowchess.api.board.Board, sq: Square, color: Color): Boolean =
|
||||||
|
val enemyColor = color.opposite
|
||||||
|
val pawnRank = sq.rank.ordinal
|
||||||
|
val fileRange = (sq.file.ordinal - 1 to sq.file.ordinal + 1).filter(f => f >= 0 && f < 8)
|
||||||
|
val rankCheck = if color == Color.White then (r: Int) => r > pawnRank else (r: Int) => r < pawnRank
|
||||||
|
|
||||||
|
board.pieces.forall { (enemySq, enemyPiece) =>
|
||||||
|
!(enemyPiece.color == enemyColor &&
|
||||||
|
enemyPiece.pieceType == PieceType.Pawn &&
|
||||||
|
fileRange.contains(enemySq.file.ordinal) &&
|
||||||
|
rankCheck(enemySq.rank.ordinal))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def rookOpenFileBonus(board: de.nowchess.api.board.Board, rookSq: Square, color: Color): Int =
|
||||||
|
val hasFriendlyPawn = board.pieces.exists { (sq, piece) =>
|
||||||
|
piece.color == color && piece.pieceType == PieceType.Pawn && sq.file == rookSq.file
|
||||||
|
}
|
||||||
|
val hasEnemyPawn = board.pieces.exists { (sq, piece) =>
|
||||||
|
piece.color != color && piece.pieceType == PieceType.Pawn && sq.file == rookSq.file
|
||||||
|
}
|
||||||
|
if !hasFriendlyPawn && !hasEnemyPawn then 20 // open file
|
||||||
|
else if !hasFriendlyPawn then 10 // semi-open file
|
||||||
|
else 0
|
||||||
|
|
||||||
|
private def kingShieldBonus(board: de.nowchess.api.board.Board, kingSq: Square, color: Color, phase: Int): Int =
|
||||||
|
val shieldRankDelta = if color == Color.White then 1 else -1
|
||||||
|
val shieldFiles = (kingSq.file.ordinal - 1 to kingSq.file.ordinal + 1).filter(f => f >= 0 && f < 8)
|
||||||
|
val shieldRank = kingSq.rank.ordinal + shieldRankDelta
|
||||||
|
|
||||||
|
if shieldRank < 0 || shieldRank > 7 then 0
|
||||||
|
else
|
||||||
|
val rawBonus = board.pieces.count { (sq, piece) =>
|
||||||
|
piece.color == color &&
|
||||||
|
piece.pieceType == PieceType.Pawn &&
|
||||||
|
shieldFiles.contains(sq.file.ordinal) &&
|
||||||
|
sq.rank.ordinal == shieldRank
|
||||||
|
} * 10
|
||||||
|
(rawBonus * phase) / maxPhase
|
||||||
|
|
||||||
|
private def slidingCount(
|
||||||
|
sq: Square,
|
||||||
|
board: de.nowchess.api.board.Board,
|
||||||
|
color: Color,
|
||||||
|
directions: List[(Int, Int)],
|
||||||
|
): Int =
|
||||||
|
directions.foldLeft(0) { case (total, (fileDelta, rankDelta)) =>
|
||||||
|
@scala.annotation.tailrec
|
||||||
|
def countRay(current: Option[Square], acc: Int): Int =
|
||||||
|
current match
|
||||||
|
case None => acc
|
||||||
|
case Some(target) =>
|
||||||
|
board.pieceAt(target) match
|
||||||
|
case Some(piece) if piece.color == color => acc
|
||||||
|
case Some(_) => acc + 1
|
||||||
|
case None => countRay(target.offset(fileDelta, rankDelta), acc + 1)
|
||||||
|
total + countRay(sq.offset(fileDelta, rankDelta), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def knightCount(sq: Square, board: de.nowchess.api.board.Board, color: Color): Int =
|
||||||
|
knightOffsets.count { case (fileDelta, rankDelta) =>
|
||||||
|
sq.offset(fileDelta, rankDelta).forall { target =>
|
||||||
|
board.pieceAt(target).forall(_.color != color)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def mobilityScore(context: GameContext, phase: Int): Int =
|
||||||
|
val (mg, eg) = context.board.pieces.foldLeft((0, 0)) { case ((mg, eg), (sq, piece)) =>
|
||||||
|
val count = piece.pieceType match
|
||||||
|
case PieceType.Knight => knightCount(sq, context.board, piece.color)
|
||||||
|
case PieceType.Bishop => slidingCount(sq, context.board, piece.color, diagonals)
|
||||||
|
case PieceType.Rook => slidingCount(sq, context.board, piece.color, orthogonals)
|
||||||
|
case PieceType.Queen => slidingCount(sq, context.board, piece.color, diagonals ++ orthogonals)
|
||||||
|
case _ => 0
|
||||||
|
val pieceMg = count * mobilityMg(piece.pieceType.ordinal)
|
||||||
|
val pieceEg = count * mobilityEg(piece.pieceType.ordinal)
|
||||||
|
val sign = if piece.color == context.turn then 1 else -1
|
||||||
|
(mg + sign * pieceMg, eg + sign * pieceEg)
|
||||||
|
}
|
||||||
|
taper(mg, eg, phase)
|
||||||
|
|
||||||
|
private def rookAndBishopBonuses(context: GameContext, phase: Int): Int =
|
||||||
|
val (baseMg, baseEg) = bishopPairBase(context)
|
||||||
|
val (rookMg, rookEg) = rookOn7thDelta(context)
|
||||||
|
taper(baseMg + rookMg, baseEg + rookEg, phase)
|
||||||
|
|
||||||
|
private def bishopPairBase(context: GameContext): (Int, Int) =
|
||||||
|
val friendlyHasPair = hasBishopPair(context, context.turn)
|
||||||
|
val enemyHasPair = hasBishopPair(context, context.turn.opposite)
|
||||||
|
val mg = pairDelta(friendlyHasPair, enemyHasPair, bishopPairMg)
|
||||||
|
val eg = pairDelta(friendlyHasPair, enemyHasPair, bishopPairEg)
|
||||||
|
(mg, eg)
|
||||||
|
|
||||||
|
private def hasBishopPair(context: GameContext, color: Color): Boolean =
|
||||||
|
val bishopSquares = context.board.pieces.collect {
|
||||||
|
case (sq, piece) if piece.color == color && piece.pieceType == PieceType.Bishop => sq
|
||||||
|
}
|
||||||
|
bishopSquares.exists(isEvenSquare) && bishopSquares.exists(sq => !isEvenSquare(sq))
|
||||||
|
|
||||||
|
private def isEvenSquare(square: Square): Boolean =
|
||||||
|
(square.file.ordinal + square.rank.ordinal) % 2 == 0
|
||||||
|
|
||||||
|
private def pairDelta(friendlyHasPair: Boolean, enemyHasPair: Boolean, bonus: Int): Int =
|
||||||
|
(if friendlyHasPair then bonus else 0) - (if enemyHasPair then bonus else 0)
|
||||||
|
|
||||||
|
private def rookOn7thDelta(context: GameContext): (Int, Int) =
|
||||||
|
context.board.pieces.foldLeft((0, 0)) { case ((mg, eg), (sq, piece)) =>
|
||||||
|
rookOn7thContribution(piece, sq, context.turn).fold((mg, eg)) { case (dMg, dEg) =>
|
||||||
|
(mg + dMg, eg + dEg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def rookOn7thContribution(piece: de.nowchess.api.board.Piece, sq: Square, turn: Color): Option[(Int, Int)] =
|
||||||
|
Option.when(piece.pieceType == PieceType.Rook && isRookOn7th(piece.color, sq)) {
|
||||||
|
val sign = if piece.color == turn then 1 else -1
|
||||||
|
(sign * rookOn7thMg, sign * rookOn7thEg)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def isRookOn7th(color: Color, sq: Square): Boolean =
|
||||||
|
if color == Color.White then sq.rank.ordinal == 6 else sq.rank.ordinal == 1
|
||||||
|
|
||||||
|
private def endgameBonus(context: GameContext): Int =
|
||||||
|
val friendlyKing = context.board.pieces.find((_, p) => p.color == context.turn && p.pieceType == PieceType.King)
|
||||||
|
val enemyKing = context.board.pieces.find((_, p) => p.color != context.turn && p.pieceType == PieceType.King)
|
||||||
|
|
||||||
|
val kingCentralBonus =
|
||||||
|
friendlyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15) -
|
||||||
|
enemyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15)
|
||||||
|
|
||||||
|
val friendlyMaterial = materialCount(context, context.turn)
|
||||||
|
val enemyMaterial = materialCount(context, context.turn.opposite)
|
||||||
|
val edgeBonus =
|
||||||
|
if friendlyMaterial > enemyMaterial then enemyKing.fold(0)((kSq, _) => (7 - kingEdgeDistance(kSq)) * 10)
|
||||||
|
else 0
|
||||||
|
|
||||||
|
kingCentralBonus + edgeBonus
|
||||||
|
|
||||||
|
private def kingCentralizationDistance(sq: Square): Int =
|
||||||
|
val fileFromCenter = (sq.file.ordinal - 3.5).abs.toInt
|
||||||
|
val rankFromCenter = (sq.rank.ordinal - 3.5).abs.toInt
|
||||||
|
math.max(fileFromCenter, rankFromCenter)
|
||||||
|
|
||||||
|
private def kingEdgeDistance(sq: Square): Int =
|
||||||
|
val fileFromEdge = math.min(sq.file.ordinal, 7 - sq.file.ordinal)
|
||||||
|
val rankFromEdge = math.min(sq.rank.ordinal, 7 - sq.rank.ordinal)
|
||||||
|
math.min(fileFromEdge, rankFromEdge)
|
||||||
|
|
||||||
|
private def materialCount(context: GameContext, color: Color): Int =
|
||||||
|
context.board.pieces.foldLeft(0) { case (sum, (_, piece)) =>
|
||||||
|
if piece.color == color then
|
||||||
|
sum + (piece.pieceType match
|
||||||
|
case PieceType.Knight => 300
|
||||||
|
case PieceType.Bishop => 300
|
||||||
|
case PieceType.Rook => 500
|
||||||
|
case PieceType.Queen => 900
|
||||||
|
case PieceType.Pawn => 0
|
||||||
|
case PieceType.King => 0
|
||||||
|
)
|
||||||
|
else sum
|
||||||
|
}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
package de.nowchess.bot.bots.nnue
|
||||||
|
|
||||||
|
import de.nowchess.api.game.GameContext
|
||||||
|
import de.nowchess.api.move.Move
|
||||||
|
import de.nowchess.bot.ai.Evaluation
|
||||||
|
|
||||||
|
object EvaluationNNUE extends Evaluation:
|
||||||
|
|
||||||
|
private val nnue = NNUE(NbaiLoader.loadDefault())
|
||||||
|
|
||||||
|
val CHECKMATE_SCORE: Int = 10_000_000
|
||||||
|
val DRAW_SCORE: Int = 0
|
||||||
|
|
||||||
|
/** Full-board evaluate — used as fallback and by non-search callers. */
|
||||||
|
def evaluate(context: GameContext): Int = nnue.evaluate(context)
|
||||||
|
|
||||||
|
// ── Accumulator hooks (incremental L1) ───────────────────────────────────
|
||||||
|
|
||||||
|
override def initAccumulator(context: GameContext): Unit =
|
||||||
|
nnue.initAccumulator(context.board)
|
||||||
|
|
||||||
|
override def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
||||||
|
nnue.copyAccumulator(parentPly, childPly)
|
||||||
|
|
||||||
|
override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit =
|
||||||
|
// Use incremental updates, but recompute from scratch every 10 plies to prevent accumulation errors
|
||||||
|
if childPly % 10 == 0 then nnue.recomputeAccumulator(childPly, child.board)
|
||||||
|
else nnue.pushAccumulator(childPly, move, parent.board)
|
||||||
|
|
||||||
|
override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int =
|
||||||
|
nnue.evaluateAtPlyWithValidation(ply, context.turn, hash, context.board)
|
||||||
@@ -0,0 +1,231 @@
|
|||||||
|
package de.nowchess.bot.bots.nnue
|
||||||
|
|
||||||
|
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Square}
|
||||||
|
import de.nowchess.api.game.GameContext
|
||||||
|
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||||
|
|
||||||
|
class NNUE(model: NbaiModel):
|
||||||
|
|
||||||
|
private val featureSize = model.layers(0).inputSize
|
||||||
|
private val accSize = model.layers(0).outputSize
|
||||||
|
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
|
||||||
|
|
||||||
|
// Column-major L1 weights for cache-friendly sparse & incremental updates.
|
||||||
|
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
|
||||||
|
private val l1WeightsT: Array[Float] =
|
||||||
|
val w = model.weights(0).weights
|
||||||
|
val t = new Array[Float](featureSize * accSize)
|
||||||
|
for j <- 0 until featureSize; i <- 0 until accSize do t(j * accSize + i) = w(i * featureSize + j)
|
||||||
|
t
|
||||||
|
|
||||||
|
// ── Accumulator stack ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
private val MAX_PLY = 128
|
||||||
|
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](accSize))
|
||||||
|
|
||||||
|
// Shared evaluation buffers: index i holds the output of layers(i) (all except the scalar output layer).
|
||||||
|
private val evalBuffers: Array[Array[Float]] = model.layers.init.map(l => new Array[Float](l.outputSize))
|
||||||
|
|
||||||
|
// ── Eval cache ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
private val EVAL_CACHE_MASK = (1 << 18) - 1L
|
||||||
|
private val evalCacheHashes = new Array[Long](1 << 18)
|
||||||
|
private val evalCacheScores = new Array[Int](1 << 18)
|
||||||
|
|
||||||
|
// ── Feature helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
private def squareNum(sq: Square): Int = sq.rank.ordinal * 8 + sq.file.ordinal
|
||||||
|
|
||||||
|
private def featureIndex(piece: Piece, sqNum: Int): Int =
|
||||||
|
val colorOffset = if piece.color == Color.White then 6 else 0
|
||||||
|
(colorOffset + piece.pieceType.ordinal) * 64 + sqNum
|
||||||
|
|
||||||
|
private def addColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
||||||
|
val offset = featureIdx * accSize
|
||||||
|
for i <- 0 until accSize do l1Pre(i) += l1WeightsT(offset + i)
|
||||||
|
|
||||||
|
private def subtractColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
|
||||||
|
val offset = featureIdx * accSize
|
||||||
|
for i <- 0 until accSize do l1Pre(i) -= l1WeightsT(offset + i)
|
||||||
|
|
||||||
|
// ── Accumulator init ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def initAccumulator(board: Board): Unit =
|
||||||
|
System.arraycopy(model.weights(0).bias, 0, l1Stack(0), 0, accSize)
|
||||||
|
for (sq, piece) <- board.pieces do addColumn(l1Stack(0), featureIndex(piece, squareNum(sq)))
|
||||||
|
|
||||||
|
// ── Accumulator push (incremental updates) ───────────────────────────────
|
||||||
|
|
||||||
|
def pushAccumulator(childPly: Int, move: Move, board: Board): Unit =
|
||||||
|
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize)
|
||||||
|
val l1 = l1Stack(childPly)
|
||||||
|
move.moveType match
|
||||||
|
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
|
||||||
|
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
|
||||||
|
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
|
||||||
|
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
|
||||||
|
|
||||||
|
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
|
||||||
|
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
|
||||||
|
|
||||||
|
def recomputeAccumulator(ply: Int, board: Board): Unit =
|
||||||
|
System.arraycopy(model.weights(0).bias, 0, l1Stack(ply), 0, accSize)
|
||||||
|
for (sq, piece) <- board.pieces do addColumn(l1Stack(ply), featureIndex(piece, squareNum(sq)))
|
||||||
|
|
||||||
|
def validateAccumulator(ply: Int, board: Board): Boolean =
|
||||||
|
// Compute what L1 should be from scratch
|
||||||
|
val expectedL1 = new Array[Float](accSize)
|
||||||
|
System.arraycopy(model.weights(0).bias, 0, expectedL1, 0, accSize)
|
||||||
|
for (sq, piece) <- board.pieces do addColumn(expectedL1, featureIndex(piece, squareNum(sq)))
|
||||||
|
|
||||||
|
// Compare with actual L1
|
||||||
|
val actual = l1Stack(ply)
|
||||||
|
val maxError =
|
||||||
|
(0 until accSize).foldLeft(0f) { (currentMax, i) =>
|
||||||
|
val error = math.abs(actual(i) - expectedL1(i))
|
||||||
|
math.max(currentMax, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxError < 0.001f // Allow small floating-point errors
|
||||||
|
|
||||||
|
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||||
|
// Extract source and destination square indices early
|
||||||
|
val fromNum = squareNum(move.from)
|
||||||
|
val toNum = squareNum(move.to)
|
||||||
|
|
||||||
|
// Get the moving piece
|
||||||
|
board.pieceAt(move.from).foreach { mover =>
|
||||||
|
subtractColumn(l1, featureIndex(mover, fromNum))
|
||||||
|
|
||||||
|
// If there's a capture, subtract the captured piece
|
||||||
|
board.pieceAt(move.to).foreach { cap =>
|
||||||
|
subtractColumn(l1, featureIndex(cap, toNum))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the piece to its new location
|
||||||
|
addColumn(l1, featureIndex(mover, toNum))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def applyEnPassantDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||||
|
board.pieceAt(move.from).foreach { pawn =>
|
||||||
|
val capturedSq = Square(move.to.file, move.from.rank)
|
||||||
|
subtractColumn(l1, featureIndex(pawn, squareNum(move.from)))
|
||||||
|
board.pieceAt(capturedSq).foreach(cap => subtractColumn(l1, featureIndex(cap, squareNum(capturedSq))))
|
||||||
|
addColumn(l1, featureIndex(pawn, squareNum(move.to)))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def applyCastleDelta(l1: Array[Float], move: Move, board: Board): Unit =
|
||||||
|
board.pieceAt(move.from).foreach { king =>
|
||||||
|
val rank = move.from.rank
|
||||||
|
val kingside = move.moveType == MoveType.CastleKingside
|
||||||
|
val (rookFrom, rookTo) =
|
||||||
|
if kingside then (Square(File.H, rank), Square(File.F, rank))
|
||||||
|
else (Square(File.A, rank), Square(File.D, rank))
|
||||||
|
val rook = Piece(king.color, PieceType.Rook)
|
||||||
|
subtractColumn(l1, featureIndex(king, squareNum(move.from)))
|
||||||
|
addColumn(l1, featureIndex(king, squareNum(move.to)))
|
||||||
|
subtractColumn(l1, featureIndex(rook, squareNum(rookFrom)))
|
||||||
|
addColumn(l1, featureIndex(rook, squareNum(rookTo)))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def applyPromotionDelta(l1: Array[Float], move: Move, promo: PromotionPiece, board: Board): Unit =
|
||||||
|
board.pieceAt(move.from).foreach { pawn =>
|
||||||
|
val toNum = squareNum(move.to)
|
||||||
|
subtractColumn(l1, featureIndex(pawn, squareNum(move.from)))
|
||||||
|
board.pieceAt(move.to).foreach(cap => subtractColumn(l1, featureIndex(cap, toNum)))
|
||||||
|
addColumn(l1, featureIndex(Piece(pawn.color, promotedType(promo)), toNum))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def promotedType(promo: PromotionPiece): PieceType = promo match
|
||||||
|
case PromotionPiece.Knight => PieceType.Knight
|
||||||
|
case PromotionPiece.Bishop => PieceType.Bishop
|
||||||
|
case PromotionPiece.Rook => PieceType.Rook
|
||||||
|
case PromotionPiece.Queen => PieceType.Queen
|
||||||
|
|
||||||
|
// ── Evaluation from accumulator ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def evaluateAtPly(ply: Int, turn: Color, hash: Long): Int =
|
||||||
|
val idx = (hash & EVAL_CACHE_MASK).toInt
|
||||||
|
if evalCacheHashes(idx) == hash then evalCacheScores(idx)
|
||||||
|
else
|
||||||
|
val score = runL2toOutput(l1Stack(ply), turn)
|
||||||
|
evalCacheHashes(idx) = hash
|
||||||
|
evalCacheScores(idx) = score
|
||||||
|
score
|
||||||
|
|
||||||
|
def evaluateAtPlyWithValidation(ply: Int, turn: Color, hash: Long, board: Board): Int =
|
||||||
|
// For debugging: validate that incremental accumulator matches recomputation
|
||||||
|
if validateAccum && ply > 0 && ply % 10 != 0 then
|
||||||
|
val isValid = validateAccumulator(ply, board)
|
||||||
|
if !isValid then System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
|
||||||
|
evaluateAtPly(ply, turn, hash)
|
||||||
|
|
||||||
|
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
|
||||||
|
val l1ReLU = evalBuffers(0)
|
||||||
|
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
|
||||||
|
|
||||||
|
val finalInput =
|
||||||
|
(1 until model.layers.length - 1).foldLeft(l1ReLU) { (input, i) =>
|
||||||
|
val lw = model.weights(i)
|
||||||
|
val out = evalBuffers(i)
|
||||||
|
val ld = model.layers(i)
|
||||||
|
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
val lastIdx = model.layers.length - 1
|
||||||
|
val output = runOutputLayer(finalInput, model.layers(lastIdx).inputSize, model.weights(lastIdx))
|
||||||
|
scoreFromOutput(output, turn)
|
||||||
|
|
||||||
|
private def runDenseReLU(
|
||||||
|
input: Array[Float],
|
||||||
|
inSize: Int,
|
||||||
|
weights: Array[Float],
|
||||||
|
bias: Array[Float],
|
||||||
|
output: Array[Float],
|
||||||
|
outSize: Int,
|
||||||
|
): Unit =
|
||||||
|
for i <- 0 until outSize do
|
||||||
|
val sum = (0 until inSize).foldLeft(bias(i))((s, j) => s + input(j) * weights(i * inSize + j))
|
||||||
|
output(i) = if sum > 0f then sum else 0f
|
||||||
|
|
||||||
|
private def runOutputLayer(input: Array[Float], inSize: Int, lw: LayerWeights): Float =
|
||||||
|
(0 until inSize).foldLeft(lw.bias(0))((sum, j) => sum + input(j) * lw.weights(j))
|
||||||
|
|
||||||
|
private def scoreFromOutput(output: Float, turn: Color): Int =
|
||||||
|
val cp =
|
||||||
|
if math.abs(output) >= 0.9999f then if output > 0f then 20000 else -20000
|
||||||
|
else
|
||||||
|
val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat
|
||||||
|
(300f * atanh).toInt
|
||||||
|
val cpFromTurn = if turn == Color.Black then -cp else cp
|
||||||
|
math.max(-20000, math.min(20000, cpFromTurn))
|
||||||
|
|
||||||
|
// ── Legacy full-board evaluate ────────────────────────────────────────────
|
||||||
|
|
||||||
|
private val legacyL1 = new Array[Float](accSize)
|
||||||
|
|
||||||
|
def evaluate(context: GameContext): Int =
|
||||||
|
System.arraycopy(model.weights(0).bias, 0, legacyL1, 0, accSize)
|
||||||
|
for (sq, piece) <- context.board.pieces do addColumn(legacyL1, featureIndex(piece, squareNum(sq)))
|
||||||
|
runL2toOutput(legacyL1, context.turn)
|
||||||
|
|
||||||
|
def benchmark(): Unit =
|
||||||
|
val context = GameContext.initial
|
||||||
|
val iterations = 1_000_000
|
||||||
|
for _ <- 0 until 10000 do evaluate(context)
|
||||||
|
val startNanos = System.nanoTime()
|
||||||
|
for _ <- 0 until iterations do evaluate(context)
|
||||||
|
val endNanos = System.nanoTime()
|
||||||
|
val totalNanos = endNanos - startNanos
|
||||||
|
val nanosPerEval = totalNanos.toDouble / iterations
|
||||||
|
println()
|
||||||
|
println("=" * 60)
|
||||||
|
println("NNUE BENCHMARK RESULTS")
|
||||||
|
println("=" * 60)
|
||||||
|
println(f"Iterations: $iterations%,d")
|
||||||
|
println(f"Total time: ${totalNanos / 1e9}%.2f seconds")
|
||||||
|
println(f"ns/eval: $nanosPerEval%.2f ns")
|
||||||
|
println(f"evals/second: ${1e9 / nanosPerEval}%.0f evals/s")
|
||||||
|
println("=" * 60)
|
||||||
|
println()
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package de.nowchess.bot.bots.nnue
|
||||||
|
|
||||||
|
import java.io.InputStream
|
||||||
|
import java.nio.{ByteBuffer, ByteOrder}
|
||||||
|
import java.nio.charset.StandardCharsets
|
||||||
|
|
||||||
|
object NbaiLoader:
|
||||||
|
|
||||||
|
/** Little-endian encoding of ASCII bytes 'N','B','A','I'. */
|
||||||
|
val MAGIC: Int = 0x4942_414e
|
||||||
|
|
||||||
|
def load(stream: InputStream): NbaiModel =
|
||||||
|
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
checkHeader(buf)
|
||||||
|
val metadata = readMetadata(buf)
|
||||||
|
val descs = readLayerDescriptors(buf)
|
||||||
|
val weights = descs.map(_ => readLayerWeights(buf))
|
||||||
|
NbaiModel(metadata, descs, weights)
|
||||||
|
|
||||||
|
/** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */
|
||||||
|
def loadDefault(): NbaiModel =
|
||||||
|
Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match
|
||||||
|
case Some(s) =>
|
||||||
|
try load(s)
|
||||||
|
finally s.close()
|
||||||
|
case None => NbaiMigrator.migrateFromBin()
|
||||||
|
|
||||||
|
private def checkHeader(buf: ByteBuffer): Unit =
|
||||||
|
val magic = buf.getInt()
|
||||||
|
if magic != MAGIC then sys.error(s"Invalid NBAI magic: 0x${magic.toHexString}")
|
||||||
|
val version = buf.getShort() & 0xffff
|
||||||
|
if version != 1 then sys.error(s"Unsupported NBAI version: $version")
|
||||||
|
|
||||||
|
private def readMetadata(buf: ByteBuffer): NbaiMetadata =
|
||||||
|
val bytes = new Array[Byte](buf.getInt())
|
||||||
|
buf.get(bytes)
|
||||||
|
NbaiMetadata.fromJson(new String(bytes, StandardCharsets.UTF_8))
|
||||||
|
|
||||||
|
private def readLayerDescriptors(buf: ByteBuffer): Array[LayerDescriptor] =
|
||||||
|
Array.tabulate(buf.getShort() & 0xffff) { _ =>
|
||||||
|
val nameBytes = new Array[Byte](buf.get() & 0xff)
|
||||||
|
buf.get(nameBytes)
|
||||||
|
LayerDescriptor(new String(nameBytes, StandardCharsets.US_ASCII), buf.getInt(), buf.getInt())
|
||||||
|
}
|
||||||
|
|
||||||
|
private def readLayerWeights(buf: ByteBuffer): LayerWeights =
|
||||||
|
LayerWeights(readFloats(buf), readFloats(buf))
|
||||||
|
|
||||||
|
private def readFloats(buf: ByteBuffer): Array[Float] =
|
||||||
|
val arr = new Array[Float](buf.getInt())
|
||||||
|
for i <- arr.indices do arr(i) = buf.getFloat()
|
||||||
|
arr
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package de.nowchess.bot.bots.nnue
|
||||||
|
|
||||||
|
import java.nio.{ByteBuffer, ByteOrder}
|
||||||
|
|
||||||
|
/** Converts the legacy nnue_weights.bin resource into an NbaiModel. Used as fallback when no .nbai file exists. */
|
||||||
|
object NbaiMigrator:
|
||||||
|
|
||||||
|
private val BinMagic = 0x4555_4e4e
|
||||||
|
private val BinVersion = 1
|
||||||
|
|
||||||
|
private val DefaultLayers: Array[LayerDescriptor] = Array(
|
||||||
|
LayerDescriptor("relu", 768, 1536),
|
||||||
|
LayerDescriptor("relu", 1536, 1024),
|
||||||
|
LayerDescriptor("relu", 1024, 512),
|
||||||
|
LayerDescriptor("relu", 512, 256),
|
||||||
|
LayerDescriptor("linear", 256, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
private val UnknownMetadata: NbaiMetadata =
|
||||||
|
NbaiMetadata(trainedBy = "unknown", trainedAt = "unknown", trainingDataCount = 0L, valLoss = 0.0, trainLoss = 0.0)
|
||||||
|
|
||||||
|
def migrateFromBin(): NbaiModel =
|
||||||
|
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
|
||||||
|
.getOrElse(sys.error("Neither nnue_weights.nbai nor nnue_weights.bin found in resources"))
|
||||||
|
try
|
||||||
|
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
checkBinHeader(buf)
|
||||||
|
val weights = DefaultLayers.map(_ => readBinLayerWeights(buf))
|
||||||
|
NbaiModel(UnknownMetadata, DefaultLayers, weights)
|
||||||
|
finally stream.close()
|
||||||
|
|
||||||
|
private def checkBinHeader(buf: ByteBuffer): Unit =
|
||||||
|
val magic = buf.getInt()
|
||||||
|
if magic != BinMagic then sys.error(s"Invalid bin magic: 0x${magic.toHexString}")
|
||||||
|
val version = buf.getInt()
|
||||||
|
if version != BinVersion then sys.error(s"Unsupported bin version: $version")
|
||||||
|
|
||||||
|
private def readBinLayerWeights(buf: ByteBuffer): LayerWeights =
|
||||||
|
LayerWeights(readBinTensor(buf), readBinTensor(buf))
|
||||||
|
|
||||||
|
private def readBinTensor(buf: ByteBuffer): Array[Float] =
|
||||||
|
val shape = Array.tabulate(buf.getInt())(_ => buf.getInt())
|
||||||
|
Array.tabulate(shape.product)(_ => buf.getFloat())
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
package de.nowchess.bot.bots.nnue
|
||||||
|
|
||||||
|
/** Descriptor for a single dense layer stored in a .nbai file. */
|
||||||
|
case class LayerDescriptor(activation: String, inputSize: Int, outputSize: Int)
|
||||||
|
|
||||||
|
/** Training metadata embedded in every .nbai file. */
|
||||||
|
case class NbaiMetadata(
|
||||||
|
trainedBy: String,
|
||||||
|
trainedAt: String,
|
||||||
|
trainingDataCount: Long,
|
||||||
|
valLoss: Double,
|
||||||
|
trainLoss: Double,
|
||||||
|
):
|
||||||
|
def toJson: String =
|
||||||
|
s"""{
|
||||||
|
| "trainedBy": "$trainedBy",
|
||||||
|
| "trainedAt": "$trainedAt",
|
||||||
|
| "trainingDataCount": $trainingDataCount,
|
||||||
|
| "valLoss": $valLoss,
|
||||||
|
| "trainLoss": $trainLoss
|
||||||
|
|}""".stripMargin
|
||||||
|
|
||||||
|
object NbaiMetadata:
|
||||||
|
def fromJson(json: String): NbaiMetadata =
|
||||||
|
def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("")
|
||||||
|
def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0")
|
||||||
|
NbaiMetadata(
|
||||||
|
str("trainedBy"),
|
||||||
|
str("trainedAt"),
|
||||||
|
num("trainingDataCount").toLong,
|
||||||
|
num("valLoss").toDouble,
|
||||||
|
num("trainLoss").toDouble,
|
||||||
|
)
|
||||||
|
|
||||||
|
/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */
|
||||||
|
case class LayerWeights(weights: Array[Float], bias: Array[Float])
|
||||||
|
|
||||||
|
/** A fully deserialized .nbai model ready to initialize NNUE. */
|
||||||
|
case class NbaiModel(
|
||||||
|
metadata: NbaiMetadata,
|
||||||
|
layers: Array[LayerDescriptor],
|
||||||
|
weights: Array[LayerWeights],
|
||||||
|
):
|
||||||
|
require(layers.length == weights.length, "Layer count must match weight count")
|
||||||
|
require(layers.length >= 2, "Model must have at least 2 layers")
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package de.nowchess.bot.bots.nnue
|
||||||
|
|
||||||
|
import java.io.{ByteArrayOutputStream, OutputStream}
|
||||||
|
import java.nio.{ByteBuffer, ByteOrder}
|
||||||
|
import java.nio.charset.StandardCharsets
|
||||||
|
|
||||||
|
object NbaiWriter:
|
||||||
|
|
||||||
|
def write(model: NbaiModel, out: OutputStream): Unit =
|
||||||
|
val acc = new ByteArrayOutputStream()
|
||||||
|
writeHeader(acc)
|
||||||
|
writeMetadata(acc, model.metadata)
|
||||||
|
writeLayerDescriptors(acc, model.layers)
|
||||||
|
model.weights.foreach(lw => writeLayerWeights(acc, lw))
|
||||||
|
out.write(acc.toByteArray)
|
||||||
|
|
||||||
|
private def writeHeader(out: ByteArrayOutputStream): Unit =
|
||||||
|
val buf = ByteBuffer.allocate(6).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
buf.putInt(NbaiLoader.MAGIC)
|
||||||
|
buf.putShort(1.toShort)
|
||||||
|
out.write(buf.array())
|
||||||
|
|
||||||
|
private def writeMetadata(out: ByteArrayOutputStream, meta: NbaiMetadata): Unit =
|
||||||
|
val json = meta.toJson.getBytes(StandardCharsets.UTF_8)
|
||||||
|
val buf = ByteBuffer.allocate(4 + json.length).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
buf.putInt(json.length)
|
||||||
|
buf.put(json)
|
||||||
|
out.write(buf.array())
|
||||||
|
|
||||||
|
private def writeLayerDescriptors(out: ByteArrayOutputStream, layers: Array[LayerDescriptor]): Unit =
|
||||||
|
val nameBytes = layers.map(_.activation.getBytes(StandardCharsets.US_ASCII))
|
||||||
|
val capacity = 2 + layers.indices.map(i => 1 + nameBytes(i).length + 8).sum
|
||||||
|
val buf = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
buf.putShort(layers.length.toShort)
|
||||||
|
layers.zip(nameBytes).foreach { (l, nb) =>
|
||||||
|
buf.put(nb.length.toByte)
|
||||||
|
buf.put(nb)
|
||||||
|
buf.putInt(l.inputSize)
|
||||||
|
buf.putInt(l.outputSize)
|
||||||
|
}
|
||||||
|
out.write(buf.array())
|
||||||
|
|
||||||
|
private def writeLayerWeights(out: ByteArrayOutputStream, lw: LayerWeights): Unit =
|
||||||
|
writeFloats(out, lw.weights)
|
||||||
|
writeFloats(out, lw.bias)
|
||||||
|
|
||||||
|
private def writeFloats(out: ByteArrayOutputStream, floats: Array[Float]): Unit =
|
||||||
|
val buf = ByteBuffer.allocate(4 + floats.length * 4).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
|
buf.putInt(floats.length)
|
||||||
|
floats.foreach(buf.putFloat)
|
||||||
|
out.write(buf.array())
|
||||||
@@ -0,0 +1,419 @@
|
|||||||
|
package de.nowchess.bot.logic
|
||||||
|
|
||||||
|
import de.nowchess.api.board.PieceType
|
||||||
|
import de.nowchess.api.game.GameContext
|
||||||
|
import de.nowchess.api.move.{Move, MoveType}
|
||||||
|
import de.nowchess.bot.ai.Evaluation
|
||||||
|
import de.nowchess.bot.util.ZobristHash
|
||||||
|
import de.nowchess.rules.RuleSet
|
||||||
|
import de.nowchess.rules.sets.DefaultRules
|
||||||
|
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}
|
||||||
|
|
||||||
|
final class AlphaBetaSearch(
|
||||||
|
rules: RuleSet = DefaultRules,
|
||||||
|
tt: TranspositionTable = TranspositionTable(),
|
||||||
|
weights: Evaluation,
|
||||||
|
numThreads: Int = Runtime.getRuntime.availableProcessors,
|
||||||
|
):
|
||||||
|
|
||||||
|
private val INF = Int.MaxValue / 2
|
||||||
|
private val MAX_QUIESCENCE_PLY = 64
|
||||||
|
private val NULL_MOVE_R = 2
|
||||||
|
private val ASPIRATION_DELTA = 50
|
||||||
|
private val ASPIRATION_DELTA_MAX = 150
|
||||||
|
private val TIME_CHECK_FREQUENCY = 1000
|
||||||
|
private val FUTILITY_MARGIN = 100
|
||||||
|
private val CHECK_EXTENSION = 1
|
||||||
|
|
||||||
|
private val timeStartMs = AtomicLong(0L)
|
||||||
|
private val timeLimitMs = AtomicLong(0L)
|
||||||
|
private val nodeCount = AtomicInteger(0)
|
||||||
|
private val ordering = MoveOrdering.OrderingContext()
|
||||||
|
|
||||||
|
private final case class QuiescenceNode(
|
||||||
|
context: GameContext,
|
||||||
|
ply: Int,
|
||||||
|
alpha: Int,
|
||||||
|
beta: Int,
|
||||||
|
hash: Long,
|
||||||
|
)
|
||||||
|
|
||||||
|
/** Return the best move for the side to move, searching to maxDepth plies. Uses iterative deepening with aspiration
|
||||||
|
* windows.
|
||||||
|
*/
|
||||||
|
def bestMove(context: GameContext, maxDepth: Int): Option[Move] =
|
||||||
|
bestMove(context, maxDepth, Set.empty)
|
||||||
|
|
||||||
|
def bestMove(context: GameContext, maxDepth: Int, excludedRootMoves: Set[Move]): Option[Move] =
|
||||||
|
tt.clear()
|
||||||
|
ordering.clear()
|
||||||
|
weights.initAccumulator(context)
|
||||||
|
timeStartMs.set(System.currentTimeMillis)
|
||||||
|
timeLimitMs.set(Long.MaxValue / 4)
|
||||||
|
nodeCount.set(0)
|
||||||
|
val rootHash = ZobristHash.hash(context)
|
||||||
|
(1 to maxDepth)
|
||||||
|
.foldLeft((None: Option[Move], 0)) { case ((bestSoFar, prevScore), depth) =>
|
||||||
|
val (alpha, beta) =
|
||||||
|
if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
|
||||||
|
val (score, move) = searchWithAspiration(
|
||||||
|
context,
|
||||||
|
depth,
|
||||||
|
alpha,
|
||||||
|
beta,
|
||||||
|
ASPIRATION_DELTA,
|
||||||
|
rootHash,
|
||||||
|
excludedRootMoves,
|
||||||
|
)
|
||||||
|
(move.orElse(bestSoFar), score)
|
||||||
|
}
|
||||||
|
._1
|
||||||
|
|
||||||
|
/** Return the best move for the side to move within a time budget (ms). Uses iterative deepening, stopping when time
|
||||||
|
* runs out.
|
||||||
|
*/
|
||||||
|
def bestMoveWithTime(context: GameContext, timeBudgetMs: Long): Option[Move] =
|
||||||
|
bestMoveWithTime(context, timeBudgetMs, Set.empty)
|
||||||
|
|
||||||
|
def bestMoveWithTime(context: GameContext, timeBudgetMs: Long, excludedRootMoves: Set[Move]): Option[Move] =
|
||||||
|
tt.clear()
|
||||||
|
ordering.clear()
|
||||||
|
weights.initAccumulator(context)
|
||||||
|
timeStartMs.set(System.currentTimeMillis)
|
||||||
|
timeLimitMs.set(timeBudgetMs)
|
||||||
|
nodeCount.set(0)
|
||||||
|
val rootHash = ZobristHash.hash(context)
|
||||||
|
|
||||||
|
@scala.annotation.tailrec
|
||||||
|
def loop(bestSoFar: Option[Move], prevScore: Int, depth: Int): Option[Move] =
|
||||||
|
if isOutOfTime then bestSoFar
|
||||||
|
else
|
||||||
|
val (alpha, beta) =
|
||||||
|
if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
|
||||||
|
val (score, move) = searchWithAspiration(
|
||||||
|
context,
|
||||||
|
depth,
|
||||||
|
alpha,
|
||||||
|
beta,
|
||||||
|
ASPIRATION_DELTA,
|
||||||
|
rootHash,
|
||||||
|
excludedRootMoves,
|
||||||
|
)
|
||||||
|
loop(move.orElse(bestSoFar), score, depth + 1)
|
||||||
|
|
||||||
|
loop(None, 0, 1)
|
||||||
|
|
||||||
|
private def isOutOfTime: Boolean =
|
||||||
|
System.currentTimeMillis - timeStartMs.get >= timeLimitMs.get
|
||||||
|
|
||||||
|
private def searchWithAspiration(
|
||||||
|
context: GameContext,
|
||||||
|
depth: Int,
|
||||||
|
alpha: Int,
|
||||||
|
beta: Int,
|
||||||
|
initialWindow: Int,
|
||||||
|
rootHash: Long,
|
||||||
|
excludedRootMoves: Set[Move],
|
||||||
|
): (Int, Option[Move]) =
|
||||||
|
val state = SearchState(rootHash, Map(rootHash -> 1))
|
||||||
|
|
||||||
|
@scala.annotation.tailrec
|
||||||
|
def loop(currentAlpha: Int, currentBeta: Int, delta: Int, attempt: Int): (Int, Option[Move]) =
|
||||||
|
if attempt >= 3 || attempt >= depth then search(context, depth, 0, Window(-INF, INF), state, excludedRootMoves)
|
||||||
|
else
|
||||||
|
val (score, move) = search(context, depth, 0, Window(currentAlpha, currentBeta), state, excludedRootMoves)
|
||||||
|
if score > currentAlpha && score < currentBeta then (score, move)
|
||||||
|
else if score <= currentAlpha then
|
||||||
|
loop(score - delta, currentBeta, math.min(delta * 2, ASPIRATION_DELTA_MAX), attempt + 1)
|
||||||
|
else loop(currentAlpha, score + delta, math.min(delta * 2, ASPIRATION_DELTA_MAX), attempt + 1)
|
||||||
|
|
||||||
|
loop(alpha, beta, initialWindow, 0)
|
||||||
|
|
||||||
|
private def hasNonPawnMaterial(context: GameContext): Boolean =
|
||||||
|
context.board.pieces.values.exists { piece =>
|
||||||
|
piece.color == context.turn &&
|
||||||
|
piece.pieceType != PieceType.Pawn &&
|
||||||
|
piece.pieceType != PieceType.King
|
||||||
|
}
|
||||||
|
|
||||||
|
private def nullMoveContext(context: GameContext): GameContext =
|
||||||
|
context.withTurn(context.turn.opposite).withEnPassantSquare(None)
|
||||||
|
|
||||||
|
private def tryNullMove(
|
||||||
|
context: GameContext,
|
||||||
|
depth: Int,
|
||||||
|
ply: Int,
|
||||||
|
beta: Int,
|
||||||
|
state: SearchState,
|
||||||
|
excludedRootMoves: Set[Move],
|
||||||
|
): Option[Int] =
|
||||||
|
val nullCtx = nullMoveContext(context)
|
||||||
|
val nullState = state.advance(ZobristHash.hash(nullCtx))
|
||||||
|
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
|
||||||
|
weights.copyAccumulator(ply, ply + 1)
|
||||||
|
val (score, _) = search(nullCtx, reductionDepth, ply + 1, Window(-beta, -beta + 1), nullState, excludedRootMoves)
|
||||||
|
if -score >= beta then Some(beta) else None
|
||||||
|
|
||||||
|
/** Negamax alpha-beta search returning (score, best move). */
|
||||||
|
private def search(
|
||||||
|
context: GameContext,
|
||||||
|
depth: Int,
|
||||||
|
ply: Int,
|
||||||
|
window: Window,
|
||||||
|
state: SearchState,
|
||||||
|
excludedRootMoves: Set[Move],
|
||||||
|
): (Int, Option[Move]) =
|
||||||
|
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
|
||||||
|
searchNode(params)
|
||||||
|
|
||||||
|
private def searchNode(params: SearchParams): (Int, Option[Move]) =
|
||||||
|
val count = nodeCount.incrementAndGet()
|
||||||
|
immediateSearchResult(params, count).getOrElse {
|
||||||
|
val legalMoves = rules.allLegalMoves(params.context)
|
||||||
|
terminalSearchResult(params, legalMoves).getOrElse(searchDeeper(params, legalMoves))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def immediateSearchResult(
|
||||||
|
params: SearchParams,
|
||||||
|
count: Int,
|
||||||
|
): Option[(Int, Option[Move])] =
|
||||||
|
if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then
|
||||||
|
Some((weights.evaluateAccumulator(params.ply, params.context, params.state.hash), None))
|
||||||
|
else if params.state.repetitions.getOrElse(params.state.hash, 0) >= 3 then Some((weights.DRAW_SCORE, None))
|
||||||
|
else ttCutoff(params)
|
||||||
|
|
||||||
|
private def ttCutoff(params: SearchParams): Option[(Int, Option[Move])] =
|
||||||
|
tt.probe(params.state.hash).filter(_.depth >= params.depth).flatMap { entry =>
|
||||||
|
entry.flag match
|
||||||
|
case TTFlag.Exact => Some((entry.score, entry.bestMove))
|
||||||
|
case TTFlag.Lower =>
|
||||||
|
val newAlpha = math.max(params.window.alpha, entry.score)
|
||||||
|
Option.when(newAlpha >= params.window.beta)((entry.score, entry.bestMove))
|
||||||
|
case TTFlag.Upper =>
|
||||||
|
val newBeta = math.min(params.window.beta, entry.score)
|
||||||
|
Option.when(params.window.alpha >= newBeta)((entry.score, entry.bestMove))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def terminalSearchResult(
|
||||||
|
params: SearchParams,
|
||||||
|
legalMoves: List[Move],
|
||||||
|
): Option[(Int, Option[Move])] =
|
||||||
|
if legalMoves.isEmpty then
|
||||||
|
Some(
|
||||||
|
(
|
||||||
|
if rules.isCheckmate(params.context) then -(weights.CHECKMATE_SCORE - params.ply) else weights.DRAW_SCORE,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else if rules.isInsufficientMaterial(params.context) || rules.isFiftyMoveRule(params.context) then
|
||||||
|
Some((weights.DRAW_SCORE, None))
|
||||||
|
else if params.depth == 0 then
|
||||||
|
Some((quiescence(params.context, params.ply, params.window.alpha, params.window.beta, params.state.hash), None))
|
||||||
|
else None
|
||||||
|
|
||||||
|
private def searchDeeper(
|
||||||
|
params: SearchParams,
|
||||||
|
legalMoves: List[Move],
|
||||||
|
): (Int, Option[Move]) =
|
||||||
|
val nullResult =
|
||||||
|
Option
|
||||||
|
.when(canTryNullMove(params))(
|
||||||
|
tryNullMove(
|
||||||
|
params.context,
|
||||||
|
params.depth,
|
||||||
|
params.ply,
|
||||||
|
params.window.beta,
|
||||||
|
params.state,
|
||||||
|
params.excludedRootMoves,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.flatten
|
||||||
|
|
||||||
|
nullResult.map((_, None)).getOrElse {
|
||||||
|
val ttBest = tt.probe(params.state.hash).flatMap(_.bestMove)
|
||||||
|
val ordered = MoveOrdering.sort(params.context, legalMoves, ttBest, params.ply, ordering)
|
||||||
|
searchSequential(
|
||||||
|
params.context,
|
||||||
|
params.depth,
|
||||||
|
params.ply,
|
||||||
|
params.window,
|
||||||
|
ordered,
|
||||||
|
params.state,
|
||||||
|
params.excludedRootMoves,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def canTryNullMove(params: SearchParams): Boolean =
|
||||||
|
params.depth >= 3 &&
|
||||||
|
!rules.isCheck(params.context) &&
|
||||||
|
hasNonPawnMaterial(params.context)
|
||||||
|
|
||||||
|
private def isQuietMove(context: GameContext, move: Move): Boolean =
|
||||||
|
!isCapture(context, move) &&
|
||||||
|
move.moveType != MoveType.CastleKingside &&
|
||||||
|
move.moveType != MoveType.CastleQueenside
|
||||||
|
|
||||||
|
private def scoreMove(
|
||||||
|
child: GameContext,
|
||||||
|
childState: SearchState,
|
||||||
|
params: SearchParams,
|
||||||
|
extension: Int,
|
||||||
|
reduction: Int,
|
||||||
|
a: Int,
|
||||||
|
): Int =
|
||||||
|
val betaNeg = -params.window.beta
|
||||||
|
if reduction > 0 then
|
||||||
|
val (rs, _) = search(
|
||||||
|
child,
|
||||||
|
math.max(0, params.depth - 1 - reduction + extension),
|
||||||
|
params.ply + 1,
|
||||||
|
Window(-a - 1, -a),
|
||||||
|
childState,
|
||||||
|
params.excludedRootMoves,
|
||||||
|
)
|
||||||
|
val s = -rs
|
||||||
|
if s > a then
|
||||||
|
val (fs, _) = search(
|
||||||
|
child,
|
||||||
|
math.max(0, params.depth - 1 + extension),
|
||||||
|
params.ply + 1,
|
||||||
|
Window(betaNeg, -a),
|
||||||
|
childState,
|
||||||
|
params.excludedRootMoves,
|
||||||
|
)
|
||||||
|
-fs
|
||||||
|
else s
|
||||||
|
else
|
||||||
|
val (rs, _) = search(
|
||||||
|
child,
|
||||||
|
math.max(0, params.depth - 1 + extension),
|
||||||
|
params.ply + 1,
|
||||||
|
Window(betaNeg, -a),
|
||||||
|
childState,
|
||||||
|
params.excludedRootMoves,
|
||||||
|
)
|
||||||
|
-rs
|
||||||
|
|
||||||
|
private def evalSingleMove(
|
||||||
|
move: Move,
|
||||||
|
moveNumber: Int,
|
||||||
|
a: Int,
|
||||||
|
params: SearchParams,
|
||||||
|
): Option[(Int, Boolean)] =
|
||||||
|
val skipRoot = params.ply == 0 && params.excludedRootMoves.contains(move)
|
||||||
|
val isQuiet = isQuietMove(params.context, move)
|
||||||
|
val futility = params.depth == 1 && isQuiet && moveNumber > 2 &&
|
||||||
|
weights.evaluateAccumulator(params.ply, params.context, params.state.hash) + FUTILITY_MARGIN < params.window.alpha
|
||||||
|
if skipRoot || futility then None
|
||||||
|
else
|
||||||
|
val child = rules.applyMove(params.context)(move)
|
||||||
|
val childHash = ZobristHash.nextHash(params.context, params.state.hash, move, child)
|
||||||
|
weights.pushAccumulator(params.ply + 1, move, params.context, child)
|
||||||
|
val childState = params.state.advance(childHash)
|
||||||
|
val extension = if rules.isCheck(child) then CHECK_EXTENSION else 0
|
||||||
|
val reduction = if moveNumber > 4 && params.depth >= 3 && isQuiet then 1 else 0
|
||||||
|
Some((scoreMove(child, childState, params, extension, reduction, a), isQuiet))
|
||||||
|
|
||||||
|
private def recordCutoff(move: Move, depth: Int, ply: Int): Unit =
|
||||||
|
ordering.addHistory(
|
||||||
|
move.from.rank.ordinal * 8 + move.from.file.ordinal,
|
||||||
|
move.to.rank.ordinal * 8 + move.to.file.ordinal,
|
||||||
|
depth * depth,
|
||||||
|
)
|
||||||
|
ordering.addKillerMove(ply, move)
|
||||||
|
|
||||||
|
@scala.annotation.tailrec
|
||||||
|
private def searchLoop(
|
||||||
|
idx: Int,
|
||||||
|
moveNumber: Int,
|
||||||
|
acc: LoopAcc,
|
||||||
|
params: SearchParams,
|
||||||
|
ordered: List[Move],
|
||||||
|
): (Option[Move], Int, Boolean) =
|
||||||
|
if idx >= ordered.length then (acc.bestMove, acc.bestScore, false)
|
||||||
|
else
|
||||||
|
val move = ordered(idx)
|
||||||
|
evalSingleMove(move, moveNumber, acc.a, params) match
|
||||||
|
case None => searchLoop(idx + 1, moveNumber + 1, acc, params, ordered)
|
||||||
|
case Some((score, isQuiet)) =>
|
||||||
|
val newAcc = LoopAcc(
|
||||||
|
if score > acc.bestScore then Some(move) else acc.bestMove,
|
||||||
|
math.max(acc.bestScore, score),
|
||||||
|
math.max(acc.a, score),
|
||||||
|
)
|
||||||
|
if newAcc.a >= params.window.beta then
|
||||||
|
if isQuiet then recordCutoff(move, params.depth, params.ply)
|
||||||
|
(newAcc.bestMove, newAcc.bestScore, true)
|
||||||
|
else searchLoop(idx + 1, moveNumber + 1, newAcc, params, ordered)
|
||||||
|
|
||||||
|
private def searchSequential(
|
||||||
|
context: GameContext,
|
||||||
|
depth: Int,
|
||||||
|
ply: Int,
|
||||||
|
window: Window,
|
||||||
|
ordered: List[Move],
|
||||||
|
state: SearchState,
|
||||||
|
excludedRootMoves: Set[Move],
|
||||||
|
): (Int, Option[Move]) =
|
||||||
|
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
|
||||||
|
val (bestMove, bestScore, cutoff) = searchLoop(0, 0, LoopAcc(None, -INF, window.alpha), params, ordered)
|
||||||
|
val flag =
|
||||||
|
if cutoff then TTFlag.Lower
|
||||||
|
else if bestScore <= window.alpha then TTFlag.Upper
|
||||||
|
else TTFlag.Exact
|
||||||
|
tt.store(TTEntry(state.hash, depth, bestScore, flag, bestMove))
|
||||||
|
(bestScore, bestMove)
|
||||||
|
|
||||||
|
/** Quiescence search: only captures until position is quiet. */
|
||||||
|
private def quiescence(
|
||||||
|
context: GameContext,
|
||||||
|
ply: Int,
|
||||||
|
alpha: Int,
|
||||||
|
beta: Int,
|
||||||
|
hash: Long,
|
||||||
|
): Int =
|
||||||
|
quiescenceNode(QuiescenceNode(context, ply, alpha, beta, hash))
|
||||||
|
|
||||||
|
private def quiescenceNode(node: QuiescenceNode): Int =
|
||||||
|
val inCheck = rules.isCheck(node.context)
|
||||||
|
val standPat = if inCheck then -INF else weights.evaluateAccumulator(node.ply, node.context, node.hash)
|
||||||
|
|
||||||
|
if !inCheck && standPat >= node.beta then node.beta
|
||||||
|
else if node.ply >= MAX_QUIESCENCE_PLY then quiescenceAtDepthLimit(node, inCheck, standPat)
|
||||||
|
else
|
||||||
|
val moves = tacticalMoves(node.context, inCheck)
|
||||||
|
if inCheck && moves.isEmpty then -(weights.CHECKMATE_SCORE - node.ply)
|
||||||
|
else
|
||||||
|
val ordered = MoveOrdering.sort(node.context, moves, None)
|
||||||
|
val a0 = if inCheck then node.alpha else math.max(node.alpha, standPat)
|
||||||
|
quiescenceLoop(node, ordered, 0, a0)
|
||||||
|
|
||||||
|
private def quiescenceAtDepthLimit(node: QuiescenceNode, inCheck: Boolean, standPat: Int): Int =
|
||||||
|
if inCheck then weights.evaluateAccumulator(node.ply, node.context, node.hash) else standPat
|
||||||
|
|
||||||
|
private def tacticalMoves(context: GameContext, inCheck: Boolean): List[Move] =
|
||||||
|
val allMoves = rules.allLegalMoves(context)
|
||||||
|
if inCheck then allMoves else allMoves.filter(m => isCapture(context, m))
|
||||||
|
|
||||||
|
@scala.annotation.tailrec
|
||||||
|
private def quiescenceLoop(
|
||||||
|
node: QuiescenceNode,
|
||||||
|
ordered: List[Move],
|
||||||
|
idx: Int,
|
||||||
|
a: Int,
|
||||||
|
): Int =
|
||||||
|
if idx >= ordered.length then a
|
||||||
|
else
|
||||||
|
val move = ordered(idx)
|
||||||
|
val child = rules.applyMove(node.context)(move)
|
||||||
|
val childHash = ZobristHash.nextHash(node.context, node.hash, move, child)
|
||||||
|
weights.pushAccumulator(node.ply + 1, move, node.context, child)
|
||||||
|
val score = -quiescence(child, node.ply + 1, -node.beta, -a, childHash)
|
||||||
|
if score >= node.beta then node.beta
|
||||||
|
else quiescenceLoop(node, ordered, idx + 1, math.max(a, score))
|
||||||
|
|
||||||
|
private def isCapture(context: GameContext, move: Move): Boolean = move.moveType match
|
||||||
|
case MoveType.Normal(true) => true
|
||||||
|
case MoveType.EnPassant => true
|
||||||
|
case MoveType.Promotion(_) => context.board.pieceAt(move.to).exists(_.color != context.turn)
|
||||||
|
case _ => false
|
||||||
@@ -0,0 +1,177 @@
|
|||||||
|
package de.nowchess.bot.logic
|
||||||
|
|
||||||
|
import de.nowchess.api.board.{Board, Color, Piece, PieceType, Square}
|
||||||
|
import de.nowchess.api.game.GameContext
|
||||||
|
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
|
||||||
|
|
||||||
|
import scala.annotation.tailrec
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
object MoveOrdering:
|
||||||
|
|
||||||
|
class OrderingContext:
|
||||||
|
private val killerMoves = mutable.Map[Int, List[Move]]()
|
||||||
|
private val historyTable = mutable.Map[(Int, Int), Int]()
|
||||||
|
|
||||||
|
def addKillerMove(ply: Int, move: Move): Unit =
|
||||||
|
val current = killerMoves.getOrElse(ply, List())
|
||||||
|
if current.isEmpty || (current.head.from != move.from || current.head.to != move.to) then
|
||||||
|
killerMoves(ply) = (move :: current).take(2)
|
||||||
|
|
||||||
|
def getKillerMoves(ply: Int): List[Move] =
|
||||||
|
killerMoves.getOrElse(ply, List())
|
||||||
|
|
||||||
|
def addHistory(from: Int, to: Int, bonus: Int): Unit =
|
||||||
|
val key = (from, to)
|
||||||
|
historyTable(key) = historyTable.getOrElse(key, 0) + bonus
|
||||||
|
|
||||||
|
def getHistory(from: Int, to: Int): Int =
|
||||||
|
historyTable.getOrElse((from, to), 0)
|
||||||
|
|
||||||
|
def clear(): Unit =
|
||||||
|
killerMoves.clear()
|
||||||
|
historyTable.clear()
|
||||||
|
|
||||||
|
def score(
|
||||||
|
context: GameContext,
|
||||||
|
move: Move,
|
||||||
|
ttBestMove: Option[Move],
|
||||||
|
ply: Int = 0,
|
||||||
|
ordering: OrderingContext = new OrderingContext(),
|
||||||
|
): Int =
|
||||||
|
if ttBestMove.exists(m => m.from == move.from && m.to == move.to) then Int.MaxValue
|
||||||
|
else
|
||||||
|
move.moveType match
|
||||||
|
case MoveType.Promotion(PromotionPiece.Queen) =>
|
||||||
|
1_000_000 + promotionCaptureBonus(context, move)
|
||||||
|
case MoveType.Normal(true) | MoveType.EnPassant =>
|
||||||
|
captureScore(context, move)
|
||||||
|
case MoveType.Promotion(_) =>
|
||||||
|
50_000 + promotionCaptureBonus(context, move)
|
||||||
|
case _ => scoreQuietMove(move, ply, ordering)
|
||||||
|
|
||||||
|
def sort(
|
||||||
|
context: GameContext,
|
||||||
|
moves: List[Move],
|
||||||
|
ttBestMove: Option[Move],
|
||||||
|
ply: Int = 0,
|
||||||
|
ordering: OrderingContext = new OrderingContext(),
|
||||||
|
): List[Move] =
|
||||||
|
moves.sortBy(m => -score(context, m, ttBestMove, ply, ordering))
|
||||||
|
|
||||||
|
private def scoreQuietMove(move: Move, ply: Int, ordering: OrderingContext): Int =
|
||||||
|
val isKiller = ordering.getKillerMoves(ply).exists(k => k.from == move.from && k.to == move.to)
|
||||||
|
val fromIdx = move.from.rank.ordinal * 8 + move.from.file.ordinal
|
||||||
|
val toIdx = move.to.rank.ordinal * 8 + move.to.file.ordinal
|
||||||
|
val history = ordering.getHistory(fromIdx, toIdx)
|
||||||
|
if isKiller then 10_000 + (history / 10) else history / 10
|
||||||
|
|
||||||
|
private def promotionCaptureBonus(context: GameContext, move: Move): Int =
|
||||||
|
if isCapture(context, move) then captureScore(context, move) else 0
|
||||||
|
|
||||||
|
private def captureScore(context: GameContext, move: Move): Int =
|
||||||
|
val see = staticExchange(context, move)
|
||||||
|
val seeBias = if see >= 0 then 20_000 else -20_000
|
||||||
|
100_000 + mvvLva(context, move) + seeBias + see
|
||||||
|
|
||||||
|
private def mvvLva(context: GameContext, move: Move): Int =
|
||||||
|
(victimValue(context, move) * 10) - attackerValue(context, move)
|
||||||
|
|
||||||
|
private def attackerValue(context: GameContext, move: Move): Int =
|
||||||
|
context.board.pieceAt(move.from).map(pieceValue).getOrElse(0)
|
||||||
|
|
||||||
|
private def victimValue(context: GameContext, move: Move): Int =
|
||||||
|
move.moveType match
|
||||||
|
case MoveType.Normal(true) => context.board.pieceAt(move.to).map(pieceValue).getOrElse(0)
|
||||||
|
case MoveType.EnPassant => 1
|
||||||
|
case MoveType.Promotion(_) => context.board.pieceAt(move.to).map(pieceValue).getOrElse(0)
|
||||||
|
case _ => 0
|
||||||
|
|
||||||
|
private def pieceValue(piece: Piece): Int = piece.pieceType match
|
||||||
|
case PieceType.Pawn => 1
|
||||||
|
case PieceType.Knight => 3
|
||||||
|
case PieceType.Bishop => 3
|
||||||
|
case PieceType.Rook => 5
|
||||||
|
case PieceType.Queen => 9
|
||||||
|
case PieceType.King => 200
|
||||||
|
|
||||||
|
private def isCapture(context: GameContext, move: Move): Boolean = move.moveType match
|
||||||
|
case MoveType.Normal(true) => true
|
||||||
|
case MoveType.EnPassant => true
|
||||||
|
case MoveType.Promotion(_) => context.board.pieceAt(move.to).exists(_.color != context.turn)
|
||||||
|
case _ => false
|
||||||
|
|
||||||
|
private def staticExchange(context: GameContext, move: Move): Int =
|
||||||
|
if !isCapture(context, move) then 0
|
||||||
|
else
|
||||||
|
val target = move.to
|
||||||
|
val initialGain = victimValue(context, move)
|
||||||
|
movedPieceAfterMove(context, move).fold(initialGain) { moved =>
|
||||||
|
val boardAfterMove = applySeeMove(context.board, move, moved)
|
||||||
|
initialGain - seeGain(boardAfterMove, target, context.turn.opposite, pieceValue(moved))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def movedPieceAfterMove(context: GameContext, move: Move): Option[Piece] =
|
||||||
|
move.moveType match
|
||||||
|
case MoveType.Promotion(pp) => Some(Piece(context.turn, promotionPieceType(pp)))
|
||||||
|
case _ => context.board.pieceAt(move.from)
|
||||||
|
|
||||||
|
private def seeGain(board: Board, target: Square, side: Color, currentValue: Int): Int =
|
||||||
|
leastValuableAttacker(board, target, side) match
|
||||||
|
case None => 0
|
||||||
|
case Some((from, attacker)) =>
|
||||||
|
val nextBoard = board.removed(from).updated(target, attacker)
|
||||||
|
val replyGain = seeGain(nextBoard, target, side.opposite, pieceValue(attacker))
|
||||||
|
math.max(0, currentValue - replyGain)
|
||||||
|
|
||||||
|
private def applySeeMove(board: Board, move: Move, moved: Piece): Board =
|
||||||
|
move.moveType match
|
||||||
|
case MoveType.EnPassant =>
|
||||||
|
val capturedSquare = Square(move.to.file, move.from.rank)
|
||||||
|
board.removed(move.from).removed(capturedSquare).updated(move.to, moved)
|
||||||
|
case _ => board.removed(move.from).updated(move.to, moved)
|
||||||
|
|
||||||
|
private def leastValuableAttacker(board: Board, target: Square, color: Color): Option[(Square, Piece)] =
|
||||||
|
board.pieces
|
||||||
|
.collect {
|
||||||
|
case (sq, piece) if piece.color == color && attacksSquare(board, sq, target, piece) => (sq, piece)
|
||||||
|
}
|
||||||
|
.toList
|
||||||
|
.sortBy { case (_, piece) => pieceValue(piece) }
|
||||||
|
.headOption
|
||||||
|
|
||||||
|
private def attacksSquare(board: Board, from: Square, target: Square, piece: Piece): Boolean =
|
||||||
|
val df = target.file.ordinal - from.file.ordinal
|
||||||
|
val dr = target.rank.ordinal - from.rank.ordinal
|
||||||
|
piece.pieceType match
|
||||||
|
case PieceType.Pawn =>
|
||||||
|
val dir = if piece.color == Color.White then 1 else -1
|
||||||
|
dr == dir && math.abs(df) == 1
|
||||||
|
case PieceType.Knight =>
|
||||||
|
val adf = math.abs(df)
|
||||||
|
val adr = math.abs(dr)
|
||||||
|
(adf == 1 && adr == 2) || (adf == 2 && adr == 1)
|
||||||
|
case PieceType.Bishop => clearLine(board, from, target, df, dr, diagonal = true)
|
||||||
|
case PieceType.Rook => clearLine(board, from, target, df, dr, diagonal = false)
|
||||||
|
case PieceType.Queen =>
|
||||||
|
clearLine(board, from, target, df, dr, diagonal = true) ||
|
||||||
|
clearLine(board, from, target, df, dr, diagonal = false)
|
||||||
|
case PieceType.King => math.abs(df) <= 1 && math.abs(dr) <= 1
|
||||||
|
|
||||||
|
private def clearLine(board: Board, from: Square, target: Square, df: Int, dr: Int, diagonal: Boolean): Boolean =
|
||||||
|
val valid =
|
||||||
|
if diagonal then math.abs(df) == math.abs(dr) && df != 0 else (df == 0 && dr != 0) || (dr == 0 && df != 0)
|
||||||
|
valid && pathClear(board, from, target, Integer.compare(df, 0), Integer.compare(dr, 0))
|
||||||
|
|
||||||
|
@tailrec
|
||||||
|
private def pathClear(board: Board, from: Square, target: Square, stepF: Int, stepR: Int): Boolean =
|
||||||
|
from.offset(stepF, stepR) match
|
||||||
|
case None => false
|
||||||
|
case Some(next) if next == target => true
|
||||||
|
case Some(next) => board.pieceAt(next).isEmpty && pathClear(board, next, target, stepF, stepR)
|
||||||
|
|
||||||
|
private def promotionPieceType(piece: PromotionPiece): PieceType = piece match
|
||||||
|
case PromotionPiece.Knight => PieceType.Knight
|
||||||
|
case PromotionPiece.Bishop => PieceType.Bishop
|
||||||
|
case PromotionPiece.Rook => PieceType.Rook
|
||||||
|
case PromotionPiece.Queen => PieceType.Queen
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
package de.nowchess.bot.logic
|
||||||
|
|
||||||
|
import de.nowchess.api.game.GameContext
|
||||||
|
import de.nowchess.api.move.Move
|
||||||
|
|
||||||
|
final case class Window(alpha: Int, beta: Int)
|
||||||
|
|
||||||
|
final case class LoopAcc(bestMove: Option[Move], bestScore: Int, a: Int)
|
||||||
|
|
||||||
|
final case class SearchParams(
|
||||||
|
context: GameContext,
|
||||||
|
depth: Int,
|
||||||
|
ply: Int,
|
||||||
|
window: Window,
|
||||||
|
state: SearchState,
|
||||||
|
excludedRootMoves: Set[Move],
|
||||||
|
)
|
||||||
|
|
||||||
|
final case class SearchState(hash: Long, repetitions: Map[Long, Int]):
|
||||||
|
def advance(nextHash: Long): SearchState =
|
||||||
|
SearchState(
|
||||||
|
nextHash,
|
||||||
|
repetitions.updatedWith(nextHash) {
|
||||||
|
case Some(v) => Some(v + 1)
|
||||||
|
case None => Some(1)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
enum TTFlag:
|
||||||
|
case Exact // Score is exact
|
||||||
|
case Lower // Score is a lower bound
|
||||||
|
case Upper // Score is an upper bound
|
||||||
|
|
||||||
|
final case class TTEntry(
|
||||||
|
hash: Long,
|
||||||
|
depth: Int,
|
||||||
|
score: Int,
|
||||||
|
flag: TTFlag,
|
||||||
|
bestMove: Option[Move],
|
||||||
|
)
|
||||||
|
|
||||||
|
final class TranspositionTable(val sizePow2: Int = 20):
|
||||||
|
private val size = 1 << sizePow2
|
||||||
|
private val mask = size - 1L
|
||||||
|
private val locks = Array.fill(size)(new Object())
|
||||||
|
private val table: Array[Option[TTEntry]] = Array.fill(size)(None)
|
||||||
|
|
||||||
|
def probe(hash: Long): Option[TTEntry] =
|
||||||
|
val index = (hash & mask).toInt
|
||||||
|
locks(index).synchronized {
|
||||||
|
table(index).filter(_.hash == hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
def store(entry: TTEntry): Unit =
|
||||||
|
val index = (entry.hash & mask).toInt
|
||||||
|
locks(index).synchronized {
|
||||||
|
table(index) = Some(entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear(): Unit =
|
||||||
|
for i <- 0 until size do locks(i).synchronized { table(i) = None }
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user