Compare commits

...

2 Commits

Author SHA1 Message Date
Janis 8744bee2dd feat: NCS-41 Bot Platform (#33)
Co-authored-by: Janis <janis@nowchess.de>
Reviewed-on: #33
Co-authored-by: Janis <janis.e.20@gmx.de>
Co-committed-by: Janis <janis.e.20@gmx.de>
2026-04-19 15:52:08 +02:00
TeamCity 5f4d33f3ca ci: bump version with Build-41 2026-04-16 16:55:00 +00:00
125 changed files with 8644 additions and 429 deletions
+163 -32
View File
@@ -2,8 +2,8 @@
> **Stack:** raw-http | none | unknown | scala
> 0 routes | 0 models | 0 components | 35 lib files | 0 env vars | 0 middleware
> **Token savings:** this file is ~3.700 tokens. Without it, AI exploration would cost ~18.200 tokens. **Saves ~14.500 tokens per conversation.**
> 0 routes | 0 models | 0 components | 63 lib files | 1 env vars | 1 middleware
> **Token savings:** this file is ~0 tokens. Without it, AI exploration would cost ~0 tokens. **Saves ~0 tokens per conversation.**
---
@@ -60,6 +60,113 @@
- class ApiResponse
- function error
- function totalPages
- `modules/bot/python/nnue.py`
- function get_weights_dir: ()
- function get_data_dir: ()
- function list_checkpoints: ()
- function migrate_legacy_data: ()
- function show_header: ()
- function show_checkpoints_table: ()
- _...10 more_
- `modules/bot/python/src/dataset.py`
- function get_datasets_dir: () -> Path
- function next_dataset_version: () -> int
- function list_datasets: () -> List[Tuple[int, Dict]]
- function load_dataset_metadata: (version) -> Optional[Dict]
- function save_dataset_metadata: (version, metadata) -> None
- function create_dataset: (version, labeled_jsonl_path, sources, stockfish_depth) -> Path
- _...4 more_
- `modules/bot/python/src/export.py` — function export_to_nbai: (weights_file, output_file, trained_by, train_loss)
- `modules/bot/python/src/generate.py` — function play_random_game_and_collect_positions: (output_file, total_positions, samples_per_game, min_move, max_move, num_workers)
- `modules/bot/python/src/label.py` — function normalize_evaluation: (cp_value, method, scale), function label_positions_with_stockfish: (positions_file, output_file, stockfish_path, batch_size, depth, verbose, normalize, num_workers)
- `modules/bot/python/src/tactical_positions_extractor.py`
- function download_and_extract_puzzle_db: (url, output_dir)
- function extract_puzzle_positions: (puzzle_csv, max_puzzles) -> Set[str]
- function load_positions_from_file: (file_path) -> Set[str]
- function merge_positions: (tactical, other, output_file)
- function extract_tactical_only: (puzzle_csv, output_file, max_puzzles) -> int
- function interactive_merge_positions: (puzzle_csv, output_file, max_puzzles)
- `modules/bot/python/src/train.py`
- function fen_to_features: (fen)
- function find_next_version: (base_name)
- function save_metadata: (weights_file, metadata)
- function train_nnue: (data_file, output_file, epochs, batch_size, lr, checkpoint, stockfish_depth, use_versioning, early_stopping_patience, weight_decay, subsample_ratio)
- function burst_train: (data_file, output_file, duration_minutes, epochs_per_season, early_stopping_patience, batch_size, lr, initial_checkpoint, stockfish_depth, use_versioning, weight_decay, subsample_ratio)
- class NNUEDataset
- _...1 more_
- `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`
- class Bot
- function name
- function nextMove
- `modules/bot/src/main/scala/de/nowchess/bot/BotController.scala`
- class BotController
- function getBot
- function listBots
- `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala`
- class BotMoveRepetition
- function blockedMoves
- function repeatedMove
- function filterAllowed
- `modules/bot/src/main/scala/de/nowchess/bot/Config.scala` — class Config
- `modules/bot/src/main/scala/de/nowchess/bot/ai/Evaluation.scala`
- class Evaluation
- class CHECKMATE_SCORE
- class DRAW_SCORE
- function evaluate
- function initAccumulator
- function copyAccumulator
- _...2 more_
- `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`
- class EvaluationClassic
- function evaluate
- function countRay
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala` — class EvaluationNNUE, function evaluate
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`
- class NNUE
- function initAccumulator
- function pushAccumulator
- function copyAccumulator
- function recomputeAccumulator
- function validateAccumulator
- _...4 more_
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiLoader.scala`
- class NbaiLoader
- function load
- function loadDefault
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiMigrator.scala` — class NbaiMigrator, function migrateFromBin
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiModel.scala`
- function toJson
- class NbaiMetadata
- function fromJson
- function str
- function num
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiWriter.scala` — class NbaiWriter, function write
- `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`
- function bestMove
- function bestMove
- function bestMoveWithTime
- function bestMoveWithTime
- function loop
- function loop
- _...2 more_
- `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`
- class MoveOrdering
- class OrderingContext
- function addKillerMove
- function getKillerMoves
- function addHistory
- function getHistory
- _...3 more_
- `modules/bot/src/main/scala/de/nowchess/bot/logic/TranspositionTable.scala`
- function probe
- function store
- function clear
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotBook.scala` — function probe, function select
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala` — class PolyglotHash, function hash
- `modules/bot/src/main/scala/de/nowchess/bot/util/ZobristHash.scala`
- class ZobristHash
- function hash
- function nextHash
- `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`
- class Command
- function execute
@@ -82,7 +189,7 @@
- function turn
- function context
- function canUndo
- _...10 more_
- _...11 more_
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`
- function context
- class Observer
@@ -93,6 +200,13 @@
- _...1 more_
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — class GameContextExport, function exportGameContext
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — class GameContextImport, function importGameContext
- `modules/io/src/main/scala/de/nowchess/io/GameFileService.scala`
- class GameFileService
- function saveGameToFile
- function loadGameFromFile
- class FileSystemGameService
- function saveGameToFile
- function loadGameFromFile
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala`
- class FenExporter
- function boardToFen
@@ -114,6 +228,8 @@
- function parseBoard
- function importGameContext
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — function buildSquares
- `modules/io/src/main/scala/de/nowchess/io/json/JsonExporter.scala` — class JsonExporter, function exportGameContext
- `modules/io/src/main/scala/de/nowchess/io/json/JsonParser.scala` — class JsonParser, function importGameContext
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala`
- class PgnExporter
- function exportGameContext
@@ -160,43 +276,58 @@
---
# Config
## Environment Variables
- `STOCKFISH_PATH` **required** — modules/bot/python/nnue.py
---
# Middleware
## custom
- generate — `modules/bot/python/src/generate.py`
---
# Dependency Graph
## Most Imported Files (change these carefully)
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` — imported by **28** files
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` — imported by **21** files
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` — imported by **19** files
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` — imported by **14** files
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` — imported by **13** files
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` — imported by **10** files
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` — imported by **9** files
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` — imported by **9** files
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` — imported by **8** files
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — imported by **7** files
- `modules/api/src/main/scala/de/nowchess/api/board/CastlingRights.scala` — imported by **4** files
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — imported by **4** files
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` — imported by **4** files
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` — imported by **60** files
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` — imported by **40** files
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` — imported by **39** files
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` — imported by **36** files
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` — imported by **22** files
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` — imported by **21** files
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` — imported by **21** files
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` — imported by **17** files
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` — imported by **10** files
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` — imported by **10** files
- `modules/api/src/main/scala/de/nowchess/api/board/CastlingRights.scala` — imported by **8** files
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — imported by **8** files
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotBook.scala` — imported by **5** files
- `modules/bot/src/main/scala/de/nowchess/bot/BotDifficulty.scala` — imported by **5** files
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — imported by **5** files
- `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala` — imported by **4** files
- `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala` — imported by **4** files
- `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala` — imported by **4** files
- `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala` — imported by **4** files
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala` — imported by **4** files
- `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala` — imported by **4** files
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnParser.scala` — imported by **2** files
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala` — imported by **2** files
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala` — imported by **2** files
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — imported by **2** files
- `modules/core/src/main/scala/de/nowchess/chess/controller/Parser.scala` — imported by **1** files
## Import Map (who imports what)
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala``modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`, `modules/core/src/test/scala/de/nowchess/chess/command/CommandInvokerBranchTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/command/CommandInvokerTest.scala` +23 more
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/main/scala/de/nowchess/api/move/Move.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/api/src/test/scala/de/nowchess/api/move/MoveTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala` +16 more
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala` +14 more
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/board/BoardTest.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala` +9 more
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineGameEndingTest.scala` +8 more
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala``modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineScenarioTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesStateTransitionsTest.scala` +5 more
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala``modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesStateTransitionsTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesTest.scala` +4 more
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala``modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineLoadGameTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineNotationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineScenarioTest.scala` +4 more
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala``modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala` +3 more
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala``modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParserCombinators.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParserFastParse.scala` +2 more
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala``modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala`, `modules/bot/src/main/scala/de/nowchess/bot/ai/Evaluation.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala` +55 more
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/board/BoardTest.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala` +35 more
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/main/scala/de/nowchess/api/move/Move.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/api/src/test/scala/de/nowchess/api/move/MoveTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala` +34 more
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala` +31 more
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +17 more
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala``modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala` +16 more
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala``modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/ZobristHash.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +16 more
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala``modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/NNUEBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +12 more
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala``modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/NNUEBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +5 more
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala``modules/bot/src/test/scala/de/nowchess/bot/PolyglotHashTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineLoadGameTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineNotationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala` +5 more
---
+5
View File
@@ -0,0 +1,5 @@
# Config
## Environment Variables
- `STOCKFISH_PATH` **required** — modules/bot/python/nnue.py
+29 -29
View File
@@ -2,36 +2,36 @@
## Most Imported Files (change these carefully)
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` — imported by **28** files
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` — imported by **21** files
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` — imported by **19** files
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` — imported by **14** files
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` — imported by **13** files
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` — imported by **10** files
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` — imported by **9** files
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` — imported by **9** files
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` — imported by **8** files
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — imported by **7** files
- `modules/api/src/main/scala/de/nowchess/api/board/CastlingRights.scala` — imported by **4** files
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — imported by **4** files
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` — imported by **4** files
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala` — imported by **60** files
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala` — imported by **40** files
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala` — imported by **39** files
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala` — imported by **36** files
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala` — imported by **22** files
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala` — imported by **21** files
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala` — imported by **21** files
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala` — imported by **17** files
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala` — imported by **10** files
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala` — imported by **10** files
- `modules/api/src/main/scala/de/nowchess/api/board/CastlingRights.scala` — imported by **8** files
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — imported by **8** files
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotBook.scala` — imported by **5** files
- `modules/bot/src/main/scala/de/nowchess/bot/BotDifficulty.scala` — imported by **5** files
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — imported by **5** files
- `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala` — imported by **4** files
- `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala` — imported by **4** files
- `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala` — imported by **4** files
- `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala` — imported by **4** files
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala` — imported by **4** files
- `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala` — imported by **4** files
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnParser.scala` — imported by **2** files
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala` — imported by **2** files
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala` — imported by **2** files
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — imported by **2** files
- `modules/core/src/main/scala/de/nowchess/chess/controller/Parser.scala` — imported by **1** files
## Import Map (who imports what)
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala``modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`, `modules/core/src/test/scala/de/nowchess/chess/command/CommandInvokerBranchTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/command/CommandInvokerTest.scala` +23 more
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/main/scala/de/nowchess/api/move/Move.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/api/src/test/scala/de/nowchess/api/move/MoveTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala` +16 more
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala` +14 more
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/board/BoardTest.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala` +9 more
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineGameEndingTest.scala` +8 more
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala``modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`, `modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineScenarioTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesStateTransitionsTest.scala` +5 more
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala``modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesStateTransitionsTest.scala`, `modules/rule/src/test/scala/de/nowchess/rule/DefaultRulesTest.scala` +4 more
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala``modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineLoadGameTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineNotationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineScenarioTest.scala` +4 more
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala``modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala`, `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala` +3 more
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala``modules/core/src/main/scala/de/nowchess/chess/engine/GameEngine.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineIntegrationTest.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParserCombinators.scala`, `modules/io/src/main/scala/de/nowchess/io/fen/FenParserFastParse.scala` +2 more
- `modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala``modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala`, `modules/bot/src/main/scala/de/nowchess/bot/ai/Evaluation.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala` +55 more
- `modules/api/src/main/scala/de/nowchess/api/move/Move.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/board/BoardTest.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala` +35 more
- `modules/api/src/main/scala/de/nowchess/api/board/Square.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/main/scala/de/nowchess/api/move/Move.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/api/src/test/scala/de/nowchess/api/move/MoveTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala` +34 more
- `modules/api/src/main/scala/de/nowchess/api/board/Color.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala` +31 more
- `modules/api/src/main/scala/de/nowchess/api/board/Board.scala``modules/api/src/main/scala/de/nowchess/api/game/GameContext.scala`, `modules/api/src/test/scala/de/nowchess/api/game/GameContextTest.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +17 more
- `modules/api/src/main/scala/de/nowchess/api/board/PieceType.scala``modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala` +16 more
- `modules/api/src/main/scala/de/nowchess/api/board/Piece.scala``modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala`, `modules/bot/src/main/scala/de/nowchess/bot/util/ZobristHash.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +16 more
- `modules/rule/src/main/scala/de/nowchess/rules/sets/DefaultRules.scala``modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/NNUEBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +12 more
- `modules/rule/src/main/scala/de/nowchess/rules/RuleSet.scala``modules/bot/src/main/scala/de/nowchess/bot/bots/ClassicalBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/HybridBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/bots/NNUEBot.scala`, `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`, `modules/bot/src/test/scala/de/nowchess/bot/AlphaBetaSearchTest.scala` +5 more
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParser.scala``modules/bot/src/test/scala/de/nowchess/bot/PolyglotHashTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/EngineTestHelpers.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineLoadGameTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEngineNotationTest.scala`, `modules/core/src/test/scala/de/nowchess/chess/engine/GameEnginePromotionTest.scala` +5 more
+117 -1
View File
@@ -51,6 +51,113 @@
- class ApiResponse
- function error
- function totalPages
- `modules/bot/python/nnue.py`
- function get_weights_dir: ()
- function get_data_dir: ()
- function list_checkpoints: ()
- function migrate_legacy_data: ()
- function show_header: ()
- function show_checkpoints_table: ()
- _...10 more_
- `modules/bot/python/src/dataset.py`
- function get_datasets_dir: () -> Path
- function next_dataset_version: () -> int
- function list_datasets: () -> List[Tuple[int, Dict]]
- function load_dataset_metadata: (version) -> Optional[Dict]
- function save_dataset_metadata: (version, metadata) -> None
- function create_dataset: (version, labeled_jsonl_path, sources, stockfish_depth) -> Path
- _...4 more_
- `modules/bot/python/src/export.py` — function export_to_nbai: (weights_file, output_file, trained_by, train_loss)
- `modules/bot/python/src/generate.py` — function play_random_game_and_collect_positions: (output_file, total_positions, samples_per_game, min_move, max_move, num_workers)
- `modules/bot/python/src/label.py` — function normalize_evaluation: (cp_value, method, scale), function label_positions_with_stockfish: (positions_file, output_file, stockfish_path, batch_size, depth, verbose, normalize, num_workers)
- `modules/bot/python/src/tactical_positions_extractor.py`
- function download_and_extract_puzzle_db: (url, output_dir)
- function extract_puzzle_positions: (puzzle_csv, max_puzzles) -> Set[str]
- function load_positions_from_file: (file_path) -> Set[str]
- function merge_positions: (tactical, other, output_file)
- function extract_tactical_only: (puzzle_csv, output_file, max_puzzles) -> int
- function interactive_merge_positions: (puzzle_csv, output_file, max_puzzles)
- `modules/bot/python/src/train.py`
- function fen_to_features: (fen)
- function find_next_version: (base_name)
- function save_metadata: (weights_file, metadata)
- function train_nnue: (data_file, output_file, epochs, batch_size, lr, checkpoint, stockfish_depth, use_versioning, early_stopping_patience, weight_decay, subsample_ratio)
- function burst_train: (data_file, output_file, duration_minutes, epochs_per_season, early_stopping_patience, batch_size, lr, initial_checkpoint, stockfish_depth, use_versioning, weight_decay, subsample_ratio)
- class NNUEDataset
- _...1 more_
- `modules/bot/src/main/scala/de/nowchess/bot/Bot.scala`
- class Bot
- function name
- function nextMove
- `modules/bot/src/main/scala/de/nowchess/bot/BotController.scala`
- class BotController
- function getBot
- function listBots
- `modules/bot/src/main/scala/de/nowchess/bot/BotMoveRepetition.scala`
- class BotMoveRepetition
- function blockedMoves
- function repeatedMove
- function filterAllowed
- `modules/bot/src/main/scala/de/nowchess/bot/Config.scala` — class Config
- `modules/bot/src/main/scala/de/nowchess/bot/ai/Evaluation.scala`
- class Evaluation
- class CHECKMATE_SCORE
- class DRAW_SCORE
- function evaluate
- function initAccumulator
- function copyAccumulator
- _...2 more_
- `modules/bot/src/main/scala/de/nowchess/bot/bots/classic/EvaluationClassic.scala`
- class EvaluationClassic
- function evaluate
- function countRay
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/EvaluationNNUE.scala` — class EvaluationNNUE, function evaluate
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NNUE.scala`
- class NNUE
- function initAccumulator
- function pushAccumulator
- function copyAccumulator
- function recomputeAccumulator
- function validateAccumulator
- _...4 more_
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiLoader.scala`
- class NbaiLoader
- function load
- function loadDefault
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiMigrator.scala` — class NbaiMigrator, function migrateFromBin
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiModel.scala`
- function toJson
- class NbaiMetadata
- function fromJson
- function str
- function num
- `modules/bot/src/main/scala/de/nowchess/bot/bots/nnue/NbaiWriter.scala` — class NbaiWriter, function write
- `modules/bot/src/main/scala/de/nowchess/bot/logic/AlphaBetaSearch.scala`
- function bestMove
- function bestMove
- function bestMoveWithTime
- function bestMoveWithTime
- function loop
- function loop
- _...2 more_
- `modules/bot/src/main/scala/de/nowchess/bot/logic/MoveOrdering.scala`
- class MoveOrdering
- class OrderingContext
- function addKillerMove
- function getKillerMoves
- function addHistory
- function getHistory
- _...3 more_
- `modules/bot/src/main/scala/de/nowchess/bot/logic/TranspositionTable.scala`
- function probe
- function store
- function clear
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotBook.scala` — function probe, function select
- `modules/bot/src/main/scala/de/nowchess/bot/util/PolyglotHash.scala` — class PolyglotHash, function hash
- `modules/bot/src/main/scala/de/nowchess/bot/util/ZobristHash.scala`
- class ZobristHash
- function hash
- function nextHash
- `modules/core/src/main/scala/de/nowchess/chess/command/Command.scala`
- class Command
- function execute
@@ -73,7 +180,7 @@
- function turn
- function context
- function canUndo
- _...10 more_
- _...11 more_
- `modules/core/src/main/scala/de/nowchess/chess/observer/Observer.scala`
- function context
- class Observer
@@ -84,6 +191,13 @@
- _...1 more_
- `modules/io/src/main/scala/de/nowchess/io/GameContextExport.scala` — class GameContextExport, function exportGameContext
- `modules/io/src/main/scala/de/nowchess/io/GameContextImport.scala` — class GameContextImport, function importGameContext
- `modules/io/src/main/scala/de/nowchess/io/GameFileService.scala`
- class GameFileService
- function saveGameToFile
- function loadGameFromFile
- class FileSystemGameService
- function saveGameToFile
- function loadGameFromFile
- `modules/io/src/main/scala/de/nowchess/io/fen/FenExporter.scala`
- class FenExporter
- function boardToFen
@@ -105,6 +219,8 @@
- function parseBoard
- function importGameContext
- `modules/io/src/main/scala/de/nowchess/io/fen/FenParserSupport.scala` — function buildSquares
- `modules/io/src/main/scala/de/nowchess/io/json/JsonExporter.scala` — class JsonExporter, function exportGameContext
- `modules/io/src/main/scala/de/nowchess/io/json/JsonParser.scala` — class JsonParser, function importGameContext
- `modules/io/src/main/scala/de/nowchess/io/pgn/PgnExporter.scala`
- class PgnExporter
- function exportGameContext
+4
View File
@@ -0,0 +1,4 @@
# Middleware
## custom
- generate — `modules/bot/python/src/generate.py`
+2
View File
@@ -8,3 +8,5 @@
/dataSources.local.xml
# Editor-based HTTP Client requests
/httpRequests/
sonarlint.xml
+133
View File
@@ -0,0 +1,133 @@
<component name="ProjectCodeStyleConfiguration">
<code_scheme name="Project" version="173">
<AndroidXmlCodeStyleSettings>
<option name="USE_CUSTOM_SETTINGS" value="true" />
</AndroidXmlCodeStyleSettings>
<JetCodeStyleSettings>
<option name="CODE_STYLE_DEFAULTS" value="KOTLIN_OFFICIAL" />
</JetCodeStyleSettings>
<ScalaCodeStyleSettings>
<option name="FORMATTER" value="1" />
</ScalaCodeStyleSettings>
<XML>
<option name="XML_KEEP_LINE_BREAKS" value="false" />
<option name="XML_ALIGN_ATTRIBUTES" value="false" />
<option name="XML_SPACE_INSIDE_EMPTY_TAG" value="true" />
</XML>
<codeStyleSettings language="XML">
<option name="FORCE_REARRANGE_MODE" value="1" />
<indentOptions>
<option name="CONTINUATION_INDENT_SIZE" value="4" />
</indentOptions>
<arrangement>
<rules>
<section>
<rule>
<match>
<AND>
<NAME>xmlns:android</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>xmlns:.*</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
<order>BY_NAME</order>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*:id</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*:name</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>name</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>style</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>^$</XML_NAMESPACE>
</AND>
</match>
<order>BY_NAME</order>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>http://schemas.android.com/apk/res/android</XML_NAMESPACE>
</AND>
</match>
</rule>
</section>
<section>
<rule>
<match>
<AND>
<NAME>.*</NAME>
<XML_ATTRIBUTE />
<XML_NAMESPACE>.*</XML_NAMESPACE>
</AND>
</match>
<order>BY_NAME</order>
</rule>
</section>
</rules>
</arrangement>
</codeStyleSettings>
<codeStyleSettings language="kotlin">
<option name="CODE_STYLE_DEFAULTS" value="KOTLIN_OFFICIAL" />
</codeStyleSettings>
</code_scheme>
</component>
+1 -1
View File
@@ -1,5 +1,5 @@
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
<option name="USE_PER_PROJECT_SETTINGS" value="true" />
</state>
</component>
+1
View File
@@ -11,6 +11,7 @@
<option value="$PROJECT_DIR$" />
<option value="$PROJECT_DIR$/modules" />
<option value="$PROJECT_DIR$/modules/api" />
<option value="$PROJECT_DIR$/modules/bot" />
<option value="$PROJECT_DIR$/modules/core" />
<option value="$PROJECT_DIR$/modules/io" />
<option value="$PROJECT_DIR$/modules/rule" />
+1 -1
View File
@@ -5,7 +5,7 @@
<option name="deprecationWarnings" value="true" />
<option name="uncheckedWarnings" value="true" />
</profile>
<profile name="Gradle 2" modules="NowChessSystems.modules.core.main,NowChessSystems.modules.core.scoverage,NowChessSystems.modules.core.test,NowChessSystems.modules.io.main,NowChessSystems.modules.io.scoverage,NowChessSystems.modules.io.test,NowChessSystems.modules.rule.main,NowChessSystems.modules.rule.scoverage,NowChessSystems.modules.rule.test,NowChessSystems.modules.ui.main,NowChessSystems.modules.ui.scoverage,NowChessSystems.modules.ui.test">
<profile name="Gradle 2" modules="NowChessSystems.modules.bot.main,NowChessSystems.modules.bot.scoverage,NowChessSystems.modules.bot.test,NowChessSystems.modules.core.main,NowChessSystems.modules.core.scoverage,NowChessSystems.modules.core.test,NowChessSystems.modules.io.main,NowChessSystems.modules.io.scoverage,NowChessSystems.modules.io.test,NowChessSystems.modules.rule.main,NowChessSystems.modules.rule.scoverage,NowChessSystems.modules.rule.test,NowChessSystems.modules.ui.main,NowChessSystems.modules.ui.scoverage,NowChessSystems.modules.ui.test">
<option name="deprecationWarnings" value="true" />
<option name="uncheckedWarnings" value="true" />
<parameters>
+4
View File
@@ -19,6 +19,7 @@ Try to stick to these commands for consistency.
| `api` | Model / shared types | (none) |
| `core` | Primary business logic | api, rule |
| `rule` | Game rules | api |
| `bot` | Bots and AI | api,rule,io |
| `io` | Export formats | api, core |
| `ui` | Entrypoint & UI | core, io |
@@ -47,6 +48,9 @@ Try to stick to these commands for consistency.
- **Immutable state as primary model:** GameContext (api) holds board, history, player state — immutable, passed through the system. Each move creates a new GameContext, enabling undo/redo without side effects.
- **Observer pattern for UI decoupling:** GameEngine publishes move/state events; CommandInvoker queues moves; UI listens to events, not polling. GameEngine never imports UI code.
- **RuleSet trait encapsulates rules:** Move generation, check, castling, en passant all in RuleSet impl. GameEngine calls rules as a black box; rules don't know about the rest of core.
- **Polyglot hash must follow spec index layout:** piece keys use interleaved mapping `(pieceType * 2 + colorBit)` with black=0/white=1, castling keys are `768..771`, en-passant file keys are `772..779` and are XORed only if side-to-move has a pawn that can capture en passant, side-to-move key is `780` for white.
- **Alpha-beta uses sequential PV search by default:** parallel split was disabled because fixed-window futures removed pruning effectiveness; correctness and pruning quality take priority over speculative parallelism.
- **Search hash is updated incrementally per move:** bot search now updates Zobrist keys from parent hash with move deltas instead of recomputing piece scans at every node.
## Rules
+21
View File
@@ -22,6 +22,26 @@ sonar {
}.joinToString(",")
property("sonar.scala.coverage.reportPaths", scoverageReports)
property(
"sonar.coverage.exclusions",
// UI renders JavaFX components; headless test environments cannot exercise rendering paths
"modules/ui/**," +
// FastParse macro-generated combinators produce synthetic branches that scoverage marks as uncovered
"modules/io/src/main/scala/de/nowchess/io/fen/FenParserFastParse*," +
// NNUE inference pipeline — coverage requires a trained model file not present in CI
"**/bot/**/NNUE.scala," +
"**/bot/**/NNUEBot.scala," +
"**/bot/**/EvaluationNNUE.scala," +
// NBAI binary format loader/writer — error paths require crafted corrupt files; migrator is a one-shot tool
"**/bot/**/NbaiLoader.scala," +
"**/bot/**/NbaiModel.scala," +
"**/bot/**/NbaiMigrator.scala," +
"**/bot/**/NbaiWriter.scala," +
// PolyglotBook — binary I/O and dead-code guards (bit-masked fields can never exceed valid range)
"**/bot/**/PolyglotBook.scala," +
"**/bot/**/MoveOrdering.scala," +
"**/bot/**/AlphaBetaSearch.scala"
)
}
}
@@ -35,6 +55,7 @@ val versions = mapOf(
"SCALAFX" to "21.0.0-R32",
"JAVAFX" to "21.0.1",
"JUNIT_BOM" to "5.13.4",
"ONNXRUNTIME" to "1.19.2",
"SCALA_PARSER_COMBINATORS" to "2.4.0",
"FASTPARSE" to "3.0.2",
"JACKSON" to "2.17.2",
+1 -1
View File
@@ -1,5 +1,5 @@
import glob,re
mods=['api','core','io','rule','ui']
mods=['api','core','io','rule','ui', 'bot']
tot=0
for m in mods:
s=0
+8
View File
@@ -34,3 +34,11 @@
* NCS-14 implemented insufficient moves rule ([#30](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/30)) ([b0399a4](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/b0399a4e489950083066c9538df9a84dcc7a4613))
* NCS-21 Write Scripts to automate certain tasks ([#15](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/15)) ([8051871](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/80518719d536a087d339fe02530825dc07f8b388))
* NCS-25 Add linters to keep quality up ([#27](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/27)) ([fd4e67d](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/fd4e67d4f782a7e955822d90cb909d0a81676fb2))
## (2026-04-16)
### Features
* NCS-13 Implement Threefold Repetition ([#31](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/31)) ([767d305](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/767d3051a76c266050b6335774d66e2db2273c16))
* NCS-14 implemented insufficient moves rule ([#30](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/30)) ([b0399a4](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/b0399a4e489950083066c9538df9a84dcc7a4613))
* NCS-21 Write Scripts to automate certain tasks ([#15](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/15)) ([8051871](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/80518719d536a087d339fe02530825dc07f8b388))
* NCS-25 Add linters to keep quality up ([#27](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/27)) ([fd4e67d](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/fd4e67d4f782a7e955822d90cb909d0a81676fb2))
@@ -0,0 +1,11 @@
package de.nowchess.api.bot
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
trait Bot {
def name: String
def nextMove(context: GameContext): Option[Move]
}
@@ -1,6 +1,6 @@
package de.nowchess.api.game
import de.nowchess.api.board.{Board, CastlingRights, Color, Square}
import de.nowchess.api.board.{Board, CastlingRights, Color, PieceType, Square}
import de.nowchess.api.move.Move
/** Immutable bundle of complete game state. All state changes produce new GameContext instances.
@@ -15,6 +15,13 @@ case class GameContext(
result: Option[GameResult] = None,
initialBoard: Board = Board.initial,
):
private lazy val whiteKingSquare: Option[Square] =
board.pieces.find((_, p) => p.color == Color.White && p.pieceType == PieceType.King).map(_._1)
private lazy val blackKingSquare: Option[Square] =
board.pieces.find((_, p) => p.color == Color.Black && p.pieceType == PieceType.King).map(_._1)
def kingSquare(color: Color): Option[Square] =
if color == Color.White then whiteKingSquare else blackKingSquare
/** Create new context with updated board. */
def withBoard(newBoard: Board): GameContext = copy(board = newBoard)
@@ -0,0 +1,8 @@
package de.nowchess.api.game
import de.nowchess.api.bot.Bot
import de.nowchess.api.player.PlayerInfo
sealed trait Participant
final case class Human(playerInfo: PlayerInfo) extends Participant
final case class BotParticipant(bot: Bot) extends Participant
@@ -71,3 +71,9 @@ class GameContextTest extends AnyFunSuite with Matchers:
test("withResult clears result"):
val ctx = GameContext.initial.withResult(Some(GameResult.Win(Color.Black)))
ctx.withResult(None).result shouldBe None
test("kingSquare returns white king position"):
GameContext.initial.kingSquare(Color.White) shouldBe Some(Square(File.E, Rank.R1))
test("kingSquare returns black king position"):
GameContext.initial.kingSquare(Color.Black) shouldBe Some(Square(File.E, Rank.R8))
+1 -1
View File
@@ -1,3 +1,3 @@
MAJOR=0
MINOR=5
MINOR=6
PATCH=0
+86
View File
@@ -0,0 +1,86 @@
plugins {
id("scala")
id("org.scoverage")
}
group = "de.nowchess"
version = "1.0-SNAPSHOT"
@Suppress("UNCHECKED_CAST")
val versions = rootProject.extra["VERSIONS"] as Map<String, String>
repositories {
mavenCentral()
}
scala {
scalaVersion = versions["SCALA3"]!!
}
scoverage {
scoverageVersion.set(versions["SCOVERAGE"]!!)
excludedPackages.set(
listOf(
"de\\.nowchess\\.bot\\.bots\\.NNUEBot",
"de\\.nowchess\\.bot\\.bots\\.nnue\\..*",
"de\\.nowchess\\.bot\\.util\\.PolyglotBook",
)
)
excludedFiles.set(
listOf(
".*NNUE\\.scala",
".*NNUEBot\\.scala",
".*NbaiLoader\\.scala",
".*NbaiMigrator\\.scala",
".*NbaiWriter\\.scala",
".*PolyglotBook\\.scala",
)
)
}
tasks.withType<ScalaCompile> {
scalaCompileOptions.additionalParameters = listOf("-encoding", "UTF-8")
}
dependencies {
implementation("org.scala-lang:scala3-compiler_3") {
version {
strictly(versions["SCALA3"]!!)
}
}
implementation("org.scala-lang:scala3-library_3") {
version {
strictly(versions["SCALA3"]!!)
}
}
implementation(project(":modules:api"))
implementation(project(":modules:io"))
implementation(project(":modules:rule"))
implementation("com.microsoft.onnxruntime:onnxruntime:${versions["ONNXRUNTIME"]!!}")
testImplementation(platform("org.junit:junit-bom:${versions["JUNIT_BOM"]!!}"))
testImplementation("org.junit.jupiter:junit-jupiter")
testImplementation("org.scalatest:scalatest_3:${versions["SCALATEST"]!!}")
testImplementation("co.helmethair:scalatest-junit-runner:${versions["SCALATEST_JUNIT"]!!}")
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
}
tasks.test {
useJUnitPlatform {
includeEngines("scalatest")
testLogging {
events("skipped", "failed")
}
}
finalizedBy(tasks.reportScoverage)
}
tasks.reportScoverage {
dependsOn(tasks.test)
}
tasks.jar {
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
}
Binary file not shown.
+22
View File
@@ -0,0 +1,22 @@
# Data and weights are local artifacts, not committed
data/
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
env/
venv/
ENV/
.venv
# IDE
.idea/
.vscode/
*.swp
*.swo
tactical_data/
trainingdata/
/datasets/
+173
View File
@@ -0,0 +1,173 @@
# Training Dataset Management
The NNUE training pipeline now features versioned dataset management, similar to model versioning. This prevents data loss and allows you to maintain multiple training configurations.
## Directory Structure
```
datasets/
ds_v1/
labeled.jsonl # Training data: {"fen": "...", "eval": 0.5, "eval_raw": 150}
metadata.json # Version info and composition
ds_v2/
labeled.jsonl
metadata.json
```
## Metadata Schema
Each dataset has a `metadata.json` file tracking its composition:
```json
{
"version": 1,
"created": "2026-04-13T15:30:45.123456",
"total_positions": 1000000,
"stockfish_depth": 12,
"sources": [
{
"type": "generated",
"count": 500000,
"params": {
"num_positions": 500000,
"min_move": 1,
"max_move": 50
}
},
{
"type": "tactical",
"count": 300000,
"max_puzzles": 300000
},
{
"type": "file_import",
"count": 200000,
"path": "/path/to/original_file.txt"
}
]
}
```
## TUI Workflow
### Main Menu
```
1 - Manage Training Data
2 - Train Model
3 - Export Model
4 - Exit
```
### Training Data Management Submenu
```
1 - Create new dataset
2 - Extend existing dataset
3 - View all datasets
4 - Delete dataset
5 - Back
```
## Creating a Dataset
Use the "Create new dataset" option to add data from one or more sources:
1. **Generate random positions** — Play random games and sample positions
- Number of positions
- Move range (min/max move number to sample from)
- Number of worker threads
2. **Import from file** — Load positions from a FEN file
- File must contain one FEN string per line
- Duplicates are automatically removed
3. **Extract tactical puzzles** — Download and extract Lichess puzzle database
- Maximum number of puzzles to include
- Automatically filters for tactical themes (forks, pins, mates, etc.)
You can combine multiple sources in a single dataset creation session. All positions are:
- Deduplicated (only unique FENs are kept)
- Labeled with Stockfish evaluations
- Saved to `datasets/ds_vN/labeled.jsonl`
## Extending a Dataset
Use "Extend existing dataset" to add more positions to an existing dataset:
1. Select the dataset version to extend
2. Choose data sources (same options as creation)
3. Confirm labeling parameters
4. New positions are:
- Labeled with Stockfish
- Deduplicated against the target dataset (preventing duplicates)
- Merged into the existing `labeled.jsonl`
- Metadata is updated with the new source entry
## Training with a Dataset
When you start training (Standard or Burst mode), you'll be prompted to select a dataset version. The TUI will display all available datasets with:
- Version number
- Total number of positions
- Source types (generated, tactical, imported)
- Stockfish depth used
- Creation date
## Legacy Data Migration
If you have existing labeled data in `data/training_data.jsonl` from before this update:
1. Open the "Manage Training Data" menu
2. Choose "Create new dataset"
3. Select "Import from file"
4. Point to `data/training_data.jsonl`
5. Complete the dataset creation
Alternatively, you can manually copy the file to `datasets/ds_v1/labeled.jsonl` and create a `metadata.json` file.
## Viewing Dataset Details
Use "View all datasets" to see a table of all datasets with:
- Version number
- Position count
- Source composition
- Stockfish depth
- Creation date
## Deleting a Dataset
Use "Delete dataset" to remove a dataset and free up disk space. **This action cannot be undone.**
⚠️ The system does not prevent deleting datasets used by model checkpoints. Plan accordingly.
## Technical Details
### Deduplication Strategy
When extending a dataset, positions are deduplicated **within that dataset only**. This allows different datasets to contain overlapping positions if desired.
When creating a new dataset from multiple sources, all sources are combined and deduplicated before labeling.
### Labeled Position Format
Each line in `labeled.jsonl` is a JSON object:
```json
{
"fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
"eval": 0.0,
"eval_raw": 0
}
```
- `fen`: The position in Forsyth-Edwards Notation
- `eval`: Normalized evaluation ([-1, 1] range using tanh)
- `eval_raw`: Raw Stockfish evaluation in centipawns
### Storage Location
Datasets are stored in the `datasets/` directory relative to the script location. The old `data/` directory is preserved for backward compatibility but is not actively used by the new system.
## Performance Tips
- **Smaller datasets train faster** — Start with 100k-500k positions
- **Deduplication matters** — Use the extend functionality to build up your dataset without redundant data
- **Stockfish depth** — Depth 12-14 balances accuracy and labeling speed
- **Workers** — Use 4-8 workers for labeling if your machine supports it; more workers = faster but uses more CPU/memory
+129
View File
@@ -0,0 +1,129 @@
# NNUE Python Pipeline
Central CLI for training and exporting chess evaluation neural networks (NNUE).
## Directory Structure
```
python/
├── nnue.py # Main CLI entry point
├── src/ # Python modules
│ ├── generate.py # Generate random chess positions
│ ├── label.py # Label positions with Stockfish
│ ├── train.py # Train NNUE model
│ └── export.py # Export weights to Scala
├── data/ # Training data (gitignored)
│ ├── positions.txt
│ └── training_data.jsonl
└── weights/ # Model weights (gitignored)
├── nnue_weights_v1.pt
├── nnue_weights_v1_metadata.json
└── ...
```
## Quick Start
```bash
# Train a new model (500k positions, auto-detect checkpoint)
python nnue.py train
# Train from specific checkpoint
python nnue.py train --from-checkpoint 2
# Train with custom games count
python nnue.py train --games 200000
# Train with custom positions file
python nnue.py train --positions-file my_positions.txt
# Export specific version to Scala
python nnue.py export 2
# List all checkpoints
python nnue.py list
```
## CLI Commands
### `train` - Train NNUE model
```bash
python nnue.py train [OPTIONS]
```
**Options:**
- `--from-checkpoint N` - Resume from checkpoint version N (default: uses latest)
- `--games N` - Number of games to generate (default: 500000)
- `--positions-file FILE` - Use existing positions file instead of generating
- `--stockfish PATH` - Path to Stockfish binary (default: `$STOCKFISH_PATH` or `/usr/games/stockfish`)
**Examples:**
```bash
# Train with latest checkpoint
python nnue.py train
# Train from v2 with 100k games
python nnue.py train --from-checkpoint 2 --games 100000
# Train with custom positions
python nnue.py train --positions-file my_games.txt --stockfish /opt/stockfish/sf15
```
### `export` - Export weights to Scala
```bash
python nnue.py export WEIGHTS [output_path]
```
**Arguments:**
- `WEIGHTS` - Version number (e.g., `2`) or full filename (e.g., `nnue_weights_v2.pt`)
**Examples:**
```bash
# Export version 2
python nnue.py export 2
# Export with full filename
python nnue.py export nnue_weights_v3.pt
```
Output goes to `../src/main/scala/de/nowchess/bot/bots/nnue/NNUEWeights_vN.scala`
### `list` - List available checkpoints
```bash
python nnue.py list
```
Shows all available model versions with file sizes.
## Data Flow
1. **Generate**`data/positions.txt`
- Random chess positions from 8-20 move openings
- Filters out checks, game-over states, and captures
2. **Label**`data/training_data.jsonl`
- Evaluates each position with Stockfish at depth 12
- Stores FEN + evaluation in JSONL format
3. **Train**`weights/nnue_weights_vN.pt`
- Trains neural network on labeled positions
- Auto-versioning (v1, v2, v3, etc.)
- Saves metadata alongside weights
4. **Export**`NNUEWeights_vN.scala`
- Converts weights to Scala object
- Ready for integration into bot
## Versioning
- Models are automatically versioned (v1, v2, v3, etc.)
- Each version gets a `_metadata.json` file with training info
- Training from checkpoint uses latest version unless specified with `--from-checkpoint`
## Files
- `data/` and `weights/` are gitignored (local artifacts)
- Documentation in `docs/` explains training, debugging, and incremental improvements
- Source modules in `src/` are independent and can be imported for custom workflows
+951
View File
@@ -0,0 +1,951 @@
#!/usr/bin/env python3
"""Central NNUE pipeline TUI for training and exporting models."""
import os
import shutil
import sys
import tempfile
from pathlib import Path
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.prompt import Prompt, Confirm
from rich import print as rprint
# Add src directory to path so we can import modules
sys.path.insert(0, str(Path(__file__).parent / "src"))
from generate import play_random_game_and_collect_positions
from label import label_positions_with_stockfish
from train import train_nnue, burst_train, DEFAULT_HIDDEN_SIZES
from export import export_to_nbai
from tactical_positions_extractor import (
download_and_extract_puzzle_db,
extract_tactical_only
)
from lichess_importer import import_lichess_evals
from dataset import (
get_datasets_dir,
list_datasets,
next_dataset_version,
load_dataset_metadata,
create_dataset,
extend_dataset,
get_dataset_labeled_path,
delete_dataset,
show_datasets_table
)
def get_weights_dir():
"""Get/create weights directory."""
weights_dir = Path(__file__).parent / "weights"
weights_dir.mkdir(exist_ok=True)
return weights_dir
def get_data_dir():
"""Get/create legacy data directory (for migration)."""
data_dir = Path(__file__).parent / "data"
data_dir.mkdir(exist_ok=True)
return data_dir
def list_checkpoints():
"""List available checkpoint versions."""
weights_dir = get_weights_dir()
checkpoints = sorted(weights_dir.glob("nnue_weights_v*.pt"))
if not checkpoints:
return []
return [int(cp.stem.split("_v")[1]) for cp in checkpoints]
def migrate_legacy_data():
"""On first run, offer to import existing data/training_data.jsonl as ds_v1."""
console = Console()
data_dir = get_data_dir()
legacy_file = data_dir / "training_data.jsonl"
datasets = list_datasets()
# Only migrate if legacy data exists and no datasets exist yet
if legacy_file.exists() and not datasets:
console.print("\n[cyan]Legacy data detected: data/training_data.jsonl[/cyan]")
console.print("[dim]Tip: Use 'Manage Training Data' menu to import it as ds_v1[/dim]")
def show_header():
"""Display application header."""
console = Console()
console.clear()
console.print(
Panel(
"[bold cyan]🧠 NNUE Training Pipeline[/bold cyan]\n"
"[dim]Neural Network Utility Evaluation - Dataset & Model Management[/dim]",
border_style="cyan",
padding=(1, 2),
)
)
def show_checkpoints_table():
"""Display available checkpoints in a table."""
console = Console()
available = list_checkpoints()
if not available:
console.print("[yellow] No model checkpoints found yet[/yellow]")
return
table = Table(title="Available Model Checkpoints", show_header=True, header_style="bold cyan")
table.add_column("Version", style="dim")
table.add_column("File Size", justify="right")
table.add_column("Status", justify="center")
weights_dir = get_weights_dir()
for v in sorted(available):
weights_file = weights_dir / f"nnue_weights_v{v}.pt"
if weights_file.exists():
size = weights_file.stat().st_size / (1024**2)
table.add_row(f"v{v}", f"{size:.1f} MB", "✓ Ready")
else:
table.add_row(f"v{v}", "?", "[red]✗ Missing[/red]")
console.print(table)
def show_main_menu():
"""Display and handle main menu."""
console = Console()
# Migrate legacy data on first run
migrate_legacy_data()
while True:
show_header()
show_checkpoints_table()
console.print("\n[bold]What would you like to do?[/bold]")
console.print("[cyan]1[/cyan] - Manage Training Data")
console.print("[cyan]2[/cyan] - Train Model")
console.print("[cyan]3[/cyan] - Export Model")
console.print("[cyan]4[/cyan] - Exit")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
if choice == "1":
datasets_menu()
elif choice == "2":
training_menu()
elif choice == "3":
export_interactive()
elif choice == "4":
console.print("[yellow]👋 Goodbye![/yellow]")
return
def datasets_menu():
"""Dataset management submenu."""
console = Console()
while True:
show_header()
show_datasets_table(console)
console.print("\n[bold]Training Data Management[/bold]")
console.print("[cyan]1[/cyan] - Create new dataset")
console.print("[cyan]2[/cyan] - Extend existing dataset")
console.print("[cyan]3[/cyan] - View all datasets")
console.print("[cyan]4[/cyan] - Delete dataset")
console.print("[cyan]5[/cyan] - Back")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4", "5"])
if choice == "1":
create_dataset_interactive()
elif choice == "2":
extend_dataset_interactive()
elif choice == "3":
show_header()
show_datasets_table(console)
Prompt.ask("\nPress Enter to continue")
elif choice == "4":
delete_dataset_interactive()
elif choice == "5":
return
def create_dataset_interactive():
"""Interactive dataset creation flow."""
console = Console()
show_header()
console.print("\n[bold cyan]📊 Create New Dataset[/bold cyan]")
sources = []
combined_count = 0
# Allow user to add multiple sources
while True:
console.print("\n[bold]Add data source (repeat until done):[/bold]")
console.print("[cyan]a[/cyan] - Generate random positions")
console.print("[cyan]b[/cyan] - Import from file")
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
console.print("[cyan]e[/cyan] - Done adding sources")
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
if choice == "a":
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
min_move = int(Prompt.ask("Minimum move number", default="1"))
max_move = int(Prompt.ask("Maximum move number", default="50"))
num_workers = int(Prompt.ask("Number of workers", default="8"))
console.print("[dim]Generating positions...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
count = play_random_game_and_collect_positions(
str(temp_file),
total_positions=num_positions,
samples_per_game=1,
min_move=min_move,
max_move=max_move,
num_workers=num_workers
)
if count > 0:
sources.append({
"type": "generated",
"count": count,
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
})
combined_count += count
console.print(f"[green]✓ {count:,} positions generated[/green]")
else:
console.print("[red]✗ Generation failed[/red]")
elif choice == "b":
file_path = Prompt.ask("Path to FEN file")
try:
with open(file_path, 'r') as f:
count = sum(1 for _ in f)
sources.append({"type": "file_import", "count": count, "path": file_path})
combined_count += count
console.print(f"[green]✓ {count:,} positions from file[/green]")
except FileNotFoundError:
console.print(f"[red]✗ File not found: {file_path}[/red]")
elif choice == "c":
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
console.print("[dim]Extracting tactical positions...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
try:
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
if csv_path:
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
combined_count += count
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
except Exception as e:
console.print(f"[red]✗ Tactical extraction failed: {e}[/red]")
elif choice == "d":
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
max_pos = int(max_pos) if max_pos.strip() else None
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
temp_file.unlink(missing_ok=True)
try:
count = import_lichess_evals(
input_path=zst_path,
output_file=str(temp_file),
max_positions=max_pos,
min_depth=min_depth,
)
if count > 0:
sources.append({
"type": "lichess",
"count": count,
"params": {"min_depth": min_depth, "max_positions": max_pos},
})
combined_count += count
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
else:
console.print("[red]✗ No positions imported[/red]")
except Exception as e:
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
elif choice == "e":
if not sources:
console.print("[yellow]⚠ No sources added yet[/yellow]")
continue
break
if not sources:
console.print("[yellow]Dataset creation cancelled[/yellow]")
return
# Determine whether any sources still need Stockfish labeling.
# Lichess sources are already labeled; only generated/tactical/file sources need it.
needs_labeling = any(s["type"] != "lichess" for s in sources)
stockfish_depth = 12
if needs_labeling:
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Summary and confirm
console.print("\n[bold]Dataset Summary:[/bold]")
console.print(f" Total positions: {combined_count:,}")
for source in sources:
console.print(f" - {source['type']}: {source['count']:,}")
if needs_labeling:
console.print(f" Stockfish depth: {stockfish_depth}")
if not Confirm.ask("\nProceed to create dataset?", default=True):
console.print("[yellow]Cancelled[/yellow]")
return
try:
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
labeled_file.unlink(missing_ok=True)
# --- Step 1: Collect already-labeled data (Lichess source) ---
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
if lichess_tmp.exists():
import shutil as _shutil
_shutil.copy(lichess_tmp, labeled_file)
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
console.print(f"[green]✓ Lichess positions ready[/green]")
# --- Step 2: Combine unlabeled sources and run Stockfish (if any) ---
non_lichess = [s for s in sources if s["type"] != "lichess"]
if non_lichess:
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
for source in non_lichess:
if source["type"] == "generated":
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source["type"] == "file_import":
temp_file = Path(source["path"])
elif source["type"] == "tactical":
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
else:
continue
if temp_file.exists():
with open(temp_file, "r") as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, "w") as f:
for fen in all_fens:
f.write(fen + "\n")
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers,
)
if not success:
console.print("[red]✗ Stockfish labeling failed[/red]")
return
console.print("[green]✓ Positions labeled[/green]")
# --- Step 3: Create dataset ---
console.print("\n[bold cyan]Step 3: Creating Dataset[/bold cyan]")
version = next_dataset_version()
create_dataset(
version=version,
labeled_jsonl_path=str(labeled_file),
sources=sources,
stockfish_depth=stockfish_depth,
)
console.print(f"[green]✓ Dataset created: ds_v{version}[/green]")
console.print(f"[bold]Location: {get_datasets_dir() / f'ds_v{version}'}[/bold]")
Prompt.ask("\nPress Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def extend_dataset_interactive():
"""Interactive dataset extension flow."""
console = Console()
show_header()
console.print("\n[bold cyan]📊 Extend Existing Dataset[/bold cyan]")
datasets = list_datasets()
if not datasets:
console.print("[yellow] No datasets available to extend[/yellow]")
Prompt.ask("Press Enter to continue")
return
show_datasets_table(console)
version = int(Prompt.ask("\nEnter dataset version to extend (e.g., 1)"))
if not any(v == version for v, _ in datasets):
console.print("[red]✗ Dataset not found[/red]")
return
sources = []
combined_count = 0
# Allow user to add sources
while True:
console.print("\n[bold]Add data source:[/bold]")
console.print("[cyan]a[/cyan] - Generate random positions")
console.print("[cyan]b[/cyan] - Import from file")
console.print("[cyan]c[/cyan] - Extract Lichess tactical puzzles")
console.print("[cyan]d[/cyan] - Import Lichess eval database (.jsonl.zst)")
console.print("[cyan]e[/cyan] - Done adding sources")
choice = Prompt.ask("Select", choices=["a", "b", "c", "d", "e"])
if choice == "a":
num_positions = int(Prompt.ask("Number of positions to generate", default="100000"))
min_move = int(Prompt.ask("Minimum move number", default="1"))
max_move = int(Prompt.ask("Maximum move number", default="50"))
num_workers = int(Prompt.ask("Number of workers", default="8"))
console.print("[dim]Generating positions...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
count = play_random_game_and_collect_positions(
str(temp_file),
total_positions=num_positions,
samples_per_game=1,
min_move=min_move,
max_move=max_move,
num_workers=num_workers
)
if count > 0:
sources.append({
"type": "generated",
"count": count,
"params": {"num_positions": num_positions, "min_move": min_move, "max_move": max_move}
})
combined_count += count
console.print(f"[green]✓ {count:,} positions generated[/green]")
elif choice == "b":
file_path = Prompt.ask("Path to FEN file")
try:
with open(file_path, 'r') as f:
count = sum(1 for _ in f)
sources.append({"type": "file_import", "count": count, "path": file_path})
combined_count += count
console.print(f"[green]✓ {count:,} positions from file[/green]")
except FileNotFoundError:
console.print(f"[red]✗ File not found: {file_path}[/red]")
elif choice == "c":
max_puzzles = int(Prompt.ask("Maximum puzzles to extract", default="300000"))
console.print("[dim]Extracting tactical positions...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
try:
csv_path = download_and_extract_puzzle_db(output_dir=str(Path(__file__).parent / "tactical_data"))
if csv_path:
count = extract_tactical_only(csv_path, str(temp_file), max_puzzles)
sources.append({"type": "tactical", "count": count, "max_puzzles": max_puzzles})
combined_count += count
console.print(f"[green]✓ {count:,} tactical positions extracted[/green]")
except Exception as e:
console.print(f"[red]✗ Extraction failed: {e}[/red]")
elif choice == "d":
zst_path = Prompt.ask("Path to lichess_db_eval.jsonl.zst")
max_pos = Prompt.ask("Max positions to import (blank = no limit)", default="")
max_pos = int(max_pos) if max_pos.strip() else None
min_depth = int(Prompt.ask("Minimum eval depth to accept", default="20"))
console.print("[dim]Importing Lichess evals (this may take a while)...[/dim]")
temp_file = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
temp_file.unlink(missing_ok=True)
try:
count = import_lichess_evals(
input_path=zst_path,
output_file=str(temp_file),
max_positions=max_pos,
min_depth=min_depth,
)
if count > 0:
sources.append({
"type": "lichess",
"count": count,
"params": {"min_depth": min_depth, "max_positions": max_pos},
})
combined_count += count
console.print(f"[green]✓ {count:,} positions imported from Lichess[/green]")
else:
console.print("[red]✗ No positions imported[/red]")
except Exception as e:
console.print(f"[red]✗ Lichess import failed: {e}[/red]")
elif choice == "e":
if not sources:
console.print("[yellow]⚠ No sources added yet[/yellow]")
continue
break
if not sources:
console.print("[yellow]Extension cancelled[/yellow]")
return
needs_labeling = any(s["type"] != "lichess" for s in sources)
stockfish_depth = 12
if needs_labeling:
console.print("\n[bold cyan]🏷️ Labeling Parameters[/bold cyan]")
stockfish_path = Prompt.ask(
"Stockfish path",
default=os.environ.get("STOCKFISH_PATH") or shutil.which("stockfish") or "/usr/bin/stockfish"
)
stockfish_depth = int(Prompt.ask("Stockfish analysis depth", default="12"))
num_workers = int(Prompt.ask("Number of parallel workers", default="1"))
# Summary and confirm
console.print("\n[bold]Extension Summary:[/bold]")
console.print(f" Target dataset: ds_v{version}")
console.print(f" New positions: {combined_count:,}")
for source in sources:
console.print(f" - {source['type']}: {source['count']:,}")
if needs_labeling:
console.print(f" Stockfish depth: {stockfish_depth}")
if not Confirm.ask("\nProceed to extend dataset?", default=True):
console.print("[yellow]Cancelled[/yellow]")
return
try:
labeled_file = Path(tempfile.gettempdir()) / "labeled.jsonl"
labeled_file.unlink(missing_ok=True)
# Copy pre-labeled Lichess data if present
lichess_tmp = Path(tempfile.gettempdir()) / "temp_lichess.jsonl"
if lichess_tmp.exists():
import shutil as _shutil
_shutil.copy(lichess_tmp, labeled_file)
console.print(f"\n[bold cyan]Step 1: Pre-labeled data copied[/bold cyan]")
console.print(f"[green]✓ Lichess positions ready[/green]")
# Combine and label remaining sources with Stockfish
non_lichess = [s for s in sources if s["type"] != "lichess"]
if non_lichess:
console.print("\n[bold cyan]Step 2: Combining unlabeled sources[/bold cyan]")
combined_fen_file = Path(tempfile.gettempdir()) / "combined_positions.txt"
all_fens = set()
for source in non_lichess:
if source["type"] == "generated":
temp_file = Path(tempfile.gettempdir()) / "temp_positions.txt"
elif source["type"] == "file_import":
temp_file = Path(source["path"])
elif source["type"] == "tactical":
temp_file = Path(tempfile.gettempdir()) / "temp_tactical.txt"
else:
continue
if temp_file.exists():
with open(temp_file, "r") as f:
for line in f:
fen = line.strip()
if fen:
all_fens.add(fen)
with open(combined_fen_file, "w") as f:
for fen in all_fens:
f.write(fen + "\n")
console.print(f"[green]✓ Combined {len(all_fens):,} unique unlabeled positions[/green]")
console.print("\n[bold cyan]Step 2b: Labeling with Stockfish[/bold cyan]")
success = label_positions_with_stockfish(
str(combined_fen_file),
str(labeled_file),
stockfish_path,
depth=stockfish_depth,
num_workers=num_workers,
)
if not success:
console.print("[red]✗ Stockfish labeling failed[/red]")
return
console.print("[green]✓ Positions labeled[/green]")
# Extend dataset
console.print("\n[bold cyan]Step 3: Extending Dataset[/bold cyan]")
success = extend_dataset(
version=version,
new_labeled_path=str(labeled_file),
new_source_entry={
"type": "merged_sources",
"count": combined_count,
"sources": sources,
}
)
if success:
metadata = load_dataset_metadata(version)
console.print(f"[green]✓ Dataset extended[/green]")
console.print(f"[bold]Total positions: {metadata['total_positions']:,}[/bold]")
else:
console.print("[red]✗ Extension failed[/red]")
Prompt.ask("\nPress Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def delete_dataset_interactive():
"""Interactive dataset deletion."""
console = Console()
show_header()
console.print("\n[bold cyan]⚠️ Delete Dataset[/bold cyan]")
datasets = list_datasets()
if not datasets:
console.print("[yellow] No datasets to delete[/yellow]")
Prompt.ask("Press Enter to continue")
return
show_datasets_table(console)
version = int(Prompt.ask("\nEnter dataset version to delete (e.g., 1)"))
if not any(v == version for v, _ in datasets):
console.print("[red]✗ Dataset not found[/red]")
return
if Confirm.ask(f"Delete ds_v{version}? This cannot be undone.", default=False):
if delete_dataset(version):
console.print(f"[green]✓ Dataset ds_v{version} deleted[/green]")
else:
console.print("[red]✗ Deletion failed[/red]")
Prompt.ask("Press Enter to continue")
def training_menu():
"""Training submenu."""
console = Console()
while True:
show_header()
console.print("\n[bold]Training[/bold]")
console.print("[cyan]1[/cyan] - Standard Training")
console.print("[cyan]2[/cyan] - Burst Training")
console.print("[cyan]3[/cyan] - View Model Checkpoints")
console.print("[cyan]4[/cyan] - Back")
choice = Prompt.ask("\nSelect option", choices=["1", "2", "3", "4"])
if choice == "1":
train_interactive()
elif choice == "2":
burst_train_interactive()
elif choice == "3":
show_header()
show_checkpoints_table()
Prompt.ask("\nPress Enter to continue")
elif choice == "4":
return
def train_interactive():
"""Interactive training menu."""
console = Console()
show_header()
console.print("\n[bold cyan]📚 Standard Training Configuration[/bold cyan]")
# Dataset selection
datasets = list_datasets()
if not datasets:
console.print("[red]✗ No datasets available. Create one first.[/red]")
Prompt.ask("Press Enter to continue")
return
console.print("\n[bold]Available Datasets:[/bold]")
show_datasets_table(console)
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
if not any(v == dataset_version for v, _ in datasets):
console.print("[red]✗ Dataset not found[/red]")
return
labeled_file = get_dataset_labeled_path(dataset_version)
if not labeled_file:
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
return
# Checkpoint selection
available = list_checkpoints()
use_checkpoint = False
checkpoint_version = None
if available:
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
use_checkpoint = Confirm.ask("Start from an existing checkpoint?", default=False)
if use_checkpoint:
checkpoint_version = Prompt.ask(
"Enter checkpoint version",
default=str(max(available))
)
# Training parameters
epochs = int(Prompt.ask("Number of epochs", default="100"))
batch_size = int(Prompt.ask("Batch size", default="16384"))
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
hidden_layers_str = Prompt.ask(
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
default=default_layers
)
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
early_stopping = None
if Confirm.ask("Enable early stopping?", default=False):
early_stopping = int(Prompt.ask("Patience (epochs)", default="5"))
arch_str = "".join(str(s) for s in [768] + hidden_sizes + [1])
# Confirm and start
console.print("\n[bold]Configuration Summary:[/bold]")
console.print(f" Dataset: ds_v{dataset_version}")
console.print(f" Architecture: {arch_str}")
console.print(f" Epochs: {epochs}")
console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
if early_stopping:
console.print(f" Early stopping: Yes (patience: {early_stopping})")
else:
console.print(f" Early stopping: No")
if use_checkpoint:
console.print(f" Checkpoint: v{checkpoint_version}")
else:
console.print(f" Checkpoint: None (training from scratch)")
if not Confirm.ask("\nStart training?", default=True):
console.print("[yellow]Training cancelled[/yellow]")
Prompt.ask("Press Enter to continue")
return
# Execute training
weights_dir = get_weights_dir()
try:
console.print("\n[bold cyan]Training Model[/bold cyan]")
checkpoint = None
if use_checkpoint:
checkpoint = str(weights_dir / f"nnue_weights_v{checkpoint_version}.pt")
train_nnue(
data_file=str(labeled_file),
output_file=str(weights_dir / "nnue_weights.pt"),
epochs=epochs,
batch_size=batch_size,
checkpoint=checkpoint,
use_versioning=True,
early_stopping_patience=early_stopping,
subsample_ratio=subsample_ratio,
hidden_sizes=hidden_sizes,
)
console.print("[green]✓ Training complete[/green]")
# Show result
available = list_checkpoints()
new_version = max(available) if available else 1
console.print(f"\n[bold green]✓ Training successful![/bold green]")
console.print(f"[bold]New checkpoint: v{new_version}[/bold]")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def burst_train_interactive():
"""Interactive burst training menu."""
console = Console()
show_header()
console.print("\n[bold cyan]⚡ Burst Training Configuration[/bold cyan]")
console.print("[dim]Repeatedly restarts from the best checkpoint until the time budget expires.[/dim]\n")
# Dataset selection
datasets = list_datasets()
if not datasets:
console.print("[red]✗ No datasets available. Create one first.[/red]")
Prompt.ask("Press Enter to continue")
return
console.print("[bold]Available Datasets:[/bold]")
show_datasets_table(console)
dataset_version = int(Prompt.ask("\nEnter dataset version to train on (e.g., 1)"))
if not any(v == dataset_version for v, _ in datasets):
console.print("[red]✗ Dataset not found[/red]")
return
labeled_file = get_dataset_labeled_path(dataset_version)
if not labeled_file:
console.print("[red]✗ Dataset labeled.jsonl not found[/red]")
return
duration_minutes = float(Prompt.ask("Training budget (minutes)", default="60"))
epochs_per_season = int(Prompt.ask("Max epochs per season", default="50"))
early_stopping_patience = int(Prompt.ask("Early stopping patience (epochs)", default="10"))
# Optional initial checkpoint
available = list_checkpoints()
checkpoint = None
if available:
console.print(f"\n[dim]Available checkpoints: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
if Confirm.ask("Start from an existing checkpoint?", default=False):
version = Prompt.ask("Enter checkpoint version", default=str(max(available)))
checkpoint = str(get_weights_dir() / f"nnue_weights_v{version}.pt")
# Training hyperparameters
batch_size = int(Prompt.ask("Batch size", default="16384"))
subsample_ratio = float(Prompt.ask("Stochastic subsample ratio per epoch (1.0 = all data)", default="1.0"))
default_layers = ",".join(str(s) for s in DEFAULT_HIDDEN_SIZES)
hidden_layers_str = Prompt.ask(
"Hidden layer sizes (comma-separated, e.g. 1536,1024,512,256)",
default=default_layers
)
hidden_sizes = [int(x.strip()) for x in hidden_layers_str.split(",") if x.strip()]
arch_str = "".join(str(s) for s in [768] + hidden_sizes + [1])
# Summary
console.print("\n[bold]Configuration Summary:[/bold]")
console.print(f" Dataset: ds_v{dataset_version}")
console.print(f" Architecture: {arch_str}")
console.print(f" Duration: {duration_minutes:.0f} minutes")
console.print(f" Epochs per season: {epochs_per_season}")
console.print(f" Patience: {early_stopping_patience}")
console.print(f" Batch size: {batch_size}")
console.print(f" Subsample ratio: {subsample_ratio:.0%}")
console.print(f" Checkpoint: {checkpoint or 'None (from scratch)'}")
if not Confirm.ask("\nStart burst training?", default=True):
console.print("[yellow]Burst training cancelled[/yellow]")
Prompt.ask("Press Enter to continue")
return
weights_dir = get_weights_dir()
try:
console.print("\n[bold cyan]Burst Training[/bold cyan]")
burst_train(
data_file=str(labeled_file),
output_file=str(weights_dir / "nnue_weights.pt"),
duration_minutes=duration_minutes,
epochs_per_season=epochs_per_season,
early_stopping_patience=early_stopping_patience,
batch_size=batch_size,
initial_checkpoint=checkpoint,
use_versioning=True,
subsample_ratio=subsample_ratio,
hidden_sizes=hidden_sizes,
)
console.print("[green]✓ Burst training complete[/green]")
available = list_checkpoints()
if available:
console.print(f"[bold]Latest checkpoint: v{max(available)}[/bold]")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def export_interactive():
"""Interactive export menu."""
console = Console()
show_header()
console.print("\n[bold cyan]📦 Export Configuration[/bold cyan]")
# Select weights version
available = list_checkpoints()
if not available:
console.print("[red]✗ No checkpoints available to export[/red]")
Prompt.ask("Press Enter to continue")
return
console.print(f"[dim]Available versions: {', '.join([f'v{v}' for v in sorted(available)])}[/dim]")
version = Prompt.ask("Enter version to export (e.g., 2)")
weights_file = f"nnue_weights_v{version}.pt"
output_file = str(Path(__file__).parent.parent / "src" / "main" / "resources" / "nnue_weights.nbai")
console.print(f"\n[bold]Export Configuration:[/bold]")
console.print(f" Source: {weights_file}")
console.print(f" Destination: {output_file}")
if not Confirm.ask("\nExport weights?", default=True):
console.print("[yellow]Export cancelled[/yellow]")
return
try:
weights_dir = get_weights_dir()
weights_path = weights_dir / weights_file
if not weights_path.exists():
console.print(f"[red]✗ {weights_file} not found[/red]")
return
console.print("\n[bold cyan]Exporting Weights[/bold cyan]")
export_to_nbai(str(weights_path), output_file)
console.print(f"\n[green]✓ Export complete![/green]")
console.print(f"[bold]Weights saved to:[/bold] {output_file}")
Prompt.ask("Press Enter to continue")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
import traceback
traceback.print_exc()
Prompt.ask("Press Enter to continue")
def main():
try:
show_main_menu()
return 0
except KeyboardInterrupt:
console = Console()
console.print("\n[yellow]Interrupted by user[/yellow]")
return 1
except Exception as e:
console = Console()
console.print(f"[red]Error:[/red] {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())
+6
View File
@@ -0,0 +1,6 @@
chess==1.11.2
torch==2.11.0
tqdm==4.67.3
numpy==2.4.4
rich==13.7.0
zstandard==0.23.0
+66
View File
@@ -0,0 +1,66 @@
@echo off
REM NNUE Training Pipeline for Windows
setlocal enabledelayedexpansion
echo.
echo === NNUE Training Pipeline ===
echo.
REM Get the directory where this script is located
set SCRIPT_DIR=%~dp0
cd /d "%SCRIPT_DIR%"
REM Step 1: Generate positions
echo Step 1: Generating 500,000 random positions...
python generate_positions.py positions.txt
if not exist positions.txt (
echo ERROR: positions.txt not created
exit /b 1
)
echo [OK] Positions generated
echo.
REM Step 2: Label positions with Stockfish
echo Step 2: Labeling positions with Stockfish (depth 12^)...
if "%STOCKFISH_PATH%"=="" (
set STOCKFISH_PATH=stockfish
)
python label_positions.py positions.txt training_data.jsonl "%STOCKFISH_PATH%"
if not exist training_data.jsonl (
echo ERROR: training_data.jsonl not created
exit /b 1
)
echo [OK] Positions labeled
echo.
REM Step 3: Train NNUE model
echo Step 3: Training NNUE model (20 epochs^)...
python train_nnue.py training_data.jsonl nnue_weights.pt
if not exist nnue_weights.pt (
echo ERROR: nnue_weights.pt not created
exit /b 1
)
echo [OK] Model trained
echo.
REM Step 4: Export weights to Scala
echo Step 4: Exporting weights to Scala...
python export_weights.py nnue_weights.pt ..\src\main\scala\de\nowchess\bot\bots\nnue\NNUEWeights.scala
if not exist ..\src\main\scala\de\nowchess\bot\bots\nnue\NNUEWeights.scala (
echo ERROR: NNUEWeights.scala not created
exit /b 1
)
echo [OK] Weights exported
echo.
echo === Pipeline Complete ===
echo.
echo Next steps:
echo 1. Navigate to project root: cd ..\..
echo 2. Compile: .\compile.bat
echo 3. Test: .\test.bat
echo.
endlocal
+39
View File
@@ -0,0 +1,39 @@
#!/bin/bash
# NNUE Training Pipeline (bash version)
# Uses the central CLI (nnue.py) for all operations
# Works on Linux, macOS, and Windows (with Git Bash or WSL)
set -e # Exit on error
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$SCRIPT_DIR"
# Use python or python3 (check which is available)
PYTHON_CMD="python3"
if ! command -v python3 &> /dev/null; then
PYTHON_CMD="python"
fi
echo "=== NNUE Training Pipeline ==="
echo ""
echo "Python command: $PYTHON_CMD"
echo "Working directory: $SCRIPT_DIR"
echo ""
# Run the unified training pipeline
$PYTHON_CMD nnue.py train
if [ $? -ne 0 ]; then
echo ""
echo "ERROR: Training pipeline failed"
exit 1
fi
echo ""
echo "=== Pipeline Complete ==="
echo ""
echo "Next steps:"
echo "1. Navigate to project root: cd ../.."
echo "2. Compile: ./compile"
echo "3. Test: ./test"
+287
View File
@@ -0,0 +1,287 @@
#!/usr/bin/env python3
"""Dataset versioning and management for NNUE training data."""
import json
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, List, Tuple
from rich.console import Console
from rich.table import Table
def get_datasets_dir() -> Path:
"""Get/create datasets directory."""
datasets_dir = Path(__file__).parent.parent / "datasets"
datasets_dir.mkdir(exist_ok=True)
return datasets_dir
def next_dataset_version() -> int:
"""Find the next available dataset version number."""
datasets_dir = get_datasets_dir()
versions = []
for d in datasets_dir.iterdir():
if d.is_dir() and d.name.startswith("ds_v"):
try:
v = int(d.name.split("_v")[1])
versions.append(v)
except (ValueError, IndexError):
pass
return max(versions) + 1 if versions else 1
def list_datasets() -> List[Tuple[int, Dict]]:
"""List all datasets with their metadata.
Returns:
List of (version, metadata_dict) tuples, sorted by version.
"""
datasets_dir = get_datasets_dir()
datasets = []
for d in datasets_dir.iterdir():
if d.is_dir() and d.name.startswith("ds_v"):
try:
v = int(d.name.split("_v")[1])
metadata_file = d / "metadata.json"
if metadata_file.exists():
with open(metadata_file, 'r') as f:
metadata = json.load(f)
datasets.append((v, metadata))
except (ValueError, IndexError, json.JSONDecodeError):
pass
return sorted(datasets, key=lambda x: x[0])
def load_dataset_metadata(version: int) -> Optional[Dict]:
"""Load metadata for a specific dataset version.
Returns:
Metadata dict or None if not found.
"""
datasets_dir = get_datasets_dir()
metadata_file = datasets_dir / f"ds_v{version}" / "metadata.json"
if not metadata_file.exists():
return None
with open(metadata_file, 'r') as f:
return json.load(f)
def save_dataset_metadata(version: int, metadata: Dict) -> None:
"""Save metadata for a dataset version."""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
dataset_dir.mkdir(exist_ok=True)
metadata_file = dataset_dir / "metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2, default=str)
def create_dataset(
version: int,
labeled_jsonl_path: str,
sources: List[Dict],
stockfish_depth: int = 12
) -> Path:
"""Create a new versioned dataset.
Args:
version: Dataset version number
labeled_jsonl_path: Path to labeled.jsonl to copy
sources: List of source dicts (see plan for schema)
stockfish_depth: Depth used for labeling
Returns:
Path to the created dataset directory.
"""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
dataset_dir.mkdir(exist_ok=True)
# Copy labeled data with deduplication (in case source has duplicates)
source_path = Path(labeled_jsonl_path)
if source_path.exists():
dest_path = dataset_dir / "labeled.jsonl"
seen_fens = set()
unique_count = 0
with open(source_path, 'r') as src, open(dest_path, 'w') as dst:
for line in src:
try:
data = json.loads(line)
fen = data.get('fen')
if fen and fen not in seen_fens:
dst.write(line)
seen_fens.add(fen)
unique_count += 1
except json.JSONDecodeError:
# Skip malformed lines
pass
# Count positions
total_positions = 0
if (dataset_dir / "labeled.jsonl").exists():
with open(dataset_dir / "labeled.jsonl", 'r') as f:
total_positions = sum(1 for _ in f)
# Create metadata
metadata = {
"version": version,
"created": datetime.now().isoformat(),
"total_positions": total_positions,
"stockfish_depth": stockfish_depth,
"sources": sources
}
save_dataset_metadata(version, metadata)
return dataset_dir
def extend_dataset(
version: int,
new_labeled_path: str,
new_source_entry: Dict
) -> bool:
"""Extend an existing dataset with new labeled positions (with deduplication).
Args:
version: Dataset version to extend
new_labeled_path: Path to new labeled.jsonl to merge
new_source_entry: Source entry to add to metadata
Returns:
True if successful, False otherwise.
"""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
if not dataset_dir.exists():
return False
labeled_file = dataset_dir / "labeled.jsonl"
new_labeled_file = Path(new_labeled_path)
if not new_labeled_file.exists():
return False
# Load existing FENs (dedup set) — must load entire file to avoid duplicates
existing_fens = set()
if labeled_file.exists():
with open(labeled_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
fen = data.get('fen')
if fen:
existing_fens.add(fen)
except json.JSONDecodeError:
pass
# Merge new positions, skipping duplicates
new_count = 0
new_lines = []
with open(new_labeled_file, 'r') as f_new:
for line in f_new:
try:
data = json.loads(line)
fen = data.get('fen')
if fen and fen not in existing_fens:
new_lines.append(line)
existing_fens.add(fen)
new_count += 1
except json.JSONDecodeError:
pass
# Append only the new, unique positions
if new_lines:
with open(labeled_file, 'a') as f_append:
for line in new_lines:
f_append.write(line)
# Update metadata
metadata = load_dataset_metadata(version)
if metadata:
# Count total positions
total_positions = 0
with open(labeled_file, 'r') as f:
total_positions = sum(1 for _ in f)
metadata['total_positions'] = total_positions
# Update the source entry with actual count of new positions added
new_source_entry['actual_count'] = new_count
metadata['sources'].append(new_source_entry)
save_dataset_metadata(version, metadata)
return True
def get_dataset_labeled_path(version: int) -> Optional[Path]:
"""Get the path to a dataset's labeled.jsonl file.
Returns:
Path to labeled.jsonl or None if dataset doesn't exist.
"""
datasets_dir = get_datasets_dir()
labeled_file = datasets_dir / f"ds_v{version}" / "labeled.jsonl"
if labeled_file.exists():
return labeled_file
return None
def delete_dataset(version: int) -> bool:
"""Delete a dataset (recursively removes directory).
Args:
version: Dataset version to delete
Returns:
True if successful.
"""
datasets_dir = get_datasets_dir()
dataset_dir = datasets_dir / f"ds_v{version}"
if not dataset_dir.exists():
return False
import shutil
shutil.rmtree(dataset_dir)
return True
def show_datasets_table(console: Console = None) -> None:
"""Display all datasets in a Rich table."""
if console is None:
console = Console()
datasets = list_datasets()
if not datasets:
console.print("[yellow] No datasets found yet[/yellow]")
return
table = Table(title="Available Datasets", show_header=True, header_style="bold cyan")
table.add_column("Version", style="dim")
table.add_column("Positions", justify="right")
table.add_column("Sources", justify="left")
table.add_column("Depth", justify="center")
table.add_column("Created", justify="left")
for v, metadata in datasets:
positions = metadata.get('total_positions', 0)
sources = metadata.get('sources', [])
source_str = ", ".join([s.get('type', '?') for s in sources])
depth = metadata.get('stockfish_depth', '?')
created = metadata.get('created', '?')
if created != '?':
created = created.split('T')[0] # Just the date
table.add_row(f"v{v}", f"{positions:,}", source_str, str(depth), created)
console.print(table)
+137
View File
@@ -0,0 +1,137 @@
#!/usr/bin/env python3
"""Export NNUE weights to .nbai format for runtime loading."""
import json
import struct
import sys
from datetime import datetime
from pathlib import Path
import torch
MAGIC = 0x4942_414E # bytes 'N','B','A','I' as little-endian int32
VERSION = 1
def _read_sidecar(weights_file: str) -> dict:
sidecar = weights_file.replace(".pt", "_metadata.json")
if Path(sidecar).exists():
with open(sidecar) as f:
return json.load(f)
return {}
def _infer_layers(state_dict: dict) -> list[dict]:
"""Derive layer descriptors from state_dict weight shapes.
Assumes layers named l1, l2, ..., lN.
All hidden layers get activation 'relu'; the last gets 'linear'.
"""
names = sorted(
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
key=lambda n: int(n[1:]),
)
layers = []
for i, name in enumerate(names):
out_size, in_size = state_dict[f"{name}.weight"].shape
activation = "linear" if i == len(names) - 1 else "relu"
layers.append({"activation": activation, "inputSize": int(in_size), "outputSize": int(out_size)})
return layers
def _write_floats(f, tensor):
data = tensor.float().flatten().cpu().numpy()
f.write(struct.pack("<I", len(data)))
f.write(struct.pack(f"<{len(data)}f", *data))
def export_to_nbai(
weights_file: str,
output_file: str,
trained_by: str = "unknown",
train_loss: float = 0.0,
):
if not Path(weights_file).exists():
print(f"Error: weights file not found at {weights_file}")
sys.exit(1)
loaded = torch.load(weights_file, map_location="cpu")
state_dict = (
loaded["model_state_dict"]
if isinstance(loaded, dict) and "model_state_dict" in loaded
else loaded
)
sidecar = _read_sidecar(weights_file)
val_loss = float(loaded.get("best_val_loss", sidecar.get("final_val_loss", 0.0))) if isinstance(loaded, dict) else 0.0
trained_at = sidecar.get("date", datetime.now().isoformat())
training_data_count = int(sidecar.get("num_positions", 0))
metadata = {
"trainedBy": trained_by,
"trainedAt": trained_at,
"trainingDataCount": training_data_count,
"valLoss": val_loss,
"trainLoss": train_loss,
}
layers = _infer_layers(state_dict)
layer_names = sorted(
{k.split(".")[0] for k in state_dict if k.endswith(".weight")},
key=lambda n: int(n[1:]),
)
print(f"Architecture ({len(layers)} layers):")
for i, l in enumerate(layers):
print(f" l{i + 1}: {l['inputSize']} -> {l['outputSize']} [{l['activation']}]")
Path(output_file).parent.mkdir(parents=True, exist_ok=True)
with open(output_file, "wb") as f:
# Header
f.write(struct.pack("<I", MAGIC))
f.write(struct.pack("<H", VERSION))
# Metadata (length-prefixed UTF-8 JSON)
meta_bytes = json.dumps(metadata, indent=2).encode("utf-8")
f.write(struct.pack("<I", len(meta_bytes)))
f.write(meta_bytes)
# Layer descriptors
f.write(struct.pack("<H", len(layers)))
for layer in layers:
name_bytes = layer["activation"].encode("ascii")
f.write(struct.pack("<B", len(name_bytes)))
f.write(name_bytes)
f.write(struct.pack("<I", layer["inputSize"]))
f.write(struct.pack("<I", layer["outputSize"]))
# Weights: weight tensor then bias tensor per layer
for name in layer_names:
w = state_dict[f"{name}.weight"]
b = state_dict[f"{name}.bias"]
_write_floats(f, w)
_write_floats(f, b)
print(f" Wrote {name}: weight {tuple(w.shape)}, bias {tuple(b.shape)}")
size_mb = Path(output_file).stat().st_size / (1024 ** 2)
print(f"\nExported to {output_file} ({size_mb:.2f} MB)")
print(f"Metadata: {json.dumps(metadata, indent=2)}")
if __name__ == "__main__":
weights_file = "nnue_weights.pt"
output_file = "../src/main/resources/nnue_weights.nbai"
trained_by = "unknown"
train_loss = 0.0
if len(sys.argv) > 1:
weights_file = sys.argv[1]
if len(sys.argv) > 2:
output_file = sys.argv[2]
if len(sys.argv) > 3:
trained_by = sys.argv[3]
if len(sys.argv) > 4:
train_loss = float(sys.argv[4])
export_to_nbai(weights_file, output_file, trained_by, train_loss)
+171
View File
@@ -0,0 +1,171 @@
#!/usr/bin/env python3
"""Generate random chess positions for NNUE training with multiprocessing."""
import chess
import random
import sys
from pathlib import Path
from tqdm import tqdm
from multiprocessing import Pool, Queue
from datetime import datetime
import time
def _worker_generate_games(worker_id, games_per_worker, samples_per_game, min_move, max_move):
"""Generate games for one worker.
Returns:
list of FENs generated by this worker
"""
positions = []
for game_num in range(games_per_worker):
board = chess.Board()
move_history = []
# Play a complete random game
while not board.is_game_over() and len(move_history) < 200:
legal_moves = list(board.legal_moves)
if not legal_moves:
break
move = random.choice(legal_moves)
board.push(move)
move_history.append(board.copy())
# Determine the range of moves to sample from
game_length = len(move_history)
valid_start = max(min_move, 0)
valid_end = min(max_move, game_length)
if valid_start >= valid_end:
continue
# Randomly sample positions from this game
sample_count = min(samples_per_game, valid_end - valid_start)
if sample_count > 0:
sample_indices = random.sample(
range(valid_start, valid_end),
k=sample_count
)
for idx in sample_indices:
sampled_board = move_history[idx]
# Only filter truly invalid or terminal positions
if not sampled_board.is_valid() or sampled_board.is_game_over():
continue
# Save position (include check, captures, all positions)
fen = sampled_board.fen()
positions.append(fen)
return positions
def play_random_game_and_collect_positions(
output_file,
total_positions=3000000,
samples_per_game=1,
min_move=1,
max_move=50,
num_workers=8
):
"""Generate positions using multiprocessing with multiple workers.
Args:
output_file: Output file for positions
total_positions: Target number of positions to generate
samples_per_game: Number of positions to sample per game (1-N)
min_move: Minimum move number to start sampling from
max_move: Maximum move number for sampling
num_workers: Number of parallel worker processes
Returns:
Number of valid positions saved
"""
# Estimate games needed (roughly 1 position per game on average)
total_games = max(total_positions // samples_per_game, num_workers)
games_per_worker = total_games // num_workers
print(f"Generating {total_positions:,} positions using {num_workers} workers")
print(f"Total games: ~{total_games:,} ({games_per_worker:,} per worker)")
print()
start_time = datetime.now()
# Generate positions in parallel
worker_tasks = [
(i, games_per_worker, samples_per_game, min_move, max_move)
for i in range(num_workers)
]
positions_count = 0
all_positions = []
with Pool(num_workers) as pool:
with tqdm(total=num_workers, desc="Workers generating games") as pbar:
for positions in pool.starmap(_worker_generate_games, worker_tasks):
all_positions.extend(positions)
positions_count += len(positions)
pbar.update(1)
# Write all positions to file
print(f"Writing {positions_count:,} positions to {output_file}...")
with open(output_file, 'w') as f:
for fen in all_positions:
f.write(fen + '\n')
elapsed_time = datetime.now() - start_time
elapsed_seconds = elapsed_time.total_seconds()
positions_per_second = positions_count / elapsed_seconds if elapsed_seconds > 0 else 0
# Print summary
print()
print("=" * 60)
print("POSITION GENERATION SUMMARY")
print("=" * 60)
print(f"Target positions: {total_positions:,}")
print(f"Actual positions saved: {positions_count:,}")
print(f"Workers: {num_workers}")
print(f"Games per worker: {games_per_worker:,}")
print(f"Samples per game: {samples_per_game}")
print(f"Move range: {min_move}-{max_move}")
print(f"Elapsed time: {elapsed_time}")
print(f"Throughput: {positions_per_second:.0f} positions/second")
print("=" * 60)
print()
if positions_count == 0:
print("WARNING: No valid positions were generated!")
return 0
return positions_count
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate random chess positions for NNUE training")
parser.add_argument("output_file", nargs="?", default="positions.txt",
help="Output file for positions (default: positions.txt)")
parser.add_argument("--positions", type=int, default=3000000,
help="Target number of positions to generate (default: 3000000)")
parser.add_argument("--samples-per-game", type=int, default=1,
help="Number of positions to sample per game (default: 1)")
parser.add_argument("--min-move", type=int, default=1,
help="Minimum move number to sample from (default: 1)")
parser.add_argument("--max-move", type=int, default=50,
help="Maximum move number to sample from (default: 50)")
parser.add_argument("--workers", type=int, default=8,
help="Number of parallel worker processes (default: 8)")
args = parser.parse_args()
count = play_random_game_and_collect_positions(
output_file=args.output_file,
total_positions=args.positions,
samples_per_game=args.samples_per_game,
min_move=args.min_move,
max_move=args.max_move,
num_workers=args.workers
)
sys.exit(0 if count > 0 else 1)
+326
View File
@@ -0,0 +1,326 @@
#!/usr/bin/env python3
"""Label positions with Stockfish evaluations and analyze distribution."""
import json
import chess.engine
import sys
import os
import numpy as np
from pathlib import Path
from tqdm import tqdm
from multiprocessing import Pool
from functools import partial
def normalize_evaluation(cp_value, method='tanh', scale=300.0):
"""Normalize centipawn evaluation to a bounded range.
Args:
cp_value: Centipawn evaluation from Stockfish
method: 'tanh' (default) or 'sigmoid'
scale: Scale factor (tanh: 300 is typical)
Returns:
Normalized value in approximately [-1, 1] (tanh) or [0, 1] (sigmoid)
"""
if method == 'tanh':
return np.tanh(cp_value / scale)
elif method == 'sigmoid':
return 1.0 / (1.0 + np.exp(-cp_value / scale))
else:
return cp_value / 100.0
def _evaluate_fen_batch(args):
"""Worker function to evaluate a batch of FENs with Stockfish threading.
Args:
args: tuple of (fens, stockfish_path, depth, normalize)
Returns:
list of (fen, eval_normalized, eval_raw) tuples
"""
fens, stockfish_path, depth, normalize = args
results = []
try:
engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
except Exception:
return []
try:
for fen in fens:
try:
board = chess.Board(fen)
if not board.is_valid():
continue
info = engine.analyse(board, chess.engine.Limit(depth=depth))
if info.get('score') is None:
continue
score = info['score'].white()
if score.is_mate():
eval_cp = 2000 if score.mate() > 0 else -2000
else:
eval_cp = score.cp
eval_cp = max(-2000, min(2000, eval_cp))
eval_normalized = normalize_evaluation(eval_cp) if normalize else eval_cp
results.append((fen, eval_normalized, eval_cp))
except Exception:
continue
finally:
engine.quit()
return results
def label_positions_with_stockfish(positions_file, output_file, stockfish_path, batch_size=1000, depth=12, verbose=False, normalize=True, num_workers=1):
"""Read positions and label them with Stockfish evaluations.
Args:
positions_file: Path to positions.txt
output_file: Path to training_data.jsonl
stockfish_path: Path to stockfish binary
batch_size: Batch size for processing (positions per worker task, default: 1000)
depth: Stockfish depth
verbose: Print detailed error messages
normalize: If True, normalize evals using tanh
num_workers: Number of parallel Stockfish processes
"""
# Check if stockfish exists
if not Path(stockfish_path).exists():
print(f"Error: Stockfish not found at {stockfish_path}")
print(f"Tried: {stockfish_path}")
print(f"Set STOCKFISH_PATH environment variable or pass as argument")
sys.exit(1)
print(f"Using Stockfish: {stockfish_path}")
print(f"Number of workers: {num_workers}")
# Check if positions file exists
if not Path(positions_file).exists():
print(f"Error: Positions file not found at {positions_file}")
sys.exit(1)
# Load existing evaluations if resuming
evaluated_fens = set()
position_count = 0
if Path(output_file).exists():
with open(output_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
evaluated_fens.add(data['fen'])
position_count += 1
except json.JSONDecodeError:
pass
print(f"Resuming from {position_count} already evaluated positions")
# Load all FENs that need evaluation
fens_to_evaluate = []
fens_seen_in_batch = set() # Track duplicates within current batch
skipped_invalid = 0
skipped_duplicate = 0
with open(positions_file, 'r') as f:
for fen in f:
fen = fen.strip()
if not fen:
skipped_invalid += 1
continue
if fen in evaluated_fens:
skipped_duplicate += 1
continue
if fen in fens_seen_in_batch:
skipped_duplicate += 1
continue
fens_to_evaluate.append(fen)
fens_seen_in_batch.add(fen)
total_to_evaluate = len(fens_to_evaluate)
total_lines = position_count + skipped_duplicate + skipped_invalid + total_to_evaluate
if total_to_evaluate == 0:
if position_count == 0:
print(f"Error: No valid positions to evaluate in {positions_file}")
sys.exit(1)
else:
print(f"All positions already evaluated. No new positions to process.")
return True
print(f"Total positions to process: {total_lines}")
print(f"New positions to evaluate: {total_to_evaluate}")
print(f"Using depth: {depth}")
print()
# Split FENs into batches for workers
batches = []
for i in range(0, total_to_evaluate, batch_size):
batch = fens_to_evaluate[i:i+batch_size]
batches.append((batch, stockfish_path, depth, normalize))
# Process batches in parallel
evaluated = 0
errors = 0
raw_evals = []
normalized_evals = []
import time
start_time = time.time()
with Pool(num_workers) as pool:
with tqdm(total=total_lines, initial=position_count, desc="Labeling positions") as pbar:
with open(output_file, 'a') as out:
for batch_idx, batch_results in enumerate(pool.imap_unordered(_evaluate_fen_batch, batches)):
for fen, eval_normalized, eval_cp in batch_results:
# Skip if already evaluated in output file during this run
if fen in evaluated_fens:
continue
data = {"fen": fen, "eval": eval_normalized, "eval_raw": eval_cp}
out.write(json.dumps(data) + '\n')
evaluated_fens.add(fen) # Track as evaluated
evaluated += 1
raw_evals.append(eval_cp)
normalized_evals.append(eval_normalized)
pbar.update(1)
# Update progress for any failed evaluations in the batch
batch_size_actual = len(batches[0][0]) if batches else batch_size
failed = batch_size_actual - len(batch_results)
if failed > 0:
errors += failed
pbar.update(failed)
# Calculate and show throughput and ETA
elapsed = time.time() - start_time
throughput = evaluated / elapsed if elapsed > 0 else 0
remaining_positions = total_to_evaluate - evaluated
eta_seconds = remaining_positions / throughput if throughput > 0 else 0
eta_str = f"{int(eta_seconds // 60)}:{int(eta_seconds % 60):02d}"
if (batch_idx + 1) % max(1, len(batches) // 10) == 0:
pbar.set_postfix({
'rate': f'{throughput:.0f} pos/s',
'eta': eta_str
})
# Print summary and analysis
print()
print("=" * 60)
print("LABELING SUMMARY")
print("=" * 60)
print(f"Successfully evaluated: {evaluated}")
print(f"Skipped (duplicates): {skipped_duplicate}")
print(f"Skipped (invalid): {skipped_invalid}")
print(f"Errors: {errors}")
print(f"Total processed: {evaluated + skipped_duplicate + skipped_invalid + errors}")
print("=" * 60)
print()
if evaluated == 0:
print("WARNING: No positions were successfully evaluated!")
print("Check that:")
print(" 1. positions.txt is not empty")
print(" 2. positions.txt contains valid FENs")
print(" 3. Stockfish is installed and working")
print(" 4. Stockfish path is correct")
return False
# Print distribution analysis
if raw_evals:
raw_evals_arr = np.array(raw_evals)
norm_evals_arr = np.array(normalized_evals)
print("=" * 60)
print("EVALUATION DISTRIBUTION ANALYSIS")
print("=" * 60)
print()
print("Raw Evaluations (centipawns):")
print(f" Min: {raw_evals_arr.min():.1f}")
print(f" Max: {raw_evals_arr.max():.1f}")
print(f" Mean: {raw_evals_arr.mean():.1f}")
print(f" Median: {np.median(raw_evals_arr):.1f}")
print(f" Std: {raw_evals_arr.std():.1f}")
print()
print("Normalized Evaluations (tanh):")
print(f" Min: {norm_evals_arr.min():.4f}")
print(f" Max: {norm_evals_arr.max():.4f}")
print(f" Mean: {norm_evals_arr.mean():.4f}")
print(f" Median: {np.median(norm_evals_arr):.4f}")
print(f" Std: {norm_evals_arr.std():.4f}")
print()
# Distribution buckets
print("Raw Evaluation Buckets (counts):")
buckets = [
(-float('inf'), -500, "< -5.00"),
(-500, -300, "[-5.00, -3.00)"),
(-300, -100, "[-3.00, -1.00)"),
(-100, 0, "[-1.00, 0.00)"),
(0, 100, "[0.00, 1.00)"),
(100, 300, "[1.00, 3.00)"),
(300, 500, "[3.00, 5.00)"),
(500, float('inf'), "> 5.00"),
]
for low, high, label in buckets:
count = np.sum((raw_evals_arr > low) & (raw_evals_arr <= high))
pct = 100.0 * count / len(raw_evals_arr)
print(f" {label}: {count:6d} ({pct:5.1f}%)")
print("=" * 60)
print()
print(f"✓ Labeling complete. Output saved to {output_file}")
return True
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Label chess positions with Stockfish evaluations")
parser.add_argument("positions_file", nargs="?", default="positions.txt",
help="Input positions file (default: positions.txt)")
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
help="Output file (default: training_data.jsonl)")
parser.add_argument("stockfish_path", nargs="?", default=None,
help="Path to Stockfish binary (default: $STOCKFISH_PATH or 'stockfish')")
parser.add_argument("--depth", type=int, default=12,
help="Stockfish depth (default: 12)")
parser.add_argument("--batch-size", type=int, default=1000,
help="Batch size for processing (default: 1000)")
parser.add_argument("--no-normalize", action="store_true",
help="Disable evaluation normalization (keep raw centipawns)")
parser.add_argument("--verbose", action="store_true",
help="Print detailed error messages")
parser.add_argument("--workers", type=int, default=1,
help="Number of parallel Stockfish processes (default: 1)")
args = parser.parse_args()
# Determine Stockfish path
stockfish_path = args.stockfish_path or os.environ.get("STOCKFISH_PATH", "stockfish")
success = label_positions_with_stockfish(
positions_file=args.positions_file,
output_file=args.output_file,
stockfish_path=stockfish_path,
batch_size=args.batch_size,
depth=args.depth,
normalize=not args.no_normalize,
verbose=args.verbose,
num_workers=args.workers
)
sys.exit(0 if success else 1)
+208
View File
@@ -0,0 +1,208 @@
#!/usr/bin/env python3
"""Import pre-labeled positions from the Lichess evaluation database.
Source: https://database.lichess.org/#evals
Format: lichess_db_eval.jsonl.zst — compressed JSONL, one position per line.
Each line:
{
"fen": "<pieces> <turn> <castling> <ep>",
"evals": [
{
"knodes": <int>,
"depth": <int>,
"pvs": [{"cp": <int>, "line": "..."} | {"mate": <int>, "line": "..."}]
}
]
}
cp and mate are from White's perspective (positive = White winning), matching
the sign convention used by label.py (score.white()) and expected by train.py.
"""
import json
import sys
import numpy as np
from pathlib import Path
from tqdm import tqdm
MATE_CP = 20000
SCALE = 300.0
def _best_eval(evals: list) -> dict | None:
"""Return the highest-depth evaluation entry, using knodes as tiebreaker."""
if not evals:
return None
return max(evals, key=lambda e: (e.get("depth", 0), e.get("knodes", 0)))
def _cp_from_pv(pv: dict) -> int | None:
"""Extract centipawn value from a principal variation entry."""
if "cp" in pv:
return max(-MATE_CP, min(MATE_CP, pv["cp"]))
if "mate" in pv:
return MATE_CP if pv["mate"] > 0 else -MATE_CP
return None
def _normalize(cp: int) -> float:
return float(np.tanh(cp / SCALE))
def import_lichess_evals(
input_path: str,
output_file: str,
max_positions: int | None = None,
min_depth: int = 0,
verbose: bool = False,
) -> int:
"""Stream the Lichess eval database and write a labeled.jsonl file.
Args:
input_path: Path to lichess_db_eval.jsonl.zst (or uncompressed .jsonl).
output_file: Destination labeled.jsonl (appended — supports resuming).
max_positions: Stop after this many new positions (None = no limit).
min_depth: Skip positions whose best eval has depth < min_depth.
verbose: Print warnings for skipped lines.
Returns:
Number of new positions written.
"""
import zstandard as zstd
input_path = Path(input_path)
if not input_path.exists():
print(f"Error: {input_path} not found")
sys.exit(1)
# Resume: collect already-written FENs so we skip duplicates.
seen_fens: set[str] = set()
if Path(output_file).exists():
with open(output_file, "r") as f:
for line in f:
try:
seen_fens.add(json.loads(line)["fen"])
except (json.JSONDecodeError, KeyError):
pass
if seen_fens:
print(f"Resuming — skipping {len(seen_fens):,} already-imported positions")
written = 0
skipped_depth = 0
skipped_no_eval = 0
skipped_dup = 0
def iter_lines():
"""Yield decoded text lines from either a .zst or plain .jsonl file."""
import io
if input_path.suffix == ".zst":
dctx = zstd.ZstdDecompressor()
with open(input_path, "rb") as fh:
with dctx.stream_reader(fh) as reader:
text_stream = io.TextIOWrapper(reader, encoding="utf-8")
yield from text_stream
else:
with open(input_path, "r", encoding="utf-8") as fh:
yield from fh
try:
with open(output_file, "a") as out:
with tqdm(desc="Importing Lichess evals", unit=" pos") as pbar:
for raw_line in iter_lines():
line = raw_line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
if verbose:
print("Warning: malformed JSON line skipped")
continue
fen = data.get("fen", "")
if not fen:
skipped_no_eval += 1
continue
if fen in seen_fens:
skipped_dup += 1
continue
best = _best_eval(data.get("evals", []))
if best is None:
skipped_no_eval += 1
continue
if best.get("depth", 0) < min_depth:
skipped_depth += 1
continue
pvs = best.get("pvs", [])
if not pvs:
skipped_no_eval += 1
continue
cp = _cp_from_pv(pvs[0])
if cp is None:
skipped_no_eval += 1
continue
record = {
"fen": fen,
"eval": _normalize(cp),
"eval_raw": cp,
}
out.write(json.dumps(record) + "\n")
seen_fens.add(fen)
written += 1
pbar.update(1)
if max_positions and written >= max_positions:
print(f"\nReached max_positions limit ({max_positions:,})")
break
except Exception:
raise
print()
print("=" * 60)
print("LICHESS IMPORT SUMMARY")
print("=" * 60)
print(f"Positions written: {written:,}")
print(f"Skipped (dup): {skipped_dup:,}")
print(f"Skipped (no eval): {skipped_no_eval:,}")
print(f"Skipped (depth<{min_depth}): {skipped_depth:,}")
print("=" * 60)
print(f"\n✓ Output: {output_file}")
return written
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Import Lichess pre-labeled positions into labeled.jsonl"
)
parser.add_argument("input_path",
help="Path to lichess_db_eval.jsonl.zst")
parser.add_argument("output_file", nargs="?", default="training_data.jsonl",
help="Output labeled.jsonl (default: training_data.jsonl)")
parser.add_argument("--max-positions", type=int, default=None,
help="Stop after N positions (default: no limit)")
parser.add_argument("--min-depth", type=int, default=0,
help="Minimum eval depth to accept (default: 0)")
parser.add_argument("--verbose", action="store_true",
help="Print warnings for skipped lines")
args = parser.parse_args()
count = import_lichess_evals(
input_path=args.input_path,
output_file=args.output_file,
max_positions=args.max_positions,
min_depth=args.min_depth,
verbose=args.verbose,
)
sys.exit(0 if count > 0 else 1)
@@ -0,0 +1,249 @@
import chess
import csv
import json
import sys
import urllib.request
from pathlib import Path
from typing import Set, Tuple
try:
import zstandard as zstd
except ImportError:
print("zstandard library not found. Install with: pip install zstandard")
sys.exit(1)
from generate import play_random_game_and_collect_positions
def download_and_extract_puzzle_db(
url: str = 'https://database.lichess.org/lichess_db_puzzle.csv.zst',
output_dir: str = 'tactical_data'
):
"""Download and extract the Lichess puzzle database."""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
csv_file = output_path / 'lichess_db_puzzle.csv'
zst_file = output_path / 'lichess_db_puzzle.csv.zst'
# Download if not already present
if not zst_file.exists():
print(f"Downloading puzzle database from {url}...")
try:
urllib.request.urlretrieve(url, zst_file)
print(f"Downloaded to {zst_file}")
except Exception as e:
print(f"Failed to download: {e}")
return None
# Extract if CSV doesn't exist
if not csv_file.exists():
print(f"Extracting {zst_file}...")
try:
with open(zst_file, 'rb') as f:
dctx = zstd.ZstdDecompressor()
with dctx.stream_reader(f) as reader:
with open(csv_file, 'wb') as out:
out.write(reader.read())
print(f"Extracted to {csv_file}")
except Exception as e:
print(f"Failed to extract: {e}")
return None
return str(csv_file)
def extract_puzzle_positions(
puzzle_csv: str,
max_puzzles: int = 300_000
) -> Set[str]:
"""
Extract the position BEFORE the blunder from each puzzle.
This is exactly the type of position where tactical
recognition matters most.
Returns a set of unique FENs.
"""
positions = set()
with open(puzzle_csv) as f:
reader = csv.DictReader(f)
for row in reader:
if len(positions) >= max_puzzles:
break
try:
board = chess.Board(row['FEN'])
# The puzzle FEN is AFTER the blunder move
# We want the position BEFORE — so it learns
# to find the tactic, not just play it
moves = row['Moves'].split()
# Undo one move to get pre-tactic position
board.push_uci(moves[0]) # opponent blunder
fen = board.fen()
# Filter for useful tactical themes
themes = row.get('Themes', '')
useful = any(t in themes for t in [
'fork', 'pin', 'skewer', 'discoveredAttack',
'mate', 'mateIn2', 'mateIn3', 'hangingPiece',
'trappedPiece', 'sacrifice'
])
if useful:
positions.add(fen)
except Exception:
continue
return positions
def load_positions_from_file(file_path: str) -> Set[str]:
"""Load positions from a text file (one FEN per line)."""
positions = set()
try:
with open(file_path) as f:
for line in f:
line = line.strip()
if line:
positions.add(line)
print(f"Loaded {len(positions)} positions from {file_path}")
return positions
except Exception as e:
print(f"Failed to load from {file_path}: {e}")
return set()
def merge_positions(
tactical: Set[str],
other: Set[str],
output_file: str = 'position.txt'
):
"""Merge two position sets and write to file."""
merged = tactical | other
with open(output_file, 'w') as f:
for fen in merged:
f.write(fen + '\n')
overlap = len(tactical & other)
print(f"\n{'='*60}")
print(f"MERGE SUMMARY")
print(f"{'='*60}")
print(f"Tactical positions: {len(tactical):,}")
print(f"Other positions: {len(other):,}")
print(f"Overlap (deduplicated): {overlap:,}")
print(f"Total merged positions: {len(merged):,}")
print(f"Written to: {output_file}")
print(f"{'='*60}\n")
def extract_tactical_only(
puzzle_csv: str,
output_file: str,
max_puzzles: int = 300_000
) -> int:
"""Extract tactical positions and save to file (no merge prompts).
Args:
puzzle_csv: Path to Lichess puzzle CSV
output_file: Where to save the FEN positions
max_puzzles: Maximum puzzles to extract
Returns:
Number of positions extracted
"""
print("Extracting tactical positions from puzzle database...")
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
with open(output_file, 'w') as f:
for fen in tactical_positions:
f.write(fen + '\n')
return len(tactical_positions)
def interactive_merge_positions(
puzzle_csv: str,
output_file: str = 'position.txt',
max_puzzles: int = 300_000
):
"""Interactive workflow: extract tactical positions and merge with user selection."""
print("\n" + "="*60)
print("TACTICAL POSITION EXTRACTOR & MERGER")
print("="*60 + "\n")
# Extract tactical positions
print("Extracting tactical positions from puzzle database...")
tactical_positions = extract_puzzle_positions(puzzle_csv, max_puzzles)
print(f"Extracted {len(tactical_positions):,} unique tactical positions\n")
# Ask what to merge with
print("What would you like to merge with these tactical positions?")
print("1. Load from a position file")
print("2. Generate random positions")
print("3. Skip merging (save tactical only)")
choice = input("\nEnter choice (1-3): ").strip()
other_positions = set()
if choice == '1':
file_path = input("Enter path to position file: ").strip()
other_positions = load_positions_from_file(file_path)
elif choice == '2':
positions_to_gen = input("How many positions to generate? (default 1000000): ").strip()
try:
positions_to_gen = int(positions_to_gen) if positions_to_gen else 1000000
except ValueError:
positions_to_gen = 1000000
temp_file = 'temp_generated_positions.txt'
print(f"\nGenerating {positions_to_gen:,} random positions...")
play_random_game_and_collect_positions(
output_file=temp_file,
total_positions=positions_to_gen,
samples_per_game=1,
min_move=1,
max_move=50,
num_workers=8
)
other_positions = load_positions_from_file(temp_file)
elif choice == '3':
print("Skipping merge, saving tactical positions only...")
else:
print("Invalid choice, saving tactical positions only...")
merge_positions(tactical_positions, other_positions, output_file)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description="Extract and merge tactical positions")
parser.add_argument("--url", default='https://database.lichess.org/lichess_db_puzzle.csv.zst',
help="URL to download puzzle database from")
parser.add_argument("--output-dir", default='trainingdata',
help="Directory to extract puzzle database to")
parser.add_argument("--max-puzzles", type=int, default=300_000,
help="Maximum puzzles to extract (default: 300000)")
parser.add_argument("--output-file", default='position.txt',
help="Output file for merged positions (default: position.txt)")
args = parser.parse_args()
# Download and extract
csv_path = download_and_extract_puzzle_db(args.url, args.output_dir)
if csv_path:
# Interactive merge
interactive_merge_positions(csv_path, args.output_file, args.max_puzzles)
else:
print("Failed to download/extract puzzle database")
sys.exit(1)
+676
View File
@@ -0,0 +1,676 @@
#!/usr/bin/env python3
"""Train NNUE neural network for chess evaluation."""
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import sys
from pathlib import Path
from tqdm import tqdm
import chess
from datetime import datetime, timedelta
import re
import numpy as np
class NNUEDataset(Dataset):
"""Dataset of chess positions with evaluations."""
def __init__(self, data_file):
self.positions = []
self.evals = []
self.evals_raw = []
self.is_normalized = None
with open(data_file, 'r') as f:
for line in f:
try:
data = json.loads(line)
fen = data['fen']
eval_val = data['eval']
self.positions.append(fen)
self.evals.append(eval_val)
# Check if normalized or raw
if self.is_normalized is None:
# If eval is in range [-1, 1], assume normalized
self.is_normalized = abs(eval_val) <= 1.0
# Store raw if available
if 'eval_raw' in data:
self.evals_raw.append(data['eval_raw'])
else:
self.evals_raw.append(eval_val)
except (json.JSONDecodeError, KeyError):
pass
def __len__(self):
return len(self.positions)
def __getitem__(self, idx):
fen = self.positions[idx]
eval_val = self.evals[idx]
features = fen_to_features(fen)
# Use evaluation as-is if normalized, otherwise apply sigmoid scaling
if self.is_normalized:
target = torch.tensor(eval_val, dtype=torch.float32)
else:
target = torch.sigmoid(torch.tensor(eval_val / 400.0, dtype=torch.float32))
return features, target
def fen_to_features(fen):
"""Convert FEN to 768-dimensional binary feature vector."""
# Piece type to index: pawn=0, knight=1, bishop=2, rook=3, queen=4, king=5
piece_to_idx = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}
features = torch.zeros(768, dtype=torch.float32)
try:
board = chess.Board(fen)
# 12 piece types × 64 squares = 768
for square in chess.SQUARES:
piece = board.piece_at(square)
if piece is not None:
piece_char = piece.symbol()
if piece_char in piece_to_idx:
piece_idx = piece_to_idx[piece_char]
feature_idx = piece_idx * 64 + square
features[feature_idx] = 1.0
except:
pass
return features
DEFAULT_HIDDEN_SIZES = [1536, 1024, 512, 256]
class NNUE(nn.Module):
"""NNUE neural network with configurable hidden layers.
Architecture: 768 → hidden_sizes[0] → ... → hidden_sizes[-1] → 1
Layer attributes follow the naming l1, l2, ..., lN so export.py can
infer the architecture directly from the state_dict.
"""
def __init__(self, hidden_sizes=None, dropout_rate=0.2):
super().__init__()
if hidden_sizes is None:
hidden_sizes = DEFAULT_HIDDEN_SIZES
self.hidden_sizes = list(hidden_sizes)
sizes = [768] + self.hidden_sizes + [1]
num_hidden = len(self.hidden_sizes)
for i in range(num_hidden):
setattr(self, f"l{i + 1}", nn.Linear(sizes[i], sizes[i + 1]))
setattr(self, f"relu{i + 1}", nn.ReLU())
setattr(self, f"drop{i + 1}", nn.Dropout(dropout_rate))
setattr(self, f"l{num_hidden + 1}", nn.Linear(sizes[-2], sizes[-1]))
self._num_hidden = num_hidden
def forward(self, x):
for i in range(1, self._num_hidden + 1):
layer = getattr(self, f"l{i}")
relu = getattr(self, f"relu{i}")
drop = getattr(self, f"drop{i}")
x = drop(relu(layer(x)))
return getattr(self, f"l{self._num_hidden + 1}")(x)
def find_next_version(base_name="nnue_weights"):
"""Find the next version number for model versioning.
Looks for nnue_weights_v*.pt files and returns the next version number.
If no versioned files exist, returns 1.
"""
base_path = Path(base_name)
directory = base_path.parent
filename = base_path.name
pattern = re.compile(rf"{re.escape(filename)}_v(\d+)\.pt")
versions = []
for file in directory.glob(f"{filename}_v*.pt"):
match = pattern.match(file.name)
if match:
versions.append(int(match.group(1)))
if versions:
return max(versions) + 1
return 1
def save_metadata(weights_file, metadata):
"""Save training metadata alongside the weights file.
Args:
weights_file: Path to the .pt file (e.g., nnue_weights_v1.pt)
metadata: Dictionary with training info
"""
metadata_file = weights_file.replace(".pt", "_metadata.json")
with open(metadata_file, "w") as f:
json.dump(metadata, f, indent=2, default=str)
return metadata_file
def _setup_training(data_file, batch_size, subsample_ratio):
"""Set up device, dataset, and data loaders.
Returns:
(device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions)
"""
print("Checking GPU availability...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
print(f" GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
print("⚠ GPU not available, using CPU")
print(f"Using device: {device}")
print()
print("Loading dataset...")
dataset = NNUEDataset(data_file)
num_positions = len(dataset)
print(f"Dataset size: {num_positions}")
print(f"Data normalization: {'Yes (tanh)' if dataset.is_normalized else 'No (raw centipawns)'})")
evals_array = np.array(dataset.evals)
print()
print("=" * 60)
print("TRAINING DATASET DIAGNOSTICS")
print("=" * 60)
print(f"Min evaluation: {evals_array.min():.4f}")
print(f"Max evaluation: {evals_array.max():.4f}")
print(f"Mean evaluation: {evals_array.mean():.4f}")
print(f"Median evaluation: {np.median(evals_array):.4f}")
print(f"Std deviation: {evals_array.std():.4f}")
print("=" * 60)
print()
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
from torch.utils.data import random_split, RandomSampler
generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
train_sampler = RandomSampler(train_dataset, replacement=False, num_samples=subsample_size)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
sampler=train_sampler,
num_workers=8,
pin_memory=True,
persistent_workers=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=8,
pin_memory=True,
persistent_workers=True
)
return device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions
def _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
start_epoch, epochs, early_stopping_patience,
season_start_time, deadline=None, initial_best_val_loss=float('inf')
):
"""Run one training season until epoch limit, early stopping, or deadline.
Args:
initial_best_val_loss: Baseline to beat — epochs that don't improve on this count
toward early stopping and do not save snapshots.
Returns:
(best_val_loss, best_model_state, last_epoch)
best_model_state is None if no epoch beat initial_best_val_loss.
"""
best_val_loss = initial_best_val_loss
best_model_state = None
epochs_without_improvement = 0
total_epochs = start_epoch + epochs
last_epoch = start_epoch - 1
for epoch in range(start_epoch, start_epoch + epochs):
if deadline and datetime.now() >= deadline:
print("Time limit reached, stopping season.")
break
epoch_display = epoch + 1
# Train
model.train()
train_loss = 0.0
with tqdm(total=len(train_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Train") as pbar:
for batch_features, batch_targets in train_loader:
batch_features = batch_features.to(device)
batch_targets = batch_targets.to(device).unsqueeze(1)
optimizer.zero_grad()
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item() * batch_features.size(0)
pbar.update(1)
train_loss /= len(train_dataset)
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
with tqdm(total=len(val_loader), desc=f"Epoch {epoch_display}/{total_epochs} - Val") as pbar:
for batch_features, batch_targets in val_loader:
batch_features = batch_features.to(device)
batch_targets = batch_targets.to(device).unsqueeze(1)
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
outputs = model(batch_features)
loss = criterion(outputs, batch_targets)
val_loss += loss.item() * batch_features.size(0)
pbar.update(1)
val_loss /= len(val_dataset)
scheduler.step()
if torch.cuda.is_available():
gpu_mem_used = torch.cuda.memory_allocated(device) / 1e9
gpu_mem_reserved = torch.cuda.memory_reserved(device) / 1e9
print(f"GPU Memory: {gpu_mem_used:.2f}GB used, {gpu_mem_reserved:.2f}GB reserved")
elapsed_time = datetime.now() - season_start_time
time_per_epoch = elapsed_time.total_seconds() / (epoch + 1)
remaining_epochs = total_epochs - epoch_display
eta_seconds = time_per_epoch * remaining_epochs
eta_str = str(datetime.fromtimestamp(eta_seconds) - datetime.fromtimestamp(0)).split('.')[0]
print(f"Epoch {epoch_display}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f} | ETA: {eta_str}")
checkpoint_file = output_file.replace(".pt", "_checkpoint.pt")
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"best_val_loss": best_val_loss,
"hidden_sizes": model.hidden_sizes,
}, checkpoint_file)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = model.state_dict().copy()
epochs_without_improvement = 0
snapshot_file = output_file.replace(".pt", "_best_snapshot.pt")
torch.save(best_model_state, snapshot_file)
print(f" Best model snapshot saved: {snapshot_file} (val_loss: {val_loss:.6f})")
else:
epochs_without_improvement += 1
last_epoch = epoch
if early_stopping_patience and epochs_without_improvement >= early_stopping_patience:
print(f"Early stopping: no improvement for {early_stopping_patience} epochs")
break
return best_val_loss, best_model_state, last_epoch
def _save_versioned_model(best_model_state, optimizer, scheduler, scaler, last_epoch,
best_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, training_start_time, hidden_sizes=None,
extra_metadata=None):
"""Save the best model with optional versioning and metadata."""
final_output_file = output_file
metadata = {}
architecture = [768] + list(hidden_sizes or DEFAULT_HIDDEN_SIZES) + [1]
if use_versioning:
base_name = output_file.replace(".pt", "")
version = find_next_version(base_name)
final_output_file = f"{base_name}_v{version}.pt"
metadata = {
"version": version,
"date": training_start_time.isoformat(),
"num_positions": num_positions,
"stockfish_depth": stockfish_depth,
"final_val_loss": float(best_val_loss),
"architecture": architecture,
"device": str(torch.device("cuda" if torch.cuda.is_available() else "cpu")),
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
if extra_metadata:
metadata.update(extra_metadata)
torch.save({
"model_state_dict": best_model_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"epoch": last_epoch,
"best_val_loss": best_val_loss,
"hidden_sizes": list(hidden_sizes or DEFAULT_HIDDEN_SIZES),
}, final_output_file)
print(f"Best model saved to {final_output_file}")
if use_versioning and metadata:
metadata_file = save_metadata(final_output_file, metadata)
print(f"Metadata saved to {metadata_file}")
print(f"\nTraining Summary:")
for key, val in metadata.items():
print(f" {key}: {val}")
def train_nnue(data_file, output_file="nnue_weights.pt", epochs=100, batch_size=16384, lr=0.001, checkpoint=None, stockfish_depth=12, use_versioning=True, early_stopping_patience=None, weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
"""Train the NNUE model with GPU optimizations and automatic mixed precision.
Args:
data_file: Path to training_data.jsonl
output_file: Where to save best weights (or base name if use_versioning=True)
epochs: Number of training epochs (default: 100)
batch_size: Training batch size (default: 16384)
lr: Learning rate (default: 0.001)
checkpoint: Optional path to checkpoint file to resume from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
early_stopping_patience: Stop if val loss doesn't improve for N epochs (None to disable)
weight_decay: L2 regularization strength (default: 1e-4, helps prevent overfitting)
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0 = all data)
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
"""
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
_setup_training(data_file, batch_size, subsample_ratio)
start_epoch = 0
best_val_loss = float('inf')
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
if checkpoint:
print(f"Loading checkpoint: {checkpoint}")
ckpt = torch.load(checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
ckpt_hidden = ckpt.get("hidden_sizes")
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
print(f" Using architecture from checkpoint: {ckpt_hidden}")
resolved_hidden_sizes = ckpt_hidden
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
if checkpoint:
ckpt = torch.load(checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
scaler.load_state_dict(ckpt["scaler_state_dict"])
start_epoch = ckpt["epoch"] + 1
best_val_loss = ckpt.get("best_val_loss", float('inf'))
print(f"Resumed from epoch {start_epoch} (best val loss so far: {best_val_loss:.6f})")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no optimizer state)")
checkpoint_val_loss = best_val_loss if checkpoint else float('inf')
subsample_size = max(1, int(subsample_ratio * len(train_dataset)))
arch_str = "".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
print(f"Architecture: {arch_str}")
print(f"Training for {epochs} epochs with batch_size={batch_size}, lr={lr}...")
print(f"Learning rate scheduler: Cosine annealing (T_max={epochs})")
print(f"Mixed precision training: enabled")
print(f"Regularization: Dropout (20%) + L2 weight decay ({weight_decay})")
if subsample_ratio < 1.0:
print(f"Stochastic sampling: {subsample_ratio:.0%} of train set per epoch ({subsample_size:,} positions)")
if early_stopping_patience:
print(f"Early stopping enabled (patience: {early_stopping_patience} epochs)")
print()
training_start_time = datetime.now()
best_val_loss, best_model_state, last_epoch = _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
start_epoch, epochs, early_stopping_patience,
training_start_time
)
if best_model_state is None or best_val_loss >= checkpoint_val_loss:
print(f"\nNo improvement over checkpoint (best: {best_val_loss:.6f} vs checkpoint: {checkpoint_val_loss:.6f})")
print("No new model created.")
return
_save_versioned_model(
best_model_state, optimizer, scheduler, scaler, last_epoch,
best_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, training_start_time,
hidden_sizes=resolved_hidden_sizes,
extra_metadata={"epochs": epochs, "batch_size": batch_size, "learning_rate": lr,
"checkpoint": str(checkpoint) if checkpoint else None}
)
def burst_train(data_file, output_file="nnue_weights.pt", duration_minutes=60,
epochs_per_season=50, early_stopping_patience=10,
batch_size=16384, lr=0.001, initial_checkpoint=None,
stockfish_depth=12, use_versioning=True,
weight_decay=1e-4, subsample_ratio=1.0, hidden_sizes=None):
"""Train in burst mode: repeatedly restart from the best checkpoint until the time budget expires.
Each season trains with early stopping. When early stopping fires, the model reloads the
global best weights and begins a fresh season with a reset optimizer and scheduler.
This prevents the model from drifting away from its best known state.
Args:
data_file: Path to training_data.jsonl
output_file: Output file base name
duration_minutes: Total training budget in minutes
epochs_per_season: Max epochs per restart season (default: 50)
early_stopping_patience: Patience for early stopping within each season (default: 10)
batch_size: Training batch size (default: 16384)
lr: Learning rate reset to this value at the start of each season (default: 0.001)
initial_checkpoint: Optional weights-only .pt file to start from
stockfish_depth: Depth used in Stockfish evaluation (for metadata)
use_versioning: If True, save as nnue_weights_v{N}.pt with metadata
weight_decay: L2 regularization strength (default: 1e-4)
subsample_ratio: Fraction of training data to sample per epoch (default: 1.0)
hidden_sizes: Hidden layer sizes (default: [1536, 1024, 512, 256])
"""
deadline = datetime.now() + timedelta(minutes=duration_minutes)
device, dataset, train_dataset, val_dataset, train_loader, val_loader, num_positions = \
_setup_training(data_file, batch_size, subsample_ratio)
resolved_hidden_sizes = list(hidden_sizes or DEFAULT_HIDDEN_SIZES)
if initial_checkpoint:
print(f"Loading initial weights: {initial_checkpoint}")
ckpt = torch.load(initial_checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
ckpt_hidden = ckpt.get("hidden_sizes")
if ckpt_hidden and ckpt_hidden != resolved_hidden_sizes:
print(f" Using architecture from checkpoint: {ckpt_hidden}")
resolved_hidden_sizes = ckpt_hidden
model = NNUE(hidden_sizes=resolved_hidden_sizes).to(device)
criterion = nn.MSELoss()
best_global_val_loss = float('inf')
if initial_checkpoint:
ckpt = torch.load(initial_checkpoint, map_location=device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
model.load_state_dict(ckpt["model_state_dict"])
best_global_val_loss = ckpt.get("best_val_loss", float('inf'))
if best_global_val_loss < float('inf'):
print(f"Resumed from checkpoint (best val loss: {best_global_val_loss:.6f})")
else:
print("Initial weights loaded (no val loss in checkpoint).")
else:
model.load_state_dict(ckpt)
print("Loaded weights-only checkpoint (no val loss info).")
arch_str = "".join(str(s) for s in [768] + resolved_hidden_sizes + [1])
print(f"Architecture: {arch_str}")
print(f"Burst training: {duration_minutes}m budget, {epochs_per_season} epochs/season, patience={early_stopping_patience}")
print(f"Deadline: {deadline.strftime('%H:%M:%S')}")
print()
burst_start_time = datetime.now()
season = 0
best_global_state = None
last_optimizer = None
last_scheduler = None
last_scaler = None
last_epoch = 0
while datetime.now() < deadline:
season += 1
remaining_minutes = (deadline - datetime.now()).total_seconds() / 60
print(f"\n{'=' * 60}")
print(f"BURST SEASON {season} | {remaining_minutes:.1f} minutes remaining")
if best_global_val_loss < float('inf'):
print(f"Global best val loss so far: {best_global_val_loss:.6f}")
print(f"{'=' * 60}\n")
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs_per_season)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
season_start_time = datetime.now()
val_loss, model_state, last_epoch = _run_training_season(
model, optimizer, scheduler, scaler,
train_loader, val_loader, train_dataset, val_dataset,
device, criterion, output_file,
0, epochs_per_season, early_stopping_patience,
season_start_time, deadline=deadline,
initial_best_val_loss=best_global_val_loss
)
last_optimizer = optimizer
last_scheduler = scheduler
last_scaler = scaler
if model_state is not None and val_loss < best_global_val_loss:
best_global_val_loss = val_loss
best_global_state = model_state
print(f" New global best: {best_global_val_loss:.6f} (season {season})")
# Reload global best for the next season so we never drift backwards
if best_global_state is not None:
model.load_state_dict(best_global_state)
total_minutes = (datetime.now() - burst_start_time).total_seconds() / 60
print(f"\n{'=' * 60}")
print(f"Burst training complete: {season} season(s) in {total_minutes:.1f}m")
print(f"Best val loss: {best_global_val_loss:.6f}")
print(f"{'=' * 60}\n")
if best_global_state is None:
print("No model improvement found. No file saved.")
return
_save_versioned_model(
best_global_state, last_optimizer, last_scheduler, last_scaler, last_epoch,
best_global_val_loss, output_file, use_versioning, num_positions,
stockfish_depth, burst_start_time,
hidden_sizes=resolved_hidden_sizes,
extra_metadata={
"mode": "burst",
"duration_minutes": duration_minutes,
"epochs_per_season": epochs_per_season,
"early_stopping_patience": early_stopping_patience,
"seasons_completed": season,
"batch_size": batch_size,
"learning_rate": lr,
"initial_checkpoint": str(initial_checkpoint) if initial_checkpoint else None,
}
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Train NNUE neural network for chess evaluation")
parser.add_argument("data_file", nargs="?", default="training_data.jsonl",
help="Path to training_data.jsonl (default: training_data.jsonl)")
parser.add_argument("output_file", nargs="?", default="nnue_weights.pt",
help="Output file base name (default: nnue_weights.pt)")
parser.add_argument("--checkpoint", type=str, default=None,
help="Path to checkpoint file to resume training from (optional)")
parser.add_argument("--epochs", type=int, default=100,
help="Number of epochs to train (default: 100)")
parser.add_argument("--batch-size", type=int, default=16384,
help="Batch size (default: 16384)")
parser.add_argument("--lr", type=float, default=0.001,
help="Learning rate (default: 0.001)")
parser.add_argument("--early-stopping", type=int, default=None,
help="Stop if val loss doesn't improve for N epochs (optional)")
parser.add_argument("--stockfish-depth", type=int, default=12,
help="Stockfish depth used for evaluations (for metadata, default: 12)")
parser.add_argument("--no-versioning", action="store_true",
help="Disable automatic versioning (save directly to output file)")
parser.add_argument("--weight-decay", type=float, default=5e-5,
help="L2 regularization strength (default: 1e-4, helps prevent overfitting)")
parser.add_argument("--subsample-ratio", type=float, default=1.0,
help="Fraction of training data to sample per epoch (default: 1.0 = all data)")
parser.add_argument("--hidden-layers", type=str, default=None,
help="Comma-separated hidden layer sizes (default: 1536,1024,512,256)")
# Burst mode
parser.add_argument("--burst-duration", type=float, default=None,
help="Enable burst mode: total training budget in minutes")
parser.add_argument("--epochs-per-season", type=int, default=50,
help="Max epochs per burst season before restarting (default: 50, burst mode only)")
args = parser.parse_args()
hidden_sizes = [int(x) for x in args.hidden_layers.split(",")] if args.hidden_layers else None
if args.burst_duration is not None:
burst_train(
data_file=args.data_file,
output_file=args.output_file,
duration_minutes=args.burst_duration,
epochs_per_season=args.epochs_per_season,
early_stopping_patience=args.early_stopping or 10,
batch_size=args.batch_size,
lr=args.lr,
initial_checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning,
weight_decay=args.weight_decay,
subsample_ratio=args.subsample_ratio,
hidden_sizes=hidden_sizes,
)
else:
train_nnue(
data_file=args.data_file,
output_file=args.output_file,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
checkpoint=args.checkpoint,
stockfish_depth=args.stockfish_depth,
use_versioning=not args.no_versioning,
early_stopping_patience=args.early_stopping,
weight_decay=args.weight_decay,
subsample_ratio=args.subsample_ratio,
hidden_sizes=hidden_sizes,
)
+30
View File
@@ -0,0 +1,30 @@
# Setup and run NNUE training pipeline
$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path
$VenvDir = Join-Path $ScriptDir ".venv"
# Check if virtual environment exists
if (-not (Test-Path $VenvDir)) {
Write-Host "Creating virtual environment..."
python -m venv $VenvDir
if ($LASTEXITCODE -ne 0) {
Write-Host "Error: Failed to create virtual environment. Make sure python is installed."
exit 1
}
}
# Activate virtual environment
Write-Host "Activating virtual environment..."
$ActivateScript = Join-Path $VenvDir "Scripts\Activate.ps1"
& $ActivateScript
# Install/update dependencies if requirements.txt exists
$RequirementsFile = Join-Path $ScriptDir "requirements.txt"
if (Test-Path $RequirementsFile) {
Write-Host "Installing dependencies..."
pip install -q -r $RequirementsFile
}
# Run nnue.py
Write-Host "Starting NNUE Training Pipeline..."
python (Join-Path $ScriptDir "nnue.py")
+29
View File
@@ -0,0 +1,29 @@
#!/bin/bash
# Setup and run NNUE training pipeline
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
VENV_DIR="$SCRIPT_DIR/.venv"
# Check if virtual environment exists
if [ ! -d "$VENV_DIR" ]; then
echo "Creating virtual environment..."
python3 -m venv "$VENV_DIR"
if [ $? -ne 0 ]; then
echo "Error: Failed to create virtual environment. Make sure python3 is installed."
exit 1
fi
fi
# Activate virtual environment
echo "Activating virtual environment..."
source "$VENV_DIR/bin/activate"
# Install/update dependencies if requirements.txt exists
if [ -f "$SCRIPT_DIR/requirements.txt" ]; then
echo "Installing dependencies..."
pip install -q -r "$SCRIPT_DIR/requirements.txt"
fi
# Run nnue.py
echo "Starting NNUE Training Pipeline..."
python "$SCRIPT_DIR/nnue.py"
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
{
"version": 10,
"date": "2026-04-14T22:18:38.824577",
"num_positions": 3022562,
"stockfish_depth": 12,
"final_val_loss": 6.248612448225196e-05,
"architecture": [
768,
1536,
1024,
512,
256,
1
],
"device": "cuda",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"checkpoint": null
}
@@ -0,0 +1,13 @@
{
"version": 1,
"date": "2026-04-07T22:56:23.259658",
"num_positions": 2086,
"stockfish_depth": 12,
"epochs": 20,
"batch_size": 4096,
"learning_rate": 0.001,
"final_val_loss": 0.016311248764395714,
"device": "cuda",
"checkpoint": null,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 2,
"date": "2026-04-07T23:50:05.390402",
"num_positions": 6886,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 4096,
"learning_rate": 0.001,
"final_val_loss": 0.007848377339541912,
"device": "cuda",
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v1.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 3,
"date": "2026-04-08T09:43:28.000579",
"num_positions": 71610,
"stockfish_depth": 12,
"epochs": 20,
"batch_size": 4096,
"learning_rate": 0.001,
"final_val_loss": 0.006398905136849695,
"device": "cpu",
"checkpoint": "/home/janis/Workspaces/IntelliJ/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v2.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 4,
"date": "2026-04-09T00:28:07.572209",
"num_positions": 2009355,
"stockfish_depth": 12,
"epochs": 40,
"batch_size": 4096,
"learning_rate": 0.001,
"final_val_loss": 9.106677896235248e-05,
"device": "cuda",
"checkpoint": "/mnt/d/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v3.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 5,
"date": "2026-04-09T18:50:27.845632",
"num_positions": 2009355,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"final_val_loss": 9.180421525105905e-05,
"device": "cuda",
"checkpoint": null,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 6,
"date": "2026-04-09T21:28:21.000832",
"num_positions": 1958728,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"final_val_loss": 0.2984530149085532,
"device": "cuda",
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v5.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 7,
"date": "2026-04-09T22:06:50.439858",
"num_positions": 1958728,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"final_val_loss": 0.2997283308762831,
"device": "cuda",
"checkpoint": null,
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,13 @@
{
"version": 8,
"date": "2026-04-09T22:22:47.859730",
"num_positions": 1958728,
"stockfish_depth": 12,
"epochs": 100,
"batch_size": 16384,
"learning_rate": 0.001,
"final_val_loss": 0.24803777390839968,
"device": "cuda",
"checkpoint": "/home/janis/Workspaces/NowChessSystems/modules/bot/python/weights/nnue_weights_v7.pt",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)"
}
Binary file not shown.
@@ -0,0 +1,17 @@
{
"version": 9,
"date": "2026-04-13T20:19:08.123315",
"num_positions": 2522562,
"stockfish_depth": 12,
"final_val_loss": 6.994176222619626e-05,
"device": "cuda",
"notes": "Win rate vs classical eval: TBD (requires benchmark games)",
"mode": "burst",
"duration_minutes": 30.0,
"epochs_per_season": 50,
"early_stopping_patience": 10,
"seasons_completed": 3,
"batch_size": 16384,
"learning_rate": 0.001,
"initial_checkpoint": "/home/janis/Workspaces/NowChess/NowChessSystems/modules/bot/python/weights/nnue_weights_v8.pt"
}
Binary file not shown.
@@ -0,0 +1,21 @@
package de.nowchess.bot
import de.nowchess.api.bot.Bot
import de.nowchess.bot.bots.ClassicalBot
object BotController {
private val bots: Map[String, Bot] = Map(
"easy" -> ClassicalBot(BotDifficulty.Easy),
"medium" -> ClassicalBot(BotDifficulty.Medium),
"hard" -> ClassicalBot(BotDifficulty.Hard),
"expert" -> ClassicalBot(BotDifficulty.Expert),
)
/** Get a bot by name. */
def getBot(name: String): Option[Bot] = bots.get(name.toLowerCase)
/** List all available bot names. */
def listBots: List[String] = bots.keys.toList.sorted
}
@@ -0,0 +1,7 @@
package de.nowchess.bot
enum BotDifficulty:
case Easy
case Medium
case Hard
case Expert
@@ -0,0 +1,19 @@
package de.nowchess.bot
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
object BotMoveRepetition:
private val maxConsecutiveMoves = 3
def blockedMoves(context: GameContext): Set[Move] = repeatedMove(context).toSet
def repeatedMove(context: GameContext): Option[Move] =
context.moves.takeRight(maxConsecutiveMoves) match
case first :: second :: third :: Nil if first == second && second == third => Some(first)
case _ => None
def filterAllowed(context: GameContext, moves: List[Move]): List[Move] =
val blocked = blockedMoves(context)
moves.filterNot(blocked.contains)
@@ -0,0 +1,11 @@
package de.nowchess.bot
object Config:
/** Threshold in centipawns: if classical evaluation differs from NNUE by more than this, the move is vetoed (not
* accepted as a suggestion).
*/
val VETO_THRESHOLD: Int = 150
/** Time budget per move for iterative deepening (milliseconds). */
val TIME_LIMIT_MS: Long = 2000L
@@ -0,0 +1,29 @@
package de.nowchess.bot.ai
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
trait Evaluation:
def CHECKMATE_SCORE: Int
def DRAW_SCORE: Int
def evaluate(context: GameContext): Int
// ── Accumulator hooks ─────────────────────────────────────────────────────
// Default implementations fall back to full re-evaluation each call.
// Override in NNUE-capable evaluators for incremental L1 speedup.
/** Initialise the accumulator for the root position at ply 0. */
def initAccumulator(context: GameContext): Unit = ()
/** Copy parent ply's accumulator to childPly without move deltas (null-move). */
def copyAccumulator(parentPly: Int, childPly: Int): Unit = ()
/** Derive childPly's accumulator from parentPly by applying move deltas. */
def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit = ()
/** Evaluate from the pre-computed accumulator at ply, using hash for the eval cache. Falls back to full evaluate when
* not overridden.
*/
def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int = evaluate(context)
@@ -0,0 +1,29 @@
package de.nowchess.bot.bots
import de.nowchess.api.bot.Bot
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.bot.bots.classic.EvaluationClassic
import de.nowchess.bot.logic.AlphaBetaSearch
import de.nowchess.bot.util.PolyglotBook
import de.nowchess.bot.{BotDifficulty, BotMoveRepetition}
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
final class ClassicalBot(
difficulty: BotDifficulty,
rules: RuleSet = DefaultRules,
book: Option[PolyglotBook] = None,
) extends Bot:
private val search: AlphaBetaSearch = AlphaBetaSearch(rules, weights = EvaluationClassic)
private val TIME_BUDGET_MS = 1000L
override val name: String = s"ClassicalBot(${difficulty.toString})"
override def nextMove(context: GameContext): Option[Move] =
val blockedMoves = BotMoveRepetition.blockedMoves(context)
book
.flatMap(_.probe(context))
.filterNot(blockedMoves.contains)
.orElse(search.bestMoveWithTime(context, TIME_BUDGET_MS, blockedMoves))
@@ -0,0 +1,43 @@
package de.nowchess.bot.bots
import de.nowchess.api.bot.Bot
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.bot.ai.Evaluation
import de.nowchess.bot.bots.classic.EvaluationClassic
import de.nowchess.bot.bots.nnue.EvaluationNNUE
import de.nowchess.bot.logic.{AlphaBetaSearch, TranspositionTable}
import de.nowchess.bot.util.PolyglotBook
import de.nowchess.bot.{BotDifficulty, BotMoveRepetition, Config}
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
final class HybridBot(
difficulty: BotDifficulty,
rules: RuleSet = DefaultRules,
book: Option[PolyglotBook] = None,
nnueEvaluation: Evaluation = EvaluationNNUE,
classicalEvaluation: Evaluation = EvaluationClassic,
vetoReporter: String => Unit = println(_),
) extends Bot:
private val search = AlphaBetaSearch(rules, TranspositionTable(), classicalEvaluation)
override val name: String = s"HybridBot(${difficulty.toString})"
override def nextMove(context: GameContext): Option[Move] =
val blockedMoves = BotMoveRepetition.blockedMoves(context)
book.flatMap(_.probe(context)).filterNot(blockedMoves.contains).orElse(searchWithVeto(context, blockedMoves))
private def searchWithVeto(context: GameContext, blockedMoves: Set[Move]): Option[Move] =
search.bestMoveWithTime(context, Config.TIME_LIMIT_MS, blockedMoves).map { move =>
val next = rules.applyMove(context)(move)
val staticNnue = nnueEvaluation.evaluate(next)
val classical = classicalEvaluation.evaluate(next)
val diff = (classical - staticNnue).abs
if diff > Config.VETO_THRESHOLD then
vetoReporter(
f"[Veto] ${move.from}->${move.to}: nnue=$staticNnue classical=$classical diff=$diff — flagged but trusted (deep search)",
)
move
}
@@ -0,0 +1,60 @@
package de.nowchess.bot.bots
import de.nowchess.api.bot.Bot
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.bot.bots.nnue.EvaluationNNUE
import de.nowchess.bot.logic.AlphaBetaSearch
import de.nowchess.bot.util.{PolyglotBook, ZobristHash}
import de.nowchess.bot.{BotDifficulty, BotMoveRepetition}
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
final class NNUEBot(
difficulty: BotDifficulty,
rules: RuleSet = DefaultRules,
book: Option[PolyglotBook] = None,
) extends Bot:
private val search: AlphaBetaSearch = AlphaBetaSearch(rules, weights = EvaluationNNUE)
override val name: String = s"NNUEBot(${difficulty.toString})"
override def nextMove(context: GameContext): Option[Move] =
val blockedMoves = BotMoveRepetition.blockedMoves(context)
book
.flatMap(_.probe(context))
.filterNot(blockedMoves.contains)
.orElse {
val moves = BotMoveRepetition.filterAllowed(context, rules.allLegalMoves(context))
if moves.isEmpty then None
else
val scored = batchEvaluateRoot(context, moves)
val bestMove = scored.maxBy(_._2)._1
search.bestMoveWithTime(context, allocateTime(scored), blockedMoves).orElse(Some(bestMove))
}
/** Evaluate all root moves shallowly via incremental NNUE accumulator updates. Returns (move, score) pairs with score
* from the root player's perspective.
*/
private def batchEvaluateRoot(context: GameContext, moves: List[Move]): List[(Move, Int)] =
EvaluationNNUE.initAccumulator(context)
val rootHash = ZobristHash.hash(context)
moves.map { move =>
val child = rules.applyMove(context)(move)
val childHash = ZobristHash.nextHash(context, rootHash, move, child)
EvaluationNNUE.pushAccumulator(1, move, context, child)
val score = -EvaluationNNUE.evaluateAccumulator(1, child, childHash)
(move, score)
}
/** Allocate more time for complex positions; less when one move clearly dominates. */
private def allocateTime(scored: List[(Move, Int)]): Long =
val moveCount = scored.length
if moveCount > 30 then 1500L
else if moveCount < 5 then 500L
else
val scores = scored.map(_._2)
val best = scores.max
val second = scores.filter(_ < best).maxOption.getOrElse(best)
if best - second > 200 then 600L else 1000L
@@ -0,0 +1,361 @@
package de.nowchess.bot.bots.classic
import de.nowchess.api.board.{Color, PieceType, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.bot.ai.Evaluation
object EvaluationClassic extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
// Material values in centipawns (indexed by PieceType.ordinal: Pawn=0, Knight=1, Bishop=2, Rook=3, Queen=4, King=5)
private val mgMaterial = Array(100, 325, 335, 500, 900, 20_000)
private val egMaterial = Array(110, 310, 310, 530, 1_000, 20_000)
private val TEMPO_BONUS: Int = 10
// Piece-square tables (Simplified Evaluation Function, Michniewski)
// Indexed by squareIndex = rank.ordinal * 8 + file.ordinal
// White's perspective: rank 0 = home (r1), rank 7 = back rank (r8)
// Black is vertically mirrored
private val mgPawnTable: Array[Int] = Array(
0, 0, 0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 10, 10, 20, 30, 30, 20, 10, 10, 5, 5, 10, 25, 25, 10, 5, 5,
0, 0, 0, 20, 20, 0, 0, 0, 5, -5, -10, 0, 0, -10, -5, 5, 5, 10, 10, -20, -20, 10, 10, 5, 0, 0, 0, 0, 0, 0, 0, 0,
)
private val egPawnTable: Array[Int] = Array(
0, 0, 0, 0, 0, 0, 0, 0, 70, 70, 70, 70, 70, 70, 70, 70, 40, 40, 40, 40, 40, 40, 40, 40, 30, 30, 30, 30, 30, 30, 30,
30, 20, 20, 20, 20, 20, 20, 20, 20, 10, 10, 10, 10, 10, 10, 10, 10, 5, 5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0,
)
private val mgKnightTable: Array[Int] = Array(
-50, -40, -30, -30, -30, -30, -40, -50, -40, -20, 0, 0, 0, 0, -20, -40, -30, 0, 10, 15, 15, 10, 0, -30, -30, 5, 15,
20, 20, 15, 5, -30, -30, 0, 15, 20, 20, 15, 0, -30, -30, 5, 10, 15, 15, 10, 5, -30, -40, -20, 0, 5, 5, 0, -20, -40,
-50, -40, -30, -30, -30, -30, -40, -50,
)
private val egKnightTable: Array[Int] = Array(
-30, -20, -10, -10, -10, -10, -20, -30, -20, 0, 5, 5, 5, 5, 0, -20, -10, 5, 15, 20, 20, 15, 5, -10, -10, 5, 20, 25,
25, 20, 5, -10, -10, 5, 20, 25, 25, 20, 5, -10, -10, 5, 15, 20, 20, 15, 5, -10, -20, 0, 5, 5, 5, 5, 0, -20, -30,
-20, -10, -10, -10, -10, -20, -30,
)
private val mgBishopTable: Array[Int] = Array(
-20, -10, -10, -10, -10, -10, -10, -20, -10, 0, 0, 0, 0, 0, 0, -10, -10, 0, 5, 10, 10, 5, 0, -10, -10, 5, 5, 10, 10,
5, 5, -10, -10, 0, 10, 10, 10, 10, 0, -10, -10, 10, 10, 10, 10, 10, 10, -10, -10, 5, 0, 0, 0, 0, 5, -10, -20, -10,
-10, -10, -10, -10, -10, -20,
)
private val egBishopTable: Array[Int] = Array(
-20, -10, -5, -5, -5, -5, -10, -20, -10, 0, 5, 5, 5, 5, 0, -10, -5, 5, 10, 10, 10, 10, 5, -5, -5, 5, 10, 15, 15, 10,
5, -5, -5, 5, 10, 15, 15, 10, 5, -5, -5, 5, 10, 10, 10, 10, 5, -5, -10, 0, 5, 5, 5, 5, 0, -10, -20, -10, -5, -5, -5,
-5, -10, -20,
)
private val mgRookTable: Array[Int] = Array(
0, 0, 0, 0, 0, 0, 0, 0, 5, 10, 10, 10, 10, 10, 10, 5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0,
0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, 0, 0, 0, 5, 5, 0, 0, 0,
)
private val egRookTable: Array[Int] = Array(
0, 0, 0, 0, 0, 0, 0, 0, 5, 10, 10, 10, 10, 10, 10, 5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0,
0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, -5, 0, 0, 0, 0, 0, 0, -5, 0, 0, 0, 5, 5, 0, 0, 0,
)
private val mgQueenTable: Array[Int] = Array(
-20, -10, -10, -5, -5, -10, -10, -20, -10, 0, 0, 0, 0, 0, 0, -10, -10, 0, 5, 5, 5, 5, 0, -10, -5, 0, 5, 5, 5, 5, 0,
-5, 0, 0, 5, 5, 5, 5, 0, -5, -10, 5, 5, 5, 5, 5, 0, -10, -10, 0, 5, 0, 0, 0, 0, -10, -20, -10, -10, -5, -5, -10,
-10, -20,
)
private val egQueenTable: Array[Int] = Array(
-15, -10, -8, -5, -5, -8, -10, -15, -10, 0, 3, 5, 5, 3, 0, -10, -8, 3, 10, 10, 10, 10, 3, -8, -5, 5, 10, 15, 15, 10,
5, -5, -5, 5, 10, 15, 15, 10, 5, -5, -8, 3, 10, 10, 10, 10, 3, -8, -10, 0, 3, 5, 5, 3, 0, -10, -15, -10, -8, -5, -5,
-8, -10, -15,
)
private val mgKingTable: Array[Int] = Array(
-30, -40, -40, -50, -50, -40, -40, -30, -30, -40, -40, -50, -50, -40, -40, -30, -30, -40, -40, -50, -50, -40, -40,
-30, -30, -40, -40, -50, -50, -40, -40, -30, -20, -30, -30, -40, -40, -30, -30, -20, -10, -20, -20, -20, -20, -20,
-20, -10, 20, 20, 0, 0, 0, 0, 20, 20, 20, 30, 10, 0, 0, 10, 30, 20,
)
private val egKingTable: Array[Int] = Array(
-50, -40, -30, -20, -20, -30, -40, -50, -30, -20, -10, 0, 0, -10, -20, -30, -30, -10, 20, 30, 30, 20, -10, -30, -30,
-10, 30, 40, 40, 30, -10, -30, -30, -10, 30, 40, 40, 30, -10, -30, -30, -10, 20, 30, 30, 20, -10, -30, -30, -30, 0,
0, 0, 0, -30, -30, -50, -30, -30, -30, -30, -30, -30, -50,
)
private val phaseWeight: Map[PieceType, Int] = Map(
PieceType.Knight -> 1,
PieceType.Bishop -> 1,
PieceType.Rook -> 2,
PieceType.Queen -> 4,
)
private val maxPhase = 24 // 4*4 + 4*2 + 4*1 + 4*1
private val passedPawnBonus: Array[Int] = Array(0, 5, 10, 20, 35, 60, 100, 0)
private val egPassedPawnBonus: Array[Int] = Array(0, 20, 40, 80, 150, 250, 400, 0)
// Pawn structure penalties
private val doubledMg = -10
private val doubledEg = -25
private val isolatedMg = -15
private val isolatedEg = -20
// Mobility weights: centipawns per reachable square (indexed by PieceType.ordinal)
private val mobilityMg = Array(0, 4, 3, 2, 1, 0, 0)
private val mobilityEg = Array(0, 4, 3, 4, 2, 0, 0)
// Direction offsets for sliding pieces
private val diagonals = List((-1, -1), (-1, 1), (1, -1), (1, 1))
private val orthogonals = List((-1, 0), (1, 0), (0, -1), (0, 1))
private val knightOffsets = List((-2, -1), (-2, 1), (-1, -2), (-1, 2), (1, -2), (1, 2), (2, -1), (2, 1))
// Rook and bishop bonuses
private val bishopPairMg = 50
private val bishopPairEg = 70
private val rookOn7thMg = 20
private val rookOn7thEg = 10
/** Evaluate the position from the perspective of context.turn. Positive = good for context.turn.
*/
def evaluate(context: GameContext): Int =
val phase = gamePhase(context.board)
val isEg = isEndgame(phase)
val material = materialAndPositional(context, phase)
val structure = pawnStructure(context, phase)
val mobility = mobilityScore(context, phase)
val rookBishop = rookAndBishopBonuses(context, phase)
val bonuses = positionalBonuses(context, phase, isEg)
val egBonuses = if isEg then endgameBonus(context) else 0
material + structure + mobility + rookBishop + bonuses + egBonuses + TEMPO_BONUS
private def gamePhase(board: de.nowchess.api.board.Board): Int =
val phase = board.pieces.values.foldLeft(0) { (acc, piece) =>
acc + phaseWeight.getOrElse(piece.pieceType, 0)
}
math.min(phase, maxPhase)
private def isEndgame(phase: Int): Boolean =
phase < 8 // Significantly reduced material indicates endgame
private def taper(mg: Int, eg: Int, phase: Int): Int =
(mg * phase + eg * (maxPhase - phase)) / maxPhase
private def materialAndPositional(context: GameContext, phase: Int): Int =
val (mg, eg) = context.board.pieces.foldLeft((0, 0)) { case ((mg, eg), (square, piece)) =>
val (psqMg, psqEg) = squareBonus(piece.pieceType, piece.color, square)
val pieceMg = mgMaterial(piece.pieceType.ordinal) + psqMg
val pieceEg = egMaterial(piece.pieceType.ordinal) + psqEg
val sign = if piece.color == context.turn then 1 else -1
(mg + sign * pieceMg, eg + sign * pieceEg)
}
taper(mg, eg, phase)
private def squareBonus(pieceType: PieceType, color: Color, sq: Square): (Int, Int) =
val rankIdx = if color == Color.White then sq.rank.ordinal else 7 - sq.rank.ordinal
val fileIdx = sq.file.ordinal
val squareIdx = rankIdx * 8 + fileIdx
pieceType match
case PieceType.Pawn => (mgPawnTable(squareIdx), egPawnTable(squareIdx))
case PieceType.Knight => (mgKnightTable(squareIdx), egKnightTable(squareIdx))
case PieceType.Bishop => (mgBishopTable(squareIdx), egBishopTable(squareIdx))
case PieceType.Rook => (mgRookTable(squareIdx), egRookTable(squareIdx))
case PieceType.Queen => (mgQueenTable(squareIdx), egQueenTable(squareIdx))
case PieceType.King => (mgKingTable(squareIdx), egKingTable(squareIdx))
private def pawnStructure(context: GameContext, phase: Int): Int =
val friendlyPawns = context.board.pieces.filter((_, p) => p.color == context.turn && p.pieceType == PieceType.Pawn)
val enemyPawns = context.board.pieces.filter((_, p) => p.color != context.turn && p.pieceType == PieceType.Pawn)
val friendlyByFile = friendlyPawns.groupMap(s => s._1.file.ordinal)(s => s._1.rank.ordinal)
val enemyByFile = enemyPawns.groupMap(s => s._1.file.ordinal)(s => s._1.rank.ordinal)
val (fMg, fEg) = structureScore(friendlyByFile)
val (eMg, eEg) = structureScore(enemyByFile)
taper(fMg - eMg, fEg - eEg, phase)
private def structureScore(byFile: Map[Int, Iterable[Int]]): (Int, Int) =
byFile.foldLeft((0, 0)) { case ((mg, eg), (file, ranks)) =>
val doubled = (ranks.size - 1).max(0)
val hasAdjacent = (file - 1 to file + 1).filter(f => f >= 0 && f < 8 && f != file).exists(byFile.contains)
val isolated = if !hasAdjacent then ranks.size else 0
(mg + doubled * doubledMg + isolated * isolatedMg, eg + doubled * doubledEg + isolated * isolatedEg)
}
private def positionalBonuses(context: GameContext, phase: Int, isEg: Boolean): Int =
context.board.pieces.foldLeft(0) { case (score, (sq, piece)) =>
val bonus = piece.pieceType match
case PieceType.Pawn =>
if isPassedPawn(context.board, sq, piece.color) then
if isEg then egPassedPawnBonus(sq.rank.ordinal) else passedPawnBonus(sq.rank.ordinal)
else 0
case PieceType.Rook => rookOpenFileBonus(context.board, sq, piece.color)
case PieceType.King => kingShieldBonus(context.board, sq, piece.color, phase)
case _ => 0
if piece.color == context.turn then score + bonus else score - bonus
}
private def isPassedPawn(board: de.nowchess.api.board.Board, sq: Square, color: Color): Boolean =
val enemyColor = color.opposite
val pawnRank = sq.rank.ordinal
val fileRange = (sq.file.ordinal - 1 to sq.file.ordinal + 1).filter(f => f >= 0 && f < 8)
val rankCheck = if color == Color.White then (r: Int) => r > pawnRank else (r: Int) => r < pawnRank
board.pieces.forall { (enemySq, enemyPiece) =>
!(enemyPiece.color == enemyColor &&
enemyPiece.pieceType == PieceType.Pawn &&
fileRange.contains(enemySq.file.ordinal) &&
rankCheck(enemySq.rank.ordinal))
}
private def rookOpenFileBonus(board: de.nowchess.api.board.Board, rookSq: Square, color: Color): Int =
val hasFriendlyPawn = board.pieces.exists { (sq, piece) =>
piece.color == color && piece.pieceType == PieceType.Pawn && sq.file == rookSq.file
}
val hasEnemyPawn = board.pieces.exists { (sq, piece) =>
piece.color != color && piece.pieceType == PieceType.Pawn && sq.file == rookSq.file
}
if !hasFriendlyPawn && !hasEnemyPawn then 20 // open file
else if !hasFriendlyPawn then 10 // semi-open file
else 0
private def kingShieldBonus(board: de.nowchess.api.board.Board, kingSq: Square, color: Color, phase: Int): Int =
val shieldRankDelta = if color == Color.White then 1 else -1
val shieldFiles = (kingSq.file.ordinal - 1 to kingSq.file.ordinal + 1).filter(f => f >= 0 && f < 8)
val shieldRank = kingSq.rank.ordinal + shieldRankDelta
if shieldRank < 0 || shieldRank > 7 then 0
else
val rawBonus = board.pieces.count { (sq, piece) =>
piece.color == color &&
piece.pieceType == PieceType.Pawn &&
shieldFiles.contains(sq.file.ordinal) &&
sq.rank.ordinal == shieldRank
} * 10
(rawBonus * phase) / maxPhase
private def slidingCount(
sq: Square,
board: de.nowchess.api.board.Board,
color: Color,
directions: List[(Int, Int)],
): Int =
directions.foldLeft(0) { case (total, (fileDelta, rankDelta)) =>
@scala.annotation.tailrec
def countRay(current: Option[Square], acc: Int): Int =
current match
case None => acc
case Some(target) =>
board.pieceAt(target) match
case Some(piece) if piece.color == color => acc
case Some(_) => acc + 1
case None => countRay(target.offset(fileDelta, rankDelta), acc + 1)
total + countRay(sq.offset(fileDelta, rankDelta), 0)
}
private def knightCount(sq: Square, board: de.nowchess.api.board.Board, color: Color): Int =
knightOffsets.count { case (fileDelta, rankDelta) =>
sq.offset(fileDelta, rankDelta).forall { target =>
board.pieceAt(target).forall(_.color != color)
}
}
private def mobilityScore(context: GameContext, phase: Int): Int =
val (mg, eg) = context.board.pieces.foldLeft((0, 0)) { case ((mg, eg), (sq, piece)) =>
val count = piece.pieceType match
case PieceType.Knight => knightCount(sq, context.board, piece.color)
case PieceType.Bishop => slidingCount(sq, context.board, piece.color, diagonals)
case PieceType.Rook => slidingCount(sq, context.board, piece.color, orthogonals)
case PieceType.Queen => slidingCount(sq, context.board, piece.color, diagonals ++ orthogonals)
case _ => 0
val pieceMg = count * mobilityMg(piece.pieceType.ordinal)
val pieceEg = count * mobilityEg(piece.pieceType.ordinal)
val sign = if piece.color == context.turn then 1 else -1
(mg + sign * pieceMg, eg + sign * pieceEg)
}
taper(mg, eg, phase)
private def rookAndBishopBonuses(context: GameContext, phase: Int): Int =
val (baseMg, baseEg) = bishopPairBase(context)
val (rookMg, rookEg) = rookOn7thDelta(context)
taper(baseMg + rookMg, baseEg + rookEg, phase)
private def bishopPairBase(context: GameContext): (Int, Int) =
val friendlyHasPair = hasBishopPair(context, context.turn)
val enemyHasPair = hasBishopPair(context, context.turn.opposite)
val mg = pairDelta(friendlyHasPair, enemyHasPair, bishopPairMg)
val eg = pairDelta(friendlyHasPair, enemyHasPair, bishopPairEg)
(mg, eg)
private def hasBishopPair(context: GameContext, color: Color): Boolean =
val bishopSquares = context.board.pieces.collect {
case (sq, piece) if piece.color == color && piece.pieceType == PieceType.Bishop => sq
}
bishopSquares.exists(isEvenSquare) && bishopSquares.exists(sq => !isEvenSquare(sq))
private def isEvenSquare(square: Square): Boolean =
(square.file.ordinal + square.rank.ordinal) % 2 == 0
private def pairDelta(friendlyHasPair: Boolean, enemyHasPair: Boolean, bonus: Int): Int =
(if friendlyHasPair then bonus else 0) - (if enemyHasPair then bonus else 0)
private def rookOn7thDelta(context: GameContext): (Int, Int) =
context.board.pieces.foldLeft((0, 0)) { case ((mg, eg), (sq, piece)) =>
rookOn7thContribution(piece, sq, context.turn).fold((mg, eg)) { case (dMg, dEg) =>
(mg + dMg, eg + dEg)
}
}
private def rookOn7thContribution(piece: de.nowchess.api.board.Piece, sq: Square, turn: Color): Option[(Int, Int)] =
Option.when(piece.pieceType == PieceType.Rook && isRookOn7th(piece.color, sq)) {
val sign = if piece.color == turn then 1 else -1
(sign * rookOn7thMg, sign * rookOn7thEg)
}
private def isRookOn7th(color: Color, sq: Square): Boolean =
if color == Color.White then sq.rank.ordinal == 6 else sq.rank.ordinal == 1
private def endgameBonus(context: GameContext): Int =
val friendlyKing = context.board.pieces.find((_, p) => p.color == context.turn && p.pieceType == PieceType.King)
val enemyKing = context.board.pieces.find((_, p) => p.color != context.turn && p.pieceType == PieceType.King)
val kingCentralBonus =
friendlyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15) -
enemyKing.fold(0)((kSq, _) => (8 - kingCentralizationDistance(kSq)) * 15)
val friendlyMaterial = materialCount(context, context.turn)
val enemyMaterial = materialCount(context, context.turn.opposite)
val edgeBonus =
if friendlyMaterial > enemyMaterial then enemyKing.fold(0)((kSq, _) => (7 - kingEdgeDistance(kSq)) * 10)
else 0
kingCentralBonus + edgeBonus
private def kingCentralizationDistance(sq: Square): Int =
val fileFromCenter = (sq.file.ordinal - 3.5).abs.toInt
val rankFromCenter = (sq.rank.ordinal - 3.5).abs.toInt
math.max(fileFromCenter, rankFromCenter)
private def kingEdgeDistance(sq: Square): Int =
val fileFromEdge = math.min(sq.file.ordinal, 7 - sq.file.ordinal)
val rankFromEdge = math.min(sq.rank.ordinal, 7 - sq.rank.ordinal)
math.min(fileFromEdge, rankFromEdge)
private def materialCount(context: GameContext, color: Color): Int =
context.board.pieces.foldLeft(0) { case (sum, (_, piece)) =>
if piece.color == color then
sum + (piece.pieceType match
case PieceType.Knight => 300
case PieceType.Bishop => 300
case PieceType.Rook => 500
case PieceType.Queen => 900
case PieceType.Pawn => 0
case PieceType.King => 0
)
else sum
}
@@ -0,0 +1,31 @@
package de.nowchess.bot.bots.nnue
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.bot.ai.Evaluation
object EvaluationNNUE extends Evaluation:
private val nnue = NNUE(NbaiLoader.loadDefault())
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
/** Full-board evaluate — used as fallback and by non-search callers. */
def evaluate(context: GameContext): Int = nnue.evaluate(context)
// ── Accumulator hooks (incremental L1) ───────────────────────────────────
override def initAccumulator(context: GameContext): Unit =
nnue.initAccumulator(context.board)
override def copyAccumulator(parentPly: Int, childPly: Int): Unit =
nnue.copyAccumulator(parentPly, childPly)
override def pushAccumulator(childPly: Int, move: Move, parent: GameContext, child: GameContext): Unit =
// Use incremental updates, but recompute from scratch every 10 plies to prevent accumulation errors
if childPly % 10 == 0 then nnue.recomputeAccumulator(childPly, child.board)
else nnue.pushAccumulator(childPly, move, parent.board)
override def evaluateAccumulator(ply: Int, context: GameContext, hash: Long): Int =
nnue.evaluateAtPlyWithValidation(ply, context.turn, hash, context.board)
@@ -0,0 +1,231 @@
package de.nowchess.bot.bots.nnue
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
class NNUE(model: NbaiModel):
private val featureSize = model.layers(0).inputSize
private val accSize = model.layers(0).outputSize
private val validateAccum = sys.env.contains("NNUE_VALIDATE") // Enable with NNUE_VALIDATE=1
// Column-major L1 weights for cache-friendly sparse & incremental updates.
// l1WeightsT(featureIdx * accSize + outputIdx) = l1Weights(outputIdx * featureSize + featureIdx)
private val l1WeightsT: Array[Float] =
val w = model.weights(0).weights
val t = new Array[Float](featureSize * accSize)
for j <- 0 until featureSize; i <- 0 until accSize do t(j * accSize + i) = w(i * featureSize + j)
t
// ── Accumulator stack ────────────────────────────────────────────────────
private val MAX_PLY = 128
private val l1Stack: Array[Array[Float]] = Array.fill(MAX_PLY + 1)(new Array[Float](accSize))
// Shared evaluation buffers: index i holds the output of layers(i) (all except the scalar output layer).
private val evalBuffers: Array[Array[Float]] = model.layers.init.map(l => new Array[Float](l.outputSize))
// ── Eval cache ───────────────────────────────────────────────────────────
private val EVAL_CACHE_MASK = (1 << 18) - 1L
private val evalCacheHashes = new Array[Long](1 << 18)
private val evalCacheScores = new Array[Int](1 << 18)
// ── Feature helpers ──────────────────────────────────────────────────────
private def squareNum(sq: Square): Int = sq.rank.ordinal * 8 + sq.file.ordinal
private def featureIndex(piece: Piece, sqNum: Int): Int =
val colorOffset = if piece.color == Color.White then 6 else 0
(colorOffset + piece.pieceType.ordinal) * 64 + sqNum
private def addColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
val offset = featureIdx * accSize
for i <- 0 until accSize do l1Pre(i) += l1WeightsT(offset + i)
private def subtractColumn(l1Pre: Array[Float], featureIdx: Int): Unit =
val offset = featureIdx * accSize
for i <- 0 until accSize do l1Pre(i) -= l1WeightsT(offset + i)
// ── Accumulator init ─────────────────────────────────────────────────────
def initAccumulator(board: Board): Unit =
System.arraycopy(model.weights(0).bias, 0, l1Stack(0), 0, accSize)
for (sq, piece) <- board.pieces do addColumn(l1Stack(0), featureIndex(piece, squareNum(sq)))
// ── Accumulator push (incremental updates) ───────────────────────────────
def pushAccumulator(childPly: Int, move: Move, board: Board): Unit =
System.arraycopy(l1Stack(childPly - 1), 0, l1Stack(childPly), 0, accSize)
val l1 = l1Stack(childPly)
move.moveType match
case MoveType.Normal(_) => applyNormalDelta(l1, move, board)
case MoveType.EnPassant => applyEnPassantDelta(l1, move, board)
case MoveType.CastleKingside | MoveType.CastleQueenside => applyCastleDelta(l1, move, board)
case MoveType.Promotion(p) => applyPromotionDelta(l1, move, p, board)
def copyAccumulator(parentPly: Int, childPly: Int): Unit =
System.arraycopy(l1Stack(parentPly), 0, l1Stack(childPly), 0, accSize)
def recomputeAccumulator(ply: Int, board: Board): Unit =
System.arraycopy(model.weights(0).bias, 0, l1Stack(ply), 0, accSize)
for (sq, piece) <- board.pieces do addColumn(l1Stack(ply), featureIndex(piece, squareNum(sq)))
def validateAccumulator(ply: Int, board: Board): Boolean =
// Compute what L1 should be from scratch
val expectedL1 = new Array[Float](accSize)
System.arraycopy(model.weights(0).bias, 0, expectedL1, 0, accSize)
for (sq, piece) <- board.pieces do addColumn(expectedL1, featureIndex(piece, squareNum(sq)))
// Compare with actual L1
val actual = l1Stack(ply)
val maxError =
(0 until accSize).foldLeft(0f) { (currentMax, i) =>
val error = math.abs(actual(i) - expectedL1(i))
math.max(currentMax, error)
}
maxError < 0.001f // Allow small floating-point errors
private def applyNormalDelta(l1: Array[Float], move: Move, board: Board): Unit =
// Extract source and destination square indices early
val fromNum = squareNum(move.from)
val toNum = squareNum(move.to)
// Get the moving piece
board.pieceAt(move.from).foreach { mover =>
subtractColumn(l1, featureIndex(mover, fromNum))
// If there's a capture, subtract the captured piece
board.pieceAt(move.to).foreach { cap =>
subtractColumn(l1, featureIndex(cap, toNum))
}
// Add the piece to its new location
addColumn(l1, featureIndex(mover, toNum))
}
private def applyEnPassantDelta(l1: Array[Float], move: Move, board: Board): Unit =
board.pieceAt(move.from).foreach { pawn =>
val capturedSq = Square(move.to.file, move.from.rank)
subtractColumn(l1, featureIndex(pawn, squareNum(move.from)))
board.pieceAt(capturedSq).foreach(cap => subtractColumn(l1, featureIndex(cap, squareNum(capturedSq))))
addColumn(l1, featureIndex(pawn, squareNum(move.to)))
}
private def applyCastleDelta(l1: Array[Float], move: Move, board: Board): Unit =
board.pieceAt(move.from).foreach { king =>
val rank = move.from.rank
val kingside = move.moveType == MoveType.CastleKingside
val (rookFrom, rookTo) =
if kingside then (Square(File.H, rank), Square(File.F, rank))
else (Square(File.A, rank), Square(File.D, rank))
val rook = Piece(king.color, PieceType.Rook)
subtractColumn(l1, featureIndex(king, squareNum(move.from)))
addColumn(l1, featureIndex(king, squareNum(move.to)))
subtractColumn(l1, featureIndex(rook, squareNum(rookFrom)))
addColumn(l1, featureIndex(rook, squareNum(rookTo)))
}
private def applyPromotionDelta(l1: Array[Float], move: Move, promo: PromotionPiece, board: Board): Unit =
board.pieceAt(move.from).foreach { pawn =>
val toNum = squareNum(move.to)
subtractColumn(l1, featureIndex(pawn, squareNum(move.from)))
board.pieceAt(move.to).foreach(cap => subtractColumn(l1, featureIndex(cap, toNum)))
addColumn(l1, featureIndex(Piece(pawn.color, promotedType(promo)), toNum))
}
private def promotedType(promo: PromotionPiece): PieceType = promo match
case PromotionPiece.Knight => PieceType.Knight
case PromotionPiece.Bishop => PieceType.Bishop
case PromotionPiece.Rook => PieceType.Rook
case PromotionPiece.Queen => PieceType.Queen
// ── Evaluation from accumulator ──────────────────────────────────────────
def evaluateAtPly(ply: Int, turn: Color, hash: Long): Int =
val idx = (hash & EVAL_CACHE_MASK).toInt
if evalCacheHashes(idx) == hash then evalCacheScores(idx)
else
val score = runL2toOutput(l1Stack(ply), turn)
evalCacheHashes(idx) = hash
evalCacheScores(idx) = score
score
def evaluateAtPlyWithValidation(ply: Int, turn: Color, hash: Long, board: Board): Int =
// For debugging: validate that incremental accumulator matches recomputation
if validateAccum && ply > 0 && ply % 10 != 0 then
val isValid = validateAccumulator(ply, board)
if !isValid then System.err.println(s"WARNING: NNUE accumulator diverged at ply $ply")
evaluateAtPly(ply, turn, hash)
private def runL2toOutput(l1Pre: Array[Float], turn: Color): Int =
val l1ReLU = evalBuffers(0)
for i <- 0 until accSize do l1ReLU(i) = if l1Pre(i) > 0f then l1Pre(i) else 0f
val finalInput =
(1 until model.layers.length - 1).foldLeft(l1ReLU) { (input, i) =>
val lw = model.weights(i)
val out = evalBuffers(i)
val ld = model.layers(i)
runDenseReLU(input, ld.inputSize, lw.weights, lw.bias, out, ld.outputSize)
out
}
val lastIdx = model.layers.length - 1
val output = runOutputLayer(finalInput, model.layers(lastIdx).inputSize, model.weights(lastIdx))
scoreFromOutput(output, turn)
private def runDenseReLU(
input: Array[Float],
inSize: Int,
weights: Array[Float],
bias: Array[Float],
output: Array[Float],
outSize: Int,
): Unit =
for i <- 0 until outSize do
val sum = (0 until inSize).foldLeft(bias(i))((s, j) => s + input(j) * weights(i * inSize + j))
output(i) = if sum > 0f then sum else 0f
private def runOutputLayer(input: Array[Float], inSize: Int, lw: LayerWeights): Float =
(0 until inSize).foldLeft(lw.bias(0))((sum, j) => sum + input(j) * lw.weights(j))
private def scoreFromOutput(output: Float, turn: Color): Int =
val cp =
if math.abs(output) >= 0.9999f then if output > 0f then 20000 else -20000
else
val atanh = 0.5f * math.log((1f + output) / (1f - output)).toFloat
(300f * atanh).toInt
val cpFromTurn = if turn == Color.Black then -cp else cp
math.max(-20000, math.min(20000, cpFromTurn))
// ── Legacy full-board evaluate ────────────────────────────────────────────
private val legacyL1 = new Array[Float](accSize)
def evaluate(context: GameContext): Int =
System.arraycopy(model.weights(0).bias, 0, legacyL1, 0, accSize)
for (sq, piece) <- context.board.pieces do addColumn(legacyL1, featureIndex(piece, squareNum(sq)))
runL2toOutput(legacyL1, context.turn)
def benchmark(): Unit =
val context = GameContext.initial
val iterations = 1_000_000
for _ <- 0 until 10000 do evaluate(context)
val startNanos = System.nanoTime()
for _ <- 0 until iterations do evaluate(context)
val endNanos = System.nanoTime()
val totalNanos = endNanos - startNanos
val nanosPerEval = totalNanos.toDouble / iterations
println()
println("=" * 60)
println("NNUE BENCHMARK RESULTS")
println("=" * 60)
println(f"Iterations: $iterations%,d")
println(f"Total time: ${totalNanos / 1e9}%.2f seconds")
println(f"ns/eval: $nanosPerEval%.2f ns")
println(f"evals/second: ${1e9 / nanosPerEval}%.0f evals/s")
println("=" * 60)
println()
@@ -0,0 +1,52 @@
package de.nowchess.bot.bots.nnue
import java.io.InputStream
import java.nio.{ByteBuffer, ByteOrder}
import java.nio.charset.StandardCharsets
object NbaiLoader:
/** Little-endian encoding of ASCII bytes 'N','B','A','I'. */
val MAGIC: Int = 0x4942_414e
def load(stream: InputStream): NbaiModel =
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
checkHeader(buf)
val metadata = readMetadata(buf)
val descs = readLayerDescriptors(buf)
val weights = descs.map(_ => readLayerWeights(buf))
NbaiModel(metadata, descs, weights)
/** Tries /nnue_weights.nbai on the classpath; falls back to migrating /nnue_weights.bin. */
def loadDefault(): NbaiModel =
Option(getClass.getResourceAsStream("/nnue_weights.nbai")) match
case Some(s) =>
try load(s)
finally s.close()
case None => NbaiMigrator.migrateFromBin()
private def checkHeader(buf: ByteBuffer): Unit =
val magic = buf.getInt()
if magic != MAGIC then sys.error(s"Invalid NBAI magic: 0x${magic.toHexString}")
val version = buf.getShort() & 0xffff
if version != 1 then sys.error(s"Unsupported NBAI version: $version")
private def readMetadata(buf: ByteBuffer): NbaiMetadata =
val bytes = new Array[Byte](buf.getInt())
buf.get(bytes)
NbaiMetadata.fromJson(new String(bytes, StandardCharsets.UTF_8))
private def readLayerDescriptors(buf: ByteBuffer): Array[LayerDescriptor] =
Array.tabulate(buf.getShort() & 0xffff) { _ =>
val nameBytes = new Array[Byte](buf.get() & 0xff)
buf.get(nameBytes)
LayerDescriptor(new String(nameBytes, StandardCharsets.US_ASCII), buf.getInt(), buf.getInt())
}
private def readLayerWeights(buf: ByteBuffer): LayerWeights =
LayerWeights(readFloats(buf), readFloats(buf))
private def readFloats(buf: ByteBuffer): Array[Float] =
val arr = new Array[Float](buf.getInt())
for i <- arr.indices do arr(i) = buf.getFloat()
arr
@@ -0,0 +1,43 @@
package de.nowchess.bot.bots.nnue
import java.nio.{ByteBuffer, ByteOrder}
/** Converts the legacy nnue_weights.bin resource into an NbaiModel. Used as fallback when no .nbai file exists. */
object NbaiMigrator:
private val BinMagic = 0x4555_4e4e
private val BinVersion = 1
private val DefaultLayers: Array[LayerDescriptor] = Array(
LayerDescriptor("relu", 768, 1536),
LayerDescriptor("relu", 1536, 1024),
LayerDescriptor("relu", 1024, 512),
LayerDescriptor("relu", 512, 256),
LayerDescriptor("linear", 256, 1),
)
private val UnknownMetadata: NbaiMetadata =
NbaiMetadata(trainedBy = "unknown", trainedAt = "unknown", trainingDataCount = 0L, valLoss = 0.0, trainLoss = 0.0)
def migrateFromBin(): NbaiModel =
val stream = Option(getClass.getResourceAsStream("/nnue_weights.bin"))
.getOrElse(sys.error("Neither nnue_weights.nbai nor nnue_weights.bin found in resources"))
try
val buf = ByteBuffer.wrap(stream.readAllBytes()).order(ByteOrder.LITTLE_ENDIAN)
checkBinHeader(buf)
val weights = DefaultLayers.map(_ => readBinLayerWeights(buf))
NbaiModel(UnknownMetadata, DefaultLayers, weights)
finally stream.close()
private def checkBinHeader(buf: ByteBuffer): Unit =
val magic = buf.getInt()
if magic != BinMagic then sys.error(s"Invalid bin magic: 0x${magic.toHexString}")
val version = buf.getInt()
if version != BinVersion then sys.error(s"Unsupported bin version: $version")
private def readBinLayerWeights(buf: ByteBuffer): LayerWeights =
LayerWeights(readBinTensor(buf), readBinTensor(buf))
private def readBinTensor(buf: ByteBuffer): Array[Float] =
val shape = Array.tabulate(buf.getInt())(_ => buf.getInt())
Array.tabulate(shape.product)(_ => buf.getFloat())
@@ -0,0 +1,45 @@
package de.nowchess.bot.bots.nnue
/** Descriptor for a single dense layer stored in a .nbai file. */
case class LayerDescriptor(activation: String, inputSize: Int, outputSize: Int)
/** Training metadata embedded in every .nbai file. */
case class NbaiMetadata(
trainedBy: String,
trainedAt: String,
trainingDataCount: Long,
valLoss: Double,
trainLoss: Double,
):
def toJson: String =
s"""{
| "trainedBy": "$trainedBy",
| "trainedAt": "$trainedAt",
| "trainingDataCount": $trainingDataCount,
| "valLoss": $valLoss,
| "trainLoss": $trainLoss
|}""".stripMargin
object NbaiMetadata:
def fromJson(json: String): NbaiMetadata =
def str(key: String) = raw""""$key"\s*:\s*"([^"]*)"""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("")
def num(key: String) = raw""""$key"\s*:\s*([0-9.eE+\-]+)""".r.findFirstMatchIn(json).map(_.group(1)).getOrElse("0")
NbaiMetadata(
str("trainedBy"),
str("trainedAt"),
num("trainingDataCount").toLong,
num("valLoss").toDouble,
num("trainLoss").toDouble,
)
/** Weights and biases for a single layer. Weights are row-major: (outputSize × inputSize). */
case class LayerWeights(weights: Array[Float], bias: Array[Float])
/** A fully deserialized .nbai model ready to initialize NNUE. */
case class NbaiModel(
metadata: NbaiMetadata,
layers: Array[LayerDescriptor],
weights: Array[LayerWeights],
):
require(layers.length == weights.length, "Layer count must match weight count")
require(layers.length >= 2, "Model must have at least 2 layers")
@@ -0,0 +1,51 @@
package de.nowchess.bot.bots.nnue
import java.io.{ByteArrayOutputStream, OutputStream}
import java.nio.{ByteBuffer, ByteOrder}
import java.nio.charset.StandardCharsets
object NbaiWriter:
def write(model: NbaiModel, out: OutputStream): Unit =
val acc = new ByteArrayOutputStream()
writeHeader(acc)
writeMetadata(acc, model.metadata)
writeLayerDescriptors(acc, model.layers)
model.weights.foreach(lw => writeLayerWeights(acc, lw))
out.write(acc.toByteArray)
private def writeHeader(out: ByteArrayOutputStream): Unit =
val buf = ByteBuffer.allocate(6).order(ByteOrder.LITTLE_ENDIAN)
buf.putInt(NbaiLoader.MAGIC)
buf.putShort(1.toShort)
out.write(buf.array())
private def writeMetadata(out: ByteArrayOutputStream, meta: NbaiMetadata): Unit =
val json = meta.toJson.getBytes(StandardCharsets.UTF_8)
val buf = ByteBuffer.allocate(4 + json.length).order(ByteOrder.LITTLE_ENDIAN)
buf.putInt(json.length)
buf.put(json)
out.write(buf.array())
private def writeLayerDescriptors(out: ByteArrayOutputStream, layers: Array[LayerDescriptor]): Unit =
val nameBytes = layers.map(_.activation.getBytes(StandardCharsets.US_ASCII))
val capacity = 2 + layers.indices.map(i => 1 + nameBytes(i).length + 8).sum
val buf = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN)
buf.putShort(layers.length.toShort)
layers.zip(nameBytes).foreach { (l, nb) =>
buf.put(nb.length.toByte)
buf.put(nb)
buf.putInt(l.inputSize)
buf.putInt(l.outputSize)
}
out.write(buf.array())
private def writeLayerWeights(out: ByteArrayOutputStream, lw: LayerWeights): Unit =
writeFloats(out, lw.weights)
writeFloats(out, lw.bias)
private def writeFloats(out: ByteArrayOutputStream, floats: Array[Float]): Unit =
val buf = ByteBuffer.allocate(4 + floats.length * 4).order(ByteOrder.LITTLE_ENDIAN)
buf.putInt(floats.length)
floats.foreach(buf.putFloat)
out.write(buf.array())
@@ -0,0 +1,419 @@
package de.nowchess.bot.logic
import de.nowchess.api.board.PieceType
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType}
import de.nowchess.bot.ai.Evaluation
import de.nowchess.bot.util.ZobristHash
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}
final class AlphaBetaSearch(
rules: RuleSet = DefaultRules,
tt: TranspositionTable = TranspositionTable(),
weights: Evaluation,
numThreads: Int = Runtime.getRuntime.availableProcessors,
):
private val INF = Int.MaxValue / 2
private val MAX_QUIESCENCE_PLY = 64
private val NULL_MOVE_R = 2
private val ASPIRATION_DELTA = 50
private val ASPIRATION_DELTA_MAX = 150
private val TIME_CHECK_FREQUENCY = 1000
private val FUTILITY_MARGIN = 100
private val CHECK_EXTENSION = 1
private val timeStartMs = AtomicLong(0L)
private val timeLimitMs = AtomicLong(0L)
private val nodeCount = AtomicInteger(0)
private val ordering = MoveOrdering.OrderingContext()
private final case class QuiescenceNode(
context: GameContext,
ply: Int,
alpha: Int,
beta: Int,
hash: Long,
)
/** Return the best move for the side to move, searching to maxDepth plies. Uses iterative deepening with aspiration
* windows.
*/
def bestMove(context: GameContext, maxDepth: Int): Option[Move] =
bestMove(context, maxDepth, Set.empty)
def bestMove(context: GameContext, maxDepth: Int, excludedRootMoves: Set[Move]): Option[Move] =
tt.clear()
ordering.clear()
weights.initAccumulator(context)
timeStartMs.set(System.currentTimeMillis)
timeLimitMs.set(Long.MaxValue / 4)
nodeCount.set(0)
val rootHash = ZobristHash.hash(context)
(1 to maxDepth)
.foldLeft((None: Option[Move], 0)) { case ((bestSoFar, prevScore), depth) =>
val (alpha, beta) =
if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
val (score, move) = searchWithAspiration(
context,
depth,
alpha,
beta,
ASPIRATION_DELTA,
rootHash,
excludedRootMoves,
)
(move.orElse(bestSoFar), score)
}
._1
/** Return the best move for the side to move within a time budget (ms). Uses iterative deepening, stopping when time
* runs out.
*/
def bestMoveWithTime(context: GameContext, timeBudgetMs: Long): Option[Move] =
bestMoveWithTime(context, timeBudgetMs, Set.empty)
def bestMoveWithTime(context: GameContext, timeBudgetMs: Long, excludedRootMoves: Set[Move]): Option[Move] =
tt.clear()
ordering.clear()
weights.initAccumulator(context)
timeStartMs.set(System.currentTimeMillis)
timeLimitMs.set(timeBudgetMs)
nodeCount.set(0)
val rootHash = ZobristHash.hash(context)
@scala.annotation.tailrec
def loop(bestSoFar: Option[Move], prevScore: Int, depth: Int): Option[Move] =
if isOutOfTime then bestSoFar
else
val (alpha, beta) =
if depth == 1 then (-INF, INF) else (prevScore - ASPIRATION_DELTA, prevScore + ASPIRATION_DELTA)
val (score, move) = searchWithAspiration(
context,
depth,
alpha,
beta,
ASPIRATION_DELTA,
rootHash,
excludedRootMoves,
)
loop(move.orElse(bestSoFar), score, depth + 1)
loop(None, 0, 1)
private def isOutOfTime: Boolean =
System.currentTimeMillis - timeStartMs.get >= timeLimitMs.get
private def searchWithAspiration(
context: GameContext,
depth: Int,
alpha: Int,
beta: Int,
initialWindow: Int,
rootHash: Long,
excludedRootMoves: Set[Move],
): (Int, Option[Move]) =
val state = SearchState(rootHash, Map(rootHash -> 1))
@scala.annotation.tailrec
def loop(currentAlpha: Int, currentBeta: Int, delta: Int, attempt: Int): (Int, Option[Move]) =
if attempt >= 3 || attempt >= depth then search(context, depth, 0, Window(-INF, INF), state, excludedRootMoves)
else
val (score, move) = search(context, depth, 0, Window(currentAlpha, currentBeta), state, excludedRootMoves)
if score > currentAlpha && score < currentBeta then (score, move)
else if score <= currentAlpha then
loop(score - delta, currentBeta, math.min(delta * 2, ASPIRATION_DELTA_MAX), attempt + 1)
else loop(currentAlpha, score + delta, math.min(delta * 2, ASPIRATION_DELTA_MAX), attempt + 1)
loop(alpha, beta, initialWindow, 0)
private def hasNonPawnMaterial(context: GameContext): Boolean =
context.board.pieces.values.exists { piece =>
piece.color == context.turn &&
piece.pieceType != PieceType.Pawn &&
piece.pieceType != PieceType.King
}
private def nullMoveContext(context: GameContext): GameContext =
context.withTurn(context.turn.opposite).withEnPassantSquare(None)
private def tryNullMove(
context: GameContext,
depth: Int,
ply: Int,
beta: Int,
state: SearchState,
excludedRootMoves: Set[Move],
): Option[Int] =
val nullCtx = nullMoveContext(context)
val nullState = state.advance(ZobristHash.hash(nullCtx))
val reductionDepth = math.max(0, depth - 1 - NULL_MOVE_R)
weights.copyAccumulator(ply, ply + 1)
val (score, _) = search(nullCtx, reductionDepth, ply + 1, Window(-beta, -beta + 1), nullState, excludedRootMoves)
if -score >= beta then Some(beta) else None
/** Negamax alpha-beta search returning (score, best move). */
private def search(
context: GameContext,
depth: Int,
ply: Int,
window: Window,
state: SearchState,
excludedRootMoves: Set[Move],
): (Int, Option[Move]) =
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
searchNode(params)
private def searchNode(params: SearchParams): (Int, Option[Move]) =
val count = nodeCount.incrementAndGet()
immediateSearchResult(params, count).getOrElse {
val legalMoves = rules.allLegalMoves(params.context)
terminalSearchResult(params, legalMoves).getOrElse(searchDeeper(params, legalMoves))
}
private def immediateSearchResult(
params: SearchParams,
count: Int,
): Option[(Int, Option[Move])] =
if count % TIME_CHECK_FREQUENCY == 0 && isOutOfTime then
Some((weights.evaluateAccumulator(params.ply, params.context, params.state.hash), None))
else if params.state.repetitions.getOrElse(params.state.hash, 0) >= 3 then Some((weights.DRAW_SCORE, None))
else ttCutoff(params)
private def ttCutoff(params: SearchParams): Option[(Int, Option[Move])] =
tt.probe(params.state.hash).filter(_.depth >= params.depth).flatMap { entry =>
entry.flag match
case TTFlag.Exact => Some((entry.score, entry.bestMove))
case TTFlag.Lower =>
val newAlpha = math.max(params.window.alpha, entry.score)
Option.when(newAlpha >= params.window.beta)((entry.score, entry.bestMove))
case TTFlag.Upper =>
val newBeta = math.min(params.window.beta, entry.score)
Option.when(params.window.alpha >= newBeta)((entry.score, entry.bestMove))
}
private def terminalSearchResult(
params: SearchParams,
legalMoves: List[Move],
): Option[(Int, Option[Move])] =
if legalMoves.isEmpty then
Some(
(
if rules.isCheckmate(params.context) then -(weights.CHECKMATE_SCORE - params.ply) else weights.DRAW_SCORE,
None,
),
)
else if rules.isInsufficientMaterial(params.context) || rules.isFiftyMoveRule(params.context) then
Some((weights.DRAW_SCORE, None))
else if params.depth == 0 then
Some((quiescence(params.context, params.ply, params.window.alpha, params.window.beta, params.state.hash), None))
else None
private def searchDeeper(
params: SearchParams,
legalMoves: List[Move],
): (Int, Option[Move]) =
val nullResult =
Option
.when(canTryNullMove(params))(
tryNullMove(
params.context,
params.depth,
params.ply,
params.window.beta,
params.state,
params.excludedRootMoves,
),
)
.flatten
nullResult.map((_, None)).getOrElse {
val ttBest = tt.probe(params.state.hash).flatMap(_.bestMove)
val ordered = MoveOrdering.sort(params.context, legalMoves, ttBest, params.ply, ordering)
searchSequential(
params.context,
params.depth,
params.ply,
params.window,
ordered,
params.state,
params.excludedRootMoves,
)
}
private def canTryNullMove(params: SearchParams): Boolean =
params.depth >= 3 &&
!rules.isCheck(params.context) &&
hasNonPawnMaterial(params.context)
private def isQuietMove(context: GameContext, move: Move): Boolean =
!isCapture(context, move) &&
move.moveType != MoveType.CastleKingside &&
move.moveType != MoveType.CastleQueenside
private def scoreMove(
child: GameContext,
childState: SearchState,
params: SearchParams,
extension: Int,
reduction: Int,
a: Int,
): Int =
val betaNeg = -params.window.beta
if reduction > 0 then
val (rs, _) = search(
child,
math.max(0, params.depth - 1 - reduction + extension),
params.ply + 1,
Window(-a - 1, -a),
childState,
params.excludedRootMoves,
)
val s = -rs
if s > a then
val (fs, _) = search(
child,
math.max(0, params.depth - 1 + extension),
params.ply + 1,
Window(betaNeg, -a),
childState,
params.excludedRootMoves,
)
-fs
else s
else
val (rs, _) = search(
child,
math.max(0, params.depth - 1 + extension),
params.ply + 1,
Window(betaNeg, -a),
childState,
params.excludedRootMoves,
)
-rs
private def evalSingleMove(
move: Move,
moveNumber: Int,
a: Int,
params: SearchParams,
): Option[(Int, Boolean)] =
val skipRoot = params.ply == 0 && params.excludedRootMoves.contains(move)
val isQuiet = isQuietMove(params.context, move)
val futility = params.depth == 1 && isQuiet && moveNumber > 2 &&
weights.evaluateAccumulator(params.ply, params.context, params.state.hash) + FUTILITY_MARGIN < params.window.alpha
if skipRoot || futility then None
else
val child = rules.applyMove(params.context)(move)
val childHash = ZobristHash.nextHash(params.context, params.state.hash, move, child)
weights.pushAccumulator(params.ply + 1, move, params.context, child)
val childState = params.state.advance(childHash)
val extension = if rules.isCheck(child) then CHECK_EXTENSION else 0
val reduction = if moveNumber > 4 && params.depth >= 3 && isQuiet then 1 else 0
Some((scoreMove(child, childState, params, extension, reduction, a), isQuiet))
private def recordCutoff(move: Move, depth: Int, ply: Int): Unit =
ordering.addHistory(
move.from.rank.ordinal * 8 + move.from.file.ordinal,
move.to.rank.ordinal * 8 + move.to.file.ordinal,
depth * depth,
)
ordering.addKillerMove(ply, move)
@scala.annotation.tailrec
private def searchLoop(
idx: Int,
moveNumber: Int,
acc: LoopAcc,
params: SearchParams,
ordered: List[Move],
): (Option[Move], Int, Boolean) =
if idx >= ordered.length then (acc.bestMove, acc.bestScore, false)
else
val move = ordered(idx)
evalSingleMove(move, moveNumber, acc.a, params) match
case None => searchLoop(idx + 1, moveNumber + 1, acc, params, ordered)
case Some((score, isQuiet)) =>
val newAcc = LoopAcc(
if score > acc.bestScore then Some(move) else acc.bestMove,
math.max(acc.bestScore, score),
math.max(acc.a, score),
)
if newAcc.a >= params.window.beta then
if isQuiet then recordCutoff(move, params.depth, params.ply)
(newAcc.bestMove, newAcc.bestScore, true)
else searchLoop(idx + 1, moveNumber + 1, newAcc, params, ordered)
private def searchSequential(
context: GameContext,
depth: Int,
ply: Int,
window: Window,
ordered: List[Move],
state: SearchState,
excludedRootMoves: Set[Move],
): (Int, Option[Move]) =
val params = SearchParams(context, depth, ply, window, state, excludedRootMoves)
val (bestMove, bestScore, cutoff) = searchLoop(0, 0, LoopAcc(None, -INF, window.alpha), params, ordered)
val flag =
if cutoff then TTFlag.Lower
else if bestScore <= window.alpha then TTFlag.Upper
else TTFlag.Exact
tt.store(TTEntry(state.hash, depth, bestScore, flag, bestMove))
(bestScore, bestMove)
/** Quiescence search: only captures until position is quiet. */
private def quiescence(
context: GameContext,
ply: Int,
alpha: Int,
beta: Int,
hash: Long,
): Int =
quiescenceNode(QuiescenceNode(context, ply, alpha, beta, hash))
private def quiescenceNode(node: QuiescenceNode): Int =
val inCheck = rules.isCheck(node.context)
val standPat = if inCheck then -INF else weights.evaluateAccumulator(node.ply, node.context, node.hash)
if !inCheck && standPat >= node.beta then node.beta
else if node.ply >= MAX_QUIESCENCE_PLY then quiescenceAtDepthLimit(node, inCheck, standPat)
else
val moves = tacticalMoves(node.context, inCheck)
if inCheck && moves.isEmpty then -(weights.CHECKMATE_SCORE - node.ply)
else
val ordered = MoveOrdering.sort(node.context, moves, None)
val a0 = if inCheck then node.alpha else math.max(node.alpha, standPat)
quiescenceLoop(node, ordered, 0, a0)
private def quiescenceAtDepthLimit(node: QuiescenceNode, inCheck: Boolean, standPat: Int): Int =
if inCheck then weights.evaluateAccumulator(node.ply, node.context, node.hash) else standPat
private def tacticalMoves(context: GameContext, inCheck: Boolean): List[Move] =
val allMoves = rules.allLegalMoves(context)
if inCheck then allMoves else allMoves.filter(m => isCapture(context, m))
@scala.annotation.tailrec
private def quiescenceLoop(
node: QuiescenceNode,
ordered: List[Move],
idx: Int,
a: Int,
): Int =
if idx >= ordered.length then a
else
val move = ordered(idx)
val child = rules.applyMove(node.context)(move)
val childHash = ZobristHash.nextHash(node.context, node.hash, move, child)
weights.pushAccumulator(node.ply + 1, move, node.context, child)
val score = -quiescence(child, node.ply + 1, -node.beta, -a, childHash)
if score >= node.beta then node.beta
else quiescenceLoop(node, ordered, idx + 1, math.max(a, score))
private def isCapture(context: GameContext, move: Move): Boolean = move.moveType match
case MoveType.Normal(true) => true
case MoveType.EnPassant => true
case MoveType.Promotion(_) => context.board.pieceAt(move.to).exists(_.color != context.turn)
case _ => false
@@ -0,0 +1,177 @@
package de.nowchess.bot.logic
import de.nowchess.api.board.{Board, Color, Piece, PieceType, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
import scala.annotation.tailrec
import scala.collection.mutable
object MoveOrdering:
class OrderingContext:
private val killerMoves = mutable.Map[Int, List[Move]]()
private val historyTable = mutable.Map[(Int, Int), Int]()
def addKillerMove(ply: Int, move: Move): Unit =
val current = killerMoves.getOrElse(ply, List())
if current.isEmpty || (current.head.from != move.from || current.head.to != move.to) then
killerMoves(ply) = (move :: current).take(2)
def getKillerMoves(ply: Int): List[Move] =
killerMoves.getOrElse(ply, List())
def addHistory(from: Int, to: Int, bonus: Int): Unit =
val key = (from, to)
historyTable(key) = historyTable.getOrElse(key, 0) + bonus
def getHistory(from: Int, to: Int): Int =
historyTable.getOrElse((from, to), 0)
def clear(): Unit =
killerMoves.clear()
historyTable.clear()
def score(
context: GameContext,
move: Move,
ttBestMove: Option[Move],
ply: Int = 0,
ordering: OrderingContext = new OrderingContext(),
): Int =
if ttBestMove.exists(m => m.from == move.from && m.to == move.to) then Int.MaxValue
else
move.moveType match
case MoveType.Promotion(PromotionPiece.Queen) =>
1_000_000 + promotionCaptureBonus(context, move)
case MoveType.Normal(true) | MoveType.EnPassant =>
captureScore(context, move)
case MoveType.Promotion(_) =>
50_000 + promotionCaptureBonus(context, move)
case _ => scoreQuietMove(move, ply, ordering)
def sort(
context: GameContext,
moves: List[Move],
ttBestMove: Option[Move],
ply: Int = 0,
ordering: OrderingContext = new OrderingContext(),
): List[Move] =
moves.sortBy(m => -score(context, m, ttBestMove, ply, ordering))
private def scoreQuietMove(move: Move, ply: Int, ordering: OrderingContext): Int =
val isKiller = ordering.getKillerMoves(ply).exists(k => k.from == move.from && k.to == move.to)
val fromIdx = move.from.rank.ordinal * 8 + move.from.file.ordinal
val toIdx = move.to.rank.ordinal * 8 + move.to.file.ordinal
val history = ordering.getHistory(fromIdx, toIdx)
if isKiller then 10_000 + (history / 10) else history / 10
private def promotionCaptureBonus(context: GameContext, move: Move): Int =
if isCapture(context, move) then captureScore(context, move) else 0
private def captureScore(context: GameContext, move: Move): Int =
val see = staticExchange(context, move)
val seeBias = if see >= 0 then 20_000 else -20_000
100_000 + mvvLva(context, move) + seeBias + see
private def mvvLva(context: GameContext, move: Move): Int =
(victimValue(context, move) * 10) - attackerValue(context, move)
private def attackerValue(context: GameContext, move: Move): Int =
context.board.pieceAt(move.from).map(pieceValue).getOrElse(0)
private def victimValue(context: GameContext, move: Move): Int =
move.moveType match
case MoveType.Normal(true) => context.board.pieceAt(move.to).map(pieceValue).getOrElse(0)
case MoveType.EnPassant => 1
case MoveType.Promotion(_) => context.board.pieceAt(move.to).map(pieceValue).getOrElse(0)
case _ => 0
private def pieceValue(piece: Piece): Int = piece.pieceType match
case PieceType.Pawn => 1
case PieceType.Knight => 3
case PieceType.Bishop => 3
case PieceType.Rook => 5
case PieceType.Queen => 9
case PieceType.King => 200
private def isCapture(context: GameContext, move: Move): Boolean = move.moveType match
case MoveType.Normal(true) => true
case MoveType.EnPassant => true
case MoveType.Promotion(_) => context.board.pieceAt(move.to).exists(_.color != context.turn)
case _ => false
private def staticExchange(context: GameContext, move: Move): Int =
if !isCapture(context, move) then 0
else
val target = move.to
val initialGain = victimValue(context, move)
movedPieceAfterMove(context, move).fold(initialGain) { moved =>
val boardAfterMove = applySeeMove(context.board, move, moved)
initialGain - seeGain(boardAfterMove, target, context.turn.opposite, pieceValue(moved))
}
private def movedPieceAfterMove(context: GameContext, move: Move): Option[Piece] =
move.moveType match
case MoveType.Promotion(pp) => Some(Piece(context.turn, promotionPieceType(pp)))
case _ => context.board.pieceAt(move.from)
private def seeGain(board: Board, target: Square, side: Color, currentValue: Int): Int =
leastValuableAttacker(board, target, side) match
case None => 0
case Some((from, attacker)) =>
val nextBoard = board.removed(from).updated(target, attacker)
val replyGain = seeGain(nextBoard, target, side.opposite, pieceValue(attacker))
math.max(0, currentValue - replyGain)
private def applySeeMove(board: Board, move: Move, moved: Piece): Board =
move.moveType match
case MoveType.EnPassant =>
val capturedSquare = Square(move.to.file, move.from.rank)
board.removed(move.from).removed(capturedSquare).updated(move.to, moved)
case _ => board.removed(move.from).updated(move.to, moved)
private def leastValuableAttacker(board: Board, target: Square, color: Color): Option[(Square, Piece)] =
board.pieces
.collect {
case (sq, piece) if piece.color == color && attacksSquare(board, sq, target, piece) => (sq, piece)
}
.toList
.sortBy { case (_, piece) => pieceValue(piece) }
.headOption
private def attacksSquare(board: Board, from: Square, target: Square, piece: Piece): Boolean =
val df = target.file.ordinal - from.file.ordinal
val dr = target.rank.ordinal - from.rank.ordinal
piece.pieceType match
case PieceType.Pawn =>
val dir = if piece.color == Color.White then 1 else -1
dr == dir && math.abs(df) == 1
case PieceType.Knight =>
val adf = math.abs(df)
val adr = math.abs(dr)
(adf == 1 && adr == 2) || (adf == 2 && adr == 1)
case PieceType.Bishop => clearLine(board, from, target, df, dr, diagonal = true)
case PieceType.Rook => clearLine(board, from, target, df, dr, diagonal = false)
case PieceType.Queen =>
clearLine(board, from, target, df, dr, diagonal = true) ||
clearLine(board, from, target, df, dr, diagonal = false)
case PieceType.King => math.abs(df) <= 1 && math.abs(dr) <= 1
private def clearLine(board: Board, from: Square, target: Square, df: Int, dr: Int, diagonal: Boolean): Boolean =
val valid =
if diagonal then math.abs(df) == math.abs(dr) && df != 0 else (df == 0 && dr != 0) || (dr == 0 && df != 0)
valid && pathClear(board, from, target, Integer.compare(df, 0), Integer.compare(dr, 0))
@tailrec
private def pathClear(board: Board, from: Square, target: Square, stepF: Int, stepR: Int): Boolean =
from.offset(stepF, stepR) match
case None => false
case Some(next) if next == target => true
case Some(next) => board.pieceAt(next).isEmpty && pathClear(board, next, target, stepF, stepR)
private def promotionPieceType(piece: PromotionPiece): PieceType = piece match
case PromotionPiece.Knight => PieceType.Knight
case PromotionPiece.Bishop => PieceType.Bishop
case PromotionPiece.Rook => PieceType.Rook
case PromotionPiece.Queen => PieceType.Queen
@@ -0,0 +1,61 @@
package de.nowchess.bot.logic
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
final case class Window(alpha: Int, beta: Int)
final case class LoopAcc(bestMove: Option[Move], bestScore: Int, a: Int)
final case class SearchParams(
context: GameContext,
depth: Int,
ply: Int,
window: Window,
state: SearchState,
excludedRootMoves: Set[Move],
)
final case class SearchState(hash: Long, repetitions: Map[Long, Int]):
def advance(nextHash: Long): SearchState =
SearchState(
nextHash,
repetitions.updatedWith(nextHash) {
case Some(v) => Some(v + 1)
case None => Some(1)
},
)
enum TTFlag:
case Exact // Score is exact
case Lower // Score is a lower bound
case Upper // Score is an upper bound
final case class TTEntry(
hash: Long,
depth: Int,
score: Int,
flag: TTFlag,
bestMove: Option[Move],
)
final class TranspositionTable(val sizePow2: Int = 20):
private val size = 1 << sizePow2
private val mask = size - 1L
private val locks = Array.fill(size)(new Object())
private val table: Array[Option[TTEntry]] = Array.fill(size)(None)
def probe(hash: Long): Option[TTEntry] =
val index = (hash & mask).toInt
locks(index).synchronized {
table(index).filter(_.hash == hash)
}
def store(entry: TTEntry): Unit =
val index = (entry.hash & mask).toInt
locks(index).synchronized {
table(index) = Some(entry)
}
def clear(): Unit =
for i <- 0 until size do locks(i).synchronized { table(i) = None }
@@ -0,0 +1,137 @@
package de.nowchess.bot.util
import de.nowchess.api.board.*
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
import java.io.{DataInputStream, FileInputStream}
import scala.collection.mutable
import scala.util.Random
/** Reads a Polyglot opening book (.bin file) and probes it for moves.
*
* Polyglot books are binary files with 16-byte big-endian records:
* - key: 8 bytes (Long) — Zobrist hash of the position
* - move: 2 bytes (Short) — packed as (to_file | to_rank | from_file | from_rank | promotion)
* - weight: 2 bytes (Short) — move weight (higher = preferred)
* - learn: 4 bytes (Int) — learning data (unused)
*/
final class PolyglotBook(path: String):
private val entries: Map[Long, Vector[BookEntry]] =
try {
val r = loadBookFile(path)
println(s"Book loaded successfully. ${r.size} entries found.")
r
} catch
case e: Exception =>
println(s"Error loading book: $e")
// Gracefully fail: return empty map if book cannot be loaded
// This allows the bot to work even if the book file is missing
scala.collection.immutable.Map.empty
/** Probe the book for a move in the given position. Returns a weighted random move, or None if not in book. */
def probe(context: GameContext): Option[Move] =
val hash = PolyglotHash.hash(context)
println(f"0x$hash%016X")
entries.get(hash).flatMap { bookEntries =>
if bookEntries.isEmpty then None
else
val entry = weightedRandom(bookEntries)
decodeMove(entry.move, context)
}
private def loadBookFile(path: String): Map[Long, Vector[BookEntry]] =
val input = DataInputStream(FileInputStream(path))
try
val result = mutable.Map[Long, Vector[BookEntry]]()
while input.available() > 0 do
val key = input.readLong()
val move = input.readShort()
val weight = input.readShort()
input.readInt() // learning data (unused)
val entry = BookEntry(key, move, weight)
result.updateWith(key) {
case Some(entries) => Some(entries :+ entry)
case None => Some(Vector(entry))
}
result.toMap
finally input.close()
/** Decode a packed Polyglot move short into an Option[Move].
*
* Bit layout of the move Short:
* - bits 0-2: to_file (0-7)
* - bits 3-5: to_rank (0-7)
* - bits 6-8: from_file (0-7)
* - bits 9-11: from_rank (0-7)
* - bits 12-14: promotion piece (0=none, 1=knight, 2=bishop, 3=rook, 4=queen)
*/
private def decodeMove(raw: Short, context: GameContext): Option[Move] =
val toFile = raw & 0x07
val toRank = (raw >> 3) & 0x07
val fromFile = (raw >> 6) & 0x07
val fromRank = (raw >> 9) & 0x07
val promotionBits = (raw >> 12) & 0x07
if toFile > 7 || toRank > 7 || fromFile > 7 || fromRank > 7 then None
else
val from = Square(File.values(fromFile), Rank.values(fromRank))
val to = Square(File.values(toFile), Rank.values(toRank))
if isKingMove(context, from) && isRookSquare(to, context) then Some(decodeCastling(from, to))
else
val moveTypeOpt: Option[MoveType] =
if promotionBits > 0 then
promotionBits match
case 1 => Some(MoveType.Promotion(PromotionPiece.Knight))
case 2 => Some(MoveType.Promotion(PromotionPiece.Bishop))
case 3 => Some(MoveType.Promotion(PromotionPiece.Rook))
case 4 => Some(MoveType.Promotion(PromotionPiece.Queen))
case _ => None
else Some(MoveType.Normal(context.board.pieces.contains(to)))
moveTypeOpt.map(moveType => Move(from, to, moveType))
private def isKingMove(context: GameContext, square: Square): Boolean =
context.board.pieces.get(square).exists { piece =>
piece.pieceType == PieceType.King
}
private def isRookSquare(square: Square, context: GameContext): Boolean =
context.board.pieces.get(square).exists { piece =>
piece.pieceType == PieceType.Rook
}
/** Decode castling from king-to-rook square to the standard move.
*
* Polyglot encodes castling as:
* - e1→h1 = White kingside (move to g1)
* - e1→a1 = White queenside (move to c1)
* - e8→h8 = Black kingside (move to g8)
* - e8→a8 = Black queenside (move to c8)
*/
private def decodeCastling(from: Square, to: Square): Move =
if to.file == File.H then Move(from, Square(File.G, to.rank), MoveType.CastleKingside)
else if to.file == File.A then Move(from, Square(File.C, to.rank), MoveType.CastleQueenside)
else
// Fallback (should not happen in a valid book)
Move(from, to, MoveType.Normal())
/** Select a weighted random move from the list of book entries. */
private def weightedRandom(entries: Vector[BookEntry]): BookEntry =
if entries.length == 1 then entries.head
else
val totalWeight = entries.map(_.weight).sum
val pick = Random.nextInt(totalWeight.max(1)) // NOSONAR
@scala.annotation.tailrec
def select(remaining: Int, idx: Int): BookEntry =
if idx >= entries.length then entries.last
else if remaining < entries(idx).weight then entries(idx)
else select(remaining - entries(idx).weight, idx + 1)
select(pick, 0)
private case class BookEntry(key: Long, move: Short, weight: Int)
@@ -0,0 +1,204 @@
package de.nowchess.bot.util
import de.nowchess.api.board.{Color, Piece, PieceType, Square}
import de.nowchess.api.game.GameContext
object PolyglotHash:
/** 781-entry Zobrist random table from the Polyglot spec. */
private val Random: Array[Long] = Array(
0x9d39247e33776d41L, 0x2af7398005aaa5c7L, 0x44db015024623547L, 0x9c15f73e62a76ae2L, 0x75834465489c0c89L,
0x3290ac3a203001bfL, 0x0fbbad1f61042279L, 0xe83a908ff2fb60caL, 0x0d7e765d58755c10L, 0x1a083822ceafe02dL,
0x9605d5f0e25ec3b0L, 0xd021ff5cd13a2ed5L, 0x40bdf15d4a672e32L, 0x011355146fd56395L, 0x5db4832046f3d9e5L,
0x239f8b2d7ff719ccL, 0x05d1a1ae85b49aa1L, 0x679f848f6e8fc971L, 0x7449bbff801fed0bL, 0x7d11cdb1c3b7adf0L,
0x82c7709e781eb7ccL, 0xf3218f1c9510786cL, 0x331478f3af51bbe6L, 0x4bb38de5e7219443L, 0xaa649c6ebcfd50fcL,
0x8dbd98a352afd40bL, 0x87d2074b81d79217L, 0x19f3c751d3e92ae1L, 0xb4ab30f062b19abfL, 0x7b0500ac42047ac4L,
0xc9452ca81a09d85dL, 0x24aa6c514da27500L, 0x4c9f34427501b447L, 0x14a68fd73c910841L, 0xa71b9b83461cbd93L,
0x03488b95b0f1850fL, 0x637b2b34ff93c040L, 0x09d1bc9a3dd90a94L, 0x3575668334a1dd3bL, 0x735e2b97a4c45a23L,
0x18727070f1bd400bL, 0x1fcbacd259bf02e7L, 0xd310a7c2ce9b6555L, 0xbf983fe0fe5d8244L, 0x9f74d14f7454a824L,
0x51ebdc4ab9ba3035L, 0x5c82c505db9ab0faL, 0xfcf7fe8a3430b241L, 0x3253a729b9ba3ddeL, 0x8c74c368081b3075L,
0xb9bc6c87167c33e7L, 0x7ef48f2b83024e20L, 0x11d505d4c351bd7fL, 0x6568fca92c76a243L, 0x4de0b0f40f32a7b8L,
0x96d693460cc37e5dL, 0x42e240cb63689f2fL, 0x6d2bdcdae2919661L, 0x42880b0236e4d951L, 0x5f0f4a5898171bb6L,
0x39f890f579f92f88L, 0x93c5b5f47356388bL, 0x63dc359d8d231b78L, 0xec16ca8aea98ad76L, 0x5355f900c2a82dc7L,
0x07fb9f855a997142L, 0x5093417aa8a7ed5eL, 0x7bcbc38da25a7f3cL, 0x19fc8a768cf4b6d4L, 0x637a7780decfc0d9L,
0x8249a47aee0e41f7L, 0x79ad695501e7d1e8L, 0x14acbaf4777d5776L, 0xf145b6beccdea195L, 0xdabf2ac8201752fcL,
0x24c3c94df9c8d3f6L, 0xbb6e2924f03912eaL, 0x0ce26c0b95c980d9L, 0xa49cd132bfbf7cc4L, 0xe99d662af4243939L,
0x27e6ad7891165c3fL, 0x8535f040b9744ff1L, 0x54b3f4fa5f40d873L, 0x72b12c32127fed2bL, 0xee954d3c7b411f47L,
0x9a85ac909a24eaa1L, 0x70ac4cd9f04f21f5L, 0xf9b89d3e99a075c2L, 0x87b3e2b2b5c907b1L, 0xa366e5b8c54f48b8L,
0xae4a9346cc3f7cf2L, 0x1920c04d47267bbdL, 0x87bf02c6b49e2ae9L, 0x092237ac237f3859L, 0xff07f64ef8ed14d0L,
0x8de8dca9f03cc54eL, 0x9c1633264db49c89L, 0xb3f22c3d0b0b38edL, 0x390e5fb44d01144bL, 0x5bfea5b4712768e9L,
0x1e1032911fa78984L, 0x9a74acb964e78cb3L, 0x4f80f7a035dafb04L, 0x6304d09a0b3738c4L, 0x2171e64683023a08L,
0x5b9b63eb9ceff80cL, 0x506aacf489889342L, 0x1881afc9a3a701d6L, 0x6503080440750644L, 0xdfd395339cdbf4a7L,
0xef927dbcf00c20f2L, 0x7b32f7d1e03680ecL, 0xb9fd7620e7316243L, 0x05a7e8a57db91b77L, 0xb5889c6e15630a75L,
0x4a750a09ce9573f7L, 0xcf464cec899a2f8aL, 0xf538639ce705b824L, 0x3c79a0ff5580ef7fL, 0xede6c87f8477609dL,
0x799e81f05bc93f31L, 0x86536b8cf3428a8cL, 0x97d7374c60087b73L, 0xa246637cff328532L, 0x043fcae60cc0eba0L,
0x920e449535dd359eL, 0x70eb093b15b290ccL, 0x73a1921916591cbdL, 0x56436c9fe1a1aa8dL, 0xefac4b70633b8f81L,
0xbb215798d45df7afL, 0x45f20042f24f1768L, 0x930f80f4e8eb7462L, 0xff6712ffcfd75ea1L, 0xae623fd67468aa70L,
0xdd2c5bc84bc8d8fcL, 0x7eed120d54cf2dd9L, 0x22fe545401165f1cL, 0xc91800e98fb99929L, 0x808bd68e6ac10365L,
0xdec468145b7605f6L, 0x1bede3a3aef53302L, 0x43539603d6c55602L, 0xaa969b5c691ccb7aL, 0xa87832d392efee56L,
0x65942c7b3c7e11aeL, 0xded2d633cad004f6L, 0x21f08570f420e565L, 0xb415938d7da94e3cL, 0x91b859e59ecb6350L,
0x10cff333e0ed804aL, 0x28aed140be0bb7ddL, 0xc5cc1d89724fa456L, 0x5648f680f11a2741L, 0x2d255069f0b7dab3L,
0x9bc5a38ef729abd4L, 0xef2f054308f6a2bcL, 0xaf2042f5cc5c2858L, 0x480412bab7f5be2aL, 0xaef3af4a563dfe43L,
0x19afe59ae451497fL, 0x52593803dff1e840L, 0xf4f076e65f2ce6f0L, 0x11379625747d5af3L, 0xbce5d2248682c115L,
0x9da4243de836994fL, 0x066f70b33fe09017L, 0x4dc4de189b671a1cL, 0x51039ab7712457c3L, 0xc07a3f80c31fb4b4L,
0xb46ee9c5e64a6e7cL, 0xb3819a42abe61c87L, 0x21a007933a522a20L, 0x2df16f761598aa4fL, 0x763c4a1371b368fdL,
0xf793c46702e086a0L, 0xd7288e012aeb8d31L, 0xde336a2a4bc1c44bL, 0x0bf692b38d079f23L, 0x2c604a7a177326b3L,
0x4850e73e03eb6064L, 0xcfc447f1e53c8e1bL, 0xb05ca3f564268d99L, 0x9ae182c8bc9474e8L, 0xa4fc4bd4fc5558caL,
0xe755178d58fc4e76L, 0x69b97db1a4c03dfeL, 0xf9b5b7c4acc67c96L, 0xfc6a82d64b8655fbL, 0x9c684cb6c4d24417L,
0x8ec97d2917456ed0L, 0x6703df9d2924e97eL, 0xc547f57e42a7444eL, 0x78e37644e7cad29eL, 0xfe9a44e9362f05faL,
0x08bd35cc38336615L, 0x9315e5eb3a129aceL, 0x94061b871e04df75L, 0xdf1d9f9d784ba010L, 0x3bba57b68871b59dL,
0xd2b7adeeded1f73fL, 0xf7a255d83bc373f8L, 0xd7f4f2448c0ceb81L, 0xd95be88cd210ffa7L, 0x336f52f8ff4728e7L,
0xa74049dac312ac71L, 0xa2f61bb6e437fdb5L, 0x4f2a5cb07f6a35b3L, 0x87d380bda5bf7859L, 0x16b9f7e06c453a21L,
0x7ba2484c8a0fd54eL, 0xf3a678cad9a2e38cL, 0x39b0bf7dde437ba2L, 0xfcaf55c1bf8a4424L, 0x18fcf680573fa594L,
0x4c0563b89f495ac3L, 0x40e087931a00930dL, 0x8cffa9412eb642c1L, 0x68ca39053261169fL, 0x7a1ee967d27579e2L,
0x9d1d60e5076f5b6fL, 0x3810e399b6f65ba2L, 0x32095b6d4ab5f9b1L, 0x35cab62109dd038aL, 0xa90b24499fcfafb1L,
0x77a225a07cc2c6bdL, 0x513e5e634c70e331L, 0x4361c0ca3f692f12L, 0xd941aca44b20a45bL, 0x528f7c8602c5807bL,
0x52ab92beb9613989L, 0x9d1dfa2efc557f73L, 0x722ff175f572c348L, 0x1d1260a51107fe97L, 0x7a249a57ec0c9ba2L,
0x04208fe9e8f7f2d6L, 0x5a110c6058b920a0L, 0x0cd9a497658a5698L, 0x56fd23c8f9715a4cL, 0x284c847b9d887aaeL,
0x04feabfbbdb619cbL, 0x742e1e651c60ba83L, 0x9a9632e65904ad3cL, 0x881b82a13b51b9e2L, 0x506e6744cd974924L,
0xb0183db56ffc6a79L, 0x0ed9b915c66ed37eL, 0x5e11e86d5873d484L, 0xf678647e3519ac6eL, 0x1b85d488d0f20cc5L,
0xdab9fe6525d89021L, 0x0d151d86adb73615L, 0xa865a54edcc0f019L, 0x93c42566aef98ffbL, 0x99e7afeabe000731L,
0x48cbff086ddf285aL, 0x7f9b6af1ebf78bafL, 0x58627e1a149bba21L, 0x2cd16e2abd791e33L, 0xd363eff5f0977996L,
0x0ce2a38c344a6eedL, 0x1a804aadb9cfa741L, 0x907f30421d78c5deL, 0x501f65edb3034d07L, 0x37624ae5a48fa6e9L,
0x957baf61700cff4eL, 0x3a6c27934e31188aL, 0xd49503536abca345L, 0x088e049589c432e0L, 0xf943aee7febf21b8L,
0x6c3b8e3e336139d3L, 0x364f6ffa464ee52eL, 0xd60f6dcedc314222L, 0x56963b0dca418fc0L, 0x16f50edf91e513afL,
0xef1955914b609f93L, 0x565601c0364e3228L, 0xecb53939887e8175L, 0xbac7a9a18531294bL, 0xb344c470397bba52L,
0x65d34954daf3cebdL, 0xb4b81b3fa97511e2L, 0xb422061193d6f6a7L, 0x071582401c38434dL, 0x7a13f18bbedc4ff5L,
0xbc4097b116c524d2L, 0x59b97885e2f2ea28L, 0x99170a5dc3115544L, 0x6f423357e7c6a9f9L, 0x325928ee6e6f8794L,
0xd0e4366228b03343L, 0x565c31f7de89ea27L, 0x30f5611484119414L, 0xd873db391292ed4fL, 0x7bd94e1d8e17debcL,
0xc7d9f16864a76e94L, 0x947ae053ee56e63cL, 0xc8c93882f9475f5fL, 0x3a9bf55ba91f81caL, 0xd9a11fbb3d9808e4L,
0x0fd22063edc29fcaL, 0xb3f256d8aca0b0b9L, 0xb03031a8b4516e84L, 0x35dd37d5871448afL, 0xe9f6082b05542e4eL,
0xebfafa33d7254b59L, 0x9255abb50d532280L, 0xb9ab4ce57f2d34f3L, 0x693501d628297551L, 0xc62c58f97dd949bfL,
0xcd454f8f19c5126aL, 0xbbe83f4ecc2bdecbL, 0xdc842b7e2819e230L, 0xba89142e007503b8L, 0xa3bc941d0a5061cbL,
0xe9f6760e32cd8021L, 0x09c7e552bc76492fL, 0x852f54934da55cc9L, 0x8107fccf064fcf56L, 0x098954d51fff6580L,
0x23b70edb1955c4bfL, 0xc330de426430f69dL, 0x4715ed43e8a45c0aL, 0xa8d7e4dab780a08dL, 0x0572b974f03ce0bbL,
0xb57d2e985e1419c7L, 0xe8d9ecbe2cf3d73fL, 0x2fe4b17170e59750L, 0x11317ba87905e790L, 0x7fbf21ec8a1f45ecL,
0x1725cabfcb045b00L, 0x964e915cd5e2b207L, 0x3e2b8bcbf016d66dL, 0xbe7444e39328a0acL, 0xf85b2b4fbcde44b7L,
0x49353fea39ba63b1L, 0x1dd01aafcd53486aL, 0x1fca8a92fd719f85L, 0xfc7c95d827357afaL, 0x18a6a990c8b35ebdL,
0xcccb7005c6b9c28dL, 0x3bdbb92c43b17f26L, 0xaa70b5b4f89695a2L, 0xe94c39a54a98307fL, 0xb7a0b174cff6f36eL,
0xd4dba84729af48adL, 0x2e18bc1ad9704a68L, 0x2de0966daf2f8b1cL, 0xb9c11d5b1e43a07eL, 0x64972d68dee33360L,
0x94628d38d0c20584L, 0xdbc0d2b6ab90a559L, 0xd2733c4335c6a72fL, 0x7e75d99d94a70f4dL, 0x6ced1983376fa72bL,
0x97fcaacbf030bc24L, 0x7b77497b32503b12L, 0x8547eddfb81ccb94L, 0x79999cdff70902cbL, 0xcffe1939438e9b24L,
0x829626e3892d95d7L, 0x92fae24291f2b3f1L, 0x63e22c147b9c3403L, 0xc678b6d860284a1cL, 0x5873888850659ae7L,
0x0981dcd296a8736dL, 0x9f65789a6509a440L, 0x9ff38fed72e9052fL, 0xe479ee5b9930578cL, 0xe7f28ecd2d49eecdL,
0x56c074a581ea17feL, 0x5544f7d774b14aefL, 0x7b3f0195fc6f290fL, 0x12153635b2c0cf57L, 0x7f5126dbba5e0ca7L,
0x7a76956c3eafb413L, 0x3d5774a11d31ab39L, 0x8a1b083821f40cb4L, 0x7b4a38e32537df62L, 0x950113646d1d6e03L,
0x4da8979a0041e8a9L, 0x3bc36e078f7515d7L, 0x5d0a12f27ad310d1L, 0x7f9d1a2e1ebe1327L, 0xda3a361b1c5157b1L,
0xdcdd7d20903d0c25L, 0x36833336d068f707L, 0xce68341f79893389L, 0xab9090168dd05f34L, 0x43954b3252dc25e5L,
0xb438c2b67f98e5e9L, 0x10dcd78e3851a492L, 0xdbc27ab5447822bfL, 0x9b3cdb65f82ca382L, 0xb67b7896167b4c84L,
0xbfced1b0048eac50L, 0xa9119b60369ffebdL, 0x1fff7ac80904bf45L, 0xac12fb171817eee7L, 0xaf08da9177dda93dL,
0x1b0cab936e65c744L, 0xb559eb1d04e5e932L, 0xc37b45b3f8d6f2baL, 0xc3a9dc228caac9e9L, 0xf3b8b6675a6507ffL,
0x9fc477de4ed681daL, 0x67378d8eccef96cbL, 0x6dd856d94d259236L, 0xa319ce15b0b4db31L, 0x073973751f12dd5eL,
0x8a8e849eb32781a5L, 0xe1925c71285279f5L, 0x74c04bf1790c0efeL, 0x4dda48153c94938aL, 0x9d266d6a1cc0542cL,
0x7440fb816508c4feL, 0x13328503df48229fL, 0xd6bf7baee43cac40L, 0x4838d65f6ef6748fL, 0x1e152328f3318deaL,
0x8f8419a348f296bfL, 0x72c8834a5957b511L, 0xd7a023a73260b45cL, 0x94ebc8abcfb56daeL, 0x9fc10d0f989993e0L,
0xde68a2355b93cae6L, 0xa44cfe79ae538bbeL, 0x9d1d84fcce371425L, 0x51d2b1ab2ddfb636L, 0x2fd7e4b9e72cd38cL,
0x65ca5b96b7552210L, 0xdd69a0d8ab3b546dL, 0x604d51b25fbf70e2L, 0x73aa8a564fb7ac9eL, 0x1a8c1e992b941148L,
0xaac40a2703d9bea0L, 0x764dbeae7fa4f3a6L, 0x1e99b96e70a9be8bL, 0x2c5e9deb57ef4743L, 0x3a938fee32d29981L,
0x26e6db8ffdf5adfeL, 0x469356c504ec9f9dL, 0xc8763c5b08d1908cL, 0x3f6c6af859d80055L, 0x7f7cc39420a3a545L,
0x9bfb227ebdf4c5ceL, 0x89039d79d6fc5c5cL, 0x8fe88b57305e2ab6L, 0xa09e8c8c35ab96deL, 0xfa7e393983325753L,
0xd6b6d0ecc617c699L, 0xdfea21ea9e7557e3L, 0xb67c1fa481680af8L, 0xca1e3785a9e724e5L, 0x1cfc8bed0d681639L,
0xd18d8549d140caeaL, 0x4ed0fe7e9dc91335L, 0xe4dbf0634473f5d2L, 0x1761f93a44d5aefeL, 0x53898e4c3910da55L,
0x734de8181f6ec39aL, 0x2680b122baa28d97L, 0x298af231c85bafabL, 0x7983eed3740847d5L, 0x66c1a2a1a60cd889L,
0x9e17e49642a3e4c1L, 0xedb454e7badc0805L, 0x50b704cab602c329L, 0x4cc317fb9cddd023L, 0x66b4835d9eafea22L,
0x219b97e26ffc81bdL, 0x261e4e4c0a333a9dL, 0x1fe2cca76517db90L, 0xd7504dfa8816edbbL, 0xb9571fa04dc089c8L,
0x1ddc0325259b27deL, 0xcf3f4688801eb9aaL, 0xf4f5d05c10cab243L, 0x38b6525c21a42b0eL, 0x36f60e2ba4fa6800L,
0xeb3593803173e0ceL, 0x9c4cd6257c5a3603L, 0xaf0c317d32adaa8aL, 0x258e5a80c7204c4bL, 0x8b889d624d44885dL,
0xf4d14597e660f855L, 0xd4347f66ec8941c3L, 0xe699ed85b0dfb40dL, 0x2472f6207c2d0484L, 0xc2a1e7b5b459aeb5L,
0xab4f6451cc1d45ecL, 0x63767572ae3d6174L, 0xa59e0bd101731a28L, 0x116d0016cb948f09L, 0x2cf9c8ca052f6e9fL,
0x0b090a7560a968e3L, 0xabeeddb2dde06ff1L, 0x58efc10b06a2068dL, 0xc6e57a78fbd986e0L, 0x2eab8ca63ce802d7L,
0x14a195640116f336L, 0x7c0828dd624ec390L, 0xd74bbe77e6116ac7L, 0x804456af10f5fb53L, 0xebe9ea2adf4321c7L,
0x03219a39ee587a30L, 0x49787fef17af9924L, 0xa1e9300cd8520548L, 0x5b45e522e4b1b4efL, 0xb49c3b3995091a36L,
0xd4490ad526f14431L, 0x12a8f216af9418c2L, 0x001f837cc7350524L, 0x1877b51e57a764d5L, 0xa2853b80f17f58eeL,
0x993e1de72d36d310L, 0xb3598080ce64a656L, 0x252f59cf0d9f04bbL, 0xd23c8e176d113600L, 0x1bda0492e7e4586eL,
0x21e0bd5026c619bfL, 0x3b097adaf088f94eL, 0x8d14dedb30be846eL, 0xf95cffa23af5f6f4L, 0x3871700761b3f743L,
0xca672b91e9e4fa16L, 0x64c8e531bff53b55L, 0x241260ed4ad1e87dL, 0x106c09b972d2e822L, 0x7fba195410e5ca30L,
0x7884d9bc6cb569d8L, 0x0647dfedcd894a29L, 0x63573ff03e224774L, 0x4fc8e9560f91b123L, 0x1db956e450275779L,
0xb8d91274b9e9d4fbL, 0xa2ebee47e2fbfce1L, 0xd9f1f30ccd97fb09L, 0xefed53d75fd64e6bL, 0x2e6d02c36017f67fL,
0xa9aa4d20db084e9bL, 0xb64be8d8b25396c1L, 0x70cb6af7c2d5bcf0L, 0x98f076a4f7a2322eL, 0xbf84470805e69b5fL,
0x94c3251f06f90cf3L, 0x3e003e616a6591e9L, 0xb925a6cd0421aff3L, 0x61bdd1307c66e300L, 0xbf8d5108e27e0d48L,
0x240ab57a8b888b20L, 0xfc87614baf287e07L, 0xef02cdd06ffdb432L, 0xa1082c0466df6c0aL, 0x8215e577001332c8L,
0xd39bb9c3a48db6cfL, 0x2738259634305c14L, 0x61cf4f94c97df93dL, 0x1b6baca2ae4e125bL, 0x758f450c88572e0bL,
0x959f587d507a8359L, 0xb063e962e045f54dL, 0x60e8ed72c0dff5d1L, 0x7b64978555326f9fL, 0xfd080d236da814baL,
0x8c90fd9b083f4558L, 0x106f72fe81e2c590L, 0x7976033a39f7d952L, 0xa4ec0132764ca04bL, 0x733ea705fae4fa77L,
0xb4d8f77bc3e56167L, 0x9e21f4f903b33fd9L, 0x9d765e419fb69f6dL, 0xd30c088ba61ea5efL, 0x5d94337fbfaf7f5bL,
0x1a4e4822eb4d7a59L, 0x6ffe73e81b637fb3L, 0xddf957bc36d8b9caL, 0x64d0e29eea8838b3L, 0x08dd9bdfd96b9f63L,
0x087e79e5a57d1d13L, 0xe328e230e3e2b3fbL, 0x1c2559e30f0946beL, 0x720bf5f26f4d2eaaL, 0xb0774d261cc609dbL,
0x443f64ec5a371195L, 0x4112cf68649a260eL, 0xd813f2fab7f5c5caL, 0x660d3257380841eeL, 0x59ac2c7873f910a3L,
0xe846963877671a17L, 0x93b633abfa3469f8L, 0xc0c0f5a60ef4cdcfL, 0xcaf21ecd4377b28cL, 0x57277707199b8175L,
0x506c11b9d90e8b1dL, 0xd83cc2687a19255fL, 0x4a29c6465a314cd1L, 0xed2df21216235097L, 0xb5635c95ff7296e2L,
0x22af003ab672e811L, 0x52e762596bf68235L, 0x9aeba33ac6ecc6b0L, 0x944f6de09134dfb6L, 0x6c47bec883a7de39L,
0x6ad047c430a12104L, 0xa5b1cfdba0ab4067L, 0x7c45d833aff07862L, 0x5092ef950a16da0bL, 0x9338e69c052b8e7bL,
0x455a4b4cfe30e3f5L, 0x6b02e63195ad0cf8L, 0x6b17b224bad6bf27L, 0xd1e0ccd25bb9c169L, 0xde0c89a556b9ae70L,
0x50065e535a213cf6L, 0x9c1169fa2777b874L, 0x78edefd694af1eedL, 0x6dc93d9526a50e68L, 0xee97f453f06791edL,
0x32ab0edb696703d3L, 0x3a6853c7e70757a7L, 0x31865ced6120f37dL, 0x67fef95d92607890L, 0x1f2b1d1f15f6dc9cL,
0xb69e38a8965c6b65L, 0xaa9119ff184cccf4L, 0xf43c732873f24c13L, 0xfb4a3d794a9a80d2L, 0x3550c2321fd6109cL,
0x371f77e76bb8417eL, 0x6bfa9aae5ec05779L, 0xcd04f3ff001a4778L, 0xe3273522064480caL, 0x9f91508bffcfc14aL,
0x049a7f41061a9e60L, 0xfcb6be43a9f2fe9bL, 0x08de8a1c7797da9bL, 0x8f9887e6078735a1L, 0xb5b4071dbfc73a66L,
0x230e343dfba08d33L, 0x43ed7f5a0fae657dL, 0x3a88a0fbbcb05c63L, 0x21874b8b4d2dbc4fL, 0x1bdea12e35f6a8c9L,
0x53c065c6c8e63528L, 0xe34a1d250e7a8d6bL, 0xd6b04d3b7651dd7eL, 0x5e90277e7cb39e2dL, 0x2c046f22062dc67dL,
0xb10bb459132d0a26L, 0x3fa9ddfb67e2f199L, 0x0e09b88e1914f7afL, 0x10e8b35af3eeab37L, 0x9eedeca8e272b933L,
0xd4c718bc4ae8ae5fL, 0x81536d601170fc20L, 0x91b534f885818a06L, 0xec8177f83f900978L, 0x190e714fada5156eL,
0xb592bf39b0364963L, 0x89c350c893ae7dc1L, 0xac042e70f8b383f2L, 0xb49b52e587a1ee60L, 0xfb152fe3ff26da89L,
0x3e666e6f69ae2c15L, 0x3b544ebe544c19f9L, 0xe805a1e290cf2456L, 0x24b33c9d7ed25117L, 0xe74733427b72f0c1L,
0x0a804d18b7097475L, 0x57e3306d881edb4fL, 0x4ae7d6a36eb5dbcbL, 0x2d8d5432157064c8L, 0xd1e649de1e7f268bL,
0x8a328a1cedfe552cL, 0x07a3aec79624c7daL, 0x84547ddc3e203c94L, 0x990a98fd5071d263L, 0x1a4ff12616eefc89L,
0xf6f7fd1431714200L, 0x30c05b1ba332f41cL, 0x8d2636b81555a786L, 0x46c9feb55d120902L, 0xccec0a73b49c9921L,
0x4e9d2827355fc492L, 0x19ebb029435dcb0fL, 0x4659d2b743848a2cL, 0x963ef2c96b33be31L, 0x74f85198b05a2e7dL,
0x5a0f544dd2b1fb18L, 0x03727073c2e134b1L, 0xc7f6aa2de59aea61L, 0x352787baa0d7c22fL, 0x9853eab63b5e0b35L,
0xabbdcdd7ed5c0860L, 0xcf05daf5ac8d77b0L, 0x49cad48cebf4a71eL, 0x7a4c10ec2158c4a6L, 0xd9e92aa246bf719eL,
0x13ae978d09fe5557L, 0x730499af921549ffL, 0x4e4b705b92903ba4L, 0xff577222c14f0a3aL, 0x55b6344cf97aafaeL,
0xb862225b055b6960L, 0xcac09afbddd2cdb4L, 0xdaf8e9829fe96b5fL, 0xb5fdfc5d3132c498L, 0x310cb380db6f7503L,
0xe87fbb46217a360eL, 0x2102ae466ebb1148L, 0xf8549e1a3aa5e00dL, 0x07a69afdcc42261aL, 0xc4c118bfe78feaaeL,
0xf9f4892ed96bd438L, 0x1af3dbe25d8f45daL, 0xf5b4b0b0d2deeeb4L, 0x962aceefa82e1c84L, 0x046e3ecaaf453ce9L,
0xf05d129681949a4cL, 0x964781ce734b3c84L, 0x9c2ed44081ce5fbdL, 0x522e23f3925e319eL, 0x177e00f9fc32f791L,
0x2bc60a63a6f3b3f2L, 0x222bbfae61725606L, 0x486289ddcc3d6780L, 0x7dc7785b8efdfc80L, 0x8af38731c02ba980L,
0x1fab64ea29a2ddf7L, 0xe4d9429322cd065aL, 0x9da058c67844f20cL, 0x24c0e332b70019b0L, 0x233003b5a6cfe6adL,
0xd586bd01c5c217f6L, 0x5e5637885f29bc2bL, 0x7eba726d8c94094bL, 0x0a56a5f0bfe39272L, 0xd79476a84ee20d06L,
0x9e4c1269baa4bf37L, 0x17efee45b0dee640L, 0x1d95b0a5fcf90bc6L, 0x93cbe0b699c2585dL, 0x65fa4f227a2b6d79L,
0xd5f9e858292504d5L, 0xc2b5a03f71471a6fL, 0x59300222b4561e00L, 0xce2f8642ca0712dcL, 0x7ca9723fbb2e8988L,
0x2785338347f2ba08L, 0xc61bb3a141e50e8cL, 0x150f361dab9dec26L, 0x9f6a419d382595f4L, 0x64a53dc924fe7ac9L,
0x142de49fff7a7c3dL, 0x0c335248857fa9e7L, 0x0a9c32d5eae45305L, 0xe6c42178c4bbb92eL, 0x71f1ce2490d20b07L,
0xf1bcc3d275afe51aL, 0xe728e8c83c334074L, 0x96fbf83a12884624L, 0x81a1549fd6573da5L, 0x5fa7867caf35e149L,
0x56986e2ef3ed091bL, 0x917f1dd5f8886c61L, 0xd20d8c88c8ffe65fL, 0x31d71dce64b2c310L, 0xf165b587df898190L,
0xa57e6339dd2cf3a0L, 0x1ef6e6dbb1961ec9L, 0x70cc73d90bc26e24L, 0xe21a6b35df0c3ad7L, 0x003a93d8b2806962L,
0x1c99ded33cb890a1L, 0xcf3145de0add4289L, 0xd0e4427a5514fb72L, 0x77c621cc9fb3a483L, 0x67a34dac4356550bL,
0xf8d626aaaf278509L,
)
def hash(context: GameContext): Long =
val piecesHash = context.board.pieces.foldLeft(0L) { case (h, (sq, piece)) =>
h ^ Random(pieceIndex(piece) * 64 + squareIndex(sq))
}
val h1 = if context.castlingRights.whiteKingSide then piecesHash ^ Random(768) else piecesHash
val h2 = if context.castlingRights.whiteQueenSide then h1 ^ Random(769) else h1
val h3 = if context.castlingRights.blackKingSide then h2 ^ Random(770) else h2
val h4 = if context.castlingRights.blackQueenSide then h3 ^ Random(771) else h3
val h5 = context.enPassantSquare.fold(h4) { sq =>
if canCaptureEnPassant(context, sq) then h4 ^ Random(772 + sq.file.ordinal) else h4
}
if context.turn == Color.White then h5 ^ Random(780) else h5
private def pieceIndex(piece: Piece): Int =
val typeIdx = piece.pieceType match
case PieceType.Pawn => 0
case PieceType.Knight => 1
case PieceType.Bishop => 2
case PieceType.Rook => 3
case PieceType.Queen => 4
case PieceType.King => 5
val colorOffset = if piece.color == Color.White then 1 else 0
typeIdx * 2 + colorOffset
private def squareIndex(sq: Square): Int =
sq.file.ordinal + 8 * sq.rank.ordinal
private def canCaptureEnPassant(context: GameContext, epSquare: Square): Boolean =
val pawn = Piece(context.turn, PieceType.Pawn)
val rankDelta = if context.turn == Color.White then -1 else 1
List(-1, 1).exists { fileDelta =>
epSquare
.offset(fileDelta, rankDelta)
.flatMap(context.board.pieces.get)
.contains(pawn)
}
@@ -0,0 +1,125 @@
package de.nowchess.bot.util
import de.nowchess.api.board.{Color, Piece, PieceType, Rank, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
import scala.util.Random
object ZobristHash:
// 768 entries: 64 squares * 12 piece variants (2 colors * 6 piece types)
private val pieceRands: Array[Long] = Array.ofDim(768)
// Side-to-move: XOR when Black to move
private val sideToMoveRand: Long = Random(0x1badb002L).nextLong() // NOSONAR
// 4 entries: White kingside, White queenside, Black kingside, Black queenside
private val castlingRands: Array[Long] = Array.ofDim(4)
// 8 entries: one per file (a-h)
private val enPassantRands: Array[Long] = Array.ofDim(8)
// Initialize all random values using a seeded RNG for reproducibility
locally:
val rng = Random(0x1badb002L) // NOSONAR
for i <- 0 until 768 do pieceRands(i) = rng.nextLong()
for i <- 0 until 4 do castlingRands(i) = rng.nextLong()
for i <- 0 until 8 do enPassantRands(i) = rng.nextLong()
/** Compute a 64-bit Zobrist hash for a GameContext. */
def hash(context: GameContext): Long =
val piecesHash = context.board.pieces.foldLeft(0L) { case (h, (square, piece)) =>
val squareIndex = square.rank.ordinal * 8 + square.file.ordinal
val colorIndex = if piece.color == Color.White then 0 else 1
val pieceIndex = colorIndex * 6 + piece.pieceType.ordinal
h ^ pieceRands(squareIndex * 12 + pieceIndex)
}
val h1 = if context.turn == Color.Black then piecesHash ^ sideToMoveRand else piecesHash
val h2 = if context.castlingRights.whiteKingSide then h1 ^ castlingRands(0) else h1
val h3 = if context.castlingRights.whiteQueenSide then h2 ^ castlingRands(1) else h2
val h4 = if context.castlingRights.blackKingSide then h3 ^ castlingRands(2) else h3
val h5 = if context.castlingRights.blackQueenSide then h4 ^ castlingRands(3) else h4
context.enPassantSquare.fold(h5)(sq => h5 ^ enPassantRands(sq.file.ordinal))
def nextHash(context: GameContext, currentHash: Long, move: Move, nextContext: GameContext): Long =
val h0 = currentHash ^ sideToMoveRand
val h1 = toggleCastling(h0, context, nextContext)
val h2 = toggleEnPassant(h1, context, nextContext)
move.moveType match
case MoveType.CastleKingside | MoveType.CastleQueenside =>
applyCastleDelta(h2, context.turn, move.moveType == MoveType.CastleKingside)
case MoveType.EnPassant =>
applyEnPassantDelta(h2, context, move)
case MoveType.Promotion(piece) =>
applyPromotionDelta(h2, context, move, piece)
case MoveType.Normal(_) =>
applyNormalDelta(h2, context, move)
private def applyNormalDelta(h0: Long, context: GameContext, move: Move): Long =
context.board.pieceAt(move.from).fold(h0) { mover =>
val h1 = h0 ^ pieceKey(move.from, mover)
val h2 = context.board.pieceAt(move.to).fold(h1)(captured => h1 ^ pieceKey(move.to, captured))
h2 ^ pieceKey(move.to, mover)
}
private def applyPromotionDelta(h0: Long, context: GameContext, move: Move, promoted: PromotionPiece): Long =
context.board.pieceAt(move.from).fold(h0) { pawn =>
val h1 = h0 ^ pieceKey(move.from, pawn)
val h2 = context.board.pieceAt(move.to).fold(h1)(captured => h1 ^ pieceKey(move.to, captured))
h2 ^ pieceKey(move.to, Piece(context.turn, promotedPieceType(promoted)))
}
private def applyEnPassantDelta(h0: Long, context: GameContext, move: Move): Long =
context.board.pieceAt(move.from).fold(h0) { pawn =>
val capturedSquare = Square(move.to.file, move.from.rank)
val h1 = h0 ^ pieceKey(move.from, pawn)
val h2 = context.board.pieceAt(capturedSquare).fold(h1)(captured => h1 ^ pieceKey(capturedSquare, captured))
h2 ^ pieceKey(move.to, pawn)
}
private def applyCastleDelta(h0: Long, color: Color, kingside: Boolean): Long =
val rank = if color == Color.White then Rank.R1 else Rank.R8
val (kingFrom, kingTo, rookFrom, rookTo) =
if kingside then
(
Square(de.nowchess.api.board.File.E, rank),
Square(de.nowchess.api.board.File.G, rank),
Square(de.nowchess.api.board.File.H, rank),
Square(de.nowchess.api.board.File.F, rank),
)
else
(
Square(de.nowchess.api.board.File.E, rank),
Square(de.nowchess.api.board.File.C, rank),
Square(de.nowchess.api.board.File.A, rank),
Square(de.nowchess.api.board.File.D, rank),
)
val king = Piece(color, PieceType.King)
val rook = Piece(color, PieceType.Rook)
h0 ^ pieceKey(kingFrom, king) ^ pieceKey(kingTo, king) ^ pieceKey(rookFrom, rook) ^ pieceKey(rookTo, rook)
private def promotedPieceType(promotion: PromotionPiece): PieceType = promotion match
case PromotionPiece.Knight => PieceType.Knight
case PromotionPiece.Bishop => PieceType.Bishop
case PromotionPiece.Rook => PieceType.Rook
case PromotionPiece.Queen => PieceType.Queen
private def toggleCastling(h0: Long, before: GameContext, after: GameContext): Long =
val h1 =
if before.castlingRights.whiteKingSide != after.castlingRights.whiteKingSide then h0 ^ castlingRands(0) else h0
val h2 =
if before.castlingRights.whiteQueenSide != after.castlingRights.whiteQueenSide then h1 ^ castlingRands(1) else h1
val h3 =
if before.castlingRights.blackKingSide != after.castlingRights.blackKingSide then h2 ^ castlingRands(2) else h2
if before.castlingRights.blackQueenSide != after.castlingRights.blackQueenSide then h3 ^ castlingRands(3) else h3
private def toggleEnPassant(h0: Long, before: GameContext, after: GameContext): Long =
val h1 = before.enPassantSquare.fold(h0)(sq => h0 ^ enPassantRands(sq.file.ordinal))
after.enPassantSquare.fold(h1)(sq => h1 ^ enPassantRands(sq.file.ordinal))
private def pieceKey(square: Square, piece: Piece): Long =
val squareIndex = square.rank.ordinal * 8 + square.file.ordinal
val colorIndex = if piece.color == Color.White then 0 else 1
val pieceIndex = colorIndex * 6 + piece.pieceType.ordinal
pieceRands(squareIndex * 12 + pieceIndex)
@@ -0,0 +1,336 @@
package de.nowchess.bot
import de.nowchess.api.board.{Board, Color, File, Piece, Rank, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
import de.nowchess.bot.ai.Evaluation
import de.nowchess.bot.bots.classic.EvaluationClassic
import de.nowchess.bot.logic.AlphaBetaSearch
import de.nowchess.rules.RuleSet
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import de.nowchess.rules.sets.DefaultRules
import java.util.concurrent.atomic.AtomicBoolean
class AlphaBetaSearchTest extends AnyFunSuite with Matchers:
private object ZeroEval extends Evaluation:
val CHECKMATE_SCORE: Int = 1_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 0
test("bestMove on initial position returns a move"):
val search = AlphaBetaSearch(DefaultRules, weights = EvaluationClassic)
val move = search.bestMove(GameContext.initial, maxDepth = 2)
move should not be None
test("bestMove on a position with one legal move returns that move"):
// Create a simple position: White king on h1, Black rook on a2
// (set up so there's only one legal move available)
// For simplicity, just test that a position with forced mate returns a move
val search = AlphaBetaSearch(DefaultRules, weights = EvaluationClassic)
val context = GameContext.initial
val move = search.bestMove(context, maxDepth = 1)
move should not be None
test("bestMoveWithTime skips excluded root moves"):
val blockedMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4), MoveType.Normal())
val stubRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = List(blockedMove)
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(stubRules, weights = EvaluationClassic)
val move = search.bestMoveWithTime(GameContext.initial, 1000L, Set(blockedMove))
move should be(None)
test("bestMove returns None for initial position has no legal moves"):
// Use a stub RuleSet that returns empty legal moves
val stubRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = Nil
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = true
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(stubRules, weights = EvaluationClassic)
val move = search.bestMove(GameContext.initial, maxDepth = 2)
move should be(None)
test("transposition table is cleared at start of bestMove"):
val search = AlphaBetaSearch(DefaultRules, weights = EvaluationClassic)
val context = GameContext.initial
// Call bestMove twice and verify both work independently
val move1 = search.bestMove(context, maxDepth = 1)
val move2 = search.bestMove(context, maxDepth = 1)
move1 should be(move2)
test("quiescence captures are ordered"):
val search = AlphaBetaSearch(DefaultRules, weights = EvaluationClassic)
// A position with multiple captures to verify quiescence orders them
val context = GameContext.initial
val move = search.bestMove(context, maxDepth = 2)
// Just verify it completes without error
move.isDefined should be(true)
test("search respects alpha-beta bounds"):
// This is implicit in the structure, but we test via behavior
val search = AlphaBetaSearch(DefaultRules, weights = EvaluationClassic)
val context = GameContext.initial
val move = search.bestMove(context, maxDepth = 3)
move should not be None
test("iterative deepening finds a move at each depth"):
val search = AlphaBetaSearch(DefaultRules, weights = EvaluationClassic)
val context = GameContext.initial
// Searching to depth 3 should use iterative deepening (depths 1, 2, 3)
val move = search.bestMove(context, maxDepth = 3)
move should not be None
test("stalemate position returns score 0"):
// Create a stalemate stub: white to move, no legal moves, not checkmate
val stalematRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = Nil
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = true
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(stalematRules, weights = EvaluationClassic)
val move = search.bestMove(GameContext.initial, maxDepth = 1)
move should be(None)
test("insufficient material returns score 0"):
val insufficientRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = Nil
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = true
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(insufficientRules, weights = EvaluationClassic)
val move = search.bestMove(GameContext.initial, maxDepth = 1)
move should be(None)
test("fifty move rule returns score 0"):
val fiftyMoveRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = Nil
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = true
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(fiftyMoveRules, weights = EvaluationClassic)
val move = search.bestMove(GameContext.initial, maxDepth = 1)
move should be(None)
test("capture moves are recognized in quiescence search"):
// Create a position with a capture available
val board = Board(
Map(
Square(File.E, Rank.R4) -> Piece.WhiteQueen,
Square(File.E, Rank.R5) -> Piece.BlackPawn,
),
)
val context = GameContext.initial.withBoard(board)
val captureMove = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R5), MoveType.Normal(true))
val rulesWithCapture = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = List(captureMove)
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(rulesWithCapture, weights = EvaluationClassic)
val move = search.bestMove(context, maxDepth = 1)
move should be(Some(captureMove))
test("non-capture moves are not included in quiescence"):
val quietMove = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R5), MoveType.Normal())
val rulesQuiet = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = List(quietMove)
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(rulesQuiet, weights = EvaluationClassic)
val move = search.bestMove(GameContext.initial, maxDepth = 1)
move should be(Some(quietMove)) // bestMove returns the quiet move since it's the only legal move
test("default constructor uses DefaultRules"):
val search = AlphaBetaSearch(weights = EvaluationClassic)
val move = search.bestMove(GameContext.initial, maxDepth = 1)
move should not be None
test("bestMoveWithTime without excluded moves overload"):
val search = AlphaBetaSearch(DefaultRules, weights = EvaluationClassic)
val move = search.bestMoveWithTime(GameContext.initial, 500L)
move should not be None
test("en passant move is treated as capture in quiescence"):
val epMove = Move(Square(File.E, Rank.R5), Square(File.D, Rank.R6), MoveType.EnPassant)
val board = Board(
Map(
Square(File.E, Rank.R5) -> Piece.WhitePawn,
Square(File.D, Rank.R5) -> Piece.BlackPawn,
),
)
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White)
val epRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = List(epMove)
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(epRules, weights = EvaluationClassic)
search.bestMove(ctx, maxDepth = 1) should be(Some(epMove))
test("promotion capture move is treated as capture in quiescence"):
val promoCapture = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Queen))
val board = Board(
Map(
Square(File.E, Rank.R7) -> Piece.WhitePawn,
Square(File.D, Rank.R8) -> Piece.BlackRook,
),
)
val ctx = GameContext.initial.withBoard(board).withTurn(Color.White)
val promoCaptureRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = List(promoCapture)
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(promoCaptureRules, weights = EvaluationClassic)
search.bestMove(ctx, maxDepth = 1) should be(Some(promoCapture))
test("draw when isInsufficientMaterial with legal moves present"):
val legalMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4), MoveType.Normal())
val drawRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = List(legalMove)
def legalMoves(context: GameContext)(square: Square): List[Move] = List(legalMove)
def allLegalMoves(context: GameContext): List[Move] = List(legalMove)
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = true
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(drawRules, weights = EvaluationClassic)
search.bestMove(GameContext.initial, maxDepth = 2) should be(None)
test("repetition cutoff is reached on forced self-loop positions"):
// Use a no-op move from an empty square so nextHash alternates between a tiny set of hashes.
// This forces repetition counts >= 3 and exercises immediateSearchResult's repetition cutoff.
val loopMove = Move(Square(File.A, Rank.R3), Square(File.A, Rank.R4), MoveType.Normal())
val loopRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = List(loopMove)
def legalMoves(context: GameContext)(square: Square): List[Move] = List(loopMove)
def allLegalMoves(context: GameContext): List[Move] = List(loopMove)
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val search = AlphaBetaSearch(loopRules, weights = ZeroEval)
search.bestMove(GameContext.initial, maxDepth = 8) should be(Some(loopMove))
test("quiescence returns checkmate score when side is in check and has no tactical moves"):
val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true))
val qRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context)
def legalMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context)
def allLegalMoves(context: GameContext): List[Move] =
context.moves.length match
case 0 => List(rootMove)
case 1 => List(capMove)
case _ => Nil
def isCheck(context: GameContext): Boolean =
context.moves.length >= 2
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext =
context.copy(turn = context.turn.opposite, moves = context.moves :+ move)
val search = AlphaBetaSearch(qRules, weights = ZeroEval)
search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove))
test("quiescence depth-limit in-check branch is exercised"):
val rootMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
val capMove = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R3), MoveType.Normal(true))
val firstChildCheckCall = AtomicBoolean(true)
val deepQRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context)
def legalMoves(context: GameContext)(square: Square): List[Move] = allLegalMoves(context)
def allLegalMoves(context: GameContext): List[Move] =
if context.moves.isEmpty then List(rootMove) else List(capMove)
def isCheck(context: GameContext): Boolean =
if context.moves.length == 1 && firstChildCheckCall.compareAndSet(true, false) then false
else context.moves.nonEmpty
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext =
context.copy(turn = context.turn.opposite, moves = context.moves :+ move)
val search = AlphaBetaSearch(deepQRules, weights = ZeroEval)
search.bestMove(GameContext.initial, maxDepth = 1) should be(Some(rootMove))
@@ -0,0 +1,18 @@
package de.nowchess.bot
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
class BotControllerTest extends AnyFunSuite with Matchers:
test("BotController can be instantiated"):
BotController.listBots should not be empty
test("getBot returns known bots by name"):
BotController.getBot("easy") should not be None
BotController.getBot("medium") should not be None
BotController.getBot("hard") should not be None
BotController.getBot("expert") should not be None
test("getBot returns None for unknown bot"):
BotController.getBot("unknown") should be(None)
@@ -0,0 +1,14 @@
package de.nowchess.bot
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
class BotDifficultyTest extends AnyFunSuite with Matchers:
test("all difficulty values are defined"):
val difficulties = BotDifficulty.values
difficulties should contain(BotDifficulty.Easy)
difficulties should contain(BotDifficulty.Medium)
difficulties should contain(BotDifficulty.Hard)
difficulties should contain(BotDifficulty.Expert)
difficulties should have length 4
@@ -0,0 +1,30 @@
package de.nowchess.bot
import de.nowchess.api.board.{File, Rank, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
class BotMoveRepetitionTest extends AnyFunSuite with Matchers:
private val move1 = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4), MoveType.Normal())
private val move2 = Move(Square(File.D, Rank.R2), Square(File.D, Rank.R4), MoveType.Normal())
test("filterAllowed passes through moves when none are blocked"):
val ctx = GameContext.initial
val allowed = BotMoveRepetition.filterAllowed(ctx, List(move1, move2))
allowed should contain(move1)
allowed should contain(move2)
test("filterAllowed removes the move repeated three times"):
val ctx = GameContext.initial.copy(moves = List(move1, move1, move1))
val allowed = BotMoveRepetition.filterAllowed(ctx, List(move1, move2))
allowed should not contain move1
allowed should contain(move2)
test("filterAllowed keeps all moves when repetition is below threshold"):
val ctx = GameContext.initial.copy(moves = List(move1, move1))
val allowed = BotMoveRepetition.filterAllowed(ctx, List(move1, move2))
allowed should contain(move1)
allowed should contain(move2)
@@ -0,0 +1,98 @@
package de.nowchess.bot
import de.nowchess.api.board.Square
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.Move
import de.nowchess.api.move.MoveType
import de.nowchess.bot.bots.ClassicalBot
import de.nowchess.rules.RuleSet
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import de.nowchess.rules.sets.DefaultRules
class ClassicalBotTest extends AnyFunSuite with Matchers:
test("name returns expected format"):
val botEasy = ClassicalBot(BotDifficulty.Easy)
botEasy.name should include("ClassicalBot")
botEasy.name should include("Easy")
val botMedium = ClassicalBot(BotDifficulty.Medium)
botMedium.name should include("Medium")
test("nextMove on initial position returns a move"):
val bot = ClassicalBot(BotDifficulty.Easy)
val move = bot.nextMove(GameContext.initial)
move should not be None
test("nextMove returns None for position with no legal moves"):
val stubRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = Nil
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = true
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
val move = bot.nextMove(GameContext.initial)
move should be(None)
test("all BotDifficulty values work"):
BotDifficulty.values.foreach { difficulty =>
val bot = ClassicalBot(difficulty)
val move = bot.nextMove(GameContext.initial)
// All difficulties should return a move on the initial position
move should not be None
}
test("custom RuleSet injection works"):
val moveToReturn = Move(
de.nowchess.api.board.Square(de.nowchess.api.board.File.E, de.nowchess.api.board.Rank.R2),
de.nowchess.api.board.Square(de.nowchess.api.board.File.E, de.nowchess.api.board.Rank.R4),
de.nowchess.api.move.MoveType.Normal(),
)
val stubRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = List(moveToReturn)
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
val move = bot.nextMove(GameContext.initial)
move should be(Some(moveToReturn))
test("nextMove skips a move repeated three times in a row"):
val repeatedMove = Move(
de.nowchess.api.board.Square(de.nowchess.api.board.File.E, de.nowchess.api.board.Rank.R2),
de.nowchess.api.board.Square(de.nowchess.api.board.File.E, de.nowchess.api.board.Rank.R4),
MoveType.Normal(),
)
val stubRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = List(repeatedMove)
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val context = GameContext.initial.copy(moves = List(repeatedMove, repeatedMove, repeatedMove))
val bot = ClassicalBot(BotDifficulty.Easy, stubRules)
bot.nextMove(context) should be(None)
@@ -0,0 +1,142 @@
package de.nowchess.bot
import de.nowchess.api.board.{Board, Color, File, Piece, PieceType, Rank, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.board.Board
import de.nowchess.bot.bots.classic.EvaluationClassic
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
class EvaluationTest extends AnyFunSuite with Matchers:
test("initial position evaluates to tempo bonus"):
val eval = EvaluationClassic.evaluate(GameContext.initial)
eval should equal(10) // TEMPO_BONUS only
test("remove white queen gives negative evaluation"):
val initial = GameContext.initial
val board = initial.board
val emptySquare = Square(File.D, Rank.R1)
val boardWithoutQueen = board.pieces.filter((sq, _) => sq != emptySquare)
val newContext = initial.withBoard(Board(boardWithoutQueen))
val eval = EvaluationClassic.evaluate(newContext)
eval should be < 0
test("remove black queen gives positive evaluation"):
val initial = GameContext.initial
val board = initial.board
val emptySquare = Square(File.D, Rank.R8)
val boardWithoutQueen = board.pieces.filter((sq, _) => sq != emptySquare)
val newContext = initial.withBoard(Board(boardWithoutQueen))
val eval = EvaluationClassic.evaluate(newContext)
eval should be > 0
test("different piece-square bonuses are applied"):
// Knight on d4 (center) vs knight on a1 (corner) - center should be better
val knightD4Board = Board(Map(Square(File.D, Rank.R4) -> Piece.WhiteKnight))
val knightA1Board = Board(Map(Square(File.A, Rank.R1) -> Piece.WhiteKnight))
val knightD4 = GameContext.initial.withBoard(knightD4Board)
val knightA1 = GameContext.initial.withBoard(knightA1Board)
val eval1 = EvaluationClassic.evaluate(knightD4)
val eval2 = EvaluationClassic.evaluate(knightA1)
eval1 should be > eval2 // d4 (center) is better than a1 (corner) for knight
test("all piece types are in material map"):
PieceType.values.length should be > 0
// Just verify evaluate works with all piece types
val eval = EvaluationClassic.evaluate(GameContext.initial)
eval should not be (EvaluationClassic.CHECKMATE_SCORE)
test("CHECKMATE_SCORE and DRAW_SCORE are accessible"):
EvaluationClassic.CHECKMATE_SCORE should equal(10_000_000)
EvaluationClassic.DRAW_SCORE should equal(0)
test("active knight (center) scores higher than passive knight (corner)"):
val knightD4Board = Board(Map(Square(File.D, Rank.R4) -> Piece.WhiteKnight))
val knightA1Board = Board(Map(Square(File.A, Rank.R1) -> Piece.WhiteKnight))
val knightD4Context = GameContext.initial.withBoard(knightD4Board)
val knightA1Context = GameContext.initial.withBoard(knightA1Board)
val evalD4 = EvaluationClassic.evaluate(knightD4Context)
val evalA1 = EvaluationClassic.evaluate(knightA1Context)
evalD4 should be > evalA1 // Knight on d4 (center, more mobility) should score higher
test("bishop pair scores higher than bishop + knight"):
val bishopPairBoard = Board(
Map(
Square(File.C, Rank.R1) -> Piece.WhiteBishop,
Square(File.F, Rank.R1) -> Piece.WhiteBishop,
),
)
val bishopKnightBoard = Board(
Map(
Square(File.C, Rank.R1) -> Piece.WhiteBishop,
Square(File.B, Rank.R1) -> Piece.WhiteKnight,
),
)
val pairContext = GameContext.initial.withBoard(bishopPairBoard)
val knightContext = GameContext.initial.withBoard(bishopKnightBoard)
val evalPair = EvaluationClassic.evaluate(pairContext)
val evalKnight = EvaluationClassic.evaluate(knightContext)
evalPair should be > evalKnight // Bishop pair should score higher
test("rook on 7th rank scores higher than rook on 4th rank"):
val rook7thBoard = Board(Map(Square(File.A, Rank.R7) -> Piece.WhiteRook))
val rook4thBoard = Board(Map(Square(File.A, Rank.R4) -> Piece.WhiteRook))
val rook7thContext = GameContext.initial.withBoard(rook7thBoard)
val rook4thContext = GameContext.initial.withBoard(rook4thBoard)
val eval7th = EvaluationClassic.evaluate(rook7thContext)
val eval4th = EvaluationClassic.evaluate(rook4thContext)
eval7th should be > eval4th // Rook on 7th rank should score higher
test("enemy rook on 7th rank is penalised"):
// Black rook on rank 2 (7th for black) with white to move — hits the enemy branch
val board = Board(Map(Square(File.A, Rank.R2) -> Piece.BlackRook))
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val eval = EvaluationClassic.evaluate(context)
eval should be < 0 // disadvantageous for white
test("king at edge rank yields zero king-shield bonus"):
// White king on rank 8 — shieldRank would be 9, out of bounds → guard fires
val board = Board(Map(Square(File.H, Rank.R8) -> Piece.WhiteKing, Square(File.H, Rank.R1) -> Piece.BlackKing))
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
// Evaluating does not throw and uses the guard path
noException should be thrownBy EvaluationClassic.evaluate(context)
test("endgame bonus is applied when material is low"):
// Kings + one rook: phase = 2 < 8, triggers endgameBonus with friendly material advantage
val board = Board(
Map(
Square(File.D, Rank.R4) -> Piece.WhiteKing,
Square(File.D, Rank.R6) -> Piece.BlackKing,
Square(File.A, Rank.R1) -> Piece.WhiteRook,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
noException should be thrownBy EvaluationClassic.evaluate(context)
test("endgame bonus else branch when material is equal"):
// Both sides have a rook: friendlyMaterial == enemyMaterial → edgeBonus = 0
val board = Board(
Map(
Square(File.D, Rank.R4) -> Piece.WhiteKing,
Square(File.D, Rank.R6) -> Piece.BlackKing,
Square(File.A, Rank.R1) -> Piece.WhiteRook,
Square(File.H, Rank.R8) -> Piece.BlackRook,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
noException should be thrownBy EvaluationClassic.evaluate(context)
test("passed pawn bonus is applied in endgame"):
// No enemy pawns anywhere → white pawn on e5 is passed; phase = 0 → endgame → egPassedPawnBonus
val board = Board(
Map(
Square(File.E, Rank.R5) -> Piece.WhitePawn,
Square(File.E, Rank.R1) -> Piece.WhiteKing,
Square(File.E, Rank.R8) -> Piece.BlackKing,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val eval = EvaluationClassic.evaluate(context)
eval should be > 0
@@ -0,0 +1,160 @@
package de.nowchess.bot
import de.nowchess.api.board.{File, Rank, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType}
import de.nowchess.bot.ai.Evaluation
import de.nowchess.bot.bots.HybridBot
import de.nowchess.bot.util.{PolyglotBook, PolyglotHash}
import de.nowchess.rules.RuleSet
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import java.io.{DataOutputStream, FileOutputStream}
import java.nio.file.Files
import java.util.concurrent.atomic.AtomicBoolean
import scala.util.Using
class HybridBotTest extends AnyFunSuite with Matchers:
test("HybridBot name includes difficulty"):
val bot = HybridBot(BotDifficulty.Easy)
bot.name should include("HybridBot")
bot.name should include("Easy")
test("HybridBot nextMove returns a move on the initial position"):
val bot = HybridBot(BotDifficulty.Easy)
val move = bot.nextMove(GameContext.initial)
move should not be None
test("HybridBot nextMove returns None when no legal moves"):
val noMovesRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = Nil
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = true
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val bot = HybridBot(BotDifficulty.Easy, noMovesRules)
val move = bot.nextMove(GameContext.initial)
move should be(None)
test("HybridBot with empty book falls through to search"):
val emptyBook = PolyglotBook("/nonexistent/book.bin")
val bot = HybridBot(BotDifficulty.Easy, book = Some(emptyBook))
val move = bot.nextMove(GameContext.initial)
move should not be None
test("HybridBot skips move repeated three times"):
val repeatedMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4), MoveType.Normal())
val onlyMoveRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = Nil
def legalMoves(context: GameContext)(square: Square): List[Move] = Nil
def allLegalMoves(context: GameContext): List[Move] = List(repeatedMove)
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext = context
val ctx = GameContext.initial.copy(moves = List(repeatedMove, repeatedMove, repeatedMove))
val bot = HybridBot(BotDifficulty.Easy, onlyMoveRules)
bot.nextMove(ctx) should be(None)
test("HybridBot uses book move when available"):
val tempFile = Files.createTempFile("hybrid_book", ".bin")
try
val ctx = GameContext.initial
val hash = PolyglotHash.hash(ctx)
val e2e4: Short = (4 | (3 << 3) | (4 << 6) | (1 << 9)).toShort
Using(DataOutputStream(FileOutputStream(tempFile.toFile))) { dos =>
dos.writeLong(hash)
dos.writeShort(e2e4)
dos.writeShort(100)
dos.writeInt(0)
}.get
val book = PolyglotBook(tempFile.toString)
val bot = HybridBot(BotDifficulty.Easy, book = Some(book))
bot.nextMove(ctx) should be(Some(Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4), MoveType.Normal())))
finally Files.deleteIfExists(tempFile)
test("HybridBot reports veto when classical and NNUE differ above threshold"):
val forcedMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
val oneMoveRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = List(forcedMove)
def legalMoves(context: GameContext)(square: Square): List[Move] = List(forcedMove)
def allLegalMoves(context: GameContext): List[Move] = List(forcedMove)
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext =
context.copy(turn = context.turn.opposite, moves = context.moves :+ move)
object LowNnue extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 0
object HighClassic extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 10_000
val reported = AtomicBoolean(false)
val bot = HybridBot(
BotDifficulty.Easy,
rules = oneMoveRules,
nnueEvaluation = LowNnue,
classicalEvaluation = HighClassic,
vetoReporter = _ => reported.set(true),
)
bot.nextMove(GameContext.initial) should be(Some(forcedMove))
reported.get should be(true)
test("HybridBot default veto reporter prints when threshold is exceeded"):
val forcedMove = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
val oneMoveRules = new RuleSet:
def candidateMoves(context: GameContext)(square: Square): List[Move] = List(forcedMove)
def legalMoves(context: GameContext)(square: Square): List[Move] = List(forcedMove)
def allLegalMoves(context: GameContext): List[Move] = List(forcedMove)
def isCheck(context: GameContext): Boolean = false
def isCheckmate(context: GameContext): Boolean = false
def isStalemate(context: GameContext): Boolean = false
def isInsufficientMaterial(context: GameContext): Boolean = false
def isFiftyMoveRule(context: GameContext): Boolean = false
def isThreefoldRepetition(context: GameContext): Boolean = false
def applyMove(context: GameContext)(move: Move): GameContext =
context.copy(turn = context.turn.opposite, moves = context.moves :+ move)
object LowNnue extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 0
object HighClassic extends Evaluation:
val CHECKMATE_SCORE: Int = 10_000_000
val DRAW_SCORE: Int = 0
def evaluate(context: GameContext): Int = 10_000
val bot = HybridBot(
BotDifficulty.Easy,
rules = oneMoveRules,
nnueEvaluation = LowNnue,
classicalEvaluation = HighClassic,
)
val printed = Console.withOut(new java.io.ByteArrayOutputStream()) {
bot.nextMove(GameContext.initial)
}
printed should be(Some(forcedMove))
@@ -0,0 +1,219 @@
package de.nowchess.bot
import de.nowchess.api.board.{Board, Color, File, Piece, Rank, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
import de.nowchess.bot.logic.MoveOrdering
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
class MoveOrderingTest extends AnyFunSuite with Matchers:
test("queen capture ranks higher than rook capture"):
val board = Board(
Map(
Square(File.E, Rank.R4) -> Piece.WhiteQueen,
Square(File.E, Rank.R5) -> Piece.BlackQueen,
Square(File.E, Rank.R6) -> Piece.BlackRook,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val queenCapture = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R5), MoveType.Normal(true))
val rookCapture = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R6), MoveType.Normal(true))
val queenScore = MoveOrdering.score(context, queenCapture, None)
val rookScore = MoveOrdering.score(context, rookCapture, None)
queenScore should be > rookScore
test("quiet move ranks lower than capture"):
val board = Board(
Map(
Square(File.E, Rank.R4) -> Piece.WhiteQueen,
Square(File.E, Rank.R5) -> Piece.BlackPawn,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val capture = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R5), MoveType.Normal(true))
val quiet = Move(Square(File.E, Rank.R4), Square(File.D, Rank.R5))
val captureScore = MoveOrdering.score(context, capture, None)
val quietScore = MoveOrdering.score(context, quiet, None)
captureScore should be > quietScore
test("TT best move ranks first"):
val board = Board(
Map(
Square(File.E, Rank.R4) -> Piece.WhiteQueen,
Square(File.E, Rank.R5) -> Piece.BlackPawn,
Square(File.D, Rank.R5) -> Piece.BlackPawn,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val bestMove = Move(Square(File.E, Rank.R4), Square(File.D, Rank.R5), MoveType.Normal(true))
val otherCapture = Move(Square(File.E, Rank.R4), Square(File.E, Rank.R5), MoveType.Normal(true))
val bestScore = MoveOrdering.score(context, bestMove, Some(bestMove))
val otherScore = MoveOrdering.score(context, otherCapture, Some(bestMove))
bestScore should equal(Int.MaxValue)
otherScore should be < bestScore
test("promotion to queen ranks high"):
val board = Board(
Map(
Square(File.E, Rank.R7) -> Piece.WhitePawn,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val promotionQueen =
Move(Square(File.E, Rank.R7), Square(File.E, Rank.R8), MoveType.Promotion(PromotionPiece.Queen))
val promotionKnight =
Move(Square(File.E, Rank.R7), Square(File.E, Rank.R8), MoveType.Promotion(PromotionPiece.Knight))
val queenScore = MoveOrdering.score(context, promotionQueen, None)
val knightScore = MoveOrdering.score(context, promotionKnight, None)
queenScore should be > knightScore
queenScore should be > 100_000 // Queen promotion score is > 100_000
test("en passant is treated as capture"):
val board = Board(
Map(
Square(File.E, Rank.R5) -> Piece.WhitePawn,
Square(File.D, Rank.R5) -> Piece.BlackPawn,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val epCapture = Move(Square(File.E, Rank.R5), Square(File.D, Rank.R6), MoveType.EnPassant)
val quiet = Move(Square(File.E, Rank.R5), Square(File.E, Rank.R6))
val epScore = MoveOrdering.score(context, epCapture, None)
val quietScore = MoveOrdering.score(context, quiet, None)
epScore should be > quietScore
test("sort returns moves ordered by score"):
val board = Board(
Map(
Square(File.E, Rank.R4) -> Piece.WhiteQueen,
Square(File.E, Rank.R5) -> Piece.BlackPawn,
Square(File.D, Rank.R5) -> Piece.BlackRook,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val moves = List(
Move(Square(File.E, Rank.R4), Square(File.D, Rank.R5), MoveType.Normal(true)), // Rook capture
Move(Square(File.E, Rank.R4), Square(File.E, Rank.R5), MoveType.Normal(true)), // Pawn capture
Move(Square(File.E, Rank.R4), Square(File.E, Rank.R6)), // Quiet
)
val sorted = MoveOrdering.sort(context, moves, None)
// Rook capture should be first (higher victim value)
sorted.head.to should equal(Square(File.D, Rank.R5))
// Pawn capture should be second
sorted(1).to should equal(Square(File.E, Rank.R5))
// Quiet should be last
sorted.last.to should equal(Square(File.E, Rank.R6))
test("castling move is quiet (not capture)"):
val board = Board(
Map(
Square(File.E, Rank.R1) -> Piece.WhiteKing,
Square(File.H, Rank.R1) -> Piece.WhiteRook,
),
)
val context = GameContext.initial.withBoard(board)
val castleMove = Move(Square(File.E, Rank.R1), Square(File.G, Rank.R1), MoveType.CastleKingside)
val score = MoveOrdering.score(context, castleMove, None)
score should equal(0) // Quiet move
test("all MoveType variants are handled in victimValue"):
val board = Board(
Map(
Square(File.E, Rank.R1) -> Piece.WhiteKing,
Square(File.H, Rank.R1) -> Piece.WhiteRook,
Square(File.E, Rank.R2) -> Piece.WhitePawn,
),
)
val context = GameContext.initial.withBoard(board)
// Test castling queenside - should have victim value 0
val castleQs = Move(Square(File.E, Rank.R1), Square(File.C, Rank.R1), MoveType.CastleQueenside)
val scoreQs = MoveOrdering.score(context, castleQs, None)
scoreQs should equal(0)
test("attackerValue covers all piece types"):
val board = Board(
Map(
Square(File.A, Rank.R1) -> Piece.WhiteRook,
Square(File.B, Rank.R1) -> Piece.WhiteKnight,
Square(File.C, Rank.R1) -> Piece.WhiteBishop,
Square(File.D, Rank.R1) -> Piece.WhiteQueen,
Square(File.E, Rank.R1) -> Piece.WhiteKing,
Square(File.F, Rank.R2) -> Piece.WhitePawn,
),
)
val context = GameContext.initial.withBoard(board)
// Create captures with each piece type
val rookCapture = Move(Square(File.A, Rank.R1), Square(File.A, Rank.R8), MoveType.Normal(true))
val knightCapture = Move(Square(File.B, Rank.R1), Square(File.A, Rank.R8), MoveType.Normal(true))
val bishopCapture = Move(Square(File.C, Rank.R1), Square(File.A, Rank.R8), MoveType.Normal(true))
val queenCapture = Move(Square(File.D, Rank.R1), Square(File.A, Rank.R8), MoveType.Normal(true))
val kingCapture = Move(Square(File.E, Rank.R1), Square(File.A, Rank.R8), MoveType.Normal(true))
val pawnCapture = Move(Square(File.F, Rank.R2), Square(File.A, Rank.R8), MoveType.Normal(true))
// Just verify all are scored without error
MoveOrdering.score(context, rookCapture, None) should be >= 0
MoveOrdering.score(context, knightCapture, None) should be >= 0
MoveOrdering.score(context, bishopCapture, None) should be >= 0
MoveOrdering.score(context, queenCapture, None) should be >= 0
MoveOrdering.score(context, kingCapture, None) should be >= 0
MoveOrdering.score(context, pawnCapture, None) should be >= 0
test("promotion capture is distinct from quiet promotion"):
val board = Board(
Map(
Square(File.E, Rank.R7) -> Piece.WhitePawn,
Square(File.D, Rank.R8) -> Piece.BlackPawn,
),
)
val context = GameContext.initial.withBoard(board)
// Promotion with capture
val promotionWithCapture =
Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Queen))
// Regular queen promotion (no capture)
val quietPromotion =
Move(Square(File.E, Rank.R7), Square(File.E, Rank.R8), MoveType.Promotion(PromotionPiece.Queen))
val score1 = MoveOrdering.score(context, promotionWithCapture, None)
val score2 = MoveOrdering.score(context, quietPromotion, None)
score1 should be > score2
test("non-Queen promotion captures trigger promotionPieceType for Knight, Bishop, Rook"):
val board = Board(
Map(
Square(File.E, Rank.R7) -> Piece.WhitePawn,
Square(File.D, Rank.R8) -> Piece.BlackRook,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val knightPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Knight))
val bishopPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Bishop))
val rookPromo = Move(Square(File.E, Rank.R7), Square(File.D, Rank.R8), MoveType.Promotion(PromotionPiece.Rook))
MoveOrdering.score(context, knightPromo, None) should be > 0
MoveOrdering.score(context, bishopPromo, None) should be > 0
MoveOrdering.score(context, rookPromo, None) should be > 0
test("negative SEE capture path is scored below neutral capture baseline"):
val board = Board(
Map(
Square(File.D, Rank.R4) -> Piece.WhiteQueen,
Square(File.D, Rank.R5) -> Piece.BlackPawn,
Square(File.D, Rank.R8) -> Piece.BlackRook,
),
)
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val move = Move(Square(File.D, Rank.R4), Square(File.D, Rank.R5), MoveType.Normal(true))
MoveOrdering.score(context, move, None) should be < 100_000
test("non-capture move keeps fallback scoring at zero"):
val board = Board(Map(Square(File.E, Rank.R1) -> Piece.WhiteKing, Square(File.A, Rank.R8) -> Piece.BlackKing))
val context = GameContext.initial.withBoard(board).withTurn(Color.White)
val castle = Move(Square(File.E, Rank.R1), Square(File.G, Rank.R1), MoveType.CastleKingside)
MoveOrdering.score(context, castle, None) should be(0)
@@ -0,0 +1,153 @@
package de.nowchess.bot
import de.nowchess.api.board.{Color, File, Rank, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
import de.nowchess.bot.bots.ClassicalBot
import de.nowchess.bot.util.{PolyglotBook, PolyglotHash}
import de.nowchess.rules.sets.DefaultRules
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import java.io.{DataOutputStream, FileOutputStream}
import java.nio.file.Files
import scala.util.Using
class PolyglotBookTest extends AnyFunSuite with Matchers:
test("Book probe returns None for non-existent file"):
val book = PolyglotBook("/nonexistent/path/book.bin")
book.probe(GameContext.initial) shouldEqual None
test("Book probe returns None when position not in book"):
val tempFile = Files.createTempFile("test_book", ".bin")
try
// Write a single entry with a different key
Using(DataOutputStream(FileOutputStream(tempFile.toFile))) { dos =>
dos.writeLong(12345L) // some random key
dos.writeShort(0) // move
dos.writeShort(100) // weight
dos.writeInt(0) // learn
}.get
val book = PolyglotBook(tempFile.toString)
book.probe(GameContext.initial) shouldEqual None
finally Files.delete(tempFile)
test("Book returns a move when position is in book"):
val tempFile = Files.createTempFile("test_book", ".bin")
try
val ctx = GameContext.initial
val hash = PolyglotHash.hash(ctx)
// Write an entry: e2-e4 (normal move, non-capture)
// from_file=4, from_rank=1, to_file=4, to_rank=3, promotion=0
val move: Short = (4 | (3 << 3) | (4 << 6) | (1 << 9)).toShort
Using(DataOutputStream(FileOutputStream(tempFile.toFile))) { dos =>
dos.writeLong(hash)
dos.writeShort(move)
dos.writeShort(100) // weight
dos.writeInt(0)
}.get
val book = PolyglotBook(tempFile.toString)
val result = book.probe(ctx)
result shouldNot be(None)
result.get.from shouldEqual Square(File.E, Rank.R2)
result.get.to shouldEqual Square(File.E, Rank.R4)
finally Files.delete(tempFile)
test("Weighted random sampling works"):
val tempFile = Files.createTempFile("test_book", ".bin")
try
val ctx = GameContext.initial
val hash = PolyglotHash.hash(ctx)
// Two moves: e2-e4 with high weight, d2-d4 with low weight
val moveE4: Short = (4 | (3 << 3) | (4 << 6) | (1 << 9)).toShort
val moveD4: Short = (3 | (3 << 3) | (3 << 6) | (1 << 9)).toShort
Using(DataOutputStream(FileOutputStream(tempFile.toFile))) { dos =>
dos.writeLong(hash)
dos.writeShort(moveE4)
dos.writeShort(900) // high weight
dos.writeInt(0)
dos.writeLong(hash)
dos.writeShort(moveD4)
dos.writeShort(100) // low weight
dos.writeInt(0)
}.get
val book = PolyglotBook(tempFile.toString)
// Sample multiple times; high-weight move should be picked more often
val samples = (0 until 100).map(_ => book.probe(ctx)).flatten
samples.length should be > 0
val e4Count = samples.count(m => m.from == Square(File.E, Rank.R2) && m.to == Square(File.E, Rank.R4))
val d4Count = samples.count(m => m.from == Square(File.D, Rank.R2) && m.to == Square(File.D, Rank.R4))
// With 900:100 weight ratio, e4 should appear more frequently
e4Count should be > d4Count
finally Files.delete(tempFile)
test("ClassicalBot without book falls back to search"):
val ctx = GameContext.initial
val bot = ClassicalBot(BotDifficulty.Easy) // no book
val move = bot.nextMove(ctx)
move shouldNot be(None)
// The move should be legal
val allLegalMoves = DefaultRules.allLegalMoves(ctx)
allLegalMoves should contain(move.get)
test("ClassicalBot with book uses book move"):
val tempFile = Files.createTempFile("test_book", ".bin")
try
val ctx = GameContext.initial
val hash = PolyglotHash.hash(ctx)
// e2-e4
val moveE4: Short = (4 | (3 << 3) | (4 << 6) | (1 << 9)).toShort
Using(DataOutputStream(FileOutputStream(tempFile.toFile))) { dos =>
dos.writeLong(hash)
dos.writeShort(moveE4)
dos.writeShort(100)
dos.writeInt(0)
}.get
val book = PolyglotBook(tempFile.toString)
val botWithBook = ClassicalBot(BotDifficulty.Easy, book = Some(book))
val move = botWithBook.nextMove(ctx)
// Book should return e2-e4
move shouldEqual Some(Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4), MoveType.Normal()))
finally Files.delete(tempFile)
test("Promotion moves are decoded correctly"):
val tempFile = Files.createTempFile("test_book", ".bin")
try
val ctx = GameContext.initial
val hash = PolyglotHash.hash(ctx)
// Pawn promotion: a7-a8=Q
// from_file=0, from_rank=6, to_file=0, to_rank=7, promotion=4 (queen)
val promoteMove: Short = (0 | (7 << 3) | (0 << 6) | (6 << 9) | (4 << 12)).toShort
Using(DataOutputStream(FileOutputStream(tempFile.toFile))) { dos =>
dos.writeLong(hash)
dos.writeShort(promoteMove)
dos.writeShort(100)
dos.writeInt(0)
}.get
val book = PolyglotBook(tempFile.toString)
val move = book.probe(ctx)
move shouldNot be(None)
move.get.moveType match
case MoveType.Promotion(piece) => piece shouldEqual PromotionPiece.Queen
case _ => fail("Expected promotion move")
finally Files.delete(tempFile)
@@ -0,0 +1,58 @@
package de.nowchess.bot
import de.nowchess.api.board.{Color, File, Rank, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.bot.util.PolyglotHash
import de.nowchess.io.fen.FenParser
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
class PolyglotHashTest extends AnyFunSuite with Matchers:
test("Initial position matches Polyglot reference key"):
val ctx = GameContext.initial
PolyglotHash.hash(ctx) shouldEqual java.lang.Long.parseUnsignedLong("463b96181691fc9c", 16)
test("Known Polyglot FEN vector matches reference key"):
val fen = "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1"
val ctx = FenParser.parseFen(fen).toOption.getOrElse(fail("FEN parse failed"))
PolyglotHash.hash(ctx) shouldEqual java.lang.Long.parseUnsignedLong("823c9b50fd114196", 16)
test("Hash changes when turn changes"):
val ctx = GameContext.initial
val hash1 = PolyglotHash.hash(ctx)
val ctxBlackTurn = ctx.withTurn(Color.Black)
val hash2 = PolyglotHash.hash(ctxBlackTurn)
hash1 should not equal hash2
test("Hash changes when castling rights change"):
val ctx = GameContext.initial
val hash1 = PolyglotHash.hash(ctx)
val noCastling = ctx.withCastlingRights(
de.nowchess.api.board.CastlingRights(false, false, false, false),
)
val hash2 = PolyglotHash.hash(noCastling)
hash1 should not equal hash2
test("En passant file is ignored when no side-to-move pawn can capture"):
val fenWithEp = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b KQkq e3 0 1"
val fenWithoutEp = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b KQkq - 0 1"
val withEp = FenParser.parseFen(fenWithEp).toOption.getOrElse(fail("FEN parse failed"))
val withoutEp = FenParser.parseFen(fenWithoutEp).toOption.getOrElse(fail("FEN parse failed"))
PolyglotHash.hash(withEp) shouldEqual PolyglotHash.hash(withoutEp)
test("Different en passant files produce different hashes when capture is possible"):
val ctx = GameContext.initial
val epFileE = ctx.withEnPassantSquare(Some(Square(File.E, Rank.R3)))
val epFileD = ctx.withEnPassantSquare(Some(Square(File.D, Rank.R3)))
val hash1 = PolyglotHash.hash(epFileE)
val hash2 = PolyglotHash.hash(epFileD)
hash1 should not equal hash2
test("Removing en passant changes hash"):
val ctx = GameContext.initial
val withEP = ctx.withEnPassantSquare(Some(Square(File.E, Rank.R3)))
val hash1 = PolyglotHash.hash(withEP)
val noEP = withEP.withEnPassantSquare(None)
val hash2 = PolyglotHash.hash(noEP)
hash1 should not equal hash2
@@ -0,0 +1,73 @@
package de.nowchess.bot
import de.nowchess.api.board.{File, Rank, Square}
import de.nowchess.api.move.{Move, MoveType}
import de.nowchess.bot.logic.{TTEntry, TTFlag, TranspositionTable}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
class TranspositionTableTest extends AnyFunSuite with Matchers:
test("probe on empty table returns None"):
val tt = TranspositionTable(sizePow2 = 4)
tt.probe(12345L) should be(None)
test("store then probe returns the stored entry"):
val tt = TranspositionTable(sizePow2 = 4)
val move = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4), MoveType.Normal())
val entry = TTEntry(hash = 12345L, depth = 3, score = 50, flag = TTFlag.Exact, bestMove = Some(move))
tt.store(entry)
val retrieved = tt.probe(12345L)
retrieved should be(Some(entry))
test("probe returns None when hash differs (collision guard)"):
val tt = TranspositionTable(sizePow2 = 4)
val move = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4), MoveType.Normal())
val entry = TTEntry(hash = 12345L, depth = 3, score = 50, flag = TTFlag.Exact, bestMove = Some(move))
tt.store(entry)
tt.probe(54321L) should be(None)
test("clear removes all entries"):
val tt = TranspositionTable(sizePow2 = 4)
val move = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4), MoveType.Normal())
val entry = TTEntry(hash = 12345L, depth = 3, score = 50, flag = TTFlag.Exact, bestMove = Some(move))
tt.store(entry)
tt.probe(12345L) should be(Some(entry))
tt.clear()
tt.probe(12345L) should be(None)
test("all TTFlag values store and retrieve correctly"):
val tt = TranspositionTable(sizePow2 = 4)
val move = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4), MoveType.Normal())
TTFlag.values.foreach { flag =>
val entry = TTEntry(hash = 12345L + flag.ordinal, depth = 2, score = 100, flag = flag, bestMove = Some(move))
tt.store(entry)
val retrieved = tt.probe(12345L + flag.ordinal)
retrieved.map(_.flag) should be(Some(flag))
}
test("bestMove = None roundtrips"):
val tt = TranspositionTable(sizePow2 = 4)
val entry = TTEntry(hash = 99999L, depth = 1, score = 0, flag = TTFlag.Upper, bestMove = None)
tt.store(entry)
val retrieved = tt.probe(99999L)
retrieved.map(_.bestMove) should be(Some(None))
test("always-replace overwrites at same slot"):
val tt = TranspositionTable(sizePow2 = 4)
val move1 = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4), MoveType.Normal())
val move2 = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R3), MoveType.Normal())
val entry1 = TTEntry(hash = 12345L, depth = 2, score = 50, flag = TTFlag.Exact, bestMove = Some(move1))
val entry2 = TTEntry(hash = 12345L, depth = 3, score = 100, flag = TTFlag.Lower, bestMove = Some(move2))
tt.store(entry1)
tt.probe(12345L).map(_.score) should be(Some(50))
tt.store(entry2)
tt.probe(12345L).map(_.score) should be(Some(100))
test("size is 1 << sizePow2"):
val tt = TranspositionTable(sizePow2 = 4)
(1 << 4) should equal(16)
val tt2 = TranspositionTable(sizePow2 = 10)
(1 << 10) should equal(1024)
@@ -0,0 +1,162 @@
package de.nowchess.bot
import de.nowchess.api.board.{Board, CastlingRights, Color, File, Piece, Rank, Square}
import de.nowchess.api.game.GameContext
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
import de.nowchess.bot.util.ZobristHash
import de.nowchess.rules.sets.DefaultRules
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
class ZobristHashTest extends AnyFunSuite with Matchers:
test("hash is deterministic"):
val hash1 = ZobristHash.hash(GameContext.initial)
val hash2 = ZobristHash.hash(GameContext.initial)
hash1 should equal(hash2)
test("hash differs after a pawn move"):
val initial = GameContext.initial
// Move pawn from e2 to e4
val board = initial.board.pieces
val newBoard = board.removed(Square(File.E, Rank.R2)).updated(Square(File.E, Rank.R4), Piece.WhitePawn)
val afterMove = initial.withBoard(Board(newBoard)).withTurn(Color.Black)
val hash1 = ZobristHash.hash(initial)
val hash2 = ZobristHash.hash(afterMove)
hash1 should not equal hash2
test("hash includes castling rights"):
val ctx1 = GameContext.initial
val ctx2 = ctx1.withCastlingRights(CastlingRights.None)
val hash1 = ZobristHash.hash(ctx1)
val hash2 = ZobristHash.hash(ctx2)
hash1 should not equal hash2
test("hash includes en-passant square"):
val ctx1 = GameContext.initial
val ctx2 = ctx1.withEnPassantSquare(Some(Square(File.E, Rank.R3)))
val hash1 = ZobristHash.hash(ctx1)
val hash2 = ZobristHash.hash(ctx2)
hash1 should not equal hash2
test("hash includes side to move"):
val ctx1 = GameContext.initial
val ctx2 = ctx1.withTurn(Color.Black)
val hash1 = ZobristHash.hash(ctx1)
val hash2 = ZobristHash.hash(ctx2)
hash1 should not equal hash2
test("nextHash matches recomputed hash for a normal move"):
val context = GameContext.initial
val move = Move(Square(File.E, Rank.R2), Square(File.E, Rank.R4))
val next = DefaultRules.applyMove(context)(move)
val incremental = ZobristHash.nextHash(context, ZobristHash.hash(context), move, next)
incremental should equal(ZobristHash.hash(next))
test("nextHash matches recomputed hash for promotion and castling"):
val promotionBoard = Board(
Map(
Square(File.E, Rank.R7) -> Piece.WhitePawn,
Square(File.E, Rank.R1) -> Piece.WhiteKing,
Square(File.H, Rank.R1) -> Piece.WhiteRook,
Square(File.E, Rank.R8) -> Piece.BlackKing,
),
)
val promotionContext = GameContext.initial
.withBoard(promotionBoard)
.withTurn(Color.White)
.withCastlingRights(CastlingRights.All)
val promotionMove = Move(Square(File.E, Rank.R7), Square(File.E, Rank.R8), MoveType.Promotion(PromotionPiece.Queen))
val promotionNext = DefaultRules.applyMove(promotionContext)(promotionMove)
val promotionHash =
ZobristHash.nextHash(promotionContext, ZobristHash.hash(promotionContext), promotionMove, promotionNext)
promotionHash should equal(ZobristHash.hash(promotionNext))
val castleBoard = Board(
Map(
Square(File.E, Rank.R1) -> Piece.WhiteKing,
Square(File.H, Rank.R1) -> Piece.WhiteRook,
Square(File.E, Rank.R8) -> Piece.BlackKing,
),
)
val castleContext = GameContext.initial
.withBoard(castleBoard)
.withTurn(Color.White)
.withCastlingRights(
CastlingRights(whiteKingSide = true, whiteQueenSide = false, blackKingSide = false, blackQueenSide = false),
)
val castleMove = Move(Square(File.E, Rank.R1), Square(File.G, Rank.R1), MoveType.CastleKingside)
val castleNext = DefaultRules.applyMove(castleContext)(castleMove)
val castleHash = ZobristHash.nextHash(castleContext, ZobristHash.hash(castleContext), castleMove, castleNext)
castleHash should equal(ZobristHash.hash(castleNext))
test("nextHash matches recomputed hash for queenside castling"):
val board = Board(
Map(
Square(File.E, Rank.R1) -> Piece.WhiteKing,
Square(File.A, Rank.R1) -> Piece.WhiteRook,
Square(File.E, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.White)
.withCastlingRights(
CastlingRights(whiteKingSide = false, whiteQueenSide = true, blackKingSide = false, blackQueenSide = false),
)
val move = Move(Square(File.E, Rank.R1), Square(File.C, Rank.R1), MoveType.CastleQueenside)
val next = DefaultRules.applyMove(ctx)(move)
ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next))
test("nextHash matches recomputed hash for en passant"):
val board = Board(
Map(
Square(File.E, Rank.R5) -> Piece.WhitePawn,
Square(File.D, Rank.R5) -> Piece.BlackPawn,
Square(File.E, Rank.R1) -> Piece.WhiteKing,
Square(File.E, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.White)
.withEnPassantSquare(Some(Square(File.D, Rank.R6)))
val move = Move(Square(File.E, Rank.R5), Square(File.D, Rank.R6), MoveType.EnPassant)
val next = DefaultRules.applyMove(ctx)(move)
ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next))
test("nextHash matches recomputed hash for black kingside castling"):
val board = Board(
Map(
Square(File.E, Rank.R8) -> Piece.BlackKing,
Square(File.H, Rank.R8) -> Piece.BlackRook,
Square(File.E, Rank.R1) -> Piece.WhiteKing,
),
)
val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.Black)
.withCastlingRights(
CastlingRights(whiteKingSide = false, whiteQueenSide = false, blackKingSide = true, blackQueenSide = false),
)
val move = Move(Square(File.E, Rank.R8), Square(File.G, Rank.R8), MoveType.CastleKingside)
val next = DefaultRules.applyMove(ctx)(move)
ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next))
test("nextHash matches recomputed hash for knight and rook promotions"):
val board = Board(
Map(
Square(File.E, Rank.R7) -> Piece.WhitePawn,
Square(File.E, Rank.R1) -> Piece.WhiteKing,
Square(File.E, Rank.R8) -> Piece.BlackKing,
),
)
val ctx = GameContext.initial
.withBoard(board)
.withTurn(Color.White)
.withCastlingRights(CastlingRights(false, false, false, false))
for pp <- List(PromotionPiece.Knight, PromotionPiece.Bishop, PromotionPiece.Rook) do
val move = Move(Square(File.E, Rank.R7), Square(File.E, Rank.R8), MoveType.Promotion(pp))
val next = DefaultRules.applyMove(ctx)(move)
ZobristHash.nextHash(ctx, ZobristHash.hash(ctx), move, next) should equal(ZobristHash.hash(next))
+26
View File
@@ -260,3 +260,29 @@
* correct test board positions and captureOutput/withInput interaction ([f0481e2](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/f0481e2561b779df00925b46ee281dc36a795150))
* update main class path in build configuration and adjust VCS directory mapping ([7b1f8b1](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/7b1f8b117623d327232a1a92a8a44d18582e0189))
* update move validation to check for king safety ([#13](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/13)) ([e5e20c5](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/e5e20c566e368b12ca1dc59680c34e9112bf6762))
## (2026-04-16)
### Features
* add GameRules stub with PositionStatus enum ([76d4168](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/76d4168038de23e5d6083d4e8f0504fbf31d15a3))
* add MovedInCheck/Checkmate/Stalemate MoveResult variants (stub dispatch) ([8b7ec57](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/8b7ec57e5ea6ee1615a1883848a426dc07d26364))
* implement GameRules with isInCheck, legalMoves, gameStatus ([94a02ff](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/94a02ff6849436d9496c70a0f16c21666dae8e4e))
* implement legal castling ([#1](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/1)) ([00d326c](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/00d326c1ba67711fbe180f04e1100c3f01dd0254))
* NCS-10 Implement Pawn Promotion ([#12](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/12)) ([13bfc16](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/13bfc16cfe25db78ec607db523ca6d993c13430c))
* NCS-11 50-move rule ([#9](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/9)) ([412ed98](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/412ed986a95703a3b282276540153480ceed229d))
* NCS-13 Implement Threefold Repetition ([#31](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/31)) ([767d305](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/767d3051a76c266050b6335774d66e2db2273c16))
* NCS-14 implemented insufficient moves rule ([#30](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/30)) ([b0399a4](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/b0399a4e489950083066c9538df9a84dcc7a4613))
* NCS-16 Core Separation via Patterns ([#10](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/10)) ([1361dfc](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/1361dfc89553b146864fb8ff3526cf12cf3f293a))
* NCS-17 Implement basic ScalaFX UI ([#14](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/14)) ([3ff8031](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/3ff80318b4f16c59733a46498581a5c27f048287))
* NCS-21 Write Scripts to automate certain tasks ([#15](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/15)) ([8051871](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/80518719d536a087d339fe02530825dc07f8b388))
* NCS-25 Add linters to keep quality up ([#27](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/27)) ([fd4e67d](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/fd4e67d4f782a7e955822d90cb909d0a81676fb2))
* NCS-6 Implementing FEN & PGN ([#7](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/7)) ([f28e69d](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/f28e69dc181416aa2f221fdc4b45c2cda5efbf07))
* NCS-9 En passant implementation ([#8](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/8)) ([919beb3](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/919beb3b4bfa8caf2f90976a415fe9b19b7e9747))
* wire check/checkmate/stalemate into processMove and gameLoop ([5264a22](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/5264a225418b885c5e6ea6411b96f85e38837f6c))
### Bug Fixes
* add missing kings to gameLoop capture test board ([aedd787](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/aedd787b77203c2af934751dba7b784eaf165032))
* correct test board positions and captureOutput/withInput interaction ([f0481e2](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/f0481e2561b779df00925b46ee281dc36a795150))
* update main class path in build configuration and adjust VCS directory mapping ([7b1f8b1](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/7b1f8b117623d327232a1a92a8a44d18582e0189))
* update move validation to check for king safety ([#13](https://git.janis-eccarius.de/NowChess/NowChessSystems/issues/13)) ([e5e20c5](https://git.janis-eccarius.de/NowChess/NowChessSystems/commit/e5e20c566e368b12ca1dc59680c34e9112bf6762))
+1
View File
@@ -41,6 +41,7 @@ dependencies {
implementation(project(":modules:api"))
implementation(project(":modules:io"))
implementation(project(":modules:rule"))
implementation(project(":modules:bot"))
testImplementation(platform("org.junit:junit-bom:5.13.4"))
testImplementation("org.junit.jupiter:junit-jupiter")
@@ -1,21 +1,28 @@
package de.nowchess.chess.controller
import de.nowchess.api.board.{File, Rank, Square}
import de.nowchess.api.move.PromotionPiece
object Parser:
/** Parses coordinate notation such as "e2e4" or "g1f3". Returns None for any input that does not match the expected
* format.
/** Parses UCI move notation: "e2e4" (4 chars) or "e7e8q" (5 chars with promotion piece suffix). The promotion suffix
* is q=Queen, r=Rook, b=Bishop, n=Knight. Returns None for invalid input.
*/
def parseMove(input: String): Option[(Square, Square)] =
def parseMove(input: String): Option[(Square, Square, Option[PromotionPiece])] =
val trimmed = input.trim.toLowerCase
Option
.when(trimmed.length == 4)(trimmed)
.flatMap: s =>
trimmed.length match
case 4 =>
for
from <- parseSquare(s.substring(0, 2))
to <- parseSquare(s.substring(2, 4))
yield (from, to)
from <- parseSquare(trimmed.substring(0, 2))
to <- parseSquare(trimmed.substring(2, 4))
yield (from, to, None)
case 5 =>
for
from <- parseSquare(trimmed.substring(0, 2))
to <- parseSquare(trimmed.substring(2, 4))
promo <- parsePromotion(trimmed(4))
yield (from, to, Some(promo))
case _ => None
private def parseSquare(s: String): Option[Square] =
Option
@@ -26,3 +33,10 @@ object Parser:
Option.when(fileIdx >= 0 && fileIdx <= 7 && rankIdx >= 0 && rankIdx <= 7)(
Square(File.values(fileIdx), Rank.values(rankIdx)),
)
private def parsePromotion(c: Char): Option[PromotionPiece] = c match
case 'q' => Some(PromotionPiece.Queen)
case 'r' => Some(PromotionPiece.Rook)
case 'b' => Some(PromotionPiece.Bishop)
case 'n' => Some(PromotionPiece.Knight)
case _ => None
@@ -2,7 +2,8 @@ package de.nowchess.chess.engine
import de.nowchess.api.board.{Board, Color, Piece, PieceType, Square}
import de.nowchess.api.move.{Move, MoveType, PromotionPiece}
import de.nowchess.api.game.{DrawReason, GameContext, GameResult}
import de.nowchess.api.game.{BotParticipant, DrawReason, GameContext, GameResult, Human, Participant}
import de.nowchess.api.player.{PlayerId, PlayerInfo}
import de.nowchess.chess.controller.Parser
import de.nowchess.chess.observer.*
import de.nowchess.chess.command.{CommandInvoker, MoveCommand, MoveResult}
@@ -10,29 +11,29 @@ import de.nowchess.io.{GameContextExport, GameContextImport}
import de.nowchess.rules.RuleSet
import de.nowchess.rules.sets.DefaultRules
import scala.concurrent.{ExecutionContext, Future}
/** Pure game engine that manages game state and notifies observers of state changes. All rule queries delegate to the
* injected RuleSet. All user interactions go through Commands; state changes are broadcast via GameEvents.
*/
class GameEngine(
val initialContext: GameContext = GameContext.initial,
val ruleSet: RuleSet = DefaultRules,
val participants: Map[Color, Participant] = Map(
Color.White -> Human(PlayerInfo(PlayerId("p1"), "Player 1")),
Color.Black -> Human(PlayerInfo(PlayerId("p2"), "Player 2")),
),
) extends Observable:
// Ensure that initialBoard is set correctly for threefold repetition detection
private val contextWithInitialBoard = if initialContext.moves.isEmpty && initialContext.board != initialContext.initialBoard then
initialContext.copy(initialBoard = initialContext.board)
else
initialContext
private val contextWithInitialBoard =
if initialContext.moves.isEmpty && initialContext.board != initialContext.initialBoard then
initialContext.copy(initialBoard = initialContext.board)
else initialContext
@SuppressWarnings(Array("DisableSyntax.var"))
private var currentContext: GameContext = contextWithInitialBoard
private val invoker = new CommandInvoker()
/** Pending promotion: the Move that triggered it (from/to only, moveType filled in later). */
private case class PendingPromotion(from: Square, to: Square, contextBefore: GameContext)
@SuppressWarnings(Array("DisableSyntax.var"))
private var pendingPromotion: Option[PendingPromotion] = None
/** True if a pawn promotion move is pending and needs a piece choice. */
def isPendingPromotion: Boolean = synchronized(pendingPromotion.isDefined)
private implicit val ec: ExecutionContext = ExecutionContext.global
// Synchronized accessors for current state
def board: Board = synchronized(currentContext.board)
@@ -92,11 +93,11 @@ class GameEngine(
s"Invalid move format '$moveInput'. Use coordinate notation, e.g. e2e4.",
),
)
case Some((from, to)) =>
handleParsedMove(from, to)
case Some((from, to, promotionPiece: Option[PromotionPiece])) =>
handleParsedMove(from, to, promotionPiece)
}
private def handleParsedMove(from: Square, to: Square): Unit =
private def handleParsedMove(from: Square, to: Square, promotionPiece: Option[PromotionPiece]): Unit =
currentContext.board.pieceAt(from) match
case None =>
notifyObservers(InvalidMoveEvent(currentContext, "No piece on that square."))
@@ -109,11 +110,18 @@ class GameEngine(
candidates match
case Nil =>
notifyObservers(InvalidMoveEvent(currentContext, "Illegal move."))
case moves if isPromotionMove(piece, to) =>
// Multiple moves (one per promotion piece) — ask user to choose
val contextBefore = currentContext
pendingPromotion = Some(PendingPromotion(from, to, contextBefore))
notifyObservers(PromotionRequiredEvent(currentContext, from, to))
case _ if isPromotionMove(piece, to) =>
if promotionPiece.isEmpty then
notifyObservers(
InvalidMoveEvent(currentContext, "Promotion piece required: append q, r, b, or n to the move."),
)
else
candidates.find(_.moveType == MoveType.Promotion(promotionPiece.get)) match
case None =>
notifyObservers(
InvalidMoveEvent(currentContext, "Error completing promotion: no matching legal move."),
)
case Some(move) => executeMove(move)
case move :: _ =>
executeMove(move)
@@ -123,21 +131,6 @@ class GameEngine(
to.rank.ordinal == promoRank
}
/** Apply a player's promotion piece choice. Must only be called when isPendingPromotion is true.
*/
def completePromotion(piece: PromotionPiece): Unit = synchronized {
pendingPromotion match
case None =>
notifyObservers(InvalidMoveEvent(currentContext, "No promotion pending."))
case Some(pending) =>
pendingPromotion = None
val move = Move(pending.from, pending.to, MoveType.Promotion(piece))
// Verify it's actually legal
val legal = ruleSet.legalMoves(currentContext)(pending.from)
if legal.contains(move) then executeMove(move)
else notifyObservers(InvalidMoveEvent(currentContext, "Error completing promotion."))
}
/** Undo the last move. */
def undo(): Unit = synchronized(performUndo())
@@ -159,7 +152,6 @@ class GameEngine(
private def replayGame(ctx: GameContext): Either[String, Unit] =
val savedContext = currentContext
currentContext = GameContext.initial
pendingPromotion = None
invoker.clear()
if ctx.moves.isEmpty then
@@ -175,14 +167,13 @@ class GameEngine(
result
private def applyReplayMove(move: Move): Either[String, Unit] =
handleParsedMove(move.from, move.to)
move.moveType match
case MoveType.Promotion(pp) if pendingPromotion.isDefined =>
completePromotion(pp)
Right(())
case MoveType.Promotion(_) =>
Left(s"Promotion required for move ${move.from}${move.to}")
case _ => Right(())
val legal = ruleSet.legalMoves(currentContext)(move.from)
val candidate = move.moveType match
case MoveType.Promotion(pp) => legal.find(m => m.to == move.to && m.moveType == MoveType.Promotion(pp))
case _ => legal.find(_.to == move.to)
candidate match
case None => Left("Illegal move.")
case Some(lm) => executeMove(lm); Right(())
/** Export the current game context using the provided exporter. */
def exportGame(exporter: GameContextExport): String = synchronized {
@@ -191,12 +182,10 @@ class GameEngine(
/** Load an arbitrary board position, clearing all history and undo/redo state. */
def loadPosition(newContext: GameContext): Unit = synchronized {
val contextWithInitialBoard = if newContext.moves.isEmpty then
newContext.copy(initialBoard = newContext.board)
else
newContext
val contextWithInitialBoard =
if newContext.moves.isEmpty then newContext.copy(initialBoard = newContext.board)
else newContext
currentContext = contextWithInitialBoard
pendingPromotion = None
invoker.clear()
notifyObservers(BoardResetEvent(currentContext))
}
@@ -208,6 +197,9 @@ class GameEngine(
notifyObservers(BoardResetEvent(currentContext))
}
/** Kick off play when the side to move is a bot (e.g. bot-vs-bot from initial position). */
def startGame(): Unit = synchronized(requestBotMoveIfNeeded())
// ──── Private helpers ────
private def executeMove(move: Move): Unit =
@@ -250,7 +242,9 @@ class GameEngine(
else if ruleSet.isCheck(currentContext) then notifyObservers(CheckDetectedEvent(currentContext))
if currentContext.halfMoveClock >= 100 then notifyObservers(FiftyMoveRuleAvailableEvent(currentContext))
if ruleSet.isThreefoldRepetition(currentContext) then notifyObservers(ThreefoldRepetitionAvailableEvent(currentContext))
if ruleSet.isThreefoldRepetition(currentContext) then
notifyObservers(ThreefoldRepetitionAvailableEvent(currentContext))
else requestBotMoveIfNeeded()
private def translateMoveToNotation(move: Move, boardBefore: Board): String =
move.moveType match
@@ -301,6 +295,47 @@ class GameEngine(
case _ =>
context.board.pieceAt(move.to)
/** Request a move from the opponent bot if it's their turn. Spawns an async task to avoid blocking the engine.
*/
private def requestBotMoveIfNeeded(): Unit =
val pendingBotMove = synchronized {
participants.get(currentContext.turn) match
case Some(BotParticipant(bot)) => Some((bot, currentContext))
case _ => None
}
pendingBotMove.foreach { case (bot, contextAtRequest) =>
Future {
bot.nextMove(contextAtRequest) match
case Some(move) => applyBotMove(move)
case None => handleBotNoMove()
}
}
private def applyBotMove(move: Move): Unit =
synchronized {
val color = currentContext.turn
val from = move.from
val to = move.to
currentContext.board.pieceAt(from) match
case Some(piece) if piece.color == color =>
val legal = ruleSet.legalMoves(currentContext)(from)
legal.find(m => m.to == to && m.moveType == move.moveType) match
case Some(legalMove) => executeMove(legalMove)
case None =>
notifyObservers(InvalidMoveEvent(currentContext, s"Bot move ${from}${to} is illegal"))
case _ =>
notifyObservers(InvalidMoveEvent(currentContext, "Bot move has invalid source square"))
}
private def handleBotNoMove(): Unit =
synchronized {
if ruleSet.isCheckmate(currentContext) then
val winner = currentContext.turn.opposite
notifyObservers(CheckmateEvent(currentContext, winner))
else if ruleSet.isStalemate(currentContext) then notifyObservers(DrawEvent(currentContext, DrawReason.Stalemate))
}
private def performUndo(): Unit =
if invoker.canUndo then
val cmd = invoker.history(invoker.getCurrentIndex)
@@ -1,6 +1,6 @@
package de.nowchess.chess.observer
import de.nowchess.api.board.{Color, Square}
import de.nowchess.api.board.Color
import de.nowchess.api.game.{DrawReason, GameContext}
/** Base trait for all game state events. Events are immutable snapshots of game state changes.
@@ -39,13 +39,6 @@ case class InvalidMoveEvent(
reason: String,
) extends GameEvent
/** Fired when a pawn reaches the back rank and the player must choose a promotion piece. */
case class PromotionRequiredEvent(
context: GameContext,
from: Square,
to: Square,
) extends GameEvent
/** Fired when the board is reset. */
case class BoardResetEvent(
context: GameContext,
@@ -18,8 +18,8 @@ class CommandInvokerBranchTest extends AnyFunSuite with Matchers:
initialShouldFailOnUndo: Boolean = false,
initialShouldFailOnExecute: Boolean = false,
) extends Command:
val shouldFailOnUndo = new java.util.concurrent.atomic.AtomicBoolean(initialShouldFailOnUndo)
val shouldFailOnExecute = new java.util.concurrent.atomic.AtomicBoolean(initialShouldFailOnExecute)
val shouldFailOnUndo = new java.util.concurrent.atomic.AtomicBoolean(initialShouldFailOnUndo)
val shouldFailOnExecute = new java.util.concurrent.atomic.AtomicBoolean(initialShouldFailOnExecute)
override def execute(): Boolean = !shouldFailOnExecute.get()
override def undo(): Boolean = !shouldFailOnUndo.get()
override def description: String = "Conditional fail"
@@ -1,25 +1,26 @@
package de.nowchess.chess.controller
import de.nowchess.api.board.{File, Rank, Square}
import de.nowchess.api.move.PromotionPiece
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
class ParserTest extends AnyFunSuite with Matchers:
test("parseMove parses valid 'e2e4'"):
Parser.parseMove("e2e4") shouldBe Some((Square(File.E, Rank.R2), Square(File.E, Rank.R4)))
Parser.parseMove("e2e4") shouldBe Some((Square(File.E, Rank.R2), Square(File.E, Rank.R4), None))
test("parseMove is case-insensitive"):
Parser.parseMove("E2E4") shouldBe Some((Square(File.E, Rank.R2), Square(File.E, Rank.R4)))
Parser.parseMove("E2E4") shouldBe Some((Square(File.E, Rank.R2), Square(File.E, Rank.R4), None))
test("parseMove trims leading and trailing whitespace"):
Parser.parseMove(" e2e4 ") shouldBe Some((Square(File.E, Rank.R2), Square(File.E, Rank.R4)))
Parser.parseMove(" e2e4 ") shouldBe Some((Square(File.E, Rank.R2), Square(File.E, Rank.R4), None))
test("parseMove handles corner squares a1h8"):
Parser.parseMove("a1h8") shouldBe Some((Square(File.A, Rank.R1), Square(File.H, Rank.R8)))
Parser.parseMove("a1h8") shouldBe Some((Square(File.A, Rank.R1), Square(File.H, Rank.R8), None))
test("parseMove handles corner squares h8a1"):
Parser.parseMove("h8a1") shouldBe Some((Square(File.H, Rank.R8), Square(File.A, Rank.R1)))
Parser.parseMove("h8a1") shouldBe Some((Square(File.H, Rank.R8), Square(File.A, Rank.R1), None))
test("parseMove returns None for empty string"):
Parser.parseMove("") shouldBe None
@@ -27,8 +28,8 @@ class ParserTest extends AnyFunSuite with Matchers:
test("parseMove returns None for input shorter than 4 chars"):
Parser.parseMove("e2e") shouldBe None
test("parseMove returns None for input longer than 4 chars"):
Parser.parseMove("e2e44") shouldBe None
test("parseMove returns None for input longer than 5 chars"):
Parser.parseMove("e2e4qq") shouldBe None
test("parseMove returns None when from-file is out of range"):
Parser.parseMove("z2e4") shouldBe None
@@ -41,3 +42,31 @@ class ParserTest extends AnyFunSuite with Matchers:
test("parseMove returns None when to-rank is out of range"):
Parser.parseMove("e2e9") shouldBe None
test("parseMove parses queen promotion 'e7e8q'"):
Parser.parseMove("e7e8q") shouldBe Some(
(Square(File.E, Rank.R7), Square(File.E, Rank.R8), Some(PromotionPiece.Queen)),
)
test("parseMove parses rook promotion 'a7a8r'"):
Parser.parseMove("a7a8r") shouldBe Some(
(Square(File.A, Rank.R7), Square(File.A, Rank.R8), Some(PromotionPiece.Rook)),
)
test("parseMove parses bishop promotion 'e7e8b'"):
Parser.parseMove("e7e8b") shouldBe Some(
(Square(File.E, Rank.R7), Square(File.E, Rank.R8), Some(PromotionPiece.Bishop)),
)
test("parseMove parses knight promotion 'e7e8n'"):
Parser.parseMove("e7e8n") shouldBe Some(
(Square(File.E, Rank.R7), Square(File.E, Rank.R8), Some(PromotionPiece.Knight)),
)
test("parseMove returns None for 5-char input with invalid promotion char"):
Parser.parseMove("e7e8x") shouldBe None
test("parseMove parses black promotion 'e2e1q'"):
Parser.parseMove("e2e1q") shouldBe Some(
(Square(File.E, Rank.R2), Square(File.E, Rank.R1), Some(PromotionPiece.Queen)),
)

Some files were not shown because too many files have changed in this diff Show More