Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/snowflake/connector/aio/_file_transfer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def postprocess_done_cb(
async def _transfer_accelerate_config(self) -> None:
if self._stage_location_type == S3_FS and self._file_metadata:
client = await self._create_file_transfer_client(self._file_metadata[0])
self._use_accelerate_endpoint = client.transfer_accelerate_config()
self._use_accelerate_endpoint = await client.transfer_accelerate_config()

async def _create_file_transfer_client(
self, meta: SnowflakeFileMeta
Expand Down Expand Up @@ -289,6 +289,7 @@ async def _create_file_transfer_client(
use_accelerate_endpoint=self._use_accelerate_endpoint,
use_s3_regional_url=self._use_s3_regional_url,
)
await client.transfer_accelerate_config(self._use_accelerate_endpoint)
return client
elif self._stage_location_type == GCS_FS:
client = SnowflakeGCSRestClient(
Expand Down
38 changes: 35 additions & 3 deletions src/snowflake/connector/aio/_s3_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@ def __init__(
self.endpoint = (
f"https://{self.s3location.bucket_name}." + stage_info["endPoint"]
)
# self.transfer_accelerate_config(use_accelerate_endpoint)
self.transfer_accelerate_config(False)
# TODO: fix accelerate logic SNOW-1628850

async def _send_request_with_authentication_and_retry(
self,
Expand Down Expand Up @@ -376,6 +373,41 @@ async def _get_bucket_accelerate_config(self, bucket_name: str) -> bool:
return use_accelerate_endpoint
return False

async def transfer_accelerate_config(
self, use_accelerate_endpoint: bool | None = None
) -> bool:
# accelerate cannot be used in China and us government
if self.region_name and self.region_name.startswith("cn-"):
self.endpoint = (
f"https://{self.s3location.bucket_name}."
f"s3.{self.region_name}.amazonaws.com.cn"
)
return False
# if self.endpoint has been set, e.g. by metadata, no more config is needed.
if self.endpoint is not None:
return self.endpoint.find("s3-accelerate.amazonaws.com") >= 0
if self.use_s3_regional_url:
self.endpoint = (
f"https://{self.s3location.bucket_name}."
f"s3.{self.region_name}.amazonaws.com"
)
return False
else:
if use_accelerate_endpoint is None:
use_accelerate_endpoint = await self._get_bucket_accelerate_config(
self.s3location.bucket_name
)

if use_accelerate_endpoint:
self.endpoint = (
f"https://{self.s3location.bucket_name}.s3-accelerate.amazonaws.com"
)
else:
self.endpoint = (
f"https://{self.s3location.bucket_name}.s3.amazonaws.com"
)
return use_accelerate_endpoint

async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool:
"""Extract error code and error message from the S3's error response.
Expected format:
Expand Down
1 change: 1 addition & 0 deletions test/integ/aio/test_put_get_with_aws_token_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ async def test_put_with_invalid_token(tmpdir, aio_connection):
)

client = SnowflakeS3RestClient(meta, creds, stage_info, 8388608)
await client.transfer_accelerate_config(None)
await client.get_file_header(meta.name) # positive case

# negative case, no aws token
Expand Down
23 changes: 19 additions & 4 deletions test/unit/aio/test_s3_util_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def test_upload_file_with_s3_upload_failed_error(tmp_path):
)
exc = Exception("Stop executing")

def mock_transfer_accelerate_config(
async def mock_transfer_accelerate_config(
self: SnowflakeS3RestClient,
use_accelerate_endpoint: bool | None = None,
) -> bool:
Expand All @@ -117,7 +117,7 @@ def mock_transfer_accelerate_config(
return_value=True,
):
with mock.patch(
"snowflake.connector.s3_storage_client.SnowflakeS3RestClient.transfer_accelerate_config",
"snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient.transfer_accelerate_config",
mock_transfer_accelerate_config,
):
with mock.patch(
Expand Down Expand Up @@ -160,6 +160,7 @@ async def test_get_header_expiry_error():
},
8 * megabyte,
)
await rest_client.transfer_accelerate_config(None)

with mock.patch(
"snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._has_expired_token",
Expand Down Expand Up @@ -241,6 +242,7 @@ async def test_upload_expiry_error():
},
8 * megabyte,
)
await rest_client.transfer_accelerate_config(None)

with mock.patch(
"snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._has_expired_token",
Expand Down Expand Up @@ -332,6 +334,7 @@ async def test_download_expiry_error():
},
8 * megabyte,
)
await rest_client.transfer_accelerate_config(None)

with mock.patch(
"snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._has_expired_token",
Expand Down Expand Up @@ -373,12 +376,23 @@ async def test_download_unknown_error(caplog):
message="No, just chuck testing...",
headers={},
)

async def mock_transfer_accelerate_config(
self: SnowflakeS3RestClient,
use_accelerate_endpoint: bool | None = None,
) -> bool:
self.endpoint = f"https://{self.s3location.bucket_name}.s3.awsamazon.com"
return False

with mock.patch(
"snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient._send_request_with_authentication_and_retry",
side_effect=error,
), mock.patch(
"snowflake.connector.aio._file_transfer_agent.SnowflakeFileTransferAgent._transfer_accelerate_config",
side_effect=None,
), mock.patch(
"snowflake.connector.aio._s3_storage_client.SnowflakeS3RestClient.transfer_accelerate_config",
mock_transfer_accelerate_config,
):
await agent.execute()
assert agent._file_metadata[0].error_details.status == 400
Expand Down Expand Up @@ -422,6 +436,7 @@ async def test_download_retry_exceeded_error():
},
8 * megabyte,
)
await rest_client.transfer_accelerate_config()
rest_client.SLEEP_UNIT = 0

with mock.patch(
Expand Down Expand Up @@ -466,7 +481,7 @@ async def test_accelerate_in_china_endpoint():
},
8 * megabyte,
)
assert not rest_client.transfer_accelerate_config()
assert not await rest_client.transfer_accelerate_config()

rest_client = SnowflakeS3RestClient(
meta,
Expand All @@ -484,4 +499,4 @@ async def test_accelerate_in_china_endpoint():
},
8 * megabyte,
)
assert not rest_client.transfer_accelerate_config()
assert not await rest_client.transfer_accelerate_config()
Loading