feat(auto-scaling): enhance AutoScaler with atomic lastScaleTime and improve scaling logic
This commit is contained in:
@@ -51,7 +51,9 @@ message EvictGamesResponse {
|
||||
int32 evictedCount = 1;
|
||||
}
|
||||
|
||||
message DrainInstanceRequest {}
|
||||
message DrainInstanceRequest {
|
||||
string instanceId = 1;
|
||||
}
|
||||
|
||||
message DrainInstanceResponse {
|
||||
int32 gamesMigrated = 1;
|
||||
|
||||
+5
-5
@@ -90,10 +90,10 @@ class CoordinatorGrpcServer extends CoordinatorServiceGrpc.CoordinatorServiceImp
|
||||
request: DrainInstanceRequest,
|
||||
responseObserver: StreamObserver[DrainInstanceResponse],
|
||||
): Unit =
|
||||
log.info("Drain instance request")
|
||||
val response = DrainInstanceResponse
|
||||
.newBuilder()
|
||||
.setGamesMigrated(0)
|
||||
.build()
|
||||
val instanceId = request.getInstanceId
|
||||
log.infof("Drain request for instance %s", instanceId)
|
||||
val gamesBefore = instanceRegistry.getInstance(instanceId).map(_.subscriptionCount).getOrElse(0)
|
||||
failoverService.onInstanceStreamDropped(instanceId)
|
||||
val response = DrainInstanceResponse.newBuilder().setGamesMigrated(gamesBefore).build()
|
||||
responseObserver.onNext(response)
|
||||
responseObserver.onCompleted()
|
||||
|
||||
+37
-54
@@ -1,82 +1,65 @@
|
||||
package de.nowchess.coordinator.grpc
|
||||
|
||||
import jakarta.enterprise.context.ApplicationScoped
|
||||
import jakarta.annotation.PreDestroy
|
||||
import org.jboss.logging.Logger
|
||||
import io.grpc.ManagedChannel
|
||||
import io.grpc.ManagedChannelBuilder
|
||||
import de.nowchess.coordinator.proto.{CoordinatorServiceGrpc, *}
|
||||
import scala.jdk.CollectionConverters.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
@ApplicationScoped
|
||||
class CoreGrpcClient:
|
||||
private val log = Logger.getLogger(classOf[CoreGrpcClient])
|
||||
private val log = Logger.getLogger(classOf[CoreGrpcClient])
|
||||
private val channels = ConcurrentHashMap[String, ManagedChannel]()
|
||||
|
||||
private def getChannel(host: String, port: Int): ManagedChannel =
|
||||
channels.computeIfAbsent(s"$host:$port", _ =>
|
||||
ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(),
|
||||
)
|
||||
|
||||
private def evictStaleChannel(host: String, port: Int): Unit =
|
||||
Option(channels.remove(s"$host:$port")).foreach(_.shutdownNow())
|
||||
|
||||
@PreDestroy
|
||||
def shutdown(): Unit =
|
||||
channels.values.asScala.foreach { ch =>
|
||||
ch.shutdown()
|
||||
if !ch.awaitTermination(5, TimeUnit.SECONDS) then ch.shutdownNow()
|
||||
}
|
||||
channels.clear()
|
||||
|
||||
def batchResubscribeGames(host: String, port: Int, gameIds: List[String]): Int =
|
||||
val channel = createChannel(host, port)
|
||||
try
|
||||
val stub = CoordinatorServiceGrpc.newStub(channel)
|
||||
val request = BatchResubscribeRequest
|
||||
.newBuilder()
|
||||
.addAllGameIds(gameIds.asJava)
|
||||
.build()
|
||||
|
||||
val latch = new java.util.concurrent.CountDownLatch(1)
|
||||
var result = 0
|
||||
|
||||
stub.batchResubscribeGames(
|
||||
request,
|
||||
new io.grpc.stub.StreamObserver[BatchResubscribeResponse]:
|
||||
override def onNext(response: BatchResubscribeResponse): Unit =
|
||||
result = response.getSubscribedCount
|
||||
|
||||
override def onError(t: Throwable): Unit =
|
||||
log.warnf(t, "batchResubscribeGames RPC failed for %s:%d", host, port)
|
||||
latch.countDown()
|
||||
|
||||
override def onCompleted(): Unit =
|
||||
latch.countDown(),
|
||||
)
|
||||
|
||||
latch.await()
|
||||
result
|
||||
finally channel.shutdown()
|
||||
val stub = CoordinatorServiceGrpc.newBlockingStub(getChannel(host, port))
|
||||
val request = BatchResubscribeRequest.newBuilder().addAllGameIds(gameIds.asJava).build()
|
||||
stub.batchResubscribeGames(request).getSubscribedCount
|
||||
catch
|
||||
case ex: Exception =>
|
||||
log.warnf(ex, "batchResubscribeGames RPC failed for %s:%d", host, port)
|
||||
evictStaleChannel(host, port)
|
||||
0
|
||||
|
||||
def unsubscribeGames(host: String, port: Int, gameIds: List[String]): Int =
|
||||
val channel = createChannel(host, port)
|
||||
try
|
||||
val stub = CoordinatorServiceGrpc.newBlockingStub(channel)
|
||||
val request = UnsubscribeGamesRequest
|
||||
.newBuilder()
|
||||
.addAllGameIds(gameIds.asJava)
|
||||
.build()
|
||||
|
||||
val response = stub.unsubscribeGames(request)
|
||||
response.getUnsubscribedCount
|
||||
val stub = CoordinatorServiceGrpc.newBlockingStub(getChannel(host, port))
|
||||
val request = UnsubscribeGamesRequest.newBuilder().addAllGameIds(gameIds.asJava).build()
|
||||
stub.unsubscribeGames(request).getUnsubscribedCount
|
||||
catch
|
||||
case ex: Exception =>
|
||||
log.warnf(ex, "unsubscribeGames RPC failed for %s:%d", host, port)
|
||||
evictStaleChannel(host, port)
|
||||
0
|
||||
finally channel.shutdown()
|
||||
|
||||
def evictGames(host: String, port: Int, gameIds: List[String]): Int =
|
||||
val channel = createChannel(host, port)
|
||||
try
|
||||
val stub = CoordinatorServiceGrpc.newBlockingStub(channel)
|
||||
val request = EvictGamesRequest
|
||||
.newBuilder()
|
||||
.addAllGameIds(gameIds.asJava)
|
||||
.build()
|
||||
|
||||
val response = stub.evictGames(request)
|
||||
response.getEvictedCount
|
||||
val stub = CoordinatorServiceGrpc.newBlockingStub(getChannel(host, port))
|
||||
val request = EvictGamesRequest.newBuilder().addAllGameIds(gameIds.asJava).build()
|
||||
stub.evictGames(request).getEvictedCount
|
||||
catch
|
||||
case ex: Exception =>
|
||||
log.warnf(ex, "evictGames RPC failed for %s:%d", host, port)
|
||||
evictStaleChannel(host, port)
|
||||
0
|
||||
finally channel.shutdown()
|
||||
|
||||
private def createChannel(host: String, port: Int): ManagedChannel =
|
||||
ManagedChannelBuilder
|
||||
.forAddress(host, port)
|
||||
.usePlaintext()
|
||||
.build()
|
||||
|
||||
+2
@@ -78,6 +78,7 @@ class CoordinatorResource:
|
||||
@Produces(Array(MediaType.APPLICATION_JSON))
|
||||
def triggerScaleUp: scala.collection.Map[String, String] =
|
||||
log.info("Manual scale up triggered")
|
||||
autoScaler.scaleUp()
|
||||
Map("status" -> "scale_up_started")
|
||||
|
||||
@POST
|
||||
@@ -85,6 +86,7 @@ class CoordinatorResource:
|
||||
@Produces(Array(MediaType.APPLICATION_JSON))
|
||||
def triggerScaleDown: scala.collection.Map[String, String] =
|
||||
log.info("Manual scale down triggered")
|
||||
autoScaler.scaleDown()
|
||||
Map("status" -> "scale_down_started")
|
||||
|
||||
case class MetricsDto(
|
||||
|
||||
+16
-18
@@ -3,8 +3,10 @@ package de.nowchess.coordinator.service
|
||||
import jakarta.enterprise.context.ApplicationScoped
|
||||
import jakarta.inject.Inject
|
||||
import de.nowchess.coordinator.config.CoordinatorConfig
|
||||
import io.fabric8.kubernetes.api.model.GenericKubernetesResource
|
||||
import io.fabric8.kubernetes.client.KubernetesClient
|
||||
import org.jboss.logging.Logger
|
||||
|
||||
import scala.compiletime.uninitialized
|
||||
|
||||
@ApplicationScoped
|
||||
@@ -19,30 +21,26 @@ class AutoScaler:
|
||||
private var instanceRegistry: InstanceRegistry = uninitialized
|
||||
|
||||
private val log = Logger.getLogger(classOf[AutoScaler])
|
||||
private var lastScaleTime = 0L
|
||||
private val lastScaleTime = new java.util.concurrent.atomic.AtomicLong(0L)
|
||||
|
||||
def checkAndScale: Unit =
|
||||
if !config.autoScaleEnabled then return
|
||||
|
||||
val now = System.currentTimeMillis()
|
||||
if now - lastScaleTime < 120000 then // 2 minute backoff
|
||||
return
|
||||
val now = System.currentTimeMillis()
|
||||
val last = lastScaleTime.get()
|
||||
if now - last < 120000 then return
|
||||
if !lastScaleTime.compareAndSet(last, now) then return
|
||||
|
||||
val instances = instanceRegistry.getAllInstances.filter(_.state == "HEALTHY")
|
||||
if instances.isEmpty then return
|
||||
|
||||
val avgLoad = instances.map(_.subscriptionCount).sum.toDouble / instances.size
|
||||
val maxCapacity = config.maxGamesPerCore * instances.size
|
||||
val avgLoad = instances.map(_.subscriptionCount).sum.toDouble / instances.size
|
||||
|
||||
if avgLoad > config.scaleUpThreshold * config.maxGamesPerCore then
|
||||
scaleUp()
|
||||
lastScaleTime = now
|
||||
if avgLoad > config.scaleUpThreshold * config.maxGamesPerCore then scaleUp()
|
||||
else if avgLoad < config.scaleDownThreshold * config.maxGamesPerCore && instances.size > config.scaleMinReplicas
|
||||
then
|
||||
scaleDown()
|
||||
lastScaleTime = now
|
||||
then scaleDown()
|
||||
|
||||
private def scaleUp(): Unit =
|
||||
def scaleUp(): Unit =
|
||||
log.info("Scaling up Argo Rollout")
|
||||
if kubeClient == null then
|
||||
log.warn("Kubernetes client not available, cannot scale")
|
||||
@@ -50,7 +48,7 @@ class AutoScaler:
|
||||
|
||||
try
|
||||
val rollout = kubeClient
|
||||
.resources(classOf[io.fabric8.kubernetes.api.model.GenericKubernetesResource])
|
||||
.resources(classOf[GenericKubernetesResource])
|
||||
.inNamespace(config.k8sNamespace)
|
||||
.withName(config.k8sRolloutName)
|
||||
.get()
|
||||
@@ -63,7 +61,7 @@ class AutoScaler:
|
||||
if currentReplicas < maxReplicas then
|
||||
spec.put("replicas", currentReplicas + 1)
|
||||
kubeClient
|
||||
.resources(classOf[io.fabric8.kubernetes.api.model.GenericKubernetesResource])
|
||||
.resources(classOf[GenericKubernetesResource])
|
||||
.inNamespace(config.k8sNamespace)
|
||||
.withName(config.k8sRolloutName)
|
||||
.createOrReplace(rollout)
|
||||
@@ -73,7 +71,7 @@ class AutoScaler:
|
||||
case ex: Exception =>
|
||||
log.warnf(ex, "Failed to scale up %s", config.k8sRolloutName)
|
||||
|
||||
private def scaleDown(): Unit =
|
||||
def scaleDown(): Unit =
|
||||
log.info("Scaling down Argo Rollout")
|
||||
if kubeClient == null then
|
||||
log.warn("Kubernetes client not available, cannot scale")
|
||||
@@ -81,7 +79,7 @@ class AutoScaler:
|
||||
|
||||
try
|
||||
val rollout = kubeClient
|
||||
.resources(classOf[io.fabric8.kubernetes.api.model.GenericKubernetesResource])
|
||||
.resources(classOf[GenericKubernetesResource])
|
||||
.inNamespace(config.k8sNamespace)
|
||||
.withName(config.k8sRolloutName)
|
||||
.get()
|
||||
@@ -94,7 +92,7 @@ class AutoScaler:
|
||||
if currentReplicas > minReplicas then
|
||||
spec.put("replicas", currentReplicas - 1)
|
||||
kubeClient
|
||||
.resources(classOf[io.fabric8.kubernetes.api.model.GenericKubernetesResource])
|
||||
.resources(classOf[GenericKubernetesResource])
|
||||
.inNamespace(config.k8sNamespace)
|
||||
.withName(config.k8sRolloutName)
|
||||
.createOrReplace(rollout)
|
||||
|
||||
+18
-12
@@ -4,9 +4,7 @@ import jakarta.enterprise.context.ApplicationScoped
|
||||
import jakarta.inject.Inject
|
||||
import org.redisson.api.RedissonClient
|
||||
import scala.jdk.CollectionConverters.*
|
||||
import scala.concurrent.duration.*
|
||||
import scala.compiletime.uninitialized
|
||||
import java.time.Instant
|
||||
import org.jboss.logging.Logger
|
||||
import de.nowchess.coordinator.dto.InstanceMetadata
|
||||
import de.nowchess.coordinator.grpc.CoreGrpcClient
|
||||
@@ -67,17 +65,25 @@ class FailoverService:
|
||||
val batches = gameIds.grouped(batchSize).toList
|
||||
|
||||
batches.zipWithIndex.foreach { case (batch, idx) =>
|
||||
val targetInstance = healthyInstances(idx % healthyInstances.size)
|
||||
try
|
||||
val subscribed = coreGrpcClient.batchResubscribeGames(
|
||||
targetInstance.hostname,
|
||||
targetInstance.grpcPort,
|
||||
batch,
|
||||
var migrated = false
|
||||
var attempt = 0
|
||||
while !migrated && attempt < healthyInstances.size do
|
||||
val target = healthyInstances((idx + attempt) % healthyInstances.size)
|
||||
attempt += 1
|
||||
try
|
||||
val subscribed = coreGrpcClient.batchResubscribeGames(target.hostname, target.grpcPort, batch)
|
||||
if subscribed > 0 then
|
||||
log.infof("Migrated %d games from %s to %s", subscribed, deadInstanceId, target.instanceId)
|
||||
migrated = true
|
||||
catch
|
||||
case ex: Exception =>
|
||||
log.warnf(ex, "Failed to migrate batch to %s, trying next", target.instanceId)
|
||||
if !migrated then
|
||||
log.errorf(
|
||||
"Failed to migrate batch of %d games from %s to any healthy instance",
|
||||
batch.size,
|
||||
deadInstanceId,
|
||||
)
|
||||
log.infof("Migrated %d games from %s to %s", subscribed, deadInstanceId, targetInstance.instanceId)
|
||||
catch
|
||||
case ex: Exception =>
|
||||
log.warnf(ex, "Failed to migrate batch to %s, will retry", targetInstance.instanceId)
|
||||
}
|
||||
|
||||
private def cleanupDeadInstance(instanceId: String): Unit =
|
||||
|
||||
+8
-22
@@ -4,11 +4,10 @@ import jakarta.enterprise.context.ApplicationScoped
|
||||
import jakarta.inject.Inject
|
||||
import org.redisson.api.RedissonClient
|
||||
import scala.jdk.CollectionConverters.*
|
||||
import scala.collection.mutable
|
||||
import scala.compiletime.uninitialized
|
||||
import com.fasterxml.jackson.databind.ObjectMapper
|
||||
import de.nowchess.coordinator.dto.InstanceMetadata
|
||||
import java.time.Instant
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
@ApplicationScoped
|
||||
class InstanceRegistry:
|
||||
@@ -16,29 +15,17 @@ class InstanceRegistry:
|
||||
private var redissonClient: RedissonClient = uninitialized
|
||||
|
||||
private val mapper = ObjectMapper()
|
||||
private val instances = mutable.Map[String, InstanceMetadata]()
|
||||
private val instances = ConcurrentHashMap[String, InstanceMetadata]()
|
||||
private var redisPrefix = "nowchess"
|
||||
|
||||
def setRedisPrefix(prefix: String): Unit =
|
||||
redisPrefix = prefix
|
||||
|
||||
def getInstance(instanceId: String): Option[InstanceMetadata] =
|
||||
instances.get(instanceId)
|
||||
Option(instances.get(instanceId))
|
||||
|
||||
def getAllInstances: List[InstanceMetadata] =
|
||||
instances.values.toList
|
||||
|
||||
def listInstancesFromRedis: List[InstanceMetadata] =
|
||||
val pattern = s"$redisPrefix:instances:*"
|
||||
val keys = redissonClient.getKeys.getKeysByPattern(pattern, 100)
|
||||
keys.asScala.flatMap { key =>
|
||||
val bucket = redissonClient.getBucket[String](key)
|
||||
val value = bucket.getAndDelete()
|
||||
if value != null then
|
||||
try Some(mapper.readValue(value, classOf[InstanceMetadata]))
|
||||
catch case _: Exception => None
|
||||
else None
|
||||
}.toList
|
||||
instances.values.asScala.toList
|
||||
|
||||
def updateInstanceFromRedis(instanceId: String): Unit =
|
||||
val key = s"$redisPrefix:instances:$instanceId"
|
||||
@@ -47,14 +34,13 @@ class InstanceRegistry:
|
||||
if value != null then
|
||||
try
|
||||
val metadata = mapper.readValue(value, classOf[InstanceMetadata])
|
||||
instances(instanceId) = metadata
|
||||
instances.put(instanceId, metadata)
|
||||
catch case _: Exception => ()
|
||||
|
||||
def markInstanceDead(instanceId: String): Unit =
|
||||
instances.get(instanceId).foreach { inst =>
|
||||
val dead = inst.copy(state = "DEAD")
|
||||
instances(instanceId) = dead
|
||||
}
|
||||
instances.computeIfPresent(instanceId, (_, inst) => inst.copy(state = "DEAD"))
|
||||
()
|
||||
|
||||
def removeInstance(instanceId: String): Unit =
|
||||
instances.remove(instanceId)
|
||||
()
|
||||
|
||||
+26
-21
@@ -25,7 +25,7 @@ class LoadBalancer:
|
||||
private var coreGrpcClient: CoreGrpcClient = uninitialized
|
||||
|
||||
private val log = Logger.getLogger(classOf[LoadBalancer])
|
||||
private var lastRebalanceTime = 0L
|
||||
private val lastRebalanceTime = new java.util.concurrent.atomic.AtomicLong(0L)
|
||||
private var redisPrefix = "nowchess"
|
||||
|
||||
def setRedisPrefix(prefix: String): Unit =
|
||||
@@ -34,7 +34,7 @@ class LoadBalancer:
|
||||
def shouldRebalance: Boolean =
|
||||
val now = System.currentTimeMillis()
|
||||
val minInterval = config.rebalanceMinInterval.toMillis
|
||||
if now - lastRebalanceTime < minInterval then return false
|
||||
if now - lastRebalanceTime.get() < minInterval then return false
|
||||
|
||||
val instances = instanceRegistry.getAllInstances
|
||||
if instances.isEmpty then return false
|
||||
@@ -54,11 +54,10 @@ class LoadBalancer:
|
||||
def rebalance: Unit =
|
||||
log.info("Starting rebalance")
|
||||
val startTime = System.currentTimeMillis()
|
||||
lastRebalanceTime = startTime
|
||||
lastRebalanceTime.set(startTime)
|
||||
|
||||
try
|
||||
val instances = instanceRegistry.getAllInstances
|
||||
.filter(_.state == "HEALTHY")
|
||||
val instances = instanceRegistry.getAllInstances.filter(_.state == "HEALTHY")
|
||||
|
||||
if instances.size < 2 then
|
||||
log.info("Not enough healthy instances for rebalance")
|
||||
@@ -75,23 +74,29 @@ class LoadBalancer:
|
||||
.filter(_.subscriptionCount < avgLoad * 0.8)
|
||||
.sortBy(_.subscriptionCount)
|
||||
|
||||
overloaded.foreach { over =>
|
||||
val excess = over.subscriptionCount - avgLoad.toInt
|
||||
if excess > 0 && underloaded.nonEmpty then
|
||||
val gamesToMove = getGamesToMove(over.instanceId, excess)
|
||||
if gamesToMove.nonEmpty then
|
||||
underloaded.headOption.foreach { under =>
|
||||
try
|
||||
val unsubscribed = coreGrpcClient.unsubscribeGames(over.hostname, over.grpcPort, gamesToMove)
|
||||
val subscribed = coreGrpcClient.batchResubscribeGames(under.hostname, under.grpcPort, gamesToMove)
|
||||
if underloaded.isEmpty then
|
||||
log.info("No underloaded instances available for rebalance")
|
||||
return
|
||||
|
||||
if subscribed > 0 then
|
||||
updateRedisGameSets(over.instanceId, under.instanceId, gamesToMove)
|
||||
log.infof("Moved %d games from %s to %s", subscribed, over.instanceId, under.instanceId)
|
||||
catch
|
||||
case ex: Exception =>
|
||||
log.warnf(ex, "Failed to move games from %s to %s", over.instanceId, under.instanceId)
|
||||
}
|
||||
var targetIdx = 0
|
||||
overloaded.foreach { over =>
|
||||
val excess = math.max(0, over.subscriptionCount - avgLoad.toInt)
|
||||
val gamesToMove = getGamesToMove(over.instanceId, excess)
|
||||
if gamesToMove.nonEmpty then
|
||||
val batchSize = math.max(1, (gamesToMove.size + underloaded.size - 1) / underloaded.size)
|
||||
gamesToMove.grouped(batchSize).foreach { batch =>
|
||||
val target = underloaded(targetIdx % underloaded.size)
|
||||
targetIdx += 1
|
||||
try
|
||||
coreGrpcClient.unsubscribeGames(over.hostname, over.grpcPort, batch)
|
||||
val subscribed = coreGrpcClient.batchResubscribeGames(target.hostname, target.grpcPort, batch)
|
||||
if subscribed > 0 then
|
||||
updateRedisGameSets(over.instanceId, target.instanceId, batch)
|
||||
log.infof("Moved %d games from %s to %s", subscribed, over.instanceId, target.instanceId)
|
||||
catch
|
||||
case ex: Exception =>
|
||||
log.warnf(ex, "Failed to move games from %s to %s", over.instanceId, target.instanceId)
|
||||
}
|
||||
}
|
||||
|
||||
val elapsed = System.currentTimeMillis() - startTime
|
||||
|
||||
Reference in New Issue
Block a user