fix: NCS-122 authenticate WebSocket connections via first-message auth (#73)
Build & Test (NowChessSystems) TeamCity build failed
Build & Test (NowChessSystems) TeamCity build failed
Replace header-based auth (not possible with browser WebSocket API) with a
first-message auth protocol: client sends {"type":"auth","token":"<JWT>"}
as the first text frame; server validates and proceeds or closes the connection.
Both GameWebSocketResource and UserWebSocketResource now hold incoming
connections in a pendingAuth set until the auth frame arrives, preventing
any game or event messages from being processed before identity is established.
Also removes the broken Bearer-prefix handling that caused header-based auth
to silently fail even for non-browser clients.
---------
Co-authored-by: LQ63 <lkhermann@web.de>
Reviewed-on: #73
Co-authored-by: Leon Hermann <lq@blackhole.local>
Co-committed-by: Leon Hermann <lq@blackhole.local>
This commit was merged in pull request #73.
This commit is contained in:
@@ -3,7 +3,6 @@ package de.nowchess.ws.resource
|
||||
import de.nowchess.ws.config.RedisConfig
|
||||
import io.micrometer.core.instrument.{Counter, Gauge, MeterRegistry}
|
||||
import io.quarkus.redis.datasource.RedisDataSource
|
||||
import io.quarkus.redis.datasource.pubsub.PubSubCommands
|
||||
import io.quarkus.websockets.next.*
|
||||
import io.smallrye.jwt.auth.principal.JWTParser
|
||||
import jakarta.annotation.PostConstruct
|
||||
@@ -34,6 +33,7 @@ class GameWebSocketResource:
|
||||
// scalafix:on DisableSyntax.var
|
||||
|
||||
private val connections = new ConcurrentHashMap[String, ConnectionMeta]()
|
||||
private val pendingAuth = ConcurrentHashMap.newKeySet[String]()
|
||||
|
||||
@PostConstruct
|
||||
def initializeMetrics(): Unit = {
|
||||
@@ -64,40 +64,60 @@ class GameWebSocketResource:
|
||||
s"${redisConfig.prefix}:game:$gameId:c2s"
|
||||
|
||||
@OnOpen
|
||||
def onOpen(connection: WebSocketConnection, handshake: HandshakeRequest): Unit =
|
||||
def onOpen(connection: WebSocketConnection): Unit =
|
||||
activeGauge
|
||||
val gameId = connection.pathParam("gameId")
|
||||
val playerId = resolvePlayerId(handshake)
|
||||
log.infof("Game WebSocket opened — gameId=%s playerId=%s", gameId, playerId.getOrElse("anonymous"))
|
||||
val gameId = connection.pathParam("gameId")
|
||||
val handler: Consumer[String] = msg => connection.sendText(msg).subscribe().`with`(_ => (), _ => ())
|
||||
val subscriber = redis.pubsub(classOf[String]).subscribe(s2cTopic(gameId), handler)
|
||||
connections.put(connection.id(), ConnectionMeta(gameId, subscriber, playerId))
|
||||
connections.put(connection.id(), ConnectionMeta(gameId, subscriber, None))
|
||||
connectionsOpened.increment()
|
||||
publishConnected(gameId, playerId)
|
||||
pendingAuth.add(connection.id())
|
||||
log.infof("Game WebSocket opened — gameId=%s connId=%s awaiting auth", gameId, connection.id())
|
||||
|
||||
@OnTextMessage
|
||||
def onTextMessage(connection: WebSocketConnection, message: String): Unit =
|
||||
messagesReceived.increment()
|
||||
Option(connections.get(connection.id())).foreach { meta =>
|
||||
val enriched = meta.playerId match
|
||||
case Some(pid) => injectPlayerId(message, pid)
|
||||
case None => message
|
||||
redis.pubsub(classOf[String]).publish(c2sTopic(meta.gameId), enriched)
|
||||
}
|
||||
if pendingAuth.remove(connection.id()) then
|
||||
val playerIdOpt =
|
||||
parseAuthToken(message)
|
||||
.flatMap(token => Try(jwtParser.parse(token)).toOption)
|
||||
.map(_.getSubject)
|
||||
playerIdOpt match
|
||||
case None =>
|
||||
log.warnf("Game WebSocket auth failed — closing connId=%s", connection.id())
|
||||
connection.close().subscribe().`with`(_ => (), _ => ())
|
||||
case Some(playerId) =>
|
||||
Option(connections.get(connection.id())).foreach { meta =>
|
||||
connections.put(connection.id(), meta.copy(playerId = Some(playerId)))
|
||||
publishConnected(meta.gameId, Some(playerId))
|
||||
}
|
||||
else
|
||||
Option(connections.get(connection.id())).foreach { meta =>
|
||||
val enriched = meta.playerId match
|
||||
case Some(pid) => injectPlayerId(message, pid)
|
||||
case None => message
|
||||
redis.pubsub(classOf[String]).publish(c2sTopic(meta.gameId), enriched)
|
||||
}
|
||||
|
||||
@OnClose
|
||||
def onClose(connection: WebSocketConnection): Unit =
|
||||
pendingAuth.remove(connection.id())
|
||||
Option(connections.remove(connection.id())).foreach { meta =>
|
||||
log.infof("Game WebSocket closed — gameId=%s", meta.gameId)
|
||||
meta.subscriber.unsubscribe(s2cTopic(meta.gameId))
|
||||
connectionsClosed.increment()
|
||||
}
|
||||
|
||||
private def resolvePlayerId(handshake: HandshakeRequest): Option[String] =
|
||||
Option(handshake.header("Authorization"))
|
||||
.filter(_.nonEmpty)
|
||||
.flatMap(token => Try(jwtParser.parse(token)).toOption)
|
||||
.map(_.getSubject)
|
||||
private def parseAuthToken(message: String): Option[String] =
|
||||
val trimmed = message.trim
|
||||
if !trimmed.contains("\"type\":\"auth\"") then None
|
||||
else
|
||||
val start = trimmed.indexOf("\"token\":\"")
|
||||
if start < 0 then None
|
||||
else
|
||||
val valueStart = start + 9
|
||||
val end = trimmed.indexOf('"', valueStart)
|
||||
if end < 0 then None else Some(trimmed.substring(valueStart, end)).filter(_.nonEmpty)
|
||||
|
||||
private def publishConnected(gameId: String, playerId: Option[String]): Unit =
|
||||
val connectedMsg = playerId match
|
||||
|
||||
@@ -32,6 +32,7 @@ class UserWebSocketResource:
|
||||
private val maxStreamLen = 1000L
|
||||
|
||||
private val connections = new ConcurrentHashMap[String, (String, WebSocketConnection)]()
|
||||
private val pendingAuth = ConcurrentHashMap.newKeySet[String]()
|
||||
|
||||
private def userStreamKey(userId: String): String =
|
||||
s"${redisConfig.prefix}:user:$userId:events:stream"
|
||||
@@ -39,29 +40,34 @@ class UserWebSocketResource:
|
||||
private def dlqKey: String = s"${redisConfig.prefix}:dlq"
|
||||
|
||||
@OnOpen
|
||||
def onOpen(connection: WebSocketConnection, handshake: HandshakeRequest): Unit =
|
||||
val userIdOpt = Option(handshake.header("Authorization"))
|
||||
.filter(_.nonEmpty)
|
||||
.flatMap(token => Try(jwtParser.parse(token)).toOption)
|
||||
.map(_.getSubject)
|
||||
def onOpen(connection: WebSocketConnection): Unit =
|
||||
pendingAuth.add(connection.id())
|
||||
|
||||
userIdOpt match
|
||||
case None =>
|
||||
log.warn("WebSocket opened with no valid JWT — closing connection")
|
||||
connection.close().subscribe().`with`(_ => (), _ => ())
|
||||
case Some(userId) =>
|
||||
log.infof("User WebSocket opened — userId=%s connId=%s", userId, connection.id())
|
||||
createGroupIfAbsent(userId, connection.id())
|
||||
connections.put(connection.id(), (userId, connection))
|
||||
executor.submit(
|
||||
new Runnable:
|
||||
def run(): Unit = pollLoop(connection.id(), userId, connection),
|
||||
)
|
||||
val connectedMsg = s"""{"type":"CONNECTED","userId":"$userId"}"""
|
||||
connection.sendText(connectedMsg).subscribe().`with`(_ => (), _ => ())
|
||||
@OnTextMessage
|
||||
def onTextMessage(connection: WebSocketConnection, message: String): Unit =
|
||||
if pendingAuth.remove(connection.id()) then
|
||||
val userIdOpt =
|
||||
parseAuthToken(message)
|
||||
.flatMap(token => Try(jwtParser.parse(token)).toOption)
|
||||
.map(_.getSubject)
|
||||
userIdOpt match
|
||||
case None =>
|
||||
log.warn("WebSocket opened with no valid JWT — closing connection")
|
||||
connection.close().subscribe().`with`(_ => (), _ => ())
|
||||
case Some(userId) =>
|
||||
log.infof("User WebSocket opened — userId=%s connId=%s", userId, connection.id())
|
||||
createGroupIfAbsent(userId, connection.id())
|
||||
connections.put(connection.id(), (userId, connection))
|
||||
executor.submit(
|
||||
new Runnable:
|
||||
def run(): Unit = pollLoop(connection.id(), userId, connection),
|
||||
)
|
||||
val connectedMsg = s"""{"type":"CONNECTED","userId":"$userId"}"""
|
||||
connection.sendText(connectedMsg).subscribe().`with`(_ => (), _ => ())
|
||||
|
||||
@OnClose
|
||||
def onClose(connection: WebSocketConnection): Unit =
|
||||
pendingAuth.remove(connection.id())
|
||||
log.infof("User WebSocket closed — connectionId=%s", connection.id())
|
||||
val userIdOpt = Option(connections.remove(connection.id())).map(_._1)
|
||||
userIdOpt.foreach { userId =>
|
||||
@@ -128,3 +134,14 @@ class UserWebSocketResource:
|
||||
) match
|
||||
case Failure(ex) => log.warnf(ex, "Failed to publish to stream %s", key)
|
||||
case Success(_) => ()
|
||||
|
||||
private def parseAuthToken(message: String): Option[String] =
|
||||
val trimmed = message.trim
|
||||
if !trimmed.contains("\"type\":\"auth\"") then None
|
||||
else
|
||||
val start = trimmed.indexOf("\"token\":\"")
|
||||
if start < 0 then None
|
||||
else
|
||||
val valueStart = start + 9
|
||||
val end = trimmed.indexOf('"', valueStart)
|
||||
if end < 0 then None else Some(trimmed.substring(valueStart, end)).filter(_.nonEmpty)
|
||||
|
||||
Reference in New Issue
Block a user