Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/sagemaker_training/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,10 @@ def install(name, path=environment.code_dir, capture_error=False):

entry_point_type = _entry_point_type.get(path, name)

if (
entry_point_type is _entry_point_type.PYTHON_PACKAGE
or entry_point_type is _entry_point_type.PYTHON_PROGRAM
or modules.has_requirements(path)
):
modules.prepare(path, name)
if entry_point_type is _entry_point_type.PYTHON_PACKAGE:
modules.install(path, capture_error)
elif entry_point_type is _entry_point_type.PYTHON_PROGRAM and modules.has_requirements(path):
modules.install_requirements(path, capture_error)

if entry_point_type is _entry_point_type.COMMAND:
os.chmod(os.path.join(path, name), 511)
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker_training/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class InstallModuleError(_CalledProcessError):
"""Error class indicating a module failed to install."""


class InstallRequirementsError(_CalledProcessError):
"""Error class indicating a module failed to install."""


class ImportModuleError(ClientError):
"""Error class indicating a module failed to import."""

Expand Down
17 changes: 17 additions & 0 deletions src/sagemaker_training/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,23 @@ def install(path, capture_error=False): # type: (str, bool) -> None
)


def install_requirements(path, capture_error=False): # type: (str, bool) -> None
"""Install dependencies from requirements.txt the executing Python environment.

Args:
path (str): Real path location of the requirements.txt file.
capture_error (bool): Default false. If True, the running process captures the
stderr, and appends it to the returned Exception message in case of errors.
"""
cmd = "%s -m pip install -r requirements.txt" % process.python_executable()

logger.info("Installing dependencies from requirements.txt:\n%s", cmd)

process.check_error(
shlex.split(cmd), errors.InstallRequirementsError, cwd=path, capture_error=capture_error
)


def import_module(uri, name=DEFAULT_MODULE_NAME): # type: (str, str) -> module
"""Download, prepare and install a compressed tar file from S3 or provided directory as a
module.
Expand Down
4 changes: 2 additions & 2 deletions test/integration/local/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ def test_install_requirements(capsys):

stdout = capsys.readouterr().out

assert "Installing collected packages: pyfiglet, train.py" in stdout
assert "Successfully installed pyfiglet-0.8.post1 train.py-1.0.0" in stdout
assert "Installing collected packages: pyfiglet" in stdout
assert "Successfully installed pyfiglet-0.8.post1" in stdout
assert "Reporting training SUCCESS" in stdout