From f6fc4b92a39b744debb61bf8dfb56d55b3932bad Mon Sep 17 00:00:00 2001 From: Karim Nakad Date: Mon, 16 Dec 2019 12:53:02 -0800 Subject: [PATCH 1/3] feature: allow setting the default bucket in Session Default bucket not created on init. --- src/sagemaker/local/local_session.py | 5 +- src/sagemaker/session.py | 34 ++++- tests/integ/test_local_mode.py | 5 +- tests/integ/test_processing.py | 184 +++++++++++++++++++++++++++ tests/integ/test_session.py | 50 ++++++++ 5 files changed, 269 insertions(+), 9 deletions(-) create mode 100644 tests/integ/test_session.py diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index e96b068899..15de7833f0 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -379,7 +379,7 @@ def __init__(self, boto_session=None): if platform.system() == "Windows": logger.warning("Windows Support for Local Mode is Experimental") - def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket): """Initialize this Local SageMaker Session. Args: @@ -413,6 +413,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): self.config = yaml.load(open(sagemaker_config_file, "r")) + self._default_bucket = None + self._desired_default_bucket_name = default_bucket + def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"): """ diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index e5058541fc..58b5d549d4 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -76,7 +76,13 @@ class Session(object): # pylint: disable=too-many-public-methods bucket based on a naming convention which includes the current AWS account ID. """ - def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None): + def __init__( + self, + boto_session=None, + sagemaker_client=None, + sagemaker_runtime_client=None, + default_bucket=None, + ): """Initialize a SageMaker ``Session``. Args: @@ -91,15 +97,23 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c ``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created using this ``Session`` use this client. If not provided, one will be created using this instance's ``boto_session``. + default_bucket (str): The default s3 bucket to be used by this session. + Ex: "sagemaker-us-west-2" + """ self._default_bucket = None # currently is used for local_code in local mode self.config = None - self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client) + self._initialize( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=sagemaker_runtime_client, + default_bucket=default_bucket, + ) - def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket): """Initialize this SageMaker Session. Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client. @@ -126,6 +140,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): prepend_user_agent(self.sagemaker_runtime_client) + self._default_bucket = None + self._desired_default_bucket_name = default_bucket + self.local_mode = False @property @@ -314,11 +331,14 @@ def default_bucket(self): if self._default_bucket: return self._default_bucket + default_bucket = self._desired_default_bucket_name region = self.boto_session.region_name - account = self.boto_session.client( - "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) - ).get_caller_identity()["Account"] - default_bucket = "sagemaker-{}-{}".format(region, account) + + if not default_bucket: + account = self.boto_session.client( + "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) + ).get_caller_identity()["Account"] + default_bucket = "sagemaker-{}-{}".format(region, account) s3 = self.boto_session.resource("s3") try: diff --git a/tests/integ/test_local_mode.py b/tests/integ/test_local_mode.py index a73c9e1e0d..f076a79404 100644 --- a/tests/integ/test_local_mode.py +++ b/tests/integ/test_local_mode.py @@ -43,7 +43,7 @@ class LocalNoS3Session(LocalSession): def __init__(self): super(LocalSession, self).__init__() - def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket): self.boto_session = boto3.Session(region_name=DEFAULT_REGION) if self.config is None: self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}} @@ -53,6 +53,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) self.local_mode = True + self._default_bucket = None + self._desired_default_bucket_name = default_bucket + @pytest.fixture(scope="module") def mxnet_model(sagemaker_local_session, mxnet_full_version): diff --git a/tests/integ/test_processing.py b/tests/integ/test_processing.py index 364bc1d6d6..d686a309f5 100644 --- a/tests/integ/test_processing.py +++ b/tests/integ/test_processing.py @@ -14,7 +14,10 @@ import os +import boto3 import pytest +from botocore.config import Config +from sagemaker import Session from sagemaker.fw_registry import default_framework_uri from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor, Processor @@ -23,6 +26,35 @@ from tests.integ.kms_utils import get_or_create_kms_key ROLE = "SageMakerRole" +DEFAULT_REGION = "us-west-2" +CUSTOM_BUCKET_PATH = "sagemaker-custom-bucket" + + +@pytest.fixture(scope="module") +def sagemaker_session_with_custom_bucket( + boto_config, sagemaker_client_config, sagemaker_runtime_config +): + boto_session = ( + boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION) + ) + sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10))) + sagemaker_client = ( + boto_session.client("sagemaker", **sagemaker_client_config) + if sagemaker_client_config + else None + ) + runtime_client = ( + boto_session.client("sagemaker-runtime", **sagemaker_runtime_config) + if sagemaker_runtime_config + else None + ) + + return Session( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=runtime_client, + default_bucket=CUSTOM_BUCKET_PATH, + ) @pytest.fixture(scope="module") @@ -170,6 +202,89 @@ def test_sklearn_with_customizations( assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} +def test_sklearn_with_custom_default_bucket( + sagemaker_session_with_custom_bucket, + image_uri, + sklearn_full_version, + cpu_instance_type, + output_kms_key, +): + input_file_path = os.path.join(DATA_DIR, "dummy_input.txt") + + sklearn_processor = SKLearnProcessor( + framework_version=sklearn_full_version, + role=ROLE, + command=["python3"], + instance_type=cpu_instance_type, + instance_count=1, + volume_size_in_gb=100, + volume_kms_key=None, + output_kms_key=output_kms_key, + max_runtime_in_seconds=3600, + base_job_name="test-sklearn-with-customizations", + env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}, + tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}], + sagemaker_session=sagemaker_session_with_custom_bucket, + ) + + sklearn_processor.run( + code=os.path.join(DATA_DIR, "dummy_script.py"), + inputs=[ + ProcessingInput( + source=input_file_path, + destination="/opt/ml/processing/input/container/path/", + input_name="dummy_input", + s3_data_type="S3Prefix", + s3_input_mode="File", + s3_data_distribution_type="FullyReplicated", + s3_compression_type="None", + ) + ], + outputs=[ + ProcessingOutput( + source="/opt/ml/processing/output/container/path/", + output_name="dummy_output", + s3_upload_mode="EndOfJob", + ) + ], + arguments=["-v"], + wait=True, + logs=True, + ) + + job_description = sklearn_processor.latest_job.describe() + + assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input" + assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"] + + assert job_description["ProcessingInputs"][1]["InputName"] == "code" + assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"] + + assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations") + + assert job_description["ProcessingJobStatus"] == "Completed" + + assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key + assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output" + + assert job_description["ProcessingResources"] == { + "ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100} + } + + assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"] + assert job_description["AppSpecification"]["ContainerEntrypoint"] == [ + "python3", + "/opt/ml/processing/input/code/dummy_script.py", + ] + assert job_description["AppSpecification"]["ImageUri"] == image_uri + + assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"} + + assert ROLE in job_description["RoleArn"] + + assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} + + def test_sklearn_with_no_inputs_or_outputs( sagemaker_session, image_uri, sklearn_full_version, cpu_instance_type ): @@ -405,3 +520,72 @@ def test_processor(sagemaker_session, image_uri, cpu_instance_type, output_kms_k assert ROLE in job_description["RoleArn"] assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} + + +def test_processor_with_custom_bucket( + sagemaker_session_with_custom_bucket, image_uri, cpu_instance_type, output_kms_key +): + script_path = os.path.join(DATA_DIR, "dummy_script.py") + + processor = Processor( + role=ROLE, + image_uri=image_uri, + instance_count=1, + instance_type=cpu_instance_type, + entrypoint=["python3", "/opt/ml/processing/input/code/dummy_script.py"], + volume_size_in_gb=100, + volume_kms_key=None, + output_kms_key=output_kms_key, + max_runtime_in_seconds=3600, + base_job_name="test-processor", + env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}, + tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}], + sagemaker_session=sagemaker_session_with_custom_bucket, + ) + + processor.run( + inputs=[ + ProcessingInput( + source=script_path, destination="/opt/ml/processing/input/code/", input_name="code" + ) + ], + outputs=[ + ProcessingOutput( + source="/opt/ml/processing/output/container/path/", + output_name="dummy_output", + s3_upload_mode="EndOfJob", + ) + ], + arguments=["-v"], + wait=True, + logs=True, + ) + + job_description = processor.latest_job.describe() + + assert job_description["ProcessingInputs"][0]["InputName"] == "code" + assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"] + + assert job_description["ProcessingJobName"].startswith("test-processor") + + assert job_description["ProcessingJobStatus"] == "Completed" + + assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key + assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output" + + assert job_description["ProcessingResources"] == { + "ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100} + } + + assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"] + assert job_description["AppSpecification"]["ContainerEntrypoint"] == [ + "python3", + "/opt/ml/processing/input/code/dummy_script.py", + ] + assert job_description["AppSpecification"]["ImageUri"] == image_uri + + assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"} + + assert ROLE in job_description["RoleArn"] + + assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py new file mode 100644 index 0000000000..0340666858 --- /dev/null +++ b/tests/integ/test_session.py @@ -0,0 +1,50 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import boto3 +from botocore.config import Config + +from sagemaker import Session + +DEFAULT_REGION = "us-west-2" +CUSTOM_BUCKET_NAME = "this-bucket-should-not-exist" + + +def test_sagemaker_session_does_not_create_bucket_on_init( + sagemaker_client_config, sagemaker_runtime_config, boto_config +): + boto_session = ( + boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION) + ) + sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10))) + sagemaker_client = ( + boto_session.client("sagemaker", **sagemaker_client_config) + if sagemaker_client_config + else None + ) + runtime_client = ( + boto_session.client("sagemaker-runtime", **sagemaker_runtime_config) + if sagemaker_runtime_config + else None + ) + + Session( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=runtime_client, + default_bucket=CUSTOM_BUCKET_NAME, + ) + + s3 = boto3.resource("s3") + assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None From ba825c478a084a0b20e1f292d8dcc8dd07f0c746 Mon Sep 17 00:00:00 2001 From: Karim Nakad Date: Mon, 16 Dec 2019 13:15:12 -0800 Subject: [PATCH 2/3] Fixing docstring, variable name, and line ordering. --- src/sagemaker/session.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 58b5d549d4..37e636ed0d 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -97,11 +97,16 @@ def __init__( ``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created using this ``Session`` use this client. If not provided, one will be created using this instance's ``boto_session``. - default_bucket (str): The default s3 bucket to be used by this session. - Ex: "sagemaker-us-west-2" + default_bucket (str): The default Amazon S3 bucket to be used by this session. + This will be created the next time an Amazon S3 bucket is needed (by calling + :func:`default_bucket`). + If not provided, a default bucket will be created based on the following format: + "sagemaker-{region}-{aws-account-id}". + Example: "sagemaker-my-custom-bucket". """ self._default_bucket = None + self._default_bucket_name_override = default_bucket # currently is used for local_code in local mode self.config = None @@ -140,9 +145,6 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, prepend_user_agent(self.sagemaker_runtime_client) - self._default_bucket = None - self._desired_default_bucket_name = default_bucket - self.local_mode = False @property @@ -331,9 +333,9 @@ def default_bucket(self): if self._default_bucket: return self._default_bucket - default_bucket = self._desired_default_bucket_name region = self.boto_session.region_name + default_bucket = self._default_bucket_name_override if not default_bucket: account = self.boto_session.client( "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) From 31fc678e81722b12a89e7e238e85c56bf59b9772 Mon Sep 17 00:00:00 2001 From: Karim Nakad Date: Mon, 16 Dec 2019 13:20:56 -0800 Subject: [PATCH 3/3] Removing unnecessary code. --- src/sagemaker/local/local_session.py | 5 +---- src/sagemaker/session.py | 3 +-- tests/integ/test_local_mode.py | 5 +---- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 15de7833f0..e96b068899 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -379,7 +379,7 @@ def __init__(self, boto_session=None): if platform.system() == "Windows": logger.warning("Windows Support for Local Mode is Experimental") - def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket): + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): """Initialize this Local SageMaker Session. Args: @@ -413,9 +413,6 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, self.config = yaml.load(open(sagemaker_config_file, "r")) - self._default_bucket = None - self._desired_default_bucket_name = default_bucket - def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"): """ diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 37e636ed0d..d835b65667 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -115,10 +115,9 @@ def __init__( boto_session=boto_session, sagemaker_client=sagemaker_client, sagemaker_runtime_client=sagemaker_runtime_client, - default_bucket=default_bucket, ) - def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket): + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): """Initialize this SageMaker Session. Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client. diff --git a/tests/integ/test_local_mode.py b/tests/integ/test_local_mode.py index f076a79404..a73c9e1e0d 100644 --- a/tests/integ/test_local_mode.py +++ b/tests/integ/test_local_mode.py @@ -43,7 +43,7 @@ class LocalNoS3Session(LocalSession): def __init__(self): super(LocalSession, self).__init__() - def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket): + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): self.boto_session = boto3.Session(region_name=DEFAULT_REGION) if self.config is None: self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}} @@ -53,9 +53,6 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) self.local_mode = True - self._default_bucket = None - self._desired_default_bucket_name = default_bucket - @pytest.fixture(scope="module") def mxnet_model(sagemaker_local_session, mxnet_full_version):