diff --git a/changelogs/fragments/2270-aws_ssm-refactor-create-terminalmanager-class.yml b/changelogs/fragments/2270-aws_ssm-refactor-create-terminalmanager-class.yml new file mode 100644 index 00000000000..89d756ec584 --- /dev/null +++ b/changelogs/fragments/2270-aws_ssm-refactor-create-terminalmanager-class.yml @@ -0,0 +1,3 @@ +--- +minor_changes: + - aws_ssm - Refactor connection/aws_ssm to add new TerminalManager class and move relevant methods to the new class (https://github.com/ansible-collections/community.aws/pull/2270). diff --git a/plugins/connection/aws_ssm.py b/plugins/connection/aws_ssm.py index f39e480d3f0..17ecaa4983b 100644 --- a/plugins/connection/aws_ssm.py +++ b/plugins/connection/aws_ssm.py @@ -356,12 +356,12 @@ from ansible.module_utils.basic import missing_required_lib from ansible.module_utils.common.process import get_bin_path from ansible.plugins.connection import ConnectionBase -from ansible.plugins.shell.powershell import _common_args from ansible.utils.display import Display from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3 from ansible_collections.community.aws.plugins.plugin_utils.s3clientmanager import S3ClientManager +from ansible_collections.community.aws.plugins.plugin_utils.terminalmanager import TerminalManager display = Display() @@ -484,6 +484,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._instance_id = None self._polling_obj = None self._has_timeout = False + self.terminal_manager = TerminalManager(self) if getattr(self._shell, "SHELL_FAMILY", "") == "powershell": self.delegate = None @@ -645,7 +646,7 @@ def start_session(self): self._stdout = os.fdopen(stdout_r, "rb", 0) # For non-windows Hosts: Ensure the session has started, and disable command echo and prompt. - self._prepare_terminal() + self.terminal_manager.prepare_terminal() self.verbosity_display(4, f"SSM CONNECTION ID: {self._session_id}") # pylint: disable=unreachable @@ -743,7 +744,7 @@ def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) -> mark_end = self.generate_mark() # Wrap command in markers accordingly for the shell used - cmd = self._wrap_command(cmd, mark_start, mark_end) + cmd = self.terminal_manager.wrap_command(cmd, mark_start, mark_end) self._flush_stderr(self._session) @@ -752,92 +753,6 @@ def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) -> return self.exec_communicate(cmd, mark_start, mark_begin, mark_end) - def _ensure_ssm_session_has_started(self) -> None: - """Ensure the SSM session has started on the host. We poll stdout - until we match the following string 'Starting session with SessionId' - """ - stdout = "" - for poll_result in self.poll("START SSM SESSION", "start_session"): - if poll_result: - stdout += to_text(self._stdout.read(1024)) - self.verbosity_display(4, f"START SSM SESSION stdout line: \n{to_bytes(stdout)}") - match = str(stdout).find("Starting session with SessionId") - if match != -1: - self.verbosity_display(4, "START SSM SESSION startup output received") - break - - def _disable_prompt_command(self) -> None: - """Disable prompt command from the host""" - end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)]) - disable_prompt_cmd = to_bytes( - "PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n", - errors="surrogate_or_strict", - ) - disable_prompt_reply = re.compile(r"\r\r\n" + re.escape(end_mark) + r"\r\r\n", re.MULTILINE) - - # Send command - self.verbosity_display(4, f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}") - self._session.stdin.write(disable_prompt_cmd) - - stdout = "" - for poll_result in self.poll("DISABLE PROMPT", disable_prompt_cmd): - if poll_result: - stdout += to_text(self._stdout.read(1024)) - self.verbosity_display(4, f"DISABLE PROMPT stdout line: \n{to_bytes(stdout)}") - if disable_prompt_reply.search(stdout): - break - - def _disable_echo_command(self) -> None: - """Disable echo command from the host""" - disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict") - - # Send command - self.verbosity_display(4, f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}") - self._session.stdin.write(disable_echo_cmd) - - stdout = "" - for poll_result in self.poll("DISABLE ECHO", disable_echo_cmd): - if poll_result: - stdout += to_text(self._stdout.read(1024)) - self.verbosity_display(4, f"DISABLE ECHO stdout line: \n{to_bytes(stdout)}") - match = str(stdout).find("stty -echo") - if match != -1: - break - - def _prepare_terminal(self) -> None: - """perform any one-time terminal settings""" - # No Windows setup for now - if self.is_windows: - return - - # Ensure SSM Session has started - self._ensure_ssm_session_has_started() - - # Disable echo command - self._disable_echo_command() # pylint: disable=unreachable - - # Disable prompt command - self._disable_prompt_command() # pylint: disable=unreachable - - self.verbosity_display(4, "PRE Terminal configured") # pylint: disable=unreachable - - def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str: - """Wrap command so stdout and status can be extracted""" - - if self.is_windows: - if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"): - cmd = self._shell._encode_script(cmd, preserve_rc=True) - cmd = cmd + "; echo " + mark_start + "\necho " + mark_end + "\n" - else: - cmd = ( - f"printf '%s\\n' '{mark_start}';\n" - f"echo | {cmd};\n" - f"printf '\\n%s\\n%s\\n' \"$?\" '{mark_end}';\n" - ) # fmt: skip - - self.verbosity_display(4, f"_wrap_command: \n'{to_text(cmd)}'") - return cmd - def _post_process(self, stdout: str, mark_begin: str) -> Tuple[str, str]: """extract command status and strip unwanted lines""" diff --git a/plugins/plugin_utils/terminalmanager.py b/plugins/plugin_utils/terminalmanager.py new file mode 100644 index 00000000000..b9bde4813f5 --- /dev/null +++ b/plugins/plugin_utils/terminalmanager.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +# Copyright: Contributors to the Ansible project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +import random +import re +import string + +from ansible.module_utils._text import to_bytes +from ansible.module_utils._text import to_text +from ansible.plugins.shell.powershell import _common_args + + +class TerminalManager: + def __init__(self, connection): + self.connection = connection + + def prepare_terminal(self) -> None: + """perform any one-time terminal settings""" + # No Windows setup for now + if self.connection.is_windows: + return + + # Ensure SSM Session has started + self.ensure_ssm_session_has_started() + + # Disable echo command + self.disable_echo_command() # pylint: disable=unreachable + + # Disable prompt command + self.disable_prompt_command() # pylint: disable=unreachable + + self.connection.verbosity_display(4, "PRE Terminal configured") # pylint: disable=unreachable + + def wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str: + """Wrap command so stdout and status can be extracted""" + + if self.connection.is_windows: + if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"): + cmd = self.connection._shell._encode_script(cmd, preserve_rc=True) + cmd = f"{cmd}; echo {mark_start}\necho {mark_end}\n" + else: + cmd = ( + f"printf '%s\\n' '{mark_start}';\n" + f"echo | {cmd};\n" + f"printf '\\n%s\\n%s\\n' \"$?\" '{mark_end}';\n" + ) # fmt: skip + + self.connection.verbosity_display(4, f"wrap_command: \n'{to_text(cmd)}'") + return cmd + + def disable_echo_command(self) -> None: + """Disable echo command from the host""" + disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict") + + # Send command + self.connection.verbosity_display(4, f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}") + self.connection._session.stdin.write(disable_echo_cmd) + + stdout = "" + for poll_result in self.connection.poll("DISABLE ECHO", disable_echo_cmd): + if poll_result: + stdout += to_text(self.connection._stdout.read(1024)) + self.connection.verbosity_display(4, f"DISABLE ECHO stdout line: \n{to_bytes(stdout)}") + match = str(stdout).find("stty -echo") + if match != -1: + break + + def disable_prompt_command(self) -> None: + """Disable prompt command from the host""" + end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.connection.MARK_LENGTH)]) + disable_prompt_cmd = to_bytes( + "PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n", + errors="surrogate_or_strict", + ) + disable_prompt_reply = re.compile(r"\r\r\n" + re.escape(end_mark) + r"\r\r\n", re.MULTILINE) + + # Send command + self.connection.verbosity_display(4, f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}") + self.connection._session.stdin.write(disable_prompt_cmd) + + stdout = "" + for poll_result in self.connection.poll("DISABLE PROMPT", disable_prompt_cmd): + if poll_result: + stdout += to_text(self.connection._stdout.read(1024)) + self.connection.verbosity_display(4, f"DISABLE PROMPT stdout line: \n{to_bytes(stdout)}") + if disable_prompt_reply.search(stdout): + break + + def ensure_ssm_session_has_started(self) -> None: + """Ensure the SSM session has started on the host. We poll stdout + until we match the following string 'Starting session with SessionId' + """ + stdout = "" + for poll_result in self.connection.poll("START SSM SESSION", "start_session"): + if poll_result: + stdout += to_text(self.connection._stdout.read(1024)) + self.connection.verbosity_display(4, f"START SSM SESSION stdout line: \n{to_bytes(stdout)}") + match = str(stdout).find("Starting session with SessionId") + if match != -1: + self.connection.verbosity_display(4, "START SSM SESSION startup output received") + break diff --git a/tests/unit/plugins/connection/aws_ssm/test_exec_command.py b/tests/unit/plugins/connection/aws_ssm/test_exec_command.py index 0817cea1123..78bc71af122 100644 --- a/tests/unit/plugins/connection/aws_ssm/test_exec_command.py +++ b/tests/unit/plugins/connection/aws_ssm/test_exec_command.py @@ -141,6 +141,8 @@ def test_connection_aws_ssm_exec_command(m_chunks, connection_aws_ssm, is_window cmd = MagicMock() in_data = MagicMock() sudoable = MagicMock() + connection_aws_ssm.terminal_manager = MagicMock() + assert result == connection_aws_ssm.exec_command(cmd, in_data, sudoable) # m_chunks.assert_called_once_with(chunk, 1024) connection_aws_ssm._flush_stderr.assert_called_once_with(connection_aws_ssm._session) diff --git a/tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py b/tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py index fe6f2361402..9319d17b9fc 100644 --- a/tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py +++ b/tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py @@ -16,6 +16,8 @@ from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3 +from ansible_collections.community.aws.plugins.connection.aws_ssm import TerminalManager + if not HAS_BOTO3: pytestmark = pytest.mark.skip("test_poll.py requires the python modules 'boto3' and 'botocore'") @@ -47,15 +49,18 @@ def test_ensure_ssm_session_has_started(m_to_text, m_to_bytes, connection_aws_ss connection_aws_ssm._stdout.read.side_effect = stdout_lines + if not hasattr(connection_aws_ssm, "terminal_manager"): + connection_aws_ssm.terminal_manager = TerminalManager(connection_aws_ssm) + poll_mock.results = [True for i in range(len(stdout_lines))] connection_aws_ssm.poll = MagicMock() connection_aws_ssm.poll.side_effect = poll_mock if timeout_failure: with pytest.raises(TimeoutError): - connection_aws_ssm._ensure_ssm_session_has_started() + connection_aws_ssm.terminal_manager.ensure_ssm_session_has_started() else: - connection_aws_ssm._ensure_ssm_session_has_started() + connection_aws_ssm.terminal_manager.ensure_ssm_session_has_started() @pytest.mark.parametrize( @@ -67,8 +72,8 @@ def test_ensure_ssm_session_has_started(m_to_text, m_to_bytes, connection_aws_ss (["stty ", "-ech"], True), ], ) -@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes") -@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text") +@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.to_bytes") +@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.to_text") def test_disable_echo_command(m_to_text, m_to_bytes, connection_aws_ssm, stdout_lines, timeout_failure): m_to_text.side_effect = str m_to_bytes.side_effect = lambda x, **kw: str(x) @@ -80,19 +85,22 @@ def test_disable_echo_command(m_to_text, m_to_bytes, connection_aws_ssm, stdout_ connection_aws_ssm.poll = MagicMock() connection_aws_ssm.poll.side_effect = poll_mock + if not hasattr(connection_aws_ssm, "terminal_manager"): + connection_aws_ssm.terminal_manager = TerminalManager(connection_aws_ssm) + if timeout_failure: with pytest.raises(TimeoutError): - connection_aws_ssm._disable_echo_command() + connection_aws_ssm.terminal_manager.disable_echo_command() else: - connection_aws_ssm._disable_echo_command() + connection_aws_ssm.terminal_manager.disable_echo_command() connection_aws_ssm._session.stdin.write.assert_called_once_with("stty -echo\n") @pytest.mark.parametrize("timeout_failure", [True, False]) -@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.random") -@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes") -@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text") +@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.random") +@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.to_bytes") +@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.to_text") def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_ssm, timeout_failure): m_to_text.side_effect = str m_to_bytes.side_effect = lambda x, **kw: str(x) @@ -101,6 +109,9 @@ def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_ connection_aws_ssm.poll = MagicMock() connection_aws_ssm.poll.side_effect = poll_mock + if not hasattr(connection_aws_ssm, "terminal_manager"): + connection_aws_ssm.terminal_manager = TerminalManager(connection_aws_ssm) + m_random.choice = MagicMock() m_random.choice.side_effect = lambda x: "a" @@ -115,8 +126,8 @@ def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_ if timeout_failure: with pytest.raises(TimeoutError): - connection_aws_ssm._disable_prompt_command() + connection_aws_ssm.terminal_manager.disable_prompt_command() else: - connection_aws_ssm._disable_prompt_command() + connection_aws_ssm.terminal_manager.disable_prompt_command() connection_aws_ssm._session.stdin.write.assert_called_once_with(prompt_cmd)