Skip to content

Commit da66f6f

Browse files
authored
fix: run Python script entry point as script and install from requirements.txt (#64)
1 parent f02ed4b commit da66f6f

File tree

5 files changed

+40
-8
lines changed

5 files changed

+40
-8
lines changed

src/sagemaker_training/entry_point.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,10 @@ def install(name, path=environment.code_dir, capture_error=False):
117117

118118
entry_point_type = _entry_point_type.get(path, name)
119119

120-
if (
121-
entry_point_type is _entry_point_type.PYTHON_PACKAGE
122-
or entry_point_type is _entry_point_type.PYTHON_PROGRAM
123-
or modules.has_requirements(path)
124-
):
125-
modules.prepare(path, name)
120+
if entry_point_type is _entry_point_type.PYTHON_PACKAGE:
126121
modules.install(path, capture_error)
122+
elif entry_point_type is _entry_point_type.PYTHON_PROGRAM and modules.has_requirements(path):
123+
modules.install_requirements(path, capture_error)
127124

128125
if entry_point_type is _entry_point_type.COMMAND:
129126
os.chmod(os.path.join(path, name), 511)

src/sagemaker_training/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ class InstallModuleError(_CalledProcessError):
5252
"""Error class indicating a module failed to install."""
5353

5454

55+
class InstallRequirementsError(_CalledProcessError):
56+
"""Error class indicating a module failed to install."""
57+
58+
5559
class ImportModuleError(ClientError):
5660
"""Error class indicating a module failed to import."""
5761

src/sagemaker_training/modules.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,23 @@ def install(path, capture_error=False): # type: (str, bool) -> None
129129
)
130130

131131

132+
def install_requirements(path, capture_error=False): # type: (str, bool) -> None
133+
"""Install dependencies from requirements.txt in the executing Python environment.
134+
135+
Args:
136+
path (str): Real path location of the requirements.txt file.
137+
capture_error (bool): Default false. If True, the running process captures the
138+
stderr, and appends it to the returned Exception message in case of errors.
139+
"""
140+
cmd = "{} -m pip install -r requirements.txt".format(process.python_executable())
141+
142+
logger.info("Installing dependencies from requirements.txt:\n{}".format(cmd))
143+
144+
process.check_error(
145+
shlex.split(cmd), errors.InstallRequirementsError, cwd=path, capture_error=capture_error
146+
)
147+
148+
132149
def import_module(uri, name=DEFAULT_MODULE_NAME): # type: (str, str) -> module
133150
"""Download, prepare and install a compressed tar file from S3 or provided directory as a
134151
module.

test/integration/local/test_dummy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,6 @@ def test_install_requirements(capsys):
4747

4848
stdout = capsys.readouterr().out
4949

50-
assert "Installing collected packages: pyfiglet, train.py" in stdout
51-
assert "Successfully installed pyfiglet-0.8.post1 train.py-1.0.0" in stdout
50+
assert "Installing collected packages: pyfiglet" in stdout
51+
assert "Successfully installed pyfiglet-0.8.post1" in stdout
5252
assert "Reporting training SUCCESS" in stdout

test/unit/test_modules.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,20 @@ def test_install(check_error):
7676
)
7777

7878

79+
@patch("sagemaker_training.process.check_error", autospec=True)
80+
def test_install_requirements(check_error):
81+
path = "c://sagemaker-pytorch-container"
82+
83+
cmd = [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"]
84+
85+
with patch("os.path.exists", return_value=True):
86+
modules.install_requirements(path)
87+
88+
check_error.assert_called_with(
89+
cmd, errors.InstallRequirementsError, cwd=path, capture_error=False
90+
)
91+
92+
7993
@patch("sagemaker_training.process.check_error", autospec=True)
8094
def test_install_fails(check_error):
8195
check_error.side_effect = errors.ClientError()

0 commit comments

Comments
 (0)