Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
minor_changes:
- aws_ssm - Refactor connection/aws_ssm to add new TerminalManager class and move relevant methods to the new class (https://github.com/ansible-collections/community.aws/pull/2270).
93 changes: 4 additions & 89 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,12 @@
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.common.process import get_bin_path
from ansible.plugins.connection import ConnectionBase
from ansible.plugins.shell.powershell import _common_args
from ansible.utils.display import Display

from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3

from ansible_collections.community.aws.plugins.plugin_utils.s3clientmanager import S3ClientManager
from ansible_collections.community.aws.plugins.plugin_utils.terminalmanager import TerminalManager

display = Display()

Expand Down Expand Up @@ -484,6 +484,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._instance_id = None
self._polling_obj = None
self._has_timeout = False
self.terminal_manager = TerminalManager(self)

if getattr(self._shell, "SHELL_FAMILY", "") == "powershell":
self.delegate = None
Expand Down Expand Up @@ -645,7 +646,7 @@ def start_session(self):
self._stdout = os.fdopen(stdout_r, "rb", 0)

# For non-windows Hosts: Ensure the session has started, and disable command echo and prompt.
self._prepare_terminal()
self.terminal_manager.prepare_terminal()

self.verbosity_display(4, f"SSM CONNECTION ID: {self._session_id}") # pylint: disable=unreachable

Expand Down Expand Up @@ -743,7 +744,7 @@ def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) ->
mark_end = self.generate_mark()

# Wrap command in markers accordingly for the shell used
cmd = self._wrap_command(cmd, mark_start, mark_end)
cmd = self.terminal_manager.wrap_command(cmd, mark_start, mark_end)

self._flush_stderr(self._session)

Expand All @@ -752,92 +753,6 @@ def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) ->

return self.exec_communicate(cmd, mark_start, mark_begin, mark_end)

def _ensure_ssm_session_has_started(self) -> None:
"""Ensure the SSM session has started on the host. We poll stdout
until we match the following string 'Starting session with SessionId'
"""
stdout = ""
for poll_result in self.poll("START SSM SESSION", "start_session"):
if poll_result:
stdout += to_text(self._stdout.read(1024))
self.verbosity_display(4, f"START SSM SESSION stdout line: \n{to_bytes(stdout)}")
match = str(stdout).find("Starting session with SessionId")
if match != -1:
self.verbosity_display(4, "START SSM SESSION startup output received")
break

def _disable_prompt_command(self) -> None:
"""Disable prompt command from the host"""
end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])
disable_prompt_cmd = to_bytes(
"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n",
errors="surrogate_or_strict",
)
disable_prompt_reply = re.compile(r"\r\r\n" + re.escape(end_mark) + r"\r\r\n", re.MULTILINE)

# Send command
self.verbosity_display(4, f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}")
self._session.stdin.write(disable_prompt_cmd)

stdout = ""
for poll_result in self.poll("DISABLE PROMPT", disable_prompt_cmd):
if poll_result:
stdout += to_text(self._stdout.read(1024))
self.verbosity_display(4, f"DISABLE PROMPT stdout line: \n{to_bytes(stdout)}")
if disable_prompt_reply.search(stdout):
break

def _disable_echo_command(self) -> None:
"""Disable echo command from the host"""
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")

# Send command
self.verbosity_display(4, f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}")
self._session.stdin.write(disable_echo_cmd)

stdout = ""
for poll_result in self.poll("DISABLE ECHO", disable_echo_cmd):
if poll_result:
stdout += to_text(self._stdout.read(1024))
self.verbosity_display(4, f"DISABLE ECHO stdout line: \n{to_bytes(stdout)}")
match = str(stdout).find("stty -echo")
if match != -1:
break

def _prepare_terminal(self) -> None:
"""perform any one-time terminal settings"""
# No Windows setup for now
if self.is_windows:
return

# Ensure SSM Session has started
self._ensure_ssm_session_has_started()

# Disable echo command
self._disable_echo_command() # pylint: disable=unreachable

# Disable prompt command
self._disable_prompt_command() # pylint: disable=unreachable

self.verbosity_display(4, "PRE Terminal configured") # pylint: disable=unreachable

def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
"""Wrap command so stdout and status can be extracted"""

if self.is_windows:
if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"):
cmd = self._shell._encode_script(cmd, preserve_rc=True)
cmd = cmd + "; echo " + mark_start + "\necho " + mark_end + "\n"
else:
cmd = (
f"printf '%s\\n' '{mark_start}';\n"
f"echo | {cmd};\n"
f"printf '\\n%s\\n%s\\n' \"$?\" '{mark_end}';\n"
) # fmt: skip

self.verbosity_display(4, f"_wrap_command: \n'{to_text(cmd)}'")
return cmd

def _post_process(self, stdout: str, mark_begin: str) -> Tuple[str, str]:
"""extract command status and strip unwanted lines"""

Expand Down
103 changes: 103 additions & 0 deletions plugins/plugin_utils/terminalmanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-

# Copyright: Contributors to the Ansible project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

import random
import re
import string

from ansible.module_utils._text import to_bytes
from ansible.module_utils._text import to_text
from ansible.plugins.shell.powershell import _common_args


class TerminalManager:
def __init__(self, connection):
self.connection = connection

def prepare_terminal(self) -> None:
"""perform any one-time terminal settings"""
# No Windows setup for now
if self.connection.is_windows:
return

# Ensure SSM Session has started
self.ensure_ssm_session_has_started()

# Disable echo command
self.disable_echo_command() # pylint: disable=unreachable

# Disable prompt command
self.disable_prompt_command() # pylint: disable=unreachable

self.connection.verbosity_display(4, "PRE Terminal configured") # pylint: disable=unreachable

def wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
"""Wrap command so stdout and status can be extracted"""

if self.connection.is_windows:
if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"):
cmd = self.connection._shell._encode_script(cmd, preserve_rc=True)
cmd = cmd + "; echo " + mark_start + "\necho " + mark_end + "\n"
else:
cmd = (
f"printf '%s\\n' '{mark_start}';\n"
f"echo | {cmd};\n"
f"printf '\\n%s\\n%s\\n' \"$?\" '{mark_end}';\n"
) # fmt: skip

self.connection.verbosity_display(4, f"wrap_command: \n'{to_text(cmd)}'")
return cmd

def disable_echo_command(self) -> None:
"""Disable echo command from the host"""
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")

# Send command
self.connection.verbosity_display(4, f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}")
self.connection._session.stdin.write(disable_echo_cmd)

stdout = ""
for poll_result in self.connection.poll("DISABLE ECHO", disable_echo_cmd):
if poll_result:
stdout += to_text(self.connection._stdout.read(1024))
self.connection.verbosity_display(4, f"DISABLE ECHO stdout line: \n{to_bytes(stdout)}")
match = str(stdout).find("stty -echo")
if match != -1:
break

def disable_prompt_command(self) -> None:
"""Disable prompt command from the host"""
end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.connection.MARK_LENGTH)])
disable_prompt_cmd = to_bytes(
"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n",
errors="surrogate_or_strict",
)
disable_prompt_reply = re.compile(r"\r\r\n" + re.escape(end_mark) + r"\r\r\n", re.MULTILINE)

# Send command
self.connection.verbosity_display(4, f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}")
self.connection._session.stdin.write(disable_prompt_cmd)

stdout = ""
for poll_result in self.connection.poll("DISABLE PROMPT", disable_prompt_cmd):
if poll_result:
stdout += to_text(self.connection._stdout.read(1024))
self.connection.verbosity_display(4, f"DISABLE PROMPT stdout line: \n{to_bytes(stdout)}")
if disable_prompt_reply.search(stdout):
break

def ensure_ssm_session_has_started(self) -> None:
"""Ensure the SSM session has started on the host. We poll stdout
until we match the following string 'Starting session with SessionId'
"""
stdout = ""
for poll_result in self.connection.poll("START SSM SESSION", "start_session"):
if poll_result:
stdout += to_text(self.connection._stdout.read(1024))
self.connection.verbosity_display(4, f"START SSM SESSION stdout line: \n{to_bytes(stdout)}")
match = str(stdout).find("Starting session with SessionId")
if match != -1:
self.connection.verbosity_display(4, "START SSM SESSION startup output received")
break
2 changes: 2 additions & 0 deletions tests/unit/plugins/connection/aws_ssm/test_exec_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def test_connection_aws_ssm_exec_command(m_chunks, connection_aws_ssm, is_window
cmd = MagicMock()
in_data = MagicMock()
sudoable = MagicMock()
connection_aws_ssm.terminal_manager = MagicMock()

assert result == connection_aws_ssm.exec_command(cmd, in_data, sudoable)
# m_chunks.assert_called_once_with(chunk, 1024)
connection_aws_ssm._flush_stderr.assert_called_once_with(connection_aws_ssm._session)
33 changes: 22 additions & 11 deletions tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3

from ansible_collections.community.aws.plugins.connection.aws_ssm import TerminalManager

if not HAS_BOTO3:
pytestmark = pytest.mark.skip("test_poll.py requires the python modules 'boto3' and 'botocore'")

Expand Down Expand Up @@ -47,15 +49,18 @@ def test_ensure_ssm_session_has_started(m_to_text, m_to_bytes, connection_aws_ss

connection_aws_ssm._stdout.read.side_effect = stdout_lines

if not hasattr(connection_aws_ssm, "terminal_manager"):
connection_aws_ssm.terminal_manager = TerminalManager(connection_aws_ssm)

poll_mock.results = [True for i in range(len(stdout_lines))]
connection_aws_ssm.poll = MagicMock()
connection_aws_ssm.poll.side_effect = poll_mock

if timeout_failure:
with pytest.raises(TimeoutError):
connection_aws_ssm._ensure_ssm_session_has_started()
connection_aws_ssm.terminal_manager.ensure_ssm_session_has_started()
else:
connection_aws_ssm._ensure_ssm_session_has_started()
connection_aws_ssm.terminal_manager.ensure_ssm_session_has_started()


@pytest.mark.parametrize(
Expand All @@ -67,8 +72,8 @@ def test_ensure_ssm_session_has_started(m_to_text, m_to_bytes, connection_aws_ss
(["stty ", "-ech"], True),
],
)
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text")
@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.to_bytes")
@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.to_text")
def test_disable_echo_command(m_to_text, m_to_bytes, connection_aws_ssm, stdout_lines, timeout_failure):
m_to_text.side_effect = str
m_to_bytes.side_effect = lambda x, **kw: str(x)
Expand All @@ -80,19 +85,22 @@ def test_disable_echo_command(m_to_text, m_to_bytes, connection_aws_ssm, stdout_
connection_aws_ssm.poll = MagicMock()
connection_aws_ssm.poll.side_effect = poll_mock

if not hasattr(connection_aws_ssm, "terminal_manager"):
connection_aws_ssm.terminal_manager = TerminalManager(connection_aws_ssm)

if timeout_failure:
with pytest.raises(TimeoutError):
connection_aws_ssm._disable_echo_command()
connection_aws_ssm.terminal_manager.disable_echo_command()
else:
connection_aws_ssm._disable_echo_command()
connection_aws_ssm.terminal_manager.disable_echo_command()

connection_aws_ssm._session.stdin.write.assert_called_once_with("stty -echo\n")


@pytest.mark.parametrize("timeout_failure", [True, False])
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.random")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text")
@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.random")
@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.to_bytes")
@patch("ansible_collections.community.aws.plugins.plugin_utils.terminalmanager.to_text")
def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_ssm, timeout_failure):
m_to_text.side_effect = str
m_to_bytes.side_effect = lambda x, **kw: str(x)
Expand All @@ -101,6 +109,9 @@ def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_
connection_aws_ssm.poll = MagicMock()
connection_aws_ssm.poll.side_effect = poll_mock

if not hasattr(connection_aws_ssm, "terminal_manager"):
connection_aws_ssm.terminal_manager = TerminalManager(connection_aws_ssm)

m_random.choice = MagicMock()
m_random.choice.side_effect = lambda x: "a"

Expand All @@ -115,8 +126,8 @@ def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_

if timeout_failure:
with pytest.raises(TimeoutError):
connection_aws_ssm._disable_prompt_command()
connection_aws_ssm.terminal_manager.disable_prompt_command()
else:
connection_aws_ssm._disable_prompt_command()
connection_aws_ssm.terminal_manager.disable_prompt_command()

connection_aws_ssm._session.stdin.write.assert_called_once_with(prompt_cmd)
Loading