Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions src/snowflake/connector/aio/_result_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
get_http_retryable_error,
is_retryable_http_code,
)
from snowflake.connector.result_batch import (
MAX_DOWNLOAD_RETRY,
SSE_C_AES,
SSE_C_ALGORITHM,
SSE_C_KEY,
)
from snowflake.connector.result_batch import SSE_C_AES, SSE_C_ALGORITHM, SSE_C_KEY
from snowflake.connector.result_batch import ArrowResultBatch as ArrowResultBatchSync
from snowflake.connector.result_batch import DownloadMetrics
from snowflake.connector.result_batch import JSONResultBatch as JSONResultBatchSync
Expand All @@ -52,8 +47,13 @@

logger = getLogger(__name__)

# we redefine the DOWNLOAD_TIMEOUT and MAX_DOWNLOAD_RETRY for async version on purpose
# because download in sync and async are different in nature and may require separate tuning
# also be aware that currently _result_batch is a private module so these values are not exposed to users directly
DOWNLOAD_TIMEOUT = None
MAX_DOWNLOAD_RETRY = 10


# TODO: consolidate this with the sync version
def create_batches_from_response(
cursor: SnowflakeCursor,
_format: str,
Expand Down Expand Up @@ -212,19 +212,27 @@ async def download_chunk(http_session):
return response, content, encoding

content, encoding = None, None
for retry in range(MAX_DOWNLOAD_RETRY):
for retry in range(max(MAX_DOWNLOAD_RETRY, 1)):
try:
# TODO: feature parity with download timeout setting, in sync it's set to 7s
# but in async we schedule multiple tasks at the same time so some tasks might
# take longer than 7s to finish which is expected

async with TimerContextManager() as download_metric:
logger.debug(f"started downloading result batch id: {self.id}")
chunk_url = self._remote_chunk_info.url
request_data = {
"url": chunk_url,
"headers": self._chunk_headers,
# "timeout": DOWNLOAD_TIMEOUT,
}
# timeout setting for download is different from the sync version which has an
# empirical value 7 seconds. It is difficult to measure this empirical value in async
# as we maximize the network throughput by downloading multiple chunks at the same time compared
# to the sync version that the overall throughput is constrained by the number of
# prefetch threads -- in asyncio we see great download performance improvement.
# if DOWNLOAD_TIMEOUT is not set, by default the aiohttp session timeout comes into effect
# which originates from the connection config.
if DOWNLOAD_TIMEOUT:
request_data["timeout"] = aiohttp.ClientTimeout(
total=DOWNLOAD_TIMEOUT
)
# Try to reuse a connection if possible
if connection and connection._rest is not None:
async with connection._rest._use_requests_session() as session:
Expand Down
88 changes: 38 additions & 50 deletions test/integ/aio/test_cursor_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
import pickle
import time
from datetime import date, datetime, timezone
from typing import TYPE_CHECKING, NamedTuple
from unittest import mock

import pytest
import pytz

import snowflake.connector
import snowflake.connector.aio
from snowflake.connector import (
InterfaceError,
NotSupportedError,
Expand All @@ -30,64 +30,31 @@
errors,
)
from snowflake.connector.aio import DictCursor, SnowflakeCursor
from snowflake.connector.aio._result_batch import (
ArrowResultBatch,
JSONResultBatch,
ResultBatch,
)
from snowflake.connector.compat import IS_WINDOWS

try:
from snowflake.connector.cursor import ResultMetadata
except ImportError:

class ResultMetadata(NamedTuple):
name: str
type_code: int
display_size: int
internal_size: int
precision: int
scale: int
is_nullable: bool


import snowflake.connector.aio
from snowflake.connector.constants import (
FIELD_ID_TO_NAME,
PARAMETER_MULTI_STATEMENT_COUNT,
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
QueryStatus,
)
from snowflake.connector.cursor import ResultMetadata
from snowflake.connector.description import CLIENT_VERSION
from snowflake.connector.errorcode import (
ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT,
ER_NO_ARROW_RESULT,
ER_NO_PYARROW,
ER_NO_PYARROW_SNOWSQL,
ER_NOT_POSITIVE_SIZE,
)
from snowflake.connector.errors import Error
from snowflake.connector.sqlstate import SQLSTATE_FEATURE_NOT_SUPPORTED
from snowflake.connector.telemetry import TelemetryField

try:
from snowflake.connector.util_text import random_string
except ImportError:
from ..randomize import random_string

try:
from snowflake.connector.aio._result_batch import ArrowResultBatch, JSONResultBatch
from snowflake.connector.constants import (
FIELD_ID_TO_NAME,
PARAMETER_MULTI_STATEMENT_COUNT,
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
)
from snowflake.connector.errorcode import (
ER_NO_ARROW_RESULT,
ER_NO_PYARROW,
ER_NO_PYARROW_SNOWSQL,
)
except ImportError:
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = None
ER_NO_ARROW_RESULT = None
ER_NO_PYARROW = None
ER_NO_PYARROW_SNOWSQL = None
ArrowResultBatch = JSONResultBatch = None
FIELD_ID_TO_NAME = {}

if TYPE_CHECKING: # pragma: no cover
from snowflake.connector.result_batch import ResultBatch

try: # pragma: no cover
from snowflake.connector.constants import QueryStatus
except ImportError:
QueryStatus = None
from snowflake.connector.util_text import random_string


@pytest.fixture
Expand Down Expand Up @@ -1826,3 +1793,24 @@ async def test_decoding_utf8_for_json_result(conn_cnx):
)
with pytest.raises(Error):
await result_batch._load("À".encode("latin1"), "latin1")


async def test_fetch_download_timeout_setting(conn_cnx):
with mock.patch.multiple(
"snowflake.connector.aio._result_batch",
DOWNLOAD_TIMEOUT=0.001,
MAX_DOWNLOAD_RETRY=2,
):
sql = "SELECT seq4(), uniform(1, 10, RANDOM(12)) FROM TABLE(GENERATOR(ROWCOUNT => 100000)) v"
async with conn_cnx() as con, con.cursor() as cur:
with pytest.raises(asyncio.TimeoutError):
await (await cur.execute(sql)).fetchall()

with mock.patch.multiple(
"snowflake.connector.aio._result_batch",
DOWNLOAD_TIMEOUT=10,
MAX_DOWNLOAD_RETRY=1,
):
sql = "SELECT seq4(), uniform(1, 10, RANDOM(12)) FROM TABLE(GENERATOR(ROWCOUNT => 100000)) v"
async with conn_cnx() as con, con.cursor() as cur:
assert len(await (await cur.execute(sql)).fetchall()) == 100000