From 343e2bdd100649a96d96da8a6d98caad6de4ad14 Mon Sep 17 00:00:00 2001 From: Leon Hermann Date: Wed, 17 Jun 2026 10:42:52 +0200 Subject: [PATCH] fix: NCS-122 authenticate WebSocket connections via first-message auth (#73) Replace header-based auth (not possible with browser WebSocket API) with a first-message auth protocol: client sends {"type":"auth","token":""} as the first text frame; server validates and proceeds or closes the connection. Both GameWebSocketResource and UserWebSocketResource now hold incoming connections in a pendingAuth set until the auth frame arrives, preventing any game or event messages from being processed before identity is established. Also removes the broken Bearer-prefix handling that caused header-based auth to silently fail even for non-browser clients. --------- Co-authored-by: LQ63 Reviewed-on: https://git.janis-eccarius.de/NowChess/NowChessSystems/pulls/73 Co-authored-by: Leon Hermann Co-committed-by: Leon Hermann --- .../ws/resource/GameWebSocketResource.scala | 56 +++++++++++++------ .../ws/resource/UserWebSocketResource.scala | 55 +++++++++++------- 2 files changed, 74 insertions(+), 37 deletions(-) diff --git a/modules/ws/src/main/scala/de/nowchess/ws/resource/GameWebSocketResource.scala b/modules/ws/src/main/scala/de/nowchess/ws/resource/GameWebSocketResource.scala index 30f4c26..8dbc7b0 100644 --- a/modules/ws/src/main/scala/de/nowchess/ws/resource/GameWebSocketResource.scala +++ b/modules/ws/src/main/scala/de/nowchess/ws/resource/GameWebSocketResource.scala @@ -3,7 +3,6 @@ package de.nowchess.ws.resource import de.nowchess.ws.config.RedisConfig import io.micrometer.core.instrument.{Counter, Gauge, MeterRegistry} import io.quarkus.redis.datasource.RedisDataSource -import io.quarkus.redis.datasource.pubsub.PubSubCommands import io.quarkus.websockets.next.* import io.smallrye.jwt.auth.principal.JWTParser import jakarta.annotation.PostConstruct @@ -34,6 +33,7 @@ class GameWebSocketResource: // scalafix:on DisableSyntax.var private val connections = new ConcurrentHashMap[String, ConnectionMeta]() + private val pendingAuth = ConcurrentHashMap.newKeySet[String]() @PostConstruct def initializeMetrics(): Unit = { @@ -64,40 +64,60 @@ class GameWebSocketResource: s"${redisConfig.prefix}:game:$gameId:c2s" @OnOpen - def onOpen(connection: WebSocketConnection, handshake: HandshakeRequest): Unit = + def onOpen(connection: WebSocketConnection): Unit = activeGauge - val gameId = connection.pathParam("gameId") - val playerId = resolvePlayerId(handshake) - log.infof("Game WebSocket opened — gameId=%s playerId=%s", gameId, playerId.getOrElse("anonymous")) + val gameId = connection.pathParam("gameId") val handler: Consumer[String] = msg => connection.sendText(msg).subscribe().`with`(_ => (), _ => ()) val subscriber = redis.pubsub(classOf[String]).subscribe(s2cTopic(gameId), handler) - connections.put(connection.id(), ConnectionMeta(gameId, subscriber, playerId)) + connections.put(connection.id(), ConnectionMeta(gameId, subscriber, None)) connectionsOpened.increment() - publishConnected(gameId, playerId) + pendingAuth.add(connection.id()) + log.infof("Game WebSocket opened — gameId=%s connId=%s awaiting auth", gameId, connection.id()) @OnTextMessage def onTextMessage(connection: WebSocketConnection, message: String): Unit = messagesReceived.increment() - Option(connections.get(connection.id())).foreach { meta => - val enriched = meta.playerId match - case Some(pid) => injectPlayerId(message, pid) - case None => message - redis.pubsub(classOf[String]).publish(c2sTopic(meta.gameId), enriched) - } + if pendingAuth.remove(connection.id()) then + val playerIdOpt = + parseAuthToken(message) + .flatMap(token => Try(jwtParser.parse(token)).toOption) + .map(_.getSubject) + playerIdOpt match + case None => + log.warnf("Game WebSocket auth failed — closing connId=%s", connection.id()) + connection.close().subscribe().`with`(_ => (), _ => ()) + case Some(playerId) => + Option(connections.get(connection.id())).foreach { meta => + connections.put(connection.id(), meta.copy(playerId = Some(playerId))) + publishConnected(meta.gameId, Some(playerId)) + } + else + Option(connections.get(connection.id())).foreach { meta => + val enriched = meta.playerId match + case Some(pid) => injectPlayerId(message, pid) + case None => message + redis.pubsub(classOf[String]).publish(c2sTopic(meta.gameId), enriched) + } @OnClose def onClose(connection: WebSocketConnection): Unit = + pendingAuth.remove(connection.id()) Option(connections.remove(connection.id())).foreach { meta => log.infof("Game WebSocket closed — gameId=%s", meta.gameId) meta.subscriber.unsubscribe(s2cTopic(meta.gameId)) connectionsClosed.increment() } - private def resolvePlayerId(handshake: HandshakeRequest): Option[String] = - Option(handshake.header("Authorization")) - .filter(_.nonEmpty) - .flatMap(token => Try(jwtParser.parse(token)).toOption) - .map(_.getSubject) + private def parseAuthToken(message: String): Option[String] = + val trimmed = message.trim + if !trimmed.contains("\"type\":\"auth\"") then None + else + val start = trimmed.indexOf("\"token\":\"") + if start < 0 then None + else + val valueStart = start + 9 + val end = trimmed.indexOf('"', valueStart) + if end < 0 then None else Some(trimmed.substring(valueStart, end)).filter(_.nonEmpty) private def publishConnected(gameId: String, playerId: Option[String]): Unit = val connectedMsg = playerId match diff --git a/modules/ws/src/main/scala/de/nowchess/ws/resource/UserWebSocketResource.scala b/modules/ws/src/main/scala/de/nowchess/ws/resource/UserWebSocketResource.scala index 52da4e9..b19fc93 100644 --- a/modules/ws/src/main/scala/de/nowchess/ws/resource/UserWebSocketResource.scala +++ b/modules/ws/src/main/scala/de/nowchess/ws/resource/UserWebSocketResource.scala @@ -32,6 +32,7 @@ class UserWebSocketResource: private val maxStreamLen = 1000L private val connections = new ConcurrentHashMap[String, (String, WebSocketConnection)]() + private val pendingAuth = ConcurrentHashMap.newKeySet[String]() private def userStreamKey(userId: String): String = s"${redisConfig.prefix}:user:$userId:events:stream" @@ -39,29 +40,34 @@ class UserWebSocketResource: private def dlqKey: String = s"${redisConfig.prefix}:dlq" @OnOpen - def onOpen(connection: WebSocketConnection, handshake: HandshakeRequest): Unit = - val userIdOpt = Option(handshake.header("Authorization")) - .filter(_.nonEmpty) - .flatMap(token => Try(jwtParser.parse(token)).toOption) - .map(_.getSubject) + def onOpen(connection: WebSocketConnection): Unit = + pendingAuth.add(connection.id()) - userIdOpt match - case None => - log.warn("WebSocket opened with no valid JWT — closing connection") - connection.close().subscribe().`with`(_ => (), _ => ()) - case Some(userId) => - log.infof("User WebSocket opened — userId=%s connId=%s", userId, connection.id()) - createGroupIfAbsent(userId, connection.id()) - connections.put(connection.id(), (userId, connection)) - executor.submit( - new Runnable: - def run(): Unit = pollLoop(connection.id(), userId, connection), - ) - val connectedMsg = s"""{"type":"CONNECTED","userId":"$userId"}""" - connection.sendText(connectedMsg).subscribe().`with`(_ => (), _ => ()) + @OnTextMessage + def onTextMessage(connection: WebSocketConnection, message: String): Unit = + if pendingAuth.remove(connection.id()) then + val userIdOpt = + parseAuthToken(message) + .flatMap(token => Try(jwtParser.parse(token)).toOption) + .map(_.getSubject) + userIdOpt match + case None => + log.warn("WebSocket opened with no valid JWT — closing connection") + connection.close().subscribe().`with`(_ => (), _ => ()) + case Some(userId) => + log.infof("User WebSocket opened — userId=%s connId=%s", userId, connection.id()) + createGroupIfAbsent(userId, connection.id()) + connections.put(connection.id(), (userId, connection)) + executor.submit( + new Runnable: + def run(): Unit = pollLoop(connection.id(), userId, connection), + ) + val connectedMsg = s"""{"type":"CONNECTED","userId":"$userId"}""" + connection.sendText(connectedMsg).subscribe().`with`(_ => (), _ => ()) @OnClose def onClose(connection: WebSocketConnection): Unit = + pendingAuth.remove(connection.id()) log.infof("User WebSocket closed — connectionId=%s", connection.id()) val userIdOpt = Option(connections.remove(connection.id())).map(_._1) userIdOpt.foreach { userId => @@ -128,3 +134,14 @@ class UserWebSocketResource: ) match case Failure(ex) => log.warnf(ex, "Failed to publish to stream %s", key) case Success(_) => () + + private def parseAuthToken(message: String): Option[String] = + val trimmed = message.trim + if !trimmed.contains("\"type\":\"auth\"") then None + else + val start = trimmed.indexOf("\"token\":\"") + if start < 0 then None + else + val valueStart = start + 9 + val end = trimmed.indexOf('"', valueStart) + if end < 0 then None else Some(trimmed.substring(valueStart, end)).filter(_.nonEmpty)