Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions dbt/adapters/clickhouse/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def s3source_clause(
) -> str:
s3config = self.config.vars.vars.get(config_name, {})
s3config.update(s3_model_config)

structure = structure or s3config.get('structure', '')
struct = ''
if structure:
Expand All @@ -284,7 +285,9 @@ def s3source_clause(
struct = f", '{','.join(structure)}'"
else:
struct = f",'{structure}'"

fmt = fmt or s3config.get('fmt')

bucket = bucket or s3config.get('bucket', '')
path = path or s3config.get('path', '')
url = bucket.replace('https://', '')
Expand All @@ -293,19 +296,26 @@ def s3source_clause(
path = f'/{path}'
url = f'{url}{path}'.replace('//', '/')
url = f'https://{url}'

access = ''
aws_access_key_id = aws_access_key_id or s3config.get('aws_access_key_id')
aws_secret_access_key = aws_secret_access_key or s3config.get('aws_secret_access_key')
if aws_access_key_id and not aws_secret_access_key:
raise DbtRuntimeError('S3 aws_access_key_id specified without aws_secret_access_key')
if aws_secret_access_key and not aws_access_key_id:
raise DbtRuntimeError('S3 aws_secret_access_key specified without aws_access_key_id')
if aws_access_key_id:
access = f", '{aws_access_key_id}', '{aws_secret_access_key}'"

comp = compression or s3config.get('compression', '')
if comp:
comp = f"', {comp}'"

extra_credentials = ''
role_arn = role_arn or s3config.get('role_arn')
if role_arn:
extra_credentials = f", extra_credentials(role_arn='{role_arn}')"

return f"s3('{url}'{access}, '{fmt}'{struct}{comp}{extra_credentials})"

def check_schema_exists(self, database, schema):
Expand Down
63 changes: 63 additions & 0 deletions tests/integration/adapter/clickhouse/test_clickhouse_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,66 @@ def test_read(self, project):
run_dbt(["run", "--select", "s3_taxis_source.sql"])
result = project.run_sql("select count() as num_rows from s3_taxis_source", fetch="one")
assert result[0] == 1000


class TestS3AwsAccessGlobal:
@pytest.fixture(scope="class")
def project_config_update(self):
return {
'vars': {
'taxi_s3': {
'bucket': 'https://datasets-documentation.s3.eu-west-3.amazonaws.com/nyc-taxi/',
'fmt': 'TabSeparatedWithNames',
'aws_access_key_id': 'ABCDEFGHIJKLMNOPQRSTUVWXYZ',
'aws_secret_access_key': '1234567890123456789012345678901234567890',
}
}
}

@pytest.fixture(scope="class")
def models(self):
return {
"s3_taxis_source.sql": s3_taxis_full_source,
"schema.yml": schema_yaml,
}

def test_role_arn_in_compiled_sql(self, project):
# Only compile, don't run
result = run_dbt(["compile", "--select", "s3_taxis_source.sql"], expect_pass=True)

# Assert the SQL contains the expected role_arn function call
assert (
", 'ABCDEFGHIJKLMNOPQRSTUVWXYZ', '1234567890123456789012345678901234567890'"
in result.results[0].node.compiled_code
)


class TestS3RoleArnGlobal:
@pytest.fixture(scope="class")
def project_config_update(self):
return {
'vars': {
'taxi_s3': {
'bucket': 'https://datasets-documentation.s3.eu-west-3.amazonaws.com/nyc-taxi/',
'fmt': 'TabSeparatedWithNames',
'role_arn': 'arn:aws:iam::123456789012:role/my-role',
}
}
}

@pytest.fixture(scope="class")
def models(self):
return {
"s3_taxis_source.sql": s3_taxis_full_source,
"schema.yml": schema_yaml,
}

def test_role_arn_in_compiled_sql(self, project):
# Only compile, don't run
result = run_dbt(["compile", "--select", "s3_taxis_source.sql"], expect_pass=True)

# Assert the SQL contains the expected role_arn function call
assert (
"extra_credentials(role_arn='arn:aws:iam::123456789012:role/my-role')"
in result.results[0].node.compiled_code
)
Loading