diff --git a/supertokens_python/post_init_callbacks.py b/supertokens_python/post_init_callbacks.py index ec27fc2a1..227acb78e 100644 --- a/supertokens_python/post_init_callbacks.py +++ b/supertokens_python/post_init_callbacks.py @@ -1,3 +1,17 @@ +# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. +# +# This software is licensed under the Apache License, Version 2.0 (the +# "License") as published by the Apache Software Foundation. +# +# You may not use this file except in compliance with the License. You may +# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + from typing import Callable, List diff --git a/supertokens_python/recipe/emailpassword/emaildelivery/services/smtp/service_implementation/__init__.py b/supertokens_python/recipe/emailpassword/emaildelivery/services/smtp/service_implementation/__init__.py index e6ea9481b..b7d9088b4 100644 --- a/supertokens_python/recipe/emailpassword/emaildelivery/services/smtp/service_implementation/__init__.py +++ b/supertokens_python/recipe/emailpassword/emaildelivery/services/smtp/service_implementation/__init__.py @@ -14,7 +14,6 @@ from typing import Any, Dict -from supertokens_python.ingredients.emaildelivery.services.smtp import Transporter from supertokens_python.ingredients.emaildelivery.types import ( EmailContent, SMTPServiceInterface, @@ -23,21 +22,9 @@ get_password_reset_email_content, ) from supertokens_python.recipe.emailpassword.types import EmailTemplateVars -from supertokens_python.recipe.emailverification.emaildelivery.services.smtp.service_implementation import ( - ServiceImplementation as EVServiceImplementation, -) class ServiceImplementation(SMTPServiceInterface[EmailTemplateVars]): - def __init__(self, transporter: Transporter) -> None: - super().__init__(transporter) - - email_verification_service_implementation = EVServiceImplementation(transporter) - self.ev_send_raw_email = ( - email_verification_service_implementation.send_raw_email - ) - self.ev_get_content = email_verification_service_implementation.get_content - async def send_raw_email( self, content: EmailContent, user_context: Dict[str, Any] ) -> None: diff --git a/supertokens_python/recipe/emailpassword/recipe.py b/supertokens_python/recipe/emailpassword/recipe.py index cbf32e78b..61195aafb 100644 --- a/supertokens_python/recipe/emailpassword/recipe.py +++ b/supertokens_python/recipe/emailpassword/recipe.py @@ -189,7 +189,7 @@ async def handle_api_request( ) if request_id == USER_PASSWORD_RESET: return await handle_password_reset_api(self.api_implementation, api_options) - # FIXME: Should be False as per Node PR but the spec here don't allow it. + return None async def handle_error( diff --git a/supertokens_python/recipe/emailverification/__init__.py b/supertokens_python/recipe/emailverification/__init__.py index ee37a884e..ec952b84a 100644 --- a/supertokens_python/recipe/emailverification/__init__.py +++ b/supertokens_python/recipe/emailverification/__init__.py @@ -20,6 +20,7 @@ from .emaildelivery import services as emaildelivery_services from . import recipe from .interfaces import TypeGetEmailForUserIdFunction +from .recipe import EmailVerificationRecipe from .types import EmailTemplateVars, User from ...ingredients.emaildelivery.types import EmailDeliveryConfig @@ -27,7 +28,6 @@ exception = ex SMTPService = emaildelivery_services.SMTPService EmailVerificationClaim = recipe.EmailVerificationClaim -EmailVerificationRecipe = recipe.EmailVerificationRecipe if TYPE_CHECKING: diff --git a/supertokens_python/recipe/emailverification/api/email_verify.py b/supertokens_python/recipe/emailverification/api/email_verify.py index 7d35ed530..1481db177 100644 --- a/supertokens_python/recipe/emailverification/api/email_verify.py +++ b/supertokens_python/recipe/emailverification/api/email_verify.py @@ -52,7 +52,7 @@ async def handle_email_verify_api( ) result = await api_implementation.email_verify_post( - token, api_options, session, user_context + token, session, api_options, user_context ) else: if api_implementation.disable_is_email_verified_get: @@ -63,9 +63,9 @@ async def handle_email_verify_api( override_global_claim_validators=lambda _, __, ___: [], user_context=user_context, ) - + assert session is not None result = await api_implementation.is_email_verified_get( - api_options, session, user_context + session, api_options, user_context ) return send_200_response(result.to_json(), api_options.response) diff --git a/supertokens_python/recipe/emailverification/api/generate_email_verify_token.py b/supertokens_python/recipe/emailverification/api/generate_email_verify_token.py index e14bdf887..d201904f8 100644 --- a/supertokens_python/recipe/emailverification/api/generate_email_verify_token.py +++ b/supertokens_python/recipe/emailverification/api/generate_email_verify_token.py @@ -35,6 +35,6 @@ async def handle_generate_email_verify_token_api( assert session is not None result = await api_implementation.generate_email_verify_token_post( - api_options, session, user_context + session, api_options, user_context ) return send_200_response(result.to_json(), api_options.response) diff --git a/supertokens_python/recipe/emailverification/asyncio/__init__.py b/supertokens_python/recipe/emailverification/asyncio/__init__.py index 27896eee4..e3c46c8a4 100644 --- a/supertokens_python/recipe/emailverification/asyncio/__init__.py +++ b/supertokens_python/recipe/emailverification/asyncio/__init__.py @@ -18,6 +18,7 @@ CreateEmailVerificationTokenEmailAlreadyVerifiedError, UnverifyEmailOkResult, CreateEmailVerificationTokenOkResult, + RevokeEmailVerificationTokensOkResult, ) from supertokens_python.recipe.emailverification.types import EmailTemplateVars from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe @@ -81,11 +82,11 @@ async def is_email_verified( ) -async def revoke_email_verification_token( +async def revoke_email_verification_tokens( user_id: str, email: Optional[str] = None, user_context: Optional[Dict[str, Any]] = None, -): +) -> RevokeEmailVerificationTokensOkResult: if user_context is None: user_context = {} @@ -95,9 +96,7 @@ async def revoke_email_verification_token( if isinstance(email_info, GetEmailForUserIdOkResult): email = email_info.email elif isinstance(email_info, EmailDoesNotExistError): - # Here we are returning OK since that's how it used to work, but a later call - # to is_verified will still return true - return CreateEmailVerificationTokenEmailAlreadyVerifiedError() + return RevokeEmailVerificationTokensOkResult() else: raise Exception("Unknown User ID provided without email") @@ -131,16 +130,6 @@ async def unverify_email( ) -async def revoke_email_verification_tokens( - user_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None -): - if user_context is None: - user_context = {} - return await EmailVerificationRecipe.get_instance().recipe_implementation.revoke_email_verification_tokens( - user_id, email, user_context - ) - - async def send_email( input_: EmailTemplateVars, user_context: Union[None, Dict[str, Any]] = None ): diff --git a/supertokens_python/recipe/emailverification/ev_claim.py b/supertokens_python/recipe/emailverification/ev_claim.py deleted file mode 100644 index 0c810a00a..000000000 --- a/supertokens_python/recipe/emailverification/ev_claim.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. -# -# This software is licensed under the Apache License, Version 2.0 (the -# "License") as published by the Apache Software Foundation. -# -# You may not use this file except in compliance with the License. You may -# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -from __future__ import annotations -from typing import Dict, Any - -from supertokens_python.recipe.session.claim_base_classes.boolean_claim import ( - BooleanClaim, - BooleanClaimValidators, -) -from supertokens_python.recipe.session.interfaces import ( - SessionClaimValidator, - JSONObject, - ClaimValidationResult, -) -from supertokens_python.types import MaybeAwaitable -from supertokens_python.utils import get_timestamp_ms - - -class IsVerifiedSCV(SessionClaimValidator): - def __init__( - self, - claim: BooleanClaim, - has_value_validator: SessionClaimValidator, - refetch_time_on_false_in_seconds: int, - max_age_in_seconds: int, - ): - super().__init__("st-ev-is-verified") - self.claim: BooleanClaim = claim # TODO: Should work without specifying type of self.claim (no pyright errors) - self.has_value_validator = has_value_validator - self.refetch_time_on_false_in_ms = refetch_time_on_false_in_seconds * 1000 - self.max_age_in_ms = max_age_in_seconds * 1000 - - async def validate( - self, payload: JSONObject, user_context: Dict[str, Any] - ) -> ClaimValidationResult: - return await self.has_value_validator.validate(payload, user_context) - - def should_refetch( - self, payload: JSONObject, user_context: Dict[str, Any] - ) -> MaybeAwaitable[bool]: - value = self.claim.get_value_from_payload(payload, user_context) - last_refetch_time = self.claim.get_last_refetch_time(payload, user_context) - assert last_refetch_time is not None - - return ( - (value is None) - or (last_refetch_time < get_timestamp_ms() - self.max_age_in_ms) - or ( - value is False - and last_refetch_time - < ( - get_timestamp_ms() - self.refetch_time_on_false_in_ms - ) # TODO: Default 5 min? - ) - ) - - -class EmailVerificationClaimValidators(BooleanClaimValidators): - def is_verified( - self, - refetch_time_on_false_in_seconds: int = 10, - max_age_in_seconds: int = 300, - ) -> SessionClaimValidator: - has_value_res = self.has_value(True, id_="st-ev-is-verified") - assert isinstance(self.claim, BooleanClaim) - return IsVerifiedSCV( - self.claim, - has_value_res, - refetch_time_on_false_in_seconds, - max_age_in_seconds, - ) diff --git a/supertokens_python/recipe/emailverification/api/implementation.py b/supertokens_python/recipe/emailverification/ev_claim_validators.py similarity index 92% rename from supertokens_python/recipe/emailverification/api/implementation.py rename to supertokens_python/recipe/emailverification/ev_claim_validators.py index 400856fa3..dd5f414fc 100644 --- a/supertokens_python/recipe/emailverification/api/implementation.py +++ b/supertokens_python/recipe/emailverification/ev_claim_validators.py @@ -12,8 +12,3 @@ # License for the specific language governing permissions and limitations # under the License. from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - pass diff --git a/supertokens_python/recipe/emailverification/interfaces.py b/supertokens_python/recipe/emailverification/interfaces.py index 31027edfb..1f508548c 100644 --- a/supertokens_python/recipe/emailverification/interfaces.py +++ b/supertokens_python/recipe/emailverification/interfaces.py @@ -167,8 +167,8 @@ def __init__(self): async def email_verify_post( self, token: str, - api_options: APIOptions, session: Optional[SessionContainer], + api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ EmailVerifyPostOkResult, EmailVerifyPostInvalidTokenError, GeneralErrorResponse @@ -178,8 +178,8 @@ async def email_verify_post( @abstractmethod async def is_email_verified_get( self, + session: SessionContainer, api_options: APIOptions, - session: Optional[SessionContainer], user_context: Dict[str, Any], ) -> Union[IsEmailVerifiedGetOkResult, GeneralErrorResponse]: pass @@ -187,8 +187,8 @@ async def is_email_verified_get( @abstractmethod async def generate_email_verify_token_post( self, - api_options: APIOptions, session: SessionContainer, + api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ GenerateEmailVerifyTokenPostOkResult, diff --git a/supertokens_python/recipe/emailverification/recipe.py b/supertokens_python/recipe/emailverification/recipe.py index 0e1b03a6c..162d00d1c 100644 --- a/supertokens_python/recipe/emailverification/recipe.py +++ b/supertokens_python/recipe/emailverification/recipe.py @@ -34,9 +34,16 @@ from ...logger import log_debug_message from ...post_init_callbacks import PostSTInitCallbacks from ..session import SessionRecipe -from ..session.claim_base_classes.boolean_claim import BooleanClaim -from ..session.interfaces import SessionContainer -from .ev_claim import EmailVerificationClaimValidators +from ..session.claim_base_classes.boolean_claim import ( + BooleanClaim, + BooleanClaimValidators, +) +from ..session.interfaces import ( + SessionContainer, + SessionClaimValidator, + JSONObject, + ClaimValidationResult, +) from .interfaces import ( APIInterface, APIOptions, @@ -53,6 +60,8 @@ VerifyEmailUsingTokenOkResult, ) from .recipe_implementation import RecipeImplementation +from ...types import MaybeAwaitable +from ...utils import get_timestamp_ms if TYPE_CHECKING: from supertokens_python.framework.request import BaseRequest @@ -272,6 +281,25 @@ def add_get_email_for_user_id_func(self, f: TypeGetEmailForUserIdFunction): self.get_email_for_user_id_funcs_from_other_recipes.append(f) +class EmailVerificationClaimValidators(BooleanClaimValidators): + def is_verified( + self, + refetch_time_on_false_in_seconds: int = 10, + max_age_in_seconds: Optional[int] = None, + id_: Optional[str] = None, + ) -> SessionClaimValidator: + max_age_in_seconds = max_age_in_seconds or self.default_max_age_in_sec + + assert isinstance(self.claim, EmailVerificationClaimClass) + return IsVerifiedSCV( + (id_ or self.claim.key), + self.claim, + self, + refetch_time_on_false_in_seconds, + max_age_in_seconds, + ) + + class EmailVerificationClaimClass(BooleanClaim): def __init__(self): async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> bool: @@ -291,7 +319,9 @@ async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> bool: super().__init__("st-ev", fetch_value) - self.validators = EmailVerificationClaimValidators(claim=self) + self.validators = EmailVerificationClaimValidators( + claim=self, default_max_age_in_sec=300 + ) EmailVerificationClaim = EmailVerificationClaimClass() @@ -301,8 +331,8 @@ class APIImplementation(APIInterface): async def email_verify_post( self, token: str, - api_options: APIOptions, session: Optional[SessionContainer], + api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[EmailVerifyPostOkResult, EmailVerifyPostInvalidTokenError]: @@ -318,8 +348,8 @@ async def email_verify_post( async def is_email_verified_get( self, + session: SessionContainer, api_options: APIOptions, - session: Optional[SessionContainer], user_context: Dict[str, Any], ) -> IsEmailVerifiedGetOkResult: if session is None: @@ -338,16 +368,13 @@ async def is_email_verified_get( async def generate_email_verify_token_post( self, - api_options: APIOptions, session: SessionContainer, + api_options: APIOptions, user_context: Dict[str, Any], ) -> Union[ GenerateEmailVerifyTokenPostOkResult, GenerateEmailVerifyTokenPostEmailAlreadyVerifiedError, ]: - if session is None: - raise Exception("Session is undefined. Should not come here.") - user_id = session.get_user_id(user_context) email_info = await EmailVerificationRecipe.get_instance().get_email_for_user_id( user_id, user_context @@ -380,7 +407,7 @@ async def generate_email_verify_token_post( email_verify_link = ( api_options.app_info.website_domain.get_as_string_dangerous() + api_options.app_info.website_base_path.get_as_string_dangerous() - + "/verify-email/" + + "/verify-email" + "?token=" + response.token + "&rid=" @@ -401,3 +428,42 @@ async def generate_email_verify_token_post( raise Exception( "Should never come here: UNKNOWN_USER_ID or invalid result from get_email_for_user_id" ) + + +class IsVerifiedSCV(SessionClaimValidator): + def __init__( + self, + id_: str, + claim: EmailVerificationClaimClass, + ev_claim_validators: EmailVerificationClaimValidators, + refetch_time_on_false_in_seconds: int, + max_age_in_seconds: int, + ): + super().__init__(id_) + self.claim: EmailVerificationClaimClass = claim + self.ev_claim_validators = ev_claim_validators + self.refetch_time_on_false_in_ms = refetch_time_on_false_in_seconds * 1000 + self.max_age_in_ms = max_age_in_seconds * 1000 + + async def validate( + self, payload: JSONObject, user_context: Dict[str, Any] + ) -> ClaimValidationResult: + return await self.ev_claim_validators.has_value(True).validate( + payload, user_context + ) + + def should_refetch( + self, payload: JSONObject, user_context: Dict[str, Any] + ) -> MaybeAwaitable[bool]: + value = self.claim.get_value_from_payload(payload, user_context) + if value is None: + return True + + last_refetch_time = self.claim.get_last_refetch_time(payload, user_context) + assert last_refetch_time is not None + + return (last_refetch_time < get_timestamp_ms() - self.max_age_in_ms) or ( + value is False + and last_refetch_time + < (get_timestamp_ms() - self.refetch_time_on_false_in_ms) + ) diff --git a/supertokens_python/recipe/emailverification/syncio/__init__.py b/supertokens_python/recipe/emailverification/syncio/__init__.py index 5213e2ffc..aba537804 100644 --- a/supertokens_python/recipe/emailverification/syncio/__init__.py +++ b/supertokens_python/recipe/emailverification/syncio/__init__.py @@ -12,14 +12,16 @@ # License for the specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Union +from typing import Any, Dict, Union, Optional from supertokens_python.async_to_sync_wrapper import sync from supertokens_python.recipe.emailverification.types import EmailTemplateVars def create_email_verification_token( - user_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None + user_id: str, + email: Optional[str] = None, + user_context: Union[None, Dict[str, Any]] = None, ): from supertokens_python.recipe.emailverification.asyncio import ( create_email_verification_token, @@ -39,15 +41,9 @@ def verify_email_using_token( def is_email_verified( - user_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None -): - from supertokens_python.recipe.emailverification.asyncio import is_email_verified - - return sync(is_email_verified(user_id, email, user_context)) - - -def unverify_email( - user_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None + user_id: str, + email: Optional[str] = None, + user_context: Union[None, Dict[str, Any]] = None, ): from supertokens_python.recipe.emailverification.asyncio import is_email_verified @@ -55,7 +51,9 @@ def unverify_email( def revoke_email_verification_tokens( - user_id: str, email: str, user_context: Union[None, Dict[str, Any]] = None + user_id: str, + email: Optional[str] = None, + user_context: Optional[Dict[str, Any]] = None, ): from supertokens_python.recipe.emailverification.asyncio import ( revoke_email_verification_tokens, @@ -64,6 +62,16 @@ def revoke_email_verification_tokens( return sync(revoke_email_verification_tokens(user_id, email, user_context)) +def unverify_email( + user_id: str, + email: Optional[str] = None, + user_context: Union[None, Dict[str, Any]] = None, +): + from supertokens_python.recipe.emailverification.asyncio import is_email_verified + + return sync(is_email_verified(user_id, email, user_context)) + + def send_email( input_: EmailTemplateVars, user_context: Union[None, Dict[str, Any]] = None ): diff --git a/supertokens_python/recipe/session/api/implementation.py b/supertokens_python/recipe/session/api/implementation.py index ac3b3c51c..ea7b6c41c 100644 --- a/supertokens_python/recipe/session/api/implementation.py +++ b/supertokens_python/recipe/session/api/implementation.py @@ -42,8 +42,8 @@ async def refresh_post( async def signout_post( self, - api_options: APIOptions, session: Optional[SessionContainer], + api_options: APIOptions, user_context: Dict[str, Any], ) -> SignOutOkayResponse: if session is not None: @@ -76,7 +76,6 @@ async def verify_session( api_options.request, anti_csrf_check, session_required, - override_global_claim_validators, user_context, ) diff --git a/supertokens_python/recipe/session/api/signout.py b/supertokens_python/recipe/session/api/signout.py index a1aed2e4e..7bbe8a2cc 100644 --- a/supertokens_python/recipe/session/api/signout.py +++ b/supertokens_python/recipe/session/api/signout.py @@ -31,13 +31,12 @@ async def handle_signout_api(api_implementation: APIInterface, api_options: APIO session = await api_options.recipe_implementation.get_session( request=api_options.request, - anti_csrf_check=None, # TODO: What should I pass here? + anti_csrf_check=None, session_required=False, - override_global_claim_validators=lambda _, __, ___: [], user_context=user_context, ) - response = await api_implementation.signout_post(api_options, session, user_context) + response = await api_implementation.signout_post(session, api_options, user_context) if api_options.response is None: raise Exception("Should never come here") return send_200_response(response.to_json(), api_options.response) diff --git a/supertokens_python/recipe/session/asyncio/__init__.py b/supertokens_python/recipe/session/asyncio/__init__.py index 89f063479..6225905e2 100644 --- a/supertokens_python/recipe/session/asyncio/__init__.py +++ b/supertokens_python/recipe/session/asyncio/__init__.py @@ -259,7 +259,6 @@ async def get_session( request, anti_csrf_check, session_required, - lambda _, __, ___: [], user_context, ) diff --git a/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py b/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py index 3ac1e8f16..ada5111a7 100644 --- a/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/boolean_claim.py @@ -11,24 +11,18 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -from typing import Callable, Optional, Dict, Any, TypeVar +from typing import Callable, Optional, Dict, Any from supertokens_python.types import MaybeAwaitable from .primitive_claim import PrimitiveClaim, PrimitiveClaimValidators -_T = TypeVar("_T", bound=bool) - class BooleanClaimValidators(PrimitiveClaimValidators[bool]): - def is_true(self, max_age: Optional[int]): - if max_age is not None: - return self.has_value(True, max_age) - return self.has_value(True) + def is_true(self, max_age: Optional[int], id_: Optional[str] = None): + return self.has_value(True, max_age, id_) - def is_false(self, max_age: Optional[int]): - if max_age is not None: - return self.has_value(False, max_age) - return self.has_value(False) + def is_false(self, max_age: Optional[int], id_: Optional[str] = None): + return self.has_value(False, max_age, id_) class BooleanClaim(PrimitiveClaim[bool]): @@ -37,8 +31,8 @@ def __init__( key: str, fetch_value: Callable[ [str, Dict[str, Any]], - MaybeAwaitable[Optional[_T]], + MaybeAwaitable[Optional[bool]], ], ): super().__init__(key, fetch_value) - self.validators = BooleanClaimValidators(claim=self) + self.validators = BooleanClaimValidators(claim=self, default_max_age_in_sec=300) diff --git a/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py b/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py index cc73bdc14..7a8899029 100644 --- a/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/primitive_array_claim.py @@ -39,10 +39,10 @@ def __init__( id_: str, claim: SessionClaim[PrimitiveList], val: _T, - max_age_in_sec: int, # TODO: Default 5 min + max_age_in_sec: int, ): super().__init__(id_) - self.claim: SessionClaim[PrimitiveList] = claim + self.claim: SessionClaim[PrimitiveList] = claim # TODO:PrimitiveArrayClaim self.val = val self.max_age_in_sec = max_age_in_sec @@ -56,7 +56,7 @@ def should_refetch( return (claim.get_value_from_payload(payload, user_context) is None) or ( self.max_age_in_sec is not None and ( - payload[claim.key].get("t", 0) + payload[claim.key]["t"] < get_timestamp_ms() - self.max_age_in_sec * 1000 ) ) @@ -112,7 +112,7 @@ async def _validate( is_valid=False, reason={ "message": "wrong value", - expected_key: vals, # FIXME: Returns a list when val is Primitive whereas + expected_key: val, # other SDKs return the item itself "actualValue": claim_val, }, @@ -124,7 +124,7 @@ async def _validate( is_valid=False, reason={ "message": "wrong value", - expected_key: vals, # FIXME: Returns a list when val is Primitive whereas + expected_key: val, # other SDKs return the item itself "actualValue": claim_val, }, @@ -179,8 +179,8 @@ def __init__( def includes( # pyright: ignore[reportInvalidTypeVarUse] self, val: Primitive, # pyright: ignore[reportInvalidTypeVarUse] - id_: Union[str, None] = None, max_age_in_seconds: Optional[int] = None, + id_: Union[str, None] = None, ) -> SessionClaimValidator: max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec return IncludesSCV( @@ -190,8 +190,8 @@ def includes( # pyright: ignore[reportInvalidTypeVarUse] def excludes( # pyright: ignore[reportInvalidTypeVarUse] self, val: Primitive, # pyright: ignore[reportInvalidTypeVarUse] - id_: Union[str, None] = None, max_age_in_seconds: Optional[int] = None, + id_: Union[str, None] = None, ) -> SessionClaimValidator: max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec return ExcludesSCV( @@ -201,8 +201,8 @@ def excludes( # pyright: ignore[reportInvalidTypeVarUse] def includes_all( self, val: PrimitiveList, - id_: Union[str, None] = None, max_age_in_seconds: Optional[int] = None, + id_: Union[str, None] = None, ) -> SessionClaimValidator: max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec return IncludesAllSCV( @@ -212,8 +212,8 @@ def includes_all( def excludes_all( self, val: PrimitiveList, - id_: Union[str, None] = None, max_age_in_seconds: Optional[int] = None, + id_: Union[str, None] = None, ) -> SessionClaimValidator: max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec return ExcludesAllSCV( @@ -250,13 +250,13 @@ def add_to_payload_( return payload def remove_from_payload_by_merge_( - self, payload: JSONObject, user_context: Dict[str, Any] + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None ) -> JSONObject: payload[self.key] = None return payload def remove_from_payload( - self, payload: JSONObject, user_context: Dict[str, Any] + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None ) -> JSONObject: del payload[self.key] return payload diff --git a/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py b/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py index 64b2b9384..c308c5696 100644 --- a/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py +++ b/supertokens_python/recipe/session/claim_base_classes/primitive_claim.py @@ -25,21 +25,21 @@ ClaimValidationResult, ) -_T = TypeVar("_T", bound=JSONPrimitive) +Primitive = TypeVar("Primitive", bound=JSONPrimitive) class HasValueSCV(SessionClaimValidator): def __init__( self, id_: str, - claim: SessionClaim[_T], - val: _T, - max_age_in_sec: Optional[int] = None, + claim: SessionClaim[Primitive], + val: Primitive, + max_age_in_sec: int, ): super().__init__(id_) - self.claim: SessionClaim[_T] = claim # Required to fix the type for pyright + self.claim: SessionClaim[Primitive] = claim # to fix the type for pyright self.val = val - self.max_age_in_sec = max_age_in_sec or 300 + self.max_age_in_sec = max_age_in_sec def should_refetch( self, @@ -61,7 +61,9 @@ async def validate( val = self.val max_age_in_sec = self.max_age_in_sec - claim_val = self.claim.get_value_from_payload(payload, user_context) + claim_val: JSONPrimitive = self.claim.get_value_from_payload( + payload, user_context + ) if claim_val is None: return ClaimValidationResult( is_valid=False, @@ -87,7 +89,7 @@ async def validate( }, ) - if claim_val != val: # type: ignore + if claim_val != val: return ClaimValidationResult( is_valid=False, reason={ @@ -100,16 +102,18 @@ async def validate( return ClaimValidationResult(is_valid=True) -class PrimitiveClaimValidators(Generic[_T]): +class PrimitiveClaimValidators(Generic[Primitive]): def __init__( - self, claim: SessionClaim[_T], default_max_age_in_sec: Optional[int] = None + self, + claim: SessionClaim[Primitive], + default_max_age_in_sec: int, ) -> None: self.claim = claim - self.default_max_age_in_sec = default_max_age_in_sec or 300 + self.default_max_age_in_sec = default_max_age_in_sec def has_value( self, - val: _T, + val: Primitive, max_age_in_sec: Optional[int] = None, id_: Optional[str] = None, ) -> SessionClaimValidator: @@ -119,25 +123,25 @@ def has_value( ) -class PrimitiveClaim(SessionClaim[_T]): +class PrimitiveClaim(SessionClaim[Primitive]): def __init__( self, key: str, fetch_value: Callable[ [str, Dict[str, Any]], - MaybeAwaitable[Optional[_T]], + MaybeAwaitable[Optional[Primitive]], ], default_max_age_in_sec: Optional[int] = None, ) -> None: super().__init__(key, fetch_value) claim = self - self.validators = PrimitiveClaimValidators(claim, default_max_age_in_sec) + self.validators = PrimitiveClaimValidators(claim, default_max_age_in_sec or 300) def add_to_payload_( self, payload: Dict[str, Any], - value: _T, + value: Primitive, user_context: Union[Dict[str, Any], None] = None, ) -> JSONObject: payload[self.key] = {"v": value, "t": get_timestamp_ms()} @@ -146,20 +150,20 @@ def add_to_payload_( return payload def remove_from_payload_by_merge_( - self, payload: JSONObject, user_context: Dict[str, Any] + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None ) -> JSONObject: payload[self.key] = None return payload def remove_from_payload( - self, payload: JSONObject, user_context: Dict[str, Any] + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None ) -> JSONObject: del payload[self.key] return payload def get_value_from_payload( self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None - ) -> Union[_T, None]: + ) -> Union[Primitive, None]: _ = user_context return payload.get(self.key, {}).get("v") diff --git a/supertokens_python/recipe/session/claims.py b/supertokens_python/recipe/session/claims.py index e3bd68b78..83bc17fc4 100644 --- a/supertokens_python/recipe/session/claims.py +++ b/supertokens_python/recipe/session/claims.py @@ -13,8 +13,9 @@ # under the License. from . import interfaces -from .claim_base_classes import boolean_claim, primitive_claim +from .claim_base_classes import boolean_claim, primitive_claim, primitive_array_claim SessionClaim = interfaces.SessionClaim BooleanClaim = boolean_claim.BooleanClaim PrimitiveClaim = primitive_claim.PrimitiveClaim +PrimitiveArrayClaim = primitive_array_claim.PrimitiveArrayClaim diff --git a/supertokens_python/recipe/session/exceptions.py b/supertokens_python/recipe/session/exceptions.py index 7d874d8fb..8df97a111 100644 --- a/supertokens_python/recipe/session/exceptions.py +++ b/supertokens_python/recipe/session/exceptions.py @@ -56,14 +56,7 @@ class TryRefreshTokenError(SuperTokensSessionError): class InvalidClaimsError(SuperTokensSessionError): def __init__(self, msg: str, payload: List[ClaimValidationError]): super().__init__(msg) - self.payload: List[Dict[str, Any]] = [] - for p in payload: - res = ( - p.__dict__.copy() - ) # Must be JSON serializable as it will be used in response - if p.reason is None: - res.pop("reason") - self.payload.append(res) + self.payload = payload class ClaimValidationError: diff --git a/supertokens_python/recipe/session/interfaces.py b/supertokens_python/recipe/session/interfaces.py index 56f0465ca..916572020 100644 --- a/supertokens_python/recipe/session/interfaces.py +++ b/supertokens_python/recipe/session/interfaces.py @@ -82,8 +82,6 @@ def __init__( List[str], List[int], List[bool], List[None], List[Dict[str, Any]] ] -FetchValueReturnType = Union[_T, None] - class SessionDoesNotExistError: pass @@ -134,12 +132,6 @@ async def get_session( request: BaseRequest, anti_csrf_check: Union[bool, None], session_required: bool, - override_global_claim_validators: Optional[ - Callable[ - [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], - MaybeAwaitable[List[SessionClaimValidator]], - ] - ], user_context: Dict[str, Any], ) -> Union[SessionContainer, None]: pass @@ -222,7 +214,7 @@ async def update_access_token_payload( async def merge_into_access_token_payload( self, session_handle: str, - access_token_payload_update: Dict[str, Any], + access_token_payload_update: JSONObject, user_context: Dict[str, Any], ) -> bool: pass @@ -324,8 +316,8 @@ async def refresh_post( @abstractmethod async def signout_post( self, - api_options: APIOptions, session: Optional[SessionContainer], + api_options: APIOptions, user_context: Dict[str, Any], ) -> SignOutOkayResponse: pass @@ -368,12 +360,14 @@ def __init__( self.remove_cookies = False @abstractmethod - async def revoke_session(self, user_context: Union[Any, None] = None) -> None: + async def revoke_session( + self, user_context: Optional[Dict[str, Any]] = None + ) -> None: pass @abstractmethod async def get_session_data( - self, user_context: Union[Dict[str, Any], None] = None + self, user_context: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: pass @@ -381,7 +375,7 @@ async def get_session_data( async def update_session_data( self, new_session_data: Dict[str, Any], - user_context: Union[Dict[str, Any], None] = None, + user_context: Optional[Dict[str, Any]] = None, ) -> None: pass @@ -389,55 +383,57 @@ async def update_session_data( async def update_access_token_payload( self, new_access_token_payload: Dict[str, Any], - user_context: Dict[str, Any], + user_context: Optional[Dict[str, Any]] = None, ) -> None: """DEPRECATED: Use merge_into_access_token_payload instead""" @abstractmethod async def merge_into_access_token_payload( - self, access_token_payload_update: Dict[str, Any], user_context: Dict[str, Any] + self, + access_token_payload_update: JSONObject, + user_context: Optional[Dict[str, Any]] = None, ) -> None: pass @abstractmethod - def get_user_id(self, user_context: Union[Dict[str, Any], None] = None) -> str: + def get_user_id(self, user_context: Optional[Dict[str, Any]] = None) -> str: pass @abstractmethod def get_access_token_payload( - self, user_context: Union[Dict[str, Any], None] = None + self, user_context: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: pass @abstractmethod - def get_handle(self, user_context: Union[Dict[str, Any], None] = None) -> str: + def get_handle(self, user_context: Optional[Dict[str, Any]] = None) -> str: pass @abstractmethod - def get_access_token(self, user_context: Union[Dict[str, Any], None] = None) -> str: + def get_access_token(self, user_context: Optional[Dict[str, Any]] = None) -> str: pass @abstractmethod async def get_time_created( - self, user_context: Union[Dict[str, Any], None] = None + self, user_context: Optional[Dict[str, Any]] = None ) -> int: pass @abstractmethod - async def get_expiry(self, user_context: Union[Dict[str, Any], None] = None) -> int: + async def get_expiry(self, user_context: Optional[Dict[str, Any]] = None) -> int: pass @abstractmethod async def assert_claims( self, claim_validators: List[SessionClaimValidator], - user_context: Union[Dict[str, Any], None] = None, + user_context: Optional[Dict[str, Any]] = None, ) -> None: pass @abstractmethod async def fetch_and_set_claim( - self, claim: SessionClaim[Any], user_context: Union[Dict[str, Any], None] = None + self, claim: SessionClaim[Any], user_context: Optional[Dict[str, Any]] = None ) -> None: pass @@ -446,13 +442,13 @@ async def set_claim_value( self, claim: SessionClaim[_T], value: _T, - user_context: Union[Dict[str, Any], None] = None, + user_context: Optional[Dict[str, Any]] = None, ) -> None: pass @abstractmethod async def get_claim_value( - self, claim: SessionClaim[_T], user_context: Union[Dict[str, Any], None] = None + self, claim: SessionClaim[_T], user_context: Optional[Dict[str, Any]] = None ) -> Union[_T, None]: pass @@ -460,15 +456,15 @@ async def get_claim_value( async def remove_claim( self, claim: SessionClaim[Any], - user_context: Union[Dict[str, Any], None] = None, + user_context: Optional[Dict[str, Any]] = None, ) -> None: pass - def sync_get_expiry(self, user_context: Union[Dict[str, Any], None] = None) -> int: + def sync_get_expiry(self, user_context: Optional[Dict[str, Any]] = None) -> int: return sync(self.get_expiry(user_context)) def sync_revoke_session( - self, user_context: Union[Dict[str, Any], None] = None + self, user_context: Optional[Dict[str, Any]] = None ) -> None: return sync(self.revoke_session(user_context=user_context)) @@ -478,12 +474,14 @@ def sync_get_session_data( return sync(self.get_session_data(user_context)) def sync_get_time_created( - self, user_context: Union[Dict[str, Any], None] = None + self, user_context: Optional[Dict[str, Any]] = None ) -> int: return sync(self.get_time_created(user_context)) def sync_merge_into_access_token_payload( - self, access_token_payload_update: Dict[str, Any], user_context: Dict[str, Any] + self, + access_token_payload_update: Dict[str, Any], + user_context: Optional[Dict[str, Any]] = None, ) -> None: return sync( self.merge_into_access_token_payload( @@ -494,7 +492,7 @@ def sync_merge_into_access_token_payload( def sync_update_access_token_payload( self, new_access_token_payload: Dict[str, Any], - user_context: Dict[str, Any], + user_context: Optional[Dict[str, Any]] = None, ) -> None: return sync( self.update_access_token_payload(new_access_token_payload, user_context) @@ -503,7 +501,7 @@ def sync_update_access_token_payload( def sync_update_session_data( self, new_session_data: Dict[str, Any], - user_context: Union[Dict[str, Any], None] = None, + user_context: Optional[Dict[str, Any]] = None, ) -> None: return sync(self.update_session_data(new_session_data, user_context)) @@ -511,27 +509,30 @@ def sync_update_session_data( def sync_assert_claims( self, claim_validators: List[SessionClaimValidator], - user_context: Dict[str, Any], + user_context: Optional[Dict[str, Any]] = None, ) -> None: return sync(self.assert_claims(claim_validators, user_context)) def sync_fetch_and_set_claim( - self, claim: SessionClaim[Any], user_context: Dict[str, Any] + self, claim: SessionClaim[Any], user_context: Optional[Dict[str, Any]] = None ) -> None: return sync(self.fetch_and_set_claim(claim, user_context)) def sync_set_claim_value( - self, claim: SessionClaim[_T], value: _T, user_context: Dict[str, Any] + self, + claim: SessionClaim[_T], + value: _T, + user_context: Optional[Dict[str, Any]] = None, ) -> None: return sync(self.set_claim_value(claim, value, user_context)) def sync_get_claim_value( - self, claim: SessionClaim[_T], user_context: Dict[str, Any] + self, claim: SessionClaim[_T], user_context: Optional[Dict[str, Any]] = None ) -> Union[_T, None]: return sync(self.get_claim_value(claim, user_context)) def sync_remove_claim( - self, claim: SessionClaim[Any], user_context: Dict[str, Any] + self, claim: SessionClaim[Any], user_context: Optional[Dict[str, Any]] = None ) -> None: return sync(self.remove_claim(claim, user_context)) @@ -571,13 +572,13 @@ def add_to_payload_( @abstractmethod def remove_from_payload_by_merge_( - self, payload: JSONObject, user_context: Dict[str, Any] + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None ) -> JSONObject: """Removes the claim from the payload by setting it to None, so merge_into_access_token_payload can clear it""" @abstractmethod def remove_from_payload( - self, payload: JSONObject, user_context: Dict[str, Any] + self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None ) -> JSONObject: """Removes the claim from the payload, by cloning and updating the entire object.""" @@ -588,7 +589,7 @@ def get_value_from_payload( """Gets the value of the claim stored in the payload""" async def build( - self, user_id: str, user_context: Union[Dict[str, Any], None] = None + self, user_id: str, user_context: Optional[Dict[str, Any]] = None ) -> JSONObject: if user_context is None: user_context = {} diff --git a/supertokens_python/recipe/session/recipe.py b/supertokens_python/recipe/session/recipe.py index 70d5df661..59893f834 100644 --- a/supertokens_python/recipe/session/recipe.py +++ b/supertokens_python/recipe/session/recipe.py @@ -26,7 +26,6 @@ TokenTheftError, UnauthorisedError, InvalidClaimsError, - TryRefreshTokenError, ) from ...types import MaybeAwaitable @@ -240,14 +239,11 @@ async def handle_error( return await self.config.error_handlers.on_invalid_claim( self, request, err.payload, response ) - if isinstance(err, TryRefreshTokenError): - log_debug_message("errorHandler: returning TRY_REFRESH_TOKEN") - return await self.config.error_handlers.on_try_refresh_token( - request, str(err), response - ) - # TODO: Is raising err okay? - raise err + log_debug_message("errorHandler: returning TRY_REFRESH_TOKEN") + return await self.config.error_handlers.on_try_refresh_token( + request, str(err), response + ) def get_all_cors_headers(self) -> List[str]: cors_headers = get_cors_allowed_headers() diff --git a/supertokens_python/recipe/session/recipe_implementation.py b/supertokens_python/recipe/session/recipe_implementation.py index b3bc97066..f783fd409 100644 --- a/supertokens_python/recipe/session/recipe_implementation.py +++ b/supertokens_python/recipe/session/recipe_implementation.py @@ -14,7 +14,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Dict, Optional, Callable +from typing import TYPE_CHECKING, Any, Dict, Optional from supertokens_python.framework.request import BaseRequest from supertokens_python.logger import log_debug_message from supertokens_python.normalised_url_path import NormalisedURLPath @@ -236,12 +236,6 @@ async def get_session( request: BaseRequest, anti_csrf_check: Union[bool, None], session_required: bool, - override_global_claim_validators: Optional[ - Callable[ - [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], - MaybeAwaitable[List[SessionClaimValidator]], - ] - ], user_context: Dict[str, Any], ) -> Optional[SessionContainer]: log_debug_message("getSession: Started") diff --git a/supertokens_python/recipe/session/utils.py b/supertokens_python/recipe/session/utils.py index abe24fd9e..86acd101d 100644 --- a/supertokens_python/recipe/session/utils.py +++ b/supertokens_python/recipe/session/utils.py @@ -137,7 +137,7 @@ def __init__( Union[BaseResponse, Awaitable[BaseResponse]], ], on_invalid_claim: Callable[ - [BaseRequest, List[Dict[str, Any]], BaseResponse], + [BaseRequest, List[ClaimValidationError], BaseResponse], Union[BaseResponse, Awaitable[BaseResponse]], ], ): @@ -185,7 +185,7 @@ async def on_invalid_claim( self, recipe: SessionRecipe, request: BaseRequest, - claim_validation_errors: List[Dict[str, Any]], + claim_validation_errors: List[ClaimValidationError], response: BaseResponse, ): _ = recipe @@ -214,7 +214,7 @@ def __init__( ] = None, on_invalid_claim: Union[ Callable[ - [BaseRequest, List[Dict[str, Any]], BaseResponse], + [BaseRequest, List[ClaimValidationError], BaseResponse], Union[BaseResponse, Awaitable[BaseResponse]], ], None, @@ -275,13 +275,23 @@ async def default_token_theft_detected_callback( async def default_invalid_claim_callback( _: BaseRequest, - claim_validation_errors: List[Dict[str, Any]], + claim_validation_errors: List[ClaimValidationError], response: BaseResponse, ) -> BaseResponse: from .recipe import SessionRecipe + payload: List[Dict[str, Any]] = [] + + for p in claim_validation_errors: + res = ( + p.__dict__.copy() + ) # Must be JSON serializable as it will be used in response + if p.reason is None: + res.pop("reason") + payload.append(res) + return send_non_200_response( - {"message": "invalid claim", "claimValidationErrors": claim_validation_errors}, + {"message": "invalid claim", "claimValidationErrors": payload}, SessionRecipe.get_instance().config.invalid_claim_status_code, response, ) diff --git a/supertokens_python/recipe/session/with_jwt/recipe_implementation.py b/supertokens_python/recipe/session/with_jwt/recipe_implementation.py index 6e7f92078..6e9144ef2 100644 --- a/supertokens_python/recipe/session/with_jwt/recipe_implementation.py +++ b/supertokens_python/recipe/session/with_jwt/recipe_implementation.py @@ -13,13 +13,10 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Union, Optional, List, Callable - -# TODO: Missing changes for session_class inside with_jwt? supertokens/supertokens-node#278 (files) +from typing import TYPE_CHECKING, Any, Dict, Union, Optional from jwt import decode -from supertokens_python.types import MaybeAwaitable from supertokens_python.utils import get_timestamp_ms from .constants import ACCESS_TOKEN_PAYLOAD_JWT_PROPERTY_NAME_KEY @@ -32,7 +29,6 @@ RecipeInterface, SessionContainer, SessionInformationResult, - SessionClaimValidator, ) from supertokens_python.framework.types import BaseRequest @@ -93,19 +89,12 @@ async def get_session( request: BaseRequest, anti_csrf_check: Union[bool, None], session_required: bool, - override_global_claim_validators: Optional[ - Callable[ - [List[SessionClaimValidator], SessionContainer, Dict[str, Any]], - MaybeAwaitable[List[SessionClaimValidator]], - ] - ], user_context: Dict[str, Any], ) -> Union[SessionContainer, None]: session_container = await og_get_session( request, anti_csrf_check, session_required, - override_global_claim_validators, user_context, ) if session_container is None: @@ -196,9 +185,11 @@ async def jwt_aware_update_access_token_payload( async def update_access_token_payload( session_handle: str, - new_access_token_payload: Dict[str, Any], + new_access_token_payload: Optional[Dict[str, Any]], user_context: Dict[str, Any], ) -> bool: + if new_access_token_payload is None: + new_access_token_payload = {} session_information = await original_implementation.get_session_information( session_handle, user_context diff --git a/supertokens_python/recipe/session/with_jwt/session_class.py b/supertokens_python/recipe/session/with_jwt/session_class.py index 3af280bfe..709a60796 100644 --- a/supertokens_python/recipe/session/with_jwt/session_class.py +++ b/supertokens_python/recipe/session/with_jwt/session_class.py @@ -15,7 +15,7 @@ from __future__ import annotations from math import ceil -from typing import TYPE_CHECKING, Any, Dict, Union +from typing import TYPE_CHECKING, Any, Dict, Union, Optional from jwt import decode from supertokens_python.recipe.session.with_jwt.constants import ( @@ -40,11 +40,14 @@ def get_session_with_jwt( original_update_access_token_payload = original_session.update_access_token_payload async def update_access_token_payload( - new_access_token_payload: Dict[str, Any], + new_access_token_payload: Optional[Dict[str, Any]], user_context: Union[None, Dict[str, Any]] = None, ) -> None: if user_context is None: user_context = {} + if new_access_token_payload is None: + new_access_token_payload = {} + access_token_payload = original_session.get_access_token_payload() if ACCESS_TOKEN_PAYLOAD_JWT_PROPERTY_NAME_KEY not in access_token_payload: diff --git a/supertokens_python/recipe/thirdparty/recipe.py b/supertokens_python/recipe/thirdparty/recipe.py index 73062e6ad..24f535edc 100644 --- a/supertokens_python/recipe/thirdparty/recipe.py +++ b/supertokens_python/recipe/thirdparty/recipe.py @@ -33,7 +33,7 @@ from .utils import SignInAndUpFeature, InputOverrideConfig from supertokens_python.exceptions import SuperTokensError, raise_general_exception -from supertokens_python.recipe.emailverification import EmailVerificationRecipe +from supertokens_python.recipe.emailverification.recipe import EmailVerificationRecipe from .api import ( handle_apple_redirect_api, @@ -138,7 +138,7 @@ async def handle_api_request( if request_id == APPLE_REDIRECT_HANDLER: return await handle_apple_redirect_api(self.api_implementation, api_options) - return None # TODO: Node PR returns False, but here signature is different. Verify if this is correct. + return None async def handle_error( self, request: BaseRequest, err: SuperTokensError, response: BaseResponse diff --git a/supertokens_python/recipe/userroles/__init__.py b/supertokens_python/recipe/userroles/__init__.py index 0b6ca45b9..c053698ac 100644 --- a/supertokens_python/recipe/userroles/__init__.py +++ b/supertokens_python/recipe/userroles/__init__.py @@ -13,10 +13,14 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Callable, Union, Optional from . import utils from .recipe import UserRolesRecipe +from . import recipe + +PermissionClaim = recipe.PermissionClaim +UserRoleClaim = recipe.UserRoleClaim if TYPE_CHECKING: from supertokens_python.supertokens import AppInfo @@ -25,6 +29,12 @@ def init( - override: Union[utils.InputOverrideConfig, None] = None + skip_adding_roles_to_access_token: Optional[bool] = None, + skip_adding_permissions_to_access_token: Optional[bool] = None, + override: Union[utils.InputOverrideConfig, None] = None, ) -> Callable[[AppInfo], RecipeModule]: - return UserRolesRecipe.init(override) + return UserRolesRecipe.init( + skip_adding_roles_to_access_token, + skip_adding_permissions_to_access_token, + override, + ) diff --git a/supertokens_python/recipe/userroles/permission_claim.py b/supertokens_python/recipe/userroles/permission_claim.py deleted file mode 100644 index 5c8824769..000000000 --- a/supertokens_python/recipe/userroles/permission_claim.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. -# -# This software is licensed under the Apache License, Version 2.0 (the -# "License") as published by the Apache Software Foundation. -# -# You may not use this file except in compliance with the License. You may -# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -from typing import List, Any, Set, Dict - -from supertokens_python.recipe.session.claim_base_classes.primitive_array_claim import ( - PrimitiveArrayClaim, -) -from supertokens_python.recipe.userroles import UserRolesRecipe -from supertokens_python.recipe.userroles.interfaces import GetPermissionsForRoleOkResult - - -class PermissionClaimClass(PrimitiveArrayClaim[List[str]]): - def __init__(self) -> None: - key = "st-perm" - - async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> List[str]: - recipe = UserRolesRecipe.get_instance() - - user_roles = await recipe.recipe_implementation.get_roles_for_user( - user_id, user_context - ) - - user_permissions: Set[str] = set() - - for role in user_roles.roles: - role_permissions = ( - await recipe.recipe_implementation.get_permissions_for_role( - role, user_context - ) - ) - - if isinstance(role_permissions, GetPermissionsForRoleOkResult): - for permission in role_permissions.permissions: - user_permissions.add(permission) - - return list(user_permissions) - - super().__init__(key, fetch_value) - - -PermissionClaim = PermissionClaimClass() diff --git a/supertokens_python/recipe/userroles/recipe.py b/supertokens_python/recipe/userroles/recipe.py index e6696d3d1..0a3471de6 100644 --- a/supertokens_python/recipe/userroles/recipe.py +++ b/supertokens_python/recipe/userroles/recipe.py @@ -15,7 +15,7 @@ from __future__ import annotations from os import environ -from typing import List, Union +from typing import List, Union, Optional, Dict, Any, Set from supertokens_python.exceptions import SuperTokensError, raise_general_exception from supertokens_python.framework import BaseRequest, BaseResponse @@ -29,7 +29,11 @@ from supertokens_python.supertokens import AppInfo from .exceptions import SuperTokensUserRolesError +from .interfaces import GetPermissionsForRoleOkResult from .utils import InputOverrideConfig +from ..session import SessionRecipe +from ..session.claim_base_classes.primitive_array_claim import PrimitiveArrayClaim +from ...post_init_callbacks import PostSTInitCallbacks class UserRolesRecipe(RecipeModule): @@ -40,10 +44,18 @@ def __init__( self, recipe_id: str, app_info: AppInfo, + skip_adding_roles_to_access_token: Optional[bool] = None, + skip_adding_permissions_to_access_token: Optional[bool] = None, override: Union[InputOverrideConfig, None] = None, ): super().__init__(recipe_id, app_info) - self.config = validate_and_normalise_user_input(self, app_info, override) + self.config = validate_and_normalise_user_input( + self, + app_info, + skip_adding_roles_to_access_token, + skip_adding_permissions_to_access_token, + override, + ) recipe_implementation = RecipeImplementation(Querier.get_instance(recipe_id)) self.recipe_implementation = ( recipe_implementation @@ -51,6 +63,16 @@ def __init__( else self.config.override.functions(recipe_implementation) ) + def callback(): + if self.config.skip_adding_roles_to_access_token is False: + SessionRecipe.get_instance().add_claim_from_other_recipe(UserRoleClaim) + if self.config.skip_adding_permissions_to_access_token is False: + SessionRecipe.get_instance().add_claim_from_other_recipe( + PermissionClaim + ) + + PostSTInitCallbacks.add_post_init_callback(callback) + def is_error_from_this_recipe_based_on_instance(self, err: Exception) -> bool: return isinstance(err, SuperTokensError) and ( isinstance(err, SuperTokensUserRolesError) @@ -78,11 +100,19 @@ def get_all_cors_headers(self) -> List[str]: return [] @staticmethod - def init(override: Union[InputOverrideConfig, None] = None): + def init( + skip_adding_roles_to_access_token: Optional[bool] = None, + skip_adding_permissions_to_access_token: Optional[bool] = None, + override: Union[InputOverrideConfig, None] = None, + ): def func(app_info: AppInfo): if UserRolesRecipe.__instance is None: UserRolesRecipe.__instance = UserRolesRecipe( - UserRolesRecipe.recipe_id, app_info, override + UserRolesRecipe.recipe_id, + app_info, + skip_adding_roles_to_access_token, + skip_adding_permissions_to_access_token, + override, ) return UserRolesRecipe.__instance raise Exception( @@ -107,3 +137,52 @@ def get_instance() -> UserRolesRecipe: raise_general_exception( "Initialisation not done. Did you forget to call the SuperTokens.init or UserRoles.init function?" ) + + +class PermissionClaimClass(PrimitiveArrayClaim[List[str]]): + def __init__(self) -> None: + key = "st-perm" + + async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> List[str]: + recipe = UserRolesRecipe.get_instance() + + user_roles = await recipe.recipe_implementation.get_roles_for_user( + user_id, user_context + ) + + user_permissions: Set[str] = set() + + for role in user_roles.roles: + role_permissions = ( + await recipe.recipe_implementation.get_permissions_for_role( + role, user_context + ) + ) + + if isinstance(role_permissions, GetPermissionsForRoleOkResult): + for permission in role_permissions.permissions: + user_permissions.add(permission) + + return list(user_permissions) + + super().__init__(key, fetch_value) + + +PermissionClaim = PermissionClaimClass() + + +class UserRoleClaimClass(PrimitiveArrayClaim[List[str]]): + def __init__(self) -> None: + key = "st-role" + + async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> List[str]: + recipe = UserRolesRecipe.get_instance() + res = await recipe.recipe_implementation.get_roles_for_user( + user_id, user_context + ) + return res.roles + + super().__init__(key, fetch_value) + + +UserRoleClaim = UserRoleClaimClass() diff --git a/supertokens_python/recipe/userroles/syncio/__init__.py b/supertokens_python/recipe/userroles/syncio/__init__.py index 97c5a2af9..72fe3cb1f 100644 --- a/supertokens_python/recipe/userroles/syncio/__init__.py +++ b/supertokens_python/recipe/userroles/syncio/__init__.py @@ -39,7 +39,7 @@ def add_role_to_user( def remove_user_role( - user_id: str, role: str, user_context: Dict[str, Any] + user_id: str, role: str, user_context: Union[Dict[str, Any], None] = None ) -> Union[RemoveUserRoleOkResult, UnknownRoleError]: from supertokens_python.recipe.userroles.asyncio import remove_user_role @@ -47,7 +47,7 @@ def remove_user_role( def get_roles_for_user( - user_id: str, user_context: Dict[str, Any] + user_id: str, user_context: Union[Dict[str, Any], None] = None ) -> GetRolesForUserOkResult: from supertokens_python.recipe.userroles.asyncio import get_roles_for_user @@ -55,7 +55,7 @@ def get_roles_for_user( def get_users_that_have_role( - role: str, user_context: Dict[str, Any] + role: str, user_context: Union[Dict[str, Any], None] = None ) -> Union[GetUsersThatHaveRoleOkResult, UnknownRoleError]: from supertokens_python.recipe.userroles.asyncio import get_users_that_have_role @@ -63,7 +63,7 @@ def get_users_that_have_role( def create_new_role_or_add_permissions( - role: str, permissions: List[str], user_context: Dict[str, Any] + role: str, permissions: List[str], user_context: Union[Dict[str, Any], None] = None ) -> CreateNewRoleOrAddPermissionsOkResult: from supertokens_python.recipe.userroles.asyncio import ( create_new_role_or_add_permissions, @@ -73,7 +73,7 @@ def create_new_role_or_add_permissions( def get_permissions_for_role( - role: str, user_context: Dict[str, Any] + role: str, user_context: Union[Dict[str, Any], None] = None ) -> Union[GetPermissionsForRoleOkResult, UnknownRoleError]: from supertokens_python.recipe.userroles.asyncio import get_permissions_for_role @@ -81,7 +81,7 @@ def get_permissions_for_role( def remove_permissions_from_role( - role: str, permissions: List[str], user_context: Dict[str, Any] + role: str, permissions: List[str], user_context: Union[Dict[str, Any], None] = None ) -> Union[RemovePermissionsFromRoleOkResult, UnknownRoleError]: from supertokens_python.recipe.userroles.asyncio import remove_permissions_from_role @@ -89,7 +89,7 @@ def remove_permissions_from_role( def get_roles_that_have_permission( - permission: str, user_context: Dict[str, Any] + permission: str, user_context: Union[Dict[str, Any], None] = None ) -> GetRolesThatHavePermissionOkResult: from supertokens_python.recipe.userroles.asyncio import ( get_roles_that_have_permission, @@ -98,13 +98,17 @@ def get_roles_that_have_permission( return sync(get_roles_that_have_permission(permission, user_context)) -def delete_role(role: str, user_context: Dict[str, Any]) -> DeleteRoleOkResult: +def delete_role( + role: str, user_context: Union[Dict[str, Any], None] = None +) -> DeleteRoleOkResult: from supertokens_python.recipe.userroles.asyncio import delete_role return sync(delete_role(role, user_context)) -def get_all_roles(user_context: Dict[str, Any]) -> GetAllRolesOkResult: +def get_all_roles( + user_context: Union[Dict[str, Any], None] = None +) -> GetAllRolesOkResult: from supertokens_python.recipe.userroles.asyncio import get_all_roles return sync(get_all_roles(user_context)) diff --git a/supertokens_python/recipe/userroles/user_role_claim.py b/supertokens_python/recipe/userroles/user_role_claim.py deleted file mode 100644 index b72be18ee..000000000 --- a/supertokens_python/recipe/userroles/user_role_claim.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved. -# -# This software is licensed under the Apache License, Version 2.0 (the -# "License") as published by the Apache Software Foundation. -# -# You may not use this file except in compliance with the License. You may -# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -from typing import List, Any, Dict - -from supertokens_python.recipe.session.claim_base_classes.primitive_array_claim import ( - PrimitiveArrayClaim, -) -from supertokens_python.recipe.userroles import UserRolesRecipe - - -class UserRoleClaimClass(PrimitiveArrayClaim[List[str]]): - def __init__(self) -> None: - key = "st-role" - - async def fetch_value(user_id: str, user_context: Dict[str, Any]) -> List[str]: - recipe = UserRolesRecipe.get_instance() - res = await recipe.recipe_implementation.get_roles_for_user( - user_id, user_context - ) - return res.roles - - super().__init__(key, fetch_value) - - -UserRoleClaim = UserRoleClaimClass() diff --git a/supertokens_python/recipe/userroles/utils.py b/supertokens_python/recipe/userroles/utils.py index c88e12107..b7e45095e 100644 --- a/supertokens_python/recipe/userroles/utils.py +++ b/supertokens_python/recipe/userroles/utils.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Callable, Union, Optional from supertokens_python.recipe.userroles.interfaces import APIInterface, RecipeInterface from supertokens_python.supertokens import AppInfo @@ -34,13 +34,24 @@ def __init__( class UserRolesConfig: - def __init__(self, override: InputOverrideConfig) -> None: + def __init__( + self, + skip_adding_roles_to_access_token: bool, + skip_adding_permissions_to_access_token: bool, + override: InputOverrideConfig, + ) -> None: + self.skip_adding_roles_to_access_token = skip_adding_roles_to_access_token + self.skip_adding_permissions_to_access_token = ( + skip_adding_permissions_to_access_token + ) self.override = override def validate_and_normalise_user_input( _recipe: UserRolesRecipe, _app_info: AppInfo, + skip_adding_roles_to_access_token: Optional[bool] = None, + skip_adding_permissions_to_access_token: Optional[bool] = None, override: Union[InputOverrideConfig, None] = None, ) -> UserRolesConfig: if override is not None and not isinstance(override, InputOverrideConfig): # type: ignore @@ -49,4 +60,13 @@ def validate_and_normalise_user_input( if override is None: override = InputOverrideConfig() - return UserRolesConfig(override=override) + if skip_adding_roles_to_access_token is None: + skip_adding_roles_to_access_token = False + if skip_adding_permissions_to_access_token is None: + skip_adding_permissions_to_access_token = False + + return UserRolesConfig( + skip_adding_roles_to_access_token=skip_adding_roles_to_access_token, + skip_adding_permissions_to_access_token=skip_adding_permissions_to_access_token, + override=override, + ) diff --git a/tests/auth-react/django3x/mysite/utils.py b/tests/auth-react/django3x/mysite/utils.py index c97fa2d70..957203555 100644 --- a/tests/auth-react/django3x/mysite/utils.py +++ b/tests/auth-react/django3x/mysite/utils.py @@ -264,8 +264,8 @@ def override_email_verification_apis( async def email_verify_post( token: str, - api_options: EVAPIOptions, session: Optional[SessionContainer], + api_options: EVAPIOptions, user_context: Dict[str, Any], ): is_general_error = await check_for_general_error( @@ -275,14 +275,14 @@ async def email_verify_post( return GeneralErrorResponse("general error from API email verify") return await original_email_verify_post( token, - api_options, session, + api_options, user_context, ) async def generate_email_verify_token_post( - api_options: EVAPIOptions, session: SessionContainer, + api_options: EVAPIOptions, user_context: Dict[str, Any], ): is_general_error = await check_for_general_error( @@ -293,7 +293,7 @@ async def generate_email_verify_token_post( "general error from API email verification code" ) return await original_generate_email_verify_token_post( - api_options, session, user_context + session, api_options, user_context ) original_implementation_email_verification.email_verify_post = email_verify_post @@ -584,8 +584,8 @@ def override_session_apis(original_implementation: SessionAPIInterface): original_signout_post = original_implementation.signout_post async def signout_post( - api_options: SAPIOptions, session: Optional[SessionContainer], + api_options: SAPIOptions, user_context: Dict[str, Any], ): is_general_error = await check_for_general_error( @@ -593,7 +593,7 @@ async def signout_post( ) if is_general_error: raise Exception("general error from signout API") - return await original_signout_post(api_options, session, user_context) + return await original_signout_post(session, api_options, user_context) original_implementation.signout_post = signout_post return original_implementation diff --git a/tests/auth-react/fastapi-server/app.py b/tests/auth-react/fastapi-server/app.py index ebf8a4993..3978d88af 100644 --- a/tests/auth-react/fastapi-server/app.py +++ b/tests/auth-react/fastapi-server/app.py @@ -317,8 +317,8 @@ def override_email_verification_apis( async def email_verify_post( token: str, - api_options: EVAPIOptions, session: Optional[SessionContainer], + api_options: EVAPIOptions, user_context: Dict[str, Any], ): is_general_error = await check_for_general_error( @@ -327,12 +327,12 @@ async def email_verify_post( if is_general_error: return GeneralErrorResponse("general error from API email verify") return await original_email_verify_post( - token, api_options, session, user_context + token, session, api_options, user_context ) async def generate_email_verify_token_post( - api_options: EVAPIOptions, session: SessionContainer, + api_options: EVAPIOptions, user_context: Dict[str, Any], ): is_general_error = await check_for_general_error( @@ -343,8 +343,8 @@ async def generate_email_verify_token_post( "general error from API email verification code" ) return await original_generate_email_verify_token_post( - api_options, session, + api_options, user_context, ) @@ -636,8 +636,8 @@ def override_session_apis(original_implementation: SessionAPIInterface): original_signout_post = original_implementation.signout_post async def signout_post( - api_options: SAPIOptions, session: Optional[SessionContainer], + api_options: SAPIOptions, user_context: Dict[str, Any], ): is_general_error = await check_for_general_error( @@ -645,7 +645,7 @@ async def signout_post( ) if is_general_error: raise Exception("general error from signout API") - return await original_signout_post(api_options, session, user_context) + return await original_signout_post(session, api_options, user_context) original_implementation.signout_post = signout_post return original_implementation diff --git a/tests/auth-react/flask-server/app.py b/tests/auth-react/flask-server/app.py index 90a5e10ed..f3ac130db 100644 --- a/tests/auth-react/flask-server/app.py +++ b/tests/auth-react/flask-server/app.py @@ -286,8 +286,8 @@ def override_email_verification_apis( async def email_verify_post( token: str, - api_options: EVAPIOptions, session: Optional[SessionContainer], + api_options: EVAPIOptions, user_context: Dict[str, Any], ): is_general_error = await check_for_general_error( @@ -297,14 +297,14 @@ async def email_verify_post( return GeneralErrorResponse("general error from API email verify") return await original_email_verify_post( token, - api_options, session, + api_options, user_context, ) async def generate_email_verify_token_post( - api_options: EVAPIOptions, session: SessionContainer, + api_options: EVAPIOptions, user_context: Dict[str, Any], ): is_general_error = await check_for_general_error( @@ -315,7 +315,7 @@ async def generate_email_verify_token_post( "general error from API email verification code" ) return await original_generate_email_verify_token_post( - api_options, session, user_context + session, api_options, user_context ) original_implementation_email_verification.email_verify_post = email_verify_post @@ -606,8 +606,8 @@ def override_session_apis(original_implementation: SessionAPIInterface): original_signout_post = original_implementation.signout_post async def signout_post( - api_options: SAPIOptions, session: Optional[SessionContainer], + api_options: SAPIOptions, user_context: Dict[str, Any], ): is_general_error = await check_for_general_error( @@ -615,7 +615,7 @@ async def signout_post( ) if is_general_error: raise Exception("general error from signout API") - return await original_signout_post(api_options, session, user_context) + return await original_signout_post(session, api_options, user_context) original_implementation.signout_post = signout_post return original_implementation diff --git a/tests/emailpassword/test_emailverify.py b/tests/emailpassword/test_emailverify.py index 6e1523685..8804382b8 100644 --- a/tests/emailpassword/test_emailverify.py +++ b/tests/emailpassword/test_emailverify.py @@ -27,7 +27,7 @@ from supertokens_python.recipe.emailverification.asyncio import ( create_email_verification_token, is_email_verified, - revoke_email_verification_token, + revoke_email_verification_tokens, unverify_email, verify_email_using_token, ) @@ -600,13 +600,13 @@ def apis_override_email_password(param: APIInterface): async def email_verify_post( token: str, - api_options: APIOptions, session: Optional[SessionContainer], + api_options: APIOptions, user_context: Dict[str, Any], ): nonlocal user_info_from_callback - response = await temp(token, api_options, session, user_context) + response = await temp(token, session, api_options, user_context) if isinstance(response, EmailVerifyPostOkResult): user_info_from_callback = response.user @@ -812,13 +812,13 @@ def apis_override_email_password(param: APIInterface): async def email_verify_post( token: str, - api_options: APIOptions, session: Optional[SessionContainer], + api_options: APIOptions, user_context: Dict[str, Any], ): nonlocal user_info_from_callback - response = await temp(token, api_options, session, user_context) + response = await temp(token, session, api_options, user_context) if isinstance(response, EmailVerifyPostOkResult): user_info_from_callback = response.user @@ -908,13 +908,13 @@ def apis_override_email_password(param: APIInterface): async def email_verify_post( token: str, - api_options: APIOptions, session: Optional[SessionContainer], + api_options: APIOptions, user_context: Dict[str, Any], ): nonlocal user_info_from_callback - response = await temp(token, api_options, session, user_context) + response = await temp(token, session, api_options, user_context) if isinstance(response, EmailVerifyPostOkResult): user_info_from_callback = response.user @@ -1021,7 +1021,7 @@ async def test_the_generate_token_api_with_valid_input_and_then_remove_token( user_id = dict_response["user"]["id"] verify_token = await create_email_verification_token(user_id) - await revoke_email_verification_token(user_id) + await revoke_email_verification_tokens(user_id) if isinstance(verify_token, CreateEmailVerificationTokenOkResult): response = await verify_email_using_token(verify_token.token) diff --git a/tests/userroles/test_add_role_to_user.py b/tests/userroles/test_add_role_to_user.py index 3fec82a51..30efc46a0 100644 --- a/tests/userroles/test_add_role_to_user.py +++ b/tests/userroles/test_add_role_to_user.py @@ -15,7 +15,7 @@ from pytest import mark, skip from supertokens_python.querier import Querier from supertokens_python import InputAppInfo, SupertokensConfig, init -from supertokens_python.recipe import userroles +from supertokens_python.recipe import userroles, session from supertokens_python.utils import is_version_gte from tests.utils import clean_st, reset, setup_st, start_st from supertokens_python.recipe.userroles import asyncio @@ -43,7 +43,7 @@ async def test_add_new_role_to_user(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -81,7 +81,7 @@ async def test_add_duplicate_role_to_user(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -124,7 +124,7 @@ async def test_add_unknown_role_to_user(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() diff --git a/tests/userroles/test_config.py b/tests/userroles/test_config.py index 85b093035..f3531e00c 100644 --- a/tests/userroles/test_config.py +++ b/tests/userroles/test_config.py @@ -15,7 +15,7 @@ from pytest import mark, skip from supertokens_python import InputAppInfo, SupertokensConfig, init from supertokens_python.querier import Querier -from supertokens_python.recipe import userroles +from supertokens_python.recipe import userroles, session from supertokens_python.utils import is_version_gte from tests.utils import clean_st, reset, setup_st, start_st @@ -41,7 +41,7 @@ async def test_recipe_works_without_config(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() diff --git a/tests/userroles/test_create_new_role_or_add_permissions.py b/tests/userroles/test_create_new_role_or_add_permissions.py index 33ee88710..6f76f0055 100644 --- a/tests/userroles/test_create_new_role_or_add_permissions.py +++ b/tests/userroles/test_create_new_role_or_add_permissions.py @@ -15,7 +15,7 @@ from pytest import mark, skip from supertokens_python import InputAppInfo, SupertokensConfig, init from supertokens_python.querier import Querier -from supertokens_python.recipe import userroles +from supertokens_python.recipe import userroles, session from supertokens_python.recipe.userroles import asyncio, interfaces from supertokens_python.utils import is_version_gte from tests.utils import clean_st, reset, setup_st, start_st @@ -42,7 +42,7 @@ async def test_create_new_role(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -69,7 +69,7 @@ async def test_create_new_role_twice(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -101,7 +101,7 @@ async def test_create_new_role_with_permissions(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -134,7 +134,7 @@ async def test_add_permissions_to_new_role_(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -172,7 +172,7 @@ async def test_add_duplicate_permission(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() diff --git a/tests/userroles/test_delete_role.py b/tests/userroles/test_delete_role.py index 583852893..60dd56678 100644 --- a/tests/userroles/test_delete_role.py +++ b/tests/userroles/test_delete_role.py @@ -15,7 +15,7 @@ from pytest import mark, skip from supertokens_python.querier import Querier from supertokens_python import InputAppInfo, SupertokensConfig, init -from supertokens_python.recipe import userroles +from supertokens_python.recipe import userroles, session from supertokens_python.utils import is_version_gte from tests.utils import clean_st, reset, setup_st, start_st from supertokens_python.recipe.userroles import asyncio @@ -43,7 +43,7 @@ async def test_create_and_assign_new_role_and_delete_it(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -90,7 +90,7 @@ async def test_delete_non_existent_role(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() diff --git a/tests/userroles/test_get_permissions_for_role.py b/tests/userroles/test_get_permissions_for_role.py index 6fd71be40..f08f089d6 100644 --- a/tests/userroles/test_get_permissions_for_role.py +++ b/tests/userroles/test_get_permissions_for_role.py @@ -15,7 +15,7 @@ from pytest import mark, skip from supertokens_python import InputAppInfo, SupertokensConfig, init from supertokens_python.querier import Querier -from supertokens_python.recipe import userroles +from supertokens_python.recipe import userroles, session from supertokens_python.recipe.userroles import asyncio, interfaces from supertokens_python.utils import is_version_gte from tests.utils import clean_st, reset, setup_st, start_st @@ -42,7 +42,7 @@ async def test_get_permission_for_a_role(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -75,7 +75,7 @@ async def test_get_permission_for_non_existent_role(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() diff --git a/tests/userroles/test_get_roles_for_user.py b/tests/userroles/test_get_roles_for_user.py index 79ce13fb8..1d56cfbb4 100644 --- a/tests/userroles/test_get_roles_for_user.py +++ b/tests/userroles/test_get_roles_for_user.py @@ -15,7 +15,7 @@ from pytest import mark, skip from supertokens_python.querier import Querier from supertokens_python import InputAppInfo, SupertokensConfig, init -from supertokens_python.recipe import userroles +from supertokens_python.recipe import userroles, session from supertokens_python.utils import is_version_gte from tests.utils import clean_st, reset, setup_st, start_st from supertokens_python.recipe.userroles import asyncio @@ -43,7 +43,7 @@ async def test_get_roles_for_user(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() diff --git a/tests/userroles/test_get_roles_that_have_permissions.py b/tests/userroles/test_get_roles_that_have_permissions.py index 8c69166b8..c8d7167e4 100644 --- a/tests/userroles/test_get_roles_that_have_permissions.py +++ b/tests/userroles/test_get_roles_that_have_permissions.py @@ -15,7 +15,7 @@ from pytest import mark, skip from supertokens_python.querier import Querier from supertokens_python import InputAppInfo, SupertokensConfig, init -from supertokens_python.recipe import userroles +from supertokens_python.recipe import userroles, session from supertokens_python.utils import is_version_gte from tests.utils import clean_st, reset, setup_st, start_st from supertokens_python.recipe.userroles import asyncio @@ -43,7 +43,7 @@ async def test_get_roles_for_that_have_permission(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -77,7 +77,7 @@ async def test_get_roles_for_unknown_permission(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() diff --git a/tests/userroles/test_get_users_that_have_role.py b/tests/userroles/test_get_users_that_have_role.py index 4def281d6..7254ed3f9 100644 --- a/tests/userroles/test_get_users_that_have_role.py +++ b/tests/userroles/test_get_users_that_have_role.py @@ -15,7 +15,7 @@ from pytest import mark, skip from supertokens_python.querier import Querier from supertokens_python import InputAppInfo, SupertokensConfig, init -from supertokens_python.recipe import userroles +from supertokens_python.recipe import userroles, session from supertokens_python.utils import is_version_gte from tests.utils import clean_st, reset, setup_st, start_st from supertokens_python.recipe.userroles import asyncio @@ -43,7 +43,7 @@ async def test_get_users_that_have_role(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -82,7 +82,7 @@ async def test_get_users_for_unknown_role(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() diff --git a/tests/userroles/test_remove_permissions_from_role.py b/tests/userroles/test_remove_permissions_from_role.py index d67c6051f..e2339cc07 100644 --- a/tests/userroles/test_remove_permissions_from_role.py +++ b/tests/userroles/test_remove_permissions_from_role.py @@ -15,7 +15,7 @@ from pytest import mark, skip from supertokens_python.querier import Querier from supertokens_python import InputAppInfo, SupertokensConfig, init -from supertokens_python.recipe import userroles +from supertokens_python.recipe import userroles, session from supertokens_python.utils import is_version_gte from tests.utils import clean_st, reset, setup_st, start_st from supertokens_python.recipe.userroles import asyncio @@ -43,7 +43,7 @@ async def test_remove_permissions_from_a_role(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -80,7 +80,7 @@ async def test_remove_permissions_from_unknown_role(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() diff --git a/tests/userroles/test_remove_user_role.py b/tests/userroles/test_remove_user_role.py index 084971330..21b15ae79 100644 --- a/tests/userroles/test_remove_user_role.py +++ b/tests/userroles/test_remove_user_role.py @@ -15,7 +15,7 @@ from pytest import mark, skip from supertokens_python.querier import Querier from supertokens_python import InputAppInfo, SupertokensConfig, init -from supertokens_python.recipe import userroles +from supertokens_python.recipe import userroles, session from supertokens_python.utils import is_version_gte from tests.utils import clean_st, reset, setup_st, start_st from supertokens_python.recipe.userroles import asyncio @@ -43,7 +43,7 @@ async def test_remove_role_from_a_user(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -91,7 +91,7 @@ async def test_remove_unassigned_role_from_user(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st() @@ -124,7 +124,7 @@ async def test_remove_non_existent_role_from_user(): website_domain="supertokens.io", ), framework="fastapi", - recipe_list=[userroles.init()], + recipe_list=[userroles.init(), session.init()], ) start_st()