Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions src/snowflake/connector/aio/_azure_storage_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from __future__ import annotations

import json
import xml.etree.ElementTree as ET
from datetime import datetime, timezone
from logging import getLogger
from random import choice
from string import hexdigits
from typing import TYPE_CHECKING, Any

import aiohttp

from ..azure_storage_client import (
SnowflakeAzureRestClient as SnowflakeAzureRestClientSync,
)
from ..compat import quote
from ..constants import FileHeader, ResultStatus
from ..encryption_util import EncryptionMetadata
from ._storage_client import SnowflakeStorageClient as SnowflakeStorageClientAsync

if TYPE_CHECKING: # pragma: no cover
from ..file_transfer_agent import SnowflakeFileMeta, StorageCredential

logger = getLogger(__name__)

from ..azure_storage_client import (
ENCRYPTION_DATA,
MATDESC,
TOKEN_EXPIRATION_ERR_MESSAGE,
)


class SnowflakeAzureRestClient(
SnowflakeStorageClientAsync, SnowflakeAzureRestClientSync
):
def __init__(
self,
meta: SnowflakeFileMeta,
credentials: StorageCredential | None,
chunk_size: int,
stage_info: dict[str, Any],
use_s3_regional_url: bool = False,
) -> None:
SnowflakeAzureRestClientSync.__init__(
self,
meta=meta,
stage_info=stage_info,
chunk_size=chunk_size,
credentials=credentials,
)

async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool:
return response.status == 403 and any(
message in response.reason for message in TOKEN_EXPIRATION_ERR_MESSAGE
)

async def _send_request_with_authentication_and_retry(
self,
verb: str,
url: str,
retry_id: int | str,
headers: dict[str, Any] = None,
data: bytes = None,
) -> aiohttp.ClientResponse:
if not headers:
headers = {}

def generate_authenticated_url_and_rest_args() -> tuple[str, dict[str, Any]]:
curtime = datetime.now(timezone.utc).replace(tzinfo=None)
timestamp = curtime.strftime("YYYY-MM-DD")
sas_token = self.credentials.creds["AZURE_SAS_TOKEN"]
if sas_token and sas_token.startswith("?"):
sas_token = sas_token[1:]
if "?" in url:
_url = url + "&" + sas_token
else:
_url = url + "?" + sas_token
headers["Date"] = timestamp
rest_args = {"headers": headers}
if data:
rest_args["data"] = data
return _url, rest_args

return await self._send_request_with_retry(
verb, generate_authenticated_url_and_rest_args, retry_id
)

async def get_file_header(self, filename: str) -> FileHeader | None:
"""Gets Azure file properties."""
container_name = quote(self.azure_location.container_name)
path = quote(self.azure_location.path) + quote(filename)
meta = self.meta
# HTTP HEAD request
url = f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}"
retry_id = "HEAD"
self.retry_count[retry_id] = 0
r = await self._send_request_with_authentication_and_retry(
"HEAD", url, retry_id
)
if r.status == 200:
meta.result_status = ResultStatus.UPLOADED
enc_data_str = r.headers.get(ENCRYPTION_DATA)
encryption_data = None if enc_data_str is None else json.loads(enc_data_str)
encryption_metadata = (
None
if not encryption_data
else EncryptionMetadata(
key=encryption_data["WrappedContentKey"]["EncryptedKey"],
iv=encryption_data["ContentEncryptionIV"],
matdesc=r.headers.get(MATDESC),
)
)
return FileHeader(
digest=r.headers.get("x-ms-meta-sfcdigest"),
content_length=int(r.headers.get("Content-Length")),
encryption_metadata=encryption_metadata,
)
elif r.status == 404:
meta.result_status = ResultStatus.NOT_FOUND_FILE
return FileHeader(
digest=None, content_length=None, encryption_metadata=None
)
else:
r.raise_for_status()

async def _initiate_multipart_upload(self) -> None:
self.block_ids = [
"".join(choice(hexdigits) for _ in range(20))
for _ in range(self.num_of_chunks)
]

async def _upload_chunk(self, chunk_id: int, chunk: bytes) -> None:
container_name = quote(self.azure_location.container_name)
path = quote(self.azure_location.path + self.meta.dst_file_name.lstrip("/"))

if self.num_of_chunks > 1:
block_id = self.block_ids[chunk_id]
url = (
f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}?comp=block"
f"&blockid={block_id}"
)
headers = {"Content-Length": str(len(chunk))}
r = await self._send_request_with_authentication_and_retry(
"PUT", url, chunk_id, headers=headers, data=chunk
)
else:
# single request
azure_metadata = self._prepare_file_metadata()
url = f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}"
headers = {
"x-ms-blob-type": "BlockBlob",
"Content-Encoding": "utf-8",
}
headers.update(azure_metadata)
r = await self._send_request_with_authentication_and_retry(
"PUT", url, chunk_id, headers=headers, data=chunk
)
r.raise_for_status() # expect status code 201

async def _complete_multipart_upload(self) -> None:
container_name = quote(self.azure_location.container_name)
path = quote(self.azure_location.path + self.meta.dst_file_name.lstrip("/"))
url = (
f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}?comp"
f"=blocklist"
)
root = ET.Element("BlockList")
for block_id in self.block_ids:
part = ET.Element("Latest")
part.text = block_id
root.append(part)
headers = {"x-ms-blob-content-encoding": "utf-8"}
azure_metadata = self._prepare_file_metadata()
headers.update(azure_metadata)
retry_id = "COMPLETE"
self.retry_count[retry_id] = 0
r = await self._send_request_with_authentication_and_retry(
"PUT", url, "COMPLETE", headers=headers, data=ET.tostring(root)
)
r.raise_for_status() # expects status code 201

async def download_chunk(self, chunk_id: int) -> None:
container_name = quote(self.azure_location.container_name)
path = quote(self.azure_location.path + self.meta.src_file_name.lstrip("/"))
url = f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}"
if self.num_of_chunks > 1:
chunk_size = self.chunk_size
if chunk_id < self.num_of_chunks - 1:
_range = f"{chunk_id * chunk_size}-{(chunk_id + 1) * chunk_size - 1}"
else:
_range = f"{chunk_id * chunk_size}-"
headers = {"Range": f"bytes={_range}"}
r = await self._send_request_with_authentication_and_retry(
"GET", url, chunk_id, headers=headers
) # expect 206
else:
# single request
r = await self._send_request_with_authentication_and_retry(
"GET", url, chunk_id
)
if r.status in (200, 206):
self.write_downloaded_chunk(chunk_id, await r.read())
r.raise_for_status()
28 changes: 21 additions & 7 deletions src/snowflake/connector/aio/_file_transfer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from logging import getLogger
from typing import IO, TYPE_CHECKING, Any

from ..azure_storage_client import SnowflakeAzureRestClient
from ..constants import (
AZURE_CHUNK_SIZE,
AZURE_FS,
Expand All @@ -29,8 +28,9 @@
SnowflakeFileTransferAgent as SnowflakeFileTransferAgentSync,
)
from ..file_transfer_agent import SnowflakeProgressPercentage, _chunk_size_calculator
from ..gcs_storage_client import SnowflakeGCSRestClient
from ..local_storage_client import SnowflakeLocalStorageClient
from ._azure_storage_client import SnowflakeAzureRestClient
from ._gcs_storage_client import SnowflakeGCSRestClient
from ._s3_storage_client import SnowflakeS3RestClient
from ._storage_client import SnowflakeStorageClient

Expand Down Expand Up @@ -92,7 +92,7 @@ async def execute(self) -> None:
for m in self._file_metadata:
m.sfagent = self

self._transfer_accelerate_config()
await self._transfer_accelerate_config()

if self._command_type == CMD_TYPE_DOWNLOAD:
if not os.path.isdir(self._local_location):
Expand Down Expand Up @@ -139,7 +139,7 @@ async def execute(self) -> None:
result.result_status = result.result_status.value

async def transfer(self, metas: list[SnowflakeFileMeta]) -> None:
files = [self._create_file_transfer_client(m) for m in metas]
files = [await self._create_file_transfer_client(m) for m in metas]
is_upload = self._command_type == CMD_TYPE_UPLOAD
finish_download_upload_tasks = []

Expand Down Expand Up @@ -258,7 +258,12 @@ def postprocess_done_cb(

self._results = metas

def _create_file_transfer_client(
async def _transfer_accelerate_config(self) -> None:
if self._stage_location_type == S3_FS and self._file_metadata:
client = await self._create_file_transfer_client(self._file_metadata[0])
self._use_accelerate_endpoint = client.transfer_accelerate_config()

async def _create_file_transfer_client(
self, meta: SnowflakeFileMeta
) -> SnowflakeStorageClient:
if self._stage_location_type == LOCAL_FS:
Expand All @@ -276,21 +281,30 @@ def _create_file_transfer_client(
use_s3_regional_url=self._use_s3_regional_url,
)
elif self._stage_location_type == S3_FS:
return SnowflakeS3RestClient(
client = SnowflakeS3RestClient(
meta=meta,
credentials=self._credentials,
stage_info=self._stage_info,
chunk_size=_chunk_size_calculator(meta.src_file_size),
use_accelerate_endpoint=self._use_accelerate_endpoint,
use_s3_regional_url=self._use_s3_regional_url,
)
return client
elif self._stage_location_type == GCS_FS:
return SnowflakeGCSRestClient(
client = SnowflakeGCSRestClient(
meta,
self._credentials,
self._stage_info,
self._cursor._connection,
self._command,
use_s3_regional_url=self._use_s3_regional_url,
)
if client.security_token:
logger.debug(f"len(GCS_ACCESS_TOKEN): {len(client.security_token)}")
else:
logger.debug(
"No access token received from GS, requesting presigned url"
)
await client._update_presigned_url()
return client
raise Exception(f"{self._stage_location_type} is an unknown stage type")
Loading
Loading