package services import com.typesafe.config.Config import play.api.libs.ws.WSClient import play.api.Configuration import play.api.libs.json.* import java.net.URI import javax.inject.* import scala.concurrent.{ExecutionContext, Future} import com.nimbusds.oauth2.sdk.* import com.nimbusds.oauth2.sdk.id.* import com.nimbusds.openid.connect.sdk.* import play.api.libs.ws.DefaultBodyWritables.* case class OpenIDUserInfo( id: String, email: Option[String], name: Option[String], picture: Option[String], provider: String, providerName: String ) object OpenIDUserInfo { implicit val writes: Writes[OpenIDUserInfo] = Json.writes[OpenIDUserInfo] implicit val reads: Reads[OpenIDUserInfo] = Json.reads[OpenIDUserInfo] } case class OpenIDProvider( name: String, clientId: String, clientSecret: String, redirectUri: String, authorizationEndpoint: String, tokenEndpoint: String, userInfoEndpoint: String, scopes: Set[String] = Set("openid", "profile", "email") ) case class TokenResponse( accessToken: String, tokenType: String, expiresIn: Option[Int], refreshToken: Option[String], idToken: Option[String] ) @Singleton class OpenIDConnectService@Inject(ws: WSClient, config: Configuration)(implicit ec: ExecutionContext) { private val providers = Map( "discord" -> OpenIDProvider( name = "Discord", clientId = config.get[String]("openid.discord.clientId"), clientSecret = config.get[String]("openid.discord.clientSecret"), redirectUri = config.get[String]("openid.discord.redirectUri"), authorizationEndpoint = "https://discord.com/oauth2/authorize", tokenEndpoint = "https://discord.com/api/oauth2/token", userInfoEndpoint = "https://discord.com/api/users/@me", scopes = Set("identify", "email") ), "keycloak" -> OpenIDProvider( name = "Identity", clientId = config.get[String]("openid.keycloak.clientId"), clientSecret = config.get[String]("openid.keycloak.clientSecret"), redirectUri = config.get[String]("openid.keycloak.redirectUri"), authorizationEndpoint = config.get[String]("openid.keycloak.authUrl") + "/protocol/openid-connect/auth", tokenEndpoint = config.get[String]("openid.keycloak.authUrl") + "/protocol/openid-connect/token", userInfoEndpoint = config.get[String]("openid.keycloak.authUrl") + "/protocol/openid-connect/userinfo", scopes = Set("openid", "profile", "email") ) ) def getAuthorizationUrl(providerName: String, state: String, nonce: String): Option[String] = { providers.get(providerName).map { provider => val authRequest = if (provider.scopes.contains("openid")) { // Use OpenID Connect AuthenticationRequest for OpenID providers new AuthenticationRequest.Builder( new ResponseType(ResponseType.Value.CODE), new com.nimbusds.oauth2.sdk.Scope(provider.scopes.mkString(" ")), new com.nimbusds.oauth2.sdk.id.ClientID(provider.clientId), URI.create(provider.redirectUri) ) .state(new com.nimbusds.oauth2.sdk.id.State(state)) .nonce(new Nonce(nonce)) .endpointURI(URI.create(provider.authorizationEndpoint)) .build() } else { // Use standard OAuth2 AuthorizationRequest for non-OpenID providers (like Discord) new AuthorizationRequest.Builder( new ResponseType(ResponseType.Value.CODE), new com.nimbusds.oauth2.sdk.id.ClientID(provider.clientId) ) .scope(new com.nimbusds.oauth2.sdk.Scope(provider.scopes.mkString(" "))) .state(new com.nimbusds.oauth2.sdk.id.State(state)) .redirectionURI(URI.create(provider.redirectUri)) .endpointURI(URI.create(provider.authorizationEndpoint)) .build() } authRequest.toURI.toString } } def exchangeCodeForTokens(providerName: String, code: String, state: String): Future[Option[TokenResponse]] = { providers.get(providerName) match { case Some(provider) => ws.url(provider.tokenEndpoint) .withHttpHeaders( "Accept" -> "application/json", "Content-Type" -> "application/x-www-form-urlencoded" ) .post( Map( "client_id" -> Seq(provider.clientId), "client_secret" -> Seq(provider.clientSecret), "code" -> Seq(code), "grant_type" -> Seq("authorization_code"), "redirect_uri" -> Seq(provider.redirectUri) ) ) .map { response => if (response.status == 200) { val json = response.json Some(TokenResponse( accessToken = (json \ "access_token").as[String], tokenType = (json \ "token_type").as[String], expiresIn = (json \ "expires_in").asOpt[Int], refreshToken = (json \ "refresh_token").asOpt[String], idToken = (json \ "id_token").asOpt[String] )) } else { None } } .recover { case _ => None } case None => Future.successful(None) } } def getUserInfo(providerName: String, accessToken: String): Future[Option[OpenIDUserInfo]] = { providers.get(providerName) match { case Some(provider) => ws.url(provider.userInfoEndpoint) .withHttpHeaders("Authorization" -> s"Bearer $accessToken") .get() .map { response => if (response.status == 200) { val json = response.json Some(OpenIDUserInfo( id = (json \ "id").as[String], email = (json \ "email").asOpt[String], name = (json \ "name").asOpt[String].orElse((json \ "login").asOpt[String]), picture = (json \ "picture").asOpt[String].orElse((json \ "avatar_url").asOpt[String]), provider = providerName, providerName = provider.name )) } else { None } } .recover { case _ => None } case None => Future.successful(None) } } def validateState(sessionState: String, returnedState: String): Boolean = { sessionState == returnedState } def generateState(): String = { java.util.UUID.randomUUID().toString.replace("-", "") } def generateNonce(): String = { java.util.UUID.randomUUID().toString.replace("-", "") } }