fix: NCS-122 authenticate WebSocket connections via first-message auth (#73)
Build & Test (NowChessSystems) TeamCity build failed

Replace header-based auth (not possible with browser WebSocket API) with a
first-message auth protocol: client sends {"type":"auth","token":"<JWT>"}
as the first text frame; server validates and proceeds or closes the connection.

Both GameWebSocketResource and UserWebSocketResource now hold incoming
connections in a pendingAuth set until the auth frame arrives, preventing
any game or event messages from being processed before identity is established.

Also removes the broken Bearer-prefix handling that caused header-based auth
to silently fail even for non-browser clients.

---------

Co-authored-by: LQ63 <lkhermann@web.de>
Reviewed-on: #73
Co-authored-by: Leon Hermann <lq@blackhole.local>
Co-committed-by: Leon Hermann <lq@blackhole.local>
This commit was merged in pull request #73.
This commit is contained in:
2026-06-17 10:42:52 +02:00
committed by Janis
parent 751a58b606
commit 343e2bdd10
2 changed files with 74 additions and 37 deletions
@@ -3,7 +3,6 @@ package de.nowchess.ws.resource
import de.nowchess.ws.config.RedisConfig
import io.micrometer.core.instrument.{Counter, Gauge, MeterRegistry}
import io.quarkus.redis.datasource.RedisDataSource
import io.quarkus.redis.datasource.pubsub.PubSubCommands
import io.quarkus.websockets.next.*
import io.smallrye.jwt.auth.principal.JWTParser
import jakarta.annotation.PostConstruct
@@ -34,6 +33,7 @@ class GameWebSocketResource:
// scalafix:on DisableSyntax.var
private val connections = new ConcurrentHashMap[String, ConnectionMeta]()
private val pendingAuth = ConcurrentHashMap.newKeySet[String]()
@PostConstruct
def initializeMetrics(): Unit = {
@@ -64,40 +64,60 @@ class GameWebSocketResource:
s"${redisConfig.prefix}:game:$gameId:c2s"
@OnOpen
def onOpen(connection: WebSocketConnection, handshake: HandshakeRequest): Unit =
def onOpen(connection: WebSocketConnection): Unit =
activeGauge
val gameId = connection.pathParam("gameId")
val playerId = resolvePlayerId(handshake)
log.infof("Game WebSocket opened — gameId=%s playerId=%s", gameId, playerId.getOrElse("anonymous"))
val gameId = connection.pathParam("gameId")
val handler: Consumer[String] = msg => connection.sendText(msg).subscribe().`with`(_ => (), _ => ())
val subscriber = redis.pubsub(classOf[String]).subscribe(s2cTopic(gameId), handler)
connections.put(connection.id(), ConnectionMeta(gameId, subscriber, playerId))
connections.put(connection.id(), ConnectionMeta(gameId, subscriber, None))
connectionsOpened.increment()
publishConnected(gameId, playerId)
pendingAuth.add(connection.id())
log.infof("Game WebSocket opened — gameId=%s connId=%s awaiting auth", gameId, connection.id())
@OnTextMessage
def onTextMessage(connection: WebSocketConnection, message: String): Unit =
messagesReceived.increment()
Option(connections.get(connection.id())).foreach { meta =>
val enriched = meta.playerId match
case Some(pid) => injectPlayerId(message, pid)
case None => message
redis.pubsub(classOf[String]).publish(c2sTopic(meta.gameId), enriched)
}
if pendingAuth.remove(connection.id()) then
val playerIdOpt =
parseAuthToken(message)
.flatMap(token => Try(jwtParser.parse(token)).toOption)
.map(_.getSubject)
playerIdOpt match
case None =>
log.warnf("Game WebSocket auth failed — closing connId=%s", connection.id())
connection.close().subscribe().`with`(_ => (), _ => ())
case Some(playerId) =>
Option(connections.get(connection.id())).foreach { meta =>
connections.put(connection.id(), meta.copy(playerId = Some(playerId)))
publishConnected(meta.gameId, Some(playerId))
}
else
Option(connections.get(connection.id())).foreach { meta =>
val enriched = meta.playerId match
case Some(pid) => injectPlayerId(message, pid)
case None => message
redis.pubsub(classOf[String]).publish(c2sTopic(meta.gameId), enriched)
}
@OnClose
def onClose(connection: WebSocketConnection): Unit =
pendingAuth.remove(connection.id())
Option(connections.remove(connection.id())).foreach { meta =>
log.infof("Game WebSocket closed — gameId=%s", meta.gameId)
meta.subscriber.unsubscribe(s2cTopic(meta.gameId))
connectionsClosed.increment()
}
private def resolvePlayerId(handshake: HandshakeRequest): Option[String] =
Option(handshake.header("Authorization"))
.filter(_.nonEmpty)
.flatMap(token => Try(jwtParser.parse(token)).toOption)
.map(_.getSubject)
private def parseAuthToken(message: String): Option[String] =
val trimmed = message.trim
if !trimmed.contains("\"type\":\"auth\"") then None
else
val start = trimmed.indexOf("\"token\":\"")
if start < 0 then None
else
val valueStart = start + 9
val end = trimmed.indexOf('"', valueStart)
if end < 0 then None else Some(trimmed.substring(valueStart, end)).filter(_.nonEmpty)
private def publishConnected(gameId: String, playerId: Option[String]): Unit =
val connectedMsg = playerId match
@@ -32,6 +32,7 @@ class UserWebSocketResource:
private val maxStreamLen = 1000L
private val connections = new ConcurrentHashMap[String, (String, WebSocketConnection)]()
private val pendingAuth = ConcurrentHashMap.newKeySet[String]()
private def userStreamKey(userId: String): String =
s"${redisConfig.prefix}:user:$userId:events:stream"
@@ -39,29 +40,34 @@ class UserWebSocketResource:
private def dlqKey: String = s"${redisConfig.prefix}:dlq"
@OnOpen
def onOpen(connection: WebSocketConnection, handshake: HandshakeRequest): Unit =
val userIdOpt = Option(handshake.header("Authorization"))
.filter(_.nonEmpty)
.flatMap(token => Try(jwtParser.parse(token)).toOption)
.map(_.getSubject)
def onOpen(connection: WebSocketConnection): Unit =
pendingAuth.add(connection.id())
userIdOpt match
case None =>
log.warn("WebSocket opened with no valid JWT — closing connection")
connection.close().subscribe().`with`(_ => (), _ => ())
case Some(userId) =>
log.infof("User WebSocket opened — userId=%s connId=%s", userId, connection.id())
createGroupIfAbsent(userId, connection.id())
connections.put(connection.id(), (userId, connection))
executor.submit(
new Runnable:
def run(): Unit = pollLoop(connection.id(), userId, connection),
)
val connectedMsg = s"""{"type":"CONNECTED","userId":"$userId"}"""
connection.sendText(connectedMsg).subscribe().`with`(_ => (), _ => ())
@OnTextMessage
def onTextMessage(connection: WebSocketConnection, message: String): Unit =
if pendingAuth.remove(connection.id()) then
val userIdOpt =
parseAuthToken(message)
.flatMap(token => Try(jwtParser.parse(token)).toOption)
.map(_.getSubject)
userIdOpt match
case None =>
log.warn("WebSocket opened with no valid JWT — closing connection")
connection.close().subscribe().`with`(_ => (), _ => ())
case Some(userId) =>
log.infof("User WebSocket opened — userId=%s connId=%s", userId, connection.id())
createGroupIfAbsent(userId, connection.id())
connections.put(connection.id(), (userId, connection))
executor.submit(
new Runnable:
def run(): Unit = pollLoop(connection.id(), userId, connection),
)
val connectedMsg = s"""{"type":"CONNECTED","userId":"$userId"}"""
connection.sendText(connectedMsg).subscribe().`with`(_ => (), _ => ())
@OnClose
def onClose(connection: WebSocketConnection): Unit =
pendingAuth.remove(connection.id())
log.infof("User WebSocket closed — connectionId=%s", connection.id())
val userIdOpt = Option(connections.remove(connection.id())).map(_._1)
userIdOpt.foreach { userId =>
@@ -128,3 +134,14 @@ class UserWebSocketResource:
) match
case Failure(ex) => log.warnf(ex, "Failed to publish to stream %s", key)
case Success(_) => ()
private def parseAuthToken(message: String): Option[String] =
val trimmed = message.trim
if !trimmed.contains("\"type\":\"auth\"") then None
else
val start = trimmed.indexOf("\"token\":\"")
if start < 0 then None
else
val valueStart = start + 9
val end = trimmed.indexOf('"', valueStart)
if end < 0 then None else Some(trimmed.substring(valueStart, end)).filter(_.nonEmpty)