diff --git a/modules/coordinator/src/main/resources/application.yml b/modules/coordinator/src/main/resources/application.yml index 4dc8652..a14ae84 100644 --- a/modules/coordinator/src/main/resources/application.yml +++ b/modules/coordinator/src/main/resources/application.yml @@ -45,6 +45,8 @@ nowchess: k8s-namespace: default k8s-rollout-name: nowchess-core k8s-rollout-label-selector: "app=nowchess-core" + startup-validation-timeout: 15s + failover-wait-timeout: 30s --- # dev profile diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/config/CoordinatorConfig.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/config/CoordinatorConfig.scala index 5a95715..c9a033e 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/config/CoordinatorConfig.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/config/CoordinatorConfig.scala @@ -56,3 +56,9 @@ trait CoordinatorConfig: @WithName("k8s-rollout-label-selector") def k8sRolloutLabelSelector: String + + @WithName("startup-validation-timeout") + def startupValidationTimeout: Duration + + @WithName("failover-wait-timeout") + def failoverWaitTimeout: Duration diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/grpc/CoordinatorGrpcServer.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/grpc/CoordinatorGrpcServer.scala index 832c52c..cd0a45f 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/grpc/CoordinatorGrpcServer.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/grpc/CoordinatorGrpcServer.scala @@ -9,6 +9,7 @@ import de.nowchess.coordinator.proto.{CoordinatorServiceGrpc, *} import io.grpc.stub.StreamObserver import com.fasterxml.jackson.databind.ObjectMapper import org.jboss.logging.Logger +import java.util.concurrent.ConcurrentHashMap @GrpcService @Singleton @@ -21,8 +22,9 @@ class CoordinatorGrpcServer extends CoordinatorServiceGrpc.CoordinatorServiceImp private var failoverService: FailoverService = uninitialized // scalafix:on DisableSyntax.var - private val mapper = ObjectMapper() - private val log = Logger.getLogger(classOf[CoordinatorGrpcServer]) + private val mapper = ObjectMapper() + private val log = Logger.getLogger(classOf[CoordinatorGrpcServer]) + private val activeStreams = ConcurrentHashMap.newKeySet[String]() override def heartbeatStream( responseObserver: StreamObserver[CoordinatorCommand], @@ -38,6 +40,7 @@ class CoordinatorGrpcServer extends CoordinatorServiceGrpc.CoordinatorServiceImp lastInstanceId = frame.getInstanceId if !firstFrameSeen then firstFrameSeen = true + activeStreams.add(frame.getInstanceId) log.infof( "First heartbeat from instance %s (host=%s http=%d grpc=%d)", frame.getInstanceId, @@ -60,10 +63,19 @@ class CoordinatorGrpcServer extends CoordinatorServiceGrpc.CoordinatorServiceImp override def onError(t: Throwable): Unit = log.warnf(t, "Heartbeat stream error for instance %s", lastInstanceId) - if lastInstanceId.nonEmpty then failoverService.onInstanceStreamDropped(lastInstanceId) + if lastInstanceId.nonEmpty then + activeStreams.remove(lastInstanceId) + failoverService + .onInstanceStreamDropped(lastInstanceId) + .subscribe() + .`with`( + _ => (), + ex => log.warnf(ex, "Failover for %s failed", lastInstanceId), + ) override def onCompleted: Unit = log.infof("Heartbeat stream completed for instance %s", lastInstanceId) + activeStreams.remove(lastInstanceId) override def batchResubscribeGames( request: BatchResubscribeRequest, @@ -108,7 +120,18 @@ class CoordinatorGrpcServer extends CoordinatorServiceGrpc.CoordinatorServiceImp val instanceId = request.getInstanceId log.infof("Drain request for instance %s", instanceId) val gamesBefore = instanceRegistry.getInstance(instanceId).map(_.subscriptionCount).getOrElse(0) - failoverService.onInstanceStreamDropped(instanceId) - val response = DrainInstanceResponse.newBuilder().setGamesMigrated(gamesBefore).build() - responseObserver.onNext(response) - responseObserver.onCompleted() + failoverService + .onInstanceStreamDropped(instanceId) + .subscribe() + .`with`( + _ => + val response = DrainInstanceResponse.newBuilder().setGamesMigrated(gamesBefore).build() + responseObserver.onNext(response) + responseObserver.onCompleted(), + ex => + log.warnf(ex, "Drain failed for %s", instanceId) + responseObserver.onError(ex), + ) + + def hasActiveStream(instanceId: String): Boolean = + activeStreams.contains(instanceId) diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/resource/CoordinatorResource.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/resource/CoordinatorResource.scala index f02a7ef..427b6e4 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/resource/CoordinatorResource.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/resource/CoordinatorResource.scala @@ -70,7 +70,13 @@ class CoordinatorResource: @Produces(Array(MediaType.APPLICATION_JSON)) def triggerFailover(@PathParam("instanceId") instanceId: String): scala.collection.Map[String, String] = log.infof("Manual failover triggered for instance %s", instanceId) - failoverService.onInstanceStreamDropped(instanceId) + failoverService + .onInstanceStreamDropped(instanceId) + .subscribe() + .`with`( + _ => (), + ex => log.warnf(ex, "Manual failover for %s failed", instanceId), + ) Map("status" -> "failover_started", "instanceId" -> instanceId) @POST diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/FailoverService.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/FailoverService.scala index fb31858..8089fba 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/FailoverService.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/FailoverService.scala @@ -8,6 +8,9 @@ import scala.compiletime.uninitialized import org.jboss.logging.Logger import de.nowchess.coordinator.dto.InstanceMetadata import de.nowchess.coordinator.grpc.CoreGrpcClient +import de.nowchess.coordinator.config.CoordinatorConfig +import io.smallrye.mutiny.Uni +import java.time.Duration @ApplicationScoped class FailoverService: @@ -21,6 +24,9 @@ class FailoverService: @Inject private var coreGrpcClient: CoreGrpcClient = uninitialized + @Inject + private var config: CoordinatorConfig = uninitialized + private val log = Logger.getLogger(classOf[FailoverService]) private var redisPrefix = "nowchess" // scalafix:on DisableSyntax.var @@ -28,7 +34,7 @@ class FailoverService: def setRedisPrefix(prefix: String): Unit = redisPrefix = prefix - def onInstanceStreamDropped(instanceId: String): Unit = + def onInstanceStreamDropped(instanceId: String): Uni[Unit] = log.infof("Instance %s stream dropped, triggering failover", instanceId) val startTime = System.currentTimeMillis() @@ -37,19 +43,32 @@ class FailoverService: val gameIds = getOrphanedGames(instanceId) log.infof("Found %d orphaned games for instance %s", gameIds.size, instanceId) - if gameIds.nonEmpty then - val healthyInstances = instanceRegistry.getAllInstances - .filter(_.state == "HEALTHY") - .sortBy(_.subscriptionCount) + if gameIds.isEmpty then + cleanupDeadInstance(instanceId) + Uni.createFrom().item(()) + else + waitForHealthyInstanceAsync() + .onItem() + .transform { _ => + val healthyInstances = instanceRegistry.getAllInstances + .filter(_.state == "HEALTHY") + .sortBy(_.subscriptionCount) + distributeGames(gameIds, healthyInstances, instanceId) - if healthyInstances.nonEmpty then - distributeGames(gameIds, healthyInstances, instanceId) - - val elapsed = System.currentTimeMillis() - startTime - log.infof("Failover completed in %dms for instance %s", elapsed, instanceId) - else log.warnf("No healthy instances available for failover of %s", instanceId) - - cleanupDeadInstance(instanceId) + val elapsed = System.currentTimeMillis() - startTime + log.infof("Failover completed in %dms for instance %s", elapsed, instanceId) + cleanupDeadInstance(instanceId) + () + } + .onFailure() + .recoverWithItem { _ => + log.errorf( + "No healthy instance appeared within %s — games orphaned for %s", + config.failoverWaitTimeout, + instanceId, + ) + () + } private def getOrphanedGames(instanceId: String): List[String] = val setKey = s"$redisPrefix:instance:$instanceId:games" @@ -101,3 +120,16 @@ class FailoverService: val setKey = s"$redisPrefix:instance:$instanceId:games" redis.key(classOf[String]).del(setKey) log.infof("Cleaned up games set for instance %s", instanceId) + + private def waitForHealthyInstanceAsync(): Uni[InstanceMetadata] = + Uni.createFrom().deferred(() => + instanceRegistry.getAllInstances + .filter(_.state == "HEALTHY") + .sortBy(_.subscriptionCount) + .headOption match + case Some(inst) => Uni.createFrom().item(inst) + case None => Uni.createFrom().failure(new RuntimeException("no healthy instance")) + ).onFailure() + .retry() + .withBackOff(Duration.ofMillis(500)) + .expireIn(config.failoverWaitTimeout.toMillis) diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/HealthMonitor.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/HealthMonitor.scala index ecc4b76..041fd61 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/HealthMonitor.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/HealthMonitor.scala @@ -2,17 +2,23 @@ package de.nowchess.coordinator.service import jakarta.annotation.PostConstruct import jakarta.enterprise.context.ApplicationScoped +import jakarta.enterprise.event.Observes import jakarta.enterprise.inject.Instance import jakarta.inject.Inject import de.nowchess.coordinator.config.CoordinatorConfig import io.fabric8.kubernetes.client.KubernetesClient import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.Watcher +import io.fabric8.kubernetes.client.WatcherException import io.micrometer.core.instrument.MeterRegistry import io.quarkus.redis.datasource.RedisDataSource +import io.quarkus.runtime.StartupEvent import scala.jdk.CollectionConverters.* import org.jboss.logging.Logger import scala.compiletime.uninitialized import java.time.Instant +import de.nowchess.coordinator.grpc.CoordinatorGrpcServer +import de.nowchess.coordinator.dto.InstanceMetadata @ApplicationScoped class HealthMonitor: @@ -32,6 +38,12 @@ class HealthMonitor: @Inject private var meterRegistry: MeterRegistry = uninitialized + @Inject + private var grpcServer: CoordinatorGrpcServer = uninitialized + + @Inject + private var failoverService: FailoverService = uninitialized + private val log = Logger.getLogger(classOf[HealthMonitor]) private var redisPrefix = "nowchess" // scalafix:on DisableSyntax.var @@ -48,6 +60,15 @@ class HealthMonitor: meterRegistry.counter("nowchess.coordinator.health.checks").increment(0) meterRegistry.counter("nowchess.coordinator.pods.unhealthy").increment(0) + def onStartup(@Observes ev: StartupEvent): Unit = + instanceRegistry.loadAllFromRedis() + val loaded = instanceRegistry.getAllInstances + log.infof("Startup: loaded %d instances from Redis", loaded.size) + if loaded.nonEmpty then + val timeoutMs = config.startupValidationTimeout.toMillis + Thread.ofVirtual().start(() => validateStartupInstances(timeoutMs)) + startPodWatch() + def checkInstanceHealth: Unit = meterRegistry.counter("nowchess.coordinator.health.checks").increment() val evicted = instanceRegistry.evictStaleInstances(config.instanceDeadTimeout) @@ -98,41 +119,33 @@ class HealthMonitor: true } - def watchK8sPods: Unit = + private def startPodWatch(): Unit = kubeClientOpt match - case None => - log.debug("Kubernetes client not available for pod watch") + case None => log.debug("K8s client unavailable, skipping pod watch") case Some(kube) => try - val pods = kube + kube .pods() .inNamespace(config.k8sNamespace) .withLabel(config.k8sRolloutLabelSelector) - .list() - .getItems - .asScala + .watch(new Watcher[Pod]: + override def eventReceived(action: Watcher.Action, pod: Pod): Unit = + action match + case Watcher.Action.DELETED => + handlePodGone(pod) + case Watcher.Action.MODIFIED + if Option(pod.getMetadata.getDeletionTimestamp).isDefined => + handlePodTerminating(pod) + case _ => () - val instances = instanceRegistry.getAllInstances - instances.foreach { inst => - val matchingPod = pods.find { pod => - pod.getMetadata.getName.contains(inst.instanceId) - } - - matchingPod match - case Some(pod) => - val isReady = isPodReady(pod) - if !isReady && inst.state == "HEALTHY" then - meterRegistry.counter("nowchess.coordinator.pods.unhealthy").increment() - log.warnf("Pod %s not ready, marking instance %s dead", pod.getMetadata.getName, inst.instanceId) - instanceRegistry.markInstanceDead(inst.instanceId) - deleteK8sPod(inst.instanceId) - case None => - log.warnf("No pod found for instance %s, evicting from registry", inst.instanceId) - instanceRegistry.removeInstance(inst.instanceId) - } + override def onClose(cause: WatcherException): Unit = + if cause != null then + log.warnf(cause, "Pod watch closed, restarting") + startPodWatch() + ) + log.info("Pod watch started") catch - case ex: Exception => - log.warnf(ex, "Failed to watch k8s pods") + case ex: Exception => log.warnf(ex, "Failed to start pod watch") private def isPodReady(pod: Pod): Boolean = Option(pod.getStatus) @@ -164,3 +177,48 @@ class HealthMonitor: catch case ex: Exception => log.warnf(ex, "Failed to delete pod for instance %s", instanceId) + + private def validateStartupInstances(timeoutMs: Long): Unit = + Thread.sleep(timeoutMs) + instanceRegistry.getAllInstances.foreach { inst => + if !grpcServer.hasActiveStream(inst.instanceId) then + log.warnf( + "Startup: instance %s did not reconnect within %dms — evicting", + inst.instanceId, + timeoutMs, + ) + instanceRegistry.removeInstance(inst.instanceId) + deleteK8sPod(inst.instanceId) + } + + private def handlePodTerminating(pod: Pod): Unit = + findRegisteredInstance(pod).foreach { inst => + if inst.state == "HEALTHY" then + meterRegistry.counter("nowchess.coordinator.pods.unhealthy").increment() + log.warnf( + "Pod %s terminating — marking instance %s dead", + pod.getMetadata.getName, + inst.instanceId, + ) + instanceRegistry.markInstanceDead(inst.instanceId) + } + + private def handlePodGone(pod: Pod): Unit = + findRegisteredInstance(pod).foreach { inst => + log.warnf( + "Pod %s deleted — triggering failover for %s", + pod.getMetadata.getName, + inst.instanceId, + ) + failoverService + .onInstanceStreamDropped(inst.instanceId) + .subscribe() + .`with`( + _ => (), + ex => log.warnf(ex, "Failover for %s failed", inst.instanceId), + ) + } + + private def findRegisteredInstance(pod: Pod): Option[InstanceMetadata] = + val podName = pod.getMetadata.getName + instanceRegistry.getAllInstances.find(inst => podName.contains(inst.instanceId)) diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/InstanceRegistry.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/InstanceRegistry.scala index 859c0cd..1da739e 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/InstanceRegistry.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/InstanceRegistry.scala @@ -4,6 +4,7 @@ import jakarta.annotation.PostConstruct import jakarta.enterprise.context.ApplicationScoped import jakarta.inject.Inject import io.quarkus.redis.datasource.ReactiveRedisDataSource +import io.quarkus.redis.datasource.RedisDataSource import scala.jdk.CollectionConverters.* import scala.compiletime.uninitialized import com.fasterxml.jackson.databind.ObjectMapper @@ -19,7 +20,10 @@ class InstanceRegistry: // scalafix:off DisableSyntax.var @Inject private var redis: ReactiveRedisDataSource = uninitialized - private var redisPrefix = "nowchess" + + @Inject + private var syncRedis: RedisDataSource = uninitialized + private var redisPrefix = "nowchess" @Inject private var meterRegistry: MeterRegistry = uninitialized @@ -42,6 +46,21 @@ class InstanceRegistry: def setRedisPrefix(prefix: String): Unit = redisPrefix = prefix + def loadAllFromRedis(): Unit = + val keys = syncRedis.key(classOf[String]).keys(s"$redisPrefix:instances:*") + keys.asScala.foreach { key => + val instanceId = key.stripPrefix(s"$redisPrefix:instances:") + val json = syncRedis.value(classOf[String]).get(key) + if json != null then + try + val metadata = mapper.readValue(json, classOf[InstanceMetadata]) + instances.put(instanceId, metadata) + log.infof("Startup: loaded instance %s from Redis", instanceId) + catch + case ex: Exception => + log.warnf(ex, "Startup: failed to parse instance %s", instanceId) + } + def getInstance(instanceId: String): Option[InstanceMetadata] = Option(instances.get(instanceId))