diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index e5058541fc..d835b65667 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -76,7 +76,13 @@ class Session(object): # pylint: disable=too-many-public-methods bucket based on a naming convention which includes the current AWS account ID. """ - def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None): + def __init__( + self, + boto_session=None, + sagemaker_client=None, + sagemaker_runtime_client=None, + default_bucket=None, + ): """Initialize a SageMaker ``Session``. Args: @@ -91,13 +97,25 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c ``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created using this ``Session`` use this client. If not provided, one will be created using this instance's ``boto_session``. + default_bucket (str): The default Amazon S3 bucket to be used by this session. + This will be created the next time an Amazon S3 bucket is needed (by calling + :func:`default_bucket`). + If not provided, a default bucket will be created based on the following format: + "sagemaker-{region}-{aws-account-id}". + Example: "sagemaker-my-custom-bucket". + """ self._default_bucket = None + self._default_bucket_name_override = default_bucket # currently is used for local_code in local mode self.config = None - self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client) + self._initialize( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=sagemaker_runtime_client, + ) def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): """Initialize this SageMaker Session. @@ -315,10 +333,13 @@ def default_bucket(self): return self._default_bucket region = self.boto_session.region_name - account = self.boto_session.client( - "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) - ).get_caller_identity()["Account"] - default_bucket = "sagemaker-{}-{}".format(region, account) + + default_bucket = self._default_bucket_name_override + if not default_bucket: + account = self.boto_session.client( + "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) + ).get_caller_identity()["Account"] + default_bucket = "sagemaker-{}-{}".format(region, account) s3 = self.boto_session.resource("s3") try: diff --git a/tests/integ/test_processing.py b/tests/integ/test_processing.py index 364bc1d6d6..d686a309f5 100644 --- a/tests/integ/test_processing.py +++ b/tests/integ/test_processing.py @@ -14,7 +14,10 @@ import os +import boto3 import pytest +from botocore.config import Config +from sagemaker import Session from sagemaker.fw_registry import default_framework_uri from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor, Processor @@ -23,6 +26,35 @@ from tests.integ.kms_utils import get_or_create_kms_key ROLE = "SageMakerRole" +DEFAULT_REGION = "us-west-2" +CUSTOM_BUCKET_PATH = "sagemaker-custom-bucket" + + +@pytest.fixture(scope="module") +def sagemaker_session_with_custom_bucket( + boto_config, sagemaker_client_config, sagemaker_runtime_config +): + boto_session = ( + boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION) + ) + sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10))) + sagemaker_client = ( + boto_session.client("sagemaker", **sagemaker_client_config) + if sagemaker_client_config + else None + ) + runtime_client = ( + boto_session.client("sagemaker-runtime", **sagemaker_runtime_config) + if sagemaker_runtime_config + else None + ) + + return Session( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=runtime_client, + default_bucket=CUSTOM_BUCKET_PATH, + ) @pytest.fixture(scope="module") @@ -170,6 +202,89 @@ def test_sklearn_with_customizations( assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} +def test_sklearn_with_custom_default_bucket( + sagemaker_session_with_custom_bucket, + image_uri, + sklearn_full_version, + cpu_instance_type, + output_kms_key, +): + input_file_path = os.path.join(DATA_DIR, "dummy_input.txt") + + sklearn_processor = SKLearnProcessor( + framework_version=sklearn_full_version, + role=ROLE, + command=["python3"], + instance_type=cpu_instance_type, + instance_count=1, + volume_size_in_gb=100, + volume_kms_key=None, + output_kms_key=output_kms_key, + max_runtime_in_seconds=3600, + base_job_name="test-sklearn-with-customizations", + env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}, + tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}], + sagemaker_session=sagemaker_session_with_custom_bucket, + ) + + sklearn_processor.run( + code=os.path.join(DATA_DIR, "dummy_script.py"), + inputs=[ + ProcessingInput( + source=input_file_path, + destination="/opt/ml/processing/input/container/path/", + input_name="dummy_input", + s3_data_type="S3Prefix", + s3_input_mode="File", + s3_data_distribution_type="FullyReplicated", + s3_compression_type="None", + ) + ], + outputs=[ + ProcessingOutput( + source="/opt/ml/processing/output/container/path/", + output_name="dummy_output", + s3_upload_mode="EndOfJob", + ) + ], + arguments=["-v"], + wait=True, + logs=True, + ) + + job_description = sklearn_processor.latest_job.describe() + + assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input" + assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"] + + assert job_description["ProcessingInputs"][1]["InputName"] == "code" + assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"] + + assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations") + + assert job_description["ProcessingJobStatus"] == "Completed" + + assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key + assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output" + + assert job_description["ProcessingResources"] == { + "ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100} + } + + assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"] + assert job_description["AppSpecification"]["ContainerEntrypoint"] == [ + "python3", + "/opt/ml/processing/input/code/dummy_script.py", + ] + assert job_description["AppSpecification"]["ImageUri"] == image_uri + + assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"} + + assert ROLE in job_description["RoleArn"] + + assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} + + def test_sklearn_with_no_inputs_or_outputs( sagemaker_session, image_uri, sklearn_full_version, cpu_instance_type ): @@ -405,3 +520,72 @@ def test_processor(sagemaker_session, image_uri, cpu_instance_type, output_kms_k assert ROLE in job_description["RoleArn"] assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} + + +def test_processor_with_custom_bucket( + sagemaker_session_with_custom_bucket, image_uri, cpu_instance_type, output_kms_key +): + script_path = os.path.join(DATA_DIR, "dummy_script.py") + + processor = Processor( + role=ROLE, + image_uri=image_uri, + instance_count=1, + instance_type=cpu_instance_type, + entrypoint=["python3", "/opt/ml/processing/input/code/dummy_script.py"], + volume_size_in_gb=100, + volume_kms_key=None, + output_kms_key=output_kms_key, + max_runtime_in_seconds=3600, + base_job_name="test-processor", + env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}, + tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}], + sagemaker_session=sagemaker_session_with_custom_bucket, + ) + + processor.run( + inputs=[ + ProcessingInput( + source=script_path, destination="/opt/ml/processing/input/code/", input_name="code" + ) + ], + outputs=[ + ProcessingOutput( + source="/opt/ml/processing/output/container/path/", + output_name="dummy_output", + s3_upload_mode="EndOfJob", + ) + ], + arguments=["-v"], + wait=True, + logs=True, + ) + + job_description = processor.latest_job.describe() + + assert job_description["ProcessingInputs"][0]["InputName"] == "code" + assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"] + + assert job_description["ProcessingJobName"].startswith("test-processor") + + assert job_description["ProcessingJobStatus"] == "Completed" + + assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key + assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output" + + assert job_description["ProcessingResources"] == { + "ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100} + } + + assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"] + assert job_description["AppSpecification"]["ContainerEntrypoint"] == [ + "python3", + "/opt/ml/processing/input/code/dummy_script.py", + ] + assert job_description["AppSpecification"]["ImageUri"] == image_uri + + assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"} + + assert ROLE in job_description["RoleArn"] + + assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py new file mode 100644 index 0000000000..0340666858 --- /dev/null +++ b/tests/integ/test_session.py @@ -0,0 +1,50 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import boto3 +from botocore.config import Config + +from sagemaker import Session + +DEFAULT_REGION = "us-west-2" +CUSTOM_BUCKET_NAME = "this-bucket-should-not-exist" + + +def test_sagemaker_session_does_not_create_bucket_on_init( + sagemaker_client_config, sagemaker_runtime_config, boto_config +): + boto_session = ( + boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION) + ) + sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10))) + sagemaker_client = ( + boto_session.client("sagemaker", **sagemaker_client_config) + if sagemaker_client_config + else None + ) + runtime_client = ( + boto_session.client("sagemaker-runtime", **sagemaker_runtime_config) + if sagemaker_runtime_config + else None + ) + + Session( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=runtime_client, + default_bucket=CUSTOM_BUCKET_NAME, + ) + + s3 = boto3.resource("s3") + assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None