From 9a0886c4b783e2e38c9c74b95ceb27e129dfeda5 Mon Sep 17 00:00:00 2001 From: KShivendu Date: Fri, 22 Jul 2022 16:50:53 +0530 Subject: [PATCH] feat: Refactor session claims and tests --- .../recipe/session/api/implementation.py | 10 +- .../recipe/session/asyncio/__init__.py | 121 +++++++++++++++++- .../claim_base_classes/primitive_claim.py | 22 +++- .../framework/django/asyncio/__init__.py | 2 +- .../framework/django/syncio/__init__.py | 2 +- .../session/framework/fastapi/__init__.py | 2 +- .../session/framework/flask/__init__.py | 2 +- .../recipe/session/interfaces.py | 71 ++++++---- supertokens_python/recipe/session/recipe.py | 38 +++--- .../recipe/session/recipe_implementation.py | 104 +++++++++------ .../recipe/session/session_class.py | 58 ++------- .../recipe/session/syncio/__init__.py | 104 ++++++++++++++- supertokens_python/recipe/session/utils.py | 110 +++++++++++++++- .../session/with_jwt/recipe_implementation.py | 2 +- tests/sessions/claims/test_assert_claims.py | 12 +- .../claims/test_create_new_session.py | 6 +- ...laims_value.py => test_get_claim_value.py} | 31 +++-- tests/sessions/claims/test_primitive_claim.py | 28 ++-- tests/sessions/claims/test_remove_claim.py | 8 +- tests/sessions/claims/test_set_claim_value.py | 8 +- ...test_validate_claims_for_session_handle.py | 65 ++++++++++ tests/sessions/claims/test_verify_session.py | 14 +- tests/sessions/claims/test_with_jwt.py | 83 ++++++++++++ tests/sessions/claims/utils.py | 14 +- 24 files changed, 717 insertions(+), 200 deletions(-) rename tests/sessions/claims/{test_get_claims_value.py => test_get_claim_value.py} (58%) create mode 100644 tests/sessions/claims/test_validate_claims_for_session_handle.py create mode 100644 tests/sessions/claims/test_with_jwt.py diff --git a/supertokens_python/recipe/session/api/implementation.py b/supertokens_python/recipe/session/api/implementation.py index 8cfe57912..405280bb7 100644 --- a/supertokens_python/recipe/session/api/implementation.py +++ b/supertokens_python/recipe/session/api/implementation.py @@ -23,6 +23,7 @@ ) from supertokens_python.types import MaybeAwaitable from supertokens_python.utils import normalise_http_method +from ..utils import get_required_claim_validators if TYPE_CHECKING: from supertokens_python.recipe.session.interfaces import APIOptions @@ -67,7 +68,7 @@ async def verify_session( session_required: bool, override_global_claim_validators: Optional[ Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]], ] ], @@ -91,10 +92,15 @@ async def verify_session( ) if session is not None: - await api_options.recipe_implementation.assert_claims( + claim_validators = await get_required_claim_validators( session, override_global_claim_validators, user_context, ) + await api_options.recipe_implementation.assert_claims( + session, + claim_validators, + user_context, + ) return session diff --git a/supertokens_python/recipe/session/asyncio/__init__.py b/supertokens_python/recipe/session/asyncio/__init__.py index 0b0c08580..85a9c7ffd 100644 --- a/supertokens_python/recipe/session/asyncio/__init__.py +++ b/supertokens_python/recipe/session/asyncio/__init__.py @@ -22,10 +22,15 @@ SessionInformationResult, SessionClaim, SessionClaimValidator, + SessionDoesntExistError, + ValidateClaimsOkResult, + JSONObject, + GetClaimValueOkResult, ) from supertokens_python.recipe.session.recipe import SessionRecipe from supertokens_python.types import MaybeAwaitable -from supertokens_python.utils import FRAMEWORKS +from supertokens_python.utils import FRAMEWORKS, resolve +from ..utils import get_required_claim_validators from ...jwt.interfaces import ( CreateJwtOkResult, CreateJwtResultUnsupportedAlgorithm, @@ -53,11 +58,118 @@ async def create_new_session( ) +async def validate_claims_for_session_handle( + session_handle: str, + override_global_claim_validators: Optional[ + Callable[ + [ + List[SessionClaimValidator], + SessionInformationResult, + Dict[str, Any], + ], # Prev. 2nd arg was SessionContainer + MaybeAwaitable[List[SessionClaimValidator]], + ] + ] = None, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[SessionDoesntExistError, ValidateClaimsOkResult]: + if user_context is None: + user_context = {} + + recipe_impl = SessionRecipe.get_instance().recipe_implementation + session_info = await recipe_impl.get_session_information( + session_handle, user_context + ) + + if session_info is None: + return SessionDoesntExistError() + + claim_validators_added_by_other_recipes = ( + SessionRecipe.get_claim_validators_added_by_other_recipes() + ) + global_claim_validators = await resolve( + recipe_impl.get_global_claim_validators( + session_info.user_id, + claim_validators_added_by_other_recipes, + user_context, + ) + ) + + if override_global_claim_validators is not None: + claim_validators = await resolve( + override_global_claim_validators( + global_claim_validators, session_info, user_context + ) + ) + else: + claim_validators = global_claim_validators + + return await recipe_impl.validate_claims_for_session_handle( + session_info, claim_validators, user_context + ) + + +async def validate_claims_in_jwt_payload( + user_id: str, + jwt_payload: JSONObject, + override_global_claim_validators: Optional[ + Callable[ + [ + List[SessionClaimValidator], + str, + Dict[str, Any], + ], # Prev. 2nd arg was SessionContainer + MaybeAwaitable[List[SessionClaimValidator]], + ] + ] = None, + user_context: Union[None, Dict[str, Any]] = None, +): + if user_context is None: + user_context = {} + + recipe_impl = SessionRecipe.get_instance().recipe_implementation + + claim_validators_added_by_other_recipes = ( + SessionRecipe.get_claim_validators_added_by_other_recipes() + ) + global_claim_validators = await resolve( + recipe_impl.get_global_claim_validators( + user_id, + claim_validators_added_by_other_recipes, + user_context, + ) + ) + + if override_global_claim_validators is not None: + claim_validators = await resolve( + override_global_claim_validators( + global_claim_validators, user_id, user_context + ) + ) + else: + claim_validators = global_claim_validators + + return await recipe_impl.validate_claims_in_jwt_payload( + user_id, jwt_payload, claim_validators, user_context + ) + + +async def fetch_and_set_claim( + session_handle: str, + claim: SessionClaim[Any], + user_context: Union[None, Dict[str, Any]] = None, +) -> bool: + if user_context is None: + user_context = {} + return await SessionRecipe.get_instance().recipe_implementation.fetch_and_set_claim( + session_handle, claim, user_context + ) + + async def get_claim_value( session_handle: str, claim: SessionClaim[_T], user_context: Union[None, Dict[str, Any]] = None, -) -> Union[_T, None]: +) -> Union[SessionDoesntExistError, GetClaimValueOkResult[_T]]: if user_context is None: user_context = {} return await SessionRecipe.get_instance().recipe_implementation.get_claim_value( @@ -96,7 +208,7 @@ async def get_session( session_required: bool = True, override_global_claim_validators: Optional[ Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]], ] ] = None, @@ -119,9 +231,10 @@ async def get_session( ) if session is not None: - await session_recipe_impl.assert_claims( + claim_validators = await get_required_claim_validators( session, override_global_claim_validators, user_context ) + await session_recipe_impl.assert_claims(session, claim_validators, user_context) return session diff --git a/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py b/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py index 06fe32052..fedaf74d1 100644 --- a/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py @@ -53,7 +53,7 @@ def should_refetch( ): return claim.get_value_from_payload(payload, user_context) is None - def validate( + async def validate( self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None, @@ -93,23 +93,26 @@ def should_refetch( claim.get_value_from_payload(payload, user_context) is None ) or (payload[claim.key]["t"] < time.time() - max_age_in_sec * 1000) - def validate( + async def validate( self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None, ): claim_val = claim.get_value_from_payload(payload, user_context) - if claim_val != val: + if claim_val is None: return { "isValid": False, "reason": { - "message": "wrong value", + "message": "value does not exist", "expectedValue": val, "actualValue": claim_val, }, } - age_in_sec = (time.time() - payload[claim.key]["t"]) / 1000 + age_in_sec = ( + time.time() + - float(claim.get_last_refetch_time(payload, user_context) or 0) + ) / 1000 if age_in_sec > max_age_in_sec: return { "isValid": False, @@ -119,6 +122,15 @@ def validate( "maxAgeInSeconds": max_age_in_sec, }, } + if claim_val != val: + return { + "isValid": False, + "reason": { + "message": "wrong value", + "expectedValue": val, + "actualValue": claim_val, + }, + } return {"isValid": True} diff --git a/supertokens_python/recipe/session/framework/django/asyncio/__init__.py b/supertokens_python/recipe/session/framework/django/asyncio/__init__.py index 94c7bfcdf..9034ecb7c 100644 --- a/supertokens_python/recipe/session/framework/django/asyncio/__init__.py +++ b/supertokens_python/recipe/session/framework/django/asyncio/__init__.py @@ -31,7 +31,7 @@ def verify_session( user_context: Union[None, Dict[str, Any]] = None, override_global_claim_validators: Optional[ Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]], ] ] = None, diff --git a/supertokens_python/recipe/session/framework/django/syncio/__init__.py b/supertokens_python/recipe/session/framework/django/syncio/__init__.py index a20c8cc10..7374a9a60 100644 --- a/supertokens_python/recipe/session/framework/django/syncio/__init__.py +++ b/supertokens_python/recipe/session/framework/django/syncio/__init__.py @@ -32,7 +32,7 @@ def verify_session( user_context: Union[None, Dict[str, Any]] = None, override_global_claim_validators: Optional[ Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]], ] ] = None, diff --git a/supertokens_python/recipe/session/framework/fastapi/__init__.py b/supertokens_python/recipe/session/framework/fastapi/__init__.py index 3ac06608f..d9da39558 100644 --- a/supertokens_python/recipe/session/framework/fastapi/__init__.py +++ b/supertokens_python/recipe/session/framework/fastapi/__init__.py @@ -26,7 +26,7 @@ def verify_session( user_context: Union[None, Dict[str, Any]] = None, override_global_claim_validators: Optional[ Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]], ] ] = None, diff --git a/supertokens_python/recipe/session/framework/flask/__init__.py b/supertokens_python/recipe/session/framework/flask/__init__.py index 3505165cf..f274d1fad 100644 --- a/supertokens_python/recipe/session/framework/flask/__init__.py +++ b/supertokens_python/recipe/session/framework/flask/__init__.py @@ -29,7 +29,7 @@ def verify_session( user_context: Union[None, Dict[str, Any]] = None, override_global_claim_validators: Optional[ Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]], ] ] = None, diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 86fc1cfad..d24898672 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -28,7 +28,7 @@ from supertokens_python.async_to_sync_wrapper import sync from supertokens_python.types import APIResponse, GeneralErrorResponse, MaybeAwaitable - +from .exceptions import ClaimValidationError from .utils import SessionConfig from ...utils import resolve @@ -74,7 +74,29 @@ def __init__( self.time_created: int = time_created -class RecipeInterface(ABC): +_T = TypeVar("_T") +JSONObject = Dict[str, Any] + +JSONPrimitive = Union[str, int, bool, None, Dict[str, Any]] + +FetchValueReturnType = Union[_T, None] + + +class SessionDoesntExistError: + pass + + +class GetClaimValueOkResult(Generic[_T]): + def __init__(self, value: Optional[_T]): + self.value = value + + +class ValidateClaimsOkResult: + def __init__(self, invalid_claims: List[ClaimValidationError]): + self.invalid_claims = invalid_claims + + +class RecipeInterface(ABC): # pylint: disable=too-many-public-methods def __init__(self): pass @@ -106,7 +128,7 @@ async def get_session( session_required: bool, override_global_claim_validators: Optional[ Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]], ] ], @@ -118,16 +140,30 @@ async def get_session( async def assert_claims( self, session: SessionContainer, - override_global_claim_validators: Optional[ - Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], - MaybeAwaitable[List[SessionClaimValidator]], - ] - ], + claim_validators: List[SessionClaimValidator], user_context: Dict[str, Any], ) -> None: pass + @abstractmethod + async def validate_claims_for_session_handle( + self, + session_info: SessionInformationResult, + claim_validators: List[SessionClaimValidator], + user_context: Dict[str, Any], + ) -> Union[ValidateClaimsOkResult, SessionDoesntExistError]: + pass + + @abstractmethod + async def validate_claims_in_jwt_payload( + self, + user_id: str, + jwt_payload: JSONObject, + claim_validators: List[SessionClaimValidator], + user_context: Dict[str, Any], + ) -> ValidateClaimsOkResult: + pass + @abstractmethod async def refresh_session( self, request: BaseRequest, user_context: Dict[str, Any] @@ -215,7 +251,7 @@ async def get_claim_value( session_handle: str, claim: SessionClaim[Any], user_context: Dict[str, Any], - ): + ) -> Union[SessionDoesntExistError, GetClaimValueOkResult[Any]]: pass @abstractmethod @@ -290,7 +326,7 @@ async def verify_session( session_required: bool, override_global_claim_validators: Optional[ Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]], ] ], @@ -453,15 +489,6 @@ def __getitem__(self, item: str): return getattr(self, item) -# Session claims: -_T = TypeVar("_T") -JSONObject = Dict[str, Any] - -JSONPrimitive = Union[str, int, bool, None, Dict[str, Any]] - -FetchValueReturnType = Union[_T, None] - - class SessionClaim(ABC, Generic[_T]): def __init__(self, key: str) -> None: self.key = key @@ -537,9 +564,9 @@ def __init__(self, id_: Optional[str] = None): self.id = id_ @abstractmethod - def validate( + async def validate( self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None - ) -> MaybeAwaitable[ClaimValidationResult]: + ) -> ClaimValidationResult: pass def should_refetch( diff --git a/supertokens_python/recipe/session/recipe.py b/supertokens_python/recipe/session/recipe.py index d88971984..965dd1a7f 100644 --- a/supertokens_python/recipe/session/recipe.py +++ b/supertokens_python/recipe/session/recipe.py @@ -66,6 +66,9 @@ class SessionRecipe(RecipeModule): recipe_id = "session" __instance = None + claims_added_by_other_recipes: List[SessionClaim[Any]] = [] + claim_validators_added_by_other_recipes: List[SessionClaimValidator] = [] + def __init__( self, recipe_id: str, @@ -151,9 +154,6 @@ def __init__( else self.config.override.apis(api_implementation) ) - self.claims_added_by_other_recipes: List[SessionClaim[Any]] = [] - self.claim_validators_added_by_other_recipes: List[SessionClaimValidator] = [] - def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return isinstance(err, SuperTokensError) and ( isinstance(err, SuperTokensSessionError) @@ -299,21 +299,27 @@ def reset(): raise_general_exception("calling testing function in non testing env") SessionRecipe.__instance = None - def add_claim_from_other_recipe(self, claim: SessionClaim[Any]): - self.claims_added_by_other_recipes.append(claim) + @staticmethod + def add_claim_from_other_recipe(claim: SessionClaim[Any]): + # We are throwing here (and not in addClaimValidatorFromOtherRecipe) because if multiple + # claims are added with the same key they will overwrite each other. Validators will all run + # and work as expected even if they are added multiple times. + if claim.key in [c.key for c in SessionRecipe.claims_added_by_other_recipes]: + raise Exception("Claim added by multiple recipes") - def get_claims_added_by_other_recipes(self) -> List[SessionClaim[Any]]: - return self.claims_added_by_other_recipes + SessionRecipe.claims_added_by_other_recipes.append(claim) - def add_claim_validator_from_other_recipe( - self, claim_validator: SessionClaimValidator - ): - self.claim_validators_added_by_other_recipes.append(claim_validator) + @staticmethod + def get_claims_added_by_other_recipes() -> List[SessionClaim[Any]]: + return SessionRecipe.claims_added_by_other_recipes - def get_claim_validators_added_by_other_recipes( - self, - ) -> List[SessionClaimValidator]: - return self.claim_validators_added_by_other_recipes + @staticmethod + def add_claim_validator_from_other_recipe(claim_validator: SessionClaimValidator): + SessionRecipe.claim_validators_added_by_other_recipes.append(claim_validator) + + @staticmethod + def get_claim_validators_added_by_other_recipes() -> List[SessionClaimValidator]: + return SessionRecipe.claim_validators_added_by_other_recipes async def verify_session( self, @@ -322,7 +328,7 @@ async def verify_session( session_required: bool, override_global_claim_validators: Optional[ Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]], ] ], diff --git a/supertokens_python/recipe/session/recipe_implementation.py b/supertokens_python/recipe/session/recipe_implementation.py index 55867e25d..b9214b98b 100644 --- a/supertokens_python/recipe/session/recipe_implementation.py +++ b/supertokens_python/recipe/session/recipe_implementation.py @@ -13,6 +13,7 @@ # under the License. from __future__ import annotations +import json from typing import TYPE_CHECKING, Any, Dict, TypeVar, Callable, Optional from supertokens_python.framework.request import BaseRequest @@ -24,7 +25,6 @@ frontend_has_interceptor, get_timestamp_ms, normalise_http_method, - resolve, ) from . import session_functions from .cookie_and_header import ( @@ -43,16 +43,24 @@ SessionClaimValidator, SessionInformationResult, SessionObj, + ValidateClaimsOkResult, + SessionDoesntExistError, + JSONObject, + GetClaimValueOkResult, ) from .session_class import Session from ...types import MaybeAwaitable +from .utils import ( + SessionConfig, + update_claims_in_payload_if_needed, + validate_claims_in_payload, +) if TYPE_CHECKING: from typing import List, Union from supertokens_python.querier import Querier - from .utils import SessionConfig from .interfaces import SessionContainer @@ -171,44 +179,62 @@ async def create_new_session( async def assert_claims( self, session: SessionContainer, - override_global_claim_validators: Optional[ - Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], - MaybeAwaitable[List[SessionClaimValidator]], - ] - ], + claim_validators: List[SessionClaimValidator], user_context: Dict[str, Any], ) -> None: - # TODO: FIXME - # This leads to circular import, so avoiding: - # session_recipe = SessionRecipe.get_instance() - # claim_validators_added_by_other_recipes = ( - # session_recipe.get_claim_validators_added_by_other_recipes() - # ) - claim_validators_added_by_other_recipes = [] - - global_claim_validators = await resolve( - self.get_global_claim_validators( - session.get_user_id(), - claim_validators_added_by_other_recipes, + log_debug_message( + "get_session: required validator ids %s", + ",".join([c.id for c in claim_validators]), + ) + await session.assert_claims(claim_validators, user_context) + log_debug_message("get_session: claim assertion successful") + + async def validate_claims_for_session_handle( + self, + session_info: SessionInformationResult, + claim_validators: List[SessionClaimValidator], + user_context: Dict[str, Any], + ) -> Union[ValidateClaimsOkResult, SessionDoesntExistError]: + original_session_claim_payload_json = json.dumps( + session_info.access_token_payload + ) + + new_access_token_payload = await update_claims_in_payload_if_needed( + claim_validators, + session_info.access_token_payload, + session_info.user_id, + user_context, + ) + + if json.dumps(new_access_token_payload) != original_session_claim_payload_json: + await self.merge_into_access_token_payload( + session_info.session_handle, + new_access_token_payload, user_context, ) + + invalid_claims = await validate_claims_in_payload( + claim_validators, + new_access_token_payload, + user_context, ) - if override_global_claim_validators is not None: - req_claim_validators = await resolve( - override_global_claim_validators( - session, global_claim_validators, user_context - ) - ) - else: - req_claim_validators = global_claim_validators - log_debug_message( - "getSession: required validator is %s", - ",".join([c.id for c in req_claim_validators]), + return ValidateClaimsOkResult(invalid_claims) + + async def validate_claims_in_jwt_payload( + self, + user_id: str, + jwt_payload: JSONObject, + claim_validators: List[SessionClaimValidator], + user_context: Dict[str, Any], + ) -> ValidateClaimsOkResult: + invalid_claims = await validate_claims_in_payload( + claim_validators, + jwt_payload, + user_context, ) - await session.assert_claims(req_claim_validators, user_context) - log_debug_message("getSession: claim assertion successful") + + return ValidateClaimsOkResult(invalid_claims) async def get_session( self, @@ -217,7 +243,7 @@ async def get_session( session_required: bool, override_global_claim_validators: Optional[ Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]], ] ], @@ -447,13 +473,15 @@ async def get_claim_value( session_handle: str, claim: SessionClaim[Any], user_context: Dict[str, Any], - ): + ) -> Union[SessionDoesntExistError, GetClaimValueOkResult[Any]]: session_info = await self.get_session_information(session_handle, user_context) if session_info is None: - raise Exception("Session does not exist") + return SessionDoesntExistError() - return claim.get_value_from_payload( - session_info.access_token_payload, user_context + return GetClaimValueOkResult( + value=claim.get_value_from_payload( + session_info.access_token_payload, user_context + ) ) def get_global_claim_validators( diff --git a/supertokens_python/recipe/session/session_class.py b/supertokens_python/recipe/session/session_class.py index e7475d675..5e9c4e3b4 100644 --- a/supertokens_python/recipe/session/session_class.py +++ b/supertokens_python/recipe/session/session_class.py @@ -17,12 +17,10 @@ from supertokens_python.recipe.session.exceptions import ( raise_unauthorised_exception, raise_invalid_claims_exception, - ClaimValidationError, ) from .interfaces import SessionContainer, SessionClaimValidator, SessionClaim -from supertokens_python.logger import log_debug_message -from ...utils import resolve +from .utils import update_claims_in_payload_if_needed, validate_claims_in_payload _T = TypeVar("_T") @@ -127,57 +125,29 @@ async def assert_claims( claim_validators: List[SessionClaimValidator], user_context: Union[Dict[str, Any], None] = None, ) -> None: + if user_context is None: + user_context = {} + original_session_claim_payload_json = json.dumps( self.get_access_token_payload() ) - new_access_token_payload = self.get_access_token_payload() - validation_errors: List[ClaimValidationError] = [] - for validator in claim_validators: - log_debug_message("Session.validate_claims checking %s", validator.id) - if ( - hasattr(validator, "claim") - and (validator.claim is not None) - and ( - await resolve( - validator.should_refetch(new_access_token_payload, user_context) - ) - ) - ): - log_debug_message("Session.validate_claims refetching %s", validator.id) - value = await resolve( - validator.claim.fetch_value(self.get_user_id(), user_context) - ) - log_debug_message( - "Session.validate_claims %s refetch res %s", - validator.id, - json.dumps(value), - ) - if value is not None: - new_access_token_payload = validator.claim.add_to_payload_( - new_access_token_payload, - value, - user_context, - ) - - claim_validation_res = await resolve( - validator.validate(new_access_token_payload, user_context) - ) - log_debug_message( - "Session.validate_claims %s validate res %s", - validator.id, - json.dumps(claim_validation_res), - ) - if not claim_validation_res.get("isValid"): - validation_errors.append( - ClaimValidationError(validator.id, claim_validation_res["reason"]) - ) + new_access_token_payload = await update_claims_in_payload_if_needed( + claim_validators, + self.get_access_token_payload(), + self.get_user_id(), + user_context, + ) if json.dumps(new_access_token_payload) != original_session_claim_payload_json: await self.merge_into_access_token_payload( new_access_token_payload, user_context ) + validation_errors = await validate_claims_in_payload( + claim_validators, new_access_token_payload, user_context + ) + if len(validation_errors) > 0: raise_invalid_claims_exception("INVALID_CLAIMS", validation_errors) diff --git a/supertokens_python/recipe/session/syncio/__init__.py b/supertokens_python/recipe/session/syncio/__init__.py index 79ed75911..54823eec9 100644 --- a/supertokens_python/recipe/session/syncio/__init__.py +++ b/supertokens_python/recipe/session/syncio/__init__.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Union, Callable, Optional +from typing import Any, Dict, List, Union, Callable, Optional, TypeVar from supertokens_python.async_to_sync_wrapper import sync from supertokens_python.recipe.openid.interfaces import ( @@ -30,6 +30,11 @@ SessionContainer, SessionInformationResult, SessionClaimValidator, + SessionClaim, + JSONObject, + ValidateClaimsOkResult, + SessionDoesntExistError, + GetClaimValueOkResult, ) @@ -61,7 +66,7 @@ def get_session( session_required: bool = True, override_global_claim_validators: Optional[ Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]], ] ] = None, @@ -212,3 +217,98 @@ def regenerate_access_token( access_token, new_access_token_payload, user_context ) ) + + +_T = TypeVar("_T") + + +def fetch_and_set_claim( + session_handle: str, + claim: SessionClaim[Any], + user_context: Union[None, Dict[str, Any]] = None, +) -> bool: + from supertokens_python.recipe.session.asyncio import ( + fetch_and_set_claim as async_fetch_and_set_claim, + ) + + return sync(async_fetch_and_set_claim(session_handle, claim, user_context)) + + +def set_claim_value( + session_handle: str, + claim: SessionClaim[_T], + value: _T, + user_context: Union[None, Dict[str, Any]] = None, +) -> bool: + from supertokens_python.recipe.session.asyncio import ( + set_claim_value as async_set_claim_value, + ) + + return sync(async_set_claim_value(session_handle, claim, value, user_context)) + + +def get_claim_value( + session_handle: str, + claim: SessionClaim[_T], + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[SessionDoesntExistError, GetClaimValueOkResult[_T]]: + from supertokens_python.recipe.session.asyncio import ( + get_claim_value as async_get_claim_value, + ) + + return sync(async_get_claim_value(session_handle, claim, user_context)) + + +def remove_claim( + session_handle: str, + claim: SessionClaim[Any], + user_context: Union[None, Dict[str, Any]] = None, +) -> bool: + from supertokens_python.recipe.session.asyncio import ( + remove_claim as async_remove_claim, + ) + + return sync(async_remove_claim(session_handle, claim, user_context)) + + +def validate_claims_for_session_handle( + session_handle: str, + override_global_claim_validators: Optional[ + Callable[ + [List[SessionClaimValidator], SessionInformationResult, Dict[str, Any]], + MaybeAwaitable[List[SessionClaimValidator]], + ] + ] = None, + user_context: Union[None, Dict[str, Any]] = None, +) -> Union[SessionDoesntExistError, ValidateClaimsOkResult]: + from supertokens_python.recipe.session.asyncio import ( + validate_claims_for_session_handle as async_validate_claims_for_session_handle, + ) + + return sync( + async_validate_claims_for_session_handle( + session_handle, override_global_claim_validators, user_context + ) + ) + + +def validate_claims_in_jwt_payload( + user_id: str, + jwt_payload: JSONObject, + override_global_claim_validators: Optional[ + Callable[ + [List[SessionClaimValidator], str, Dict[str, Any]], + MaybeAwaitable[List[SessionClaimValidator]], + ] + ] = None, + user_context: Union[None, Dict[str, Any]] = None, +): + from supertokens_python.recipe.session.asyncio import ( + validate_claims_in_jwt_payload as async_validate_claims_in_jwt_payload, + ) + + return sync( + async_validate_claims_in_jwt_payload( + user_id, jwt_payload, override_global_claim_validators, user_context + ) + ) diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index 1a7e921e9..26a70ee3f 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -13,9 +13,13 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Union, List, Dict +import json +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Union, List, Dict, Optional from urllib.parse import urlparse +from tldextract import extract # type: ignore +from typing_extensions import Literal + from supertokens_python.exceptions import raise_general_exception from supertokens_python.framework import BaseResponse from supertokens_python.normalised_url_path import NormalisedURLPath @@ -28,21 +32,25 @@ send_non_200_response_with_message, resolve, ) -from tldextract import extract # type: ignore -from typing_extensions import Literal - from .constants import SESSION_REFRESH from .cookie_and_header import clear_cookies +from .exceptions import ClaimValidationError from .with_jwt.constants import ( ACCESS_TOKEN_PAYLOAD_JWT_PROPERTY_NAME_KEY, JWT_RESERVED_KEY_USE_ERROR_MESSAGE, ) +from ...types import MaybeAwaitable if TYPE_CHECKING: from supertokens_python.framework import BaseRequest from supertokens_python.supertokens import AppInfo - from .interfaces import APIInterface, RecipeInterface + from .interfaces import ( + APIInterface, + RecipeInterface, + SessionContainer, + SessionClaimValidator, + ) from .recipe import SessionRecipe from supertokens_python.logger import log_debug_message @@ -457,3 +465,95 @@ def validate_and_normalise_user_input( jwt, invalid_claim_status_code if (invalid_claim_status_code is not None) else 403, ) + + +async def get_required_claim_validators( + session: SessionContainer, + override_global_claim_validators: Optional[ + Callable[ + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], + MaybeAwaitable[List[SessionClaimValidator]], + ] + ], + user_context: Dict[str, Any], +) -> List[SessionClaimValidator]: + claim_validators_added_by_other_recipes = ( + SessionRecipe.get_claim_validators_added_by_other_recipes() + ) + global_claim_validators = await resolve( + SessionRecipe.get_instance().recipe_implementation.get_global_claim_validators( + session.get_user_id(), + claim_validators_added_by_other_recipes, + user_context, + ) + ) + + if override_global_claim_validators is not None: + return await resolve( + override_global_claim_validators( + global_claim_validators, session, user_context + ) + ) + + return global_claim_validators + + +async def update_claims_in_payload_if_needed( + claim_validators: List[SessionClaimValidator], + new_access_token_payload: Dict[str, Any], + user_id: str, + user_context: Dict[str, Any], +): + for validator in claim_validators: + log_debug_message( + "update_claims_in_payload_if_needed checking %s", validator.id + ) + if ( + hasattr(validator, "claim") + and (validator.claim is not None) + and ( + await resolve( + validator.should_refetch(new_access_token_payload, user_context) + ) + ) + ): + log_debug_message( + "update_claims_in_payload_if_needed refetching %s", validator.id + ) + value = await resolve(validator.claim.fetch_value(user_id, user_context)) + log_debug_message( + "update_claims_in_payload_if_needed %s refetch res %s", + validator.id, + json.dumps(value), + ) + if value is not None: + new_access_token_payload = validator.claim.add_to_payload_( + new_access_token_payload, + value, + user_context, + ) + + return new_access_token_payload + + +async def validate_claims_in_payload( + claim_validators: List[SessionClaimValidator], + new_access_token_payload: Dict[str, Any], + user_context: Dict[str, Any], +): + validation_errors: List[ClaimValidationError] = [] + for validator in claim_validators: + claim_validation_res = await validator.validate( + new_access_token_payload, user_context + ) + log_debug_message( + "validate_claims_in_payload %s validate res %s", + validator.id, + json.dumps(claim_validation_res), + ) + if not claim_validation_res.get("isValid"): + validation_errors.append( + ClaimValidationError(validator.id, claim_validation_res["reason"]) + ) + + return validation_errors diff --git a/supertokens_python/recipe/session/with_jwt/recipe_implementation.py b/supertokens_python/recipe/session/with_jwt/recipe_implementation.py index 34f8ec965..411dc9722 100644 --- a/supertokens_python/recipe/session/with_jwt/recipe_implementation.py +++ b/supertokens_python/recipe/session/with_jwt/recipe_implementation.py @@ -93,7 +93,7 @@ async def get_session( session_required: bool, override_global_claim_validators: Optional[ Callable[ - [SessionContainer, List[SessionClaimValidator], Dict[str, Any]], + [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]], ] ], diff --git a/tests/sessions/claims/test_assert_claims.py b/tests/sessions/claims/test_assert_claims.py index e8e90f615..42a16f6bf 100644 --- a/tests/sessions/claims/test_assert_claims.py +++ b/tests/sessions/claims/test_assert_claims.py @@ -45,17 +45,11 @@ async def test_should_call_validate_with_the_same_payload_object(): ) class DummyClaimValidator(SessionClaimValidator): - def __init__(self): - super().__init__("claim_validator_id") - - def should_refetch( - self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None - ) -> Any: - return True + id = "claim_validator_id" - def validate( + async def validate( self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None - ) -> Any: + ): return {"isValid": True} class DummyClaim(PrimitiveClaim): diff --git a/tests/sessions/claims/test_create_new_session.py b/tests/sessions/claims/test_create_new_session.py index 4b43b9692..b8db20f6e 100644 --- a/tests/sessions/claims/test_create_new_session.py +++ b/tests/sessions/claims/test_create_new_session.py @@ -7,7 +7,7 @@ from supertokens_python.recipe.session.asyncio import create_new_session from tests.utils import setup_function, teardown_function, start_st, min_api_version from .utils import ( - st_init_args, + st_init_args_with_TrueClaim, NoneClaim, get_st_init_args, session_functions_override_with_claim, @@ -22,7 +22,7 @@ @min_api_version("2.13") async def test_create_access_token_payload_with_session_claims(): - init(**st_init_args) # type:ignore + init(**st_init_args_with_TrueClaim) # type:ignore start_st() dummy_req: BaseRequest = MagicMock() @@ -49,7 +49,7 @@ async def test_should_create_access_token_payload_with_session_claims_with_an_no @min_api_version("2.13") async def test_should_merge_claims_and_passed_access_token_payload_obj(): new_st_init = { - **st_init_args, + **st_init_args_with_TrueClaim, "recipe_list": [ session.init( override=session.InputOverrideConfig( diff --git a/tests/sessions/claims/test_get_claims_value.py b/tests/sessions/claims/test_get_claim_value.py similarity index 58% rename from tests/sessions/claims/test_get_claims_value.py rename to tests/sessions/claims/test_get_claim_value.py index 13d3cc15d..417bdb0cf 100644 --- a/tests/sessions/claims/test_get_claims_value.py +++ b/tests/sessions/claims/test_get_claim_value.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock -from pytest import mark, raises +from pytest import mark from supertokens_python import init from supertokens_python.framework.request import BaseRequest @@ -9,17 +9,22 @@ create_new_session, get_claim_value, ) -from supertokens_python.recipe.session.interfaces import SessionContainer +from supertokens_python.recipe.session.interfaces import ( + SessionContainer, + GetClaimValueOkResult, + SessionDoesntExistError, +) from tests.utils import setup_function, teardown_function, start_st -from .utils import TrueClaim, st_init_args +from .utils import TrueClaim, st_init_args_with_TrueClaim _ = setup_function # type:ignore _ = teardown_function # type:ignore +pytestmark = mark.asyncio + -@mark.asyncio async def test_should_get_the_right_value(): - init(**st_init_args) # type:ignore + init(**st_init_args_with_TrueClaim) # type:ignore start_st() dummy_req: BaseRequest = MagicMock() @@ -29,24 +34,22 @@ async def test_should_get_the_right_value(): assert res is True -@mark.asyncio async def test_should_get_the_right_value_using_session_handle(): - init(**st_init_args) # type:ignore + init(**st_init_args_with_TrueClaim) # type:ignore start_st() dummy_req: BaseRequest = MagicMock() s: SessionContainer = await create_new_session(dummy_req, "someId") res = await get_claim_value(s.get_handle(), TrueClaim) - assert res is True + assert isinstance(res, GetClaimValueOkResult) + assert res.value is True -@mark.asyncio -async def test_should_throw_for_non_existing_handle(): - new_st_init = {**st_init_args, "recipe_list": [session.init()]} +async def test_should_work_for_non_existing_handle(): + new_st_init = {**st_init_args_with_TrueClaim, "recipe_list": [session.init()]} init(**new_st_init) # type: ignore start_st() - with raises(Exception) as e: - _ = await get_claim_value("non_existing_handle", TrueClaim) - assert str(e) == "Session does not exist" + res = await get_claim_value("non_existing_handle", TrueClaim) + assert isinstance(res, SessionDoesntExistError) diff --git a/tests/sessions/claims/test_primitive_claim.py b/tests/sessions/claims/test_primitive_claim.py index 1a28b9e01..d58671f2a 100644 --- a/tests/sessions/claims/test_primitive_claim.py +++ b/tests/sessions/claims/test_primitive_claim.py @@ -4,6 +4,7 @@ from pytest import mark from supertokens_python.recipe.session.claims import PrimitiveClaim +from supertokens_python.utils import resolve timestamp = real_time.time() val = {"foo": 1} @@ -132,7 +133,7 @@ async def test_should_return_none_for_empty_payload(time_mock: MagicMock): @_test_wrapper async def test_validators_should_not_validate_empty_payload(_time_mock: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) - res = claim.validators.has_value(val).validate({}) # TODO: missing await + res = await claim.validators.has_value(val).validate({}) assert res == { "isValid": False, @@ -148,7 +149,7 @@ async def test_validators_should_not_validate_empty_payload(_time_mock: MagicMoc async def test_should_not_validate_mismatching_payload(_time_mock: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build("user_id") - res = claim.validators.has_value(val2).validate(payload) + res = await claim.validators.has_value(val2).validate(payload) assert res == { "isValid": False, @@ -164,7 +165,7 @@ async def test_should_not_validate_mismatching_payload(_time_mock: MagicMock): async def test_validator_should_validate_matching_payload(_time_mock: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build("user_id") - res = claim.validators.has_value(val).validate(payload) + res = await claim.validators.has_value(val).validate(payload) assert res == {"isValid": True} @@ -179,21 +180,26 @@ async def test_should_validate_old_values_as_well(time_mock: MagicMock): # Increase clock time by 1000 time_mock.time.return_value += 100 * SECONDS # type: ignore - res = claim.validators.has_value(val).validate(payload) + res = await claim.validators.has_value(val).validate(payload) assert res == {"isValid": True} @_test_wrapper async def test_should_refetch_if_value_not_set(_time_mock: MagicMock): - claim = PrimitiveClaim("key", sync_fetch_value) - assert claim.validators.has_value(val).should_refetch(val2, {}) is True + claim = PrimitiveClaim("key", async_fetch_value) + assert ( + await resolve(claim.validators.has_value(val).should_refetch(val2, {})) is True + ) @_test_wrapper async def test_validator_should_not_refetch_if_value_is_set(_time_mock: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build("user_id") - assert claim.validators.has_value(val2).should_refetch(payload, {}) is False + assert ( + await resolve(claim.validators.has_value(val2).should_refetch(payload, {})) + is False + ) # validators.has_fresh_value @@ -202,7 +208,7 @@ async def test_validator_should_not_refetch_if_value_is_set(_time_mock: MagicMoc @_test_wrapper async def test_should_not_validate_empty_payload(_time_mock: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) - res = claim.validators.has_fresh_value(val, 600).validate({}, {}) + res = await claim.validators.has_fresh_value(val, 600).validate({}, {}) assert res == { "isValid": False, "reason": { @@ -219,7 +225,7 @@ async def test_has_fresh_value_should_not_validate_mismatching_payload( ): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build("user_id") - res = claim.validators.has_fresh_value(val2, 600).validate(payload) + res = await claim.validators.has_fresh_value(val2, 600).validate(payload) assert res == { "isValid": False, "reason": { @@ -234,7 +240,7 @@ async def test_has_fresh_value_should_not_validate_mismatching_payload( async def test_should_validate_matching_payload(_time_mock: MagicMock): claim = PrimitiveClaim("key", sync_fetch_value) payload = await claim.build("user_id") - res = claim.validators.has_fresh_value(val, 600).validate(payload) + res = await claim.validators.has_fresh_value(val, 600).validate(payload) assert res == {"isValid": True} @@ -248,7 +254,7 @@ async def test_should_not_validate_old_values_as_well(time_mock: MagicMock): # Increase clock time: time_mock.time.return_value += 100 * SECONDS # type: ignore - res = claim.validators.has_fresh_value(val, 10).validate(payload) + res = await claim.validators.has_fresh_value(val, 10).validate(payload) assert res == { "isValid": False, "reason": { diff --git a/tests/sessions/claims/test_remove_claim.py b/tests/sessions/claims/test_remove_claim.py index 0aadb95f5..00a1e82e9 100644 --- a/tests/sessions/claims/test_remove_claim.py +++ b/tests/sessions/claims/test_remove_claim.py @@ -14,7 +14,7 @@ from supertokens_python.recipe.session.session_class import Session from tests.sessions.claims.utils import TrueClaim from tests.utils import start_st, setup_function, teardown_function -from .test_get_claims_value import st_init_args +from .test_get_claim_value import st_init_args_with_TrueClaim from .utils import time_patch_wrapper _ = setup_function # type:ignore @@ -45,7 +45,7 @@ async def test_should_attempt_to_set_claim_to_none(): @time_patch_wrapper async def test_should_clear_previously_set_claim(time_mock: MagicMock): time_mock.time.return_value = timestamp # type: ignore - init(**st_init_args) # type:ignore + init(**st_init_args_with_TrueClaim) # type:ignore start_st() dummy_req: BaseRequest = MagicMock() @@ -59,7 +59,7 @@ async def test_should_clear_previously_set_claim(time_mock: MagicMock): @time_patch_wrapper async def test_should_clear_previously_set_claim_using_handle(time_mock: MagicMock): time_mock.time.return_value = timestamp # type: ignore - init(**st_init_args) # type:ignore + init(**st_init_args_with_TrueClaim) # type:ignore start_st() dummy_req: BaseRequest = MagicMock() @@ -79,7 +79,7 @@ async def test_should_clear_previously_set_claim_using_handle(time_mock: MagicMo @time_patch_wrapper async def test_should_work_ok_for_non_existing_handle(_time_mock: MagicMock): - init(**st_init_args) # type:ignore + init(**st_init_args_with_TrueClaim) # type:ignore start_st() res = await remove_claim("non-existing-handle", TrueClaim) diff --git a/tests/sessions/claims/test_set_claim_value.py b/tests/sessions/claims/test_set_claim_value.py index 2469d8215..ab7bf7652 100644 --- a/tests/sessions/claims/test_set_claim_value.py +++ b/tests/sessions/claims/test_set_claim_value.py @@ -11,7 +11,7 @@ set_claim_value, ) from supertokens_python.recipe.session.session_class import Session -from tests.sessions.claims.utils import TrueClaim, st_init_args +from tests.sessions.claims.utils import TrueClaim, st_init_args_with_TrueClaim from tests.utils import setup_function, teardown_function from tests.utils import start_st @@ -49,7 +49,7 @@ async def test_should_merge_the_right_value(): async def test_should_overwrite_claim_value(): - init(**st_init_args) # type: ignore + init(**st_init_args_with_TrueClaim) # type: ignore start_st() dummy_req: BaseRequest = MagicMock() @@ -70,7 +70,7 @@ async def test_should_overwrite_claim_value(): async def test_should_overwrite_claim_value_using_session_handle(): - init(**st_init_args) # type: ignore + init(**st_init_args_with_TrueClaim) # type: ignore start_st() dummy_req: BaseRequest = MagicMock() @@ -95,7 +95,7 @@ async def test_should_overwrite_claim_value_using_session_handle(): async def test_should_work_ok_for_non_existing_handles(): - init(**st_init_args) # type: ignore + init(**st_init_args_with_TrueClaim) # type: ignore start_st() res = await set_claim_value("non-existing-handle", TrueClaim, "NEW_TRUE") diff --git a/tests/sessions/claims/test_validate_claims_for_session_handle.py b/tests/sessions/claims/test_validate_claims_for_session_handle.py new file mode 100644 index 000000000..fec5f23f6 --- /dev/null +++ b/tests/sessions/claims/test_validate_claims_for_session_handle.py @@ -0,0 +1,65 @@ +# Session.validateClaimsForSessionHandle +from unittest.mock import MagicMock +from pytest import mark + +from supertokens_python.framework import BaseRequest +from supertokens_python import init +from supertokens_python.recipe import session +from supertokens_python.recipe.session.asyncio import ( + create_new_session, + validate_claims_for_session_handle, +) +from supertokens_python.recipe.session.exceptions import ClaimValidationError +from supertokens_python.recipe.session.interfaces import ( + ValidateClaimsOkResult, + SessionDoesntExistError, +) +from tests.sessions.claims.utils import ( + get_st_init_args, + NoneClaim, + TrueClaim, + st_init_common_args, +) +from tests.utils import setup_function, teardown_function, start_st + +_ = setup_function # type:ignore +_ = teardown_function # type:ignore + +pytest_mark = mark.asyncio + + +async def test_should_return_the_right_validation_errors(): + init(**get_st_init_args(TrueClaim)) # type:ignore + start_st() + + dummy_req: BaseRequest = MagicMock() + s = await create_new_session(dummy_req, "someId") + + failing_validator = NoneClaim.validators.has_value(True) + res = await validate_claims_for_session_handle( + s.get_handle(), + lambda _, __, ___: [TrueClaim.validators.has_value(True), failing_validator], + ) + assert res == ValidateClaimsOkResult( + [ + ClaimValidationError( + failing_validator.id, + reason={ + "message": "wrong value", + "actualValue": None, + "expectedValue": True, + }, + ) + ] + ) + + +async def test_should_work_for_not_existing_handle(): + new_st_init = {**st_init_common_args, "recipe_list": [session.init()]} + init(**new_st_init) # type: ignore + start_st() + + res = await validate_claims_for_session_handle( + "non_existing_handle", lambda _, __, ___: [] + ) + assert isinstance(res, SessionDoesntExistError) diff --git a/tests/sessions/claims/test_verify_session.py b/tests/sessions/claims/test_verify_session.py index 925c4773b..4e4e388ac 100644 --- a/tests/sessions/claims/test_verify_session.py +++ b/tests/sessions/claims/test_verify_session.py @@ -1,4 +1,4 @@ -from typing import List, Any, Dict, Union, Awaitable +from typing import List, Any, Dict, Union from unittest.mock import patch, AsyncMock from fastapi import FastAPI, Depends @@ -88,18 +88,18 @@ async def new_get_global_claim_validators( class AlwaysValidValidator(SessionClaimValidator): id = "always-valid-validator" - def validate( + async def validate( self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None - ) -> Union[ClaimValidationResult, Awaitable[ClaimValidationResult]]: + ) -> ClaimValidationResult: return {"isValid": True} class AlwaysInvalidValidator(SessionClaimValidator): id = "always-invalid-validator" - def validate( + async def validate( self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None - ): + ) -> ClaimValidationResult: return {"isValid": False, "reason": "foo"} @@ -163,9 +163,9 @@ def __init__(self, is_valid: bool): super().__init__("test_id") self.is_valid = is_valid - def validate( + async def validate( self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None - ) -> Union[ClaimValidationResult, Awaitable[ClaimValidationResult]]: + ) -> ClaimValidationResult: if self.is_valid: return {"isValid": True} return {"isValid": False, "reason": "test_reason"} diff --git a/tests/sessions/claims/test_with_jwt.py b/tests/sessions/claims/test_with_jwt.py new file mode 100644 index 000000000..9f5dc2ffb --- /dev/null +++ b/tests/sessions/claims/test_with_jwt.py @@ -0,0 +1,83 @@ +import json + +from fastapi import FastAPI +from fastapi.requests import Request +from pytest import fixture +from starlette.testclient import TestClient + +from supertokens_python import init +from supertokens_python.framework.fastapi import get_middleware +from supertokens_python.recipe.session import JWTConfig +from supertokens_python.recipe.session.asyncio import ( + create_new_session, + get_session_information, + validate_claims_in_jwt_payload, +) +from supertokens_python.recipe.session.exceptions import ClaimValidationError +from supertokens_python.recipe.session.interfaces import ValidateClaimsOkResult +from supertokens_python.utils import utf_base64decode +from tests.sessions.claims.utils import get_st_init_args, NoneClaim, TrueClaim +from tests.utils import setup_function, teardown_function, start_st, min_api_version + +_ = setup_function # type:ignore +_ = teardown_function # type:ignore + + +@fixture(scope="function") +async def fastapi_client(): + app = FastAPI() + app.add_middleware(get_middleware()) + + @app.post("/create") + async def create_api(request: Request): # type: ignore + user_id = "userId" + s = await create_new_session(request, user_id, {}, {}) + return {"session_handle": s.get_handle()} + + return TestClient(app) + + +@min_api_version("2.9") +async def test_should_create_the_right_access_token_payload_with_claims_and_JWT_enabled( + fastapi_client: TestClient, +): + init(**get_st_init_args(TrueClaim, jwt=JWTConfig(enable=True))) # type:ignore + start_st() + + create_res = fastapi_client.post(url="/create") + session_handle = create_res.json()["session_handle"] + + session_info = await get_session_information(session_handle) + assert session_info is not None + access_token_payload = session_info.access_token_payload + # TODO: .sub and .iss should be undefined as per node PR + assert access_token_payload["jwt"] is not None + assert access_token_payload["_jwtPName"] == "jwt" + + decoded_jwt = json.loads(utf_base64decode(access_token_payload["jwt"])) + assert decoded_jwt == { + "sub": "userId", + "iss": "https://api.supertokens.io/auth", + "_jwtPName": None, + } + + assert TrueClaim.get_value_from_payload(access_token_payload) is True + assert TrueClaim.get_value_from_payload(decoded_jwt) is True + + failing_validator = NoneClaim.validators.has_value(True) + res = await validate_claims_in_jwt_payload( + session_info.user_id, + decoded_jwt, + lambda _, __, ___: [ + TrueClaim.validators.has_fresh_value(True, 2), + failing_validator, + ], + ) + + assert isinstance(res, ValidateClaimsOkResult) + assert res.invalid_claims == [ + ClaimValidationError( + failing_validator.id, + {"actualValue": None, "expectedValue": True, "message": "wrong value"}, + ) + ] diff --git a/tests/sessions/claims/utils.py b/tests/sessions/claims/utils.py index 3191715d0..6fb976a07 100644 --- a/tests/sessions/claims/utils.py +++ b/tests/sessions/claims/utils.py @@ -1,5 +1,5 @@ import time as real_time -from typing import Dict, Any, Union +from typing import Dict, Any, Union, Optional from unittest.mock import patch from pytest import mark @@ -7,6 +7,7 @@ from supertokens_python import InputAppInfo, SupertokensConfig from supertokens_python.framework.request import BaseRequest from supertokens_python.recipe import session +from supertokens_python.recipe.session import JWTConfig from supertokens_python.recipe.session.claims import BooleanClaim, SessionClaim from supertokens_python.recipe.session.interfaces import RecipeInterface @@ -69,7 +70,7 @@ async def new_create_new_session( "mode": "asgi", } -st_init_args = { +st_init_args_with_TrueClaim = { **st_init_common_args, "recipe_list": [ session.init( @@ -81,14 +82,17 @@ async def new_create_new_session( } -def get_st_init_args(claim: SessionClaim[Any] = TrueClaim): +def get_st_init_args( + claim: SessionClaim[Any] = TrueClaim, jwt: Optional[JWTConfig] = None +): return { - **st_init_args, + **st_init_args_with_TrueClaim, "recipe_list": [ session.init( override=session.InputOverrideConfig( functions=session_functions_override_with_claim(claim), - ) + ), + jwt=jwt, ), ], }