diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index ab98dd3702..ef22ea711e 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -5,6 +5,7 @@ on: branches: - master - main + - dev/aio-connector tags: - v* pull_request: @@ -12,6 +13,7 @@ on: - master - main - prep-** + - dev/aio-connector workflow_dispatch: inputs: logLevel: @@ -332,10 +334,101 @@ jobs: .coverage.py${{ env.shortver }}-lambda-ci junit.py${{ env.shortver }}-lambda-ci-dev.xml + test-aio: + name: Test asyncio ${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: build + runs-on: ${{ matrix.os.image_name }} + strategy: + fail-fast: false + matrix: + os: + - image_name: ubuntu-latest + download_name: manylinux_x86_64 + - image_name: macos-latest + download_name: macosx_x86_64 + - image_name: windows-2019 + download_name: win_amd64 + python-version: ["3.10", "3.11", "3.12"] + cloud-provider: [aws, azure, gcp] + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Display Python version + run: python -c "import sys; print(sys.version)" + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/public/parameters_${{ matrix.cloud-provider }}.py.gpg > test/parameters.py + - name: Download wheel(s) + uses: actions/download-artifact@v4 + with: + name: ${{ matrix.os.download_name }}_py${{ matrix.python-version }} + path: dist + - name: Show wheels downloaded + run: ls -lh dist + shell: bash + - name: Upgrade setuptools, pip and wheel + run: python -m pip install -U setuptools pip wheel + - name: Install tox + run: python -m pip install tox>=4 + - name: Run tests + run: python -m tox run -e aio + env: + PYTHON_VERSION: ${{ matrix.python-version }} + cloud_provider: ${{ matrix.cloud-provider }} + PYTEST_ADDOPTS: --color=yes --tb=short + TOX_PARALLEL_NO_SPINNER: 1 + shell: bash + - name: Combine coverages + run: python -m tox run -e coverage --skip-missing-interpreters false + shell: bash + - uses: actions/upload-artifact@v4 + with: + name: coverage_aio_${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + .tox/.coverage + .tox/coverage.xml + + test-unsupporeted-aio: + name: Test unsupported asyncio ${{ matrix.os.download_name }}-${{ matrix.python-version }} + runs-on: ${{ matrix.os.image_name }} + strategy: + fail-fast: false + matrix: + os: + - image_name: ubuntu-latest + download_name: manylinux_x86_64 + python-version: [ "3.8", "3.9" ] + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Display Python version + run: python -c "import sys; print(sys.version)" + - name: Upgrade setuptools, pip and wheel + run: python -m pip install -U setuptools pip wheel + - name: Install tox + run: python -m pip install tox>=4 + - name: Run tests + run: python -m tox run -e aio-unsupported-python + env: + PYTHON_VERSION: ${{ matrix.python-version }} + PYTEST_ADDOPTS: --color=yes --tb=short + TOX_PARALLEL_NO_SPINNER: 1 + shell: bash + combine-coverage: if: ${{ success() || failure() }} name: Combine coverage - needs: [lint, test, test-fips, test-lambda] + needs: [lint, test, test-fips, test-lambda, test-aio] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/ci/test_fips.sh b/ci/test_fips.sh index bc97c9d7f2..b21b044809 100755 --- a/ci/test_fips.sh +++ b/ci/test_fips.sh @@ -21,6 +21,6 @@ python -c "from cryptography.hazmat.backends.openssl import backend;print('Cryp pip freeze cd $CONNECTOR_DIR -pytest -vvv --cov=snowflake.connector --cov-report=xml:coverage.xml test +pytest -vvv --cov=snowflake.connector --cov-report=xml:coverage.xml test --ignore=test/integ/aio --ignore=test/unit/aio deactivate diff --git a/setup.cfg b/setup.cfg index 38c3b3e5d2..d9865ac02c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -91,8 +91,11 @@ development = pytest-timeout pytest-xdist pytzdata + pytest-asyncio pandas = pandas>=1.0.0,<3.0.0 pyarrow secure-local-storage = keyring>=23.1.0,<26.0.0 +aio = + aiohttp diff --git a/src/snowflake/connector/aio/__init__.py b/src/snowflake/connector/aio/__init__.py new file mode 100644 index 0000000000..628bc2abf1 --- /dev/null +++ b/src/snowflake/connector/aio/__init__.py @@ -0,0 +1,20 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from ._connection import SnowflakeConnection +from ._cursor import DictCursor, SnowflakeCursor + +__all__ = [ + SnowflakeConnection, + SnowflakeCursor, + DictCursor, +] + + +async def connect(**kwargs) -> SnowflakeConnection: + conn = SnowflakeConnection(**kwargs) + await conn.connect() + return conn diff --git a/src/snowflake/connector/aio/_azure_storage_client.py b/src/snowflake/connector/aio/_azure_storage_client.py new file mode 100644 index 0000000000..0299128118 --- /dev/null +++ b/src/snowflake/connector/aio/_azure_storage_client.py @@ -0,0 +1,210 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import json +import xml.etree.ElementTree as ET +from datetime import datetime, timezone +from logging import getLogger +from random import choice +from string import hexdigits +from typing import TYPE_CHECKING, Any + +import aiohttp + +from ..azure_storage_client import AzureCredentialFilter +from ..azure_storage_client import ( + SnowflakeAzureRestClient as SnowflakeAzureRestClientSync, +) +from ..compat import quote +from ..constants import FileHeader, ResultStatus +from ..encryption_util import EncryptionMetadata +from ._storage_client import SnowflakeStorageClient as SnowflakeStorageClientAsync + +if TYPE_CHECKING: # pragma: no cover + from ..file_transfer_agent import SnowflakeFileMeta, StorageCredential + +from ..azure_storage_client import ( + ENCRYPTION_DATA, + MATDESC, + TOKEN_EXPIRATION_ERR_MESSAGE, +) + +logger = getLogger(__name__) + +getLogger("aiohttp").addFilter(AzureCredentialFilter()) + + +class SnowflakeAzureRestClient( + SnowflakeStorageClientAsync, SnowflakeAzureRestClientSync +): + def __init__( + self, + meta: SnowflakeFileMeta, + credentials: StorageCredential | None, + chunk_size: int, + stage_info: dict[str, Any], + use_s3_regional_url: bool = False, + ) -> None: + SnowflakeAzureRestClientSync.__init__( + self, + meta=meta, + stage_info=stage_info, + chunk_size=chunk_size, + credentials=credentials, + ) + + async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool: + return response.status == 403 and any( + message in response.reason for message in TOKEN_EXPIRATION_ERR_MESSAGE + ) + + async def _send_request_with_authentication_and_retry( + self, + verb: str, + url: str, + retry_id: int | str, + headers: dict[str, Any] = None, + data: bytes = None, + ) -> aiohttp.ClientResponse: + if not headers: + headers = {} + + def generate_authenticated_url_and_rest_args() -> tuple[str, dict[str, Any]]: + curtime = datetime.now(timezone.utc).replace(tzinfo=None) + timestamp = curtime.strftime("YYYY-MM-DD") + sas_token = self.credentials.creds["AZURE_SAS_TOKEN"] + if sas_token and sas_token.startswith("?"): + sas_token = sas_token[1:] + if "?" in url: + _url = url + "&" + sas_token + else: + _url = url + "?" + sas_token + headers["Date"] = timestamp + rest_args = {"headers": headers} + if data: + rest_args["data"] = data + return _url, rest_args + + return await self._send_request_with_retry( + verb, generate_authenticated_url_and_rest_args, retry_id + ) + + async def get_file_header(self, filename: str) -> FileHeader | None: + """Gets Azure file properties.""" + container_name = quote(self.azure_location.container_name) + path = quote(self.azure_location.path) + quote(filename) + meta = self.meta + # HTTP HEAD request + url = f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}" + retry_id = "HEAD" + self.retry_count[retry_id] = 0 + r = await self._send_request_with_authentication_and_retry( + "HEAD", url, retry_id + ) + if r.status == 200: + meta.result_status = ResultStatus.UPLOADED + enc_data_str = r.headers.get(ENCRYPTION_DATA) + encryption_data = None if enc_data_str is None else json.loads(enc_data_str) + encryption_metadata = ( + None + if not encryption_data + else EncryptionMetadata( + key=encryption_data["WrappedContentKey"]["EncryptedKey"], + iv=encryption_data["ContentEncryptionIV"], + matdesc=r.headers.get(MATDESC), + ) + ) + return FileHeader( + digest=r.headers.get("x-ms-meta-sfcdigest"), + content_length=int(r.headers.get("Content-Length")), + encryption_metadata=encryption_metadata, + ) + elif r.status == 404: + meta.result_status = ResultStatus.NOT_FOUND_FILE + return FileHeader( + digest=None, content_length=None, encryption_metadata=None + ) + else: + r.raise_for_status() + + async def _initiate_multipart_upload(self) -> None: + self.block_ids = [ + "".join(choice(hexdigits) for _ in range(20)) + for _ in range(self.num_of_chunks) + ] + + async def _upload_chunk(self, chunk_id: int, chunk: bytes) -> None: + container_name = quote(self.azure_location.container_name) + path = quote(self.azure_location.path + self.meta.dst_file_name.lstrip("/")) + + if self.num_of_chunks > 1: + block_id = self.block_ids[chunk_id] + url = ( + f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}?comp=block" + f"&blockid={block_id}" + ) + headers = {"Content-Length": str(len(chunk))} + r = await self._send_request_with_authentication_and_retry( + "PUT", url, chunk_id, headers=headers, data=chunk + ) + else: + # single request + azure_metadata = self._prepare_file_metadata() + url = f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}" + headers = { + "x-ms-blob-type": "BlockBlob", + "Content-Encoding": "utf-8", + } + headers.update(azure_metadata) + r = await self._send_request_with_authentication_and_retry( + "PUT", url, chunk_id, headers=headers, data=chunk + ) + r.raise_for_status() # expect status code 201 + + async def _complete_multipart_upload(self) -> None: + container_name = quote(self.azure_location.container_name) + path = quote(self.azure_location.path + self.meta.dst_file_name.lstrip("/")) + url = ( + f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}?comp" + f"=blocklist" + ) + root = ET.Element("BlockList") + for block_id in self.block_ids: + part = ET.Element("Latest") + part.text = block_id + root.append(part) + headers = {"x-ms-blob-content-encoding": "utf-8"} + azure_metadata = self._prepare_file_metadata() + headers.update(azure_metadata) + retry_id = "COMPLETE" + self.retry_count[retry_id] = 0 + r = await self._send_request_with_authentication_and_retry( + "PUT", url, "COMPLETE", headers=headers, data=ET.tostring(root) + ) + r.raise_for_status() # expects status code 201 + + async def download_chunk(self, chunk_id: int) -> None: + container_name = quote(self.azure_location.container_name) + path = quote(self.azure_location.path + self.meta.src_file_name.lstrip("/")) + url = f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}" + if self.num_of_chunks > 1: + chunk_size = self.chunk_size + if chunk_id < self.num_of_chunks - 1: + _range = f"{chunk_id * chunk_size}-{(chunk_id + 1) * chunk_size - 1}" + else: + _range = f"{chunk_id * chunk_size}-" + headers = {"Range": f"bytes={_range}"} + r = await self._send_request_with_authentication_and_retry( + "GET", url, chunk_id, headers=headers + ) # expect 206 + else: + # single request + r = await self._send_request_with_authentication_and_retry( + "GET", url, chunk_id + ) + if r.status in (200, 206): + self.write_downloaded_chunk(chunk_id, await r.read()) + r.raise_for_status() diff --git a/src/snowflake/connector/aio/_build_upload_agent.py b/src/snowflake/connector/aio/_build_upload_agent.py new file mode 100644 index 0000000000..f6f44511dc --- /dev/null +++ b/src/snowflake/connector/aio/_build_upload_agent.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from io import BytesIO +from logging import getLogger +from typing import TYPE_CHECKING, cast + +from snowflake.connector import Error +from snowflake.connector._utils import get_temp_type_for_object +from snowflake.connector.bind_upload_agent import BindUploadAgent as BindUploadAgentSync +from snowflake.connector.errors import BindUploadError + +if TYPE_CHECKING: + from snowflake.connector.aio import SnowflakeCursor + +logger = getLogger(__name__) + + +class BindUploadAgent(BindUploadAgentSync): + def __init__( + self, + cursor: SnowflakeCursor, + rows: list[bytes], + stream_buffer_size: int = 1024 * 1024 * 10, + ) -> None: + super().__init__(cursor, rows, stream_buffer_size) + self.cursor = cast("SnowflakeCursor", cursor) + + async def _create_stage(self) -> None: + create_stage_sql = ( + f"create or replace {get_temp_type_for_object(self._use_scoped_temp_object)} stage {self._STAGE_NAME} " + "file_format=(type=csv field_optionally_enclosed_by='\"')" + ) + await self.cursor.execute(create_stage_sql) + + async def upload(self) -> None: + try: + await self._create_stage() + except Error as err: + self.cursor.connection._session_parameters[ + "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD" + ] = 0 + logger.debug("Failed to create stage for binding.") + raise BindUploadError from err + + row_idx = 0 + while row_idx < len(self.rows): + f = BytesIO() + size = 0 + while True: + f.write(self.rows[row_idx]) + size += len(self.rows[row_idx]) + row_idx += 1 + if row_idx >= len(self.rows) or size >= self._stream_buffer_size: + break + try: + await self.cursor.execute( + f"PUT file://{row_idx}.csv {self.stage_path}", file_stream=f + ) + except Error as err: + logger.debug("Failed to upload the bindings file to stage.") + raise BindUploadError from err + f.close() diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py new file mode 100644 index 0000000000..d62bc33754 --- /dev/null +++ b/src/snowflake/connector/aio/_connection.py @@ -0,0 +1,979 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from __future__ import annotations + +import asyncio +import atexit +import copy +import logging +import os +import pathlib +import sys +import uuid +from contextlib import suppress +from io import StringIO +from logging import getLogger +from types import TracebackType +from typing import Any, AsyncIterator, Iterable + +from snowflake.connector import ( + DatabaseError, + EasyLoggingConfigPython, + Error, + OperationalError, + ProgrammingError, + proxy, +) + +from .._query_context_cache import QueryContextCache +from ..compat import IS_LINUX, quote, urlencode +from ..config_manager import CONFIG_MANAGER, _get_default_connection_params +from ..connection import DEFAULT_CONFIGURATION as DEFAULT_CONFIGURATION_SYNC +from ..connection import SnowflakeConnection as SnowflakeConnectionSync +from ..connection import _get_private_bytes_from_file +from ..constants import ( + _CONNECTIVITY_ERR_MSG, + ENV_VAR_PARTNER, + PARAMETER_AUTOCOMMIT, + PARAMETER_CLIENT_PREFETCH_THREADS, + PARAMETER_CLIENT_REQUEST_MFA_TOKEN, + PARAMETER_CLIENT_SESSION_KEEP_ALIVE, + PARAMETER_CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY, + PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL, + PARAMETER_CLIENT_TELEMETRY_ENABLED, + PARAMETER_CLIENT_VALIDATE_DEFAULT_PARAMETERS, + PARAMETER_ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1, + PARAMETER_QUERY_CONTEXT_CACHE_SIZE, + PARAMETER_SERVICE_NAME, + PARAMETER_TIMEZONE, + QueryStatus, +) +from ..description import PLATFORM, PYTHON_VERSION, SNOWFLAKE_CONNECTOR_VERSION +from ..errorcode import ( + ER_CONNECTION_IS_CLOSED, + ER_FAILED_TO_CONNECT_TO_DB, + ER_INVALID_VALUE, +) +from ..network import ( + DEFAULT_AUTHENTICATOR, + EXTERNAL_BROWSER_AUTHENTICATOR, + KEY_PAIR_AUTHENTICATOR, + OAUTH_AUTHENTICATOR, + REQUEST_ID, + USR_PWD_MFA_AUTHENTICATOR, + ReauthenticationRequest, +) +from ..sqlstate import SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_FEATURE_NOT_SUPPORTED +from ..telemetry import TelemetryData, TelemetryField +from ..time_util import get_time_millis +from ..util_text import split_statements +from ._cursor import SnowflakeCursor +from ._description import CLIENT_NAME +from ._network import SnowflakeRestful +from ._telemetry import TelemetryClient +from ._time_util import HeartBeatTimer +from .auth import ( + FIRST_PARTY_AUTHENTICATORS, + Auth, + AuthByDefault, + AuthByIdToken, + AuthByKeyPair, + AuthByOAuth, + AuthByOkta, + AuthByPlugin, + AuthByUsrPwdMfa, + AuthByWebBrowser, +) + +logger = getLogger(__name__) + +# deep copy to avoid pollute sync config +DEFAULT_CONFIGURATION = copy.deepcopy(DEFAULT_CONFIGURATION_SYNC) +DEFAULT_CONFIGURATION["application"] = (CLIENT_NAME, (type(None), str)) + + +class SnowflakeConnection(SnowflakeConnectionSync): + OCSP_ENV_LOCK = asyncio.Lock() + + def __init__( + self, + connection_name: str | None = None, + connections_file_path: pathlib.Path | None = None, + **kwargs, + ) -> None: + # note we don't call super here because asyncio can not/is not recommended + # to perform async operation in the __init__ while in the sync connection we + # perform connect + self._conn_parameters = self._init_connection_parameters( + kwargs, connection_name, connections_file_path + ) + self._connected = False + self.expired = False + # check SNOW-1218851 for long term improvement plan to refactor ocsp code + atexit.register(self._close_at_exit) + + def __enter__(self): + # async connection does not support sync context manager + raise TypeError( + "'SnowflakeConnection' object does not support the context manager protocol" + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + # async connection does not support sync context manager + raise TypeError( + "'SnowflakeConnection' object does not support the context manager protocol" + ) + + async def __aenter__(self) -> SnowflakeConnection: + """Context manager.""" + await self.connect() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Context manager with commit or rollback teardown.""" + if not self._session_parameters.get("AUTOCOMMIT", False): + # Either AUTOCOMMIT is turned off, or is not set so we default to old behavior + if exc_tb is None: + await self.commit() + else: + await self.rollback() + await self.close() + + async def __open_connection(self): + """Opens a new network connection.""" + self.converter = self._converter_class( + use_numpy=self._numpy, support_negative_year=self._support_negative_year + ) + + proxy.set_proxies( + self.proxy_host, self.proxy_port, self.proxy_user, self.proxy_password + ) + + self._rest = SnowflakeRestful( + host=self.host, + port=self.port, + protocol=self._protocol, + inject_client_pause=self._inject_client_pause, + connection=self, + ) + logger.debug("REST API object was created: %s:%s", self.host, self.port) + + if "SF_OCSP_RESPONSE_CACHE_SERVER_URL" in os.environ: + logger.debug( + "Custom OCSP Cache Server URL found in environment - %s", + os.environ["SF_OCSP_RESPONSE_CACHE_SERVER_URL"], + ) + + if ".privatelink.snowflakecomputing." in self.host: + await SnowflakeConnection.setup_ocsp_privatelink( + self.application, self.host + ) + else: + if "SF_OCSP_RESPONSE_CACHE_SERVER_URL" in os.environ: + del os.environ["SF_OCSP_RESPONSE_CACHE_SERVER_URL"] + + if self._session_parameters is None: + self._session_parameters = {} + if self._autocommit is not None: + self._session_parameters[PARAMETER_AUTOCOMMIT] = self._autocommit + + if self._timezone is not None: + self._session_parameters[PARAMETER_TIMEZONE] = self._timezone + + if self._validate_default_parameters: + # Snowflake will validate the requested database, schema, and warehouse + self._session_parameters[PARAMETER_CLIENT_VALIDATE_DEFAULT_PARAMETERS] = ( + True + ) + + if self.client_session_keep_alive is not None: + self._session_parameters[PARAMETER_CLIENT_SESSION_KEEP_ALIVE] = ( + self._client_session_keep_alive + ) + + if self.client_session_keep_alive_heartbeat_frequency is not None: + self._session_parameters[ + PARAMETER_CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY + ] = self._validate_client_session_keep_alive_heartbeat_frequency() + + if self.client_prefetch_threads: + self._session_parameters[PARAMETER_CLIENT_PREFETCH_THREADS] = ( + self._validate_client_prefetch_threads() + ) + + # Setup authenticator + auth = Auth(self.rest) + + if self._session_token and self._master_token: + await auth._rest.update_tokens( + self._session_token, + self._master_token, + self._master_validity_in_seconds, + ) + heartbeat_ret = await auth._rest._heartbeat() + logger.debug(heartbeat_ret) + if not heartbeat_ret or not heartbeat_ret.get("success"): + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": "Session and master tokens invalid", + "errno": ER_INVALID_VALUE, + }, + ) + else: + logger.debug("Session and master token validation successful.") + + else: + if self.auth_class is not None: + if type( + self.auth_class + ) not in FIRST_PARTY_AUTHENTICATORS and not issubclass( + type(self.auth_class), AuthByKeyPair + ): + raise TypeError("auth_class must be a child class of AuthByKeyPair") + self.auth_class = self.auth_class + elif self._authenticator == DEFAULT_AUTHENTICATOR: + self.auth_class = AuthByDefault( + password=self._password, + timeout=self.login_timeout, + backoff_generator=self._backoff_generator, + ) + elif self._authenticator == EXTERNAL_BROWSER_AUTHENTICATOR: + self._session_parameters[ + PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL + ] = (self._client_store_temporary_credential if IS_LINUX else True) + auth.read_temporary_credentials( + self.host, + self.user, + self._session_parameters, + ) + # Depending on whether self._rest.id_token is available we do different + # auth_instance + if self._rest.id_token is None: + self.auth_class = AuthByWebBrowser( + application=self.application, + protocol=self._protocol, + host=self.host, + port=self.port, + timeout=self.login_timeout, + backoff_generator=self._backoff_generator, + ) + else: + self.auth_class = AuthByIdToken( + id_token=self._rest.id_token, + application=self.application, + protocol=self._protocol, + host=self.host, + port=self.port, + timeout=self.login_timeout, + backoff_generator=self._backoff_generator, + ) + + elif self._authenticator == KEY_PAIR_AUTHENTICATOR: + private_key = self._private_key + + if self._private_key_file: + private_key = _get_private_bytes_from_file( + self._private_key_file, + self._private_key_file_pwd, + ) + + self.auth_class = AuthByKeyPair( + private_key=private_key, + timeout=self.login_timeout, + backoff_generator=self._backoff_generator, + ) + elif self._authenticator == OAUTH_AUTHENTICATOR: + self.auth_class = AuthByOAuth( + oauth_token=self._token, + timeout=self.login_timeout, + backoff_generator=self._backoff_generator, + ) + elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR: + self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = ( + self._client_request_mfa_token if IS_LINUX else True + ) + if self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN]: + auth.read_temporary_credentials( + self.host, + self.user, + self._session_parameters, + ) + self.auth_class = AuthByUsrPwdMfa( + password=self._password, + mfa_token=self.rest.mfa_token, + timeout=self.login_timeout, + backoff_generator=self._backoff_generator, + ) + else: + # okta URL, e.g., https://.okta.com/ + self.auth_class = AuthByOkta( + application=self.application, + timeout=self.login_timeout, + backoff_generator=self._backoff_generator, + ) + + await self.authenticate_with_retry(self.auth_class) + + self._password = None # ensure password won't persist + await self.auth_class.reset_secrets() + + self.initialize_query_context_cache() + + if self.client_session_keep_alive: + # This will be called after the heartbeat frequency has actually been set. + # By this point it should have been decided if the heartbeat has to be enabled + # and what would the heartbeat frequency be + await self._add_heartbeat() + + async def _add_heartbeat(self) -> None: + if not self._heartbeat_task: + self._heartbeat_task = HeartBeatTimer( + self.client_session_keep_alive_heartbeat_frequency, self._heartbeat_tick + ) + await self._heartbeat_task.start() + logger.debug("started heartbeat") + + async def _heartbeat_tick(self) -> None: + """Execute a hearbeat if connection isn't closed yet.""" + if not self.is_closed(): + logger.debug("heartbeating!") + await self.rest._heartbeat() + + async def _all_async_queries_finished(self) -> bool: + """Checks whether all async queries started by this Connection have finished executing.""" + + if not self._async_sfqids: + return True + + queries = list(reversed(self._async_sfqids.keys())) + + found_unfinished_query = False + + async def async_query_check_helper( + sfq_id: str, + ) -> bool: + try: + nonlocal found_unfinished_query + return found_unfinished_query or self.is_still_running( + await self.get_query_status(sfq_id) + ) + except asyncio.CancelledError: + pass + + tasks = [ + asyncio.create_task(async_query_check_helper(sfqid)) for sfqid in queries + ] + for task in asyncio.as_completed(tasks): + if await task: + found_unfinished_query = True + break + for task in tasks: + task.cancel() + await asyncio.gather(*tasks) + return not found_unfinished_query + + async def _authenticate(self, auth_instance: AuthByPlugin): + await auth_instance.prepare( + conn=self, + authenticator=self._authenticator, + service_name=self.service_name, + account=self.account, + user=self.user, + password=self._password, + ) + self._consent_cache_id_token = getattr( + auth_instance, "consent_cache_id_token", True + ) + + auth = Auth(self.rest) + # record start time for computing timeout + auth_instance._retry_ctx.set_start_time() + try: + await auth.authenticate( + auth_instance=auth_instance, + account=self.account, + user=self.user, + database=self.database, + schema=self.schema, + warehouse=self.warehouse, + role=self.role, + passcode=self._passcode, + passcode_in_password=self._passcode_in_password, + mfa_callback=self._mfa_callback, + password_callback=self._password_callback, + session_parameters=self._session_parameters, + ) + except OperationalError as e: + logger.debug( + "Operational Error raised at authentication" + f"for authenticator: {type(auth_instance).__name__}" + ) + while True: + try: + await auth_instance.handle_timeout( + authenticator=self._authenticator, + service_name=self.service_name, + account=self.account, + user=self.user, + password=self._password, + ) + await auth.authenticate( + auth_instance=auth_instance, + account=self.account, + user=self.user, + database=self.database, + schema=self.schema, + warehouse=self.warehouse, + role=self.role, + passcode=self._passcode, + passcode_in_password=self._passcode_in_password, + mfa_callback=self._mfa_callback, + password_callback=self._password_callback, + session_parameters=self._session_parameters, + ) + except OperationalError as auth_op: + if auth_op.errno == ER_FAILED_TO_CONNECT_TO_DB: + if _CONNECTIVITY_ERR_MSG in e.msg: + auth_op.msg += f"\n{_CONNECTIVITY_ERR_MSG}" + raise auth_op from e + logger.debug("Continuing authenticator specific timeout handling") + continue + break + + async def _cancel_heartbeat(self) -> None: + """Cancel a heartbeat thread.""" + if self._heartbeat_task: + await self._heartbeat_task.stop() + self._heartbeat_task = None + logger.debug("stopped heartbeat") + + def _init_connection_parameters( + self, + connection_init_kwargs: dict, + connection_name: str | None = None, + connections_file_path: pathlib.Path | None = None, + ) -> dict: + ret_kwargs = connection_init_kwargs + easy_logging = EasyLoggingConfigPython() + easy_logging.create_log() + self._lock_sequence_counter = asyncio.Lock() + self.sequence_counter = 0 + self._errorhandler = Error.default_errorhandler + self._lock_converter = asyncio.Lock() + self.messages = [] + self._async_sfqids: dict[str, None] = {} + self._done_async_sfqids: dict[str, None] = {} + self._client_param_telemetry_enabled = True + self._server_param_telemetry_enabled = False + self._session_parameters: dict[str, str | int | bool] = {} + logger.info( + "Snowflake Connector for Python Version: %s, " + "Python Version: %s, Platform: %s", + SNOWFLAKE_CONNECTOR_VERSION, + PYTHON_VERSION, + PLATFORM, + ) + + self._rest = None + for name, (value, _) in DEFAULT_CONFIGURATION.items(): + setattr(self, f"_{name}", value) + + self._heartbeat_task = None + is_kwargs_empty = not connection_init_kwargs + + if "application" not in connection_init_kwargs: + if ENV_VAR_PARTNER in os.environ.keys(): + connection_init_kwargs["application"] = os.environ[ENV_VAR_PARTNER] + elif "streamlit" in sys.modules: + connection_init_kwargs["application"] = "streamlit" + + self.converter = None + self.query_context_cache: QueryContextCache | None = None + self.query_context_cache_size = 5 + if connections_file_path is not None: + # Change config file path and force update cache + for i, s in enumerate(CONFIG_MANAGER._slices): + if s.section == "connections": + CONFIG_MANAGER._slices[i] = s._replace(path=connections_file_path) + CONFIG_MANAGER.read_config() + break + if connection_name is not None: + connections = CONFIG_MANAGER["connections"] + if connection_name not in connections: + raise Error( + f"Invalid connection_name '{connection_name}'," + f" known ones are {list(connections.keys())}" + ) + ret_kwargs = {**connections[connection_name], **connection_init_kwargs} + elif is_kwargs_empty: + # connection_name is None and kwargs was empty when called + ret_kwargs = _get_default_connection_params() + # TODO: SNOW-1770153 on self.__set_error_attributes() + return ret_kwargs + + async def _cancel_query( + self, sql: str, request_id: uuid.UUID + ) -> dict[str, bool | None]: + """Cancels the query with the exact SQL query and requestId.""" + logger.debug("_cancel_query sql=[%s], request_id=[%s]", sql, request_id) + url_parameters = {REQUEST_ID: str(uuid.uuid4())} + + return await self.rest.request( + "/queries/v1/abort-request?" + urlencode(url_parameters), + { + "sqlText": sql, + REQUEST_ID: str(request_id), + }, + ) + + def _close_at_exit(self): + with suppress(Exception): + asyncio.run(self.close(retry=False)) + + async def _get_query_status( + self, sf_qid: str + ) -> tuple[QueryStatus, dict[str, Any]]: + """Retrieves the status of query with sf_qid and returns it with the raw response. + + This is the underlying function used by the public get_status functions. + + Args: + sf_qid: Snowflake query id of interest. + + Raises: + ValueError: if sf_qid is not a valid UUID string. + """ + try: + uuid.UUID(sf_qid) + except ValueError: + raise ValueError(f"Invalid UUID: '{sf_qid}'") + logger.debug(f"get_query_status sf_qid='{sf_qid}'") + + status = "NO_DATA" + if self.is_closed(): + return QueryStatus.DISCONNECTED, {"data": {"queries": []}} + status_resp = await self.rest.request( + "/monitoring/queries/" + quote(sf_qid), method="get", client="rest" + ) + if "queries" not in status_resp["data"]: + return QueryStatus.FAILED_WITH_ERROR, status_resp + queries = status_resp["data"]["queries"] + if len(queries) > 0: + status = queries[0]["status"] + status_ret = QueryStatus[status] + return status_ret, status_resp + + async def _log_telemetry(self, telemetry_data) -> None: + if self.telemetry_enabled: + await self._telemetry.try_add_log_to_batch(telemetry_data) + + async def _log_telemetry_imported_packages(self) -> None: + if self._log_imported_packages_in_telemetry: + # filter out duplicates caused by submodules + # and internal modules with names starting with an underscore + imported_modules = { + k.split(".", maxsplit=1)[0] + for k in list(sys.modules) + if not k.startswith("_") + } + ts = get_time_millis() + await self._log_telemetry( + TelemetryData.from_telemetry_data_dict( + from_dict={ + TelemetryField.KEY_TYPE.value: TelemetryField.IMPORTED_PACKAGES.value, + TelemetryField.KEY_VALUE.value: str(imported_modules), + }, + timestamp=ts, + connection=self, + ) + ) + + async def _next_sequence_counter(self) -> int: + """Gets next sequence counter. Used internally.""" + async with self._lock_sequence_counter: + self.sequence_counter += 1 + logger.debug("sequence counter: %s", self.sequence_counter) + return self.sequence_counter + + async def _update_parameters( + self, + parameters: dict[str, str | int | bool], + ) -> None: + """Update session parameters.""" + async with self._lock_converter: + self.converter.set_parameters(parameters) + for name, value in parameters.items(): + self._session_parameters[name] = value + if PARAMETER_CLIENT_TELEMETRY_ENABLED == name: + self._server_param_telemetry_enabled = value + elif PARAMETER_CLIENT_SESSION_KEEP_ALIVE == name: + # Only set if the local config is None. + # Always give preference to user config. + if self.client_session_keep_alive is None: + self.client_session_keep_alive = value + elif ( + PARAMETER_CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY == name + and self.client_session_keep_alive_heartbeat_frequency is None + ): + # Only set if local value hasn't been set already. + self.client_session_keep_alive_heartbeat_frequency = value + elif PARAMETER_SERVICE_NAME == name: + self.service_name = value + elif PARAMETER_CLIENT_PREFETCH_THREADS == name: + self.client_prefetch_threads = value + elif PARAMETER_ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1 == name: + self.enable_stage_s3_privatelink_for_us_east_1 = value + elif PARAMETER_QUERY_CONTEXT_CACHE_SIZE == name: + self.query_context_cache_size = value + + async def _reauthenticate(self): + return await self._auth_class.reauthenticate(conn=self) + + @property + def auth_class(self) -> AuthByPlugin | None: + return self._auth_class + + @auth_class.setter + def auth_class(self, value: AuthByPlugin) -> None: + if isinstance(value, AuthByPlugin): + self._auth_class = value + else: + raise TypeError("auth_class must subclass AuthByPluginAsync") + + @property + def client_prefetch_threads(self) -> int: + return self._client_prefetch_threads + + @client_prefetch_threads.setter + def client_prefetch_threads(self, value) -> None: + self._client_prefetch_threads = value + + @property + def errorhandler(self) -> None: + # check SNOW-1763103 + raise NotImplementedError( + "Async Snowflake Python Connector does not support errorhandler. " + "Please open a feature request issue in github if your want this feature: " + "https://github.com/snowflakedb/snowflake-connector-python/issues/new/choose." + ) + + @errorhandler.setter + def errorhandler(self, value) -> None: + # check SNOW-1763103 + raise NotImplementedError( + "Async Snowflake Python Connector does not support errorhandler. " + "Please open a feature request issue in github if your want this feature: " + "https://github.com/snowflakedb/snowflake-connector-python/issues/new/choose." + ) + + @property + def rest(self) -> SnowflakeRestful | None: + return self._rest + + async def authenticate_with_retry(self, auth_instance) -> None: + # make some changes if needed before real __authenticate + try: + await self._authenticate(auth_instance) + except ReauthenticationRequest as ex: + # cached id_token expiration error, we have cleaned id_token and try to authenticate again + logger.debug("ID token expired. Reauthenticating...: %s", ex) + if isinstance(auth_instance, AuthByIdToken): + # Note: SNOW-733835 IDToken auth needs to authenticate through + # SSO if it has expired + await self._reauthenticate() + else: + await self._authenticate(auth_instance) + + async def autocommit(self, mode) -> None: + """Sets autocommit mode to True, or False. Defaults to True.""" + if not self.rest: + Error.errorhandler_wrapper( + self, + None, + DatabaseError, + { + "msg": "Connection is closed", + "errno": ER_CONNECTION_IS_CLOSED, + "sqlstate": SQLSTATE_CONNECTION_NOT_EXISTS, + }, + ) + if not isinstance(mode, bool): + Error.errorhandler_wrapper( + self, + None, + ProgrammingError, + { + "msg": f"Invalid parameter: {mode}", + "errno": ER_INVALID_VALUE, + }, + ) + try: + await self.cursor().execute(f"ALTER SESSION SET autocommit={mode}") + except Error as e: + if e.sqlstate == SQLSTATE_FEATURE_NOT_SUPPORTED: + logger.debug( + "Autocommit feature is not enabled for this " "connection. Ignored" + ) + + async def close(self, retry: bool = True) -> None: + """Closes the connection.""" + # unregister to dereference connection object as it's already closed after the execution + atexit.unregister(self._close_at_exit) + try: + if not self.rest: + logger.debug("Rest object has been destroyed, cannot close session") + return + + # will hang if the application doesn't close the connection and + # CLIENT_SESSION_KEEP_ALIVE is set, because the heartbeat runs on + # a separate thread. + await self._cancel_heartbeat() + + # close telemetry first, since it needs rest to send remaining data + logger.info("closed") + + await self._telemetry.close( + send_on_close=bool(retry and self.telemetry_enabled) + ) + if ( + await self._all_async_queries_finished() + and not self._server_session_keep_alive + ): + logger.info("No async queries seem to be running, deleting session") + try: + await self.rest.delete_session(retry=retry) + except Exception as e: + logger.debug( + "Exception encountered in deleting session. ignoring...: %s", e + ) + else: + logger.info( + "There are {} async queries still running, not deleting session".format( + len(self._async_sfqids) + ) + ) + await self.rest.close() + self._rest = None + if self.query_context_cache: + self.query_context_cache.clear_cache() + del self.messages[:] + logger.debug("Session is closed") + except Exception as e: + logger.debug( + "Exception encountered in closing connection. ignoring...: %s", e + ) + + async def cmd_query( + self, + sql: str, + sequence_counter: int, + request_id: uuid.UUID, + binding_params: None | tuple | dict[str, dict[str, str]] = None, + binding_stage: str | None = None, + is_file_transfer: bool = False, + statement_params: dict[str, str] | None = None, + is_internal: bool = False, + describe_only: bool = False, + _no_results: bool = False, + _update_current_object: bool = True, + _no_retry: bool = False, + timeout: int | None = None, + dataframe_ast: str | None = None, + ) -> dict[str, Any]: + """Executes a query with a sequence counter.""" + logger.debug("_cmd_query") + data = { + "sqlText": sql, + "asyncExec": _no_results, + "sequenceId": sequence_counter, + "querySubmissionTime": get_time_millis(), + } + if dataframe_ast is not None: + data["dataframeAst"] = dataframe_ast + if statement_params is not None: + data["parameters"] = statement_params + if is_internal: + data["isInternal"] = is_internal + if describe_only: + data["describeOnly"] = describe_only + if binding_stage is not None: + # binding stage for bulk array binding + data["bindStage"] = binding_stage + if binding_params is not None: + # binding parameters. This is for qmarks paramstyle. + data["bindings"] = binding_params + if not _no_results: + # not an async query. + queryContext = self.get_query_context() + # Here queryContextDTO should be a dict object field, same with `parameters` field + data["queryContextDTO"] = queryContext + client = "sfsql_file_transfer" if is_file_transfer else "sfsql" + + if logger.getEffectiveLevel() <= logging.DEBUG: + logger.debug( + "sql=[%s], sequence_id=[%s], is_file_transfer=[%s]", + self._format_query_for_log(data["sqlText"]), + data["sequenceId"], + is_file_transfer, + ) + + url_parameters = {REQUEST_ID: request_id} + + ret = await self.rest.request( + "/queries/v1/query-request?" + urlencode(url_parameters), + data, + client=client, + _no_results=_no_results, + _include_retry_params=True, + _no_retry=_no_retry, + timeout=timeout, + ) + + if ret is None: + ret = {"data": {}} + if ret.get("data") is None: + ret["data"] = {} + if _update_current_object: + data = ret["data"] + if "finalDatabaseName" in data and data["finalDatabaseName"] is not None: + self._database = data["finalDatabaseName"] + if "finalSchemaName" in data and data["finalSchemaName"] is not None: + self._schema = data["finalSchemaName"] + if "finalWarehouseName" in data and data["finalWarehouseName"] is not None: + self._warehouse = data["finalWarehouseName"] + if "finalRoleName" in data: + self._role = data["finalRoleName"] + if "queryContext" in data and not _no_results: + # here the data["queryContext"] field has been automatically converted from JSON into a dict type + self.set_query_context(data["queryContext"]) + + return ret + + async def commit(self) -> None: + """Commits the current transaction.""" + await self.cursor().execute("COMMIT") + + async def connect(self, **kwargs) -> None: + """Establishes connection to Snowflake.""" + logger.debug("connect") + if len(kwargs) > 0: + self.__config(**kwargs) + else: + self.__config(**self._conn_parameters) + + if self.enable_connection_diag: + raise NotImplementedError( + "Connection diagnostic is not supported in asyncio" + ) + else: + await self.__open_connection() + self._telemetry = TelemetryClient(self._rest) + await self._log_telemetry_imported_packages() + + def cursor( + self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor + ) -> SnowflakeCursor: + logger.debug("cursor") + if not self.rest: + Error.errorhandler_wrapper( + self, + None, + DatabaseError, + { + "msg": "Connection is closed.\nPlease establish the connection first by " + "explicitly calling `await SnowflakeConnection.connect()` or " + "using an async context manager: `async with SnowflakeConnection() as conn`. " + "\nEnsure the connection is open before attempting any operations.", + "errno": ER_CONNECTION_IS_CLOSED, + "sqlstate": SQLSTATE_CONNECTION_NOT_EXISTS, + }, + ) + return cursor_class(self) + + async def execute_stream( + self, + stream: StringIO, + remove_comments: bool = False, + cursor_class: type[SnowflakeCursor] = SnowflakeCursor, + **kwargs, + ) -> AsyncIterator[SnowflakeCursor, None, None]: + """Executes a stream of SQL statements. This is a non-standard convenient method.""" + split_statements_list = split_statements( + stream, remove_comments=remove_comments + ) + # Note: split_statements_list is a list of tuples of sql statements and whether they are put/get + non_empty_statements = [e for e in split_statements_list if e[0]] + for sql, is_put_or_get in non_empty_statements: + cur = self.cursor(cursor_class=cursor_class) + await cur.execute(sql, _is_put_get=is_put_or_get, **kwargs) + yield cur + + async def execute_string( + self, + sql_text: str, + remove_comments: bool = False, + return_cursors: bool = True, + cursor_class: type[SnowflakeCursor] = SnowflakeCursor, + **kwargs, + ) -> Iterable[SnowflakeCursor]: + """Executes a SQL text including multiple statements. This is a non-standard convenience method.""" + stream = StringIO(sql_text) + ret = [] + async for cursor in self.execute_stream( + stream, remove_comments=remove_comments, cursor_class=cursor_class, **kwargs + ): + ret.append(cursor) + + return ret if return_cursors else list() + + async def get_query_status(self, sf_qid: str) -> QueryStatus: + """Retrieves the status of query with sf_qid. + + Query status is returned as a QueryStatus. + + Args: + sf_qid: Snowflake query id of interest. + + Raises: + ValueError: if sf_qid is not a valid UUID string. + """ + status, _ = await self._get_query_status(sf_qid) + self._cache_query_status(sf_qid, status) + return status + + async def get_query_status_throw_if_error(self, sf_qid: str) -> QueryStatus: + """Retrieves the status of query with sf_qid as a QueryStatus and raises an exception if the query terminated with an error. + + Query status is returned as a QueryStatus. + + Args: + sf_qid: Snowflake query id of interest. + + Raises: + ValueError: if sf_qid is not a valid UUID string. + """ + status, status_resp = await self._get_query_status(sf_qid) + self._cache_query_status(sf_qid, status) + if self.is_an_error(status): + self._process_error_query_status(sf_qid, status_resp) + return status + + @staticmethod + async def setup_ocsp_privatelink(app, hostname) -> None: + async with SnowflakeConnection.OCSP_ENV_LOCK: + ocsp_cache_server = f"http://ocsp.{hostname}/ocsp_response_cache.json" + os.environ["SF_OCSP_RESPONSE_CACHE_SERVER_URL"] = ocsp_cache_server + logger.debug("OCSP Cache Server is updated: %s", ocsp_cache_server) + + async def rollback(self) -> None: + """Rolls back the current transaction.""" + await self.cursor().execute("ROLLBACK") diff --git a/src/snowflake/connector/aio/_cursor.py b/src/snowflake/connector/aio/_cursor.py new file mode 100644 index 0000000000..37a6fbd2c8 --- /dev/null +++ b/src/snowflake/connector/aio/_cursor.py @@ -0,0 +1,1134 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import collections +import logging +import re +import signal +import sys +import typing +import uuid +from logging import getLogger +from types import TracebackType +from typing import IO, TYPE_CHECKING, Any, AsyncIterator, Literal, Sequence, overload + +from typing_extensions import Self + +import snowflake.connector.cursor +from snowflake.connector import ( + Error, + IntegrityError, + InterfaceError, + NotSupportedError, + ProgrammingError, +) +from snowflake.connector._sql_util import get_file_transfer_type +from snowflake.connector.aio._build_upload_agent import BindUploadAgent +from snowflake.connector.aio._result_batch import ( + ResultBatch, + create_batches_from_response, +) +from snowflake.connector.aio._result_set import ResultSet, ResultSetIterator +from snowflake.connector.constants import ( + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT, + QueryStatus, +) +from snowflake.connector.cursor import ( + ASYNC_NO_DATA_MAX_RETRY, + ASYNC_RETRY_PATTERN, + DESC_TABLE_RE, +) +from snowflake.connector.cursor import DictCursor as DictCursorSync +from snowflake.connector.cursor import ResultMetadata, ResultMetadataV2, ResultState +from snowflake.connector.cursor import SnowflakeCursor as SnowflakeCursorSync +from snowflake.connector.cursor import T +from snowflake.connector.errorcode import ( + ER_CURSOR_IS_CLOSED, + ER_FAILED_PROCESSING_PYFORMAT, + ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT, + ER_INVALID_VALUE, + ER_NOT_POSITIVE_SIZE, +) +from snowflake.connector.errors import BindUploadError, DatabaseError +from snowflake.connector.file_transfer_agent import SnowflakeProgressPercentage +from snowflake.connector.telemetry import TelemetryData, TelemetryField +from snowflake.connector.time_util import get_time_millis + +if TYPE_CHECKING: + from pandas import DataFrame + from pyarrow import Table + + from snowflake.connector.aio import SnowflakeConnection + +logger = getLogger(__name__) + + +class SnowflakeCursor(SnowflakeCursorSync): + def __init__( + self, + connection: SnowflakeConnection, + use_dict_result: bool = False, + ): + super().__init__(connection, use_dict_result) + # the following fixes type hint + self._connection = typing.cast("SnowflakeConnection", self._connection) + self._inner_cursor = typing.cast(SnowflakeCursor, self._inner_cursor) + self._lock_canceling = asyncio.Lock() + self._timebomb: asyncio.Task | None = None + self._prefetch_hook: typing.Callable[[], typing.Awaitable] | None = None + + def __aiter__(self): + return self + + def __iter__(self): + raise TypeError( + "'snowflake.connector.aio.SnowflakeCursor' only supports async iteration." + ) + + async def __anext__(self): + while True: + _next = await self.fetchone() + if _next is None: + raise StopAsyncIteration + return _next + + async def __aenter__(self): + return self + + def __enter__(self): + # async cursor does not support sync context manager + raise TypeError( + "'SnowflakeCursor' object does not support the context manager protocol" + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + # async cursor does not support sync context manager + raise TypeError( + "'SnowflakeCursor' object does not support the context manager protocol" + ) + + def __del__(self): + # do nothing in async, __del__ is unreliable + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Context manager with commit or rollback.""" + await self.close() + + async def _timebomb_task(self, timeout, query): + try: + logger.debug("started timebomb in %ss", timeout) + await asyncio.sleep(timeout) + await self.__cancel_query(query) + return True + except asyncio.CancelledError: + logger.debug("cancelled timebomb in timebomb task") + return False + + async def __cancel_query(self, query) -> None: + if self._sequence_counter >= 0 and not self.is_closed(): + logger.debug("canceled. %s, request_id: %s", query, self._request_id) + async with self._lock_canceling: + await self._connection._cancel_query(query, self._request_id) + + async def _describe_internal( + self, *args: Any, **kwargs: Any + ) -> list[ResultMetadataV2]: + """Obtain the schema of the result without executing the query. + + This function takes the same arguments as execute, please refer to that function + for documentation. + + This function is for internal use only + + Returns: + The schema of the result, in the new result metadata format. + """ + kwargs["_describe_only"] = kwargs["_is_internal"] = True + await self.execute(*args, **kwargs) + return self._description + + async def _execute_helper( + self, + query: str, + timeout: int = 0, + statement_params: dict[str, str] | None = None, + binding_params: tuple | dict[str, dict[str, str]] = None, + binding_stage: str | None = None, + is_internal: bool = False, + describe_only: bool = False, + _no_results: bool = False, + _is_put_get=None, + _no_retry: bool = False, + dataframe_ast: str | None = None, + ) -> dict[str, Any]: + del self.messages[:] + + if statement_params is not None and not isinstance(statement_params, dict): + Error.errorhandler_wrapper( + self.connection, + self, + ProgrammingError, + { + "msg": "The data type of statement params is invalid. It must be dict.", + "errno": ER_INVALID_VALUE, + }, + ) + + # check if current installation include arrow extension or not, + # if not, we set statement level query result format to be JSON + if not snowflake.connector.cursor.CAN_USE_ARROW_RESULT_FORMAT: + logger.debug("Cannot use arrow result format, fallback to json format") + if statement_params is None: + statement_params = { + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "JSON" + } + else: + result_format_val = statement_params.get( + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT + ) + if str(result_format_val).upper() == "ARROW": + self.check_can_use_arrow_resultset() + elif result_format_val is None: + statement_params[PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT] = ( + "JSON" + ) + + self._sequence_counter = await self._connection._next_sequence_counter() + self._request_id = uuid.uuid4() + + logger.debug(f"Request id: {self._request_id}") + + logger.debug("running query [%s]", self._format_query_for_log(query)) + if _is_put_get is not None: + # if told the query is PUT or GET, use the information + self._is_file_transfer = _is_put_get + else: + # or detect it. + self._is_file_transfer = get_file_transfer_type(query) is not None + logger.debug("is_file_transfer: %s", self._is_file_transfer is not None) + + real_timeout = ( + timeout if timeout and timeout > 0 else self._connection.network_timeout + ) + + if real_timeout is not None: + self._timebomb = asyncio.create_task( + self._timebomb_task(real_timeout, query) + ) + logger.debug("started timebomb in %ss", real_timeout) + else: + self._timebomb = None + + original_sigint = signal.getsignal(signal.SIGINT) + + def interrupt_handler(*_): # pragma: no cover + try: + signal.signal(signal.SIGINT, snowflake.connector.cursor.exit_handler) + except (ValueError, TypeError): + # ignore failures + pass + try: + if self._timebomb is not None: + self._timebomb.cancel() + self._timebomb = None + logger.debug("cancelled timebomb in finally") + asyncio.create_task(self.__cancel_query(query)) + finally: + if original_sigint: + try: + signal.signal(signal.SIGINT, original_sigint) + except (ValueError, TypeError): + # ignore failures + pass + raise KeyboardInterrupt + + try: + if not original_sigint == snowflake.connector.cursor.exit_handler: + signal.signal(signal.SIGINT, interrupt_handler) + except ValueError: # pragma: no cover + logger.debug( + "Failed to set SIGINT handler. " "Not in main thread. Ignored..." + ) + ret: dict[str, Any] = {"data": {}} + try: + ret = await self._connection.cmd_query( + query, + self._sequence_counter, + self._request_id, + binding_params=binding_params, + binding_stage=binding_stage, + is_file_transfer=bool(self._is_file_transfer), + statement_params=statement_params, + is_internal=is_internal, + describe_only=describe_only, + _no_results=_no_results, + _no_retry=_no_retry, + timeout=real_timeout, + dataframe_ast=dataframe_ast, + ) + finally: + try: + if original_sigint: + signal.signal(signal.SIGINT, original_sigint) + except (ValueError, TypeError): # pragma: no cover + logger.debug( + "Failed to reset SIGINT handler. Not in main " "thread. Ignored..." + ) + if self._timebomb is not None: + self._timebomb.cancel() + try: + await self._timebomb + except asyncio.CancelledError: + pass + logger.debug("cancelled timebomb in finally") + + if "data" in ret and "parameters" in ret["data"]: + parameters = ret["data"].get("parameters", list()) + # Set session parameters for cursor object + for kv in parameters: + if "TIMESTAMP_OUTPUT_FORMAT" in kv["name"]: + self._timestamp_output_format = kv["value"] + elif "TIMESTAMP_NTZ_OUTPUT_FORMAT" in kv["name"]: + self._timestamp_ntz_output_format = kv["value"] + elif "TIMESTAMP_LTZ_OUTPUT_FORMAT" in kv["name"]: + self._timestamp_ltz_output_format = kv["value"] + elif "TIMESTAMP_TZ_OUTPUT_FORMAT" in kv["name"]: + self._timestamp_tz_output_format = kv["value"] + elif "DATE_OUTPUT_FORMAT" in kv["name"]: + self._date_output_format = kv["value"] + elif "TIME_OUTPUT_FORMAT" in kv["name"]: + self._time_output_format = kv["value"] + elif "TIMEZONE" in kv["name"]: + self._timezone = kv["value"] + elif "BINARY_OUTPUT_FORMAT" in kv["name"]: + self._binary_output_format = kv["value"] + # Set session parameters for connection object + await self._connection._update_parameters( + {p["name"]: p["value"] for p in parameters} + ) + + self.query = query + self._sequence_counter = -1 + return ret + + async def _init_result_and_meta(self, data: dict[Any, Any]) -> None: + is_dml = self._is_dml(data) + self._query_result_format = data.get("queryResultFormat", "json") + logger.debug("Query result format: %s", self._query_result_format) + + if self._total_rowcount == -1 and not is_dml and data.get("total") is not None: + self._total_rowcount = data["total"] + + self._description: list[ResultMetadataV2] = [ + ResultMetadataV2.from_column(col) for col in data["rowtype"] + ] + + result_chunks = create_batches_from_response( + self, self._query_result_format, data, self._description + ) + + if not (is_dml or self.is_file_transfer): + logger.info( + "Number of results in first chunk: %s", result_chunks[0].rowcount + ) + + self._result_set = ResultSet( + self, + result_chunks, + self._connection.client_prefetch_threads, + ) + self._rownumber = -1 + self._result_state = ResultState.VALID + + # don't update the row count when the result is returned from `describe` method + if is_dml and "rowset" in data and len(data["rowset"]) > 0: + updated_rows = 0 + for idx, desc in enumerate(self._description): + if desc.name in ( + "number of rows updated", + "number of multi-joined rows updated", + "number of rows deleted", + ) or desc.name.startswith("number of rows inserted"): + updated_rows += int(data["rowset"][0][idx]) + if self._total_rowcount == -1: + self._total_rowcount = updated_rows + else: + self._total_rowcount += updated_rows + + async def _init_multi_statement_results(self, data: dict) -> None: + await self._log_telemetry_job_data( + TelemetryField.MULTI_STATEMENT, TelemetryData.TRUE + ) + self.multi_statement_savedIds = data["resultIds"].split(",") + self._multi_statement_resultIds = collections.deque( + self.multi_statement_savedIds + ) + if self._is_file_transfer: + Error.errorhandler_wrapper( + self.connection, + self, + ProgrammingError, + { + "msg": "PUT/GET commands are not supported for multi-statement queries and cannot be executed.", + "errno": ER_INVALID_VALUE, + }, + ) + await self.nextset() + + async def _log_telemetry_job_data( + self, telemetry_field: TelemetryField, value: Any + ) -> None: + ts = get_time_millis() + try: + await self._connection._log_telemetry( + TelemetryData.from_telemetry_data_dict( + from_dict={ + TelemetryField.KEY_TYPE.value: telemetry_field.value, + TelemetryField.KEY_SFQID.value: self._sfqid, + TelemetryField.KEY_VALUE.value: value, + }, + timestamp=ts, + connection=self._connection, + ) + ) + except AttributeError: + logger.warning( + "Cursor failed to log to telemetry. Connection object may be None.", + exc_info=True, + ) + + async def _preprocess_pyformat_query( + self, + command: str, + params: Sequence[Any] | dict[Any, Any] | None = None, + ) -> str: + # pyformat/format paramstyle + # client side binding + processed_params = self._connection._process_params_pyformat(params, self) + # SNOW-513061 collect telemetry for empty sequence usage before we make the breaking change announcement + if params is not None and len(params) == 0: + await self._log_telemetry_job_data( + TelemetryField.EMPTY_SEQ_INTERPOLATION, + ( + TelemetryData.TRUE + if self.connection._interpolate_empty_sequences + else TelemetryData.FALSE + ), + ) + if logger.getEffectiveLevel() <= logging.DEBUG: + logger.debug( + f"binding: [{self._format_query_for_log(command)}] " + f"with input=[{params}], " + f"processed=[{processed_params}]", + ) + if ( + self.connection._interpolate_empty_sequences + and processed_params is not None + ) or ( + not self.connection._interpolate_empty_sequences + and len(processed_params) > 0 + ): + query = command % processed_params + else: + query = command + return query + + async def abort_query(self, qid: str) -> bool: + url = f"/queries/{qid}/abort-request" + ret = await self._connection.rest.request(url=url, method="post") + return ret.get("success") + + @overload + async def callproc(self, procname: str) -> tuple: ... + + @overload + async def callproc(self, procname: str, args: T) -> T: ... + + async def callproc(self, procname: str, args=tuple()): + """Call a stored procedure. + + Args: + procname: The stored procedure to be called. + args: Parameters to be passed into the stored procedure. + + Returns: + The input parameters. + """ + marker_format = "%s" if self._connection.is_pyformat else "?" + command = ( + f"CALL {procname}({', '.join([marker_format for _ in range(len(args))])})" + ) + await self.execute(command, args) + return args + + @property + def connection(self) -> SnowflakeConnection: + return self._connection + + async def close(self): + """Closes the cursor object. + + Returns whether the cursor was closed during this call. + """ + try: + if self.is_closed(): + return False + async with self._lock_canceling: + self.reset(closing=True) + self._connection = None + del self.messages[:] + return True + except Exception: + return None + + async def execute( + self, + command: str, + params: Sequence[Any] | dict[Any, Any] | None = None, + _bind_stage: str | None = None, + timeout: int | None = None, + _exec_async: bool = False, + _no_retry: bool = False, + _do_reset: bool = True, + _put_callback: SnowflakeProgressPercentage = None, + _put_azure_callback: SnowflakeProgressPercentage = None, + _put_callback_output_stream: IO[str] = sys.stdout, + _get_callback: SnowflakeProgressPercentage = None, + _get_azure_callback: SnowflakeProgressPercentage = None, + _get_callback_output_stream: IO[str] = sys.stdout, + _show_progress_bar: bool = True, + _statement_params: dict[str, str] | None = None, + _is_internal: bool = False, + _describe_only: bool = False, + _no_results: bool = False, + _is_put_get: bool | None = None, + _raise_put_get_error: bool = True, + _force_put_overwrite: bool = False, + _skip_upload_on_content_match: bool = False, + file_stream: IO[bytes] | None = None, + num_statements: int | None = None, + _dataframe_ast: str | None = None, + ) -> Self | dict[str, Any] | None: + if _exec_async: + _no_results = True + logger.debug("executing SQL/command") + if self.is_closed(): + Error.errorhandler_wrapper( + self.connection, + self, + InterfaceError, + {"msg": "Cursor is closed in execute.", "errno": ER_CURSOR_IS_CLOSED}, + ) + + if _do_reset: + self.reset() + command = command.strip(" \t\n\r") if command else None + if not command: + logger.warning("execute: no query is given to execute") + return None + logger.debug("query: [%s]", self._format_query_for_log(command)) + + _statement_params = _statement_params or dict() + # If we need to add another parameter, please consider introducing a dict for all extra params + # See discussion in https://github.com/snowflakedb/snowflake-connector-python/pull/1524#discussion_r1174061775 + if num_statements is not None: + _statement_params = { + **_statement_params, + "MULTI_STATEMENT_COUNT": num_statements, + } + + kwargs: dict[str, Any] = { + "timeout": timeout, + "statement_params": _statement_params, + "is_internal": _is_internal, + "describe_only": _describe_only, + "_no_results": _no_results, + "_is_put_get": _is_put_get, + "_no_retry": _no_retry, + "dataframe_ast": _dataframe_ast, + } + + if self._connection.is_pyformat: + query = await self._preprocess_pyformat_query(command, params) + else: + # qmark and numeric paramstyle + query = command + if _bind_stage: + kwargs["binding_stage"] = _bind_stage + else: + if params is not None and not isinstance(params, (list, tuple)): + errorvalue = { + "msg": f"Binding parameters must be a list: {params}", + "errno": ER_FAILED_PROCESSING_PYFORMAT, + } + Error.errorhandler_wrapper( + self.connection, self, ProgrammingError, errorvalue + ) + + kwargs["binding_params"] = self._connection._process_params_qmarks( + params, self + ) + + m = DESC_TABLE_RE.match(query) + if m: + query1 = f"describe table {m.group(1)}" + logger.debug( + "query was rewritten: org=%s, new=%s", + " ".join(line.strip() for line in query.split("\n")), + query1, + ) + query = query1 + + ret = await self._execute_helper(query, **kwargs) + self._sfqid = ( + ret["data"]["queryId"] + if "data" in ret and "queryId" in ret["data"] + else None + ) + logger.debug(f"sfqid: {self.sfqid}") + self._sqlstate = ( + ret["data"]["sqlState"] + if "data" in ret and "sqlState" in ret["data"] + else None + ) + logger.debug("query execution done") + + self._first_chunk_time = get_time_millis() + + # if server gives a send time, log the time it took to arrive + if "data" in ret and "sendResultTime" in ret["data"]: + time_consume_first_result = ( + self._first_chunk_time - ret["data"]["sendResultTime"] + ) + await self._log_telemetry_job_data( + TelemetryField.TIME_CONSUME_FIRST_RESULT, time_consume_first_result + ) + + if ret["success"]: + logger.debug("SUCCESS") + data = ret["data"] + + for m in self.ALTER_SESSION_RE.finditer(query): + # session parameters + param = m.group(1).upper() + value = m.group(2) + self._connection.converter.set_parameter(param, value) + + if "resultIds" in data: + await self._init_multi_statement_results(data) + return self + else: + self.multi_statement_savedIds = [] + + self._is_file_transfer = "command" in data and data["command"] in ( + "UPLOAD", + "DOWNLOAD", + ) + logger.debug("PUT OR GET: %s", self.is_file_transfer) + if self.is_file_transfer: + from ._file_transfer_agent import SnowflakeFileTransferAgent + + # Decide whether to use the old, or new code path + sf_file_transfer_agent = SnowflakeFileTransferAgent( + self, + query, + ret, + put_callback=_put_callback, + put_azure_callback=_put_azure_callback, + put_callback_output_stream=_put_callback_output_stream, + get_callback=_get_callback, + get_azure_callback=_get_azure_callback, + get_callback_output_stream=_get_callback_output_stream, + show_progress_bar=_show_progress_bar, + raise_put_get_error=_raise_put_get_error, + force_put_overwrite=_force_put_overwrite + or data.get("overwrite", False), + skip_upload_on_content_match=_skip_upload_on_content_match, + source_from_stream=file_stream, + multipart_threshold=data.get("threshold"), + use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1, + ) + await sf_file_transfer_agent.execute() + data = sf_file_transfer_agent.result() + self._total_rowcount = len(data["rowset"]) if "rowset" in data else -1 + + if _exec_async: + self.connection._async_sfqids[self._sfqid] = None + if _no_results: + self._total_rowcount = ( + ret["data"]["total"] + if "data" in ret and "total" in ret["data"] + else -1 + ) + return data + await self._init_result_and_meta(data) + else: + self._total_rowcount = ( + ret["data"]["total"] if "data" in ret and "total" in ret["data"] else -1 + ) + logger.debug(ret) + err = ret["message"] + code = ret.get("code", -1) + if self._timebomb and self._timebomb.result(): + err = ( + f"SQL execution was cancelled by the client due to a timeout. " + f"Error message received from the server: {err}" + ) + if "data" in ret: + err += ret["data"].get("errorMessage", "") + errvalue = { + "msg": err, + "errno": int(code), + "sqlstate": self._sqlstate, + "sfqid": self._sfqid, + "query": query, + } + is_integrity_error = ( + code == "100072" + ) # NULL result in a non-nullable column + error_class = IntegrityError if is_integrity_error else ProgrammingError + Error.errorhandler_wrapper(self.connection, self, error_class, errvalue) + return self + + async def executemany( + self, + command: str, + seqparams: Sequence[Any] | dict[str, Any], + **kwargs: Any, + ) -> SnowflakeCursor: + """Executes a command/query with the given set of parameters sequentially.""" + logger.debug("executing many SQLs/commands") + command = command.strip(" \t\n\r") if command else None + + if not seqparams: + logger.warning( + "No parameters provided to executemany, returning without doing anything." + ) + return self + + if self.INSERT_SQL_RE.match(command) and ( + "num_statements" not in kwargs or kwargs.get("num_statements") == 1 + ): + if self._connection.is_pyformat: + # TODO(SNOW-940692) - utilize multi-statement instead of rewriting the query and + # accumulate results to mock the result from a single insert statement as formatted below + logger.debug("rewriting INSERT query") + command_wo_comments = re.sub(self.COMMENT_SQL_RE, "", command) + m = self.INSERT_SQL_VALUES_RE.match(command_wo_comments) + if not m: + Error.errorhandler_wrapper( + self.connection, + self, + InterfaceError, + { + "msg": "Failed to rewrite multi-row insert", + "errno": ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT, + }, + ) + + fmt = m.group(1) + values = [] + for param in seqparams: + logger.debug(f"parameter: {param}") + values.append( + fmt % self._connection._process_params_pyformat(param, self) + ) + command = command.replace(fmt, ",".join(values), 1) + await self.execute(command, **kwargs) + return self + else: + logger.debug("bulk insert") + # sanity check + row_size = len(seqparams[0]) + for row in seqparams: + if len(row) != row_size: + error_value = { + "msg": f"Bulk data size don't match. expected: {row_size}, " + f"got: {len(row)}, command: {command}", + "errno": ER_INVALID_VALUE, + } + Error.errorhandler_wrapper( + self.connection, self, InterfaceError, error_value + ) + return self + bind_size = len(seqparams) * row_size + bind_stage = None + if ( + bind_size + > self.connection._session_parameters[ + "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD" + ] + > 0 + ): + # bind stage optimization + try: + rows = self.connection._write_params_to_byte_rows(seqparams) + bind_uploader = BindUploadAgent(self, rows) + await bind_uploader.upload() + bind_stage = bind_uploader.stage_path + except BindUploadError: + logger.debug( + "Failed to upload binds to stage, sending binds to " + "Snowflake instead." + ) + binding_param = ( + None if bind_stage else list(map(list, zip(*seqparams))) + ) # transpose + await self.execute( + command, params=binding_param, _bind_stage=bind_stage, **kwargs + ) + return self + + self.reset() + if "num_statements" not in kwargs: + # fall back to old driver behavior when the user does not provide the parameter to enable + # multi-statement optimizations for executemany + for param in seqparams: + await self.execute(command, params=param, _do_reset=False, **kwargs) + else: + if re.search(";/s*$", command) is None: + command = command + "; " + if self._connection.is_pyformat: + processed_queries = [ + await self._preprocess_pyformat_query(command, params) + for params in seqparams + ] + query = "".join(processed_queries) + params = None + else: + query = command * len(seqparams) + params = [param for parameters in seqparams for param in parameters] + + kwargs["num_statements"]: int = kwargs.get("num_statements") * len( + seqparams + ) + + await self.execute(query, params, _do_reset=False, **kwargs) + + return self + + async def execute_async(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + """Convenience function to execute a query without waiting for results (asynchronously). + + This function takes the same arguments as execute, please refer to that function + for documentation. Please note that PUT and GET statements are not supported by this method. + """ + kwargs["_exec_async"] = True + return await self.execute(*args, **kwargs) + + @property + def errorhandler(self): + # TODO: SNOW-1763103 for async error handler + raise NotImplementedError( + "Async Snowflake Python Connector does not support errorhandler. " + "Please open a feature request issue in github if your want this feature: " + "https://github.com/snowflakedb/snowflake-connector-python/issues/new/choose." + ) + + @errorhandler.setter + def errorhandler(self, value): + # TODO: SNOW-1763103 for async error handler + raise NotImplementedError( + "Async Snowflake Python Connector does not support errorhandler. " + "Please open a feature request issue in github if your want this feature: " + "https://github.com/snowflakedb/snowflake-connector-python/issues/new/choose." + ) + + async def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]: + """Obtain the schema of the result without executing the query. + + This function takes the same arguments as execute, please refer to that function + for documentation. + + Returns: + The schema of the result. + """ + kwargs["_describe_only"] = kwargs["_is_internal"] = True + await self.execute(*args, **kwargs) + + if self._description is None: + return None + return [meta._to_result_metadata_v1() for meta in self._description] + + async def fetchone(self) -> dict | tuple | None: + """Fetches one row.""" + if self._prefetch_hook is not None: + await self._prefetch_hook() + if self._result is None and self._result_set is not None: + self._result: ResultSetIterator = await self._result_set._create_iter() + self._result_state = ResultState.VALID + try: + if self._result is None: + raise TypeError("'NoneType' object is not an iterator") + _next = await self._result.get_next() + if isinstance(_next, Exception): + Error.errorhandler_wrapper_from_ready_exception( + self._connection, + self, + _next, + ) + if _next is not None: + self._rownumber += 1 + return _next + except TypeError as err: + if self._result_state == ResultState.DEFAULT: + raise err + else: + return None + + async def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]: + """Fetches the number of specified rows.""" + if size is None: + size = self.arraysize + + if size < 0: + errorvalue = { + "msg": ( + "The number of rows is not zero or " "positive number: {}" + ).format(size), + "errno": ER_NOT_POSITIVE_SIZE, + } + Error.errorhandler_wrapper( + self.connection, self, ProgrammingError, errorvalue + ) + ret = [] + while size > 0: + row = await self.fetchone() + if row is None: + break + ret.append(row) + if size is not None: + size -= 1 + + return ret + + async def fetchall(self) -> list[tuple] | list[dict]: + """Fetches all of the results.""" + if self._prefetch_hook is not None: + await self._prefetch_hook() + if self._result is None and self._result_set is not None: + self._result: ResultSetIterator = await self._result_set._create_iter( + is_fetch_all=True + ) + self._result_state = ResultState.VALID + + if self._result is None: + if self._result_state == ResultState.DEFAULT: + raise TypeError("'NoneType' object is not an iterator") + else: + return [] + + return await self._result.fetch_all_data() + + async def fetch_arrow_batches(self) -> AsyncIterator[Table]: + self.check_can_use_arrow_resultset() + if self._prefetch_hook is not None: + await self._prefetch_hook() + if self._query_result_format != "arrow": + raise NotSupportedError + await self._log_telemetry_job_data( + TelemetryField.ARROW_FETCH_BATCHES, TelemetryData.TRUE + ) + return await self._result_set._fetch_arrow_batches() + + @overload + async def fetch_arrow_all( + self, force_return_table: Literal[False] + ) -> Table | None: ... + + @overload + async def fetch_arrow_all(self, force_return_table: Literal[True]) -> Table: ... + + async def fetch_arrow_all(self, force_return_table: bool = False) -> Table | None: + """ + Args: + force_return_table: Set to True so that when the query returns zero rows, + an empty pyarrow table will be returned with schema using the highest bit length for each column. + Default value is False in which case None is returned in case of zero rows. + """ + self.check_can_use_arrow_resultset() + + if self._prefetch_hook is not None: + await self._prefetch_hook() + if self._query_result_format != "arrow": + raise NotSupportedError + await self._log_telemetry_job_data( + TelemetryField.ARROW_FETCH_ALL, TelemetryData.TRUE + ) + return await self._result_set._fetch_arrow_all( + force_return_table=force_return_table + ) + + async def fetch_pandas_batches(self, **kwargs: Any) -> AsyncIterator[DataFrame]: + """Fetches a single Arrow Table.""" + self.check_can_use_pandas() + if self._prefetch_hook is not None: + await self._prefetch_hook() + if self._query_result_format != "arrow": + raise NotSupportedError + await self._log_telemetry_job_data( + TelemetryField.PANDAS_FETCH_BATCHES, TelemetryData.TRUE + ) + return await self._result_set._fetch_pandas_batches(**kwargs) + + async def fetch_pandas_all(self, **kwargs: Any) -> DataFrame: + self.check_can_use_pandas() + if self._prefetch_hook is not None: + await self._prefetch_hook() + if self._query_result_format != "arrow": + raise NotSupportedError + await self._log_telemetry_job_data( + TelemetryField.PANDAS_FETCH_ALL, TelemetryData.TRUE + ) + return await self._result_set._fetch_pandas_all(**kwargs) + + async def nextset(self) -> SnowflakeCursor | None: + """ + Fetches the next set of results if the previously executed query was multi-statement so that subsequent calls + to any of the fetch*() methods will return rows from the next query's set of results. Returns None if no more + query results are available. + """ + if self._prefetch_hook is not None: + await self._prefetch_hook() + self.reset() + if self._multi_statement_resultIds: + await self.query_result(self._multi_statement_resultIds[0]) + logger.info( + f"Retrieved results for query ID: {self._multi_statement_resultIds.popleft()}" + ) + return self + + return None + + async def get_result_batches(self) -> list[ResultBatch] | None: + """Get the previously executed query's ``ResultBatch`` s if available. + + If they are unavailable, in case nothing has been executed yet None will + be returned. + + For a detailed description of ``ResultBatch`` s please see the docstring of: + ``snowflake.connector.result_batches.ResultBatch`` + """ + if self._result_set is None: + return None + await self._log_telemetry_job_data( + TelemetryField.GET_PARTITIONS_USED, TelemetryData.TRUE + ) + return self._result_set.batches + + async def get_results_from_sfqid(self, sfqid: str) -> None: + """Gets the results from previously ran query. This methods differs from ``SnowflakeCursor.query_result`` + in that it monitors the ``sfqid`` until it is no longer running, and then retrieves the results. + """ + + async def wait_until_ready() -> None: + """Makes sure query has finished executing and once it has retrieves results.""" + no_data_counter = 0 + retry_pattern_pos = 0 + while True: + status, status_resp = await self.connection._get_query_status(sfqid) + self.connection._cache_query_status(sfqid, status) + if not self.connection.is_still_running(status): + break + if status == QueryStatus.NO_DATA: # pragma: no cover + no_data_counter += 1 + if no_data_counter > ASYNC_NO_DATA_MAX_RETRY: + raise DatabaseError( + "Cannot retrieve data on the status of this query. No information returned " + "from server for query '{}'" + ) + await asyncio.sleep( + 0.5 * ASYNC_RETRY_PATTERN[retry_pattern_pos] + ) # Same wait as JDBC + # If we can advance in ASYNC_RETRY_PATTERN then do so + if retry_pattern_pos < (len(ASYNC_RETRY_PATTERN) - 1): + retry_pattern_pos += 1 + if status != QueryStatus.SUCCESS: + logger.info(f"Status of query '{sfqid}' is {status.name}") + self.connection._process_error_query_status( + sfqid, + status_resp, + error_message=f"Status of query '{sfqid}' is {status.name}, results are unavailable", + error_cls=DatabaseError, + ) + await self._inner_cursor.execute( + f"select * from table(result_scan('{sfqid}'))" + ) + self._result = self._inner_cursor._result + self._query_result_format = self._inner_cursor._query_result_format + self._total_rowcount = self._inner_cursor._total_rowcount + self._description = self._inner_cursor._description + self._result_set = self._inner_cursor._result_set + self._result_state = ResultState.VALID + self._rownumber = 0 + # Unset this function, so that we don't block anymore + self._prefetch_hook = None + + if ( + self._inner_cursor._total_rowcount == 1 + and await self._inner_cursor.fetchall() + == [("Multiple statements executed successfully.",)] + ): + url = f"/queries/{sfqid}/result" + ret = await self._connection.rest.request(url=url, method="get") + if "data" in ret and "resultIds" in ret["data"]: + await self._init_multi_statement_results(ret["data"]) + + await self.connection.get_query_status_throw_if_error( + sfqid + ) # Trigger an exception if query failed + klass = self.__class__ + self._inner_cursor = klass(self.connection) + self._sfqid = sfqid + self._prefetch_hook = wait_until_ready + + async def query_result(self, qid: str) -> SnowflakeCursor: + """Query the result of a previously executed query.""" + url = f"/queries/{qid}/result" + ret = await self._connection.rest.request(url=url, method="get") + self._sfqid = ( + ret["data"]["queryId"] + if "data" in ret and "queryId" in ret["data"] + else None + ) + self._sqlstate = ( + ret["data"]["sqlState"] + if "data" in ret and "sqlState" in ret["data"] + else None + ) + logger.debug("sfqid=%s", self._sfqid) + + if ret.get("success"): + data = ret.get("data") + await self._init_result_and_meta(data) + else: + logger.info("failed") + logger.debug(ret) + err = ret["message"] + code = ret.get("code", -1) + if "data" in ret: + err += ret["data"].get("errorMessage", "") + errvalue = { + "msg": err, + "errno": int(code), + "sqlstate": self._sqlstate, + "sfqid": self._sfqid, + } + Error.errorhandler_wrapper( + self.connection, self, ProgrammingError, errvalue + ) + return self + + +class DictCursor(DictCursorSync, SnowflakeCursor): + pass diff --git a/src/snowflake/connector/aio/_description.py b/src/snowflake/connector/aio/_description.py new file mode 100644 index 0000000000..9b5f175408 --- /dev/null +++ b/src/snowflake/connector/aio/_description.py @@ -0,0 +1,9 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +"""Various constants.""" + +from __future__ import annotations + +CLIENT_NAME = "AsyncioPythonConnector" # don't change! diff --git a/src/snowflake/connector/aio/_file_transfer_agent.py b/src/snowflake/connector/aio/_file_transfer_agent.py new file mode 100644 index 0000000000..f87444ef59 --- /dev/null +++ b/src/snowflake/connector/aio/_file_transfer_agent.py @@ -0,0 +1,311 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import os +import sys +from logging import getLogger +from typing import IO, TYPE_CHECKING, Any + +from ..constants import ( + AZURE_CHUNK_SIZE, + AZURE_FS, + CMD_TYPE_DOWNLOAD, + CMD_TYPE_UPLOAD, + GCS_FS, + LOCAL_FS, + S3_FS, + ResultStatus, + megabyte, +) +from ..errorcode import ER_FILE_NOT_EXISTS +from ..errors import Error, OperationalError +from ..file_transfer_agent import SnowflakeFileMeta +from ..file_transfer_agent import ( + SnowflakeFileTransferAgent as SnowflakeFileTransferAgentSync, +) +from ..file_transfer_agent import SnowflakeProgressPercentage, _chunk_size_calculator +from ..local_storage_client import SnowflakeLocalStorageClient +from ._azure_storage_client import SnowflakeAzureRestClient +from ._gcs_storage_client import SnowflakeGCSRestClient +from ._s3_storage_client import SnowflakeS3RestClient +from ._storage_client import SnowflakeStorageClient + +if TYPE_CHECKING: # pragma: no cover + from ._cursor import SnowflakeCursor + + +logger = getLogger(__name__) + + +class SnowflakeFileTransferAgent(SnowflakeFileTransferAgentSync): + """Snowflake File Transfer Agent provides cloud provider independent implementation for putting/getting files.""" + + def __init__( + self, + cursor: SnowflakeCursor, + command: str, + ret: dict[str, Any], + put_callback: type[SnowflakeProgressPercentage] | None = None, + put_azure_callback: type[SnowflakeProgressPercentage] | None = None, + put_callback_output_stream: IO[str] = sys.stdout, + get_callback: type[SnowflakeProgressPercentage] | None = None, + get_azure_callback: type[SnowflakeProgressPercentage] | None = None, + get_callback_output_stream: IO[str] = sys.stdout, + show_progress_bar: bool = True, + raise_put_get_error: bool = True, + force_put_overwrite: bool = True, + skip_upload_on_content_match: bool = False, + multipart_threshold: int | None = None, + source_from_stream: IO[bytes] | None = None, + use_s3_regional_url: bool = False, + ) -> None: + super().__init__( + cursor, + command, + ret, + put_callback, + put_azure_callback, + put_callback_output_stream, + get_callback, + get_azure_callback, + get_callback_output_stream, + show_progress_bar, + raise_put_get_error, + force_put_overwrite, + skip_upload_on_content_match, + multipart_threshold, + source_from_stream, + use_s3_regional_url, + ) + + async def execute(self) -> None: + self._parse_command() + self._init_file_metadata() + + if self._command_type == CMD_TYPE_UPLOAD: + self._process_file_compression_type() + + for m in self._file_metadata: + m.sfagent = self + + await self._transfer_accelerate_config() + + if self._command_type == CMD_TYPE_DOWNLOAD: + if not os.path.isdir(self._local_location): + os.makedirs(self._local_location) + + if self._stage_location_type == LOCAL_FS: + if not os.path.isdir(self._stage_info["location"]): + os.makedirs(self._stage_info["location"]) + + for m in self._file_metadata: + m.overwrite = self._overwrite + m.skip_upload_on_content_match = self._skip_upload_on_content_match + m.sfagent = self + if self._stage_location_type != LOCAL_FS: + m.put_callback = self._put_callback + m.put_azure_callback = self._put_azure_callback + m.put_callback_output_stream = self._put_callback_output_stream + m.get_callback = self._get_callback + m.get_azure_callback = self._get_azure_callback + m.get_callback_output_stream = self._get_callback_output_stream + m.show_progress_bar = self._show_progress_bar + + # multichunk threshold + m.multipart_threshold = self._multipart_threshold + + # TODO: SNOW-1625364 for renaming client_prefetch_threads in asyncio + logger.debug(f"parallel=[{self._parallel}]") + if self._raise_put_get_error and not self._file_metadata: + Error.errorhandler_wrapper( + self._cursor.connection, + self._cursor, + OperationalError, + { + "msg": "While getting file(s) there was an error: " + "the file does not exist.", + "errno": ER_FILE_NOT_EXISTS, + }, + ) + await self.transfer(self._file_metadata) + + # turn enum to string, in order to have backward compatible interface + + for result in self._results: + result.result_status = result.result_status.value + + async def transfer(self, metas: list[SnowflakeFileMeta]) -> None: + files = [await self._create_file_transfer_client(m) for m in metas] + is_upload = self._command_type == CMD_TYPE_UPLOAD + finish_download_upload_tasks = [] + + async def preprocess_done_cb( + success: bool, + result: Any, + done_client: SnowflakeStorageClient, + ) -> None: + if not success: + logger.debug(f"Failed to prepare {done_client.meta.name}.") + try: + if is_upload: + await done_client.finish_upload() + done_client.delete_client_data() + else: + await done_client.finish_download() + except Exception as error: + done_client.meta.error_details = error + elif done_client.meta.result_status == ResultStatus.SKIPPED: + # this case applies to upload only + return + else: + try: + logger.debug(f"Finished preparing file {done_client.meta.name}") + tasks = [] + for _chunk_id in range(done_client.num_of_chunks): + task = ( + asyncio.create_task(done_client.upload_chunk(_chunk_id)) + if is_upload + else asyncio.create_task( + done_client.download_chunk(_chunk_id) + ) + ) + task.add_done_callback( + lambda t, dc=done_client, _chunk_id=_chunk_id: transfer_done_cb( + t, dc, _chunk_id + ) + ) + tasks.append(task) + await asyncio.gather(*tasks) + await asyncio.gather(*finish_download_upload_tasks) + except Exception as error: + done_client.meta.error_details = error + + def transfer_done_cb( + task: asyncio.Task, + done_client: SnowflakeStorageClient, + chunk_id: int, + ) -> None: + # Note: chunk_id is 0 based while num_of_chunks is count + logger.debug( + f"Chunk {chunk_id}/{done_client.num_of_chunks} of file {done_client.meta.name} reached callback" + ) + if task.exception(): + done_client.failed_transfers += 1 + logger.debug( + f"Chunk {chunk_id} of file {done_client.meta.name} failed to transfer for unexpected exception {task.exception()}" + ) + else: + done_client.successful_transfers += 1 + logger.debug( + f"Chunk progress: {done_client.meta.name}: completed: {done_client.successful_transfers} failed: {done_client.failed_transfers} total: {done_client.num_of_chunks}" + ) + if ( + done_client.successful_transfers + done_client.failed_transfers + == done_client.num_of_chunks + ): + if is_upload: + finish_upload_task = asyncio.create_task( + done_client.finish_upload() + ) + finish_download_upload_tasks.append(finish_upload_task) + done_client.delete_client_data() + else: + finish_download_task = asyncio.create_task( + done_client.finish_download() + ) + finish_download_task.add_done_callback( + lambda t, dc=done_client: postprocess_done_cb(t, dc) + ) + finish_download_upload_tasks.append(finish_download_task) + + def postprocess_done_cb( + task: asyncio.Task, + done_client: SnowflakeStorageClient, + ) -> None: + logger.debug(f"File {done_client.meta.name} reached postprocess callback") + + if task.exception(): + done_client.failed_transfers += 1 + logger.debug( + f"File {done_client.meta.name} failed to transfer for unexpected exception {task.exception()}" + ) + # Whether there was an exception or not, we're done the file. + + task_of_files = [] + for file_client in files: + try: + # TODO: SNOW-1708819 for code refactoring + res = ( + await file_client.prepare_upload() + if is_upload + else await file_client.prepare_download() + ) + is_successful = True + except Exception as e: + res = e + file_client.meta.error_details = e + is_successful = False + + task = asyncio.create_task( + preprocess_done_cb(is_successful, res, done_client=file_client) + ) + task_of_files.append(task) + await asyncio.gather(*task_of_files) + + self._results = metas + + async def _transfer_accelerate_config(self) -> None: + if self._stage_location_type == S3_FS and self._file_metadata: + client = await self._create_file_transfer_client(self._file_metadata[0]) + self._use_accelerate_endpoint = await client.transfer_accelerate_config() + + async def _create_file_transfer_client( + self, meta: SnowflakeFileMeta + ) -> SnowflakeStorageClient: + if self._stage_location_type == LOCAL_FS: + return SnowflakeLocalStorageClient( + meta, + self._stage_info, + 4 * megabyte, + ) + elif self._stage_location_type == AZURE_FS: + return SnowflakeAzureRestClient( + meta, + self._credentials, + AZURE_CHUNK_SIZE, + self._stage_info, + use_s3_regional_url=self._use_s3_regional_url, + ) + elif self._stage_location_type == S3_FS: + client = SnowflakeS3RestClient( + meta=meta, + credentials=self._credentials, + stage_info=self._stage_info, + chunk_size=_chunk_size_calculator(meta.src_file_size), + use_accelerate_endpoint=self._use_accelerate_endpoint, + use_s3_regional_url=self._use_s3_regional_url, + ) + await client.transfer_accelerate_config(self._use_accelerate_endpoint) + return client + elif self._stage_location_type == GCS_FS: + client = SnowflakeGCSRestClient( + meta, + self._credentials, + self._stage_info, + self._cursor._connection, + self._command, + use_s3_regional_url=self._use_s3_regional_url, + ) + if client.security_token: + logger.debug(f"len(GCS_ACCESS_TOKEN): {len(client.security_token)}") + else: + logger.debug( + "No access token received from GS, requesting presigned url" + ) + await client._update_presigned_url() + return client + raise Exception(f"{self._stage_location_type} is an unknown stage type") diff --git a/src/snowflake/connector/aio/_gcs_storage_client.py b/src/snowflake/connector/aio/_gcs_storage_client.py new file mode 100644 index 0000000000..5ad3e2f97c --- /dev/null +++ b/src/snowflake/connector/aio/_gcs_storage_client.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import json +import os +from logging import getLogger +from typing import TYPE_CHECKING, Any + +import aiohttp + +from ..constants import HTTP_HEADER_CONTENT_ENCODING, FileHeader, ResultStatus +from ..encryption_util import EncryptionMetadata +from ..gcs_storage_client import SnowflakeGCSRestClient as SnowflakeGCSRestClientSync +from ._storage_client import SnowflakeStorageClient as SnowflakeStorageClientAsync + +if TYPE_CHECKING: # pragma: no cover + from ..file_transfer_agent import SnowflakeFileMeta, StorageCredential + from ._connection import SnowflakeConnection + +logger = getLogger(__name__) + +from ..gcs_storage_client import ( + GCS_METADATA_ENCRYPTIONDATAPROP, + GCS_METADATA_MATDESC_KEY, + GCS_METADATA_SFC_DIGEST, +) + + +class SnowflakeGCSRestClient(SnowflakeStorageClientAsync, SnowflakeGCSRestClientSync): + def __init__( + self, + meta: SnowflakeFileMeta, + credentials: StorageCredential, + stage_info: dict[str, Any], + cnx: SnowflakeConnection, + command: str, + use_s3_regional_url: bool = False, + ) -> None: + """Creates a client object with given stage credentials. + + Args: + stage_info: Access credentials and info of a stage. + + Returns: + The client to communicate with GCS. + """ + SnowflakeStorageClientAsync.__init__( + self, + meta=meta, + stage_info=stage_info, + chunk_size=-1, + credentials=credentials, + chunked_transfer=False, + ) + self.stage_info = stage_info + self._command = command + self.meta = meta + self._cursor = cnx.cursor() + # presigned_url in meta is for downloading + self.presigned_url: str = meta.presigned_url or stage_info.get("presignedUrl") + self.security_token = credentials.creds.get("GCS_ACCESS_TOKEN") + + async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool: + return self.security_token and response.status == 401 + + async def _has_expired_presigned_url( + self, response: aiohttp.ClientResponse + ) -> bool: + # Presigned urls can be generated for any xml-api operation + # offered by GCS. Hence the error codes expected are similar + # to xml api. + # https://cloud.google.com/storage/docs/xml-api/reference-status + + presigned_url_expired = (not self.security_token) and response.status == 400 + if presigned_url_expired and self.last_err_is_presigned_url: + logger.debug("Presigned url expiration error two times in a row.") + response.raise_for_status() + self.last_err_is_presigned_url = presigned_url_expired + return presigned_url_expired + + async def _upload_chunk(self, chunk_id: int, chunk: bytes) -> None: + meta = self.meta + + content_encoding = "" + if meta.dst_compression_type is not None: + content_encoding = meta.dst_compression_type.name.lower() + + # We set the contentEncoding to blank for GZIP files. We don't + # want GCS to think our gzip files are gzips because it makes + # them download uncompressed, and none of the other providers do + # that. There's essentially no way for us to prevent that + # behavior. Bad Google. + if content_encoding and content_encoding == "gzip": + content_encoding = "" + + gcs_headers = { + HTTP_HEADER_CONTENT_ENCODING: content_encoding, + GCS_METADATA_SFC_DIGEST: meta.sha256_digest, + } + + if self.encryption_metadata: + gcs_headers.update( + { + GCS_METADATA_ENCRYPTIONDATAPROP: json.dumps( + { + "EncryptionMode": "FullBlob", + "WrappedContentKey": { + "KeyId": "symmKey1", + "EncryptedKey": self.encryption_metadata.key, + "Algorithm": "AES_CBC_256", + }, + "EncryptionAgent": { + "Protocol": "1.0", + "EncryptionAlgorithm": "AES_CBC_256", + }, + "ContentEncryptionIV": self.encryption_metadata.iv, + "KeyWrappingMetadata": {"EncryptionLibrary": "Java 5.3.0"}, + } + ), + GCS_METADATA_MATDESC_KEY: self.encryption_metadata.matdesc, + } + ) + + def generate_url_and_rest_args() -> ( + tuple[str, dict[str, dict[str | Any, str | None] | bytes]] + ): + if not self.presigned_url: + upload_url = self.generate_file_url( + self.stage_info["location"], meta.dst_file_name.lstrip("/") + ) + access_token = self.security_token + else: + upload_url = self.presigned_url + access_token: str | None = None + if access_token: + gcs_headers.update({"Authorization": f"Bearer {access_token}"}) + rest_args = {"headers": gcs_headers, "data": chunk} + return upload_url, rest_args + + response = await self._send_request_with_retry( + "PUT", generate_url_and_rest_args, chunk_id + ) + response.raise_for_status() + meta.gcs_file_header_digest = gcs_headers[GCS_METADATA_SFC_DIGEST] + meta.gcs_file_header_content_length = meta.upload_size + meta.gcs_file_header_encryption_metadata = json.loads( + gcs_headers.get(GCS_METADATA_ENCRYPTIONDATAPROP, "null") + ) + + async def download_chunk(self, chunk_id: int) -> None: + meta = self.meta + + def generate_url_and_rest_args() -> ( + tuple[str, dict[str, dict[str, str] | bool]] + ): + gcs_headers = {} + if not self.presigned_url: + download_url = self.generate_file_url( + self.stage_info["location"], meta.src_file_name.lstrip("/") + ) + access_token = self.security_token + gcs_headers["Authorization"] = f"Bearer {access_token}" + else: + download_url = self.presigned_url + rest_args = {"headers": gcs_headers} + return download_url, rest_args + + response = await self._send_request_with_retry( + "GET", generate_url_and_rest_args, chunk_id + ) + response.raise_for_status() + + self.write_downloaded_chunk(chunk_id, await response.read()) + + encryption_metadata = None + + if response.headers.get(GCS_METADATA_ENCRYPTIONDATAPROP, None): + encryptiondata = json.loads( + response.headers[GCS_METADATA_ENCRYPTIONDATAPROP] + ) + + if encryptiondata: + encryption_metadata = EncryptionMetadata( + key=encryptiondata["WrappedContentKey"]["EncryptedKey"], + iv=encryptiondata["ContentEncryptionIV"], + matdesc=( + response.headers[GCS_METADATA_MATDESC_KEY] + if GCS_METADATA_MATDESC_KEY in response.headers + else None + ), + ) + + meta.gcs_file_header_digest = response.headers.get(GCS_METADATA_SFC_DIGEST) + meta.gcs_file_header_content_length = len(await response.read()) + meta.gcs_file_header_encryption_metadata = encryption_metadata + + async def finish_download(self) -> None: + await SnowflakeStorageClientAsync.finish_download(self) + # Sadly, we can only determine the src file size after we've + # downloaded it, unlike the other cloud providers where the + # metadata can be read beforehand. + self.meta.src_file_size = os.path.getsize(self.full_dst_file_name) + + async def _update_presigned_url(self) -> None: + """Updates the file metas with presigned urls if any. + + Currently only the file metas generated for PUT/GET on a GCP account need the presigned urls. + """ + logger.debug("Updating presigned url") + + # Rewrite the command such that a new PUT call is made for each file + # represented by the regex (if present) separately. This is the only + # way to get the presigned url for that file. + file_path_to_be_replaced = self._get_local_file_path_from_put_command() + + if not file_path_to_be_replaced: + # This prevents GET statements to proceed + return + + # At this point the connector has already figured out and + # validated that the local file exists and has also decided + # upon the destination file name and the compression type. + # The only thing that's left to do is to get the presigned + # url for the destination file. If the command originally + # referred to a single file, then the presigned url got in + # that case is simply ignore, since the file name is not what + # we want. + + # GS only looks at the file name at the end of local file + # path to figure out the remote object name. Hence the prefix + # for local path is not necessary in the reconstructed command. + file_path_to_replace_with = self.meta.dst_file_name + command_with_single_file = self._command + command_with_single_file = command_with_single_file.replace( + file_path_to_be_replaced, file_path_to_replace_with + ) + + logger.debug("getting presigned url for %s", file_path_to_replace_with) + ret = await self._cursor._execute_helper(command_with_single_file) + + stage_info = ret.get("data", dict()).get("stageInfo", dict()) + self.meta.presigned_url = stage_info.get("presignedUrl") + self.presigned_url = stage_info.get("presignedUrl") + + async def get_file_header(self, filename: str) -> FileHeader | None: + """Gets the remote file's metadata. + + Args: + filename: Not applicable to GCS. + + Returns: + The file header, with expected properties populated or None, based on how the request goes with the + storage provider. + + Notes: + Sometimes this method is called to verify that the file has indeed been uploaded. In cases of presigned + url, we have no way of verifying that, except with the http status code of 200 which we have already + confirmed and set the meta.result_status = UPLOADED/DOWNLOADED. + """ + meta = self.meta + if ( + meta.result_status == ResultStatus.UPLOADED + or meta.result_status == ResultStatus.DOWNLOADED + ): + return FileHeader( + digest=meta.gcs_file_header_digest, + content_length=meta.gcs_file_header_content_length, + encryption_metadata=meta.gcs_file_header_encryption_metadata, + ) + elif self.presigned_url: + meta.result_status = ResultStatus.NOT_FOUND_FILE + else: + + def generate_url_and_authenticated_headers(): + url = self.generate_file_url( + self.stage_info["location"], filename.lstrip("/") + ) + gcs_headers = {"Authorization": f"Bearer {self.security_token}"} + rest_args = {"headers": gcs_headers} + return url, rest_args + + retry_id = "HEAD" + self.retry_count[retry_id] = 0 + response = await self._send_request_with_retry( + "HEAD", generate_url_and_authenticated_headers, retry_id + ) + if response.status == 404: + meta.result_status = ResultStatus.NOT_FOUND_FILE + return None + elif response.status == 200: + digest = response.headers.get(GCS_METADATA_SFC_DIGEST, None) + content_length = int(response.headers.get("content-length", "0")) + + encryption_metadata = EncryptionMetadata("", "", "") + if response.headers.get(GCS_METADATA_ENCRYPTIONDATAPROP, None): + encryption_data = json.loads( + response.headers[GCS_METADATA_ENCRYPTIONDATAPROP] + ) + + if encryption_data: + encryption_metadata = EncryptionMetadata( + key=encryption_data["WrappedContentKey"]["EncryptedKey"], + iv=encryption_data["ContentEncryptionIV"], + matdesc=( + response.headers[GCS_METADATA_MATDESC_KEY] + if GCS_METADATA_MATDESC_KEY in response.headers + else None + ), + ) + meta.result_status = ResultStatus.UPLOADED + return FileHeader( + digest=digest, + content_length=content_length, + encryption_metadata=encryption_metadata, + ) + response.raise_for_status() + return None diff --git a/src/snowflake/connector/aio/_network.py b/src/snowflake/connector/aio/_network.py new file mode 100644 index 0000000000..d5a20be348 --- /dev/null +++ b/src/snowflake/connector/aio/_network.py @@ -0,0 +1,878 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import collections +import contextlib +import gzip +import itertools +import json +import logging +import re +import uuid +from typing import TYPE_CHECKING, Any + +import OpenSSL.SSL +from urllib3.util.url import parse_url + +from ..compat import FORBIDDEN, OK, UNAUTHORIZED, urlencode, urlparse +from ..constants import ( + _CONNECTIVITY_ERR_MSG, + HTTP_HEADER_ACCEPT, + HTTP_HEADER_CONTENT_TYPE, + HTTP_HEADER_SERVICE_NAME, + HTTP_HEADER_USER_AGENT, + OCSPMode, +) +from ..errorcode import ( + ER_CONNECTION_IS_CLOSED, + ER_CONNECTION_TIMEOUT, + ER_FAILED_TO_CONNECT_TO_DB, + ER_FAILED_TO_RENEW_SESSION, + ER_FAILED_TO_REQUEST, + ER_RETRYABLE_CODE, +) +from ..errors import ( + DatabaseError, + Error, + ForbiddenError, + InterfaceError, + OperationalError, + ProgrammingError, + RefreshTokenError, +) +from ..network import ( + ACCEPT_TYPE_APPLICATION_SNOWFLAKE, + BAD_REQUEST_GS_CODE, + CONTENT_TYPE_APPLICATION_JSON, + DEFAULT_SOCKET_CONNECT_TIMEOUT, + EXTERNAL_BROWSER_AUTHENTICATOR, + HEADER_AUTHORIZATION_KEY, + HEADER_SNOWFLAKE_TOKEN, + ID_TOKEN_EXPIRED_GS_CODE, + IMPLEMENTATION, + MASTER_TOKEN_EXPIRED_GS_CODE, + MASTER_TOKEN_INVALD_GS_CODE, + MASTER_TOKEN_NOTFOUND_GS_CODE, + NO_TOKEN, + PLATFORM, + PYTHON_VERSION, + QUERY_IN_PROGRESS_ASYNC_CODE, + QUERY_IN_PROGRESS_CODE, + REQUEST_ID, + REQUEST_TYPE_RENEW, + SESSION_EXPIRED_GS_CODE, + SNOWFLAKE_CONNECTOR_VERSION, + ReauthenticationRequest, + RetryRequest, +) +from ..network import SessionPool as SessionPoolSync +from ..network import SnowflakeRestful as SnowflakeRestfulSync +from ..network import get_http_retryable_error, is_login_request, is_retryable_http_code +from ..secret_detector import SecretDetector +from ..sqlstate import ( + SQLSTATE_CONNECTION_NOT_EXISTS, + SQLSTATE_CONNECTION_REJECTED, + SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, +) +from ..time_util import TimeoutBackoffCtx +from ._description import CLIENT_NAME +from ._ssl_connector import SnowflakeSSLConnector + +if TYPE_CHECKING: + from snowflake.connector.aio import SnowflakeConnection + +logger = logging.getLogger(__name__) + +PYTHON_CONNECTOR_USER_AGENT = f"{CLIENT_NAME}/{SNOWFLAKE_CONNECTOR_VERSION} ({PLATFORM}) {IMPLEMENTATION}/{PYTHON_VERSION}" + +try: + import aiohttp +except ImportError: + logger.warning("Please install aiohttp to use asyncio features.") + raise + + +def raise_okta_unauthorized_error( + connection: SnowflakeConnection | None, response: aiohttp.ClientResponse +) -> None: + Error.errorhandler_wrapper( + connection, + None, + DatabaseError, + { + "msg": f"Failed to get authentication by OKTA: {response.status}: {response.reason}", + "errno": ER_FAILED_TO_CONNECT_TO_DB, + "sqlstate": SQLSTATE_CONNECTION_REJECTED, + }, + ) + + +def raise_failed_request_error( + connection: SnowflakeConnection | None, + url: str, + method: str, + response: aiohttp.ClientResponse, +) -> None: + Error.errorhandler_wrapper( + connection, + None, + InterfaceError, + { + "msg": f"{response.status} {response.reason}: {method} {url}", + "errno": ER_FAILED_TO_REQUEST, + "sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + }, + ) + + +class SessionPool(SessionPoolSync): + def __init__(self, rest: SnowflakeRestful) -> None: + super().__init__(rest) + + async def close(self): + """Closes all active and idle sessions in this session pool.""" + if self._active_sessions: + logger.debug(f"Closing {len(self._active_sessions)} active sessions") + for s in itertools.chain(set(self._active_sessions), set(self._idle_sessions)): + try: + await s.close() + except Exception as e: + logger.info(f"Session cleanup failed: {e}") + self._active_sessions.clear() + self._idle_sessions.clear() + + +class SnowflakeRestful(SnowflakeRestfulSync): + def __init__( + self, + host: str = "127.0.0.1", + port: int = 8080, + protocol: str = "http", + inject_client_pause: int = 0, + connection: SnowflakeConnection | None = None, + ): + super().__init__(host, port, protocol, inject_client_pause, connection) + self._lock_token = asyncio.Lock() + self._sessions_map: dict[str | None, SessionPool] = collections.defaultdict( + lambda: SessionPool(self) + ) + self._ocsp_mode = ( + self._connection._ocsp_mode() if self._connection else OCSPMode.FAIL_OPEN + ) + if self._connection and self._connection.proxy_host: + self._get_proxy_headers = lambda url: {"Host": parse_url(url).hostname} + else: + self._get_proxy_headers = lambda _: None + + async def close(self) -> None: + if hasattr(self, "_token"): + del self._token + if hasattr(self, "_master_token"): + del self._master_token + if hasattr(self, "_id_token"): + del self._id_token + if hasattr(self, "_mfa_token"): + del self._mfa_token + + for session_pool in self._sessions_map.values(): + await session_pool.close() + + async def request( + self, + url, + body=None, + method: str = "post", + client: str = "sfsql", + timeout: int | None = None, + _no_results: bool = False, + _include_retry_params: bool = False, + _no_retry: bool = False, + ): + if body is None: + body = {} + if self.master_token is None and self.token is None: + Error.errorhandler_wrapper( + self._connection, + None, + DatabaseError, + { + "msg": "Connection is closed", + "errno": ER_CONNECTION_IS_CLOSED, + "sqlstate": SQLSTATE_CONNECTION_NOT_EXISTS, + }, + ) + + if client == "sfsql": + accept_type = ACCEPT_TYPE_APPLICATION_SNOWFLAKE + else: + accept_type = CONTENT_TYPE_APPLICATION_JSON + + headers = { + HTTP_HEADER_CONTENT_TYPE: CONTENT_TYPE_APPLICATION_JSON, + HTTP_HEADER_ACCEPT: accept_type, + HTTP_HEADER_USER_AGENT: PYTHON_CONNECTOR_USER_AGENT, + } + try: + from opentelemetry.propagate import inject + + inject(headers) + except ModuleNotFoundError as e: + logger.debug(f"Opentelemtry otel injection failed because of: {e}") + if self._connection.service_name: + headers[HTTP_HEADER_SERVICE_NAME] = self._connection.service_name + if method == "post": + return await self._post_request( + url, + headers, + json.dumps(body), + token=self.token, + _no_results=_no_results, + timeout=timeout, + _include_retry_params=_include_retry_params, + no_retry=_no_retry, + ) + else: + return await self._get_request( + url, + headers, + token=self.token, + timeout=timeout, + ) + + async def update_tokens( + self, + session_token, + master_token, + master_validity_in_seconds=None, + id_token=None, + mfa_token=None, + ) -> None: + """Updates session and master tokens and optionally temporary credential.""" + async with self._lock_token: + self._token = session_token + self._master_token = master_token + self._id_token = id_token + self._mfa_token = mfa_token + self._master_validity_in_seconds = master_validity_in_seconds + + async def _renew_session(self): + """Renew a session and master token.""" + return await self._token_request(REQUEST_TYPE_RENEW) + + async def _token_request(self, request_type): + logger.debug( + "updating session. master_token: {}".format( + "****" if self.master_token else None + ) + ) + headers = { + HTTP_HEADER_CONTENT_TYPE: CONTENT_TYPE_APPLICATION_JSON, + HTTP_HEADER_ACCEPT: CONTENT_TYPE_APPLICATION_JSON, + HTTP_HEADER_USER_AGENT: PYTHON_CONNECTOR_USER_AGENT, + } + if self._connection.service_name: + headers[HTTP_HEADER_SERVICE_NAME] = self._connection.service_name + request_id = str(uuid.uuid4()) + logger.debug("request_id: %s", request_id) + url = "/session/token-request?" + urlencode({REQUEST_ID: request_id}) + + # NOTE: ensure an empty key if master token is not set. + # This avoids HTTP 400. + header_token = self.master_token or "" + body = { + "oldSessionToken": self.token, + "requestType": request_type, + } + ret = await self._post_request( + url, + headers, + json.dumps(body), + token=header_token, + ) + if ret.get("success") and ret.get("data", {}).get("sessionToken"): + logger.debug("success: %s", SecretDetector.mask_secrets(str(ret))) + await self.update_tokens( + ret["data"]["sessionToken"], + ret["data"].get("masterToken"), + master_validity_in_seconds=ret["data"].get("masterValidityInSeconds"), + ) + logger.debug("updating session completed") + return ret + else: + logger.debug("failed: %s", SecretDetector.mask_secrets(str(ret))) + err = ret.get("message") + if err is not None and ret.get("data"): + err += ret["data"].get("errorMessage", "") + errno = ret.get("code") or ER_FAILED_TO_RENEW_SESSION + if errno in ( + ID_TOKEN_EXPIRED_GS_CODE, + SESSION_EXPIRED_GS_CODE, + MASTER_TOKEN_NOTFOUND_GS_CODE, + MASTER_TOKEN_EXPIRED_GS_CODE, + MASTER_TOKEN_INVALD_GS_CODE, + BAD_REQUEST_GS_CODE, + ): + raise ReauthenticationRequest( + ProgrammingError( + msg=err, + errno=int(errno), + sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + ) + ) + Error.errorhandler_wrapper( + self._connection, + None, + ProgrammingError, + { + "msg": err, + "errno": int(errno), + "sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + }, + ) + + async def _heartbeat(self) -> Any | dict[Any, Any] | None: + headers = { + HTTP_HEADER_CONTENT_TYPE: CONTENT_TYPE_APPLICATION_JSON, + HTTP_HEADER_ACCEPT: CONTENT_TYPE_APPLICATION_JSON, + HTTP_HEADER_USER_AGENT: PYTHON_CONNECTOR_USER_AGENT, + } + if self._connection.service_name: + headers[HTTP_HEADER_SERVICE_NAME] = self._connection.service_name + request_id = str(uuid.uuid4()) + logger.debug("request_id: %s", request_id) + url = "/session/heartbeat?" + urlencode({REQUEST_ID: request_id}) + ret = await self._post_request( + url, + headers, + None, + token=self.token, + ) + if not ret.get("success"): + logger.error("Failed to heartbeat. code: %s, url: %s", ret.get("code"), url) + return ret + + async def delete_session(self, retry: bool = False) -> None: + """Deletes the session.""" + if self.master_token is None: + Error.errorhandler_wrapper( + self._connection, + None, + DatabaseError, + { + "msg": "Connection is closed", + "errno": ER_CONNECTION_IS_CLOSED, + "sqlstate": SQLSTATE_CONNECTION_NOT_EXISTS, + }, + ) + + url = "/session?" + urlencode({"delete": "true"}) + headers = { + HTTP_HEADER_CONTENT_TYPE: CONTENT_TYPE_APPLICATION_JSON, + HTTP_HEADER_ACCEPT: CONTENT_TYPE_APPLICATION_JSON, + HTTP_HEADER_USER_AGENT: PYTHON_CONNECTOR_USER_AGENT, + } + if self._connection.service_name: + headers[HTTP_HEADER_SERVICE_NAME] = self._connection.service_name + + body = {} + retry_limit = 3 if retry else 1 + num_retries = 0 + should_retry = True + while should_retry and (num_retries < retry_limit): + try: + should_retry = False + ret = await self._post_request( + url, + headers, + json.dumps(body), + token=self.token, + timeout=5, + no_retry=True, + ) + if not ret: + if retry: + should_retry = True + else: + return + elif ret.get("success"): + return + err = ret.get("message") + if err is not None and ret.get("data"): + err += ret["data"].get("errorMessage", "") + # no exception is raised + logger.debug("error in deleting session. ignoring...: %s", err) + except Exception as e: + logger.debug("error in deleting session. ignoring...: %s", e) + finally: + num_retries += 1 + + async def _get_request( + self, + url: str, + headers: dict[str, str], + token: str = None, + timeout: int | None = None, + is_fetch_query_status: bool = False, + ) -> dict[str, Any]: + if "Content-Encoding" in headers: + del headers["Content-Encoding"] + if "Content-Length" in headers: + del headers["Content-Length"] + + full_url = f"{self.server_url}{url}" + ret = await self.fetch( + "get", + full_url, + headers, + timeout=timeout, + token=token, + is_fetch_query_status=is_fetch_query_status, + ) + if ret.get("code") == SESSION_EXPIRED_GS_CODE: + try: + ret = await self._renew_session() + except ReauthenticationRequest as ex: + if self._connection._authenticator != EXTERNAL_BROWSER_AUTHENTICATOR: + raise ex.cause + ret = await self._connection._reauthenticate() + logger.debug( + "ret[code] = {code} after renew_session".format( + code=(ret.get("code", "N/A")) + ) + ) + if ret.get("success"): + return await self._get_request( + url, + headers, + token=self.token, + is_fetch_query_status=is_fetch_query_status, + ) + + return ret + + async def _post_request( + self, + url, + headers, + body, + token=None, + timeout: int | None = None, + socket_timeout: int | None = None, + _no_results: bool = False, + no_retry: bool = False, + _include_retry_params: bool = False, + ) -> dict[str, Any]: + full_url = f"{self.server_url}{url}" + if self._connection._probe_connection: + # TODO: SNOW-1572318 for probe connection + raise NotImplementedError("probe_connection is not supported in asyncio") + + ret = await self.fetch( + "post", + full_url, + headers, + data=body, + timeout=timeout, + token=token, + no_retry=no_retry, + _include_retry_params=_include_retry_params, + socket_timeout=socket_timeout, + ) + logger.debug( + "ret[code] = {code}, after post request".format( + code=(ret.get("code", "N/A")) + ) + ) + + if ret.get("code") == MASTER_TOKEN_EXPIRED_GS_CODE: + self._connection.expired = True + elif ret.get("code") == SESSION_EXPIRED_GS_CODE: + try: + ret = await self._renew_session() + except ReauthenticationRequest as ex: + if self._connection._authenticator != EXTERNAL_BROWSER_AUTHENTICATOR: + raise ex.cause + ret = await self._connection._reauthenticate() + logger.debug( + "ret[code] = {code} after renew_session".format( + code=(ret.get("code", "N/A")) + ) + ) + if ret.get("success"): + return await self._post_request( + url, headers, body, token=self.token, timeout=timeout + ) + + if isinstance(ret.get("data"), dict) and ret["data"].get("queryId"): + logger.debug("Query id: {}".format(ret["data"]["queryId"])) + + if ret.get("code") == QUERY_IN_PROGRESS_ASYNC_CODE and _no_results: + return ret + + while ret.get("code") in (QUERY_IN_PROGRESS_CODE, QUERY_IN_PROGRESS_ASYNC_CODE): + if self._inject_client_pause > 0: + logger.debug("waiting for %s...", self._inject_client_pause) + await asyncio.sleep(self._inject_client_pause) + # ping pong + result_url = ret["data"]["getResultUrl"] + logger.debug("ping pong starting...") + ret = await self._get_request( + result_url, + headers, + token=self.token, + timeout=timeout, + is_fetch_query_status=bool( + re.match(r"^/queries/.+/result$", result_url) + ), + ) + logger.debug("ret[code] = %s", ret.get("code", "N/A")) + logger.debug("ping pong done") + + return ret + + async def fetch( + self, + method: str, + full_url: str, + headers: dict[str, Any], + data: dict[str, Any] | None = None, + timeout: int | None = None, + **kwargs, + ) -> dict[Any, Any]: + """Carry out API request with session management.""" + + class RetryCtx(TimeoutBackoffCtx): + def __init__( + self, + _include_retry_params: bool = False, + _include_retry_reason: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.retry_reason = 0 + self._include_retry_params = _include_retry_params + self._include_retry_reason = _include_retry_reason + + def add_retry_params(self, full_url: str) -> str: + if self._include_retry_params and self.current_retry_count > 0: + retry_params = { + "clientStartTime": self._start_time_millis, + "retryCount": self.current_retry_count, + } + if self._include_retry_reason: + retry_params.update({"retryReason": self.retry_reason}) + suffix = urlencode(retry_params) + sep = "&" if urlparse(full_url).query else "?" + return full_url + sep + suffix + else: + return full_url + + include_retry_reason = self._connection._enable_retry_reason_in_query_response + include_retry_params = kwargs.pop("_include_retry_params", False) + + async with self._use_requests_session(full_url) as session: + retry_ctx = RetryCtx( + _include_retry_params=include_retry_params, + _include_retry_reason=include_retry_reason, + timeout=( + timeout if timeout is not None else self._connection.network_timeout + ), + backoff_generator=self._connection._backoff_generator, + ) + + retry_ctx.set_start_time() + while True: + ret = await self._request_exec_wrapper( + session, method, full_url, headers, data, retry_ctx, **kwargs + ) + if ret is not None: + return ret + + async def _request_exec_wrapper( + self, + session, + method, + full_url, + headers, + data, + retry_ctx, + no_retry: bool = False, + token=NO_TOKEN, + **kwargs, + ): + conn = self._connection + logger.debug( + "remaining request timeout: %s ms, retry cnt: %s", + retry_ctx.remaining_time_millis if retry_ctx.timeout is not None else "N/A", + retry_ctx.current_retry_count + 1, + ) + + full_url = retry_ctx.add_retry_params(full_url) + full_url = SnowflakeRestful.add_request_guid(full_url) + is_fetch_query_status = kwargs.pop("is_fetch_query_status", False) + try: + return_object = await self._request_exec( + session=session, + method=method, + full_url=full_url, + headers=headers, + data=data, + token=token, + **kwargs, + ) + if return_object is not None: + return return_object + if is_fetch_query_status: + err_msg = ( + "fetch query status failed and http request returned None, this" + " is usually caused by transient network failures, retrying..." + ) + logger.info(err_msg) + raise RetryRequest(err_msg) + self._handle_unknown_error(method, full_url, headers, data, conn) + return {} + except RetryRequest as e: + cause = e.args[0] + if no_retry: + self.log_and_handle_http_error_with_cause( + e, + full_url, + method, + retry_ctx.timeout, + retry_ctx.current_retry_count, + conn, + timed_out=False, + ) + return {} # required for tests + if not retry_ctx.should_retry: + self.log_and_handle_http_error_with_cause( + e, + full_url, + method, + retry_ctx.timeout, + retry_ctx.current_retry_count, + conn, + ) + return {} # required for tests + + logger.debug( + "retrying: errorclass=%s, " + "error=%s, " + "counter=%s, " + "sleeping=%s(s)", + type(cause), + cause, + retry_ctx.current_retry_count + 1, + retry_ctx.current_sleep_time, + ) + await asyncio.sleep(float(retry_ctx.current_sleep_time)) + retry_ctx.increment() + + reason = getattr(cause, "errno", 0) + retry_ctx.retry_reason = reason + # notes: in sync implementation we check ECONNRESET in error message and close low level urllib session + # we do not have the logic here because aiohttp handles low level connection close-reopen for us + return None # retry + except Exception as e: + if not no_retry: + raise e + logger.debug("Ignored error", exc_info=True) + return {} + + async def _request_exec( + self, + session: aiohttp.ClientSession, + method, + full_url, + headers, + data, + token, + catch_okta_unauthorized_error: bool = False, + is_raw_text: bool = False, + is_raw_binary: bool = False, + binary_data_handler=None, + socket_timeout: int | None = None, + is_okta_authentication: bool = False, + ): + if socket_timeout is None: + if self._connection.socket_timeout is not None: + logger.debug("socket_timeout specified in connection") + socket_timeout = self._connection.socket_timeout + else: + socket_timeout = DEFAULT_SOCKET_CONNECT_TIMEOUT + logger.debug("socket timeout: %s", socket_timeout) + + try: + if not catch_okta_unauthorized_error and data and len(data) > 0: + headers["Content-Encoding"] = "gzip" + input_data = gzip.compress(data.encode("utf-8")) + else: + input_data = data + + if HEADER_AUTHORIZATION_KEY in headers: + del headers[HEADER_AUTHORIZATION_KEY] + if token != NO_TOKEN: + headers[HEADER_AUTHORIZATION_KEY] = HEADER_SNOWFLAKE_TOKEN.format( + token=token + ) + + # socket timeout is constant. You should be able to receive + # the response within the time. If not, asyncio.TimeoutError is raised. + + # delta compared to sync: + # - in sync, we specify "verify" to True; in aiohttp, + # the counter parameter is "ssl" and it already defaults to True + raw_ret = await session.request( + method=method, + url=full_url, + headers=headers, + data=input_data, + timeout=aiohttp.ClientTimeout(socket_timeout), + proxy_headers=self._get_proxy_headers(full_url), + ) + try: + if raw_ret.status == OK: + logger.debug("SUCCESS") + if is_raw_text: + ret = await raw_ret.text() + elif is_raw_binary: + # TODO: SNOW-1738595 for is_raw_binary support + raise NotImplementedError( + "reading raw binary data is not supported in asyncio connector," + " please open a feature request issue in" + " github: https://github.com/snowflakedb/snowflake-connector-python/issues/new/choose" + ) + else: + ret = await raw_ret.json() + return ret + + if is_login_request(full_url) and raw_ret.status == FORBIDDEN: + raise ForbiddenError + + elif is_retryable_http_code(raw_ret.status): + err = get_http_retryable_error(raw_ret.status) + # retryable server exceptions + if is_okta_authentication: + raise RefreshTokenError( + msg="OKTA authentication requires token refresh." + ) + if is_login_request(full_url): + logger.debug( + "Received retryable response code while logging in. Will be handled by " + f"authenticator. Ignore the following. Error stack: {err}", + exc_info=True, + ) + raise OperationalError( + msg="Login request is retryable. Will be handled by authenticator", + errno=ER_RETRYABLE_CODE, + ) + else: + logger.debug(f"{err}. Retrying...") + raise RetryRequest(err) + + elif raw_ret.status == UNAUTHORIZED and catch_okta_unauthorized_error: + # OKTA Unauthorized errors + raise_okta_unauthorized_error(self._connection, raw_ret) + return None # required for tests + else: + raise_failed_request_error( + self._connection, full_url, method, raw_ret + ) + return None # required for tests + finally: + raw_ret.close() # ensure response is closed + except (aiohttp.ClientSSLError, aiohttp.ClientConnectorSSLError) as se: + msg = f"Hit non-retryable SSL error, {str(se)}.\n{_CONNECTIVITY_ERR_MSG}" + logger.debug(msg) + # the following code is for backward compatibility with old versions of python connector which calls + # self._handle_unknown_error to process SSLError + Error.errorhandler_wrapper( + self._connection, + None, + OperationalError, + { + "msg": msg, + "errno": ER_FAILED_TO_REQUEST, + }, + ) + except ( + aiohttp.ClientConnectionError, + aiohttp.ClientConnectorError, + aiohttp.ConnectionTimeoutError, + asyncio.TimeoutError, + OpenSSL.SSL.SysCallError, + KeyError, # SNOW-39175: asn1crypto.keys.PublicKeyInfo + ValueError, + RuntimeError, + AttributeError, # json decoding error + ) as err: + if isinstance(err, RuntimeError) and "Event loop is closed" in str(err): + logger.info( + "If you see the logging error message 'RuntimeError: Event loop is closed' during program exit, it probably indicates that the connection was not closed properly before the event loop was shut down. Please use SnowflakeConnection.close() to close connection." + ) + raise err + if is_login_request(full_url): + logger.debug( + "Hit a timeout error while logging in. Will be handled by " + f"authenticator. Ignore the following. Error stack: {err}", + exc_info=True, + ) + raise OperationalError( + msg="ConnectionTimeout occurred during login. Will be handled by authenticator", + errno=ER_CONNECTION_TIMEOUT, + ) + else: + logger.debug( + "Hit retryable client error. Retrying... Ignore the following " + f"error stack: {err}", + exc_info=True, + ) + raise RetryRequest(err) + except Exception as err: + if isinstance(err, (Error, RetryRequest, ReauthenticationRequest)): + raise err + raise OperationalError( + msg=f"Unexpected error occurred during request execution: {err}" + "Please check the stack trace for more information and retry the operation." + "If you think this is a bug, please collect the error information and open a bug report in github: " + "https://github.com/snowflakedb/snowflake-connector-python/issues/new/choose.", + errno=ER_FAILED_TO_REQUEST, + ) from err + + def make_requests_session(self) -> aiohttp.ClientSession: + s = aiohttp.ClientSession( + connector=SnowflakeSSLConnector(snowflake_ocsp_mode=self._ocsp_mode), + trust_env=True, # this is for proxy support, proxy.set_proxy will set envs and trust_env allows reading env + ) + return s + + @contextlib.asynccontextmanager + async def _use_requests_session( + self, url: str | None = None + ) -> aiohttp.ClientSession: + if self._connection.disable_request_pooling: + session = self.make_requests_session() + try: + yield session + finally: + await session.close() + else: + try: + hostname = urlparse(url).hostname + except Exception: + hostname = None + + session_pool: SessionPool = self._sessions_map[hostname] + session = session_pool.get_session() + logger.debug(f"Session status for SessionPool '{hostname}', {session_pool}") + try: + yield session + finally: + session_pool.return_session(session) + logger.debug( + f"Session status for SessionPool '{hostname}', {session_pool}" + ) diff --git a/src/snowflake/connector/aio/_ocsp_asn1crypto.py b/src/snowflake/connector/aio/_ocsp_asn1crypto.py new file mode 100644 index 0000000000..963d954a4f --- /dev/null +++ b/src/snowflake/connector/aio/_ocsp_asn1crypto.py @@ -0,0 +1,49 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import ssl +from collections import OrderedDict +from logging import getLogger + +from aiohttp.client_proto import ResponseHandler +from asn1crypto.x509 import Certificate + +from ..ocsp_asn1crypto import SnowflakeOCSPAsn1Crypto as SnowflakeOCSPAsn1CryptoSync +from ._ocsp_snowflake import SnowflakeOCSP + +logger = getLogger(__name__) + + +class SnowflakeOCSPAsn1Crypto(SnowflakeOCSP, SnowflakeOCSPAsn1CryptoSync): + + def extract_certificate_chain(self, connection: ResponseHandler): + ssl_object = connection.transport.get_extra_info("ssl_object") + if not ssl_object: + raise RuntimeError( + "Unable to get the SSL object from the asyncio transport to perform OCSP validation." + "Please open an issue on the Snowflake Python Connector GitHub repository " + "and provide your execution environment" + " details: https://github.com/snowflakedb/snowflake-connector-python/issues/new/choose." + "As a workaround, you can create the connection with `insecure_mode=True` to skip OCSP Validation." + ) + + cert_map = OrderedDict() + # in Python 3.10, get_unverified_chain was introduced as a + # private method: https://github.com/python/cpython/pull/25467 + # which returns all the peer certs in the chain. + # Python 3.13 will have the method get_unverified_chain publicly available on ssl.SSLSocket class + # https://docs.python.org/pl/3.13/library/ssl.html#ssl.SSLSocket.get_unverified_chain + unverified_chain = ssl_object._sslobj.get_unverified_chain() + logger.debug("# of certificates: %s", len(unverified_chain)) + + for cert in unverified_chain: + cert = Certificate.load(ssl.PEM_cert_to_DER_cert(cert.public_bytes())) + logger.debug( + "subject: %s, issuer: %s", cert.subject.native, cert.issuer.native + ) + cert_map[cert.subject.sha256] = cert + + return self.create_pair_issuer_subject(cert_map) diff --git a/src/snowflake/connector/aio/_ocsp_snowflake.py b/src/snowflake/connector/aio/_ocsp_snowflake.py new file mode 100644 index 0000000000..b7e042cea5 --- /dev/null +++ b/src/snowflake/connector/aio/_ocsp_snowflake.py @@ -0,0 +1,565 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import json +import os +import time +from logging import getLogger +from typing import Any + +import aiohttp +from aiohttp.client_proto import ResponseHandler +from asn1crypto.ocsp import CertId +from asn1crypto.x509 import Certificate + +import snowflake.connector.ocsp_snowflake +from snowflake.connector.backoff_policies import exponential_backoff +from snowflake.connector.compat import OK +from snowflake.connector.constants import HTTP_HEADER_USER_AGENT +from snowflake.connector.errorcode import ( + ER_OCSP_FAILED_TO_CONNECT_CACHE_SERVER, + ER_OCSP_RESPONSE_CACHE_DOWNLOAD_FAILED, + ER_OCSP_RESPONSE_FETCH_EXCEPTION, + ER_OCSP_RESPONSE_FETCH_FAILURE, + ER_OCSP_RESPONSE_UNAVAILABLE, + ER_OCSP_URL_INFO_MISSING, +) +from snowflake.connector.errors import RevocationCheckError +from snowflake.connector.network import PYTHON_CONNECTOR_USER_AGENT +from snowflake.connector.ocsp_snowflake import OCSPCache, OCSPResponseValidationResult +from snowflake.connector.ocsp_snowflake import OCSPServer as OCSPServerSync +from snowflake.connector.ocsp_snowflake import OCSPTelemetryData +from snowflake.connector.ocsp_snowflake import SnowflakeOCSP as SnowflakeOCSPSync +from snowflake.connector.url_util import extract_top_level_domain_from_hostname + +logger = getLogger(__name__) + + +class OCSPServer(OCSPServerSync): + async def download_cache_from_server(self, ocsp): + if self.CACHE_SERVER_ENABLED: + # if any of them is not cache, download the cache file from + # OCSP response cache server. + try: + retval = await OCSPServer._download_ocsp_response_cache( + ocsp, self.CACHE_SERVER_URL + ) + if not retval: + raise RevocationCheckError( + msg="OCSP Cache Server Unavailable.", + errno=ER_OCSP_RESPONSE_CACHE_DOWNLOAD_FAILED, + ) + logger.debug( + "downloaded OCSP response cache file from %s", self.CACHE_SERVER_URL + ) + # len(OCSP_RESPONSE_VALIDATION_CACHE) is thread-safe, however, we do not want to + # block for logging purpose, thus using len(OCSP_RESPONSE_VALIDATION_CACHE._cache) here. + logger.debug( + "# of certificates: %u", + len( + snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE._cache + ), + ) + except RevocationCheckError as rce: + logger.debug( + "OCSP Response cache download failed. The client" + "will reach out to the OCSP Responder directly for" + "any missing OCSP responses %s\n" % rce.msg + ) + raise + + @staticmethod + async def _download_ocsp_response_cache(ocsp, url, do_retry: bool = True) -> bool: + """Downloads OCSP response cache from the cache server.""" + headers = {HTTP_HEADER_USER_AGENT: PYTHON_CONNECTOR_USER_AGENT} + sf_timeout = SnowflakeOCSP.OCSP_CACHE_SERVER_CONNECTION_TIMEOUT + + try: + start_time = time.time() + logger.debug("started downloading OCSP response cache file: %s", url) + + if ocsp.test_mode is not None: + test_timeout = os.getenv( + "SF_TEST_OCSP_CACHE_SERVER_CONNECTION_TIMEOUT", None + ) + sf_cache_server_url = os.getenv("SF_TEST_OCSP_CACHE_SERVER_URL", None) + if test_timeout is not None: + sf_timeout = int(test_timeout) + if sf_cache_server_url is not None: + url = sf_cache_server_url + + async with aiohttp.ClientSession() as session: + max_retry = SnowflakeOCSP.OCSP_CACHE_SERVER_MAX_RETRY if do_retry else 1 + sleep_time = 1 + backoff = exponential_backoff()() + for _ in range(max_retry): + response = await session.get( + url, + timeout=sf_timeout, # socket timeout + headers=headers, + ) + if response.status == OK: + ocsp.decode_ocsp_response_cache(await response.json()) + elapsed_time = time.time() - start_time + logger.debug( + "ended downloading OCSP response cache file. " + "elapsed time: %ss", + elapsed_time, + ) + break + elif max_retry > 1: + sleep_time = next(backoff) + logger.debug( + "OCSP server returned %s. Retrying in %s(s)", + response.status, + sleep_time, + ) + await asyncio.sleep(sleep_time) + else: + logger.error( + "Failed to get OCSP response after %s attempt.", max_retry + ) + return False + return True + except Exception as e: + logger.debug("Failed to get OCSP response cache from %s: %s", url, e) + raise RevocationCheckError( + msg=f"Failed to get OCSP Response Cache from {url}: {e}", + errno=ER_OCSP_FAILED_TO_CONNECT_CACHE_SERVER, + ) + + +class SnowflakeOCSP(SnowflakeOCSPSync): + + def __init__( + self, + ocsp_response_cache_uri=None, + use_ocsp_cache_server=None, + use_post_method: bool = True, + use_fail_open: bool = True, + **kwargs, + ) -> None: + self.test_mode = os.getenv("SF_OCSP_TEST_MODE", None) + + if self.test_mode == "true": + logger.debug("WARNING - DRIVER CONFIGURED IN TEST MODE") + + self._use_post_method = use_post_method + self.OCSP_CACHE_SERVER = OCSPServer( + top_level_domain=extract_top_level_domain_from_hostname( + kwargs.pop("hostname", None) + ) + ) + + self.debug_ocsp_failure_url = None + + if os.getenv("SF_OCSP_FAIL_OPEN") is not None: + # failOpen Env Variable is for internal usage/ testing only. + # Using it in production is not advised and not supported. + self.FAIL_OPEN = os.getenv("SF_OCSP_FAIL_OPEN").lower() == "true" + else: + self.FAIL_OPEN = use_fail_open + + SnowflakeOCSP.OCSP_CACHE.reset_ocsp_response_cache_uri(ocsp_response_cache_uri) + + if not OCSPServer.is_enabled_new_ocsp_endpoint(): + self.OCSP_CACHE_SERVER.reset_ocsp_dynamic_cache_server_url( + use_ocsp_cache_server + ) + + if not snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE: + SnowflakeOCSP.OCSP_CACHE.read_file(self) + + async def validate( + self, + hostname: str | None, + connection: ResponseHandler, + no_exception: bool = False, + ) -> ( + list[ + tuple[ + Exception | None, + Certificate, + Certificate, + CertId, + str | bytes, + ] + ] + | None + ): + """Validates the certificate is not revoked using OCSP.""" + logger.debug("validating certificate: %s", hostname) + + do_retry = SnowflakeOCSP.get_ocsp_retry_choice() + + m = not SnowflakeOCSP.OCSP_WHITELIST.match(hostname) + if m or hostname.startswith("ocspssd"): + logger.debug("skipping OCSP check: %s", hostname) + return [None, None, None, None, None] + + if OCSPServer.is_enabled_new_ocsp_endpoint(): + self.OCSP_CACHE_SERVER.reset_ocsp_endpoint(hostname) + + telemetry_data = OCSPTelemetryData() + telemetry_data.set_cache_enabled(self.OCSP_CACHE_SERVER.CACHE_SERVER_ENABLED) + telemetry_data.set_insecure_mode(False) + telemetry_data.set_sfc_peer_host(hostname) + telemetry_data.set_fail_open(self.is_enabled_fail_open()) + + try: + cert_data = self.extract_certificate_chain(connection) + except RevocationCheckError: + telemetry_data.set_event_sub_type( + OCSPTelemetryData.CERTIFICATE_EXTRACTION_FAILED + ) + logger.debug( + telemetry_data.generate_telemetry_data("RevocationCheckFailure") + ) + return None + + return await self._validate( + hostname, cert_data, telemetry_data, do_retry, no_exception + ) + + async def _validate( + self, + hostname: str | None, + cert_data: list[tuple[Certificate, Certificate]], + telemetry_data: OCSPTelemetryData, + do_retry: bool = True, + no_exception: bool = False, + ) -> list[tuple[Exception | None, Certificate, Certificate, CertId, bytes]]: + """Validate certs sequentially if OCSP response cache server is used.""" + results = await self._validate_certificates_sequential( + cert_data, telemetry_data, hostname, do_retry=do_retry + ) + + SnowflakeOCSP.OCSP_CACHE.update_file(self) + + any_err = False + for err, _, _, _, _ in results: + if isinstance(err, RevocationCheckError): + err.msg += f" for {hostname}" + if not no_exception and err is not None: + raise err + elif err is not None: + any_err = True + + logger.debug("ok" if not any_err else "failed") + return results + + async def _validate_issue_subject( + self, + issuer: Certificate, + subject: Certificate, + telemetry_data: OCSPTelemetryData, + hostname: str | None = None, + do_retry: bool = True, + ) -> tuple[ + tuple[bytes, bytes, bytes], + [Exception | None, Certificate, Certificate, CertId, bytes], + ]: + cert_id, req = self.create_ocsp_request(issuer, subject) + cache_key = self.decode_cert_id_key(cert_id) + ocsp_response_validation_result = ( + snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE.get( + cache_key + ) + ) + + if ( + ocsp_response_validation_result is None + or not ocsp_response_validation_result.validated + ): + r = await self.validate_by_direct_connection( + issuer, + subject, + telemetry_data, + hostname, + do_retry=do_retry, + cache_key=cache_key, + ) + return cache_key, r + else: + return cache_key, ( + ocsp_response_validation_result.exception, + ocsp_response_validation_result.issuer, + ocsp_response_validation_result.subject, + ocsp_response_validation_result.cert_id, + ocsp_response_validation_result.ocsp_response, + ) + + async def _check_ocsp_response_cache_server( + self, + cert_data: list[tuple[Certificate, Certificate]], + ) -> None: + """Checks if OCSP response is in cache, and if not it downloads the OCSP response cache from the server. + + Args: + cert_data: Tuple of issuer and subject certificates. + """ + in_cache = False + for issuer, subject in cert_data: + # check if any OCSP response is NOT in cache + cert_id, _ = self.create_ocsp_request(issuer, subject) + in_cache, _ = SnowflakeOCSP.OCSP_CACHE.find_cache(self, cert_id, subject) + if not in_cache: + # not found any + break + + if not in_cache: + await self.OCSP_CACHE_SERVER.download_cache_from_server(self) + + async def _validate_certificates_sequential( + self, + cert_data: list[tuple[Certificate, Certificate]], + telemetry_data: OCSPTelemetryData, + hostname: str | None = None, + do_retry: bool = True, + ) -> list[tuple[Exception | None, Certificate, Certificate, CertId, bytes]]: + try: + await self._check_ocsp_response_cache_server(cert_data) + except RevocationCheckError as rce: + telemetry_data.set_event_sub_type( + OCSPTelemetryData.ERROR_CODE_MAP[rce.errno] + ) + except Exception as ex: + logger.debug( + "Caught unknown exception - %s. Continue to validate by direct connection", + str(ex), + ) + + to_update_cache_dict = {} + + task_results = await asyncio.gather( + *[ + self._validate_issue_subject( + issuer, + subject, + hostname=hostname, + telemetry_data=telemetry_data, + do_retry=do_retry, + ) + for issuer, subject in cert_data + ] + ) + results = [validate_result for _, validate_result in task_results] + for cache_key, validate_result in task_results: + if validate_result[0] is not None or validate_result[4] is not None: + to_update_cache_dict[cache_key] = OCSPResponseValidationResult( + *validate_result, + ts=int(time.time()), + validated=True, + ) + OCSPCache.CACHE_UPDATED = True + + snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE.update( + to_update_cache_dict + ) + return results + + async def validate_by_direct_connection( + self, + issuer: Certificate, + subject: Certificate, + telemetry_data: OCSPTelemetryData, + hostname: str = None, + do_retry: bool = True, + **kwargs: Any, + ) -> tuple[Exception | None, Certificate, Certificate, CertId, bytes]: + cert_id, req = self.create_ocsp_request(issuer, subject) + cache_status, ocsp_response = self.is_cert_id_in_cache( + cert_id, subject, **kwargs + ) + + try: + if not cache_status: + telemetry_data.set_cache_hit(False) + logger.debug("getting OCSP response from CA's OCSP server") + ocsp_response = await self._fetch_ocsp_response( + req, subject, cert_id, telemetry_data, hostname, do_retry + ) + else: + ocsp_url = self.extract_ocsp_url(subject) + cert_id_enc = self.encode_cert_id_base64( + self.decode_cert_id_key(cert_id) + ) + telemetry_data.set_cache_hit(True) + self.debug_ocsp_failure_url = SnowflakeOCSP.create_ocsp_debug_info( + self, req, ocsp_url + ) + telemetry_data.set_ocsp_url(ocsp_url) + telemetry_data.set_ocsp_req(req) + telemetry_data.set_cert_id(cert_id_enc) + logger.debug("using OCSP response cache") + + if not ocsp_response: + telemetry_data.set_event_sub_type( + OCSPTelemetryData.OCSP_RESPONSE_UNAVAILABLE + ) + raise RevocationCheckError( + msg="Could not retrieve OCSP Response. Cannot perform Revocation Check", + errno=ER_OCSP_RESPONSE_UNAVAILABLE, + ) + try: + self.process_ocsp_response(issuer, cert_id, ocsp_response) + err = None + except RevocationCheckError as op_er: + telemetry_data.set_event_sub_type( + OCSPTelemetryData.ERROR_CODE_MAP[op_er.errno] + ) + raise op_er + + except RevocationCheckError as rce: + telemetry_data.set_error_msg(rce.msg) + err = self.verify_fail_open(rce, telemetry_data) + + except Exception as ex: + logger.debug("OCSP Validation failed %s", str(ex)) + telemetry_data.set_error_msg(str(ex)) + err = self.verify_fail_open(ex, telemetry_data) + SnowflakeOCSP.OCSP_CACHE.delete_cache(self, cert_id) + + return err, issuer, subject, cert_id, ocsp_response + + async def _fetch_ocsp_response( + self, + ocsp_request, + subject, + cert_id, + telemetry_data, + hostname=None, + do_retry: bool = True, + ): + """Fetches OCSP response using OCSPRequest.""" + sf_timeout = SnowflakeOCSP.CA_OCSP_RESPONDER_CONNECTION_TIMEOUT + ocsp_url = self.extract_ocsp_url(subject) + cert_id_enc = self.encode_cert_id_base64(self.decode_cert_id_key(cert_id)) + if not ocsp_url: + telemetry_data.set_event_sub_type(OCSPTelemetryData.OCSP_URL_MISSING) + raise RevocationCheckError( + msg="No OCSP URL found in cert. Cannot perform Certificate Revocation check", + errno=ER_OCSP_URL_INFO_MISSING, + ) + headers = {HTTP_HEADER_USER_AGENT: PYTHON_CONNECTOR_USER_AGENT} + + if not OCSPServer.is_enabled_new_ocsp_endpoint(): + actual_method = "post" if self._use_post_method else "get" + if self.OCSP_CACHE_SERVER.OCSP_RETRY_URL: + # no POST is supported for Retry URL at the moment. + actual_method = "get" + + if actual_method == "get": + b64data = self.decode_ocsp_request_b64(ocsp_request) + target_url = self.OCSP_CACHE_SERVER.generate_get_url(ocsp_url, b64data) + payload = None + else: + target_url = ocsp_url + payload = self.decode_ocsp_request(ocsp_request) + headers["Content-Type"] = "application/ocsp-request" + else: + actual_method = "post" + target_url = self.OCSP_CACHE_SERVER.OCSP_RETRY_URL + ocsp_req_enc = self.decode_ocsp_request_b64(ocsp_request) + + payload = json.dumps( + { + "hostname": hostname, + "ocsp_request": ocsp_req_enc, + "cert_id": cert_id_enc, + "ocsp_responder_url": ocsp_url, + } + ) + headers["Content-Type"] = "application/json" + + telemetry_data.set_ocsp_connection_method(actual_method) + if self.test_mode is not None: + logger.debug("WARNING - DRIVER IS CONFIGURED IN TESTMODE.") + test_ocsp_url = os.getenv("SF_TEST_OCSP_URL", None) + test_timeout = os.getenv( + "SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT", None + ) + if test_timeout is not None: + sf_timeout = int(test_timeout) + if test_ocsp_url is not None: + target_url = test_ocsp_url + + self.debug_ocsp_failure_url = SnowflakeOCSP.create_ocsp_debug_info( + self, ocsp_request, ocsp_url + ) + telemetry_data.set_ocsp_req(self.decode_ocsp_request_b64(ocsp_request)) + telemetry_data.set_ocsp_url(ocsp_url) + telemetry_data.set_cert_id(cert_id_enc) + + ret = None + logger.debug("url: %s", target_url) + sf_max_retry = SnowflakeOCSP.CA_OCSP_RESPONDER_MAX_RETRY_FO + if not self.is_enabled_fail_open(): + sf_max_retry = SnowflakeOCSP.CA_OCSP_RESPONDER_MAX_RETRY_FC + + async with aiohttp.ClientSession() as session: + max_retry = sf_max_retry if do_retry else 1 + sleep_time = 1 + backoff = exponential_backoff()() + for _ in range(max_retry): + try: + response = await session.request( + headers=headers, + method=actual_method, + url=target_url, + timeout=sf_timeout, + data=payload, + ) + if response.status == OK: + logger.debug( + "OCSP response was successfully returned from OCSP " + "server." + ) + ret = await response.content.read() + break + elif max_retry > 1: + sleep_time = next(backoff) + logger.debug( + "OCSP server returned %s. Retrying in %s(s)", + response.status, + sleep_time, + ) + await asyncio.sleep(sleep_time) + except Exception as ex: + if max_retry > 1: + sleep_time = next(backoff) + logger.debug( + "Could not fetch OCSP Response from server" + "Retrying in %s(s)", + sleep_time, + ) + await asyncio.sleep(sleep_time) + else: + telemetry_data.set_event_sub_type( + OCSPTelemetryData.OCSP_RESPONSE_FETCH_EXCEPTION + ) + raise RevocationCheckError( + msg="Could not fetch OCSP Response from server. Consider" + "checking your whitelists : Exception - {}".format(str(ex)), + errno=ER_OCSP_RESPONSE_FETCH_EXCEPTION, + ) + else: + logger.error( + "Failed to get OCSP response after {} attempt. Consider checking " + "for OCSP URLs being blocked".format(max_retry) + ) + telemetry_data.set_event_sub_type( + OCSPTelemetryData.OCSP_RESPONSE_FETCH_FAILURE + ) + raise RevocationCheckError( + msg="Failed to get OCSP response after {} attempt.".format( + max_retry + ), + errno=ER_OCSP_RESPONSE_FETCH_FAILURE, + ) + + return ret diff --git a/src/snowflake/connector/aio/_result_batch.py b/src/snowflake/connector/aio/_result_batch.py new file mode 100644 index 0000000000..3bf9565ee7 --- /dev/null +++ b/src/snowflake/connector/aio/_result_batch.py @@ -0,0 +1,422 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import abc +import asyncio +import json +from logging import getLogger +from typing import TYPE_CHECKING, Any, Iterator, Sequence + +import aiohttp + +from snowflake.connector import Error +from snowflake.connector.aio._network import ( + raise_failed_request_error, + raise_okta_unauthorized_error, +) +from snowflake.connector.aio._time_util import TimerContextManager +from snowflake.connector.arrow_context import ArrowConverterContext +from snowflake.connector.backoff_policies import exponential_backoff +from snowflake.connector.compat import OK, UNAUTHORIZED +from snowflake.connector.constants import IterUnit +from snowflake.connector.converter import SnowflakeConverterType +from snowflake.connector.cursor import ResultMetadataV2 +from snowflake.connector.network import ( + RetryRequest, + get_http_retryable_error, + is_retryable_http_code, +) +from snowflake.connector.result_batch import SSE_C_AES, SSE_C_ALGORITHM, SSE_C_KEY +from snowflake.connector.result_batch import ArrowResultBatch as ArrowResultBatchSync +from snowflake.connector.result_batch import DownloadMetrics +from snowflake.connector.result_batch import JSONResultBatch as JSONResultBatchSync +from snowflake.connector.result_batch import RemoteChunkInfo +from snowflake.connector.result_batch import ResultBatch as ResultBatchSync +from snowflake.connector.result_batch import _create_nanoarrow_iterator +from snowflake.connector.secret_detector import SecretDetector + +if TYPE_CHECKING: + from pandas import DataFrame + from pyarrow import Table + + from snowflake.connector.aio._connection import SnowflakeConnection + from snowflake.connector.aio._cursor import SnowflakeCursor + +logger = getLogger(__name__) + +# we redefine the DOWNLOAD_TIMEOUT and MAX_DOWNLOAD_RETRY for async version on purpose +# because download in sync and async are different in nature and may require separate tuning +# also be aware that currently _result_batch is a private module so these values are not exposed to users directly +DOWNLOAD_TIMEOUT = None +MAX_DOWNLOAD_RETRY = 10 + + +def create_batches_from_response( + cursor: SnowflakeCursor, + _format: str, + data: dict[str, Any], + schema: Sequence[ResultMetadataV2], +) -> list[ResultBatch]: + column_converters: list[tuple[str, SnowflakeConverterType]] = [] + arrow_context: ArrowConverterContext | None = None + rowtypes = data["rowtype"] + total_len: int = data.get("total", 0) + first_chunk_len = total_len + rest_of_chunks: list[ResultBatch] = [] + if _format == "json": + + def col_to_converter(col: dict[str, Any]) -> tuple[str, SnowflakeConverterType]: + type_name = col["type"].upper() + python_method = cursor._connection.converter.to_python_method( + type_name, col + ) + return type_name, python_method + + column_converters = [col_to_converter(c) for c in rowtypes] + else: + rowset_b64 = data.get("rowsetBase64") + arrow_context = ArrowConverterContext(cursor._connection._session_parameters) + if "chunks" in data: + chunks = data["chunks"] + logger.debug(f"chunk size={len(chunks)}") + # prepare the downloader for further fetch + qrmk = data.get("qrmk") + chunk_headers: dict[str, Any] = {} + if "chunkHeaders" in data: + chunk_headers = {} + for header_key, header_value in data["chunkHeaders"].items(): + chunk_headers[header_key] = header_value + if "encryption" not in header_key: + logger.debug( + f"added chunk header: key={header_key}, value={header_value}" + ) + elif qrmk is not None: + logger.debug(f"qrmk={SecretDetector.mask_secrets(qrmk)}") + chunk_headers[SSE_C_ALGORITHM] = SSE_C_AES + chunk_headers[SSE_C_KEY] = qrmk + + def remote_chunk_info(c: dict[str, Any]) -> RemoteChunkInfo: + return RemoteChunkInfo( + url=c["url"], + uncompressedSize=c["uncompressedSize"], + compressedSize=c["compressedSize"], + ) + + if _format == "json": + rest_of_chunks = [ + JSONResultBatch( + c["rowCount"], + chunk_headers, + remote_chunk_info(c), + schema, + column_converters, + cursor._use_dict_result, + json_result_force_utf8_decoding=cursor._connection._json_result_force_utf8_decoding, + ) + for c in chunks + ] + else: + rest_of_chunks = [ + ArrowResultBatch( + c["rowCount"], + chunk_headers, + remote_chunk_info(c), + arrow_context, + cursor._use_dict_result, + cursor._connection._numpy, + schema, + cursor._connection._arrow_number_to_decimal, + ) + for c in chunks + ] + for c in rest_of_chunks: + first_chunk_len -= c.rowcount + if _format == "json": + first_chunk = JSONResultBatch.from_data( + data.get("rowset"), + first_chunk_len, + schema, + column_converters, + cursor._use_dict_result, + ) + elif rowset_b64 is not None: + first_chunk = ArrowResultBatch.from_data( + rowset_b64, + first_chunk_len, + arrow_context, + cursor._use_dict_result, + cursor._connection._numpy, + schema, + cursor._connection._arrow_number_to_decimal, + ) + else: + logger.error(f"Don't know how to construct ResultBatches from response: {data}") + first_chunk = ArrowResultBatch.from_data( + "", + 0, + arrow_context, + cursor._use_dict_result, + cursor._connection._numpy, + schema, + cursor._connection._arrow_number_to_decimal, + ) + + return [first_chunk] + rest_of_chunks + + +class ResultBatch(ResultBatchSync): + def __iter__(self): + raise TypeError( + f"Async '{type(self).__name__}' does not support '__iter__', " + f"please call the `create_iter` coroutine method on the '{type(self).__name__}' object" + " to explicitly create an iterator." + ) + + @abc.abstractmethod + async def create_iter( + self, **kwargs + ) -> ( + Iterator[dict | Exception] + | Iterator[tuple | Exception] + | Iterator[Table] + | Iterator[DataFrame] + ): + """Downloads the data from blob storage that this ResultChunk points at. + + This function is the one that does the actual work for ``self.__iter__``. + + It is necessary because a ``ResultBatch`` can return multiple types of + iterators. A good example of this is simply iterating through + ``SnowflakeCursor`` and calling ``fetch_pandas_batches`` on it. + """ + raise NotImplementedError() + + async def _download( + self, connection: SnowflakeConnection | None = None, **kwargs + ) -> tuple[bytes, str]: + """Downloads the data that the ``ResultBatch`` is pointing at.""" + sleep_timer = 1 + backoff = ( + connection._backoff_generator + if connection is not None + else exponential_backoff()() + ) + + async def download_chunk(http_session): + response, content, encoding = None, None, None + logger.debug( + f"downloading result batch id: {self.id} with existing session {http_session}" + ) + response = await http_session.get(**request_data) + if response.status == OK: + logger.debug(f"successfully downloaded result batch id: {self.id}") + content, encoding = await response.read(), response.get_encoding() + return response, content, encoding + + content, encoding = None, None + for retry in range(max(MAX_DOWNLOAD_RETRY, 1)): + try: + + async with TimerContextManager() as download_metric: + logger.debug(f"started downloading result batch id: {self.id}") + chunk_url = self._remote_chunk_info.url + request_data = { + "url": chunk_url, + "headers": self._chunk_headers, + } + # timeout setting for download is different from the sync version which has an + # empirical value 7 seconds. It is difficult to measure this empirical value in async + # as we maximize the network throughput by downloading multiple chunks at the same time compared + # to the sync version that the overall throughput is constrained by the number of + # prefetch threads -- in asyncio we see great download performance improvement. + # if DOWNLOAD_TIMEOUT is not set, by default the aiohttp session timeout comes into effect + # which originates from the connection config. + if DOWNLOAD_TIMEOUT: + request_data["timeout"] = aiohttp.ClientTimeout( + total=DOWNLOAD_TIMEOUT + ) + # Try to reuse a connection if possible + if connection and connection._rest is not None: + async with connection._rest._use_requests_session() as session: + logger.debug( + f"downloading result batch id: {self.id} with existing session {session}" + ) + response, content, encoding = await download_chunk(session) + else: + async with aiohttp.ClientSession() as session: + logger.debug( + f"downloading result batch id: {self.id} with new session" + ) + response, content, encoding = await download_chunk(session) + + if response.status == OK: + break + # Raise error here to correctly go in to exception clause + if is_retryable_http_code(response.status): + # retryable server exceptions + error: Error = get_http_retryable_error(response.status) + raise RetryRequest(error) + elif response.status == UNAUTHORIZED: + # make a unauthorized error + raise_okta_unauthorized_error(None, response) + else: + raise_failed_request_error(None, chunk_url, "get", response) + + except (RetryRequest, Exception) as e: + if retry == MAX_DOWNLOAD_RETRY - 1: + # Re-throw if we failed on the last retry + e = e.args[0] if isinstance(e, RetryRequest) else e + raise e + sleep_timer = next(backoff) + logger.exception( + f"Failed to fetch the large result set batch " + f"{self.id} for the {retry + 1} th time, " + f"backing off for {sleep_timer}s for the reason: '{e}'" + ) + await asyncio.sleep(sleep_timer) + + self._metrics[DownloadMetrics.download.value] = ( + download_metric.get_timing_millis() + ) + return content, encoding + + +class JSONResultBatch(ResultBatch, JSONResultBatchSync): + async def create_iter( + self, connection: SnowflakeConnection | None = None, **kwargs + ) -> Iterator[dict | Exception] | Iterator[tuple | Exception]: + if self._local: + return iter(self._data) + content, encoding = await self._download(connection=connection) + # Load data to a intermediate form + logger.debug(f"started loading result batch id: {self.id}") + async with TimerContextManager() as load_metric: + downloaded_data = await self._load(content, encoding) + logger.debug(f"finished loading result batch id: {self.id}") + self._metrics[DownloadMetrics.load.value] = load_metric.get_timing_millis() + # Process downloaded data + async with TimerContextManager() as parse_metric: + parsed_data = self._parse(downloaded_data) + self._metrics[DownloadMetrics.parse.value] = parse_metric.get_timing_millis() + return iter(parsed_data) + + async def _load(self, content: bytes, encoding: str) -> list: + """This function loads a compressed JSON file into memory. + + Returns: + Whatever ``json.loads`` return, but in a list. + Unfortunately there's no type hint for this. + For context: https://github.com/python/typing/issues/182 + """ + # if users specify how to decode the data, we decode the bytes using the specified encoding + if self._json_result_force_utf8_decoding: + try: + read_data = str(content, "utf-8", errors="strict") + except Exception as exc: + err_msg = f"failed to decode json result content due to error {exc!r}" + logger.error(err_msg) + raise Error(msg=err_msg) + else: + # note: SNOW-787480 response.apparent_encoding is unreliable, chardet.detect can be wrong which is used by + # response.text to decode content, check issue: https://github.com/chardet/chardet/issues/148 + read_data = content.decode(encoding, "strict") + return json.loads("".join(["[", read_data, "]"])) + + +class ArrowResultBatch(ResultBatch, ArrowResultBatchSync): + async def _load( + self, content, row_unit: IterUnit + ) -> Iterator[dict | Exception] | Iterator[tuple | Exception]: + """Creates a ``PyArrowIterator`` from a response. + + This is used to iterate through results in different ways depending on which + mode that ``PyArrowIterator`` is in. + """ + return _create_nanoarrow_iterator( + content, + self._context, + self._use_dict_result, + self._numpy, + self._number_to_decimal, + row_unit, + ) + + async def _create_iter( + self, iter_unit: IterUnit, connection: SnowflakeConnection | None = None + ) -> Iterator[dict | Exception] | Iterator[tuple | Exception] | Iterator[Table]: + """Create an iterator for the ResultBatch. Used by get_arrow_iter.""" + """Create an iterator for the ResultBatch. Used by get_arrow_iter.""" + if self._local: + try: + return self._from_data(self._data, iter_unit) + except Exception: + if connection and getattr(connection, "_debug_arrow_chunk", False): + logger.debug(f"arrow data can not be parsed: {self._data}") + raise + content, _ = await self._download(connection=connection) + logger.debug(f"started loading result batch id: {self.id}") + async with TimerContextManager() as load_metric: + try: + loaded_data = await self._load(content, iter_unit) + except Exception: + if connection and getattr(connection, "_debug_arrow_chunk", False): + logger.debug(f"arrow data can not be parsed: {content}") + raise + logger.debug(f"finished loading result batch id: {self.id}") + self._metrics[DownloadMetrics.load.value] = load_metric.get_timing_millis() + return loaded_data + + async def _get_pandas_iter( + self, connection: SnowflakeConnection | None = None, **kwargs + ) -> Iterator[DataFrame]: + """An iterator for this batch which yields a pandas DataFrame""" + iterator_data = [] + dataframe = await self.to_pandas(connection=connection, **kwargs) + if not dataframe.empty: + iterator_data.append(dataframe) + return iter(iterator_data) + + async def _get_arrow_iter( + self, connection: SnowflakeConnection | None = None + ) -> Iterator[Table]: + """Returns an iterator for this batch which yields a pyarrow Table""" + return await self._create_iter( + iter_unit=IterUnit.TABLE_UNIT, connection=connection + ) + + async def to_arrow(self, connection: SnowflakeConnection | None = None) -> Table: + """Returns this batch as a pyarrow Table""" + val = next(await self._get_arrow_iter(connection=connection), None) + if val is not None: + return val + return self._create_empty_table() + + async def to_pandas( + self, connection: SnowflakeConnection | None = None, **kwargs + ) -> DataFrame: + """Returns this batch as a pandas DataFrame""" + self._check_can_use_pandas() + table = await self.to_arrow(connection=connection) + return table.to_pandas(**kwargs) + + async def create_iter( + self, connection: SnowflakeConnection | None = None, **kwargs + ) -> ( + Iterator[dict | Exception] + | Iterator[tuple | Exception] + | Iterator[Table] + | Iterator[DataFrame] + ): + """The interface used by ResultSet to create an iterator for this ResultBatch.""" + iter_unit: IterUnit = kwargs.pop("iter_unit", IterUnit.ROW_UNIT) + if iter_unit == IterUnit.TABLE_UNIT: + structure = kwargs.pop("structure", "pandas") + if structure == "pandas": + return await self._get_pandas_iter(connection=connection, **kwargs) + else: + return await self._get_arrow_iter(connection=connection) + else: + return await self._create_iter(iter_unit=iter_unit, connection=connection) diff --git a/src/snowflake/connector/aio/_result_set.py b/src/snowflake/connector/aio/_result_set.py new file mode 100644 index 0000000000..2ac9639947 --- /dev/null +++ b/src/snowflake/connector/aio/_result_set.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import inspect +from collections import deque +from logging import getLogger +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Deque, + Iterator, + Literal, + Union, + cast, + overload, +) + +from snowflake.connector.aio._result_batch import ( + ArrowResultBatch, + JSONResultBatch, + ResultBatch, +) +from snowflake.connector.constants import IterUnit +from snowflake.connector.options import pandas +from snowflake.connector.result_set import ResultSet as ResultSetSync + +from .. import NotSupportedError +from ..options import pyarrow as pa +from ..result_batch import DownloadMetrics +from ..telemetry import TelemetryField +from ..time_util import get_time_millis + +if TYPE_CHECKING: + from pandas import DataFrame + from pyarrow import Table + + from snowflake.connector.aio._cursor import SnowflakeCursor + +logger = getLogger(__name__) + + +class ResultSetIterator: + def __init__( + self, + first_batch_iter: Iterator[tuple], + unfetched_batches: Deque[ResultBatch], + final: Callable[[], Awaitable[None]], + prefetch_thread_num: int, + **kw: Any, + ) -> None: + self._is_fetch_all = kw.pop("is_fetch_all", False) + self._first_batch_iter = first_batch_iter + self._unfetched_batches = unfetched_batches + self._final = final + self._prefetch_thread_num = prefetch_thread_num + self._kw = kw + self._generator = self.generator() + + async def _download_all_batches(self): + # try to download all the batches at one time, won't return until all the batches are downloaded + tasks = [] + for result_batch in self._unfetched_batches: + tasks.append(result_batch.create_iter(**self._kw)) + await asyncio.sleep(0) + return tasks + + async def _download_batch_and_convert_to_list(self, result_batch): + return list(await result_batch.create_iter(**self._kw)) + + async def fetch_all_data(self): + rets = list(self._first_batch_iter) + tasks = [ + self._download_batch_and_convert_to_list(result_batch) + for result_batch in self._unfetched_batches + ] + batches = await asyncio.gather(*tasks) + for batch in batches: + rets.extend(batch) + # yield to avoid blocking the event loop for too long when processing large result sets + # await asyncio.sleep(0) + return rets + + async def generator(self): + if self._is_fetch_all: + + tasks = await self._download_all_batches() + for value in self._first_batch_iter: + yield value + + new_batches = await asyncio.gather(*tasks) + for batch in new_batches: + for value in batch: + yield value + + await self._final() + else: + download_tasks = deque() + for _ in range( + min(self._prefetch_thread_num, len(self._unfetched_batches)) + ): + logger.debug( + f"queuing download of result batch id: {self._unfetched_batches[0].id}" + ) + download_tasks.append( + asyncio.create_task( + self._unfetched_batches.popleft().create_iter(**self._kw) + ) + ) + + for value in self._first_batch_iter: + yield value + + i = 1 + while download_tasks: + logger.debug(f"user requesting to consume result batch {i}") + + # Submit the next un-fetched batch to the pool + if self._unfetched_batches: + logger.debug( + f"queuing download of result batch id: {self._unfetched_batches[0].id}" + ) + download_tasks.append( + asyncio.create_task( + self._unfetched_batches.popleft().create_iter(**self._kw) + ) + ) + + task = download_tasks.popleft() + # this will raise an exception if one has occurred + batch_iterator = await task + + logger.debug(f"user began consuming result batch {i}") + for value in batch_iterator: + yield value + logger.debug(f"user finished consuming result batch {i}") + i += 1 + await self._final() + + async def get_next(self): + return await anext(self._generator, None) + + +class ResultSet(ResultSetSync): + def __init__( + self, + cursor: SnowflakeCursor, + result_chunks: list[JSONResultBatch] | list[ArrowResultBatch], + prefetch_thread_num: int, + ) -> None: + super().__init__(cursor, result_chunks, prefetch_thread_num) + self.batches = cast( + Union[list[JSONResultBatch], list[ArrowResultBatch]], self.batches + ) + + def _can_create_arrow_iter(self) -> None: + # For now we don't support mixed ResultSets, so assume first partition's type + # represents them all + head_type = type(self.batches[0]) + if head_type != ArrowResultBatch: + raise NotSupportedError( + f"Trying to use arrow fetching on {head_type} which " + f"is not ArrowResultChunk" + ) + + async def _create_iter( + self, + **kwargs, + ) -> ResultSetIterator: + """Set up a new iterator through all batches with first 5 chunks downloaded. + + This function is a helper function to ``__iter__`` and it was introduced for the + cases where we need to propagate some values to later ``_download`` calls. + """ + # pop is_fetch_all and pass it to result_set_iterator + is_fetch_all = kwargs.pop("is_fetch_all", False) + + # add connection so that result batches can use sessions + kwargs["connection"] = self._cursor.connection + + first_batch_iter = await self.batches[0].create_iter(**kwargs) + + # batches that have not been fetched + unfetched_batches = deque(self.batches[1:]) + for num, batch in enumerate(unfetched_batches): + logger.debug(f"result batch {num + 1} has id: {batch.id}") + + return ResultSetIterator( + first_batch_iter, + unfetched_batches, + self._finish_iterating, + self.prefetch_thread_num, + is_fetch_all=is_fetch_all, + **kwargs, + ) + + async def _fetch_arrow_batches( + self, + ) -> AsyncIterator[Table]: + """Fetches all the results as Arrow Tables, chunked by Snowflake back-end.""" + self._can_create_arrow_iter() + result_set_iterator = await self._create_iter( + iter_unit=IterUnit.TABLE_UNIT, structure="arrow" + ) + return result_set_iterator.generator() + + @overload + async def _fetch_arrow_all( + self, force_return_table: Literal[False] + ) -> Table | None: ... + + @overload + async def _fetch_arrow_all(self, force_return_table: Literal[True]) -> Table: ... + + async def _fetch_arrow_all(self, force_return_table: bool = False) -> Table | None: + """Fetches a single Arrow Table from all of the ``ResultBatch``.""" + self._can_create_arrow_iter() + result_set_iterator = await self._create_iter( + iter_unit=IterUnit.TABLE_UNIT, structure="arrow" + ) + tables = list(await result_set_iterator.fetch_all_data()) + if tables: + return pa.concat_tables(tables) + else: + return await self.batches[0].to_arrow() if force_return_table else None + + async def _fetch_pandas_batches(self, **kwargs) -> AsyncIterator[DataFrame]: + self._can_create_arrow_iter() + result_set_iterator = await self._create_iter( + iter_unit=IterUnit.TABLE_UNIT, structure="pandas", **kwargs + ) + return result_set_iterator.generator() + + async def _fetch_pandas_all(self, **kwargs) -> DataFrame: + """Fetches a single Pandas dataframe.""" + result_set_iterator = await self._create_iter( + iter_unit=IterUnit.TABLE_UNIT, structure="pandas", **kwargs + ) + concat_args = list(inspect.signature(pandas.concat).parameters) + concat_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in concat_args} + dataframes = await result_set_iterator.fetch_all_data() + if dataframes: + return pandas.concat( + dataframes, + ignore_index=True, # Don't keep in result batch indexes + **concat_kwargs, + ) + # Empty dataframe + return await self.batches[0].to_pandas(**kwargs) + + async def _finish_iterating(self) -> None: + await self._report_metrics() + + async def _report_metrics(self) -> None: + """Report metrics for the result set.""" + # TODO: SNOW-1572217 async telemetry + """Report all metrics totalled up. + + This includes TIME_CONSUME_LAST_RESULT, TIME_DOWNLOADING_CHUNKS and + TIME_PARSING_CHUNKS in that order. + """ + if self._cursor._first_chunk_time is not None: + time_consume_last_result = ( + get_time_millis() - self._cursor._first_chunk_time + ) + await self._cursor._log_telemetry_job_data( + TelemetryField.TIME_CONSUME_LAST_RESULT, time_consume_last_result + ) + metrics = self._get_metrics() + if DownloadMetrics.download.value in metrics: + await self._cursor._log_telemetry_job_data( + TelemetryField.TIME_DOWNLOADING_CHUNKS, + metrics.get(DownloadMetrics.download.value), + ) + if DownloadMetrics.parse.value in metrics: + await self._cursor._log_telemetry_job_data( + TelemetryField.TIME_PARSING_CHUNKS, + metrics.get(DownloadMetrics.parse.value), + ) diff --git a/src/snowflake/connector/aio/_s3_storage_client.py b/src/snowflake/connector/aio/_s3_storage_client.py new file mode 100644 index 0000000000..9be04fe215 --- /dev/null +++ b/src/snowflake/connector/aio/_s3_storage_client.py @@ -0,0 +1,425 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import xml.etree.ElementTree as ET +from datetime import datetime, timezone +from io import IOBase +from logging import getLogger +from typing import TYPE_CHECKING, Any + +import aiohttp + +from ..compat import quote, urlparse +from ..constants import ( + HTTP_HEADER_CONTENT_TYPE, + HTTP_HEADER_VALUE_OCTET_STREAM, + FileHeader, + ResultStatus, +) +from ..encryption_util import EncryptionMetadata +from ..s3_storage_client import ( + AMZ_IV, + AMZ_KEY, + AMZ_MATDESC, + EXPIRED_TOKEN, + META_PREFIX, + SFC_DIGEST, + UNSIGNED_PAYLOAD, + S3Location, +) +from ..s3_storage_client import SnowflakeS3RestClient as SnowflakeS3RestClientSync +from ._storage_client import SnowflakeStorageClient as SnowflakeStorageClientAsync + +if TYPE_CHECKING: # pragma: no cover + from ..file_transfer_agent import SnowflakeFileMeta, StorageCredential + +logger = getLogger(__name__) + + +class SnowflakeS3RestClient(SnowflakeStorageClientAsync, SnowflakeS3RestClientSync): + def __init__( + self, + meta: SnowflakeFileMeta, + credentials: StorageCredential, + stage_info: dict[str, Any], + chunk_size: int, + use_accelerate_endpoint: bool | None = None, + use_s3_regional_url: bool = False, + ) -> None: + """Rest client for S3 storage. + + Args: + stage_info: + """ + SnowflakeStorageClientAsync.__init__( + self, + meta=meta, + stage_info=stage_info, + chunk_size=chunk_size, + credentials=credentials, + ) + # Signature version V4 + # Addressing style Virtual Host + self.region_name: str = stage_info["region"] + # Multipart upload only + self.upload_id: str | None = None + self.etags: list[str] | None = None + self.s3location: S3Location = ( + SnowflakeS3RestClient._extract_bucket_name_and_path( + self.stage_info["location"] + ) + ) + self.use_s3_regional_url = use_s3_regional_url + self.location_type = stage_info.get("locationType") + + # if GS sends us an endpoint, it's likely for FIPS. Use it. + self.endpoint: str | None = None + if stage_info["endPoint"]: + self.endpoint = ( + f"https://{self.s3location.bucket_name}." + stage_info["endPoint"] + ) + + async def _send_request_with_authentication_and_retry( + self, + url: str, + verb: str, + retry_id: int | str, + query_parts: dict[str, str] | None = None, + x_amz_headers: dict[str, str] | None = None, + headers: dict[str, str] | None = None, + payload: bytes | bytearray | IOBase | None = None, + unsigned_payload: bool = False, + ignore_content_encoding: bool = False, + ) -> aiohttp.ClientResponse: + if x_amz_headers is None: + x_amz_headers = {} + if headers is None: + headers = {} + if payload is None: + payload = b"" + if query_parts is None: + query_parts = {} + parsed_url = urlparse(url) + x_amz_headers["x-amz-security-token"] = self.credentials.creds.get( + "AWS_TOKEN", "" + ) + x_amz_headers["host"] = parsed_url.hostname + if unsigned_payload: + x_amz_headers["x-amz-content-sha256"] = UNSIGNED_PAYLOAD + else: + x_amz_headers["x-amz-content-sha256"] = ( + SnowflakeS3RestClient._hash_bytes_hex(payload).lower().decode() + ) + + def generate_authenticated_url_and_args_v4() -> tuple[str, dict[str, bytes]]: + t = datetime.now(timezone.utc).replace(tzinfo=None) + amzdate = t.strftime("%Y%m%dT%H%M%SZ") + short_amzdate = amzdate[:8] + x_amz_headers["x-amz-date"] = amzdate + + ( + canonical_request, + signed_headers, + ) = self._construct_canonical_request_and_signed_headers( + verb=verb, + canonical_uri_parameter=parsed_url.path + + (f";{parsed_url.params}" if parsed_url.params else ""), + query_parts=query_parts, + canonical_headers=x_amz_headers, + payload_hash=x_amz_headers["x-amz-content-sha256"], + ) + string_to_sign, scope = self._construct_string_to_sign( + self.region_name, + "s3", + amzdate, + short_amzdate, + self._hash_bytes_hex(canonical_request.encode("utf-8")).lower(), + ) + kDate = self._sign_bytes( + ("AWS4" + self.credentials.creds["AWS_SECRET_KEY"]).encode("utf-8"), + short_amzdate, + ) + kRegion = self._sign_bytes(kDate, self.region_name) + kService = self._sign_bytes(kRegion, "s3") + signing_key = self._sign_bytes(kService, "aws4_request") + + signature = self._sign_bytes_hex(signing_key, string_to_sign).lower() + authorization_header = ( + "AWS4-HMAC-SHA256 " + + f"Credential={self.credentials.creds['AWS_KEY_ID']}/{scope}, " + + f"SignedHeaders={signed_headers}, " + + f"Signature={signature.decode('utf-8')}" + ) + headers.update(x_amz_headers) + headers["Authorization"] = authorization_header + rest_args = {"headers": headers} + + if payload: + rest_args["data"] = payload + + if ignore_content_encoding: + rest_args["auto_decompress"] = False + + return url, rest_args + + return await self._send_request_with_retry( + verb, generate_authenticated_url_and_args_v4, retry_id + ) + + async def get_file_header(self, filename: str) -> FileHeader | None: + """Gets the metadata of file in specified location. + + Args: + filename: Name of remote file. + + Returns: + None if HEAD returns 404, otherwise a FileHeader instance populated + with metadata + """ + path = quote(self.s3location.path + filename.lstrip("/")) + url = self.endpoint + f"/{path}" + + retry_id = "HEAD" + self.retry_count[retry_id] = 0 + response = await self._send_request_with_authentication_and_retry( + url=url, verb="HEAD", retry_id=retry_id + ) + if response.status == 200: + self.meta.result_status = ResultStatus.UPLOADED + metadata = response.headers + encryption_metadata = ( + EncryptionMetadata( + key=metadata.get(META_PREFIX + AMZ_KEY), + iv=metadata.get(META_PREFIX + AMZ_IV), + matdesc=metadata.get(META_PREFIX + AMZ_MATDESC), + ) + if metadata.get(META_PREFIX + AMZ_KEY) + else None + ) + return FileHeader( + digest=metadata.get(META_PREFIX + SFC_DIGEST), + content_length=int(metadata.get("Content-Length")), + encryption_metadata=encryption_metadata, + ) + elif response.status == 404: + logger.debug( + f"not found. bucket: {self.s3location.bucket_name}, path: {path}" + ) + self.meta.result_status = ResultStatus.NOT_FOUND_FILE + return None + else: + response.raise_for_status() + + # for multi-chunk file transfer + async def _initiate_multipart_upload(self) -> None: + query_parts = (("uploads", ""),) + path = quote(self.s3location.path + self.meta.dst_file_name.lstrip("/")) + query_string = self._construct_query_string(query_parts) + url = self.endpoint + f"/{path}?{query_string}" + s3_metadata = self._prepare_file_metadata() + # initiate multipart upload + retry_id = "Initiate" + self.retry_count[retry_id] = 0 + response = await self._send_request_with_authentication_and_retry( + url=url, + verb="POST", + retry_id=retry_id, + x_amz_headers=s3_metadata, + headers={HTTP_HEADER_CONTENT_TYPE: HTTP_HEADER_VALUE_OCTET_STREAM}, + query_parts=dict(query_parts), + ) + if response.status == 200: + self.upload_id = ET.fromstring(await response.read())[2].text + self.etags = [None] * self.num_of_chunks + else: + response.raise_for_status() + + async def _upload_chunk(self, chunk_id: int, chunk: bytes) -> None: + path = quote(self.s3location.path + self.meta.dst_file_name.lstrip("/")) + url = self.endpoint + f"/{path}" + + if self.num_of_chunks == 1: # single request + s3_metadata = self._prepare_file_metadata() + response = await self._send_request_with_authentication_and_retry( + url=url, + verb="PUT", + retry_id=chunk_id, + payload=chunk, + x_amz_headers=s3_metadata, + headers={HTTP_HEADER_CONTENT_TYPE: HTTP_HEADER_VALUE_OCTET_STREAM}, + unsigned_payload=True, + ) + response.raise_for_status() + else: + # multipart PUT + query_parts = ( + ("partNumber", str(chunk_id + 1)), + ("uploadId", self.upload_id), + ) + query_string = self._construct_query_string(query_parts) + chunk_url = f"{url}?{query_string}" + response = await self._send_request_with_authentication_and_retry( + url=chunk_url, + verb="PUT", + retry_id=chunk_id, + payload=chunk, + unsigned_payload=True, + query_parts=dict(query_parts), + ) + if response.status == 200: + self.etags[chunk_id] = response.headers["ETag"] + response.raise_for_status() + + async def _complete_multipart_upload(self) -> None: + query_parts = (("uploadId", self.upload_id),) + path = quote(self.s3location.path + self.meta.dst_file_name.lstrip("/")) + query_string = self._construct_query_string(query_parts) + url = self.endpoint + f"/{path}?{query_string}" + logger.debug("Initiating multipart upload complete") + # Complete multipart upload + root = ET.Element("CompleteMultipartUpload") + for idx, etag_str in enumerate(self.etags): + part = ET.Element("Part") + etag = ET.Element("ETag") + etag.text = etag_str + part.append(etag) + part_number = ET.Element("PartNumber") + part_number.text = str(idx + 1) + part.append(part_number) + root.append(part) + retry_id = "Complete" + self.retry_count[retry_id] = 0 + response = await self._send_request_with_authentication_and_retry( + url=url, + verb="POST", + retry_id=retry_id, + payload=ET.tostring(root), + query_parts=dict(query_parts), + ) + response.raise_for_status() + + async def _abort_multipart_upload(self) -> None: + if self.upload_id is None: + return + query_parts = (("uploadId", self.upload_id),) + path = quote(self.s3location.path + self.meta.dst_file_name.lstrip("/")) + query_string = self._construct_query_string(query_parts) + url = self.endpoint + f"/{path}?{query_string}" + + retry_id = "Abort" + self.retry_count[retry_id] = 0 + response = await self._send_request_with_authentication_and_retry( + url=url, + verb="DELETE", + retry_id=retry_id, + query_parts=dict(query_parts), + ) + response.raise_for_status() + + async def download_chunk(self, chunk_id: int) -> None: + logger.debug(f"Downloading chunk {chunk_id}") + path = quote(self.s3location.path + self.meta.src_file_name.lstrip("/")) + url = self.endpoint + f"/{path}" + if self.num_of_chunks == 1: + response = await self._send_request_with_authentication_and_retry( + url=url, + verb="GET", + retry_id=chunk_id, + ignore_content_encoding=True, + ) + if response.status == 200: + self.write_downloaded_chunk(0, await response.read()) + self.meta.result_status = ResultStatus.DOWNLOADED + response.raise_for_status() + else: + chunk_size = self.chunk_size + if chunk_id < self.num_of_chunks - 1: + _range = f"{chunk_id * chunk_size}-{(chunk_id + 1) * chunk_size - 1}" + else: + _range = f"{chunk_id * chunk_size}-" + + response = await self._send_request_with_authentication_and_retry( + url=url, + verb="GET", + retry_id=chunk_id, + headers={"Range": f"bytes={_range}"}, + ) + if response.status in (200, 206): + self.write_downloaded_chunk(chunk_id, await response.read()) + response.raise_for_status() + + async def _get_bucket_accelerate_config(self, bucket_name: str) -> bool: + query_parts = (("accelerate", ""),) + query_string = self._construct_query_string(query_parts) + url = f"https://{bucket_name}.s3.amazonaws.com/?{query_string}" + retry_id = "accelerate" + self.retry_count[retry_id] = 0 + + response = await self._send_request_with_authentication_and_retry( + url=url, verb="GET", retry_id=retry_id, query_parts=dict(query_parts) + ) + if response.status == 200: + config = ET.fromstring(await response.text()) + namespace = config.tag[: config.tag.index("}") + 1] + statusTag = f"{namespace}Status" + found = config.find(statusTag) + use_accelerate_endpoint = ( + False if found is None else (found.text == "Enabled") + ) + logger.debug(f"use_accelerate_endpoint: {use_accelerate_endpoint}") + return use_accelerate_endpoint + return False + + async def transfer_accelerate_config( + self, use_accelerate_endpoint: bool | None = None + ) -> bool: + # accelerate cannot be used in China and us government + if self.region_name and self.region_name.startswith("cn-"): + self.endpoint = ( + f"https://{self.s3location.bucket_name}." + f"s3.{self.region_name}.amazonaws.com.cn" + ) + return False + # if self.endpoint has been set, e.g. by metadata, no more config is needed. + if self.endpoint is not None: + return self.endpoint.find("s3-accelerate.amazonaws.com") >= 0 + if self.use_s3_regional_url: + self.endpoint = ( + f"https://{self.s3location.bucket_name}." + f"s3.{self.region_name}.amazonaws.com" + ) + return False + else: + if use_accelerate_endpoint is None: + use_accelerate_endpoint = await self._get_bucket_accelerate_config( + self.s3location.bucket_name + ) + + if use_accelerate_endpoint: + self.endpoint = ( + f"https://{self.s3location.bucket_name}.s3-accelerate.amazonaws.com" + ) + else: + self.endpoint = ( + f"https://{self.s3location.bucket_name}.s3.amazonaws.com" + ) + return use_accelerate_endpoint + + async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool: + """Extract error code and error message from the S3's error response. + Expected format: + https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html#RESTErrorResponses + Args: + response: Rest error response in XML format + Returns: True if the error response is caused by token expiration + """ + if response.status != 400: + return False + message = await response.text() + if not message: + return False + err = ET.fromstring(await response.read()) + return err.find("Code").text == EXPIRED_TOKEN diff --git a/src/snowflake/connector/aio/_ssl_connector.py b/src/snowflake/connector/aio/_ssl_connector.py new file mode 100644 index 0000000000..86d7d5acf5 --- /dev/null +++ b/src/snowflake/connector/aio/_ssl_connector.py @@ -0,0 +1,82 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import logging +import sys +from typing import TYPE_CHECKING + +import aiohttp +from aiohttp import ClientRequest, ClientTimeout +from aiohttp.client_proto import ResponseHandler +from aiohttp.connector import Connection + +from snowflake.connector.constants import OCSPMode + +from .. import OperationalError +from ..errorcode import ER_OCSP_RESPONSE_CERT_STATUS_REVOKED +from ..ssl_wrap_socket import FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME +from ._ocsp_asn1crypto import SnowflakeOCSPAsn1Crypto + +if TYPE_CHECKING: + from aiohttp.tracing import Trace + +log = logging.getLogger(__name__) + + +class SnowflakeSSLConnector(aiohttp.TCPConnector): + def __init__(self, *args, **kwargs): + self._snowflake_ocsp_mode = kwargs.pop( + "snowflake_ocsp_mode", OCSPMode.FAIL_OPEN + ) + if self._snowflake_ocsp_mode == OCSPMode.FAIL_OPEN and sys.version_info < ( + 3, + 10, + ): + raise RuntimeError( + "Async Snowflake Python Connector requires Python 3.10+ for OCSP validation related features. " + "Please open a feature request issue in github if your want to use Python 3.9 or lower: " + "https://github.com/snowflakedb/snowflake-connector-python/issues/new/choose." + ) + + super().__init__(*args, **kwargs) + + async def connect( + self, req: ClientRequest, traces: list[Trace], timeout: ClientTimeout + ) -> Connection: + connection = await super().connect(req, traces, timeout) + protocol = connection.protocol + if ( + req.is_ssl() + and protocol is not None + and not getattr(protocol, "_snowflake_ocsp_validated", False) + ): + if self._snowflake_ocsp_mode == OCSPMode.INSECURE: + log.info( + "THIS CONNECTION IS IN INSECURE " + "MODE. IT MEANS THE CERTIFICATE WILL BE " + "VALIDATED BUT THE CERTIFICATE REVOCATION " + "STATUS WILL NOT BE CHECKED." + ) + else: + await self.validate_ocsp(req.url.host, protocol) + protocol._snowflake_ocsp_validated = True + return connection + + async def validate_ocsp(self, hostname: str, protocol: ResponseHandler): + + v = await SnowflakeOCSPAsn1Crypto( + ocsp_response_cache_uri=FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME, + use_fail_open=self._snowflake_ocsp_mode == OCSPMode.FAIL_OPEN, + hostname=hostname, + ).validate(hostname, protocol) + if not v: + raise OperationalError( + msg=( + "The certificate is revoked or " + "could not be validated: hostname={}".format(hostname) + ), + errno=ER_OCSP_RESPONSE_CERT_STATUS_REVOKED, + ) diff --git a/src/snowflake/connector/aio/_storage_client.py b/src/snowflake/connector/aio/_storage_client.py new file mode 100644 index 0000000000..5096a8be5d --- /dev/null +++ b/src/snowflake/connector/aio/_storage_client.py @@ -0,0 +1,319 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import os +import shutil +from abc import abstractmethod +from logging import getLogger +from math import ceil +from typing import TYPE_CHECKING, Any, Callable + +import aiohttp +import OpenSSL + +from ..constants import FileHeader, ResultStatus +from ..encryption_util import SnowflakeEncryptionUtil +from ..errors import RequestExceedMaxRetryError +from ..storage_client import SnowflakeStorageClient as SnowflakeStorageClientSync + +if TYPE_CHECKING: # pragma: no cover + from ..file_transfer_agent import SnowflakeFileMeta, StorageCredential + +logger = getLogger(__name__) + + +class SnowflakeStorageClient(SnowflakeStorageClientSync): + TRANSIENT_ERRORS = (OpenSSL.SSL.SysCallError, asyncio.TimeoutError, ConnectionError) + + def __init__( + self, + meta: SnowflakeFileMeta, + stage_info: dict[str, Any], + chunk_size: int, + chunked_transfer: bool | None = True, + credentials: StorageCredential | None = None, + max_retry: int = 5, + ) -> None: + SnowflakeStorageClientSync.__init__( + self, + meta=meta, + stage_info=stage_info, + chunk_size=chunk_size, + chunked_transfer=chunked_transfer, + credentials=credentials, + max_retry=max_retry, + ) + + @abstractmethod + async def get_file_header(self, filename: str) -> FileHeader | None: + """Check if file exists in target location and obtain file metadata if exists. + + Notes: + Updates meta.result_status. + """ + pass + + async def preprocess(self) -> None: + meta = self.meta + logger.debug(f"Preprocessing {meta.src_file_name}") + file_header = await self.get_file_header( + meta.dst_file_name + ) # check if file exists on remote + if not meta.overwrite: + self.get_digest() # self.get_file_header needs digest for multiparts upload when aws is used. + if meta.result_status == ResultStatus.UPLOADED: + # Skipped + logger.debug( + f'file already exists location="{self.stage_info["location"]}", ' + f'file_name="{meta.dst_file_name}"' + ) + meta.dst_file_size = 0 + meta.result_status = ResultStatus.SKIPPED + self.preprocessed = True + return + # Uploading + if meta.require_compress: + self.compress() + self.get_digest() + + if ( + meta.skip_upload_on_content_match + and file_header + and meta.sha256_digest == file_header.digest + ): + logger.debug(f"same file contents for {meta.name}, skipping upload") + meta.result_status = ResultStatus.SKIPPED + + self.preprocessed = True + + async def prepare_upload(self) -> None: + meta = self.meta + + if not self.preprocessed: + await self.preprocess() + elif meta.encryption_material: + # need to clean up previous encrypted file + os.remove(self.data_file) + logger.debug(f"Preparing to upload {meta.src_file_name}") + + if meta.encryption_material: + self.encrypt() + else: + self.data_file = meta.real_src_file_name + logger.debug("finished preprocessing") + if meta.upload_size < meta.multipart_threshold or not self.chunked_transfer: + self.num_of_chunks = 1 + else: + # multi-chunk file transfer + self.num_of_chunks = ceil(meta.upload_size / self.chunk_size) + + logger.debug(f"number of chunks {self.num_of_chunks}") + # clean up + self.retry_count = {} + + for chunk_id in range(self.num_of_chunks): + self.retry_count[chunk_id] = 0 + # multi-chunk file transfer + if self.chunked_transfer and self.num_of_chunks > 1: + await self._initiate_multipart_upload() + + async def finish_upload(self) -> None: + meta = self.meta + if self.successful_transfers == self.num_of_chunks and self.num_of_chunks != 0: + # multi-chunk file transfer + if self.num_of_chunks > 1: + await self._complete_multipart_upload() + meta.result_status = ResultStatus.UPLOADED + meta.dst_file_size = meta.upload_size + logger.debug(f"{meta.src_file_name} upload is completed.") + else: + # TODO: add more error details to result/meta + meta.dst_file_size = 0 + logger.debug(f"{meta.src_file_name} upload is aborted.") + # multi-chunk file transfer + if self.num_of_chunks > 1: + await self._abort_multipart_upload() + meta.result_status = ResultStatus.ERROR + + async def finish_download(self) -> None: + meta = self.meta + if self.num_of_chunks != 0 and self.successful_transfers == self.num_of_chunks: + meta.result_status = ResultStatus.DOWNLOADED + if meta.encryption_material: + logger.debug(f"encrypted data file={self.full_dst_file_name}") + # For storage utils that do not have the privilege of + # getting the metadata early, both object and metadata + # are downloaded at once. In which case, the file meta will + # be updated with all the metadata that we need and + # then we can call get_file_header to get just that and also + # preserve the idea of getting metadata in the first place. + # One example of this is the utils that use presigned url + # for upload/download and not the storage client library. + if meta.presigned_url is not None: + file_header = await self.get_file_header(meta.src_file_name) + self.encryption_metadata = file_header.encryption_metadata + + tmp_dst_file_name = SnowflakeEncryptionUtil.decrypt_file( + self.encryption_metadata, + meta.encryption_material, + str(self.intermediate_dst_path), + tmp_dir=self.tmp_dir, + ) + shutil.move(tmp_dst_file_name, self.full_dst_file_name) + self.intermediate_dst_path.unlink() + else: + logger.debug(f"not encrypted data file={self.full_dst_file_name}") + shutil.move(str(self.intermediate_dst_path), self.full_dst_file_name) + stat_info = os.stat(self.full_dst_file_name) + meta.dst_file_size = stat_info.st_size + else: + # TODO: add more error details to result/meta + if os.path.isfile(self.full_dst_file_name): + os.unlink(self.full_dst_file_name) + logger.exception(f"Failed to download a file: {self.full_dst_file_name}") + meta.dst_file_size = -1 + meta.result_status = ResultStatus.ERROR + + async def _send_request_with_retry( + self, + verb: str, + get_request_args: Callable[[], tuple[str, dict[str, Any]]], + retry_id: int, + ) -> aiohttp.ClientResponse: + url = "" + conn = None + if self.meta.sfagent and self.meta.sfagent._cursor.connection: + conn = self.meta.sfagent._cursor._connection + + while self.retry_count[retry_id] < self.max_retry: + cur_timestamp = self.credentials.timestamp + url, rest_kwargs = get_request_args() + # rest_kwargs["timeout"] = (REQUEST_CONNECTION_TIMEOUT, REQUEST_READ_TIMEOUT) + try: + if conn: + async with conn._rest._use_requests_session(url) as session: + logger.debug(f"storage client request with session {session}") + response = await session.request(verb, url, **rest_kwargs) + else: + logger.debug("storage client request with new session") + response = await aiohttp.ClientSession().request( + verb, url, **rest_kwargs + ) + + if await self._has_expired_presigned_url(response): + await self._update_presigned_url() + else: + self.last_err_is_presigned_url = False + if response.status in self.TRANSIENT_HTTP_ERR: + await asyncio.sleep( + min( + # TODO should SLEEP_UNIT come from the parent + # SnowflakeConnection and be customizable by users? + (2 ** self.retry_count[retry_id]) * self.SLEEP_UNIT, + self.SLEEP_MAX, + ) + ) + self.retry_count[retry_id] += 1 + elif await self._has_expired_token(response): + self.credentials.update(cur_timestamp) + else: + return response + except self.TRANSIENT_ERRORS as e: + self.last_err_is_presigned_url = False + await asyncio.sleep( + min( + (2 ** self.retry_count[retry_id]) * self.SLEEP_UNIT, + self.SLEEP_MAX, + ) + ) + logger.warning(f"{verb} with url {url} failed for transient error: {e}") + self.retry_count[retry_id] += 1 + else: + raise RequestExceedMaxRetryError( + f"{verb} with url {url} failed for exceeding maximum retries." + ) + + async def prepare_download(self) -> None: + # TODO: add nicer error message for when target directory is not writeable + # but this should be done before we get here + base_dir = os.path.dirname(self.full_dst_file_name) + if not os.path.exists(base_dir): + os.makedirs(base_dir) + + # HEAD + file_header = await self.get_file_header(self.meta.real_src_file_name) + + if file_header and file_header.encryption_metadata: + self.encryption_metadata = file_header.encryption_metadata + + self.num_of_chunks = 1 + if file_header and file_header.content_length: + self.meta.src_file_size = file_header.content_length + # multi-chunk file transfer + if ( + self.chunked_transfer + and self.meta.src_file_size > self.meta.multipart_threshold + ): + self.num_of_chunks = ceil(file_header.content_length / self.chunk_size) + + # Preallocate encrypted file. + with self.intermediate_dst_path.open("wb+") as fd: + fd.truncate(self.meta.src_file_size) + + async def upload_chunk(self, chunk_id: int) -> None: + new_stream = not bool(self.meta.src_stream or self.meta.intermediate_stream) + fd = ( + self.meta.src_stream + or self.meta.intermediate_stream + or open(self.data_file, "rb") + ) + try: + if self.num_of_chunks == 1: + _data = fd.read() + else: + fd.seek(chunk_id * self.chunk_size) + _data = fd.read(self.chunk_size) + finally: + if new_stream: + fd.close() + logger.debug(f"Uploading chunk {chunk_id} of file {self.data_file}") + await self._upload_chunk(chunk_id, _data) + logger.debug(f"Successfully uploaded chunk {chunk_id} of file {self.data_file}") + + @abstractmethod + async def _upload_chunk(self, chunk_id: int, chunk: bytes) -> None: + pass + + @abstractmethod + async def download_chunk(self, chunk_id: int) -> None: + pass + + # Override in GCS + async def _has_expired_presigned_url( + self, response: aiohttp.ClientResponse + ) -> bool: + return False + + # Override in GCS + async def _update_presigned_url(self) -> None: + return + + # Override in S3 + async def _initiate_multipart_upload(self) -> None: + return + + # Override in S3 + async def _complete_multipart_upload(self) -> None: + return + + # Override in S3 + async def _abort_multipart_upload(self) -> None: + return + + @abstractmethod + async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool: + pass diff --git a/src/snowflake/connector/aio/_telemetry.py b/src/snowflake/connector/aio/_telemetry.py new file mode 100644 index 0000000000..f5aa5d4254 --- /dev/null +++ b/src/snowflake/connector/aio/_telemetry.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import logging +from asyncio import Lock +from typing import TYPE_CHECKING + +from ..secret_detector import SecretDetector +from ..telemetry import TelemetryClient as TelemetryClientSync +from ..telemetry import TelemetryData +from ..test_util import ENABLE_TELEMETRY_LOG, rt_plain_logger + +if TYPE_CHECKING: + from ._network import SnowflakeRestful + +logger = logging.getLogger(__name__) + + +class TelemetryClient(TelemetryClientSync): + """Client to enqueue and send metrics to the telemetry endpoint in batch.""" + + def __init__(self, rest: SnowflakeRestful, flush_size=None) -> None: + super().__init__(rest, flush_size) + self._lock = Lock() + + async def add_log_to_batch(self, telemetry_data: TelemetryData) -> None: + if self.is_closed: + raise Exception("Attempted to add log when TelemetryClient is closed") + elif not self._enabled: + logger.debug("TelemetryClient disabled. Ignoring log.") + return + + async with self._lock: + self._log_batch.append(telemetry_data) + + if len(self._log_batch) >= self._flush_size: + await self.send_batch() + + async def send_batch(self) -> None: + if self.is_closed: + raise Exception("Attempted to send batch when TelemetryClient is closed") + elif not self._enabled: + logger.debug("TelemetryClient disabled. Not sending logs.") + return + + async with self._lock: + to_send = self._log_batch + self._log_batch = [] + + if not to_send: + logger.debug("Nothing to send to telemetry.") + return + + body = {"logs": [x.to_dict() for x in to_send]} + logger.debug( + "Sending %d logs to telemetry. Data is %s.", + len(body), + SecretDetector.mask_secrets(str(body))[1], + ) + if ENABLE_TELEMETRY_LOG: + # This logger guarantees the payload won't be masked. Testing purpose. + rt_plain_logger.debug(f"Inband telemetry data being sent is {body}") + try: + ret = await self._rest.request( + TelemetryClient.SF_PATH_TELEMETRY, + body=body, + method="post", + client=None, + timeout=5, + ) + if not ret["success"]: + logger.info( + "Non-success response from telemetry server: %s. " + "Disabling telemetry.", + str(ret), + ) + self._enabled = False + else: + logger.debug("Successfully uploading metrics to telemetry.") + except Exception: + self._enabled = False + logger.debug("Failed to upload metrics to telemetry.", exc_info=True) + + async def try_add_log_to_batch(self, telemetry_data: TelemetryData) -> None: + try: + await self.add_log_to_batch(telemetry_data) + except Exception: + logger.warning("Failed to add log to telemetry.", exc_info=True) + + async def close(self, send_on_close: bool = True) -> None: + if not self.is_closed: + logger.debug("Closing telemetry client.") + if send_on_close: + await self.send_batch() + self._rest = None diff --git a/src/snowflake/connector/aio/_time_util.py b/src/snowflake/connector/aio/_time_util.py new file mode 100644 index 0000000000..c11f19728f --- /dev/null +++ b/src/snowflake/connector/aio/_time_util.py @@ -0,0 +1,61 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import logging +from typing import Callable + +from ..time_util import TimerContextManager as TimerContextManagerSync + +logger = logging.getLogger(__name__) + + +class HeartBeatTimer: + """An asyncio-based timer which executes a function every client_session_keep_alive_heartbeat_frequency seconds.""" + + def __init__( + self, client_session_keep_alive_heartbeat_frequency: int, f: Callable + ) -> None: + self.interval = client_session_keep_alive_heartbeat_frequency + self.function = f + self._task = None + self._stopped = asyncio.Event() # Event to stop the loop + + async def run(self) -> None: + """Async function to run the heartbeat at regular intervals.""" + try: + while not self._stopped.is_set(): + await asyncio.sleep(self.interval) + if not self._stopped.is_set(): + try: + await self.function() + except Exception as e: + logger.debug("failed to heartbeat: %s", e) + except asyncio.CancelledError: + logger.debug("Heartbeat timer was cancelled.") + + async def start(self) -> None: + """Starts the heartbeat.""" + self._stopped.clear() + self._task = asyncio.create_task(self.run()) + + async def stop(self) -> None: + """Stops the heartbeat.""" + self._stopped.set() + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + +class TimerContextManager(TimerContextManagerSync): + async def __aenter__(self): + return super().__enter__() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return super().__exit__(exc_type, exc_val, exc_tb) diff --git a/src/snowflake/connector/aio/auth/__init__.py b/src/snowflake/connector/aio/auth/__init__.py new file mode 100644 index 0000000000..90c76e1875 --- /dev/null +++ b/src/snowflake/connector/aio/auth/__init__.py @@ -0,0 +1,41 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from ...auth.by_plugin import AuthType +from ._auth import Auth +from ._by_plugin import AuthByPlugin +from ._default import AuthByDefault +from ._idtoken import AuthByIdToken +from ._keypair import AuthByKeyPair +from ._oauth import AuthByOAuth +from ._okta import AuthByOkta +from ._usrpwdmfa import AuthByUsrPwdMfa +from ._webbrowser import AuthByWebBrowser + +FIRST_PARTY_AUTHENTICATORS = frozenset( + ( + AuthByDefault, + AuthByKeyPair, + AuthByOAuth, + AuthByOkta, + AuthByUsrPwdMfa, + AuthByWebBrowser, + AuthByIdToken, + ) +) + +__all__ = [ + "AuthByPlugin", + "AuthByDefault", + "AuthByKeyPair", + "AuthByOAuth", + "AuthByOkta", + "AuthByUsrPwdMfa", + "AuthByWebBrowser", + "Auth", + "AuthType", + "FIRST_PARTY_AUTHENTICATORS", +] diff --git a/src/snowflake/connector/aio/auth/_auth.py b/src/snowflake/connector/aio/auth/_auth.py new file mode 100644 index 0000000000..a11cd89eb1 --- /dev/null +++ b/src/snowflake/connector/aio/auth/_auth.py @@ -0,0 +1,381 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import copy +import json +import logging +import uuid +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Callable + +from ...auth import Auth as AuthSync +from ...auth._auth import ( + AUTHENTICATION_REQUEST_KEY_WHITELIST, + ID_TOKEN, + MFA_TOKEN, + delete_temporary_credential, +) +from ...compat import urlencode +from ...constants import ( + HTTP_HEADER_ACCEPT, + HTTP_HEADER_CONTENT_TYPE, + HTTP_HEADER_SERVICE_NAME, + HTTP_HEADER_USER_AGENT, +) +from ...errorcode import ER_FAILED_TO_CONNECT_TO_DB +from ...errors import ( + BadGatewayError, + DatabaseError, + Error, + ForbiddenError, + ProgrammingError, + ServiceUnavailableError, +) +from ...network import ( + ACCEPT_TYPE_APPLICATION_SNOWFLAKE, + CONTENT_TYPE_APPLICATION_JSON, + ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE, + PYTHON_CONNECTOR_USER_AGENT, + ReauthenticationRequest, +) +from ...sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED + +if TYPE_CHECKING: + from ._by_plugin import AuthByPlugin + +logger = logging.getLogger(__name__) + + +class Auth(AuthSync): + async def authenticate( + self, + auth_instance: AuthByPlugin, + account: str, + user: str, + database: str | None = None, + schema: str | None = None, + warehouse: str | None = None, + role: str | None = None, + passcode: str | None = None, + passcode_in_password: bool = False, + mfa_callback: Callable[[], None] | None = None, + password_callback: Callable[[], str] | None = None, + session_parameters: dict[Any, Any] | None = None, + # max time waiting for MFA response, currently unused + timeout: int | None = None, + ) -> dict[str, str | int | bool]: + if mfa_callback or password_callback: + # TODO: SNOW-1707210 for mfa_callback and password_callback support + raise NotImplementedError( + "mfa_callback or password_callback is not supported in asyncio connector, please open a feature" + " request issue in github: https://github.com/snowflakedb/snowflake-connector-python/issues/new/choose" + ) + logger.debug("authenticate") + + if timeout is None: + timeout = auth_instance.timeout + + if session_parameters is None: + session_parameters = {} + + request_id = str(uuid.uuid4()) + headers = { + HTTP_HEADER_CONTENT_TYPE: CONTENT_TYPE_APPLICATION_JSON, + HTTP_HEADER_ACCEPT: ACCEPT_TYPE_APPLICATION_SNOWFLAKE, + HTTP_HEADER_USER_AGENT: PYTHON_CONNECTOR_USER_AGENT, + } + if HTTP_HEADER_SERVICE_NAME in session_parameters: + headers[HTTP_HEADER_SERVICE_NAME] = session_parameters[ + HTTP_HEADER_SERVICE_NAME + ] + url = "/session/v1/login-request" + + body_template = Auth.base_auth_data( + user, + account, + self._rest._connection.application, + self._rest._connection._internal_application_name, + self._rest._connection._internal_application_version, + self._rest._connection._ocsp_mode(), + self._rest._connection._login_timeout, + self._rest._connection._network_timeout, + self._rest._connection._socket_timeout, + ) + + body = copy.deepcopy(body_template) + # updating request body + await auth_instance.update_body(body) + + logger.debug( + "account=%s, user=%s, database=%s, schema=%s, " + "warehouse=%s, role=%s, request_id=%s", + account, + user, + database, + schema, + warehouse, + role, + request_id, + ) + url_parameters = {"request_id": request_id} + if database is not None: + url_parameters["databaseName"] = database + if schema is not None: + url_parameters["schemaName"] = schema + if warehouse is not None: + url_parameters["warehouse"] = warehouse + if role is not None: + url_parameters["roleName"] = role + + url = url + "?" + urlencode(url_parameters) + + # first auth request + if passcode_in_password: + body["data"]["EXT_AUTHN_DUO_METHOD"] = "passcode" + elif passcode: + body["data"]["EXT_AUTHN_DUO_METHOD"] = "passcode" + body["data"]["PASSCODE"] = passcode + + if session_parameters: + body["data"]["SESSION_PARAMETERS"] = session_parameters + + logger.debug( + "body['data']: %s", + { + k: v if k in AUTHENTICATION_REQUEST_KEY_WHITELIST else "******" + for (k, v) in body["data"].items() + }, + ) + + try: + ret = await self._rest._post_request( + url, + headers, + json.dumps(body), + socket_timeout=auth_instance._socket_timeout, + ) + except ForbiddenError as err: + # HTTP 403 + raise err.__class__( + msg=( + "Failed to connect to DB. " + "Verify the account name is correct: {host}:{port}. " + "{message}" + ).format( + host=self._rest._host, port=self._rest._port, message=str(err) + ), + errno=ER_FAILED_TO_CONNECT_TO_DB, + sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + ) + except (ServiceUnavailableError, BadGatewayError) as err: + # HTTP 502/504 + raise err.__class__( + msg=( + "Failed to connect to DB. " + "Service is unavailable: {host}:{port}. " + "{message}" + ).format( + host=self._rest._host, port=self._rest._port, message=str(err) + ), + errno=ER_FAILED_TO_CONNECT_TO_DB, + sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + ) + + # waiting for MFA authentication + if ret["data"] and ret["data"].get("nextAction") in ( + "EXT_AUTHN_DUO_ALL", + "EXT_AUTHN_DUO_PUSH_N_PASSCODE", + ): + body["inFlightCtx"] = ret["data"].get("inFlightCtx") + body["data"]["EXT_AUTHN_DUO_METHOD"] = "push" + self.ret = {"message": "Timeout", "data": {}} + + async def post_request_wrapper(self, url, headers, body) -> None: + # get the MFA response + self.ret = await self._rest._post_request( + url, + headers, + body, + socket_timeout=auth_instance._socket_timeout, + ) + + # send new request to wait until MFA is approved + try: + await asyncio.wait_for( + post_request_wrapper(self, url, headers, json.dumps(body)), + timeout=timeout, + ) + except asyncio.TimeoutError: + logger.debug("get the MFA response timed out") + + ret = self.ret + if ( + ret + and ret["data"] + and ret["data"].get("nextAction") == "EXT_AUTHN_SUCCESS" + ): + body = copy.deepcopy(body_template) + body["inFlightCtx"] = ret["data"].get("inFlightCtx") + # final request to get tokens + ret = await self._rest._post_request( + url, + headers, + json.dumps(body), + socket_timeout=auth_instance._socket_timeout, + ) + elif not ret or not ret["data"] or not ret["data"].get("token"): + # not token is returned. + Error.errorhandler_wrapper( + self._rest._connection, + None, + DatabaseError, + { + "msg": ( + "Failed to connect to DB. MFA " + "authentication failed: {" + "host}:{port}. {message}" + ).format( + host=self._rest._host, + port=self._rest._port, + message=ret["message"], + ), + "errno": ER_FAILED_TO_CONNECT_TO_DB, + "sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + }, + ) + return session_parameters # required for unit test + + elif ret["data"] and ret["data"].get("nextAction") == "PWD_CHANGE": + if callable(password_callback): + body = copy.deepcopy(body_template) + body["inFlightCtx"] = ret["data"].get("inFlightCtx") + body["data"]["LOGIN_NAME"] = user + body["data"]["PASSWORD"] = ( + auth_instance.password + if hasattr(auth_instance, "password") + else None + ) + body["data"]["CHOSEN_NEW_PASSWORD"] = password_callback() + # New Password input + ret = await self._rest._post_request( + url, + headers, + json.dumps(body), + socket_timeout=auth_instance._socket_timeout, + ) + + logger.debug("completed authentication") + if not ret["success"]: + errno = ret.get("code", ER_FAILED_TO_CONNECT_TO_DB) + if errno == ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE: + # clear stored id_token if failed to connect because of id_token + # raise an exception for reauth without id_token + self._rest.id_token = None + delete_temporary_credential(self._rest._host, user, ID_TOKEN) + raise ReauthenticationRequest( + ProgrammingError( + msg=ret["message"], + errno=int(errno), + sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + ) + ) + from . import AuthByKeyPair + + if isinstance(auth_instance, AuthByKeyPair): + logger.debug( + "JWT Token authentication failed. " + "Token expires at: %s. " + "Current Time: %s", + str(auth_instance._jwt_token_exp), + str(datetime.now(timezone.utc).replace(tzinfo=None)), + ) + from . import AuthByUsrPwdMfa + + if isinstance(auth_instance, AuthByUsrPwdMfa): + delete_temporary_credential(self._rest._host, user, MFA_TOKEN) + Error.errorhandler_wrapper( + self._rest._connection, + None, + DatabaseError, + { + "msg": ( + "Failed to connect to DB: {host}:{port}. " "{message}" + ).format( + host=self._rest._host, + port=self._rest._port, + message=ret["message"], + ), + "errno": ER_FAILED_TO_CONNECT_TO_DB, + "sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + }, + ) + else: + logger.debug( + "token = %s", + ( + "******" + if ret["data"] and ret["data"].get("token") is not None + else "NULL" + ), + ) + logger.debug( + "master_token = %s", + ( + "******" + if ret["data"] and ret["data"].get("masterToken") is not None + else "NULL" + ), + ) + logger.debug( + "id_token = %s", + ( + "******" + if ret["data"] and ret["data"].get("idToken") is not None + else "NULL" + ), + ) + logger.debug( + "mfa_token = %s", + ( + "******" + if ret["data"] and ret["data"].get("mfaToken") is not None + else "NULL" + ), + ) + if not ret["data"]: + Error.errorhandler_wrapper( + None, + None, + Error, + { + "msg": "There is no data in the returning response, please retry the operation." + }, + ) + await self._rest.update_tokens( + ret["data"].get("token"), + ret["data"].get("masterToken"), + master_validity_in_seconds=ret["data"].get("masterValidityInSeconds"), + id_token=ret["data"].get("idToken"), + mfa_token=ret["data"].get("mfaToken"), + ) + self.write_temporary_credentials( + self._rest._host, user, session_parameters, ret + ) + if ret["data"] and "sessionId" in ret["data"]: + self._rest._connection._session_id = ret["data"].get("sessionId") + if ret["data"] and "sessionInfo" in ret["data"]: + session_info = ret["data"].get("sessionInfo") + self._rest._connection._database = session_info.get("databaseName") + self._rest._connection._schema = session_info.get("schemaName") + self._rest._connection._warehouse = session_info.get("warehouseName") + self._rest._connection._role = session_info.get("roleName") + if ret["data"] and "parameters" in ret["data"]: + session_parameters.update( + {p["name"]: p["value"] for p in ret["data"].get("parameters")} + ) + await self._rest._connection._update_parameters(session_parameters) + return session_parameters diff --git a/src/snowflake/connector/aio/auth/_by_plugin.py b/src/snowflake/connector/aio/auth/_by_plugin.py new file mode 100644 index 0000000000..818769a9f2 --- /dev/null +++ b/src/snowflake/connector/aio/auth/_by_plugin.py @@ -0,0 +1,135 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import logging +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Iterator + +from ... import DatabaseError, Error, OperationalError +from ...auth import AuthByPlugin as AuthByPluginSync +from ...errorcode import ER_FAILED_TO_CONNECT_TO_DB +from ...sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED + +if TYPE_CHECKING: + from .. import SnowflakeConnection + +logger = logging.getLogger(__name__) + + +class AuthByPlugin(AuthByPluginSync): + def __init__( + self, + timeout: int | None = None, + backoff_generator: Iterator | None = None, + **kwargs, + ) -> None: + super().__init__(timeout, backoff_generator, **kwargs) + + @abstractmethod + async def prepare( + self, + *, + conn: SnowflakeConnection, + authenticator: str, + service_name: str | None, + account: str, + user: str, + password: str | None, + **kwargs: Any, + ) -> str | None: + raise NotImplementedError + + @abstractmethod + async def update_body(self, body: dict[Any, Any]) -> None: + """Update the body of the authentication request.""" + raise NotImplementedError + + @abstractmethod + async def reset_secrets(self) -> None: + """Reset secret members.""" + raise NotImplementedError + + @abstractmethod + async def reauthenticate( + self, + *, + conn: SnowflakeConnection, + **kwargs: Any, + ) -> dict[str, Any]: + """Re-perform authentication. + + The difference between this and authentication is that secrets will be removed + from memory by the time this gets called. + """ + raise NotImplementedError + + async def _handle_failure( + self, + *, + conn: SnowflakeConnection, + ret: dict[Any, Any], + **kwargs: Any, + ) -> None: + """Handles a failure when an issue happens while connecting to Snowflake. + + If the user returns from this function execution will continue. The argument + data can be manipulated from within this function and so recovery is possible + from here. + """ + Error.errorhandler_wrapper( + conn, + None, + DatabaseError, + { + "msg": "Failed to connect to DB: {host}:{port}, {message}".format( + host=conn._rest._host, + port=conn._rest._port, + message=ret["message"], + ), + "errno": int(ret.get("code", -1)), + "sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + }, + ) + + async def handle_timeout( + self, + *, + authenticator: str, + service_name: str | None, + account: str, + user: str, + password: str, + **kwargs: Any, + ) -> None: + """Default timeout handler. + + This will trigger if the authenticator + hasn't implemented one. By default we retry on timeouts and use + jitter to deduce the time to sleep before retrying. The sleep + time ranges between 1 and 16 seconds. + """ + + # Some authenticators may not want to delete the parameters to this function + # Currently, the only authenticator where this is the case is AuthByKeyPair + if kwargs.pop("delete_params", True): + del authenticator, service_name, account, user, password + + logger.debug("Default timeout handler invoked for authenticator") + if not self._retry_ctx.should_retry: + error = OperationalError( + msg=f"Could not connect to Snowflake backend after {self._retry_ctx.current_retry_count + 1} attempt(s)." + "Aborting", + errno=ER_FAILED_TO_CONNECT_TO_DB, + ) + raise error + else: + logger.debug( + f"Hit connection timeout, attempt number {self._retry_ctx.current_retry_count + 1}." + " Will retry in a bit..." + ) + await asyncio.sleep(float(self._retry_ctx.current_sleep_time)) + self._retry_ctx.increment() diff --git a/src/snowflake/connector/aio/auth/_default.py b/src/snowflake/connector/aio/auth/_default.py new file mode 100644 index 0000000000..1466db4d7a --- /dev/null +++ b/src/snowflake/connector/aio/auth/_default.py @@ -0,0 +1,32 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from logging import getLogger +from typing import Any + +from ...auth.default import AuthByDefault as AuthByDefaultSync +from ._by_plugin import AuthByPlugin as AuthByPluginAsync + +logger = getLogger(__name__) + + +class AuthByDefault(AuthByPluginAsync, AuthByDefaultSync): + def __init__(self, password: str, **kwargs) -> None: + """Initializes an instance with a password.""" + AuthByDefaultSync.__init__(self, password, **kwargs) + + async def reset_secrets(self) -> None: + self._password = None + + async def prepare(self, **kwargs: Any) -> None: + AuthByDefaultSync.prepare(self, **kwargs) + + async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]: + return AuthByDefaultSync.reauthenticate(self, **kwargs) + + async def update_body(self, body: dict[Any, Any]) -> None: + """Sets the password if available.""" + AuthByDefaultSync.update_body(self, body) diff --git a/src/snowflake/connector/aio/auth/_idtoken.py b/src/snowflake/connector/aio/auth/_idtoken.py new file mode 100644 index 0000000000..23bca2beaa --- /dev/null +++ b/src/snowflake/connector/aio/auth/_idtoken.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ...auth.idtoken import AuthByIdToken as AuthByIdTokenSync +from ._by_plugin import AuthByPlugin as AuthByPluginAsync +from ._webbrowser import AuthByWebBrowser + +if TYPE_CHECKING: + from .._connection import SnowflakeConnection + + +class AuthByIdToken(AuthByPluginAsync, AuthByIdTokenSync): + def __init__( + self, + id_token: str, + application: str, + protocol: str | None, + host: str | None, + port: str | None, + **kwargs, + ) -> None: + """Initialized an instance with an IdToken.""" + AuthByIdTokenSync.__init__( + self, id_token, application, protocol, host, port, **kwargs + ) + + async def reset_secrets(self) -> None: + AuthByIdTokenSync.reset_secrets(self) + + async def prepare(self, **kwargs: Any) -> None: + AuthByIdTokenSync.prepare(self, **kwargs) + + async def reauthenticate( + self, + *, + conn: SnowflakeConnection, + **kwargs: Any, + ) -> dict[str, bool]: + conn.auth_class = AuthByWebBrowser( + application=self._application, + protocol=self._protocol, + host=self._host, + port=self._port, + timeout=conn.login_timeout, + backoff_generator=conn._backoff_generator, + ) + await conn._authenticate(conn.auth_class) + await conn._auth_class.reset_secrets() + return {"success": True} + + async def update_body(self, body: dict[Any, Any]) -> None: + """Sets the id_token if available.""" + AuthByIdTokenSync.update_body(self, body) diff --git a/src/snowflake/connector/aio/auth/_keypair.py b/src/snowflake/connector/aio/auth/_keypair.py new file mode 100644 index 0000000000..641f387d11 --- /dev/null +++ b/src/snowflake/connector/aio/auth/_keypair.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from __future__ import annotations + +from logging import getLogger +from typing import Any + +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + +from ...auth.keypair import AuthByKeyPair as AuthByKeyPairSync +from ._by_plugin import AuthByPlugin as AuthByPluginAsync + +logger = getLogger(__name__) + + +class AuthByKeyPair(AuthByPluginAsync, AuthByKeyPairSync): + def __init__( + self, + private_key: bytes | RSAPrivateKey, + lifetime_in_seconds: int = AuthByKeyPairSync.LIFETIME, + **kwargs, + ) -> None: + AuthByKeyPairSync.__init__(self, private_key, lifetime_in_seconds, **kwargs) + + async def reset_secrets(self) -> None: + AuthByKeyPairSync.reset_secrets(self) + + async def prepare(self, **kwargs: Any) -> None: + AuthByKeyPairSync.prepare(self, **kwargs) + + async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]: + return AuthByKeyPairSync.reauthenticate(self, **kwargs) + + async def update_body(self, body: dict[Any, Any]) -> None: + """Sets the private key if available.""" + AuthByKeyPairSync.update_body(self, body) + + async def handle_timeout( + self, + *, + authenticator: str, + service_name: str | None, + account: str, + user: str, + password: str | None, + **kwargs: Any, + ) -> None: + logger.debug("Invoking base timeout handler") + await AuthByPluginAsync.handle_timeout( + self, + authenticator=authenticator, + service_name=service_name, + account=account, + user=user, + password=password, + delete_params=False, + ) + + logger.debug("Base timeout handler passed, preparing new token before retrying") + await self.prepare(account=account, user=user) diff --git a/src/snowflake/connector/aio/auth/_oauth.py b/src/snowflake/connector/aio/auth/_oauth.py new file mode 100644 index 0000000000..04cd44ba2c --- /dev/null +++ b/src/snowflake/connector/aio/auth/_oauth.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from typing import Any + +from ...auth.oauth import AuthByOAuth as AuthByOAuthSync +from ._by_plugin import AuthByPlugin as AuthByPluginAsync + + +class AuthByOAuth(AuthByPluginAsync, AuthByOAuthSync): + def __init__(self, oauth_token: str, **kwargs) -> None: + """Initializes an instance with an OAuth Token.""" + AuthByOAuthSync.__init__(self, oauth_token, **kwargs) + + async def reset_secrets(self) -> None: + AuthByOAuthSync.reset_secrets(self) + + async def prepare(self, **kwargs: Any) -> None: + AuthByOAuthSync.prepare(self, **kwargs) + + async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]: + return AuthByOAuthSync.reauthenticate(self, **kwargs) + + async def update_body(self, body: dict[Any, Any]) -> None: + AuthByOAuthSync.update_body(self, body) diff --git a/src/snowflake/connector/aio/auth/_okta.py b/src/snowflake/connector/aio/auth/_okta.py new file mode 100644 index 0000000000..d8cd216df5 --- /dev/null +++ b/src/snowflake/connector/aio/auth/_okta.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import json +import logging +import time +from functools import partial +from typing import TYPE_CHECKING, Any, Awaitable, Callable + +from snowflake.connector.aio.auth import Auth + +from ... import DatabaseError, Error +from ...auth.okta import AuthByOkta as AuthByOktaSync +from ...compat import urlencode +from ...constants import ( + HTTP_HEADER_ACCEPT, + HTTP_HEADER_CONTENT_TYPE, + HTTP_HEADER_SERVICE_NAME, + HTTP_HEADER_USER_AGENT, +) +from ...errorcode import ER_IDP_CONNECTION_ERROR +from ...errors import RefreshTokenError +from ...network import CONTENT_TYPE_APPLICATION_JSON, PYTHON_CONNECTOR_USER_AGENT +from ...sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED +from ._by_plugin import AuthByPlugin as AuthByPluginAsync + +if TYPE_CHECKING: + from .. import SnowflakeConnection + +logger = logging.getLogger(__name__) + + +class AuthByOkta(AuthByPluginAsync, AuthByOktaSync): + def __init__(self, application: str, **kwargs) -> None: + AuthByOktaSync.__init__(self, application, **kwargs) + + async def reset_secrets(self) -> None: + AuthByOktaSync.reset_secrets(self) + + async def prepare( + self, + *, + conn: SnowflakeConnection, + authenticator: str, + service_name: str | None, + account: str, + user: str, + password: str, + **kwargs: Any, + ) -> None: + """SAML Authentication. + + Steps are: + 1. query GS to obtain IDP token and SSO url + 2. IMPORTANT Client side validation: + validate both token url and sso url contains same prefix + (protocol + host + port) as the given authenticator url. + Explanation: + This provides a way for the user to 'authenticate' the IDP it is + sending his/her credentials to. Without such a check, the user could + be coerced to provide credentials to an IDP impersonator. + 3. query IDP token url to authenticate and retrieve access token + 4. given access token, query IDP URL snowflake app to get SAML response + 5. IMPORTANT Client side validation: + validate the post back url come back with the SAML response + contains the same prefix as the Snowflake's server url, which is the + intended destination url to Snowflake. + Explanation: + This emulates the behavior of IDP initiated login flow in the user + browser where the IDP instructs the browser to POST the SAML + assertion to the specific SP endpoint. This is critical in + preventing a SAML assertion issued to one SP from being sent to + another SP. + """ + logger.debug("authenticating by SAML") + headers, sso_url, token_url = await self._step1( + conn, + authenticator, + service_name, + account, + user, + ) + await self._step2(conn, authenticator, sso_url, token_url) + response_html = await self._step4( + conn, + partial(self._step3, conn, headers, token_url, user, password), + sso_url, + ) + await self._step5(conn, response_html) + + async def reauthenticate(self, **kwargs: Any) -> dict[str, bool]: + return AuthByOktaSync.reauthenticate(self, **kwargs) + + async def update_body(self, body: dict[Any, Any]) -> None: + AuthByOktaSync.update_body(self, body) + + async def _step1( + self, + conn: SnowflakeConnection, + authenticator: str, + service_name: str | None, + account: str, + user: str, + ) -> tuple[dict[str, str], str, str]: + logger.debug("step 1: query GS to obtain IDP token and SSO url") + + headers = { + HTTP_HEADER_CONTENT_TYPE: CONTENT_TYPE_APPLICATION_JSON, + HTTP_HEADER_ACCEPT: CONTENT_TYPE_APPLICATION_JSON, + HTTP_HEADER_USER_AGENT: PYTHON_CONNECTOR_USER_AGENT, + } + if service_name: + headers[HTTP_HEADER_SERVICE_NAME] = service_name + url = "/session/authenticator-request" + body = Auth.base_auth_data( + user, + account, + conn.application, + conn._internal_application_name, + conn._internal_application_version, + conn._ocsp_mode(), + conn.login_timeout, + conn._network_timeout, + ) + + body["data"]["AUTHENTICATOR"] = authenticator + logger.debug( + "account=%s, authenticator=%s", + account, + authenticator, + ) + ret = await conn._rest._post_request( + url, + headers, + json.dumps(body), + timeout=conn._rest._connection.login_timeout, + socket_timeout=conn._rest._connection.login_timeout, + ) + + if not ret["success"]: + await self._handle_failure(conn=conn, ret=ret) + + data = ret["data"] + token_url = data["tokenUrl"] + sso_url = data["ssoUrl"] + return headers, sso_url, token_url + + async def _step2( + self, + conn: SnowflakeConnection, + authenticator: str, + sso_url: str, + token_url: str, + ) -> None: + return super()._step2(conn, authenticator, sso_url, token_url) + + @staticmethod + async def _step3( + conn: SnowflakeConnection, + headers: dict[str, str], + token_url: str, + user: str, + password: str, + ) -> str: + logger.debug( + "step 3: query IDP token url to authenticate and " "retrieve access token" + ) + data = { + "username": user, + "password": password, + } + ret = await conn._rest.fetch( + "post", + token_url, + headers, + data=json.dumps(data), + timeout=conn._rest._connection.login_timeout, + socket_timeout=conn._rest._connection.login_timeout, + catch_okta_unauthorized_error=True, + ) + one_time_token = ret.get("sessionToken", ret.get("cookieToken")) + if not one_time_token: + Error.errorhandler_wrapper( + conn._rest._connection, + None, + DatabaseError, + { + "msg": ( + "The authentication failed for {user} " + "by {token_url}.".format( + token_url=token_url, + user=user, + ) + ), + "errno": ER_IDP_CONNECTION_ERROR, + "sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, + }, + ) + return one_time_token + + @staticmethod + async def _step4( + conn: SnowflakeConnection, + generate_one_time_token: Callable[[], Awaitable[str]], + sso_url: str, + ) -> dict[Any, Any]: + logger.debug("step 4: query IDP URL snowflake app to get SAML " "response") + timeout_time = time.time() + conn.login_timeout if conn.login_timeout else None + response_html = {} + origin_sso_url = sso_url + while timeout_time is None or time.time() < timeout_time: + try: + url_parameters = { + "RelayState": "/some/deep/link", + "onetimetoken": await generate_one_time_token(), + } + sso_url = origin_sso_url + "?" + urlencode(url_parameters) + headers = { + HTTP_HEADER_ACCEPT: "*/*", + } + remaining_timeout = timeout_time - time.time() if timeout_time else None + response_html = await conn._rest.fetch( + "get", + sso_url, + headers, + timeout=remaining_timeout, + socket_timeout=remaining_timeout, + is_raw_text=True, + is_okta_authentication=True, + ) + break + except RefreshTokenError: + logger.debug("step4: refresh token for re-authentication") + return response_html + + async def _step5( + self, + conn: SnowflakeConnection, + response_html: str, + ) -> None: + return super()._step5(conn, response_html) diff --git a/src/snowflake/connector/aio/auth/_usrpwdmfa.py b/src/snowflake/connector/aio/auth/_usrpwdmfa.py new file mode 100644 index 0000000000..4175bf5015 --- /dev/null +++ b/src/snowflake/connector/aio/auth/_usrpwdmfa.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from ...auth.usrpwdmfa import AuthByUsrPwdMfa as AuthByUsrPwdMfaSync +from ._by_plugin import AuthByPlugin as AuthByPluginAsync + + +class AuthByUsrPwdMfa(AuthByPluginAsync, AuthByUsrPwdMfaSync): + def __init__( + self, + password: str, + mfa_token: str | None = None, + **kwargs, + ) -> None: + """Initializes and instance with a password and a mfa token.""" + AuthByUsrPwdMfaSync.__init__(self, password, mfa_token, **kwargs) + + async def reset_secrets(self) -> None: + AuthByUsrPwdMfaSync.reset_secrets(self) + + async def prepare(self, **kwargs) -> None: + AuthByUsrPwdMfaSync.prepare(self, **kwargs) + + async def reauthenticate(self, **kwargs) -> dict[str, bool]: + return AuthByUsrPwdMfaSync.reauthenticate(self, **kwargs) + + async def update_body(self, body: dict[str, str]) -> None: + AuthByUsrPwdMfaSync.update_body(self, body) diff --git a/src/snowflake/connector/aio/auth/_webbrowser.py b/src/snowflake/connector/aio/auth/_webbrowser.py new file mode 100644 index 0000000000..97e9bbc1b6 --- /dev/null +++ b/src/snowflake/connector/aio/auth/_webbrowser.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from __future__ import annotations + +import asyncio +import json +import logging +import os +import select +import socket +import time +from types import ModuleType +from typing import TYPE_CHECKING, Any + +from snowflake.connector.aio.auth import Auth + +from ... import OperationalError +from ...auth.webbrowser import BUF_SIZE +from ...auth.webbrowser import AuthByWebBrowser as AuthByWebBrowserSync +from ...compat import IS_WINDOWS, parse_qs +from ...constants import ( + HTTP_HEADER_ACCEPT, + HTTP_HEADER_CONTENT_TYPE, + HTTP_HEADER_SERVICE_NAME, + HTTP_HEADER_USER_AGENT, +) +from ...errorcode import ( + ER_IDP_CONNECTION_ERROR, + ER_INVALID_VALUE, + ER_NO_HOSTNAME_FOUND, + ER_UNABLE_TO_OPEN_BROWSER, +) +from ...network import ( + CONTENT_TYPE_APPLICATION_JSON, + DEFAULT_SOCKET_CONNECT_TIMEOUT, + PYTHON_CONNECTOR_USER_AGENT, +) +from ...url_util import is_valid_url +from ._by_plugin import AuthByPlugin as AuthByPluginAsync + +if TYPE_CHECKING: + from .._connection import SnowflakeConnection + +logger = logging.getLogger(__name__) + + +class AuthByWebBrowser(AuthByPluginAsync, AuthByWebBrowserSync): + def __init__( + self, + application: str, + webbrowser_pkg: ModuleType | None = None, + socket_pkg: type[socket.socket] | None = None, + protocol: str | None = None, + host: str | None = None, + port: str | None = None, + **kwargs, + ) -> None: + AuthByWebBrowserSync.__init__( + self, + application, + webbrowser_pkg, + socket_pkg, + protocol, + host, + port, + **kwargs, + ) + self._event_loop = asyncio.get_event_loop() + + async def reset_secrets(self) -> None: + AuthByWebBrowserSync.reset_secrets(self) + + async def prepare( + self, + *, + conn: SnowflakeConnection, + authenticator: str, + service_name: str | None, + account: str, + user: str, + **kwargs: Any, + ) -> None: + """Web Browser based Authentication.""" + logger.debug("authenticating by Web Browser") + + socket_connection = self._socket(socket.AF_INET, socket.SOCK_STREAM) + + if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true": + if IS_WINDOWS: + logger.warning( + "Configuration SNOWFLAKE_AUTH_SOCKET_REUSE_PORT is not available in Windows. Ignoring." + ) + else: + socket_connection.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + + try: + try: + socket_connection.bind( + ( + os.getenv("SF_AUTH_SOCKET_ADDR", "localhost"), + int(os.getenv("SF_AUTH_SOCKET_PORT", 0)), + ) + ) + except socket.gaierror as ex: + if ex.args[0] == socket.EAI_NONAME: + raise OperationalError( + msg="localhost is not found. Ensure /etc/hosts has " + "localhost entry.", + errno=ER_NO_HOSTNAME_FOUND, + ) + else: + raise ex + socket_connection.listen(0) # no backlog + callback_port = socket_connection.getsockname()[1] + + if conn._disable_console_login: + logger.debug("step 1: query GS to obtain SSO url") + sso_url = await self._get_sso_url( + conn, authenticator, service_name, account, callback_port, user + ) + else: + logger.debug("step 1: constructing console login url") + sso_url = self._get_console_login_url(conn, callback_port, user) + + logger.debug("Validate SSO URL") + if not is_valid_url(sso_url): + await self._handle_failure( + conn=conn, + ret={ + "code": ER_INVALID_VALUE, + "message": (f"The SSO URL provided {sso_url} is invalid"), + }, + ) + return + + print( + "Initiating login request with your identity provider. A " + "browser window should have opened for you to complete the " + "login. If you can't see it, check existing browser windows, " + "or your OS settings. Press CTRL+C to abort and try again..." + ) + + logger.debug("step 2: open a browser") + print(f"Going to open: {sso_url} to authenticate...") + if not self._webbrowser.open_new(sso_url): + print( + "We were unable to open a browser window for you, " + "please open the url above manually then paste the " + "URL you are redirected to into the terminal." + ) + url = input("Enter the URL the SSO URL redirected you to: ") + self._process_get_url(url) + if not self._token: + # Input contained no token, either URL was incorrectly pasted, + # empty or just wrong + await self._handle_failure( + conn=conn, + ret={ + "code": ER_UNABLE_TO_OPEN_BROWSER, + "message": ( + "Unable to open a browser in this environment and " + "SSO URL contained no token" + ), + }, + ) + return + else: + logger.debug("step 3: accept SAML token") + await self._receive_saml_token(conn, socket_connection) + finally: + socket_connection.close() + + async def reauthenticate( + self, + *, + conn: SnowflakeConnection, + **kwargs: Any, + ) -> dict[str, bool]: + await conn.authenticate_with_retry(self) + return {"success": True} + + async def update_body(self, body: dict[Any, Any]) -> None: + AuthByWebBrowserSync.update_body(self, body) + + async def _receive_saml_token( + self, conn: SnowflakeConnection, socket_connection + ) -> None: + """Receives SAML token from web browser.""" + while True: + try: + attempts = 0 + raw_data = bytearray() + socket_client = None + max_attempts = 15 + + # when running in a containerized environment, socket_client.recv ocassionally returns an empty byte array + # an immediate successive call to socket_client.recv gets the actual data + while len(raw_data) == 0 and attempts < max_attempts: + attempts += 1 + read_sockets, _write_sockets, _exception_sockets = select.select( + [socket_connection], [], [] + ) + + if read_sockets[0] is not None: + # Receive the data in small chunks and retransmit it + socket_client, _ = await self._event_loop.sock_accept( + socket_connection + ) + + try: + # Async delta: async version of sock_recv does not take flags + # on one hand, sock must be a non-blocking socket in async according to python docs: + # https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.sock_recv + # on the other hand according to linux: https://man7.org/linux/man-pages/man2/recvmsg.2.html + # sync flag MSG_DONTWAIT achieves the same effect as O_NONBLOCK, but it's a per-call flag + # however here for each call we accept a new socket, so they are effectively the same. + # https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.sock_recv + socket_client.setblocking(False) + raw_data = await asyncio.wait_for( + self._event_loop.sock_recv(socket_client, BUF_SIZE), + timeout=( + DEFAULT_SOCKET_CONNECT_TIMEOUT + if conn.socket_timeout is None + else conn.socket_timeout + ), + ) + except asyncio.TimeoutError: + logger.debug( + "sock_recv timed out while attempting to retrieve callback token request" + ) + if attempts < max_attempts: + sleep_time = 0.25 + logger.debug( + f"Waiting {sleep_time} seconds before trying again" + ) + await asyncio.sleep(sleep_time) + else: + logger.debug("Exceeded retry count") + + data = raw_data.decode("utf-8").split("\r\n") + + if not await self._process_options(data, socket_client): + await self._process_receive_saml_token(conn, data, socket_client) + break + + finally: + socket_client.shutdown(socket.SHUT_RDWR) + socket_client.close() + + async def _process_options( + self, data: list[str], socket_client: socket.socket + ) -> bool: + """Allows JS Ajax access to this endpoint.""" + for line in data: + if line.startswith("OPTIONS "): + break + else: + return False + + self._get_user_agent(data) + requested_headers, requested_origin = self._check_post_requested(data) + if not requested_headers: + return False + + if not self._validate_origin(requested_origin): + # validate Origin and fail if not match with the server. + return False + + self._origin = requested_origin + content = [ + "HTTP/1.1 200 OK", + "Date: {}".format( + time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()) + ), + "Access-Control-Allow-Methods: POST, GET", + f"Access-Control-Allow-Headers: {requested_headers}", + "Access-Control-Max-Age: 86400", + f"Access-Control-Allow-Origin: {self._origin}", + "", + "", + ] + await self._event_loop.sock_sendall( + socket_client, "\r\n".join(content).encode("utf-8") + ) + return True + + async def _process_receive_saml_token( + self, conn: SnowflakeConnection, data: list[str], socket_client: socket.socket + ) -> None: + if not self._process_get(data) and not await self._process_post(conn, data): + return # error + + content = [ + "HTTP/1.1 200 OK", + "Content-Type: text/html", + ] + if self._origin: + data = {"consent": self.consent_cache_id_token} + msg = json.dumps(data) + content.append(f"Access-Control-Allow-Origin: {self._origin}") + content.append("Vary: Accept-Encoding, Origin") + else: + msg = f""" + + +SAML Response for Snowflake + +Your identity was confirmed and propagated to Snowflake {self._application}. +You can close this window now and go back where you started from. +""" + content.append(f"Content-Length: {len(msg)}") + content.append("") + content.append(msg) + + await self._event_loop.sock_sendall( + socket_client, "\r\n".join(content).encode("utf-8") + ) + + async def _process_post(self, conn: SnowflakeConnection, data: list[str]) -> bool: + for line in data: + if line.startswith("POST "): + break + else: + await self._handle_failure( + conn=conn, + ret={ + "code": ER_IDP_CONNECTION_ERROR, + "message": "Invalid HTTP request from web browser. Idp " + "authentication could have failed.", + }, + ) + return False + + self._get_user_agent(data) + try: + # parse the response as JSON + payload = json.loads(data[-1]) + self._token = payload.get("token") + self.consent_cache_id_token = payload.get("consent", True) + except Exception: + # key=value form. + self._token = parse_qs(data[-1])["token"][0] + return True + + async def _get_sso_url( + self, + conn: SnowflakeConnection, + authenticator: str, + service_name: str | None, + account: str, + callback_port: int, + user: str, + ) -> str: + """Gets SSO URL from Snowflake.""" + headers = { + HTTP_HEADER_CONTENT_TYPE: CONTENT_TYPE_APPLICATION_JSON, + HTTP_HEADER_ACCEPT: CONTENT_TYPE_APPLICATION_JSON, + HTTP_HEADER_USER_AGENT: PYTHON_CONNECTOR_USER_AGENT, + } + if service_name: + headers[HTTP_HEADER_SERVICE_NAME] = service_name + + url = "/session/authenticator-request" + body = Auth.base_auth_data( + user, + account, + conn._rest._connection.application, + conn._rest._connection._internal_application_name, + conn._rest._connection._internal_application_version, + conn._rest._connection._ocsp_mode(), + conn._rest._connection.login_timeout, + conn._rest._connection._network_timeout, + ) + + body["data"]["AUTHENTICATOR"] = authenticator + body["data"]["BROWSER_MODE_REDIRECT_PORT"] = str(callback_port) + logger.debug( + "account=%s, authenticator=%s, user=%s", account, authenticator, user + ) + ret = await conn._rest._post_request( + url, + headers, + json.dumps(body), + timeout=conn._rest._connection.login_timeout, + socket_timeout=conn._rest._connection.login_timeout, + ) + if not ret["success"]: + await self._handle_failure(conn=conn, ret=ret) + data = ret["data"] + sso_url = data["ssoUrl"] + self._proof_key = data["proofKey"] + return sso_url diff --git a/src/snowflake/connector/errors.py b/src/snowflake/connector/errors.py index 9c262cc4b2..8926afddb0 100644 --- a/src/snowflake/connector/errors.py +++ b/src/snowflake/connector/errors.py @@ -336,10 +336,18 @@ def hand_to_other_handler( connection.messages.append((error_class, error_value)) if cursor is not None: cursor.messages.append((error_class, error_value)) - cursor.errorhandler(connection, cursor, error_class, error_value) + try: + cursor.errorhandler(connection, cursor, error_class, error_value) + except NotImplementedError: + # for async compatibility, check SNOW-1763096 and SNOW-1763103 + cursor._errorhandler(connection, cursor, error_class, error_value) return True elif connection is not None: - connection.errorhandler(connection, cursor, error_class, error_value) + try: + connection.errorhandler(connection, cursor, error_class, error_value) + except NotImplementedError: + # for async compatibility, check SNOW-1763096 and SNOW-1763103 + connection._errorhandler(connection, cursor, error_class, error_value) return True return False diff --git a/test/aiodep/unsupported_python_version.py b/test/aiodep/unsupported_python_version.py new file mode 100644 index 0000000000..2d34947f12 --- /dev/null +++ b/test/aiodep/unsupported_python_version.py @@ -0,0 +1,41 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import asyncio +import sys + +import snowflake.connector.aio + +assert ( + sys.version_info.major == 3 and sys.version_info.minor <= 9 +), "This test is only for Python 3.9 and lower" + + +CONNECTION_PARAMETERS = { + "account": "test", + "user": "test", + "password": "test", + "schema": "test", + "database": "test", + "protocol": "test", + "host": "test.snowflakecomputing.com", + "warehouse": "test", + "port": 443, + "role": "test", +} + + +async def main(): + try: + async with snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS): + pass + except Exception as exc: + assert isinstance( + exc, RuntimeError + ) and "Async Snowflake Python Connector requires Python 3.10+" in str( + exc + ), "should raise RuntimeError" + + +asyncio.run(main()) diff --git a/test/helpers.py b/test/helpers.py index 34cc309bb9..98f1db898a 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -5,20 +5,23 @@ from __future__ import annotations +import asyncio import base64 +import functools import math import os import random import secrets import time from typing import TYPE_CHECKING, Pattern, Sequence -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pytest from snowflake.connector.compat import OK if TYPE_CHECKING: + import snowflake.connector.aio import snowflake.connector.connection try: @@ -41,6 +44,10 @@ from snowflake.connector.constants import QueryStatus except ImportError: QueryStatus = None +try: + import snowflake.connector.aio +except ImportError: + pass def create_mock_response(status_code: int) -> Mock: @@ -56,6 +63,16 @@ def create_mock_response(status_code: int) -> Mock: return mock_resp +def create_async_mock_response(status: int) -> AsyncMock: + async def _create_async_mock_response(url, *, status, **kwargs): + resp = AsyncMock(status=status) + resp.read.return_value = "success" if status == OK else "fail" + resp.status = status + return resp + + return functools.partial(_create_async_mock_response, status=status) + + def verify_log_tuple( module: str, level: int, @@ -112,6 +129,40 @@ def _wait_until_query_success( ) +async def _wait_while_query_running_async( + con: snowflake.connector.aio.SnowflakeConnection, + sfqid: str, + sleep_time: int, + dont_cache: bool = False, +) -> None: + """ + Checks if the provided still returns that it is still running, and if so, + sleeps for the specified time in a while loop. + """ + query_status = con._get_query_status if dont_cache else con.get_query_status + while con.is_still_running(await query_status(sfqid)): + await asyncio.sleep(sleep_time) + + +async def _wait_until_query_success_async( + con: snowflake.connector.aio.SnowflakeConnection, + sfqid: str, + num_checks: int, + sleep_per_check: int, +) -> None: + for _ in range(num_checks): + status = await con.get_query_status(sfqid) + if status == QueryStatus.SUCCESS: + break + await asyncio.sleep(sleep_per_check) + else: + pytest.fail( + "We should have broke out of wait loop for query success." + f"Query ID: {sfqid}" + f"Final query status: {status}" + ) + + def create_nanoarrow_pyarrow_iterator(input_data, use_table_iterator): # create nanoarrow based iterator return ( diff --git a/test/integ/aio/__init__.py b/test/integ/aio/__init__.py new file mode 100644 index 0000000000..ef416f64a0 --- /dev/null +++ b/test/integ/aio/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/test/integ/aio/conftest.py b/test/integ/aio/conftest.py new file mode 100644 index 0000000000..498aae3983 --- /dev/null +++ b/test/integ/aio/conftest.py @@ -0,0 +1,146 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from contextlib import asynccontextmanager +from test.integ.conftest import get_db_parameters, is_public_testaccount +from typing import AsyncContextManager, Callable, Generator + +import pytest + +from snowflake.connector.aio import SnowflakeConnection +from snowflake.connector.aio._telemetry import TelemetryClient +from snowflake.connector.connection import DefaultConverterClass +from snowflake.connector.telemetry import TelemetryData + + +class TelemetryCaptureHandlerAsync(TelemetryClient): + def __init__( + self, + real_telemetry: TelemetryClient, + propagate: bool = True, + ): + super().__init__(real_telemetry._rest) + self.records: list[TelemetryData] = [] + self._real_telemetry = real_telemetry + self._propagate = propagate + + async def add_log_to_batch(self, telemetry_data): + self.records.append(telemetry_data) + if self._propagate: + await super().add_log_to_batch(telemetry_data) + + async def send_batch(self): + self.records = [] + if self._propagate: + await super().send_batch() + + +class TelemetryCaptureFixtureAsync: + """Provides a way to capture Snowflake telemetry messages.""" + + @asynccontextmanager + async def patch_connection( + self, + con: SnowflakeConnection, + propagate: bool = True, + ) -> Generator[TelemetryCaptureHandlerAsync, None, None]: + original_telemetry = con._telemetry + new_telemetry = TelemetryCaptureHandlerAsync( + original_telemetry, + propagate, + ) + con._telemetry = new_telemetry + try: + yield new_telemetry + finally: + con._telemetry = original_telemetry + + +@pytest.fixture(scope="session") +def capture_sf_telemetry_async() -> TelemetryCaptureFixtureAsync: + return TelemetryCaptureFixtureAsync() + + +async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection: + """Creates a connection using the parameters defined in parameters.py. + + You can select from the different connections by supplying the appropiate + connection_name parameter and then anything else supplied will overwrite the values + from parameters.py. + """ + ret = get_db_parameters(connection_name) + ret.update(kwargs) + connection = SnowflakeConnection(**ret) + await connection.connect() + return connection + + +@asynccontextmanager +async def db( + connection_name: str = "default", + **kwargs, +) -> Generator[SnowflakeConnection, None, None]: + if not kwargs.get("timezone"): + kwargs["timezone"] = "UTC" + if not kwargs.get("converter_class"): + kwargs["converter_class"] = DefaultConverterClass() + cnx = await create_connection(connection_name, **kwargs) + try: + yield cnx + finally: + await cnx.close() + + +@asynccontextmanager +async def negative_db( + connection_name: str = "default", + **kwargs, +) -> Generator[SnowflakeConnection, None, None]: + if not kwargs.get("timezone"): + kwargs["timezone"] = "UTC" + if not kwargs.get("converter_class"): + kwargs["converter_class"] = DefaultConverterClass() + cnx = await create_connection(connection_name, **kwargs) + if not is_public_testaccount(): + await cnx.cursor().execute("alter session set SUPPRESS_INCIDENT_DUMPS=true") + try: + yield cnx + finally: + await cnx.close() + + +@pytest.fixture +def conn_cnx(): + return db + + +@pytest.fixture() +async def conn_testaccount() -> SnowflakeConnection: + connection = await create_connection("default") + yield connection + await connection.close() + + +@pytest.fixture() +def negative_conn_cnx() -> Callable[..., AsyncContextManager[SnowflakeConnection]]: + """Use this if an incident is expected and we don't want GS to create a dump file about the incident.""" + return negative_db + + +@pytest.fixture() +async def aio_connection(db_parameters): + cnx = SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + database=db_parameters["database"], + schema=db_parameters["schema"], + warehouse=db_parameters["warehouse"], + protocol=db_parameters["protocol"], + timezone="UTC", + ) + yield cnx + await cnx.close() diff --git a/test/integ/aio/lambda/__init__.py b/test/integ/aio/lambda/__init__.py new file mode 100644 index 0000000000..ef416f64a0 --- /dev/null +++ b/test/integ/aio/lambda/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/test/integ/aio/lambda/test_basic_query_async.py b/test/integ/aio/lambda/test_basic_query_async.py new file mode 100644 index 0000000000..1f34541269 --- /dev/null +++ b/test/integ/aio/lambda/test_basic_query_async.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + + +async def test_connection(conn_cnx): + """Test basic connection.""" + async with conn_cnx() as cnx: + cur = cnx.cursor() + result = await (await cur.execute("select 1;")).fetchall() + assert result == [(1,)] + + +async def test_large_resultset(conn_cnx): + """Test large resultset.""" + async with conn_cnx() as cnx: + cur = cnx.cursor() + result = await ( + await cur.execute( + "select seq8(), randstr(1000, random()) from table(generator(rowcount=>10000));" + ) + ).fetchall() + assert len(result) == 10000 diff --git a/test/integ/aio/pandas/__init__.py b/test/integ/aio/pandas/__init__.py new file mode 100644 index 0000000000..ef416f64a0 --- /dev/null +++ b/test/integ/aio/pandas/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/test/integ/aio/pandas/test_arrow_chunk_iterator_async.py b/test/integ/aio/pandas/test_arrow_chunk_iterator_async.py new file mode 100644 index 0000000000..8ac2ddbee6 --- /dev/null +++ b/test/integ/aio/pandas/test_arrow_chunk_iterator_async.py @@ -0,0 +1,80 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import datetime +import random +from typing import Callable + +import pytest + +try: + from snowflake.connector.options import installed_pandas +except ImportError: + installed_pandas = False + +try: + import snowflake.connector.nanoarrow_arrow_iterator # NOQA + + no_arrow_iterator_ext = False +except ImportError: + no_arrow_iterator_ext = True + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas option is not installed.", +) +@pytest.mark.parametrize("timestamp_type", ("TZ", "LTZ", "NTZ")) +async def test_iterate_over_timestamp_chunk(conn_cnx, timestamp_type): + seed = datetime.datetime.now().timestamp() + row_numbers = 10 + random.seed(seed) + + # Generate random test data + def generator_test_data(scale: int) -> Callable[[], int]: + def generate_test_data() -> int: + nonlocal scale + epoch = random.randint(-100_355_968, 2_534_023_007) + frac = random.randint(0, 10**scale - 1) + if scale == 8: + frac *= 10 ** (9 - scale) + scale = 9 + return int(f"{epoch}{str(frac).rjust(scale, '0')}") + + return generate_test_data + + test_generators = [generator_test_data(i) for i in range(10)] + test_data = [[g() for g in test_generators] for _ in range(row_numbers)] + + async with conn_cnx( + session_parameters={ + "PYTHON_CONNECTOR_QUERY_RESULT_FORMAT": "ARROW_FORCE", + "TIMESTAMP_TZ_OUTPUT_FORMAT": "YYYY-MM-DD HH24:MI:SS.FF6 TZHTZM", + "TIMESTAMP_LTZ_OUTPUT_FORMAT": "YYYY-MM-DD HH24:MI:SS.FF6 TZHTZM", + "TIMESTAMP_NTZ_OUTPUT_FORMAT": "YYYY-MM-DD HH24:MI:SS.FF6 ", + } + ) as conn: + async with conn.cursor() as cur: + results = await ( + await cur.execute( + "select " + + ", ".join( + f"to_timestamp_{timestamp_type}(${s + 1}, {s if s != 8 else 9}) c_{s}" + for s in range(10) + ) + + ", " + + ", ".join(f"c_{i}::varchar" for i in range(10)) + + f" from values {', '.join(str(tuple(e)) for e in test_data)}" + ) + ).fetch_arrow_all() + retrieved_results = [ + list(map(lambda e: e.as_py().strftime("%Y-%m-%d %H:%M:%S.%f %z"), line)) + for line in list(results)[:10] + ] + retrieved_strigs = [ + list(map(lambda e: e.as_py().replace("Z", "+0000"), line)) + for line in list(results)[10:] + ] + + assert retrieved_results == retrieved_strigs diff --git a/test/integ/aio/pandas/test_arrow_pandas_async.py b/test/integ/aio/pandas/test_arrow_pandas_async.py new file mode 100644 index 0000000000..dce55241b0 --- /dev/null +++ b/test/integ/aio/pandas/test_arrow_pandas_async.py @@ -0,0 +1,1525 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import decimal +import itertools +import random +import time +from datetime import datetime +from decimal import Decimal +from enum import Enum +from unittest import mock + +import numpy +import pytest +import pytz +from numpy.testing import assert_equal + +try: + from snowflake.connector.constants import ( + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT, + IterUnit, + ) +except ImportError: + # This is because of olddriver tests + class IterUnit(Enum): + ROW_UNIT = "row" + TABLE_UNIT = "table" + + +try: + from snowflake.connector.options import installed_pandas, pandas, pyarrow +except ImportError: + installed_pandas = False + pandas = None + pyarrow = None + +try: + from snowflake.connector.nanoarrow_arrow_iterator import PyArrowIterator # NOQA + + no_arrow_iterator_ext = False +except ImportError: + no_arrow_iterator_ext = True + +SQL_ENABLE_ARROW = "alter session set python_connector_query_result_format='ARROW';" + +EPSILON = 1e-8 + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_num_one(conn_cnx): + print("Test fetching one single dataframe") + row_count = 50000 + col_count = 2 + random_seed = get_random_seed() + sql_exec = ( + f"select seq4() as c1, uniform(1, 10, random({random_seed})) as c2 from " + f"table(generator(rowcount=>{row_count})) order by c1, c2" + ) + await fetch_pandas(conn_cnx, sql_exec, row_count, col_count, "one") + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_scaled_tinyint(conn_cnx): + cases = ["NULL", 0.11, -0.11, "NULL", 1.27, -1.28, "NULL"] + table = "test_arrow_tiny_int" + column = "(a number(5,2))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_scaled_smallint(conn_cnx): + cases = ["NULL", 0, 0.11, -0.11, "NULL", 32.767, -32.768, "NULL"] + table = "test_arrow_small_int" + column = "(a number(5,3))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_scaled_int(conn_cnx): + cases = [ + "NULL", + 0, + "NULL", + 0.123456789, + -0.123456789, + 2.147483647, + -2.147483648, + "NULL", + ] + table = "test_arrow_int" + column = "(a number(10,9))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is not installed.", +) +async def test_scaled_bigint(conn_cnx): + cases = [ + "NULL", + 0, + "NULL", + "1.23456789E-10", + "-1.23456789E-10", + "2.147483647E-9", + "-2.147483647E-9", + "-1e-9", + "1e-9", + "1e-8", + "-1e-8", + "NULL", + ] + table = "test_arrow_big_int" + column = "(a number(38,18))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one", epsilon=EPSILON) + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_decimal(conn_cnx): + cases = [ + "NULL", + 0, + "NULL", + "10000000000000000000000000000000000000", + "12345678901234567890123456789012345678", + "99999999999999999999999999999999999999", + "-1000000000000000000000000000000000000", + "-2345678901234567890123456789012345678", + "-9999999999999999999999999999999999999", + "NULL", + ] + table = "test_arrow_decimal" + column = "(a number(38,0))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one", data_type="decimal") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is not installed.", +) +async def test_scaled_decimal(conn_cnx): + cases = [ + "NULL", + 0, + "NULL", + "1.0000000000000000000000000000000000000", + "1.2345678901234567890123456789012345678", + "9.9999999999999999999999999999999999999", + "-1.000000000000000000000000000000000000", + "-2.345678901234567890123456789012345678", + "-9.999999999999999999999999999999999999", + "NULL", + ] + table = "test_arrow_decimal" + column = "(a number(38,37))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one", data_type="decimal") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is not installed.", +) +async def test_scaled_decimal_SNOW_133561(conn_cnx): + cases = [ + "NULL", + 0, + "NULL", + "1.2345", + "2.1001", + "2.2001", + "2.3001", + "2.3456", + "-9.999", + "-1.000", + "-3.4567", + "3.4567", + "4.5678", + "5.6789", + "-0.0012", + "NULL", + ] + table = "test_scaled_decimal_SNOW_133561" + column = "(a number(38,10))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one", data_type="float") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_boolean(conn_cnx): + cases = ["NULL", True, "NULL", False, True, True, "NULL", True, False, "NULL"] + table = "test_arrow_boolean" + column = "(a boolean)" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_double(conn_cnx): + cases = [ + "NULL", + # SNOW-31249 + "-86.6426540296895", + "3.14159265359", + # SNOW-76269 + "1.7976931348623157E308", + "1.7E308", + "1.7976931348623151E308", + "-1.7976931348623151E308", + "-1.7E308", + "-1.7976931348623157E308", + "NULL", + ] + table = "test_arrow_double" + column = "(a double)" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_semi_struct(conn_cnx): + sql_text = """ + select array_construct(10, 20, 30), + array_construct(null, 'hello', 3::double, 4, 5), + array_construct(), + object_construct('a',1,'b','BBBB', 'c',null), + object_construct('Key_One', parse_json('NULL'), 'Key_Two', null, 'Key_Three', 'null'), + to_variant(3.2), + parse_json('{ "a": null}'), + 100::variant; + """ + res = [ + "[\n" + " 10,\n" + " 20,\n" + " 30\n" + "]", + "[\n" + + " undefined,\n" + + ' "hello",\n' + + " 3.000000000000000e+00,\n" + + " 4,\n" + + " 5\n" + + "]", + "[]", + "{\n" + ' "a": 1,\n' + ' "b": "BBBB"\n' + "}", + "{\n" + ' "Key_One": null,\n' + ' "Key_Three": "null"\n' + "}", + "3.2", + "{\n" + ' "a": null\n' + "}", + "100", + ] + async with conn_cnx() as cnx_table: + # fetch dataframe with new arrow support + cursor_table = cnx_table.cursor() + await cursor_table.execute(SQL_ENABLE_ARROW) + await cursor_table.execute(sql_text) + df_new = await cursor_table.fetch_pandas_all() + col_new = df_new.iloc[0] + for j, c_new in enumerate(col_new): + assert res[j] == c_new, ( + "{} column: original value is {}, new value is {}, " + "values are not equal".format(j, res[j], c_new) + ) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_date(conn_cnx): + cases = [ + "NULL", + "2017-01-01", + "2014-01-02", + "2014-01-02", + "1970-01-01", + "1970-01-01", + "NULL", + "1969-12-31", + "0200-02-27", + "NULL", + "0200-02-28", + # "0200-02-29", # day is out of range + # "0000-01-01", # year 0 is out of range + "0001-12-31", + "NULL", + ] + table = "test_arrow_date" + column = "(a date)" + values = ",".join( + [f"({i}, {c})" if c == "NULL" else f"({i}, '{c}')" for i, c in enumerate(cases)] + ) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one", data_type="date") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +@pytest.mark.parametrize("scale", [i for i in range(10)]) +async def test_time(conn_cnx, scale): + cases = [ + "NULL", + "00:00:51", + "01:09:03.100000", + "02:23:23.120000", + "03:56:23.123000", + "04:56:53.123400", + "09:01:23.123450", + "11:03:29.123456", + # note: Python's max time precision is microsecond, rest of them will lose precision + # "15:31:23.1234567", + # "19:01:43.12345678", + # "23:59:59.99999999", + "NULL", + ] + table = "test_arrow_time" + column = f"(a time({scale}))" + values = ",".join( + [f"({i}, {c})" if c == "NULL" else f"({i}, '{c}')" for i, c in enumerate(cases)] + ) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas( + conn, sql_text, cases, 1, "one", data_type="time", scale=scale + ) + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +@pytest.mark.parametrize("scale", [i for i in range(10)]) +async def test_timestampntz(conn_cnx, scale): + cases = [ + "NULL", + "1970-01-01 00:00:00", + "1970-01-01 00:00:01", + "1970-01-01 00:00:10", + "2014-01-02 16:00:00", + "2014-01-02 12:34:56", + "2017-01-01 12:00:00.123456789", + "2014-01-02 16:00:00.000000001", + "NULL", + "2014-01-02 12:34:57.1", + "1969-12-31 23:59:59.000000001", + "1970-01-01 00:00:00.123412423", + "1970-01-01 00:00:01.000001", + "1969-12-31 11:59:59.001", + # "0001-12-31 11:59:59.11", + # pandas._libs.tslibs.np_datetime.OutOfBoundsDatetime: + # Out of bounds nanosecond timestamp: 1-12-31 11:59:59 + "NULL", + ] + table = "test_arrow_timestamp" + column = f"(a timestampntz({scale}))" + + values = ",".join( + [f"({i}, {c})" if c == "NULL" else f"({i}, '{c}')" for i, c in enumerate(cases)] + ) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas( + conn, sql_text, cases, 1, "one", data_type="timestamp", scale=scale + ) + await finish(conn, table) + + +@pytest.mark.parametrize( + "timestamp_str", + [ + "'1400-01-01 01:02:03.123456789'::timestamp as low_ts", + "'9999-01-01 01:02:03.123456789789'::timestamp as high_ts", + ], +) +async def test_timestampntz_raises_overflow(conn_cnx, timestamp_str): + async with conn_cnx() as conn: + r = await conn.cursor().execute(f"select {timestamp_str}") + with pytest.raises(OverflowError, match="overflows int64 range."): + await r.fetch_arrow_all() + + +async def test_timestampntz_down_scale(conn_cnx): + async with conn_cnx() as conn: + r = await conn.cursor().execute( + "select '1400-01-01 01:02:03.123456'::timestamp as low_ts, '9999-01-01 01:02:03.123456'::timestamp as high_ts" + ) + table = await r.fetch_arrow_all() + lower_dt = table[0][0].as_py() # type: datetime + assert ( + lower_dt.year, + lower_dt.month, + lower_dt.day, + lower_dt.hour, + lower_dt.minute, + lower_dt.second, + lower_dt.microsecond, + ) == (1400, 1, 1, 1, 2, 3, 123456) + higher_dt = table[1][0].as_py() + assert ( + higher_dt.year, + higher_dt.month, + higher_dt.day, + higher_dt.hour, + higher_dt.minute, + higher_dt.second, + higher_dt.microsecond, + ) == (9999, 1, 1, 1, 2, 3, 123456) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +@pytest.mark.parametrize( + "scale, timezone", + itertools.product( + [i for i in range(10)], ["UTC", "America/New_York", "Australia/Sydney"] + ), +) +async def test_timestamptz(conn_cnx, scale, timezone): + cases = [ + "NULL", + "1971-01-01 00:00:00", + "1971-01-11 00:00:01", + "1971-01-01 00:00:10", + "2014-01-02 16:00:00", + "2014-01-02 12:34:56", + "2017-01-01 12:00:00.123456789", + "2014-01-02 16:00:00.000000001", + "NULL", + "2014-01-02 12:34:57.1", + "1969-12-31 23:59:59.000000001", + "1970-01-01 00:00:00.123412423", + "1970-01-01 00:00:01.000001", + "1969-12-31 11:59:59.001", + # "0001-12-31 11:59:59.11", + # pandas._libs.tslibs.np_datetime.OutOfBoundsDatetime: + # Out of bounds nanosecond timestamp: 1-12-31 11:59:59 + "NULL", + ] + table = "test_arrow_timestamp" + column = f"(a timestamptz({scale}))" + values = ",".join( + [f"({i}, {c})" if c == "NULL" else f"({i}, '{c}')" for i, c in enumerate(cases)] + ) + async with conn_cnx() as conn: + await init(conn, table, column, values, timezone=timezone) + sql_text = f"select a from {table} order by s" + await validate_pandas( + conn, + sql_text, + cases, + 1, + "one", + data_type="timestamptz", + scale=scale, + timezone=timezone, + ) + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +@pytest.mark.parametrize( + "scale, timezone", + itertools.product( + [i for i in range(10)], ["UTC", "America/New_York", "Australia/Sydney"] + ), +) +async def test_timestampltz(conn_cnx, scale, timezone): + cases = [ + "NULL", + "1970-01-01 00:00:00", + "1970-01-01 00:00:01", + "1970-01-01 00:00:10", + "2014-01-02 16:00:00", + "2014-01-02 12:34:56", + "2017-01-01 12:00:00.123456789", + "2014-01-02 16:00:00.000000001", + "NULL", + "2014-01-02 12:34:57.1", + "1969-12-31 23:59:59.000000001", + "1970-01-01 00:00:00.123412423", + "1970-01-01 00:00:01.000001", + "1969-12-31 11:59:59.001", + # "0001-12-31 11:59:59.11", + # pandas._libs.tslibs.np_datetime.OutOfBoundsDatetime: + # Out of bounds nanosecond timestamp: 1-12-31 11:59:59 + "NULL", + ] + table = "test_arrow_timestamp" + column = f"(a timestampltz({scale}))" + values = ",".join( + [f"({i}, {c})" if c == "NULL" else f"({i}, '{c}')" for i, c in enumerate(cases)] + ) + async with conn_cnx() as conn: + await init(conn, table, column, values, timezone=timezone) + sql_text = f"select a from {table} order by s" + await validate_pandas( + conn, + sql_text, + cases, + 1, + "one", + data_type="timestamp", + scale=scale, + timezone=timezone, + ) + await finish(conn, table) + + +@pytest.mark.skipolddriver +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_vector(conn_cnx, is_public_test): + if is_public_test: + pytest.xfail( + reason="This feature hasn't been rolled out for public Snowflake deployments yet." + ) + tests = [ + ( + "vector(int,3)", + [ + "NULL", + "[1,2,3]::vector(int,3)", + ], + ["NULL", numpy.array([1, 2, 3])], + ), + ( + "vector(float,3)", + [ + "NULL", + "[1.3,2.4,3.5]::vector(float,3)", + ], + ["NULL", numpy.array([1.3, 2.4, 3.5], dtype=numpy.float32)], + ), + ] + for vector_type, cases, typed_cases in tests: + table = "test_arrow_vector" + column = f"(a {vector_type})" + values = [f"{i}, {c}" for i, c in enumerate(cases)] + async with conn_cnx() as conn: + await init_with_insert_select(conn, table, column, values) + # Test general fetches + sql_text = f"select a from {table} order by s" + await validate_pandas( + conn, sql_text, typed_cases, 1, method="one", data_type=vector_type + ) + + # Test empty result sets + cur = conn.cursor() + await cur.execute(f"select a from {table} limit 0") + df = await cur.fetch_pandas_all() + assert len(df) == 0 + assert df.dtypes[0] == "object" + + await finish(conn, table) + + +async def validate_pandas( + cnx_table, + sql, + cases, + col_count, + method="one", + data_type="float", + epsilon=None, + scale=0, + timezone=None, +): + """Tests that parameters can be customized. + + Args: + cnx_table: Connection object. + sql: SQL command for execution. + cases: Test cases. + col_count: Number of columns in dataframe. + method: If method is 'batch', we fetch dataframes in batch. If method is 'one', we fetch a single dataframe + containing all data (Default value = 'one'). + data_type: Defines how to compare values (Default value = 'float'). + epsilon: For comparing double values (Default value = None). + scale: For comparing time values with scale (Default value = 0). + timezone: For comparing timestamp ltz (Default value = None). + """ + + row_count = len(cases) + assert col_count != 0, "# of columns should be larger than 0" + + cursor_table = cnx_table.cursor() + await cursor_table.execute(SQL_ENABLE_ARROW) + await cursor_table.execute(sql) + + # build dataframe + total_rows, total_batches = 0, 0 + start_time = time.time() + + if method == "one": + df_new = await cursor_table.fetch_pandas_all() + total_rows = df_new.shape[0] + else: + async for df_new in await cursor_table.fetch_pandas_batches(): + total_rows += df_new.shape[0] + total_batches += 1 + end_time = time.time() + + print(f"new way (fetching {method}) took {end_time - start_time}s") + if method == "batch": + print(f"new way has # of batches : {total_batches}") + await cursor_table.close() + assert ( + total_rows == row_count + ), f"there should be {row_count} rows, but {total_rows} rows" + + # verify the correctness + # only do it when fetch one dataframe + if method == "one": + assert (row_count, col_count) == df_new.shape, ( + "the shape of old dataframe is {}, " + "the shape of new dataframe is {}, " + "shapes are not equal".format((row_count, col_count), df_new.shape) + ) + + for i in range(row_count): + for j in range(col_count): + c_new = df_new.iat[i, j] + if type(cases[i]) is str and cases[i] == "NULL": + assert c_new is None or pandas.isnull(c_new), ( + "{} row, {} column: original value is NULL, " + "new value is {}, values are not equal".format(i, j, c_new) + ) + else: + if data_type == "float": + c_case = float(cases[i]) + elif data_type == "decimal": + c_case = Decimal(cases[i]) + elif data_type == "date": + c_case = datetime.strptime(cases[i], "%Y-%m-%d").date() + elif data_type == "time": + time_str_len = 8 if scale == 0 else 9 + scale + c_case = cases[i].strip()[:time_str_len] + c_new = str(c_new).strip()[:time_str_len] + assert c_new == c_case, ( + "{} row, {} column: original value is {}, " + "new value is {}, " + "values are not equal".format(i, j, cases[i], c_new) + ) + break + elif data_type.startswith("timestamp"): + time_str_len = 19 if scale == 0 else 20 + scale + if timezone: + c_case = pandas.Timestamp( + cases[i][:time_str_len], tz=timezone + ) + if data_type == "timestamptz": + c_case = c_case.tz_convert("UTC") + else: + c_case = pandas.Timestamp(cases[i][:time_str_len]) + assert c_case == c_new, ( + "{} row, {} column: original value is {}, new value is {}, " + "values are not equal".format(i, j, cases[i], c_new) + ) + break + elif data_type.startswith("vector"): + assert numpy.array_equal(cases[i], c_new), ( + "{} row, {} column: original value is {}, new value is {}, " + "values are not equal".format(i, j, cases[i], c_new) + ) + continue + else: + c_case = cases[i] + if epsilon is None: + assert c_case == c_new, ( + "{} row, {} column: original value is {}, new value is {}, " + "values are not equal".format(i, j, cases[i], c_new) + ) + else: + assert abs(c_case - c_new) < epsilon, ( + "{} row, {} column: original value is {}, " + "new value is {}, epsilon is {} \ + values are not equal".format( + i, j, cases[i], c_new, epsilon + ) + ) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_num_batch(conn_cnx): + print("Test fetching dataframes in batch") + row_count = 1000000 + col_count = 2 + random_seed = get_random_seed() + sql_exec = ( + f"select seq4() as c1, uniform(1, 10, random({random_seed})) as c2 from " + f"table(generator(rowcount=>{row_count})) order by c1, c2" + ) + await fetch_pandas(conn_cnx, sql_exec, row_count, col_count, "batch") + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +@pytest.mark.parametrize( + "result_format", + ["pandas", "arrow"], +) +async def test_empty(conn_cnx, result_format): + print("Test fetch empty dataframe") + async with conn_cnx() as cnx: + cursor = cnx.cursor() + await cursor.execute(SQL_ENABLE_ARROW) + await cursor.execute( + "select seq4() as foo, seq4() as bar from table(generator(rowcount=>1)) limit 0" + ) + fetch_all_fn = getattr(cursor, f"fetch_{result_format}_all") + fetch_batches_fn = getattr(cursor, f"fetch_{result_format}_batches") + result = await fetch_all_fn() + if result_format == "pandas": + assert len(list(result)) == 2 + assert list(result)[0] == "FOO" + assert list(result)[1] == "BAR" + else: + assert result is None + + await cursor.execute( + "select seq4() as foo from table(generator(rowcount=>1)) limit 0" + ) + df_count = 0 + async for _ in await fetch_batches_fn(): + df_count += 1 + assert df_count == 0 + + +def get_random_seed(): + random.seed(datetime.now().timestamp()) + return random.randint(0, 10000) + + +async def fetch_pandas(conn_cnx, sql, row_count, col_count, method="one"): + """Tests that parameters can be customized. + + Args: + conn_cnx: Connection object. + sql: SQL command for execution. + row_count: Number of total rows combining all dataframes. + col_count: Number of columns in dataframe. + method: If method is 'batch', we fetch dataframes in batch. If method is 'one', we fetch a single dataframe + containing all data (Default value = 'one'). + """ + assert row_count != 0, "# of rows should be larger than 0" + assert col_count != 0, "# of columns should be larger than 0" + + async with conn_cnx() as conn: + # fetch dataframe by fetching row by row + cursor_row = conn.cursor() + await cursor_row.execute(SQL_ENABLE_ARROW) + await cursor_row.execute(sql) + + # build dataframe + # actually its exec time would be different from `pandas.read_sql()` via sqlalchemy as most people use + # further perf test can be done separately + start_time = time.time() + rows = 0 + if method == "one": + df_old = pandas.DataFrame( + await cursor_row.fetchall(), + columns=[f"c{i}" for i in range(col_count)], + ) + else: + print("use fetchmany") + while True: + dat = await cursor_row.fetchmany(10000) + if not dat: + break + else: + df_old = pandas.DataFrame( + dat, columns=[f"c{i}" for i in range(col_count)] + ) + rows += df_old.shape[0] + end_time = time.time() + print(f"The original way took {end_time - start_time}s") + await cursor_row.close() + + # fetch dataframe with new arrow support + cursor_table = conn.cursor() + await cursor_table.execute(SQL_ENABLE_ARROW) + await cursor_table.execute(sql) + + # build dataframe + total_rows, total_batches = 0, 0 + start_time = time.time() + if method == "one": + df_new = await cursor_table.fetch_pandas_all() + total_rows = df_new.shape[0] + else: + async for df_new in await cursor_table.fetch_pandas_batches(): + total_rows += df_new.shape[0] + total_batches += 1 + end_time = time.time() + print(f"new way (fetching {method}) took {end_time - start_time}s") + if method == "batch": + print(f"new way has # of batches : {total_batches}") + await cursor_table.close() + assert total_rows == row_count, "there should be {} rows, but {} rows".format( + row_count, total_rows + ) + + # verify the correctness + # only do it when fetch one dataframe + if method == "one": + assert ( + df_old.shape == df_new.shape + ), "the shape of old dataframe is {}, the shape of new dataframe is {}, \ + shapes are not equal".format( + df_old.shape, df_new.shape + ) + + for i in range(row_count): + col_old = df_old.iloc[i] + col_new = df_new.iloc[i] + for j, (c_old, c_new) in enumerate(zip(col_old, col_new)): + assert c_old == c_new, ( + f"{i} row, {j} column: old value is {c_old}, new value " + f"is {c_new} values are not equal" + ) + else: + assert ( + rows == total_rows + ), f"the number of rows are not equal {rows} vs {total_rows}" + + +async def init(json_cnx, table, column, values, timezone=None): + cursor_json = json_cnx.cursor() + if timezone is not None: + await cursor_json.execute(f"ALTER SESSION SET TIMEZONE = '{timezone}'") + column_with_seq = column[0] + "s number, " + column[1:] + await cursor_json.execute(f"create or replace table {table} {column_with_seq}") + await cursor_json.execute(f"insert into {table} values {values}") + + +async def init_with_insert_select(json_cnx, table, column, rows, timezone=None): + cursor_json = json_cnx.cursor() + if timezone is not None: + await cursor_json.execute(f"ALTER SESSION SET TIMEZONE = '{timezone}'") + column_with_seq = column[0] + "s number, " + column[1:] + await cursor_json.execute(f"create or replace table {table} {column_with_seq}") + for row in rows: + await cursor_json.execute(f"insert into {table} select {row}") + + +async def finish(json_cnx, table): + cursor_json = json_cnx.cursor() + await cursor_json.execute(f"drop table if exists {table};") + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_arrow_fetch_result_scan(conn_cnx): + async with conn_cnx() as cnx: + cur = cnx.cursor() + await cur.execute("alter session set query_result_format='ARROW_FORCE'") + await cur.execute( + "alter session set python_connector_query_result_format='ARROW_FORCE'" + ) + res = await (await cur.execute("select 1, 2, 3")).fetch_pandas_all() + assert tuple(res) == ("1", "2", "3") + result_scan_res = await ( + await cur.execute(f"select * from table(result_scan('{cur.sfqid}'));") + ).fetch_pandas_all() + assert tuple(result_scan_res) == ("1", "2", "3") + + +@pytest.mark.parametrize("query_format", ("JSON", "ARROW")) +@pytest.mark.parametrize("resultscan_format", ("JSON", "ARROW")) +async def test_query_resultscan_combos(conn_cnx, query_format, resultscan_format): + if query_format == "JSON" and resultscan_format == "ARROW": + pytest.xfail("fix not yet released to test deployment") + async with conn_cnx() as cnx: + sfqid = None + results = None + scanned_results = None + async with cnx.cursor() as query_cur: + await query_cur.execute( + "alter session set python_connector_query_result_format='{}'".format( + query_format + ) + ) + await query_cur.execute( + "select seq8(), randstr(1000,random()) from table(generator(rowcount=>100))" + ) + sfqid = query_cur.sfqid + assert query_cur._query_result_format.upper() == query_format + if query_format == "JSON": + results = await query_cur.fetchall() + else: + results = await query_cur.fetch_pandas_all() + async with cnx.cursor() as resultscan_cur: + await resultscan_cur.execute( + "alter session set python_connector_query_result_format='{}'".format( + resultscan_format + ) + ) + await resultscan_cur.execute(f"select * from table(result_scan('{sfqid}'))") + if resultscan_format == "JSON": + scanned_results = await resultscan_cur.fetchall() + else: + scanned_results = await resultscan_cur.fetch_pandas_all() + assert resultscan_cur._query_result_format.upper() == resultscan_format + if isinstance(results, pandas.DataFrame): + results = [tuple(e) for e in results.values.tolist()] + if isinstance(scanned_results, pandas.DataFrame): + scanned_results = [tuple(e) for e in scanned_results.values.tolist()] + assert results == scanned_results + + +@pytest.mark.parametrize( + "use_decimal,expected", + [ + (False, numpy.float64), + pytest.param(True, decimal.Decimal, marks=pytest.mark.skipolddriver), + ], +) +async def test_number_fetchall_retrieve_type(conn_cnx, use_decimal, expected): + async with conn_cnx(arrow_number_to_decimal=use_decimal) as con: + async with con.cursor() as cur: + await cur.execute("SELECT 12345600.87654301::NUMBER(18, 8) a") + result_df = await cur.fetch_pandas_all() + a_column = result_df["A"] + assert isinstance(a_column.values[0], expected), type(a_column.values[0]) + + +@pytest.mark.parametrize( + "use_decimal,expected", + [ + ( + False, + numpy.float64, + ), + pytest.param(True, decimal.Decimal, marks=pytest.mark.skipolddriver), + ], +) +async def test_number_fetchbatches_retrieve_type( + conn_cnx, use_decimal: bool, expected: type +): + async with conn_cnx(arrow_number_to_decimal=use_decimal) as con: + async with con.cursor() as cur: + await cur.execute("SELECT 12345600.87654301::NUMBER(18, 8) a") + async for batch in await cur.fetch_pandas_batches(): + a_column = batch["A"] + assert isinstance(a_column.values[0], expected), type( + a_column.values[0] + ) + + +async def test_execute_async_and_fetch_pandas_batches(conn_cnx): + """Test get pandas in an asynchronous way""" + + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute("select 1/2") + res_sync = await cur.fetch_pandas_batches() + + result = await cur.execute_async("select 1/2") + await cur.get_results_from_sfqid(result["queryId"]) + res_async = await cur.fetch_pandas_batches() + + assert res_sync is not None + assert res_async is not None + while True: + try: + r_sync = await res_sync.__anext__() + r_async = await res_async.__anext__() + assert r_sync.values == r_async.values + except StopAsyncIteration: + break + + +async def test_execute_async_and_fetch_arrow_batches(conn_cnx): + """Test fetching result of an asynchronous query as batches of arrow tables""" + + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute("select 1/2") + res_sync = await cur.fetch_arrow_batches() + + result = await cur.execute_async("select 1/2") + await cur.get_results_from_sfqid(result["queryId"]) + res_async = await cur.fetch_arrow_batches() + + assert res_sync is not None + assert res_async is not None + while True: + try: + r_sync = await res_sync.__anext__() + r_async = await res_async.__anext__() + assert r_sync == r_async + except StopAsyncIteration: + break + + +async def test_simple_async_pandas(conn_cnx): + """Simple test to that shows the most simple usage of fire and forget. + + This test also makes sure that wait_until_ready function's sleeping is tested and + that some fields are copied over correctly from the original query. + """ + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 5))" + ) + await cur.get_results_from_sfqid(cur.sfqid) + assert len(await cur.fetch_pandas_all()) == 1 + assert cur.rowcount + assert cur.description + + +async def test_simple_async_arrow(conn_cnx): + """Simple test for async fetch_arrow_all""" + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 5))" + ) + await cur.get_results_from_sfqid(cur.sfqid) + assert len(await cur.fetch_arrow_all()) == 1 + assert cur.rowcount + assert cur.description + + +@pytest.mark.parametrize( + "use_decimal,expected", + [ + ( + True, + decimal.Decimal, + ), + pytest.param(False, numpy.float64, marks=pytest.mark.xfail), + ], +) +async def test_number_iter_retrieve_type(conn_cnx, use_decimal: bool, expected: type): + async with conn_cnx(arrow_number_to_decimal=use_decimal) as con: + async with con.cursor() as cur: + await cur.execute("SELECT 12345600.87654301::NUMBER(18, 8) a") + async for row in cur: + assert isinstance(row[0], expected), type(row[0]) + + +async def test_resultbatches_pandas_functionality(conn_cnx): + """Fetch ArrowResultBatches as pandas dataframes and check its result.""" + rowcount = 100000 + expected_df = pandas.DataFrame(data={"A": range(rowcount)}) + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute( + f"select seq4() a from table(generator(rowcount => {rowcount}));" + ) + assert cur._result_set.total_row_index() == rowcount + result_batches = await cur.get_result_batches() + assert (await cur.fetch_pandas_all()).index[-1] == rowcount - 1 + assert len(result_batches) > 1 + + iterables = [] + for b in result_batches: + iterables.append( + list(await b.create_iter(iter_unit=IterUnit.TABLE_UNIT, structure="arrow")) + ) + tables = itertools.chain.from_iterable(iterables) + final_df = pyarrow.concat_tables(tables).to_pandas() + assert numpy.array_equal(expected_df, final_df) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing. or no new telemetry defined - skipolddrive", +) +@pytest.mark.parametrize( + "fetch_method, expected_telemetry_type", + [ + ("one", "client_fetch_pandas_all"), # TelemetryField.PANDAS_FETCH_ALL + ("batch", "client_fetch_pandas_batches"), # TelemetryField.PANDAS_FETCH_BATCHES + ], +) +async def test_pandas_telemetry( + conn_cnx, capture_sf_telemetry_async, fetch_method, expected_telemetry_type +): + cases = ["NULL", 0.11, -0.11, "NULL", 1.27, -1.28, "NULL"] + table = "test_telemetry" + column = "(a number(5,2))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn, capture_sf_telemetry_async.patch_connection( + conn, False + ) as telemetry_test: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + + await validate_pandas( + conn, + sql_text, + cases, + 1, + fetch_method, + ) + + occurence = 0 + for t in telemetry_test.records: + if t.message["type"] == expected_telemetry_type: + occurence += 1 + assert occurence == 1 + + await finish(conn, table) + + +@pytest.mark.parametrize("result_format", ["pandas", "arrow"]) +async def test_batch_to_pandas_arrow(conn_cnx, result_format): + rowcount = 10 + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute(SQL_ENABLE_ARROW) + await cur.execute( + f"select seq4() as foo, seq4() as bar from table(generator(rowcount=>{rowcount})) order by foo asc" + ) + batches = await cur.get_result_batches() + assert len(batches) == 1 + batch = batches[0] + + # check that size, columns, and FOO column data is correct + if result_format == "pandas": + df = await batch.to_pandas() + assert type(df) is pandas.DataFrame + assert df.shape == (10, 2) + assert all(df.columns == ["FOO", "BAR"]) + assert list(df.FOO) == list(range(rowcount)) + elif result_format == "arrow": + arrow_table = await batch.to_arrow() + assert type(arrow_table) is pyarrow.Table + assert arrow_table.shape == (10, 2) + assert arrow_table.column_names == ["FOO", "BAR"] + assert arrow_table.to_pydict()["FOO"] == list(range(rowcount)) + + +@pytest.mark.internal +@pytest.mark.parametrize("enable_structured_types", [True, False]) +async def test_to_arrow_datatypes(enable_structured_types, conn_cnx): + expected_types = ( + pyarrow.int64(), + pyarrow.float64(), + pyarrow.string(), + pyarrow.date64(), + pyarrow.timestamp("ns"), + pyarrow.string(), + pyarrow.timestamp("ns"), + pyarrow.timestamp("ns"), + pyarrow.timestamp("ns"), + pyarrow.binary(), + pyarrow.time64("ns"), + pyarrow.bool_(), + pyarrow.string(), + pyarrow.string(), + pyarrow.list_(pyarrow.float64(), 5), + ) + + query = """ + select + 1 :: INTEGER as FIXED_type, + 2.0 :: FLOAT as REAL_type, + 'test' :: TEXT as TEXT_type, + '2024-02-28' :: DATE as DATE_type, + '2020-03-12 01:02:03.123456789' :: TIMESTAMP as TIMESTAMP_type, + '{"foo": "bar"}' :: VARIANT as VARIANT_type, + '2020-03-12 01:02:03.123456789' :: TIMESTAMP_LTZ as TIMESTAMP_LTZ_type, + '2020-03-12 01:02:03.123456789' :: TIMESTAMP_TZ as TIMESTAMP_TZ_type, + '2020-03-12 01:02:03.123456789' :: TIMESTAMP_NTZ as TIMESTAMP_NTZ_type, + '0xAAAA' :: BINARY as BINARY_type, + '01:02:03.123456789' :: TIME as TIME_type, + true :: BOOLEAN as BOOLEAN_type, + TO_GEOGRAPHY('LINESTRING(13.4814 52.5015, -121.8212 36.8252)') as GEOGRAPHY_type, + TO_GEOMETRY('LINESTRING(13.4814 52.5015, -121.8212 36.8252)') as GEOMETRY_type, + [1,2,3,4,5] :: vector(float, 5) as VECTOR_type, + object_construct('k1', 1, 'k2', 2, 'k3', 3, 'k4', 4, 'k5', 5) :: map(varchar, int) as MAP_type, + object_construct('city', 'san jose', 'population', 0.05) :: object(city varchar, population float) as OBJECT_type, + [1.0, 3.1, 4.5] :: array(float) as ARRAY_type + WHERE 1=0 + """ + + structured_params = { + "ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE", + "IGNORE_CLIENT_VESRION_IN_STRUCTURED_TYPES_RESPONSE", + "FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", + } + + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute(SQL_ENABLE_ARROW) + try: + if enable_structured_types: + for param in structured_params: + await cur.execute(f"alter session set {param}=true") + expected_types += ( + pyarrow.map_(pyarrow.string(), pyarrow.int64()), + pyarrow.struct( + {"city": pyarrow.string(), "population": pyarrow.float64()} + ), + pyarrow.list_(pyarrow.float64()), + ) + else: + expected_types += ( + pyarrow.string(), + pyarrow.string(), + pyarrow.string(), + ) + # Ensure an empty batch to use default typing + # Otherwise arrow will resize types to save space + await cur.execute(query) + batches = cur.get_result_batches() + assert len(batches) == 1 + batch = batches[0] + arrow_table = batch.to_arrow() + for actual, expected in zip(arrow_table.schema, expected_types): + assert ( + actual.type == expected + ), f"Expected {actual.name} :: {actual.type} column to be of type {expected}" + finally: + if enable_structured_types: + for param in structured_params: + await cur.execute(f"alter session unset {param}") + + +async def test_simple_arrow_fetch(conn_cnx): + rowcount = 250_000 + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute(SQL_ENABLE_ARROW) + await cur.execute( + f"select seq4() as foo from table(generator(rowcount=>{rowcount})) order by foo asc" + ) + arrow_table = await cur.fetch_arrow_all() + assert arrow_table.shape == (rowcount, 1) + assert arrow_table.to_pydict()["FOO"] == list(range(rowcount)) + + await cur.execute( + f"select seq4() as foo from table(generator(rowcount=>{rowcount})) order by foo asc" + ) + assert ( + len(await cur.get_result_batches()) > 1 + ) # non-trivial number of batches + + # the start and end points of each batch + lo, hi = 0, 0 + async for table in await cur.fetch_arrow_batches(): + assert type(table) is pyarrow.Table # sanity type check + + # check that data is correct + length = len(table) + hi += length + assert table.to_pydict()["FOO"] == list(range(lo, hi)) + lo += length + + assert lo == rowcount + + +async def test_arrow_zero_rows(conn_cnx): + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute(SQL_ENABLE_ARROW) + await cur.execute("select 1::NUMBER(38,0) limit 0") + table = await cur.fetch_arrow_all(force_return_table=True) + # Snowflake will return an integer dtype with maximum bit-length if + # no rows are returned + assert table.schema[0].type == pyarrow.int64() + await cur.execute("select 1::NUMBER(38,0) limit 0") + # test default behavior + assert await cur.fetch_arrow_all(force_return_table=False) is None + + +@pytest.mark.parametrize("fetch_fn_name", ["to_arrow", "to_pandas", "create_iter"]) +@pytest.mark.parametrize("pass_connection", [True, False]) +async def test_sessions_used(conn_cnx, fetch_fn_name, pass_connection): + rowcount = 250_000 + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute(SQL_ENABLE_ARROW) + await cur.execute( + f"select seq1() from table(generator(rowcount=>{rowcount}))" + ) + batches = await cur.get_result_batches() + assert len(batches) > 1 + batch = batches[-1] + + connection = cnx if pass_connection else None + fetch_fn = getattr(batch, fetch_fn_name) + + # check that sessions are used when connection is supplied + with mock.patch( + "snowflake.connector.aio._network.SnowflakeRestful._use_requests_session", + side_effect=cnx._rest._use_requests_session, + ) as get_session_mock: + await fetch_fn(connection=connection) + assert get_session_mock.call_count == (1 if pass_connection else 0) + + +def assert_dtype_equal(a, b): + """Pandas method of asserting the same numpy dtype of variables by computing hash.""" + assert_equal(a, b) + assert_equal( + hash(a), hash(b), "two equivalent types do not hash to the same value !" + ) + + +def assert_pandas_batch_types( + batch: pandas.DataFrame, expected_types: list[type] +) -> None: + assert batch.dtypes is not None + + pandas_dtypes = batch.dtypes + # pd.string is represented as an np.object + # np.dtype string is not the same as pd.string (python) + for pandas_dtype, expected_type in zip(pandas_dtypes, expected_types): + assert_dtype_equal(pandas_dtype.type, numpy.dtype(expected_type).type) + + +async def test_pandas_dtypes(conn_cnx): + async with conn_cnx( + session_parameters={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "arrow_force" + } + ) as cnx: + async with cnx.cursor() as cur: + await cur.execute( + "select 1::integer, 2.3::double, 'foo'::string, current_timestamp()::timestamp where 1=0" + ) + expected_types = [numpy.int64, float, object, numpy.datetime64] + assert_pandas_batch_types(await cur.fetch_pandas_all(), expected_types) + + batches = await cur.get_result_batches() + assert await batches[0].to_arrow() is not True + assert_pandas_batch_types(await batches[0].to_pandas(), expected_types) + + +async def test_timestamp_tz(conn_cnx): + async with conn_cnx( + session_parameters={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "arrow_force" + } + ) as cnx: + async with cnx.cursor() as cur: + await cur.execute("select '1990-01-04 10:00:00 +1100'::timestamp_tz as d") + res = await cur.fetchall() + assert res[0][0].tzinfo is not None + res_pd = await cur.fetch_pandas_all() + assert res_pd.D.dt.tz is pytz.UTC + res_pa = await cur.fetch_arrow_all() + assert res_pa.field("D").type.tz == "UTC" + + +async def test_arrow_number_to_decimal(conn_cnx): + async with conn_cnx( + session_parameters={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "arrow_force" + }, + arrow_number_to_decimal=True, + ) as cnx: + async with cnx.cursor() as cur: + await cur.execute("select -3.20 as num") + df = await cur.fetch_pandas_all() + val = df.NUM[0] + assert val == Decimal("-3.20") + assert isinstance(val, decimal.Decimal) + + +@pytest.mark.parametrize( + "timestamp_type", + [ + "TIMESTAMP_TZ", + "TIMESTAMP_NTZ", + "TIMESTAMP_LTZ", + ], +) +async def test_time_interval_microsecond(conn_cnx, timestamp_type): + async with conn_cnx( + session_parameters={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "arrow_force" + } + ) as cnx: + async with cnx.cursor() as cur: + res = await ( + await cur.execute( + f"SELECT TO_{timestamp_type}('2010-06-25 12:15:30.747000')+INTERVAL '8999999999999998 MICROSECONDS'" + ) + ).fetchone() + assert res[0].microsecond == 746998 + res = await ( + await cur.execute( + f"SELECT TO_{timestamp_type}('2010-06-25 12:15:30.747000')+INTERVAL '8999999999999999 MICROSECONDS'" + ) + ).fetchone() + assert res[0].microsecond == 746999 + + +async def test_fetch_with_pandas_nullable_types(conn_cnx): + # use several float values to test nullable types. Nullable types can preserve both nan and null in float + sql_text = """ + select 1.0::float, 'NaN'::float, Null::float; + """ + # https://arrow.apache.org/docs/python/pandas.html#nullable-types + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + expected_dtypes = pandas.Series( + [pandas.Float64Dtype(), pandas.Float64Dtype(), pandas.Float64Dtype()], + index=["1.0::FLOAT", "'NAN'::FLOAT", "NULL::FLOAT"], + ) + expected_df_to_string = """ 1.0::FLOAT 'NAN'::FLOAT NULL::FLOAT +0 1.0 NaN """ + async with conn_cnx() as cnx_table: + # fetch dataframe with new arrow support + cursor_table = cnx_table.cursor() + await cursor_table.execute(SQL_ENABLE_ARROW) + await cursor_table.execute(sql_text) + # test fetch_pandas_batches + async for df in await cursor_table.fetch_pandas_batches( + types_mapper=dtype_mapping.get + ): + pandas._testing.assert_series_equal(df.dtypes, expected_dtypes) + print(df) + assert df.to_string() == expected_df_to_string + # test fetch_pandas_all + df = await cursor_table.fetch_pandas_all(types_mapper=dtype_mapping.get) + pandas._testing.assert_series_equal(df.dtypes, expected_dtypes) + assert df.to_string() == expected_df_to_string diff --git a/test/integ/aio/pandas/test_logging_async.py b/test/integ/aio/pandas/test_logging_async.py new file mode 100644 index 0000000000..9b35d11a8b --- /dev/null +++ b/test/integ/aio/pandas/test_logging_async.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import logging + + +async def test_rand_table_log(caplog, conn_cnx, db_parameters): + async with conn_cnx() as conn: + caplog.set_level(logging.DEBUG, "snowflake.connector") + + num_of_rows = 10 + async with conn.cursor() as cur: + await ( + await cur.execute( + "select randstr(abs(mod(random(), 100)), random()) from table(generator(rowcount => {}));".format( + num_of_rows + ) + ) + ).fetchall() + + # make assertions + has_batch_read = has_batch_size = has_chunk_info = has_batch_index = False + for record in caplog.records: + if "Batches read:" in record.msg: + has_batch_read = True + assert "arrow_iterator" in record.filename + assert "__cinit__" in record.funcName + + if "Arrow BatchSize:" in record.msg: + has_batch_size = True + assert "CArrowIterator.cpp" in record.filename + assert "CArrowIterator" in record.funcName + + if "Arrow chunk info:" in record.msg: + has_chunk_info = True + assert "CArrowChunkIterator.cpp" in record.filename + assert "CArrowChunkIterator" in record.funcName + + if "Current batch index:" in record.msg: + has_batch_index = True + assert "CArrowChunkIterator.cpp" in record.filename + assert "next" in record.funcName + + # each of these records appear at least once in records + assert has_batch_read and has_batch_size and has_chunk_info and has_batch_index diff --git a/test/integ/aio/sso/__init__.py b/test/integ/aio/sso/__init__.py new file mode 100644 index 0000000000..ef416f64a0 --- /dev/null +++ b/test/integ/aio/sso/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/test/integ/aio/sso/test_connection_manual_async.py b/test/integ/aio/sso/test_connection_manual_async.py new file mode 100644 index 0000000000..438283131c --- /dev/null +++ b/test/integ/aio/sso/test_connection_manual_async.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +# This test requires the SSO and Snowflake admin connection parameters. +# +# CONNECTION_PARAMETERS_SSO = { +# 'account': 'testaccount', +# 'user': 'qa@snowflakecomputing.com', +# 'protocol': 'http', +# 'host': 'testaccount.reg.snowflakecomputing.com', +# 'port': '8082', +# 'authenticator': 'externalbrowser', +# 'timezone': 'UTC', +# } +# +# CONNECTION_PARAMETERS_ADMIN = { ... Snowflake admin ... } +import os +import sys + +import pytest + +import snowflake.connector.aio +from snowflake.connector.auth._auth import delete_temporary_credential + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +try: + from parameters import CONNECTION_PARAMETERS_SSO +except ImportError: + CONNECTION_PARAMETERS_SSO = {} + +try: + from parameters import CONNECTION_PARAMETERS_ADMIN +except ImportError: + CONNECTION_PARAMETERS_ADMIN = {} + +ID_TOKEN = "ID_TOKEN" + + +@pytest.fixture +async def token_validity_test_values(request): + async with snowflake.connector.aio.SnowflakeConnection( + **CONNECTION_PARAMETERS_ADMIN + ) as cnx: + await cnx.cursor().execute( + """ +ALTER SYSTEM SET + MASTER_TOKEN_VALIDITY=60, + SESSION_TOKEN_VALIDITY=5, + ID_TOKEN_VALIDITY=60 +""" + ) + # ALLOW_UNPROTECTED_ID_TOKEN is going to be deprecated in the future + # cnx.cursor().execute("alter account testaccount set ALLOW_UNPROTECTED_ID_TOKEN=true;") + await cnx.cursor().execute("alter account testaccount set ALLOW_ID_TOKEN=true;") + await cnx.cursor().execute( + "alter account testaccount set ID_TOKEN_FEATURE_ENABLED=true;" + ) + + async def fin(): + async with snowflake.connector.connect(**CONNECTION_PARAMETERS_ADMIN) as cnx: + await cnx.cursor().execute( + """ +ALTER SYSTEM SET + MASTER_TOKEN_VALIDITY=default, + SESSION_TOKEN_VALIDITY=default, + ID_TOKEN_VALIDITY=default +""" + ) + + request.addfinalizer(fin) + return None + + +@pytest.mark.skipif( + not ( + CONNECTION_PARAMETERS_SSO + and CONNECTION_PARAMETERS_ADMIN + and delete_temporary_credential + ), + reason="SSO and ADMIN connection parameters must be provided.", +) +async def test_connect_externalbrowser(token_validity_test_values): + """SSO Id Token Cache tests. This test should only be ran if keyring optional dependency is installed. + + In order to run this test, remove the above pytest.mark.skip annotation and run it. It will popup a windows once + but the rest connections should not create popups. + """ + delete_temporary_credential( + host=CONNECTION_PARAMETERS_SSO["host"], + user=CONNECTION_PARAMETERS_SSO["user"], + cred_type=ID_TOKEN, + ) # delete existing temporary credential + CONNECTION_PARAMETERS_SSO["client_store_temporary_credential"] = True + + # change database and schema to non-default one + print( + "[INFO] 1st connection gets id token and stores in the local cache (keychain/credential manager/cache file). " + "This popup a browser to SSO login" + ) + cnx = snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS_SSO) + await cnx.connect() + assert cnx.database == "TESTDB" + assert cnx.schema == "PUBLIC" + assert cnx.role == "SYSADMIN" + assert cnx.warehouse == "REGRESS" + ret = await ( + await cnx.cursor().execute( + "select current_database(), current_schema(), " + "current_role(), current_warehouse()" + ) + ).fetchall() + assert ret[0][0] == "TESTDB" + assert ret[0][1] == "PUBLIC" + assert ret[0][2] == "SYSADMIN" + assert ret[0][3] == "REGRESS" + await cnx.close() + + print( + "[INFO] 2nd connection reads the local cache and uses the id token. " + "This should not popups a browser." + ) + CONNECTION_PARAMETERS_SSO["database"] = "testdb" + CONNECTION_PARAMETERS_SSO["schema"] = "testschema" + cnx = snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS_SSO) + await cnx.connect() + print( + "[INFO] Running a 10 seconds query. If the session expires in 10 " + "seconds, the query should renew the token in the middle, " + "and the current objects should be refreshed." + ) + await cnx.cursor().execute("select seq8() from table(generator(timelimit=>10))") + assert cnx.database == "TESTDB" + assert cnx.schema == "TESTSCHEMA" + assert cnx.role == "SYSADMIN" + assert cnx.warehouse == "REGRESS" + + print("[INFO] Running a 1 second query. ") + await cnx.cursor().execute("select seq8() from table(generator(timelimit=>1))") + assert cnx.database == "TESTDB" + assert cnx.schema == "TESTSCHEMA" + assert cnx.role == "SYSADMIN" + assert cnx.warehouse == "REGRESS" + + print( + "[INFO] Running a 90 seconds query. This pops up a browser in the " + "middle of the query." + ) + await cnx.cursor().execute("select seq8() from table(generator(timelimit=>90))") + assert cnx.database == "TESTDB" + assert cnx.schema == "TESTSCHEMA" + assert cnx.role == "SYSADMIN" + assert cnx.warehouse == "REGRESS" + + await cnx.close() + + # change database and schema again to ensure they are overridden + CONNECTION_PARAMETERS_SSO["database"] = "testdb" + CONNECTION_PARAMETERS_SSO["schema"] = "testschema" + cnx = snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS_SSO) + await cnx.connect() + assert cnx.database == "TESTDB" + assert cnx.schema == "TESTSCHEMA" + assert cnx.role == "SYSADMIN" + assert cnx.warehouse == "REGRESS" + await cnx.close() + + async with snowflake.connector.aio.SnowflakeConnection( + **CONNECTION_PARAMETERS_ADMIN + ) as cnx_admin: + # cnx_admin.cursor().execute("alter account testaccount set ALLOW_UNPROTECTED_ID_TOKEN=false;") + await cnx_admin.cursor().execute( + "alter account testaccount set ALLOW_ID_TOKEN=false;" + ) + await cnx_admin.cursor().execute( + "alter account testaccount set ID_TOKEN_FEATURE_ENABLED=false;" + ) + print( + "[INFO] Login again with ALLOW_UNPROTECTED_ID_TOKEN unset. Please make sure this pops up the browser" + ) + cnx = snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS_SSO) + await cnx.connect() + await cnx.close() diff --git a/test/integ/aio/sso/test_unit_mfa_cache_async.py b/test/integ/aio/sso/test_unit_mfa_cache_async.py new file mode 100644 index 0000000000..288c33e69e --- /dev/null +++ b/test/integ/aio/sso/test_unit_mfa_cache_async.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import json +import os +from unittest.mock import Mock, patch + +import pytest + +import snowflake.connector.aio +from snowflake.connector.compat import IS_LINUX +from snowflake.connector.errors import DatabaseError + +try: + from snowflake.connector.compat import IS_MACOS +except ImportError: + import platform + + IS_MACOS = platform.system() == "Darwin" + +try: + import keyring # noqa + + from snowflake.connector.auth._auth import delete_temporary_credential +except ImportError: + delete_temporary_credential = None + +MFA_TOKEN = "MFATOKEN" + + +# Although this is an unit test, we put it under test/integ/sso, since it needs keyring package installed +@pytest.mark.skipif( + delete_temporary_credential is None, + reason="delete_temporary_credential is not available.", +) +@patch("snowflake.connector.aio._network.SnowflakeRestful._post_request") +async def test_mfa_cache(mockSnowflakeRestfulPostRequest): + """Connects with (username, pwd, mfa) mock.""" + os.environ["SF_TEMPORARY_CREDENTIAL_CACHE_DIR"] = os.getenv( + "WORKSPACE", os.path.expanduser("~") + ) + + LOCAL_CACHE = dict() + + async def mock_post_request(url, headers, json_body, **kwargs): + global mock_post_req_cnt + ret = None + body = json.loads(json_body) + if mock_post_req_cnt == 0: + # issue MFA token for a succeeded login + assert ( + body["data"]["SESSION_PARAMETERS"].get("CLIENT_REQUEST_MFA_TOKEN") + is True + ) + ret = { + "success": True, + "message": None, + "data": { + "token": "TOKEN", + "masterToken": "MASTER_TOKEN", + "mfaToken": "MFA_TOKEN", + }, + } + elif mock_post_req_cnt == 2: + # check associated mfa token and issue a new mfa token + # note: Normally, backend doesn't issue a new mfa token in this case, we do it here only to test + # whether the driver can replace the old token when server provides a new token + assert ( + body["data"]["SESSION_PARAMETERS"].get("CLIENT_REQUEST_MFA_TOKEN") + is True + ) + assert body["data"]["TOKEN"] == "MFA_TOKEN" + ret = { + "success": True, + "message": None, + "data": { + "token": "NEW_TOKEN", + "masterToken": "NEW_MASTER_TOKEN", + "mfaToken": "NEW_MFA_TOKEN", + }, + } + elif mock_post_req_cnt == 4: + # check new mfa token + assert ( + body["data"]["SESSION_PARAMETERS"].get("CLIENT_REQUEST_MFA_TOKEN") + is True + ) + assert body["data"]["TOKEN"] == "NEW_MFA_TOKEN" + ret = { + "success": True, + "message": None, + "data": { + "token": "NEW_TOKEN", + "masterToken": "NEW_MASTER_TOKEN", + }, + } + elif mock_post_req_cnt == 6: + # mock a failed log in + ret = {"success": False, "message": None, "data": {}} + elif mock_post_req_cnt == 7: + assert ( + body["data"]["SESSION_PARAMETERS"].get("CLIENT_REQUEST_MFA_TOKEN") + is True + ) + assert "TOKEN" not in body["data"] + ret = { + "success": True, + "data": {"token": "TOKEN", "masterToken": "MASTER_TOKEN"}, + } + elif mock_post_req_cnt in [1, 3, 5, 8]: + # connection.close() + ret = {"success": True} + mock_post_req_cnt += 1 + return ret + + def mock_del_password(system, user): + LOCAL_CACHE.pop(system + user, None) + + def mock_set_password(system, user, pwd): + LOCAL_CACHE[system + user] = pwd + + def mock_get_password(system, user): + return LOCAL_CACHE.get(system + user, None) + + global mock_post_req_cnt + mock_post_req_cnt = 0 + + # POST requests mock + mockSnowflakeRestfulPostRequest.side_effect = mock_post_request + + async def test_body(conn_cfg): + delete_temporary_credential( + host=conn_cfg["host"], user=conn_cfg["user"], cred_type=MFA_TOKEN + ) + + # first connection, no mfa token cache + con = snowflake.connector.aio.SnowflakeConnection(**conn_cfg) + await con.connect() + assert con._rest.token == "TOKEN" + assert con._rest.master_token == "MASTER_TOKEN" + assert con._rest.mfa_token == "MFA_TOKEN" + await con.close() + + # second connection that uses the mfa token issued for first connection to login + con = snowflake.connector.aio.SnowflakeConnection(**conn_cfg) + await con.connect() + assert con._rest.token == "NEW_TOKEN" + assert con._rest.master_token == "NEW_MASTER_TOKEN" + assert con._rest.mfa_token == "NEW_MFA_TOKEN" + await con.close() + + # third connection which is expected to login with new mfa token + con = snowflake.connector.aio.SnowflakeConnection(**conn_cfg) + await con.connect() + assert con._rest.mfa_token is None + await con.close() + + with pytest.raises(DatabaseError): + # A failed login will be forced by a mocked response for this connection + # Under authentication failed exception, mfa cache is expected to be cleaned up + con = snowflake.connector.aio.SnowflakeConnection(**conn_cfg) + await con.connect() + + # no mfa cache token should be sent at this connection + con = snowflake.connector.aio.SnowflakeConnection(**conn_cfg) + await con.connect() + await con.close() + + conn_cfg = { + "account": "testaccount", + "user": "testuser", + "password": "testpwd", + "authenticator": "username_password_mfa", + "host": "testaccount.snowflakecomputing.com", + } + if IS_LINUX: + conn_cfg["client_request_mfa_token"] = True + + if IS_MACOS: + with patch( + "keyring.delete_password", Mock(side_effect=mock_del_password) + ), patch("keyring.set_password", Mock(side_effect=mock_set_password)), patch( + "keyring.get_password", Mock(side_effect=mock_get_password) + ): + await test_body(conn_cfg) + else: + await test_body(conn_cfg) diff --git a/test/integ/aio/test_arrow_result_async.py b/test/integ/aio/test_arrow_result_async.py new file mode 100644 index 0000000000..f5788b2259 --- /dev/null +++ b/test/integ/aio/test_arrow_result_async.py @@ -0,0 +1,1090 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import base64 +import json +import logging +import random +import re +from contextlib import asynccontextmanager +from datetime import timedelta + +import numpy +import pytest + +import snowflake.connector.aio._cursor +from snowflake.connector.errors import OperationalError, ProgrammingError + +pytestmark = [ + pytest.mark.skipolddriver, # old test driver tests won't run this module +] + + +from test.integ.test_arrow_result import ( + DATATYPE_TEST_CONFIGURATIONS, + ICEBERG_CONFIG, + ICEBERG_STRUCTURED_REPRS, + ICEBERG_SUPPORTED, + ICEBERG_UNSUPPORTED_TYPES, + PANDAS_REPRS, + PANDAS_STRUCTURED_REPRS, + SEMI_STRUCTURED_REPRS, + STRUCTURED_TYPES_SUPPORTED, + dumps, + get_random_seed, + no_arrow_iterator_ext, + pandas_available, + random_string, + serialize, +) + + +async def datatype_verify(cur, data, deserialize): + rows = await cur.fetchall() + assert len(rows) == len(data), "Result should have same number of rows as examples" + for row, datum in zip(rows, data): + actual = json.loads(row[0]) if deserialize else row[0] + assert len(row) == 1, "Result should only have one column." + assert actual == datum, "Result values should match input examples." + + +async def pandas_verify(cur, data, deserialize): + pdf = await cur.fetch_pandas_all() + assert len(pdf) == len(data), "Result should have same number of rows as examples" + for value, datum in zip(pdf.COL.to_list(), data): + if deserialize: + value = json.loads(value) + if isinstance(value, numpy.ndarray): + value = value.tolist() + + # Numpy nans have to be checked with isnan. nan != nan according to numpy + if isinstance(value, float) and numpy.isnan(value): + assert datum is None or numpy.isnan(datum), "nan values should return nan." + else: + if isinstance(value, dict): + value = { + k: v.tolist() if isinstance(v, numpy.ndarray) else v + for k, v in value.items() + } + assert ( + value == datum or value is datum + ), f"Result value {value} should match input example {datum}." + + +async def verify_datatypes( + conn_cnx, + query, + examples, + schema, + iceberg=False, + pandas=False, + deserialize=False, +): + table_name = f"arrow_datatype_test_verifaction_table_{random_string(5)}" + async with structured_type_wrapped_conn(conn_cnx) as conn: + try: + await conn.cursor().execute("alter session set use_cached_result=false") + iceberg_table, iceberg_config = ( + ("iceberg", ICEBERG_CONFIG) if iceberg else ("", "") + ) + await conn.cursor().execute( + f"create {iceberg_table} table if not exists {table_name} {schema} {iceberg_config}" + ) + await conn.cursor().execute(f"insert into {table_name} {query}") + cur = await conn.cursor().execute(f"select * from {table_name}") + if pandas: + await pandas_verify(cur, examples, deserialize) + else: + await datatype_verify(cur, examples, deserialize) + finally: + await conn.cursor().execute(f"drop table if exists {table_name}") + + +@asynccontextmanager +async def structured_type_wrapped_conn(conn_cnx): + parameters = {} + if STRUCTURED_TYPES_SUPPORTED: + parameters = { + "python_connector_query_result_format": "arrow", + "ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE": True, + "ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT": True, + "FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT": True, + "IGNORE_CLIENT_VESRION_IN_STRUCTURED_TYPES_RESPONSE": True, + } + + async with conn_cnx(session_parameters=parameters) as conn: + yield conn + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not ICEBERG_SUPPORTED, reason="Iceberg not supported in this environment." +) +@pytest.mark.parametrize("datatype", ICEBERG_UNSUPPORTED_TYPES) +async def test_iceberg_negative(datatype, conn_cnx): + table_name = f"arrow_datatype_test_verification_table_{random_string(5)}" + async with structured_type_wrapped_conn(conn_cnx) as conn: + try: + with pytest.raises(ProgrammingError): + await conn.cursor().execute( + f"create iceberg table if not exists {table_name} (col {datatype}) {ICEBERG_CONFIG}" + ) + finally: + await conn.cursor().execute(f"drop table if exists {table_name}") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "datatype,examples,iceberg,pandas", DATATYPE_TEST_CONFIGURATIONS +) +async def test_datatypes(datatype, examples, iceberg, pandas, conn_cnx): + json_values = re.escape(json.dumps(examples, default=serialize)) + query = f""" + SELECT + value :: {datatype} as col + FROM + TABLE(FLATTEN(input => parse_json('{json_values}'))); + """ + if pandas: + examples = PANDAS_REPRS.get(datatype, examples) + if datatype == "VARIANT": + examples = [dumps(ex) for ex in examples] + await verify_datatypes( + conn_cnx, query, examples, f"(col {datatype})", iceberg, pandas + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "datatype,examples,iceberg,pandas", DATATYPE_TEST_CONFIGURATIONS +) +async def test_array(datatype, examples, iceberg, pandas, conn_cnx): + json_values = re.escape(json.dumps(examples, default=serialize)) + + if STRUCTURED_TYPES_SUPPORTED: + col_type = f"array({datatype})" + if datatype == "VARIANT": + examples = [dumps(ex) if ex else ex for ex in examples] + elif pandas: + if iceberg: + examples = ICEBERG_STRUCTURED_REPRS.get(datatype, examples) + else: + examples = PANDAS_STRUCTURED_REPRS.get(datatype, examples) + else: + col_type = "array" + examples = SEMI_STRUCTURED_REPRS.get(datatype, examples) + + query = f""" + SELECT + parse_json('{json_values}') :: {col_type} as col + """ + await verify_datatypes( + conn_cnx, + query, + (examples,), + f"(col {col_type})", + iceberg, + pandas, + not STRUCTURED_TYPES_SUPPORTED, + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not STRUCTURED_TYPES_SUPPORTED, reason="Testing structured type feature." +) +async def test_structured_type_binds(conn_cnx): + original_style = snowflake.connector.paramstyle + snowflake.connector.paramstyle = "qmark" + data = ( + 1, + [True, False, True], + {"k1": 1, "k2": 2, "k3": 3, "k4": 4, "k5": 5}, + {"city": "san jose", "population": 0.05}, + [1.0, 3.1, 4.5], + ) + json_data = [json.dumps(d) for d in data] + schema = "(num number, arr_b array(boolean), map map(varchar, int), obj object(city varchar, population float), arr_f array(float))" + table_name = f"arrow_structured_type_binds_test_{random_string(5)}" + async with structured_type_wrapped_conn(conn_cnx) as conn: + try: + await conn.cursor().execute("alter session set enable_bind_stage_v2=Enable") + await conn.cursor().execute( + f"create table if not exists {table_name} {schema}" + ) + await conn.cursor().execute( + f"insert into {table_name} select ?, ?, ?, ?, ?", json_data + ) + result = await ( + await conn.cursor().execute(f"select * from {table_name}") + ).fetchall() + assert result[0] == data + + # Binds don't work with values statement yet + with pytest.raises(ProgrammingError): + await conn.cursor().execute( + f"insert into {table_name} values (?, ?, ?, ?, ?)", json_data + ) + finally: + snowflake.connector.paramstyle = original_style + await conn.cursor().execute(f"drop table if exists {table_name}") + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not STRUCTURED_TYPES_SUPPORTED, reason="map type not supported in this environment" +) +@pytest.mark.parametrize("key_type", ["varchar", "number"]) +@pytest.mark.parametrize( + "datatype,examples,iceberg,pandas", DATATYPE_TEST_CONFIGURATIONS +) +async def test_map(key_type, datatype, examples, iceberg, pandas, conn_cnx): + if iceberg and key_type == "number": + pytest.skip("Iceberg does not support number keys.") + data = {str(i) if key_type == "varchar" else i: ex for i, ex in enumerate(examples)} + json_string = re.escape(json.dumps(data, default=serialize)) + + if datatype == "VARIANT": + data = {k: dumps(v) if v else v for k, v in data.items()} + if pandas: + data = list(data.items()) + elif pandas: + examples = PANDAS_STRUCTURED_REPRS.get(datatype, examples) + data = [ + (str(i) if key_type == "varchar" else i, ex) + for i, ex in enumerate(examples) + ] + + query = f""" + SELECT + parse_json('{json_string}') :: map({key_type}, {datatype}) as col + """ + + if iceberg and pandas and datatype in ICEBERG_STRUCTURED_REPRS: + with pytest.raises(ValueError): + # SNOW-1320508: Timestamp types nested in maps currently cause an exception for iceberg tables + await verify_datatypes( + conn_cnx, + query, + [data], + f"(col map({key_type}, {datatype}))", + iceberg, + pandas, + ) + else: + await verify_datatypes( + conn_cnx, + query, + [data], + f"(col map({key_type}, {datatype}))", + iceberg, + pandas, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "datatype,examples,iceberg,pandas", DATATYPE_TEST_CONFIGURATIONS +) +async def test_object(datatype, examples, iceberg, pandas, conn_cnx): + fields = [f"{datatype}_{i}" for i in range(len(examples))] + data = {k: v for k, v in zip(fields, examples)} + json_string = re.escape(json.dumps(data, default=serialize)) + + if STRUCTURED_TYPES_SUPPORTED: + schema = ", ".join(f"{field} {datatype}" for field in fields) + col_type = f"object({schema})" + if datatype == "VARIANT": + examples = [dumps(s) if s else s for s in examples] + elif pandas: + if iceberg: + examples = ICEBERG_STRUCTURED_REPRS.get(datatype, examples) + else: + examples = PANDAS_STRUCTURED_REPRS.get(datatype, examples) + else: + col_type = "object" + examples = SEMI_STRUCTURED_REPRS.get(datatype, examples) + expected_data = {k: v for k, v in zip(fields, examples)} + + query = f""" + SELECT + parse_json('{json_string}') :: {col_type} as col + """ + + if iceberg and pandas and datatype in ICEBERG_STRUCTURED_REPRS: + with pytest.raises(ValueError): + # SNOW-1320508: Timestamp types nested in objects currently cause an exception for iceberg tables + await verify_datatypes( + conn_cnx, + query, + [expected_data], + f"(col {col_type})", + iceberg, + pandas, + ) + else: + await verify_datatypes( + conn_cnx, + query, + [expected_data], + f"(col {col_type})", + iceberg, + pandas, + not STRUCTURED_TYPES_SUPPORTED, + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not STRUCTURED_TYPES_SUPPORTED, reason="map type not supported in this environment" +) +@pytest.mark.parametrize("pandas", [True, False] if pandas_available else [False]) +@pytest.mark.parametrize("iceberg", [True, False]) +async def test_nested_types(conn_cnx, iceberg, pandas): + data = {"child": [{"key1": {"struct_field": "value"}}]} + json_string = re.escape(json.dumps(data, default=serialize)) + query = f""" + SELECT + parse_json('{json_string}') :: object(child array(map (varchar, object(struct_field varchar)))) as col + """ + if pandas: + data = { + "child": [ + [ + ("key1", {"struct_field": "value"}), + ] + ] + } + await verify_datatypes( + conn_cnx, + query, + [data], + "(col object(child array(map (varchar, object(struct_field varchar)))))", + iceberg, + pandas, + ) + + +@pytest.mark.asyncio +async def test_select_tinyint(conn_cnx): + cases = [0, 1, -1, 127, -128] + table = "test_arrow_tiny_int" + column = "(a int)" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("num", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_scaled_tinyint(conn_cnx): + cases = [0.0, 0.11, -0.11, 1.27, -1.28] + table = "test_arrow_tiny_int" + column = "(a number(5,3))" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("num", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_smallint(conn_cnx): + cases = [0, 1, -1, 127, -128, 128, -129, 32767, -32768] + table = "test_arrow_small_int" + column = "(a int)" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("num", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_scaled_smallint(conn_cnx): + cases = ["0", "2.0", "-2.0", "32.767", "-32.768"] + table = "test_arrow_small_int" + column = "(a number(5,3))" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("num", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_int(conn_cnx): + cases = [ + 0, + 1, + -1, + 127, + -128, + 128, + -129, + 32767, + -32768, + 32768, + -32769, + 2147483647, + -2147483648, + ] + table = "test_arrow_int" + column = "(a int)" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("num", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_scaled_int(conn_cnx): + cases = ["0", "0.123456789", "-0.123456789", "0.2147483647", "-0.2147483647"] + table = "test_arrow_int" + column = "(a number(10,9))" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("num", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_bigint(conn_cnx): + cases = [ + 0, + 1, + -1, + 127, + -128, + 128, + -129, + 32767, + -32768, + 32768, + -32769, + 2147483647, + -2147483648, + 2147483648, + -2147483649, + 9223372036854775807, + -9223372036854775808, + ] + table = "test_arrow_bigint" + column = "(a int)" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("num", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_scaled_bigint(conn_cnx): + cases = [ + "0", + "0.000000000000000001", + "-0.000000000000000001", + "0.000000000000000127", + "-0.000000000000000128", + "0.000000000000000128", + "-0.000000000000000129", + "0.000000000000032767", + "-0.000000000000032768", + "0.000000000000032768", + "-0.000000000000032769", + "0.000000002147483647", + "-0.000000002147483648", + "0.000000002147483648", + "-0.000000002147483649", + "9.223372036854775807", + "-9.223372036854775808", + ] + table = "test_arrow_bigint" + column = "(a number(38,18))" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("num", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_decimal(conn_cnx): + cases = [ + "10000000000000000000000000000000000000", + "12345678901234567890123456789012345678", + "99999999999999999999999999999999999999", + ] + table = "test_arrow_decimal" + column = "(a number(38,0))" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("num", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_scaled_decimal(conn_cnx): + cases = [ + "0", + "0.000000000000000001", + "-0.000000000000000001", + "0.000000000000000127", + "-0.000000000000000128", + "0.000000000000000128", + "-0.000000000000000129", + "0.000000000000032767", + "-0.000000000000032768", + "0.000000000000032768", + "-0.000000000000032769", + "0.000000002147483647", + "-0.000000002147483648", + "0.000000002147483648", + "-0.000000002147483649", + "9.223372036854775807", + "-9.223372036854775808", + ] + table = "test_arrow_decimal" + column = "(a number(38,37))" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("num", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_large_scaled_decimal(conn_cnx): + cases = [ + "1.0000000000000000000000000000000000000", + "1.2345678901234567890123456789012345678", + "9.9999999999999999999999999999999999999", + ] + table = "test_arrow_decimal" + column = "(a number(38,37))" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("num", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_scaled_decimal_SNOW_133561(conn_cnx): + cases = [ + "0", + "1.2345", + "2.3456", + "-9.999", + "-1.000", + "-3.4567", + "3.4567", + "4.5678", + "5.6789", + "NULL", + ] + table = "test_scaled_decimal_SNOW_133561" + column = "(a number(38,10))" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("num", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_boolean(conn_cnx): + cases = ["true", "false", "true"] + table = "test_arrow_boolean" + column = "(a boolean)" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("boolean", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.skipif( + no_arrow_iterator_ext, reason="arrow_iterator extension is not built." +) +@pytest.mark.asyncio +async def test_select_double_precision(conn_cnx): + cases = [ + # SNOW-31249 + "-86.6426540296895", + "3.14159265359", + # SNOW-76269 + "1.7976931348623157e+308", + "1.7e+308", + "1.7976931348623151e+308", + "-1.7976931348623151e+308", + "-1.7e+308", + "-1.7976931348623157e+308", + ] + table = "test_arrow_double" + column = "(a double)" + values = "(" + "),(".join([f"{i}, {c}" for i, c in enumerate(cases)]) + ")" + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + col_count = 1 + await iterate_over_test_chunk( + "float", conn_cnx, sql_text, row_count, col_count, expected=cases + ) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_semi_structure(conn_cnx): + sql_text = """select array_construct(10, 20, 30), + array_construct(null, 'hello', 3::double, 4, 5), + array_construct(), + object_construct('a',1,'b','BBBB', 'c',null), + object_construct('Key_One', parse_json('NULL'), 'Key_Two', null, 'Key_Three', 'null'), + to_variant(3.2), + parse_json('{ "a": null}'), + 100::variant; + """ + row_count = 1 + col_count = 8 + await iterate_over_test_chunk("struct", conn_cnx, sql_text, row_count, col_count) + + +@pytest.mark.asyncio +async def test_select_vector(conn_cnx, is_public_test): + if is_public_test: + pytest.xfail( + reason="This feature hasn't been rolled out for public Snowflake deployments yet." + ) + + sql_text = """select [1,2,3]::vector(int,3), + [1.1,2.2]::vector(float,2), + NULL::vector(int,2), + NULL::vector(float,3); + """ + row_count = 1 + col_count = 4 + await iterate_over_test_chunk("vector", conn_cnx, sql_text, row_count, col_count) + + +@pytest.mark.asyncio +async def test_select_time(conn_cnx): + for scale in range(10): + await select_time_with_scale(conn_cnx, scale) + + +async def select_time_with_scale(conn_cnx, scale): + cases = [ + "00:01:23", + "00:01:23.1", + "00:01:23.12", + "00:01:23.123", + "00:01:23.1234", + "00:01:23.12345", + "00:01:23.123456", + "00:01:23.1234567", + "00:01:23.12345678", + "00:01:23.123456789", + ] + table = "test_arrow_time" + column = f"(a time({scale}))" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, '{c}'" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("time", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_date(conn_cnx): + cases = [ + "2016-07-23", + "1970-01-01", + "1969-12-31", + "0001-01-01", + "9999-12-31", + ] + table = "test_arrow_time" + column = "(a date)" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, '{c}'" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + await iterate_over_test_chunk("date", conn_cnx, sql_text, row_count, col_count) + await finish(conn_cnx, table) + + +@pytest.mark.parametrize("scale", range(10)) +@pytest.mark.parametrize("type", ["timestampntz", "timestampltz", "timestamptz"]) +@pytest.mark.asyncio +async def test_select_timestamp_with_scale(conn_cnx, scale, type): + cases = [ + "2017-01-01 12:00:00", + "2014-01-02 16:00:00", + "2014-01-02 12:34:56", + "2017-01-01 12:00:00.123456789", + "2014-01-02 16:00:00.000000001", + "2014-01-02 12:34:56.1", + "1969-12-31 23:59:59.000000001", + "1969-12-31 23:59:58.000000001", + "1969-11-30 23:58:58.000001001", + "1970-01-01 00:00:00.123412423", + "1970-01-01 00:00:01.000001", + "1969-12-31 11:59:59.001", + "0001-12-31 11:59:59.11", + ] + table = "test_arrow_timestamp" + column = f"(a {type}({scale}))" + values = ( + "(-1, NULL), (" + + "),(".join([f"{i}, '{c}'" for i, c in enumerate(cases)]) + + f"), ({len(cases)}, NULL)" + ) + await init(conn_cnx, table, column, values) + sql_text = f"select a from {table} order by s" + row_count = len(cases) + 2 + col_count = 1 + # TODO SNOW-534252 + await iterate_over_test_chunk( + type, + conn_cnx, + sql_text, + row_count, + col_count, + eps=timedelta(microseconds=1), + ) + await finish(conn_cnx, table) + + +@pytest.mark.asyncio +async def test_select_with_string(conn_cnx): + col_count = 2 + row_count = 50000 + random_seed = get_random_seed() + length = random.randint(1, 10) + sql_text = ( + "select seq4() as c1, randstr({}, random({})) as c2 from ".format( + length, random_seed + ) + + "table(generator(rowcount=>50000)) order by c1" + ) + await iterate_over_test_chunk("string", conn_cnx, sql_text, row_count, col_count) + + +@pytest.mark.asyncio +async def test_select_with_bool(conn_cnx): + col_count = 2 + row_count = 50000 + random_seed = get_random_seed() + sql_text = ( + "select seq4() as c1, as_boolean(uniform(0, 1, random({}))) as c2 from ".format( + random_seed + ) + + f"table(generator(rowcount=>{row_count})) order by c1" + ) + await iterate_over_test_chunk("bool", conn_cnx, sql_text, row_count, col_count) + + +@pytest.mark.asyncio +async def test_select_with_float(conn_cnx): + col_count = 2 + row_count = 50000 + random_seed = get_random_seed() + pow_val = random.randint(0, 10) + val_len = random.randint(0, 16) + # if we assign val_len a larger value like 20, then the precision difference between c++ and python will become + # very obvious so if we meet some error in this test in the future, please check that whether it is caused by + # different precision between python and c++ + val_range = random.randint(0, 10**val_len) + + sql_text = "select seq4() as c1, as_double(uniform({}, {}, random({})))/{} as c2 from ".format( + -val_range, val_range, random_seed, 10**pow_val + ) + "table(generator(rowcount=>{})) order by c1".format( + row_count + ) + await iterate_over_test_chunk( + "float", + conn_cnx, + sql_text, + row_count, + col_count, + eps=10 ** (-pow_val + 1), + ) + + +@pytest.mark.asyncio +async def test_select_with_empty_resultset(conn_cnx): + async with conn_cnx() as cnx: + cursor = cnx.cursor() + await cursor.execute("alter session set query_result_format='ARROW_FORCE'") + await cursor.execute( + "alter session set python_connector_query_result_format='ARROW_FORCE'" + ) + await cursor.execute( + "select seq4() from table(generator(rowcount=>100)) limit 0" + ) + + assert await cursor.fetchone() is None + + +@pytest.mark.asyncio +async def test_select_with_large_resultset(conn_cnx): + col_count = 5 + row_count = 1000000 + random_seed = get_random_seed() + + sql_text = ( + "select seq4() as c1, " + "uniform(-10000, 10000, random({})) as c2, " + "randstr(5, random({})) as c3, " + "randstr(10, random({})) as c4, " + "uniform(-100000, 100000, random({})) as c5 " + "from table(generator(rowcount=>{}))".format( + random_seed, random_seed, random_seed, random_seed, row_count + ) + ) + + await iterate_over_test_chunk( + "large_resultset", conn_cnx, sql_text, row_count, col_count + ) + + +@pytest.mark.asyncio +async def test_dict_cursor(conn_cnx): + async with conn_cnx() as cnx: + async with cnx.cursor(snowflake.connector.aio.DictCursor) as c: + await c.execute( + "alter session set python_connector_query_result_format='ARROW'" + ) + + # first test small result generated by GS + ret = await (await c.execute("select 1 as foo, 2 as bar")).fetchone() + assert ret["FOO"] == 1 + assert ret["BAR"] == 2 + + # test larger result set + row_index = 1 + async for row in await c.execute( + "select row_number() over (order by val asc) as foo, " + "row_number() over (order by val asc) as bar " + "from (select seq4() as val from table(generator(rowcount=>10000)));" + ): + assert row["FOO"] == row_index + assert row["BAR"] == row_index + row_index += 1 + + +@pytest.mark.asyncio +async def test_fetch_as_numpy_val(conn_cnx): + async with conn_cnx(numpy=True) as cnx: + cursor = cnx.cursor() + await cursor.execute( + "alter session set python_connector_query_result_format='ARROW'" + ) + + val = await ( + await cursor.execute( + """ +select 1.23456::double, 1.3456::number(10, 4), 1234567::number(10, 0) +""" + ) + ).fetchone() + assert isinstance(val[0], numpy.float64) + assert val[0] == numpy.float64("1.23456") + assert isinstance(val[1], numpy.float64) + assert val[1] == numpy.float64("1.3456") + assert isinstance(val[2], numpy.int64) + assert val[2] == numpy.float64("1234567") + + val = await ( + await cursor.execute( + """ +select '2019-08-10'::date, '2019-01-02 12:34:56.1234'::timestamp_ntz(4), +'2019-01-02 12:34:56.123456789'::timestamp_ntz(9), '2019-01-02 12:34:56.123456789'::timestamp_ntz(8) +""" + ) + ).fetchone() + assert isinstance(val[0], numpy.datetime64) + assert val[0] == numpy.datetime64("2019-08-10") + assert isinstance(val[1], numpy.datetime64) + assert val[1] == numpy.datetime64("2019-01-02 12:34:56.1234") + assert isinstance(val[2], numpy.datetime64) + assert val[2] == numpy.datetime64("2019-01-02 12:34:56.123456789") + assert isinstance(val[3], numpy.datetime64) + assert val[3] == numpy.datetime64("2019-01-02 12:34:56.12345678") + + +async def iterate_over_test_chunk( + test_name, conn_cnx, sql_text, row_count, col_count, eps=None, expected=None +): + async with conn_cnx() as json_cnx: + async with conn_cnx() as arrow_cnx: + if expected is None: + cursor_json = json_cnx.cursor() + await cursor_json.execute( + "alter session set query_result_format='JSON'" + ) + await cursor_json.execute( + "alter session set python_connector_query_result_format='JSON'" + ) + await cursor_json.execute(sql_text) + + cursor_arrow = arrow_cnx.cursor() + await cursor_arrow.execute("alter session set use_cached_result=false") + await cursor_arrow.execute( + "alter session set query_result_format='ARROW_FORCE'" + ) + await cursor_arrow.execute( + "alter session set python_connector_query_result_format='ARROW_FORCE'" + ) + await cursor_arrow.execute(sql_text) + assert cursor_arrow._query_result_format == "arrow" + + if expected is None: + for _ in range(0, row_count): + json_res = await cursor_json.fetchone() + arrow_res = await cursor_arrow.fetchone() + for j in range(0, col_count): + if test_name == "float" and eps is not None: + assert abs(json_res[j] - arrow_res[j]) <= eps + elif ( + test_name == "timestampltz" + and json_res[j] is not None + and eps is not None + ): + assert abs(json_res[j] - arrow_res[j]) <= eps + elif test_name == "vector": + assert json_res[j] == pytest.approx(arrow_res[j]) + else: + assert json_res[j] == arrow_res[j] + else: + # only support single column for now + for i in range(0, row_count): + arrow_res = await cursor_arrow.fetchone() + assert str(arrow_res[0]) == expected[i] + + +@pytest.mark.parametrize("debug_arrow_chunk", [True, False]) +@pytest.mark.asyncio +async def test_arrow_bad_data(conn_cnx, caplog, debug_arrow_chunk): + with caplog.at_level(logging.DEBUG): + async with conn_cnx( + debug_arrow_chunk=debug_arrow_chunk + ) as arrow_cnx, arrow_cnx.cursor() as cursor: + await cursor.execute("select 1") + cursor._result_set.batches[0]._data = base64.b64encode(b"wrong_data") + with pytest.raises(OperationalError): + await cursor.fetchone() + expr = bool("arrow data can not be parsed" in caplog.text) + assert expr if debug_arrow_chunk else not expr + + +async def init(conn_cnx, table, column, values): + async with conn_cnx() as json_cnx: + cursor_json = json_cnx.cursor() + column_with_seq = column[0] + "s number, " + column[1:] + await cursor_json.execute(f"create or replace table {table} {column_with_seq}") + await cursor_json.execute(f"insert into {table} values {values}") + + +async def finish(conn_cnx, table): + async with conn_cnx() as json_cnx: + cursor_json = json_cnx.cursor() + await cursor_json.execute(f"drop table IF EXISTS {table};") diff --git a/test/integ/aio/test_async_async.py b/test/integ/aio/test_async_async.py new file mode 100644 index 0000000000..8dcdb936d6 --- /dev/null +++ b/test/integ/aio/test_async_async.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import logging + +import pytest + +from snowflake.connector import DatabaseError, ProgrammingError +from snowflake.connector.constants import QueryStatus + +# Mark all tests in this file to time out after 2 minutes to prevent hanging forever +pytestmark = pytest.mark.timeout(120) + + +async def test_simple_async(conn_cnx): + """Simple test to that shows the most simple usage of fire and forget. + + This test also makes sure that wait_until_ready function's sleeping is tested and + that some fields are copied over correctly from the original query. + """ + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 5))" + ) + await cur.get_results_from_sfqid(cur.sfqid) + assert len(await cur.fetchall()) == 1 + assert cur.rowcount + assert cur.description + + +async def test_async_result_iteration(conn_cnx): + """Test yielding results of an async query. + + Ensures that wait_until_ready is also called in __iter__() via _prefetch_hook(). + """ + + async def result_generator(query): + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async(query) + await cur.get_results_from_sfqid(cur.sfqid) + async for row in cur: + yield row + + gen = result_generator("select count(*) from table(generator(timeLimit => 5))") + assert await anext(gen) + with pytest.raises(StopAsyncIteration): + await anext(gen) + + +async def test_async_exec(conn_cnx): + """Tests whether simple async query execution works. + + Runs a query that takes a few seconds to finish and then totally closes connection + to Snowflake. Then waits enough time for that query to finish, opens a new connection + and fetches results. It also tests QueryStatus related functionality too. + + This test tends to hang longer than expected when the testing warehouse is overloaded. + Manually looking at query history reveals that when a full GH actions + Jenkins test load hits one warehouse + it can be queued for 15 seconds, so for now we wait 5 seconds before checking and then we give it another 25 + seconds to finish. + """ + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 5))" + ) + q_id = cur.sfqid + status = await con.get_query_status(q_id) + assert con.is_still_running(status) + await asyncio.sleep(5) + async with conn_cnx() as con: + async with con.cursor() as cur: + for _ in range(25): + # Check upto 15 times once a second to see if it's done + status = await con.get_query_status(q_id) + if status == QueryStatus.SUCCESS: + break + await asyncio.sleep(1) + else: + pytest.fail( + f"We should have broke out of this loop, final query status: {status}" + ) + status = await con.get_query_status_throw_if_error(q_id) + assert status == QueryStatus.SUCCESS + await cur.get_results_from_sfqid(q_id) + assert len(await cur.fetchall()) == 1 + + +async def test_async_error(conn_cnx, caplog): + """Tests whether simple async query error retrieval works. + + Runs a query that will fail to execute and then tests that if we tried to get results for the query + then that would raise an exception. It also tests QueryStatus related functionality too. + """ + async with conn_cnx() as con: + async with con.cursor() as cur: + sql = "select * from nonexistentTable" + await cur.execute_async(sql) + q_id = cur.sfqid + with pytest.raises(ProgrammingError) as sync_error: + await cur.execute(sql) + while con.is_still_running(await con.get_query_status(q_id)): + await asyncio.sleep(1) + status = await con.get_query_status(q_id) + assert status == QueryStatus.FAILED_WITH_ERROR + assert con.is_an_error(status) + with pytest.raises(ProgrammingError) as e1: + await con.get_query_status_throw_if_error(q_id) + assert sync_error.value.errno != -1 + with pytest.raises(ProgrammingError) as e2: + await cur.get_results_from_sfqid(q_id) + assert e1.value.errno == e2.value.errno == sync_error.value.errno + + sfqid = (await cur.execute_async("SELECT SYSTEM$WAIT(2)"))["queryId"] + await cur.get_results_from_sfqid(sfqid) + async with con.cursor() as cancel_cursor: + # use separate cursor to cancel as execute will overwrite the previous query status + await cancel_cursor.execute(f"SELECT SYSTEM$CANCEL_QUERY('{sfqid}')") + with pytest.raises(DatabaseError) as e3, caplog.at_level(logging.INFO): + await cur.fetchall() + assert ( + "SQL execution canceled" in e3.value.msg + and f"Status of query '{sfqid}' is {QueryStatus.FAILED_WITH_ERROR.name}" + in caplog.text + ) + + +async def test_mix_sync_async(conn_cnx): + async with conn_cnx() as con: + async with con.cursor() as cur: + # Setup + await cur.execute( + "alter session set CLIENT_TIMESTAMP_TYPE_MAPPING=TIMESTAMP_TZ" + ) + try: + for table in ["smallTable", "uselessTable"]: + await cur.execute( + "create or replace table {} (colA string, colB int)".format( + table + ) + ) + await cur.execute( + "insert into {} values ('row1', 1), ('row2', 2), ('row3', 3)".format( + table + ) + ) + await cur.execute_async("select * from smallTable") + sf_qid1 = cur.sfqid + await cur.execute_async("select * from uselessTable") + sf_qid2 = cur.sfqid + # Wait until the 2 queries finish + while con.is_still_running(await con.get_query_status(sf_qid1)): + await asyncio.sleep(1) + while con.is_still_running(await con.get_query_status(sf_qid2)): + await asyncio.sleep(1) + await cur.execute("drop table uselessTable") + assert await cur.fetchall() == [("USELESSTABLE successfully dropped.",)] + await cur.get_results_from_sfqid(sf_qid1) + assert await cur.fetchall() == [("row1", 1), ("row2", 2), ("row3", 3)] + await cur.get_results_from_sfqid(sf_qid2) + assert await cur.fetchall() == [("row1", 1), ("row2", 2), ("row3", 3)] + finally: + for table in ["smallTable", "uselessTable"]: + await cur.execute(f"drop table if exists {table}") + + +async def test_async_qmark(conn_cnx): + """Tests that qmark parameter binding works with async queries.""" + import snowflake.connector + + orig_format = snowflake.connector.paramstyle + snowflake.connector.paramstyle = "qmark" + try: + async with conn_cnx() as con: + async with con.cursor() as cur: + try: + await cur.execute( + "create or replace table qmark_test (aa STRING, bb STRING)" + ) + await cur.execute( + "insert into qmark_test VALUES(?, ?)", ("test11", "test12") + ) + await cur.execute_async("select * from qmark_test") + async_qid = cur.sfqid + async with conn_cnx() as con2: + async with con2.cursor() as cur2: + await cur2.get_results_from_sfqid(async_qid) + assert await cur2.fetchall() == [("test11", "test12")] + finally: + await cur.execute("drop table if exists qmark_test") + finally: + snowflake.connector.paramstyle = orig_format + + +async def test_done_caching(conn_cnx): + """Tests whether get status caching is working as expected.""" + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 5))" + ) + qid1 = cur.sfqid + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 10))" + ) + qid2 = cur.sfqid + assert len(con._async_sfqids) == 2 + await asyncio.sleep(5) + while con.is_still_running(await con.get_query_status(qid1)): + await asyncio.sleep(1) + assert await con.get_query_status(qid1) == QueryStatus.SUCCESS + assert len(con._async_sfqids) == 1 + assert len(con._done_async_sfqids) == 1 + await asyncio.sleep(5) + while con.is_still_running(await con.get_query_status(qid2)): + await asyncio.sleep(1) + assert await con.get_query_status(qid2) == QueryStatus.SUCCESS + assert len(con._async_sfqids) == 0 + assert len(con._done_async_sfqids) == 2 + assert await con._all_async_queries_finished() + + +async def test_invalid_uuid_get_status(conn_cnx): + async with conn_cnx() as con: + async with con.cursor() as cur: + with pytest.raises( + ValueError, match=r"Invalid UUID: 'doesnt exist, dont even look'" + ): + await cur.get_results_from_sfqid("doesnt exist, dont even look") + + +async def test_unknown_sfqid(conn_cnx): + """Tests the exception that there is no Exception thrown when we attempt to get a status of a not existing query.""" + async with conn_cnx() as con: + assert ( + await con.get_query_status("12345678-1234-4123-A123-123456789012") + == QueryStatus.NO_DATA + ) + + +async def test_unknown_sfqid_results(conn_cnx): + """Tests that there is no Exception thrown when we attempt to get a status of a not existing query.""" + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.get_results_from_sfqid("12345678-1234-4123-A123-123456789012") + + +async def test_not_fetching(conn_cnx): + """Tests whether executing a new query actually cleans up after an async result retrieving. + + If someone tries to retrieve results then the first fetch would have to block. We should not block + if we executed a new query. + """ + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async("select 1") + sf_qid = cur.sfqid + await cur.get_results_from_sfqid(sf_qid) + await cur.execute("select 2") + assert cur._inner_cursor is None + assert cur._prefetch_hook is None + + +async def test_close_connection_with_running_async_queries(conn_cnx): + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 10))" + ) + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 1))" + ) + assert not (await con._all_async_queries_finished()) + assert len(con._done_async_sfqids) < 2 and con.rest is None + + +async def test_close_connection_with_completed_async_queries(conn_cnx): + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async("select 1") + qid1 = cur.sfqid + await cur.execute_async("select 2") + qid2 = cur.sfqid + while con.is_still_running( + (await con._get_query_status(qid1))[0] + ): # use _get_query_status to avoid caching + await asyncio.sleep(1) + while con.is_still_running((await con._get_query_status(qid2))[0]): + await asyncio.sleep(1) + assert await con._all_async_queries_finished() + assert len(con._done_async_sfqids) == 2 and con.rest is None diff --git a/test/integ/aio/test_autocommit_async.py b/test/integ/aio/test_autocommit_async.py new file mode 100644 index 0000000000..ecf05517f3 --- /dev/null +++ b/test/integ/aio/test_autocommit_async.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import snowflake.connector.aio + + +async def exe0(cnx, sql): + return await cnx.cursor().execute(sql) + + +async def _run_autocommit_off(cnx, db_parameters): + """Runs autocommit off test. + + Args: + cnx: The database connection context. + db_parameters: Database parameters. + """ + + async def exe(cnx, sql): + return await cnx.cursor().execute(sql.format(name=db_parameters["name"])) + + await exe( + cnx, + """ +INSERT INTO {name} VALUES(True), (False), (False) +""", + ) + res = await ( + await exe0( + cnx, + """ +SELECT CURRENT_TRANSACTION() +""", + ) + ).fetchone() + assert res[0] is not None + res = await ( + await exe( + cnx, + """ +SELECT COUNT(*) FROM {name} WHERE c1 +""", + ) + ).fetchone() + assert res[0] == 1 + res = await ( + await exe( + cnx, + """ +SELECT COUNT(*) FROM {name} WHERE NOT c1 +""", + ) + ).fetchone() + assert res[0] == 2 + await cnx.rollback() + res = await ( + await exe0( + cnx, + """ +SELECT CURRENT_TRANSACTION() +""", + ) + ).fetchone() + assert res[0] is None + res = await ( + await exe( + cnx, + """ +SELECT COUNT(*) FROM {name} WHERE NOT c1 +""", + ) + ).fetchone() + assert res[0] == 0 + await exe( + cnx, + """ +INSERT INTO {name} VALUES(True), (False), (False) +""", + ) + await cnx.commit() + res = await ( + await exe( + cnx, + """ +SELECT COUNT(*) FROM {name} WHERE NOT c1 +""", + ) + ).fetchone() + assert res[0] == 2 + await cnx.rollback() + res = await ( + await exe( + cnx, + """ +SELECT COUNT(*) FROM {name} WHERE NOT c1 +""", + ) + ).fetchone() + assert res[0] == 2 + + +async def _run_autocommit_on(cnx, db_parameters): + """Run autocommit on test. + + Args: + cnx: The database connection context. + db_parameters: Database parameters. + """ + + async def exe(cnx, sql): + return await cnx.cursor().execute(sql.format(name=db_parameters["name"])) + + await exe( + cnx, + """ +INSERT INTO {name} VALUES(True), (False), (False) +""", + ) + await cnx.rollback() + res = await ( + await exe( + cnx, + """ +SELECT COUNT(*) FROM {name} WHERE NOT c1 +""", + ) + ).fetchone() + assert res[0] == 4 + + +async def test_autocommit_attribute(conn_cnx, db_parameters): + """Tests autocommit attribute. + + Args: + conn_cnx: The database connection context. + db_parameters: Database parameters. + """ + + async def exe(cnx, sql): + return await cnx.cursor().execute(sql.format(name=db_parameters["name"])) + + async with conn_cnx() as cnx: + await exe( + cnx, + """ +CREATE TABLE {name} (c1 boolean) +""", + ) + try: + await cnx.autocommit(False) + await _run_autocommit_off(cnx, db_parameters) + await cnx.autocommit(True) + await _run_autocommit_on(cnx, db_parameters) + finally: + await exe( + cnx, + """ +DROP TABLE IF EXISTS {name} + """, + ) + + +async def test_autocommit_parameters(db_parameters): + """Tests autocommit parameter. + + Args: + db_parameters: Database parameters. + """ + + async def exe(cnx, sql): + return await cnx.cursor().execute(sql.format(name=db_parameters["name"])) + + async with snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + protocol=db_parameters["protocol"], + schema=db_parameters["schema"], + database=db_parameters["database"], + autocommit=False, + ) as cnx: + await exe( + cnx, + """ +CREATE TABLE {name} (c1 boolean) +""", + ) + await _run_autocommit_off(cnx, db_parameters) + + async with snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + protocol=db_parameters["protocol"], + schema=db_parameters["schema"], + database=db_parameters["database"], + autocommit=True, + ) as cnx: + await _run_autocommit_on(cnx, db_parameters) + await exe( + cnx, + """ +DROP TABLE IF EXISTS {name} +""", + ) diff --git a/test/integ/aio/test_bindings_async.py b/test/integ/aio/test_bindings_async.py new file mode 100644 index 0000000000..06b8017918 --- /dev/null +++ b/test/integ/aio/test_bindings_async.py @@ -0,0 +1,612 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import calendar +import tempfile +import time +from datetime import date, datetime +from datetime import time as datetime_time +from datetime import timedelta, timezone +from decimal import Decimal +from unittest.mock import patch + +import pendulum +import pytest +import pytz + +from snowflake.connector.converter import convert_datetime_to_epoch +from snowflake.connector.errors import ForbiddenError, ProgrammingError +from snowflake.connector.util_text import random_string + +tempfile.gettempdir() + +PST_TZ = "America/Los_Angeles" +JST_TZ = "Asia/Tokyo" +CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD" + + +async def test_invalid_binding_option(conn_cnx): + """Invalid paramstyle parameters.""" + with pytest.raises(ProgrammingError): + async with conn_cnx(paramstyle="hahaha"): + pass + + # valid cases + for s in ["format", "pyformat", "qmark", "numeric"]: + async with conn_cnx(paramstyle=s): + pass + + +@pytest.mark.parametrize( + "bulk_array_optimization", + [True, False], +) +async def test_binding(conn_cnx, db_parameters, bulk_array_optimization): + """Paramstyle qmark binding tests to cover basic data types.""" + CREATE_TABLE = """create or replace table {name} ( + c1 BOOLEAN, + c2 INTEGER, + c3 NUMBER(38,2), + c4 VARCHAR(1234), + c5 FLOAT, + c6 BINARY, + c7 BINARY, + c8 TIMESTAMP_NTZ, + c9 TIMESTAMP_NTZ, + c10 TIMESTAMP_NTZ, + c11 TIMESTAMP_NTZ, + c12 TIMESTAMP_LTZ, + c13 TIMESTAMP_LTZ, + c14 TIMESTAMP_LTZ, + c15 TIMESTAMP_LTZ, + c16 TIMESTAMP_TZ, + c17 TIMESTAMP_TZ, + c18 TIMESTAMP_TZ, + c19 TIMESTAMP_TZ, + c20 DATE, + c21 TIME, + c22 TIMESTAMP_NTZ, + c23 TIME, + c24 STRING, + c25 STRING, + c26 STRING + ) + """ + INSERT = """ +insert into {name} values( +?,?,?, ?,?,?, ?,?,?, ?,?,?, ?,?,?, ?,?,?, ?,?,?, ?,?,?,?,?) +""" + async with conn_cnx(paramstyle="qmark") as cnx: + await cnx.cursor().execute(CREATE_TABLE.format(name=db_parameters["name"])) + current_utctime = datetime.now(timezone.utc).replace(tzinfo=None) + current_localtime = pytz.utc.localize(current_utctime, is_dst=False).astimezone( + pytz.timezone(PST_TZ) + ) + current_localtime_without_tz = datetime.now() + current_localtime_with_other_tz = pytz.utc.localize( + current_localtime_without_tz, is_dst=False + ).astimezone(pytz.timezone(JST_TZ)) + dt = date(2017, 12, 30) + tm = datetime_time(hour=1, minute=2, second=3, microsecond=456) + struct_time_v = time.strptime("30 Sep 01 11:20:30", "%d %b %y %H:%M:%S") + tdelta = timedelta( + seconds=tm.hour * 3600 + tm.minute * 60 + tm.second, microseconds=tm.microsecond + ) + data = ( + True, + 1, + Decimal("1.2"), + "str1", + 1.2, + # Py2 has bytes in str type, so Python Connector + b"abc", + bytearray(b"def"), + current_utctime, + current_localtime, + current_localtime_without_tz, + current_localtime_with_other_tz, + ("TIMESTAMP_LTZ", current_utctime), + ("TIMESTAMP_LTZ", current_localtime), + ("TIMESTAMP_LTZ", current_localtime_without_tz), + ("TIMESTAMP_LTZ", current_localtime_with_other_tz), + ("TIMESTAMP_TZ", current_utctime), + ("TIMESTAMP_TZ", current_localtime), + ("TIMESTAMP_TZ", current_localtime_without_tz), + ("TIMESTAMP_TZ", current_localtime_with_other_tz), + dt, + tm, + ("TIMESTAMP_NTZ", struct_time_v), + ("TIME", tdelta), + ("TEXT", None), + "", + ',an\\\\escaped"line\n', + ) + try: + async with conn_cnx( + paramstyle="qmark", timezone=PST_TZ + ) as cnx, cnx.cursor() as c: + if bulk_array_optimization: + cnx._session_parameters[CLIENT_STAGE_ARRAY_BINDING_THRESHOLD] = 1 + await c.executemany(INSERT.format(name=db_parameters["name"]), [data]) + else: + await c.execute(INSERT.format(name=db_parameters["name"]), data) + + ret = await ( + await c.execute( + """ +select * from {name} where c1=? and c2=? +""".format( + name=db_parameters["name"] + ), + (True, 1), + ) + ).fetchone() + assert len(ret) == 26 + assert ret[0], "BOOLEAN" + assert ret[2] == Decimal("1.2"), "NUMBER" + assert ret[4] == 1.2, "FLOAT" + assert ret[5] == b"abc" + assert ret[6] == b"def" + assert ret[7] == current_utctime + assert convert_datetime_to_epoch(ret[8]) == convert_datetime_to_epoch( + current_localtime + ) + assert convert_datetime_to_epoch(ret[9]) == convert_datetime_to_epoch( + current_localtime_without_tz + ) + assert convert_datetime_to_epoch(ret[10]) == convert_datetime_to_epoch( + current_localtime_with_other_tz + ) + assert convert_datetime_to_epoch(ret[11]) == convert_datetime_to_epoch( + current_utctime + ) + assert convert_datetime_to_epoch(ret[12]) == convert_datetime_to_epoch( + current_localtime + ) + assert convert_datetime_to_epoch(ret[13]) == convert_datetime_to_epoch( + current_localtime_without_tz + ) + assert convert_datetime_to_epoch(ret[14]) == convert_datetime_to_epoch( + current_localtime_with_other_tz + ) + assert convert_datetime_to_epoch(ret[15]) == convert_datetime_to_epoch( + current_utctime + ) + assert convert_datetime_to_epoch(ret[16]) == convert_datetime_to_epoch( + current_localtime + ) + assert convert_datetime_to_epoch(ret[17]) == convert_datetime_to_epoch( + current_localtime_without_tz + ) + assert convert_datetime_to_epoch(ret[18]) == convert_datetime_to_epoch( + current_localtime_with_other_tz + ) + assert ret[19] == dt + assert ret[20] == tm + assert convert_datetime_to_epoch(ret[21]) == calendar.timegm(struct_time_v) + assert ( + timedelta( + seconds=ret[22].hour * 3600 + ret[22].minute * 60 + ret[22].second, + microseconds=ret[22].microsecond, + ) + == tdelta + ) + assert ret[23] is None + assert ret[24] == "" + assert ret[25] == ',an\\\\escaped"line\n' + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +drop table if exists {name} +""".format( + name=db_parameters["name"] + ) + ) + + +async def test_pendulum_binding(conn_cnx, db_parameters): + pendulum_test = pendulum.now() + try: + async with conn_cnx() as cnx, cnx.cursor() as c: + await c.execute( + """ + create or replace table {name} ( + c1 timestamp + ) + """.format( + name=db_parameters["name"] + ) + ) + fmt = "insert into {name}(c1) values(%(v1)s)".format( + name=db_parameters["name"] + ) + await c.execute(fmt, {"v1": pendulum_test}) + assert ( + len( + await ( + await c.execute( + "select count(*) from {name}".format( + name=db_parameters["name"] + ) + ) + ).fetchall() + ) + == 1 + ) + async with conn_cnx(paramstyle="qmark") as cnx, cnx.cursor() as c: + await c.execute( + """ + create or replace table {name} (c1 timestamp, c2 timestamp) + """.format( + name=db_parameters["name"] + ) + ) + await c.execute( + """ + insert into {name} values(?, ?) + """.format( + name=db_parameters["name"] + ), + (pendulum_test, pendulum_test), + ) + ret = await ( + await c.execute( + """ + select * from {name} + """.format( + name=db_parameters["name"] + ) + ) + ).fetchone() + assert convert_datetime_to_epoch(ret[0]) == convert_datetime_to_epoch( + pendulum_test + ) + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ + drop table if exists {name} + """.format( + name=db_parameters["name"] + ) + ) + + +async def test_binding_with_numeric(conn_cnx, db_parameters): + """Paramstyle numeric tests. Both qmark and numeric leverages server side bindings.""" + async with conn_cnx(paramstyle="numeric") as cnx: + await cnx.cursor().execute( + """ +create or replace table {name} (c1 integer, c2 string) +""".format( + name=db_parameters["name"] + ) + ) + + try: + async with conn_cnx(paramstyle="numeric") as cnx, cnx.cursor() as c: + await c.execute( + """ +insert into {name}(c1, c2) values(:2, :1) + """.format( + name=db_parameters["name"] + ), + ("str1", 123), + ) + await c.execute( + """ +insert into {name}(c1, c2) values(:2, :1) + """.format( + name=db_parameters["name"] + ), + ("str2", 456), + ) + # numeric and qmark can be used in the same session + rec = await ( + await c.execute( + """ +select * from {name} where c1=? +""".format( + name=db_parameters["name"] + ), + (123,), + ) + ).fetchall() + assert len(rec) == 1 + assert rec[0][0] == 123 + assert rec[0][1] == "str1" + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +drop table if exists {name} +""".format( + name=db_parameters["name"] + ) + ) + + +async def test_binding_timestamps(conn_cnx, db_parameters): + """Binding datetime object with TIMESTAMP_LTZ. + + The value is bound as TIMESTAMP_NTZ, but since it is converted to UTC in the backend, + the returned value must be ???. + """ + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +create or replace table {name} ( + c1 integer, + c2 timestamp_ltz) +""".format( + name=db_parameters["name"] + ) + ) + + try: + async with conn_cnx( + paramstyle="numeric", timezone=PST_TZ + ) as cnx, cnx.cursor() as c: + current_localtime = datetime.now() + await c.execute( + """ +insert into {name}(c1, c2) values(:1, :2) + """.format( + name=db_parameters["name"] + ), + (123, ("TIMESTAMP_LTZ", current_localtime)), + ) + rec = await ( + await c.execute( + """ +select * from {name} where c1=? + """.format( + name=db_parameters["name"] + ), + (123,), + ) + ).fetchall() + assert len(rec) == 1 + assert rec[0][0] == 123 + assert convert_datetime_to_epoch(rec[0][1]) == convert_datetime_to_epoch( + current_localtime + ) + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +drop table if exists {name} +""".format( + name=db_parameters["name"] + ) + ) + + +@pytest.mark.parametrize( + "num_rows", [pytest.param(100000, marks=pytest.mark.skipolddriver), 4] +) +async def test_binding_bulk_insert(conn_cnx, db_parameters, num_rows): + """Bulk insert test.""" + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +create or replace table {name} ( + c1 integer, + c2 string +) +""".format( + name=db_parameters["name"] + ) + ) + try: + async with conn_cnx(paramstyle="qmark") as cnx: + c = cnx.cursor() + fmt = "insert into {name}(c1,c2) values(?,?)".format( + name=db_parameters["name"] + ) + await c.executemany(fmt, [(idx, f"test{idx}") for idx in range(num_rows)]) + assert c.rowcount == num_rows + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +drop table if exists {name} +""".format( + name=db_parameters["name"] + ) + ) + + +@pytest.mark.skipolddriver +async def test_binding_bulk_insert_date(conn_cnx, db_parameters): + """Bulk insert test.""" + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +create or replace table {name} ( + c1 date +) +""".format( + name=db_parameters["name"] + ) + ) + try: + async with conn_cnx(paramstyle="qmark") as cnx: + c = cnx.cursor() + cnx._session_parameters[CLIENT_STAGE_ARRAY_BINDING_THRESHOLD] = 1 + dates = [ + [date.fromisoformat("1750-05-09")], + [date.fromisoformat("1969-01-01")], + [date.fromisoformat("1970-01-01")], + [date.fromisoformat("2023-05-12")], + [date.fromisoformat("2999-12-31")], + [date.fromisoformat("3000-12-31")], + [date.fromisoformat("9999-12-31")], + ] + await c.executemany( + f'INSERT INTO {db_parameters["name"]}(c1) VALUES (?)', dates + ) + assert c.rowcount == len(dates) + ret = await ( + await c.execute(f'SELECT c1 from {db_parameters["name"]}') + ).fetchall() + assert ret == [ + (date(1750, 5, 9),), + (date(1969, 1, 1),), + (date(1970, 1, 1),), + (date(2023, 5, 12),), + (date(2999, 12, 31),), + (date(3000, 12, 31),), + (date(9999, 12, 31),), + ] + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +drop table if exists {name} +""".format( + name=db_parameters["name"] + ) + ) + + +@pytest.mark.skipolddriver +async def test_binding_insert_date(conn_cnx, db_parameters): + bind_query = "SELECT TRY_TO_DATE(TO_CHAR(?,?),?)" + bind_variables = (date(2016, 4, 10), "YYYY-MM-DD", "YYYY-MM-DD") + bind_variables_2 = (date(2016, 4, 10), "YYYY-MM-DD", "DD-MON-YYYY") + async with conn_cnx(paramstyle="qmark") as cnx, cnx.cursor() as cursor: + assert await (await cursor.execute(bind_query, bind_variables)).fetchall() == [ + (date(2016, 4, 10),) + ] + # the second sql returns None because 2016-04-10 doesn't comply with the format DD-MON-YYYY + assert await ( + await cursor.execute(bind_query, bind_variables_2) + ).fetchall() == [(None,)] + + +@pytest.mark.skipolddriver +async def test_bulk_insert_binding_fallback(conn_cnx): + """When stage creation fails, bulk inserts falls back to server side binding and disables stage optimization.""" + async with conn_cnx(paramstyle="qmark") as cnx, cnx.cursor() as csr: + query = f"insert into {random_string(5)}(c1,c2) values(?,?)" + cnx._session_parameters[CLIENT_STAGE_ARRAY_BINDING_THRESHOLD] = 1 + with patch.object(csr, "_execute_helper") as mocked_execute_helper, patch( + "snowflake.connector.aio._cursor.BindUploadAgent._create_stage" + ) as mocked_stage_creation: + mocked_stage_creation.side_effect = ForbiddenError + await csr.executemany(query, [(idx, f"test{idx}") for idx in range(4)]) + mocked_stage_creation.assert_called_once() + mocked_execute_helper.assert_called_once() + assert ( + "binding_stage" not in mocked_execute_helper.call_args[1] + ), "Stage binding should fail" + assert ( + "binding_params" in mocked_execute_helper.call_args[1] + ), "Should fall back to server side binding" + assert cnx._session_parameters[CLIENT_STAGE_ARRAY_BINDING_THRESHOLD] == 0 + + +async def test_binding_bulk_update(conn_cnx, db_parameters): + """Bulk update test. + + Notes: + UPDATE,MERGE and DELETE are not supported for actual bulk operation + but executemany accepts the multiple rows and iterate DMLs. + """ + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +create or replace table {name} ( + c1 integer, + c2 string +) +""".format( + name=db_parameters["name"] + ) + ) + try: + async with conn_cnx(paramstyle="qmark") as cnx, cnx.cursor() as c: + # short list + fmt = "insert into {name}(c1,c2) values(?,?)".format( + name=db_parameters["name"] + ) + await c.executemany( + fmt, + [ + (1, "test1"), + (2, "test2"), + (3, "test3"), + (4, "test4"), + ], + ) + assert c.rowcount == 4 + + fmt = "update {name} set c2=:2 where c1=:1".format( + name=db_parameters["name"] + ) + await c.executemany( + fmt, + [ + (1, "test5"), + (2, "test6"), + ], + ) + assert c.rowcount == 2 + + fmt = "select * from {name} where c1=?".format(name=db_parameters["name"]) + rec = await (await c.execute(fmt, (1,))).fetchall() + assert rec[0][0] == 1 + assert rec[0][1] == "test5" + + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +drop table if exists {name} +""".format( + name=db_parameters["name"] + ) + ) + + +async def test_binding_identifier(conn_cnx, db_parameters): + """Binding a table name.""" + try: + async with conn_cnx(paramstyle="qmark") as cnx, cnx.cursor() as c: + data = "test" + await c.execute( + """ +create or replace table identifier(?) (c1 string) +""", + (db_parameters["name"],), + ) + await c.execute( + """ +insert into identifier(?) values(?) +""", + (db_parameters["name"], data), + ) + ret = await ( + await c.execute( + """ +select * from identifier(?) +""", + (db_parameters["name"],), + ) + ).fetchall() + assert len(ret) == 1 + assert ret[0][0] == data + finally: + async with conn_cnx(paramstyle="qmark") as cnx: + await cnx.cursor().execute( + """ +drop table if exists identifier(?) +""", + (db_parameters["name"],), + ) diff --git a/test/integ/aio/test_boolean_async.py b/test/integ/aio/test_boolean_async.py new file mode 100644 index 0000000000..93c9bbdebe --- /dev/null +++ b/test/integ/aio/test_boolean_async.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + + +async def test_binding_fetching_boolean(conn_cnx, db_parameters): + try: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +create or replace table {name} (c1 boolean, c2 integer) +""".format( + name=db_parameters["name"] + ) + ) + + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +insert into {name} values(%s,%s), (%s,%s), (%s,%s) +""".format( + name=db_parameters["name"] + ), + (True, 1, False, 2, True, 3), + ) + results = await ( + await cnx.cursor().execute( + """ +select * from {name} order by 1""".format( + name=db_parameters["name"] + ) + ) + ).fetchall() + assert not results[0][0] + assert results[1][0] + assert results[2][0] + results = await ( + await cnx.cursor().execute( + """ +select c1 from {name} where c2=2 +""".format( + name=db_parameters["name"] + ) + ) + ).fetchall() + assert not results[0][0] + + # SNOW-15905: boolean support + results = await ( + await cnx.cursor().execute( + """ +SELECT CASE WHEN (null LIKE trim(null)) THEN null ELSE null END +""" + ) + ).fetchall() + assert not results[0][0] + + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +drop table if exists {name} +""".format( + name=db_parameters["name"] + ) + ) + + +async def test_boolean_from_compiler(conn_cnx): + async with conn_cnx() as cnx: + ret = await (await cnx.cursor().execute("SELECT true")).fetchone() + assert ret[0] + + ret = await (await cnx.cursor().execute("SELECT false")).fetchone() + assert not ret[0] diff --git a/test/integ/aio/test_client_session_keep_alive_async.py b/test/integ/aio/test_client_session_keep_alive_async.py new file mode 100644 index 0000000000..fa242baad9 --- /dev/null +++ b/test/integ/aio/test_client_session_keep_alive_async.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio + +import pytest + +import snowflake.connector.aio + +try: + from parameters import CONNECTION_PARAMETERS +except ImportError: + CONNECTION_PARAMETERS = {} + +try: + from parameters import CONNECTION_PARAMETERS_ADMIN +except ImportError: + CONNECTION_PARAMETERS_ADMIN = {} + + +@pytest.fixture +async def token_validity_test_values(request): + async with snowflake.connector.aio.SnowflakeConnection( + **CONNECTION_PARAMETERS_ADMIN + ) as cnx: + print("[INFO] Setting token validity to test values") + await cnx.cursor().execute( + """ +ALTER SYSTEM SET + MASTER_TOKEN_VALIDITY=30, + SESSION_TOKEN_VALIDITY=10 +""" + ) + + async def fin(): + async with snowflake.connector.aio.SnowflakeConnection( + **CONNECTION_PARAMETERS_ADMIN + ) as cnx: + print("[INFO] Reverting token validity") + await cnx.cursor().execute( + """ +ALTER SYSTEM SET + MASTER_TOKEN_VALIDITY=default, + SESSION_TOKEN_VALIDITY=default +""" + ) + + request.addfinalizer(fin) + return None + + +@pytest.mark.skipif( + not (CONNECTION_PARAMETERS_ADMIN), + reason="ADMIN connection parameters must be provided.", +) +async def test_client_session_keep_alive(token_validity_test_values): + test_connection_parameters = CONNECTION_PARAMETERS.copy() + print("[INFO] Connected") + test_connection_parameters["client_session_keep_alive"] = True + async with snowflake.connector.aio.SnowflakeConnection( + **test_connection_parameters + ) as con: + print("[INFO] Running a query. Ensuring a connection is valid.") + await con.cursor().execute("select 1") + print("[INFO] Sleeping 15s") + await asyncio.sleep(15) + print( + "[INFO] Running a query. Both master and session tokens must " + "have been renewed by token request" + ) + await con.cursor().execute("select 1") + print("[INFO] Sleeping 40s") + await asyncio.sleep(40) + print( + "[INFO] Running a query. Master token must have been renewed " + "by the heartbeat" + ) + await con.cursor().execute("select 1") diff --git a/test/integ/aio/test_concurrent_create_objects_async.py b/test/integ/aio/test_concurrent_create_objects_async.py new file mode 100644 index 0000000000..a376776de6 --- /dev/null +++ b/test/integ/aio/test_concurrent_create_objects_async.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +from logging import getLogger + +import pytest + +from snowflake.connector import ProgrammingError + +try: + from parameters import CONNECTION_PARAMETERS_ADMIN +except ImportError: + CONNECTION_PARAMETERS_ADMIN = {} + +logger = getLogger(__name__) + + +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible." +) +async def test_snow5871(conn_cnx, db_parameters): + await _test_snow5871( + conn_cnx, + db_parameters, + number_of_threads=5, + rt_max_outgoing_rate=60, + rt_max_burst_size=5, + rt_max_borrowing_limt=1000, + rt_reset_period=10000, + ) + + await _test_snow5871( + conn_cnx, + db_parameters, + number_of_threads=40, + rt_max_outgoing_rate=60, + rt_max_burst_size=1, + rt_max_borrowing_limt=200, + rt_reset_period=1000, + ) + + +async def _create_a_table(meta): + cnx = meta["cnx"] + name = meta["name"] + try: + await cnx.cursor().execute( + """ +create table {} (aa int) + """.format( + name + ) + ) + # print("Success #" + meta['idx']) + return {"success": True} + except ProgrammingError: + logger.exception("Failed to create a table") + return {"success": False} + + +async def _test_snow5871( + conn_cnx, + db_parameters, + number_of_threads=10, + rt_max_outgoing_rate=60, + rt_max_burst_size=1, + rt_max_borrowing_limt=1000, + rt_reset_period=10000, +): + """SNOW-5871: rate limiting for creation of non-recycable objects.""" + logger.debug( + ( + "number_of_threads = %s, rt_max_outgoing_rate = %s, " + "rt_max_burst_size = %s, rt_max_borrowing_limt = %s, " + "rt_reset_period = %s" + ), + number_of_threads, + rt_max_outgoing_rate, + rt_max_burst_size, + rt_max_borrowing_limt, + rt_reset_period, + ) + async with conn_cnx( + user=db_parameters["sf_user"], + password=db_parameters["sf_password"], + account=db_parameters["sf_account"], + ) as cnx: + await cnx.cursor().execute( + """ +alter system set + RT_MAX_OUTGOING_RATE={}, + RT_MAX_BURST_SIZE={}, + RT_MAX_BORROWING_LIMIT={}, + RT_RESET_PERIOD={}""".format( + rt_max_outgoing_rate, + rt_max_burst_size, + rt_max_borrowing_limt, + rt_reset_period, + ) + ) + + try: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "create or replace database {name}_db".format( + name=db_parameters["name"] + ) + ) + meta = [] + for i in range(number_of_threads): + meta.append( + { + "idx": str(i + 1), + "cnx": cnx, + "name": db_parameters["name"] + "tbl_5871_" + str(i + 1), + } + ) + + tasks = [ + asyncio.create_task(_create_a_table(per_meta)) for per_meta in meta + ] + results = await asyncio.gather(*tasks) + success = 0 + for r in results: + success += 1 if r["success"] else 0 + + # at least one should be success + assert success >= 1, "success queries" + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "drop database if exists {name}_db".format(name=db_parameters["name"]) + ) + + async with conn_cnx( + user=db_parameters["sf_user"], + password=db_parameters["sf_password"], + account=db_parameters["sf_account"], + ) as cnx: + await cnx.cursor().execute( + """ +alter system set + RT_MAX_OUTGOING_RATE=default, + RT_MAX_BURST_SIZE=default, + RT_RESET_PERIOD=default, + RT_MAX_BORROWING_LIMIT=default""" + ) diff --git a/test/integ/aio/test_concurrent_insert_async.py b/test/integ/aio/test_concurrent_insert_async.py new file mode 100644 index 0000000000..be98474dfc --- /dev/null +++ b/test/integ/aio/test_concurrent_insert_async.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +from logging import getLogger + +import pytest + +import snowflake.connector.aio +from snowflake.connector.errors import ProgrammingError + +try: + from parameters import CONNECTION_PARAMETERS_ADMIN +except Exception: + CONNECTION_PARAMETERS_ADMIN = {} + +logger = getLogger(__name__) + + +async def _concurrent_insert(meta): + """Concurrent insert method.""" + cnx = snowflake.connector.aio.SnowflakeConnection( + user=meta["user"], + password=meta["password"], + host=meta["host"], + port=meta["port"], + account=meta["account"], + database=meta["database"], + schema=meta["schema"], + timezone="UTC", + protocol="http", + ) + await cnx.connect() + try: + await cnx.cursor().execute("use warehouse {}".format(meta["warehouse"])) + table = meta["table"] + sql = f"insert into {table} values(%(c1)s, %(c2)s)" + logger.debug(sql) + await cnx.cursor().execute( + sql, + { + "c1": meta["idx"], + "c2": "test string " + meta["idx"], + }, + ) + meta["success"] = True + logger.debug("Succeeded process #%s", meta["idx"]) + except Exception: + logger.exception("failed to insert into a table [%s]", table) + meta["success"] = False + finally: + await cnx.close() + return meta + + +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, + reason="The user needs a privilege of create warehouse.", +) +async def test_concurrent_insert(conn_cnx, db_parameters): + """Concurrent insert tests. Inserts block on the one that's running.""" + number_of_tasks = 22 # change this to increase the concurrency + expected_success_runs = number_of_tasks - 1 + cnx_array = [] + + try: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +create or replace warehouse {} +warehouse_type=standard +warehouse_size=small +""".format( + db_parameters["name_wh"] + ) + ) + sql = """ +create or replace table {name} (c1 integer, c2 string) +""".format( + name=db_parameters["name"] + ) + await cnx.cursor().execute(sql) + for i in range(number_of_tasks): + cnx_array.append( + { + "host": db_parameters["host"], + "port": db_parameters["port"], + "user": db_parameters["user"], + "password": db_parameters["password"], + "account": db_parameters["account"], + "database": db_parameters["database"], + "schema": db_parameters["schema"], + "table": db_parameters["name"], + "idx": str(i), + "warehouse": db_parameters["name_wh"], + } + ) + tasks = [ + asyncio.create_task(_concurrent_insert(cnx_item)) + for cnx_item in cnx_array + ] + results = await asyncio.gather(*tasks) + success = 0 + for record in results: + success += 1 if record["success"] else 0 + + # 21 threads or more + assert success >= expected_success_runs, "Number of success run" + + c = cnx.cursor() + sql = "select * from {name} order by 1".format(name=db_parameters["name"]) + await c.execute(sql) + for rec in c: + logger.debug(rec) + await c.close() + + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "drop table if exists {}".format(db_parameters["name"]) + ) + await cnx.cursor().execute( + "drop warehouse if exists {}".format(db_parameters["name_wh"]) + ) + + +async def _concurrent_insert_using_connection(meta): + connection = meta["connection"] + idx = meta["idx"] + name = meta["name"] + try: + await connection.cursor().execute( + f"INSERT INTO {name} VALUES(%s, %s)", + (idx, f"test string{idx}"), + ) + except ProgrammingError as e: + if e.errno != 619: # SQL Execution Canceled + raise + + +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, + reason="The user needs a privilege of create warehouse.", +) +async def test_concurrent_insert_using_connection(conn_cnx, db_parameters): + """Concurrent insert tests using the same connection.""" + try: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +create or replace warehouse {} +warehouse_type=standard +warehouse_size=small +""".format( + db_parameters["name_wh"] + ) + ) + await cnx.cursor().execute( + """ +CREATE OR REPLACE TABLE {name} (c1 INTEGER, c2 STRING) +""".format( + name=db_parameters["name"] + ) + ) + number_of_tasks = 5 + metas = [] + for i in range(number_of_tasks): + metas.append( + { + "connection": cnx, + "idx": i, + "name": db_parameters["name"], + } + ) + tasks = [ + asyncio.create_task(_concurrent_insert_using_connection(meta)) + for meta in metas + ] + await asyncio.gather(*tasks) + cnt = 0 + async for _ in await cnx.cursor().execute( + "SELECT * FROM {name} ORDER BY 1".format(name=db_parameters["name"]) + ): + cnt += 1 + assert ( + cnt <= number_of_tasks + ), "Number of records should be less than the number of threads" + assert cnt > 0, "Number of records should be one or more number of threads" + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "drop table if exists {}".format(db_parameters["name"]) + ) + await cnx.cursor().execute( + "drop warehouse if exists {}".format(db_parameters["name_wh"]) + ) diff --git a/test/integ/aio/test_connection_async.py b/test/integ/aio/test_connection_async.py new file mode 100644 index 0000000000..ab4a15c614 --- /dev/null +++ b/test/integ/aio/test_connection_async.py @@ -0,0 +1,1559 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import gc +import logging +import os +import pathlib +import queue +import stat +import tempfile +import warnings +import weakref +from test.integ.conftest import RUNNING_ON_GH +from test.randomize import random_string +from unittest import mock +from uuid import uuid4 + +import pytest + +import snowflake.connector.aio +from snowflake.connector import DatabaseError, OperationalError, ProgrammingError +from snowflake.connector.aio import SnowflakeConnection +from snowflake.connector.aio._description import CLIENT_NAME +from snowflake.connector.connection import DEFAULT_CLIENT_PREFETCH_THREADS +from snowflake.connector.errorcode import ( + ER_CONNECTION_IS_CLOSED, + ER_FAILED_PROCESSING_PYFORMAT, + ER_INVALID_VALUE, + ER_NO_ACCOUNT_NAME, + ER_NOT_IMPLICITY_SNOWFLAKE_DATATYPE, +) +from snowflake.connector.errors import Error, InterfaceError +from snowflake.connector.network import APPLICATION_SNOWSQL, ReauthenticationRequest +from snowflake.connector.sqlstate import SQLSTATE_FEATURE_NOT_SUPPORTED +from snowflake.connector.telemetry import TelemetryField + +try: # pragma: no cover + from ..parameters import CONNECTION_PARAMETERS_ADMIN +except ImportError: + CONNECTION_PARAMETERS_ADMIN = {} + +from snowflake.connector.aio.auth import AuthByOkta, AuthByPlugin + +try: + from snowflake.connector.errorcode import ER_FAILED_PROCESSING_QMARK +except ImportError: # Keep olddrivertest from breaking + ER_FAILED_PROCESSING_QMARK = 252012 + + +async def test_basic(conn_testaccount): + """Basic Connection test.""" + assert conn_testaccount, "invalid cnx" + # Test default values + assert conn_testaccount.session_id + + +async def test_connection_without_schema(db_parameters): + """Basic Connection test without schema.""" + cnx = snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + database=db_parameters["database"], + protocol=db_parameters["protocol"], + timezone="UTC", + ) + await cnx.connect() + assert cnx, "invalid cnx" + await cnx.close() + + +async def test_connection_without_database_schema(db_parameters): + """Basic Connection test without database and schema.""" + cnx = snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + protocol=db_parameters["protocol"], + timezone="UTC", + ) + await cnx.connect() + assert cnx, "invalid cnx" + await cnx.close() + + +async def test_connection_without_database2(db_parameters): + """Basic Connection test without database.""" + cnx = snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + schema=db_parameters["schema"], + protocol=db_parameters["protocol"], + timezone="UTC", + ) + await cnx.connect() + assert cnx, "invalid cnx" + await cnx.close() + + +async def test_with_config(db_parameters): + """Creates a connection with the config parameter.""" + config = { + "user": db_parameters["user"], + "password": db_parameters["password"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "account": db_parameters["account"], + "schema": db_parameters["schema"], + "database": db_parameters["database"], + "protocol": db_parameters["protocol"], + "timezone": "UTC", + } + cnx = snowflake.connector.aio.SnowflakeConnection(**config) + try: + await cnx.connect() + assert cnx, "invalid cnx" + assert not cnx.client_session_keep_alive # default is False + finally: + await cnx.close() + + +@pytest.mark.skipolddriver +async def test_with_tokens(conn_cnx, db_parameters): + """Creates a connection using session and master token.""" + try: + async with conn_cnx( + timezone="UTC", + ) as initial_cnx: + assert initial_cnx, "invalid initial cnx" + master_token = initial_cnx.rest._master_token + session_token = initial_cnx.rest._token + async with snowflake.connector.aio.SnowflakeConnection( + account=db_parameters["account"], + host=db_parameters["host"], + port=db_parameters["port"], + protocol=db_parameters["protocol"], + session_token=session_token, + master_token=master_token, + ) as token_cnx: + await token_cnx.connect() + assert token_cnx, "invalid second cnx" + except Exception: + # This is my way of guaranteeing that we'll not expose the + # sensitive information that this test needs to handle. + # db_parameter contains passwords. + pytest.fail("something failed", pytrace=False) + + +@pytest.mark.skipolddriver +async def test_with_tokens_expired(conn_cnx, db_parameters): + """Creates a connection using session and master token.""" + try: + async with conn_cnx( + timezone="UTC", + ) as initial_cnx: + assert initial_cnx, "invalid initial cnx" + master_token = initial_cnx._rest._master_token + session_token = initial_cnx._rest._token + + with pytest.raises(ProgrammingError): + token_cnx = snowflake.connector.aio.SnowflakeConnection( + account=db_parameters["account"], + host=db_parameters["host"], + port=db_parameters["port"], + protocol=db_parameters["protocol"], + session_token=session_token, + master_token=master_token, + ) + await token_cnx.connect() + await token_cnx.close() + except Exception: + # This is my way of guaranteeing that we'll not expose the + # sensitive information that this test needs to handle. + # db_parameter contains passwords. + pytest.fail("something failed", pytrace=False) + + +async def test_keep_alive_true(db_parameters): + """Creates a connection with client_session_keep_alive parameter.""" + config = { + "user": db_parameters["user"], + "password": db_parameters["password"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "account": db_parameters["account"], + "schema": db_parameters["schema"], + "database": db_parameters["database"], + "protocol": db_parameters["protocol"], + "timezone": "UTC", + "client_session_keep_alive": True, + } + cnx = snowflake.connector.aio.SnowflakeConnection(**config) + try: + await cnx.connect() + assert cnx.client_session_keep_alive + finally: + await cnx.close() + + +async def test_keep_alive_heartbeat_frequency(db_parameters): + """Tests heartbeat setting. + + Creates a connection with client_session_keep_alive_heartbeat_frequency + parameter. + """ + config = { + "user": db_parameters["user"], + "password": db_parameters["password"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "account": db_parameters["account"], + "schema": db_parameters["schema"], + "database": db_parameters["database"], + "protocol": db_parameters["protocol"], + "timezone": "UTC", + "client_session_keep_alive": True, + "client_session_keep_alive_heartbeat_frequency": 1000, + } + cnx = snowflake.connector.aio.SnowflakeConnection(**config) + try: + await cnx.connect() + assert cnx.client_session_keep_alive_heartbeat_frequency == 1000 + finally: + await cnx.close() + + +@pytest.mark.skipolddriver +async def test_keep_alive_heartbeat_frequency_min(db_parameters): + """Tests heartbeat setting with custom frequency. + + Creates a connection with client_session_keep_alive_heartbeat_frequency parameter and set the minimum frequency. + Also if a value comes as string, should be properly converted to int and not fail assertion. + """ + config = { + "user": db_parameters["user"], + "password": db_parameters["password"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "account": db_parameters["account"], + "schema": db_parameters["schema"], + "database": db_parameters["database"], + "protocol": db_parameters["protocol"], + "timezone": "UTC", + "client_session_keep_alive": True, + "client_session_keep_alive_heartbeat_frequency": "10", + } + cnx = snowflake.connector.aio.SnowflakeConnection(**config) + try: + # The min value of client_session_keep_alive_heartbeat_frequency + # is 1/16 of master token validity, so 14400 / 4 /4 => 900 + await cnx.connect() + assert cnx.client_session_keep_alive_heartbeat_frequency == 900 + finally: + await cnx.close() + + +async def test_keep_alive_heartbeat_send(db_parameters): + config = { + "user": db_parameters["user"], + "password": db_parameters["password"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "account": db_parameters["account"], + "schema": db_parameters["schema"], + "database": db_parameters["database"], + "protocol": db_parameters["protocol"], + "timezone": "UTC", + "client_session_keep_alive": True, + "client_session_keep_alive_heartbeat_frequency": "1", + } + with mock.patch( + "snowflake.connector.aio._connection.SnowflakeConnection._validate_client_session_keep_alive_heartbeat_frequency", + return_value=900, + ), mock.patch( + "snowflake.connector.aio._connection.SnowflakeConnection.client_session_keep_alive_heartbeat_frequency", + new_callable=mock.PropertyMock, + return_value=1, + ), mock.patch( + "snowflake.connector.aio._connection.SnowflakeConnection._heartbeat_tick" + ) as mocked_heartbeat: + cnx = snowflake.connector.aio.SnowflakeConnection(**config) + try: + await cnx.connect() + # we manually call the heartbeat function once to verify heartbeat request works + assert "success" in (await cnx._rest._heartbeat()) + assert cnx.client_session_keep_alive_heartbeat_frequency == 1 + await asyncio.sleep(3) + + finally: + await cnx.close() + # we verify the SnowflakeConnection._heartbeat_tick is called at least twice because we sleep for 3 seconds + # while the frequency is 1 second + assert mocked_heartbeat.called + assert mocked_heartbeat.call_count >= 2 + + +async def test_bad_db(db_parameters): + """Attempts to use a bad DB.""" + cnx = snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + protocol=db_parameters["protocol"], + database="baddb", + ) + await cnx.connect() + assert cnx, "invald cnx" + await cnx.close() + + +async def test_with_string_login_timeout(db_parameters): + """Test that login_timeout when passed as string does not raise TypeError. + + In this test, we pass bad login credentials to raise error and trigger login + timeout calculation. We expect to see DatabaseError instead of TypeError that + comes from str - int arithmetic. + """ + with pytest.raises(DatabaseError): + async with snowflake.connector.aio.SnowflakeConnection( + protocol="http", + user="bogus", + password="bogus", + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + login_timeout="5", + ): + pass + + +async def test_bogus(db_parameters): + """Attempts to login with invalid user name and password. + + Notes: + This takes a long time. + """ + with pytest.raises(DatabaseError): + async with snowflake.connector.aio.SnowflakeConnection( + protocol="http", + user="bogus", + password="bogus", + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + login_timeout=5, + ): + pass + + with pytest.raises(DatabaseError): + async with snowflake.connector.aio.SnowflakeConnection( + protocol="http", + user="bogus", + password="bogus", + account="testaccount123", + host=db_parameters["host"], + port=db_parameters["port"], + login_timeout=5, + insecure_mode=True, + ): + pass + + with pytest.raises(DatabaseError): + async with snowflake.connector.aio.SnowflakeConnection( + protocol="http", + user="snowman", + password="", + account="testaccount123", + host=db_parameters["host"], + port=db_parameters["port"], + login_timeout=5, + ): + pass + + with pytest.raises(ProgrammingError): + async with snowflake.connector.aio.SnowflakeConnection( + protocol="http", + user="", + password="password", + account="testaccount123", + host=db_parameters["host"], + port=db_parameters["port"], + login_timeout=5, + ): + pass + + +async def test_invalid_application(db_parameters): + """Invalid application name.""" + with pytest.raises(snowflake.connector.Error): + async with snowflake.connector.aio.SnowflakeConnection( + protocol=db_parameters["protocol"], + user=db_parameters["user"], + password=db_parameters["password"], + application="%%%", + ): + pass + + +async def test_valid_application(db_parameters): + """Valid application name.""" + application = "Special_Client" + cnx = snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + application=application, + protocol=db_parameters["protocol"], + ) + await cnx.connect() + assert cnx.application == application, "Must be valid application" + await cnx.close() + + +async def test_invalid_default_parameters(db_parameters): + """Invalid database, schema, warehouse and role name.""" + cnx = snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + protocol=db_parameters["protocol"], + database="neverexists", + schema="neverexists", + warehouse="neverexits", + ) + await cnx.connect() + assert cnx, "Must be success" + + with pytest.raises(snowflake.connector.DatabaseError): + # must not success + async with snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + protocol=db_parameters["protocol"], + database="neverexists", + schema="neverexists", + validate_default_parameters=True, + ): + pass + + with pytest.raises(snowflake.connector.DatabaseError): + # must not success + async with snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + protocol=db_parameters["protocol"], + database=db_parameters["database"], + schema="neverexists", + validate_default_parameters=True, + ): + pass + + with pytest.raises(snowflake.connector.DatabaseError): + # must not success + async with snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + protocol=db_parameters["protocol"], + database=db_parameters["database"], + schema=db_parameters["schema"], + warehouse="neverexists", + validate_default_parameters=True, + ): + pass + + # Invalid role name is already validated + with pytest.raises(snowflake.connector.DatabaseError): + # must not success + async with snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + protocol=db_parameters["protocol"], + database=db_parameters["database"], + schema=db_parameters["schema"], + role="neverexists", + ): + pass + + +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, + reason="The user needs a privilege of create warehouse.", +) +async def test_drop_create_user(conn_cnx, db_parameters): + """Drops and creates user.""" + async with conn_cnx() as cnx: + + async def exe(sql): + return await cnx.cursor().execute(sql) + + await exe("use role accountadmin") + await exe("drop user if exists snowdog") + await exe("create user if not exists snowdog identified by 'testdoc'") + await exe("use {}".format(db_parameters["database"])) + await exe("create or replace role snowdog_role") + await exe("grant role snowdog_role to user snowdog") + try: + # This statement will be partially executed because REFERENCE_USAGE + # will not be granted. + await exe( + "grant all on database {} to role snowdog_role".format( + db_parameters["database"] + ) + ) + except ProgrammingError as error: + err_str = ( + "Grant partially executed: privileges [REFERENCE_USAGE] not granted." + ) + assert 3011 == error.errno + assert error.msg.find(err_str) != -1 + await exe( + "grant all on schema {} to role snowdog_role".format( + db_parameters["schema"] + ) + ) + + async with conn_cnx(user="snowdog", password="testdoc") as cnx2: + + async def exe(sql): + return await cnx2.cursor().execute(sql) + + await exe("use role snowdog_role") + await exe("use {}".format(db_parameters["database"])) + await exe("use schema {}".format(db_parameters["schema"])) + await exe("create or replace table friends(name varchar(100))") + await exe("drop table friends") + async with conn_cnx() as cnx: + + async def exe(sql): + return await cnx.cursor().execute(sql) + + await exe("use role accountadmin") + await exe( + "revoke all on database {} from role snowdog_role".format( + db_parameters["database"] + ) + ) + await exe("drop role snowdog_role") + await exe("drop user if exists snowdog") + + +@pytest.mark.timeout(15) +@pytest.mark.skipolddriver +async def test_invalid_account_timeout(): + with pytest.raises(InterfaceError): + async with snowflake.connector.aio.SnowflakeConnection( + account="bogus", user="test", password="test", login_timeout=5 + ): + pass + + +@pytest.mark.timeout(15) +async def test_invalid_proxy(db_parameters): + with pytest.raises(OperationalError): + async with snowflake.connector.aio.SnowflakeConnection( + protocol="http", + account="testaccount", + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + login_timeout=5, + proxy_host="localhost", + proxy_port="3333", + ): + pass + # NOTE environment variable is set if the proxy parameter is specified. + del os.environ["HTTP_PROXY"] + del os.environ["HTTPS_PROXY"] + + +@pytest.mark.timeout(15) +@pytest.mark.skipolddriver +async def test_eu_connection(tmpdir): + """Tests setting custom region. + + If region is specified to eu-central-1, the URL should become + https://testaccount1234.eu-central-1.snowflakecomputing.com/ . + + Notes: + Region is deprecated. + """ + import os + + os.environ["SF_OCSP_RESPONSE_CACHE_SERVER_ENABLED"] = "true" + with pytest.raises(InterfaceError): + # must reach Snowflake + async with snowflake.connector.aio.SnowflakeConnection( + account="testaccount1234", + user="testuser", + password="testpassword", + region="eu-central-1", + login_timeout=5, + ocsp_response_cache_filename=os.path.join( + str(tmpdir), "test_ocsp_cache.txt" + ), + ): + pass + + +@pytest.mark.skipolddriver +async def test_us_west_connection(tmpdir): + """Tests default region setting. + + Region='us-west-2' indicates no region is included in the hostname, i.e., + https://testaccount1234.snowflakecomputing.com. + + Notes: + Region is deprecated. + """ + with pytest.raises(InterfaceError): + # must reach Snowflake + async with snowflake.connector.aio.SnowflakeConnection( + account="testaccount1234", + user="testuser", + password="testpassword", + region="us-west-2", + login_timeout=5, + ): + pass + + +@pytest.mark.timeout(60) +async def test_privatelink(db_parameters): + """Ensure the OCSP cache server URL is overridden if privatelink connection is used.""" + try: + os.environ["SF_OCSP_FAIL_OPEN"] = "false" + os.environ["SF_OCSP_DO_RETRY"] = "false" + async with snowflake.connector.aio.SnowflakeConnection( + account="testaccount", + user="testuser", + password="testpassword", + region="eu-central-1.privatelink", + login_timeout=5, + ): + pass + pytest.fail("should not make connection") + except OperationalError: + ocsp_url = os.getenv("SF_OCSP_RESPONSE_CACHE_SERVER_URL") + assert ocsp_url is not None, "OCSP URL should not be None" + assert ( + ocsp_url == "http://ocsp.testaccount.eu-central-1." + "privatelink.snowflakecomputing.com/" + "ocsp_response_cache.json" + ) + + cnx = snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + database=db_parameters["database"], + protocol=db_parameters["protocol"], + timezone="UTC", + ) + await cnx.connect() + assert cnx, "invalid cnx" + + ocsp_url = os.getenv("SF_OCSP_RESPONSE_CACHE_SERVER_URL") + assert ocsp_url is None, f"OCSP URL should be None: {ocsp_url}" + del os.environ["SF_OCSP_DO_RETRY"] + del os.environ["SF_OCSP_FAIL_OPEN"] + + +async def test_disable_request_pooling(db_parameters): + """Creates a connection with client_session_keep_alive parameter.""" + config = { + "user": db_parameters["user"], + "password": db_parameters["password"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "account": db_parameters["account"], + "schema": db_parameters["schema"], + "database": db_parameters["database"], + "protocol": db_parameters["protocol"], + "timezone": "UTC", + "disable_request_pooling": True, + } + cnx = snowflake.connector.aio.SnowflakeConnection(**config) + try: + await cnx.connect() + assert cnx.disable_request_pooling + finally: + await cnx.close() + + +async def test_privatelink_ocsp_url_creation(): + hostname = "testaccount.us-east-1.privatelink.snowflakecomputing.com" + await SnowflakeConnection.setup_ocsp_privatelink(APPLICATION_SNOWSQL, hostname) + + ocsp_cache_server = os.getenv("SF_OCSP_RESPONSE_CACHE_SERVER_URL", None) + assert ( + ocsp_cache_server + == "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/ocsp_response_cache.json" + ) + + del os.environ["SF_OCSP_RESPONSE_CACHE_SERVER_URL"] + + await SnowflakeConnection.setup_ocsp_privatelink(CLIENT_NAME, hostname) + ocsp_cache_server = os.getenv("SF_OCSP_RESPONSE_CACHE_SERVER_URL", None) + assert ( + ocsp_cache_server + == "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/ocsp_response_cache.json" + ) + + +async def test_privatelink_ocsp_url_concurrent(): + bucket = queue.Queue() + + hostname = "testaccount.us-east-1.privatelink.snowflakecomputing.com" + expectation = "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/ocsp_response_cache.json" + task = [] + + for _ in range(15): + task.append( + asyncio.create_task( + ExecPrivatelinkAsyncTask( + bucket, hostname, expectation, CLIENT_NAME + ).run() + ) + ) + + await asyncio.gather(*task) + assert bucket.qsize() == 15 + for _ in range(15): + if bucket.get() != "Success": + raise AssertionError() + + if os.getenv("SF_OCSP_RESPONSE_CACHE_SERVER_URL", None) is not None: + del os.environ["SF_OCSP_RESPONSE_CACHE_SERVER_URL"] + + +async def test_privatelink_ocsp_url_concurrent_snowsql(): + bucket = queue.Queue() + + hostname = "testaccount.us-east-1.privatelink.snowflakecomputing.com" + expectation = "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/ocsp_response_cache.json" + task = [] + + for _ in range(15): + task.append( + asyncio.create_task( + ExecPrivatelinkAsyncTask( + bucket, hostname, expectation, APPLICATION_SNOWSQL + ).run() + ) + ) + + await asyncio.gather(*task) + assert bucket.qsize() == 15 + for _ in range(15): + if bucket.get() != "Success": + raise AssertionError() + + +class ExecPrivatelinkAsyncTask: + def __init__(self, bucket, hostname, expectation, client_name): + self.bucket = bucket + self.hostname = hostname + self.expectation = expectation + self.client_name = client_name + + async def run(self): + await SnowflakeConnection.setup_ocsp_privatelink( + self.client_name, self.hostname + ) + ocsp_cache_server = os.getenv("SF_OCSP_RESPONSE_CACHE_SERVER_URL", None) + if ocsp_cache_server is not None and ocsp_cache_server != self.expectation: + print(f"Got {ocsp_cache_server} Expected {self.expectation}") + self.bucket.put("Fail") + else: + self.bucket.put("Success") + + +async def test_okta_url(conn_cnx): + orig_authenticator = "https://someaccount.okta.com/snowflake/oO56fExYCGnfV83/2345" + + async def mock_auth(self, auth_instance): + assert isinstance(auth_instance, AuthByOkta) + assert self._authenticator == orig_authenticator + + with mock.patch( + "snowflake.connector.aio.SnowflakeConnection._authenticate", + mock_auth, + ): + async with conn_cnx( + timezone="UTC", + authenticator=orig_authenticator, + ) as cnx: + assert cnx + + +async def test_dashed_url(db_parameters): + """Test whether dashed URLs get created correctly.""" + with mock.patch( + "snowflake.connector.aio._network.SnowflakeRestful.fetch", + return_value={"data": {"token": None, "masterToken": None}, "success": True}, + ) as mocked_fetch: + async with snowflake.connector.aio.SnowflakeConnection( + user="test-user", + password="test-password", + host="test-host", + port="443", + account="test-account", + ) as cnx: + assert cnx + cnx.commit = cnx.rollback = lambda: asyncio.sleep( + 0 + ) # Skip tear down, there's only a mocked rest api + assert any( + [ + c[0][1].startswith("https://test-host:443") + for c in mocked_fetch.call_args_list + ] + ) + + +async def test_dashed_url_account_name(db_parameters): + """Tests whether dashed URLs get created correctly when no hostname is provided.""" + with mock.patch( + "snowflake.connector.aio._network.SnowflakeRestful.fetch", + return_value={"data": {"token": None, "masterToken": None}, "success": True}, + ) as mocked_fetch: + async with snowflake.connector.aio.SnowflakeConnection( + user="test-user", + password="test-password", + port="443", + account="test-account", + ) as cnx: + assert cnx + cnx.commit = cnx.rollback = lambda: asyncio.sleep( + 0 + ) # Skip tear down, there's only a mocked rest api + assert any( + [ + c[0][1].startswith( + "https://test-account.snowflakecomputing.com:443" + ) + for c in mocked_fetch.call_args_list + ] + ) + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "name,value,exc_warn", + [ + # Not existing parameter + ( + "no_such_parameter", + True, + UserWarning("'no_such_parameter' is an unknown connection parameter"), + ), + # Typo in parameter name + ( + "applucation", + True, + UserWarning( + "'applucation' is an unknown connection parameter, did you mean 'application'?" + ), + ), + # Single type error + ( + "support_negative_year", + "True", + UserWarning( + "'support_negative_year' connection parameter should be of type " + "'bool', but is a 'str'" + ), + ), + # Multiple possible type error + ( + "autocommit", + "True", + UserWarning( + "'autocommit' connection parameter should be of type " + "'(NoneType, bool)', but is a 'str'" + ), + ), + ], +) +async def test_invalid_connection_parameter(db_parameters, name, value, exc_warn): + with warnings.catch_warnings(record=True) as w: + conn_params = { + "account": db_parameters["account"], + "user": db_parameters["user"], + "password": db_parameters["password"], + "schema": db_parameters["schema"], + "database": db_parameters["database"], + "protocol": db_parameters["protocol"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "validate_default_parameters": True, + name: value, + } + try: + conn = snowflake.connector.aio.SnowflakeConnection(**conn_params) + await conn.connect() + assert getattr(conn, "_" + name) == value + assert len(w) == 1 + assert str(w[0].message) == str(exc_warn) + finally: + await conn.close() + + +async def test_invalid_connection_parameters_turned_off(db_parameters): + """Makes sure parameter checking can be turned off.""" + with warnings.catch_warnings(record=True) as w: + conn_params = { + "account": db_parameters["account"], + "user": db_parameters["user"], + "password": db_parameters["password"], + "schema": db_parameters["schema"], + "database": db_parameters["database"], + "protocol": db_parameters["protocol"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "validate_default_parameters": False, + "autocommit": "True", # Wrong type + "applucation": "this is a typo or my own variable", # Wrong name + } + try: + conn = snowflake.connector.aio.SnowflakeConnection(**conn_params) + await conn.connect() + assert conn._autocommit == conn_params["autocommit"] + assert conn._applucation == conn_params["applucation"] + assert len(w) == 0 + finally: + await conn.close() + + +async def test_invalid_connection_parameters_only_warns(db_parameters): + """This test supresses warnings to only have warehouse, database and schema checking.""" + with warnings.catch_warnings(record=True) as w: + conn_params = { + "account": db_parameters["account"], + "user": db_parameters["user"], + "password": db_parameters["password"], + "schema": db_parameters["schema"], + "database": db_parameters["database"], + "protocol": db_parameters["protocol"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "validate_default_parameters": True, + "autocommit": "True", # Wrong type + "applucation": "this is a typo or my own variable", # Wrong name + } + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + conn = snowflake.connector.aio.SnowflakeConnection(**conn_params) + await conn.connect() + assert conn._autocommit == conn_params["autocommit"] + assert conn._applucation == conn_params["applucation"] + assert len(w) == 0 + finally: + await conn.close() + + +@pytest.mark.skipolddriver +async def test_region_deprecation(conn_cnx): + """Tests whether region raises a deprecation warning.""" + async with conn_cnx() as conn: + with warnings.catch_warnings(record=True) as w: + conn.region + assert len(w) == 1 + assert issubclass(w[0].category, PendingDeprecationWarning) + assert "Region has been deprecated" in str(w[0].message) + + +@pytest.mark.skip("SNOW-1763103") +async def test_invalid_errorhander_error(conn_cnx): + """Tests if no errorhandler cannot be set.""" + async with conn_cnx() as conn: + with pytest.raises(ProgrammingError, match="None errorhandler is specified"): + conn.errorhandler = None + original_handler = conn.errorhandler + conn.errorhandler = original_handler + assert conn.errorhandler is original_handler + + +async def test_disable_request_pooling_setter(conn_cnx): + """Tests whether request pooling can be set successfully.""" + async with conn_cnx() as conn: + original_value = conn.disable_request_pooling + conn.disable_request_pooling = not original_value + assert conn.disable_request_pooling == (not original_value) + conn.disable_request_pooling = original_value + assert conn.disable_request_pooling == original_value + + +async def test_autocommit_closed_already(conn_cnx): + """Test if setting autocommit on an already closed connection raised right error.""" + async with conn_cnx() as conn: + pass + with pytest.raises(DatabaseError, match=r"Connection is closed") as dbe: + await conn.autocommit(True) + assert dbe.errno == ER_CONNECTION_IS_CLOSED + + +async def test_autocommit_invalid_type(conn_cnx): + """Tests if setting autocommit on an already closed connection raised right error.""" + async with conn_cnx() as conn: + with pytest.raises(ProgrammingError, match=r"Invalid parameter: True") as dbe: + await conn.autocommit("True") + assert dbe.errno == ER_INVALID_VALUE + + +async def test_autocommit_unsupported(conn_cnx, caplog): + """Tests if server-side error is handled correctly when setting autocommit.""" + async with conn_cnx() as conn: + caplog.set_level(logging.DEBUG, "snowflake.connector") + with mock.patch( + "snowflake.connector.aio.SnowflakeCursor.execute", + side_effect=Error("Test error", sqlstate=SQLSTATE_FEATURE_NOT_SUPPORTED), + ): + await conn.autocommit(True) + assert ( + "snowflake.connector.aio._connection", + logging.DEBUG, + "Autocommit feature is not enabled for this connection. Ignored", + ) in caplog.record_tuples + + +async def test_sequence_counter(conn_cnx): + """Tests whether setting sequence counter and increasing it works as expected.""" + async with conn_cnx(sequence_counter=4) as conn: + assert conn.sequence_counter == 4 + async with conn.cursor() as cur: + assert await (await cur.execute("select 1 ")).fetchall() == [(1,)] + assert conn.sequence_counter == 5 + + +async def test_missing_account(conn_cnx): + """Test whether missing account raises the right exception.""" + with pytest.raises(ProgrammingError, match="Account must be specified") as pe: + async with conn_cnx(account=""): + pass + assert pe.errno == ER_NO_ACCOUNT_NAME + + +@pytest.mark.parametrize("resp", [None, {}]) +async def test_empty_response(conn_cnx, resp): + """Tests that cmd_query returns an empty response when empty/no response is recevided from back-end.""" + async with conn_cnx() as conn: + with mock.patch( + "snowflake.connector.aio._network.SnowflakeRestful.request", + return_value=resp, + ): + assert await conn.cmd_query("select 1", 0, uuid4()) == {"data": {}} + + +@pytest.mark.skipolddriver +async def test_authenticate_error(conn_cnx, caplog): + """Test Reauthenticate error handling while authenticating.""" + # The docs say unsafe should make this test work, but + # it doesn't seem to work on MagicMock + mock_auth = mock.Mock(spec=AuthByPlugin, unsafe=True) + mock_auth.prepare.return_value = mock_auth + mock_auth.update_body.side_effect = ReauthenticationRequest(None) + mock_auth._retry_ctx = mock.MagicMock() + async with conn_cnx() as conn: + caplog.set_level(logging.DEBUG, "snowflake.connector") + with pytest.raises(ReauthenticationRequest): + await conn.authenticate_with_retry(mock_auth) + assert ( + "snowflake.connector.aio._connection", + logging.DEBUG, + "ID token expired. Reauthenticating...: None", + ) in caplog.record_tuples + + +@pytest.mark.skipolddriver +async def test_process_qmark_params_error(conn_cnx): + """Tests errors thrown in _process_params_qmarks.""" + sql = "select 1;" + async with conn_cnx(paramstyle="qmark") as conn: + async with conn.cursor() as cur: + with pytest.raises( + ProgrammingError, + match="Binding parameters must be a list: invalid input", + ) as pe: + await cur.execute(sql, params="invalid input") + assert pe.value.errno == ER_FAILED_PROCESSING_PYFORMAT + with pytest.raises( + ProgrammingError, + match="Binding parameters must be a list where one element is a single " + "value or a pair of Snowflake datatype and a value", + ) as pe: + await cur.execute( + sql, + params=( + ( + 1, + 2, + 3, + ), + ), + ) + assert pe.value.errno == ER_FAILED_PROCESSING_QMARK + with pytest.raises( + ProgrammingError, + match=r"Python data type \[magicmock\] cannot be automatically mapped " + r"to Snowflake", + ) as pe: + await cur.execute(sql, params=[mock.MagicMock()]) + assert pe.value.errno == ER_NOT_IMPLICITY_SNOWFLAKE_DATATYPE + + +@pytest.mark.skipolddriver +async def test_process_param_dict_error(conn_cnx): + """Tests whether exceptions in __process_params_dict are handled correctly.""" + async with conn_cnx() as conn: + with pytest.raises( + ProgrammingError, match="Failed processing pyformat-parameters: test" + ) as pe: + with mock.patch( + "snowflake.connector.converter.SnowflakeConverter.to_snowflake", + side_effect=Exception("test"), + ): + conn._process_params_pyformat({"asd": "something"}) + assert pe.errno == ER_FAILED_PROCESSING_PYFORMAT + + +@pytest.mark.skipolddriver +async def test_process_param_error(conn_cnx): + """Tests whether exceptions in __process_params_dict are handled correctly.""" + async with conn_cnx() as conn: + with pytest.raises( + ProgrammingError, match="Failed processing pyformat-parameters; test" + ) as pe: + with mock.patch( + "snowflake.connector.converter.SnowflakeConverter.to_snowflake", + side_effect=Exception("test"), + ): + conn._process_params_pyformat(mock.Mock()) + assert pe.errno == ER_FAILED_PROCESSING_PYFORMAT + + +@pytest.mark.parametrize( + "auto_commit", [pytest.param(True, marks=pytest.mark.skipolddriver), False] +) +async def test_autocommit(conn_cnx, db_parameters, auto_commit): + conn = snowflake.connector.aio.SnowflakeConnection(**db_parameters) + with mock.patch.object(conn, "commit") as mocked_commit: + async with conn: + async with conn.cursor() as cur: + await cur.execute(f"alter session set autocommit = {auto_commit}") + if auto_commit: + assert not mocked_commit.called + else: + assert mocked_commit.called + + +@pytest.mark.skipolddriver +async def test_client_prefetch_threads_setting(conn_cnx): + """Tests whether client_prefetch_threads updated and is propagated to result set.""" + async with conn_cnx() as conn: + assert conn.client_prefetch_threads == DEFAULT_CLIENT_PREFETCH_THREADS + new_thread_count = conn.client_prefetch_threads + 1 + async with conn.cursor() as cur: + await cur.execute( + f"alter session set client_prefetch_threads={new_thread_count}" + ) + assert cur._result_set.prefetch_thread_num == new_thread_count + assert conn.client_prefetch_threads == new_thread_count + + +@pytest.mark.external +async def test_client_failover_connection_url(conn_cnx): + async with conn_cnx("client_failover") as conn: + async with conn.cursor() as cur: + assert await (await cur.execute("select 1;")).fetchall() == [ + (1,), + ] + + +async def test_connection_gc(conn_cnx): + """This test makes sure that a heartbeat thread doesn't prevent garbage collection of SnowflakeConnection.""" + conn = await conn_cnx(client_session_keep_alive=True).__aenter__() + conn_wref = weakref.ref(conn) + del conn + # this is different from sync test because we need to yield to give connection.close + # coroutine a chance to run all the teardown tasks + for _ in range(100): + await asyncio.sleep(0.01) + gc.collect() + assert conn_wref() is None + + +@pytest.mark.skipolddriver +async def test_connection_cant_be_reused(conn_cnx): + row_count = 50_000 + async with conn_cnx() as conn: + cursors = await conn.execute_string( + f"select seq4() as n from table(generator(rowcount => {row_count}));" + ) + assert len(cursors[0]._result_set.batches) > 1 # We need to have remote results + res = [] + async for result in cursors[0]: + res.append(result) + assert res + + +@pytest.mark.external +@pytest.mark.skipolddriver +async def test_ocsp_cache_working(conn_cnx): + """Verifies that the OCSP cache is functioning. + + The only way we can verify this is that the number of hits and misses increase. + """ + from snowflake.connector.ocsp_snowflake import OCSP_RESPONSE_VALIDATION_CACHE + + original_count = ( + OCSP_RESPONSE_VALIDATION_CACHE.telemetry["hit"] + + OCSP_RESPONSE_VALIDATION_CACHE.telemetry["miss"] + ) + async with conn_cnx() as cnx: + assert cnx + assert ( + OCSP_RESPONSE_VALIDATION_CACHE.telemetry["hit"] + + OCSP_RESPONSE_VALIDATION_CACHE.telemetry["miss"] + > original_count + ) + + +@pytest.mark.skipolddriver +async def test_imported_packages_telemetry( + conn_cnx, capture_sf_telemetry_async, db_parameters +): + # these imports are not used but for testing + import html.parser # noqa: F401 + import json # noqa: F401 + import multiprocessing as mp # noqa: F401 + from datetime import date # noqa: F401 + from math import sqrt # noqa: F401 + + def check_packages(message: str, expected_packages: list[str]) -> bool: + return ( + all([package in message for package in expected_packages]) + and "__main__" not in message + ) + + packages = [ + "pytest", + "unittest", + "json", + "multiprocessing", + "html", + "datetime", + "math", + ] + + async with conn_cnx() as conn, capture_sf_telemetry_async.patch_connection( + conn, False + ) as telemetry_test: + await conn._log_telemetry_imported_packages() + assert len(telemetry_test.records) > 0 + assert any( + [ + t.message[TelemetryField.KEY_TYPE.value] + == TelemetryField.IMPORTED_PACKAGES.value + and CLIENT_NAME == t.message[TelemetryField.KEY_SOURCE.value] + and check_packages(t.message["value"], packages) + for t in telemetry_test.records + ] + ) + + # test different application + new_application_name = "PythonSnowpark" + config = { + "user": db_parameters["user"], + "password": db_parameters["password"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "account": db_parameters["account"], + "schema": db_parameters["schema"], + "database": db_parameters["database"], + "protocol": db_parameters["protocol"], + "timezone": "UTC", + "application": new_application_name, + } + async with snowflake.connector.aio.SnowflakeConnection( + **config + ) as conn, capture_sf_telemetry_async.patch_connection( + conn, False + ) as telemetry_test: + await conn._log_telemetry_imported_packages() + assert len(telemetry_test.records) > 0 + assert any( + [ + t.message[TelemetryField.KEY_TYPE.value] + == TelemetryField.IMPORTED_PACKAGES.value + and new_application_name == t.message[TelemetryField.KEY_SOURCE.value] + for t in telemetry_test.records + ] + ) + + # test opt out + config["log_imported_packages_in_telemetry"] = False + async with snowflake.connector.aio.SnowflakeConnection( + **config + ) as conn, capture_sf_telemetry_async.patch_connection( + conn, False + ) as telemetry_test: + await conn._log_telemetry_imported_packages() + assert len(telemetry_test.records) == 0 + + +@pytest.mark.skipolddriver +async def test_disable_query_context_cache(conn_cnx) -> None: + async with conn_cnx(disable_query_context_cache=True) as conn: + # check that connector function correctly when query context + # cache is disabled + ret = await (await conn.cursor().execute("select 1")).fetchone() + assert ret == (1,) + assert conn.query_context_cache is None + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "mode", + ("file", "env"), +) +async def test_connection_name_loading(monkeypatch, db_parameters, tmp_path, mode): + import tomlkit + + doc = tomlkit.document() + default_con = tomlkit.table() + tmp_connections_file: None | pathlib.Path = None + try: + # If anything unexpected fails here, don't want to expose password + for k, v in db_parameters.items(): + default_con[k] = v + doc["default"] = default_con + with monkeypatch.context() as m: + if mode == "env": + m.setenv("SF_CONNECTIONS", tomlkit.dumps(doc)) + else: + tmp_connections_file = tmp_path / "connections.toml" + tmp_connections_file.write_text(tomlkit.dumps(doc)) + tmp_connections_file.chmod(stat.S_IRUSR | stat.S_IWUSR) + async with snowflake.connector.aio.SnowflakeConnection( + connection_name="default", + connections_file_path=tmp_connections_file, + ) as conn: + async with conn.cursor() as cur: + assert await (await cur.execute("select 1;")).fetchall() == [ + (1,), + ] + except Exception: + # This is my way of guaranteeing that we'll not expose the + # sensitive information that this test needs to handle. + # db_parameter contains passwords. + pytest.fail("something failed", pytrace=False) + + +@pytest.mark.skipolddriver +async def test_default_connection_name_loading(monkeypatch, db_parameters): + import tomlkit + + doc = tomlkit.document() + default_con = tomlkit.table() + try: + # If anything unexpected fails here, don't want to expose password + for k, v in db_parameters.items(): + default_con[k] = v + doc["default"] = default_con + with monkeypatch.context() as m: + m.setenv("SNOWFLAKE_CONNECTIONS", tomlkit.dumps(doc)) + m.setenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", "default") + async with snowflake.connector.aio.SnowflakeConnection() as conn: + async with conn.cursor() as cur: + assert await (await cur.execute("select 1;")).fetchall() == [ + (1,), + ] + except Exception: + # This is my way of guaranteeing that we'll not expose the + # sensitive information that this test needs to handle. + # db_parameter contains passwords. + pytest.fail("something failed", pytrace=False) + + +@pytest.mark.skipolddriver +async def test_not_found_connection_name(): + connection_name = random_string(5) + with pytest.raises( + Error, + match=f"Invalid connection_name '{connection_name}', known ones are", + ): + await snowflake.connector.aio.SnowflakeConnection( + connection_name=connection_name + ).connect() + + +@pytest.mark.skipolddriver +async def test_server_session_keep_alive(conn_cnx): + mock_delete_session = mock.MagicMock() + async with conn_cnx(server_session_keep_alive=True) as conn: + conn.rest.delete_session = mock_delete_session + mock_delete_session.assert_not_called() + + mock_delete_session = mock.MagicMock() + async with conn_cnx() as conn: + conn.rest.delete_session = mock_delete_session + mock_delete_session.assert_called_once() + + +@pytest.mark.skipolddriver +async def test_ocsp_mode_insecure(conn_cnx, is_public_test, caplog): + caplog.set_level(logging.DEBUG, "snowflake.connector.ocsp_snowflake") + async with conn_cnx(insecure_mode=True) as conn, conn.cursor() as cur: + assert await (await cur.execute("select 1")).fetchall() == [(1,)] + assert "snowflake.connector.ocsp_snowflake" not in caplog.text + caplog.clear() + + async with conn_cnx() as conn, conn.cursor() as cur: + assert await (await cur.execute("select 1")).fetchall() == [(1,)] + if is_public_test: + assert "snowflake.connector.ocsp_snowflake" in caplog.text + else: + assert "snowflake.connector.ocsp_snowflake" not in caplog.text + + +@pytest.mark.skipolddriver +def test_connection_atexit_close(db_parameters): + """Basic Connection test without schema.""" + conn = snowflake.connector.aio.SnowflakeConnection(**db_parameters) + + async def func(): + await conn.connect() + return conn + + conn = asyncio.run(func()) + conn._close_at_exit() + assert conn.is_closed() + + +@pytest.mark.skipolddriver +async def test_token_file_path(tmp_path, db_parameters): + fake_token = "some token" + token_file_path = tmp_path / "token" + with open(token_file_path, "w") as f: + f.write(fake_token) + + conn = snowflake.connector.aio.SnowflakeConnection( + **db_parameters, token=fake_token + ) + await conn.connect() + assert conn._token == fake_token + conn = snowflake.connector.aio.SnowflakeConnection( + **db_parameters, token_file_path=token_file_path + ) + await conn.connect() + assert conn._token == fake_token + + +@pytest.mark.skipolddriver +@pytest.mark.skipif(not RUNNING_ON_GH, reason="no ocsp in the environment") +async def test_mock_non_existing_server(conn_cnx, caplog): + from snowflake.connector.cache import SFDictCache + + # disabling local cache and pointing ocsp cache server to a non-existing url + # connection should still work as it will directly validate the certs against CA servers + with tempfile.NamedTemporaryFile() as tmp, caplog.at_level(logging.DEBUG): + with mock.patch( + "snowflake.connector.url_util.extract_top_level_domain_from_hostname", + return_value="nonexistingtopleveldomain", + ): + with mock.patch( + "snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE", + SFDictCache(), + ): + with mock.patch( + "snowflake.connector.ocsp_snowflake.OCSPCache.OCSP_RESPONSE_CACHE_FILE_NAME", + tmp.name, + ): + async with conn_cnx(): + pass + assert all( + s in caplog.text + for s in [ + "Failed to read OCSP response cache file", + "It will validate with OCSP server.", + "writing OCSP response cache file to", + ] + ) + + +@pytest.mark.xfail( + reason="TODO: SNOW-1759084 await anext(self._generator, None) does not execute code after yield" +) +async def test_disable_telemetry(conn_cnx, caplog): + # default behavior, closing connection, it will send telemetry + with caplog.at_level(logging.DEBUG): + async with conn_cnx() as conn: + async with conn.cursor() as cur: + await (await cur.execute("select 1")).fetchall() + assert ( + len(conn._telemetry._log_batch) == 3 + ) # 3 events are `import package`, `fetch first`, it's missing `fetch last` because of SNOW-1759084 + + assert "POST /telemetry/send" in caplog.text + caplog.clear() + + # set session parameters to false + with caplog.at_level(logging.DEBUG): + async with conn_cnx( + session_parameters={"CLIENT_TELEMETRY_ENABLED": False} + ) as conn, conn.cursor() as cur: + await (await cur.execute("select 1")).fetchall() + assert not conn.telemetry_enabled and not conn._telemetry._log_batch + # this enable won't work as the session parameter is set to false + conn.telemetry_enabled = True + await (await cur.execute("select 1")).fetchall() + assert not conn.telemetry_enabled and not conn._telemetry._log_batch + + assert "POST /telemetry/send" not in caplog.text + caplog.clear() + + # test disable telemetry in the client + with caplog.at_level(logging.DEBUG): + async with conn_cnx() as conn: + assert conn.telemetry_enabled and len(conn._telemetry._log_batch) == 1 + conn.telemetry_enabled = False + async with conn.cursor() as cur: + await (await cur.execute("select 1")).fetchall() + assert not conn.telemetry_enabled + assert "POST /telemetry/send" not in caplog.text diff --git a/test/integ/aio/test_converter_async.py b/test/integ/aio/test_converter_async.py new file mode 100644 index 0000000000..4ab9216721 --- /dev/null +++ b/test/integ/aio/test_converter_async.py @@ -0,0 +1,526 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from datetime import time +from test.integ.test_converter import _compose_ltz, _compose_ntz, _compose_tz + +import pytest + +from snowflake.connector.compat import IS_WINDOWS +from snowflake.connector.converter import _generate_tzinfo_from_tzoffset +from snowflake.connector.converter_snowsql import SnowflakeConverterSnowSQL + + +async def test_fetch_timestamps(conn_cnx): + PST_TZ = "America/Los_Angeles" + + tzdiff = 1860 - 1440 # -07:00 + tzinfo = _generate_tzinfo_from_tzoffset(tzdiff) + + # TIMESTAMP_TZ + r0 = _compose_tz("1325568896.123456", tzinfo) + r1 = _compose_tz("1325568896.123456", tzinfo) + r2 = _compose_tz("1325568896.123456", tzinfo) + r3 = _compose_tz("1325568896.123456", tzinfo) + r4 = _compose_tz("1325568896.12345", tzinfo) + r5 = _compose_tz("1325568896.1234", tzinfo) + r6 = _compose_tz("1325568896.123", tzinfo) + r7 = _compose_tz("1325568896.12", tzinfo) + r8 = _compose_tz("1325568896.1", tzinfo) + r9 = _compose_tz("1325568896", tzinfo) + + # TIMESTAMP_NTZ + r10 = _compose_ntz("1325568896.123456") + r11 = _compose_ntz("1325568896.123456") + r12 = _compose_ntz("1325568896.123456") + r13 = _compose_ntz("1325568896.123456") + r14 = _compose_ntz("1325568896.12345") + r15 = _compose_ntz("1325568896.1234") + r16 = _compose_ntz("1325568896.123") + r17 = _compose_ntz("1325568896.12") + r18 = _compose_ntz("1325568896.1") + r19 = _compose_ntz("1325568896") + + # TIMESTAMP_LTZ + r20 = _compose_ltz("1325568896.123456", PST_TZ) + r21 = _compose_ltz("1325568896.123456", PST_TZ) + r22 = _compose_ltz("1325568896.123456", PST_TZ) + r23 = _compose_ltz("1325568896.123456", PST_TZ) + r24 = _compose_ltz("1325568896.12345", PST_TZ) + r25 = _compose_ltz("1325568896.1234", PST_TZ) + r26 = _compose_ltz("1325568896.123", PST_TZ) + r27 = _compose_ltz("1325568896.12", PST_TZ) + r28 = _compose_ltz("1325568896.1", PST_TZ) + r29 = _compose_ltz("1325568896", PST_TZ) + + # TIME + r30 = time(5, 7, 8, 123456) + r31 = time(5, 7, 8, 123456) + r32 = time(5, 7, 8, 123456) + r33 = time(5, 7, 8, 123456) + r34 = time(5, 7, 8, 123450) + r35 = time(5, 7, 8, 123400) + r36 = time(5, 7, 8, 123000) + r37 = time(5, 7, 8, 120000) + r38 = time(5, 7, 8, 100000) + r39 = time(5, 7, 8) + + async with conn_cnx() as cnx: + cur = cnx.cursor() + await cur.execute( + """ +ALTER SESSION SET TIMEZONE='{tz}'; +""".format( + tz=PST_TZ + ) + ) + await cur.execute( + """ +SELECT + '2012-01-03 12:34:56.123456789+07:00'::timestamp_tz(9), + '2012-01-03 12:34:56.12345678+07:00'::timestamp_tz(8), + '2012-01-03 12:34:56.1234567+07:00'::timestamp_tz(7), + '2012-01-03 12:34:56.123456+07:00'::timestamp_tz(6), + '2012-01-03 12:34:56.12345+07:00'::timestamp_tz(5), + '2012-01-03 12:34:56.1234+07:00'::timestamp_tz(4), + '2012-01-03 12:34:56.123+07:00'::timestamp_tz(3), + '2012-01-03 12:34:56.12+07:00'::timestamp_tz(2), + '2012-01-03 12:34:56.1+07:00'::timestamp_tz(1), + '2012-01-03 12:34:56+07:00'::timestamp_tz(0), + '2012-01-03 05:34:56.123456789'::timestamp_ntz(9), + '2012-01-03 05:34:56.12345678'::timestamp_ntz(8), + '2012-01-03 05:34:56.1234567'::timestamp_ntz(7), + '2012-01-03 05:34:56.123456'::timestamp_ntz(6), + '2012-01-03 05:34:56.12345'::timestamp_ntz(5), + '2012-01-03 05:34:56.1234'::timestamp_ntz(4), + '2012-01-03 05:34:56.123'::timestamp_ntz(3), + '2012-01-03 05:34:56.12'::timestamp_ntz(2), + '2012-01-03 05:34:56.1'::timestamp_ntz(1), + '2012-01-03 05:34:56'::timestamp_ntz(0), + '2012-01-02 21:34:56.123456789'::timestamp_ltz(9), + '2012-01-02 21:34:56.12345678'::timestamp_ltz(8), + '2012-01-02 21:34:56.1234567'::timestamp_ltz(7), + '2012-01-02 21:34:56.123456'::timestamp_ltz(6), + '2012-01-02 21:34:56.12345'::timestamp_ltz(5), + '2012-01-02 21:34:56.1234'::timestamp_ltz(4), + '2012-01-02 21:34:56.123'::timestamp_ltz(3), + '2012-01-02 21:34:56.12'::timestamp_ltz(2), + '2012-01-02 21:34:56.1'::timestamp_ltz(1), + '2012-01-02 21:34:56'::timestamp_ltz(0), + '05:07:08.123456789'::time(9), + '05:07:08.12345678'::time(8), + '05:07:08.1234567'::time(7), + '05:07:08.123456'::time(6), + '05:07:08.12345'::time(5), + '05:07:08.1234'::time(4), + '05:07:08.123'::time(3), + '05:07:08.12'::time(2), + '05:07:08.1'::time(1), + '05:07:08'::time(0) +""" + ) + ret = await cur.fetchone() + assert ret[0] == r0 + assert ret[1] == r1 + assert ret[2] == r2 + assert ret[3] == r3 + assert ret[4] == r4 + assert ret[5] == r5 + assert ret[6] == r6 + assert ret[7] == r7 + assert ret[8] == r8 + assert ret[9] == r9 + assert ret[10] == r10 + assert ret[11] == r11 + assert ret[12] == r12 + assert ret[13] == r13 + assert ret[14] == r14 + assert ret[15] == r15 + assert ret[16] == r16 + assert ret[17] == r17 + assert ret[18] == r18 + assert ret[19] == r19 + assert ret[20] == r20 + assert ret[21] == r21 + assert ret[22] == r22 + assert ret[23] == r23 + assert ret[24] == r24 + assert ret[25] == r25 + assert ret[26] == r26 + assert ret[27] == r27 + assert ret[28] == r28 + assert ret[29] == r29 + assert ret[30] == r30 + assert ret[31] == r31 + assert ret[32] == r32 + assert ret[33] == r33 + assert ret[34] == r34 + assert ret[35] == r35 + assert ret[36] == r36 + assert ret[37] == r37 + assert ret[38] == r38 + assert ret[39] == r39 + + +async def test_fetch_timestamps_snowsql(conn_cnx): + PST_TZ = "America/Los_Angeles" + + converter_class = SnowflakeConverterSnowSQL + sql = """ +SELECT + '2012-01-03 12:34:56.123456789+07:00'::timestamp_tz(9), + '2012-01-03 12:34:56.12345678+07:00'::timestamp_tz(8), + '2012-01-03 12:34:56.1234567+07:00'::timestamp_tz(7), + '2012-01-03 12:34:56.123456+07:00'::timestamp_tz(6), + '2012-01-03 12:34:56.12345+07:00'::timestamp_tz(5), + '2012-01-03 12:34:56.1234+07:00'::timestamp_tz(4), + '2012-01-03 12:34:56.123+07:00'::timestamp_tz(3), + '2012-01-03 12:34:56.12+07:00'::timestamp_tz(2), + '2012-01-03 12:34:56.1+07:00'::timestamp_tz(1), + '2012-01-03 12:34:56+07:00'::timestamp_tz(0), + '2012-01-03 05:34:56.123456789'::timestamp_ntz(9), + '2012-01-03 05:34:56.12345678'::timestamp_ntz(8), + '2012-01-03 05:34:56.1234567'::timestamp_ntz(7), + '2012-01-03 05:34:56.123456'::timestamp_ntz(6), + '2012-01-03 05:34:56.12345'::timestamp_ntz(5), + '2012-01-03 05:34:56.1234'::timestamp_ntz(4), + '2012-01-03 05:34:56.123'::timestamp_ntz(3), + '2012-01-03 05:34:56.12'::timestamp_ntz(2), + '2012-01-03 05:34:56.1'::timestamp_ntz(1), + '2012-01-03 05:34:56'::timestamp_ntz(0), + '2012-01-02 21:34:56.123456789'::timestamp_ltz(9), + '2012-01-02 21:34:56.12345678'::timestamp_ltz(8), + '2012-01-02 21:34:56.1234567'::timestamp_ltz(7), + '2012-01-02 21:34:56.123456'::timestamp_ltz(6), + '2012-01-02 21:34:56.12345'::timestamp_ltz(5), + '2012-01-02 21:34:56.1234'::timestamp_ltz(4), + '2012-01-02 21:34:56.123'::timestamp_ltz(3), + '2012-01-02 21:34:56.12'::timestamp_ltz(2), + '2012-01-02 21:34:56.1'::timestamp_ltz(1), + '2012-01-02 21:34:56'::timestamp_ltz(0), + '05:07:08.123456789'::time(9), + '05:07:08.12345678'::time(8), + '05:07:08.1234567'::time(7), + '05:07:08.123456'::time(6), + '05:07:08.12345'::time(5), + '05:07:08.1234'::time(4), + '05:07:08.123'::time(3), + '05:07:08.12'::time(2), + '05:07:08.1'::time(1), + '05:07:08'::time(0) +""" + async with conn_cnx(converter_class=converter_class) as cnx: + cur = cnx.cursor() + await cur.execute( + """ +alter session set python_connector_query_result_format='JSON' +""" + ) + await cur.execute( + """ +ALTER SESSION SET TIMEZONE='{tz}'; +""".format( + tz=PST_TZ + ) + ) + await cur.execute( + """ +ALTER SESSION SET + TIMESTAMP_OUTPUT_FORMAT='YYYY-MM-DD HH24:MI:SS.FF9 TZH:TZM', + TIMESTAMP_NTZ_OUTPUT_FORMAT='YYYY-MM-DD HH24:MI:SS.FF9 TZH:TZM', + TIME_OUTPUT_FORMAT='HH24:MI:SS.FF9'; + """ + ) + await cur.execute(sql) + ret = await cur.fetchone() + assert ret[0] == "2012-01-03 12:34:56.123456789 +0700" + assert ret[1] == "2012-01-03 12:34:56.123456780 +0700" + assert ret[2] == "2012-01-03 12:34:56.123456700 +0700" + assert ret[3] == "2012-01-03 12:34:56.123456000 +0700" + assert ret[4] == "2012-01-03 12:34:56.123450000 +0700" + assert ret[5] == "2012-01-03 12:34:56.123400000 +0700" + assert ret[6] == "2012-01-03 12:34:56.123000000 +0700" + assert ret[7] == "2012-01-03 12:34:56.120000000 +0700" + assert ret[8] == "2012-01-03 12:34:56.100000000 +0700" + assert ret[9] == "2012-01-03 12:34:56.000000000 +0700" + assert ret[10] == "2012-01-03 05:34:56.123456789 " + assert ret[11] == "2012-01-03 05:34:56.123456780 " + assert ret[12] == "2012-01-03 05:34:56.123456700 " + assert ret[13] == "2012-01-03 05:34:56.123456000 " + assert ret[14] == "2012-01-03 05:34:56.123450000 " + assert ret[15] == "2012-01-03 05:34:56.123400000 " + assert ret[16] == "2012-01-03 05:34:56.123000000 " + assert ret[17] == "2012-01-03 05:34:56.120000000 " + assert ret[18] == "2012-01-03 05:34:56.100000000 " + assert ret[19] == "2012-01-03 05:34:56.000000000 " + assert ret[20] == "2012-01-02 21:34:56.123456789 -0800" + assert ret[21] == "2012-01-02 21:34:56.123456780 -0800" + assert ret[22] == "2012-01-02 21:34:56.123456700 -0800" + assert ret[23] == "2012-01-02 21:34:56.123456000 -0800" + assert ret[24] == "2012-01-02 21:34:56.123450000 -0800" + assert ret[25] == "2012-01-02 21:34:56.123400000 -0800" + assert ret[26] == "2012-01-02 21:34:56.123000000 -0800" + assert ret[27] == "2012-01-02 21:34:56.120000000 -0800" + assert ret[28] == "2012-01-02 21:34:56.100000000 -0800" + assert ret[29] == "2012-01-02 21:34:56.000000000 -0800" + assert ret[30] == "05:07:08.123456789" + assert ret[31] == "05:07:08.123456780" + assert ret[32] == "05:07:08.123456700" + assert ret[33] == "05:07:08.123456000" + assert ret[34] == "05:07:08.123450000" + assert ret[35] == "05:07:08.123400000" + assert ret[36] == "05:07:08.123000000" + assert ret[37] == "05:07:08.120000000" + assert ret[38] == "05:07:08.100000000" + assert ret[39] == "05:07:08.000000000" + + await cur.execute( + """ +ALTER SESSION SET + TIMESTAMP_OUTPUT_FORMAT='YYYY-MM-DD HH24:MI:SS.FF6 TZH:TZM', + TIMESTAMP_NTZ_OUTPUT_FORMAT='YYYY-MM-DD HH24:MI:SS.FF6 TZH:TZM', + TIME_OUTPUT_FORMAT='HH24:MI:SS.FF6'; + """ + ) + await cur.execute(sql) + ret = await cur.fetchone() + assert ret[0] == "2012-01-03 12:34:56.123456 +0700" + assert ret[1] == "2012-01-03 12:34:56.123456 +0700" + assert ret[2] == "2012-01-03 12:34:56.123456 +0700" + assert ret[3] == "2012-01-03 12:34:56.123456 +0700" + assert ret[4] == "2012-01-03 12:34:56.123450 +0700" + assert ret[5] == "2012-01-03 12:34:56.123400 +0700" + assert ret[6] == "2012-01-03 12:34:56.123000 +0700" + assert ret[7] == "2012-01-03 12:34:56.120000 +0700" + assert ret[8] == "2012-01-03 12:34:56.100000 +0700" + assert ret[9] == "2012-01-03 12:34:56.000000 +0700" + assert ret[10] == "2012-01-03 05:34:56.123456 " + assert ret[11] == "2012-01-03 05:34:56.123456 " + assert ret[12] == "2012-01-03 05:34:56.123456 " + assert ret[13] == "2012-01-03 05:34:56.123456 " + assert ret[14] == "2012-01-03 05:34:56.123450 " + assert ret[15] == "2012-01-03 05:34:56.123400 " + assert ret[16] == "2012-01-03 05:34:56.123000 " + assert ret[17] == "2012-01-03 05:34:56.120000 " + assert ret[18] == "2012-01-03 05:34:56.100000 " + assert ret[19] == "2012-01-03 05:34:56.000000 " + assert ret[20] == "2012-01-02 21:34:56.123456 -0800" + assert ret[21] == "2012-01-02 21:34:56.123456 -0800" + assert ret[22] == "2012-01-02 21:34:56.123456 -0800" + assert ret[23] == "2012-01-02 21:34:56.123456 -0800" + assert ret[24] == "2012-01-02 21:34:56.123450 -0800" + assert ret[25] == "2012-01-02 21:34:56.123400 -0800" + assert ret[26] == "2012-01-02 21:34:56.123000 -0800" + assert ret[27] == "2012-01-02 21:34:56.120000 -0800" + assert ret[28] == "2012-01-02 21:34:56.100000 -0800" + assert ret[29] == "2012-01-02 21:34:56.000000 -0800" + assert ret[30] == "05:07:08.123456" + assert ret[31] == "05:07:08.123456" + assert ret[32] == "05:07:08.123456" + assert ret[33] == "05:07:08.123456" + assert ret[34] == "05:07:08.123450" + assert ret[35] == "05:07:08.123400" + assert ret[36] == "05:07:08.123000" + assert ret[37] == "05:07:08.120000" + assert ret[38] == "05:07:08.100000" + assert ret[39] == "05:07:08.000000" + + +async def test_fetch_timestamps_negative_epoch(conn_cnx): + """Negative epoch.""" + r0 = _compose_ntz("-602594703.876544") + r1 = _compose_ntz("1325594096.123456") + async with conn_cnx() as cnx: + cur = cnx.cursor() + await cur.execute( + """\ +SELECT + '1950-11-27 12:34:56.123456'::timestamp_ntz(6), + '2012-01-03 12:34:56.123456'::timestamp_ntz(6) +""" + ) + ret = await cur.fetchone() + assert ret[0] == r0 + assert ret[1] == r1 + + +async def test_date_0001_9999(conn_cnx): + """Test 0001 and 9999 for all platforms.""" + async with conn_cnx( + converter_class=SnowflakeConverterSnowSQL, support_negative_year=True + ) as cnx: + await cnx.cursor().execute( + """ +ALTER SESSION SET + DATE_OUTPUT_FORMAT='YYYY-MM-DD' +""" + ) + cur = cnx.cursor() + await cur.execute( + """ +alter session set python_connector_query_result_format='JSON' +""" + ) + await cur.execute( + """ +SELECT + DATE_FROM_PARTS(1900, 1, 1), + DATE_FROM_PARTS(2500, 2, 3), + DATE_FROM_PARTS(1, 10, 31), + DATE_FROM_PARTS(9999, 3, 20) + ; +""" + ) + ret = await cur.fetchone() + assert ret[0] == "1900-01-01" + assert ret[1] == "2500-02-03" + assert ret[2] == "0001-10-31" + assert ret[3] == "9999-03-20" + + +@pytest.mark.skipif(IS_WINDOWS, reason="year out of range error") +async def test_five_or_more_digit_year_date_converter(conn_cnx): + """Past and future dates.""" + async with conn_cnx( + converter_class=SnowflakeConverterSnowSQL, support_negative_year=True + ) as cnx: + await cnx.cursor().execute( + """ +ALTER SESSION SET + DATE_OUTPUT_FORMAT='YYYY-MM-DD' +""" + ) + cur = cnx.cursor() + await cur.execute( + """ +alter session set python_connector_query_result_format='JSON' +""" + ) + await cur.execute( + """ +SELECT + DATE_FROM_PARTS(10000, 1, 1), + DATE_FROM_PARTS(-0001, 2, 5), + DATE_FROM_PARTS(56789, 3, 4), + DATE_FROM_PARTS(198765, 4, 3), + DATE_FROM_PARTS(-234567, 5, 2) + ; +""" + ) + ret = await cur.fetchone() + assert ret[0] == "10000-01-01" + assert ret[1] == "-0001-02-05" + assert ret[2] == "56789-03-04" + assert ret[3] == "198765-04-03" + assert ret[4] == "-234567-05-02" + + await cnx.cursor().execute( + """ +ALTER SESSION SET + DATE_OUTPUT_FORMAT='YY-MM-DD' +""" + ) + cur = cnx.cursor() + await cur.execute( + """ +SELECT + DATE_FROM_PARTS(10000, 1, 1), + DATE_FROM_PARTS(-0001, 2, 5), + DATE_FROM_PARTS(56789, 3, 4), + DATE_FROM_PARTS(198765, 4, 3), + DATE_FROM_PARTS(-234567, 5, 2) + ; +""" + ) + ret = await cur.fetchone() + assert ret[0] == "00-01-01" + assert ret[1] == "-01-02-05" + assert ret[2] == "89-03-04" + assert ret[3] == "65-04-03" + assert ret[4] == "-67-05-02" + + +async def test_franction_followed_by_year_format(conn_cnx): + """Both year and franctions are included but fraction shows up followed by year.""" + async with conn_cnx(converter_class=SnowflakeConverterSnowSQL) as cnx: + await cnx.cursor().execute( + """ +alter session set python_connector_query_result_format='JSON' +""" + ) + await cnx.cursor().execute( + """ +ALTER SESSION SET + TIMESTAMP_OUTPUT_FORMAT='HH24:MI:SS.FF6 MON DD, YYYY', + TIMESTAMP_NTZ_OUTPUT_FORMAT='HH24:MI:SS.FF6 MON DD, YYYY' +""" + ) + async for rec in await cnx.cursor().execute( + """ +SELECT + '2012-01-03 05:34:56.123456'::TIMESTAMP_NTZ(6) +""" + ): + assert rec[0] == "05:34:56.123456 Jan 03, 2012" + + +async def test_fetch_fraction_timestamp(conn_cnx): + """Additional fetch timestamp tests. Mainly used for SnowSQL which converts to string representations.""" + PST_TZ = "America/Los_Angeles" + + converter_class = SnowflakeConverterSnowSQL + sql = """ +SELECT + '1900-01-01T05:00:00.000Z'::timestamp_tz(7), + '1900-01-01T05:00:00.000'::timestamp_ntz(7), + '1900-01-01T05:00:01.000Z'::timestamp_tz(7), + '1900-01-01T05:00:01.000'::timestamp_ntz(7), + '1900-01-01T05:00:01.012Z'::timestamp_tz(7), + '1900-01-01T05:00:01.012'::timestamp_ntz(7), + '1900-01-01T05:00:00.012Z'::timestamp_tz(7), + '1900-01-01T05:00:00.012'::timestamp_ntz(7), + '2100-01-01T05:00:00.012Z'::timestamp_tz(7), + '2100-01-01T05:00:00.012'::timestamp_ntz(7), + '1970-01-01T00:00:00Z'::timestamp_tz(7), + '1970-01-01T00:00:00'::timestamp_ntz(7) +""" + async with conn_cnx(converter_class=converter_class) as cnx: + cur = cnx.cursor() + await cur.execute( + """ +alter session set python_connector_query_result_format='JSON' +""" + ) + await cur.execute( + """ +ALTER SESSION SET TIMEZONE='{tz}'; +""".format( + tz=PST_TZ + ) + ) + await cur.execute( + """ +ALTER SESSION SET + TIMESTAMP_OUTPUT_FORMAT='YYYY-MM-DD HH24:MI:SS.FF9 TZH:TZM', + TIMESTAMP_NTZ_OUTPUT_FORMAT='YYYY-MM-DD HH24:MI:SS.FF9', + TIME_OUTPUT_FORMAT='HH24:MI:SS.FF9'; + """ + ) + await cur.execute(sql) + ret = await cur.fetchone() + assert ret[0] == "1900-01-01 05:00:00.000000000 +0000" + assert ret[1] == "1900-01-01 05:00:00.000000000" + assert ret[2] == "1900-01-01 05:00:01.000000000 +0000" + assert ret[3] == "1900-01-01 05:00:01.000000000" + assert ret[4] == "1900-01-01 05:00:01.012000000 +0000" + assert ret[5] == "1900-01-01 05:00:01.012000000" + assert ret[6] == "1900-01-01 05:00:00.012000000 +0000" + assert ret[7] == "1900-01-01 05:00:00.012000000" + assert ret[8] == "2100-01-01 05:00:00.012000000 +0000" + assert ret[9] == "2100-01-01 05:00:00.012000000" + assert ret[10] == "1970-01-01 00:00:00.000000000 +0000" + assert ret[11] == "1970-01-01 00:00:00.000000000" diff --git a/test/integ/aio/test_converter_more_timestamp_async.py b/test/integ/aio/test_converter_more_timestamp_async.py new file mode 100644 index 0000000000..e8316e4807 --- /dev/null +++ b/test/integ/aio/test_converter_more_timestamp_async.py @@ -0,0 +1,133 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from datetime import datetime, timedelta + +import pytz +from dateutil.parser import parse + +from snowflake.connector.converter import ZERO_EPOCH, _generate_tzinfo_from_tzoffset + + +async def test_fetch_various_timestamps(conn_cnx): + """More coverage of timestamp. + + Notes: + Currently TIMESTAMP_LTZ is not tested. + """ + PST_TZ = "America/Los_Angeles" + epoch_times = ["1325568896", "-2208943503", "0", "-1"] + timezones = ["+07:00", "+00:00", "-01:00", "-09:00"] + fractions = "123456789" + data_types = ["TIMESTAMP_TZ", "TIMESTAMP_NTZ"] + + data = [] + for dt in data_types: + for et in epoch_times: + if dt == "TIMESTAMP_TZ": + for tz in timezones: + tzdiff = (int(tz[1:3]) * 60 + int(tz[4:6])) * ( + -1 if tz[0] == "-" else 1 + ) + tzinfo = _generate_tzinfo_from_tzoffset(tzdiff) + try: + ts = datetime.fromtimestamp(float(et), tz=tzinfo) + except (OSError, ValueError): + ts = ZERO_EPOCH + timedelta(seconds=float(et)) + if pytz.utc != tzinfo: + ts += tzinfo.utcoffset(ts) + ts = ts.replace(tzinfo=tzinfo) + data.append( + { + "scale": 0, + "dt": dt, + "inp": ts.strftime(f"%Y-%m-%d %H:%M:%S{tz}"), + "out": ts, + } + ) + for idx in range(len(fractions)): + scale = idx + 1 + if idx + 1 != 6: # SNOW-28597 + try: + ts0 = datetime.fromtimestamp(float(et), tz=tzinfo) + except (OSError, ValueError): + ts0 = ZERO_EPOCH + timedelta(seconds=float(et)) + if pytz.utc != tzinfo: + ts0 += tzinfo.utcoffset(ts0) + ts0 = ts0.replace(tzinfo=tzinfo) + ts0_str = ts0.strftime( + "%Y-%m-%d %H:%M:%S.{ff}{tz}".format( + ff=fractions[: idx + 1], tz=tz + ) + ) + ts1 = parse(ts0_str) + data.append( + {"scale": scale, "dt": dt, "inp": ts0_str, "out": ts1} + ) + elif dt == "TIMESTAMP_LTZ": + # WIP. this test work in edge case + tzinfo = pytz.timezone(PST_TZ) + ts0 = datetime.fromtimestamp(float(et)) + ts0 = pytz.utc.localize(ts0).astimezone(tzinfo) + ts0_str = ts0.strftime("%Y-%m-%d %H:%M:%S") + ts1 = ts0 + data.append({"scale": 0, "dt": dt, "inp": ts0_str, "out": ts1}) + for idx in range(len(fractions)): + ts0 = datetime.fromtimestamp(float(et)) + ts0 = pytz.utc.localize(ts0).astimezone(tzinfo) + ts0_str = ts0.strftime(f"%Y-%m-%d %H:%M:%S.{fractions[: idx + 1]}") + ts1 = ts0 + timedelta(seconds=float(f"0.{fractions[: idx + 1]}")) + data.append( + {"scale": idx + 1, "dt": dt, "inp": ts0_str, "out": ts1} + ) + else: + # TIMESTAMP_NTZ + try: + ts0 = datetime.fromtimestamp(float(et)) + except (OSError, ValueError): + ts0 = ZERO_EPOCH + timedelta(seconds=(float(et))) + ts0_str = ts0.strftime("%Y-%m-%d %H:%M:%S") + ts1 = parse(ts0_str) + data.append({"scale": 0, "dt": dt, "inp": ts0_str, "out": ts1}) + for idx in range(len(fractions)): + try: + ts0 = datetime.fromtimestamp(float(et)) + except (OSError, ValueError): + ts0 = ZERO_EPOCH + timedelta(seconds=(float(et))) + ts0_str = ts0.strftime(f"%Y-%m-%d %H:%M:%S.{fractions[: idx + 1]}") + ts1 = parse(ts0_str) + data.append( + {"scale": idx + 1, "dt": dt, "inp": ts0_str, "out": ts1} + ) + sql = "SELECT " + for d in data: + sql += "'{inp}'::{dt}({scale}), ".format( + inp=d["inp"], dt=d["dt"], scale=d["scale"] + ) + sql += "1" + async with conn_cnx() as cnx: + cur = cnx.cursor() + await cur.execute( + """ +ALTER SESSION SET TIMEZONE='{tz}'; +""".format( + tz=PST_TZ + ) + ) + rec = await (await cur.execute(sql)).fetchone() + for idx, d in enumerate(data): + comp, lower, higher = _in_range(d["out"], rec[idx]) + assert ( + comp + ), "data: {d}: target={target}, lower={lower}, higher={" "higher}".format( + d=d, target=rec[idx], lower=lower, higher=higher + ) + + +def _in_range(reference, target): + lower = reference - timedelta(microseconds=1) + higher = reference + timedelta(microseconds=1) + return lower <= target <= higher, lower, higher diff --git a/test/integ/aio/test_converter_null_async.py b/test/integ/aio/test_converter_null_async.py new file mode 100644 index 0000000000..4da319ed9d --- /dev/null +++ b/test/integ/aio/test_converter_null_async.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from test.integ.test_converter_null import NUMERIC_VALUES + +import snowflake.connector.aio +from snowflake.connector.converter import ZERO_EPOCH +from snowflake.connector.converter_null import SnowflakeNoConverterToPython + + +async def test_converter_no_converter_to_python(db_parameters): + """Tests no converter. + + This should not translate the Snowflake internal data representation to the Python native types. + """ + async with snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + database=db_parameters["database"], + schema=db_parameters["schema"], + protocol=db_parameters["protocol"], + timezone="UTC", + converter_class=SnowflakeNoConverterToPython, + ) as con: + await con.cursor().execute( + """ + alter session set python_connector_query_result_format='JSON' + """ + ) + + ret = await ( + await con.cursor().execute( + """ + select current_timestamp(), + 1::NUMBER, + 2.0::FLOAT, + 'test1' + """ + ) + ).fetchone() + assert isinstance(ret[0], str) + assert NUMERIC_VALUES.match(ret[0]) + assert isinstance(ret[1], str) + assert NUMERIC_VALUES.match(ret[1]) + await con.cursor().execute( + "create or replace table testtb(c1 timestamp_ntz(6))" + ) + try: + current_time = datetime.now(timezone.utc).replace(tzinfo=None) + # binding value should have no impact + await con.cursor().execute( + "insert into testtb(c1) values(%s)", (current_time,) + ) + ret = ( + await (await con.cursor().execute("select * from testtb")).fetchone() + )[0] + assert ZERO_EPOCH + timedelta(seconds=(float(ret))) == current_time + finally: + await con.cursor().execute("drop table if exists testtb") diff --git a/test/integ/aio/test_cursor_async.py b/test/integ/aio/test_cursor_async.py new file mode 100644 index 0000000000..660cb572b0 --- /dev/null +++ b/test/integ/aio/test_cursor_async.py @@ -0,0 +1,1788 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import decimal +import json +import logging +import os +import pickle +import time +from datetime import date, datetime, timezone +from unittest import mock + +import pytest +import pytz + +import snowflake.connector +import snowflake.connector.aio +from snowflake.connector import ( + InterfaceError, + NotSupportedError, + ProgrammingError, + constants, + errorcode, + errors, +) +from snowflake.connector.aio import DictCursor, SnowflakeCursor, _connection +from snowflake.connector.aio._result_batch import ( + ArrowResultBatch, + JSONResultBatch, + ResultBatch, +) +from snowflake.connector.compat import IS_WINDOWS +from snowflake.connector.constants import ( + FIELD_ID_TO_NAME, + PARAMETER_MULTI_STATEMENT_COUNT, + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT, + QueryStatus, +) +from snowflake.connector.cursor import ResultMetadata +from snowflake.connector.description import CLIENT_VERSION +from snowflake.connector.errorcode import ( + ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT, + ER_NO_ARROW_RESULT, + ER_NO_PYARROW, + ER_NO_PYARROW_SNOWSQL, + ER_NOT_POSITIVE_SIZE, +) +from snowflake.connector.errors import Error +from snowflake.connector.sqlstate import SQLSTATE_FEATURE_NOT_SUPPORTED +from snowflake.connector.telemetry import TelemetryField +from snowflake.connector.util_text import random_string + + +@pytest.fixture +async def conn(conn_cnx, db_parameters): + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +create table {name} ( +aa int, +dt date, +tm time, +ts timestamp, +tsltz timestamp_ltz, +tsntz timestamp_ntz, +tstz timestamp_tz, +pct float, +ratio number(5,2), +b binary) +""".format( + name=db_parameters["name"] + ) + ) + + yield conn_cnx + + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "use {db}.{schema}".format( + db=db_parameters["database"], schema=db_parameters["schema"] + ) + ) + await cnx.cursor().execute( + "drop table {name}".format(name=db_parameters["name"]) + ) + + +def _check_results(cursor, results): + assert cursor.sfqid, "Snowflake query id is None" + assert cursor.rowcount == 3, "the number of records" + assert results[0] == 65432, "the first result was wrong" + assert results[1] == 98765, "the second result was wrong" + assert results[2] == 123456, "the third result was wrong" + + +def _name_from_description(named_access: bool): + if named_access: + return lambda meta: meta.name + else: + return lambda meta: meta[0] + + +def _type_from_description(named_access: bool): + if named_access: + return lambda meta: meta.type_code + else: + return lambda meta: meta[1] + + +async def test_insert_select(conn, db_parameters, caplog): + """Inserts and selects integer data.""" + async with conn() as cnx: + c = cnx.cursor() + try: + await c.execute( + "insert into {name}(aa) values(123456)," + "(98765),(65432)".format(name=db_parameters["name"]) + ) + cnt = 0 + async for rec in c: + cnt += int(rec[0]) + assert cnt == 3, "wrong number of records were inserted" + assert c.rowcount == 3, "wrong number of records were inserted" + finally: + await c.close() + + try: + c = cnx.cursor() + await c.execute( + "select aa from {name} order by aa".format(name=db_parameters["name"]) + ) + results = [] + async for rec in c: + results.append(rec[0]) + _check_results(c, results) + assert "Number of results in first chunk: 3" in caplog.text + finally: + await c.close() + + async with cnx.cursor(snowflake.connector.aio.DictCursor) as c: + caplog.clear() + assert "Number of results in first chunk: 3" not in caplog.text + await c.execute( + "select aa from {name} order by aa".format(name=db_parameters["name"]) + ) + results = [] + async for rec in c: + results.append(rec["AA"]) + _check_results(c, results) + assert "Number of results in first chunk: 3" in caplog.text + + +async def test_insert_and_select_by_separate_connection(conn, db_parameters, caplog): + """Inserts a record and select it by a separate connection.""" + async with conn() as cnx: + result = await cnx.cursor().execute( + "insert into {name}(aa) values({value})".format( + name=db_parameters["name"], value="1234" + ) + ) + cnt = 0 + async for rec in result: + cnt += int(rec[0]) + assert cnt == 1, "wrong number of records were inserted" + assert result.rowcount == 1, "wrong number of records were inserted" + + cnx2 = snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + database=db_parameters["database"], + schema=db_parameters["schema"], + protocol=db_parameters["protocol"], + timezone="UTC", + ) + await cnx2.connect() + try: + c = cnx2.cursor() + await c.execute("select aa from {name}".format(name=db_parameters["name"])) + results = [] + async for rec in c: + results.append(rec[0]) + await c.close() + assert results[0] == 1234, "the first result was wrong" + assert result.rowcount == 1, "wrong number of records were selected" + assert "Number of results in first chunk: 1" in caplog.text + finally: + await cnx2.close() + + +def _total_milliseconds_from_timedelta(td): + """Returns the total number of milliseconds contained in the duration object.""" + return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) // 10**3 + + +def _total_seconds_from_timedelta(td): + """Returns the total number of seconds contained in the duration object.""" + return _total_milliseconds_from_timedelta(td) // 10**3 + + +async def test_insert_timestamp_select(conn, db_parameters): + """Inserts and gets timestamp, timestamp with tz, date, and time. + + Notes: + Currently the session parameter TIMEZONE is ignored. + """ + PST_TZ = "America/Los_Angeles" + JST_TZ = "Asia/Tokyo" + current_timestamp = datetime.now(timezone.utc).replace(tzinfo=None) + current_timestamp = current_timestamp.replace(tzinfo=pytz.timezone(PST_TZ)) + current_date = current_timestamp.date() + current_time = current_timestamp.time() + + other_timestamp = current_timestamp.replace(tzinfo=pytz.timezone(JST_TZ)) + + async with conn() as cnx: + await cnx.cursor().execute("alter session set TIMEZONE=%s", (PST_TZ,)) + c = cnx.cursor() + try: + fmt = ( + "insert into {name}(aa, tsltz, tstz, tsntz, dt, tm) " + "values(%(value)s,%(tsltz)s, %(tstz)s, %(tsntz)s, " + "%(dt)s, %(tm)s)" + ) + await c.execute( + fmt.format(name=db_parameters["name"]), + { + "value": 1234, + "tsltz": current_timestamp, + "tstz": other_timestamp, + "tsntz": current_timestamp, + "dt": current_date, + "tm": current_time, + }, + ) + cnt = 0 + async for rec in c: + cnt += int(rec[0]) + assert cnt == 1, "wrong number of records were inserted" + assert c.rowcount == 1, "wrong number of records were selected" + finally: + await c.close() + + cnx2 = snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + database=db_parameters["database"], + schema=db_parameters["schema"], + protocol=db_parameters["protocol"], + timezone="UTC", + ) + await cnx2.connect() + try: + c = cnx2.cursor() + await c.execute( + "select aa, tsltz, tstz, tsntz, dt, tm from {name}".format( + name=db_parameters["name"] + ) + ) + + result_numeric_value = [] + result_timestamp_value = [] + result_other_timestamp_value = [] + result_ntz_timestamp_value = [] + result_date_value = [] + result_time_value = [] + + async for aa, ts, tstz, tsntz, dt, tm in c: + result_numeric_value.append(aa) + result_timestamp_value.append(ts) + result_other_timestamp_value.append(tstz) + result_ntz_timestamp_value.append(tsntz) + result_date_value.append(dt) + result_time_value.append(tm) + await c.close() + assert result_numeric_value[0] == 1234, "the integer result was wrong" + + td_diff = _total_milliseconds_from_timedelta( + current_timestamp - result_timestamp_value[0] + ) + assert td_diff == 0, "the timestamp result was wrong" + + td_diff = _total_milliseconds_from_timedelta( + other_timestamp - result_other_timestamp_value[0] + ) + assert td_diff == 0, "the other timestamp result was wrong" + + td_diff = _total_milliseconds_from_timedelta( + current_timestamp.replace(tzinfo=None) - result_ntz_timestamp_value[0] + ) + assert td_diff == 0, "the other timestamp result was wrong" + + assert current_date == result_date_value[0], "the date result was wrong" + + assert current_time == result_time_value[0], "the time result was wrong" + + name = _name_from_description(False) + type_code = _type_from_description(False) + descriptions = [c.description] + if hasattr(c, "_description_internal"): + # If _description_internal is defined, even the old description attribute will + # return ResultMetadata (v1) and not a plain tuple. This indirection is needed + # to support old-driver tests + name = _name_from_description(True) + type_code = _type_from_description(True) + descriptions.append(c._description_internal) + for desc in descriptions: + assert len(desc) == 6, "invalid number of column meta data" + assert name(desc[0]).upper() == "AA", "invalid column name" + assert name(desc[1]).upper() == "TSLTZ", "invalid column name" + assert name(desc[2]).upper() == "TSTZ", "invalid column name" + assert name(desc[3]).upper() == "TSNTZ", "invalid column name" + assert name(desc[4]).upper() == "DT", "invalid column name" + assert name(desc[5]).upper() == "TM", "invalid column name" + assert ( + constants.FIELD_ID_TO_NAME[type_code(desc[0])] == "FIXED" + ), f"invalid column name: {constants.FIELD_ID_TO_NAME[desc[0][1]]}" + assert ( + constants.FIELD_ID_TO_NAME[type_code(desc[1])] == "TIMESTAMP_LTZ" + ), "invalid column name" + assert ( + constants.FIELD_ID_TO_NAME[type_code(desc[2])] == "TIMESTAMP_TZ" + ), "invalid column name" + assert ( + constants.FIELD_ID_TO_NAME[type_code(desc[3])] == "TIMESTAMP_NTZ" + ), "invalid column name" + assert ( + constants.FIELD_ID_TO_NAME[type_code(desc[4])] == "DATE" + ), "invalid column name" + assert ( + constants.FIELD_ID_TO_NAME[type_code(desc[5])] == "TIME" + ), "invalid column name" + finally: + await cnx2.close() + + +async def test_insert_timestamp_ltz(conn, db_parameters): + """Inserts and retrieve timestamp ltz.""" + tzstr = "America/New_York" + # sync with the session parameter + async with conn() as cnx: + await cnx.cursor().execute(f"alter session set timezone='{tzstr}'") + + current_time = datetime.now() + current_time = current_time.replace(tzinfo=pytz.timezone(tzstr)) + + c = cnx.cursor() + try: + fmt = "insert into {name}(aa, tsltz) values(%(value)s,%(ts)s)" + await c.execute( + fmt.format(name=db_parameters["name"]), + { + "value": 8765, + "ts": current_time, + }, + ) + cnt = 0 + async for rec in c: + cnt += int(rec[0]) + assert cnt == 1, "wrong number of records were inserted" + finally: + await c.close() + + try: + c = cnx.cursor() + await c.execute( + "select aa,tsltz from {name}".format(name=db_parameters["name"]) + ) + result_numeric_value = [] + result_timestamp_value = [] + async for aa, ts in c: + result_numeric_value.append(aa) + result_timestamp_value.append(ts) + + td_diff = _total_milliseconds_from_timedelta( + current_time - result_timestamp_value[0] + ) + + assert td_diff == 0, "the first result was wrong" + finally: + await c.close() + + +async def test_struct_time(conn, db_parameters): + """Binds struct_time object for updating timestamp.""" + tzstr = "America/New_York" + os.environ["TZ"] = tzstr + if not IS_WINDOWS: + time.tzset() + test_time = time.strptime("30 Sep 01 11:20:30", "%d %b %y %H:%M:%S") + + async with conn() as cnx: + c = cnx.cursor() + try: + fmt = "insert into {name}(aa, tsltz) values(%(value)s,%(ts)s)" + await c.execute( + fmt.format(name=db_parameters["name"]), + { + "value": 87654, + "ts": test_time, + }, + ) + cnt = 0 + async for rec in c: + cnt += int(rec[0]) + finally: + await c.close() + os.environ["TZ"] = "UTC" + if not IS_WINDOWS: + time.tzset() + assert cnt == 1, "wrong number of records were inserted" + + try: + result = await cnx.cursor().execute( + "select aa, tsltz from {name}".format(name=db_parameters["name"]) + ) + async for _, _tsltz in result: + pass + + _tsltz -= _tsltz.tzinfo.utcoffset(_tsltz) + + assert test_time.tm_year == _tsltz.year, "Year didn't match" + assert test_time.tm_mon == _tsltz.month, "Month didn't match" + assert test_time.tm_mday == _tsltz.day, "Day didn't match" + assert test_time.tm_hour == _tsltz.hour, "Hour didn't match" + assert test_time.tm_min == _tsltz.minute, "Minute didn't match" + assert test_time.tm_sec == _tsltz.second, "Second didn't match" + finally: + os.environ["TZ"] = "UTC" + if not IS_WINDOWS: + time.tzset() + + +async def test_insert_binary_select(conn, db_parameters): + """Inserts and get a binary value.""" + value = b"\x00\xFF\xA1\xB2\xC3" + + async with conn() as cnx: + c = cnx.cursor() + try: + fmt = "insert into {name}(b) values(%(b)s)" + await c.execute(fmt.format(name=db_parameters["name"]), {"b": value}) + count = sum([int(rec[0]) async for rec in c]) + assert count == 1, "wrong number of records were inserted" + assert c.rowcount == 1, "wrong number of records were selected" + finally: + await c.close() + + cnx2 = snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + database=db_parameters["database"], + schema=db_parameters["schema"], + protocol=db_parameters["protocol"], + ) + await cnx2.connect() + try: + c = cnx2.cursor() + await c.execute("select b from {name}".format(name=db_parameters["name"])) + + results = [b async for (b,) in c] + assert value == results[0], "the binary result was wrong" + + name = _name_from_description(False) + type_code = _type_from_description(False) + descriptions = [c.description] + if hasattr(c, "_description_internal"): + # If _description_internal is defined, even the old description attribute will + # return ResultMetadata (v1) and not a plain tuple. This indirection is needed + # to support old-driver tests + name = _name_from_description(True) + type_code = _type_from_description(True) + descriptions.append(c._description_internal) + for desc in descriptions: + assert len(desc) == 1, "invalid number of column meta data" + assert name(desc[0]).upper() == "B", "invalid column name" + assert ( + constants.FIELD_ID_TO_NAME[type_code(desc[0])] == "BINARY" + ), "invalid column name" + finally: + await cnx2.close() + + +async def test_insert_binary_select_with_bytearray(conn, db_parameters): + """Inserts and get a binary value using the bytearray type.""" + value = bytearray(b"\x00\xFF\xA1\xB2\xC3") + + async with conn() as cnx: + c = cnx.cursor() + try: + fmt = "insert into {name}(b) values(%(b)s)" + await c.execute(fmt.format(name=db_parameters["name"]), {"b": value}) + count = sum([int(rec[0]) async for rec in c]) + assert count == 1, "wrong number of records were inserted" + assert c.rowcount == 1, "wrong number of records were selected" + finally: + await c.close() + + cnx2 = snowflake.connector.aio.SnowflakeConnection( + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + account=db_parameters["account"], + database=db_parameters["database"], + schema=db_parameters["schema"], + protocol=db_parameters["protocol"], + ) + await cnx2.connect() + try: + c = cnx2.cursor() + await c.execute("select b from {name}".format(name=db_parameters["name"])) + + results = [b async for (b,) in c] + assert bytes(value) == results[0], "the binary result was wrong" + + name = _name_from_description(False) + type_code = _type_from_description(False) + descriptions = [c.description] + if hasattr(c, "_description_internal"): + # If _description_internal is defined, even the old description attribute will + # return ResultMetadata (v1) and not a plain tuple. This indirection is needed + # to support old-driver tests + name = _name_from_description(True) + type_code = _type_from_description(True) + descriptions.append(c._description_internal) + for desc in descriptions: + assert len(desc) == 1, "invalid number of column meta data" + assert name(desc[0]).upper() == "B", "invalid column name" + assert ( + constants.FIELD_ID_TO_NAME[type_code(desc[0])] == "BINARY" + ), "invalid column name" + finally: + await cnx2.close() + + +async def test_variant(conn, db_parameters): + """Variant including JSON object.""" + name_variant = db_parameters["name"] + "_variant" + async with conn() as cnx: + await cnx.cursor().execute( + """ +create table {name} ( +created_at timestamp, data variant) +""".format( + name=name_variant + ) + ) + + try: + async with conn() as cnx: + current_time = datetime.now() + c = cnx.cursor() + try: + fmt = ( + "insert into {name}(created_at, data) " + "select column1, parse_json(column2) " + "from values(%(created_at)s, %(data)s)" + ) + await c.execute( + fmt.format(name=name_variant), + { + "created_at": current_time, + "data": ( + '{"SESSION-PARAMETERS":{' + '"TIMEZONE":"UTC", "SPECIAL_FLAG":true}}' + ), + }, + ) + cnt = 0 + async for rec in c: + cnt += int(rec[0]) + assert cnt == 1, "wrong number of records were inserted" + assert c.rowcount == 1, "wrong number of records were inserted" + finally: + await c.close() + + result = await cnx.cursor().execute( + f"select created_at, data from {name_variant}" + ) + _, data = await result.fetchone() + data = json.loads(data) + assert data["SESSION-PARAMETERS"]["SPECIAL_FLAG"], ( + "JSON data should be parsed properly. " "Invalid JSON data" + ) + finally: + async with conn() as cnx: + await cnx.cursor().execute(f"drop table {name_variant}") + + +async def test_geography(conn_cnx): + """Variant including JSON object.""" + name_geo = random_string(5, "test_geography_") + async with conn_cnx( + session_parameters={ + "GEOGRAPHY_OUTPUT_FORMAT": "geoJson", + }, + ) as cnx: + async with cnx.cursor() as cur: + await cur.execute(f"create temporary table {name_geo} (geo geography)") + await cur.execute( + f"insert into {name_geo} values ('POINT(0 0)'), ('LINESTRING(1 1, 2 2)')" + ) + expected_data = [ + {"coordinates": [0, 0], "type": "Point"}, + {"coordinates": [[1, 1], [2, 2]], "type": "LineString"}, + ] + + async with cnx.cursor() as cur: + # Test with GEOGRAPHY return type + result = await cur.execute(f"select * from {name_geo}") + for metadata in [cur.description, cur._description_internal]: + assert FIELD_ID_TO_NAME[metadata[0].type_code] == "GEOGRAPHY" + data = await result.fetchall() + for raw_data in data: + row = json.loads(raw_data[0]) + assert row in expected_data + + +async def test_geometry(conn_cnx): + """Variant including JSON object.""" + name_geo = random_string(5, "test_geometry_") + async with conn_cnx( + session_parameters={ + "GEOMETRY_OUTPUT_FORMAT": "geoJson", + }, + ) as cnx: + async with cnx.cursor() as cur: + await cur.execute(f"create temporary table {name_geo} (geo GEOMETRY)") + await cur.execute( + f"insert into {name_geo} values ('POINT(0 0)'), ('LINESTRING(1 1, 2 2)')" + ) + expected_data = [ + {"coordinates": [0, 0], "type": "Point"}, + {"coordinates": [[1, 1], [2, 2]], "type": "LineString"}, + ] + + async with cnx.cursor() as cur: + # Test with GEOMETRY return type + result = await cur.execute(f"select * from {name_geo}") + for metadata in [cur.description, cur._description_internal]: + assert FIELD_ID_TO_NAME[metadata[0].type_code] == "GEOMETRY" + data = await result.fetchall() + for raw_data in data: + row = json.loads(raw_data[0]) + assert row in expected_data + + +async def test_vector(conn_cnx, is_public_test): + if is_public_test: + pytest.xfail( + reason="This feature hasn't been rolled out for public Snowflake deployments yet." + ) + name_vectors = random_string(5, "test_vector_") + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + # Seed test data + expected_data_ints = [[1, 3, -5], [40, 1234567, 1], "NULL"] + expected_data_floats = [ + [1.8, -3.4, 6.7, 0, 2.3], + [4.121212121, 31234567.4, 7, -2.123, 1], + "NULL", + ] + await cur.execute( + f"create temporary table {name_vectors} (int_vec VECTOR(INT,3), float_vec VECTOR(FLOAT,5))" + ) + for i in range(len(expected_data_ints)): + await cur.execute( + f"insert into {name_vectors} select {expected_data_ints[i]}::VECTOR(INT,3), {expected_data_floats[i]}::VECTOR(FLOAT,5)" + ) + + async with cnx.cursor() as cur: + # Test a basic fetch + await cur.execute( + f"select int_vec, float_vec from {name_vectors} order by float_vec" + ) + for metadata in [cur.description, cur._description_internal]: + assert FIELD_ID_TO_NAME[metadata[0].type_code] == "VECTOR" + assert FIELD_ID_TO_NAME[metadata[1].type_code] == "VECTOR" + data = await cur.fetchall() + for i, row in enumerate(data): + if expected_data_floats[i] == "NULL": + assert row[0] is None + else: + assert row[0] == expected_data_ints[i] + + if expected_data_ints[i] == "NULL": + assert row[1] is None + else: + assert row[1] == pytest.approx(expected_data_floats[i]) + + # Test an empty result set + await cur.execute( + f"select int_vec, float_vec from {name_vectors} where int_vec = [1,2,3]::VECTOR(int,3)" + ) + for metadata in [cur.description, cur._description_internal]: + assert FIELD_ID_TO_NAME[metadata[0].type_code] == "VECTOR" + assert FIELD_ID_TO_NAME[metadata[1].type_code] == "VECTOR" + data = await cur.fetchall() + assert len(data) == 0 + + +async def test_invalid_bind_data_type(conn_cnx): + """Invalid bind data type.""" + async with conn_cnx() as cnx: + with pytest.raises(errors.ProgrammingError): + await cnx.cursor().execute("select 1 from dual where 1=%s", ([1, 2, 3],)) + + +async def test_timeout_query(conn_cnx): + async with conn_cnx() as cnx: + async with cnx.cursor() as c: + with pytest.raises(errors.ProgrammingError) as err: + await c.execute( + "select seq8() as c1 from table(generator(timeLimit => 60))", + timeout=5, + ) + assert err.value.errno == 604, ( + "Invalid error code" + and "SQL execution was cancelled by the client due to a timeout" + in err.value.msg + ) + + +async def test_executemany(conn, db_parameters): + """Executes many statements. Client binding is supported by either dict, or list data types. + + Notes: + The binding data type is dict and tuple, respectively. + """ + table_name = random_string(5, "test_executemany_") + async with conn() as cnx: + async with cnx.cursor() as c: + await c.execute(f"create temp table {table_name} (aa number)") + await c.executemany( + f"insert into {table_name}(aa) values(%(value)s)", + [ + {"value": 1234}, + {"value": 234}, + {"value": 34}, + {"value": 4}, + ], + ) + assert (await c.fetchone())[0] == 4, "number of records" + assert c.rowcount == 4, "wrong number of records were inserted" + + async with cnx.cursor() as c: + fmt = "insert into {name}(aa) values(%s)".format(name=db_parameters["name"]) + await c.executemany( + fmt, + [ + (12345,), + (1234,), + (234,), + (34,), + (4,), + ], + ) + assert (await c.fetchone())[0] == 5, "number of records" + assert c.rowcount == 5, "wrong number of records were inserted" + + +async def test_executemany_qmark_types(conn, db_parameters): + table_name = random_string(5, "test_executemany_qmark_types_") + async with conn(paramstyle="qmark") as cnx: + async with cnx.cursor() as cur: + await cur.execute(f"create temp table {table_name} (birth_date date)") + + insert_qy = f"INSERT INTO {table_name} (birth_date) values (?)" + date_1, date_2, date_3, date_4 = ( + date(1969, 2, 7), + date(1969, 1, 1), + date(2999, 12, 31), + date(9999, 1, 1), + ) + + # insert two dates, one in tuple format which specifies + # the snowflake type similar to how we support it in this + # example: + # https://docs.snowflake.com/en/user-guide/python-connector-example.html#using-qmark-or-numeric-binding-with-datetime-objects + await cur.executemany( + insert_qy, + [[date_1], [("DATE", date_2)], [date_3], [date_4]], + # test that kwargs get passed through executemany properly + _statement_params={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "json" + }, + ) + assert all( + isinstance(rb, JSONResultBatch) for rb in await cur.get_result_batches() + ) + + await cur.execute(f"select * from {table_name}") + assert {row[0] async for row in cur} == {date_1, date_2, date_3, date_4} + + +async def test_executemany_params_iterator(conn): + """Cursor.executemany() works with an interator of params.""" + table_name = random_string(5, "executemany_params_iterator_") + async with conn() as cnx: + async with cnx.cursor() as c: + await c.execute(f"create temp table {table_name}(bar integer)") + fmt = f"insert into {table_name}(bar) values(%(value)s)" + await c.executemany(fmt, ({"value": x} for x in ("1234", "234", "34", "4"))) + assert (await c.fetchone())[0] == 4, "number of records" + assert c.rowcount == 4, "wrong number of records were inserted" + + async with cnx.cursor() as c: + fmt = f"insert into {table_name}(bar) values(%s)" + await c.executemany(fmt, ((x,) for x in (12345, 1234, 234, 34, 4))) + assert (await c.fetchone())[0] == 5, "number of records" + assert c.rowcount == 5, "wrong number of records were inserted" + + +async def test_executemany_empty_params(conn): + """Cursor.executemany() does nothing if params is empty.""" + table_name = random_string(5, "executemany_empty_params_") + async with conn() as cnx: + async with cnx.cursor() as c: + # The table isn't created, so if this were executed, it would error. + await c.executemany(f"insert into {table_name}(aa) values(%(value)s)", []) + assert c.query is None + + +async def test_closed_cursor(conn, db_parameters): + """Attempts to use the closed cursor. It should raise errors. + + Notes: + The binding data type is scalar. + """ + table_name = random_string(5, "test_closed_cursor_") + async with conn() as cnx: + async with cnx.cursor() as c: + await c.execute(f"create temp table {table_name} (aa number)") + fmt = f"insert into {table_name}(aa) values(%s)" + await c.executemany( + fmt, + [ + 12345, + 1234, + 234, + 34, + 4, + ], + ) + assert (await c.fetchone())[0] == 5, "number of records" + assert c.rowcount == 5, "number of records" + + with pytest.raises(InterfaceError, match="Cursor is closed in execute") as err: + await c.execute(f"select aa from {table_name}") + assert err.value.errno == errorcode.ER_CURSOR_IS_CLOSED + assert ( + c.rowcount == 5 + ), "SNOW-647539: rowcount should remain available after cursor is closed" + + +async def test_fetchmany(conn, db_parameters, caplog): + table_name = random_string(5, "test_fetchmany_") + async with conn() as cnx: + async with cnx.cursor() as c: + await c.execute(f"create temp table {table_name} (aa number)") + await c.executemany( + f"insert into {table_name}(aa) values(%(value)s)", + [ + {"value": "3456789"}, + {"value": "234567"}, + {"value": "1234"}, + {"value": "234"}, + {"value": "34"}, + {"value": "4"}, + ], + ) + assert (await c.fetchone())[0] == 6, "number of records" + assert c.rowcount == 6, "number of records" + + async with cnx.cursor() as c: + await c.execute(f"select aa from {table_name} order by aa desc") + assert "Number of results in first chunk: 6" in caplog.text + + rows = await c.fetchmany(2) + assert len(rows) == 2, "The number of records" + assert rows[1][0] == 234567, "The second record" + + rows = await c.fetchmany(1) + assert len(rows) == 1, "The number of records" + assert rows[0][0] == 1234, "The first record" + + rows = await c.fetchmany(5) + assert len(rows) == 3, "The number of records" + assert rows[-1][0] == 4, "The last record" + + assert len(await c.fetchmany(15)) == 0, "The number of records" + + +async def test_process_params(conn, db_parameters): + """Binds variables for insert and other queries.""" + table_name = random_string(5, "test_process_params_") + async with conn() as cnx: + async with cnx.cursor() as c: + await c.execute(f"create temp table {table_name} (aa number)") + await c.executemany( + f"insert into {table_name}(aa) values(%(value)s)", + [ + {"value": "3456789"}, + {"value": "234567"}, + {"value": "1234"}, + {"value": "234"}, + {"value": "34"}, + {"value": "4"}, + ], + ) + assert (await c.fetchone())[0] == 6, "number of records" + + async with cnx.cursor() as c: + await c.execute( + f"select count(aa) from {table_name} where aa > %(value)s", + {"value": 1233}, + ) + assert (await c.fetchone())[0] == 3, "the number of records" + + async with cnx.cursor() as c: + await c.execute( + f"select count(aa) from {table_name} where aa > %s", (1234,) + ) + assert (await c.fetchone())[0] == 2, "the number of records" + + +@pytest.mark.parametrize( + ("interpolate_empty_sequences", "expected_outcome"), [(False, "%%s"), (True, "%s")] +) +async def test_process_params_empty( + conn_cnx, interpolate_empty_sequences, expected_outcome +): + """SQL is interpolated if params aren't None.""" + async with conn_cnx(interpolate_empty_sequences=interpolate_empty_sequences) as cnx: + async with cnx.cursor() as cursor: + await cursor.execute("select '%%s'", None) + assert await cursor.fetchone() == ("%%s",) + await cursor.execute("select '%%s'", ()) + assert await cursor.fetchone() == (expected_outcome,) + + +async def test_real_decimal(conn, db_parameters): + async with conn() as cnx: + c = cnx.cursor() + fmt = ("insert into {name}(aa, pct, ratio) " "values(%s,%s,%s)").format( + name=db_parameters["name"] + ) + await c.execute(fmt, (9876, 12.3, decimal.Decimal("23.4"))) + async for (_cnt,) in c: + pass + assert _cnt == 1, "the number of records" + await c.close() + + c = cnx.cursor() + fmt = "select aa, pct, ratio from {name}".format(name=db_parameters["name"]) + await c.execute(fmt) + async for _aa, _pct, _ratio in c: + pass + assert _aa == 9876, "the integer value" + assert _pct == 12.3, "the float value" + assert _ratio == decimal.Decimal("23.4"), "the decimal value" + await c.close() + + async with cnx.cursor(snowflake.connector.aio.DictCursor) as c: + fmt = "select aa, pct, ratio from {name}".format(name=db_parameters["name"]) + await c.execute(fmt) + rec = await c.fetchone() + assert rec["AA"] == 9876, "the integer value" + assert rec["PCT"] == 12.3, "the float value" + assert rec["RATIO"] == decimal.Decimal("23.4"), "the decimal value" + + +@pytest.mark.skip("SNOW-1763103 error handler async") +async def test_none_errorhandler(conn_testaccount): + c = conn_testaccount.cursor() + with pytest.raises(errors.ProgrammingError): + c.errorhandler = None + + +@pytest.mark.skip("SNOW-1763103 error handler async") +async def test_nope_errorhandler(conn_testaccount): + def user_errorhandler(connection, cursor, errorclass, errorvalue): + pass + + c = conn_testaccount.cursor() + c.errorhandler = user_errorhandler + await c.execute("select * foooooo never_exists_table") + await c.execute("select * barrrrr never_exists_table") + await c.execute("select * daaaaaa never_exists_table") + assert c.messages[0][0] == errors.ProgrammingError, "One error was recorded" + assert len(c.messages) == 1, "should be one error" + + +@pytest.mark.internal +async def test_binding_negative(negative_conn_cnx, db_parameters): + async with negative_conn_cnx() as cnx: + with pytest.raises(TypeError): + await cnx.cursor().execute( + "INSERT INTO {name}(aa) VALUES(%s)".format(name=db_parameters["name"]), + (1, 2, 3), + ) + with pytest.raises(errors.ProgrammingError): + await cnx.cursor().execute( + "INSERT INTO {name}(aa) VALUES(%s)".format(name=db_parameters["name"]), + (), + ) + with pytest.raises(errors.ProgrammingError): + await cnx.cursor().execute( + "INSERT INTO {name}(aa) VALUES(%s)".format(name=db_parameters["name"]), + (["a"],), + ) + + +async def test_execute_stores_query(conn_cnx): + async with conn_cnx() as cnx: + async with cnx.cursor() as cursor: + assert cursor.query is None + await cursor.execute("select 1") + assert cursor.query == "select 1" + + +async def test_execute_after_close(conn_testaccount): + """SNOW-13588: Raises an error if executing after the connection is closed.""" + cursor = conn_testaccount.cursor() + await conn_testaccount.close() + with pytest.raises(errors.Error): + await cursor.execute("show tables") + + +async def test_multi_table_insert(conn, db_parameters): + try: + async with conn() as cnx: + cur = cnx.cursor() + await cur.execute( + """ + INSERT INTO {name}(aa) VALUES(1234),(9876),(2345) + """.format( + name=db_parameters["name"] + ) + ) + assert cur.rowcount == 3, "the number of records" + + await cur.execute( + """ +CREATE OR REPLACE TABLE {name}_foo (aa_foo int) + """.format( + name=db_parameters["name"] + ) + ) + + await cur.execute( + """ +CREATE OR REPLACE TABLE {name}_bar (aa_bar int) + """.format( + name=db_parameters["name"] + ) + ) + + await cur.execute( + """ +INSERT ALL + INTO {name}_foo(aa_foo) VALUES(aa) + INTO {name}_bar(aa_bar) VALUES(aa) + SELECT aa FROM {name} + """.format( + name=db_parameters["name"] + ) + ) + assert cur.rowcount == 6 + finally: + async with conn() as cnx: + await cnx.cursor().execute( + """ +DROP TABLE IF EXISTS {name}_foo +""".format( + name=db_parameters["name"] + ) + ) + await cnx.cursor().execute( + """ +DROP TABLE IF EXISTS {name}_bar +""".format( + name=db_parameters["name"] + ) + ) + + +@pytest.mark.skipif( + True, + reason=""" +Negative test case. +""", +) +async def test_fetch_before_execute(conn_testaccount): + """SNOW-13574: Fetch before execute.""" + cursor = conn_testaccount.cursor() + with pytest.raises(errors.DataError): + await cursor.fetchone() + + +async def test_close_twice(conn_testaccount): + await conn_testaccount.close() + await conn_testaccount.close() + + +@pytest.mark.parametrize("result_format", ("arrow", "json")) +async def test_fetch_out_of_range_timestamp_value(conn, result_format): + async with conn() as cnx: + cur = cnx.cursor() + await cur.execute( + f"alter session set python_connector_query_result_format='{result_format}'" + ) + await cur.execute("select '12345-01-02'::timestamp_ntz") + with pytest.raises(errors.InterfaceError): + await cur.fetchone() + + +async def test_null_in_non_null(conn): + table_name = random_string(5, "null_in_non_null") + error_msg = "NULL result in a non-nullable column" + async with conn() as cnx: + cur = cnx.cursor() + await cur.execute(f"create temp table {table_name}(bar char not null)") + with pytest.raises(errors.IntegrityError, match=error_msg): + await cur.execute(f"insert into {table_name} values (null)") + + +@pytest.mark.parametrize("sql", (None, ""), ids=["None", "empty"]) +async def test_empty_execution(conn, sql): + """Checks whether executing an empty string, or nothing behaves as expected.""" + async with conn() as cnx: + async with cnx.cursor() as cur: + if sql is not None: + await cur.execute(sql) + assert cur._result is None + with pytest.raises( + TypeError, match="'NoneType' object is not( an)? itera(tor|ble)" + ): + await cur.fetchone() + with pytest.raises( + TypeError, match="'NoneType' object is not( an)? itera(tor|ble)" + ): + await cur.fetchall() + + +@pytest.mark.parametrize("reuse_results", [False, True]) +async def test_reset_fetch(conn, reuse_results): + """Tests behavior after resetting an open cursor.""" + async with conn(reuse_results=reuse_results) as cnx: + async with cnx.cursor() as cur: + await cur.execute("select 1") + assert cur.rowcount == 1 + cur.reset() + assert ( + cur.rowcount is None + ), "calling reset on an open cursor should unset rowcount" + assert not cur.is_closed(), "calling reset should not close the cursor" + if reuse_results: + assert await cur.fetchone() == (1,) + else: + assert await cur.fetchone() is None + assert len(await cur.fetchall()) == 0 + + +async def test_rownumber(conn): + """Checks whether rownumber is returned as expected.""" + async with conn() as cnx: + async with cnx.cursor() as cur: + assert await cur.execute("select * from values (1), (2)") + assert cur.rownumber is None + assert await cur.fetchone() == (1,) + assert cur.rownumber == 0 + assert await cur.fetchone() == (2,) + assert cur.rownumber == 1 + + +async def test_values_set(conn): + """Checks whether a bunch of properties start as Nones, but get set to something else when a query was executed.""" + properties = [ + "timestamp_output_format", + "timestamp_ltz_output_format", + "timestamp_tz_output_format", + "timestamp_ntz_output_format", + "date_output_format", + "timezone", + "time_output_format", + "binary_output_format", + ] + async with conn() as cnx: + async with cnx.cursor() as cur: + for property in properties: + assert getattr(cur, property) is None + # use a statement that alters session parameters due to HTAP optimization + assert await ( + await cur.execute("alter session set TIMEZONE='America/Los_Angeles'") + ).fetchone() == ("Statement executed successfully.",) + # The default values might change in future, so let's just check that they aren't None anymore + for property in properties: + assert getattr(cur, property) is not None + + +async def test_execute_helper_params_error(conn_testaccount): + """Tests whether calling _execute_helper with a non-dict statement params is handled correctly.""" + async with conn_testaccount.cursor() as cur: + with pytest.raises( + ProgrammingError, + match=r"The data type of statement params is invalid. It must be dict.$", + ): + await cur._execute_helper("select %()s", statement_params="1") + + +async def test_desc_rewrite(conn, caplog): + """Tests whether describe queries are rewritten as expected and this action is logged.""" + async with conn() as cnx: + async with cnx.cursor() as cur: + table_name = random_string(5, "test_desc_rewrite_") + try: + await cur.execute(f"create or replace table {table_name} (a int)") + caplog.set_level(logging.DEBUG, "snowflake.connector") + await cur.execute(f"desc {table_name}") + assert ( + "snowflake.connector.aio._cursor", + 10, + "query was rewritten: org=desc {table_name}, new=describe table {table_name}".format( + table_name=table_name + ), + ) in caplog.record_tuples + finally: + await cur.execute(f"drop table {table_name}") + + +@pytest.mark.parametrize("result_format", [False, None, "json"]) +async def test_execute_helper_cannot_use_arrow(conn_cnx, caplog, result_format): + """Tests whether cannot use arrow is handled correctly inside of _execute_helper.""" + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + with mock.patch( + "snowflake.connector.cursor.CAN_USE_ARROW_RESULT_FORMAT", False + ): + if result_format is False: + result_format = None + else: + result_format = { + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: result_format + } + caplog.set_level(logging.DEBUG, "snowflake.connector") + await cur.execute("select 1", _statement_params=result_format) + assert ( + "snowflake.connector.aio._cursor", + logging.DEBUG, + "Cannot use arrow result format, fallback to json format", + ) in caplog.record_tuples + assert await cur.fetchone() == (1,) + + +async def test_execute_helper_cannot_use_arrow_exception(conn_cnx): + """Like test_execute_helper_cannot_use_arrow but when we are trying to force arrow an Exception should be raised.""" + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + with mock.patch( + "snowflake.connector.cursor.CAN_USE_ARROW_RESULT_FORMAT", False + ): + with pytest.raises( + ProgrammingError, + match="The result set in Apache Arrow format is not supported for the platform.", + ): + await cur.execute( + "select 1", + _statement_params={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "arrow" + }, + ) + + +async def test_check_can_use_arrow_resultset(conn_cnx, caplog): + """Tests check_can_use_arrow_resultset has no effect when we can use arrow.""" + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + with mock.patch( + "snowflake.connector.cursor.CAN_USE_ARROW_RESULT_FORMAT", True + ): + caplog.set_level(logging.DEBUG, "snowflake.connector") + cur.check_can_use_arrow_resultset() + assert "Arrow" not in caplog.text + + +@pytest.mark.parametrize("snowsql", [True, False]) +async def test_check_cannot_use_arrow_resultset(conn_cnx, caplog, snowsql): + """Tests check_can_use_arrow_resultset expected outcomes.""" + config = {} + if snowsql: + config["application"] = "SnowSQL" + async with conn_cnx(**config) as cnx: + async with cnx.cursor() as cur: + with mock.patch( + "snowflake.connector.cursor.CAN_USE_ARROW_RESULT_FORMAT", False + ): + with pytest.raises( + ProgrammingError, + match=( + "Currently SnowSQL doesn't support the result set in Apache Arrow format." + if snowsql + else "The result set in Apache Arrow format is not supported for the platform." + ), + ) as pe: + cur.check_can_use_arrow_resultset() + assert pe.errno == ( + ER_NO_PYARROW_SNOWSQL if snowsql else ER_NO_ARROW_RESULT + ) + + +async def test_check_can_use_pandas(conn_cnx): + """Tests check_can_use_arrow_resultset has no effect when we can import pandas.""" + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + with mock.patch("snowflake.connector.cursor.installed_pandas", True): + cur.check_can_use_pandas() + + +async def test_check_cannot_use_pandas(conn_cnx): + """Tests check_can_use_arrow_resultset has expected outcomes.""" + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + with mock.patch("snowflake.connector.cursor.installed_pandas", False): + with pytest.raises( + ProgrammingError, + match=r"Optional dependency: 'pandas' is not installed, please see the " + "following link for install instructions: https:.*", + ) as pe: + cur.check_can_use_pandas() + assert pe.errno == ER_NO_PYARROW + + +async def test_not_supported_pandas(conn_cnx): + """Check that fetch_pandas functions return expected error when arrow results are not available.""" + result_format = {PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "json"} + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute("select 1", _statement_params=result_format) + with mock.patch("snowflake.connector.cursor.installed_pandas", True): + with pytest.raises(NotSupportedError): + await cur.fetch_pandas_all() + with pytest.raises(NotSupportedError): + list(await cur.fetch_pandas_batches()) + + +async def test_query_cancellation(conn_cnx): + """Tests whether query_cancellation works.""" + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute( + "select max(seq8()) from table(generator(timeLimit=>30));", + _no_results=True, + ) + sf_qid = cur.sfqid + await cur.abort_query(sf_qid) + + +async def test_executemany_insert_rewrite(conn_cnx): + """Tests calling executemany with a non rewritable pyformat insert query.""" + async with conn_cnx() as con: + async with con.cursor() as cur: + with pytest.raises( + InterfaceError, match="Failed to rewrite multi-row insert" + ) as ie: + await cur.executemany("insert into numbers (select 1)", [1, 2]) + assert ie.errno == ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT + + +async def test_executemany_bulk_insert_size_mismatch(conn_cnx): + """Tests bulk insert error with variable length of arguments.""" + async with conn_cnx(paramstyle="qmark") as con: + async with con.cursor() as cur: + with pytest.raises( + InterfaceError, match="Bulk data size don't match. expected: 1, got: 2" + ) as ie: + await cur.executemany("insert into numbers values (?,?)", [[1], [1, 2]]) + assert ie.errno == ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT + + +async def test_fetchmany_size_error(conn_cnx): + """Tests retrieving a negative number of results.""" + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute("select 1") + with pytest.raises( + ProgrammingError, + match="The number of rows is not zero or positive number: -1", + ) as ie: + await cur.fetchmany(-1) + assert ie.errno == ER_NOT_POSITIVE_SIZE + + +async def test_scroll(conn_cnx): + """Tests if scroll returns a NotSupported exception.""" + async with conn_cnx() as con: + async with con.cursor() as cur: + with pytest.raises( + NotSupportedError, match="scroll is not supported." + ) as nse: + await cur.scroll(2) + assert nse.errno == SQLSTATE_FEATURE_NOT_SUPPORTED + + +async def test__log_telemetry_job_data(conn_cnx, caplog): + """Tests whether we handle missing connection object correctly while logging a telemetry event.""" + async with conn_cnx() as con: + async with con.cursor() as cur: + with mock.patch.object(cur, "_connection", None): + caplog.set_level(logging.DEBUG, "snowflake.connector") + await cur._log_telemetry_job_data( + TelemetryField.ARROW_FETCH_ALL, True + ) # dummy value + assert ( + "snowflake.connector.aio._cursor", + logging.WARNING, + "Cursor failed to log to telemetry. Connection object may be None.", + ) in caplog.record_tuples + + +@pytest.mark.parametrize( + "result_format,expected_chunk_type", + ( + ("json", JSONResultBatch), + ("arrow", ArrowResultBatch), + ), +) +async def test_resultbatch( + conn_cnx, + result_format, + expected_chunk_type, + capture_sf_telemetry_async, +): + """This test checks the following things: + 1. After executing a query can we pickle the result batches + 2. When we get the batches, do we emit a telemetry log + 3. Whether we can iterate through ResultBatches multiple times + 4. Whether the results make sense + 5. See whether getter functions are working + """ + rowcount = 100000 + async with conn_cnx( + session_parameters={ + "python_connector_query_result_format": result_format, + } + ) as con: + async with capture_sf_telemetry_async.patch_connection(con) as telemetry_data: + async with con.cursor() as cur: + await cur.execute( + f"select seq4() from table(generator(rowcount => {rowcount}));" + ) + assert cur._result_set.total_row_index() == rowcount + pre_pickle_partitions = await cur.get_result_batches() + assert len(pre_pickle_partitions) > 1 + assert pre_pickle_partitions is not None + assert all( + isinstance(p, expected_chunk_type) for p in pre_pickle_partitions + ) + pickle_str = pickle.dumps(pre_pickle_partitions) + assert any( + t.message["type"] == TelemetryField.GET_PARTITIONS_USED.value + for t in telemetry_data.records + ) + post_pickle_partitions: list[ResultBatch] = pickle.loads(pickle_str) + total_rows = 0 + # Make sure the batches can be iterated over individually + for it in post_pickle_partitions: + print(it) + + for i, partition in enumerate(post_pickle_partitions): + # Tests whether the getter functions are working + if i == 0: + assert partition.compressed_size is None + assert partition.uncompressed_size is None + else: + assert partition.compressed_size is not None + assert partition.uncompressed_size is not None + # TODO: SNOW-1759076 Async for support in Cursor.get_result_batches() + for row in await partition.create_iter(): + col1 = row[0] + assert col1 == total_rows + total_rows += 1 + assert total_rows == rowcount + total_rows = 0 + # Make sure the batches can be iterated over again + for partition in post_pickle_partitions: + # TODO: SNOW-1759076 Async for support in Cursor.get_result_batches() + for row in await partition.create_iter(): + col1 = row[0] + assert col1 == total_rows + total_rows += 1 + assert total_rows == rowcount + + +@pytest.mark.parametrize( + "result_format,patch_path", + ( + ("json", "snowflake.connector.aio._result_batch.JSONResultBatch.create_iter"), + ("arrow", "snowflake.connector.aio._result_batch.ArrowResultBatch.create_iter"), + ), +) +async def test_resultbatch_lazy_fetching_and_schemas( + conn_cnx, result_format, patch_path +): + """Tests whether pre-fetching results chunks fetches the right amount of them.""" + rowcount = 1000000 # We need at least 5 chunks for this test + async with conn_cnx( + session_parameters={ + "python_connector_query_result_format": result_format, + } + ) as con: + async with con.cursor() as cur: + # Dummy return value necessary to not iterate through every batch with + # first fetchone call + + downloads = [iter([(i,)]) for i in range(10)] + + with mock.patch( + patch_path, + side_effect=downloads, + ) as patched_download: + await cur.execute( + f"select seq4() as c1, randstr(1,random()) as c2 " + f"from table(generator(rowcount => {rowcount}));" + ) + result_batches = await cur.get_result_batches() + batch_schemas = [batch.schema for batch in result_batches] + for schema in batch_schemas: + # all batches should have the same schema + assert schema == [ + ResultMetadata("C1", 0, None, None, 10, 0, False), + ResultMetadata("C2", 2, None, 16777216, None, None, False), + ] + assert patched_download.call_count == 0 + assert len(result_batches) > 5 + assert result_batches[0]._local # Sanity check first chunk being local + await cur.fetchone() # Trigger pre-fetching + + # While the first chunk is local we still call _download on it, which + # short circuits and just parses (for JSON batches) and then returns + # an iterator through that data, so we expect the call count to be 5. + # (0 local and 1, 2, 3, 4 pre-fetched) = 5 total + start_time = time.time() + while time.time() < start_time + 1: + # TODO: fix me, call count is different + if patched_download.call_count == 5: + break + else: + assert patched_download.call_count == 5 + + +@pytest.mark.parametrize("result_format", ["json", "arrow"]) +async def test_resultbatch_schema_exists_when_zero_rows(conn_cnx, result_format): + async with conn_cnx( + session_parameters={"python_connector_query_result_format": result_format} + ) as con: + async with con.cursor() as cur: + await cur.execute( + "select seq4() as c1, randstr(1,random()) as c2 from table(generator(rowcount => 1)) where 1=0" + ) + result_batches = await cur.get_result_batches() + # verify there is 1 batch and 0 rows in that batch + assert len(result_batches) == 1 + assert result_batches[0].rowcount == 0 + # verify that the schema is correct + schema = result_batches[0].schema + assert schema == [ + ResultMetadata("C1", 0, None, None, 10, 0, False), + ResultMetadata("C2", 2, None, 16777216, None, None, False), + ] + + +async def test_optional_telemetry(conn_cnx, capture_sf_telemetry_async): + """Make sure that we do not fail when _first_chunk_time is not present in cursor.""" + async with conn_cnx() as con: + async with con.cursor() as cur: + async with capture_sf_telemetry_async.patch_connection( + con, False + ) as telemetry: + await cur.execute("select 1;") + cur._first_chunk_time = None + assert await cur.fetchall() == [ + (1,), + ] + assert not any( + r.message.get("type", "") + == TelemetryField.TIME_CONSUME_LAST_RESULT.value + for r in telemetry.records + ) + + +@pytest.mark.parametrize("result_format", ("json", "arrow")) +@pytest.mark.parametrize("cursor_type", (SnowflakeCursor, DictCursor)) +@pytest.mark.parametrize("fetch_method", ("__anext__", "fetchone")) +async def test_out_of_range_year(conn_cnx, result_format, cursor_type, fetch_method): + """Tests whether the year 10000 is out of range exception is raised as expected.""" + async with conn_cnx( + session_parameters={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: result_format + } + ) as con: + async with con.cursor(cursor_type) as cur: + await cur.execute( + "select * from VALUES (1, TO_TIMESTAMP('9999-01-01 00:00:00')), (2, TO_TIMESTAMP('10000-01-01 00:00:00'))" + ) + iterate_obj = cur if fetch_method == "fetchone" else aiter(cur) + fetch_next_fn = getattr(iterate_obj, fetch_method) + # first fetch doesn't raise error + await fetch_next_fn() + with pytest.raises( + InterfaceError, + match=( + "date value out of range" + if IS_WINDOWS + else "year 10000 is out of range" + ), + ): + await fetch_next_fn() + + +async def test_describe(conn_cnx): + async with conn_cnx() as con: + async with con.cursor() as cur: + for describe in [cur.describe, cur._describe_internal]: + table_name = random_string(5, "test_describe_") + # test select + description = await describe( + "select * from VALUES(1, 3.1415926, 'snow', TO_TIMESTAMP('2021-01-01 00:00:00'))" + ) + assert description is not None + column_types = [column.type_code for column in description] + assert constants.FIELD_ID_TO_NAME[column_types[0]] == "FIXED" + assert constants.FIELD_ID_TO_NAME[column_types[1]] == "FIXED" + assert constants.FIELD_ID_TO_NAME[column_types[2]] == "TEXT" + assert "TIMESTAMP" in constants.FIELD_ID_TO_NAME[column_types[3]] + assert len(await cur.fetchall()) == 0 + + # test insert + await cur.execute(f"create table {table_name} (aa int)") + try: + description = await describe( + "insert into {name}(aa) values({value})".format( + name=table_name, value="1234" + ) + ) + assert description[0].name == "number of rows inserted" + assert cur.rowcount is None + finally: + await cur.execute(f"drop table if exists {table_name}") + + +async def test_fetch_batches_with_sessions(conn_cnx): + rowcount = 250_000 + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute( + f"select seq4() as foo from table(generator(rowcount=>{rowcount}))" + ) + + num_batches = len(await cur.get_result_batches()) + + with mock.patch( + "snowflake.connector.aio._network.SnowflakeRestful._use_requests_session", + side_effect=con._rest._use_requests_session, + ) as get_session_mock: + result = await cur.fetchall() + # all but one batch is downloaded using a session + assert get_session_mock.call_count == num_batches - 1 + assert len(result) == rowcount + + +async def test_null_connection(conn_cnx): + retries = 15 + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select seq4() as c from table(generator(rowcount=>50000))" + ) + await con.rest.delete_session() + status = await con.get_query_status(cur.sfqid) + for _ in range(retries): + if status not in (QueryStatus.RUNNING,): + break + await asyncio.sleep(1) + status = await con.get_query_status(cur.sfqid) + else: + pytest.fail(f"query is still running after {retries} retries") + assert status == QueryStatus.FAILED_WITH_ERROR + assert con.is_an_error(status) + + +async def test_multi_statement_failure(conn_cnx): + """ + This test mocks the driver version sent to Snowflake to be 2.8.1, which does not support multi-statement. + The backend should not allow multi-statements to be submitted for versions older than 2.9.0 and should raise an + error when a multi-statement is submitted, regardless of the MULTI_STATEMENT_COUNT parameter. + """ + try: + _connection.DEFAULT_CONFIGURATION["internal_application_version"] = ( + "2.8.1", + (type(None), str), + ) + async with conn_cnx() as con: + async with con.cursor() as cur: + with pytest.raises( + ProgrammingError, + match="Multiple SQL statements in a single API call are not supported; use one API call per statement instead.", + ): + await cur.execute( + f"alter session set {PARAMETER_MULTI_STATEMENT_COUNT}=0" + ) + await cur.execute("select 1; select 2; select 3;") + finally: + _connection.DEFAULT_CONFIGURATION["internal_application_version"] = ( + CLIENT_VERSION, + (type(None), str), + ) + + +async def test_decoding_utf8_for_json_result(conn_cnx): + # SNOW-787480, if not explicitly setting utf-8 decoding, the data will be + # detected decoding as windows-1250 by chardet.detect + async with conn_cnx( + session_parameters={"python_connector_query_result_format": "JSON"} + ) as con, con.cursor() as cur: + sql = """select '"",' || '"",' || '"",' || '"",' || '"",' || 'Ofigràfic' || '"",' from TABLE(GENERATOR(ROWCOUNT => 5000)) v;""" + ret = await (await cur.execute(sql)).fetchall() + assert len(ret) == 5000 + # This test case is tricky, for most of the test cases, the decoding is incorrect and can could be different + # on different platforms, however, due to randomness, in rare cases the decoding is indeed utf-8, + # the backend behavior is flaky + assert ret[0] in ( + ('"","","","","",OfigrĂ\xa0fic"",',), # AWS Cloud + ('"","","","","",OfigrÃ\xa0fic"",',), # GCP Mac and Linux Cloud + ('"","","","","",Ofigr\xc3\\xa0fic"",',), # GCP Windows Cloud + ( + '"","","","","",Ofigràfic"",', + ), # regression environment gets the correct decoding + ) + + async with conn_cnx( + session_parameters={"python_connector_query_result_format": "JSON"}, + json_result_force_utf8_decoding=True, + ) as con, con.cursor() as cur: + ret = await (await cur.execute(sql)).fetchall() + assert len(ret) == 5000 + assert ret[0] == ('"","","","","",Ofigràfic"",',) + + result_batch = JSONResultBatch( + None, None, None, None, None, False, json_result_force_utf8_decoding=True + ) + with pytest.raises(Error): + await result_batch._load("À".encode("latin1"), "latin1") + + +async def test_fetch_download_timeout_setting(conn_cnx): + with mock.patch.multiple( + "snowflake.connector.aio._result_batch", + DOWNLOAD_TIMEOUT=0.001, + MAX_DOWNLOAD_RETRY=2, + ): + sql = "SELECT seq4(), uniform(1, 10, RANDOM(12)) FROM TABLE(GENERATOR(ROWCOUNT => 100000)) v" + async with conn_cnx() as con, con.cursor() as cur: + with pytest.raises(asyncio.TimeoutError): + await (await cur.execute(sql)).fetchall() + + with mock.patch.multiple( + "snowflake.connector.aio._result_batch", + DOWNLOAD_TIMEOUT=10, + MAX_DOWNLOAD_RETRY=1, + ): + sql = "SELECT seq4(), uniform(1, 10, RANDOM(12)) FROM TABLE(GENERATOR(ROWCOUNT => 100000)) v" + async with conn_cnx() as con, con.cursor() as cur: + assert len(await (await cur.execute(sql)).fetchall()) == 100000 diff --git a/test/integ/aio/test_cursor_binding_async.py b/test/integ/aio/test_cursor_binding_async.py new file mode 100644 index 0000000000..b7ba9c2a96 --- /dev/null +++ b/test/integ/aio/test_cursor_binding_async.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import pytest + +from snowflake.connector.errors import ProgrammingError + + +async def test_binding_security(conn_cnx, db_parameters): + """SQL Injection Tests.""" + try: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "CREATE OR REPLACE TABLE {name} " + "(aa INT, bb STRING)".format(name=db_parameters["name"]) + ) + await cnx.cursor().execute( + "INSERT INTO {name} VALUES(%s, %s)".format(name=db_parameters["name"]), + (1, "test1"), + ) + await cnx.cursor().execute( + "INSERT INTO {name} VALUES(%(aa)s, %(bb)s)".format( + name=db_parameters["name"] + ), + {"aa": 2, "bb": "test2"}, + ) + async for _rec in await cnx.cursor().execute( + "SELECT * FROM {name} ORDER BY 1 DESC".format( + name=db_parameters["name"] + ) + ): + break + assert _rec[0] == 2, "First column" + assert _rec[1] == "test2", "Second column" + async for _rec in await cnx.cursor().execute( + "SELECT * FROM {name} WHERE aa=%s".format(name=db_parameters["name"]), + (1,), + ): + break + assert _rec[0] == 1, "First column" + assert _rec[1] == "test1", "Second column" + + # SQL injection safe test + # Good Example + with pytest.raises(ProgrammingError): + await cnx.cursor().execute( + "SELECT * FROM {name} WHERE aa=%s".format( + name=db_parameters["name"] + ), + ("1 or aa>0",), + ) + + with pytest.raises(ProgrammingError): + await cnx.cursor().execute( + "SELECT * FROM {name} WHERE aa=%(aa)s".format( + name=db_parameters["name"] + ), + {"aa": "1 or aa>0"}, + ) + + # Bad Example in application. DON'T DO THIS + c = cnx.cursor() + await c.execute( + "SELECT * FROM {name} WHERE aa=%s".format(name=db_parameters["name"]) + % ("1 or aa>0",) + ) + rec = await c.fetchall() + assert len(rec) == 2, "not raising error unlike the previous one." + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "drop table if exists {name}".format(name=db_parameters["name"]) + ) + + +async def test_binding_list(conn_cnx, db_parameters): + """SQL binding list type for IN.""" + try: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "CREATE OR REPLACE TABLE {name} " + "(aa INT, bb STRING)".format(name=db_parameters["name"]) + ) + await cnx.cursor().execute( + "INSERT INTO {name} VALUES(%s, %s)".format(name=db_parameters["name"]), + (1, "test1"), + ) + await cnx.cursor().execute( + "INSERT INTO {name} VALUES(%(aa)s, %(bb)s)".format( + name=db_parameters["name"] + ), + {"aa": 2, "bb": "test2"}, + ) + await cnx.cursor().execute( + "INSERT INTO {name} VALUES(3, 'test3')".format( + name=db_parameters["name"] + ) + ) + async for _rec in await cnx.cursor().execute( + """ +SELECT * FROM {name} WHERE aa IN (%s) ORDER BY 1 DESC +""".format( + name=db_parameters["name"] + ), + ([1, 3],), + ): + break + assert _rec[0] == 3, "First column" + assert _rec[1] == "test3", "Second column" + + async for _rec in await cnx.cursor().execute( + "SELECT * FROM {name} WHERE aa=%s".format(name=db_parameters["name"]), + (1,), + ): + break + assert _rec[0] == 1, "First column" + assert _rec[1] == "test1", "Second column" + + await cnx.cursor().execute( + """ +SELECT * FROM {name} WHERE aa IN (%s) ORDER BY 1 DESC +""".format( + name=db_parameters["name"] + ), + ((1,),), + ) + + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "drop table if exists {name}".format(name=db_parameters["name"]) + ) + + +@pytest.mark.internal +async def test_unsupported_binding(negative_conn_cnx, db_parameters): + """Unsupported data binding.""" + try: + async with negative_conn_cnx() as cnx: + await cnx.cursor().execute( + "CREATE OR REPLACE TABLE {name} " + "(aa INT, bb STRING)".format(name=db_parameters["name"]) + ) + await cnx.cursor().execute( + "INSERT INTO {name} VALUES(%s, %s)".format(name=db_parameters["name"]), + (1, "test1"), + ) + + sql = "select count(*) from {name} where aa=%s".format( + name=db_parameters["name"] + ) + + async with cnx.cursor() as cur: + rec = await (await cur.execute(sql, (1,))).fetchone() + assert rec[0] is not None, "no value is returned" + + # dict + with pytest.raises(ProgrammingError): + await cnx.cursor().execute(sql, ({"value": 1},)) + finally: + async with negative_conn_cnx() as cnx: + await cnx.cursor().execute( + "drop table if exists {name}".format(name=db_parameters["name"]) + ) diff --git a/test/integ/aio/test_cursor_context_manager_async.py b/test/integ/aio/test_cursor_context_manager_async.py new file mode 100644 index 0000000000..c1589468a1 --- /dev/null +++ b/test/integ/aio/test_cursor_context_manager_async.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from logging import getLogger + + +async def test_context_manager(conn_testaccount, db_parameters): + """Tests context Manager support in Cursor.""" + logger = getLogger(__name__) + + async def tables(conn): + async with conn.cursor() as cur: + await cur.execute("show tables") + name_to_idx = {elem[0]: idx for idx, elem in enumerate(cur.description)} + async for row in cur: + yield row[name_to_idx["name"]] + + try: + await conn_testaccount.cursor().execute( + "create or replace table {} (a int)".format(db_parameters["name"]) + ) + all_tables = [ + rec + async for rec in tables(conn_testaccount) + if rec == db_parameters["name"].upper() + ] + logger.info("tables: %s", all_tables) + assert len(all_tables) == 1, "number of tables" + finally: + await conn_testaccount.cursor().execute( + "drop table if exists {}".format(db_parameters["name"]) + ) diff --git a/test/integ/aio/test_dataintegrity_async.py b/test/integ/aio/test_dataintegrity_async.py new file mode 100644 index 0000000000..384e7e9b6e --- /dev/null +++ b/test/integ/aio/test_dataintegrity_async.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python -O +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +"""Script to test database capabilities and the DB-API interface. + +It tests for functionality and data integrity for some of the basic data types. Adapted from a script +taken from the MySQL python driver. +""" + +from __future__ import annotations + +import random +import time +from math import fabs + +import pytz + +from snowflake.connector.dbapi import DateFromTicks, TimeFromTicks, TimestampFromTicks + +try: + from snowflake.connector.util_text import random_string +except ImportError: + from ..randomize import random_string + + +async def table_exists(conn_cnx, name): + with conn_cnx() as cnx: + with cnx.cursor() as cursor: + try: + cursor.execute("select * from %s where 1=0" % name) + except Exception: + cnx.rollback() + return False + else: + return True + + +async def create_table(conn_cnx, columndefs, partial_name): + table = f'"dbabi_dibasic_{partial_name}"' + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "CREATE OR REPLACE TABLE {table} ({columns})".format( + table=table, columns="\n".join(columndefs) + ) + ) + return table + + +async def check_data_integrity(conn_cnx, columndefs, partial_name, generator): + rows = random.randrange(10, 15) + # floating_point_types = ('REAL','DOUBLE','DECIMAL') + floating_point_types = ("REAL", "DOUBLE") + + table = await create_table(conn_cnx, columndefs, partial_name) + async with conn_cnx() as cnx: + async with cnx.cursor() as cursor: + # insert some data as specified by generator passed in + insert_statement = "INSERT INTO {} VALUES ({})".format( + table, + ",".join(["%s"] * len(columndefs)), + ) + data = [ + [generator(i, j) for j in range(len(columndefs))] for i in range(rows) + ] + await cursor.executemany(insert_statement, data) + await cnx.commit() + + # verify 2 things: correct number of rows, correct values for + # each row + await cursor.execute(f"select * from {table} order by 1") + result_sequences = await cursor.fetchall() + results = [] + for i in result_sequences: + results.append(i) + + # verify the right number of rows were returned + assert len(results) == rows, ( + "fetchall did not return " "expected number of rows" + ) + + # verify the right values were returned + # for numbers, allow a difference of .000001 + for x, y in zip(results, sorted(data)): + if any(data_type in partial_name for data_type in floating_point_types): + for _ in range(rows): + df = fabs(float(x[0]) - float(y[0])) + if float(y[0]) != 0.0: + df = df / float(y[0]) + assert df <= 0.00000001, ( + "fetchall did not return correct values within " + "the expected range" + ) + else: + assert list(x) == list(y), "fetchall did not return correct values" + + await cursor.execute(f"drop table if exists {table}") + + +async def test_INT(conn_cnx): + # Number data + def generator(row, col): + return row * row + + await check_data_integrity(conn_cnx, ("col1 INT",), "INT", generator) + + +async def test_DECIMAL(conn_cnx): + # DECIMAL + def generator(row, col): + from decimal import Decimal + + return Decimal("%d.%02d" % (row, col)) + + await check_data_integrity(conn_cnx, ("col1 DECIMAL(5,2)",), "DECIMAL", generator) + + +async def test_REAL(conn_cnx): + def generator(row, col): + return row * 1000.0 + + await check_data_integrity(conn_cnx, ("col1 REAL",), "REAL", generator) + + +async def test_REAL2(conn_cnx): + def generator(row, col): + return row * 3.14 + + await check_data_integrity(conn_cnx, ("col1 REAL",), "REAL", generator) + + +async def test_DOUBLE(conn_cnx): + def generator(row, col): + return row / 1e-99 + + await check_data_integrity(conn_cnx, ("col1 DOUBLE",), "DOUBLE", generator) + + +async def test_FLOAT(conn_cnx): + def generator(row, col): + return row * 2.0 + + await check_data_integrity(conn_cnx, ("col1 FLOAT(67)",), "FLOAT", generator) + + +async def test_DATE(conn_cnx): + ticks = time.time() + + def generator(row, col): + return DateFromTicks(ticks + row * 86400 - col * 1313) + + await check_data_integrity(conn_cnx, ("col1 DATE",), "DATE", generator) + + +async def test_STRING(conn_cnx): + def generator(row, col): + import string + + rstr = random_string(1024, choices=string.ascii_letters + string.digits) + return rstr + + await check_data_integrity(conn_cnx, ("col2 STRING",), "STRING", generator) + + +async def test_TEXT(conn_cnx): + def generator(row, col): + rstr = "".join([chr(i) for i in range(33, 127)] * 100) + return rstr + + await check_data_integrity(conn_cnx, ("col2 TEXT",), "TEXT", generator) + + +async def test_VARCHAR(conn_cnx): + def generator(row, col): + import string + + rstr = random_string(50, choices=string.ascii_letters + string.digits) + return rstr + + await check_data_integrity(conn_cnx, ("col2 VARCHAR",), "VARCHAR", generator) + + +async def test_BINARY(conn_cnx): + def generator(row, col): + return bytes(random.getrandbits(8) for _ in range(50)) + + await check_data_integrity(conn_cnx, ("col1 BINARY",), "BINARY", generator) + + +async def test_TIMESTAMPNTZ(conn_cnx): + ticks = time.time() + + def generator(row, col): + return TimestampFromTicks(ticks + row * 86400 - col * 1313) + + await check_data_integrity( + conn_cnx, ("col1 TIMESTAMPNTZ",), "TIMESTAMPNTZ", generator + ) + + +async def test_TIMESTAMPNTZ_EXPLICIT(conn_cnx): + ticks = time.time() + + def generator(row, col): + return TimestampFromTicks(ticks + row * 86400 - col * 1313) + + await check_data_integrity( + conn_cnx, + ("col1 TIMESTAMP without time zone",), + "TIMESTAMPNTZ_EXPLICIT", + generator, + ) + + +# string that contains control characters (white spaces), etc. +async def test_DATETIME(conn_cnx): + ticks = time.time() + + def generator(row, col): + ret = TimestampFromTicks(ticks + row * 86400 - col * 1313) + myzone = pytz.timezone("US/Pacific") + ret = myzone.localize(ret) + + await check_data_integrity(conn_cnx, ("col1 TIMESTAMP",), "DATETIME", generator) + + +async def test_TIMESTAMP(conn_cnx): + ticks = time.time() + + def generator(row, col): + ret = TimestampFromTicks(ticks + row * 86400 - col * 1313) + myzone = pytz.timezone("US/Pacific") + return myzone.localize(ret) + + await check_data_integrity( + conn_cnx, ("col1 TIMESTAMP_LTZ",), "TIMESTAMP", generator + ) + + +async def test_TIMESTAMP_EXPLICIT(conn_cnx): + ticks = time.time() + + def generator(row, col): + ret = TimestampFromTicks(ticks + row * 86400 - col * 1313) + myzone = pytz.timezone("Australia/Sydney") + return myzone.localize(ret) + + await check_data_integrity( + conn_cnx, + ("col1 TIMESTAMP with local time zone",), + "TIMESTAMP_EXPLICIT", + generator, + ) + + +async def test_TIMESTAMPTZ(conn_cnx): + ticks = time.time() + + def generator(row, col): + ret = TimestampFromTicks(ticks + row * 86400 - col * 1313) + myzone = pytz.timezone("America/Vancouver") + return myzone.localize(ret) + + await check_data_integrity( + conn_cnx, ("col1 TIMESTAMPTZ",), "TIMESTAMPTZ", generator + ) + + +async def test_TIMESTAMPTZ_EXPLICIT(conn_cnx): + ticks = time.time() + + def generator(row, col): + ret = TimestampFromTicks(ticks + row * 86400 - col * 1313) + myzone = pytz.timezone("America/Vancouver") + return myzone.localize(ret) + + await check_data_integrity( + conn_cnx, ("col1 TIMESTAMP with time zone",), "TIMESTAMPTZ_EXPLICIT", generator + ) + + +async def test_TIMESTAMPLTZ(conn_cnx): + ticks = time.time() + + def generator(row, col): + ret = TimestampFromTicks(ticks + row * 86400 - col * 1313) + myzone = pytz.timezone("America/New_York") + return myzone.localize(ret) + + await check_data_integrity( + conn_cnx, ("col1 TIMESTAMPLTZ",), "TIMESTAMPLTZ", generator + ) + + +async def test_fractional_TIMESTAMP(conn_cnx): + ticks = time.time() + + def generator(row, col): + ret = TimestampFromTicks( + ticks + row * 86400 - col * 1313 + row * 0.7 * col / 3.0 + ) + myzone = pytz.timezone("Europe/Paris") + return myzone.localize(ret) + + await check_data_integrity( + conn_cnx, ("col1 TIMESTAMP_LTZ",), "TIMESTAMP_fractional", generator + ) + + +async def test_TIME(conn_cnx): + ticks = time.time() + + def generator(row, col): + ret = TimeFromTicks(ticks + row * 86400 - col * 1313) + return ret + + await check_data_integrity(conn_cnx, ("col1 TIME",), "TIME", generator) diff --git a/test/integ/aio/test_daylight_savings_async.py b/test/integ/aio/test_daylight_savings_async.py new file mode 100644 index 0000000000..d1cc9c8885 --- /dev/null +++ b/test/integ/aio/test_daylight_savings_async.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from datetime import datetime + +import pytz + + +async def _insert_timestamp(ctx, table, tz, dt): + myzone = pytz.timezone(tz) + ts = myzone.localize(dt, is_dst=True) + print("\n") + print(f"{repr(ts)}") + await ctx.cursor().execute( + "INSERT INTO {table} VALUES(%s)".format( + table=table, + ), + (ts,), + ) + + result = await (await ctx.cursor().execute(f"SELECT * FROM {table}")).fetchone() + retrieved_ts = result[0] + print("#####") + print(f"Retrieved ts: {repr(retrieved_ts)}") + print(f"Retrieved and converted TS{repr(retrieved_ts.astimezone(myzone))}") + print("#####") + assert result[0] == ts + await ctx.cursor().execute(f"DELETE FROM {table}") + + +async def test_daylight_savings_in_TIMESTAMP_LTZ(conn_cnx, db_parameters): + async with conn_cnx() as ctx: + await ctx.cursor().execute( + "CREATE OR REPLACE TABLE {table} (c1 timestamp_ltz)".format( + table=db_parameters["name"], + ) + ) + try: + dt = datetime(year=2016, month=3, day=13, hour=18, minute=47, second=32) + await _insert_timestamp(ctx, db_parameters["name"], "Australia/Sydney", dt) + dt = datetime(year=2016, month=3, day=13, hour=8, minute=39, second=23) + await _insert_timestamp(ctx, db_parameters["name"], "Europe/Paris", dt) + dt = datetime(year=2016, month=3, day=13, hour=8, minute=39, second=23) + await _insert_timestamp(ctx, db_parameters["name"], "UTC", dt) + + dt = datetime(year=2016, month=3, day=13, hour=1, minute=14, second=8) + await _insert_timestamp(ctx, db_parameters["name"], "America/New_York", dt) + + dt = datetime(year=2016, month=3, day=12, hour=22, minute=32, second=4) + await _insert_timestamp(ctx, db_parameters["name"], "US/Pacific", dt) + + finally: + await ctx.cursor().execute( + "DROP TABLE IF EXISTS {table}".format( + table=db_parameters["name"], + ) + ) diff --git a/test/integ/aio/test_dbapi_async.py b/test/integ/aio/test_dbapi_async.py new file mode 100644 index 0000000000..7ea1957a41 --- /dev/null +++ b/test/integ/aio/test_dbapi_async.py @@ -0,0 +1,877 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +"""Script to test database capabilities and the DB-API interface for functionality and data integrity. + +Adapted from a script by M-A Lemburg and taken from the MySQL python driver. +""" + +from __future__ import annotations + +import time + +import pytest + +import snowflake.connector.aio +import snowflake.connector.dbapi +from snowflake.connector import dbapi, errorcode, errors +from snowflake.connector.util_text import random_string + +TABLE1 = "dbapi_ddl1" +TABLE2 = "dbapi_ddl2" + + +async def drop_dbapi_tables(conn_cnx): + async with conn_cnx() as cnx: + async with cnx.cursor() as cursor: + for ddl in (TABLE1, TABLE2): + dropsql = f"drop table if exists {ddl}" + await cursor.execute(dropsql) + + +async def executeDDL1(cursor): + await cursor.execute(f"create or replace table {TABLE1} (name string)") + + +async def executeDDL2(cursor): + await cursor.execute(f"create or replace table {TABLE2} (name string)") + + +@pytest.fixture() +async def conn_local(request, conn_cnx): + async def fin(): + await drop_dbapi_tables(conn_cnx) + + yield conn_cnx + await fin() + + +async def _paraminsert(cur): + await executeDDL1(cur) + await cur.execute(f"insert into {TABLE1} values ('string inserted into table')") + assert cur.rowcount in (-1, 1) + + await cur.execute( + f"insert into {TABLE1} values (%(dbapi_ddl2)s)", {TABLE2: "Cooper's"} + ) + assert cur.rowcount in (-1, 1) + + await cur.execute(f"select name from {TABLE1}") + res = await cur.fetchall() + assert len(res) == 2, "cursor.fetchall returned too few rows" + dbapi_ddl2s = [res[0][0], res[1][0]] + dbapi_ddl2s.sort() + assert dbapi_ddl2s[0] == "Cooper's", "cursor.fetchall retrieved incorrect data" + assert ( + dbapi_ddl2s[1] == "string inserted into table" + ), "cursor.fetchall retrieved incorrect data" + + +async def test_connect(conn_cnx): + async with conn_cnx(): + pass + + +async def test_apilevel(): + try: + apilevel = snowflake.connector.apilevel + assert apilevel == "2.0", "test_dbapi:test_apilevel" + except AttributeError: + raise Exception("test_apilevel: apilevel not defined") + + +async def test_threadsafety(): + try: + threadsafety = snowflake.connector.threadsafety + assert threadsafety == 2, "check value of threadsafety is 2" + except errors.AttributeError: + raise Exception("AttributeError: not defined in Snowflake.connector") + + +async def test_paramstyle(): + try: + paramstyle = snowflake.connector.paramstyle + assert paramstyle == "pyformat" + except AttributeError: + raise Exception("snowflake.connector.paramstyle not defined") + + +async def test_exceptions(): + # required exceptions should be defined in a hierarchy + try: + assert issubclass(errors._Warning, Exception) + except AttributeError: + # Compatibility for olddriver tests + assert issubclass(errors.Warning, Exception) + assert issubclass(errors.Error, Exception) + assert issubclass(errors.InterfaceError, errors.Error) + assert issubclass(errors.DatabaseError, errors.Error) + assert issubclass(errors.OperationalError, errors.Error) + assert issubclass(errors.IntegrityError, errors.Error) + assert issubclass(errors.InternalError, errors.Error) + assert issubclass(errors.ProgrammingError, errors.Error) + assert issubclass(errors.NotSupportedError, errors.Error) + + +@pytest.mark.skip("SNOW-1770153 for error as attribute on connection") +async def test_exceptions_as_connection_attributes(conn_cnx): + async with conn_cnx() as con: + try: + assert con.Warning == errors._Warning + except AttributeError: + # Compatibility for olddriver tests + assert con.Warning == errors.Warning + assert con.Error == errors.Error + assert con.InterfaceError == errors.InterfaceError + assert con.DatabaseError == errors.DatabaseError + assert con.OperationalError == errors.OperationalError + assert con.IntegrityError == errors.IntegrityError + assert con.InternalError == errors.InternalError + assert con.ProgrammingError == errors.ProgrammingError + assert con.NotSupportedError == errors.NotSupportedError + + +async def test_commit(db_parameters): + con = snowflake.connector.aio.SnowflakeConnection( + account=db_parameters["account"], + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + protocol=db_parameters["protocol"], + ) + await con.connect() + try: + # Commit must work, even if it doesn't do anything + await con.commit() + finally: + await con.close() + + +async def test_rollback(conn_cnx, db_parameters): + async with conn_cnx() as cnx: + cur = cnx.cursor() + await cur.execute( + "create or replace table {} (a int)".format(db_parameters["name"]) + ) + await cnx.cursor().execute("begin") + await cur.execute( + """ +insert into {} (select seq8() seq + from table(generator(rowCount => 10)) v) +""".format( + db_parameters["name"] + ) + ) + await cnx.rollback() + dbapi_rollback = await ( + await cur.execute("select count(*) from {}".format(db_parameters["name"])) + ).fetchone() + assert dbapi_rollback[0] == 0, "transaction not rolled back" + await cur.execute("drop table {}".format(db_parameters["name"])) + await cur.close() + + +async def test_cursor(conn_cnx): + async with conn_cnx() as cnx: + try: + cur = cnx.cursor() + finally: + await cur.close() + + +async def test_cursor_isolation(conn_local): + async with conn_local() as con: + # two cursors from same connection have transaction isolation + cur1 = con.cursor() + cur2 = con.cursor() + await executeDDL1(cur1) + await cur1.execute( + f"insert into {TABLE1} values ('string inserted into table')" + ) + await cur2.execute(f"select name from {TABLE1}") + dbapi_ddl1 = await cur2.fetchall() + assert len(dbapi_ddl1) == 1 + assert len(dbapi_ddl1[0]) == 1 + assert dbapi_ddl1[0][0], "string inserted into table" + + +async def test_description(conn_local): + async with conn_local() as con: + cur = con.cursor() + assert cur.description is None, ( + "cursor.description should be none if there has not been any " + "statements executed" + ) + + await executeDDL1(cur) + assert ( + cur.description[0][0].lower() == "status" + ), "cursor.description returns status of insert" + await cur.execute("select name from %s" % TABLE1) + assert ( + len(cur.description) == 1 + ), "cursor.description describes too many columns" + assert ( + len(cur.description[0]) == 7 + ), "cursor.description[x] tuples must have 7 elements" + assert ( + cur.description[0][0].lower() == "name" + ), "cursor.description[x][0] must return column name" + # No, the column type is a numeric value + + # assert cur.description[0][1] == dbapi.STRING, ( + # 'cursor.description[x][1] must return column type. Got %r' + # % cur.description[0][1] + # ) + + # Make sure self.description gets reset + await executeDDL2(cur) + assert len(cur.description) == 1, "cursor.description is not reset" + + +async def test_rowcount(conn_local): + async with conn_local() as con: + cur = con.cursor() + assert cur.rowcount is None, ( + "cursor.rowcount not set to None when no statement have not be " + "executed yet" + ) + await executeDDL1(cur) + await cur.execute( + ("insert into %s values " "('string inserted into table')") % TABLE1 + ) + await cur.execute("select name from %s" % TABLE1) + assert cur.rowcount == 1, "cursor.rowcount should the number of rows returned" + + +async def test_close(db_parameters): + con = snowflake.connector.aio.SnowflakeConnection( + account=db_parameters["account"], + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + protocol=db_parameters["protocol"], + ) + await con.connect() + try: + cur = con.cursor() + finally: + await con.close() + + # commit is currently a nop; disabling for now + # connection.commit should raise an Error if called after connection is + # closed. + # assert calling(con.commit()),raises(errors.Error,'con.commit')) + + # disabling due to SNOW-13645 + # cursor.close() should raise an Error if called after connection closed + # try: + # cur.close() + # should not get here and raise and exception + # assert calling(cur.close()),raises(errors.Error, + # 'calling cursor.close() twice in a row does not get an error')) + # except BASE_EXCEPTION_CLASS as err: + # assert error.errno,equal_to( + # errorcode.ER_CURSOR_IS_CLOSED),'cursor.close() called twice in a row') + + # calling cursor.execute after connection is closed should raise an error + with pytest.raises(errors.Error) as e: + await cur.execute(f"create or replace table {TABLE1} (name string)") + assert ( + e.value.errno == errorcode.ER_CURSOR_IS_CLOSED + ), "cursor.execute() called twice in a row" + + # try to create a cursor on a closed connection + with pytest.raises(errors.Error) as e: + con.cursor() + assert ( + e.value.errno == errorcode.ER_CONNECTION_IS_CLOSED + ), "tried to create a cursor on a closed cursor" + + +async def test_execute(conn_local): + async with conn_local() as con: + cur = con.cursor() + await _paraminsert(cur) + + +async def test_executemany(conn_local): + async with conn_local() as con: + cur = con.cursor() + await executeDDL1(cur) + margs = [{"dbapi_ddl2": "Cooper's"}, {"dbapi_ddl2": "Boag's"}] + + await cur.executemany( + "insert into %s values (%%(dbapi_ddl2)s)" % (TABLE1), margs + ) + assert cur.rowcount == 2, ( + "insert using cursor.executemany set cursor.rowcount to " + "incorrect value %r" % cur.rowcount + ) + await cur.execute("select name from %s" % TABLE1) + res = await cur.fetchall() + assert len(res) == 2, "cursor.fetchall retrieved incorrect number of rows" + dbapi_ddl2s = [res[0][0], res[1][0]] + dbapi_ddl2s.sort() + assert dbapi_ddl2s[0] == "Boag's", "incorrect data retrieved" + assert dbapi_ddl2s[1] == "Cooper's", "incorrect data retrieved" + + +async def test_fetchone(conn_local): + async with conn_local() as con: + cur = con.cursor() + # SNOW-13548 - disabled + # assert calling(cur.fetchone()),raises(errors.Error), + # 'cursor.fetchone does not raise an Error if called before + # executing a query' + # ) + await executeDDL1(cur) + + await cur.execute("select name from %s" % TABLE1) + # assert calling( + # cur.fetchone()), is_(None), + # 'cursor.fetchone should return None if a query does not return any rows') + # assert cur.rowcount==-1)) + + await cur.execute("insert into %s values ('Row 1'),('Row 2')" % TABLE1) + await cur.execute("select name from %s order by 1" % TABLE1) + r = await cur.fetchone() + assert len(r) == 1, "cursor.fetchone should have returned 1 row" + assert r[0] == "Row 1", "cursor.fetchone returned incorrect data" + assert cur.rowcount == 2, "curosr.rowcount should be 2" + + +SAMPLES = [ + "Carlton Cold", + "Carlton Draft", + "Mountain Goat", + "Redback", + "String inserted into table", + "XXXX", +] + + +def _populate(): + """Returns a list of sql commands to setup the DB for the fetch tests.""" + populate = [ + # NOTE NO GOOD using format to bind data + f"insert into {TABLE1} values ('{s}')" + for s in SAMPLES + ] + return populate + + +async def test_fetchmany(conn_local): + async with conn_local() as con: + cur = con.cursor() + + # disable due to SNOW-13648 + # assert calling(cur.fetchmany()),errors.Error, + # 'cursor.fetchmany should raise an Error if called without executing a query') + + await executeDDL1(cur) + for sql in _populate(): + await cur.execute(sql) + + await cur.execute("select name from %s" % TABLE1) + cur.arraysize = 1 + r = await cur.fetchmany() + assert len(r) == 1, ( + "cursor.fetchmany retrieved incorrect number of rows, " + "should get 1 rows, received %s" % len(r) + ) + cur.arraysize = 10 + r = await cur.fetchmany(3) # Should get 3 rows + assert len(r) == 3, ( + "cursor.fetchmany retrieved incorrect number of rows, " + "should get 3 rows, received %s" % len(r) + ) + r = await cur.fetchmany(4) # Should get 2 more + assert len(r) == 2, ( + "cursor.fetchmany retrieved incorrect number of rows, " "should get 2 more." + ) + r = await cur.fetchmany(4) # Should be an empty sequence + assert len(r) == 0, ( + "cursor.fetchmany should return an empty sequence after " + "results are exhausted" + ) + assert cur.rowcount in (-1, 6) + + # Same as above, using cursor.arraysize + cur.arraysize = 4 + await cur.execute("select name from %s" % TABLE1) + r = await cur.fetchmany() # Should get 4 rows + assert len(r) == 4, "cursor.arraysize not being honoured by fetchmany" + r = await cur.fetchmany() # Should get 2 more + assert len(r) == 2 + r = await cur.fetchmany() # Should be an empty sequence + assert len(r) == 0 + assert cur.rowcount in (-1, 6) + + cur.arraysize = 6 + await cur.execute("select name from %s order by 1" % TABLE1) + rows = await cur.fetchmany() # Should get all rows + assert cur.rowcount in (-1, 6) + assert len(rows) == 6 + assert len(rows) == 6 + rows = [row[0] for row in rows] + rows.sort() + + # Make sure we get the right data back out + for i in range(0, 6): + assert rows[i] == SAMPLES[i], "incorrect data retrieved by cursor.fetchmany" + + rows = await cur.fetchmany() # Should return an empty list + assert len(rows) == 0, ( + "cursor.fetchmany should return an empty sequence if " + "called after the whole result set has been fetched" + ) + assert cur.rowcount in (-1, 6) + + await executeDDL2(cur) + await cur.execute("select name from %s" % TABLE2) + r = await cur.fetchmany() # Should get empty sequence + assert len(r) == 0, ( + "cursor.fetchmany should return an empty sequence if " + "query retrieved no rows" + ) + assert cur.rowcount in (-1, 0) + + +async def test_fetchall(conn_local): + async with conn_local() as con: + cur = con.cursor() + # disable due to SNOW-13648 + # assert calling(cur.fetchall()),raises(errors.Error), + # 'cursor.fetchall should raise an Error if called without executing a query' + # ) + await executeDDL1(cur) + for sql in _populate(): + await cur.execute(sql) + # assert calling(cur.fetchall()),errors.Error,'cursor.fetchall should raise an Error if called', + # 'after executing a a statement that does not return rows' + # ) + + await cur.execute(f"select name from {TABLE1}") + rows = await cur.fetchall() + assert cur.rowcount in (-1, len(SAMPLES)) + assert len(rows) == len(SAMPLES), "cursor.fetchall did not retrieve all rows" + rows = [r[0] for r in rows] + rows.sort() + for i in range(0, len(SAMPLES)): + assert rows[i] == SAMPLES[i], "cursor.fetchall retrieved incorrect rows" + rows = await cur.fetchall() + assert len(rows) == 0, ( + "cursor.fetchall should return an empty list if called " + "after the whole result set has been fetched" + ) + assert cur.rowcount in (-1, len(SAMPLES)) + + await executeDDL2(cur) + await cur.execute("select name from %s" % TABLE2) + rows = await cur.fetchall() + assert cur.rowcount == 0, "executed but no row was returned" + assert len(rows) == 0, ( + "cursor.fetchall should return an empty list if " + "a select query returns no rows" + ) + + +async def test_mixedfetch(conn_local): + async with conn_local() as con: + cur = con.cursor() + await executeDDL1(cur) + for sql in _populate(): + await cur.execute(sql) + + await cur.execute("select name from %s" % TABLE1) + rows1 = await cur.fetchone() + rows23 = await cur.fetchmany(2) + rows4 = await cur.fetchone() + rows56 = await cur.fetchall() + assert cur.rowcount in (-1, 6) + assert len(rows23) == 2, "fetchmany returned incorrect number of rows" + assert len(rows56) == 2, "fetchall returned incorrect number of rows" + + rows = [rows1[0]] + rows.extend([rows23[0][0], rows23[1][0]]) + rows.append(rows4[0]) + rows.extend([rows56[0][0], rows56[1][0]]) + rows.sort() + for i in range(0, len(SAMPLES)): + assert rows[i] == SAMPLES[i], "incorrect data returned" + + +async def test_arraysize(conn_cnx): + async with conn_cnx() as con: + cur = con.cursor() + assert hasattr(cur, "arraysize"), "cursor.arraysize must be defined" + + +async def test_setinputsizes(conn_local): + async with conn_local() as con: + cur = con.cursor() + cur.setinputsizes((25,)) + await _paraminsert(cur) # Make sure cursor still works + + +async def test_setoutputsize_basic(conn_local): + # Basic test is to make sure setoutputsize doesn't blow up + async with conn_local() as con: + cur = con.cursor() + cur.setoutputsize(1000) + cur.setoutputsize(2000, 0) + await _paraminsert(cur) # Make sure the cursor still works + + +async def test_description2(conn_local): + try: + async with conn_local() as con: + # ENABLE_FIX_67159 changes the column size to the actual size. By default it is disabled at the moment. + expected_column_size = ( + 26 if not con.account.startswith("sfctest0") else 16777216 + ) + cur = con.cursor() + await executeDDL1(cur) + assert ( + len(cur.description) == 1 + ), "length cursor.description should be 1 after executing an insert" + await cur.execute("select name from %s" % TABLE1) + assert ( + len(cur.description) == 1 + ), "cursor.description returns too many columns" + assert ( + len(cur.description[0]) == 7 + ), "cursor.description[x] tuples must have 7 elements" + assert ( + cur.description[0][0].lower() == "name" + ), "cursor.description[x][0] must return column name" + + # Make sure self.description gets reset + await executeDDL2(cur) + # assert cur.description is None, ( + # 'cursor.description not being set to None') + # description fields: name | type_code | display_size | internal_size | precision | scale | null_ok + # name and type_code are mandatory, the other five are optional and are set to None if no meaningful values can be provided. + expected = [ + ("COL0", 0, None, None, 38, 0, True), + # number (FIXED) + ("COL1", 0, None, None, 9, 4, False), + # decimal + ("COL2", 2, None, expected_column_size, None, None, False), + # string + ("COL3", 3, None, None, None, None, True), + # date + ("COL4", 6, None, None, 0, 9, True), + # timestamp + ("COL5", 5, None, None, None, None, True), + # variant + ("COL6", 6, None, None, 0, 9, True), + # timestamp_ltz + ("COL7", 7, None, None, 0, 9, True), + # timestamp_tz + ("COL8", 8, None, None, 0, 9, True), + # timestamp_ntz + ("COL9", 9, None, None, None, None, True), + # object + ("COL10", 10, None, None, None, None, True), + # array + # ('col11', 11, ... # binary + ("COL12", 12, None, None, 0, 9, True), + # time + # ('col13', 13, ... # boolean + ] + + async with conn_local() as cnx: + cursor = cnx.cursor() + await cursor.execute( + """ +alter session set timestamp_input_format = 'YYYY-MM-DD HH24:MI:SS TZH:TZM' +""" + ) + await cursor.execute( + """ +create or replace table test_description ( +col0 number, col1 decimal(9,4) not null, +col2 string not null default 'place-holder', col3 date, col4 timestamp_ltz, +col5 variant, col6 timestamp_ltz, col7 timestamp_tz, col8 timestamp_ntz, +col9 object, col10 array, col12 time) +""" # col11 binary, col12 time + ) + await cursor.execute( + """ +insert into test_description select column1, column2, column3, column4, +column5, parse_json(column6), column7, column8, column9, parse_xml(column10), +parse_json(column11), column12 from VALUES +(65538, 12345.1234, 'abcdefghijklmnopqrstuvwxyz', +'2015-09-08','2015-09-08 15:39:20 -00:00','{ name:[1, 2, 3, 4]}', +'2015-06-01 12:00:01 +00:00','2015-04-05 06:07:08 +08:00', +'2015-06-03 12:00:03 +03:00', +' JulietteRomeo', +'["xx", "yy", "zz", null, 1]', '12:34:56') +""" + ) + await cursor.execute("select * from test_description") + await cursor.fetchone() + assert cursor.description == expected, "cursor.description is incorrect" + finally: + async with conn_local() as con: + async with con.cursor() as cursor: + await cursor.execute("drop table if exists test_description") + await cursor.execute( + "alter session set timestamp_input_format = default" + ) + + +async def test_closecursor(conn_cnx): + async with conn_cnx() as cnx: + cursor = cnx.cursor() + await cursor.close() + # The connection will be unusable from this point forward; an Error (or subclass) exception will + # be raised if any operation is attempted with the connection. The same applies to all cursor + # objects trying to use the connection. + # close twice + + +async def test_None(conn_local): + async with conn_local() as con: + cur = con.cursor() + await executeDDL1(cur) + await cur.execute("insert into %s values (NULL)" % TABLE1) + await cur.execute("select name from %s" % TABLE1) + r = await cur.fetchall() + assert len(r) == 1 + assert len(r[0]) == 1 + assert r[0][0] is None, "NULL value not returned as None" + + +def test_Date(): + d1 = snowflake.connector.dbapi.Date(2002, 12, 25) + d2 = snowflake.connector.dbapi.DateFromTicks( + time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0)) + ) + # API doesn't specify, but it seems to be implied + assert str(d1) == str(d2) + + +def test_Time(): + t1 = snowflake.connector.dbapi.Time(13, 45, 30) + t2 = snowflake.connector.dbapi.TimeFromTicks( + time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0)) + ) + # API doesn't specify, but it seems to be implied + assert str(t1) == str(t2) + + +def test_Timestamp(): + t1 = snowflake.connector.dbapi.Timestamp(2002, 12, 25, 13, 45, 30) + t2 = snowflake.connector.dbapi.TimestampFromTicks( + time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)) + ) + # API doesn't specify, but it seems to be implied + assert str(t1) == str(t2) + + +def test_STRING(): + assert hasattr(dbapi, "STRING"), "dbapi.STRING must be defined" + + +def test_BINARY(): + assert hasattr(dbapi, "BINARY"), "dbapi.BINARY must be defined." + + +def test_NUMBER(): + assert hasattr(dbapi, "NUMBER"), "dbapi.NUMBER must be defined." + + +def test_DATETIME(): + assert hasattr(dbapi, "DATETIME"), "dbapi.DATETIME must be defined." + + +def test_ROWID(): + assert hasattr(dbapi, "ROWID"), "dbapi.ROWID must be defined." + + +async def test_substring(conn_local): + async with conn_local() as con: + cur = con.cursor() + await executeDDL1(cur) + args = {"dbapi_ddl2": '"" "\'",\\"\\""\'"'} + await cur.execute("insert into %s values (%%(dbapi_ddl2)s)" % TABLE1, args) + await cur.execute("select name from %s" % TABLE1) + res = await cur.fetchall() + dbapi_ddl2 = res[0][0] + assert ( + dbapi_ddl2 == args["dbapi_ddl2"] + ), "incorrect data retrieved, got {}, should be {}".format( + dbapi_ddl2, args["dbapi_ddl2"] + ) + + +async def test_escape(conn_local): + teststrings = [ + "abc\ndef", + "abc\\ndef", + "abc\\\ndef", + "abc\\\\ndef", + "abc\\\\\ndef", + 'abc"def', + 'abc""def', + "abc'def", + "abc''def", + 'abc"def', + 'abc""def', + "abc'def", + "abc''def", + "abc\tdef", + "abc\\tdef", + "abc\\\tdef", + "\\x", + ] + + async with conn_local() as con: + cur = con.cursor() + await executeDDL1(cur) + for i in teststrings: + args = {"dbapi_ddl2": i} + await cur.execute("insert into %s values (%%(dbapi_ddl2)s)" % TABLE1, args) + await cur.execute("select * from %s" % TABLE1) + row = await cur.fetchone() + await cur.execute("delete from %s where name=%%s" % TABLE1, i) + assert ( + i == row[0] + ), f"newline not properly converted, got {row[0]}, should be {i}" + + +@pytest.mark.skipolddriver +async def test_callproc(conn_local): + name_sp = random_string(5, "test_stored_procedure_") + message = random_string(10) + async with conn_local() as con: + cur = con.cursor() + await executeDDL1(cur) + await cur.execute( + f""" + create or replace temporary procedure {name_sp}(message varchar) + returns varchar not null + language sql + as + begin + return message; + end; + """ + ) + ret = await cur.callproc(name_sp, (message,)) + assert ret == (message,) and await cur.fetchall() == [(message,)] + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize("paramstyle", ["pyformat", "qmark"]) +async def test_callproc_overload(conn_cnx, paramstyle): + """Test calling stored procedures overloaded with different input parameters and returns.""" + name_sp = random_string(5, "test_stored_procedure_") + async with conn_cnx(paramstyle=paramstyle) as cnx: + async with cnx.cursor() as cursor: + await cursor.execute( + f""" + create or replace temporary procedure {name_sp}(p1 varchar, p2 int, p3 date) + returns string not null + language sql + as + begin + return 'teststring'; + end; + """ + ) + + await cursor.execute( + f""" + create or replace temporary procedure {name_sp}(p1 float, p2 char) + returns float not null + language sql + as + begin + return 1.23; + end; + """ + ) + + await cursor.execute( + f""" + create or replace temporary procedure {name_sp}(p1 boolean) + returns table(col1 int, col2 string) + language sql + as + declare + res resultset default (SELECT * from values(1, 'a'),(2, 'b') as t(col1, col2)); + begin + return table(res); + end; + """ + ) + + await cursor.execute( + f""" + create or replace temporary procedure {name_sp}() + returns boolean + language sql + as + begin + return true; + end; + """ + ) + + ret = await cursor.callproc(name_sp, ("str", 1, "2022-02-22")) + assert ret == ("str", 1, "2022-02-22") and await cursor.fetchall() == [ + ("teststring",) + ] + + ret = await cursor.callproc(name_sp, (0.99, "c")) + assert ret == (0.99, "c") and await cursor.fetchall() == [(1.23,)] + + ret = await cursor.callproc(name_sp, (True,)) + assert ret == (True,) and await cursor.fetchall() == [(1, "a"), (2, "b")] + + ret = await cursor.callproc(name_sp) + assert ret == () and await cursor.fetchall() == [(True,)] + + +@pytest.mark.skipolddriver +async def test_callproc_invalid(conn_cnx): + """Test invalid callproc""" + name_sp = random_string(5, "test_stored_procedure_") + message = random_string(10) + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + # stored procedure does not exist + with pytest.raises(errors.ProgrammingError) as pe: + await cur.callproc(name_sp) + assert pe.value.errno == 2140 + + await cur.execute( + f""" + create or replace temporary procedure {name_sp}(message varchar) + returns varchar not null + language sql + as + begin + return message; + end; + """ + ) + + # parameters do not match the signature + with pytest.raises(errors.ProgrammingError) as pe: + await cur.callproc(name_sp) + assert pe.value.errno == 1044 + + with pytest.raises(TypeError): + await cur.callproc(name_sp, message) + + ret = await cur.callproc(name_sp, (message,)) + assert ret == (message,) and await cur.fetchall() == [(message,)] diff --git a/test/integ/aio/test_errors_async.py b/test/integ/aio/test_errors_async.py new file mode 100644 index 0000000000..e673ea900e --- /dev/null +++ b/test/integ/aio/test_errors_async.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import traceback + +import pytest + +import snowflake.connector.aio +from snowflake.connector import errors +from snowflake.connector.telemetry import TelemetryField + + +@pytest.mark.skip("SNOW-1770153 for error as attribute on connection") +async def test_error_classes(conn_cnx): + """Error classes in Connector module, object.""" + # class + assert snowflake.connector.ProgrammingError == errors.ProgrammingError + assert snowflake.connector.OperationalError == errors.OperationalError + + # object + async with conn_cnx() as ctx: + assert ctx.ProgrammingError == errors.ProgrammingError + + +@pytest.mark.skipolddriver +async def test_error_code(conn_cnx): + """Error code is included in the exception.""" + syntax_errno = 1494 + syntax_errno_old = 1003 + syntax_sqlstate = "42601" + syntax_sqlstate_old = "42000" + query = "SELECT * FROOOM TEST" + async with conn_cnx() as ctx: + with pytest.raises(errors.ProgrammingError) as e: + await ctx.cursor().execute(query) + assert ( + e.value.errno == syntax_errno or e.value.errno == syntax_errno_old + ), "Syntax error code" + assert ( + e.value.sqlstate == syntax_sqlstate + or e.value.sqlstate == syntax_sqlstate_old + ), "Syntax SQL state" + assert e.value.query == query, "Query mismatch" + e.match( + rf"^({syntax_errno:06d} \({syntax_sqlstate}\)|{syntax_errno_old:06d} \({syntax_sqlstate_old}\)): " + ) + + +@pytest.mark.skipolddriver +async def test_error_telemetry(conn_cnx): + async with conn_cnx() as ctx: + with pytest.raises(errors.ProgrammingError) as e: + await ctx.cursor().execute("SELECT * FROOOM TEST") + telemetry_stacktrace = e.value.telemetry_traceback + assert "SELECT * FROOOM TEST" not in telemetry_stacktrace + for frame in traceback.extract_tb(e.value.__traceback__): + assert frame.line not in telemetry_stacktrace + telemetry_data = e.value.generate_telemetry_exception_data() + assert ( + "Failed to detect Syntax error" + not in telemetry_data[TelemetryField.KEY_REASON.value] + ) diff --git a/test/integ/aio/test_execute_multi_statements_async.py b/test/integ/aio/test_execute_multi_statements_async.py new file mode 100644 index 0000000000..fd24f8f2b7 --- /dev/null +++ b/test/integ/aio/test_execute_multi_statements_async.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import codecs +import os +from io import BytesIO, StringIO +from unittest.mock import patch + +import pytest + +from snowflake.connector import ProgrammingError +from snowflake.connector.aio import DictCursor + +THIS_DIR = os.path.dirname(os.path.realpath(__file__)) + + +async def test_execute_string(conn_cnx, db_parameters): + async with conn_cnx() as cnx: + await cnx.execute_string( + """ +CREATE OR REPLACE TABLE {tbl1} (c1 int, c2 string); +CREATE OR REPLACE TABLE {tbl2} (c1 int, c2 string); +INSERT INTO {tbl1} VALUES(1,'test123'); +INSERT INTO {tbl1} VALUES(2,'test234'); +INSERT INTO {tbl1} VALUES(3,'test345'); +INSERT INTO {tbl2} VALUES(101,'test123'); +INSERT INTO {tbl2} VALUES(102,'test234'); +INSERT INTO {tbl2} VALUES(103,'test345'); +""".format( + tbl1=db_parameters["name"] + "1", tbl2=db_parameters["name"] + "2" + ), + return_cursors=False, + ) + try: + async with conn_cnx() as cnx: + ret = await ( + await cnx.cursor().execute( + """ +SELECT * FROM {tbl1} ORDER BY 1 +""".format( + tbl1=db_parameters["name"] + "1" + ) + ) + ).fetchall() + assert ret[0][0] == 1 + assert ret[2][1] == "test345" + ret = await ( + await cnx.cursor().execute( + """ +SELECT * FROM {tbl2} ORDER BY 2 +""".format( + tbl2=db_parameters["name"] + "2" + ) + ) + ).fetchall() + assert ret[0][0] == 101 + assert ret[2][1] == "test345" + + curs = await cnx.execute_string( + """ +SELECT * FROM {tbl1} ORDER BY 1 DESC; +SELECT * FROM {tbl2} ORDER BY 1 DESC; +""".format( + tbl1=db_parameters["name"] + "1", tbl2=db_parameters["name"] + "2" + ) + ) + assert curs[0].rowcount == 3 + assert curs[1].rowcount == 3 + ret1 = await curs[0].fetchone() + assert ret1[0] == 3 + ret2 = await curs[1].fetchone() + assert ret2[0] == 103 + finally: + async with conn_cnx() as cnx: + await cnx.execute_string( + """ + DROP TABLE IF EXISTS {tbl1}; + DROP TABLE IF EXISTS {tbl2}; + """.format( + tbl1=db_parameters["name"] + "1", tbl2=db_parameters["name"] + "2" + ), + return_cursors=False, + ) + + +@pytest.mark.skipolddriver +async def test_execute_string_dict_cursor(conn_cnx, db_parameters): + async with conn_cnx() as cnx: + await cnx.execute_string( + """ +CREATE OR REPLACE TABLE {tbl1} (C1 int, C2 string); +CREATE OR REPLACE TABLE {tbl2} (C1 int, C2 string); +INSERT INTO {tbl1} VALUES(1,'test123'); +INSERT INTO {tbl1} VALUES(2,'test234'); +INSERT INTO {tbl1} VALUES(3,'test345'); +INSERT INTO {tbl2} VALUES(101,'test123'); +INSERT INTO {tbl2} VALUES(102,'test234'); +INSERT INTO {tbl2} VALUES(103,'test345'); +""".format( + tbl1=db_parameters["name"] + "1", tbl2=db_parameters["name"] + "2" + ), + return_cursors=False, + ) + try: + async with conn_cnx() as cnx: + ret = await cnx.cursor(cursor_class=DictCursor).execute( + """ +SELECT * FROM {tbl1} ORDER BY 1 +""".format( + tbl1=db_parameters["name"] + "1" + ) + ) + assert ret.rowcount == 3 + assert ret._use_dict_result + ret = await ret.fetchall() + assert type(ret) is list + assert type(ret[0]) is dict + assert type(ret[2]) is dict + assert ret[0]["C1"] == 1 + assert ret[2]["C2"] == "test345" + + ret = await cnx.cursor(cursor_class=DictCursor).execute( + """ +SELECT * FROM {tbl2} ORDER BY 2 +""".format( + tbl2=db_parameters["name"] + "2" + ) + ) + assert ret.rowcount == 3 + ret = await ret.fetchall() + assert type(ret) is list + assert type(ret[0]) is dict + assert type(ret[2]) is dict + assert ret[0]["C1"] == 101 + assert ret[2]["C2"] == "test345" + + curs = await cnx.execute_string( + """ +SELECT * FROM {tbl1} ORDER BY 1 DESC; +SELECT * FROM {tbl2} ORDER BY 1 DESC; +""".format( + tbl1=db_parameters["name"] + "1", tbl2=db_parameters["name"] + "2" + ), + cursor_class=DictCursor, + ) + assert type(curs) is list + assert curs[0].rowcount == 3 + assert curs[1].rowcount == 3 + ret1 = await curs[0].fetchone() + assert type(ret1) is dict + assert ret1["C1"] == 3 + assert ret1["C2"] == "test345" + ret2 = await curs[1].fetchone() + assert type(ret2) is dict + assert ret2["C1"] == 103 + finally: + async with conn_cnx() as cnx: + await cnx.execute_string( + """ + DROP TABLE IF EXISTS {tbl1}; + DROP TABLE IF EXISTS {tbl2}; + """.format( + tbl1=db_parameters["name"] + "1", tbl2=db_parameters["name"] + "2" + ), + return_cursors=False, + ) + + +async def test_execute_string_kwargs(conn_cnx, db_parameters): + async with conn_cnx() as cnx: + with patch( + "snowflake.connector.cursor.SnowflakeCursor.execute", autospec=True + ) as mock_execute: + await cnx.execute_string( + """ +CREATE OR REPLACE TABLE {tbl1} (c1 int, c2 string); +CREATE OR REPLACE TABLE {tbl2} (c1 int, c2 string); +INSERT INTO {tbl1} VALUES(1,'test123'); +INSERT INTO {tbl1} VALUES(2,'test234'); +INSERT INTO {tbl1} VALUES(3,'test345'); +INSERT INTO {tbl2} VALUES(101,'test123'); +INSERT INTO {tbl2} VALUES(102,'test234'); +INSERT INTO {tbl2} VALUES(103,'test345'); + """.format( + tbl1=db_parameters["name"] + "1", tbl2=db_parameters["name"] + "2" + ), + return_cursors=False, + _no_results=True, + ) + for call in mock_execute.call_args_list: + assert call[1].get("_no_results", False) + + +async def test_execute_string_with_error(conn_cnx): + async with conn_cnx() as cnx: + with pytest.raises(ProgrammingError): + await cnx.execute_string( + """ +SELECT 1; +SELECT 234; +SELECT bafa; +""" + ) + + +async def test_execute_stream(conn_cnx): + # file stream + expected_results = [1, 2, 3] + with codecs.open( + os.path.join(THIS_DIR, "../../data", "multiple_statements.sql"), + encoding="utf-8", + ) as f: + async with conn_cnx() as cnx: + idx = 0 + async for rec in cnx.execute_stream(f): + assert (await rec.fetchall())[0][0] == expected_results[idx] + idx += 1 + + # text stream + expected_results = [3, 4, 5, 6] + async with conn_cnx() as cnx: + idx = 0 + async for rec in cnx.execute_stream( + StringIO("SELECT 3; SELECT 4; SELECT 5;\nSELECT 6;") + ): + assert (await rec.fetchall())[0][0] == expected_results[idx] + idx += 1 + + +async def test_execute_stream_with_error(conn_cnx): + # file stream + expected_results = [1, 2, 3] + with open(os.path.join(THIS_DIR, "../../data", "multiple_statements.sql")) as f: + async with conn_cnx() as cnx: + idx = 0 + async for rec in cnx.execute_stream(f): + assert (await rec.fetchall())[0][0] == expected_results[idx] + idx += 1 + + # read a file including syntax error in the middle + with codecs.open( + os.path.join(THIS_DIR, "../../data", "multiple_statements_negative.sql"), + encoding="utf-8", + ) as f: + async with conn_cnx() as cnx: + gen = cnx.execute_stream(f) + rec = await anext(gen) + assert (await rec.fetchall())[0][0] == 987 + # rec = await (await anext(gen)).fetchall() + # assert rec[0][0] == 987 # the first statement succeeds + with pytest.raises(ProgrammingError): + await anext(gen) # the second statement fails + + # binary stream including Ascii data + async with conn_cnx() as cnx: + with pytest.raises(TypeError): + gen = cnx.execute_stream( + BytesIO(b"SELECT 3; SELECT 4; SELECT 5;\nSELECT 6;") + ) + await anext(gen) + + +@pytest.mark.skipolddriver +async def test_execute_string_empty_lines(conn_cnx, db_parameters): + """Tests whether execute_string can filter out empty lines.""" + async with conn_cnx() as cnx: + cursors = await cnx.execute_string("select 1;\n\n") + assert len(cursors) == 1 + assert [await c.fetchall() for c in cursors] == [[(1,)]] diff --git a/test/integ/aio/test_key_pair_authentication_async.py b/test/integ/aio/test_key_pair_authentication_async.py new file mode 100644 index 0000000000..e138978a95 --- /dev/null +++ b/test/integ/aio/test_key_pair_authentication_async.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import uuid + +import pytest +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import dsa, rsa + +import snowflake.connector +import snowflake.connector.aio + + +async def test_different_key_length(is_public_test, request, conn_cnx, db_parameters): + if is_public_test: + pytest.skip("This test requires ACCOUNTADMIN privilege to set the public key") + + test_user = "python_test_keypair_user_" + str(uuid.uuid4()).replace("-", "_") + + db_config = { + "protocol": db_parameters["protocol"], + "account": db_parameters["account"], + "user": test_user, + "host": db_parameters["host"], + "port": db_parameters["port"], + "database": db_parameters["database"], + "schema": db_parameters["schema"], + "timezone": "UTC", + } + + async def finalizer(): + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ + use role accountadmin + """ + ) + await cnx.cursor().execute( + """ + drop user if exists {user} + """.format( + user=test_user + ) + ) + + def fin(): + loop = asyncio.get_event_loop() + loop.run_until_complete(finalizer()) + + request.addfinalizer(fin) + + testcases = [2048, 4096, 8192] + + async with conn_cnx() as cnx: + cursor = cnx.cursor() + await cursor.execute( + """ + use role accountadmin + """ + ) + await cursor.execute("create user " + test_user) + + for key_length in testcases: + private_key_der, public_key_der_encoded = generate_key_pair(key_length) + + await cnx.cursor().execute( + """ + alter user {user} set rsa_public_key='{public_key}' + """.format( + user=test_user, public_key=public_key_der_encoded + ) + ) + + db_config["private_key"] = private_key_der + async with snowflake.connector.aio.SnowflakeConnection(**db_config) as _: + pass + + +@pytest.mark.skipolddriver +async def test_multiple_key_pair(is_public_test, request, conn_cnx, db_parameters): + if is_public_test: + pytest.skip("This test requires ACCOUNTADMIN privilege to set the public key") + + test_user = "python_test_keypair_user_" + str(uuid.uuid4()).replace("-", "_") + + db_config = { + "protocol": db_parameters["protocol"], + "account": db_parameters["account"], + "user": test_user, + "host": db_parameters["host"], + "port": db_parameters["port"], + "database": db_parameters["database"], + "schema": db_parameters["schema"], + "timezone": "UTC", + } + + async def finalizer(): + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ + use role accountadmin + """ + ) + await cnx.cursor().execute( + """ + drop user if exists {user} + """.format( + user=test_user + ) + ) + + def fin(): + loop = asyncio.get_event_loop() + loop.run_until_complete(finalizer()) + + request.addfinalizer(fin) + + private_key_one_der, public_key_one_der_encoded = generate_key_pair(2048) + private_key_two_der, public_key_two_der_encoded = generate_key_pair(2048) + + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ + use role accountadmin + """ + ) + await cnx.cursor().execute( + """ + create user {user} + """.format( + user=test_user + ) + ) + await cnx.cursor().execute( + """ + alter user {user} set rsa_public_key='{public_key}' + """.format( + user=test_user, public_key=public_key_one_der_encoded + ) + ) + + db_config["private_key"] = private_key_one_der + async with snowflake.connector.aio.SnowflakeConnection(**db_config) as _: + pass + + # assert exception since different key pair is used + db_config["private_key"] = private_key_two_der + # although specifying password, + # key pair authentication should used and it should fail since we don't do fall back + db_config["password"] = "fake_password" + with pytest.raises(snowflake.connector.errors.DatabaseError) as exec_info: + await snowflake.connector.aio.SnowflakeConnection(**db_config).connect() + + assert exec_info.value.errno == 250001 + assert exec_info.value.sqlstate == "08001" + assert "JWT token is invalid" in exec_info.value.msg + + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ + use role accountadmin + """ + ) + await cnx.cursor().execute( + """ + alter user {user} set rsa_public_key_2='{public_key}' + """.format( + user=test_user, public_key=public_key_two_der_encoded + ) + ) + + async with snowflake.connector.aio.SnowflakeConnection(**db_config) as _: + pass + + +async def test_bad_private_key(db_parameters): + db_config = { + "protocol": db_parameters["protocol"], + "account": db_parameters["account"], + "user": db_parameters["user"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "database": db_parameters["database"], + "schema": db_parameters["schema"], + "timezone": "UTC", + } + + dsa_private_key = dsa.generate_private_key(key_size=2048, backend=default_backend()) + dsa_private_key_der = dsa_private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + encrypted_rsa_private_key_der = rsa.generate_private_key( + key_size=2048, public_exponent=65537, backend=default_backend() + ).private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.BestAvailableEncryption(b"abcd"), + ) + + bad_private_key_test_cases = [ + b"abcd", + dsa_private_key_der, + encrypted_rsa_private_key_der, + ] + + for private_key in bad_private_key_test_cases: + db_config["private_key"] = private_key + with pytest.raises(snowflake.connector.errors.ProgrammingError) as exec_info: + await snowflake.connector.aio.SnowflakeConnection(**db_config).connect() + assert exec_info.value.errno == 251008 + + +def generate_key_pair(key_length): + private_key = rsa.generate_private_key( + backend=default_backend(), public_exponent=65537, key_size=key_length + ) + + private_key_der = private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + public_key_pem = ( + private_key.public_key() + .public_bytes( + serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo + ) + .decode("utf-8") + ) + + # strip off header + public_key_der_encoded = "".join(public_key_pem.split("\n")[1:-2]) + + return private_key_der, public_key_der_encoded diff --git a/test/integ/aio/test_large_put_async.py b/test/integ/aio/test_large_put_async.py new file mode 100644 index 0000000000..1639a1a3d5 --- /dev/null +++ b/test/integ/aio/test_large_put_async.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import os +from test.generate_test_files import generate_k_lines_of_n_files +from unittest.mock import patch + +import pytest + +from snowflake.connector.aio._file_transfer_agent import SnowflakeFileTransferAgent + + +@pytest.mark.skipolddriver +@pytest.mark.aws +async def test_put_copy_large_files(tmpdir, conn_cnx, db_parameters): + """[s3] Puts and Copies into large files.""" + # generates N files + number_of_files = 2 + number_of_lines = 200000 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + + files = os.path.join(tmp_dir, "file*") + async with conn_cnx() as cnx: + await cnx.cursor().execute( + f""" +create table {db_parameters['name']} ( +aa int, +dt date, +ts timestamp, +tsltz timestamp_ltz, +tsntz timestamp_ntz, +tstz timestamp_tz, +pct float, +ratio number(6,2)) +""" + ) + try: + async with conn_cnx() as cnx: + files = files.replace("\\", "\\\\") + + def mocked_file_agent(*args, **kwargs): + newkwargs = kwargs.copy() + newkwargs.update(multipart_threshold=10000) + agent = SnowflakeFileTransferAgent(*args, **newkwargs) + mocked_file_agent.agent = agent + return agent + + with patch( + "snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent", + side_effect=mocked_file_agent, + ): + # upload with auto compress = True + await cnx.cursor().execute( + f"put 'file://{files}' @%{db_parameters['name']} auto_compress=True", + ) + assert mocked_file_agent.agent._multipart_threshold == 10000 + await cnx.cursor().execute(f"remove @%{db_parameters['name']}") + + # upload with auto compress = False + await cnx.cursor().execute( + f"put 'file://{files}' @%{db_parameters['name']} auto_compress=False", + ) + assert mocked_file_agent.agent._multipart_threshold == 10000 + + # Upload again. There was a bug when a large file is uploaded again while it already exists in a stage. + # Refer to preprocess(self) of storage_client.py. + # self.get_digest() needs to be called before self.get_file_header(meta.dst_file_name). + # SNOW-749141 + await cnx.cursor().execute( + f"put 'file://{files}' @%{db_parameters['name']} auto_compress=False", + ) # do not add `overwrite=True` because overwrite will skip the code path to extract file header. + + c = cnx.cursor() + try: + await c.execute("copy into {}".format(db_parameters["name"])) + cnt = 0 + async for _ in c: + cnt += 1 + assert cnt == number_of_files, "Number of PUT files" + finally: + await c.close() + + c = cnx.cursor() + try: + await c.execute( + "select count(*) from {name}".format(name=db_parameters["name"]) + ) + cnt = 0 + async for rec in c: + cnt += rec[0] + assert cnt == number_of_files * number_of_lines, "Number of rows" + finally: + await c.close() + finally: + async with conn_cnx( + user=db_parameters["user"], + account=db_parameters["account"], + password=db_parameters["password"], + ) as cnx: + await cnx.cursor().execute( + "drop table if exists {table}".format(table=db_parameters["name"]) + ) diff --git a/test/integ/aio/test_large_result_set_async.py b/test/integ/aio/test_large_result_set_async.py new file mode 100644 index 0000000000..08ca9877a9 --- /dev/null +++ b/test/integ/aio/test_large_result_set_async.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from unittest.mock import Mock + +import pytest + +from snowflake.connector.telemetry import TelemetryField + +NUMBER_OF_ROWS = 50000 + +PREFETCH_THREADS = [8, 3, 1] + + +@pytest.fixture() +async def ingest_data(request, conn_cnx, db_parameters): + async with conn_cnx( + user=db_parameters["user"], + account=db_parameters["account"], + password=db_parameters["password"], + ) as cnx: + await cnx.cursor().execute( + """ + create or replace table {name} ( + c0 int, + c1 int, + c2 int, + c3 int, + c4 int, + c5 int, + c6 int, + c7 int, + c8 int, + c9 int) + """.format( + name=db_parameters["name"] + ) + ) + await cnx.cursor().execute( + """ + insert into {name} + select random(100), + random(100), + random(100), + random(100), + random(100), + random(100), + random(100), + random(100), + random(100), + random(100) + from table(generator(rowCount=>{number_of_rows})) + """.format( + name=db_parameters["name"], number_of_rows=NUMBER_OF_ROWS + ) + ) + first_val = ( + await ( + await cnx.cursor().execute( + "select c0 from {name} order by 1 limit 1".format( + name=db_parameters["name"] + ) + ) + ).fetchone() + )[0] + last_val = ( + await ( + await cnx.cursor().execute( + "select c9 from {name} order by 1 desc limit 1".format( + name=db_parameters["name"] + ) + ) + ).fetchone() + )[0] + + async def fin(): + async with conn_cnx( + user=db_parameters["user"], + account=db_parameters["account"], + password=db_parameters["password"], + ) as cnx: + await cnx.cursor().execute( + "drop table if exists {name}".format(name=db_parameters["name"]) + ) + + yield first_val, last_val + await fin() + + +@pytest.mark.aws +@pytest.mark.parametrize("num_threads", PREFETCH_THREADS) +async def test_query_large_result_set_n_threads( + conn_cnx, db_parameters, ingest_data, num_threads +): + sql = "select * from {name} order by 1".format(name=db_parameters["name"]) + async with conn_cnx( + user=db_parameters["user"], + account=db_parameters["account"], + password=db_parameters["password"], + client_prefetch_threads=num_threads, + ) as cnx: + assert cnx.client_prefetch_threads == num_threads + results = [] + async for rec in await cnx.cursor().execute(sql): + results.append(rec) + num_rows = len(results) + assert NUMBER_OF_ROWS == num_rows + assert results[0][0] == ingest_data[0] + assert results[num_rows - 1][8] == ingest_data[1] + + +@pytest.mark.aws +@pytest.mark.skipolddriver +async def test_query_large_result_set(conn_cnx, db_parameters, ingest_data): + """[s3] Gets Large Result set.""" + sql = "select * from {name} order by 1".format(name=db_parameters["name"]) + async with conn_cnx() as cnx: + telemetry_data = [] + add_log_mock = Mock() + add_log_mock.side_effect = lambda datum: telemetry_data.append(datum) + cnx._telemetry.add_log_to_batch = add_log_mock + + result2 = [] + async for rec in await cnx.cursor().execute(sql): + result2.append(rec) + + num_rows = len(result2) + assert result2[0][0] == ingest_data[0] + assert result2[num_rows - 1][8] == ingest_data[1] + + result999 = [] + async for rec in await cnx.cursor().execute(sql): + result999.append(rec) + + num_rows = len(result999) + assert result999[0][0] == ingest_data[0] + assert result999[num_rows - 1][8] == ingest_data[1] + + assert len(result2) == len( + result999 + ), "result length is different: result2, and result999" + for i, (x, y) in enumerate(zip(result2, result999)): + assert x == y, f"element {i}" + + # verify that the expected telemetry metrics were logged + expected = [ + TelemetryField.TIME_CONSUME_FIRST_RESULT, + TelemetryField.TIME_CONSUME_LAST_RESULT, + # NOTE: Arrow doesn't do parsing like how JSON does, so depending on what + # way this is executed only look for JSON result sets + # TelemetryField.TIME_PARSING_CHUNKS, + TelemetryField.TIME_DOWNLOADING_CHUNKS, + ] + for field in expected: + assert ( + sum( + 1 if x.message["type"] == field.value else 0 for x in telemetry_data + ) + == 2 + ), ( + "Expected three telemetry logs (one per query) " + "for log type {}".format(field.value) + ) diff --git a/test/integ/aio/test_load_unload_async.py b/test/integ/aio/test_load_unload_async.py new file mode 100644 index 0000000000..a45daa33c3 --- /dev/null +++ b/test/integ/aio/test_load_unload_async.py @@ -0,0 +1,498 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import os +import pathlib +from getpass import getuser +from logging import getLogger +from os import path + +import pytest + +try: + from parameters import CONNECTION_PARAMETERS_ADMIN +except ImportError: + CONNECTION_PARAMETERS_ADMIN = {} + +THIS_DIR = path.dirname(path.realpath(__file__)) + +logger = getLogger(__name__) + + +@pytest.fixture() +def test_data(request, conn_cnx, db_parameters): + def connection(): + """Abstracting away connection creation.""" + return conn_cnx() + + return create_test_data(request, db_parameters, connection) + + +@pytest.fixture() +def s3_test_data(request, conn_cnx, db_parameters): + def connection(): + """Abstracting away connection creation.""" + return conn_cnx( + user=db_parameters["user"], + account=db_parameters["account"], + password=db_parameters["password"], + ) + + return create_test_data(request, db_parameters, connection) + + +async def create_test_data(request, db_parameters, connection): + assert "AWS_ACCESS_KEY_ID" in os.environ, "AWS_ACCESS_KEY_ID is missing" + assert "AWS_SECRET_ACCESS_KEY" in os.environ, "AWS_SECRET_ACCESS_KEY is missing" + + unique_name = db_parameters["name"] + database_name = f"{unique_name}_db" + warehouse_name = f"{unique_name}_wh" + + async def fin(): + async with connection() as cnx: + async with cnx.cursor() as cur: + await cur.execute(f"drop database {database_name}") + await cur.execute(f"drop warehouse {warehouse_name}") + + request.addfinalizer(fin) + + class TestData: + def __init__(self): + self.test_data_dir = (pathlib.Path(__file__).parent / "data").absolute() + self.AWS_ACCESS_KEY_ID = "'{}'".format(os.environ["AWS_ACCESS_KEY_ID"]) + self.AWS_SECRET_ACCESS_KEY = "'{}'".format( + os.environ["AWS_SECRET_ACCESS_KEY"] + ) + self.stage_name = f"{unique_name}_stage" + self.warehouse_name = warehouse_name + self.database_name = database_name + self.connection = connection + self.user_bucket = os.getenv( + "SF_AWS_USER_BUCKET", f"sfc-eng-regression/{getuser()}/reg" + ) + + ret = TestData() + + async with connection() as cnx: + async with cnx.cursor() as cur: + await cur.execute("use role sysadmin") + await cur.execute( + """ +create or replace warehouse {} +warehouse_size = 'small' warehouse_type='standard' +auto_suspend=1800 +""".format( + warehouse_name + ) + ) + await cur.execute( + """ +create or replace database {} +""".format( + database_name + ) + ) + await cur.execute( + """ +create or replace schema pytesting_schema +""" + ) + await cur.execute( + """ +create or replace file format VSV type = 'CSV' +field_delimiter='|' error_on_column_count_mismatch=false + """ + ) + return ret + + +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible." +) +async def test_load_s3(test_data): + async with test_data.connection() as cnx: + async with cnx.cursor() as cur: + await cur.execute(f"use warehouse {test_data.warehouse_name}") + await cur.execute(f"use schema {test_data.database_name}.pytesting_schema") + await cur.execute( + """ +create or replace table tweets(created_at timestamp, +id number, id_str string, text string, source string, +in_reply_to_status_id number, in_reply_to_status_id_str string, +in_reply_to_user_id number, in_reply_to_user_id_str string, +in_reply_to_screen_name string, user__id number, user__id_str string, +user__name string, user__screen_name string, user__location string, +user__description string, user__url string, +user__entities__description__urls string, user__protected string, +user__followers_count number, user__friends_count number, +user__listed_count number, user__created_at timestamp, +user__favourites_count number, user__utc_offset number, +user__time_zone string, user__geo_enabled string, user__verified string, +user__statuses_count number, user__lang string, +user__contributors_enabled string, user__is_translator string, +user__profile_background_color string, +user__profile_background_image_url string, +user__profile_background_image_url_https string, +user__profile_background_tile string, user__profile_image_url string, +user__profile_image_url_https string, user__profile_link_color string, +user__profile_sidebar_border_color string, +user__profile_sidebar_fill_color string, user__profile_text_color string, +user__profile_use_background_image string, user__default_profile string, +user__default_profile_image string, user__following string, +user__follow_request_sent string, user__notifications string, geo string, +coordinates string, place string, contributors string, retweet_count number, +favorite_count number, entities__hashtags string, entities__symbols string, +entities__urls string, entities__user_mentions string, favorited string, +retweeted string, lang string) +""" + ) + await cur.execute("ls @%tweets") + assert cur.rowcount == 0, ( + "table newly created should not have any files in its " "staging area" + ) + await cur.execute( + """ +copy into tweets from s3://sfc-eng-data/twitter/O1k/tweets/ +credentials=(AWS_KEY_ID={aws_access_key_id} +AWS_SECRET_KEY={aws_secret_access_key}) +file_format=(skip_header=1 null_if=('') field_optionally_enclosed_by='"') +""".format( + aws_access_key_id=test_data.AWS_ACCESS_KEY_ID, + aws_secret_access_key=test_data.AWS_SECRET_ACCESS_KEY, + ) + ) + assert cur.rowcount == 1, "copy into tweets did not set rowcount to 1" + results = await cur.fetchall() + assert ( + results[0][0] == "s3://sfc-eng-data/twitter/O1k/tweets/1.csv.gz" + ), "ls @%tweets failed" + await cur.execute("drop table tweets") + + +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible." +) +async def test_put_local_file(test_data): + async with test_data.connection() as cnx: + async with cnx.cursor() as cur: + await cur.execute( + "alter session set DISABLE_PUT_AND_GET_ON_EXTERNAL_STAGE=false" + ) + await cur.execute(f"use warehouse {test_data.warehouse_name}") + await cur.execute( + f"""use schema {test_data.database_name}.pytesting_schema""" + ) + await cur.execute( + """ +create or replace table pytest_putget_t1 (c1 STRING, c2 STRING, c3 STRING, +c4 STRING, c5 STRING, c6 STRING, c7 STRING, c8 STRING, c9 STRING) +stage_file_format = (field_delimiter = '|' error_on_column_count_mismatch=false) +stage_copy_options = (purge=false) +stage_location = (url = 's3://sfc-eng-regression/jenkins/{stage_name}' +credentials = ( +AWS_KEY_ID={aws_access_key_id} +AWS_SECRET_KEY={aws_secret_access_key})) +""".format( + aws_access_key_id=test_data.AWS_ACCESS_KEY_ID, + aws_secret_access_key=test_data.AWS_SECRET_ACCESS_KEY, + stage_name=test_data.stage_name, + ) + ) + await cur.execute( + """put file://{}/ExecPlatform/Database/data/orders_10*.csv @%pytest_putget_t1""".format( + str(test_data.test_data_dir) + ) + ) + await cur.execute("ls @%pytest_putget_t1") + _ = await cur.fetchall() + assert cur.rowcount == 2, "ls @%pytest_putget_t1 did not return 2 rows" + await cur.execute("copy into pytest_putget_t1") + results = await cur.fetchall() + assert len(results) == 2, "2 files were not copied" + assert results[0][1] == "LOADED", "file 1 was not loaded after copy" + assert results[1][1] == "LOADED", "file 2 was not loaded after copy" + + await cur.execute("select count(*) from pytest_putget_t1") + results = await cur.fetchall() + assert results[0][0] == 73, "73 rows not loaded into putest_putget_t1" + await cur.execute("rm @%pytest_putget_t1") + results = await cur.fetchall() + assert len(results) == 2, "two files were not removed" + await cur.execute( + "select STATUS from information_schema.load_history where table_name='PYTEST_PUTGET_T1'" + ) + results = await cur.fetchall() + assert results[0][0] == "LOADED", "history does not show file to be loaded" + await cur.execute("drop table pytest_putget_t1") + + +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible." +) +async def test_put_load_from_user_stage(test_data): + async with test_data.connection() as cnx: + async with cnx.cursor() as cur: + await cur.execute( + "alter session set DISABLE_PUT_AND_GET_ON_EXTERNAL_STAGE=false" + ) + await cur.execute( + """ +use warehouse {} +""".format( + test_data.warehouse_name + ) + ) + await cur.execute( + """ +use schema {}.pytesting_schema +""".format( + test_data.database_name + ) + ) + await cur.execute( + """ +create or replace stage {stage_name} +url='s3://{user_bucket}/{stage_name}' +credentials = ( +AWS_KEY_ID={aws_access_key_id} +AWS_SECRET_KEY={aws_secret_access_key}) +""".format( + aws_access_key_id=test_data.AWS_ACCESS_KEY_ID, + aws_secret_access_key=test_data.AWS_SECRET_ACCESS_KEY, + user_bucket=test_data.user_bucket, + stage_name=test_data.stage_name, + ) + ) + await cur.execute( + """ +create or replace table pytest_putget_t2 (c1 STRING, c2 STRING, c3 STRING, +c4 STRING, c5 STRING, c6 STRING, c7 STRING, c8 STRING, c9 STRING) +""" + ) + await cur.execute( + """put file://{}/ExecPlatform/Database/data/orders_10*.csv @{}""".format( + test_data.test_data_dir, test_data.stage_name + ) + ) + # two files should have been put in the staging are + results = await cur.fetchall() + assert len(results) == 2 + + await cur.execute("ls @%pytest_putget_t2") + results = await cur.fetchall() + assert len(results) == 0, "no files should have been loaded yet" + + # copy + await cur.execute( + """ +copy into pytest_putget_t2 from @{stage_name} +file_format = (field_delimiter = '|' error_on_column_count_mismatch=false) +purge=true +""".format( + stage_name=test_data.stage_name + ) + ) + results = sorted(await cur.fetchall()) + assert len(results) == 2, "copy failed to load two files from the stage" + assert results[0][ + 0 + ] == "s3://{user_bucket}/{stage_name}/orders_100.csv.gz".format( + user_bucket=test_data.user_bucket, + stage_name=test_data.stage_name, + ), "copy did not load file orders_100" + + assert results[1][ + 0 + ] == "s3://{user_bucket}/{stage_name}/orders_101.csv.gz".format( + user_bucket=test_data.user_bucket, + stage_name=test_data.stage_name, + ), "copy did not load file orders_101" + + # should be empty (purged) + await cur.execute(f"ls @{test_data.stage_name}") + results = await cur.fetchall() + assert len(results) == 0, "copied files not purged" + await cur.execute("drop table pytest_putget_t2") + await cur.execute(f"drop stage {test_data.stage_name}") + + +@pytest.mark.aws +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible." +) +async def test_unload(db_parameters, s3_test_data): + async with s3_test_data.connection() as cnx: + async with cnx.cursor() as cur: + await cur.execute(f"""use warehouse {s3_test_data.warehouse_name}""") + await cur.execute( + f"""use schema {s3_test_data.database_name}.pytesting_schema""" + ) + await cur.execute( + """ +create or replace stage {stage_name} +url='s3://{user_bucket}/{stage_name}/unload/' +credentials = ( +AWS_KEY_ID={aws_access_key_id} +AWS_SECRET_KEY={aws_secret_access_key}) +""".format( + aws_access_key_id=s3_test_data.AWS_ACCESS_KEY_ID, + aws_secret_access_key=s3_test_data.AWS_SECRET_ACCESS_KEY, + user_bucket=s3_test_data.user_bucket, + stage_name=s3_test_data.stage_name, + ) + ) + + await cur.execute( + """ +CREATE OR REPLACE TABLE pytest_t3 (c1 STRING, c2 STRING, c3 STRING, +c4 STRING, c5 STRING, c6 STRING, c7 STRING, c8 STRING, c9 STRING) +stage_file_format = (format_name = 'vsv' field_delimiter = '|' +error_on_column_count_mismatch=false) +""" + ) + await cur.execute( + """ +alter stage {stage_name} set file_format = (format_name = 'VSV' ) +""".format( + stage_name=s3_test_data.stage_name + ) + ) + + # make sure its clean + await cur.execute(f"rm @{s3_test_data.stage_name}") + + # put local file + await cur.execute( + "put file://{}/ExecPlatform/Database/data/orders_10*.csv @%pytest_t3".format( + s3_test_data.test_data_dir + ) + ) + + # copy into table + await cur.execute( + """ +copy into pytest_t3 +file_format = (field_delimiter = '|' error_on_column_count_mismatch=false) +purge=true +""" + ) + # unload from table + await cur.execute( + """ +copy into @{stage_name}/pytest_t3/data_ +from pytest_t3 file_format=(format_name='VSV' compression='gzip') +max_file_size=10000000 +""".format( + stage_name=s3_test_data.stage_name + ) + ) + + # load the data back to another table + await cur.execute( + """ +CREATE OR REPLACE TABLE pytest_t3_copy +(c1 STRING, c2 STRING, c3 STRING, c4 STRING, c5 STRING, +c6 STRING, c7 STRING, c8 STRING, c9 STRING) +stage_file_format = (format_name = 'VSV' ) +""" + ) + + await cur.execute( + """ +copy into pytest_t3_copy +from @{stage_name}/pytest_t3/data_ return_failed_only=true +""".format( + stage_name=s3_test_data.stage_name + ) + ) + + # check to make sure they are equal + await cur.execute( + """ +(select * from pytest_t3 minus select * from pytest_t3_copy) +union +(select * from pytest_t3_copy minus select * from pytest_t3) +""" + ) + assert cur.rowcount == 0, "unloaded/reloaded data were not the same" + # clean stage + await cur.execute( + "rm @{stage_name}/pytest_t3/data_".format( + stage_name=s3_test_data.stage_name + ) + ) + assert cur.rowcount == 1, "only one file was expected to be removed" + + # unload with deflate + await cur.execute( + """ +copy into @{stage_name}/pytest_t3/data_ +from pytest_t3 file_format=(format_name='VSV' compression='deflate') +max_file_size=10000000 +""".format( + stage_name=s3_test_data.stage_name + ) + ) + results = await cur.fetchall() + assert results[0][0] == 73, "73 rows were expected to be loaded" + + # create a table to unload data into + await cur.execute( + """ +CREATE OR REPLACE TABLE pytest_t3_copy +(c1 STRING, c2 STRING, c3 STRING, c4 STRING, c5 STRING, c6 STRING, +c7 STRING, c8 STRING, c9 STRING) +stage_file_format = (format_name = 'VSV' +compression='deflate') +""" + ) + results = await cur.fetchall() + assert results[0][0] == "Table PYTEST_T3_COPY successfully created." + + await cur.execute( + """ +alter stage {stage_name} set file_format = (format_name = 'VSV' + compression='deflate')""".format( + stage_name=s3_test_data.stage_name + ) + ) + + await cur.execute( + """ +copy into pytest_t3_copy from @{stage_name}/pytest_t3/data_ +return_failed_only=true +""".format( + stage_name=s3_test_data.stage_name + ) + ) + results = await cur.fetchall() + assert results[0][2] == "LOADED" + assert results[0][4] == 73 + # check to make sure they are equal + await cur.execute( + """ +(select * from pytest_t3 minus select * from pytest_t3_copy) union +(select * from pytest_t3_copy minus select * from pytest_t3)""" + ) + assert cur.rowcount == 0, "unloaded/reloaded data were not the same" + await cur.execute( + "rm @{stage_name}/pytest_t3/data_".format( + stage_name=s3_test_data.stage_name + ) + ) + assert cur.rowcount == 1, "only one file was expected to be removed" + + # clean stage + await cur.execute( + "rm @{stage_name}/pytest_t3/data_".format( + stage_name=s3_test_data.stage_name + ) + ) + + await cur.execute("drop table pytest_t3_copy") + await cur.execute(f"drop stage {s3_test_data.stage_name}") diff --git a/test/integ/aio/test_multi_statement_async.py b/test/integ/aio/test_multi_statement_async.py new file mode 100644 index 0000000000..0968a42564 --- /dev/null +++ b/test/integ/aio/test_multi_statement_async.py @@ -0,0 +1,398 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from test.helpers import ( + _wait_until_query_success_async, + _wait_while_query_running_async, +) + +import pytest + +from snowflake.connector import ProgrammingError, errors +from snowflake.connector.aio import SnowflakeCursor +from snowflake.connector.constants import PARAMETER_MULTI_STATEMENT_COUNT, QueryStatus +from snowflake.connector.util_text import random_string + + +@pytest.fixture(scope="module", params=[False, True]) +def skip_to_last_set(request) -> bool: + return request.param + + +async def test_multi_statement_wrong_count(conn_cnx): + """Tries to send the wrong number of statements.""" + async with conn_cnx(session_parameters={PARAMETER_MULTI_STATEMENT_COUNT: 1}) as con: + async with con.cursor() as cur: + with pytest.raises( + errors.ProgrammingError, + match="Actual statement count 2 did not match the desired statement count 1.", + ): + await cur.execute("select 1; select 2") + + with pytest.raises( + errors.ProgrammingError, + match="Actual statement count 2 did not match the desired statement count 1.", + ): + await cur.execute( + "alter session set MULTI_STATEMENT_COUNT=2; select 1;" + ) + + await cur.execute("alter session set MULTI_STATEMENT_COUNT=5") + with pytest.raises( + errors.ProgrammingError, + match="Actual statement count 1 did not match the desired statement count 5.", + ): + await cur.execute("select 1;") + + with pytest.raises( + errors.ProgrammingError, + match="Actual statement count 3 did not match the desired statement count 5.", + ): + await cur.execute("select 1; select 2; select 3;") + + +async def _check_multi_statement_results( + cur: SnowflakeCursor, + checks: "list[list[tuple] | function]", + skip_to_last_set: bool, +) -> None: + savedIds = [] + for index, check in enumerate(checks): + if not skip_to_last_set or index == len(checks) - 1: + if callable(check): + assert check(await cur.fetchall()) + else: + assert await cur.fetchall() == check + savedIds.append(cur.sfqid) + assert await cur.nextset() == (cur if index < len(checks) - 1 else None) + assert await cur.fetchall() == [] + + assert cur.multi_statement_savedIds[-1 if skip_to_last_set else 0 :] == savedIds + + +async def test_multi_statement_basic(conn_cnx, skip_to_last_set: bool): + """Selects fixed integer data using statement level parameters.""" + async with conn_cnx() as con: + async with con.cursor() as cur: + statement_params = dict() + await cur.execute( + "select 1; select 2; select 'a';", + num_statements=3, + _statement_params=statement_params, + ) + await _check_multi_statement_results( + cur, + checks=[ + [(1,)], + [(2,)], + [("a",)], + ], + skip_to_last_set=skip_to_last_set, + ) + assert len(statement_params) == 0 + + +async def test_insert_select_multi(conn_cnx, db_parameters, skip_to_last_set: bool): + """Naive use of multi-statement to check multiple SQL functions.""" + async with conn_cnx(session_parameters={PARAMETER_MULTI_STATEMENT_COUNT: 0}) as con: + async with con.cursor() as cur: + table_name = random_string(5, "test_multi_table_").upper() + await cur.execute( + "use schema {db}.{schema};\n" + "create table {name} (aa int);\n" + "insert into {name}(aa) values(123456),(98765),(65432);\n" + "select aa from {name} order by aa;\n" + "drop table {name};".format( + db=db_parameters["database"], + schema=( + db_parameters["schema"] + if "schema" in db_parameters + else "PUBLIC" + ), + name=table_name, + ) + ) + await _check_multi_statement_results( + cur, + checks=[ + [("Statement executed successfully.",)], + [(f"Table {table_name} successfully created.",)], + [(3,)], + [(65432,), (98765,), (123456,)], + [(f"{table_name} successfully dropped.",)], + ], + skip_to_last_set=skip_to_last_set, + ) + + +@pytest.mark.parametrize("style", ["pyformat", "qmark"]) +async def test_binding_multi(conn_cnx, style: str, skip_to_last_set: bool): + """Tests using pyformat and qmark style bindings with multi-statement""" + test_string = "select {s}; select {s}, {s}; select {s}, {s}, {s};" + async with conn_cnx(paramstyle=style) as con: + async with con.cursor() as cur: + sql = test_string.format(s="%s" if style == "pyformat" else "?") + await cur.execute(sql, (10, 20, 30, "a", "b", "c"), num_statements=3) + await _check_multi_statement_results( + cur, + checks=[[(10,)], [(20, 30)], [("a", "b", "c")]], + skip_to_last_set=skip_to_last_set, + ) + + +async def test_async_exec_multi(conn_cnx, skip_to_last_set: bool): + """Tests whether async execution query works within a multi-statement""" + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select 1; select 2; select count(*) from table(generator(timeLimit => 1)); select 'b';", + num_statements=4, + ) + q_id = cur.sfqid + assert con.is_still_running(await con.get_query_status(q_id)) + await _wait_while_query_running_async(con, q_id, sleep_time=1) + async with conn_cnx() as con: + async with con.cursor() as cur: + await _wait_until_query_success_async( + con, q_id, num_checks=3, sleep_per_check=1 + ) + assert ( + await con.get_query_status_throw_if_error(q_id) == QueryStatus.SUCCESS + ) + + await cur.get_results_from_sfqid(q_id) + await _check_multi_statement_results( + cur, + checks=[[(1,)], [(2,)], lambda x: x > [(0,)], [("b",)]], + skip_to_last_set=skip_to_last_set, + ) + + +async def test_async_error_multi(conn_cnx): + """ + Runs a query that will fail to execute and then tests that if we tried to get results for the query + then that would raise an exception. It also tests QueryStatus related functionality too. + """ + async with conn_cnx(session_parameters={PARAMETER_MULTI_STATEMENT_COUNT: 0}) as con: + async with con.cursor() as cur: + sql = "select 1; select * from nonexistentTable" + q_id = (await cur.execute_async(sql)).get("queryId") + with pytest.raises( + ProgrammingError, + match="SQL compilation error:\nObject 'NONEXISTENTTABLE' does not exist or not authorized.", + ) as sync_error: + await cur.execute(sql) + await _wait_while_query_running_async(con, q_id, sleep_time=1) + assert await con.get_query_status(q_id) == QueryStatus.FAILED_WITH_ERROR + with pytest.raises(ProgrammingError) as e1: + await con.get_query_status_throw_if_error(q_id) + assert sync_error.value.errno != -1 + with pytest.raises(ProgrammingError) as e2: + await cur.get_results_from_sfqid(q_id) + assert e1.value.errno == e2.value.errno == sync_error.value.errno + + +async def test_mix_sync_async_multi(conn_cnx, skip_to_last_set: bool): + """Tests sending multiple multi-statement async queries at the same time.""" + async with conn_cnx( + session_parameters={ + PARAMETER_MULTI_STATEMENT_COUNT: 0, + "CLIENT_TIMESTAMP_TYPE_MAPPING": "TIMESTAMP_TZ", + } + ) as con: + async with con.cursor() as cur: + await cur.execute( + "create or replace temp table smallTable (colA string, colB int);" + "create or replace temp table uselessTable (colA string, colB int);" + ) + for table in ["smallTable", "uselessTable"]: + await cur.execute( + f"insert into {table} values('row1', 1);" + f"insert into {table} values('row2', 2);" + f"insert into {table} values('row3', 3);" + ) + await cur.execute_async("select 1; select 'a'; select * from smallTable;") + sf_qid1 = cur.sfqid + await cur.execute_async("select 2; select 'b'; select * from uselessTable") + sf_qid2 = cur.sfqid + # Wait until the 2 queries finish + await _wait_while_query_running_async(con, sf_qid1, sleep_time=1) + await _wait_while_query_running_async(con, sf_qid2, sleep_time=1) + await cur.execute("drop table uselessTable") + assert await cur.fetchall() == [("USELESSTABLE successfully dropped.",)] + await cur.get_results_from_sfqid(sf_qid1) + await _check_multi_statement_results( + cur, + checks=[[(1,)], [("a",)], [("row1", 1), ("row2", 2), ("row3", 3)]], + skip_to_last_set=skip_to_last_set, + ) + await cur.get_results_from_sfqid(sf_qid2) + await _check_multi_statement_results( + cur, + checks=[[(2,)], [("b",)], [("row1", 1), ("row2", 2), ("row3", 3)]], + skip_to_last_set=skip_to_last_set, + ) + + +async def test_done_caching_multi(conn_cnx, skip_to_last_set: bool): + """Tests whether get status caching is working as expected.""" + async with conn_cnx(session_parameters={PARAMETER_MULTI_STATEMENT_COUNT: 0}) as con: + async with con.cursor() as cur: + await cur.execute_async( + "select 1; select 'a'; select count(*) from table(generator(timeLimit => 2));" + ) + qid1 = cur.sfqid + await cur.execute_async( + "select 2; select 'b'; select count(*) from table(generator(timeLimit => 2));" + ) + qid2 = cur.sfqid + assert len(con._async_sfqids) == 2 + await _wait_while_query_running_async(con, qid1, sleep_time=1) + await _wait_until_query_success_async( + con, qid1, num_checks=3, sleep_per_check=1 + ) + assert await con.get_query_status(qid1) == QueryStatus.SUCCESS + await cur.get_results_from_sfqid(qid1) + await _check_multi_statement_results( + cur, + checks=[[(1,)], [("a",)], lambda x: x > [(0,)]], + skip_to_last_set=skip_to_last_set, + ) + assert len(con._async_sfqids) == 1 + assert len(con._done_async_sfqids) == 1 + await _wait_while_query_running_async(con, qid2, sleep_time=1) + await _wait_until_query_success_async( + con, qid2, num_checks=3, sleep_per_check=1 + ) + assert await con.get_query_status(qid2) == QueryStatus.SUCCESS + await cur.get_results_from_sfqid(qid2) + await _check_multi_statement_results( + cur, + checks=[[(2,)], [("b",)], lambda x: x > [(0,)]], + skip_to_last_set=skip_to_last_set, + ) + assert len(con._async_sfqids) == 0 + assert len(con._done_async_sfqids) == 2 + assert await con._all_async_queries_finished() + + +async def test_alter_session_multi(conn_cnx): + """Tests whether multiple alter session queries are detected and stored in the connection.""" + async with conn_cnx(session_parameters={PARAMETER_MULTI_STATEMENT_COUNT: 0}) as con: + async with con.cursor() as cur: + sql = ( + "select 1;" + "alter session set autocommit=false;" + "select 'a';" + "alter session set json_indent = 4;" + "alter session set CLIENT_TIMESTAMP_TYPE_MAPPING = 'TIMESTAMP_TZ'" + ) + await cur.execute(sql) + assert con.converter.get_parameter("AUTOCOMMIT") == "false" + assert con.converter.get_parameter("JSON_INDENT") == "4" + assert ( + con.converter.get_parameter("CLIENT_TIMESTAMP_TYPE_MAPPING") + == "TIMESTAMP_TZ" + ) + + +async def test_executemany_multi(conn_cnx, skip_to_last_set: bool): + """Tests executemany with multi-statement optimizations enabled through the num_statements parameter.""" + table1 = random_string(5, "test_executemany_multi_") + table2 = random_string(5, "test_executemany_multi_") + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute( + f"create temp table {table1} (aa number); create temp table {table2} (bb number);", + num_statements=2, + ) + await cur.executemany( + f"insert into {table1}(aa) values(%(value1)s); insert into {table2}(bb) values(%(value2)s);", + [ + {"value1": 1234, "value2": 4}, + {"value1": 234, "value2": 34}, + {"value1": 34, "value2": 234}, + {"value1": 4, "value2": 1234}, + ], + num_statements=2, + ) + assert (await cur.fetchone())[0] == 1 + while await cur.nextset(): + assert (await cur.fetchone())[0] == 1 + await cur.execute( + f"select aa from {table1}; select bb from {table2};", num_statements=2 + ) + await _check_multi_statement_results( + cur, + checks=[[(1234,), (234,), (34,), (4,)], [(4,), (34,), (234,), (1234,)]], + skip_to_last_set=skip_to_last_set, + ) + + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute( + f"create temp table {table1} (aa number); create temp table {table2} (bb number);", + num_statements=2, + ) + await cur.executemany( + f"insert into {table1}(aa) values(%s); insert into {table2}(bb) values(%s);", + [ + (12345, 4), + (1234, 34), + (234, 234), + (34, 1234), + (4, 12345), + ], + num_statements=2, + ) + assert (await cur.fetchone())[0] == 1 + while await cur.nextset(): + assert (await cur.fetchone())[0] == 1 + await cur.execute( + f"select aa from {table1}; select bb from {table2};", num_statements=2 + ) + await _check_multi_statement_results( + cur, + checks=[ + [(12345,), (1234,), (234,), (34,), (4,)], + [(4,), (34,), (234,), (1234,), (12345,)], + ], + skip_to_last_set=skip_to_last_set, + ) + + +async def test_executmany_qmark_multi(conn_cnx, skip_to_last_set: bool): + """Tests executemany with multi-statement optimization with qmark style.""" + table1 = random_string(5, "test_executemany_qmark_multi_") + table2 = random_string(5, "test_executemany_qmark_multi_") + async with conn_cnx(paramstyle="qmark") as con: + async with con.cursor() as cur: + await cur.execute( + f"create temp table {table1}(aa number); create temp table {table2}(bb number);", + num_statements=2, + ) + await cur.executemany( + f"insert into {table1}(aa) values(?); insert into {table2}(bb) values(?);", + [ + [1234, 4], + [234, 34], + [34, 234], + [4, 1234], + ], + num_statements=2, + ) + assert (await cur.fetchone())[0] == 1 + while await cur.nextset(): + assert (await cur.fetchone())[0] == 1 + await cur.execute( + f"select aa from {table1}; select bb from {table2};", num_statements=2 + ) + await _check_multi_statement_results( + cur, + checks=[ + [(1234,), (234,), (34,), (4,)], + [(4,), (34,), (234,), (1234,)], + ], + skip_to_last_set=skip_to_last_set, + ) diff --git a/test/integ/aio/test_network_async.py b/test/integ/aio/test_network_async.py new file mode 100644 index 0000000000..0bf153abb7 --- /dev/null +++ b/test/integ/aio/test_network_async.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import logging +import unittest.mock +from logging import getLogger + +import pytest + +import snowflake.connector.aio +from snowflake.connector import errorcode, errors +from snowflake.connector.aio._network import SnowflakeRestful +from snowflake.connector.network import ( + QUERY_IN_PROGRESS_ASYNC_CODE, + QUERY_IN_PROGRESS_CODE, +) + +logger = getLogger(__name__) + + +async def test_no_auth(db_parameters): + """SNOW-13588: No auth Rest API test.""" + rest = SnowflakeRestful(host=db_parameters["host"], port=db_parameters["port"]) + try: + # no auth + # show warehouse + await rest.request( + url="/queries", + body={ + "sequenceId": 10000, + "sqlText": "show warehouses", + "parameters": { + "ui_mode": True, + }, + }, + method="post", + client="rest", + ) + raise Exception("Must fail with auth error") + except errors.Error as e: + assert e.errno == errorcode.ER_CONNECTION_IS_CLOSED + finally: + await rest.close() + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "query_return_code", [QUERY_IN_PROGRESS_CODE, QUERY_IN_PROGRESS_ASYNC_CODE] +) +async def test_none_object_when_querying_result( + db_parameters, caplog, query_return_code +): + # this test simulate the case where the response from the server is None + # the following events happen in sequence: + # 1. we send a simple query to the server which is a post request + # 2. we record the query result in a global variable + # 3. we mock return a query in progress code and an url to fetch the query result + # 4. we return None for the fetching query result request for the first time + # 5. for the second time, we return the code for the query result + # 6. in the end, we assert the result, and retry has taken place when result is None by checking logging + + original_request_exec = SnowflakeRestful._request_exec + expected_ret = None + get_executed_time = 0 + + async def side_effect_request_exec(self, *args, **kwargs): + nonlocal expected_ret, get_executed_time + # 1. we send a simple query to the server which is a post request + if "queries/v1/query-request" in kwargs["full_url"]: + ret = await original_request_exec(self, *args, **kwargs) + expected_ret = ret # 2. we record the query result in a global variable + # 3. we mock return a query in progress code and an url to fetch the query result + return { + "code": query_return_code, + "data": {"getResultUrl": "/queries/123/result"}, + } + + if "/queries/123/result" in kwargs["full_url"]: + if get_executed_time == 0: + # 4. we return None for the 1st time fetching query result request, this should trigger retry + get_executed_time += 1 + return None + else: + # 5. for the second time, we return the code for the query result, this indicates retry success + return expected_ret + + with caplog.at_level(logging.INFO): + async with snowflake.connector.aio.SnowflakeConnection( + **db_parameters + ) as conn, conn.cursor() as cursor: + with unittest.mock.patch.object( + SnowflakeRestful, "_request_exec", new=side_effect_request_exec + ): + # 6. in the end, we assert the result, and retry has taken place when result is None by checking logging + assert await (await cursor.execute("select 1")).fetchone() == (1,) + assert ( + "fetch query status failed and http request returned None, this is usually caused by transient network failures, retrying" + in caplog.text + ) diff --git a/test/integ/aio/test_numpy_binding_async.py b/test/integ/aio/test_numpy_binding_async.py new file mode 100644 index 0000000000..429c7af9d7 --- /dev/null +++ b/test/integ/aio/test_numpy_binding_async.py @@ -0,0 +1,193 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import datetime +import time + +import numpy as np + + +async def test_numpy_datatype_binding(conn_cnx, db_parameters): + """Tests numpy data type bindings.""" + epoch_time = time.time() + current_datetime = datetime.datetime.fromtimestamp(epoch_time) + current_datetime64 = np.datetime64(current_datetime) + all_data = [ + { + "tz": "America/Los_Angeles", + "float": "1.79769313486e+308", + "numpy_bool": np.True_, + "epoch_time": epoch_time, + "current_time": current_datetime64, + "specific_date": np.datetime64("2005-02-25T03:30"), + "expected_specific_date": np.datetime64("2005-02-25T03:30").astype( + datetime.datetime + ), + }, + { + "tz": "Asia/Tokyo", + "float": "-1.79769313486e+308", + "numpy_bool": np.False_, + "epoch_time": epoch_time, + "current_time": current_datetime64, + "specific_date": np.datetime64("1970-12-31T05:00:00"), + "expected_specific_date": np.datetime64("1970-12-31T05:00:00").astype( + datetime.datetime + ), + }, + { + "tz": "America/New_York", + "float": "-1.79769313486e+308", + "numpy_bool": np.True_, + "epoch_time": epoch_time, + "current_time": current_datetime64, + "specific_date": np.datetime64("1969-12-31T05:00:00"), + "expected_specific_date": np.datetime64("1969-12-31T05:00:00").astype( + datetime.datetime + ), + }, + { + "tz": "UTC", + "float": "-1.79769313486e+308", + "numpy_bool": np.False_, + "epoch_time": epoch_time, + "current_time": current_datetime64, + "specific_date": np.datetime64("1968-11-12T07:00:00.123"), + "expected_specific_date": np.datetime64("1968-11-12T07:00:00.123").astype( + datetime.datetime + ), + }, + ] + try: + async with conn_cnx(numpy=True) as cnx: + await cnx.cursor().execute( + """ +CREATE OR REPLACE TABLE {name} ( + c1 integer, -- int8 + c2 integer, -- int16 + c3 integer, -- int32 + c4 integer, -- int64 + c5 float, -- float16 + c6 float, -- float32 + c7 float, -- float64 + c8 timestamp_ntz, -- datetime64 + c9 date, -- datetime64 + c10 timestamp_ltz, -- datetime64, + c11 timestamp_tz, -- datetime64 + c12 boolean) -- numpy.bool_ + """.format( + name=db_parameters["name"] + ) + ) + for data in all_data: + await cnx.cursor().execute( + """ +ALTER SESSION SET timezone='{tz}'""".format( + tz=data["tz"] + ) + ) + await cnx.cursor().execute( + """ +INSERT INTO {name}( + c1, + c2, + c3, + c4, + c5, + c6, + c7, + c8, + c9, + c10, + c11, + c12 +) +VALUES( + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s)""".format( + name=db_parameters["name"] + ), + ( + np.iinfo(np.int8).max, + np.iinfo(np.int16).max, + np.iinfo(np.int32).max, + np.iinfo(np.int64).max, + np.finfo(np.float16).max, + np.finfo(np.float32).max, + np.float64(data["float"]), + data["current_time"], + data["current_time"], + data["current_time"], + data["specific_date"], + data["numpy_bool"], + ), + ) + rec = await ( + await cnx.cursor().execute( + """ +SELECT + c1, + c2, + c3, + c4, + c5, + c6, + c7, + c8, + c9, + c10, + c11, + c12 + FROM {name}""".format( + name=db_parameters["name"] + ) + ) + ).fetchone() + assert np.int8(rec[0]) == np.iinfo(np.int8).max + assert np.int16(rec[1]) == np.iinfo(np.int16).max + assert np.int32(rec[2]) == np.iinfo(np.int32).max + assert np.int64(rec[3]) == np.iinfo(np.int64).max + assert np.float16(rec[4]) == np.finfo(np.float16).max + assert np.float32(rec[5]) == np.finfo(np.float32).max + assert rec[6] == np.float64(data["float"]) + assert rec[7] == data["current_time"] + assert str(rec[8]) == str(data["current_time"])[0:10] + assert rec[9] == datetime.datetime.fromtimestamp( + epoch_time, rec[9].tzinfo + ) + assert rec[10] == data["expected_specific_date"].replace( + tzinfo=rec[10].tzinfo + ) + assert ( + isinstance(rec[11], bool) + and rec[11] == data["numpy_bool"] + and np.bool_(rec[11]) == data["numpy_bool"] + ) + await cnx.cursor().execute( + """ +DELETE FROM {name}""".format( + name=db_parameters["name"] + ) + ) + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ + DROP TABLE IF EXISTS {name} + """.format( + name=db_parameters["name"] + ) + ) diff --git a/test/integ/aio/test_pickle_timestamp_tz_async.py b/test/integ/aio/test_pickle_timestamp_tz_async.py new file mode 100644 index 0000000000..4317a180ae --- /dev/null +++ b/test/integ/aio/test_pickle_timestamp_tz_async.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import os +import pickle + + +async def test_pickle_timestamp_tz(tmpdir, conn_cnx): + """Ensures the timestamp_tz result is pickle-able.""" + tmp_dir = str(tmpdir.mkdir("pickles")) + output = os.path.join(tmp_dir, "tz.pickle") + expected_tz = None + async with conn_cnx() as con: + async for rec in await con.cursor().execute( + "select '2019-08-11 01:02:03.123 -03:00'::TIMESTAMP_TZ" + ): + expected_tz = rec[0] + with open(output, "wb") as f: + pickle.dump(expected_tz, f) + + with open(output, "rb") as f: + read_tz = pickle.load(f) + assert expected_tz == read_tz diff --git a/test/integ/aio/test_put_get_async.py b/test/integ/aio/test_put_get_async.py new file mode 100644 index 0000000000..bf7a7fff9b --- /dev/null +++ b/test/integ/aio/test_put_get_async.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import filecmp +import logging +import os +from io import BytesIO +from logging import getLogger +from os import path +from unittest import mock + +import pytest + +from snowflake.connector import OperationalError + +try: + from snowflake.connector.util_text import random_string +except ImportError: + from test.randomize import random_string + +from test.generate_test_files import generate_k_lines_of_n_files + +try: + from ..parameters import CONNECTION_PARAMETERS_ADMIN +except ImportError: + CONNECTION_PARAMETERS_ADMIN = {} + +THIS_DIR = path.dirname(path.realpath(__file__)) + +logger = getLogger(__name__) + +pytestmark = pytest.mark.asyncio +CLOUD = os.getenv("cloud_provider", "dev") + + +async def test_utf8_filename(tmp_path, aio_connection): + test_file = tmp_path / "utf卡豆.csv" + test_file.write_text("1,2,3\n") + stage_name = random_string(5, "test_utf8_filename_") + await aio_connection.connect() + cursor = aio_connection.cursor() + await cursor.execute(f"create temporary stage {stage_name}") + await ( + await cursor.execute( + "PUT 'file://{}' @{}".format(str(test_file).replace("\\", "/"), stage_name) + ) + ).fetchall() + await cursor.execute(f"select $1, $2, $3 from @{stage_name}") + assert await cursor.fetchone() == ("1", "2", "3") + + +async def test_put_threshold(tmp_path, aio_connection, is_public_test): + if is_public_test: + pytest.xfail( + reason="This feature hasn't been rolled out for public Snowflake deployments yet." + ) + file_name = "test_put_get_with_aws_token.txt.gz" + stage_name = random_string(5, "test_put_get_threshold_") + file = tmp_path / file_name + file.touch() + await aio_connection.connect() + cursor = aio_connection.cursor() + await cursor.execute(f"create temporary stage {stage_name}") + from snowflake.connector.file_transfer_agent import SnowflakeFileTransferAgent + + with mock.patch( + "snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent", + autospec=SnowflakeFileTransferAgent, + ) as mock_agent: + await cursor.execute(f"put file://{file} @{stage_name} threshold=156") + assert mock_agent.call_args[1].get("multipart_threshold", -1) == 156 + + +# Snowflake on GCP does not support multipart uploads +@pytest.mark.xfail(reason="multipart transfer is not merged yet") +# @pytest.mark.aws +# @pytest.mark.azure +@pytest.mark.parametrize("use_stream", [False, True]) +async def test_multipart_put(aio_connection, tmp_path, use_stream): + """This test does a multipart upload of a smaller file and then downloads it.""" + stage_name = random_string(5, "test_multipart_put_") + chunk_size = 6967790 + # Generate about 12 MB + generate_k_lines_of_n_files(100_000, 1, tmp_dir=str(tmp_path)) + get_dir = tmp_path / "get_dir" + get_dir.mkdir() + upload_file = tmp_path / "file0" + await aio_connection.connect() + cursor = aio_connection.cursor() + await cursor.execute(f"create temporary stage {stage_name}") + real_cmd_query = aio_connection.cmd_query + + async def fake_cmd_query(*a, **kw): + """Create a mock function to inject some value into the returned JSON""" + ret = await real_cmd_query(*a, **kw) + ret["data"]["threshold"] = chunk_size + return ret + + with mock.patch.object(aio_connection, "cmd_query", side_effect=fake_cmd_query): + with mock.patch("snowflake.connector.constants.S3_CHUNK_SIZE", chunk_size): + if use_stream: + kw = { + "command": f"put file://file0 @{stage_name} AUTO_COMPRESS=FALSE", + "file_stream": BytesIO(upload_file.read_bytes()), + } + else: + kw = { + "command": f"put file://{upload_file} @{stage_name} AUTO_COMPRESS=FALSE", + } + await cursor.execute(**kw) + res = await cursor.execute(f"list @{stage_name}") + print(await res.fetchall()) + await cursor.execute(f"get @{stage_name}/{upload_file.name} file://{get_dir}") + downloaded_file = get_dir / upload_file.name + assert downloaded_file.exists() + assert filecmp.cmp(upload_file, downloaded_file) + + +async def test_put_special_file_name(tmp_path, aio_connection): + test_file = tmp_path / "data~%23.csv" + test_file.write_text("1,2,3\n") + stage_name = random_string(5, "test_special_filename_") + await aio_connection.connect() + cursor = aio_connection.cursor() + await cursor.execute(f"create temporary stage {stage_name}") + filename_in_put = str(test_file).replace("\\", "/") + await ( + await cursor.execute( + f"PUT 'file://{filename_in_put}' @{stage_name}", + ) + ).fetchall() + await cursor.execute(f"select $1, $2, $3 from @{stage_name}") + assert await cursor.fetchone() == ("1", "2", "3") + + +async def test_get_empty_file(tmp_path, aio_connection): + test_file = tmp_path / "data.csv" + test_file.write_text("1,2,3\n") + stage_name = random_string(5, "test_get_empty_file_") + await aio_connection.connect() + cur = aio_connection.cursor() + await cur.execute(f"create temporary stage {stage_name}") + filename_in_put = str(test_file).replace("\\", "/") + await cur.execute( + f"PUT 'file://{filename_in_put}' @{stage_name}", + ) + empty_file = tmp_path / "foo.csv" + with pytest.raises(OperationalError, match=".*the file does not exist.*$"): + await cur.execute(f"GET @{stage_name}/foo.csv file://{tmp_path}") + assert not empty_file.exists() + + +async def test_get_file_permission(tmp_path, aio_connection, caplog): + test_file = tmp_path / "data.csv" + test_file.write_text("1,2,3\n") + stage_name = random_string(5, "test_get_empty_file_") + await aio_connection.connect() + cur = aio_connection.cursor() + await cur.execute(f"create temporary stage {stage_name}") + filename_in_put = str(test_file).replace("\\", "/") + await cur.execute( + f"PUT 'file://{filename_in_put}' @{stage_name}", + ) + + with caplog.at_level(logging.ERROR): + await cur.execute(f"GET @{stage_name}/data.csv file://{tmp_path}") + assert "FileNotFoundError" not in caplog.text + + # get the default mask, usually it is 0o022 + default_mask = os.umask(0) + os.umask(default_mask) + # files by default are given the permission 644 (Octal) + # umask is for denial, we need to negate + assert oct(os.stat(test_file).st_mode)[-3:] == oct(0o666 & ~default_mask)[-3:] + + +async def test_get_multiple_files_with_same_name(tmp_path, aio_connection, caplog): + test_file = tmp_path / "data.csv" + test_file.write_text("1,2,3\n") + stage_name = random_string(5, "test_get_multiple_files_with_same_name_") + await aio_connection.connect() + cur = aio_connection.cursor() + await cur.execute(f"create temporary stage {stage_name}") + filename_in_put = str(test_file).replace("\\", "/") + await cur.execute( + f"PUT 'file://{filename_in_put}' @{stage_name}/data/1/", + ) + await cur.execute( + f"PUT 'file://{filename_in_put}' @{stage_name}/data/2/", + ) + + with caplog.at_level(logging.WARNING): + try: + await cur.execute( + f"GET @{stage_name} file://{tmp_path} PATTERN='.*data.csv.gz'" + ) + except OperationalError: + # This is expected flakiness + pass + assert "Downloading multiple files with the same name" in caplog.text + + +async def test_transfer_error_message(tmp_path, aio_connection): + test_file = tmp_path / "data.csv" + test_file.write_text("1,2,3\n") + stage_name = random_string(5, "test_utf8_filename_") + await aio_connection.connect() + cursor = aio_connection.cursor() + await cursor.execute(f"create temporary stage {stage_name}") + with mock.patch( + "snowflake.connector.aio._storage_client.SnowflakeStorageClient.finish_upload", + side_effect=ConnectionError, + ): + with pytest.raises(OperationalError): + ( + await cursor.execute( + "PUT 'file://{}' @{}".format( + str(test_file).replace("\\", "/"), stage_name + ) + ) + ).fetchall() diff --git a/test/integ/aio/test_put_get_compress_enc_async.py b/test/integ/aio/test_put_get_compress_enc_async.py new file mode 100644 index 0000000000..8035f5b05f --- /dev/null +++ b/test/integ/aio/test_put_get_compress_enc_async.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import filecmp +import pathlib +from test.integ_helpers import put_async +from unittest.mock import patch + +import pytest + +from snowflake.connector.util_text import random_string + +pytestmark = pytest.mark.skipolddriver # old test driver tests won't run this module + +from snowflake.connector.aio._s3_storage_client import SnowflakeS3RestClient + +orig_send_req = SnowflakeS3RestClient._send_request_with_authentication_and_retry + + +def _prepare_tmp_file(to_dir: pathlib.Path) -> tuple[pathlib.Path, str]: + tmp_dir = to_dir / "data" + tmp_dir.mkdir() + file_name = "data.txt" + test_path = tmp_dir / file_name + with test_path.open("w") as f: + f.write("test1,test2\n") + f.write("test3,test4") + return test_path, file_name + + +async def mock_send_request( + self, + url, + verb, + retry_id, + query_parts=None, + x_amz_headers=None, + headers=None, + payload=None, + unsigned_payload=False, + ignore_content_encoding=False, +): + # when called under _initiate_multipart_upload and _upload_chunk, add content-encoding to header + if verb is not None and verb in ("POST", "PUT") and headers is not None: + headers["Content-Encoding"] = "gzip" + return await orig_send_req( + self, + url, + verb, + retry_id, + query_parts, + x_amz_headers, + headers, + payload, + unsigned_payload, + ignore_content_encoding, + ) + + +@pytest.mark.parametrize("auto_compress", [True, False]) +async def test_auto_compress_switch( + tmp_path: pathlib.Path, + conn_cnx, + auto_compress, +): + """Tests PUT command with auto_compress=False|True.""" + _test_name = random_string(5, "test_auto_compress_switch") + test_data, file_name = _prepare_tmp_file(tmp_path) + + async with conn_cnx() as cnx: + await cnx.cursor().execute(f"RM @~/{_test_name}") + try: + file_stream = test_data.open("rb") + async with cnx.cursor() as cur: + await put_async( + cur, + str(test_data), + f"~/{_test_name}", + False, + sql_options=f"auto_compress={auto_compress}", + file_stream=file_stream, + ) + + ret = await (await cnx.cursor().execute(f"LS @~/{_test_name}")).fetchone() + uploaded_gz_name = f"{file_name}.gz" + if auto_compress: + assert uploaded_gz_name in ret[0] + else: + assert uploaded_gz_name not in ret[0] + + # get this file, if the client handle compression meta correctly + get_dir = tmp_path / "get_dir" + get_dir.mkdir() + await cnx.cursor().execute( + f"GET @~/{_test_name}/{file_name} file://{get_dir}" + ) + + downloaded_file = get_dir / ( + uploaded_gz_name if auto_compress else file_name + ) + assert downloaded_file.exists() + if not auto_compress: + assert filecmp.cmp(test_data, downloaded_file) + + finally: + await cnx.cursor().execute(f"RM @~/{_test_name}") + if file_stream: + file_stream.close() + + +@pytest.mark.aws +async def test_get_gzip_content_encoding( + tmp_path: pathlib.Path, + conn_cnx, +): + """Tests GET command for a content-encoding=GZIP in stage""" + _test_name = random_string(5, "test_get_gzip_content_encoding") + test_data, file_name = _prepare_tmp_file(tmp_path) + + with patch( + "snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._send_request_with_authentication_and_retry", + mock_send_request, + ): + async with conn_cnx() as cnx: + await cnx.cursor().execute(f"RM @~/{_test_name}") + try: + file_stream = test_data.open("rb") + async with cnx.cursor() as cur: + await put_async( + cur, + str(test_data), + f"~/{_test_name}", + False, + sql_options="auto_compress=True", + file_stream=file_stream, + ) + + ret = await ( + await cnx.cursor().execute(f"LS @~/{_test_name}") + ).fetchone() + assert f"{file_name}.gz" in ret[0] + + # get this file, if the client handle compression meta correctly + get_dir = tmp_path / "get_dir" + get_dir.mkdir() + ret = await ( + await cnx.cursor().execute( + f"GET @~/{_test_name}/{file_name} file://{get_dir}" + ) + ).fetchone() + downloaded_file = get_dir / ret[0] + assert downloaded_file.exists() + + finally: + await cnx.cursor().execute(f"RM @~/{_test_name}") + if file_stream: + file_stream.close() + + +@pytest.mark.aws +async def test_sse_get_gzip_content_encoding( + tmp_path: pathlib.Path, + conn_cnx, +): + """Tests GET command for a content-encoding=GZIP in stage and it is SSE(server side encrypted)""" + _test_name = random_string(5, "test_sse_get_gzip_content_encoding") + test_data, orig_file_name = _prepare_tmp_file(tmp_path) + stage_name = random_string(5, "sse_stage") + with patch( + "snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._send_request_with_authentication_and_retry", + mock_send_request, + ): + async with conn_cnx() as cnx: + await cnx.cursor().execute( + f"create or replace stage {stage_name} ENCRYPTION=(TYPE='SNOWFLAKE_SSE')" + ) + await cnx.cursor().execute(f"RM @{stage_name}/{_test_name}") + try: + file_stream = test_data.open("rb") + async with cnx.cursor() as cur: + await put_async( + cur, + str(test_data), + f"{stage_name}/{_test_name}", + False, + sql_options="auto_compress=True", + file_stream=file_stream, + ) + + ret = await ( + await cnx.cursor().execute(f"LS @{stage_name}/{_test_name}") + ).fetchone() + assert f"{orig_file_name}.gz" in ret[0] + + # get this file, if the client handle compression meta correctly + get_dir = tmp_path / "get_dir" + get_dir.mkdir() + ret = await ( + await cnx.cursor().execute( + f"GET @{stage_name}/{_test_name}/{orig_file_name} file://{get_dir}" + ) + ).fetchone() + # TODO: The downloaded file should always be the unzip (original) file + downloaded_file = get_dir / ret[0] + assert downloaded_file.exists() + + finally: + await cnx.cursor().execute(f"RM @{stage_name}/{_test_name}") + if file_stream: + file_stream.close() diff --git a/test/integ/aio/test_put_get_medium_async.py b/test/integ/aio/test_put_get_medium_async.py new file mode 100644 index 0000000000..aeb9fcd2a3 --- /dev/null +++ b/test/integ/aio/test_put_get_medium_async.py @@ -0,0 +1,849 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import datetime +import gzip +import os +import sys +from logging import getLogger +from typing import IO, TYPE_CHECKING + +import pytest +import pytz + +from snowflake.connector import ProgrammingError +from snowflake.connector.aio._cursor import DictCursor +from snowflake.connector.file_transfer_agent import ( + SnowflakeAzureProgressPercentage, + SnowflakeProgressPercentage, + SnowflakeS3ProgressPercentage, +) + +try: + from snowflake.connector.util_text import random_string +except ImportError: + from test.randomize import random_string + +from test.generate_test_files import generate_k_lines_of_n_files +from test.integ_helpers import put_async + +if TYPE_CHECKING: + from snowflake.connector.aio import SnowflakeConnection + from snowflake.connector.aio._cursor import SnowflakeCursor + +try: + from ..parameters import CONNECTION_PARAMETERS_ADMIN +except ImportError: + CONNECTION_PARAMETERS_ADMIN = {} + +THIS_DIR = os.path.dirname(os.path.realpath(__file__)) +logger = getLogger(__name__) + +pytestmark = pytest.mark.asyncio +CLOUD = os.getenv("cloud_provider", "dev") + + +@pytest.fixture() +def file_src(request) -> tuple[str, int, IO[bytes]]: + file_name = request.param + data_file = os.path.join(THIS_DIR, "../../data", file_name) + file_size = os.stat(data_file).st_size + stream = open(data_file, "rb") + yield data_file, file_size, stream + stream.close() + + +async def run(cnx, db_parameters, sql): + sql = sql.format(name=db_parameters["name"]) + res = await cnx.cursor().execute(sql) + return await res.fetchall() + + +async def run_file_operation(cnx, db_parameters, files, sql): + sql = sql.format(files=files.replace("\\", "\\\\"), name=db_parameters["name"]) + res = await cnx.cursor().execute(sql) + return await res.fetchall() + + +async def run_dict_result(cnx, db_parameters, sql): + sql = sql.format(name=db_parameters["name"]) + res = await cnx.cursor(DictCursor).execute(sql) + return await res.fetchall() + + +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +@pytest.mark.parametrize("file_src", ["put_get_1.txt"], indirect=["file_src"]) +async def test_put_copy0(aio_connection, db_parameters, from_path, file_src): + """Puts and Copies a file.""" + file_path, _, file_stream = file_src + kwargs = { + "_put_callback": SnowflakeS3ProgressPercentage, + "_get_callback": SnowflakeS3ProgressPercentage, + "_put_azure_callback": SnowflakeAzureProgressPercentage, + "_get_azure_callback": SnowflakeAzureProgressPercentage, + "file_stream": file_stream, + } + + async def run_with_cursor( + cnx: SnowflakeConnection, sql: str + ) -> tuple[SnowflakeCursor, list[tuple] | list[dict]]: + sql = sql.format(name=db_parameters["name"]) + cur = cnx.cursor(DictCursor) + res = await cur.execute(sql) + return cur, await res.fetchall() + + await aio_connection.connect() + cursor = aio_connection.cursor(DictCursor) + await run( + aio_connection, + db_parameters, + """ +create or replace table {name} ( +aa int, +dt date, +ts timestamp, +tsltz timestamp_ltz, +tsntz timestamp_ntz, +tstz timestamp_tz, +pct float, +ratio number(5,2)) +""", + ) + + ret = await put_async( + cursor, file_path, f"%{db_parameters['name']}", from_path, **kwargs + ) + ret = await ret.fetchall() + assert cursor.is_file_transfer, "PUT" + assert len(ret) == 1, "Upload one file" + assert ret[0]["source"] == os.path.basename(file_path), "File name" + + c, ret = await run_with_cursor(aio_connection, "copy into {name}") + assert not c.is_file_transfer, "COPY" + assert len(ret) == 1 and ret[0]["status"] == "LOADED", "Failed to load data" + + assert ret[0]["rows_loaded"] == 3, "Failed to load 3 rows of data" + + await run(aio_connection, db_parameters, "drop table if exists {name}") + + +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +@pytest.mark.parametrize("file_src", ["gzip_sample.txt.gz"], indirect=["file_src"]) +async def test_put_copy_compressed(aio_connection, db_parameters, from_path, file_src): + """Puts and Copies compressed files.""" + file_name, file_size, file_stream = file_src + await aio_connection.connect() + + await run_dict_result( + aio_connection, db_parameters, "create or replace table {name} (value string)" + ) + csr = aio_connection.cursor(DictCursor) + ret = await put_async( + csr, + file_name, + f"%{db_parameters['name']}", + from_path, + file_stream=file_stream, + ) + ret = await ret.fetchall() + assert ret[0]["source"] == os.path.basename(file_name), "File name" + assert ret[0]["source_size"] == file_size, "File size" + assert ret[0]["status"] == "UPLOADED" + + ret = await run_dict_result(aio_connection, db_parameters, "copy into {name}") + assert len(ret) == 1 and ret[0]["status"] == "LOADED", "Failed to load data" + assert ret[0]["rows_loaded"] == 1, "Failed to load 1 rows of data" + + await run(aio_connection, db_parameters, "drop table if exists {name}") + + +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +@pytest.mark.parametrize("file_src", ["bzip2_sample.txt.bz2"], indirect=["file_src"]) +@pytest.mark.skip(reason="BZ2 is not detected in this test case. Need investigation") +async def test_put_copy_bz2_compressed( + aio_connection, db_parameters, from_path, file_src +): + """Put and Copy bz2 compressed files.""" + file_name, _, file_stream = file_src + await aio_connection.connect() + + await run( + aio_connection, db_parameters, "create or replace table {name} (value string)" + ) + res = await put_async( + aio_connection.cursor(), + file_name, + f"%{db_parameters['name']}", + from_path, + file_stream=file_stream, + ) + for rec in await res.fetchall(): + print(rec) + assert rec[-2] == "UPLOADED" + + for rec in await run(aio_connection, db_parameters, "copy into {name}"): + print(rec) + assert rec[1] == "LOADED" + + await run(aio_connection, db_parameters, "drop table if exists {name}") + + +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +@pytest.mark.parametrize("file_src", ["brotli_sample.txt.br"], indirect=["file_src"]) +async def test_put_copy_brotli_compressed( + aio_connection, db_parameters, from_path, file_src +): + """Puts and Copies brotli compressed files.""" + file_name, _, file_stream = file_src + await aio_connection.connect() + + await run( + aio_connection, db_parameters, "create or replace table {name} (value string)" + ) + res = await put_async( + aio_connection.cursor(), + file_name, + f"%{db_parameters['name']}", + from_path, + file_stream=file_stream, + ) + for rec in await res.fetchall(): + print(rec) + assert rec[-2] == "UPLOADED" + + for rec in await run( + aio_connection, + db_parameters, + "copy into {name} file_format=(compression='BROTLI')", + ): + print(rec) + assert rec[1] == "LOADED" + + await run(aio_connection, db_parameters, "drop table if exists {name}") + + +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +@pytest.mark.parametrize("file_src", ["zstd_sample.txt.zst"], indirect=["file_src"]) +async def test_put_copy_zstd_compressed( + aio_connection, db_parameters, from_path, file_src +): + """Puts and Copies zstd compressed files.""" + file_name, _, file_stream = file_src + await aio_connection.connect() + + await run( + aio_connection, db_parameters, "create or replace table {name} (value string)" + ) + res = await put_async( + aio_connection.cursor(), + file_name, + f"%{db_parameters['name']}", + from_path, + file_stream=file_stream, + ) + for rec in await res.fetchall(): + print(rec) + assert rec[-2] == "UPLOADED" + for rec in await run( + aio_connection, + db_parameters, + "copy into {name} file_format=(compression='ZSTD')", + ): + print(rec) + assert rec[1] == "LOADED" + + await run(aio_connection, db_parameters, "drop table if exists {name}") + + +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +@pytest.mark.parametrize("file_src", ["nation.impala.parquet"], indirect=["file_src"]) +async def test_put_copy_parquet_compressed( + aio_connection, db_parameters, from_path, file_src +): + """Puts and Copies parquet compressed files.""" + file_name, _, file_stream = file_src + await aio_connection.connect() + + await run( + aio_connection, + db_parameters, + """ +create or replace table {name} +(value variant) +stage_file_format=(type='parquet') +""", + ) + for rec in await ( + await put_async( + aio_connection.cursor(), + file_name, + f"%{db_parameters['name']}", + from_path, + file_stream=file_stream, + ) + ).fetchall(): + print(rec) + assert rec[-2] == "UPLOADED" + assert rec[4] == "PARQUET" + assert rec[5] == "PARQUET" + + for rec in await run(aio_connection, db_parameters, "copy into {name}"): + print(rec) + assert rec[1] == "LOADED" + + await run(aio_connection, db_parameters, "drop table if exists {name}") + + +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +@pytest.mark.parametrize("file_src", ["TestOrcFile.test1.orc"], indirect=["file_src"]) +async def test_put_copy_orc_compressed( + aio_connection, db_parameters, from_path, file_src +): + """Puts and Copies ORC compressed files.""" + file_name, _, file_stream = file_src + await aio_connection.connect() + await run( + aio_connection, + db_parameters, + """ +create or replace table {name} (value variant) stage_file_format=(type='orc') +""", + ) + for rec in await ( + await put_async( + aio_connection.cursor(), + file_name, + f"%{db_parameters['name']}", + from_path, + file_stream=file_stream, + ) + ).fetchall(): + print(rec) + assert rec[-2] == "UPLOADED" + assert rec[4] == "ORC" + assert rec[5] == "ORC" + for rec in await run(aio_connection, db_parameters, "copy into {name}"): + print(rec) + assert rec[1] == "LOADED" + + await run(aio_connection, db_parameters, "drop table if exists {name}") + + +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible." +) +async def test_copy_get(tmpdir, aio_connection, db_parameters): + """Copies and Gets a file.""" + name_unload = db_parameters["name"] + "_unload" + tmp_dir = str(tmpdir.mkdir("copy_get_stage")) + tmp_dir_user = str(tmpdir.mkdir("user_get")) + await aio_connection.connect() + + async def run_test(cnx, sql): + sql = sql.format( + name_unload=name_unload, + tmpdir=tmp_dir, + tmp_dir_user=tmp_dir_user, + name=db_parameters["name"], + ) + res = await cnx.cursor().execute(sql) + return await res.fetchall() + + await run_test( + aio_connection, "alter session set DISABLE_PUT_AND_GET_ON_EXTERNAL_STAGE=false" + ) + await run_test( + aio_connection, + """ +create or replace table {name} ( +aa int, +dt date, +ts timestamp, +tsltz timestamp_ltz, +tsntz timestamp_ntz, +tstz timestamp_tz, +pct float, +ratio number(5,2)) +""", + ) + await run_test( + aio_connection, + """ +create or replace stage {name_unload} +file_format = ( +format_name = 'common.public.csv' +field_delimiter = '|' +error_on_column_count_mismatch=false); +""", + ) + current_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + current_time = current_time.replace(tzinfo=pytz.timezone("America/Los_Angeles")) + current_date = datetime.date.today() + other_time = current_time.replace(tzinfo=pytz.timezone("Asia/Tokyo")) + + fmt = """ +insert into {name}(aa, dt, tstz) +values(%(value)s,%(dt)s,%(tstz)s) +""".format( + name=db_parameters["name"] + ) + aio_connection.cursor().executemany( + fmt, + [ + {"value": 6543, "dt": current_date, "tstz": other_time}, + {"value": 1234, "dt": current_date, "tstz": other_time}, + ], + ) + + await run_test( + aio_connection, + """ +copy into @{name_unload}/data_ +from {name} +file_format=( +format_name='common.public.csv' +compression='gzip') +max_file_size=10000000 +""", + ) + ret = await run_test(aio_connection, "get @{name_unload}/ file://{tmp_dir_user}/") + + assert ret[0][2] == "DOWNLOADED", "Failed to download" + cnt = 0 + for _, _, _ in os.walk(tmp_dir_user): + cnt += 1 + assert cnt > 0, "No file was downloaded" + + await run_test(aio_connection, "drop stage {name_unload}") + await run_test(aio_connection, "drop table if exists {name}") + + +@pytest.mark.flaky(reruns=3) +async def test_put_copy_many_files(tmpdir, aio_connection, db_parameters): + """Puts and Copies many_files.""" + # generates N files + number_of_files = 100 + number_of_lines = 1000 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + + files = os.path.join(tmp_dir, "file*") + await aio_connection.connect() + + await run_file_operation( + aio_connection, + db_parameters, + files, + """ +create or replace table {name} ( +aa int, +dt date, +ts timestamp, +tsltz timestamp_ltz, +tsntz timestamp_ntz, +tstz timestamp_tz, +pct float, +ratio number(6,2)) +""", + ) + await run_file_operation( + aio_connection, db_parameters, files, "put 'file://{files}' @%{name}" + ) + await run_file_operation(aio_connection, db_parameters, files, "copy into {name}") + rows = 0 + for rec in await run_file_operation( + aio_connection, db_parameters, files, "select count(*) from {name}" + ): + rows += rec[0] + assert rows == number_of_files * number_of_lines, "Number of rows" + + await run_file_operation( + aio_connection, db_parameters, files, "drop table if exists {name}" + ) + + +@pytest.mark.aws +async def test_put_copy_many_files_s3(tmpdir, aio_connection, db_parameters): + """[s3] Puts and Copies many files.""" + # generates N files + number_of_files = 10 + number_of_lines = 1000 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + + files = os.path.join(tmp_dir, "file*") + await aio_connection.connect() + + await run_file_operation( + aio_connection, + db_parameters, + files, + """ +create or replace table {name} ( +aa int, +dt date, +ts timestamp, +tsltz timestamp_ltz, +tsntz timestamp_ntz, +tstz timestamp_tz, +pct float, +ratio number(6,2)) +""", + ) + try: + await run_file_operation( + aio_connection, db_parameters, files, "put 'file://{files}' @%{name}" + ) + await run_file_operation( + aio_connection, db_parameters, files, "copy into {name}" + ) + + rows = 0 + for rec in await run_file_operation( + aio_connection, db_parameters, files, "select count(*) from {name}" + ): + rows += rec[0] + assert rows == number_of_files * number_of_lines, "Number of rows" + finally: + await run_file_operation( + aio_connection, db_parameters, files, "drop table if exists {name}" + ) + + +@pytest.mark.aws +@pytest.mark.azure +@pytest.mark.flaky(reruns=3) +async def test_put_copy_duplicated_files_s3(tmpdir, aio_connection, db_parameters): + """[s3] Puts and Copies duplicated files.""" + # generates N files + number_of_files = 5 + number_of_lines = 100 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + + files = os.path.join(tmp_dir, "file*") + await aio_connection.connect() + + await run_file_operation( + aio_connection, + db_parameters, + files, + """ +create or replace table {name} ( +aa int, +dt date, +ts timestamp, +tsltz timestamp_ltz, +tsntz timestamp_ntz, +tstz timestamp_tz, +pct float, +ratio number(6,2)) +""", + ) + + try: + success_cnt = 0 + skipped_cnt = 0 + for rec in await run_file_operation( + aio_connection, db_parameters, files, "put 'file://{files}' @%{name}" + ): + logger.info("rec=%s", rec) + if rec[6] == "UPLOADED": + success_cnt += 1 + elif rec[6] == "SKIPPED": + skipped_cnt += 1 + assert success_cnt == number_of_files, "uploaded files" + assert skipped_cnt == 0, "skipped files" + + deleted_cnt = 0 + await run_file_operation( + aio_connection, db_parameters, files, "rm @%{name}/file0" + ) + deleted_cnt += 1 + await run_file_operation( + aio_connection, db_parameters, files, "rm @%{name}/file1" + ) + deleted_cnt += 1 + await run_file_operation( + aio_connection, db_parameters, files, "rm @%{name}/file2" + ) + deleted_cnt += 1 + + success_cnt = 0 + skipped_cnt = 0 + for rec in await run_file_operation( + aio_connection, db_parameters, files, "put 'file://{files}' @%{name}" + ): + logger.info("rec=%s", rec) + if rec[6] == "UPLOADED": + success_cnt += 1 + elif rec[6] == "SKIPPED": + skipped_cnt += 1 + assert success_cnt == deleted_cnt, "uploaded files in the second time" + assert ( + skipped_cnt == number_of_files - deleted_cnt + ), "skipped files in the second time" + + await run_file_operation( + aio_connection, db_parameters, files, "copy into {name}" + ) + rows = 0 + for rec in await run_file_operation( + aio_connection, db_parameters, files, "select count(*) from {name}" + ): + rows += rec[0] + assert rows == number_of_files * number_of_lines, "Number of rows" + finally: + await run_file_operation( + aio_connection, db_parameters, files, "drop table if exists {name}" + ) + + +@pytest.mark.skipolddriver +@pytest.mark.aws +@pytest.mark.azure +async def test_put_collision(tmpdir, aio_connection): + """File name collision test. The data set have the same file names but contents are different.""" + number_of_files = 5 + number_of_lines = 10 + # data set 1 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, + number_of_files, + compress=True, + tmp_dir=str(tmpdir.mkdir("data1")), + ) + files1 = os.path.join(tmp_dir, "file*") + await aio_connection.connect() + cursor = aio_connection.cursor() + # data set 2 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, + number_of_files, + compress=True, + tmp_dir=str(tmpdir.mkdir("data2")), + ) + files2 = os.path.join(tmp_dir, "file*") + + stage_name = random_string(5, "test_put_collision_") + await cursor.execute(f"RM @~/{stage_name}") + try: + # upload all files + success_cnt = 0 + skipped_cnt = 0 + for rec in await ( + await cursor.execute( + "PUT 'file://{file}' @~/{stage_name}".format( + file=files1.replace("\\", "\\\\"), stage_name=stage_name + ) + ) + ).fetchall(): + + logger.info("rec=%s", rec) + if rec[6] == "UPLOADED": + success_cnt += 1 + elif rec[6] == "SKIPPED": + skipped_cnt += 1 + assert success_cnt == number_of_files + assert skipped_cnt == 0 + + # will skip uploading all files + success_cnt = 0 + skipped_cnt = 0 + for rec in await ( + await cursor.execute( + "PUT 'file://{file}' @~/{stage_name}".format( + file=files2.replace("\\", "\\\\"), stage_name=stage_name + ) + ) + ).fetchall(): + logger.info("rec=%s", rec) + if rec[6] == "UPLOADED": + success_cnt += 1 + elif rec[6] == "SKIPPED": + skipped_cnt += 1 + assert success_cnt == 0 + assert skipped_cnt == number_of_files + + # will overwrite all files + success_cnt = 0 + skipped_cnt = 0 + for rec in await ( + await cursor.execute( + "PUT 'file://{file}' @~/{stage_name} OVERWRITE=true".format( + file=files2.replace("\\", "\\\\"), stage_name=stage_name + ) + ) + ).fetchall(): + logger.info("rec=%s", rec) + if rec[6] == "UPLOADED": + success_cnt += 1 + elif rec[6] == "SKIPPED": + skipped_cnt += 1 + assert success_cnt == number_of_files + assert skipped_cnt == 0 + + finally: + await cursor.execute(f"RM @~/{stage_name}") + + +def _generate_huge_value_json(tmpdir, n=1, value_size=1): + fname = str(tmpdir.join("test_put_get_huge_json")) + f = gzip.open(fname, "wb") + for i in range(n): + logger.debug(f"adding a value in {i}") + f.write(f'{{"k":"{random_string(value_size)}"}}') + f.close() + return fname + + +@pytest.mark.aws +async def test_put_get_large_files_s3(tmpdir, aio_connection, db_parameters): + """[s3] Puts and Gets Large files.""" + number_of_files = 3 + number_of_lines = 200000 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + + files = os.path.join(tmp_dir, "file*") + output_dir = os.path.join(tmp_dir, "output_dir") + os.makedirs(output_dir) + await aio_connection.connect() + + class cb(SnowflakeProgressPercentage): + def __init__(self, filename, filesize, **_): + pass + + def __call__(self, bytes_amount): + pass + + async def run_test(cnx, sql): + return await ( + await cnx.cursor().execute( + sql.format( + files=files.replace("\\", "\\\\"), + dir=db_parameters["name"], + output_dir=output_dir.replace("\\", "\\\\"), + ), + _put_callback_output_stream=sys.stdout, + _get_callback_output_stream=sys.stdout, + _get_callback=cb, + _put_callback=cb, + ) + ).fetchall() + + try: + await run_test(aio_connection, "PUT 'file://{files}' @~/{dir}") + # run(cnx, "PUT 'file://{files}' @~/{dir}") # retry + all_recs = [] + for _ in range(100): + all_recs = await run_test(aio_connection, "LIST @~/{dir}") + if len(all_recs) == number_of_files: + break + await asyncio.sleep(1) + else: + pytest.fail( + "cannot list all files. Potentially " + "PUT command missed uploading Files: {}".format(all_recs) + ) + all_recs = await run_test(aio_connection, "GET @~/{dir} 'file://{output_dir}'") + assert len(all_recs) == number_of_files + assert all([rec[2] == "DOWNLOADED" for rec in all_recs]) + finally: + await run_test(aio_connection, "RM @~/{dir}") + + +@pytest.mark.aws +@pytest.mark.azure +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +@pytest.mark.parametrize("file_src", ["put_get_1.txt"], indirect=["file_src"]) +async def test_put_get_with_hint( + tmpdir, aio_connection, db_parameters, from_path, file_src +): + """SNOW-15153: PUTs and GETs with hint.""" + tmp_dir = str(tmpdir.mkdir("put_get_with_hint")) + file_name, file_size, file_stream = file_src + await aio_connection.connect() + + async def run_test(cnx, sql, _is_put_get=None): + sql = sql.format( + local_dir=tmp_dir.replace("\\", "\\\\"), name=db_parameters["name"] + ) + res = await cnx.cursor().execute(sql, _is_put_get=_is_put_get) + return await res.fetchone() + + # regular PUT case + ret = await ( + await put_async( + aio_connection.cursor(), + file_name, + f"~/{db_parameters['name']}", + from_path, + file_stream=file_stream, + ) + ).fetchone() + assert ret[0] == os.path.basename(file_name), "PUT filename" + # clean up a file + ret = await run_test(aio_connection, "RM @~/{name}") + assert ret[0].endswith(os.path.basename(file_name) + ".gz"), "RM filename" + + # PUT detection failure + with pytest.raises(ProgrammingError): + await put_async( + aio_connection.cursor(), + file_name, + f"~/{db_parameters['name']}", + from_path, + commented=True, + file_stream=file_stream, + ) + + # PUT with hint + ret = await ( + await put_async( + aio_connection.cursor(), + file_name, + f"~/{db_parameters['name']}", + from_path, + file_stream=file_stream, + _is_put_get=True, + ) + ).fetchone() + assert ret[0] == os.path.basename(file_name), "PUT filename" + + # GET detection failure + commented_get_sql = """ +--- test comments +GET @~/{name} file://{local_dir}""" + + with pytest.raises(ProgrammingError): + await run_test(aio_connection, commented_get_sql) + + # GET with hint + ret = await run_test(aio_connection, commented_get_sql, _is_put_get=True) + assert ret[0] == os.path.basename(file_name) + ".gz", "GET filename" diff --git a/test/integ/aio/test_put_get_snow_4525_async.py b/test/integ/aio/test_put_get_snow_4525_async.py new file mode 100644 index 0000000000..f65a4330aa --- /dev/null +++ b/test/integ/aio/test_put_get_snow_4525_async.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import os +import pathlib + + +async def test_load_bogus_file(tmp_path: pathlib.Path, conn_cnx, db_parameters): + """SNOW-4525: Loads Bogus file and should fail.""" + async with conn_cnx() as cnx: + await cnx.cursor().execute( + f""" +create or replace table {db_parameters["name"]} ( +aa int, +dt date, +ts timestamp, +tsltz timestamp_ltz, +tsntz timestamp_ntz, +tstz timestamp_tz, +pct float, +ratio number(5,2)) +""" + ) + temp_file = tmp_path / "bogus_files" + with temp_file.open("wb") as random_binary_file: + random_binary_file.write(os.urandom(1024)) + await cnx.cursor().execute(f"put file://{temp_file} @%{db_parameters['name']}") + + async with cnx.cursor() as c: + await c.execute(f"copy into {db_parameters['name']} on_error='skip_file'") + cnt = 0 + async for _rec in c: + cnt += 1 + assert _rec[1] == "LOAD_FAILED" + await cnx.cursor().execute(f"drop table if exists {db_parameters['name']}") + + +async def test_load_bogus_json_file(tmp_path: pathlib.Path, conn_cnx, db_parameters): + """SNOW-4525: Loads Bogus JSON file and should fail.""" + async with conn_cnx() as cnx: + json_table = db_parameters["name"] + "_json" + await cnx.cursor().execute(f"create or replace table {json_table} (v variant)") + + temp_file = tmp_path / "bogus_json_files" + temp_file.write_bytes(os.urandom(1024)) + await cnx.cursor().execute(f"put file://{temp_file} @%{json_table}") + + async with cnx.cursor() as c: + await c.execute( + f"copy into {json_table} on_error='skip_file' " + "file_format=(type='json')" + ) + cnt = 0 + async for _rec in c: + cnt += 1 + assert _rec[1] == "LOAD_FAILED" + await cnx.cursor().execute(f"drop table if exists {json_table}") diff --git a/test/integ/aio/test_put_get_user_stage_async.py b/test/integ/aio/test_put_get_user_stage_async.py new file mode 100644 index 0000000000..f242c41122 --- /dev/null +++ b/test/integ/aio/test_put_get_user_stage_async.py @@ -0,0 +1,514 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import mimetypes +import os +from getpass import getuser +from logging import getLogger +from test.generate_test_files import generate_k_lines_of_n_files +from test.integ_helpers import put_async +from unittest.mock import patch + +import pytest + +from snowflake.connector.cursor import SnowflakeCursor +from snowflake.connector.util_text import random_string + + +@pytest.mark.aws +@pytest.mark.parametrize("from_path", [True, False]) +async def test_put_get_small_data_via_user_stage( + is_public_test, tmpdir, conn_cnx, from_path +): + """[s3] Puts and Gets Small Data via User Stage.""" + if is_public_test or "AWS_ACCESS_KEY_ID" not in os.environ: + pytest.skip("This test requires to change the internal parameter") + number_of_files = 5 if from_path else 1 + number_of_lines = 1 + _put_get_user_stage( + tmpdir, + conn_cnx, + number_of_files=number_of_files, + number_of_lines=number_of_lines, + from_path=from_path, + ) + + +@pytest.mark.skip(reason="endpoints don't have s3-acc string, skip it for now") +@pytest.mark.internal +@pytest.mark.skipolddriver +@pytest.mark.aws +@pytest.mark.parametrize( + "from_path", + [True, False], +) +@pytest.mark.parametrize( + "accelerate_config", + [True, False], +) +def test_put_get_accelerate_user_stage(tmpdir, conn_cnx, from_path, accelerate_config): + """[s3] Puts and Gets Small Data via User Stage.""" + from snowflake.connector.file_transfer_agent import SnowflakeFileTransferAgent + from snowflake.connector.s3_storage_client import SnowflakeS3RestClient + + number_of_files = 5 if from_path else 1 + number_of_lines = 1 + endpoints = [] + + def mocked_file_agent(*args, **kwargs): + agent = SnowflakeFileTransferAgent(*args, **kwargs) + mocked_file_agent.agent = agent + return agent + + original_accelerate_config = SnowflakeS3RestClient.transfer_accelerate_config + expected_cfg = accelerate_config + + def mock_s3_transfer_accelerate_config(self, *args, **kwargs) -> bool: + bret = original_accelerate_config(self, *args, **kwargs) + endpoints.append(self.endpoint) + return bret + + def mock_s3_get_bucket_config(self, *args, **kwargs) -> bool: + return expected_cfg + + with patch( + "snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent", + side_effect=mocked_file_agent, + ): + with patch.multiple( + "snowflake.connector.s3_storage_client.SnowflakeS3RestClient", + _get_bucket_accelerate_config=mock_s3_get_bucket_config, + transfer_accelerate_config=mock_s3_transfer_accelerate_config, + ): + _put_get_user_stage( + tmpdir, + conn_cnx, + number_of_files=number_of_files, + number_of_lines=number_of_lines, + from_path=from_path, + ) + config_accl = mocked_file_agent.agent._use_accelerate_endpoint + if accelerate_config: + assert (config_accl is True) and all( + ele.find("s3-acc") >= 0 for ele in endpoints + ) + else: + assert (config_accl is False) and all( + ele.find("s3-acc") < 0 for ele in endpoints + ) + + +@pytest.mark.aws +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +def test_put_get_large_data_via_user_stage( + is_public_test, + tmpdir, + conn_cnx, + from_path, +): + """[s3] Puts and Gets Large Data via User Stage.""" + if is_public_test or "AWS_ACCESS_KEY_ID" not in os.environ: + pytest.skip("This test requires to change the internal parameter") + number_of_files = 2 if from_path else 1 + number_of_lines = 200000 + _put_get_user_stage( + tmpdir, + conn_cnx, + number_of_files=number_of_files, + number_of_lines=number_of_lines, + from_path=from_path, + ) + + +@pytest.mark.aws +@pytest.mark.internal +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +def test_put_small_data_use_s3_regional_url( + is_public_test, + tmpdir, + conn_cnx, + db_parameters, + from_path, +): + """[s3] Puts Small Data via User Stage using regional url.""" + if is_public_test or "AWS_ACCESS_KEY_ID" not in os.environ: + pytest.skip("This test requires to change the internal parameter") + number_of_files = 5 if from_path else 1 + number_of_lines = 1 + put_cursor = _put_get_user_stage_s3_regional_url( + tmpdir, + conn_cnx, + db_parameters, + number_of_files=number_of_files, + number_of_lines=number_of_lines, + from_path=from_path, + ) + assert put_cursor._connection._session_parameters.get( + "ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1" + ) + + +async def _put_get_user_stage_s3_regional_url( + tmpdir, + conn_cnx, + db_parameters, + number_of_files=1, + number_of_lines=1, + from_path=True, +) -> SnowflakeCursor | None: + async with conn_cnx( + role="accountadmin", + ) as cnx: + await cnx.cursor().execute( + "alter account set ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1 = true;" + ) + try: + put_cursor = await _put_get_user_stage( + tmpdir, + conn_cnx, + number_of_files, + number_of_lines, + from_path, + ) + finally: + async with conn_cnx( + role="accountadmin", + ) as cnx: + await cnx.cursor().execute( + "alter account set ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1 = false;" + ) + return put_cursor + + +async def _put_get_user_stage( + tmpdir, + conn_cnx, + number_of_files=1, + number_of_lines=1, + from_path=True, +) -> SnowflakeCursor | None: + put_cursor: SnowflakeCursor | None = None + # sanity check + assert "AWS_ACCESS_KEY_ID" in os.environ, "AWS_ACCESS_KEY_ID is missing" + assert "AWS_SECRET_ACCESS_KEY" in os.environ, "AWS_SECRET_ACCESS_KEY is missing" + if not from_path: + assert number_of_files == 1 + + random_str = random_string(5, "put_get_user_stage_") + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + + files = os.path.join(tmp_dir, "file*" if from_path else os.listdir(tmp_dir)[0]) + file_stream = None if from_path else open(files, "rb") + + stage_name = f"{random_str}_stage_{number_of_files}_{number_of_lines}" + async with conn_cnx() as cnx: + await cnx.cursor().execute( + f""" +create or replace table {random_str} ( +aa int, +dt date, +ts timestamp, +tsltz timestamp_ltz, +tsntz timestamp_ntz, +tstz timestamp_tz, +pct float, +ratio number(6,2)) +""" + ) + user_bucket = os.getenv( + "SF_AWS_USER_BUCKET", f"sfc-eng-regression/{getuser()}/reg" + ) + await cnx.cursor().execute( + f""" +create or replace stage {stage_name} +url='s3://{user_bucket}/{stage_name}-{number_of_files}-{number_of_lines}' +credentials=( + AWS_KEY_ID='{os.getenv("AWS_ACCESS_KEY_ID")}' + AWS_SECRET_KEY='{os.getenv("AWS_SECRET_ACCESS_KEY")}' +) +""" + ) + try: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "alter session set disable_put_and_get_on_external_stage = false" + ) + await cnx.cursor().execute(f"rm @{stage_name}") + + put_cursor = cnx.cursor() + await put_async( + put_cursor, files, stage_name, from_path, file_stream=file_stream + ) + await cnx.cursor().execute(f"copy into {random_str} from @{stage_name}") + c = cnx.cursor() + try: + await c.execute(f"select count(*) from {random_str}") + rows = 0 + async for rec in c: + rows += rec[0] + assert rows == number_of_files * number_of_lines, "Number of rows" + finally: + await c.close() + await cnx.cursor().execute(f"rm @{stage_name}") + await cnx.cursor().execute(f"copy into @{stage_name} from {random_str}") + tmp_dir_user = str(tmpdir.mkdir("put_get_stage")) + await cnx.cursor().execute(f"get @{stage_name}/ file://{tmp_dir_user}/") + for _, _, files in os.walk(tmp_dir_user): + for file in files: + mimetypes.init() + _, encoding = mimetypes.guess_type(file) + assert encoding == "gzip", "exported file type" + finally: + if file_stream: + file_stream.close() + async with conn_cnx() as cnx: + await cnx.cursor().execute(f"rm @{stage_name}") + await cnx.cursor().execute(f"drop stage if exists {stage_name}") + await cnx.cursor().execute(f"drop table if exists {random_str}") + return put_cursor + + +@pytest.mark.aws +@pytest.mark.flaky(reruns=3) +async def test_put_get_duplicated_data_user_stage( + is_public_test, + tmpdir, + conn_cnx, + number_of_files=5, + number_of_lines=100, +): + """[s3] Puts and Gets Duplicated Data using User Stage.""" + if is_public_test or "AWS_ACCESS_KEY_ID" not in os.environ: + pytest.skip("This test requires to change the internal parameter") + + random_str = random_string(5, "test_put_get_duplicated_data_user_stage_") + logger = getLogger(__name__) + assert "AWS_ACCESS_KEY_ID" in os.environ, "AWS_ACCESS_KEY_ID is missing" + assert "AWS_SECRET_ACCESS_KEY" in os.environ, "AWS_SECRET_ACCESS_KEY is missing" + + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + + files = os.path.join(tmp_dir, "file*") + + stage_name = f"{random_str}_stage" + async with conn_cnx() as cnx: + await cnx.cursor().execute( + f""" +create or replace table {random_str} ( +aa int, +dt date, +ts timestamp, +tsltz timestamp_ltz, +tsntz timestamp_ntz, +tstz timestamp_tz, +pct float, +ratio number(6,2)) +""" + ) + user_bucket = os.getenv( + "SF_AWS_USER_BUCKET", f"sfc-eng-regression/{getuser()}/reg" + ) + await cnx.cursor().execute( + f""" +create or replace stage {stage_name} +url='s3://{user_bucket}/{stage_name}-{number_of_files}-{number_of_lines}' +credentials=( + AWS_KEY_ID='{os.getenv("AWS_ACCESS_KEY_ID")}' + AWS_SECRET_KEY='{os.getenv("AWS_SECRET_ACCESS_KEY")}' +) +""" + ) + try: + async with conn_cnx() as cnx: + c = cnx.cursor() + try: + async for rec in await c.execute(f"rm @{stage_name}"): + logger.info("rec=%s", rec) + finally: + await c.close() + + success_cnt = 0 + skipped_cnt = 0 + async with cnx.cursor() as c: + await c.execute( + "alter session set disable_put_and_get_on_external_stage = false" + ) + async for rec in await c.execute(f"put file://{files} @{stage_name}"): + logger.info(f"rec={rec}") + if rec[6] == "UPLOADED": + success_cnt += 1 + elif rec[6] == "SKIPPED": + skipped_cnt += 1 + assert success_cnt == number_of_files, "uploaded files" + assert skipped_cnt == 0, "skipped files" + + logger.info(f"deleting files in {stage_name}") + + deleted_cnt = 0 + await cnx.cursor().execute(f"rm @{stage_name}/file0") + deleted_cnt += 1 + await cnx.cursor().execute(f"rm @{stage_name}/file1") + deleted_cnt += 1 + await cnx.cursor().execute(f"rm @{stage_name}/file2") + deleted_cnt += 1 + + success_cnt = 0 + skipped_cnt = 0 + async with cnx.cursor() as c: + async for rec in await c.execute( + f"put file://{files} @{stage_name}", + _raise_put_get_error=False, + ): + logger.info(f"rec={rec}") + if rec[6] == "UPLOADED": + success_cnt += 1 + elif rec[6] == "SKIPPED": + skipped_cnt += 1 + assert success_cnt == deleted_cnt, "uploaded files in the second time" + assert ( + skipped_cnt == number_of_files - deleted_cnt + ), "skipped files in the second time" + + await asyncio.sleep(5) + await cnx.cursor().execute(f"copy into {random_str} from @{stage_name}") + async with cnx.cursor() as c: + await c.execute(f"select count(*) from {random_str}") + rows = 0 + async for rec in c: + rows += rec[0] + assert rows == number_of_files * number_of_lines, "Number of rows" + await cnx.cursor().execute(f"rm @{stage_name}") + await cnx.cursor().execute(f"copy into @{stage_name} from {random_str}") + tmp_dir_user = str(tmpdir.mkdir("stage2")) + await cnx.cursor().execute(f"get @{stage_name}/ file://{tmp_dir_user}/") + for _, _, files in os.walk(tmp_dir_user): + for file in files: + mimetypes.init() + _, encoding = mimetypes.guess_type(file) + assert encoding == "gzip", "exported file type" + + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute(f"drop stage if exists {stage_name}") + await cnx.cursor().execute(f"drop table if exists {random_str}") + + +@pytest.mark.aws +async def test_get_data_user_stage( + is_public_test, + tmpdir, + conn_cnx, +): + """SNOW-20927: Tests Get failure with 404 error.""" + stage_name = random_string(5, "test_get_data_user_stage_") + if is_public_test or "AWS_ACCESS_KEY_ID" not in os.environ: + pytest.skip("This test requires to change the internal parameter") + + default_s3bucket = os.getenv( + "SF_AWS_USER_BUCKET", f"sfc-eng-regression/{getuser()}/reg" + ) + test_data = [ + { + "s3location": "{}/{}".format(default_s3bucket, f"{stage_name}_stage"), + "stage_name": f"{stage_name}_stage1", + "data_file_name": "data.txt", + }, + ] + for elem in test_data: + await _put_list_rm_files_in_stage(tmpdir, conn_cnx, elem) + + +async def _put_list_rm_files_in_stage(tmpdir, conn_cnx, elem): + s3location = elem["s3location"] + stage_name = elem["stage_name"] + data_file_name = elem["data_file_name"] + + from io import open + + from snowflake.connector.constants import UTF8 + + tmp_dir = str(tmpdir.mkdir("data")) + data_file = os.path.join(tmp_dir, data_file_name) + with open(data_file, "w", encoding=UTF8) as f: + f.write("123,456,string1\n") + f.write("789,012,string2\n") + + output_dir = str(tmpdir.mkdir("output")) + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +create or replace stage {stage_name} + url='s3://{s3location}' + credentials=( + AWS_KEY_ID='{aws_key_id}' + AWS_SECRET_KEY='{aws_secret_key}' + ) +""".format( + s3location=s3location, + stage_name=stage_name, + aws_key_id=os.getenv("AWS_ACCESS_KEY_ID"), + aws_secret_key=os.getenv("AWS_SECRET_ACCESS_KEY"), + ) + ) + try: + async with conn_cnx() as cnx: + await cnx.cursor().execute(f"RM @{stage_name}") + await cnx.cursor().execute( + "alter session set disable_put_and_get_on_external_stage = false" + ) + rec = await ( + await cnx.cursor().execute( + """ +PUT file://{file} @{stage_name} +""".format( + file=data_file, stage_name=stage_name + ) + ) + ).fetchone() + assert rec[0] == data_file_name + assert rec[6] == "UPLOADED" + rec = await ( + await cnx.cursor().execute( + """ +LIST @{stage_name} + """.format( + stage_name=stage_name + ) + ) + ).fetchone() + assert rec, "LIST should return something" + assert rec[0].startswith("s3://"), "The file location in S3" + rec = await ( + await cnx.cursor().execute( + """ +GET @{stage_name} file://{output_dir} +""".format( + stage_name=stage_name, output_dir=output_dir + ) + ) + ).fetchone() + assert rec[0] == data_file_name + ".gz" + assert rec[2] == "DOWNLOADED" + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + """ +RM @{stage_name} +""".format( + stage_name=stage_name + ) + ) + await cnx.cursor().execute(f"drop stage if exists {stage_name}") diff --git a/test/integ/aio/test_put_get_with_aws_token_async.py b/test/integ/aio/test_put_get_with_aws_token_async.py new file mode 100644 index 0000000000..92fa99aed0 --- /dev/null +++ b/test/integ/aio/test_put_get_with_aws_token_async.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import glob +import gzip +import os + +import pytest +from aiohttp import ClientResponseError + +from snowflake.connector.constants import UTF8 + +try: # pragma: no cover + from snowflake.connector.aio._file_transfer_agent import SnowflakeFileMeta + from snowflake.connector.aio._s3_storage_client import ( + S3Location, + SnowflakeS3RestClient, + ) + from snowflake.connector.file_transfer_agent import StorageCredential +except ImportError: + pass + +try: + from snowflake.connector.util_text import random_string +except ImportError: + from test.randomize import random_string + +from test.integ_helpers import put_async + +# Mark every test in this module as an aws test +pytestmark = [pytest.mark.asyncio, pytest.mark.aws] + + +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +async def test_put_get_with_aws(tmpdir, aio_connection, from_path): + """[s3] Puts and Gets a small text using AWS S3.""" + # create a data file + fname = str(tmpdir.join("test_put_get_with_aws_token.txt.gz")) + original_contents = "123,test1\n456,test2\n" + with gzip.open(fname, "wb") as f: + f.write(original_contents.encode(UTF8)) + tmp_dir = str(tmpdir.mkdir("test_put_get_with_aws_token")) + table_name = random_string(5, "snow9144_") + + await aio_connection.connect() + csr = aio_connection.cursor() + + try: + await csr.execute(f"create or replace table {table_name} (a int, b string)") + file_stream = None if from_path else open(fname, "rb") + await put_async( + csr, + fname, + f"%{table_name}", + from_path, + sql_options=" auto_compress=true parallel=30", + file_stream=file_stream, + ) + rec = await csr.fetchone() + assert rec[6] == "UPLOADED" + await csr.execute(f"copy into {table_name}") + await csr.execute(f"rm @%{table_name}") + assert await (await csr.execute(f"ls @%{table_name}")).fetchall() == [] + await csr.execute( + f"copy into @%{table_name} from {table_name} " + "file_format=(type=csv compression='gzip')" + ) + await csr.execute(f"get @%{table_name} file://{tmp_dir}") + rec = await csr.fetchone() + assert rec[0].startswith("data_"), "A file downloaded by GET" + assert rec[1] == 36, "Return right file size" + assert rec[2] == "DOWNLOADED", "Return DOWNLOADED status" + assert rec[3] == "", "Return no error message" + finally: + await csr.execute(f"drop table {table_name}") + if file_stream: + file_stream.close() + + files = glob.glob(os.path.join(tmp_dir, "data_*")) + with gzip.open(files[0], "rb") as fd: + contents = fd.read().decode(UTF8) + assert original_contents == contents, "Output is different from the original file" + + +@pytest.mark.skipolddriver +async def test_put_with_invalid_token(tmpdir, aio_connection): + """[s3] SNOW-6154: Uses invalid combination of AWS credential.""" + # create a data file + fname = str(tmpdir.join("test_put_get_with_aws_token.txt.gz")) + with gzip.open(fname, "wb") as f: + f.write("123,test1\n456,test2".encode(UTF8)) + table_name = random_string(5, "snow6154_") + + await aio_connection.connect() + csr = aio_connection.cursor() + + try: + await csr.execute(f"create or replace table {table_name} (a int, b string)") + ret = await csr._execute_helper(f"put file://{fname} @%{table_name}") + stage_info = ret["data"]["stageInfo"] + stage_credentials = stage_info["creds"] + creds = StorageCredential(stage_credentials, csr, "COMMAND WILL NOT BE USED") + statinfo = os.stat(fname) + meta = SnowflakeFileMeta( + name=os.path.basename(fname), + src_file_name=fname, + src_file_size=statinfo.st_size, + stage_location_type="S3", + encryption_material=None, + dst_file_name=os.path.basename(fname), + sha256_digest="None", + ) + + client = SnowflakeS3RestClient(meta, creds, stage_info, 8388608) + await client.transfer_accelerate_config(None) + await client.get_file_header(meta.name) # positive case + + # negative case, no aws token + token = stage_info["creds"]["AWS_TOKEN"] + del stage_info["creds"]["AWS_TOKEN"] + with pytest.raises(ClientResponseError): + await client.get_file_header(meta.name) + + # negative case, wrong location + stage_info["creds"]["AWS_TOKEN"] = token + s3path = client.s3location.path + bad_path = os.path.dirname(os.path.dirname(s3path)) + "/" + _s3location = S3Location(client.s3location.bucket_name, bad_path) + client.s3location = _s3location + client.chunks = [b"this is a chunk"] + client.num_of_chunks = 1 + client.retry_count[0] = 0 + client.data_file = fname + with pytest.raises(ClientResponseError): + await client.upload_chunk(0) + finally: + await csr.execute(f"drop table if exists {table_name}") diff --git a/test/integ/aio/test_put_get_with_azure_token_async.py b/test/integ/aio/test_put_get_with_azure_token_async.py new file mode 100644 index 0000000000..9dea563b78 --- /dev/null +++ b/test/integ/aio/test_put_get_with_azure_token_async.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import glob +import gzip +import logging +import os +import sys +import time +from logging import getLogger + +import pytest + +from snowflake.connector.constants import UTF8 +from snowflake.connector.file_transfer_agent import ( + SnowflakeAzureProgressPercentage, + SnowflakeProgressPercentage, +) + +try: + from snowflake.connector.util_text import random_string +except ImportError: + from test.randomize import random_string + +from test.generate_test_files import generate_k_lines_of_n_files +from test.integ_helpers import put_async + +logger = getLogger(__name__) + +# Mark every test in this module as an azure and a putget test +pytestmark = [pytest.mark.asyncio, pytest.mark.azure] + + +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +async def test_put_get_with_azure(tmpdir, aio_connection, from_path, caplog): + """[azure] Puts and Gets a small text using Azure.""" + # create a data file + caplog.set_level(logging.DEBUG) + fname = str(tmpdir.join("test_put_get_with_azure_token.txt.gz")) + original_contents = "123,test1\n456,test2\n" + with gzip.open(fname, "wb") as f: + f.write(original_contents.encode(UTF8)) + tmp_dir = str(tmpdir.mkdir("test_put_get_with_azure_token")) + table_name = random_string(5, "snow32806_") + + await aio_connection.connect() + csr = aio_connection.cursor() + + await csr.execute(f"create or replace table {table_name} (a int, b string)") + try: + file_stream = None if from_path else open(fname, "rb") + await put_async( + csr, + fname, + f"%{table_name}", + from_path, + sql_options=" auto_compress=true parallel=30", + _put_callback=SnowflakeAzureProgressPercentage, + _get_callback=SnowflakeAzureProgressPercentage, + file_stream=file_stream, + ) + assert (await csr.fetchone())[6] == "UPLOADED" + await csr.execute(f"copy into {table_name}") + await csr.execute(f"rm @%{table_name}") + assert await (await csr.execute(f"ls @%{table_name}")).fetchall() == [] + await csr.execute( + f"copy into @%{table_name} from {table_name} " + "file_format=(type=csv compression='gzip')" + ) + await csr.execute( + f"get @%{table_name} file://{tmp_dir}", + _put_callback=SnowflakeAzureProgressPercentage, + _get_callback=SnowflakeAzureProgressPercentage, + ) + rec = await csr.fetchone() + assert rec[0].startswith("data_"), "A file downloaded by GET" + assert rec[1] == 36, "Return right file size" + assert rec[2] == "DOWNLOADED", "Return DOWNLOADED status" + assert rec[3] == "", "Return no error message" + finally: + if file_stream: + file_stream.close() + await csr.execute(f"drop table {table_name}") + + for line in caplog.text.splitlines(): + if "blob.core.windows.net" in line: + assert ( + "sig=" not in line + ), "connectionpool logger is leaking sensitive information" + files = glob.glob(os.path.join(tmp_dir, "data_*")) + with gzip.open(files[0], "rb") as fd: + contents = fd.read().decode(UTF8) + assert original_contents == contents, "Output is different from the original file" + + +async def test_put_copy_many_files_azure(tmpdir, aio_connection): + """[azure] Puts and Copies many files.""" + # generates N files + number_of_files = 10 + number_of_lines = 1000 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + folder_name = random_string(5, "test_put_copy_many_files_azure_") + + files = os.path.join(tmp_dir, "file*") + + async def run(csr, sql): + sql = sql.format(files=files, name=folder_name) + return await (await csr.execute(sql)).fetchall() + + await aio_connection.connect() + csr = aio_connection.cursor() + + await run( + csr, + """ + create or replace table {name} ( + aa int, + dt date, + ts timestamp, + tsltz timestamp_ltz, + tsntz timestamp_ntz, + tstz timestamp_tz, + pct float, + ratio number(6,2)) + """, + ) + try: + all_recs = await run(csr, "put file://{files} @%{name}") + assert all([rec[6] == "UPLOADED" for rec in all_recs]) + await run(csr, "copy into {name}") + + rows = sum(rec[0] for rec in await run(csr, "select count(*) from {name}")) + assert rows == number_of_files * number_of_lines, "Number of rows" + finally: + await run(csr, "drop table if exists {name}") + + +async def test_put_copy_duplicated_files_azure(tmpdir, aio_connection): + """[azure] Puts and Copies duplicated files.""" + # generates N files + number_of_files = 5 + number_of_lines = 100 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + table_name = random_string(5, "test_put_copy_duplicated_files_azure_") + + files = os.path.join(tmp_dir, "file*") + + async def run(csr, sql): + sql = sql.format(files=files, name=table_name) + return await (await csr.execute(sql, _raise_put_get_error=False)).fetchall() + + await aio_connection.connect() + csr = aio_connection.cursor() + await run( + csr, + """ + create or replace table {name} ( + aa int, + dt date, + ts timestamp, + tsltz timestamp_ltz, + tsntz timestamp_ntz, + tstz timestamp_tz, + pct float, + ratio number(6,2)) + """, + ) + + try: + success_cnt = 0 + skipped_cnt = 0 + for rec in await run(csr, "put file://{files} @%{name}"): + logger.info("rec=%s", rec) + if rec[6] == "UPLOADED": + success_cnt += 1 + elif rec[6] == "SKIPPED": + skipped_cnt += 1 + assert success_cnt == number_of_files, "uploaded files" + assert skipped_cnt == 0, "skipped files" + + deleted_cnt = 0 + await run(csr, "rm @%{name}/file0") + deleted_cnt += 1 + await run(csr, "rm @%{name}/file1") + deleted_cnt += 1 + await run(csr, "rm @%{name}/file2") + deleted_cnt += 1 + + success_cnt = 0 + skipped_cnt = 0 + for rec in await run(csr, "put file://{files} @%{name}"): + logger.info("rec=%s", rec) + if rec[6] == "UPLOADED": + success_cnt += 1 + elif rec[6] == "SKIPPED": + skipped_cnt += 1 + assert success_cnt == deleted_cnt, "uploaded files in the second time" + assert ( + skipped_cnt == number_of_files - deleted_cnt + ), "skipped files in the second time" + + await run(csr, "copy into {name}") + rows = 0 + for rec in await run(csr, "select count(*) from {name}"): + rows += rec[0] + assert rows == number_of_files * number_of_lines, "Number of rows" + finally: + await run(csr, "drop table if exists {name}") + + +async def test_put_get_large_files_azure(tmpdir, aio_connection): + """[azure] Puts and Gets Large files.""" + number_of_files = 3 + number_of_lines = 200000 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + + files = os.path.join(tmp_dir, "file*") + output_dir = os.path.join(tmp_dir, "output_dir") + os.makedirs(output_dir) + folder_name = random_string(5, "test_put_get_large_files_azure_") + + class cb(SnowflakeProgressPercentage): + def __init__(self, filename, filesize, **_): + pass + + def __call__(self, bytes_amount): + pass + + async def run(cnx, sql): + return await ( + await cnx.cursor().execute( + sql.format(files=files, dir=folder_name, output_dir=output_dir), + _put_callback_output_stream=sys.stdout, + _get_callback_output_stream=sys.stdout, + _get_callback=cb, + _put_callback=cb, + ) + ).fetchall() + + await aio_connection.connect() + try: + all_recs = await run(aio_connection, "PUT file://{files} @~/{dir}") + assert all([rec[6] == "UPLOADED" for rec in all_recs]) + + for _ in range(60): + for _ in range(100): + all_recs = await run(aio_connection, "LIST @~/{dir}") + if len(all_recs) == number_of_files: + break + # you may not get the files right after PUT command + # due to the nature of Azure blob, which synchronizes + # data eventually. + time.sleep(1) + else: + # wait for another second and retry. + # this could happen if the files are partially available + # but not all. + time.sleep(1) + continue + break # success + else: + pytest.fail( + "cannot list all files. Potentially " + "PUT command missed uploading Files: {}".format(all_recs) + ) + all_recs = await run(aio_connection, "GET @~/{dir} file://{output_dir}") + assert len(all_recs) == number_of_files + assert all([rec[2] == "DOWNLOADED" for rec in all_recs]) + finally: + await run(aio_connection, "RM @~/{dir}") diff --git a/test/integ/aio/test_put_get_with_gcp_account_async.py b/test/integ/aio/test_put_get_with_gcp_account_async.py new file mode 100644 index 0000000000..937f45e306 --- /dev/null +++ b/test/integ/aio/test_put_get_with_gcp_account_async.py @@ -0,0 +1,427 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import glob +import gzip +import os +import sys +from filecmp import cmp +from logging import getLogger + +import pytest + +from snowflake.connector.constants import UTF8 +from snowflake.connector.errors import ProgrammingError +from snowflake.connector.file_transfer_agent import SnowflakeProgressPercentage + +try: + from snowflake.connector.util_text import random_string +except ImportError: + from test.randomize import random_string + +from test.generate_test_files import generate_k_lines_of_n_files +from test.integ_helpers import put_async + +logger = getLogger(__name__) + +# Mark every test in this module as a gcp test +pytestmark = [pytest.mark.asyncio, pytest.mark.gcp] + + +@pytest.mark.parametrize("enable_gcs_downscoped", [True]) +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +async def test_put_get_with_gcp( + tmpdir, + aio_connection, + is_public_test, + enable_gcs_downscoped, + from_path, +): + """[gcp] Puts and Gets a small text using gcp.""" + # create a data file + fname = str(tmpdir.join("test_put_get_with_gcp_token.txt.gz")) + original_contents = "123,test1\n456,test2\n" + with gzip.open(fname, "wb") as f: + f.write(original_contents.encode(UTF8)) + tmp_dir = str(tmpdir.mkdir("test_put_get_with_gcp_token")) + table_name = random_string(5, "snow32806_") + + await aio_connection.connect() + csr = aio_connection.cursor() + try: + await csr.execute( + f"ALTER SESSION SET GCS_USE_DOWNSCOPED_CREDENTIAL = {enable_gcs_downscoped}" + ) + except ProgrammingError as e: + if enable_gcs_downscoped: + # not raise error when the parameter is not available yet, using old behavior + raise e + await csr.execute(f"create or replace table {table_name} (a int, b string)") + try: + file_stream = None if from_path else open(fname, "rb") + await put_async( + csr, + fname, + f"%{table_name}", + from_path, + sql_options=" auto_compress=true parallel=30", + file_stream=file_stream, + ) + assert (await csr.fetchone())[6] == "UPLOADED" + await csr.execute(f"copy into {table_name}") + await csr.execute(f"rm @%{table_name}") + assert await (await csr.execute(f"ls @%{table_name}")).fetchall() == [] + await csr.execute( + f"copy into @%{table_name} from {table_name} " + "file_format=(type=csv compression='gzip')" + ) + await csr.execute(f"get @%{table_name} file://{tmp_dir}") + rec = await csr.fetchone() + assert rec[0].startswith("data_"), "A file downloaded by GET" + assert rec[1] == 36, "Return right file size" + assert rec[2] == "DOWNLOADED", "Return DOWNLOADED status" + assert rec[3] == "", "Return no error message" + finally: + if file_stream: + file_stream.close() + await csr.execute(f"drop table {table_name}") + + files = glob.glob(os.path.join(tmp_dir, "data_*")) + with gzip.open(files[0], "rb") as fd: + contents = fd.read().decode(UTF8) + assert original_contents == contents, "Output is different from the original file" + + +@pytest.mark.parametrize("enable_gcs_downscoped", [True]) +async def test_put_copy_many_files_gcp( + tmpdir, + aio_connection, + is_public_test, + enable_gcs_downscoped, +): + """[gcp] Puts and Copies many files.""" + # generates N files + number_of_files = 10 + number_of_lines = 1000 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + table_name = random_string(5, "test_put_copy_many_files_gcp_") + + files = os.path.join(tmp_dir, "file*") + + async def run(csr, sql): + sql = sql.format(files=files, name=table_name) + return await (await csr.execute(sql)).fetchall() + + await aio_connection.connect() + csr = aio_connection.cursor() + try: + await csr.execute( + f"ALTER SESSION SET GCS_USE_DOWNSCOPED_CREDENTIAL = {enable_gcs_downscoped}" + ) + except ProgrammingError as e: + if enable_gcs_downscoped: + # not raise error when the parameter is not available yet, using old behavior + raise e + await run( + csr, + """ + create or replace table {name} ( + aa int, + dt date, + ts timestamp, + tsltz timestamp_ltz, + tsntz timestamp_ntz, + tstz timestamp_tz, + pct float, + ratio number(6,2)) + """, + ) + try: + statement = "put file://{files} @%{name}" + if enable_gcs_downscoped: + statement += " overwrite = true" + + all_recs = await run(csr, statement) + assert all([rec[6] == "UPLOADED" for rec in all_recs]) + await run(csr, "copy into {name}") + + rows = sum(rec[0] for rec in await run(csr, "select count(*) from {name}")) + assert rows == number_of_files * number_of_lines, "Number of rows" + finally: + await run(csr, "drop table if exists {name}") + + +@pytest.mark.parametrize("enable_gcs_downscoped", [True]) +async def test_put_copy_duplicated_files_gcp( + tmpdir, + aio_connection, + is_public_test, + enable_gcs_downscoped, +): + """[gcp] Puts and Copies duplicated files.""" + # generates N files + number_of_files = 5 + number_of_lines = 100 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + table_name = random_string(5, "test_put_copy_duplicated_files_gcp_") + + files = os.path.join(tmp_dir, "file*") + + async def run(csr, sql): + sql = sql.format(files=files, name=table_name) + return await (await csr.execute(sql)).fetchall() + + await aio_connection.connect() + csr = aio_connection.cursor() + try: + await csr.execute( + f"ALTER SESSION SET GCS_USE_DOWNSCOPED_CREDENTIAL = {enable_gcs_downscoped}" + ) + except ProgrammingError as e: + if enable_gcs_downscoped: + # not raise error when the parameter is not available yet, using old behavior + raise e + await run( + csr, + """ + create or replace table {name} ( + aa int, + dt date, + ts timestamp, + tsltz timestamp_ltz, + tsntz timestamp_ntz, + tstz timestamp_tz, + pct float, + ratio number(6,2)) + """, + ) + + try: + success_cnt = 0 + skipped_cnt = 0 + put_statement = "put file://{files} @%{name}" + if enable_gcs_downscoped: + put_statement += " overwrite = true" + for rec in await run(csr, put_statement): + logger.info("rec=%s", rec) + if rec[6] == "UPLOADED": + success_cnt += 1 + elif rec[6] == "SKIPPED": + skipped_cnt += 1 + assert success_cnt == number_of_files, "uploaded files" + assert skipped_cnt == 0, "skipped files" + + deleted_cnt = 0 + await run(csr, "rm @%{name}/file0") + deleted_cnt += 1 + await run(csr, "rm @%{name}/file1") + deleted_cnt += 1 + await run(csr, "rm @%{name}/file2") + deleted_cnt += 1 + + success_cnt = 0 + skipped_cnt = 0 + for rec in await run(csr, put_statement): + logger.info("rec=%s", rec) + if rec[6] == "UPLOADED": + success_cnt += 1 + elif rec[6] == "SKIPPED": + skipped_cnt += 1 + assert success_cnt == number_of_files, "uploaded files in the second time" + assert skipped_cnt == 0, "skipped files in the second time" + + await run(csr, "copy into {name}") + rows = 0 + for rec in await run(csr, "select count(*) from {name}"): + rows += rec[0] + assert rows == number_of_files * number_of_lines, "Number of rows" + finally: + await run(csr, "drop table if exists {name}") + + +@pytest.mark.parametrize("enable_gcs_downscoped", [True]) +async def test_put_get_large_files_gcp( + tmpdir, + aio_connection, + is_public_test, + enable_gcs_downscoped, +): + """[gcp] Puts and Gets Large files.""" + number_of_files = 3 + number_of_lines = 200000 + tmp_dir = generate_k_lines_of_n_files( + number_of_lines, number_of_files, tmp_dir=str(tmpdir.mkdir("data")) + ) + folder_name = random_string(5, "test_put_get_large_files_gcp_") + + files = os.path.join(tmp_dir, "file*") + output_dir = os.path.join(tmp_dir, "output_dir") + os.makedirs(output_dir) + + class cb(SnowflakeProgressPercentage): + def __init__(self, filename, filesize, **_): + pass + + def __call__(self, bytes_amount): + pass + + async def run(cnx, sql): + return await ( + await cnx.cursor().execute( + sql.format(files=files, dir=folder_name, output_dir=output_dir), + _put_callback_output_stream=sys.stdout, + _get_callback_output_stream=sys.stdout, + _get_callback=cb, + _put_callback=cb, + ) + ).fetchall() + + await aio_connection.connect() + try: + try: + await run( + aio_connection, + f"ALTER SESSION SET GCS_USE_DOWNSCOPED_CREDENTIAL = {enable_gcs_downscoped}", + ) + except ProgrammingError as e: + if enable_gcs_downscoped: + # not raise error when the parameter is not available yet, using old behavior + raise e + all_recs = await run(aio_connection, "PUT file://{files} @~/{dir}") + assert all([rec[6] == "UPLOADED" for rec in all_recs]) + + for _ in range(60): + for _ in range(100): + all_recs = await run(aio_connection, "LIST @~/{dir}") + if len(all_recs) == number_of_files: + break + # you may not get the files right after PUT command + # due to the nature of gcs blob, which synchronizes + # data eventually. + await asyncio.sleep(1) + else: + # wait for another second and retry. + # this could happen if the files are partially available + # but not all. + await asyncio.sleep(1) + continue + break # success + else: + pytest.fail( + "cannot list all files. Potentially " + f"PUT command missed uploading Files: {all_recs}" + ) + all_recs = await run(aio_connection, "GET @~/{dir} file://{output_dir}") + assert len(all_recs) == number_of_files + assert all([rec[2] == "DOWNLOADED" for rec in all_recs]) + finally: + await run(aio_connection, "RM @~/{dir}") + + +@pytest.mark.parametrize("enable_gcs_downscoped", [True]) +async def test_auto_compress_off_gcp( + tmpdir, + aio_connection, + is_public_test, + enable_gcs_downscoped, +): + """[gcp] Puts and Gets a small text using gcp with no auto compression.""" + fname = str( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), "../../data", "example.json" + ) + ) + stage_name = random_string(5, "teststage_") + await aio_connection.connect() + cursor = aio_connection.cursor() + try: + await cursor.execute( + f"ALTER SESSION SET GCS_USE_DOWNSCOPED_CREDENTIAL = {enable_gcs_downscoped}" + ) + except ProgrammingError as e: + if enable_gcs_downscoped: + # not raise error when the parameter is not available yet, using old behavior + raise e + try: + await cursor.execute(f"create or replace stage {stage_name}") + await cursor.execute(f"put file://{fname} @{stage_name} auto_compress=false") + await cursor.execute(f"get @{stage_name} file://{tmpdir}") + downloaded_file = os.path.join(str(tmpdir), "example.json") + assert cmp(fname, downloaded_file) + finally: + await cursor.execute(f"drop stage {stage_name}") + + +@pytest.mark.parametrize( + "from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)] +) +async def test_put_overwrite_with_downscope( + tmpdir, + aio_connection, + is_public_test, + from_path, +): + """Tests whether _force_put_overwrite and overwrite=true works as intended.""" + + await aio_connection.connect() + csr = aio_connection.cursor() + tmp_dir = str(tmpdir.mkdir("data")) + test_data = os.path.join(tmp_dir, "data.txt") + stage_dir = f"test_put_overwrite_async_{random_string()}" + with open(test_data, "w") as f: + f.write("test1,test2") + f.write("test3,test4") + + await csr.execute(f"RM @~/{stage_dir}") + try: + file_stream = None if from_path else open(test_data, "rb") + await csr.execute("ALTER SESSION SET GCS_USE_DOWNSCOPED_CREDENTIAL = TRUE") + await put_async( + csr, + test_data, + f"~/{stage_dir}", + from_path, + file_stream=file_stream, + ) + data = await csr.fetchall() + assert data[0][6] == "UPLOADED" + + await put_async( + csr, + test_data, + f"~/{stage_dir}", + from_path, + file_stream=file_stream, + ) + data = await csr.fetchall() + assert data[0][6] == "SKIPPED" + + await put_async( + csr, + test_data, + f"~/{stage_dir}", + from_path, + sql_options="OVERWRITE = TRUE", + file_stream=file_stream, + ) + data = await csr.fetchall() + assert data[0][6] == "UPLOADED" + + ret = await (await csr.execute(f"LS @~/{stage_dir}")).fetchone() + assert f"{stage_dir}/data.txt" in ret[0] + assert "data.txt.gz" in ret[0] + finally: + if file_stream: + file_stream.close() + await csr.execute(f"RM @~/{stage_dir}") diff --git a/test/integ/aio/test_put_windows_path_async.py b/test/integ/aio/test_put_windows_path_async.py new file mode 100644 index 0000000000..5c274706d8 --- /dev/null +++ b/test/integ/aio/test_put_windows_path_async.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import os + + +async def test_abc(conn_cnx, tmpdir, db_parameters): + """Tests PUTing a file on Windows using the URI and Windows path.""" + import pathlib + + tmp_dir = str(tmpdir.mkdir("data")) + test_data = os.path.join(tmp_dir, "data.txt") + with open(test_data, "w") as f: + f.write("test1,test2") + f.write("test3,test4") + + fileURI = pathlib.Path(test_data).as_uri() + + subdir = db_parameters["name"] + async with conn_cnx( + user=db_parameters["user"], + account=db_parameters["account"], + password=db_parameters["password"], + ) as con: + rec = await ( + await con.cursor().execute(f"put {fileURI} @~/{subdir}0/") + ).fetchall() + assert rec[0][6] == "UPLOADED" + + rec = await ( + await con.cursor().execute(f"put file://{test_data} @~/{subdir}1/") + ).fetchall() + assert rec[0][6] == "UPLOADED" + + await con.cursor().execute(f"rm @~/{subdir}0") + await con.cursor().execute(f"rm @~/{subdir}1") diff --git a/test/integ/aio/test_qmark_async.py b/test/integ/aio/test_qmark_async.py new file mode 100644 index 0000000000..71f33b52d1 --- /dev/null +++ b/test/integ/aio/test_qmark_async.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import pytest + +from snowflake.connector import errors + + +async def test_qmark_paramstyle(conn_cnx, db_parameters): + """Tests that binding question marks is not supported by default.""" + try: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "CREATE OR REPLACE TABLE {name} " + "(aa STRING, bb STRING)".format(name=db_parameters["name"]) + ) + await cnx.cursor().execute( + "INSERT INTO {name} VALUES('?', '?')".format(name=db_parameters["name"]) + ) + async for rec in await cnx.cursor().execute( + "SELECT * FROM {name}".format(name=db_parameters["name"]) + ): + assert rec[0] == "?", "First column value" + with pytest.raises(errors.ProgrammingError): + await cnx.cursor().execute( + "INSERT INTO {name} VALUES(?,?)".format( + name=db_parameters["name"] + ) + ) + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "DROP TABLE IF EXISTS {name}".format(name=db_parameters["name"]) + ) + + +async def test_numeric_paramstyle(conn_cnx, db_parameters): + """Tests that binding numeric positional style is not supported.""" + try: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "CREATE OR REPLACE TABLE {name} " + "(aa STRING, bb STRING)".format(name=db_parameters["name"]) + ) + await cnx.cursor().execute( + "INSERT INTO {name} VALUES(':1', ':2')".format( + name=db_parameters["name"] + ) + ) + async for rec in await cnx.cursor().execute( + "SELECT * FROM {name}".format(name=db_parameters["name"]) + ): + assert rec[0] == ":1", "First column value" + with pytest.raises(errors.ProgrammingError): + await cnx.cursor().execute( + "INSERT INTO {name} VALUES(:1,:2)".format( + name=db_parameters["name"] + ) + ) + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "DROP TABLE IF EXISTS {name}".format(name=db_parameters["name"]) + ) + + +@pytest.mark.internal +async def test_qmark_paramstyle_enabled(negative_conn_cnx, db_parameters): + """Enable qmark binding.""" + import snowflake.connector + + snowflake.connector.paramstyle = "qmark" + try: + async with negative_conn_cnx() as cnx: + await cnx.cursor().execute( + "CREATE OR REPLACE TABLE {name} " + "(aa STRING, bb STRING)".format(name=db_parameters["name"]) + ) + await cnx.cursor().execute( + "INSERT INTO {name} VALUES(?, ?)".format(name=db_parameters["name"]), + ("test11", "test12"), + ) + ret = await ( + await cnx.cursor().execute( + "select * from {name}".format(name=db_parameters["name"]) + ) + ).fetchone() + assert ret[0] == "test11" + assert ret[1] == "test12" + finally: + async with negative_conn_cnx() as cnx: + await cnx.cursor().execute( + "DROP TABLE IF EXISTS {name}".format(name=db_parameters["name"]) + ) + snowflake.connector.paramstyle = "pyformat" + + # After changing back to pyformat, binding qmark should fail. + try: + async with negative_conn_cnx() as cnx: + await cnx.cursor().execute( + "CREATE OR REPLACE TABLE {name} " + "(aa STRING, bb STRING)".format(name=db_parameters["name"]) + ) + with pytest.raises(TypeError): + await cnx.cursor().execute( + "INSERT INTO {name} VALUES(?, ?)".format( + name=db_parameters["name"] + ), + ("test11", "test12"), + ) + finally: + async with negative_conn_cnx() as cnx: + await cnx.cursor().execute( + "DROP TABLE IF EXISTS {name}".format(name=db_parameters["name"]) + ) + + +async def test_binding_datetime_qmark(conn_cnx, db_parameters): + """Ensures datetime can bound.""" + import datetime + + import snowflake.connector + + snowflake.connector.paramstyle = "qmark" + try: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "CREATE OR REPLACE TABLE {name} " + "(aa TIMESTAMP_NTZ)".format(name=db_parameters["name"]) + ) + days = 2 + inserts = tuple((datetime.datetime(2018, 1, i + 1),) for i in range(days)) + await cnx.cursor().executemany( + "INSERT INTO {name} VALUES(?)".format(name=db_parameters["name"]), + inserts, + ) + ret = await ( + await cnx.cursor().execute( + "SELECT * FROM {name} ORDER BY 1".format(name=db_parameters["name"]) + ) + ).fetchall() + for i in range(days): + assert ret[i][0] == inserts[i][0] + finally: + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "DROP TABLE IF EXISTS {name}".format(name=db_parameters["name"]) + ) + + +async def test_binding_none(conn_cnx): + import snowflake.connector + + original = snowflake.connector.paramstyle + snowflake.connector.paramstyle = "qmark" + + async with conn_cnx() as con: + try: + table_name = "foo" + await con.cursor().execute(f"CREATE TABLE {table_name}(bar text)") + await con.cursor().execute(f"INSERT INTO {table_name} VALUES (?)", [None]) + finally: + await con.cursor().execute(f"DROP TABLE {table_name}") + snowflake.connector.paramstyle = original diff --git a/test/integ/aio/test_query_cancelling_async.py b/test/integ/aio/test_query_cancelling_async.py new file mode 100644 index 0000000000..72d35d77de --- /dev/null +++ b/test/integ/aio/test_query_cancelling_async.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import logging +from logging import getLogger + +import pytest + +from snowflake.connector import errors + +logger = getLogger(__name__) +logging.basicConfig(level=logging.CRITICAL) + +try: + from ..parameters import CONNECTION_PARAMETERS_ADMIN +except ImportError: + CONNECTION_PARAMETERS_ADMIN = {} + + +@pytest.fixture() +async def conn_cnx_query_cancelling(request, conn_cnx): + async with conn_cnx() as cnx: + await cnx.cursor().execute("use role securityadmin") + await cnx.cursor().execute( + "create or replace user magicuser1 password='xxx' " "default_role='PUBLIC'" + ) + await cnx.cursor().execute( + "create or replace user magicuser2 password='xxx' " "default_role='PUBLIC'" + ) + + yield conn_cnx + + async with conn_cnx() as cnx: + await cnx.cursor().execute("use role accountadmin") + await cnx.cursor().execute("drop user magicuser1") + await cnx.cursor().execute("drop user magicuser2") + + +async def _query_run(conn, shared, expectedCanceled=True): + """Runs a query, and wait for possible cancellation.""" + async with conn(user="magicuser1", password="xxx") as cnx: + await cnx.cursor().execute("use warehouse regress") + + # Collect the session_id + async with cnx.cursor() as c: + await c.execute("SELECT current_session()") + async for rec in c: + with shared.lock: + shared.session_id = int(rec[0]) + logger.info(f"Current Session id: {shared.session_id}") + + # Run a long query and see if we're canceled + canceled = False + try: + c = cnx.cursor() + await c.execute( + """ +select count(*) from table(generator(timeLimit => 10))""" + ) + except errors.ProgrammingError as e: + logger.info("FAILED TO RUN QUERY: %s", e) + canceled = e.errno == 604 + if not canceled: + logger.exception("must have been canceled") + raise + finally: + await c.close() + + if canceled: + logger.info("Query failed or was canceled") + else: + logger.info("Query finished successfully") + + assert canceled == expectedCanceled + + +async def _query_cancel(conn, shared, user, password, expectedCanceled): + """Tests cancelling the query running in another thread.""" + async with conn(user=user, password=password) as cnx: + await cnx.cursor().execute("use warehouse regress") + # .use_warehouse_database_schema(cnx) + + logger.info( + "User %s's role is: %s", + user, + (await (await cnx.cursor().execute("select current_role()")).fetchone())[0], + ) + # Run the cancel query + logger.info("User %s is waiting for Session ID to be available", user) + while True: + async with shared.lock: + if shared.session_id is not None: + break + logger.info("User %s is waiting for Session ID to be available", user) + await asyncio.sleep(1) + logger.info(f"Target Session id: {shared.session_id}") + try: + query = f"call system$cancel_all_queries({shared.session_id})" + logger.info("Query: %s", query) + await cnx.cursor().execute(query) + assert ( + expectedCanceled + ), "You should NOT be able to " "cancel the query [{}]".format( + shared.session_id + ) + except errors.ProgrammingError as e: + logger.info("FAILED TO CANCEL THE QUERY: %s", e) + assert ( + not expectedCanceled + ), "You should be able to " "cancel the query [{}]".format( + shared.session_id + ) + + +async def _test_helper(conn, expectedCanceled, cancelUser, cancelPass): + """Helper function for the actual tests. + + queryRun is always run with magicuser1/xxx. + queryCancel is run with cancelUser/cancelPass + """ + + class Shared: + def __init__(self): + self.lock = asyncio.Lock() + self.session_id = None + + shared = Shared() + + queryRun = asyncio.create_task(_query_run(conn, shared, expectedCanceled)) + queryCancel = asyncio.create_task( + _query_cancel(conn, shared, cancelUser, cancelPass, expectedCanceled) + ) + await asyncio.gather(queryRun, queryCancel) + + +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible." +) +async def test_same_user_canceling(conn_cnx_query_cancelling): + """Tests that the same user CAN cancel his own query.""" + await _test_helper(conn_cnx_query_cancelling, True, "magicuser1", "xxx") + + +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible." +) +async def test_other_user_canceling(conn_cnx_query_cancelling): + """Tests that the other user CAN NOT cancel his own query.""" + await _test_helper(conn_cnx_query_cancelling, False, "magicuser2", "xxx") diff --git a/test/integ/aio/test_results_async.py b/test/integ/aio/test_results_async.py new file mode 100644 index 0000000000..09aad67802 --- /dev/null +++ b/test/integ/aio/test_results_async.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import pytest + +from snowflake.connector import ProgrammingError + + +async def test_results(conn_cnx): + """Gets results for the given qid.""" + async with conn_cnx() as cnx: + cur = cnx.cursor() + await cur.execute("select * from values(1,2),(3,4)") + sfqid = cur.sfqid + cur = await cur.query_result(sfqid) + got_sfqid = cur.sfqid + assert await cur.fetchall() == [(1, 2), (3, 4)] + assert sfqid == got_sfqid + + +async def test_results_with_error(conn_cnx): + """Gets results with error.""" + async with conn_cnx() as cnx: + cur = cnx.cursor() + with pytest.raises(ProgrammingError) as e: + await cur.execute("select blah") + sfqid = e.value.sfqid + + with pytest.raises(ProgrammingError) as e: + await cur.query_result(sfqid) + got_sfqid = e.value.sfqid + + assert sfqid is not None + assert got_sfqid is not None + assert got_sfqid == sfqid diff --git a/test/integ/aio/test_reuse_cursor_async.py b/test/integ/aio/test_reuse_cursor_async.py new file mode 100644 index 0000000000..db6aa41aff --- /dev/null +++ b/test/integ/aio/test_reuse_cursor_async.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + + +async def test_reuse_cursor(conn_cnx, db_parameters): + """Ensures only the last executed command/query's result sets are returned.""" + async with conn_cnx() as cnx: + c = cnx.cursor() + await c.execute( + "create or replace table {name}(c1 string)".format( + name=db_parameters["name"] + ) + ) + try: + await c.execute( + "insert into {name} values('123'),('456'),('678')".format( + name=db_parameters["name"] + ) + ) + await c.execute("show tables") + await c.execute("select current_date()") + rec = await c.fetchone() + assert len(rec) == 1, "number of records is wrong" + await c.execute( + "select * from {name} order by 1".format(name=db_parameters["name"]) + ) + recs = await c.fetchall() + assert c.description[0][0] == "C1", "fisrt column name" + assert len(recs) == 3, "number of records is wrong" + finally: + await c.execute( + "drop table if exists {name}".format(name=db_parameters["name"]) + ) diff --git a/test/integ/aio/test_session_parameters_async.py b/test/integ/aio/test_session_parameters_async.py new file mode 100644 index 0000000000..8a291ec0c7 --- /dev/null +++ b/test/integ/aio/test_session_parameters_async.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import pytest + +import snowflake.connector.aio +from snowflake.connector.util_text import random_string + +try: # pragma: no cover + from ..parameters import CONNECTION_PARAMETERS_ADMIN +except ImportError: + CONNECTION_PARAMETERS_ADMIN = {} + + +async def test_session_parameters(db_parameters): + """Sets the session parameters in connection time.""" + async with snowflake.connector.aio.SnowflakeConnection( + protocol=db_parameters["protocol"], + account=db_parameters["account"], + user=db_parameters["user"], + password=db_parameters["password"], + host=db_parameters["host"], + port=db_parameters["port"], + database=db_parameters["database"], + schema=db_parameters["schema"], + session_parameters={"TIMEZONE": "UTC"}, + ) as connection: + ret = await ( + await connection.cursor().execute("show parameters like 'TIMEZONE'") + ).fetchone() + assert ret[1] == "UTC" + + +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, + reason="Snowflake admin required to setup parameter.", +) +async def test_client_session_keep_alive(db_parameters, conn_cnx): + """Tests client_session_keep_alive setting. + + Ensures that client's explicit config for client_session_keep_alive + session parameter is always honored and given higher precedence over + user and account level backend configuration. + """ + admin_cnxn = snowflake.connector.aio.SnowflakeConnection( + protocol=db_parameters["sf_protocol"], + account=db_parameters["sf_account"], + user=db_parameters["sf_user"], + password=db_parameters["sf_password"], + host=db_parameters["sf_host"], + port=db_parameters["sf_port"], + ) + await admin_cnxn.connect() + + # Ensure backend parameter is set to False + await set_backend_client_session_keep_alive(db_parameters, admin_cnxn, False) + async with conn_cnx(client_session_keep_alive=True) as connection: + ret = await ( + await connection.cursor() + .execute("show parameters like 'CLIENT_SESSION_KEEP_ALIVE'") + .fetchone() + ) + assert ret[1] == "true" + + # Set backend parameter to True + await set_backend_client_session_keep_alive(db_parameters, admin_cnxn, True) + + # Set session parameter to False + async with conn_cnx(client_session_keep_alive=False) as connection: + ret = await ( + await connection.cursor() + .execute("show parameters like 'CLIENT_SESSION_KEEP_ALIVE'") + .fetchone() + ) + assert ret[1] == "false" + + # Set session parameter to None backend parameter continues to be True + async with conn_cnx(client_session_keep_alive=None) as connection: + ret = await ( + await connection.cursor() + .execute("show parameters like 'CLIENT_SESSION_KEEP_ALIVE'") + .fetchone() + ) + assert ret[1] == "true" + + await admin_cnxn.close() + + +async def set_backend_client_session_keep_alive( + db_parameters: object, admin_cnx: object, val: bool +) -> None: + """Set both at Account level and User level.""" + query = "alter account {} set CLIENT_SESSION_KEEP_ALIVE={}".format( + db_parameters["account"], str(val) + ) + await admin_cnx.cursor().execute(query) + + query = "alter user {}.{} set CLIENT_SESSION_KEEP_ALIVE={}".format( + db_parameters["account"], db_parameters["user"], str(val) + ) + await admin_cnx.cursor().execute(query) + + +@pytest.mark.internal +async def test_htap_optimizations(db_parameters: object, conn_cnx) -> None: + random_prefix = random_string(5, "test_prefix").lower() + test_wh = f"{random_prefix}_wh" + test_db = f"{random_prefix}_db" + test_schema = f"{random_prefix}_schema" + + async with conn_cnx("admin") as admin_cnx: + try: + await admin_cnx.cursor().execute( + f"CREATE WAREHOUSE IF NOT EXISTS {test_wh}" + ) + await admin_cnx.cursor().execute(f"USE WAREHOUSE {test_wh}") + await admin_cnx.cursor().execute(f"CREATE DATABASE IF NOT EXISTS {test_db}") + await admin_cnx.cursor().execute( + f"CREATE SCHEMA IF NOT EXISTS {test_schema}" + ) + query = f"alter account {db_parameters['sf_account']} set ENABLE_SNOW_654741_FOR_TESTING=true" + await admin_cnx.cursor().execute(query) + + # assert wh, db, schema match conn params + assert admin_cnx._warehouse.lower() == test_wh + assert admin_cnx._database.lower() == test_db + assert admin_cnx._schema.lower() == test_schema + + # alter session set TIMESTAMP_OUTPUT_FORMAT='YYYY-MM-DD HH24:MI:SS.FFTZH' + await admin_cnx.cursor().execute( + "alter session set TIMESTAMP_OUTPUT_FORMAT='YYYY-MM-DD HH24:MI:SS.FFTZH'" + ) + + # create or replace table + await admin_cnx.cursor().execute( + "create or replace temp table testtable1 (cola string, colb int)" + ) + # insert into table 3 vals + await admin_cnx.cursor().execute( + "insert into testtable1 values ('row1', 1), ('row2', 2), ('row3', 3)" + ) + # select * from table + ret = await ( + await admin_cnx.cursor().execute("select * from testtable1") + ).fetchall() + # assert we get 3 results + assert len(ret) == 3 + + # assert wh, db, schema + assert admin_cnx._warehouse.lower() == test_wh + assert admin_cnx._database.lower() == test_db + assert admin_cnx._schema.lower() == test_schema + + assert ( + admin_cnx._session_parameters["TIMESTAMP_OUTPUT_FORMAT"] + == "YYYY-MM-DD HH24:MI:SS.FFTZH" + ) + + # alter session unset TIMESTAMP_OUTPUT_FORMAT + await admin_cnx.cursor().execute( + "alter session unset TIMESTAMP_OUTPUT_FORMAT" + ) + finally: + # alter account unset ENABLE_SNOW_654741_FOR_TESTING + query = f"alter account {db_parameters['sf_account']} unset ENABLE_SNOW_654741_FOR_TESTING" + await admin_cnx.cursor().execute(query) + await admin_cnx.cursor().execute(f"DROP SCHEMA IF EXISTS {test_schema}") + await admin_cnx.cursor().execute(f"DROP DATABASE IF EXISTS {test_db}") + await admin_cnx.cursor().execute(f"DROP WAREHOUSE IF EXISTS {test_wh}") diff --git a/test/integ/aio/test_statement_parameter_binding_async.py b/test/integ/aio/test_statement_parameter_binding_async.py new file mode 100644 index 0000000000..da83f87939 --- /dev/null +++ b/test/integ/aio/test_statement_parameter_binding_async.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from datetime import datetime + +import pytest +import pytz + +try: + from parameters import CONNECTION_PARAMETERS_ADMIN +except ImportError: + CONNECTION_PARAMETERS_ADMIN = {} + + +@pytest.mark.skipif( + not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible." +) +async def test_binding_security(conn_cnx): + """Tests binding statement parameters.""" + expected_qa_mode_datetime = datetime(1967, 6, 23, 7, 0, 0, 123000, pytz.UTC) + + async with conn_cnx() as cnx: + await cnx.cursor().execute("alter session set timezone='UTC'") + async with cnx.cursor() as cur: + await cur.execute("show databases like 'TESTDB'") + rec = await cur.fetchone() + assert rec[0] != expected_qa_mode_datetime + + async with cnx.cursor() as cur: + await cur.execute( + "show databases like 'TESTDB'", + _statement_params={ + "QA_MODE": True, + }, + ) + rec = await cur.fetchone() + assert rec[0] == expected_qa_mode_datetime + + async with cnx.cursor() as cur: + await cur.execute("show databases like 'TESTDB'") + rec = await cur.fetchone() + assert rec[0] != expected_qa_mode_datetime diff --git a/test/integ/aio/test_structured_types_async.py b/test/integ/aio/test_structured_types_async.py new file mode 100644 index 0000000000..33a05bfeaa --- /dev/null +++ b/test/integ/aio/test_structured_types_async.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from __future__ import annotations + +from textwrap import dedent + +import pytest + + +async def test_structured_array_types(conn_cnx): + async with conn_cnx() as cnx: + cur = cnx.cursor() + sql = dedent( + """select + [1, 2]::array(int), + [1.1::float, 1.2::float]::array(float), + ['a', 'b']::array(string not null), + [current_timestamp(), current_timestamp()]::array(timestamp), + [current_timestamp()::timestamp_ltz, current_timestamp()::timestamp_ltz]::array(timestamp_ltz), + [current_timestamp()::timestamp_tz, current_timestamp()::timestamp_tz]::array(timestamp_tz), + [current_timestamp()::timestamp_ntz, current_timestamp()::timestamp_ntz]::array(timestamp_ntz), + [current_date(), current_date()]::array(date), + [current_time(), current_time()]::array(time), + [True, False]::array(boolean), + [1::variant, 'b'::variant]::array(variant not null), + [{'a': 'b'}, {'c': 1}]::array(object) + """ + ) + # Geography and geometry are not supported in an array + # [TO_GEOGRAPHY('POINT(-122.35 37.55)'), TO_GEOGRAPHY('POINT(-123.35 37.55)')]::array(GEOGRAPHY), + # [TO_GEOMETRY('POINT(1820.12 890.56)'), TO_GEOMETRY('POINT(1820.12 890.56)')]::array(GEOMETRY), + await cur.execute(sql) + for metadata in cur.description: + assert metadata.type_code == 10 # same as a regular array + for metadata in await cur.describe(sql): + assert metadata.type_code == 10 + + +@pytest.mark.xfail( + reason="SNOW-1305289: Param difference in aws environment", strict=False +) +async def test_structured_map_types(conn_cnx): + async with conn_cnx() as cnx: + cur = cnx.cursor() + sql = dedent( + """select + {'a': 1}::map(string, variant), + {'a': 1.1::float}::map(string, float), + {'a': 'b'}::map(string, string), + {'a': current_timestamp()}::map(string, timestamp), + {'a': current_timestamp()::timestamp_ltz}::map(string, timestamp_ltz), + {'a': current_timestamp()::timestamp_ntz}::map(string, timestamp_ntz), + {'a': current_timestamp()::timestamp_tz}::map(string, timestamp_tz), + {'a': current_date()}::map(string, date), + {'a': current_time()}::map(string, time), + {'a': False}::map(string, boolean), + {'a': 'b'::variant}::map(string, variant not null), + {'a': {'c': 1}}::map(string, object) + """ + ) + await cur.execute(sql) + for metadata in cur.description: + assert metadata.type_code == 9 # same as a regular object + for metadata in await cur.describe(sql): + assert metadata.type_code == 9 diff --git a/test/integ/aio/test_transaction_async.py b/test/integ/aio/test_transaction_async.py new file mode 100644 index 0000000000..487c9c6d84 --- /dev/null +++ b/test/integ/aio/test_transaction_async.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import snowflake.connector.aio + + +async def test_transaction(conn_cnx, db_parameters): + """Tests transaction API.""" + async with conn_cnx() as cnx: + await cnx.cursor().execute( + "create table {name} (c1 int)".format(name=db_parameters["name"]) + ) + await cnx.cursor().execute( + "insert into {name}(c1) " + "values(1234),(3456)".format(name=db_parameters["name"]) + ) + c = cnx.cursor() + await c.execute("select * from {name}".format(name=db_parameters["name"])) + total = 0 + async for rec in c: + total += rec[0] + assert total == 4690, "total integer" + + # + await cnx.cursor().execute("begin") + await cnx.cursor().execute( + "insert into {name}(c1) values(5678),(7890)".format( + name=db_parameters["name"] + ) + ) + c = cnx.cursor() + await c.execute("select * from {name}".format(name=db_parameters["name"])) + total = 0 + async for rec in c: + total += rec[0] + assert total == 18258, "total integer" + await cnx.rollback() + + await c.execute("select * from {name}".format(name=db_parameters["name"])) + total = 0 + async for rec in c: + total += rec[0] + assert total == 4690, "total integer" + + # + await cnx.cursor().execute("begin") + await cnx.cursor().execute( + "insert into {name}(c1) values(2345),(6789)".format( + name=db_parameters["name"] + ) + ) + c = cnx.cursor() + await c.execute("select * from {name}".format(name=db_parameters["name"])) + total = 0 + async for rec in c: + total += rec[0] + assert total == 13824, "total integer" + await cnx.commit() + await cnx.rollback() + c = cnx.cursor() + await c.execute("select * from {name}".format(name=db_parameters["name"])) + total = 0 + async for rec in c: + total += rec[0] + assert total == 13824, "total integer" + + +async def test_connection_context_manager(request, db_parameters): + db_config = { + "protocol": db_parameters["protocol"], + "account": db_parameters["account"], + "user": db_parameters["user"], + "password": db_parameters["password"], + "host": db_parameters["host"], + "port": db_parameters["port"], + "database": db_parameters["database"], + "schema": db_parameters["schema"], + "timezone": "UTC", + } + + async def fin(): + async with snowflake.connector.aio.SnowflakeConnection(**db_config) as cnx: + await cnx.cursor().execute( + """ +DROP TABLE IF EXISTS {name} +""".format( + name=db_parameters["name"] + ) + ) + + try: + async with snowflake.connector.aio.SnowflakeConnection(**db_config) as cnx: + await cnx.autocommit(False) + await cnx.cursor().execute( + """ +CREATE OR REPLACE TABLE {name} (cc1 int) +""".format( + name=db_parameters["name"] + ) + ) + await cnx.cursor().execute( + """ +INSERT INTO {name} VALUES(1),(2),(3) +""".format( + name=db_parameters["name"] + ) + ) + ret = await ( + await cnx.cursor().execute( + """ +SELECT SUM(cc1) FROM {name} +""".format( + name=db_parameters["name"] + ) + ) + ).fetchone() + assert ret[0] == 6 + await cnx.commit() + await cnx.cursor().execute( + """ +INSERT INTO {name} VALUES(4),(5),(6) +""".format( + name=db_parameters["name"] + ) + ) + ret = await ( + await cnx.cursor().execute( + """ +SELECT SUM(cc1) FROM {name} +""".format( + name=db_parameters["name"] + ) + ) + ).fetchone() + assert ret[0] == 21 + await cnx.cursor().execute( + """ +SELECT WRONG SYNTAX QUERY +""" + ) + raise Exception("Failed to cause the syntax error") + except snowflake.connector.Error: + # syntax error should be caught here + # and the last change must have been rollbacked + async with snowflake.connector.aio.SnowflakeConnection(**db_config) as cnx: + ret = await ( + await cnx.cursor().execute( + """ +SELECT SUM(cc1) FROM {name} +""".format( + name=db_parameters["name"] + ) + ) + ).fetchone() + assert ret[0] == 6 + yield + await fin() diff --git a/test/integ_helpers.py b/test/integ_helpers.py index cf9e0c9642..d4e32a4e50 100644 --- a/test/integ_helpers.py +++ b/test/integ_helpers.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: # pragma: no cover + from snowflake.connector.aio._cursor import SnowflakeCursor as SnowflakeCursorAsync from snowflake.connector.cursor import SnowflakeCursor @@ -45,3 +46,38 @@ def put( file=file_path.replace("\\", "\\\\"), stage=stage_path, sql_options=sql_options ) return csr.execute(sql, **kwargs) + + +async def put_async( + csr: SnowflakeCursorAsync, + file_path: str, + stage_path: str, + from_path: bool, + sql_options: str | None = "", + **kwargs, +) -> SnowflakeCursorAsync: + """Execute PUT query with given cursor. + + Args: + csr: Snowflake cursor object. + file_path: Path to the target file in local system; Or . when from_path is False. + stage_path: Destination path of file on the stage. + from_path: Whether the target file is fetched with given path, specify file_stream= if False. + sql_options: Optional arguments to the PUT command. + **kwargs: Optional arguments passed to SnowflakeCursor.execute() + + Returns: + A result class with the results in it. This can either be json, or an arrow result class. + """ + sql = "put 'file://{file}' @{stage} {sql_options}" + if from_path: + kwargs.pop("file_stream", None) + else: + # PUT from stream + file_path = os.path.basename(file_path) + if kwargs.pop("commented", False): + sql = "--- test comments\n" + sql + sql = sql.format( + file=file_path.replace("\\", "\\\\"), stage=stage_path, sql_options=sql_options + ) + return await csr.execute(sql, **kwargs) diff --git a/test/stress/aio/README.md b/test/stress/aio/README.md new file mode 100644 index 0000000000..881f8613e1 --- /dev/null +++ b/test/stress/aio/README.md @@ -0,0 +1,21 @@ +## quick start for performance testing + + +### setup + +note: you need to put your own credentials into parameters.py + +```bash +git clone git@github.com:snowflakedb/snowflake-connector-python.git +cd snowflake-connector-python/test/stress +pip install -r dev_requirements.txt +touch parameters.py # set your own connection parameters +``` + +### run e2e perf test + +This test will run query against snowflake. update the script to prepare the data and run the test. + +```python +python e2e_iterator.py +``` diff --git a/test/stress/aio/__init__.py b/test/stress/aio/__init__.py new file mode 100644 index 0000000000..ef416f64a0 --- /dev/null +++ b/test/stress/aio/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/test/stress/aio/dev_requirements.txt b/test/stress/aio/dev_requirements.txt new file mode 100644 index 0000000000..b09f51fa8d --- /dev/null +++ b/test/stress/aio/dev_requirements.txt @@ -0,0 +1,6 @@ +psutil +../.. +matplotlib +aiohttp +pandas +asyncio diff --git a/test/stress/aio/e2e_iterator.py b/test/stress/aio/e2e_iterator.py new file mode 100644 index 0000000000..7bb9b51674 --- /dev/null +++ b/test/stress/aio/e2e_iterator.py @@ -0,0 +1,446 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +""" +This script is used for end-to-end performance test for asyncio python connector. + +1. select and consume rows of different types for 3 hr, (very large amount of data 10m rows) + + - goal: timeout/retry/refresh token + - fetch_one/fetch_many/fetch_pandas_batches + - validate the fetched data is accurate + +2. put file + - many small files + - one large file + - verify files(etc. file amount, sha256 signature) + +3. get file + - many small files + - one large file + - verify files (etc. file amount, sha256 signature) +""" + +import argparse +import asyncio +import csv +import datetime +import gzip +import hashlib +import os.path +import random +import secrets +import string +from decimal import Decimal + +import pandas as pd +import pytz +import util as stress_util +from util import task_decorator + +from parameters import CONNECTION_PARAMETERS +from snowflake.connector.aio import SnowflakeConnection + +stress_util.print_to_console = False +can_draw = True +try: + import matplotlib.pyplot as plt +except ImportError: + print("graphs can not be drawn as matplotlib is not installed.") + can_draw = False + +expected_row = ( + 123456, + b"HELP", + True, + "a", + "b", + datetime.date(2023, 7, 18), + datetime.datetime(2023, 7, 18, 12, 51), + Decimal("984.280"), + Decimal("268.350"), + 123.456, + 738.132, + 6789, + 23456, + 12583, + 513.431, + 10, + 9, + "abc456", + "def123", + datetime.time(12, 34, 56), + datetime.datetime(2021, 1, 1, 0, 0), + datetime.datetime(2021, 1, 1, 0, 0, tzinfo=pytz.UTC), + datetime.datetime.strptime( + "2021-01-01 00:00:00 +0000", "%Y-%m-%d %H:%M:%S %z" + ).astimezone(pytz.timezone("America/Los_Angeles")), + datetime.datetime(2021, 1, 1, 0, 0), + 1, + b"HELP", + "vxlmls!21321#@!#!", +) + +expected_pandas = ( + 123456, + b"HELP", + True, + "a", + "b", + datetime.date(2023, 7, 18), + datetime.datetime(2023, 7, 18, 12, 51), + Decimal("984.28"), + Decimal("268.35"), + 123.456, + 738.132, + 6789, + 23456, + 12583, + 513.431, + 10, + 9, + "abc456", + "def123", + datetime.time(12, 34, 56), + datetime.datetime(2021, 1, 1, 0, 0), + datetime.datetime.strptime("2020-12-31 16:00:00 -0800", "%Y-%m-%d %H:%M:%S %z"), + datetime.datetime.strptime( + "2021-01-01 00:00:00 +0000", "%Y-%m-%d %H:%M:%S %z" + ).astimezone(pytz.timezone("America/Los_Angeles")), + datetime.datetime(2021, 1, 1, 0, 0), + 1, + b"HELP", + "vxlmls!21321#@!#!", +) +expected_pandas = pd.DataFrame( + [expected_pandas], + columns=[ + "C1", + "C2", + "C3", + "C4", + "C5", + "C6", + "C7", + "C8", + "C9", + "C10", + "C11", + "C12", + "C13", + "C14", + "C15", + "C16", + "C17", + "C18", + "C19", + "C20", + "C21", + "C22", + "C23", + "C24", + "C25", + "C26", + "C27", + ], +) + + +async def prepare_data(cursor, row_count=100, test_table_name="TEMP_ARROW_TEST_TABLE"): + await cursor.execute( + f"""\ +CREATE OR REPLACE TEMP TABLE {test_table_name} ( + C1 BIGINT, C2 BINARY, C3 BOOLEAN, C4 CHAR, C5 CHARACTER, C6 DATE, C7 DATETIME, C8 DEC(12,3), + C9 DECIMAL(12,3), C10 DOUBLE, C11 FLOAT, C12 INT, C13 INTEGER, C14 NUMBER, C15 REAL, C16 BYTEINT, + C17 SMALLINT, C18 STRING, C19 TEXT, C20 TIME, C21 TIMESTAMP, C22 TIMESTAMP_TZ, C23 TIMESTAMP_LTZ, + C24 TIMESTAMP_NTZ, C25 TINYINT, C26 VARBINARY, C27 VARCHAR); +""" + ) + + for _ in range(row_count): + await cursor.execute( + f"""\ +INSERT INTO {test_table_name} SELECT + 123456, + TO_BINARY('HELP', 'UTF-8'), + TRUE, + 'a', + 'b', + '2023-07-18', + '2023-07-18 12:51:00', + 984.28, + 268.35, + 123.456, + 738.132, + 6789, + 23456, + 12583, + 513.431, + 10, + 9, + 'abc456', + 'def123', + '12:34:56', + '2021-01-01 00:00:00 +0000', + '2021-01-01 00:00:00 +0000', + '2021-01-01 00:00:00 +0000', + '2021-01-01 00:00:00 +0000', + 1, + TO_BINARY('HELP', 'UTF-8'), + 'vxlmls!21321#@!#!' +; +""" + ) + + +def data_generator(): + return { + "C1": random.randint(-1_000_000, 1_000_000), + "C2": secrets.token_bytes(4), + "C3": random.choice([True, False]), + "C4": random.choice(string.ascii_letters), + "C5": random.choice(string.ascii_letters), + "C6": datetime.date.today().isoformat(), + "C7": datetime.datetime.now().isoformat(), + "C8": round(random.uniform(-1_000, 1_000), 3), + "C9": round(random.uniform(-1_000, 1_000), 3), + "C10": random.uniform(-1_000, 1_000), + "C11": random.uniform(-1_000, 1_000), + "C12": random.randint(-1_000_000, 1_000_000), + "C13": random.randint(-1_000_000, 1_000_000), + "C14": random.randint(-1_000_000, 1_000_000), + "C15": random.uniform(-1_000, 1_000), + "C16": random.randint(-128, 127), + "C17": random.randint(-32_768, 32_767), + "C18": "".join(random.choices(string.ascii_letters + string.digits, k=8)), + "C19": "".join(random.choices(string.ascii_letters + string.digits, k=10)), + "C20": datetime.datetime.now().time().isoformat(), + "C21": datetime.datetime.now().isoformat() + " +00:00", + "C22": datetime.datetime.now().isoformat() + " +00:00", + "C23": datetime.datetime.now().isoformat() + " +00:00", + "C24": datetime.datetime.now().isoformat() + " +00:00", + "C25": random.randint(0, 255), + "C26": secrets.token_bytes(4), + "C27": "".join( + random.choices(string.ascii_letters + string.digits, k=12) + ), # VARCHAR + } + + +async def prepare_file(cursor, stage_location): + if not os.path.exists("../stress_test_data/single_chunk_file_1.csv"): + with open("../stress_test_data/single_chunk_file_1.csv", "w") as f: + d = data_generator() + writer = csv.writer(f) + writer.writerow(d.keys()) + writer.writerow(d.values()) + if not os.path.exists("../stress_test_data/single_chunk_file_2.csv"): + with open("../stress_test_data/single_chunk_file_2.csv", "w") as f: + d = data_generator() + writer = csv.writer(f) + writer.writerow(d.keys()) + writer.writerow(d.values()) + if not os.path.exists("../stress_test_data/multiple_chunks_file_1.csv"): + with open("../stress_test_data/multiple_chunks_file_1.csv", "w") as f: + writer = csv.writer(f) + d = data_generator() + writer.writerow(d.keys()) + for _ in range(2000000): + writer.writerow(data_generator().values()) + if not os.path.exists("../stress_test_data/multiple_chunks_file_2.csv"): + with open("../stress_test_data/multiple_chunks_file_2.csv", "w") as f: + writer = csv.writer(f) + d = data_generator() + writer.writerow(d.keys()) + for _ in range(2000000): + writer.writerow(data_generator().values()) + res = await cursor.execute( + f"PUT file://../stress_test_data/multiple_chunks_file_* {stage_location} OVERWRITE = TRUE" + ) + print(f"test file uploaded to {stage_location}", await res.fetchall()) + await cursor.execute( + f"PUT file://../stress_test_data/single_chunk_file_* {stage_location} OVERWRITE = TRUE" + ) + print(f"test file uploaded to {stage_location}", await res.fetchall()) + + +async def task_fetch_one_row(cursor, table_name, row_count_limit=50000): + res = await cursor.execute(f"select * from {table_name} limit {row_count_limit}") + + for _ in range(row_count_limit): + ret = await res.fetchone() + print("task_fetch_one_row done, result: ", ret) + assert ret == expected_row + + +async def task_fetch_rows(cursor, table_name, row_count_limit=50000): + ret = await ( + await cursor.execute(f"select * from {table_name} limit {row_count_limit}") + ).fetchall() + print("task_fetch_rows done, result: ", ret) + print(ret[0]) + assert ret[0] == expected_row + + +async def task_fetch_arrow_batches(cursor, table_name, row_count_limit=50000): + ret = await ( + await cursor.execute(f"select * from {table_name} limit {2}") + ).fetch_arrow_batches() + print("fetch_arrow_batches done, result: ", ret) + async for a in ret: + assert a.to_pandas().iloc[0].to_string(index=False) == expected_pandas.iloc[ + 0 + ].to_string(index=False) + + +async def put_file(cursor, stage_location, is_multiple, is_multi_chunk_file): + file_name = "multiple_chunks_file_" if is_multi_chunk_file else "single_chunk_file_" + source_file = ( + f"file://../stress_test_data/{file_name}*" + if is_multiple + else f"file://../stress_test_data/{file_name}1.csv" + ) + sql = f"PUT {source_file} {stage_location} OVERWRITE = TRUE" + res = await cursor.execute(sql) + print("put_file done, result: ", await res.fetchall()) + + +async def get_file(cursor, stage_location, is_multiple, is_multi_chunk_file): + file_name = "multiple_chunks_file_" if is_multi_chunk_file else "single_chunk_file_" + stage_file = ( + f"{stage_location}" if is_multiple else f"{stage_location}{file_name}1.csv" + ) + sql = ( + f"GET {stage_file} file://../stress_test_data/ PATTERN = '.*{file_name}.*'" + if is_multiple + else f"GET {stage_file} file://../stress_test_data/" + ) + res = await cursor.execute(sql) + print("get_file done, result: ", await res.fetchall()) + hash_downloaded = hashlib.md5() + hash_original = hashlib.md5() + with gzip.open(f"../stress_test_data/{file_name}1.csv.gz", "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_downloaded.update(chunk) + with open(f"../stress_test_data/{file_name}1.csv", "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_original.update(chunk) + assert hash_downloaded.hexdigest() == hash_original.hexdigest() + + +async def async_wrapper(args): + conn = SnowflakeConnection( + user=CONNECTION_PARAMETERS["user"], + password=CONNECTION_PARAMETERS["password"], + host=CONNECTION_PARAMETERS["host"], + account=CONNECTION_PARAMETERS["account"], + database=CONNECTION_PARAMETERS["database"], + schema=CONNECTION_PARAMETERS["schema"], + warehouse=CONNECTION_PARAMETERS["warehouse"], + ) + await conn.connect() + cursor = conn.cursor() + + # prepare file + await prepare_file(cursor, args.stage_location) + await prepare_data(cursor, args.row_count, args.test_table_name) + + perf_record_file = "stress_perf_record" + memory_record_file = "stress_memory_record" + with open(perf_record_file, "w") as perf_file, open( + memory_record_file, "w" + ) as memory_file: + with task_decorator(perf_file, memory_file): + for _ in range(args.iteration_cnt): + if args.test_function == "FETCH_ONE_ROW": + await task_fetch_one_row(cursor, args.test_table_name) + if args.test_function == "FETCH_ROWS": + await task_fetch_rows(cursor, args.test_table_name) + if args.test_function == "FETCH_ARROW_BATCHES": + await task_fetch_arrow_batches(cursor, args.test_table_name) + if args.test_function == "GET_FILE": + await get_file( + cursor, + args.stage_location, + args.is_multiple_file, + args.is_multiple_chunks_file, + ) + if args.test_function == "PUT_FILE": + await put_file( + cursor, + args.stage_location, + args.is_multiple_file, + args.is_multiple_chunks_file, + ) + + if can_draw: + with open(perf_record_file) as perf_file, open( + memory_record_file + ) as memory_file: + # sample rate + perf_lines = perf_file.readlines() + perf_records = [float(line) for line in perf_lines] + + memory_lines = memory_file.readlines() + memory_records = [float(line) for line in memory_lines] + + plt.plot([i for i in range(len(perf_records))], perf_records) + plt.title("per iteration execution time") + plt.show(block=False) + plt.figure() + plt.plot([i for i in range(len(memory_records))], memory_records) + plt.title("memory usage") + plt.show(block=True) + + await conn.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--iteration_cnt", + type=int, + default=5000, + help="how many times to run the test function, default is 5000", + ) + parser.add_argument( + "--row_count", + type=int, + default=100, + help="how man rows of data to insert into the temp test able if test_table_name is not provided", + ) + parser.add_argument( + "--test_table_name", + type=str, + default="ARROW_TEST_TABLE", + help="an existing test table that has data prepared, by default the it looks for 'ARROW_TEST_TABLE'", + ) + parser.add_argument( + "--test_function", + type=str, + default="FETCH_ARROW_BATCHES", + help="function to test, by default it is 'FETCH_ONE_ROW', it can also be 'FETCH_ROWS', 'FETCH_ARROW_BATCHES', 'GET_FILE', 'PUT_FILE'", + ) + parser.add_argument( + "--stage_location", + type=str, + default="", + help="stage location used to store files, example: '@test_stage/'", + required=True, + ) + parser.add_argument( + "--is_multiple_file", + type=str, + default=True, + help="transfer multiple file in get or put", + ) + parser.add_argument( + "--is_multiple_chunks_file", + type=str, + default=True, + help="transfer multiple chunks file in get or put", + ) + args = parser.parse_args() + + asyncio.run(async_wrapper(args)) diff --git a/test/stress/aio/util.py b/test/stress/aio/util.py new file mode 100644 index 0000000000..ee961b24ab --- /dev/null +++ b/test/stress/aio/util.py @@ -0,0 +1,31 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import time +from contextlib import contextmanager + +import psutil + +process = psutil.Process() + +SAMPLE_RATE = 10 # record data evey SAMPLE_RATE execution + + +@contextmanager +def task_decorator(perf_file, memory_file): + count = 0 + + start = time.time() + yield + memory_usage = ( + process.memory_info().rss / 1024 / 1024 + ) # rss is of unit bytes, we get unit in MB + period = time.time() - start + if count % SAMPLE_RATE == 0: + perf_file.write(str(period) + "\n") + print(f"execution time {count}") + print(f"memory usage: {memory_usage} MB") + print(f"execution time: {period} s") + memory_file.write(str(memory_usage) + "\n") + count += 1 diff --git a/test/unit/aio/mock_utils.py b/test/unit/aio/mock_utils.py new file mode 100644 index 0000000000..5341904dfe --- /dev/null +++ b/test/unit/aio/mock_utils.py @@ -0,0 +1,46 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import aiohttp + +from snowflake.connector.auth.by_plugin import DEFAULT_AUTH_CLASS_TIMEOUT +from snowflake.connector.connection import DEFAULT_BACKOFF_POLICY + + +def mock_async_request_with_action(next_action, sleep=None): + async def mock_request(*args, **kwargs): + if sleep is not None: + await asyncio.sleep(sleep) + if next_action == "RETRY": + return MagicMock( + status=503, + close=lambda: None, + ) + elif next_action == "ERROR": + raise aiohttp.ClientConnectionError() + + return mock_request + + +def mock_connection( + login_timeout=DEFAULT_AUTH_CLASS_TIMEOUT, + network_timeout=None, + socket_timeout=None, + backoff_policy=DEFAULT_BACKOFF_POLICY, + disable_saml_url_check=False, +): + return AsyncMock( + _login_timeout=login_timeout, + login_timeout=login_timeout, + _network_timeout=network_timeout, + network_timeout=network_timeout, + _socket_timeout=socket_timeout, + socket_timeout=socket_timeout, + _backoff_policy=backoff_policy, + backoff_policy=backoff_policy, + _disable_saml_url_check=disable_saml_url_check, + ) diff --git a/test/unit/aio/test_auth_async.py b/test/unit/aio/test_auth_async.py new file mode 100644 index 0000000000..b36a64d0eb --- /dev/null +++ b/test/unit/aio/test_auth_async.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import inspect +import sys +from test.unit.aio.mock_utils import mock_connection +from unittest.mock import Mock, PropertyMock + +import pytest + +import snowflake.connector.errors +from snowflake.connector.aio._network import SnowflakeRestful +from snowflake.connector.aio.auth import Auth, AuthByDefault, AuthByPlugin +from snowflake.connector.constants import OCSPMode +from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION + + +def _init_rest(application, post_requset): + connection = mock_connection() + connection.errorhandler = Mock(return_value=None) + connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) + type(connection).application = PropertyMock(return_value=application) + type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME) + type(connection)._internal_application_version = PropertyMock( + return_value=CLIENT_VERSION + ) + + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", port=443, connection=connection + ) + rest._post_request = post_requset + return rest + + +def _create_mock_auth_mfs_rest_response(next_action: str): + async def _mock_auth_mfa_rest_response(url, headers, body, **kwargs): + """Tests successful case.""" + global mock_cnt + _ = url + _ = headers + _ = body + _ = kwargs.get("dummy") + if mock_cnt == 0: + ret = { + "success": True, + "message": None, + "data": { + "nextAction": next_action, + "inFlightCtx": "inFlightCtx", + }, + } + elif mock_cnt == 1: + ret = { + "success": True, + "message": None, + "data": { + "token": "TOKEN", + "masterToken": "MASTER_TOKEN", + }, + } + + mock_cnt += 1 + return ret + + return _mock_auth_mfa_rest_response + + +async def _mock_auth_mfa_rest_response_failure(url, headers, body, **kwargs): + """Tests failed case.""" + global mock_cnt + _ = url + _ = headers + _ = body + _ = kwargs.get("dummy") + + if mock_cnt == 0: + ret = { + "success": True, + "message": None, + "data": { + "nextAction": "EXT_AUTHN_DUO_ALL", + "inFlightCtx": "inFlightCtx", + }, + } + elif mock_cnt == 1: + ret = { + "success": True, + "message": None, + "data": { + "nextAction": "BAD", + "inFlightCtx": "inFlightCtx", + }, + } + elif mock_cnt == 2: + ret = { + "success": True, + "message": None, + "data": None, + } + mock_cnt += 1 + return ret + + +async def _mock_auth_mfa_rest_response_timeout(url, headers, body, **kwargs): + """Tests timeout case.""" + global mock_cnt + _ = url + _ = headers + _ = body + _ = kwargs.get("dummy") + if mock_cnt == 0: + ret = { + "success": True, + "message": None, + "data": { + "nextAction": "EXT_AUTHN_DUO_ALL", + "inFlightCtx": "inFlightCtx", + }, + } + elif mock_cnt == 1: + await asyncio.sleep(10) # should timeout while here + ret = {} + elif mock_cnt == 2: + ret = { + "success": True, + "message": None, + "data": None, + } + + mock_cnt += 1 + return ret + + +@pytest.mark.parametrize( + "next_action", ("EXT_AUTHN_DUO_ALL", "EXT_AUTHN_DUO_PUSH_N_PASSCODE") +) +async def test_auth_mfa(next_action: str): + """Authentication by MFA.""" + global mock_cnt + application = "testapplication" + account = "testaccount" + user = "testuser" + password = "testpassword" + + # success test case + mock_cnt = 0 + rest = _init_rest(application, _create_mock_auth_mfs_rest_response(next_action)) + auth = Auth(rest) + auth_instance = AuthByDefault(password) + await auth.authenticate(auth_instance, account, user) + assert not rest._connection.errorhandler.called # not error + assert rest.token == "TOKEN" + assert rest.master_token == "MASTER_TOKEN" + + # failure test case + mock_cnt = 0 + rest = _init_rest(application, _mock_auth_mfa_rest_response_failure) + auth = Auth(rest) + auth_instance = AuthByDefault(password) + await auth.authenticate(auth_instance, account, user) + assert rest._connection.errorhandler.called # error + + # timeout 1 second + mock_cnt = 0 + rest = _init_rest(application, _mock_auth_mfa_rest_response_timeout) + auth = Auth(rest) + auth_instance = AuthByDefault(password) + await auth.authenticate(auth_instance, account, user, timeout=1) + assert rest._connection.errorhandler.called # error + + # ret["data"] is none + with pytest.raises(snowflake.connector.errors.Error): + mock_cnt = 2 + rest = _init_rest(application, _mock_auth_mfa_rest_response_timeout) + auth = Auth(rest) + auth_instance = AuthByDefault(password) + await auth.authenticate(auth_instance, account, user) + + +async def _mock_auth_password_change_rest_response(url, headers, body, **kwargs): + """Test successful case.""" + global mock_cnt + _ = url + _ = headers + _ = body + _ = kwargs.get("dummy") + if mock_cnt == 0: + ret = { + "success": True, + "message": None, + "data": { + "nextAction": "PWD_CHANGE", + "inFlightCtx": "inFlightCtx", + }, + } + elif mock_cnt == 1: + ret = { + "success": True, + "message": None, + "data": { + "token": "TOKEN", + "masterToken": "MASTER_TOKEN", + }, + } + + mock_cnt += 1 + return ret + + +@pytest.mark.xfail(reason="SNOW-1707210: password_callback callback not implemented ") +async def test_auth_password_change(): + """Tests password change.""" + global mock_cnt + + async def _password_callback(): + return "NEW_PASSWORD" + + application = "testapplication" + account = "testaccount" + user = "testuser" + password = "testpassword" + + # success test case + mock_cnt = 0 + rest = _init_rest(application, _mock_auth_password_change_rest_response) + auth = Auth(rest) + auth_instance = AuthByDefault(password) + await auth.authenticate( + auth_instance, account, user, password_callback=_password_callback + ) + assert not rest._connection.errorhandler.called # not error + + +async def test_authbyplugin_abc_api(): + """This test verifies that the abstract function signatures have not changed.""" + bc = AuthByPlugin + + # Verify properties + assert inspect.isdatadescriptor(bc.timeout) + assert inspect.isdatadescriptor(bc.type_) + assert inspect.isdatadescriptor(bc.assertion_content) + + # Verify method signatures + # update_body + if sys.version_info < (3, 12): + assert inspect.isfunction(bc.update_body) + assert str(inspect.signature(bc.update_body).parameters) == ( + "OrderedDict([('self', ), " + "('body', )])" + ) + + # authenticate + assert inspect.isfunction(bc.prepare) + assert str(inspect.signature(bc.prepare).parameters) == ( + "OrderedDict([('self', ), " + "('conn', ), " + "('authenticator', ), " + "('service_name', ), " + "('account', ), " + "('user', ), " + "('password', ), " + "('kwargs', )])" + ) + + # handle_failure + assert inspect.isfunction(bc._handle_failure) + assert str(inspect.signature(bc._handle_failure).parameters) == ( + "OrderedDict([('self', ), " + "('conn', ), " + "('ret', ), " + "('kwargs', )])" + ) + + # handle_timeout + assert inspect.isfunction(bc.handle_timeout) + assert str(inspect.signature(bc.handle_timeout).parameters) == ( + "OrderedDict([('self', ), " + "('authenticator', ), " + "('service_name', ), " + "('account', ), " + "('user', ), " + "('password', ), " + "('kwargs', )])" + ) + else: + # starting from python 3.12 the repr of collections.OrderedDict is changed + # to use regular dictionary formating instead of pairs of keys and values. + # see https://github.com/python/cpython/issues/101446 + assert inspect.isfunction(bc.update_body) + assert str(inspect.signature(bc.update_body).parameters) == ( + """OrderedDict({'self': , \ +'body': })""" + ) + + # authenticate + assert inspect.isfunction(bc.prepare) + assert str(inspect.signature(bc.prepare).parameters) == ( + """OrderedDict({'self': , \ +'conn': , \ +'authenticator': , \ +'service_name': , \ +'account': , \ +'user': , \ +'password': , \ +'kwargs': })""" + ) + + # handle_failure + assert inspect.isfunction(bc._handle_failure) + assert str(inspect.signature(bc._handle_failure).parameters) == ( + """OrderedDict({'self': , \ +'conn': , \ +'ret': , \ +'kwargs': })""" + ) + + # handle_timeout + assert inspect.isfunction(bc.handle_timeout) + assert str(inspect.signature(bc.handle_timeout).parameters) == ( + """OrderedDict({'self': , \ +'authenticator': , \ +'service_name': , \ +'account': , \ +'user': , \ +'password': , \ +'kwargs': })""" + ) diff --git a/test/unit/aio/test_auth_keypair_async.py b/test/unit/aio/test_auth_keypair_async.py new file mode 100644 index 0000000000..9c4037ed0e --- /dev/null +++ b/test/unit/aio/test_auth_keypair_async.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from test.unit.aio.mock_utils import mock_connection +from unittest.mock import Mock, PropertyMock, patch + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.hazmat.primitives.serialization import load_der_private_key +from pytest import raises + +from snowflake.connector.aio._network import SnowflakeRestful +from snowflake.connector.aio.auth import Auth, AuthByKeyPair +from snowflake.connector.constants import OCSPMode +from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION + + +def _create_mock_auth_keypair_rest_response(): + async def _mock_auth_key_pair_rest_response(url, headers, body, **kwargs): + return { + "success": True, + "data": { + "token": "TOKEN", + "masterToken": "MASTER_TOKEN", + }, + } + + return _mock_auth_key_pair_rest_response + + +async def test_auth_keypair(): + """Simple Key Pair test.""" + private_key_der, public_key_der_encoded = generate_key_pair(2048) + application = "testapplication" + account = "testaccount" + user = "testuser" + auth_instance = AuthByKeyPair(private_key=private_key_der) + auth_instance._retry_ctx.set_start_time() + await auth_instance.handle_timeout( + authenticator="SNOWFLAKE_JWT", + service_name=None, + account=account, + user=user, + password=None, + ) + + # success test case + rest = _init_rest(application, _create_mock_auth_keypair_rest_response()) + auth = Auth(rest) + await auth.authenticate(auth_instance, account, user) + assert not rest._connection.errorhandler.called # not error + assert rest.token == "TOKEN" + assert rest.master_token == "MASTER_TOKEN" + + +async def test_auth_keypair_abc(): + """Simple Key Pair test using abstraction layer.""" + private_key_der, public_key_der_encoded = generate_key_pair(2048) + application = "testapplication" + account = "testaccount" + user = "testuser" + + private_key = load_der_private_key( + data=private_key_der, + password=None, + backend=default_backend(), + ) + + assert isinstance(private_key, RSAPrivateKey) + + auth_instance = AuthByKeyPair(private_key=private_key) + auth_instance._retry_ctx.set_start_time() + await auth_instance.handle_timeout( + authenticator="SNOWFLAKE_JWT", + service_name=None, + account=account, + user=user, + password=None, + ) + + # success test case + rest = _init_rest(application, _create_mock_auth_keypair_rest_response()) + auth = Auth(rest) + await auth.authenticate(auth_instance, account, user) + assert not rest._connection.errorhandler.called # not error + assert rest.token == "TOKEN" + assert rest.master_token == "MASTER_TOKEN" + + +async def test_auth_keypair_bad_type(): + """Simple Key Pair test using abstraction layer.""" + account = "testaccount" + user = "testuser" + + class Bad: + pass + + for bad_private_key in ("abcd", 1234, Bad()): + auth_instance = AuthByKeyPair(private_key=bad_private_key) + with raises(TypeError) as ex: + await auth_instance.prepare(account=account, user=user) + assert str(type(bad_private_key)) in str(ex) + + +@patch("snowflake.connector.aio.auth.AuthByKeyPair.prepare") +async def test_renew_token(mockPrepare): + private_key_der, _ = generate_key_pair(2048) + auth_instance = AuthByKeyPair(private_key=private_key_der) + + # force renew condition to be met + auth_instance._retry_ctx.set_start_time() + auth_instance._jwt_timeout = 0 + account = "testaccount" + user = "testuser" + + await auth_instance.handle_timeout( + authenticator="SNOWFLAKE_JWT", + service_name=None, + account=account, + user=user, + password=None, + ) + + assert mockPrepare.called + + +def _init_rest(application, post_requset): + connection = mock_connection() + connection.errorhandler = Mock(return_value=None) + connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) + type(connection).application = PropertyMock(return_value=application) + type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME) + type(connection)._internal_application_version = PropertyMock( + return_value=CLIENT_VERSION + ) + + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", port=443, connection=connection + ) + rest._post_request = post_requset + return rest + + +def generate_key_pair(key_length): + private_key = rsa.generate_private_key( + backend=default_backend(), public_exponent=65537, key_size=key_length + ) + + private_key_der = private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + public_key_pem = ( + private_key.public_key() + .public_bytes( + serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo + ) + .decode("utf-8") + ) + + # strip off header + public_key_der_encoded = "".join(public_key_pem.split("\n")[1:-2]) + + return private_key_der, public_key_der_encoded diff --git a/test/unit/aio/test_auth_mfa_async.py b/test/unit/aio/test_auth_mfa_async.py new file mode 100644 index 0000000000..403e70d2e5 --- /dev/null +++ b/test/unit/aio/test_auth_mfa_async.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from unittest import mock + +from snowflake.connector.aio import SnowflakeConnection + + +async def test_mfa_token_cache(): + with mock.patch( + "snowflake.connector.aio._network.SnowflakeRestful.fetch", + ): + with mock.patch( + "snowflake.connector.aio.auth.Auth._write_temporary_credential", + ) as save_mock: + async with SnowflakeConnection( + account="account", + user="user", + password="password", + authenticator="username_password_mfa", + client_store_temporary_credential=True, + client_request_mfa_token=True, + ): + assert save_mock.called + with mock.patch( + "snowflake.connector.aio._network.SnowflakeRestful.fetch", + return_value={ + "data": { + "token": "abcd", + "masterToken": "defg", + }, + "success": True, + }, + ): + with mock.patch( + "snowflake.connector.aio.SnowflakeCursor._init_result_and_meta", + ): + with mock.patch( + "snowflake.connector.aio.auth.Auth._write_temporary_credential", + return_value=None, + ) as load_mock: + async with SnowflakeConnection( + account="account", + user="user", + password="password", + authenticator="username_password_mfa", + client_store_temporary_credential=True, + client_request_mfa_token=True, + ): + assert load_mock.called diff --git a/test/unit/aio/test_auth_oauth_async.py b/test/unit/aio/test_auth_oauth_async.py new file mode 100644 index 0000000000..1c99c1f123 --- /dev/null +++ b/test/unit/aio/test_auth_oauth_async.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from snowflake.connector.aio.auth import AuthByOAuth + + +async def test_auth_oauth(): + """Simple OAuth test.""" + token = "oAuthToken" + auth = AuthByOAuth(token) + body = {"data": {}} + await auth.update_body(body) + assert body["data"]["TOKEN"] == token, body + assert body["data"]["AUTHENTICATOR"] == "OAUTH", body diff --git a/test/unit/aio/test_auth_okta_async.py b/test/unit/aio/test_auth_okta_async.py new file mode 100644 index 0000000000..c2ceee78d3 --- /dev/null +++ b/test/unit/aio/test_auth_okta_async.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import logging +from test.unit.aio.mock_utils import mock_connection +from unittest.mock import AsyncMock, Mock, PropertyMock, patch + +import aiohttp +import pytest + +from snowflake.connector.aio._network import SnowflakeRestful +from snowflake.connector.aio.auth import AuthByOkta +from snowflake.connector.constants import OCSPMode +from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION + + +async def test_auth_okta(): + """Authentication by OKTA positive test case.""" + authenticator = "https://testsso.snowflake.net/" + application = "testapplication" + account = "testaccount" + user = "testuser" + password = "testpassword" + service_name = "" + + ref_sso_url = "https://testsso.snowflake.net/sso" + ref_token_url = "https://testsso.snowflake.net/token" + rest = _init_rest(ref_sso_url, ref_token_url) + + auth = AuthByOkta(application) + + # step 1 + headers, sso_url, token_url = await auth._step1( + rest._connection, authenticator, service_name, account, user + ) + assert not rest._connection.errorhandler.called # no error + assert headers.get("accept") is not None + assert headers.get("Content-Type") is not None + assert headers.get("User-Agent") is not None + assert sso_url == ref_sso_url + assert token_url == ref_token_url + + # step 2 + await auth._step2(rest._connection, authenticator, sso_url, token_url) + assert not rest._connection.errorhandler.called # no error + + # step 3 + ref_one_time_token = "1token1" + + async def fake_fetch(method, full_url, headers, **kwargs): + return { + "cookieToken": ref_one_time_token, + } + + rest.fetch = fake_fetch + one_time_token = await auth._step3( + rest._connection, headers, token_url, user, password + ) + assert not rest._connection.errorhandler.called # no error + assert one_time_token == ref_one_time_token + + # step 4 + ref_response_html = """ + +
+ +""" + + async def fake_fetch(method, full_url, headers, **kwargs): + return ref_response_html + + async def get_one_time_token(): + return one_time_token + + rest.fetch = fake_fetch + response_html = await auth._step4(rest._connection, get_one_time_token, sso_url) + assert response_html == response_html + + # step 5 + rest._protocol = "https" + rest._host = f"{account}.snowflakecomputing.com" + rest._port = 443 + await auth._step5(rest._connection, ref_response_html) + assert not rest._connection.errorhandler.called # no error + assert ref_response_html == auth._saml_response + + +async def test_auth_okta_step1_negative(): + """Authentication by OKTA step1 negative test case.""" + authenticator = "https://testsso.snowflake.net/" + application = "testapplication" + account = "testaccount" + user = "testuser" + service_name = "" + + # not success status is returned + ref_sso_url = "https://testsso.snowflake.net/sso" + ref_token_url = "https://testsso.snowflake.net/token" + rest = _init_rest(ref_sso_url, ref_token_url, success=False, message="error") + auth = AuthByOkta(application) + # step 1 + _, _, _ = await auth._step1( + rest._connection, authenticator, service_name, account, user + ) + assert rest._connection.errorhandler.called # error should be raised + + +async def test_auth_okta_step2_negative(): + """Authentication by OKTA step2 negative test case.""" + authenticator = "https://testsso.snowflake.net/" + application = "testapplication" + account = "testaccount" + user = "testuser" + service_name = "" + + # invalid SSO URL + ref_sso_url = "https://testssoinvalid.snowflake.net/sso" + ref_token_url = "https://testsso.snowflake.net/token" + rest = _init_rest(ref_sso_url, ref_token_url) + + auth = AuthByOkta(application) + # step 1 + headers, sso_url, token_url = await auth._step1( + rest._connection, authenticator, service_name, account, user + ) + # step 2 + await auth._step2(rest._connection, authenticator, sso_url, token_url) + assert rest._connection.errorhandler.called # error + + # invalid TOKEN URL + ref_sso_url = "https://testsso.snowflake.net/sso" + ref_token_url = "https://testssoinvalid.snowflake.net/token" + rest = _init_rest(ref_sso_url, ref_token_url) + + auth = AuthByOkta(application) + # step 1 + headers, sso_url, token_url = await auth._step1( + rest._connection, authenticator, service_name, account, user + ) + # step 2 + await auth._step2(rest._connection, authenticator, sso_url, token_url) + assert rest._connection.errorhandler.called # error + + +async def test_auth_okta_step3_negative(): + """Authentication by OKTA step3 negative test case.""" + authenticator = "https://testsso.snowflake.net/" + application = "testapplication" + account = "testaccount" + user = "testuser" + password = "testpassword" + service_name = "" + + ref_sso_url = "https://testsso.snowflake.net/sso" + ref_token_url = "https://testsso.snowflake.net/token" + rest = _init_rest(ref_sso_url, ref_token_url) + + auth = AuthByOkta(application) + # step 1 + headers, sso_url, token_url = await auth._step1( + rest._connection, authenticator, service_name, account, user + ) + # step 2 + await auth._step2(rest._connection, authenticator, sso_url, token_url) + assert not rest._connection.errorhandler.called # no error + + # step 3: authentication by IdP failed. + async def fake_fetch(method, full_url, headers, **kwargs): + return { + "failed": "auth failed", + } + + rest.fetch = fake_fetch + _ = await auth._step3(rest._connection, headers, token_url, user, password) + assert rest._connection.errorhandler.called # auth failure error + + +async def test_auth_okta_step4_negative(caplog): + """Authentication by OKTA step4 negative test case.""" + authenticator = "https://testsso.snowflake.net/" + application = "testapplication" + account = "testaccount" + user = "testuser" + service_name = "" + + ref_sso_url = "https://testsso.snowflake.net/sso" + ref_token_url = "https://testsso.snowflake.net/token" + rest = _init_rest(ref_sso_url, ref_token_url) + + auth = AuthByOkta(application) + # step 1 + headers, sso_url, token_url = await auth._step1( + rest._connection, authenticator, service_name, account, user + ) + # step 2 + await auth._step2(rest._connection, authenticator, sso_url, token_url) + assert not rest._connection.errorhandler.called # no error + + # step 3: authentication by IdP failed due to throttling + raise_token_refresh_error = True + second_token_generated = False + + async def get_one_time_token(): + nonlocal raise_token_refresh_error + nonlocal second_token_generated + if raise_token_refresh_error: + assert not second_token_generated + return "1token1" + else: + second_token_generated = True + return "2token2" + + # the first time, when step4 gets executed, we return 429 + # the second time when step4 gets retried, we return 200 + async def mock_session_request(*args, **kwargs): + nonlocal second_token_generated + url = kwargs.get("url") + assert url == ( + "https://testsso.snowflake.net/sso?RelayState=%2Fsome%2Fdeep%2Flink&onetimetoken=1token1" + if not second_token_generated + else "https://testsso.snowflake.net/sso?RelayState=%2Fsome%2Fdeep%2Flink&onetimetoken=2token2" + ) + nonlocal raise_token_refresh_error + if raise_token_refresh_error: + raise_token_refresh_error = False + return AsyncMock(status=429) + else: + resp = AsyncMock(status=200) + resp.text.return_value = "success" + return resp + + with patch.object( + aiohttp.ClientSession, + "request", + new=mock_session_request, + ): + caplog.set_level(logging.DEBUG, "snowflake.connector") + response_html = await auth._step4(rest._connection, get_one_time_token, sso_url) + # make sure the RefreshToken error is caught and tried + assert "step4: refresh token for re-authentication" in caplog.text + # test that token generation method is called + assert second_token_generated + assert response_html == "success" + assert not rest._connection.errorhandler.called + + +@pytest.mark.parametrize("disable_saml_url_check", [True, False]) +async def test_auth_okta_step5_negative(disable_saml_url_check): + """Authentication by OKTA step5 negative test case.""" + authenticator = "https://testsso.snowflake.net/" + application = "testapplication" + account = "testaccount" + user = "testuser" + password = "testpassword" + service_name = "" + + ref_sso_url = "https://testsso.snowflake.net/sso" + ref_token_url = "https://testsso.snowflake.net/token" + rest = _init_rest( + ref_sso_url, ref_token_url, disable_saml_url_check=disable_saml_url_check + ) + + auth = AuthByOkta(application) + # step 1 + headers, sso_url, token_url = await auth._step1( + rest._connection, authenticator, service_name, account, user + ) + assert not rest._connection.errorhandler.called # no error + # step 2 + await auth._step2(rest._connection, authenticator, sso_url, token_url) + assert not rest._connection.errorhandler.called # no error + # step 3 + ref_one_time_token = "1token1" + + async def fake_fetch(method, full_url, headers, **kwargs): + return { + "cookieToken": ref_one_time_token, + } + + rest.fetch = fake_fetch + one_time_token = await auth._step3( + rest._connection, headers, token_url, user, password + ) + assert not rest._connection.errorhandler.called # no error + + # step 4 + # HTML includes invalid account name + ref_response_html = """ + +
+ +""" + + async def fake_fetch(method, full_url, headers, **kwargs): + return ref_response_html + + async def get_one_time_token(): + return one_time_token + + rest.fetch = fake_fetch + response_html = await auth._step4(rest._connection, get_one_time_token, sso_url) + assert response_html == ref_response_html + + # step 5 + rest._protocol = "https" + rest._host = f"{account}.snowflakecomputing.com" + rest._port = 443 + await auth._step5(rest._connection, ref_response_html) + assert disable_saml_url_check ^ rest._connection.errorhandler.called # error + + +def _init_rest( + ref_sso_url, ref_token_url, success=True, message=None, disable_saml_url_check=False +): + async def post_request(url, headers, body, **kwargs): + _ = url + _ = headers + _ = body + _ = kwargs.get("dummy") + return { + "success": success, + "message": message, + "data": { + "ssoUrl": ref_sso_url, + "tokenUrl": ref_token_url, + }, + } + + connection = mock_connection(disable_saml_url_check=disable_saml_url_check) + connection.errorhandler = Mock(return_value=None) + connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) + type(connection).application = PropertyMock(return_value=CLIENT_NAME) + type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME) + type(connection)._internal_application_version = PropertyMock( + return_value=CLIENT_VERSION + ) + + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", port=443, connection=connection + ) + connection._rest = rest + rest._post_request = post_request + return rest diff --git a/test/unit/aio/test_auth_webbrowser_async.py b/test/unit/aio/test_auth_webbrowser_async.py new file mode 100644 index 0000000000..758529137f --- /dev/null +++ b/test/unit/aio/test_auth_webbrowser_async.py @@ -0,0 +1,873 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import base64 +import socket +from test.unit.aio.mock_utils import mock_connection +from unittest import mock +from unittest.mock import MagicMock, Mock, PropertyMock, patch + +import pytest + +from snowflake.connector.aio import SnowflakeConnection +from snowflake.connector.aio._network import SnowflakeRestful +from snowflake.connector.aio.auth import AuthByIdToken, AuthByWebBrowser +from snowflake.connector.compat import IS_WINDOWS, urlencode +from snowflake.connector.constants import OCSPMode +from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION +from snowflake.connector.network import ( + EXTERNAL_BROWSER_AUTHENTICATOR, + ReauthenticationRequest, +) + +AUTHENTICATOR = "https://testsso.snowflake.net/" +APPLICATION = "testapplication" +ACCOUNT = "testaccount" +USER = "testuser" +PASSWORD = "testpassword" +SERVICE_NAME = "" +REF_PROOF_KEY = "MOCK_PROOF_KEY" +REF_SSO_URL = "https://testsso.snowflake.net/sso" +INVALID_SSO_URL = "this is an invalid URL" +CLIENT_PORT = 12345 +SNOWFLAKE_PORT = 443 +HOST = "testaccount.snowflakecomputing.com" +PROOF_KEY = b"F5mR7M2J4y0jgG9CqyyWqEpyFT2HG48HFUByOS3tGaI" +REF_CONSOLE_LOGIN_SSO_URL = ( + f"http://{HOST}:{SNOWFLAKE_PORT}/console/login?login_name={USER}&browser_mode_redirect_port={CLIENT_PORT}&" + + urlencode({"proof_key": base64.b64encode(PROOF_KEY).decode("ascii")}) +) + + +def mock_webserver(target_instance, application, port): + _ = application + _ = port + target_instance._webserver_status = True + + +def successful_web_callback(token): + return ( + "\r\n".join( + [ + f"GET /?token={token}&confirm=true HTTP/1.1", + "User-Agent: snowflake-agent", + ] + ) + ).encode("utf-8") + + +def _init_socket(): + mock_socket_instance = MagicMock() + mock_socket_instance.getsockname.return_value = [None, CLIENT_PORT] + mock_socket_client = MagicMock() + mock_socket_instance.accept.return_value = (mock_socket_client, None) + return Mock(return_value=mock_socket_instance) + + +def _mock_event_loop_sock_accept(): + async def mock_accept(*_): + mock_socket_client = MagicMock() + mock_socket_client.send.side_effect = lambda *args: None + return mock_socket_client, None + + return mock_accept + + +def _mock_event_loop_sock_recv(recv_side_effect_func): + async def mock_recv(*args): + # first arg is socket_client, second arg is BUF_SIZE + assert len(args) == 2 + return recv_side_effect_func(args[1]) + + return mock_recv + + +class UnexpectedRecvArgs(Exception): + pass + + +def recv_setup(recv_list): + recv_call_number = 0 + + def recv_side_effect(*args): + nonlocal recv_call_number + recv_call_number += 1 + + # if we should block (default behavior), then the only arg should be BUF_SIZE + if len(args) == 1: + return recv_list[recv_call_number - 1] + + raise UnexpectedRecvArgs( + f"socket_client.recv call expected a single argeument, but received: {args}" + ) + + return recv_side_effect + + +def recv_setup_with_msg_nowait( + ref_token, number_of_blocking_io_errors_before_success=1 +): + call_number = 0 + + def internally_scoped_function(*args): + nonlocal call_number + call_number += 1 + + if call_number <= number_of_blocking_io_errors_before_success: + raise BlockingIOError() + else: + return successful_web_callback(ref_token) + + return internally_scoped_function + + +@pytest.mark.parametrize("disable_console_login", [True, False]) +@patch("secrets.token_bytes", return_value=PROOF_KEY) +async def test_auth_webbrowser_get(_, disable_console_login): + """Authentication by WebBrowser positive test case.""" + ref_token = "MOCK_TOKEN" + + rest = _init_rest( + REF_SSO_URL, REF_PROOF_KEY, disable_console_login=disable_console_login + ) + + # mock socket + recv_func = recv_setup([successful_web_callback(ref_token)]) + mock_socket_pkg = _init_socket() + + # mock webbrowser + mock_webbrowser = MagicMock() + mock_webbrowser.open_new.return_value = True + + # Mock select.select to return socket client + with mock.patch( + "select.select", return_value=([mock_socket_pkg.return_value], [], []) + ): + auth = AuthByWebBrowser( + application=APPLICATION, + webbrowser_pkg=mock_webbrowser, + socket_pkg=mock_socket_pkg, + ) + with mock.patch.object( + auth._event_loop, + "sock_accept", + side_effect=_mock_event_loop_sock_accept(), + ), mock.patch.object( + auth._event_loop, "sock_sendall", return_value=None + ), mock.patch.object( + auth._event_loop, + "sock_recv", + side_effect=_mock_event_loop_sock_recv(recv_func), + ): + await auth.prepare( + conn=rest._connection, + authenticator=AUTHENTICATOR, + service_name=SERVICE_NAME, + account=ACCOUNT, + user=USER, + password=PASSWORD, + ) + assert not rest._connection.errorhandler.called # no error + assert auth.assertion_content == ref_token + body = {"data": {}} + await auth.update_body(body) + assert body["data"]["TOKEN"] == ref_token + assert body["data"]["AUTHENTICATOR"] == EXTERNAL_BROWSER_AUTHENTICATOR + + if disable_console_login: + mock_webbrowser.open_new.assert_called_once_with(REF_SSO_URL) + assert body["data"]["PROOF_KEY"] == REF_PROOF_KEY + else: + mock_webbrowser.open_new.assert_called_once_with(REF_CONSOLE_LOGIN_SSO_URL) + + +@pytest.mark.parametrize("disable_console_login", [True, False]) +@patch("secrets.token_bytes", return_value=PROOF_KEY) +async def test_auth_webbrowser_post(_, disable_console_login): + """Authentication by WebBrowser positive test case with POST.""" + ref_token = "MOCK_TOKEN" + + rest = _init_rest( + REF_SSO_URL, REF_PROOF_KEY, disable_console_login=disable_console_login + ) + + # mock socket + recv_func = recv_setup( + [ + ( + "\r\n".join( + [ + "POST / HTTP/1.1", + "User-Agent: snowflake-agent", + f"Host: localhost:{CLIENT_PORT}", + "", + f"token={ref_token}&confirm=true", + ] + ) + ).encode("utf-8") + ] + ) + mock_socket_pkg = _init_socket() + + # mock webbrowser + mock_webbrowser = MagicMock() + mock_webbrowser.open_new.return_value = True + + # Mock select.select to return socket client + with mock.patch( + "select.select", return_value=([mock_socket_pkg.return_value], [], []) + ): + auth = AuthByWebBrowser( + application=APPLICATION, + webbrowser_pkg=mock_webbrowser, + socket_pkg=mock_socket_pkg, + ) + with mock.patch.object( + auth._event_loop, + "sock_accept", + side_effect=_mock_event_loop_sock_accept(), + ), mock.patch.object( + auth._event_loop, "sock_sendall", return_value=None + ), mock.patch.object( + auth._event_loop, + "sock_recv", + side_effect=_mock_event_loop_sock_recv(recv_func), + ): + await auth.prepare( + conn=rest._connection, + authenticator=AUTHENTICATOR, + service_name=SERVICE_NAME, + account=ACCOUNT, + user=USER, + password=PASSWORD, + ) + assert not rest._connection.errorhandler.called # no error + assert auth.assertion_content == ref_token + body = {"data": {}} + await auth.update_body(body) + assert body["data"]["TOKEN"] == ref_token + assert body["data"]["AUTHENTICATOR"] == EXTERNAL_BROWSER_AUTHENTICATOR + + if disable_console_login: + mock_webbrowser.open_new.assert_called_once_with(REF_SSO_URL) + assert body["data"]["PROOF_KEY"] == REF_PROOF_KEY + else: + mock_webbrowser.open_new.assert_called_once_with(REF_CONSOLE_LOGIN_SSO_URL) + + +@pytest.mark.parametrize("disable_console_login", [True, False]) +@pytest.mark.parametrize( + "input_text,expected_error", + [ + ("", True), + ("http://example.com/notokenurl", True), + ("http://example.com/sso?token=", True), + ("http://example.com/sso?token=MOCK_TOKEN", False), + ], +) +@patch("secrets.token_bytes", return_value=PROOF_KEY) +async def test_auth_webbrowser_fail_webbrowser( + _, capsys, input_text, expected_error, disable_console_login +): + """Authentication by WebBrowser with failed to start web browser case.""" + rest = _init_rest( + REF_SSO_URL, REF_PROOF_KEY, disable_console_login=disable_console_login + ) + ref_token = "MOCK_TOKEN" + + # mock socket + recv_func = recv_setup([successful_web_callback(ref_token)]) + mock_socket_pkg = _init_socket() + + # mock webbrowser + mock_webbrowser = MagicMock() + mock_webbrowser.open_new.return_value = False + + auth = AuthByWebBrowser( + application=APPLICATION, + webbrowser_pkg=mock_webbrowser, + socket_pkg=mock_socket_pkg, + ) + with patch("builtins.input", return_value=input_text), patch.object( + auth._event_loop, + "sock_accept", + side_effect=_mock_event_loop_sock_accept(), + ), mock.patch.object( + auth._event_loop, "sock_sendall", return_value=None + ), mock.patch.object( + auth._event_loop, "sock_recv", side_effect=_mock_event_loop_sock_recv(recv_func) + ): + await auth.prepare( + conn=rest._connection, + authenticator=AUTHENTICATOR, + service_name=SERVICE_NAME, + account=ACCOUNT, + user=USER, + password=PASSWORD, + ) + captured = capsys.readouterr() + assert captured.out == ( + "Initiating login request with your identity provider. A browser window " + "should have opened for you to complete the login. If you can't see it, " + "check existing browser windows, or your OS settings. Press CTRL+C to " + f"abort and try again...\nGoing to open: {REF_SSO_URL if disable_console_login else REF_CONSOLE_LOGIN_SSO_URL} to authenticate...\nWe were unable to open a browser window for " + "you, please open the url above manually then paste the URL you " + "are redirected to into the terminal.\n" + ) + if expected_error: + assert rest._connection.errorhandler.called # an error + assert auth.assertion_content is None + else: + assert not rest._connection.errorhandler.called # no error + body = {"data": {}} + await auth.update_body(body) + assert body["data"]["TOKEN"] == ref_token + assert body["data"]["AUTHENTICATOR"] == EXTERNAL_BROWSER_AUTHENTICATOR + if disable_console_login: + assert body["data"]["PROOF_KEY"] == REF_PROOF_KEY + + +@pytest.mark.parametrize("disable_console_login", [True, False]) +@patch("secrets.token_bytes", return_value=PROOF_KEY) +async def test_auth_webbrowser_fail_webserver(_, capsys, disable_console_login): + """Authentication by WebBrowser with failed to start web browser case.""" + rest = _init_rest( + REF_SSO_URL, REF_PROOF_KEY, disable_console_login=disable_console_login + ) + + # mock socket + recv_func = recv_setup( + [("\r\n".join(["GARBAGE", "User-Agent: snowflake-agent"])).encode("utf-8")] + ) + mock_socket_pkg = _init_socket() + + # mock webbrowser + mock_webbrowser = MagicMock() + mock_webbrowser.open_new.return_value = True + + # Mock select.select to return socket client + with mock.patch( + "select.select", return_value=([mock_socket_pkg.return_value], [], []) + ): + # case 1: invalid HTTP request + auth = AuthByWebBrowser( + application=APPLICATION, + webbrowser_pkg=mock_webbrowser, + socket_pkg=mock_socket_pkg, + ) + with mock.patch.object( + auth._event_loop, + "sock_accept", + side_effect=_mock_event_loop_sock_accept(), + ), mock.patch.object( + auth._event_loop, "sock_sendall", return_value=None + ), mock.patch.object( + auth._event_loop, + "sock_recv", + side_effect=_mock_event_loop_sock_recv(recv_func), + ): + await auth.prepare( + conn=rest._connection, + authenticator=AUTHENTICATOR, + service_name=SERVICE_NAME, + account=ACCOUNT, + user=USER, + password=PASSWORD, + ) + captured = capsys.readouterr() + assert captured.out == ( + "Initiating login request with your identity provider. A browser window " + "should have opened for you to complete the login. If you can't see it, " + "check existing browser windows, or your OS settings. Press CTRL+C to " + f"abort and try again...\nGoing to open: {REF_SSO_URL if disable_console_login else REF_CONSOLE_LOGIN_SSO_URL} to authenticate...\n" + ) + assert rest._connection.errorhandler.called # an error + assert auth.assertion_content is None + + +def _init_rest( + ref_sso_url, + ref_proof_key, + success=True, + message=None, + disable_console_login=False, + socket_timeout=None, +): + async def post_request(url, headers, body, **kwargs): + _ = url + _ = headers + _ = body + _ = kwargs.get("dummy") + return { + "success": success, + "message": message, + "data": { + "ssoUrl": ref_sso_url, + "proofKey": ref_proof_key, + }, + } + + connection = mock_connection(socket_timeout=socket_timeout) + connection.errorhandler = Mock(return_value=None) + connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) + connection._disable_console_login = disable_console_login + type(connection).application = PropertyMock(return_value=CLIENT_NAME) + type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME) + type(connection)._internal_application_version = PropertyMock( + return_value=CLIENT_VERSION + ) + + rest = SnowflakeRestful(host=HOST, port=SNOWFLAKE_PORT, connection=connection) + rest._post_request = post_request + connection._rest = rest + return rest + + +async def test_idtoken_reauth(): + """This test makes sure that AuthByIdToken reverts to AuthByWebBrowser. + + This happens when the initial connection fails. Such as when the saved ID + token has expired. + """ + + auth_inst = AuthByIdToken( + id_token="token", + application="application", + protocol="protocol", + host="host", + port="port", + ) + + # We'll use this Exception to make sure AuthByWebBrowser authentication + # flow is called as expected + class StopExecuting(Exception): + pass + + with mock.patch( + "snowflake.connector.aio.auth.AuthByIdToken.prepare", + side_effect=ReauthenticationRequest(Exception()), + ): + with mock.patch( + "snowflake.connector.aio.auth.AuthByWebBrowser.prepare", + side_effect=StopExecuting(), + ): + with pytest.raises(StopExecuting): + async with SnowflakeConnection( + user="user", + account="account", + auth_class=auth_inst, + ): + pass + + +async def test_auth_webbrowser_invalid_sso(monkeypatch): + """Authentication by WebBrowser with failed to start web browser case.""" + rest = _init_rest(INVALID_SSO_URL, REF_PROOF_KEY, disable_console_login=True) + + # mock webbrowser + mock_webbrowser = MagicMock() + mock_webbrowser.open_new.return_value = False + + # mock socket + mock_socket_instance = MagicMock() + mock_socket_instance.getsockname.return_value = [None, CLIENT_PORT] + + mock_socket_client = MagicMock() + mock_socket_client.recv.return_value = ( + "\r\n".join(["GET /?token=MOCK_TOKEN HTTP/1.1", "User-Agent: snowflake-agent"]) + ).encode("utf-8") + mock_socket_instance.accept.return_value = (mock_socket_client, None) + mock_socket = Mock(return_value=mock_socket_instance) + + auth = AuthByWebBrowser( + application=APPLICATION, + webbrowser_pkg=mock_webbrowser, + socket_pkg=mock_socket, + ) + await auth.prepare( + conn=rest._connection, + authenticator=AUTHENTICATOR, + service_name=SERVICE_NAME, + account=ACCOUNT, + user=USER, + password=PASSWORD, + ) + assert rest._connection.errorhandler.called # an error + assert auth.assertion_content is None + + +async def test_auth_webbrowser_socket_recv_retries_up_to_15_times_on_empty_bytearray(): + """Authentication by WebBrowser retries on empty bytearray response from socket.recv""" + ref_token = "MOCK_TOKEN" + rest = _init_rest(REF_SSO_URL, REF_PROOF_KEY, disable_console_login=True) + + # mock socket + recv_func = recv_setup( + # 14th return is empty byte array, but 15th call will return successful_web_callback + ([bytearray()] * 14) + + [successful_web_callback(ref_token)] + ) + mock_socket_pkg = _init_socket() + + # mock webbrowser + mock_webbrowser = MagicMock() + mock_webbrowser.open_new.return_value = True + + # Mock select.select to return socket client + with mock.patch( + "select.select", return_value=([mock_socket_pkg.return_value], [], []) + ), mock.patch("asyncio.sleep") as sleep: + auth = AuthByWebBrowser( + application=APPLICATION, + webbrowser_pkg=mock_webbrowser, + socket_pkg=mock_socket_pkg, + ) + with patch.object( + auth._event_loop, + "sock_accept", + side_effect=_mock_event_loop_sock_accept(), + ), mock.patch.object( + auth._event_loop, "sock_sendall", return_value=None + ), mock.patch.object( + auth._event_loop, + "sock_recv", + side_effect=_mock_event_loop_sock_recv(recv_func), + ): + await auth.prepare( + conn=rest._connection, + authenticator=AUTHENTICATOR, + service_name=SERVICE_NAME, + account=ACCOUNT, + user=USER, + password=PASSWORD, + ) + assert not rest._connection.errorhandler.called # no error + assert auth.assertion_content == ref_token + body = {"data": {}} + await auth.update_body(body) + assert body["data"]["TOKEN"] == ref_token + assert body["data"]["AUTHENTICATOR"] == EXTERNAL_BROWSER_AUTHENTICATOR + assert sleep.call_count == 0 + + +async def test_auth_webbrowser_socket_recv_loop_fails_after_15_attempts(): + """Authentication by WebBrowser stops trying after 15 consective socket.recv emty bytearray returns.""" + ref_token = "MOCK_TOKEN" + rest = _init_rest(REF_SSO_URL, REF_PROOF_KEY) + + # mock socket + recv_func = recv_setup( + # 15th return is empty byte array, so successful_web_callback will never be fetched from recv + ([bytearray()] * 15) + + [successful_web_callback(ref_token)] + ) + mock_socket_pkg = _init_socket() + + # mock webbrowser + mock_webbrowser = MagicMock() + mock_webbrowser.open_new.return_value = True + + # Mock select.select to return socket client + with mock.patch( + "select.select", return_value=([mock_socket_pkg.return_value], [], []) + ), mock.patch("asyncio.sleep") as sleep: + auth = AuthByWebBrowser( + application=APPLICATION, + webbrowser_pkg=mock_webbrowser, + socket_pkg=mock_socket_pkg, + ) + with mock.patch.object( + auth._event_loop, + "sock_accept", + side_effect=_mock_event_loop_sock_accept(), + ), mock.patch.object( + auth._event_loop, + "sock_recv", + side_effect=_mock_event_loop_sock_recv(recv_func), + ): + await auth.prepare( + conn=rest._connection, + authenticator=AUTHENTICATOR, + service_name=SERVICE_NAME, + account=ACCOUNT, + user=USER, + password=PASSWORD, + ) + assert rest._connection.errorhandler.called # an error + assert auth.assertion_content is None + assert sleep.call_count == 0 + + +async def test_auth_webbrowser_socket_recv_does_not_block_with_env_var(monkeypatch): + """Authentication by WebBrowser socket.recv Does not block, but retries if BlockingIOError thrown.""" + ref_token = "MOCK_TOKEN" + rest = _init_rest( + REF_SSO_URL, REF_PROOF_KEY, disable_console_login=True, socket_timeout=1 + ) + + # mock socket + mock_socket_pkg = _init_socket() + + counting = 0 + + async def sock_recv_timeout(*_): + nonlocal counting + if counting < 14: + counting += 1 + raise asyncio.TimeoutError() + return successful_web_callback(ref_token) + + # mock webbrowser + mock_webbrowser = MagicMock() + mock_webbrowser.open_new.return_value = True + + # Mock select.select to return socket client + with mock.patch( + "select.select", return_value=([mock_socket_pkg.return_value], [], []) + ), mock.patch("asyncio.sleep") as sleep: + auth = AuthByWebBrowser( + application=APPLICATION, + webbrowser_pkg=mock_webbrowser, + socket_pkg=mock_socket_pkg, + ) + + with mock.patch.object( + auth._event_loop, "sock_recv", new=sock_recv_timeout + ), mock.patch.object( + auth._event_loop, + "sock_accept", + side_effect=_mock_event_loop_sock_accept(), + ), mock.patch.object( + auth._event_loop, "sock_sendall", return_value=None + ): + await auth.prepare( + conn=rest._connection, + authenticator=AUTHENTICATOR, + service_name=SERVICE_NAME, + account=ACCOUNT, + user=USER, + password=PASSWORD, + ) + assert not rest._connection.errorhandler.called # no error + assert auth.assertion_content == ref_token + body = {"data": {}} + await auth.update_body(body) + assert body["data"]["TOKEN"] == ref_token + assert body["data"]["PROOF_KEY"] == REF_PROOF_KEY + assert body["data"]["AUTHENTICATOR"] == EXTERNAL_BROWSER_AUTHENTICATOR + sleep_times = [t[0][0] for t in sleep.call_args_list] + assert sleep.call_count == counting == 14 + assert sleep_times == [0.25] * 14 + + +async def test_auth_webbrowser_socket_recv_blocking_stops_retries_after_15_attempts( + monkeypatch, +): + """Authentication by WebBrowser socket.recv Does not block, but retries if BlockingIOError thrown.""" + rest = _init_rest(REF_SSO_URL, REF_PROOF_KEY) + + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", "true") + + # mock socket + mock_socket_pkg = _init_socket() + + # mock webbrowser + mock_webbrowser = MagicMock() + mock_webbrowser.open_new.return_value = True + + async def sock_recv_timeout(*_): + raise asyncio.TimeoutError() + + # Mock select.select to return socket client + with mock.patch( + "select.select", return_value=([mock_socket_pkg.return_value], [], []) + ), mock.patch("asyncio.sleep") as sleep: + auth = AuthByWebBrowser( + application=APPLICATION, + webbrowser_pkg=mock_webbrowser, + socket_pkg=mock_socket_pkg, + ) + with mock.patch.object( + auth._event_loop, "sock_recv", new=sock_recv_timeout + ), mock.patch.object( + auth._event_loop, "sock_sendall", return_value=None + ), mock.patch.object( + auth._event_loop, + "sock_accept", + side_effect=_mock_event_loop_sock_accept(), + ): + await auth.prepare( + conn=rest._connection, + authenticator=AUTHENTICATOR, + service_name=SERVICE_NAME, + account=ACCOUNT, + user=USER, + password=PASSWORD, + ) + assert rest._connection.errorhandler.called # an error + assert auth.assertion_content is None + sleep_times = [t[0][0] for t in sleep.call_args_list] + assert sleep.call_count == 14 + assert sleep_times == [0.25] * 14 + + +@pytest.mark.skipif( + IS_WINDOWS, reason="SNOWFLAKE_AUTH_SOCKET_REUSE_PORT is not supported on Windows" +) +async def test_auth_webbrowser_socket_reuseport_with_env_flag(monkeypatch): + """Authentication by WebBrowser socket.recv Does not block, but retries if BlockingIOError thrown.""" + ref_token = "MOCK_TOKEN" + rest = _init_rest(REF_SSO_URL, REF_PROOF_KEY) + + # mock socket + recv_func = recv_setup([successful_web_callback(ref_token)]) + mock_socket_pkg = _init_socket() + + # mock webbrowser + mock_webbrowser = MagicMock() + mock_webbrowser.open_new.return_value = True + + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + + # Mock select.select to return socket client + with mock.patch( + "select.select", return_value=([mock_socket_pkg.return_value], [], []) + ): + auth = AuthByWebBrowser( + application=APPLICATION, + webbrowser_pkg=mock_webbrowser, + socket_pkg=mock_socket_pkg, + ) + with mock.patch.object( + auth._event_loop, + "sock_accept", + side_effect=_mock_event_loop_sock_accept(), + ), mock.patch.object( + auth._event_loop, "sock_sendall", return_value=None + ), mock.patch.object( + auth._event_loop, + "sock_recv", + side_effect=_mock_event_loop_sock_recv(recv_func), + ): + await auth.prepare( + conn=rest._connection, + authenticator=AUTHENTICATOR, + service_name=SERVICE_NAME, + account=ACCOUNT, + user=USER, + password=PASSWORD, + ) + assert mock_socket_pkg.return_value.setsockopt.call_count == 1 + assert mock_socket_pkg.return_value.setsockopt.call_args.args == ( + socket.SOL_SOCKET, + socket.SO_REUSEPORT, + 1, + ) + + assert not rest._connection.errorhandler.called # no error + assert auth.assertion_content == ref_token + + +async def test_auth_webbrowser_socket_reuseport_option_not_set_with_false_flag( + monkeypatch, +): + """Authentication by WebBrowser socket.recv Does not block, but retries if BlockingIOError thrown.""" + ref_token = "MOCK_TOKEN" + rest = _init_rest(REF_SSO_URL, REF_PROOF_KEY) + + # mock socket + recv_func = recv_setup([successful_web_callback(ref_token)]) + mock_socket_pkg = _init_socket() + + # mock webbrowser + mock_webbrowser = MagicMock() + mock_webbrowser.open_new.return_value = True + + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "false") + + # Mock select.select to return socket client + with mock.patch( + "select.select", return_value=([mock_socket_pkg.return_value], [], []) + ): + auth = AuthByWebBrowser( + application=APPLICATION, + webbrowser_pkg=mock_webbrowser, + socket_pkg=mock_socket_pkg, + ) + with mock.patch.object( + auth._event_loop, + "sock_accept", + side_effect=_mock_event_loop_sock_accept(), + ), mock.patch.object( + auth._event_loop, "sock_sendall", return_value=None + ), mock.patch.object( + auth._event_loop, + "sock_recv", + side_effect=_mock_event_loop_sock_recv(recv_func), + ): + await auth.prepare( + conn=rest._connection, + authenticator=AUTHENTICATOR, + service_name=SERVICE_NAME, + account=ACCOUNT, + user=USER, + password=PASSWORD, + ) + assert mock_socket_pkg.return_value.setsockopt.call_count == 0 + + assert not rest._connection.errorhandler.called # no error + assert auth.assertion_content == ref_token + + +async def test_auth_webbrowser_socket_reuseport_option_not_set_with_no_flag( + monkeypatch, +): + """Authentication by WebBrowser socket.recv Does not block, but retries if BlockingIOError thrown.""" + ref_token = "MOCK_TOKEN" + rest = _init_rest(REF_SSO_URL, REF_PROOF_KEY) + + # mock socket + recv_func = recv_setup([successful_web_callback(ref_token)]) + mock_socket_pkg = _init_socket() + + # mock webbrowser + mock_webbrowser = MagicMock() + mock_webbrowser.open_new.return_value = True + + # Mock select.select to return socket client + with mock.patch( + "select.select", return_value=([mock_socket_pkg.return_value], [], []) + ): + auth = AuthByWebBrowser( + application=APPLICATION, + webbrowser_pkg=mock_webbrowser, + socket_pkg=mock_socket_pkg, + ) + with mock.patch.object( + auth._event_loop, + "sock_accept", + side_effect=_mock_event_loop_sock_accept(), + ), mock.patch.object( + auth._event_loop, "sock_sendall", return_value=None + ), mock.patch.object( + auth._event_loop, + "sock_recv", + side_effect=_mock_event_loop_sock_recv(recv_func), + ): + await auth.prepare( + conn=rest._connection, + authenticator=AUTHENTICATOR, + service_name=SERVICE_NAME, + account=ACCOUNT, + user=USER, + password=PASSWORD, + ) + assert mock_socket_pkg.return_value.setsockopt.call_count == 0 + + assert not rest._connection.errorhandler.called # no error + assert auth.assertion_content == ref_token diff --git a/test/unit/aio/test_bind_upload_agent_async.py b/test/unit/aio/test_bind_upload_agent_async.py new file mode 100644 index 0000000000..ffceb50f15 --- /dev/null +++ b/test/unit/aio/test_bind_upload_agent_async.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from unittest.mock import AsyncMock + + +async def test_bind_upload_agent_uploading_multiple_files(): + from snowflake.connector.aio._build_upload_agent import BindUploadAgent + + csr = AsyncMock(auto_spec=True) + rows = [bytes(10)] * 10 + agent = BindUploadAgent(csr, rows, stream_buffer_size=10) + await agent.upload() + assert csr.execute.call_count == 11 # 1 for stage creation + 10 files + + +async def test_bind_upload_agent_row_size_exceed_buffer_size(): + from snowflake.connector.aio._build_upload_agent import BindUploadAgent + + csr = AsyncMock(auto_spec=True) + rows = [bytes(15)] * 10 + agent = BindUploadAgent(csr, rows, stream_buffer_size=10) + await agent.upload() + assert csr.execute.call_count == 11 # 1 for stage creation + 10 files diff --git a/test/unit/aio/test_connection_async_unit.py b/test/unit/aio/test_connection_async_unit.py new file mode 100644 index 0000000000..1e20b244cd --- /dev/null +++ b/test/unit/aio/test_connection_async_unit.py @@ -0,0 +1,553 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import json +import logging +import os +import stat +import sys +from contextlib import asynccontextmanager +from pathlib import Path +from secrets import token_urlsafe +from test.randomize import random_string +from test.unit.aio.mock_utils import mock_async_request_with_action +from test.unit.mock_utils import zero_backoff +from textwrap import dedent +from unittest import mock +from unittest.mock import patch + +import aiohttp +import pytest +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + +import snowflake.connector.aio +from snowflake.connector.aio._network import SnowflakeRestful +from snowflake.connector.aio.auth import ( + AuthByDefault, + AuthByOAuth, + AuthByOkta, + AuthByUsrPwdMfa, + AuthByWebBrowser, +) +from snowflake.connector.config_manager import CONFIG_MANAGER +from snowflake.connector.connection import DEFAULT_CONFIGURATION +from snowflake.connector.constants import ( + _CONNECTIVITY_ERR_MSG, + ENV_VAR_PARTNER, + QueryStatus, +) +from snowflake.connector.errors import ( + Error, + InterfaceError, + OperationalError, + ProgrammingError, +) + + +def fake_connector(**kwargs) -> snowflake.connector.aio.SnowflakeConnection: + return snowflake.connector.aio.SnowflakeConnection( + user="user", + account="account", + password="testpassword", + database="TESTDB", + warehouse="TESTWH", + **kwargs, + ) + + +@asynccontextmanager +async def fake_db_conn(**kwargs): + conn = fake_connector(**kwargs) + await conn.connect() + yield conn + await conn.close() + + +@pytest.fixture +def mock_post_requests(monkeypatch): + request_body = {} + + async def mock_post_request(request, url, headers, json_body, **kwargs): + nonlocal request_body + request_body.update(json.loads(json_body)) + return { + "success": True, + "message": None, + "data": { + "token": "TOKEN", + "masterToken": "MASTER_TOKEN", + "idToken": None, + "parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}], + }, + } + + monkeypatch.setattr( + snowflake.connector.aio._network.SnowflakeRestful, + "_post_request", + mock_post_request, + ) + + return request_body + + +async def test_connect_with_service_name(mock_post_requests): + async with fake_db_conn() as conn: + assert conn.service_name == "FAKE_SERVICE_NAME" + + +@patch("snowflake.connector.aio._network.SnowflakeRestful._post_request") +async def test_connection_ignore_exception(mockSnowflakeRestfulPostRequest): + async def mock_post_request(url, headers, json_body, **kwargs): + global mock_cnt + ret = None + if mock_cnt == 0: + # return from /v1/login-request + ret = { + "success": True, + "message": None, + "data": { + "token": "TOKEN", + "masterToken": "MASTER_TOKEN", + "idToken": None, + "parameters": [ + {"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"} + ], + }, + } + elif mock_cnt == 1: + ret = { + "success": False, + "message": "Session gone", + "data": None, + "code": 390111, + } + mock_cnt += 1 + return ret + + # POST requests mock + mockSnowflakeRestfulPostRequest.side_effect = mock_post_request + + global mock_cnt + mock_cnt = 0 + + account = "testaccount" + user = "testuser" + + # connection + con = snowflake.connector.aio.SnowflakeConnection( + account=account, + user=user, + password="testpassword", + database="TESTDB", + warehouse="TESTWH", + ) + await con.connect() + # Test to see if closing connection works or raises an exception. If an exception is raised, test will fail. + await con.close() + + +def test_is_still_running(): + """Checks that is_still_running returns expected results.""" + statuses = [ + (QueryStatus.RUNNING, True), + (QueryStatus.ABORTING, False), + (QueryStatus.SUCCESS, False), + (QueryStatus.FAILED_WITH_ERROR, False), + (QueryStatus.ABORTED, False), + (QueryStatus.QUEUED, True), + (QueryStatus.FAILED_WITH_INCIDENT, False), + (QueryStatus.DISCONNECTED, False), + (QueryStatus.RESUMING_WAREHOUSE, True), + (QueryStatus.QUEUED_REPARING_WAREHOUSE, True), + (QueryStatus.RESTARTED, False), + (QueryStatus.BLOCKED, True), + (QueryStatus.NO_DATA, True), + ] + for status, expected_result in statuses: + assert ( + snowflake.connector.aio.SnowflakeConnection.is_still_running(status) + == expected_result + ) + + +async def test_partner_env_var(mock_post_requests): + PARTNER_NAME = "Amanda" + + with patch.dict(os.environ, {ENV_VAR_PARTNER: PARTNER_NAME}): + async with fake_db_conn() as conn: + assert conn.application == PARTNER_NAME + + assert ( + mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == PARTNER_NAME + ) + + +async def test_imported_module(mock_post_requests): + with patch.dict(sys.modules, {"streamlit": "foo"}): + async with fake_db_conn() as conn: + assert conn.application == "streamlit" + + assert ( + mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == "streamlit" + ) + + +@pytest.mark.parametrize( + "auth_class", + ( + pytest.param( + type("auth_class", (AuthByDefault,), {})("my_secret_password"), + id="AuthByDefault", + ), + pytest.param( + type("auth_class", (AuthByOAuth,), {})("my_token"), + id="AuthByOAuth", + ), + pytest.param( + type("auth_class", (AuthByOkta,), {})("Python connector"), + id="AuthByOkta", + ), + pytest.param( + type("auth_class", (AuthByUsrPwdMfa,), {})("password", "mfa_token"), + id="AuthByUsrPwdMfa", + ), + pytest.param( + type("auth_class", (AuthByWebBrowser,), {})(None, None), + id="AuthByWebBrowser", + ), + ), +) +async def test_negative_custom_auth(auth_class): + """Tests that non-AuthByKeyPair custom auth is not allowed.""" + with pytest.raises( + TypeError, + match="auth_class must be a child class of AuthByKeyPair", + ): + await snowflake.connector.aio.SnowflakeConnection( + account="account", + user="user", + auth_class=auth_class, + ).connect() + + +async def test_missing_default_connection(monkeypatch, tmp_path): + connections_file = tmp_path / "aio_connections.toml" + config_file = tmp_path / "aio_config.toml" + with monkeypatch.context() as m: + m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False) + m.delenv("SNOWFLAKE_CONNECTIONS", raising=False) + m.setattr(CONFIG_MANAGER, "conf_file_cache", None) + m.setattr(CONFIG_MANAGER, "file_path", config_file) + + with pytest.raises( + Error, + match="Default connection with name 'default' cannot be found, known ones are \\[\\]", + ): + snowflake.connector.aio.SnowflakeConnection( + connections_file_path=connections_file + ) + + +async def test_missing_default_connection_conf_file(monkeypatch, tmp_path): + connection_name = random_string(5) + connections_file = tmp_path / "aio_connections.toml" + config_file = tmp_path / "aio_config.toml" + config_file.write_text( + dedent( + f"""\ + default_connection_name = "{connection_name}" + """ + ) + ) + config_file.chmod(stat.S_IRUSR | stat.S_IWUSR) + with monkeypatch.context() as m: + m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False) + m.delenv("SNOWFLAKE_CONNECTIONS", raising=False) + m.setattr(CONFIG_MANAGER, "conf_file_cache", None) + m.setattr(CONFIG_MANAGER, "file_path", config_file) + + with pytest.raises( + Error, + match=f"Default connection with name '{connection_name}' cannot be found, known ones are \\[\\]", + ): + await snowflake.connector.aio.SnowflakeConnection( + connections_file_path=connections_file + ).connect() + + +async def test_missing_default_connection_conn_file(monkeypatch, tmp_path): + connections_file = tmp_path / "aio_connections.toml" + config_file = tmp_path / "aio_config.toml" + connections_file.write_text( + dedent( + """\ + [con_a] + user = "test user" + account = "test account" + password = "test password" + """ + ) + ) + connections_file.chmod(stat.S_IRUSR | stat.S_IWUSR) + with monkeypatch.context() as m: + m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False) + m.delenv("SNOWFLAKE_CONNECTIONS", raising=False) + m.setattr(CONFIG_MANAGER, "conf_file_cache", None) + m.setattr(CONFIG_MANAGER, "file_path", config_file) + + with pytest.raises( + Error, + match="Default connection with name 'default' cannot be found, known ones are \\['con_a'\\]", + ): + await snowflake.connector.aio.SnowflakeConnection( + connections_file_path=connections_file + ).connect() + + +async def test_missing_default_connection_conf_conn_file(monkeypatch, tmp_path): + connection_name = random_string(5) + connections_file = tmp_path / "aio_connections.toml" + config_file = tmp_path / "aio_config.toml" + config_file.write_text( + dedent( + f"""\ + default_connection_name = "{connection_name}" + """ + ) + ) + config_file.chmod(stat.S_IRUSR | stat.S_IWUSR) + connections_file.write_text( + dedent( + """\ + [con_a] + user = "test user" + account = "test account" + password = "test password" + """ + ) + ) + connections_file.chmod(stat.S_IRUSR | stat.S_IWUSR) + with monkeypatch.context() as m: + m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False) + m.delenv("SNOWFLAKE_CONNECTIONS", raising=False) + m.setattr(CONFIG_MANAGER, "conf_file_cache", None) + m.setattr(CONFIG_MANAGER, "file_path", config_file) + + with pytest.raises( + Error, + match=f"Default connection with name '{connection_name}' cannot be found, known ones are \\['con_a'\\]", + ): + await snowflake.connector.aio.SnowflakeConnection( + connections_file_path=connections_file + ).connect() + + +async def test_invalid_backoff_policy(): + with pytest.raises(ProgrammingError): + # zero_backoff() is a generator, not a generator function + _ = await fake_connector(backoff_policy=zero_backoff()).connect() + + with pytest.raises(ProgrammingError): + # passing a non-generator function should not work + _ = await fake_connector(backoff_policy=lambda: None).connect() + + with pytest.raises(InterfaceError): + # passing a generator function should make it pass config and error during connection + _ = await fake_connector(backoff_policy=zero_backoff).connect() + + +@pytest.mark.parametrize("next_action", ("RETRY", "ERROR")) +@patch("aiohttp.ClientSession.request") +async def test_handle_timeout(mockSessionRequest, next_action): + mockSessionRequest.side_effect = mock_async_request_with_action( + next_action, sleep=5 + ) + + with pytest.raises(OperationalError): + # no backoff for testing + async with fake_db_conn( + login_timeout=9, + backoff_policy=zero_backoff, + ): + pass + + # authenticator should be the only retry mechanism for login requests + # 9 seconds should be enough for authenticator to attempt twice + # however, loosen restrictions to avoid thread scheduling causing failure + assert 1 < mockSessionRequest.call_count < 4 + + +async def test_private_key_file_reading(tmp_path: Path): + key_file = tmp_path / "aio_key.pem" + + private_key = rsa.generate_private_key( + backend=default_backend(), public_exponent=65537, key_size=2048 + ) + + private_key_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + key_file.write_bytes(private_key_pem) + + pkb = private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + exc_msg = "stop execution" + + with mock.patch( + "snowflake.connector.aio.auth.AuthByKeyPair.__init__", + side_effect=Exception(exc_msg), + ) as m: + with pytest.raises( + Exception, + match=exc_msg, + ): + await snowflake.connector.aio.SnowflakeConnection( + account="test_account", + user="test_user", + private_key_file=str(key_file), + ).connect() + assert m.call_count == 1 + assert m.call_args_list[0].kwargs["private_key"] == pkb + + +async def test_encrypted_private_key_file_reading(tmp_path: Path): + key_file = tmp_path / "aio_key.pem" + private_key_password = token_urlsafe(25) + private_key = rsa.generate_private_key( + backend=default_backend(), public_exponent=65537, key_size=2048 + ) + + private_key_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.BestAvailableEncryption( + private_key_password.encode("utf-8") + ), + ) + + key_file.write_bytes(private_key_pem) + + pkb = private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + exc_msg = "stop execution" + + with mock.patch( + "snowflake.connector.aio.auth.AuthByKeyPair.__init__", + side_effect=Exception(exc_msg), + ) as m: + with pytest.raises( + Exception, + match=exc_msg, + ): + await snowflake.connector.aio.SnowflakeConnection( + account="test_account", + user="test_user", + private_key_file=str(key_file), + private_key_file_pwd=private_key_password, + ).connect() + assert m.call_count == 1 + assert m.call_args_list[0].kwargs["private_key"] == pkb + + +async def test_expired_detection(): + with mock.patch( + "snowflake.connector.aio._network.SnowflakeRestful._post_request", + return_value={ + "data": { + "masterToken": "some master token", + "token": "some token", + "validityInSeconds": 3600, + "masterValidityInSeconds": 14400, + "displayUserName": "TEST_USER", + "serverVersion": "7.42.0", + }, + "code": None, + "message": None, + "success": True, + }, + ): + conn = fake_connector() + await conn.connect() + assert not conn.expired + async with conn.cursor() as cur: + with mock.patch( + "snowflake.connector.aio._network.SnowflakeRestful.fetch", + return_value={ + "data": { + "errorCode": "390114", + "reAuthnMethods": ["USERNAME_PASSWORD"], + }, + "code": "390114", + "message": "Authentication token has expired. The user must authenticate again.", + "success": False, + "headers": None, + }, + ): + with pytest.raises(ProgrammingError): + await cur.execute("select 1;") + assert conn.expired + + +async def test_disable_saml_url_check_config(): + with mock.patch( + "snowflake.connector.aio._network.SnowflakeRestful._post_request", + return_value={ + "data": { + "serverVersion": "a.b.c", + }, + "code": None, + "message": None, + "success": True, + }, + ): + async with fake_db_conn() as conn: + assert ( + conn._disable_saml_url_check + == DEFAULT_CONFIGURATION.get("disable_saml_url_check")[0] + ) + + +def test_request_guid(): + assert ( + SnowflakeRestful.add_request_guid( + "https://test.snowflakecomputing.com" + ).startswith("https://test.snowflakecomputing.com?request_guid=") + and SnowflakeRestful.add_request_guid( + "http://test.snowflakecomputing.cn?a=b" + ).startswith("http://test.snowflakecomputing.cn?a=b&request_guid=") + and SnowflakeRestful.add_request_guid( + "https://test.snowflakecomputing.com.cn" + ).startswith("https://test.snowflakecomputing.com.cn?request_guid=") + and SnowflakeRestful.add_request_guid("https://test.abc.cn?a=b") + == "https://test.abc.cn?a=b" + ) + + +async def test_ssl_error_hint(caplog): + with mock.patch( + "aiohttp.ClientSession.request", + side_effect=aiohttp.ClientSSLError(mock.Mock(), OSError("SSL error")), + ), caplog.at_level(logging.DEBUG): + with pytest.raises(OperationalError) as exc: + await fake_connector().connect() + assert _CONNECTIVITY_ERR_MSG in exc.value.msg and isinstance( + exc.value, OperationalError + ) + assert "SSL error" in caplog.text and _CONNECTIVITY_ERR_MSG in caplog.text diff --git a/test/unit/aio/test_cursor_async_unit.py b/test/unit/aio/test_cursor_async_unit.py new file mode 100644 index 0000000000..ec23635731 --- /dev/null +++ b/test/unit/aio/test_cursor_async_unit.py @@ -0,0 +1,86 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import unittest.mock +from unittest.mock import MagicMock, patch + +import pytest + +from snowflake.connector.aio import SnowflakeConnection, SnowflakeCursor +from snowflake.connector.errors import ServiceUnavailableError + +try: + from snowflake.connector.constants import FileTransferType +except ImportError: + from enum import Enum + + class FileTransferType(Enum): + GET = "get" + PUT = "put" + + +class FakeConnection(SnowflakeConnection): + def __init__(self): + self._log_max_query_length = 0 + self._reuse_results = None + + +@pytest.mark.parametrize( + "sql,_type", + ( + ("", None), + ("select 1;", None), + ("PUT file:///tmp/data/mydata.csv @my_int_stage;", FileTransferType.PUT), + ("GET @%mytable file:///tmp/data/;", FileTransferType.GET), + ("/**/PUT file:///tmp/data/mydata.csv @my_int_stage;", FileTransferType.PUT), + ("/**/ GET @%mytable file:///tmp/data/;", FileTransferType.GET), + pytest.param( + "/**/\n" + + "\t/*/get\t*/\t/**/\n" * 10000 + + "\t*/get @~/test.csv file:///tmp\n", + None, + id="long_incorrect", + ), + pytest.param( + "/**/\n" + "\t/*/put\t*/\t/**/\n" * 10000 + "put file:///tmp/data.csv @~", + FileTransferType.PUT, + id="long_correct", + ), + ), +) +def test_get_filetransfer_type(sql, _type): + assert SnowflakeCursor.get_file_transfer_type(sql) == _type + + +def test_cursor_attribute(): + fake_conn = FakeConnection() + cursor = SnowflakeCursor(fake_conn) + assert cursor.lastrowid is None + + +@patch("snowflake.connector.aio._cursor.SnowflakeCursor._SnowflakeCursor__cancel_query") +async def test_cursor_execute_timeout(mockCancelQuery): + async def mock_cmd_query(*args, **kwargs): + await asyncio.sleep(10) + raise ServiceUnavailableError() + + fake_conn = FakeConnection() + fake_conn.cmd_query = mock_cmd_query + fake_conn._rest = unittest.mock.AsyncMock() + fake_conn._paramstyle = MagicMock() + fake_conn._next_sequence_counter = unittest.mock.AsyncMock() + + cursor = SnowflakeCursor(fake_conn) + + with pytest.raises(ServiceUnavailableError): + await cursor.execute( + command="SELECT * FROM nonexistent", + timeout=1, + ) + + # query cancel request should be sent upon timeout + assert mockCancelQuery.called diff --git a/test/unit/aio/test_gcs_client_async.py b/test/unit/aio/test_gcs_client_async.py new file mode 100644 index 0000000000..4ff648e620 --- /dev/null +++ b/test/unit/aio/test_gcs_client_async.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import logging +from os import path +from unittest import mock +from unittest.mock import AsyncMock, Mock + +import pytest +from aiohttp import ClientResponse + +from snowflake.connector.aio import SnowflakeConnection +from snowflake.connector.constants import SHA256_DIGEST + +try: + from snowflake.connector.util_text import random_string +except ImportError: + from test.randomize import random_string + +from snowflake.connector.aio._file_transfer_agent import ( + SnowflakeFileMeta, + SnowflakeFileTransferAgent, +) +from snowflake.connector.errors import RequestExceedMaxRetryError +from snowflake.connector.file_transfer_agent import StorageCredential +from snowflake.connector.vendored.requests import HTTPError + +try: # pragma: no cover + from snowflake.connector.aio._gcs_storage_client import SnowflakeGCSRestClient +except ImportError: + SnowflakeGCSRestClient = None + + +from snowflake.connector.vendored import requests + +vendored_request = True + + +THIS_DIR = path.dirname(path.realpath(__file__)) + + +@pytest.mark.parametrize("errno", [408, 429, 500, 503]) +async def test_upload_retry_errors(errno, tmpdir): + """Tests whether retryable errors are handled correctly when upploading.""" + error = AsyncMock() + error.status = errno + f_name = str(tmpdir.join("some_file.txt")) + meta = SnowflakeFileMeta( + name=f_name, + src_file_name=f_name, + stage_location_type="GCS", + presigned_url="some_url", + sha256_digest="asd", + ) + if RequestExceedMaxRetryError is not None: + mock_connection = mock.create_autospec(SnowflakeConnection) + client = SnowflakeGCSRestClient( + meta, + StorageCredential({}, mock_connection, ""), + {}, + mock_connection, + "", + ) + with open(f_name, "w") as f: + f.write(random_string(15)) + client.data_file = f_name + + with mock.patch( + "aiohttp.ClientSession.request", + new_callable=AsyncMock, + ) as m: + m.return_value = error + with pytest.raises(RequestExceedMaxRetryError): + # Retry quickly during unit tests + client.SLEEP_UNIT = 0.0 + await client.upload_chunk(0) + + +async def test_upload_uncaught_exception(tmpdir): + """Tests whether non-retryable errors are handled correctly when uploading.""" + f_name = str(tmpdir.join("some_file.txt")) + exc = HTTPError("501 Server Error") + with open(f_name, "w") as f: + f.write(random_string(15)) + agent = SnowflakeFileTransferAgent( + mock.MagicMock(), + f"put {f_name} @~", + { + "data": { + "command": "UPLOAD", + "src_locations": [f_name], + "stageInfo": { + "locationType": "GCS", + "location": "", + "creds": {"AWS_SECRET_KEY": "", "AWS_KEY_ID": ""}, + "region": "test", + "endPoint": None, + }, + "localLocation": "/tmp", + } + }, + ) + with mock.patch( + "snowflake.connector.aio._gcs_storage_client.SnowflakeGCSRestClient.get_file_header", + ), mock.patch( + "snowflake.connector.aio._gcs_storage_client.SnowflakeGCSRestClient._upload_chunk", + side_effect=exc, + ): + await agent.execute() + assert agent._file_metadata[0].error_details is exc + + +@pytest.mark.parametrize("errno", [403, 408, 429, 500, 503]) +async def test_download_retry_errors(errno, tmp_path): + """Tests whether retryable errors are handled correctly when downloading.""" + error = AsyncMock() + error.status = errno + if errno == 403: + pytest.skip("This behavior has changed in the move from SDKs") + meta_info = { + "name": "data1.txt.gz", + "stage_location_type": "S3", + "no_sleeping_time": True, + "put_callback": None, + "put_callback_output_stream": None, + SHA256_DIGEST: "123456789abcdef", + "dst_file_name": "data1.txt.gz", + "src_file_name": path.join(THIS_DIR, "../data", "put_get_1.txt"), + "overwrite": True, + } + meta = SnowflakeFileMeta(**meta_info) + creds = {"AWS_SECRET_KEY": "", "AWS_KEY_ID": ""} + cnx = mock.MagicMock(autospec=SnowflakeConnection) + rest_client = SnowflakeGCSRestClient( + meta, + StorageCredential( + creds, + cnx, + "GET file:/tmp/file.txt @~", + ), + { + "locationType": "AWS", + "location": "bucket/path", + "creds": creds, + "region": "test", + "endPoint": None, + }, + cnx, + "GET file:///tmp/file.txt @~", + ) + + rest_client.SLEEP_UNIT = 0 + with mock.patch( + "aiohttp.ClientSession.request", + new_callable=AsyncMock, + ) as m: + m.return_value = error + with pytest.raises( + RequestExceedMaxRetryError, + match="GET with url .* failed for exceeding maximum retries", + ): + await rest_client.download_chunk(0) + + +@pytest.mark.parametrize("errno", (501, 403)) +async def test_download_uncaught_exception(tmp_path, errno): + """Tests whether non-retryable errors are handled correctly when downloading.""" + error = AsyncMock(spec=ClientResponse) + error.status = errno + error.raise_for_status.return_value = None + error.raise_for_status.side_effect = HTTPError("Fake exceptiom") + meta_info = { + "name": "data1.txt.gz", + "stage_location_type": "S3", + "no_sleeping_time": True, + "put_callback": None, + "put_callback_output_stream": None, + SHA256_DIGEST: "123456789abcdef", + "dst_file_name": "data1.txt.gz", + "src_file_name": path.join(THIS_DIR, "../data", "put_get_1.txt"), + "overwrite": True, + } + meta = SnowflakeFileMeta(**meta_info) + creds = {"AWS_SECRET_KEY": "", "AWS_KEY_ID": ""} + cnx = mock.MagicMock(autospec=SnowflakeConnection) + rest_client = SnowflakeGCSRestClient( + meta, + StorageCredential( + creds, + cnx, + "GET file:/tmp/file.txt @~", + ), + { + "locationType": "AWS", + "location": "bucket/path", + "creds": creds, + "region": "test", + "endPoint": None, + }, + cnx, + "GET file:///tmp/file.txt @~", + ) + + rest_client.SLEEP_UNIT = 0 + with mock.patch( + "aiohttp.ClientSession.request", + new_callable=AsyncMock, + ) as m: + m.return_value = error + with pytest.raises( + requests.exceptions.HTTPError, + ): + await rest_client.download_chunk(0) + + +async def test_upload_put_timeout(tmp_path, caplog): + """Tests whether timeout error is handled correctly when uploading.""" + caplog.set_level(logging.DEBUG, "snowflake.connector") + f_name = str(tmp_path / "some_file.txt") + with open(f_name, "w") as f: + f.write(random_string(15)) + agent = SnowflakeFileTransferAgent( + mock.Mock(autospec=SnowflakeConnection, connection=None), + f"put {f_name} @~", + { + "data": { + "command": "UPLOAD", + "src_locations": [f_name], + "stageInfo": { + "locationType": "GCS", + "location": "", + "creds": {"AWS_SECRET_KEY": "", "AWS_KEY_ID": ""}, + "region": "test", + "endPoint": None, + }, + "localLocation": "/tmp", + } + }, + ) + + async def custom_side_effect(method, url, **kwargs): + if method in ["PUT"]: + raise asyncio.TimeoutError() + return AsyncMock(spec=ClientResponse) + + SnowflakeGCSRestClient.SLEEP_UNIT = 0 + + with mock.patch( + "aiohttp.ClientSession.request", + AsyncMock(side_effect=custom_side_effect), + ): + await agent.execute() + assert ( + "snowflake.connector.aio._storage_client", + logging.WARNING, + "PUT with url https://storage.googleapis.com//some_file.txt.gz failed for transient error: ", + ) in caplog.record_tuples + assert ( + "snowflake.connector.aio._file_transfer_agent", + logging.DEBUG, + "Chunk 0 of file some_file.txt failed to transfer for unexpected exception PUT with url https://storage.googleapis.com//some_file.txt.gz failed for exceeding maximum retries.", + ) in caplog.record_tuples + + +async def test_download_timeout(tmp_path, caplog): + """Tests whether timeout error is handled correctly when downloading.""" + meta_info = { + "name": "data1.txt.gz", + "stage_location_type": "S3", + "no_sleeping_time": True, + "put_callback": None, + "put_callback_output_stream": None, + SHA256_DIGEST: "123456789abcdef", + "dst_file_name": "data1.txt.gz", + "src_file_name": path.join(THIS_DIR, "../data", "put_get_1.txt"), + "overwrite": True, + } + meta = SnowflakeFileMeta(**meta_info) + creds = {"AWS_SECRET_KEY": "", "AWS_KEY_ID": ""} + cnx = mock.MagicMock(autospec=SnowflakeConnection) + rest_client = SnowflakeGCSRestClient( + meta, + StorageCredential( + creds, + cnx, + "GET file:/tmp/file.txt @~", + ), + { + "locationType": "AWS", + "location": "bucket/path", + "creds": creds, + "region": "test", + "endPoint": None, + }, + cnx, + "GET file:///tmp/file.txt @~", + ) + + rest_client.SLEEP_UNIT = 0 + + async def custom_side_effect(method, url, **kwargs): + if method in ["GET"]: + raise asyncio.TimeoutError() + return AsyncMock(spec=ClientResponse) + + SnowflakeGCSRestClient.SLEEP_UNIT = 0 + + with mock.patch( + "aiohttp.ClientSession.request", + AsyncMock(side_effect=custom_side_effect), + ): + exc = Exception("stop execution") + with mock.patch.object(rest_client.credentials, "update", side_effect=exc): + with pytest.raises(RequestExceedMaxRetryError): + await rest_client.download_chunk(0) + + +async def test_get_file_header_none_with_presigned_url(tmp_path): + """Tests whether default file handle created by get_file_header is as expected.""" + meta = SnowflakeFileMeta( + name=str(tmp_path / "some_file"), + src_file_name=str(tmp_path / "some_file"), + stage_location_type="GCS", + presigned_url="www.example.com", + ) + storage_credentials = Mock() + storage_credentials.creds = {} + stage_info = Mock() + connection = Mock() + client = SnowflakeGCSRestClient( + meta, storage_credentials, stage_info, connection, "" + ) + if not client.security_token: + await client._update_presigned_url() + file_header = await client.get_file_header(meta.name) + assert file_header is None diff --git a/test/unit/aio/test_mfa_no_cache_async.py b/test/unit/aio/test_mfa_no_cache_async.py new file mode 100644 index 0000000000..b90bd51eb6 --- /dev/null +++ b/test/unit/aio/test_mfa_no_cache_async.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import json +from unittest.mock import patch + +import pytest + +import snowflake.connector.aio +from snowflake.connector.compat import IS_LINUX + +try: + from snowflake.connector.options import installed_keyring +except ImportError: + # if installed_keyring is unavailable, we set it as True to skip the test + installed_keyring = True +try: + from snowflake.connector.auth._auth import delete_temporary_credential +except ImportError: + delete_temporary_credential = None + +MFA_TOKEN = "MFATOKEN" + + +@pytest.mark.skipif( + IS_LINUX or installed_keyring or not delete_temporary_credential, + reason="Required test env is Mac/Win with no pre-installed keyring package" + "and available delete_temporary_credential.", +) +@patch("snowflake.connector.aio._network.SnowflakeRestful._post_request") +async def test_mfa_no_local_secure_storage(mockSnowflakeRestfulPostRequest): + """Test whether username_password_mfa authenticator can work when no local secure storage is available.""" + global mock_post_req_cnt + mock_post_req_cnt = 0 + + # This test requires Mac/Win and no keyring lib is installed + assert not installed_keyring + + async def mock_post_request(url, headers, json_body, **kwargs): + global mock_post_req_cnt + ret = None + body = json.loads(json_body) + if mock_post_req_cnt == 0: + # issue MFA token for a succeeded login + assert ( + body["data"]["SESSION_PARAMETERS"].get("CLIENT_REQUEST_MFA_TOKEN") + is True + ) + ret = { + "success": True, + "message": None, + "data": { + "token": "TOKEN", + "masterToken": "MASTER_TOKEN", + "mfaToken": "MFA_TOKEN", + }, + } + elif mock_post_req_cnt == 2: + # No local secure storage available, so no mfa cache token should be provided + assert ( + body["data"]["SESSION_PARAMETERS"].get("CLIENT_REQUEST_MFA_TOKEN") + is True + ) + assert "TOKEN" not in body["data"] + ret = { + "success": True, + "message": None, + "data": { + "token": "NEW_TOKEN", + "masterToken": "NEW_MASTER_TOKEN", + }, + } + elif mock_post_req_cnt in [1, 3]: + # connection.close() + ret = {"success": True} + mock_post_req_cnt += 1 + return ret + + # POST requests mock + mockSnowflakeRestfulPostRequest.side_effect = mock_post_request + + conn_cfg = { + "account": "testaccount", + "user": "testuser", + "password": "testpwd", + "authenticator": "username_password_mfa", + "host": "testaccount.snowflakecomputing.com", + } + + delete_temporary_credential( + host=conn_cfg["host"], user=conn_cfg["user"], cred_type=MFA_TOKEN + ) + + # first connection, no mfa token cache + con = snowflake.connector.aio.SnowflakeConnection(**conn_cfg) + await con.connect() + assert con._rest.token == "TOKEN" + assert con._rest.master_token == "MASTER_TOKEN" + assert con._rest.mfa_token == "MFA_TOKEN" + await con.close() + + # second connection, no mfa token should be issued as well since no available local secure storage + con = snowflake.connector.aio.SnowflakeConnection(**conn_cfg) + await con.connect() + assert con._rest.token == "NEW_TOKEN" + assert con._rest.master_token == "NEW_MASTER_TOKEN" + assert not con._rest.mfa_token + await con.close() diff --git a/test/unit/aio/test_ocsp.py b/test/unit/aio/test_ocsp.py new file mode 100644 index 0000000000..90cbcc3cbf --- /dev/null +++ b/test/unit/aio/test_ocsp.py @@ -0,0 +1,441 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +# Please note that not all the unit tests from test/unit/test_ocsp.py is ported to this file, +# as those un-ported test cases are irrelevant to the asyncio implementation. + +from __future__ import annotations + +import asyncio +import functools +import os +import platform +import ssl +import time +from contextlib import asynccontextmanager +from os import environ, path +from unittest import mock + +import aiohttp +import aiohttp.client_proto +import pytest + +import snowflake.connector.ocsp_snowflake +from snowflake.connector.aio._ocsp_asn1crypto import SnowflakeOCSPAsn1Crypto as SFOCSP +from snowflake.connector.aio._ocsp_snowflake import OCSPCache, SnowflakeOCSP +from snowflake.connector.errors import RevocationCheckError +from snowflake.connector.util_text import random_string + +pytestmark = pytest.mark.asyncio + +try: + from snowflake.connector.cache import SFDictFileCache + from snowflake.connector.errorcode import ( + ER_OCSP_RESPONSE_CERT_STATUS_REVOKED, + ER_OCSP_RESPONSE_FETCH_FAILURE, + ) + from snowflake.connector.ocsp_snowflake import OCSP_CACHE + + @pytest.fixture(autouse=True) + def overwrite_ocsp_cache(tmpdir): + """This fixture swaps out the actual OCSP cache for a temprary one.""" + if OCSP_CACHE is not None: + tmp_cache_file = os.path.join(tmpdir, "tmp_cache") + with mock.patch( + "snowflake.connector.ocsp_snowflake.OCSP_CACHE", + SFDictFileCache(file_path=tmp_cache_file), + ): + yield + os.unlink(tmp_cache_file) + +except ImportError: + ER_OCSP_RESPONSE_CERT_STATUS_REVOKED = None + ER_OCSP_RESPONSE_FETCH_FAILURE = None + OCSP_CACHE = None + +TARGET_HOSTS = [ + "ocspssd.us-east-1.snowflakecomputing.com", + "sqs.us-west-2.amazonaws.com", + "sfcsupport.us-east-1.snowflakecomputing.com", + "sfcsupport.eu-central-1.snowflakecomputing.com", + "sfc-eng-regression.s3.amazonaws.com", + "sfctest0.snowflakecomputing.com", + "sfc-ds2-customer-stage.s3.amazonaws.com", + "snowflake.okta.com", + "sfcdev1.blob.core.windows.net", + "sfc-aus-ds1-customer-stage.s3-ap-southeast-2.amazonaws.com", +] + +THIS_DIR = path.dirname(path.realpath(__file__)) + + +@asynccontextmanager +async def _asyncio_connect(url, timeout=5): + loop = asyncio.get_event_loop() + transport, protocol = await loop.create_connection( + functools.partial(aiohttp.client_proto.ResponseHandler, loop), + host=url, + port=443, + ssl=ssl.create_default_context(), + ssl_handshake_timeout=timeout, + ) + yield protocol + transport.close() + + +@pytest.fixture(autouse=True) +def random_ocsp_response_validation_cache(): + file_path = { + "linux": os.path.join( + "~", + ".cache", + "snowflake", + f"ocsp_response_validation_cache{random_string()}", + ), + "darwin": os.path.join( + "~", + "Library", + "Caches", + "Snowflake", + f"ocsp_response_validation_cache{random_string()}", + ), + "windows": os.path.join( + "~", + "AppData", + "Local", + "Snowflake", + "Caches", + f"ocsp_response_validation_cache{random_string()}", + ), + } + yield SFDictFileCache( + entry_lifetime=3600, + file_path=file_path, + ) + try: + os.unlink(file_path[platform.system().lower()]) + except Exception: + pass + + +async def test_ocsp(): + """OCSP tests.""" + # reset the memory cache + SnowflakeOCSP.clear_cache() + ocsp = SFOCSP() + for url in TARGET_HOSTS: + async with _asyncio_connect(url, timeout=5) as connection: + assert await ocsp.validate(url, connection), f"Failed to validate: {url}" + + +async def test_ocsp_wo_cache_server(): + """OCSP Tests with Cache Server Disabled.""" + SnowflakeOCSP.clear_cache() + ocsp = SFOCSP(use_ocsp_cache_server=False) + for url in TARGET_HOSTS: + async with _asyncio_connect(url, timeout=5) as connection: + assert await ocsp.validate(url, connection), f"Failed to validate: {url}" + + +async def test_ocsp_wo_cache_file(): + """OCSP tests without File cache. + + Notes: + Use /etc as a readonly directory such that no cache file is used. + """ + # reset the memory cache + SnowflakeOCSP.clear_cache() + OCSPCache.del_cache_file() + environ["SF_OCSP_RESPONSE_CACHE_DIR"] = "/etc" + OCSPCache.reset_cache_dir() + + try: + ocsp = SFOCSP() + for url in TARGET_HOSTS: + async with _asyncio_connect(url, timeout=5) as connection: + assert await ocsp.validate( + url, connection + ), f"Failed to validate: {url}" + finally: + del environ["SF_OCSP_RESPONSE_CACHE_DIR"] + OCSPCache.reset_cache_dir() + + +async def test_ocsp_fail_open_w_single_endpoint(): + SnowflakeOCSP.clear_cache() + + OCSPCache.del_cache_file() + + environ["SF_OCSP_TEST_MODE"] = "true" + environ["SF_TEST_OCSP_URL"] = "http://httpbin.org/delay/10" + environ["SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT"] = "5" + + ocsp = SFOCSP(use_ocsp_cache_server=False) + + try: + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate( + "snowflake.okta.com", connection + ), "Failed to validate: {}".format("snowflake.okta.com") + finally: + del environ["SF_OCSP_TEST_MODE"] + del environ["SF_TEST_OCSP_URL"] + del environ["SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT"] + + +@pytest.mark.skipif( + ER_OCSP_RESPONSE_CERT_STATUS_REVOKED is None, + reason="No ER_OCSP_RESPONSE_CERT_STATUS_REVOKED is available.", +) +async def test_ocsp_fail_close_w_single_endpoint(): + SnowflakeOCSP.clear_cache() + + environ["SF_OCSP_TEST_MODE"] = "true" + environ["SF_TEST_OCSP_URL"] = "http://httpbin.org/delay/10" + environ["SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT"] = "5" + + OCSPCache.del_cache_file() + + ocsp = SFOCSP(use_ocsp_cache_server=False, use_fail_open=False) + + with pytest.raises(RevocationCheckError) as ex: + async with _asyncio_connect("snowflake.okta.com") as connection: + await ocsp.validate("snowflake.okta.com", connection) + + try: + assert ( + ex.value.errno == ER_OCSP_RESPONSE_FETCH_FAILURE + ), "Connection should have failed" + finally: + del environ["SF_OCSP_TEST_MODE"] + del environ["SF_TEST_OCSP_URL"] + del environ["SF_TEST_CA_OCSP_RESPONDER_CONNECTION_TIMEOUT"] + + +async def test_ocsp_bad_validity(): + SnowflakeOCSP.clear_cache() + + environ["SF_OCSP_TEST_MODE"] = "true" + environ["SF_TEST_OCSP_FORCE_BAD_RESPONSE_VALIDITY"] = "true" + + OCSPCache.del_cache_file() + + ocsp = SFOCSP(use_ocsp_cache_server=False) + async with _asyncio_connect("snowflake.okta.com") as connection: + + assert await ocsp.validate( + "snowflake.okta.com", connection + ), "Connection should have passed with fail open" + del environ["SF_OCSP_TEST_MODE"] + del environ["SF_TEST_OCSP_FORCE_BAD_RESPONSE_VALIDITY"] + + +async def test_ocsp_single_endpoint(): + environ["SF_OCSP_ACTIVATE_NEW_ENDPOINT"] = "True" + SnowflakeOCSP.clear_cache() + ocsp = SFOCSP() + ocsp.OCSP_CACHE_SERVER.NEW_DEFAULT_CACHE_SERVER_BASE_URL = "https://snowflake.preprod3.us-west-2-dev.external-zone.snowflakecomputing.com:8085/ocsp/" + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate( + "snowflake.okta.com", connection + ), "Failed to validate: {}".format("snowflake.okta.com") + + del environ["SF_OCSP_ACTIVATE_NEW_ENDPOINT"] + + +async def test_ocsp_by_post_method(): + """OCSP tests.""" + # reset the memory cache + SnowflakeOCSP.clear_cache() + ocsp = SFOCSP(use_post_method=True) + for url in TARGET_HOSTS: + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate(url, connection), f"Failed to validate: {url}" + + +async def test_ocsp_with_file_cache(tmpdir): + """OCSP tests and the cache server and file.""" + tmp_dir = str(tmpdir.mkdir("ocsp_response_cache")) + cache_file_name = path.join(tmp_dir, "cache_file.txt") + + # reset the memory cache + SnowflakeOCSP.clear_cache() + ocsp = SFOCSP(ocsp_response_cache_uri="file://" + cache_file_name) + for url in TARGET_HOSTS: + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate(url, connection), f"Failed to validate: {url}" + + +async def test_ocsp_with_bogus_cache_files( + tmpdir, random_ocsp_response_validation_cache +): + with mock.patch( + "snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE", + random_ocsp_response_validation_cache, + ): + from snowflake.connector.ocsp_snowflake import OCSPResponseValidationResult + + """Attempts to use bogus OCSP response data.""" + cache_file_name, target_hosts = await _store_cache_in_file(tmpdir) + + ocsp = SFOCSP() + OCSPCache.read_ocsp_response_cache_file(ocsp, cache_file_name) + cache_data = snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE + assert cache_data, "more than one cache entries should be stored." + + # setting bogus data + current_time = int(time.time()) + for k, _ in cache_data.items(): + cache_data[k] = OCSPResponseValidationResult( + ocsp_response=b"bogus", + ts=current_time, + validated=True, + ) + + # write back the cache file + OCSPCache.CACHE = cache_data + OCSPCache.write_ocsp_response_cache_file(ocsp, cache_file_name) + + # forces to use the bogus cache file but it should raise errors + SnowflakeOCSP.clear_cache() + ocsp = SFOCSP() + for hostname in target_hosts: + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate( + hostname, connection + ), f"Failed to validate: {hostname}" + + +async def test_ocsp_with_outdated_cache(tmpdir, random_ocsp_response_validation_cache): + with mock.patch( + "snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE", + random_ocsp_response_validation_cache, + ): + from snowflake.connector.ocsp_snowflake import OCSPResponseValidationResult + + """Attempts to use outdated OCSP response cache file.""" + cache_file_name, target_hosts = await _store_cache_in_file(tmpdir) + + ocsp = SFOCSP() + + # reading cache file + OCSPCache.read_ocsp_response_cache_file(ocsp, cache_file_name) + cache_data = snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE + assert cache_data, "more than one cache entries should be stored." + + # setting outdated data + current_time = int(time.time()) + for k, v in cache_data.items(): + cache_data[k] = OCSPResponseValidationResult( + ocsp_response=v.ocsp_response, + ts=current_time - 144 * 60 * 60, + validated=True, + ) + + # write back the cache file + OCSPCache.CACHE = cache_data + OCSPCache.write_ocsp_response_cache_file(ocsp, cache_file_name) + + # forces to use the bogus cache file but it should raise errors + SnowflakeOCSP.clear_cache() # reset the memory cache + SFOCSP() + assert ( + SnowflakeOCSP.cache_size() == 0 + ), "must be empty. outdated cache should not be loaded" + + +async def _store_cache_in_file(tmpdir, target_hosts=None): + if target_hosts is None: + target_hosts = TARGET_HOSTS + os.environ["SF_OCSP_RESPONSE_CACHE_DIR"] = str(tmpdir) + OCSPCache.reset_cache_dir() + filename = path.join(str(tmpdir), "ocsp_response_cache.json") + + # cache OCSP response + SnowflakeOCSP.clear_cache() + ocsp = SFOCSP( + ocsp_response_cache_uri="file://" + filename, use_ocsp_cache_server=False + ) + for hostname in target_hosts: + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate( + hostname, connection + ), f"Failed to validate: {hostname}" + assert path.exists(filename), "OCSP response cache file" + return filename, target_hosts + + +async def test_ocsp_with_invalid_cache_file(): + """OCSP tests with an invalid cache file.""" + SnowflakeOCSP.clear_cache() # reset the memory cache + ocsp = SFOCSP(ocsp_response_cache_uri="NEVER_EXISTS") + for url in TARGET_HOSTS[0:1]: + async with _asyncio_connect(url) as connection: + assert await ocsp.validate(url, connection), f"Failed to validate: {url}" + + +@mock.patch( + "snowflake.connector.aio._ocsp_snowflake.SnowflakeOCSP._fetch_ocsp_response", + new_callable=mock.AsyncMock, + side_effect=BrokenPipeError("fake error"), +) +async def test_ocsp_cache_when_server_is_down( + mock_fetch_ocsp_response, tmpdir, random_ocsp_response_validation_cache +): + with mock.patch( + "snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE", + random_ocsp_response_validation_cache, + ): + ocsp = SFOCSP() + + """Attempts to use outdated OCSP response cache file.""" + cache_file_name, target_hosts = await _store_cache_in_file(tmpdir) + + # reading cache file + OCSPCache.read_ocsp_response_cache_file(ocsp, cache_file_name) + cache_data = snowflake.connector.ocsp_snowflake.OCSP_RESPONSE_VALIDATION_CACHE + assert not cache_data, "no cache should present because of broken pipe" + + +async def test_concurrent_ocsp_requests(tmpdir): + """Run OCSP revocation checks in parallel. The memory and file caches are deleted randomly.""" + cache_file_name = path.join(str(tmpdir), "cache_file.txt") + SnowflakeOCSP.clear_cache() # reset the memory cache + SFOCSP(ocsp_response_cache_uri="file://" + cache_file_name) + + target_hosts = TARGET_HOSTS * 5 + await asyncio.gather( + *[ + _validate_certs_using_ocsp(hostname, cache_file_name) + for hostname in target_hosts + ] + ) + + +async def _validate_certs_using_ocsp(url, cache_file_name): + """Validate OCSP response. Deleting memory cache and file cache randomly.""" + import logging + + logger = logging.getLogger("test") + + logging.basicConfig(level=logging.DEBUG) + import random + + await asyncio.sleep(random.randint(0, 3)) + if random.random() < 0.2: + logger.info("clearing up cache: OCSP_VALIDATION_CACHE") + SnowflakeOCSP.clear_cache() + if random.random() < 0.05: + logger.info("deleting a cache file: %s", cache_file_name) + try: + # delete cache file can file because other coroutine is reading the file + # here we just randomly delete the file such passing OSError achieves the same effect + SnowflakeOCSP.delete_cache_file() + except OSError: + pass + + async with _asyncio_connect(url) as connection: + ocsp = SFOCSP(ocsp_response_cache_uri="file://" + cache_file_name) + await ocsp.validate(url, connection) diff --git a/test/unit/aio/test_put_get_async.py b/test/unit/aio/test_put_get_async.py new file mode 100644 index 0000000000..702e1bb50d --- /dev/null +++ b/test/unit/aio/test_put_get_async.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import os +from os import chmod, path +from unittest import mock + +import pytest + +from snowflake.connector import OperationalError +from snowflake.connector.aio._cursor import SnowflakeCursor +from snowflake.connector.aio._file_transfer_agent import SnowflakeFileTransferAgent +from snowflake.connector.compat import IS_WINDOWS +from snowflake.connector.errors import Error + +pytestmark = pytest.mark.asyncio +CLOUD = os.getenv("cloud_provider", "dev") + + +@pytest.mark.skip +@pytest.mark.skipif(IS_WINDOWS, reason="permission model is different") +async def test_put_error(tmpdir): + """Tests for raise_put_get_error flag (now turned on by default) in SnowflakeFileTransferAgent.""" + tmp_dir = str(tmpdir.mkdir("putfiledir")) + file1 = path.join(tmp_dir, "file1") + remote_location = path.join(tmp_dir, "remote_loc") + with open(file1, "w") as f: + f.write("test1") + + con = mock.AsyncMock() + cursor = await con.cursor() + cursor.errorhandler = Error.default_errorhandler + query = "PUT something" + ret = { + "data": { + "command": "UPLOAD", + "autoCompress": False, + "src_locations": [file1], + "sourceCompression": "none", + "stageInfo": { + "creds": {}, + "location": remote_location, + "locationType": "LOCAL_FS", + "path": "remote_loc", + }, + }, + "success": True, + } + + agent_class = SnowflakeFileTransferAgent + + # no error is raised + sf_file_transfer_agent = agent_class(cursor, query, ret, raise_put_get_error=False) + await sf_file_transfer_agent.execute() + sf_file_transfer_agent.result() + + # nobody can read now. + chmod(file1, 0o000) + # Permission error should be raised + sf_file_transfer_agent = agent_class(cursor, query, ret, raise_put_get_error=True) + await sf_file_transfer_agent.execute() + with pytest.raises(OperationalError, match="PermissionError"): + sf_file_transfer_agent.result() + + # unspecified, should fail because flag is on by default now + sf_file_transfer_agent = agent_class(cursor, query, ret) + await sf_file_transfer_agent.execute() + with pytest.raises(OperationalError, match="PermissionError"): + sf_file_transfer_agent.result() + + chmod(file1, 0o700) + + +async def test_get_empty_file(tmpdir): + """Tests for error message when retrieving missing file.""" + tmp_dir = str(tmpdir.mkdir("getfiledir")) + + con = mock.AsyncMock() + cursor = await con.cursor() + cursor.errorhandler = Error.default_errorhandler + query = f"GET something file:\\{tmp_dir}" + ret = { + "data": { + "localLocation": tmp_dir, + "command": "DOWNLOAD", + "autoCompress": False, + "src_locations": [], + "sourceCompression": "none", + "stageInfo": { + "creds": {}, + "location": "", + "locationType": "S3", + "path": "remote_loc", + }, + }, + "success": True, + } + + sf_file_transfer_agent = SnowflakeFileTransferAgent( + cursor, query, ret, raise_put_get_error=True + ) + with pytest.raises(OperationalError, match=".*the file does not exist.*$"): + await sf_file_transfer_agent.execute() + assert not sf_file_transfer_agent.result()["rowset"] + + +@pytest.mark.skipolddriver +async def test_upload_file_with_azure_upload_failed_error(tmp_path): + """Tests Upload file with expired Azure storage token.""" + file1 = tmp_path / "file1" + with file1.open("w") as f: + f.write("test1") + rest_client = SnowflakeFileTransferAgent( + mock.MagicMock(autospec=SnowflakeCursor), + "PUT some_file.txt", + { + "data": { + "command": "UPLOAD", + "src_locations": [file1], + "sourceCompression": "none", + "stageInfo": { + "creds": { + "AZURE_SAS_TOKEN": "sas_token", + }, + "location": "some_bucket", + "region": "no_region", + "locationType": "AZURE", + "path": "remote_loc", + "endPoint": "", + "storageAccount": "storage_account", + }, + }, + "success": True, + }, + ) + exc = Exception("Stop executing") + with mock.patch( + "snowflake.connector.aio._azure_storage_client.SnowflakeAzureRestClient._has_expired_token", + return_value=True, + ): + with mock.patch( + "snowflake.connector.file_transfer_agent.StorageCredential.update", + side_effect=exc, + ) as mock_update: + await rest_client.execute() + assert mock_update.called + assert rest_client._results[0].error_details is exc diff --git a/test/unit/aio/test_renew_session_async.py b/test/unit/aio/test_renew_session_async.py new file mode 100644 index 0000000000..205bbcac3d --- /dev/null +++ b/test/unit/aio/test_renew_session_async.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import logging +from test.unit.mock_utils import mock_connection +from unittest.mock import Mock, PropertyMock + +from snowflake.connector.aio._network import SnowflakeRestful + + +async def test_renew_session(): + OLD_SESSION_TOKEN = "old_session_token" + OLD_MASTER_TOKEN = "old_master_token" + NEW_SESSION_TOKEN = "new_session_token" + NEW_MASTER_TOKEN = "new_master_token" + connection = mock_connection() + connection.errorhandler = Mock(return_value=None) + type(connection)._probe_connection = PropertyMock(return_value=False) + + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", port=443, connection=connection + ) + rest._token = OLD_SESSION_TOKEN + rest._master_token = OLD_MASTER_TOKEN + + # inject a fake method (success) + async def fake_request_exec(**_): + return { + "success": True, + "data": { + "sessionToken": NEW_SESSION_TOKEN, + "masterToken": NEW_MASTER_TOKEN, + }, + } + + rest._request_exec = fake_request_exec + + await rest._renew_session() + assert not rest._connection.errorhandler.called # no error + assert rest.master_token == NEW_MASTER_TOKEN + assert rest.token == NEW_SESSION_TOKEN + + # inject a fake method (failure) + async def fake_request_exec(**_): + return {"success": False, "message": "failed to renew session", "code": 987654} + + rest._request_exec = fake_request_exec + + await rest._renew_session() + assert rest._connection.errorhandler.called # error + + # no master token + del rest._master_token + await rest._renew_session() + assert rest._connection.errorhandler.called # error + + +async def test_mask_token_when_renew_session(caplog): + caplog.set_level(logging.DEBUG) + OLD_SESSION_TOKEN = "old_session_token" + OLD_MASTER_TOKEN = "old_master_token" + NEW_SESSION_TOKEN = "new_session_token" + NEW_MASTER_TOKEN = "new_master_token" + connection = mock_connection() + connection.errorhandler = Mock(return_value=None) + type(connection)._probe_connection = PropertyMock(return_value=False) + + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", port=443, connection=connection + ) + rest._token = OLD_SESSION_TOKEN + rest._master_token = OLD_MASTER_TOKEN + + # inject a fake method (success) + async def fake_request_exec(**_): + return { + "success": True, + "data": { + "sessionToken": NEW_SESSION_TOKEN, + "masterToken": NEW_MASTER_TOKEN, + }, + } + + rest._request_exec = fake_request_exec + + # no secrets recorded when renew succeed + await rest._renew_session() + assert "new_session_token" not in caplog.text + assert "new_master_token" not in caplog.text + assert "old_session_token" not in caplog.text + assert "old_master_token" not in caplog.text + + async def fake_request_exec(**_): + return {"success": False, "message": "failed to renew session", "code": 987654} + + rest._request_exec = fake_request_exec + + # no secrets recorded when renew failed + await rest._renew_session() + assert "new_session_token" not in caplog.text + assert "new_master_token" not in caplog.text + assert "old_session_token" not in caplog.text + assert "old_master_token" not in caplog.text diff --git a/test/unit/aio/test_result_batch_async.py b/test/unit/aio/test_result_batch_async.py new file mode 100644 index 0000000000..2b43799db2 --- /dev/null +++ b/test/unit/aio/test_result_batch_async.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from collections import namedtuple +from http import HTTPStatus +from test.helpers import create_async_mock_response +from unittest import mock + +import pytest + +from snowflake.connector import DatabaseError, InterfaceError +from snowflake.connector.compat import ( + BAD_GATEWAY, + BAD_REQUEST, + FORBIDDEN, + GATEWAY_TIMEOUT, + INTERNAL_SERVER_ERROR, + METHOD_NOT_ALLOWED, + OK, + REQUEST_TIMEOUT, + SERVICE_UNAVAILABLE, + UNAUTHORIZED, +) +from snowflake.connector.errorcode import ( + ER_FAILED_TO_CONNECT_TO_DB, + ER_FAILED_TO_REQUEST, +) +from snowflake.connector.errors import ( + BadGatewayError, + BadRequest, + ForbiddenError, + GatewayTimeoutError, + InternalServerError, + MethodNotAllowed, + OtherHTTPRetryableError, + ServiceUnavailableError, +) + +try: + from snowflake.connector.aio._result_batch import ( + MAX_DOWNLOAD_RETRY, + JSONResultBatch, + ) + from snowflake.connector.compat import TOO_MANY_REQUESTS + from snowflake.connector.errors import TooManyRequests + + REQUEST_MODULE_PATH = "aiohttp.ClientSession" +except ImportError: + MAX_DOWNLOAD_RETRY = None + JSONResultBatch = None + REQUEST_MODULE_PATH = "aiohttp.ClientSession" + TooManyRequests = None + TOO_MANY_REQUESTS = None +from snowflake.connector.sqlstate import ( + SQLSTATE_CONNECTION_REJECTED, + SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, +) + +MockRemoteChunkInfo = namedtuple("MockRemoteChunkInfo", "url") +chunk_info = MockRemoteChunkInfo("http://www.chunk-url.com") +result_batch = ( + JSONResultBatch(100, None, chunk_info, [], [], True) if JSONResultBatch else None +) + + +pytestmark = pytest.mark.asyncio + + +@mock.patch(REQUEST_MODULE_PATH + ".get") +async def test_ok_response_download(mock_get): + mock_get.side_effect = create_async_mock_response(200) + + content, encoding = await result_batch._download() + + # successful on first try + assert mock_get.call_count == 1 and content == "success" + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "errcode,error_class", + [ + (BAD_REQUEST, BadRequest), # 400 + (FORBIDDEN, ForbiddenError), # 403 + (METHOD_NOT_ALLOWED, MethodNotAllowed), # 405 + (REQUEST_TIMEOUT, OtherHTTPRetryableError), # 408 + (TOO_MANY_REQUESTS, TooManyRequests), # 429 + (INTERNAL_SERVER_ERROR, InternalServerError), # 500 + (BAD_GATEWAY, BadGatewayError), # 502 + (SERVICE_UNAVAILABLE, ServiceUnavailableError), # 503 + (GATEWAY_TIMEOUT, GatewayTimeoutError), # 504 + (555, OtherHTTPRetryableError), # random 5xx error + ], +) +async def test_retryable_response_download(errcode, error_class): + """This test checks that responses which are deemed 'retryable' are handled correctly.""" + # retryable exceptions + with mock.patch( + REQUEST_MODULE_PATH + ".get", side_effect=create_async_mock_response(errcode) + ) as mock_get: + # mock_get.return_value = create_async_mock_response(errcode) + + with mock.patch("asyncio.sleep", return_value=None): + with pytest.raises(error_class) as ex: + _ = await result_batch._download() + err_msg = ex.value.msg + if isinstance(errcode, HTTPStatus): + assert str(errcode.value) in err_msg + else: + assert str(errcode) in err_msg + assert mock_get.call_count == MAX_DOWNLOAD_RETRY + + +async def test_unauthorized_response_download(): + """This tests that the Unauthorized response (401 status code) is handled correctly.""" + with mock.patch( + REQUEST_MODULE_PATH + ".get", + side_effect=create_async_mock_response(UNAUTHORIZED), + ) as mock_get: + with mock.patch("asyncio.sleep", return_value=None): + with pytest.raises(DatabaseError) as ex: + _ = await result_batch._download() + error = ex.value + assert error.errno == ER_FAILED_TO_CONNECT_TO_DB + assert error.sqlstate == SQLSTATE_CONNECTION_REJECTED + assert "401" in error.msg + assert mock_get.call_count == MAX_DOWNLOAD_RETRY + + +@pytest.mark.parametrize("status_code", [201, 302]) +async def test_non_200_response_download(status_code): + """This test checks that "success" codes which are not 200 still retry.""" + with mock.patch( + REQUEST_MODULE_PATH + ".get", + side_effect=create_async_mock_response(status_code), + ) as mock_get: + with mock.patch("asyncio.sleep", return_value=None): + with pytest.raises(InterfaceError) as ex: + _ = await result_batch._download() + error = ex.value + assert error.errno == ER_FAILED_TO_REQUEST + assert error.sqlstate == SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED + assert mock_get.call_count == MAX_DOWNLOAD_RETRY + + +async def test_retries_until_success(): + with mock.patch(REQUEST_MODULE_PATH + ".get") as mock_get: + error_codes = [BAD_REQUEST, UNAUTHORIZED, 201] + # There is an OK added to the list of responses so that there is a success + # and the retry loop ends. + mock_responses = [ + create_async_mock_response(code)("") for code in error_codes + [OK] + ] + mock_get.side_effect = mock_responses + + with mock.patch("asyncio.sleep", return_value=None): + res, _ = await result_batch._download() + assert res == "success" + # call `get` once for each error and one last time when it succeeds + assert mock_get.call_count == len(error_codes) + 1 diff --git a/test/unit/aio/test_retry_network_async.py b/test/unit/aio/test_retry_network_async.py new file mode 100644 index 0000000000..0dbb35235e --- /dev/null +++ b/test/unit/aio/test_retry_network_async.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import errno +import json +import logging +import os +from test.unit.aio.mock_utils import mock_async_request_with_action, mock_connection +from test.unit.mock_utils import zero_backoff +from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch +from uuid import uuid4 + +import aiohttp +import OpenSSL.SSL +import pytest + +import snowflake.connector.aio +from snowflake.connector.aio._network import SnowflakeRestful +from snowflake.connector.compat import ( + BAD_GATEWAY, + BAD_REQUEST, + FORBIDDEN, + GATEWAY_TIMEOUT, + INTERNAL_SERVER_ERROR, + OK, + SERVICE_UNAVAILABLE, + UNAUTHORIZED, +) +from snowflake.connector.errors import ( + DatabaseError, + Error, + ForbiddenError, + InterfaceError, + OperationalError, + OtherHTTPRetryableError, + ServiceUnavailableError, +) +from snowflake.connector.network import STATUS_TO_EXCEPTION, RetryRequest + +pytestmark = pytest.mark.skipolddriver + + +THIS_DIR = os.path.dirname(os.path.realpath(__file__)) + + +class Cnt: + def __init__(self): + self.c = 0 + + def set(self, cnt): + self.c = cnt + + def reset(self): + self.set(0) + + +async def fake_connector() -> snowflake.connector.aio.SnowflakeConnection: + conn = snowflake.connector.aio.SnowflakeConnection( + user="user", + account="account", + password="testpassword", + database="TESTDB", + warehouse="TESTWH", + ) + await conn.connect() + return conn + + +@patch("snowflake.connector.aio._network.SnowflakeRestful._request_exec") +async def test_retry_reason(mockRequestExec): + url = "" + cnt = Cnt() + + async def mock_exec(session, method, full_url, headers, data, token, **kwargs): + # take actions based on data["sqlText"] + nonlocal url + url = full_url + data = json.loads(data) + sql = data.get("sqlText", "default") + success_result = { + "success": True, + "message": None, + "data": { + "token": "TOKEN", + "masterToken": "MASTER_TOKEN", + "idToken": None, + "parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}], + }, + } + cnt.c += 1 + if "retry" in sql: + # error = HTTP Error 429 + if cnt.c < 3: # retry twice for 429 error + raise RetryRequest(OtherHTTPRetryableError(errno=429)) + return success_result + elif "unknown error" in sql: + # Raise unknown http error + if cnt.c == 1: # retry once for 100 error + raise RetryRequest(OtherHTTPRetryableError(errno=100)) + return success_result + elif "flip" in sql: + if cnt.c == 1: # retry first with 100 + raise RetryRequest(OtherHTTPRetryableError(errno=100)) + elif cnt.c == 2: # then with 429 + raise RetryRequest(OtherHTTPRetryableError(errno=429)) + return success_result + + return success_result + + conn = await fake_connector() + mockRequestExec.side_effect = mock_exec + + # ensure query requests don't have the retryReason if retryCount == 0 + cnt.reset() + await conn.cmd_query("success", 0, uuid4()) + assert "retryReason" not in url + assert "retryCount" not in url + + # ensure query requests have correct retryReason when retry reason is sent by server + cnt.reset() + await conn.cmd_query("retry", 0, uuid4()) + assert "retryReason=429" in url + assert "retryCount=2" in url + + cnt.reset() + await conn.cmd_query("unknown error", 0, uuid4()) + assert "retryReason=100" in url + assert "retryCount=1" in url + + # ensure query requests have retryReason reset to 0 when no reason is given + cnt.reset() + await conn.cmd_query("success", 0, uuid4()) + assert "retryReason" not in url + assert "retryCount" not in url + + # ensure query requests have retryReason gets updated with updated error code + cnt.reset() + await conn.cmd_query("flip", 0, uuid4()) + assert "retryReason=429" in url + assert "retryCount=2" in url + + # ensure that disabling works and only suppresses retryReason + conn._enable_retry_reason_in_query_response = False + + cnt.reset() + await conn.cmd_query("retry", 0, uuid4()) + assert "retryReason" not in url + assert "retryCount=2" in url + + cnt.reset() + await conn.cmd_query("unknown error", 0, uuid4()) + assert "retryReason" not in url + assert "retryCount=1" in url + + +async def test_request_exec(): + connection = mock_connection() + connection.errorhandler = Error.default_errorhandler + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", + port=443, + connection=connection, + ) + + default_parameters = { + "method": "POST", + "full_url": "https://testaccount.snowflakecomputing.com/", + "headers": {}, + "data": '{"code": 12345}', + "token": None, + } + + login_parameters = { + **default_parameters, + "full_url": "https://bad_id.snowflakecomputing.com:443/session/v1/login-request?request_id=s0m3-r3a11Y-rAnD0m-reqID&request_guid=s0m3-r3a11Y-rAnD0m-reqGUID", + } + + # request mock + output_data = {"success": True, "code": 12345} + request_mock = AsyncMock() + type(request_mock).status = PropertyMock(return_value=OK) + request_mock.json.return_value = output_data + + # session mock + session = AsyncMock() + session.request.return_value = request_mock + + # success + ret = await rest._request_exec(session=session, **default_parameters) + assert ret == output_data, "output data" + + # retryable exceptions + for errcode in [ + BAD_REQUEST, # 400 + FORBIDDEN, # 403 + INTERNAL_SERVER_ERROR, # 500 + BAD_GATEWAY, # 502 + SERVICE_UNAVAILABLE, # 503 + GATEWAY_TIMEOUT, # 504 + 555, # random 5xx error + ]: + type(request_mock).status = PropertyMock(return_value=errcode) + try: + await rest._request_exec(session=session, **default_parameters) + pytest.fail("should fail") + except RetryRequest as e: + cls = STATUS_TO_EXCEPTION.get(errcode, OtherHTTPRetryableError) + assert isinstance(e.args[0], cls), "must be internal error exception" + + # unauthorized + type(request_mock).status = PropertyMock(return_value=UNAUTHORIZED) + with pytest.raises(InterfaceError): + await rest._request_exec(session=session, **default_parameters) + + # unauthorized with catch okta unauthorized error + # TODO: what is the difference to InterfaceError? + type(request_mock).status = PropertyMock(return_value=UNAUTHORIZED) + with pytest.raises(DatabaseError): + await rest._request_exec( + session=session, catch_okta_unauthorized_error=True, **default_parameters + ) + + # forbidden on login-request raises ForbiddenError + type(request_mock).status = PropertyMock(return_value=FORBIDDEN) + with pytest.raises(ForbiddenError): + await rest._request_exec(session=session, **login_parameters) + + # handle retryable exception + for exc in [ + aiohttp.ConnectionTimeoutError, + aiohttp.ClientConnectorError(MagicMock(), OSError(1)), + asyncio.TimeoutError, + AttributeError, + ]: + session = AsyncMock() + session.request = Mock(side_effect=exc) + + try: + await rest._request_exec(session=session, **default_parameters) + pytest.fail("should fail") + except RetryRequest as e: + cause = e.args[0] + assert ( + isinstance(cause, exc) + if not isinstance(cause, aiohttp.ClientConnectorError) + else cause == exc + ) + + # handle OpenSSL errors and BadStateLine + for exc in [ + OpenSSL.SSL.SysCallError(errno.ECONNRESET), + OpenSSL.SSL.SysCallError(errno.ETIMEDOUT), + OpenSSL.SSL.SysCallError(errno.EPIPE), + OpenSSL.SSL.SysCallError(-1), # unknown + ]: + session = AsyncMock() + session.request = Mock(side_effect=exc) + try: + await rest._request_exec(session=session, **default_parameters) + pytest.fail("should fail") + except RetryRequest as e: + assert e.args[0] == exc, "same error instance" + + +async def test_fetch(): + connection = mock_connection() + connection.errorhandler = Mock(return_value=None) + + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", port=443, connection=connection + ) + + cnt = Cnt() + default_parameters = { + "method": "POST", + "full_url": "https://testaccount.snowflakecomputing.com/", + "headers": {"cnt": cnt}, + "data": '{"code": 12345}', + } + + NOT_RETRYABLE = 1000 + + class NotRetryableException(Exception): + pass + + async def fake_request_exec(**kwargs): + headers = kwargs.get("headers") + cnt = headers["cnt"] + await asyncio.sleep(3) + if cnt.c <= 1: + # the first two raises failure + cnt.c += 1 + raise RetryRequest(Exception("can retry")) + elif cnt.c == NOT_RETRYABLE: + # not retryable exception + raise NotRetryableException("cannot retry") + else: + # return success in the third attempt + return {"success": True, "data": "valid data"} + + # inject a fake method + rest._request_exec = fake_request_exec + + # first two attempts will fail but third will success + cnt.reset() + ret = await rest.fetch(timeout=10, **default_parameters) + assert ret == {"success": True, "data": "valid data"} + assert not rest._connection.errorhandler.called # no error + + # first attempt to reach timeout even if the exception is retryable + cnt.reset() + ret = await rest.fetch(timeout=1, **default_parameters) + assert ret == {} + assert rest._connection.errorhandler.called # error + + # not retryable excpetion + cnt.set(NOT_RETRYABLE) + with pytest.raises(NotRetryableException): + await rest.fetch(timeout=7, **default_parameters) + + # first attempt fails and will not retry + cnt.reset() + default_parameters["no_retry"] = True + ret = await rest.fetch(timeout=10, **default_parameters) + assert ret == {} + assert cnt.c == 1 # failed on first call - did not retry + assert rest._connection.errorhandler.called # error + + +async def test_secret_masking(caplog): + connection = mock_connection() + connection.errorhandler = Mock(return_value=None) + + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", port=443, connection=connection + ) + + data = ( + '{"code": 12345,' + ' "data": {"TOKEN": "_Y1ZNETTn5/qfUWj3Jedb", "PASSWORD": "dummy_pass"}' + "}" + ) + default_parameters = { + "method": "POST", + "full_url": "https://testaccount.snowflakecomputing.com/", + "headers": {}, + "data": data, + } + + class NotRetryableException(Exception): + pass + + async def fake_request_exec(**kwargs): + return None + + # inject a fake method + rest._request_exec = fake_request_exec + + # first two attempts will fail but third will success + with caplog.at_level(logging.ERROR): + ret = await rest.fetch(timeout=10, **default_parameters) + assert '"TOKEN": "****' in caplog.text + assert '"PASSWORD": "****' in caplog.text + assert ret == {} + + +async def test_retry_connection_reset_error(caplog): + connection = mock_connection() + connection.errorhandler = Mock(return_value=None) + + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", port=443, connection=connection + ) + + data = ( + '{"code": 12345,' + ' "data": {"TOKEN": "_Y1ZNETTn5/qfUWj3Jedb", "PASSWORD": "dummy_pass"}' + "}" + ) + default_parameters = { + "method": "POST", + "full_url": "https://testaccount.snowflakecomputing.com/", + "headers": {}, + "data": data, + } + + async def error_send(*args, **kwargs): + raise OSError(104, "ECONNRESET") + + with patch( + "snowflake.connector.aio._ssl_connector.SnowflakeSSLConnector.connect" + ) as mock_conn, patch("aiohttp.client_reqrep.ClientRequest.send", error_send): + with caplog.at_level(logging.DEBUG): + await rest.fetch(timeout=10, **default_parameters) + + # this test is different from sync test because aiohttp automatically + # closes the underlying broken socket if it encounters a connection reset error + assert mock_conn.call_count > 1 + + +@pytest.mark.parametrize("next_action", ("RETRY", "ERROR")) +@patch("aiohttp.ClientSession.request") +async def test_login_request_timeout(mockSessionRequest, next_action): + """For login requests, all errors should be bubbled up as OperationalError for authenticator to handle""" + mockSessionRequest.side_effect = mock_async_request_with_action(next_action) + + connection = mock_connection() + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", port=443, connection=connection + ) + + with pytest.raises(OperationalError): + await rest.fetch( + method="post", + full_url="https://testaccount.snowflakecomputing.com/session/v1/login-request", + headers=dict(), + ) + + +@pytest.mark.parametrize( + "next_action_result", + (("RETRY", ServiceUnavailableError), ("ERROR", OperationalError)), +) +@patch("aiohttp.ClientSession.request") +async def test_retry_request_timeout(mockSessionRequest, next_action_result): + next_action, next_result = next_action_result + mockSessionRequest.side_effect = mock_async_request_with_action(next_action, 5) + # no backoff for testing + connection = mock_connection( + network_timeout=13, + backoff_policy=zero_backoff, + ) + connection.errorhandler = Error.default_errorhandler + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", port=443, connection=connection + ) + + with pytest.raises(next_result): + await rest.fetch( + method="post", + full_url="https://testaccount.snowflakecomputing.com/queries/v1/query-request", + headers=dict(), + ) + + # 13 seconds should be enough for authenticator to attempt thrice + # however, loosen restrictions to avoid thread scheduling causing failure + assert 1 < mockSessionRequest.call_count < 5 diff --git a/test/unit/aio/test_s3_util_async.py b/test/unit/aio/test_s3_util_async.py new file mode 100644 index 0000000000..821246aafb --- /dev/null +++ b/test/unit/aio/test_s3_util_async.py @@ -0,0 +1,502 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import logging +import re +from os import path +from test.helpers import verify_log_tuple +from unittest import mock +from unittest.mock import MagicMock + +import pytest + +from snowflake.connector.aio import SnowflakeConnection +from snowflake.connector.aio._cursor import SnowflakeCursor +from snowflake.connector.aio._file_transfer_agent import SnowflakeFileTransferAgent +from snowflake.connector.constants import SHA256_DIGEST + +try: + from aiohttp import ClientResponse, ClientResponseError + + from snowflake.connector.aio._s3_storage_client import SnowflakeS3RestClient + from snowflake.connector.constants import megabyte + from snowflake.connector.errors import RequestExceedMaxRetryError + from snowflake.connector.file_transfer_agent import ( + SnowflakeFileMeta, + StorageCredential, + ) + from snowflake.connector.s3_storage_client import ERRORNO_WSAECONNABORTED + from snowflake.connector.vendored.requests import HTTPError +except ImportError: + # Compatibility for olddriver tests + from requests import HTTPError + + from snowflake.connector.s3_util import ERRORNO_WSAECONNABORTED # NOQA + + SnowflakeFileMeta = dict + SnowflakeS3RestClient = None + RequestExceedMaxRetryError = None + StorageCredential = None + megabytes = 1024 * 1024 + DEFAULT_MAX_RETRY = 5 + +THIS_DIR = path.dirname(path.realpath(__file__)) +MINIMAL_METADATA = SnowflakeFileMeta( + name="file.txt", + stage_location_type="S3", + src_file_name="file.txt", +) + + +@pytest.mark.parametrize( + "input, bucket_name, s3path", + [ + ("sfc-eng-regression/test_sub_dir/", "sfc-eng-regression", "test_sub_dir/"), + ( + "sfc-eng-regression/stakeda/test_stg/test_sub_dir/", + "sfc-eng-regression", + "stakeda/test_stg/test_sub_dir/", + ), + ("sfc-eng-regression/", "sfc-eng-regression", ""), + ("sfc-eng-regression//", "sfc-eng-regression", "/"), + ("sfc-eng-regression///", "sfc-eng-regression", "//"), + ], +) +def test_extract_bucket_name_and_path(input, bucket_name, s3path): + """Extracts bucket name and S3 path.""" + s3_loc = SnowflakeS3RestClient._extract_bucket_name_and_path(input) + assert s3_loc.bucket_name == bucket_name + assert s3_loc.path == s3path + + +async def test_upload_file_with_s3_upload_failed_error(tmp_path): + """Tests Upload file with S3UploadFailedError, which could indicate AWS token expires.""" + file1 = tmp_path / "file1" + with file1.open("w") as f: + f.write("test1") + rest_client = SnowflakeFileTransferAgent( + MagicMock(autospec=SnowflakeCursor), + "PUT some_file.txt", + { + "data": { + "command": "UPLOAD", + "autoCompress": False, + "src_locations": [file1], + "sourceCompression": "none", + "stageInfo": { + "creds": { + "AWS_SECRET_KEY": "secret key", + "AWS_KEY_ID": "secret id", + "AWS_TOKEN": "", + }, + "location": "some_bucket", + "region": "no_region", + "locationType": "S3", + "path": "remote_loc", + "endPoint": "", + }, + }, + "success": True, + }, + ) + exc = Exception("Stop executing") + + async def mock_transfer_accelerate_config( + self: SnowflakeS3RestClient, + use_accelerate_endpoint: bool | None = None, + ) -> bool: + self.endpoint = f"https://{self.s3location.bucket_name}.s3.awsamazon.com" + return False + + with mock.patch( + "snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._has_expired_token", + return_value=True, + ): + with mock.patch( + "snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient.transfer_accelerate_config", + mock_transfer_accelerate_config, + ): + with mock.patch( + "snowflake.connector.file_transfer_agent.StorageCredential.update", + side_effect=exc, + ) as mock_update: + await rest_client.execute() + assert mock_update.called + assert rest_client._results[0].error_details is exc + + +async def test_get_header_expiry_error(): + """Tests whether token expiry error is handled as expected when getting header.""" + meta_info = { + "name": "data1.txt.gz", + "stage_location_type": "S3", + "no_sleeping_time": True, + "put_callback": None, + "put_callback_output_stream": None, + SHA256_DIGEST: "123456789abcdef", + "dst_file_name": "data1.txt.gz", + "src_file_name": path.join(THIS_DIR, "../data", "put_get_1.txt"), + "overwrite": True, + } + meta = SnowflakeFileMeta(**meta_info) + creds = {"AWS_SECRET_KEY": "", "AWS_KEY_ID": "", "AWS_TOKEN": ""} + rest_client = SnowflakeS3RestClient( + meta, + StorageCredential( + creds, + MagicMock(autospec=SnowflakeConnection), + "PUT file:/tmp/file.txt @~", + ), + { + "locationType": "AWS", + "location": "bucket/path", + "creds": creds, + "region": "test", + "endPoint": None, + }, + 8 * megabyte, + ) + await rest_client.transfer_accelerate_config(None) + + with mock.patch( + "snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._has_expired_token", + return_value=True, + ): + exc = Exception("stop execution") + with mock.patch.object(rest_client.credentials, "update", side_effect=exc): + with pytest.raises(Exception) as caught_exc: + await rest_client.get_file_header("file.txt") + assert caught_exc.value is exc + + +async def test_get_header_unknown_error(caplog): + """Tests whether unexpected errors are handled as expected when getting header.""" + caplog.set_level(logging.DEBUG, "snowflake.connector") + meta_info = { + "name": "data1.txt.gz", + "stage_location_type": "S3", + "no_sleeping_time": True, + "put_callback": None, + "put_callback_output_stream": None, + SHA256_DIGEST: "123456789abcdef", + "dst_file_name": "data1.txt.gz", + "src_file_name": path.join(THIS_DIR, "../data", "put_get_1.txt"), + "overwrite": True, + } + meta = SnowflakeFileMeta(**meta_info) + creds = {"AWS_SECRET_KEY": "", "AWS_KEY_ID": "", "AWS_TOKEN": ""} + rest_client = SnowflakeS3RestClient( + meta, + StorageCredential( + creds, + MagicMock(autospec=SnowflakeConnection), + "PUT file:/tmp/file.txt @~", + ), + { + "locationType": "AWS", + "location": "bucket/path", + "creds": creds, + "region": "test", + "endPoint": None, + }, + 8 * megabyte, + ) + exc = HTTPError("555 Server Error") + with mock.patch.object(rest_client, "get_file_header", side_effect=exc): + with pytest.raises(HTTPError, match="555 Server Error"): + await rest_client.get_file_header("file.txt") + + +async def test_upload_expiry_error(): + """Tests whether token expiry error is handled as expected when uploading.""" + meta_info = { + "name": "data1.txt.gz", + "stage_location_type": "S3", + "no_sleeping_time": True, + "put_callback": None, + "put_callback_output_stream": None, + SHA256_DIGEST: "123456789abcdef", + "dst_file_name": "data1.txt.gz", + "src_file_name": path.join(THIS_DIR, "../../data", "put_get_1.txt"), + "overwrite": True, + } + meta = SnowflakeFileMeta(**meta_info) + creds = {"AWS_SECRET_KEY": "", "AWS_KEY_ID": "", "AWS_TOKEN": ""} + rest_client = SnowflakeS3RestClient( + meta, + StorageCredential( + creds, + MagicMock(autospec=SnowflakeConnection), + "PUT file:/tmp/file.txt @~", + ), + { + "locationType": "AWS", + "location": "bucket/path", + "creds": creds, + "region": "test", + "endPoint": None, + }, + 8 * megabyte, + ) + await rest_client.transfer_accelerate_config(None) + + with mock.patch( + "snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._has_expired_token", + return_value=True, + ): + exc = Exception("stop execution") + with mock.patch.object(rest_client.credentials, "update", side_effect=exc): + with mock.patch( + "snowflake.connector.aio._storage_client.SnowflakeStorageClient.preprocess" + ): + await rest_client.prepare_upload() + with pytest.raises(Exception) as caught_exc: + await rest_client.upload_chunk(0) + assert caught_exc.value is exc + + +async def test_upload_unknown_error(): + """Tests whether unknown errors are handled as expected when uploading.""" + meta_info = { + "name": "data1.txt.gz", + "stage_location_type": "S3", + "no_sleeping_time": True, + "put_callback": None, + "put_callback_output_stream": None, + SHA256_DIGEST: "123456789abcdef", + "dst_file_name": "data1.txt.gz", + "src_file_name": path.join(THIS_DIR, "../../data", "put_get_1.txt"), + "overwrite": True, + } + meta = SnowflakeFileMeta(**meta_info) + creds = {"AWS_SECRET_KEY": "", "AWS_KEY_ID": "", "AWS_TOKEN": ""} + rest_client = SnowflakeS3RestClient( + meta, + StorageCredential( + creds, + MagicMock(autospec=SnowflakeConnection), + "PUT file:/tmp/file.txt @~", + ), + { + "locationType": "AWS", + "location": "bucket/path", + "creds": creds, + "region": "test", + "endPoint": None, + }, + 8 * megabyte, + ) + + exc = Exception("stop execution") + with mock.patch.object(rest_client.credentials, "update", side_effect=exc): + with mock.patch( + "snowflake.connector.aio._storage_client.SnowflakeStorageClient.preprocess" + ): + await rest_client.prepare_upload() + with pytest.raises(HTTPError, match="555 Server Error"): + e = HTTPError("555 Server Error") + with mock.patch.object(rest_client, "_upload_chunk", side_effect=e): + await rest_client.upload_chunk(0) + + +async def test_download_expiry_error(): + """Tests whether token expiry error is handled as expected when downloading.""" + meta_info = { + "name": "data1.txt.gz", + "stage_location_type": "S3", + "no_sleeping_time": True, + "put_callback": None, + "put_callback_output_stream": None, + SHA256_DIGEST: "123456789abcdef", + "dst_file_name": "data1.txt.gz", + "src_file_name": "path/to/put_get_1.txt", + "overwrite": True, + } + meta = SnowflakeFileMeta(**meta_info) + creds = {"AWS_SECRET_KEY": "", "AWS_KEY_ID": "", "AWS_TOKEN": ""} + rest_client = SnowflakeS3RestClient( + meta, + StorageCredential( + creds, + MagicMock(autospec=SnowflakeConnection), + "GET file:/tmp/file.txt @~", + ), + { + "locationType": "AWS", + "location": "bucket/path", + "creds": creds, + "region": "test", + "endPoint": None, + }, + 8 * megabyte, + ) + await rest_client.transfer_accelerate_config(None) + + with mock.patch( + "snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._has_expired_token", + return_value=True, + ): + exc = Exception("stop execution") + with mock.patch.object(rest_client.credentials, "update", side_effect=exc): + with pytest.raises(Exception) as caught_exc: + await rest_client.download_chunk(0) + assert caught_exc.value is exc + + +async def test_download_unknown_error(caplog): + """Tests whether an unknown error is handled as expected when downloading.""" + caplog.set_level(logging.DEBUG, "snowflake.connector") + agent = SnowflakeFileTransferAgent( + MagicMock(), + "get @~/f /tmp", + { + "data": { + "command": "DOWNLOAD", + "src_locations": ["/tmp/a"], + "stageInfo": { + "locationType": "S3", + "location": "", + "creds": {"AWS_SECRET_KEY": "", "AWS_KEY_ID": "", "AWS_TOKEN": ""}, + "region": "", + "endPoint": None, + }, + "localLocation": "/tmp", + } + }, + ) + + error = ClientResponseError( + mock.AsyncMock(), + mock.AsyncMock(spec=ClientResponse), + status=400, + message="No, just chuck testing...", + headers={}, + ) + + async def mock_transfer_accelerate_config( + self: SnowflakeS3RestClient, + use_accelerate_endpoint: bool | None = None, + ) -> bool: + self.endpoint = f"https://{self.s3location.bucket_name}.s3.awsamazon.com" + return False + + with mock.patch( + "snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._send_request_with_authentication_and_retry", + side_effect=error, + ), mock.patch( + "snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent._transfer_accelerate_config", + side_effect=None, + ), mock.patch( + "snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient.transfer_accelerate_config", + mock_transfer_accelerate_config, + ): + await agent.execute() + assert agent._file_metadata[0].error_details.status == 400 + assert agent._file_metadata[0].error_details.message == "No, just chuck testing..." + assert verify_log_tuple( + "snowflake.connector.aio._storage_client", + logging.ERROR, + re.compile("Failed to download a file: .*a"), + caplog.record_tuples, + ) + + +async def test_download_retry_exceeded_error(): + """Tests whether a retry exceeded error is handled as expected when downloading.""" + meta_info = { + "name": "data1.txt.gz", + "stage_location_type": "S3", + "no_sleeping_time": True, + "put_callback": None, + "put_callback_output_stream": None, + SHA256_DIGEST: "123456789abcdef", + "dst_file_name": "data1.txt.gz", + "src_file_name": "path/to/put_get_1.txt", + "overwrite": True, + } + meta = SnowflakeFileMeta(**meta_info) + creds = {"AWS_SECRET_KEY": "", "AWS_KEY_ID": "", "AWS_TOKEN": ""} + rest_client = SnowflakeS3RestClient( + meta, + StorageCredential( + creds, + MagicMock(autospec=SnowflakeConnection), + "GET file:/tmp/file.txt @~", + ), + { + "locationType": "AWS", + "location": "bucket/path", + "creds": creds, + "region": "test", + "endPoint": None, + }, + 8 * megabyte, + ) + await rest_client.transfer_accelerate_config() + rest_client.SLEEP_UNIT = 0 + + with mock.patch( + "aiohttp.ClientSession.request", + side_effect=ConnectionError("transit error"), + ): + with mock.patch.object(rest_client.credentials, "update"): + with pytest.raises( + RequestExceedMaxRetryError, + match=r"GET with url .* failed for exceeding maximum retries", + ): + await rest_client.download_chunk(0) + + +async def test_accelerate_in_china_endpoint(): + meta_info = { + "name": "data1.txt.gz", + "stage_location_type": "S3", + "no_sleeping_time": True, + "put_callback": None, + "put_callback_output_stream": None, + SHA256_DIGEST: "123456789abcdef", + "dst_file_name": "data1.txt.gz", + "src_file_name": "path/to/put_get_1.txt", + "overwrite": True, + } + meta = SnowflakeFileMeta(**meta_info) + creds = {"AWS_SECRET_KEY": "", "AWS_KEY_ID": "", "AWS_TOKEN": ""} + rest_client = SnowflakeS3RestClient( + meta, + StorageCredential( + creds, + MagicMock(autospec=SnowflakeConnection), + "GET file:/tmp/file.txt @~", + ), + { + "locationType": "S3China", + "location": "bucket/path", + "creds": creds, + "region": "test", + "endPoint": None, + }, + 8 * megabyte, + ) + assert not await rest_client.transfer_accelerate_config() + + rest_client = SnowflakeS3RestClient( + meta, + StorageCredential( + creds, + MagicMock(autospec=SnowflakeConnection), + "GET file:/tmp/file.txt @~", + ), + { + "locationType": "S3", + "location": "bucket/path", + "creds": creds, + "region": "cn-north-1", + "endPoint": None, + }, + 8 * megabyte, + ) + assert not await rest_client.transfer_accelerate_config() diff --git a/test/unit/aio/test_session_manager_async.py b/test/unit/aio/test_session_manager_async.py new file mode 100644 index 0000000000..b117e0faf5 --- /dev/null +++ b/test/unit/aio/test_session_manager_async.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from unittest import mock + +from snowflake.connector.aio._network import SnowflakeRestful +from snowflake.connector.ssl_wrap_socket import DEFAULT_OCSP_MODE + +hostname_1 = "sfctest0.snowflakecomputing.com" +url_1 = f"https://{hostname_1}:443/session/v1/login-request" + +hostname_2 = "sfc-ds2-customer-stage.s3.amazonaws.com" +url_2 = f"https://{hostname_2}/rgm1-s-sfctest0/stages/" +url_3 = f"https://{hostname_2}/rgm1-s-sfctst0/stages/another-url" + + +mock_conn = mock.AsyncMock() +mock_conn.disable_request_pooling = False +mock_conn._ocsp_mode = lambda: DEFAULT_OCSP_MODE + + +async def close_sessions(rest: SnowflakeRestful, num_session_pools: int) -> None: + """Helper function to call SnowflakeRestful.close(). Asserts close was called on all SessionPools.""" + with mock.patch("snowflake.connector.aio._network.SessionPool.close") as close_mock: + await rest.close() + assert close_mock.call_count == num_session_pools + + +async def create_session( + rest: SnowflakeRestful, num_sessions: int = 1, url: str | None = None +) -> None: + """ + Creates 'num_sessions' sessions to 'url'. This is recursive so that idle sessions + are not reused. + """ + if num_sessions == 0: + return + async with rest._use_requests_session(url): + await create_session(rest, num_sessions - 1, url) + + +@mock.patch("snowflake.connector.aio._network.SnowflakeRestful.make_requests_session") +async def test_no_url_multiple_sessions(make_session_mock): + rest = SnowflakeRestful(connection=mock_conn) + + await create_session(rest, 2) + + assert make_session_mock.call_count == 2 + + assert list(rest._sessions_map.keys()) == [None] + + session_pool = rest._sessions_map[None] + assert len(session_pool._idle_sessions) == 2 + assert len(session_pool._active_sessions) == 0 + + await close_sessions(rest, 1) + + +@mock.patch("snowflake.connector.aio._network.SnowflakeRestful.make_requests_session") +async def test_multiple_urls_multiple_sessions(make_session_mock): + rest = SnowflakeRestful(connection=mock_conn) + + for url in [url_1, url_2, None]: + await create_session(rest, num_sessions=2, url=url) + + assert make_session_mock.call_count == 6 + + hostnames = list(rest._sessions_map.keys()) + for hostname in [hostname_1, hostname_2, None]: + assert hostname in hostnames + + for pool in rest._sessions_map.values(): + assert len(pool._idle_sessions) == 2 + assert len(pool._active_sessions) == 0 + + await close_sessions(rest, 3) + + +@mock.patch("snowflake.connector.aio._network.SnowflakeRestful.make_requests_session") +async def test_multiple_urls_reuse_sessions(make_session_mock): + rest = SnowflakeRestful(connection=mock_conn) + for url in [url_1, url_2, url_3, None]: + # create 10 sessions, one after another + for _ in range(10): + await create_session(rest, url=url) + + # only one session is created and reused thereafter + assert make_session_mock.call_count == 3 + + hostnames = list(rest._sessions_map.keys()) + assert len(hostnames) == 3 + for hostname in [hostname_1, hostname_2, None]: + assert hostname in hostnames + + for pool in rest._sessions_map.values(): + assert len(pool._idle_sessions) == 1 + assert len(pool._active_sessions) == 0 + + await close_sessions(rest, 3) diff --git a/test/unit/aio/test_storage_client_async.py b/test/unit/aio/test_storage_client_async.py new file mode 100644 index 0000000000..648332a2d9 --- /dev/null +++ b/test/unit/aio/test_storage_client_async.py @@ -0,0 +1,61 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from os import path +from unittest.mock import MagicMock + +try: + from snowflake.connector.aio import SnowflakeConnection + from snowflake.connector.aio._file_transfer_agent import SnowflakeFileMeta + from snowflake.connector.aio._s3_storage_client import SnowflakeS3RestClient + from snowflake.connector.constants import ResultStatus + from snowflake.connector.file_transfer_agent import StorageCredential +except ImportError: + # Compatibility for olddriver tests + from snowflake.connector.s3_util import ERRORNO_WSAECONNABORTED # NOQA + + SnowflakeFileMeta = dict + SnowflakeS3RestClient = None + RequestExceedMaxRetryError = None + StorageCredential = None + megabytes = 1024 * 1024 + DEFAULT_MAX_RETRY = 5 + +THIS_DIR = path.dirname(path.realpath(__file__)) +megabyte = 1024 * 1024 + + +async def test_status_when_num_of_chunks_is_zero(): + meta_info = { + "name": "data1.txt.gz", + "stage_location_type": "S3", + "no_sleeping_time": True, + "put_callback": None, + "put_callback_output_stream": None, + "sha256_digest": "123456789abcdef", + "dst_file_name": "data1.txt.gz", + "src_file_name": path.join(THIS_DIR, "../data", "put_get_1.txt"), + "overwrite": True, + } + meta = SnowflakeFileMeta(**meta_info) + creds = {"AWS_SECRET_KEY": "", "AWS_KEY_ID": "", "AWS_TOKEN": ""} + rest_client = SnowflakeS3RestClient( + meta, + StorageCredential( + creds, + MagicMock(autospec=SnowflakeConnection), + "PUT file:/tmp/file.txt @~", + ), + { + "locationType": "AWS", + "location": "bucket/path", + "creds": creds, + "region": "test", + "endPoint": None, + }, + 8 * megabyte, + ) + rest_client.successful_transfers = 0 + rest_client.num_of_chunks = 0 + await rest_client.finish_upload() + assert meta.result_status == ResultStatus.ERROR diff --git a/test/unit/aio/test_telemetry_async.py b/test/unit/aio/test_telemetry_async.py new file mode 100644 index 0000000000..d7716107bc --- /dev/null +++ b/test/unit/aio/test_telemetry_async.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from unittest.mock import Mock + +import snowflake.connector.aio._telemetry +import snowflake.connector.telemetry + + +def test_telemetry_data_to_dict(): + """Tests that TelemetryData instances are properly converted to dicts.""" + assert snowflake.connector.telemetry.TelemetryData({}, 2000).to_dict() == { + "message": {}, + "timestamp": "2000", + } + + d = {"type": "test", "query_id": "1", "value": 20} + assert snowflake.connector.telemetry.TelemetryData(d, 1234).to_dict() == { + "message": d, + "timestamp": "1234", + } + + +def get_client_and_mock(): + rest_call = Mock() + rest_call.return_value = {"success": True} + rest = Mock() + rest.attach_mock(rest_call, "request") + client = snowflake.connector.aio._telemetry.TelemetryClient(rest, 2) + return client, rest_call + + +async def test_telemetry_simple_flush(): + """Tests that metrics are properly enqueued and sent to telemetry.""" + client, rest_call = get_client_and_mock() + + await client.add_log_to_batch(snowflake.connector.telemetry.TelemetryData({}, 2000)) + assert rest_call.call_count == 0 + + await client.add_log_to_batch(snowflake.connector.telemetry.TelemetryData({}, 3000)) + assert rest_call.call_count == 1 + + +async def test_telemetry_close(): + """Tests that remaining metrics are flushed on close.""" + client, rest_call = get_client_and_mock() + + await client.add_log_to_batch(snowflake.connector.telemetry.TelemetryData({}, 2000)) + assert rest_call.call_count == 0 + + await client.close() + assert rest_call.call_count == 1 + assert client.is_closed + + +async def test_telemetry_close_empty(): + """Tests that no calls are made on close if there are no metrics to flush.""" + client, rest_call = get_client_and_mock() + + await client.close() + assert rest_call.call_count == 0 + assert client.is_closed + + +async def test_telemetry_send_batch(): + """Tests that metrics are sent with the send_batch method.""" + client, rest_call = get_client_and_mock() + + await client.add_log_to_batch(snowflake.connector.telemetry.TelemetryData({}, 2000)) + assert rest_call.call_count == 0 + + await client.send_batch() + assert rest_call.call_count == 1 + + +async def test_telemetry_send_batch_empty(): + """Tests that send_batch does nothing when there are no metrics to send.""" + client, rest_call = get_client_and_mock() + + await client.send_batch() + assert rest_call.call_count == 0 + + +async def test_telemetry_send_batch_clear(): + """Tests that send_batch clears the first batch and will not send anything on a second call.""" + client, rest_call = get_client_and_mock() + + await client.add_log_to_batch(snowflake.connector.telemetry.TelemetryData({}, 2000)) + assert rest_call.call_count == 0 + + await client.send_batch() + assert rest_call.call_count == 1 + + await client.send_batch() + assert rest_call.call_count == 1 + + +async def test_telemetry_auto_disable(): + """Tests that the client will automatically disable itself if a request fails.""" + client, rest_call = get_client_and_mock() + rest_call.return_value = {"success": False} + + await client.add_log_to_batch(snowflake.connector.telemetry.TelemetryData({}, 2000)) + assert client.is_enabled() + + await client.send_batch() + assert not client.is_enabled() + + +async def test_telemetry_add_batch_disabled(): + """Tests that the client will not add logs if disabled.""" + client, _ = get_client_and_mock() + + client.disable() + await client.add_log_to_batch(snowflake.connector.telemetry.TelemetryData({}, 2000)) + + assert client.buffer_size() == 0 + + +async def test_telemetry_send_batch_disabled(): + """Tests that the client will not send logs if disabled.""" + client, rest_call = get_client_and_mock() + + await client.add_log_to_batch(snowflake.connector.telemetry.TelemetryData({}, 2000)) + assert client.buffer_size() == 1 + + client.disable() + + await client.send_batch() + assert client.buffer_size() == 1 + assert rest_call.call_count == 0 diff --git a/test/unit/mock_utils.py b/test/unit/mock_utils.py index d3bdc43031..b6e27d514d 100644 --- a/test/unit/mock_utils.py +++ b/test/unit/mock_utils.py @@ -1,7 +1,6 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # - import time from unittest.mock import MagicMock diff --git a/tox.ini b/tox.ini index 6faca8c0d8..27339bc60f 100644 --- a/tox.ini +++ b/tox.ini @@ -33,15 +33,17 @@ setenv = COVERAGE_FILE = {env:COVERAGE_FILE:{toxworkdir}/.coverage.{envname}} ci: SNOWFLAKE_PYTEST_OPTS = -vvv # Set test type, either notset, unit, integ, or both + # aio is only supported on python >= 3.10 unit-integ: SNOWFLAKE_TEST_TYPE = (unit or integ) !unit-!integ: SNOWFLAKE_TEST_TYPE = (unit or integ) - unit: SNOWFLAKE_TEST_TYPE = unit - integ: SNOWFLAKE_TEST_TYPE = integ + unit: SNOWFLAKE_TEST_TYPE = unit and not aio + integ: SNOWFLAKE_TEST_TYPE = integ and not aio parallel: SNOWFLAKE_PYTEST_OPTS = {env:SNOWFLAKE_PYTEST_OPTS:} -n auto # Add common parts into pytest command SNOWFLAKE_PYTEST_COV_LOCATION = {env:JUNIT_REPORT_DIR:{toxworkdir}}/junit.{envname}-{env:cloud_provider:dev}.xml SNOWFLAKE_PYTEST_COV_CMD = --cov snowflake.connector --junitxml {env:SNOWFLAKE_PYTEST_COV_LOCATION} --cov-report= SNOWFLAKE_PYTEST_CMD = pytest {env:SNOWFLAKE_PYTEST_OPTS:} {env:SNOWFLAKE_PYTEST_COV_CMD} + SNOWFLAKE_PYTEST_CMD_IGNORE_AIO = {env:SNOWFLAKE_PYTEST_CMD} --ignore=test/integ/aio --ignore=test/unit/aio SNOWFLAKE_TEST_MODE = true passenv = AWS_ACCESS_KEY_ID @@ -60,10 +62,10 @@ passenv = commands = # Test environments # Note: make sure to have a default env and all the other special ones - !pandas-!sso-!lambda-!extras: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and not sso and not pandas and not lambda" {posargs:} test - pandas: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and pandas" {posargs:} test - sso: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and sso" {posargs:} test - lambda: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and lambda" {posargs:} test + !pandas-!sso-!lambda-!extras: {env:SNOWFLAKE_PYTEST_CMD_IGNORE_AIO} -m "{env:SNOWFLAKE_TEST_TYPE} and not sso and not pandas and not lambda and not aio" {posargs:} test + pandas: {env:SNOWFLAKE_PYTEST_CMD_IGNORE_AIO} -m "{env:SNOWFLAKE_TEST_TYPE} and pandas and not aio" {posargs:} test + sso: {env:SNOWFLAKE_PYTEST_CMD_IGNORE_AIO} -m "{env:SNOWFLAKE_TEST_TYPE} and sso and not aio" {posargs:} test + lambda: {env:SNOWFLAKE_PYTEST_CMD_IGNORE_AIO} -m "{env:SNOWFLAKE_TEST_TYPE} and lambda and not aio" {posargs:} test extras: python -m test.extras.run {posargs:} [testenv:olddriver] @@ -86,7 +88,7 @@ skip_install = True setenv = {[testenv]setenv} passenv = {[testenv]passenv} commands = - {env:SNOWFLAKE_PYTEST_CMD} -m "not skipolddriver" -vvv {posargs:} test + {env:SNOWFLAKE_PYTEST_CMD_IGNORE_AIO} -m "not skipolddriver" -vvv {posargs:} test [testenv:noarrowextension] basepython = python3.8 @@ -97,6 +99,22 @@ commands = pip install . python -c 'import snowflake.connector.result_batch' +[testenv:aio] +description = Run aio tests +extras= + development + aio + pandas +commands = {env:SNOWFLAKE_PYTEST_CMD} -m "aio" -vvv {posargs:} test + +[testenv:aio-unsupported-python] +description = Run aio connector on unsupported python versions +extras= + aio +commands = + pip install . + python test/aiodep/unsupported_python_version.py + [testenv:coverage] description = [run locally after tests]: combine coverage data and create report ; generates a diff coverage against origin/master (can be changed by setting DIFF_AGAINST env var) @@ -173,6 +191,8 @@ markers = timeout: tests that need a timeout time internal: tests that could but should only run on our internal CI external: tests that could but should only run on our external CI + aio: asyncio tests +asyncio_mode=auto [isort] multi_line_output = 3