2727)
2828
2929from .._query_context_cache import QueryContextCache
30- from ..auth import AuthByIdToken
31- from ..compat import quote , urlencode
30+ from ..compat import IS_LINUX , quote , urlencode
3231from ..config_manager import CONFIG_MANAGER , _get_default_connection_params
3332from ..connection import DEFAULT_CONFIGURATION
3433from ..connection import SnowflakeConnection as SnowflakeConnectionSync
34+ from ..connection import _get_private_bytes_from_file
3535from ..connection_diagnostic import ConnectionDiagnostic
3636from ..constants import (
3737 ENV_VAR_PARTNER ,
3838 PARAMETER_AUTOCOMMIT ,
3939 PARAMETER_CLIENT_PREFETCH_THREADS ,
40+ PARAMETER_CLIENT_REQUEST_MFA_TOKEN ,
4041 PARAMETER_CLIENT_SESSION_KEEP_ALIVE ,
4142 PARAMETER_CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY ,
43+ PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL ,
4244 PARAMETER_CLIENT_TELEMETRY_ENABLED ,
4345 PARAMETER_CLIENT_VALIDATE_DEFAULT_PARAMETERS ,
4446 PARAMETER_ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1 ,
5355 ER_FAILED_TO_CONNECT_TO_DB ,
5456 ER_INVALID_VALUE ,
5557)
56- from ..network import DEFAULT_AUTHENTICATOR , REQUEST_ID , ReauthenticationRequest
58+ from ..network import (
59+ DEFAULT_AUTHENTICATOR ,
60+ EXTERNAL_BROWSER_AUTHENTICATOR ,
61+ KEY_PAIR_AUTHENTICATOR ,
62+ OAUTH_AUTHENTICATOR ,
63+ REQUEST_ID ,
64+ USR_PWD_MFA_AUTHENTICATOR ,
65+ ReauthenticationRequest ,
66+ )
5767from ..sqlstate import SQLSTATE_CONNECTION_NOT_EXISTS , SQLSTATE_FEATURE_NOT_SUPPORTED
5868from ..telemetry import TelemetryData , TelemetryField
5969from ..time_util import get_time_millis
6070from ..util_text import split_statements
6171from ._cursor import SnowflakeCursor
6272from ._network import SnowflakeRestful
6373from ._time_util import HeartBeatTimer
64- from .auth import Auth , AuthByDefault , AuthByPlugin
74+ from .auth import (
75+ FIRST_PARTY_AUTHENTICATORS ,
76+ Auth ,
77+ AuthByDefault ,
78+ AuthByIdToken ,
79+ AuthByKeyPair ,
80+ AuthByOAuth ,
81+ AuthByOkta ,
82+ AuthByPlugin ,
83+ AuthByUsrPwdMfa ,
84+ AuthByWebBrowser ,
85+ )
6586
6687logger = getLogger (__name__ )
6788
@@ -196,7 +217,6 @@ async def __open_connection(self):
196217 heartbeat_ret = await auth ._rest ._heartbeat ()
197218 logger .debug (heartbeat_ret )
198219 if not heartbeat_ret or not heartbeat_ret .get ("success" ):
199- # TODO: errorhandler could be async?
200220 Error .errorhandler_wrapper (
201221 self ,
202222 None ,
@@ -211,20 +231,94 @@ async def __open_connection(self):
211231
212232 else :
213233 if self .auth_class is not None :
214- raise NotImplementedError (
215- "asyncio support for auth_class is not supported"
216- )
234+ if type (
235+ self .auth_class
236+ ) not in FIRST_PARTY_AUTHENTICATORS and not issubclass (
237+ type (self .auth_class ), AuthByKeyPair
238+ ):
239+ raise TypeError ("auth_class must be a child class of AuthByKeyPair" )
240+ self .auth_class = self .auth_class
217241 elif self ._authenticator == DEFAULT_AUTHENTICATOR :
218242 self .auth_class = AuthByDefault (
219243 password = self ._password ,
220244 timeout = self .login_timeout ,
221245 backoff_generator = self ._backoff_generator ,
222246 )
247+ elif self ._authenticator == EXTERNAL_BROWSER_AUTHENTICATOR :
248+ self ._session_parameters [
249+ PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL
250+ ] = (self ._client_store_temporary_credential if IS_LINUX else True )
251+ auth .read_temporary_credentials (
252+ self .host ,
253+ self .user ,
254+ self ._session_parameters ,
255+ )
256+ # Depending on whether self._rest.id_token is available we do different
257+ # auth_instance
258+ if self ._rest .id_token is None :
259+ self .auth_class = AuthByWebBrowser (
260+ application = self .application ,
261+ protocol = self ._protocol ,
262+ host = self .host ,
263+ port = self .port ,
264+ timeout = self .login_timeout ,
265+ backoff_generator = self ._backoff_generator ,
266+ )
267+ else :
268+ self .auth_class = AuthByIdToken (
269+ id_token = self ._rest .id_token ,
270+ application = self .application ,
271+ protocol = self ._protocol ,
272+ host = self .host ,
273+ port = self .port ,
274+ timeout = self .login_timeout ,
275+ backoff_generator = self ._backoff_generator ,
276+ )
277+
278+ elif self ._authenticator == KEY_PAIR_AUTHENTICATOR :
279+ private_key = self ._private_key
280+
281+ if self ._private_key_file :
282+ private_key = _get_private_bytes_from_file (
283+ self ._private_key_file ,
284+ self ._private_key_file_pwd ,
285+ )
286+
287+ self .auth_class = AuthByKeyPair (
288+ private_key = private_key ,
289+ timeout = self .login_timeout ,
290+ backoff_generator = self ._backoff_generator ,
291+ )
292+ elif self ._authenticator == OAUTH_AUTHENTICATOR :
293+ self .auth_class = AuthByOAuth (
294+ oauth_token = self ._token ,
295+ timeout = self .login_timeout ,
296+ backoff_generator = self ._backoff_generator ,
297+ )
298+ elif self ._authenticator == USR_PWD_MFA_AUTHENTICATOR :
299+ self ._session_parameters [PARAMETER_CLIENT_REQUEST_MFA_TOKEN ] = (
300+ self ._client_request_mfa_token if IS_LINUX else True
301+ )
302+ if self ._session_parameters [PARAMETER_CLIENT_REQUEST_MFA_TOKEN ]:
303+ auth .read_temporary_credentials (
304+ self .host ,
305+ self .user ,
306+ self ._session_parameters ,
307+ )
308+ self .auth_class = AuthByUsrPwdMfa (
309+ password = self ._password ,
310+ mfa_token = self .rest .mfa_token ,
311+ timeout = self .login_timeout ,
312+ backoff_generator = self ._backoff_generator ,
313+ )
223314 else :
224- raise NotImplementedError (
225- f"asyncio support for authenticator is not supported { self ._authenticator } "
315+ # okta URL, e.g., https://<account>.okta.com/
316+ self .auth_class = AuthByOkta (
317+ application = self .application ,
318+ timeout = self .login_timeout ,
319+ backoff_generator = self ._backoff_generator ,
226320 )
227- # TODO: asyncio support for other authenticators
321+
228322 await self .authenticate_with_retry (self .auth_class )
229323
230324 self ._password = None # ensure password won't persist
0 commit comments