fix: NCS-122 authenticate WebSocket connections via first-message auth (#73)
Build & Test (NowChessSystems) TeamCity build failed

Replace header-based auth (not possible with browser WebSocket API) with a
first-message auth protocol: client sends {"type":"auth","token":"<JWT>"}
as the first text frame; server validates and proceeds or closes the connection.

Both GameWebSocketResource and UserWebSocketResource now hold incoming
connections in a pendingAuth set until the auth frame arrives, preventing
any game or event messages from being processed before identity is established.

Also removes the broken Bearer-prefix handling that caused header-based auth
to silently fail even for non-browser clients.

---------

Co-authored-by: LQ63 <lkhermann@web.de>
Reviewed-on: #73
Co-authored-by: Leon Hermann <lq@blackhole.local>
Co-committed-by: Leon Hermann <lq@blackhole.local>
This commit was merged in pull request #73.
This commit is contained in:
2026-06-17 10:42:52 +02:00
committed by Janis
parent 751a58b606
commit 343e2bdd10
2 changed files with 74 additions and 37 deletions
@@ -3,7 +3,6 @@ package de.nowchess.ws.resource
import de.nowchess.ws.config.RedisConfig import de.nowchess.ws.config.RedisConfig
import io.micrometer.core.instrument.{Counter, Gauge, MeterRegistry} import io.micrometer.core.instrument.{Counter, Gauge, MeterRegistry}
import io.quarkus.redis.datasource.RedisDataSource import io.quarkus.redis.datasource.RedisDataSource
import io.quarkus.redis.datasource.pubsub.PubSubCommands
import io.quarkus.websockets.next.* import io.quarkus.websockets.next.*
import io.smallrye.jwt.auth.principal.JWTParser import io.smallrye.jwt.auth.principal.JWTParser
import jakarta.annotation.PostConstruct import jakarta.annotation.PostConstruct
@@ -34,6 +33,7 @@ class GameWebSocketResource:
// scalafix:on DisableSyntax.var // scalafix:on DisableSyntax.var
private val connections = new ConcurrentHashMap[String, ConnectionMeta]() private val connections = new ConcurrentHashMap[String, ConnectionMeta]()
private val pendingAuth = ConcurrentHashMap.newKeySet[String]()
@PostConstruct @PostConstruct
def initializeMetrics(): Unit = { def initializeMetrics(): Unit = {
@@ -64,40 +64,60 @@ class GameWebSocketResource:
s"${redisConfig.prefix}:game:$gameId:c2s" s"${redisConfig.prefix}:game:$gameId:c2s"
@OnOpen @OnOpen
def onOpen(connection: WebSocketConnection, handshake: HandshakeRequest): Unit = def onOpen(connection: WebSocketConnection): Unit =
activeGauge activeGauge
val gameId = connection.pathParam("gameId") val gameId = connection.pathParam("gameId")
val playerId = resolvePlayerId(handshake)
log.infof("Game WebSocket opened — gameId=%s playerId=%s", gameId, playerId.getOrElse("anonymous"))
val handler: Consumer[String] = msg => connection.sendText(msg).subscribe().`with`(_ => (), _ => ()) val handler: Consumer[String] = msg => connection.sendText(msg).subscribe().`with`(_ => (), _ => ())
val subscriber = redis.pubsub(classOf[String]).subscribe(s2cTopic(gameId), handler) val subscriber = redis.pubsub(classOf[String]).subscribe(s2cTopic(gameId), handler)
connections.put(connection.id(), ConnectionMeta(gameId, subscriber, playerId)) connections.put(connection.id(), ConnectionMeta(gameId, subscriber, None))
connectionsOpened.increment() connectionsOpened.increment()
publishConnected(gameId, playerId) pendingAuth.add(connection.id())
log.infof("Game WebSocket opened — gameId=%s connId=%s awaiting auth", gameId, connection.id())
@OnTextMessage @OnTextMessage
def onTextMessage(connection: WebSocketConnection, message: String): Unit = def onTextMessage(connection: WebSocketConnection, message: String): Unit =
messagesReceived.increment() messagesReceived.increment()
Option(connections.get(connection.id())).foreach { meta => if pendingAuth.remove(connection.id()) then
val enriched = meta.playerId match val playerIdOpt =
case Some(pid) => injectPlayerId(message, pid) parseAuthToken(message)
case None => message .flatMap(token => Try(jwtParser.parse(token)).toOption)
redis.pubsub(classOf[String]).publish(c2sTopic(meta.gameId), enriched) .map(_.getSubject)
} playerIdOpt match
case None =>
log.warnf("Game WebSocket auth failed — closing connId=%s", connection.id())
connection.close().subscribe().`with`(_ => (), _ => ())
case Some(playerId) =>
Option(connections.get(connection.id())).foreach { meta =>
connections.put(connection.id(), meta.copy(playerId = Some(playerId)))
publishConnected(meta.gameId, Some(playerId))
}
else
Option(connections.get(connection.id())).foreach { meta =>
val enriched = meta.playerId match
case Some(pid) => injectPlayerId(message, pid)
case None => message
redis.pubsub(classOf[String]).publish(c2sTopic(meta.gameId), enriched)
}
@OnClose @OnClose
def onClose(connection: WebSocketConnection): Unit = def onClose(connection: WebSocketConnection): Unit =
pendingAuth.remove(connection.id())
Option(connections.remove(connection.id())).foreach { meta => Option(connections.remove(connection.id())).foreach { meta =>
log.infof("Game WebSocket closed — gameId=%s", meta.gameId) log.infof("Game WebSocket closed — gameId=%s", meta.gameId)
meta.subscriber.unsubscribe(s2cTopic(meta.gameId)) meta.subscriber.unsubscribe(s2cTopic(meta.gameId))
connectionsClosed.increment() connectionsClosed.increment()
} }
private def resolvePlayerId(handshake: HandshakeRequest): Option[String] = private def parseAuthToken(message: String): Option[String] =
Option(handshake.header("Authorization")) val trimmed = message.trim
.filter(_.nonEmpty) if !trimmed.contains("\"type\":\"auth\"") then None
.flatMap(token => Try(jwtParser.parse(token)).toOption) else
.map(_.getSubject) val start = trimmed.indexOf("\"token\":\"")
if start < 0 then None
else
val valueStart = start + 9
val end = trimmed.indexOf('"', valueStart)
if end < 0 then None else Some(trimmed.substring(valueStart, end)).filter(_.nonEmpty)
private def publishConnected(gameId: String, playerId: Option[String]): Unit = private def publishConnected(gameId: String, playerId: Option[String]): Unit =
val connectedMsg = playerId match val connectedMsg = playerId match
@@ -32,6 +32,7 @@ class UserWebSocketResource:
private val maxStreamLen = 1000L private val maxStreamLen = 1000L
private val connections = new ConcurrentHashMap[String, (String, WebSocketConnection)]() private val connections = new ConcurrentHashMap[String, (String, WebSocketConnection)]()
private val pendingAuth = ConcurrentHashMap.newKeySet[String]()
private def userStreamKey(userId: String): String = private def userStreamKey(userId: String): String =
s"${redisConfig.prefix}:user:$userId:events:stream" s"${redisConfig.prefix}:user:$userId:events:stream"
@@ -39,29 +40,34 @@ class UserWebSocketResource:
private def dlqKey: String = s"${redisConfig.prefix}:dlq" private def dlqKey: String = s"${redisConfig.prefix}:dlq"
@OnOpen @OnOpen
def onOpen(connection: WebSocketConnection, handshake: HandshakeRequest): Unit = def onOpen(connection: WebSocketConnection): Unit =
val userIdOpt = Option(handshake.header("Authorization")) pendingAuth.add(connection.id())
.filter(_.nonEmpty)
.flatMap(token => Try(jwtParser.parse(token)).toOption)
.map(_.getSubject)
userIdOpt match @OnTextMessage
case None => def onTextMessage(connection: WebSocketConnection, message: String): Unit =
log.warn("WebSocket opened with no valid JWT — closing connection") if pendingAuth.remove(connection.id()) then
connection.close().subscribe().`with`(_ => (), _ => ()) val userIdOpt =
case Some(userId) => parseAuthToken(message)
log.infof("User WebSocket opened — userId=%s connId=%s", userId, connection.id()) .flatMap(token => Try(jwtParser.parse(token)).toOption)
createGroupIfAbsent(userId, connection.id()) .map(_.getSubject)
connections.put(connection.id(), (userId, connection)) userIdOpt match
executor.submit( case None =>
new Runnable: log.warn("WebSocket opened with no valid JWT — closing connection")
def run(): Unit = pollLoop(connection.id(), userId, connection), connection.close().subscribe().`with`(_ => (), _ => ())
) case Some(userId) =>
val connectedMsg = s"""{"type":"CONNECTED","userId":"$userId"}""" log.infof("User WebSocket opened — userId=%s connId=%s", userId, connection.id())
connection.sendText(connectedMsg).subscribe().`with`(_ => (), _ => ()) createGroupIfAbsent(userId, connection.id())
connections.put(connection.id(), (userId, connection))
executor.submit(
new Runnable:
def run(): Unit = pollLoop(connection.id(), userId, connection),
)
val connectedMsg = s"""{"type":"CONNECTED","userId":"$userId"}"""
connection.sendText(connectedMsg).subscribe().`with`(_ => (), _ => ())
@OnClose @OnClose
def onClose(connection: WebSocketConnection): Unit = def onClose(connection: WebSocketConnection): Unit =
pendingAuth.remove(connection.id())
log.infof("User WebSocket closed — connectionId=%s", connection.id()) log.infof("User WebSocket closed — connectionId=%s", connection.id())
val userIdOpt = Option(connections.remove(connection.id())).map(_._1) val userIdOpt = Option(connections.remove(connection.id())).map(_._1)
userIdOpt.foreach { userId => userIdOpt.foreach { userId =>
@@ -128,3 +134,14 @@ class UserWebSocketResource:
) match ) match
case Failure(ex) => log.warnf(ex, "Failed to publish to stream %s", key) case Failure(ex) => log.warnf(ex, "Failed to publish to stream %s", key)
case Success(_) => () case Success(_) => ()
private def parseAuthToken(message: String): Option[String] =
val trimmed = message.trim
if !trimmed.contains("\"type\":\"auth\"") then None
else
val start = trimmed.indexOf("\"token\":\"")
if start < 0 then None
else
val valueStart = start + 9
val end = trimmed.indexOf('"', valueStart)
if end < 0 then None else Some(trimmed.substring(valueStart, end)).filter(_.nonEmpty)