1818import boto3
1919import pytest
2020from sagemaker import LocalSession , Session
21- from sagemaker .tensorflow import TensorFlow
2221
23- from test .integration import NO_P2_REGIONS , NO_P3_REGIONS
22+ from integration import image_utils
23+ from integration import NO_P2_REGIONS , NO_P3_REGIONS
24+
2425
2526logger = logging .getLogger (__name__ )
2627logging .getLogger ('boto' ).setLevel (logging .INFO )
2930logging .getLogger ('auth.py' ).setLevel (logging .INFO )
3031logging .getLogger ('connectionpool.py' ).setLevel (logging .INFO )
3132
32- SCRIPT_PATH = os .path .dirname (os .path .realpath (__file__ ))
33+ DIR_PATH = os .path .dirname (os .path .realpath (__file__ ))
3334
3435
3536def pytest_addoption (parser ):
36- parser .addoption ('--docker-base-name' , default = 'sagemaker-tensorflow-scriptmode' )
37+ parser .addoption ('--build-image' , '-B' , action = 'store_true' )
38+ parser .addoption ('--push-image' , '-P' , action = 'store_true' )
39+ parser .addoption ('--dockerfile-type' , '-T' , choices = ['dlc.cpu' , 'dlc.gpu' , 'tf' ],
40+ default = 'tf' )
41+ parser .addoption ('--dockerfile' , '-D' , default = None )
42+ parser .addoption ('--docker-base-name' , default = 'sagemaker-tensorflow-training' )
3743 parser .addoption ('--tag' , default = None )
3844 parser .addoption ('--region' , default = 'us-west-2' )
39- parser .addoption ('--framework-version' , default = TensorFlow . LATEST_VERSION )
45+ parser .addoption ('--framework-version' , default = '1.15.2' )
4046 parser .addoption ('--processor' , default = 'cpu' , choices = ['cpu' , 'gpu' , 'cpu,gpu' ])
4147 parser .addoption ('--py-version' , default = '3' , choices = ['2' , '3' , '2,3' ])
4248 parser .addoption ('--account-id' , default = '142577830533' )
@@ -48,6 +54,38 @@ def pytest_configure(config):
4854 os .environ ['TEST_PROCESSORS' ] = config .getoption ('--processor' )
4955
5056
57+ @pytest .fixture (scope = 'session' , name = 'dockerfile_type' )
58+ def fixture_dockerfile_type (request ):
59+ return request .config .getoption ('--dockerfile-type' )
60+
61+
62+ @pytest .fixture (scope = 'session' , name = 'dockerfile' )
63+ def fixture_dockerfile (request , dockerfile_type ):
64+ dockerfile = request .config .getoption ('--dockerfile' )
65+ return dockerfile if dockerfile else 'Dockerfile.{}' .format (dockerfile_type )
66+
67+
68+ @pytest .fixture (scope = 'session' , name = 'build_image' , autouse = True )
69+ def fixture_build_image (request , framework_version , dockerfile , image_uri , region ):
70+ build_image = request .config .getoption ('--build-image' )
71+ if build_image :
72+ return image_utils .build_image (framework_version = framework_version ,
73+ dockerfile = dockerfile ,
74+ image_uri = image_uri ,
75+ region = region ,
76+ cwd = os .path .join (DIR_PATH , '..' , '..' ))
77+
78+ return image_uri
79+
80+
81+ @pytest .fixture (scope = 'session' , name = 'push_image' , autouse = True )
82+ def fixture_push_image (request , image_uri , region , account_id ):
83+ push_image = request .config .getoption ('--push-image' )
84+ if push_image :
85+ return image_utils .push_image (image_uri , region , account_id )
86+ return None
87+
88+
5189@pytest .fixture (scope = 'session' )
5290def docker_base_name (request ):
5391 return request .config .getoption ('--docker-base-name' )
@@ -63,7 +101,7 @@ def framework_version(request):
63101 return request .config .getoption ('--framework-version' )
64102
65103
66- @pytest .fixture
104+ @pytest .fixture ( scope = 'session' )
67105def tag (request , framework_version , processor , py_version ):
68106 provided_tag = request .config .getoption ('--tag' )
69107 default_tag = '{}-{}-py{}' .format (framework_version , processor , py_version )
@@ -107,12 +145,20 @@ def skip_gpu_instance_restricted_regions(region, instance_type):
107145 pytest .skip ('Skipping GPU test in region {}' .format (region ))
108146
109147
110- @pytest .fixture
111- def docker_image (docker_base_name , tag ):
112- return '{}:{}' .format (docker_base_name , tag )
148+ @pytest .fixture (autouse = True )
149+ def skip_by_dockerfile_type (request , dockerfile_type ):
150+ is_generic = (dockerfile_type == 'tf' )
151+ if request .node .get_closest_marker ('skip_generic' ) and is_generic :
152+ pytest .skip ('Skipping because running generic image without mpi and horovod' )
113153
114154
115- @pytest .fixture
116- def ecr_image (account_id , docker_base_name , tag , region ):
117- return '{}.dkr.ecr.{}.amazonaws.com/{}:{}' .format (
118- account_id , region , docker_base_name , tag )
155+ @pytest .fixture (name = 'docker_registry' , scope = 'session' )
156+ def fixture_docker_registry (account_id , region ):
157+ return '{}.dkr.ecr.{}.amazonaws.com' .format (account_id , region ) if account_id else None
158+
159+
160+ @pytest .fixture (name = 'image_uri' , scope = 'session' )
161+ def fixture_image_uri (docker_registry , docker_base_name , tag ):
162+ if docker_registry :
163+ return '{}/{}:{}' .format (docker_registry , docker_base_name , tag )
164+ return '{}:{}' .format (docker_base_name , tag )
0 commit comments