Skip to content

Commit 2a5df36

Browse files
Merge pull request #451 from supertokens/nest-asyncio-config
feat: Use nest-asyncio when configured with env var
2 parents 258790a + 45ae5c8 commit 2a5df36

File tree

9 files changed

+120
-66
lines changed

9 files changed

+120
-66
lines changed

.circleci/config_continue.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,22 @@ jobs:
7979
- run: make with-django2x
8080
- run: (cd .circleci/ && ./websiteDjango2x.sh)
8181
- slack/status
82+
test-website-flask-nest-asyncio:
83+
docker:
84+
- image: rishabhpoddar/supertokens_python_driver_testing
85+
resource_class: large
86+
environment:
87+
SUPERTOKENS_NEST_ASYNCIO: "1"
88+
steps:
89+
- checkout
90+
- run: update-alternatives --install "/usr/bin/java" "java" "/usr/java/jdk-15.0.1/bin/java" 2
91+
- run: update-alternatives --install "/usr/bin/javac" "javac" "/usr/java/jdk-15.0.1/bin/javac" 2
92+
- run: git config --global url."https://github.com/".insteadOf ssh://[email protected]/
93+
- run: echo "127.0.0.1 localhost.org" >> /etc/hosts
94+
- run: make with-flask
95+
- run: python -m pip install nest-asyncio
96+
- run: (cd .circleci/ && ./websiteFlask.sh)
97+
- slack/status
8298
test-authreact-fastapi:
8399
docker:
84100
- image: rishabhpoddar/supertokens_python_driver_testing

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [unreleased]
1010

11+
## [0.16.2] - 2023-09-20
12+
13+
- Allow use of [nest-asyncio](https://pypi.org/project/nest-asyncio/) when env var `SUPERTOKENS_NEST_ASYNCIO=1`.
14+
- Retry Querier request on `AsyncLibraryNotFoundError`
15+
1116
## [0.16.1] - 2023-09-19
1217
- Handle AWS Public URLs (ending with `.amazonaws.com`) separately while extracting TLDs for SameSite attribute.
1318

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171
setup(
7272
name="supertokens_python",
73-
version="0.16.1",
73+
version="0.16.2",
7474
author="SuperTokens",
7575
license="Apache 2.0",
7676
author_email="[email protected]",

supertokens_python/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15-
from typing import Any, Callable, Dict, List, Optional, Union
15+
from typing import Any, Callable, Dict, List, Optional
1616

1717
from typing_extensions import Literal
1818

@@ -32,11 +32,16 @@ def init(
3232
framework: Literal["fastapi", "flask", "django"],
3333
supertokens_config: SupertokensConfig,
3434
recipe_list: List[Callable[[supertokens.AppInfo], RecipeModule]],
35-
mode: Union[Literal["asgi", "wsgi"], None] = None,
36-
telemetry: Union[bool, None] = None,
35+
mode: Optional[Literal["asgi", "wsgi"]] = None,
36+
telemetry: Optional[bool] = None,
3737
):
3838
return Supertokens.init(
39-
app_info, framework, supertokens_config, recipe_list, mode, telemetry
39+
app_info,
40+
framework,
41+
supertokens_config,
42+
recipe_list,
43+
mode,
44+
telemetry,
4045
)
4146

4247

supertokens_python/async_to_sync_wrapper.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,32 @@
1414

1515
import asyncio
1616
from typing import Any, Coroutine, TypeVar
17+
from os import getenv
1718

1819
_T = TypeVar("_T")
1920

2021

21-
def check_event_loop():
22+
def nest_asyncio_enabled():
23+
return getenv("SUPERTOKENS_NEST_ASYNCIO", "") == "1"
24+
25+
26+
def create_or_get_event_loop() -> asyncio.AbstractEventLoop:
2227
try:
23-
asyncio.get_event_loop()
24-
except RuntimeError as ex:
28+
return asyncio.get_event_loop()
29+
except Exception as ex:
2530
if "There is no current event loop in thread" in str(ex):
2631
loop = asyncio.new_event_loop()
32+
33+
if nest_asyncio_enabled():
34+
import nest_asyncio # type: ignore
35+
36+
nest_asyncio.apply(loop) # type: ignore
37+
2738
asyncio.set_event_loop(loop)
39+
return loop
40+
raise ex
2841

2942

3043
def sync(co: Coroutine[Any, Any, _T]) -> _T:
31-
check_event_loop()
32-
loop = asyncio.get_event_loop()
44+
loop = create_or_get_event_loop()
3345
return loop.run_until_complete(co)

supertokens_python/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
SUPPORTED_CDI_VERSIONS = ["3.0"]
17-
VERSION = "0.16.1"
17+
VERSION = "0.16.2"
1818
TELEMETRY = "/telemetry"
1919
USER_COUNT = "/users/count"
2020
USER_DELETE = "/user/remove"

supertokens_python/querier.py

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from .exceptions import raise_general_exception
4040
from .process_state import AllowedProcessStates, ProcessState
4141
from .utils import find_max_version, is_4xx_error, is_5xx_error
42+
from sniffio import AsyncLibraryNotFoundError
43+
from supertokens_python.async_to_sync_wrapper import create_or_get_event_loop
4244

4345

4446
class Querier:
@@ -71,6 +73,35 @@ def get_hosts_alive_for_testing():
7173
raise_general_exception("calling testing function in non testing env")
7274
return Querier.__hosts_alive_for_testing
7375

76+
async def api_request(
77+
self,
78+
url: str,
79+
method: str,
80+
attempts_remaining: int,
81+
*args: Any,
82+
**kwargs: Any,
83+
) -> Response:
84+
if attempts_remaining == 0:
85+
raise_general_exception("Retry request failed")
86+
87+
try:
88+
async with AsyncClient() as client:
89+
if method == "GET":
90+
return await client.get(url, *args, **kwargs) # type: ignore
91+
if method == "POST":
92+
return await client.post(url, *args, **kwargs) # type: ignore
93+
if method == "PUT":
94+
return await client.put(url, *args, **kwargs) # type: ignore
95+
if method == "DELETE":
96+
return await client.delete(url, *args, **kwargs) # type: ignore
97+
raise Exception("Shouldn't come here")
98+
except AsyncLibraryNotFoundError:
99+
# Retry
100+
loop = create_or_get_event_loop()
101+
return loop.run_until_complete(
102+
self.api_request(url, method, attempts_remaining - 1, *args, **kwargs)
103+
)
104+
74105
async def get_api_version(self):
75106
if Querier.api_version is not None:
76107
return Querier.api_version
@@ -79,12 +110,11 @@ async def get_api_version(self):
79110
AllowedProcessStates.CALLING_SERVICE_IN_GET_API_VERSION
80111
)
81112

82-
async def f(url: str) -> Response:
113+
async def f(url: str, method: str) -> Response:
83114
headers = {}
84115
if Querier.__api_key is not None:
85116
headers = {API_KEY_HEADER: Querier.__api_key}
86-
async with AsyncClient() as client:
87-
return await client.get(url, headers=headers) # type:ignore
117+
return await self.api_request(url, method, 2, headers=headers)
88118

89119
response = await self.__send_request_helper(
90120
NormalisedURLPath(API_VERSION), "GET", f, len(self.__hosts)
@@ -134,13 +164,14 @@ async def send_get_request(
134164
if params is None:
135165
params = {}
136166

137-
async def f(url: str) -> Response:
138-
async with AsyncClient() as client:
139-
return await client.get( # type:ignore
140-
url,
141-
params=params,
142-
headers=await self.__get_headers_with_api_version(path),
143-
)
167+
async def f(url: str, method: str) -> Response:
168+
return await self.api_request(
169+
url,
170+
method,
171+
2,
172+
headers=await self.__get_headers_with_api_version(path),
173+
params=params,
174+
)
144175

145176
return await self.__send_request_helper(path, "GET", f, len(self.__hosts))
146177

@@ -163,9 +194,14 @@ async def send_post_request(
163194
headers = await self.__get_headers_with_api_version(path)
164195
headers["content-type"] = "application/json; charset=utf-8"
165196

166-
async def f(url: str) -> Response:
167-
async with AsyncClient() as client:
168-
return await client.post(url, json=data, headers=headers) # type: ignore
197+
async def f(url: str, method: str) -> Response:
198+
return await self.api_request(
199+
url,
200+
method,
201+
2,
202+
headers=await self.__get_headers_with_api_version(path),
203+
json=data,
204+
)
169205

170206
return await self.__send_request_helper(path, "POST", f, len(self.__hosts))
171207

@@ -175,13 +211,14 @@ async def send_delete_request(
175211
if params is None:
176212
params = {}
177213

178-
async def f(url: str) -> Response:
179-
async with AsyncClient() as client:
180-
return await client.delete( # type:ignore
181-
url,
182-
params=params,
183-
headers=await self.__get_headers_with_api_version(path),
184-
)
214+
async def f(url: str, method: str) -> Response:
215+
return await self.api_request(
216+
url,
217+
method,
218+
2,
219+
headers=await self.__get_headers_with_api_version(path),
220+
params=params,
221+
)
185222

186223
return await self.__send_request_helper(path, "DELETE", f, len(self.__hosts))
187224

@@ -194,9 +231,8 @@ async def send_put_request(
194231
headers = await self.__get_headers_with_api_version(path)
195232
headers["content-type"] = "application/json; charset=utf-8"
196233

197-
async def f(url: str) -> Response:
198-
async with AsyncClient() as client:
199-
return await client.put(url, json=data, headers=headers) # type: ignore
234+
async def f(url: str, method: str) -> Response:
235+
return await self.api_request(url, method, 2, headers=headers, json=data)
200236

201237
return await self.__send_request_helper(path, "PUT", f, len(self.__hosts))
202238

@@ -223,7 +259,7 @@ async def __send_request_helper(
223259
self,
224260
path: NormalisedURLPath,
225261
method: str,
226-
http_function: Callable[[str], Awaitable[Response]],
262+
http_function: Callable[[str, str], Awaitable[Response]],
227263
no_of_tries: int,
228264
retry_info_map: Optional[Dict[str, int]] = None,
229265
) -> Any:
@@ -253,7 +289,7 @@ async def __send_request_helper(
253289
ProcessState.get_instance().add_state(
254290
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
255291
)
256-
response = await http_function(url)
292+
response = await http_function(url, method)
257293
if ("SUPERTOKENS_ENV" in environ) and (
258294
environ["SUPERTOKENS_ENV"] == "testing"
259295
):

supertokens_python/supertokens.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def __init__(
148148
framework: Literal["fastapi", "flask", "django"],
149149
supertokens_config: SupertokensConfig,
150150
recipe_list: List[Callable[[AppInfo], RecipeModule]],
151-
mode: Union[Literal["asgi", "wsgi"], None],
152-
telemetry: Union[bool, None],
151+
mode: Optional[Literal["asgi", "wsgi"]],
152+
telemetry: Optional[bool],
153153
):
154154
if not isinstance(app_info, InputAppInfo): # type: ignore
155155
raise ValueError("app_info must be an instance of InputAppInfo")
@@ -215,12 +215,17 @@ def init(
215215
framework: Literal["fastapi", "flask", "django"],
216216
supertokens_config: SupertokensConfig,
217217
recipe_list: List[Callable[[AppInfo], RecipeModule]],
218-
mode: Union[Literal["asgi", "wsgi"], None],
219-
telemetry: Union[bool, None],
218+
mode: Optional[Literal["asgi", "wsgi"]],
219+
telemetry: Optional[bool],
220220
):
221221
if Supertokens.__instance is None:
222222
Supertokens.__instance = Supertokens(
223-
app_info, framework, supertokens_config, recipe_list, mode, telemetry
223+
app_info,
224+
framework,
225+
supertokens_config,
226+
recipe_list,
227+
mode,
228+
telemetry,
224229
)
225230
PostSTInitCallbacks.run_post_init_callbacks()
226231

supertokens_python/utils.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from __future__ import annotations
1616

17-
import asyncio
1817
import json
1918
import threading
2019
import warnings
@@ -27,7 +26,6 @@
2726
Any,
2827
Awaitable,
2928
Callable,
30-
Coroutine,
3129
Dict,
3230
List,
3331
TypeVar,
@@ -39,7 +37,6 @@
3937
from httpx import HTTPStatusError, Response
4038
from tldextract import extract # type: ignore
4139

42-
from supertokens_python.async_to_sync_wrapper import check_event_loop
4340
from supertokens_python.framework.django.framework import DjangoFramework
4441
from supertokens_python.framework.fastapi.framework import FastapiFramework
4542
from supertokens_python.framework.flask.framework import FlaskFramework
@@ -195,28 +192,6 @@ def find_first_occurrence_in_list(
195192
return None
196193

197194

198-
def execute_async(mode: str, func: Callable[[], Coroutine[Any, Any, None]]):
199-
real_mode = None
200-
try:
201-
asyncio.get_running_loop()
202-
real_mode = "asgi"
203-
except RuntimeError:
204-
real_mode = "wsgi"
205-
206-
if mode != real_mode:
207-
warnings.warn(
208-
"Inconsistent mode detected, check if you are using the right asgi / wsgi mode",
209-
category=RuntimeWarning,
210-
)
211-
212-
if real_mode == "wsgi":
213-
asyncio.run(func())
214-
else:
215-
check_event_loop()
216-
loop = asyncio.get_event_loop()
217-
loop.create_task(func())
218-
219-
220195
def frontend_has_interceptor(request: BaseRequest) -> bool:
221196
return get_rid_from_header(request) is not None
222197

0 commit comments

Comments
 (0)