feat(auto-scaling): enhance AutoScaler with atomic lastScaleTime and improve scaling logic

This commit is contained in:
2026-04-26 19:57:36 +02:00
parent 106b4d3b7e
commit 57e6e5d200
8 changed files with 115 additions and 133 deletions
@@ -51,7 +51,9 @@ message EvictGamesResponse {
int32 evictedCount = 1; int32 evictedCount = 1;
} }
message DrainInstanceRequest {} message DrainInstanceRequest {
string instanceId = 1;
}
message DrainInstanceResponse { message DrainInstanceResponse {
int32 gamesMigrated = 1; int32 gamesMigrated = 1;
@@ -90,10 +90,10 @@ class CoordinatorGrpcServer extends CoordinatorServiceGrpc.CoordinatorServiceImp
request: DrainInstanceRequest, request: DrainInstanceRequest,
responseObserver: StreamObserver[DrainInstanceResponse], responseObserver: StreamObserver[DrainInstanceResponse],
): Unit = ): Unit =
log.info("Drain instance request") val instanceId = request.getInstanceId
val response = DrainInstanceResponse log.infof("Drain request for instance %s", instanceId)
.newBuilder() val gamesBefore = instanceRegistry.getInstance(instanceId).map(_.subscriptionCount).getOrElse(0)
.setGamesMigrated(0) failoverService.onInstanceStreamDropped(instanceId)
.build() val response = DrainInstanceResponse.newBuilder().setGamesMigrated(gamesBefore).build()
responseObserver.onNext(response) responseObserver.onNext(response)
responseObserver.onCompleted() responseObserver.onCompleted()
@@ -1,82 +1,65 @@
package de.nowchess.coordinator.grpc package de.nowchess.coordinator.grpc
import jakarta.enterprise.context.ApplicationScoped import jakarta.enterprise.context.ApplicationScoped
import jakarta.annotation.PreDestroy
import org.jboss.logging.Logger import org.jboss.logging.Logger
import io.grpc.ManagedChannel import io.grpc.ManagedChannel
import io.grpc.ManagedChannelBuilder import io.grpc.ManagedChannelBuilder
import de.nowchess.coordinator.proto.{CoordinatorServiceGrpc, *} import de.nowchess.coordinator.proto.{CoordinatorServiceGrpc, *}
import scala.jdk.CollectionConverters.* import scala.jdk.CollectionConverters.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.TimeUnit
@ApplicationScoped @ApplicationScoped
class CoreGrpcClient: class CoreGrpcClient:
private val log = Logger.getLogger(classOf[CoreGrpcClient]) private val log = Logger.getLogger(classOf[CoreGrpcClient])
private val channels = ConcurrentHashMap[String, ManagedChannel]()
def batchResubscribeGames(host: String, port: Int, gameIds: List[String]): Int = private def getChannel(host: String, port: Int): ManagedChannel =
val channel = createChannel(host, port) channels.computeIfAbsent(s"$host:$port", _ =>
try ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(),
val stub = CoordinatorServiceGrpc.newStub(channel)
val request = BatchResubscribeRequest
.newBuilder()
.addAllGameIds(gameIds.asJava)
.build()
val latch = new java.util.concurrent.CountDownLatch(1)
var result = 0
stub.batchResubscribeGames(
request,
new io.grpc.stub.StreamObserver[BatchResubscribeResponse]:
override def onNext(response: BatchResubscribeResponse): Unit =
result = response.getSubscribedCount
override def onError(t: Throwable): Unit =
log.warnf(t, "batchResubscribeGames RPC failed for %s:%d", host, port)
latch.countDown()
override def onCompleted(): Unit =
latch.countDown(),
) )
latch.await() private def evictStaleChannel(host: String, port: Int): Unit =
result Option(channels.remove(s"$host:$port")).foreach(_.shutdownNow())
finally channel.shutdown()
@PreDestroy
def shutdown(): Unit =
channels.values.asScala.foreach { ch =>
ch.shutdown()
if !ch.awaitTermination(5, TimeUnit.SECONDS) then ch.shutdownNow()
}
channels.clear()
def batchResubscribeGames(host: String, port: Int, gameIds: List[String]): Int =
try
val stub = CoordinatorServiceGrpc.newBlockingStub(getChannel(host, port))
val request = BatchResubscribeRequest.newBuilder().addAllGameIds(gameIds.asJava).build()
stub.batchResubscribeGames(request).getSubscribedCount
catch
case ex: Exception =>
log.warnf(ex, "batchResubscribeGames RPC failed for %s:%d", host, port)
evictStaleChannel(host, port)
0
def unsubscribeGames(host: String, port: Int, gameIds: List[String]): Int = def unsubscribeGames(host: String, port: Int, gameIds: List[String]): Int =
val channel = createChannel(host, port)
try try
val stub = CoordinatorServiceGrpc.newBlockingStub(channel) val stub = CoordinatorServiceGrpc.newBlockingStub(getChannel(host, port))
val request = UnsubscribeGamesRequest val request = UnsubscribeGamesRequest.newBuilder().addAllGameIds(gameIds.asJava).build()
.newBuilder() stub.unsubscribeGames(request).getUnsubscribedCount
.addAllGameIds(gameIds.asJava)
.build()
val response = stub.unsubscribeGames(request)
response.getUnsubscribedCount
catch catch
case ex: Exception => case ex: Exception =>
log.warnf(ex, "unsubscribeGames RPC failed for %s:%d", host, port) log.warnf(ex, "unsubscribeGames RPC failed for %s:%d", host, port)
evictStaleChannel(host, port)
0 0
finally channel.shutdown()
def evictGames(host: String, port: Int, gameIds: List[String]): Int = def evictGames(host: String, port: Int, gameIds: List[String]): Int =
val channel = createChannel(host, port)
try try
val stub = CoordinatorServiceGrpc.newBlockingStub(channel) val stub = CoordinatorServiceGrpc.newBlockingStub(getChannel(host, port))
val request = EvictGamesRequest val request = EvictGamesRequest.newBuilder().addAllGameIds(gameIds.asJava).build()
.newBuilder() stub.evictGames(request).getEvictedCount
.addAllGameIds(gameIds.asJava)
.build()
val response = stub.evictGames(request)
response.getEvictedCount
catch catch
case ex: Exception => case ex: Exception =>
log.warnf(ex, "evictGames RPC failed for %s:%d", host, port) log.warnf(ex, "evictGames RPC failed for %s:%d", host, port)
evictStaleChannel(host, port)
0 0
finally channel.shutdown()
private def createChannel(host: String, port: Int): ManagedChannel =
ManagedChannelBuilder
.forAddress(host, port)
.usePlaintext()
.build()
@@ -78,6 +78,7 @@ class CoordinatorResource:
@Produces(Array(MediaType.APPLICATION_JSON)) @Produces(Array(MediaType.APPLICATION_JSON))
def triggerScaleUp: scala.collection.Map[String, String] = def triggerScaleUp: scala.collection.Map[String, String] =
log.info("Manual scale up triggered") log.info("Manual scale up triggered")
autoScaler.scaleUp()
Map("status" -> "scale_up_started") Map("status" -> "scale_up_started")
@POST @POST
@@ -85,6 +86,7 @@ class CoordinatorResource:
@Produces(Array(MediaType.APPLICATION_JSON)) @Produces(Array(MediaType.APPLICATION_JSON))
def triggerScaleDown: scala.collection.Map[String, String] = def triggerScaleDown: scala.collection.Map[String, String] =
log.info("Manual scale down triggered") log.info("Manual scale down triggered")
autoScaler.scaleDown()
Map("status" -> "scale_down_started") Map("status" -> "scale_down_started")
case class MetricsDto( case class MetricsDto(
@@ -3,8 +3,10 @@ package de.nowchess.coordinator.service
import jakarta.enterprise.context.ApplicationScoped import jakarta.enterprise.context.ApplicationScoped
import jakarta.inject.Inject import jakarta.inject.Inject
import de.nowchess.coordinator.config.CoordinatorConfig import de.nowchess.coordinator.config.CoordinatorConfig
import io.fabric8.kubernetes.api.model.GenericKubernetesResource
import io.fabric8.kubernetes.client.KubernetesClient import io.fabric8.kubernetes.client.KubernetesClient
import org.jboss.logging.Logger import org.jboss.logging.Logger
import scala.compiletime.uninitialized import scala.compiletime.uninitialized
@ApplicationScoped @ApplicationScoped
@@ -19,30 +21,26 @@ class AutoScaler:
private var instanceRegistry: InstanceRegistry = uninitialized private var instanceRegistry: InstanceRegistry = uninitialized
private val log = Logger.getLogger(classOf[AutoScaler]) private val log = Logger.getLogger(classOf[AutoScaler])
private var lastScaleTime = 0L private val lastScaleTime = new java.util.concurrent.atomic.AtomicLong(0L)
def checkAndScale: Unit = def checkAndScale: Unit =
if !config.autoScaleEnabled then return if !config.autoScaleEnabled then return
val now = System.currentTimeMillis() val now = System.currentTimeMillis()
if now - lastScaleTime < 120000 then // 2 minute backoff val last = lastScaleTime.get()
return if now - last < 120000 then return
if !lastScaleTime.compareAndSet(last, now) then return
val instances = instanceRegistry.getAllInstances.filter(_.state == "HEALTHY") val instances = instanceRegistry.getAllInstances.filter(_.state == "HEALTHY")
if instances.isEmpty then return if instances.isEmpty then return
val avgLoad = instances.map(_.subscriptionCount).sum.toDouble / instances.size val avgLoad = instances.map(_.subscriptionCount).sum.toDouble / instances.size
val maxCapacity = config.maxGamesPerCore * instances.size
if avgLoad > config.scaleUpThreshold * config.maxGamesPerCore then if avgLoad > config.scaleUpThreshold * config.maxGamesPerCore then scaleUp()
scaleUp()
lastScaleTime = now
else if avgLoad < config.scaleDownThreshold * config.maxGamesPerCore && instances.size > config.scaleMinReplicas else if avgLoad < config.scaleDownThreshold * config.maxGamesPerCore && instances.size > config.scaleMinReplicas
then then scaleDown()
scaleDown()
lastScaleTime = now
private def scaleUp(): Unit = def scaleUp(): Unit =
log.info("Scaling up Argo Rollout") log.info("Scaling up Argo Rollout")
if kubeClient == null then if kubeClient == null then
log.warn("Kubernetes client not available, cannot scale") log.warn("Kubernetes client not available, cannot scale")
@@ -50,7 +48,7 @@ class AutoScaler:
try try
val rollout = kubeClient val rollout = kubeClient
.resources(classOf[io.fabric8.kubernetes.api.model.GenericKubernetesResource]) .resources(classOf[GenericKubernetesResource])
.inNamespace(config.k8sNamespace) .inNamespace(config.k8sNamespace)
.withName(config.k8sRolloutName) .withName(config.k8sRolloutName)
.get() .get()
@@ -63,7 +61,7 @@ class AutoScaler:
if currentReplicas < maxReplicas then if currentReplicas < maxReplicas then
spec.put("replicas", currentReplicas + 1) spec.put("replicas", currentReplicas + 1)
kubeClient kubeClient
.resources(classOf[io.fabric8.kubernetes.api.model.GenericKubernetesResource]) .resources(classOf[GenericKubernetesResource])
.inNamespace(config.k8sNamespace) .inNamespace(config.k8sNamespace)
.withName(config.k8sRolloutName) .withName(config.k8sRolloutName)
.createOrReplace(rollout) .createOrReplace(rollout)
@@ -73,7 +71,7 @@ class AutoScaler:
case ex: Exception => case ex: Exception =>
log.warnf(ex, "Failed to scale up %s", config.k8sRolloutName) log.warnf(ex, "Failed to scale up %s", config.k8sRolloutName)
private def scaleDown(): Unit = def scaleDown(): Unit =
log.info("Scaling down Argo Rollout") log.info("Scaling down Argo Rollout")
if kubeClient == null then if kubeClient == null then
log.warn("Kubernetes client not available, cannot scale") log.warn("Kubernetes client not available, cannot scale")
@@ -81,7 +79,7 @@ class AutoScaler:
try try
val rollout = kubeClient val rollout = kubeClient
.resources(classOf[io.fabric8.kubernetes.api.model.GenericKubernetesResource]) .resources(classOf[GenericKubernetesResource])
.inNamespace(config.k8sNamespace) .inNamespace(config.k8sNamespace)
.withName(config.k8sRolloutName) .withName(config.k8sRolloutName)
.get() .get()
@@ -94,7 +92,7 @@ class AutoScaler:
if currentReplicas > minReplicas then if currentReplicas > minReplicas then
spec.put("replicas", currentReplicas - 1) spec.put("replicas", currentReplicas - 1)
kubeClient kubeClient
.resources(classOf[io.fabric8.kubernetes.api.model.GenericKubernetesResource]) .resources(classOf[GenericKubernetesResource])
.inNamespace(config.k8sNamespace) .inNamespace(config.k8sNamespace)
.withName(config.k8sRolloutName) .withName(config.k8sRolloutName)
.createOrReplace(rollout) .createOrReplace(rollout)
@@ -4,9 +4,7 @@ import jakarta.enterprise.context.ApplicationScoped
import jakarta.inject.Inject import jakarta.inject.Inject
import org.redisson.api.RedissonClient import org.redisson.api.RedissonClient
import scala.jdk.CollectionConverters.* import scala.jdk.CollectionConverters.*
import scala.concurrent.duration.*
import scala.compiletime.uninitialized import scala.compiletime.uninitialized
import java.time.Instant
import org.jboss.logging.Logger import org.jboss.logging.Logger
import de.nowchess.coordinator.dto.InstanceMetadata import de.nowchess.coordinator.dto.InstanceMetadata
import de.nowchess.coordinator.grpc.CoreGrpcClient import de.nowchess.coordinator.grpc.CoreGrpcClient
@@ -67,17 +65,25 @@ class FailoverService:
val batches = gameIds.grouped(batchSize).toList val batches = gameIds.grouped(batchSize).toList
batches.zipWithIndex.foreach { case (batch, idx) => batches.zipWithIndex.foreach { case (batch, idx) =>
val targetInstance = healthyInstances(idx % healthyInstances.size) var migrated = false
var attempt = 0
while !migrated && attempt < healthyInstances.size do
val target = healthyInstances((idx + attempt) % healthyInstances.size)
attempt += 1
try try
val subscribed = coreGrpcClient.batchResubscribeGames( val subscribed = coreGrpcClient.batchResubscribeGames(target.hostname, target.grpcPort, batch)
targetInstance.hostname, if subscribed > 0 then
targetInstance.grpcPort, log.infof("Migrated %d games from %s to %s", subscribed, deadInstanceId, target.instanceId)
batch, migrated = true
)
log.infof("Migrated %d games from %s to %s", subscribed, deadInstanceId, targetInstance.instanceId)
catch catch
case ex: Exception => case ex: Exception =>
log.warnf(ex, "Failed to migrate batch to %s, will retry", targetInstance.instanceId) log.warnf(ex, "Failed to migrate batch to %s, trying next", target.instanceId)
if !migrated then
log.errorf(
"Failed to migrate batch of %d games from %s to any healthy instance",
batch.size,
deadInstanceId,
)
} }
private def cleanupDeadInstance(instanceId: String): Unit = private def cleanupDeadInstance(instanceId: String): Unit =
@@ -4,11 +4,10 @@ import jakarta.enterprise.context.ApplicationScoped
import jakarta.inject.Inject import jakarta.inject.Inject
import org.redisson.api.RedissonClient import org.redisson.api.RedissonClient
import scala.jdk.CollectionConverters.* import scala.jdk.CollectionConverters.*
import scala.collection.mutable
import scala.compiletime.uninitialized import scala.compiletime.uninitialized
import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.ObjectMapper
import de.nowchess.coordinator.dto.InstanceMetadata import de.nowchess.coordinator.dto.InstanceMetadata
import java.time.Instant import java.util.concurrent.ConcurrentHashMap
@ApplicationScoped @ApplicationScoped
class InstanceRegistry: class InstanceRegistry:
@@ -16,29 +15,17 @@ class InstanceRegistry:
private var redissonClient: RedissonClient = uninitialized private var redissonClient: RedissonClient = uninitialized
private val mapper = ObjectMapper() private val mapper = ObjectMapper()
private val instances = mutable.Map[String, InstanceMetadata]() private val instances = ConcurrentHashMap[String, InstanceMetadata]()
private var redisPrefix = "nowchess" private var redisPrefix = "nowchess"
def setRedisPrefix(prefix: String): Unit = def setRedisPrefix(prefix: String): Unit =
redisPrefix = prefix redisPrefix = prefix
def getInstance(instanceId: String): Option[InstanceMetadata] = def getInstance(instanceId: String): Option[InstanceMetadata] =
instances.get(instanceId) Option(instances.get(instanceId))
def getAllInstances: List[InstanceMetadata] = def getAllInstances: List[InstanceMetadata] =
instances.values.toList instances.values.asScala.toList
def listInstancesFromRedis: List[InstanceMetadata] =
val pattern = s"$redisPrefix:instances:*"
val keys = redissonClient.getKeys.getKeysByPattern(pattern, 100)
keys.asScala.flatMap { key =>
val bucket = redissonClient.getBucket[String](key)
val value = bucket.getAndDelete()
if value != null then
try Some(mapper.readValue(value, classOf[InstanceMetadata]))
catch case _: Exception => None
else None
}.toList
def updateInstanceFromRedis(instanceId: String): Unit = def updateInstanceFromRedis(instanceId: String): Unit =
val key = s"$redisPrefix:instances:$instanceId" val key = s"$redisPrefix:instances:$instanceId"
@@ -47,14 +34,13 @@ class InstanceRegistry:
if value != null then if value != null then
try try
val metadata = mapper.readValue(value, classOf[InstanceMetadata]) val metadata = mapper.readValue(value, classOf[InstanceMetadata])
instances(instanceId) = metadata instances.put(instanceId, metadata)
catch case _: Exception => () catch case _: Exception => ()
def markInstanceDead(instanceId: String): Unit = def markInstanceDead(instanceId: String): Unit =
instances.get(instanceId).foreach { inst => instances.computeIfPresent(instanceId, (_, inst) => inst.copy(state = "DEAD"))
val dead = inst.copy(state = "DEAD") ()
instances(instanceId) = dead
}
def removeInstance(instanceId: String): Unit = def removeInstance(instanceId: String): Unit =
instances.remove(instanceId) instances.remove(instanceId)
()
@@ -25,7 +25,7 @@ class LoadBalancer:
private var coreGrpcClient: CoreGrpcClient = uninitialized private var coreGrpcClient: CoreGrpcClient = uninitialized
private val log = Logger.getLogger(classOf[LoadBalancer]) private val log = Logger.getLogger(classOf[LoadBalancer])
private var lastRebalanceTime = 0L private val lastRebalanceTime = new java.util.concurrent.atomic.AtomicLong(0L)
private var redisPrefix = "nowchess" private var redisPrefix = "nowchess"
def setRedisPrefix(prefix: String): Unit = def setRedisPrefix(prefix: String): Unit =
@@ -34,7 +34,7 @@ class LoadBalancer:
def shouldRebalance: Boolean = def shouldRebalance: Boolean =
val now = System.currentTimeMillis() val now = System.currentTimeMillis()
val minInterval = config.rebalanceMinInterval.toMillis val minInterval = config.rebalanceMinInterval.toMillis
if now - lastRebalanceTime < minInterval then return false if now - lastRebalanceTime.get() < minInterval then return false
val instances = instanceRegistry.getAllInstances val instances = instanceRegistry.getAllInstances
if instances.isEmpty then return false if instances.isEmpty then return false
@@ -54,11 +54,10 @@ class LoadBalancer:
def rebalance: Unit = def rebalance: Unit =
log.info("Starting rebalance") log.info("Starting rebalance")
val startTime = System.currentTimeMillis() val startTime = System.currentTimeMillis()
lastRebalanceTime = startTime lastRebalanceTime.set(startTime)
try try
val instances = instanceRegistry.getAllInstances val instances = instanceRegistry.getAllInstances.filter(_.state == "HEALTHY")
.filter(_.state == "HEALTHY")
if instances.size < 2 then if instances.size < 2 then
log.info("Not enough healthy instances for rebalance") log.info("Not enough healthy instances for rebalance")
@@ -75,22 +74,28 @@ class LoadBalancer:
.filter(_.subscriptionCount < avgLoad * 0.8) .filter(_.subscriptionCount < avgLoad * 0.8)
.sortBy(_.subscriptionCount) .sortBy(_.subscriptionCount)
if underloaded.isEmpty then
log.info("No underloaded instances available for rebalance")
return
var targetIdx = 0
overloaded.foreach { over => overloaded.foreach { over =>
val excess = over.subscriptionCount - avgLoad.toInt val excess = math.max(0, over.subscriptionCount - avgLoad.toInt)
if excess > 0 && underloaded.nonEmpty then
val gamesToMove = getGamesToMove(over.instanceId, excess) val gamesToMove = getGamesToMove(over.instanceId, excess)
if gamesToMove.nonEmpty then if gamesToMove.nonEmpty then
underloaded.headOption.foreach { under => val batchSize = math.max(1, (gamesToMove.size + underloaded.size - 1) / underloaded.size)
gamesToMove.grouped(batchSize).foreach { batch =>
val target = underloaded(targetIdx % underloaded.size)
targetIdx += 1
try try
val unsubscribed = coreGrpcClient.unsubscribeGames(over.hostname, over.grpcPort, gamesToMove) coreGrpcClient.unsubscribeGames(over.hostname, over.grpcPort, batch)
val subscribed = coreGrpcClient.batchResubscribeGames(under.hostname, under.grpcPort, gamesToMove) val subscribed = coreGrpcClient.batchResubscribeGames(target.hostname, target.grpcPort, batch)
if subscribed > 0 then if subscribed > 0 then
updateRedisGameSets(over.instanceId, under.instanceId, gamesToMove) updateRedisGameSets(over.instanceId, target.instanceId, batch)
log.infof("Moved %d games from %s to %s", subscribed, over.instanceId, under.instanceId) log.infof("Moved %d games from %s to %s", subscribed, over.instanceId, target.instanceId)
catch catch
case ex: Exception => case ex: Exception =>
log.warnf(ex, "Failed to move games from %s to %s", over.instanceId, under.instanceId) log.warnf(ex, "Failed to move games from %s to %s", over.instanceId, target.instanceId)
} }
} }