fix: NCS-122 authenticate WebSocket connections via first-message auth (#73)
Build & Test (NowChessSystems) TeamCity build failed
Build & Test (NowChessSystems) TeamCity build failed
Replace header-based auth (not possible with browser WebSocket API) with a
first-message auth protocol: client sends {"type":"auth","token":"<JWT>"}
as the first text frame; server validates and proceeds or closes the connection.
Both GameWebSocketResource and UserWebSocketResource now hold incoming
connections in a pendingAuth set until the auth frame arrives, preventing
any game or event messages from being processed before identity is established.
Also removes the broken Bearer-prefix handling that caused header-based auth
to silently fail even for non-browser clients.
---------
Co-authored-by: LQ63 <lkhermann@web.de>
Reviewed-on: #73
Co-authored-by: Leon Hermann <lq@blackhole.local>
Co-committed-by: Leon Hermann <lq@blackhole.local>
This commit was merged in pull request #73.
This commit is contained in:
@@ -3,7 +3,6 @@ package de.nowchess.ws.resource
|
|||||||
import de.nowchess.ws.config.RedisConfig
|
import de.nowchess.ws.config.RedisConfig
|
||||||
import io.micrometer.core.instrument.{Counter, Gauge, MeterRegistry}
|
import io.micrometer.core.instrument.{Counter, Gauge, MeterRegistry}
|
||||||
import io.quarkus.redis.datasource.RedisDataSource
|
import io.quarkus.redis.datasource.RedisDataSource
|
||||||
import io.quarkus.redis.datasource.pubsub.PubSubCommands
|
|
||||||
import io.quarkus.websockets.next.*
|
import io.quarkus.websockets.next.*
|
||||||
import io.smallrye.jwt.auth.principal.JWTParser
|
import io.smallrye.jwt.auth.principal.JWTParser
|
||||||
import jakarta.annotation.PostConstruct
|
import jakarta.annotation.PostConstruct
|
||||||
@@ -34,6 +33,7 @@ class GameWebSocketResource:
|
|||||||
// scalafix:on DisableSyntax.var
|
// scalafix:on DisableSyntax.var
|
||||||
|
|
||||||
private val connections = new ConcurrentHashMap[String, ConnectionMeta]()
|
private val connections = new ConcurrentHashMap[String, ConnectionMeta]()
|
||||||
|
private val pendingAuth = ConcurrentHashMap.newKeySet[String]()
|
||||||
|
|
||||||
@PostConstruct
|
@PostConstruct
|
||||||
def initializeMetrics(): Unit = {
|
def initializeMetrics(): Unit = {
|
||||||
@@ -64,40 +64,60 @@ class GameWebSocketResource:
|
|||||||
s"${redisConfig.prefix}:game:$gameId:c2s"
|
s"${redisConfig.prefix}:game:$gameId:c2s"
|
||||||
|
|
||||||
@OnOpen
|
@OnOpen
|
||||||
def onOpen(connection: WebSocketConnection, handshake: HandshakeRequest): Unit =
|
def onOpen(connection: WebSocketConnection): Unit =
|
||||||
activeGauge
|
activeGauge
|
||||||
val gameId = connection.pathParam("gameId")
|
val gameId = connection.pathParam("gameId")
|
||||||
val playerId = resolvePlayerId(handshake)
|
|
||||||
log.infof("Game WebSocket opened — gameId=%s playerId=%s", gameId, playerId.getOrElse("anonymous"))
|
|
||||||
val handler: Consumer[String] = msg => connection.sendText(msg).subscribe().`with`(_ => (), _ => ())
|
val handler: Consumer[String] = msg => connection.sendText(msg).subscribe().`with`(_ => (), _ => ())
|
||||||
val subscriber = redis.pubsub(classOf[String]).subscribe(s2cTopic(gameId), handler)
|
val subscriber = redis.pubsub(classOf[String]).subscribe(s2cTopic(gameId), handler)
|
||||||
connections.put(connection.id(), ConnectionMeta(gameId, subscriber, playerId))
|
connections.put(connection.id(), ConnectionMeta(gameId, subscriber, None))
|
||||||
connectionsOpened.increment()
|
connectionsOpened.increment()
|
||||||
publishConnected(gameId, playerId)
|
pendingAuth.add(connection.id())
|
||||||
|
log.infof("Game WebSocket opened — gameId=%s connId=%s awaiting auth", gameId, connection.id())
|
||||||
|
|
||||||
@OnTextMessage
|
@OnTextMessage
|
||||||
def onTextMessage(connection: WebSocketConnection, message: String): Unit =
|
def onTextMessage(connection: WebSocketConnection, message: String): Unit =
|
||||||
messagesReceived.increment()
|
messagesReceived.increment()
|
||||||
Option(connections.get(connection.id())).foreach { meta =>
|
if pendingAuth.remove(connection.id()) then
|
||||||
val enriched = meta.playerId match
|
val playerIdOpt =
|
||||||
case Some(pid) => injectPlayerId(message, pid)
|
parseAuthToken(message)
|
||||||
case None => message
|
.flatMap(token => Try(jwtParser.parse(token)).toOption)
|
||||||
redis.pubsub(classOf[String]).publish(c2sTopic(meta.gameId), enriched)
|
.map(_.getSubject)
|
||||||
}
|
playerIdOpt match
|
||||||
|
case None =>
|
||||||
|
log.warnf("Game WebSocket auth failed — closing connId=%s", connection.id())
|
||||||
|
connection.close().subscribe().`with`(_ => (), _ => ())
|
||||||
|
case Some(playerId) =>
|
||||||
|
Option(connections.get(connection.id())).foreach { meta =>
|
||||||
|
connections.put(connection.id(), meta.copy(playerId = Some(playerId)))
|
||||||
|
publishConnected(meta.gameId, Some(playerId))
|
||||||
|
}
|
||||||
|
else
|
||||||
|
Option(connections.get(connection.id())).foreach { meta =>
|
||||||
|
val enriched = meta.playerId match
|
||||||
|
case Some(pid) => injectPlayerId(message, pid)
|
||||||
|
case None => message
|
||||||
|
redis.pubsub(classOf[String]).publish(c2sTopic(meta.gameId), enriched)
|
||||||
|
}
|
||||||
|
|
||||||
@OnClose
|
@OnClose
|
||||||
def onClose(connection: WebSocketConnection): Unit =
|
def onClose(connection: WebSocketConnection): Unit =
|
||||||
|
pendingAuth.remove(connection.id())
|
||||||
Option(connections.remove(connection.id())).foreach { meta =>
|
Option(connections.remove(connection.id())).foreach { meta =>
|
||||||
log.infof("Game WebSocket closed — gameId=%s", meta.gameId)
|
log.infof("Game WebSocket closed — gameId=%s", meta.gameId)
|
||||||
meta.subscriber.unsubscribe(s2cTopic(meta.gameId))
|
meta.subscriber.unsubscribe(s2cTopic(meta.gameId))
|
||||||
connectionsClosed.increment()
|
connectionsClosed.increment()
|
||||||
}
|
}
|
||||||
|
|
||||||
private def resolvePlayerId(handshake: HandshakeRequest): Option[String] =
|
private def parseAuthToken(message: String): Option[String] =
|
||||||
Option(handshake.header("Authorization"))
|
val trimmed = message.trim
|
||||||
.filter(_.nonEmpty)
|
if !trimmed.contains("\"type\":\"auth\"") then None
|
||||||
.flatMap(token => Try(jwtParser.parse(token)).toOption)
|
else
|
||||||
.map(_.getSubject)
|
val start = trimmed.indexOf("\"token\":\"")
|
||||||
|
if start < 0 then None
|
||||||
|
else
|
||||||
|
val valueStart = start + 9
|
||||||
|
val end = trimmed.indexOf('"', valueStart)
|
||||||
|
if end < 0 then None else Some(trimmed.substring(valueStart, end)).filter(_.nonEmpty)
|
||||||
|
|
||||||
private def publishConnected(gameId: String, playerId: Option[String]): Unit =
|
private def publishConnected(gameId: String, playerId: Option[String]): Unit =
|
||||||
val connectedMsg = playerId match
|
val connectedMsg = playerId match
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ class UserWebSocketResource:
|
|||||||
private val maxStreamLen = 1000L
|
private val maxStreamLen = 1000L
|
||||||
|
|
||||||
private val connections = new ConcurrentHashMap[String, (String, WebSocketConnection)]()
|
private val connections = new ConcurrentHashMap[String, (String, WebSocketConnection)]()
|
||||||
|
private val pendingAuth = ConcurrentHashMap.newKeySet[String]()
|
||||||
|
|
||||||
private def userStreamKey(userId: String): String =
|
private def userStreamKey(userId: String): String =
|
||||||
s"${redisConfig.prefix}:user:$userId:events:stream"
|
s"${redisConfig.prefix}:user:$userId:events:stream"
|
||||||
@@ -39,29 +40,34 @@ class UserWebSocketResource:
|
|||||||
private def dlqKey: String = s"${redisConfig.prefix}:dlq"
|
private def dlqKey: String = s"${redisConfig.prefix}:dlq"
|
||||||
|
|
||||||
@OnOpen
|
@OnOpen
|
||||||
def onOpen(connection: WebSocketConnection, handshake: HandshakeRequest): Unit =
|
def onOpen(connection: WebSocketConnection): Unit =
|
||||||
val userIdOpt = Option(handshake.header("Authorization"))
|
pendingAuth.add(connection.id())
|
||||||
.filter(_.nonEmpty)
|
|
||||||
.flatMap(token => Try(jwtParser.parse(token)).toOption)
|
|
||||||
.map(_.getSubject)
|
|
||||||
|
|
||||||
userIdOpt match
|
@OnTextMessage
|
||||||
case None =>
|
def onTextMessage(connection: WebSocketConnection, message: String): Unit =
|
||||||
log.warn("WebSocket opened with no valid JWT — closing connection")
|
if pendingAuth.remove(connection.id()) then
|
||||||
connection.close().subscribe().`with`(_ => (), _ => ())
|
val userIdOpt =
|
||||||
case Some(userId) =>
|
parseAuthToken(message)
|
||||||
log.infof("User WebSocket opened — userId=%s connId=%s", userId, connection.id())
|
.flatMap(token => Try(jwtParser.parse(token)).toOption)
|
||||||
createGroupIfAbsent(userId, connection.id())
|
.map(_.getSubject)
|
||||||
connections.put(connection.id(), (userId, connection))
|
userIdOpt match
|
||||||
executor.submit(
|
case None =>
|
||||||
new Runnable:
|
log.warn("WebSocket opened with no valid JWT — closing connection")
|
||||||
def run(): Unit = pollLoop(connection.id(), userId, connection),
|
connection.close().subscribe().`with`(_ => (), _ => ())
|
||||||
)
|
case Some(userId) =>
|
||||||
val connectedMsg = s"""{"type":"CONNECTED","userId":"$userId"}"""
|
log.infof("User WebSocket opened — userId=%s connId=%s", userId, connection.id())
|
||||||
connection.sendText(connectedMsg).subscribe().`with`(_ => (), _ => ())
|
createGroupIfAbsent(userId, connection.id())
|
||||||
|
connections.put(connection.id(), (userId, connection))
|
||||||
|
executor.submit(
|
||||||
|
new Runnable:
|
||||||
|
def run(): Unit = pollLoop(connection.id(), userId, connection),
|
||||||
|
)
|
||||||
|
val connectedMsg = s"""{"type":"CONNECTED","userId":"$userId"}"""
|
||||||
|
connection.sendText(connectedMsg).subscribe().`with`(_ => (), _ => ())
|
||||||
|
|
||||||
@OnClose
|
@OnClose
|
||||||
def onClose(connection: WebSocketConnection): Unit =
|
def onClose(connection: WebSocketConnection): Unit =
|
||||||
|
pendingAuth.remove(connection.id())
|
||||||
log.infof("User WebSocket closed — connectionId=%s", connection.id())
|
log.infof("User WebSocket closed — connectionId=%s", connection.id())
|
||||||
val userIdOpt = Option(connections.remove(connection.id())).map(_._1)
|
val userIdOpt = Option(connections.remove(connection.id())).map(_._1)
|
||||||
userIdOpt.foreach { userId =>
|
userIdOpt.foreach { userId =>
|
||||||
@@ -128,3 +134,14 @@ class UserWebSocketResource:
|
|||||||
) match
|
) match
|
||||||
case Failure(ex) => log.warnf(ex, "Failed to publish to stream %s", key)
|
case Failure(ex) => log.warnf(ex, "Failed to publish to stream %s", key)
|
||||||
case Success(_) => ()
|
case Success(_) => ()
|
||||||
|
|
||||||
|
private def parseAuthToken(message: String): Option[String] =
|
||||||
|
val trimmed = message.trim
|
||||||
|
if !trimmed.contains("\"type\":\"auth\"") then None
|
||||||
|
else
|
||||||
|
val start = trimmed.indexOf("\"token\":\"")
|
||||||
|
if start < 0 then None
|
||||||
|
else
|
||||||
|
val valueStart = start + 9
|
||||||
|
val end = trimmed.indexOf('"', valueStart)
|
||||||
|
if end < 0 then None else Some(trimmed.substring(valueStart, end)).filter(_.nonEmpty)
|
||||||
|
|||||||
Reference in New Issue
Block a user