3939from .exceptions import raise_general_exception
4040from .process_state import AllowedProcessStates , ProcessState
4141from .utils import find_max_version , is_4xx_error , is_5xx_error
42+ from sniffio import AsyncLibraryNotFoundError
43+ from supertokens_python .async_to_sync_wrapper import create_or_get_event_loop
4244
4345
4446class Querier :
@@ -71,6 +73,35 @@ def get_hosts_alive_for_testing():
7173 raise_general_exception ("calling testing function in non testing env" )
7274 return Querier .__hosts_alive_for_testing
7375
76+ async def api_request (
77+ self ,
78+ url : str ,
79+ method : str ,
80+ attempts_remaining : int ,
81+ * args : Any ,
82+ ** kwargs : Any ,
83+ ) -> Response :
84+ if attempts_remaining == 0 :
85+ raise_general_exception ("Retry request failed" )
86+
87+ try :
88+ async with AsyncClient () as client :
89+ if method == "GET" :
90+ return await client .get (url , * args , ** kwargs ) # type: ignore
91+ if method == "POST" :
92+ return await client .post (url , * args , ** kwargs ) # type: ignore
93+ if method == "PUT" :
94+ return await client .put (url , * args , ** kwargs ) # type: ignore
95+ if method == "DELETE" :
96+ return await client .delete (url , * args , ** kwargs ) # type: ignore
97+ raise Exception ("Shouldn't come here" )
98+ except AsyncLibraryNotFoundError :
99+ # Retry
100+ loop = create_or_get_event_loop ()
101+ return loop .run_until_complete (
102+ self .api_request (url , method , attempts_remaining - 1 , * args , ** kwargs )
103+ )
104+
74105 async def get_api_version (self ):
75106 if Querier .api_version is not None :
76107 return Querier .api_version
@@ -79,12 +110,11 @@ async def get_api_version(self):
79110 AllowedProcessStates .CALLING_SERVICE_IN_GET_API_VERSION
80111 )
81112
82- async def f (url : str ) -> Response :
113+ async def f (url : str , method : str ) -> Response :
83114 headers = {}
84115 if Querier .__api_key is not None :
85116 headers = {API_KEY_HEADER : Querier .__api_key }
86- async with AsyncClient () as client :
87- return await client .get (url , headers = headers ) # type:ignore
117+ return await self .api_request (url , method , 2 , headers = headers )
88118
89119 response = await self .__send_request_helper (
90120 NormalisedURLPath (API_VERSION ), "GET" , f , len (self .__hosts )
@@ -134,13 +164,14 @@ async def send_get_request(
134164 if params is None :
135165 params = {}
136166
137- async def f (url : str ) -> Response :
138- async with AsyncClient () as client :
139- return await client .get ( # type:ignore
140- url ,
141- params = params ,
142- headers = await self .__get_headers_with_api_version (path ),
143- )
167+ async def f (url : str , method : str ) -> Response :
168+ return await self .api_request (
169+ url ,
170+ method ,
171+ 2 ,
172+ headers = await self .__get_headers_with_api_version (path ),
173+ params = params ,
174+ )
144175
145176 return await self .__send_request_helper (path , "GET" , f , len (self .__hosts ))
146177
@@ -163,9 +194,14 @@ async def send_post_request(
163194 headers = await self .__get_headers_with_api_version (path )
164195 headers ["content-type" ] = "application/json; charset=utf-8"
165196
166- async def f (url : str ) -> Response :
167- async with AsyncClient () as client :
168- return await client .post (url , json = data , headers = headers ) # type: ignore
197+ async def f (url : str , method : str ) -> Response :
198+ return await self .api_request (
199+ url ,
200+ method ,
201+ 2 ,
202+ headers = await self .__get_headers_with_api_version (path ),
203+ json = data ,
204+ )
169205
170206 return await self .__send_request_helper (path , "POST" , f , len (self .__hosts ))
171207
@@ -175,13 +211,14 @@ async def send_delete_request(
175211 if params is None :
176212 params = {}
177213
178- async def f (url : str ) -> Response :
179- async with AsyncClient () as client :
180- return await client .delete ( # type:ignore
181- url ,
182- params = params ,
183- headers = await self .__get_headers_with_api_version (path ),
184- )
214+ async def f (url : str , method : str ) -> Response :
215+ return await self .api_request (
216+ url ,
217+ method ,
218+ 2 ,
219+ headers = await self .__get_headers_with_api_version (path ),
220+ params = params ,
221+ )
185222
186223 return await self .__send_request_helper (path , "DELETE" , f , len (self .__hosts ))
187224
@@ -194,9 +231,8 @@ async def send_put_request(
194231 headers = await self .__get_headers_with_api_version (path )
195232 headers ["content-type" ] = "application/json; charset=utf-8"
196233
197- async def f (url : str ) -> Response :
198- async with AsyncClient () as client :
199- return await client .put (url , json = data , headers = headers ) # type: ignore
234+ async def f (url : str , method : str ) -> Response :
235+ return await self .api_request (url , method , 2 , headers = headers , json = data )
200236
201237 return await self .__send_request_helper (path , "PUT" , f , len (self .__hosts ))
202238
@@ -223,7 +259,7 @@ async def __send_request_helper(
223259 self ,
224260 path : NormalisedURLPath ,
225261 method : str ,
226- http_function : Callable [[str ], Awaitable [Response ]],
262+ http_function : Callable [[str , str ], Awaitable [Response ]],
227263 no_of_tries : int ,
228264 retry_info_map : Optional [Dict [str , int ]] = None ,
229265 ) -> Any :
@@ -253,7 +289,7 @@ async def __send_request_helper(
253289 ProcessState .get_instance ().add_state (
254290 AllowedProcessStates .CALLING_SERVICE_IN_REQUEST_HELPER
255291 )
256- response = await http_function (url )
292+ response = await http_function (url , method )
257293 if ("SUPERTOKENS_ENV" in environ ) and (
258294 environ ["SUPERTOKENS_ENV" ] == "testing"
259295 ):
0 commit comments