diff --git a/modules/coordinator/src/main/proto/coordinator_service.proto b/modules/coordinator/src/main/proto/coordinator_service.proto index 2f49119..8c00215 100644 --- a/modules/coordinator/src/main/proto/coordinator_service.proto +++ b/modules/coordinator/src/main/proto/coordinator_service.proto @@ -51,7 +51,9 @@ message EvictGamesResponse { int32 evictedCount = 1; } -message DrainInstanceRequest {} +message DrainInstanceRequest { + string instanceId = 1; +} message DrainInstanceResponse { int32 gamesMigrated = 1; diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/grpc/CoordinatorGrpcServer.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/grpc/CoordinatorGrpcServer.scala index 37777ea..a781c01 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/grpc/CoordinatorGrpcServer.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/grpc/CoordinatorGrpcServer.scala @@ -90,10 +90,10 @@ class CoordinatorGrpcServer extends CoordinatorServiceGrpc.CoordinatorServiceImp request: DrainInstanceRequest, responseObserver: StreamObserver[DrainInstanceResponse], ): Unit = - log.info("Drain instance request") - val response = DrainInstanceResponse - .newBuilder() - .setGamesMigrated(0) - .build() + val instanceId = request.getInstanceId + log.infof("Drain request for instance %s", instanceId) + val gamesBefore = instanceRegistry.getInstance(instanceId).map(_.subscriptionCount).getOrElse(0) + failoverService.onInstanceStreamDropped(instanceId) + val response = DrainInstanceResponse.newBuilder().setGamesMigrated(gamesBefore).build() responseObserver.onNext(response) responseObserver.onCompleted() diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/grpc/CoreGrpcClient.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/grpc/CoreGrpcClient.scala index d5f3618..88489ab 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/grpc/CoreGrpcClient.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/grpc/CoreGrpcClient.scala @@ -1,82 +1,65 @@ package de.nowchess.coordinator.grpc import jakarta.enterprise.context.ApplicationScoped +import jakarta.annotation.PreDestroy import org.jboss.logging.Logger import io.grpc.ManagedChannel import io.grpc.ManagedChannelBuilder import de.nowchess.coordinator.proto.{CoordinatorServiceGrpc, *} import scala.jdk.CollectionConverters.* +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.TimeUnit @ApplicationScoped class CoreGrpcClient: - private val log = Logger.getLogger(classOf[CoreGrpcClient]) + private val log = Logger.getLogger(classOf[CoreGrpcClient]) + private val channels = ConcurrentHashMap[String, ManagedChannel]() + + private def getChannel(host: String, port: Int): ManagedChannel = + channels.computeIfAbsent(s"$host:$port", _ => + ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(), + ) + + private def evictStaleChannel(host: String, port: Int): Unit = + Option(channels.remove(s"$host:$port")).foreach(_.shutdownNow()) + + @PreDestroy + def shutdown(): Unit = + channels.values.asScala.foreach { ch => + ch.shutdown() + if !ch.awaitTermination(5, TimeUnit.SECONDS) then ch.shutdownNow() + } + channels.clear() def batchResubscribeGames(host: String, port: Int, gameIds: List[String]): Int = - val channel = createChannel(host, port) try - val stub = CoordinatorServiceGrpc.newStub(channel) - val request = BatchResubscribeRequest - .newBuilder() - .addAllGameIds(gameIds.asJava) - .build() - - val latch = new java.util.concurrent.CountDownLatch(1) - var result = 0 - - stub.batchResubscribeGames( - request, - new io.grpc.stub.StreamObserver[BatchResubscribeResponse]: - override def onNext(response: BatchResubscribeResponse): Unit = - result = response.getSubscribedCount - - override def onError(t: Throwable): Unit = - log.warnf(t, "batchResubscribeGames RPC failed for %s:%d", host, port) - latch.countDown() - - override def onCompleted(): Unit = - latch.countDown(), - ) - - latch.await() - result - finally channel.shutdown() + val stub = CoordinatorServiceGrpc.newBlockingStub(getChannel(host, port)) + val request = BatchResubscribeRequest.newBuilder().addAllGameIds(gameIds.asJava).build() + stub.batchResubscribeGames(request).getSubscribedCount + catch + case ex: Exception => + log.warnf(ex, "batchResubscribeGames RPC failed for %s:%d", host, port) + evictStaleChannel(host, port) + 0 def unsubscribeGames(host: String, port: Int, gameIds: List[String]): Int = - val channel = createChannel(host, port) try - val stub = CoordinatorServiceGrpc.newBlockingStub(channel) - val request = UnsubscribeGamesRequest - .newBuilder() - .addAllGameIds(gameIds.asJava) - .build() - - val response = stub.unsubscribeGames(request) - response.getUnsubscribedCount + val stub = CoordinatorServiceGrpc.newBlockingStub(getChannel(host, port)) + val request = UnsubscribeGamesRequest.newBuilder().addAllGameIds(gameIds.asJava).build() + stub.unsubscribeGames(request).getUnsubscribedCount catch case ex: Exception => log.warnf(ex, "unsubscribeGames RPC failed for %s:%d", host, port) + evictStaleChannel(host, port) 0 - finally channel.shutdown() def evictGames(host: String, port: Int, gameIds: List[String]): Int = - val channel = createChannel(host, port) try - val stub = CoordinatorServiceGrpc.newBlockingStub(channel) - val request = EvictGamesRequest - .newBuilder() - .addAllGameIds(gameIds.asJava) - .build() - - val response = stub.evictGames(request) - response.getEvictedCount + val stub = CoordinatorServiceGrpc.newBlockingStub(getChannel(host, port)) + val request = EvictGamesRequest.newBuilder().addAllGameIds(gameIds.asJava).build() + stub.evictGames(request).getEvictedCount catch case ex: Exception => log.warnf(ex, "evictGames RPC failed for %s:%d", host, port) + evictStaleChannel(host, port) 0 - finally channel.shutdown() - - private def createChannel(host: String, port: Int): ManagedChannel = - ManagedChannelBuilder - .forAddress(host, port) - .usePlaintext() - .build() diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/resource/CoordinatorResource.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/resource/CoordinatorResource.scala index 3725d9e..5735459 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/resource/CoordinatorResource.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/resource/CoordinatorResource.scala @@ -78,6 +78,7 @@ class CoordinatorResource: @Produces(Array(MediaType.APPLICATION_JSON)) def triggerScaleUp: scala.collection.Map[String, String] = log.info("Manual scale up triggered") + autoScaler.scaleUp() Map("status" -> "scale_up_started") @POST @@ -85,6 +86,7 @@ class CoordinatorResource: @Produces(Array(MediaType.APPLICATION_JSON)) def triggerScaleDown: scala.collection.Map[String, String] = log.info("Manual scale down triggered") + autoScaler.scaleDown() Map("status" -> "scale_down_started") case class MetricsDto( diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/AutoScaler.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/AutoScaler.scala index 9db7f10..4d0010e 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/AutoScaler.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/AutoScaler.scala @@ -3,8 +3,10 @@ package de.nowchess.coordinator.service import jakarta.enterprise.context.ApplicationScoped import jakarta.inject.Inject import de.nowchess.coordinator.config.CoordinatorConfig +import io.fabric8.kubernetes.api.model.GenericKubernetesResource import io.fabric8.kubernetes.client.KubernetesClient import org.jboss.logging.Logger + import scala.compiletime.uninitialized @ApplicationScoped @@ -19,30 +21,26 @@ class AutoScaler: private var instanceRegistry: InstanceRegistry = uninitialized private val log = Logger.getLogger(classOf[AutoScaler]) - private var lastScaleTime = 0L + private val lastScaleTime = new java.util.concurrent.atomic.AtomicLong(0L) def checkAndScale: Unit = if !config.autoScaleEnabled then return - val now = System.currentTimeMillis() - if now - lastScaleTime < 120000 then // 2 minute backoff - return + val now = System.currentTimeMillis() + val last = lastScaleTime.get() + if now - last < 120000 then return + if !lastScaleTime.compareAndSet(last, now) then return val instances = instanceRegistry.getAllInstances.filter(_.state == "HEALTHY") if instances.isEmpty then return - val avgLoad = instances.map(_.subscriptionCount).sum.toDouble / instances.size - val maxCapacity = config.maxGamesPerCore * instances.size + val avgLoad = instances.map(_.subscriptionCount).sum.toDouble / instances.size - if avgLoad > config.scaleUpThreshold * config.maxGamesPerCore then - scaleUp() - lastScaleTime = now + if avgLoad > config.scaleUpThreshold * config.maxGamesPerCore then scaleUp() else if avgLoad < config.scaleDownThreshold * config.maxGamesPerCore && instances.size > config.scaleMinReplicas - then - scaleDown() - lastScaleTime = now + then scaleDown() - private def scaleUp(): Unit = + def scaleUp(): Unit = log.info("Scaling up Argo Rollout") if kubeClient == null then log.warn("Kubernetes client not available, cannot scale") @@ -50,7 +48,7 @@ class AutoScaler: try val rollout = kubeClient - .resources(classOf[io.fabric8.kubernetes.api.model.GenericKubernetesResource]) + .resources(classOf[GenericKubernetesResource]) .inNamespace(config.k8sNamespace) .withName(config.k8sRolloutName) .get() @@ -63,7 +61,7 @@ class AutoScaler: if currentReplicas < maxReplicas then spec.put("replicas", currentReplicas + 1) kubeClient - .resources(classOf[io.fabric8.kubernetes.api.model.GenericKubernetesResource]) + .resources(classOf[GenericKubernetesResource]) .inNamespace(config.k8sNamespace) .withName(config.k8sRolloutName) .createOrReplace(rollout) @@ -73,7 +71,7 @@ class AutoScaler: case ex: Exception => log.warnf(ex, "Failed to scale up %s", config.k8sRolloutName) - private def scaleDown(): Unit = + def scaleDown(): Unit = log.info("Scaling down Argo Rollout") if kubeClient == null then log.warn("Kubernetes client not available, cannot scale") @@ -81,7 +79,7 @@ class AutoScaler: try val rollout = kubeClient - .resources(classOf[io.fabric8.kubernetes.api.model.GenericKubernetesResource]) + .resources(classOf[GenericKubernetesResource]) .inNamespace(config.k8sNamespace) .withName(config.k8sRolloutName) .get() @@ -94,7 +92,7 @@ class AutoScaler: if currentReplicas > minReplicas then spec.put("replicas", currentReplicas - 1) kubeClient - .resources(classOf[io.fabric8.kubernetes.api.model.GenericKubernetesResource]) + .resources(classOf[GenericKubernetesResource]) .inNamespace(config.k8sNamespace) .withName(config.k8sRolloutName) .createOrReplace(rollout) diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/FailoverService.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/FailoverService.scala index f3bf3fb..40f724d 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/FailoverService.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/FailoverService.scala @@ -4,9 +4,7 @@ import jakarta.enterprise.context.ApplicationScoped import jakarta.inject.Inject import org.redisson.api.RedissonClient import scala.jdk.CollectionConverters.* -import scala.concurrent.duration.* import scala.compiletime.uninitialized -import java.time.Instant import org.jboss.logging.Logger import de.nowchess.coordinator.dto.InstanceMetadata import de.nowchess.coordinator.grpc.CoreGrpcClient @@ -67,17 +65,25 @@ class FailoverService: val batches = gameIds.grouped(batchSize).toList batches.zipWithIndex.foreach { case (batch, idx) => - val targetInstance = healthyInstances(idx % healthyInstances.size) - try - val subscribed = coreGrpcClient.batchResubscribeGames( - targetInstance.hostname, - targetInstance.grpcPort, - batch, + var migrated = false + var attempt = 0 + while !migrated && attempt < healthyInstances.size do + val target = healthyInstances((idx + attempt) % healthyInstances.size) + attempt += 1 + try + val subscribed = coreGrpcClient.batchResubscribeGames(target.hostname, target.grpcPort, batch) + if subscribed > 0 then + log.infof("Migrated %d games from %s to %s", subscribed, deadInstanceId, target.instanceId) + migrated = true + catch + case ex: Exception => + log.warnf(ex, "Failed to migrate batch to %s, trying next", target.instanceId) + if !migrated then + log.errorf( + "Failed to migrate batch of %d games from %s to any healthy instance", + batch.size, + deadInstanceId, ) - log.infof("Migrated %d games from %s to %s", subscribed, deadInstanceId, targetInstance.instanceId) - catch - case ex: Exception => - log.warnf(ex, "Failed to migrate batch to %s, will retry", targetInstance.instanceId) } private def cleanupDeadInstance(instanceId: String): Unit = diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/InstanceRegistry.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/InstanceRegistry.scala index 8fe1be7..dd2798f 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/InstanceRegistry.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/InstanceRegistry.scala @@ -4,11 +4,10 @@ import jakarta.enterprise.context.ApplicationScoped import jakarta.inject.Inject import org.redisson.api.RedissonClient import scala.jdk.CollectionConverters.* -import scala.collection.mutable import scala.compiletime.uninitialized import com.fasterxml.jackson.databind.ObjectMapper import de.nowchess.coordinator.dto.InstanceMetadata -import java.time.Instant +import java.util.concurrent.ConcurrentHashMap @ApplicationScoped class InstanceRegistry: @@ -16,29 +15,17 @@ class InstanceRegistry: private var redissonClient: RedissonClient = uninitialized private val mapper = ObjectMapper() - private val instances = mutable.Map[String, InstanceMetadata]() + private val instances = ConcurrentHashMap[String, InstanceMetadata]() private var redisPrefix = "nowchess" def setRedisPrefix(prefix: String): Unit = redisPrefix = prefix def getInstance(instanceId: String): Option[InstanceMetadata] = - instances.get(instanceId) + Option(instances.get(instanceId)) def getAllInstances: List[InstanceMetadata] = - instances.values.toList - - def listInstancesFromRedis: List[InstanceMetadata] = - val pattern = s"$redisPrefix:instances:*" - val keys = redissonClient.getKeys.getKeysByPattern(pattern, 100) - keys.asScala.flatMap { key => - val bucket = redissonClient.getBucket[String](key) - val value = bucket.getAndDelete() - if value != null then - try Some(mapper.readValue(value, classOf[InstanceMetadata])) - catch case _: Exception => None - else None - }.toList + instances.values.asScala.toList def updateInstanceFromRedis(instanceId: String): Unit = val key = s"$redisPrefix:instances:$instanceId" @@ -47,14 +34,13 @@ class InstanceRegistry: if value != null then try val metadata = mapper.readValue(value, classOf[InstanceMetadata]) - instances(instanceId) = metadata + instances.put(instanceId, metadata) catch case _: Exception => () def markInstanceDead(instanceId: String): Unit = - instances.get(instanceId).foreach { inst => - val dead = inst.copy(state = "DEAD") - instances(instanceId) = dead - } + instances.computeIfPresent(instanceId, (_, inst) => inst.copy(state = "DEAD")) + () def removeInstance(instanceId: String): Unit = instances.remove(instanceId) + () diff --git a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/LoadBalancer.scala b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/LoadBalancer.scala index 98b76cd..28e8b15 100644 --- a/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/LoadBalancer.scala +++ b/modules/coordinator/src/main/scala/de/nowchess/coordinator/service/LoadBalancer.scala @@ -25,7 +25,7 @@ class LoadBalancer: private var coreGrpcClient: CoreGrpcClient = uninitialized private val log = Logger.getLogger(classOf[LoadBalancer]) - private var lastRebalanceTime = 0L + private val lastRebalanceTime = new java.util.concurrent.atomic.AtomicLong(0L) private var redisPrefix = "nowchess" def setRedisPrefix(prefix: String): Unit = @@ -34,7 +34,7 @@ class LoadBalancer: def shouldRebalance: Boolean = val now = System.currentTimeMillis() val minInterval = config.rebalanceMinInterval.toMillis - if now - lastRebalanceTime < minInterval then return false + if now - lastRebalanceTime.get() < minInterval then return false val instances = instanceRegistry.getAllInstances if instances.isEmpty then return false @@ -54,11 +54,10 @@ class LoadBalancer: def rebalance: Unit = log.info("Starting rebalance") val startTime = System.currentTimeMillis() - lastRebalanceTime = startTime + lastRebalanceTime.set(startTime) try - val instances = instanceRegistry.getAllInstances - .filter(_.state == "HEALTHY") + val instances = instanceRegistry.getAllInstances.filter(_.state == "HEALTHY") if instances.size < 2 then log.info("Not enough healthy instances for rebalance") @@ -75,23 +74,29 @@ class LoadBalancer: .filter(_.subscriptionCount < avgLoad * 0.8) .sortBy(_.subscriptionCount) - overloaded.foreach { over => - val excess = over.subscriptionCount - avgLoad.toInt - if excess > 0 && underloaded.nonEmpty then - val gamesToMove = getGamesToMove(over.instanceId, excess) - if gamesToMove.nonEmpty then - underloaded.headOption.foreach { under => - try - val unsubscribed = coreGrpcClient.unsubscribeGames(over.hostname, over.grpcPort, gamesToMove) - val subscribed = coreGrpcClient.batchResubscribeGames(under.hostname, under.grpcPort, gamesToMove) + if underloaded.isEmpty then + log.info("No underloaded instances available for rebalance") + return - if subscribed > 0 then - updateRedisGameSets(over.instanceId, under.instanceId, gamesToMove) - log.infof("Moved %d games from %s to %s", subscribed, over.instanceId, under.instanceId) - catch - case ex: Exception => - log.warnf(ex, "Failed to move games from %s to %s", over.instanceId, under.instanceId) - } + var targetIdx = 0 + overloaded.foreach { over => + val excess = math.max(0, over.subscriptionCount - avgLoad.toInt) + val gamesToMove = getGamesToMove(over.instanceId, excess) + if gamesToMove.nonEmpty then + val batchSize = math.max(1, (gamesToMove.size + underloaded.size - 1) / underloaded.size) + gamesToMove.grouped(batchSize).foreach { batch => + val target = underloaded(targetIdx % underloaded.size) + targetIdx += 1 + try + coreGrpcClient.unsubscribeGames(over.hostname, over.grpcPort, batch) + val subscribed = coreGrpcClient.batchResubscribeGames(target.hostname, target.grpcPort, batch) + if subscribed > 0 then + updateRedisGameSets(over.instanceId, target.instanceId, batch) + log.infof("Moved %d games from %s to %s", subscribed, over.instanceId, target.instanceId) + catch + case ex: Exception => + log.warnf(ex, "Failed to move games from %s to %s", over.instanceId, target.instanceId) + } } val elapsed = System.currentTimeMillis() - startTime