Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ jobs:
python-version: ["3.10", "3.11", "3.12"]
cloud-provider: [aws, azure, gcp]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
Expand All @@ -366,7 +366,7 @@ jobs:
gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \
.github/workflows/parameters/public/parameters_${{ matrix.cloud-provider }}.py.gpg > test/parameters.py
- name: Download wheel(s)
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: ${{ matrix.os.download_name }}_py${{ matrix.python-version }}
path: dist
Expand All @@ -388,7 +388,7 @@ jobs:
- name: Combine coverages
run: python -m tox run -e coverage --skip-missing-interpreters false
shell: bash
- uses: actions/upload-artifact@v3
- uses: actions/upload-artifact@v4
with:
name: coverage_aio_${{ matrix.os.download_name }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}
path: |
Expand Down
100 changes: 60 additions & 40 deletions src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from ..util_text import split_statements
from ._cursor import SnowflakeCursor
from ._network import SnowflakeRestful
from ._time_util import HeartBeatTimer
from .auth import Auth, AuthByDefault, AuthByPlugin

logger = getLogger(__name__)
Expand Down Expand Up @@ -87,7 +88,19 @@ def __init__(
# get the imported modules from sys.modules
# self._log_telemetry_imported_packages() # TODO: async telemetry support
# check SNOW-1218851 for long term improvement plan to refactor ocsp code
# atexit.register(self._close_at_exit) # TODO: async atexit support/test
atexit.register(self._close_at_exit)

def __enter__(self):
# async connection does not support sync context manager
raise TypeError(
"'SnowflakeConnection' object does not support the context manager protocol"
)

def __exit__(self, exc_type, exc_val, exc_tb):
# async connection does not support sync context manager
raise TypeError(
"'SnowflakeConnection' object does not support the context manager protocol"
)

async def __aenter__(self) -> SnowflakeConnection:
"""Context manager."""
Expand Down Expand Up @@ -135,7 +148,9 @@ async def __open_connection(self):
)

if ".privatelink.snowflakecomputing." in self.host:
SnowflakeConnection.setup_ocsp_privatelink(self.application, self.host)
await SnowflakeConnection.setup_ocsp_privatelink(
self.application, self.host
)
else:
if "SF_OCSP_RESPONSE_CACHE_SERVER_URL" in os.environ:
del os.environ["SF_OCSP_RESPONSE_CACHE_SERVER_URL"]
Expand Down Expand Up @@ -164,11 +179,10 @@ async def __open_connection(self):
PARAMETER_CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY
] = self._validate_client_session_keep_alive_heartbeat_frequency()

# TODO: client_prefetch_threads support
# if self.client_prefetch_threads:
# self._session_parameters[PARAMETER_CLIENT_PREFETCH_THREADS] = (
# self._validate_client_prefetch_threads()
# )
if self.client_prefetch_threads:
self._session_parameters[PARAMETER_CLIENT_PREFETCH_THREADS] = (
self._validate_client_prefetch_threads()
)

# Setup authenticator
auth = Auth(self.rest)
Expand Down Expand Up @@ -203,7 +217,7 @@ async def __open_connection(self):
elif self._authenticator == DEFAULT_AUTHENTICATOR:
self.auth_class = AuthByDefault(
password=self._password,
timeout=self._login_timeout,
timeout=self.login_timeout,
backoff_generator=self._backoff_generator,
)
else:
Expand All @@ -222,10 +236,21 @@ async def __open_connection(self):
# This will be called after the heartbeat frequency has actually been set.
# By this point it should have been decided if the heartbeat has to be enabled
# and what would the heartbeat frequency be
# TODO: implement asyncio heartbeat/timer
raise NotImplementedError(
"asyncio client_session_keep_alive is not supported"
await self._add_heartbeat()

async def _add_heartbeat(self) -> None:
if not self._heartbeat_task:
self._heartbeat_task = HeartBeatTimer(
self.client_session_keep_alive_heartbeat_frequency, self._heartbeat_tick
)
await self._heartbeat_task.start()
logger.debug("started heartbeat")

async def _heartbeat_tick(self) -> None:
"""Execute a hearbeat if connection isn't closed yet."""
if not self.is_closed():
logger.debug("heartbeating!")
await self.rest._heartbeat()

async def _all_async_queries_finished(self) -> bool:
"""Checks whether all async queries started by this Connection have finished executing."""
Expand Down Expand Up @@ -322,6 +347,13 @@ async def _authenticate(self, auth_instance: AuthByPlugin):
continue
break

async def _cancel_heartbeat(self) -> None:
"""Cancel a heartbeat thread."""
if self._heartbeat_task:
await self._heartbeat_task.stop()
self._heartbeat_task = None
logger.debug("stopped heartbeat")

def _init_connection_parameters(
self,
connection_init_kwargs: dict,
Expand Down Expand Up @@ -353,7 +385,7 @@ def _init_connection_parameters(
for name, (value, _) in DEFAULT_CONFIGURATION.items():
setattr(self, f"_{name}", value)

self.heartbeat_thread = None
self._heartbeat_task = None
is_kwargs_empty = not connection_init_kwargs

if "application" not in connection_init_kwargs:
Expand Down Expand Up @@ -403,7 +435,7 @@ async def _cancel_query(

def _close_at_exit(self):
with suppress(Exception):
asyncio.get_event_loop().run_until_complete(self.close(retry=False))
asyncio.run(self.close(retry=False))

async def _get_query_status(
self, sf_qid: str
Expand Down Expand Up @@ -587,8 +619,7 @@ async def close(self, retry: bool = True) -> None:
# will hang if the application doesn't close the connection and
# CLIENT_SESSION_KEEP_ALIVE is set, because the heartbeat runs on
# a separate thread.
# TODO: async heartbeat support
# self._cancel_heartbeat()
await self._cancel_heartbeat()

# close telemetry first, since it needs rest to send remaining data
logger.info("closed")
Expand All @@ -600,7 +631,12 @@ async def close(self, retry: bool = True) -> None:
and not self._server_session_keep_alive
):
logger.info("No async queries seem to be running, deleting session")
await self.rest.delete_session(retry=retry)
try:
await self.rest.delete_session(retry=retry)
except Exception as e:
logger.debug(
"Exception encountered in deleting session. ignoring...: %s", e
)
else:
logger.info(
"There are {} async queries still running, not deleting session".format(
Expand Down Expand Up @@ -837,33 +873,17 @@ async def get_query_status_throw_if_error(self, sf_qid: str) -> QueryStatus:
"""
status, status_resp = await self._get_query_status(sf_qid)
self._cache_query_status(sf_qid, status)
queries = status_resp["data"]["queries"]
if self.is_an_error(status):
if sf_qid in self._async_sfqids:
self._async_sfqids.pop(sf_qid, None)
message = status_resp.get("message")
if message is None:
message = ""
code = queries[0].get("errorCode", -1)
sql_state = None
if "data" in status_resp:
message += (
queries[0].get("errorMessage", "") if len(queries) > 0 else ""
)
sql_state = status_resp["data"].get("sqlState")
Error.errorhandler_wrapper(
self,
None,
ProgrammingError,
{
"msg": message,
"errno": int(code),
"sqlstate": sql_state,
"sfqid": sf_qid,
},
)
self._process_error_query_status(sf_qid, status_resp)
return status

@staticmethod
async def setup_ocsp_privatelink(app, hostname) -> None:
async with SnowflakeConnection.OCSP_ENV_LOCK:
ocsp_cache_server = f"http://ocsp.{hostname}/ocsp_response_cache.json"
os.environ["SF_OCSP_RESPONSE_CACHE_SERVER_URL"] = ocsp_cache_server
logger.debug("OCSP Cache Server is updated: %s", ocsp_cache_server)

async def rollback(self) -> None:
"""Rolls back the current transaction."""
await self.cursor().execute("ROLLBACK")
5 changes: 5 additions & 0 deletions src/snowflake/connector/aio/_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def __init__(
def __aiter__(self):
return self

def __iter__(self):
raise TypeError(
"'snowflake.connector.aio.SnowflakeCursor' only supports async iteration."
)

async def __anext__(self):
while True:
_next = await self.fetchone()
Expand Down
51 changes: 30 additions & 21 deletions src/snowflake/connector/aio/_result_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,27 @@ async def create_iter(

async def _download(
self, connection: SnowflakeConnection | None = None, **kwargs
) -> aiohttp.ClientResponse:
) -> tuple[bytes, str]:
"""Downloads the data that the ``ResultBatch`` is pointing at."""
sleep_timer = 1
backoff = (
connection._backoff_generator
if connection is not None
else exponential_backoff()()
)

async def download_chunk(http_session):
response, content, encoding = None, None, None
logger.debug(
f"downloading result batch id: {self.id} with existing session {http_session}"
)
response = await http_session.get(**request_data)
if response.status == OK:
logger.debug(f"successfully downloaded result batch id: {self.id}")
content, encoding = await response.read(), response.get_encoding()
return response, content, encoding

content, encoding = None, None
for retry in range(MAX_DOWNLOAD_RETRY):
try:
# TODO: feature parity with download timeout setting, in sync it's set to 7s
Expand All @@ -218,20 +231,16 @@ async def _download(
logger.debug(
f"downloading result batch id: {self.id} with existing session {session}"
)
response = await session.request("get", **request_data)
response, content, encoding = await download_chunk(session)
else:
logger.debug(
f"downloading result batch id: {self.id} with new session"
)
async with aiohttp.ClientSession() as session:
response = await session.get(**request_data)
logger.debug(
f"downloading result batch id: {self.id} with new session"
)
response, content, encoding = await download_chunk(session)

if response.status == OK:
logger.debug(
f"successfully downloaded result batch id: {self.id}"
)
break

# Raise error here to correctly go in to exception clause
if is_retryable_http_code(response.status):
# retryable server exceptions
Expand Down Expand Up @@ -259,7 +268,7 @@ async def _download(
self._metrics[DownloadMetrics.download.value] = (
download_metric.get_timing_millis()
)
return response
return content, encoding


class JSONResultBatch(ResultBatch, JSONResultBatchSync):
Expand All @@ -268,11 +277,11 @@ async def create_iter(
) -> Iterator[dict | Exception] | Iterator[tuple | Exception]:
if self._local:
return iter(self._data)
response = await self._download(connection=connection)
content, encoding = await self._download(connection=connection)
# Load data to a intermediate form
logger.debug(f"started loading result batch id: {self.id}")
async with TimerContextManager() as load_metric:
downloaded_data = await self._load(response)
downloaded_data = self._load(content, encoding)
logger.debug(f"finished loading result batch id: {self.id}")
self._metrics[DownloadMetrics.load.value] = load_metric.get_timing_millis()
# Process downloaded data
Expand All @@ -281,7 +290,7 @@ async def create_iter(
self._metrics[DownloadMetrics.parse.value] = parse_metric.get_timing_millis()
return iter(parsed_data)

async def _load(self, response: aiohttp.ClientResponse) -> list:
async def _load(self, content: bytes, encoding: str) -> list:
"""This function loads a compressed JSON file into memory.

Returns:
Expand All @@ -292,29 +301,29 @@ async def _load(self, response: aiohttp.ClientResponse) -> list:
# if users specify how to decode the data, we decode the bytes using the specified encoding
if self._json_result_force_utf8_decoding:
try:
read_data = str(await response.read(), "utf-8", errors="strict")
read_data = str(content, "utf-8", errors="strict")
except Exception as exc:
err_msg = f"failed to decode json result content due to error {exc!r}"
logger.error(err_msg)
raise Error(msg=err_msg)
else:
# note: SNOW-787480 response.apparent_encoding is unreliable, chardet.detect can be wrong which is used by
# response.text to decode content, check issue: https://github.com/chardet/chardet/issues/148
read_data = await response.text()
read_data = content.decode(encoding, "strict")
return json.loads("".join(["[", read_data, "]"]))


class ArrowResultBatch(ResultBatch, ArrowResultBatchSync):
async def _load(
self, response: aiohttp.ClientResponse, row_unit: IterUnit
self, content, row_unit: IterUnit
) -> Iterator[dict | Exception] | Iterator[tuple | Exception]:
"""Creates a ``PyArrowIterator`` from a response.

This is used to iterate through results in different ways depending on which
mode that ``PyArrowIterator`` is in.
"""
return _create_nanoarrow_iterator(
await response.read(),
content,
self._context,
self._use_dict_result,
self._numpy,
Expand All @@ -334,14 +343,14 @@ async def _create_iter(
if connection and getattr(connection, "_debug_arrow_chunk", False):
logger.debug(f"arrow data can not be parsed: {self._data}")
raise
response = await self._download(connection=connection)
content, _ = await self._download(connection=connection)
logger.debug(f"started loading result batch id: {self.id}")
async with TimerContextManager() as load_metric:
try:
loaded_data = await self._load(response, iter_unit)
loaded_data = await self._load(content, iter_unit)
except Exception:
if connection and getattr(connection, "_debug_arrow_chunk", False):
logger.debug(f"arrow data can not be parsed: {response}")
logger.debug(f"arrow data can not be parsed: {content}")
raise
logger.debug(f"finished loading result batch id: {self.id}")
self._metrics[DownloadMetrics.load.value] = load_metric.get_timing_millis()
Expand Down
Loading