Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 343e2bdd10 | |||
| 751a58b606 | |||
| 30295a4bb9 |
@@ -22,6 +22,9 @@
|
|||||||
// NOWCHESS_JDBC_URL (default: jdbc:postgresql://localhost:5432/nowchess)
|
// NOWCHESS_JDBC_URL (default: jdbc:postgresql://localhost:5432/nowchess)
|
||||||
// NOWCHESS_DB_USER (default: nowchess)
|
// NOWCHESS_DB_USER (default: nowchess)
|
||||||
// NOWCHESS_DB_PASS (default: nowchess)
|
// NOWCHESS_DB_PASS (default: nowchess)
|
||||||
|
// NOWCHESS_PGN_PATH (optional) — file or http(s) URL of a Lichess PGN dump (.pgn or .pgn.zst).
|
||||||
|
// When set, all batch jobs read games from the dump instead of PostgreSQL and
|
||||||
|
// skip JDBC write-back (Parquet/CSV output only). Demo data source.
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("scala")
|
id("scala")
|
||||||
@@ -71,6 +74,10 @@ dependencies {
|
|||||||
|
|
||||||
// PostgreSQL JDBC driver bundled so it is available on executor classpath.
|
// PostgreSQL JDBC driver bundled so it is available on executor classpath.
|
||||||
implementation("org.postgresql:postgresql:42.7.4")
|
implementation("org.postgresql:postgresql:42.7.4")
|
||||||
|
|
||||||
|
// zstd-jni: decompress Lichess .pgn.zst dumps in-process. Provided at runtime by Spark
|
||||||
|
// (it uses zstd-jni internally for shuffle/event-log compression), so compile-only here.
|
||||||
|
compileOnly("com.github.luben:zstd-jni:1.5.6-9")
|
||||||
}
|
}
|
||||||
|
|
||||||
application {
|
application {
|
||||||
|
|||||||
@@ -0,0 +1,119 @@
|
|||||||
|
package de.nowchess.analytics
|
||||||
|
|
||||||
|
import org.apache.spark.SparkFiles
|
||||||
|
import org.apache.spark.sql.DataFrame
|
||||||
|
import org.apache.spark.sql.SparkSession
|
||||||
|
import org.apache.spark.sql.functions as F
|
||||||
|
|
||||||
|
/** Normalised game-record source for the batch jobs.
|
||||||
|
*
|
||||||
|
* Every batch job consumes the same five-column shape:
|
||||||
|
* - white_id, black_id : player identifiers
|
||||||
|
* - result : one of "white", "black", "draw"
|
||||||
|
* - move_count : number of plies
|
||||||
|
* - pgn : full PGN ("[Event …]…\n\n1. e4 …"), header and movetext separated by a blank line
|
||||||
|
*
|
||||||
|
* Two backends, selected by the `NOWCHESS_PGN_PATH` environment variable:
|
||||||
|
* - unset → PostgreSQL `game_records` table (production)
|
||||||
|
* - set → a Lichess PGN dump file/URL (demo). Point it at a `lichess_db_standard_rated_*.pgn[.zst]`
|
||||||
|
* to drive every batch job from real Lichess games.
|
||||||
|
*
|
||||||
|
* Lichess parsing uses only Spark SQL string functions — no UDFs — so Catalyst can push predicates,
|
||||||
|
* matching the no-UDF approach already used in OpeningBookJob.
|
||||||
|
*/
|
||||||
|
object GameSource:
|
||||||
|
|
||||||
|
private val PgnPathEnv = "NOWCHESS_PGN_PATH"
|
||||||
|
|
||||||
|
/** True when a Lichess PGN dump is configured; jobs use this to skip JDBC write-back. */
|
||||||
|
def isPgnMode: Boolean = sys.env.contains(PgnPathEnv)
|
||||||
|
|
||||||
|
def load(spark: SparkSession, jdbcUrl: String, dbUser: String, dbPass: String): DataFrame =
|
||||||
|
sys.env.get(PgnPathEnv) match
|
||||||
|
case Some(path) => fromLichessPgn(spark, path)
|
||||||
|
case None => fromJdbc(spark, jdbcUrl, dbUser, dbPass)
|
||||||
|
|
||||||
|
def fromJdbc(spark: SparkSession, jdbcUrl: String, dbUser: String, dbPass: String): DataFrame =
|
||||||
|
spark.read
|
||||||
|
.format("jdbc")
|
||||||
|
.option("url", jdbcUrl)
|
||||||
|
.option("dbtable", "game_records")
|
||||||
|
.option("user", dbUser)
|
||||||
|
.option("password", dbPass)
|
||||||
|
.option("driver", "org.postgresql.Driver")
|
||||||
|
.option("fetchsize", "10000")
|
||||||
|
.load()
|
||||||
|
.select("white_id", "black_id", "result", "move_count", "pgn")
|
||||||
|
|
||||||
|
/** Parses a Lichess PGN dump into the normalised game shape.
|
||||||
|
*
|
||||||
|
* `path` may be:
|
||||||
|
* - an http(s)/ftp URL — fetched once via SparkContext.addFile and distributed to executors, then read
|
||||||
|
* from the local replica (no S3/PVC needed; handy for a staging demo)
|
||||||
|
* - any Hadoop-readable path (file://, hdfs://, s3a://, …)
|
||||||
|
*
|
||||||
|
* `.zst` dumps (Lichess' native format) are decompressed in-process via zstd-jni; `.gz`/`.bz2` are
|
||||||
|
* handled by Spark's text reader codecs.
|
||||||
|
*
|
||||||
|
* Records are split on the "[Event " tag that opens every game, so each row holds one complete game
|
||||||
|
* (the empty fragment before the first game is filtered out). Header tags are read with regexp_extract;
|
||||||
|
* the movetext (after the blank line) is cleaned of clock/eval comments and move numbers to count plies.
|
||||||
|
*/
|
||||||
|
def fromLichessPgn(spark: SparkSession, path: String): DataFrame =
|
||||||
|
val resolved = resolvePath(spark, path)
|
||||||
|
val record = F.col("value")
|
||||||
|
|
||||||
|
val resultTag = F.regexp_extract(record, "Result \"([^\"]*)\"", 1)
|
||||||
|
val result = F
|
||||||
|
.when(resultTag === "1-0", "white")
|
||||||
|
.when(resultTag === "0-1", "black")
|
||||||
|
.when(resultTag === "1/2-1/2", "draw")
|
||||||
|
.otherwise(F.lit(null).cast("string"))
|
||||||
|
|
||||||
|
val moveText = F.coalesce(F.split(record, "\n\n").getItem(1), F.lit(""))
|
||||||
|
val noComment = F.regexp_replace(moveText, "\\{[^}]*\\}", "")
|
||||||
|
val noResult = F.regexp_replace(noComment, "(1-0|0-1|1/2-1/2|\\*)", "")
|
||||||
|
val noNumbers = F.regexp_replace(noResult, "\\d+\\.+", " ")
|
||||||
|
val plies = F.size(F.filter(F.split(F.trim(noNumbers), "\\s+"), tok => F.length(tok) > 0))
|
||||||
|
|
||||||
|
spark.read
|
||||||
|
.option("lineSep", "[Event ")
|
||||||
|
.text(resolved)
|
||||||
|
.filter(F.length(F.trim(record)) > 0)
|
||||||
|
.select(
|
||||||
|
F.regexp_extract(record, "White \"([^\"]*)\"", 1).as("white_id"),
|
||||||
|
F.regexp_extract(record, "Black \"([^\"]*)\"", 1).as("black_id"),
|
||||||
|
result.as("result"),
|
||||||
|
plies.as("move_count"),
|
||||||
|
F.concat(F.lit("[Event "), record).as("pgn"),
|
||||||
|
)
|
||||||
|
.filter((F.col("white_id") =!= "").and(F.col("black_id") =!= ""))
|
||||||
|
|
||||||
|
/** Turns an http(s)/ftp URL into a cluster-local path by fetching it once with SparkContext.addFile,
|
||||||
|
* which distributes the file to every executor. `.zst` is decompressed in-process and the plain `.pgn`
|
||||||
|
* is redistributed. Non-URL paths are returned unchanged.
|
||||||
|
*/
|
||||||
|
private def resolvePath(spark: SparkSession, path: String): String =
|
||||||
|
if !path.matches("^(https?|ftp)://.*") then path
|
||||||
|
else
|
||||||
|
spark.sparkContext.addFile(path)
|
||||||
|
val local = SparkFiles.get(baseName(path))
|
||||||
|
if !local.endsWith(".zst") then "file://" + local
|
||||||
|
else distribute(spark, decompressZstd(local))
|
||||||
|
|
||||||
|
private def baseName(path: String): String = path.substring(path.lastIndexOf('/') + 1)
|
||||||
|
|
||||||
|
private def distribute(spark: SparkSession, localPath: String): String =
|
||||||
|
spark.sparkContext.addFile("file://" + localPath)
|
||||||
|
"file://" + SparkFiles.get(baseName(localPath))
|
||||||
|
|
||||||
|
/** Decompresses a `.zst` file to a temp `.pgn` using zstd-jni (bundled with Spark at runtime). */
|
||||||
|
private def decompressZstd(srcPath: String): String =
|
||||||
|
val out = java.io.File.createTempFile("lichess-", ".pgn")
|
||||||
|
out.deleteOnExit()
|
||||||
|
val in = com.github.luben.zstd.ZstdInputStream(
|
||||||
|
java.io.BufferedInputStream(java.io.FileInputStream(srcPath)),
|
||||||
|
)
|
||||||
|
try java.nio.file.Files.copy(in, out.toPath, java.nio.file.StandardCopyOption.REPLACE_EXISTING)
|
||||||
|
finally in.close()
|
||||||
|
out.getAbsolutePath
|
||||||
@@ -37,15 +37,8 @@ object OpeningBookJob:
|
|||||||
outputDir: String,
|
outputDir: String,
|
||||||
maxPlies: Int,
|
maxPlies: Int,
|
||||||
): Unit =
|
): Unit =
|
||||||
val games = spark.read
|
val games = GameSource
|
||||||
.format("jdbc")
|
.load(spark, jdbcUrl, dbUser, dbPass)
|
||||||
.option("url", jdbcUrl)
|
|
||||||
.option("dbtable", "game_records")
|
|
||||||
.option("user", dbUser)
|
|
||||||
.option("password", dbPass)
|
|
||||||
.option("driver", "org.postgresql.Driver")
|
|
||||||
.option("fetchsize", "10000")
|
|
||||||
.load()
|
|
||||||
.select("pgn", "result")
|
.select("pgn", "result")
|
||||||
.filter(F.col("result").isNotNull.and(F.col("pgn").isNotNull))
|
.filter(F.col("result").isNotNull.and(F.col("pgn").isNotNull))
|
||||||
|
|
||||||
@@ -79,15 +72,16 @@ object OpeningBookJob:
|
|||||||
.option("header", "true")
|
.option("header", "true")
|
||||||
.csv(s"$outputDir/opening_book_top1000")
|
.csv(s"$outputDir/opening_book_top1000")
|
||||||
|
|
||||||
top1000.write
|
if !GameSource.isPgnMode then
|
||||||
.mode("overwrite")
|
top1000.write
|
||||||
.format("jdbc")
|
.mode("overwrite")
|
||||||
.option("url", jdbcUrl)
|
.format("jdbc")
|
||||||
.option("dbtable", "analytics_opening_stats")
|
.option("url", jdbcUrl)
|
||||||
.option("user", dbUser)
|
.option("dbtable", "analytics_opening_stats")
|
||||||
.option("password", dbPass)
|
.option("user", dbUser)
|
||||||
.option("driver", "org.postgresql.Driver")
|
.option("password", dbPass)
|
||||||
.save()
|
.option("driver", "org.postgresql.Driver")
|
||||||
|
.save()
|
||||||
|
|
||||||
/** Extracts the first `maxPlies` moves from a PGN column as a space-separated string.
|
/** Extracts the first `maxPlies` moves from a PGN column as a space-separated string.
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -50,15 +50,8 @@ object PlayerClusteringJob:
|
|||||||
outputDir: String,
|
outputDir: String,
|
||||||
k: Int,
|
k: Int,
|
||||||
): Unit =
|
): Unit =
|
||||||
val games = spark.read
|
val games = GameSource
|
||||||
.format("jdbc")
|
.load(spark, jdbcUrl, dbUser, dbPass)
|
||||||
.option("url", jdbcUrl)
|
|
||||||
.option("dbtable", "game_records")
|
|
||||||
.option("user", dbUser)
|
|
||||||
.option("password", dbPass)
|
|
||||||
.option("driver", "org.postgresql.Driver")
|
|
||||||
.option("fetchsize", "10000")
|
|
||||||
.load()
|
|
||||||
.select("white_id", "black_id", "result", "move_count")
|
.select("white_id", "black_id", "result", "move_count")
|
||||||
.filter(F.col("result").isNotNull)
|
.filter(F.col("result").isNotNull)
|
||||||
|
|
||||||
@@ -126,25 +119,26 @@ object PlayerClusteringJob:
|
|||||||
.option("header", "true")
|
.option("header", "true")
|
||||||
.csv(s"$outputDir/cluster_archetypes")
|
.csv(s"$outputDir/cluster_archetypes")
|
||||||
|
|
||||||
clustersDf.write
|
if !GameSource.isPgnMode then
|
||||||
.mode("overwrite")
|
clustersDf.write
|
||||||
.format("jdbc")
|
.mode("overwrite")
|
||||||
.option("url", jdbcUrl)
|
.format("jdbc")
|
||||||
.option("dbtable", "analytics_player_clusters")
|
.option("url", jdbcUrl)
|
||||||
.option("user", dbUser)
|
.option("dbtable", "analytics_player_clusters")
|
||||||
.option("password", dbPass)
|
.option("user", dbUser)
|
||||||
.option("driver", "org.postgresql.Driver")
|
.option("password", dbPass)
|
||||||
.save()
|
.option("driver", "org.postgresql.Driver")
|
||||||
|
.save()
|
||||||
|
|
||||||
archetypes.write
|
archetypes.write
|
||||||
.mode("overwrite")
|
.mode("overwrite")
|
||||||
.format("jdbc")
|
.format("jdbc")
|
||||||
.option("url", jdbcUrl)
|
.option("url", jdbcUrl)
|
||||||
.option("dbtable", "analytics_cluster_archetypes")
|
.option("dbtable", "analytics_cluster_archetypes")
|
||||||
.option("user", dbUser)
|
.option("user", dbUser)
|
||||||
.option("password", dbPass)
|
.option("password", dbPass)
|
||||||
.option("driver", "org.postgresql.Driver")
|
.option("driver", "org.postgresql.Driver")
|
||||||
.save()
|
.save()
|
||||||
|
|
||||||
private def buildPlayerStats(games: org.apache.spark.sql.DataFrame): org.apache.spark.sql.DataFrame =
|
private def buildPlayerStats(games: org.apache.spark.sql.DataFrame): org.apache.spark.sql.DataFrame =
|
||||||
val asWhite = games.select(
|
val asWhite = games.select(
|
||||||
|
|||||||
@@ -53,15 +53,8 @@ object PlayerGraphJob:
|
|||||||
dbPass: String,
|
dbPass: String,
|
||||||
outputDir: String,
|
outputDir: String,
|
||||||
): Unit =
|
): Unit =
|
||||||
val gamesRdd: RDD[Row] = spark.read
|
val gamesRdd: RDD[Row] = GameSource
|
||||||
.format("jdbc")
|
.load(spark, jdbcUrl, dbUser, dbPass)
|
||||||
.option("url", jdbcUrl)
|
|
||||||
.option("dbtable", "game_records")
|
|
||||||
.option("user", dbUser)
|
|
||||||
.option("password", dbPass)
|
|
||||||
.option("driver", "org.postgresql.Driver")
|
|
||||||
.option("fetchsize", "10000")
|
|
||||||
.load()
|
|
||||||
.select("white_id", "black_id", "result")
|
.select("white_id", "black_id", "result")
|
||||||
.filter(F.col("result").isNotNull)
|
.filter(F.col("result").isNotNull)
|
||||||
.rdd
|
.rdd
|
||||||
@@ -116,15 +109,16 @@ object PlayerGraphJob:
|
|||||||
.mode("overwrite")
|
.mode("overwrite")
|
||||||
.parquet(s"$outputDir/player_graph")
|
.parquet(s"$outputDir/player_graph")
|
||||||
|
|
||||||
result.write
|
if !GameSource.isPgnMode then
|
||||||
.mode("overwrite")
|
result.write
|
||||||
.format("jdbc")
|
.mode("overwrite")
|
||||||
.option("url", jdbcUrl)
|
.format("jdbc")
|
||||||
.option("dbtable", "analytics_player_graph")
|
.option("url", jdbcUrl)
|
||||||
.option("user", dbUser)
|
.option("dbtable", "analytics_player_graph")
|
||||||
.option("password", dbPass)
|
.option("user", dbUser)
|
||||||
.option("driver", "org.postgresql.Driver")
|
.option("password", dbPass)
|
||||||
.save()
|
.option("driver", "org.postgresql.Driver")
|
||||||
|
.save()
|
||||||
|
|
||||||
// How many players belong to each connected component?
|
// How many players belong to each connected component?
|
||||||
// A large dominant component + many singletons is the expected shape.
|
// A large dominant component + many singletons is the expected shape.
|
||||||
|
|||||||
@@ -34,15 +34,8 @@ object PlayerStatsJob:
|
|||||||
dbPass: String,
|
dbPass: String,
|
||||||
outputDir: String,
|
outputDir: String,
|
||||||
): Unit =
|
): Unit =
|
||||||
val games = spark.read
|
val games = GameSource
|
||||||
.format("jdbc")
|
.load(spark, jdbcUrl, dbUser, dbPass)
|
||||||
.option("url", jdbcUrl)
|
|
||||||
.option("dbtable", "game_records")
|
|
||||||
.option("user", dbUser)
|
|
||||||
.option("password", dbPass)
|
|
||||||
.option("driver", "org.postgresql.Driver")
|
|
||||||
.option("fetchsize", "10000")
|
|
||||||
.load()
|
|
||||||
.select("white_id", "black_id", "result", "move_count")
|
.select("white_id", "black_id", "result", "move_count")
|
||||||
.filter(F.col("result").isNotNull)
|
.filter(F.col("result").isNotNull)
|
||||||
|
|
||||||
@@ -84,12 +77,13 @@ object PlayerStatsJob:
|
|||||||
.mode("overwrite")
|
.mode("overwrite")
|
||||||
.parquet(s"$outputDir/player_stats")
|
.parquet(s"$outputDir/player_stats")
|
||||||
|
|
||||||
stats.write
|
if !GameSource.isPgnMode then
|
||||||
.mode("overwrite")
|
stats.write
|
||||||
.format("jdbc")
|
.mode("overwrite")
|
||||||
.option("url", jdbcUrl)
|
.format("jdbc")
|
||||||
.option("dbtable", "analytics_player_stats")
|
.option("url", jdbcUrl)
|
||||||
.option("user", dbUser)
|
.option("dbtable", "analytics_player_stats")
|
||||||
.option("password", dbPass)
|
.option("user", dbUser)
|
||||||
.option("driver", "org.postgresql.Driver")
|
.option("password", dbPass)
|
||||||
.save()
|
.option("driver", "org.postgresql.Driver")
|
||||||
|
.save()
|
||||||
|
|||||||
+1
-1
@@ -20,7 +20,7 @@ object TournamentBotConfig:
|
|||||||
tournamentId <- env.get("TOURNAMENT_ID").filter(_.nonEmpty)
|
tournamentId <- env.get("TOURNAMENT_ID").filter(_.nonEmpty)
|
||||||
token <- env.get("TOURNAMENT_BOT_TOKEN").filter(_.nonEmpty)
|
token <- env.get("TOURNAMENT_BOT_TOKEN").filter(_.nonEmpty)
|
||||||
botId <- jwtSubject(token)
|
botId <- jwtSubject(token)
|
||||||
serverUrl = env.getOrElse("TOURNAMENT_SERVER_URL", "http://localhost:8089")
|
serverUrl = env.getOrElse("TOURNAMENT_SERVER_URL", "http://141.37.123.132:8086")
|
||||||
difficulty = env.getOrElse("TOURNAMENT_BOT_DIFFICULTY", "medium")
|
difficulty = env.getOrElse("TOURNAMENT_BOT_DIFFICULTY", "medium")
|
||||||
yield TournamentBotConfig(serverUrl, tournamentId, token, botId, difficulty)
|
yield TournamentBotConfig(serverUrl, tournamentId, token, botId, difficulty)
|
||||||
|
|
||||||
|
|||||||
+28
-1
@@ -39,10 +39,11 @@ class TournamentBotGamePlayer:
|
|||||||
// scalafix:on DisableSyntax.var
|
// scalafix:on DisableSyntax.var
|
||||||
|
|
||||||
val defaultServerUrl: String =
|
val defaultServerUrl: String =
|
||||||
System.getenv().asScala.getOrElse("TOURNAMENT_SERVER_URL", "http://localhost:8089")
|
System.getenv().asScala.getOrElse("TOURNAMENT_SERVER_URL", "http://141.37.123.132:8086")
|
||||||
|
|
||||||
@PostConstruct
|
@PostConstruct
|
||||||
def initialize(): Unit =
|
def initialize(): Unit =
|
||||||
|
parkOnStartup()
|
||||||
config match
|
config match
|
||||||
case None =>
|
case None =>
|
||||||
log.info("Tournament bot disabled — set TOURNAMENT_ID and TOURNAMENT_BOT_TOKEN to enable")
|
log.info("Tournament bot disabled — set TOURNAMENT_ID and TOURNAMENT_BOT_TOKEN to enable")
|
||||||
@@ -50,6 +51,32 @@ class TournamentBotGamePlayer:
|
|||||||
log.infof("Tournament bot enabled — server=%s tournament=%s bot=%s", cfg.serverUrl, cfg.tournamentId, cfg.botId)
|
log.infof("Tournament bot enabled — server=%s tournament=%s bot=%s", cfg.serverUrl, cfg.tournamentId, cfg.botId)
|
||||||
startAsync(cfg)
|
startAsync(cfg)
|
||||||
|
|
||||||
|
private def parkOnStartup(): Unit =
|
||||||
|
park(defaultServerUrl, "expert") match
|
||||||
|
case Some(id) => log.infof("Parked expert bot on %s as id %s", defaultServerUrl, id)
|
||||||
|
case None => log.warnf("Failed to park expert bot on %s", defaultServerUrl)
|
||||||
|
|
||||||
|
private def park(serverUrl: String, difficulty: String): Option[String] =
|
||||||
|
System.getenv().asScala.get("TOURNAMENT_BOT_TOKEN").filter(_.nonEmpty).flatMap { token =>
|
||||||
|
Try {
|
||||||
|
val body = s"""{"name":"${botName(difficulty)}"}"""
|
||||||
|
val response = client
|
||||||
|
.target(serverUrl)
|
||||||
|
.path("api")
|
||||||
|
.path("bots")
|
||||||
|
.request(MediaType.APPLICATION_JSON)
|
||||||
|
.header("Authorization", s"Bearer $token")
|
||||||
|
.post(Entity.entity(body, MediaType.APPLICATION_JSON))
|
||||||
|
if response.getStatus == 201 || response.getStatus == 200 then
|
||||||
|
val id = objectMapper.readTree(response.readEntity(classOf[String])).path("id").asText()
|
||||||
|
response.close()
|
||||||
|
Option(id).filter(_.nonEmpty)
|
||||||
|
else { log.warnf("Parking bot %s returned status %d", botName(difficulty), response.getStatus); response.close(); None }
|
||||||
|
}.getOrElse(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def botName(difficulty: String): String = s"NowChess ${difficulty.capitalize}"
|
||||||
|
|
||||||
def joinTournament(
|
def joinTournament(
|
||||||
tournamentId: String,
|
tournamentId: String,
|
||||||
botToken: String,
|
botToken: String,
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package de.nowchess.ws.resource
|
|||||||
import de.nowchess.ws.config.RedisConfig
|
import de.nowchess.ws.config.RedisConfig
|
||||||
import io.micrometer.core.instrument.{Counter, Gauge, MeterRegistry}
|
import io.micrometer.core.instrument.{Counter, Gauge, MeterRegistry}
|
||||||
import io.quarkus.redis.datasource.RedisDataSource
|
import io.quarkus.redis.datasource.RedisDataSource
|
||||||
import io.quarkus.redis.datasource.pubsub.PubSubCommands
|
|
||||||
import io.quarkus.websockets.next.*
|
import io.quarkus.websockets.next.*
|
||||||
import io.smallrye.jwt.auth.principal.JWTParser
|
import io.smallrye.jwt.auth.principal.JWTParser
|
||||||
import jakarta.annotation.PostConstruct
|
import jakarta.annotation.PostConstruct
|
||||||
@@ -34,6 +33,7 @@ class GameWebSocketResource:
|
|||||||
// scalafix:on DisableSyntax.var
|
// scalafix:on DisableSyntax.var
|
||||||
|
|
||||||
private val connections = new ConcurrentHashMap[String, ConnectionMeta]()
|
private val connections = new ConcurrentHashMap[String, ConnectionMeta]()
|
||||||
|
private val pendingAuth = ConcurrentHashMap.newKeySet[String]()
|
||||||
|
|
||||||
@PostConstruct
|
@PostConstruct
|
||||||
def initializeMetrics(): Unit = {
|
def initializeMetrics(): Unit = {
|
||||||
@@ -64,40 +64,60 @@ class GameWebSocketResource:
|
|||||||
s"${redisConfig.prefix}:game:$gameId:c2s"
|
s"${redisConfig.prefix}:game:$gameId:c2s"
|
||||||
|
|
||||||
@OnOpen
|
@OnOpen
|
||||||
def onOpen(connection: WebSocketConnection, handshake: HandshakeRequest): Unit =
|
def onOpen(connection: WebSocketConnection): Unit =
|
||||||
activeGauge
|
activeGauge
|
||||||
val gameId = connection.pathParam("gameId")
|
val gameId = connection.pathParam("gameId")
|
||||||
val playerId = resolvePlayerId(handshake)
|
|
||||||
log.infof("Game WebSocket opened — gameId=%s playerId=%s", gameId, playerId.getOrElse("anonymous"))
|
|
||||||
val handler: Consumer[String] = msg => connection.sendText(msg).subscribe().`with`(_ => (), _ => ())
|
val handler: Consumer[String] = msg => connection.sendText(msg).subscribe().`with`(_ => (), _ => ())
|
||||||
val subscriber = redis.pubsub(classOf[String]).subscribe(s2cTopic(gameId), handler)
|
val subscriber = redis.pubsub(classOf[String]).subscribe(s2cTopic(gameId), handler)
|
||||||
connections.put(connection.id(), ConnectionMeta(gameId, subscriber, playerId))
|
connections.put(connection.id(), ConnectionMeta(gameId, subscriber, None))
|
||||||
connectionsOpened.increment()
|
connectionsOpened.increment()
|
||||||
publishConnected(gameId, playerId)
|
pendingAuth.add(connection.id())
|
||||||
|
log.infof("Game WebSocket opened — gameId=%s connId=%s awaiting auth", gameId, connection.id())
|
||||||
|
|
||||||
@OnTextMessage
|
@OnTextMessage
|
||||||
def onTextMessage(connection: WebSocketConnection, message: String): Unit =
|
def onTextMessage(connection: WebSocketConnection, message: String): Unit =
|
||||||
messagesReceived.increment()
|
messagesReceived.increment()
|
||||||
Option(connections.get(connection.id())).foreach { meta =>
|
if pendingAuth.remove(connection.id()) then
|
||||||
val enriched = meta.playerId match
|
val playerIdOpt =
|
||||||
case Some(pid) => injectPlayerId(message, pid)
|
parseAuthToken(message)
|
||||||
case None => message
|
.flatMap(token => Try(jwtParser.parse(token)).toOption)
|
||||||
redis.pubsub(classOf[String]).publish(c2sTopic(meta.gameId), enriched)
|
.map(_.getSubject)
|
||||||
}
|
playerIdOpt match
|
||||||
|
case None =>
|
||||||
|
log.warnf("Game WebSocket auth failed — closing connId=%s", connection.id())
|
||||||
|
connection.close().subscribe().`with`(_ => (), _ => ())
|
||||||
|
case Some(playerId) =>
|
||||||
|
Option(connections.get(connection.id())).foreach { meta =>
|
||||||
|
connections.put(connection.id(), meta.copy(playerId = Some(playerId)))
|
||||||
|
publishConnected(meta.gameId, Some(playerId))
|
||||||
|
}
|
||||||
|
else
|
||||||
|
Option(connections.get(connection.id())).foreach { meta =>
|
||||||
|
val enriched = meta.playerId match
|
||||||
|
case Some(pid) => injectPlayerId(message, pid)
|
||||||
|
case None => message
|
||||||
|
redis.pubsub(classOf[String]).publish(c2sTopic(meta.gameId), enriched)
|
||||||
|
}
|
||||||
|
|
||||||
@OnClose
|
@OnClose
|
||||||
def onClose(connection: WebSocketConnection): Unit =
|
def onClose(connection: WebSocketConnection): Unit =
|
||||||
|
pendingAuth.remove(connection.id())
|
||||||
Option(connections.remove(connection.id())).foreach { meta =>
|
Option(connections.remove(connection.id())).foreach { meta =>
|
||||||
log.infof("Game WebSocket closed — gameId=%s", meta.gameId)
|
log.infof("Game WebSocket closed — gameId=%s", meta.gameId)
|
||||||
meta.subscriber.unsubscribe(s2cTopic(meta.gameId))
|
meta.subscriber.unsubscribe(s2cTopic(meta.gameId))
|
||||||
connectionsClosed.increment()
|
connectionsClosed.increment()
|
||||||
}
|
}
|
||||||
|
|
||||||
private def resolvePlayerId(handshake: HandshakeRequest): Option[String] =
|
private def parseAuthToken(message: String): Option[String] =
|
||||||
Option(handshake.header("Authorization"))
|
val trimmed = message.trim
|
||||||
.filter(_.nonEmpty)
|
if !trimmed.contains("\"type\":\"auth\"") then None
|
||||||
.flatMap(token => Try(jwtParser.parse(token)).toOption)
|
else
|
||||||
.map(_.getSubject)
|
val start = trimmed.indexOf("\"token\":\"")
|
||||||
|
if start < 0 then None
|
||||||
|
else
|
||||||
|
val valueStart = start + 9
|
||||||
|
val end = trimmed.indexOf('"', valueStart)
|
||||||
|
if end < 0 then None else Some(trimmed.substring(valueStart, end)).filter(_.nonEmpty)
|
||||||
|
|
||||||
private def publishConnected(gameId: String, playerId: Option[String]): Unit =
|
private def publishConnected(gameId: String, playerId: Option[String]): Unit =
|
||||||
val connectedMsg = playerId match
|
val connectedMsg = playerId match
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ class UserWebSocketResource:
|
|||||||
private val maxStreamLen = 1000L
|
private val maxStreamLen = 1000L
|
||||||
|
|
||||||
private val connections = new ConcurrentHashMap[String, (String, WebSocketConnection)]()
|
private val connections = new ConcurrentHashMap[String, (String, WebSocketConnection)]()
|
||||||
|
private val pendingAuth = ConcurrentHashMap.newKeySet[String]()
|
||||||
|
|
||||||
private def userStreamKey(userId: String): String =
|
private def userStreamKey(userId: String): String =
|
||||||
s"${redisConfig.prefix}:user:$userId:events:stream"
|
s"${redisConfig.prefix}:user:$userId:events:stream"
|
||||||
@@ -39,29 +40,34 @@ class UserWebSocketResource:
|
|||||||
private def dlqKey: String = s"${redisConfig.prefix}:dlq"
|
private def dlqKey: String = s"${redisConfig.prefix}:dlq"
|
||||||
|
|
||||||
@OnOpen
|
@OnOpen
|
||||||
def onOpen(connection: WebSocketConnection, handshake: HandshakeRequest): Unit =
|
def onOpen(connection: WebSocketConnection): Unit =
|
||||||
val userIdOpt = Option(handshake.header("Authorization"))
|
pendingAuth.add(connection.id())
|
||||||
.filter(_.nonEmpty)
|
|
||||||
.flatMap(token => Try(jwtParser.parse(token)).toOption)
|
|
||||||
.map(_.getSubject)
|
|
||||||
|
|
||||||
userIdOpt match
|
@OnTextMessage
|
||||||
case None =>
|
def onTextMessage(connection: WebSocketConnection, message: String): Unit =
|
||||||
log.warn("WebSocket opened with no valid JWT — closing connection")
|
if pendingAuth.remove(connection.id()) then
|
||||||
connection.close().subscribe().`with`(_ => (), _ => ())
|
val userIdOpt =
|
||||||
case Some(userId) =>
|
parseAuthToken(message)
|
||||||
log.infof("User WebSocket opened — userId=%s connId=%s", userId, connection.id())
|
.flatMap(token => Try(jwtParser.parse(token)).toOption)
|
||||||
createGroupIfAbsent(userId, connection.id())
|
.map(_.getSubject)
|
||||||
connections.put(connection.id(), (userId, connection))
|
userIdOpt match
|
||||||
executor.submit(
|
case None =>
|
||||||
new Runnable:
|
log.warn("WebSocket opened with no valid JWT — closing connection")
|
||||||
def run(): Unit = pollLoop(connection.id(), userId, connection),
|
connection.close().subscribe().`with`(_ => (), _ => ())
|
||||||
)
|
case Some(userId) =>
|
||||||
val connectedMsg = s"""{"type":"CONNECTED","userId":"$userId"}"""
|
log.infof("User WebSocket opened — userId=%s connId=%s", userId, connection.id())
|
||||||
connection.sendText(connectedMsg).subscribe().`with`(_ => (), _ => ())
|
createGroupIfAbsent(userId, connection.id())
|
||||||
|
connections.put(connection.id(), (userId, connection))
|
||||||
|
executor.submit(
|
||||||
|
new Runnable:
|
||||||
|
def run(): Unit = pollLoop(connection.id(), userId, connection),
|
||||||
|
)
|
||||||
|
val connectedMsg = s"""{"type":"CONNECTED","userId":"$userId"}"""
|
||||||
|
connection.sendText(connectedMsg).subscribe().`with`(_ => (), _ => ())
|
||||||
|
|
||||||
@OnClose
|
@OnClose
|
||||||
def onClose(connection: WebSocketConnection): Unit =
|
def onClose(connection: WebSocketConnection): Unit =
|
||||||
|
pendingAuth.remove(connection.id())
|
||||||
log.infof("User WebSocket closed — connectionId=%s", connection.id())
|
log.infof("User WebSocket closed — connectionId=%s", connection.id())
|
||||||
val userIdOpt = Option(connections.remove(connection.id())).map(_._1)
|
val userIdOpt = Option(connections.remove(connection.id())).map(_._1)
|
||||||
userIdOpt.foreach { userId =>
|
userIdOpt.foreach { userId =>
|
||||||
@@ -128,3 +134,14 @@ class UserWebSocketResource:
|
|||||||
) match
|
) match
|
||||||
case Failure(ex) => log.warnf(ex, "Failed to publish to stream %s", key)
|
case Failure(ex) => log.warnf(ex, "Failed to publish to stream %s", key)
|
||||||
case Success(_) => ()
|
case Success(_) => ()
|
||||||
|
|
||||||
|
private def parseAuthToken(message: String): Option[String] =
|
||||||
|
val trimmed = message.trim
|
||||||
|
if !trimmed.contains("\"type\":\"auth\"") then None
|
||||||
|
else
|
||||||
|
val start = trimmed.indexOf("\"token\":\"")
|
||||||
|
if start < 0 then None
|
||||||
|
else
|
||||||
|
val valueStart = start + 9
|
||||||
|
val end = trimmed.indexOf('"', valueStart)
|
||||||
|
if end < 0 then None else Some(trimmed.substring(valueStart, end)).filter(_.nonEmpty)
|
||||||
|
|||||||
Reference in New Issue
Block a user