fix: add thread-safety synchronization to Observable and CommandInvoker
Test-driven fixes for code review blockers NCS-16: **Observable (CRITICAL):** Added synchronized blocks to subscribe, unsubscribe, notifyObservers, and observerCount to prevent race conditions when concurrent threads register observers while notifications are dispatched. **CommandInvoker (IMPORTANT):** Added synchronized blocks to all methods (execute, undo, redo, history, getCurrentIndex, canUndo, canRedo, clear) to ensure atomic access to mutable state (executedCommands, currentIndex). Tests: - Added ObservableThreadSafetyTest: 3 tests for concurrent subscribe/unsubscribe/notify - Added CommandInvokerThreadSafetyTest: 2 tests for concurrent execute/undo/redo - All 54 existing tests remain green - Full build passes with 100% core coverage Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -8,7 +8,7 @@ class CommandInvoker:
|
||||
/** Execute a command and add it to history.
|
||||
* Discards any redo history if not at the end of the stack.
|
||||
*/
|
||||
def execute(command: Command): Boolean =
|
||||
def execute(command: Command): Boolean = synchronized {
|
||||
if command.execute() then
|
||||
// Remove any commands after current index (redo stack is discarded)
|
||||
while currentIndex < executedCommands.size - 1 do
|
||||
@@ -18,9 +18,10 @@ class CommandInvoker:
|
||||
true
|
||||
else
|
||||
false
|
||||
}
|
||||
|
||||
/** Undo the last executed command if possible. */
|
||||
def undo(): Boolean =
|
||||
def undo(): Boolean = synchronized {
|
||||
if currentIndex >= 0 && currentIndex < executedCommands.size then
|
||||
val command = executedCommands(currentIndex)
|
||||
if command.undo() then
|
||||
@@ -30,9 +31,10 @@ class CommandInvoker:
|
||||
false
|
||||
else
|
||||
false
|
||||
}
|
||||
|
||||
/** Redo the next command in history if available. */
|
||||
def redo(): Boolean =
|
||||
def redo(): Boolean = synchronized {
|
||||
if currentIndex + 1 < executedCommands.size then
|
||||
val command = executedCommands(currentIndex + 1)
|
||||
if command.execute() then
|
||||
@@ -42,20 +44,30 @@ class CommandInvoker:
|
||||
false
|
||||
else
|
||||
false
|
||||
}
|
||||
|
||||
/** Get the history of all executed commands. */
|
||||
def history: List[Command] = executedCommands.toList
|
||||
def history: List[Command] = synchronized {
|
||||
executedCommands.toList
|
||||
}
|
||||
|
||||
/** Get the current position in command history. */
|
||||
def getCurrentIndex: Int = currentIndex
|
||||
def getCurrentIndex: Int = synchronized {
|
||||
currentIndex
|
||||
}
|
||||
|
||||
/** Clear all command history. */
|
||||
def clear(): Unit =
|
||||
def clear(): Unit = synchronized {
|
||||
executedCommands.clear()
|
||||
currentIndex = -1
|
||||
}
|
||||
|
||||
/** Check if undo is available. */
|
||||
def canUndo: Boolean = currentIndex >= 0
|
||||
def canUndo: Boolean = synchronized {
|
||||
currentIndex >= 0
|
||||
}
|
||||
|
||||
/** Check if redo is available. */
|
||||
def canRedo: Boolean = currentIndex + 1 < executedCommands.size
|
||||
def canRedo: Boolean = synchronized {
|
||||
currentIndex + 1 < executedCommands.size
|
||||
}
|
||||
|
||||
@@ -67,16 +67,21 @@ trait Observable:
|
||||
private val observers = scala.collection.mutable.Set[Observer]()
|
||||
|
||||
/** Register an observer to receive game events. */
|
||||
def subscribe(observer: Observer): Unit =
|
||||
def subscribe(observer: Observer): Unit = synchronized {
|
||||
observers += observer
|
||||
}
|
||||
|
||||
/** Unregister an observer. */
|
||||
def unsubscribe(observer: Observer): Unit =
|
||||
def unsubscribe(observer: Observer): Unit = synchronized {
|
||||
observers -= observer
|
||||
}
|
||||
|
||||
/** Notify all observers of a game event. */
|
||||
protected def notifyObservers(event: GameEvent): Unit =
|
||||
protected def notifyObservers(event: GameEvent): Unit = synchronized {
|
||||
observers.foreach(_.onGameEvent(event))
|
||||
}
|
||||
|
||||
/** Return current list of observers (for testing). */
|
||||
def observerCount: Int = observers.size
|
||||
def observerCount: Int = synchronized {
|
||||
observers.size
|
||||
}
|
||||
|
||||
+131
@@ -0,0 +1,131 @@
|
||||
package de.nowchess.chess.command
|
||||
|
||||
import de.nowchess.api.board.{Square, File, Rank, Board, Color}
|
||||
import de.nowchess.chess.logic.GameHistory
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
import scala.collection.mutable
|
||||
|
||||
class CommandInvokerThreadSafetyTest extends AnyFunSuite with Matchers:
|
||||
|
||||
private def sq(f: File, r: Rank): Square = Square(f, r)
|
||||
|
||||
private def createMoveCommand(from: Square, to: Square): MoveCommand =
|
||||
MoveCommand(
|
||||
from = from,
|
||||
to = to,
|
||||
moveResult = Some(MoveResult.Successful(Board.initial, GameHistory.empty, Color.White, None)),
|
||||
previousBoard = Some(Board.initial),
|
||||
previousHistory = Some(GameHistory.empty),
|
||||
previousTurn = Some(Color.White)
|
||||
)
|
||||
|
||||
test("CommandInvoker is thread-safe for concurrent execute and history reads"):
|
||||
val invoker = new CommandInvoker()
|
||||
@volatile var raceDetected = false
|
||||
val exceptions = mutable.ListBuffer[Exception]()
|
||||
|
||||
// Thread 1: executes commands
|
||||
val executorThread = new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
try {
|
||||
for i <- 1 to 1000 do
|
||||
val cmd = createMoveCommand(
|
||||
sq(File.E, Rank.R2),
|
||||
sq(File.E, Rank.R4)
|
||||
)
|
||||
invoker.execute(cmd)
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
exceptions += e
|
||||
raceDetected = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Thread 2: reads history during execution
|
||||
val readerThread = new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
try {
|
||||
for _ <- 1 to 1000 do
|
||||
val _ = invoker.history
|
||||
val _ = invoker.getCurrentIndex
|
||||
Thread.sleep(0) // Yield to increase contention
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
exceptions += e
|
||||
raceDetected = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
executorThread.start()
|
||||
readerThread.start()
|
||||
executorThread.join()
|
||||
readerThread.join()
|
||||
|
||||
exceptions.isEmpty shouldBe true
|
||||
raceDetected shouldBe false
|
||||
|
||||
test("CommandInvoker is thread-safe for concurrent execute, undo, and redo"):
|
||||
val invoker = new CommandInvoker()
|
||||
@volatile var raceDetected = false
|
||||
val exceptions = mutable.ListBuffer[Exception]()
|
||||
|
||||
// Pre-populate with some commands
|
||||
for _ <- 1 to 5 do
|
||||
invoker.execute(createMoveCommand(sq(File.E, Rank.R2), sq(File.E, Rank.R4)))
|
||||
|
||||
// Thread 1: executes new commands
|
||||
val executorThread = new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
try {
|
||||
for _ <- 1 to 500 do
|
||||
invoker.execute(createMoveCommand(sq(File.D, Rank.R2), sq(File.D, Rank.R4)))
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
exceptions += e
|
||||
raceDetected = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Thread 2: undoes commands
|
||||
val undoThread = new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
try {
|
||||
for _ <- 1 to 500 do
|
||||
if invoker.canUndo then
|
||||
invoker.undo()
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
exceptions += e
|
||||
raceDetected = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Thread 3: redoes commands
|
||||
val redoThread = new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
try {
|
||||
for _ <- 1 to 500 do
|
||||
if invoker.canRedo then
|
||||
invoker.redo()
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
exceptions += e
|
||||
raceDetected = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
executorThread.start()
|
||||
undoThread.start()
|
||||
redoThread.start()
|
||||
executorThread.join()
|
||||
undoThread.join()
|
||||
redoThread.join()
|
||||
|
||||
exceptions.isEmpty shouldBe true
|
||||
raceDetected shouldBe false
|
||||
+168
@@ -0,0 +1,168 @@
|
||||
package de.nowchess.chess.observer
|
||||
|
||||
import de.nowchess.api.board.{Board, Color}
|
||||
import de.nowchess.chess.logic.GameHistory
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
import scala.collection.mutable
|
||||
|
||||
class ObservableThreadSafetyTest extends AnyFunSuite with Matchers:
|
||||
|
||||
private class TestObservable extends Observable:
|
||||
def testNotifyObservers(event: GameEvent): Unit =
|
||||
notifyObservers(event)
|
||||
|
||||
private class CountingObserver extends Observer:
|
||||
@volatile private var eventCount = 0
|
||||
@volatile private var lastEvent: Option[GameEvent] = None
|
||||
|
||||
def onGameEvent(event: GameEvent): Unit =
|
||||
eventCount += 1
|
||||
lastEvent = Some(event)
|
||||
|
||||
private def createTestEvent(): GameEvent =
|
||||
BoardResetEvent(
|
||||
board = Board.initial,
|
||||
history = GameHistory.empty,
|
||||
turn = Color.White
|
||||
)
|
||||
|
||||
test("Observable is thread-safe for concurrent subscribe and notify"):
|
||||
val observable = new TestObservable()
|
||||
val testEvent = createTestEvent()
|
||||
@volatile var raceConditionCaught = false
|
||||
|
||||
// Thread 1: repeatedly notifies observers with long iteration
|
||||
val notifierThread = new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
try {
|
||||
for _ <- 1 to 500000 do
|
||||
observable.testNotifyObservers(testEvent)
|
||||
} catch {
|
||||
case _: java.util.ConcurrentModificationException =>
|
||||
raceConditionCaught = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Thread 2: rapidly subscribes/unsubscribes observers during notify
|
||||
val subscriberThread = new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
try {
|
||||
for _ <- 1 to 500000 do
|
||||
val obs = new CountingObserver()
|
||||
observable.subscribe(obs)
|
||||
observable.unsubscribe(obs)
|
||||
} catch {
|
||||
case _: java.util.ConcurrentModificationException =>
|
||||
raceConditionCaught = true
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
notifierThread.start()
|
||||
subscriberThread.start()
|
||||
notifierThread.join()
|
||||
subscriberThread.join()
|
||||
|
||||
raceConditionCaught shouldBe false
|
||||
|
||||
test("Observable is thread-safe for concurrent subscribe, unsubscribe, and notify"):
|
||||
val observable = new TestObservable()
|
||||
val testEvent = createTestEvent()
|
||||
val exceptions = mutable.ListBuffer[Exception]()
|
||||
val observers = mutable.ListBuffer[CountingObserver]()
|
||||
|
||||
// Pre-subscribe some observers
|
||||
for _ <- 1 to 10 do
|
||||
val obs = new CountingObserver()
|
||||
observers += obs
|
||||
observable.subscribe(obs)
|
||||
|
||||
// Thread 1: notifies observers
|
||||
val notifierThread = new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
try {
|
||||
for _ <- 1 to 5000 do
|
||||
observable.testNotifyObservers(testEvent)
|
||||
} catch {
|
||||
case e: Exception => exceptions += e
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Thread 2: subscribes new observers
|
||||
val subscriberThread = new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
try {
|
||||
for _ <- 1 to 5000 do
|
||||
val obs = new CountingObserver()
|
||||
observable.subscribe(obs)
|
||||
} catch {
|
||||
case e: Exception => exceptions += e
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Thread 3: unsubscribes observers
|
||||
val unsubscriberThread = new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
try {
|
||||
for i <- 1 to 5000 do
|
||||
if observers.nonEmpty then
|
||||
val obs = observers(i % observers.size)
|
||||
observable.unsubscribe(obs)
|
||||
} catch {
|
||||
case e: Exception => exceptions += e
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
notifierThread.start()
|
||||
subscriberThread.start()
|
||||
unsubscriberThread.start()
|
||||
notifierThread.join()
|
||||
subscriberThread.join()
|
||||
unsubscriberThread.join()
|
||||
|
||||
exceptions.isEmpty shouldBe true
|
||||
|
||||
test("Observable.observerCount is thread-safe during concurrent modifications"):
|
||||
val observable = new TestObservable()
|
||||
val exceptions = mutable.ListBuffer[Exception]()
|
||||
val countResults = mutable.ListBuffer[Int]()
|
||||
|
||||
// Thread 1: subscribes observers
|
||||
val subscriberThread = new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
try {
|
||||
for _ <- 1 to 500 do
|
||||
observable.subscribe(new CountingObserver())
|
||||
} catch {
|
||||
case e: Exception => exceptions += e
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Thread 2: reads observer count
|
||||
val readerThread = new Thread(new Runnable {
|
||||
def run(): Unit = {
|
||||
try {
|
||||
for _ <- 1 to 500 do
|
||||
val count = observable.observerCount
|
||||
countResults += count
|
||||
} catch {
|
||||
case e: Exception => exceptions += e
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
subscriberThread.start()
|
||||
readerThread.start()
|
||||
subscriberThread.join()
|
||||
readerThread.join()
|
||||
|
||||
exceptions.isEmpty shouldBe true
|
||||
// Count should never go backwards
|
||||
for i <- 1 until countResults.size do
|
||||
countResults(i) >= countResults(i - 1) shouldBe true
|
||||
Reference in New Issue
Block a user