diff --git a/seacatauth/external_login/providers/appleid.py b/seacatauth/external_login/providers/appleid.py index 46817a52..3fdcf654 100644 --- a/seacatauth/external_login/providers/appleid.py +++ b/seacatauth/external_login/providers/appleid.py @@ -81,7 +81,7 @@ async def get_user_info(self, authorize_data: dict, expected_nonce: str | None = raise ExternalOAuthFlowError("Unknown error during authorization flow.") id_token = authorize_data.get("id_token") - verified_claims = self._get_verified_claims(id_token, expected_nonce) + verified_claims = await self._get_verified_claims(id_token, expected_nonce) user_info = { "sub": str(verified_claims.get("sub")), diff --git a/seacatauth/external_login/providers/generic.py b/seacatauth/external_login/providers/generic.py index 5b69d55a..6c5e0929 100644 --- a/seacatauth/external_login/providers/generic.py +++ b/seacatauth/external_login/providers/generic.py @@ -1,3 +1,4 @@ +import datetime import json import re import typing @@ -103,6 +104,7 @@ def __init__(self, external_login_svc, config_section_name, config=None): assert self.Label is not None self.JwkSet = None + self.JwkSetLastUpdate = None # The URL to return to after successful external login # Mostly for debugging purposes @@ -111,11 +113,17 @@ def __init__(self, external_login_svc, config_section_name, config=None): else: self.CallbackUrl = external_login_svc.CallbackUrlTemplate.format(provider_type=self.Type) + external_login_svc.App.PubSub.subscribe("Application.housekeeping!", self._on_housekeeping) + async def initialize(self, app): await self._prepare_jwks() + async def _on_housekeeping(self, event_name): + await self._prepare_jwks(speculative=False) + + def acr_value(self) -> str: """ Authentication Context Class Reference (ACR) @@ -132,6 +140,10 @@ async def _prepare_jwks(self, speculative=True): return if self.JwkSet and speculative: return + jwksSet_last_update_diff = (datetime.datetime.now(datetime.timezone.utc) - self.JwkSetLastUpdate).total_seconds() if self.JwkSetLastUpdate is not None else None + # Refresh JWKS if it's older than 5 minutes + if self.JwkSetLastUpdate is not None and jwksSet_last_update_diff < 300: + return async with aiohttp.ClientSession() as session: async with session.get(self.JwksUri) as resp: if resp.status != 200: @@ -146,7 +158,8 @@ async def _prepare_jwks(self, speculative=True): return jwks = await resp.text() self.JwkSet = jwcrypto.jwk.JWKSet.from_json(jwks) - L.info("Identity provider public JWK set loaded.", struct_data={"type": self.Type}) + self.JwkSetLastUpdate = datetime.datetime.now(datetime.timezone.utc) + L.info("JWK set was loaded.", struct_data={"type": self.Type, "date": self.JwkSetLastUpdate}) def get_authorize_uri( self, redirect_uri: typing.Optional[str] = None, @@ -231,7 +244,7 @@ async def get_user_info(self, authorize_data: dict, expected_nonce: str | None = id_token = token_data["id_token"] await self._prepare_jwks() - id_token_claims = self._get_verified_claims(id_token, expected_nonce) + id_token_claims = await self._get_verified_claims(id_token, expected_nonce) user_info = self._user_data_from_id_token_claims(id_token_claims) user_info["sub"] = str(user_info["sub"]) return user_info @@ -247,7 +260,7 @@ def _user_data_from_id_token_claims(self, id_token_claims: dict): } return user_info - def _get_verified_claims(self, id_token, expected_nonce: str | None = None): + async def _get_verified_claims(self, id_token, expected_nonce: str | None = None): check_claims = self._get_claims_to_verify() if expected_nonce: check_claims["nonce"] = expected_nonce @@ -260,6 +273,15 @@ def _get_verified_claims(self, id_token, expected_nonce: str | None = None): except jwcrypto.jwt.JWTExpired: L.error("Expired ID token.", struct_data={"provider": self.Type}) raise ExternalOAuthFlowError("Expired ID token.") + except jwcrypto.jwt.JWTMissingKey: + L.error("Error reading ID token - Invalid key in JWKSet.", struct_data={ + "provider": self.Type}) + # provider probably change jwks - refresh them + await self._prepare_jwks(speculative=False) + raise ExternalOAuthFlowError("Error reading ID token - Invalid key in JWKSet.") + except jwcrypto.jwt.JWTMissingClaim: + L.error("Missing ID token claim.", struct_data={"provider": self.Type}) + raise ExternalOAuthFlowError("Missing ID token claim.") except Exception as e: L.error("Error reading ID token claims.", struct_data={ "provider": self.Type, "error": str(e)})