Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refresh external-login JWKS when Missing key is thrown #433

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion seacatauth/external_login/providers/appleid.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def get_user_info(self, authorize_data: dict, expected_nonce: str | None =
raise ExternalOAuthFlowError("Unknown error during authorization flow.")

id_token = authorize_data.get("id_token")
verified_claims = self._get_verified_claims(id_token, expected_nonce)
verified_claims = await self._get_verified_claims(id_token, expected_nonce)

user_info = {
"sub": str(verified_claims.get("sub")),
Expand Down
22 changes: 19 additions & 3 deletions seacatauth/external_login/providers/generic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import json
import re
import typing
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(self, external_login_svc, config_section_name, config=None):
assert self.Label is not None

self.JwkSet = None
self.JwkSetLastUpdate = None

# The URL to return to after successful external login
# Mostly for debugging purposes
Expand Down Expand Up @@ -132,6 +134,10 @@ async def _prepare_jwks(self, speculative=True):
return
if self.JwkSet and speculative:
return
jwksSet_last_update_diff = (datetime.datetime.now(datetime.timezone.utc) - self.JwkSetLastUpdate).total_seconds() if self.JwkSetLastUpdate is not None else None
# Refresh JWKS if it's older than 5 minutes
if self.JwkSetLastUpdate is not None and jwksSet_last_update_diff < 300:
return
async with aiohttp.ClientSession() as session:
async with session.get(self.JwksUri) as resp:
if resp.status != 200:
Expand All @@ -146,7 +152,8 @@ async def _prepare_jwks(self, speculative=True):
return
jwks = await resp.text()
self.JwkSet = jwcrypto.jwk.JWKSet.from_json(jwks)
L.info("Identity provider public JWK set loaded.", struct_data={"type": self.Type})
self.JwkSetLastUpdate = datetime.datetime.now(datetime.timezone.utc)
L.info("Identity provider public JWK set loaded.", struct_data={"type": self.Type, "date": self.JwkSetLastUpdate})

def get_authorize_uri(
self, redirect_uri: typing.Optional[str] = None,
Expand Down Expand Up @@ -231,7 +238,7 @@ async def get_user_info(self, authorize_data: dict, expected_nonce: str | None =
id_token = token_data["id_token"]
await self._prepare_jwks()

id_token_claims = self._get_verified_claims(id_token, expected_nonce)
id_token_claims = await self._get_verified_claims(id_token, expected_nonce)
user_info = self._user_data_from_id_token_claims(id_token_claims)
user_info["sub"] = str(user_info["sub"])
return user_info
Expand All @@ -247,7 +254,7 @@ def _user_data_from_id_token_claims(self, id_token_claims: dict):
}
return user_info

def _get_verified_claims(self, id_token, expected_nonce: str | None = None):
async def _get_verified_claims(self, id_token, expected_nonce: str | None = None):
check_claims = self._get_claims_to_verify()
if expected_nonce:
check_claims["nonce"] = expected_nonce
Expand All @@ -260,6 +267,15 @@ def _get_verified_claims(self, id_token, expected_nonce: str | None = None):
except jwcrypto.jwt.JWTExpired:
L.error("Expired ID token.", struct_data={"provider": self.Type})
raise ExternalOAuthFlowError("Expired ID token.")
except jwcrypto.jwt.JWTMissingKey:
L.error("Error reading ID token - Invalid key in JWKSet.", struct_data={
"provider": self.Type})
# provider probably change jwks - refresh them
await self._prepare_jwks()
raise ExternalOAuthFlowError("Error reading ID token - Invalid key in JWKSet.")
except jwcrypto.jwt.JWTMissingClaim:
L.error("Missing ID token claim.", struct_data={"provider": self.Type})
raise ExternalOAuthFlowError("Missing ID token claim.")
except Exception as e:
L.error("Error reading ID token claims.", struct_data={
"provider": self.Type, "error": str(e)})
Expand Down