Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def __init__(self, boto_session=None):
if platform.system() == "Windows":
logger.warning("Windows Support for Local Mode is Experimental")

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
"""Initialize this Local SageMaker Session.

Args:
Expand Down Expand Up @@ -413,6 +413,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):

self.config = yaml.load(open(sagemaker_config_file, "r"))

self._default_bucket = None
self._desired_default_bucket_name = default_bucket

def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
"""

Expand Down
34 changes: 27 additions & 7 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,13 @@ class Session(object): # pylint: disable=too-many-public-methods
bucket based on a naming convention which includes the current AWS account ID.
"""

def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None):
def __init__(
self,
boto_session=None,
sagemaker_client=None,
sagemaker_runtime_client=None,
default_bucket=None,
):
"""Initialize a SageMaker ``Session``.

Args:
Expand All @@ -91,15 +97,23 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c
``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
using this ``Session`` use this client. If not provided, one will be created using
this instance's ``boto_session``.
default_bucket (str): The default s3 bucket to be used by this session.
Ex: "sagemaker-us-west-2"

"""
self._default_bucket = None

# currently is used for local_code in local mode
self.config = None

self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client)
self._initialize(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=sagemaker_runtime_client,
default_bucket=default_bucket,
)

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
"""Initialize this SageMaker Session.

Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
Expand All @@ -126,6 +140,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):

prepend_user_agent(self.sagemaker_runtime_client)

self._default_bucket = None
self._desired_default_bucket_name = default_bucket

self.local_mode = False

@property
Expand Down Expand Up @@ -314,11 +331,14 @@ def default_bucket(self):
if self._default_bucket:
return self._default_bucket

default_bucket = self._desired_default_bucket_name
region = self.boto_session.region_name
account = self.boto_session.client(
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
).get_caller_identity()["Account"]
default_bucket = "sagemaker-{}-{}".format(region, account)

if not default_bucket:
account = self.boto_session.client(
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
).get_caller_identity()["Account"]
default_bucket = "sagemaker-{}-{}".format(region, account)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd move l. 334 to be right before the if statement because it's clearer that's it's basically just an if/else there. Or just make it an if/else:

if self._desired_default_bucket_name
    default_bucket = self._desired_default_bucket_name
else
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the line closer, but I like to avoid ifelses when possible.
I think it's more readable to have a default case get overridden under specific circumstances. Also guarantees that the variable always gets set.


s3 = self.boto_session.resource("s3")
try:
Expand Down
5 changes: 4 additions & 1 deletion tests/integ/test_local_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class LocalNoS3Session(LocalSession):
def __init__(self):
super(LocalSession, self).__init__()

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
self.boto_session = boto3.Session(region_name=DEFAULT_REGION)
if self.config is None:
self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}}
Expand All @@ -53,6 +53,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
self.local_mode = True

self._default_bucket = None
self._desired_default_bucket_name = default_bucket


@pytest.fixture(scope="module")
def mxnet_model(sagemaker_local_session, mxnet_full_version):
Expand Down
184 changes: 184 additions & 0 deletions tests/integ/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

import os

import boto3
import pytest
from botocore.config import Config
from sagemaker import Session
from sagemaker.fw_registry import default_framework_uri

from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor, Processor
Expand All @@ -23,6 +26,35 @@
from tests.integ.kms_utils import get_or_create_kms_key

ROLE = "SageMakerRole"
DEFAULT_REGION = "us-west-2"
CUSTOM_BUCKET_PATH = "sagemaker-custom-bucket"


@pytest.fixture(scope="module")
def sagemaker_session_with_custom_bucket(
boto_config, sagemaker_client_config, sagemaker_runtime_config
):
boto_session = (
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
)
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
sagemaker_client = (
boto_session.client("sagemaker", **sagemaker_client_config)
if sagemaker_client_config
else None
)
runtime_client = (
boto_session.client("sagemaker-runtime", **sagemaker_runtime_config)
if sagemaker_runtime_config
else None
)

return Session(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=runtime_client,
default_bucket=CUSTOM_BUCKET_PATH,
)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -170,6 +202,89 @@ def test_sklearn_with_customizations(
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}


def test_sklearn_with_custom_default_bucket(
sagemaker_session_with_custom_bucket,
image_uri,
sklearn_full_version,
cpu_instance_type,
output_kms_key,
):
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")

sklearn_processor = SKLearnProcessor(
framework_version=sklearn_full_version,
role=ROLE,
command=["python3"],
instance_type=cpu_instance_type,
instance_count=1,
volume_size_in_gb=100,
volume_kms_key=None,
output_kms_key=output_kms_key,
max_runtime_in_seconds=3600,
base_job_name="test-sklearn-with-customizations",
env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"},
tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}],
sagemaker_session=sagemaker_session_with_custom_bucket,
)

sklearn_processor.run(
code=os.path.join(DATA_DIR, "dummy_script.py"),
inputs=[
ProcessingInput(
source=input_file_path,
destination="/opt/ml/processing/input/container/path/",
input_name="dummy_input",
s3_data_type="S3Prefix",
s3_input_mode="File",
s3_data_distribution_type="FullyReplicated",
s3_compression_type="None",
)
],
outputs=[
ProcessingOutput(
source="/opt/ml/processing/output/container/path/",
output_name="dummy_output",
s3_upload_mode="EndOfJob",
)
],
arguments=["-v"],
wait=True,
logs=True,
)

job_description = sklearn_processor.latest_job.describe()

assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input"
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]

assert job_description["ProcessingInputs"][1]["InputName"] == "code"
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"]

assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations")

assert job_description["ProcessingJobStatus"] == "Completed"

assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key
assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output"

assert job_description["ProcessingResources"] == {
"ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100}
}

assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
"python3",
"/opt/ml/processing/input/code/dummy_script.py",
]
assert job_description["AppSpecification"]["ImageUri"] == image_uri

assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}

assert ROLE in job_description["RoleArn"]

assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}


def test_sklearn_with_no_inputs_or_outputs(
sagemaker_session, image_uri, sklearn_full_version, cpu_instance_type
):
Expand Down Expand Up @@ -405,3 +520,72 @@ def test_processor(sagemaker_session, image_uri, cpu_instance_type, output_kms_k
assert ROLE in job_description["RoleArn"]

assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}


def test_processor_with_custom_bucket(
sagemaker_session_with_custom_bucket, image_uri, cpu_instance_type, output_kms_key
):
script_path = os.path.join(DATA_DIR, "dummy_script.py")

processor = Processor(
role=ROLE,
image_uri=image_uri,
instance_count=1,
instance_type=cpu_instance_type,
entrypoint=["python3", "/opt/ml/processing/input/code/dummy_script.py"],
volume_size_in_gb=100,
volume_kms_key=None,
output_kms_key=output_kms_key,
max_runtime_in_seconds=3600,
base_job_name="test-processor",
env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"},
tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}],
sagemaker_session=sagemaker_session_with_custom_bucket,
)

processor.run(
inputs=[
ProcessingInput(
source=script_path, destination="/opt/ml/processing/input/code/", input_name="code"
)
],
outputs=[
ProcessingOutput(
source="/opt/ml/processing/output/container/path/",
output_name="dummy_output",
s3_upload_mode="EndOfJob",
)
],
arguments=["-v"],
wait=True,
logs=True,
)

job_description = processor.latest_job.describe()

assert job_description["ProcessingInputs"][0]["InputName"] == "code"
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]

assert job_description["ProcessingJobName"].startswith("test-processor")

assert job_description["ProcessingJobStatus"] == "Completed"

assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key
assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output"

assert job_description["ProcessingResources"] == {
"ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100}
}

assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
"python3",
"/opt/ml/processing/input/code/dummy_script.py",
]
assert job_description["AppSpecification"]["ImageUri"] == image_uri

assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}

assert ROLE in job_description["RoleArn"]

assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
50 changes: 50 additions & 0 deletions tests/integ/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import boto3
from botocore.config import Config

from sagemaker import Session

DEFAULT_REGION = "us-west-2"
CUSTOM_BUCKET_NAME = "this-bucket-should-not-exist"


def test_sagemaker_session_does_not_create_bucket_on_init(
sagemaker_client_config, sagemaker_runtime_config, boto_config
):
boto_session = (
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
)
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
sagemaker_client = (
boto_session.client("sagemaker", **sagemaker_client_config)
if sagemaker_client_config
else None
)
runtime_client = (
boto_session.client("sagemaker-runtime", **sagemaker_runtime_config)
if sagemaker_runtime_config
else None
)

Session(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=runtime_client,
default_bucket=CUSTOM_BUCKET_NAME,
)

s3 = boto3.resource("s3")
assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None