Skip to content

Commit be6ec2f

Browse files
committed
modified based on feedback
1 parent e882fd2 commit be6ec2f

File tree

3 files changed

+65
-30
lines changed

3 files changed

+65
-30
lines changed

plugins/connection/aws_ssm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
484484
self._instance_id = None
485485
self._polling_obj = None
486486
self._has_timeout = False
487-
self.terminal_manager = TerminalManager(self)
487+
self.terminal_manager = TerminalManager(
488+
session=self._session,
489+
stdout=self._stdout,
490+
poller=self._polling_obj or select.poll(),
491+
verbosity_display=lambda level, msg: self.verbosity_display(level, msg),
492+
)
488493

489494
if getattr(self._shell, "SHELL_FAMILY", "") == "powershell":
490495
self.delegate = None

plugins/plugin_utils/terminalmanager.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,27 @@
66
import random
77
import re
88
import string
9+
from typing import Any
10+
from typing import Callable
911

1012
from ansible.module_utils._text import to_bytes
1113
from ansible.module_utils._text import to_text
1214
from ansible.plugins.shell.powershell import _common_args
1315

1416

1517
class TerminalManager:
16-
def __init__(self, connection):
17-
self.connection = connection
18+
MARK_LENGTH = 26
1819

19-
def prepare_terminal(self) -> None:
20+
def __init__(self, session: Any, stdout: Any, poller: Callable, verbosity_display: Callable) -> None:
21+
self._session = session
22+
self._stdout = stdout
23+
self._poller = poller
24+
self.verbosity_display = verbosity_display
25+
26+
def prepare_terminal(self, is_windows: bool) -> None:
2027
"""perform any one-time terminal settings"""
2128
# No Windows setup for now
22-
if self.connection.is_windows:
29+
if is_windows:
2330
return
2431

2532
# Ensure SSM Session has started
@@ -31,14 +38,13 @@ def prepare_terminal(self) -> None:
3138
# Disable prompt command
3239
self.disable_prompt_command() # pylint: disable=unreachable
3340

34-
self.connection.verbosity_display(4, "PRE Terminal configured") # pylint: disable=unreachable
41+
self.verbosity_display(4, "PRE Terminal configured") # pylint: disable=unreachable
3542

36-
def wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
43+
def wrap_command(self, cmd: str, mark_start: str, mark_end: str, is_windows: bool, shell: Any) -> str:
3744
"""Wrap command so stdout and status can be extracted"""
38-
39-
if self.connection.is_windows:
45+
if is_windows:
4046
if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"):
41-
cmd = self.connection._shell._encode_script(cmd, preserve_rc=True)
47+
cmd = shell._encode_script(cmd, preserve_rc=True)
4248
cmd = cmd + "; echo " + mark_start + "\necho " + mark_end + "\n"
4349
else:
4450
cmd = (
@@ -47,44 +53,44 @@ def wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
4753
f"printf '\\n%s\\n%s\\n' \"$?\" '{mark_end}';\n"
4854
) # fmt: skip
4955

50-
self.connection.verbosity_display(4, f"wrap_command: \n'{to_text(cmd)}'")
56+
self.verbosity_display(4, f"wrap_command: \n'{to_text(cmd)}'")
5157
return cmd
5258

5359
def disable_echo_command(self) -> None:
5460
"""Disable echo command from the host"""
5561
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")
5662

5763
# Send command
58-
self.connection.verbosity_display(4, f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}")
59-
self.connection._session.stdin.write(disable_echo_cmd)
64+
self.verbosity_display(4, f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}")
65+
self._session.stdin.write(disable_echo_cmd)
6066

6167
stdout = ""
62-
for poll_result in self.connection.poll("DISABLE ECHO", disable_echo_cmd):
68+
for poll_result in self._poller("DISABLE ECHO", disable_echo_cmd):
6369
if poll_result:
64-
stdout += to_text(self.connection._stdout.read(1024))
65-
self.connection.verbosity_display(4, f"DISABLE ECHO stdout line: \n{to_bytes(stdout)}")
70+
stdout += to_text(self._stdout.read(1024))
71+
self.verbosity_display(4, f"DISABLE ECHO stdout line: \n{to_bytes(stdout)}")
6672
match = str(stdout).find("stty -echo")
6773
if match != -1:
6874
break
6975

7076
def disable_prompt_command(self) -> None:
7177
"""Disable prompt command from the host"""
72-
end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.connection.MARK_LENGTH)])
78+
end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])
7379
disable_prompt_cmd = to_bytes(
7480
"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n",
7581
errors="surrogate_or_strict",
7682
)
7783
disable_prompt_reply = re.compile(r"\r\r\n" + re.escape(end_mark) + r"\r\r\n", re.MULTILINE)
7884

7985
# Send command
80-
self.connection.verbosity_display(4, f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}")
81-
self.connection._session.stdin.write(disable_prompt_cmd)
86+
self.verbosity_display(4, f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}")
87+
self._session.stdin.write(disable_prompt_cmd)
8288

8389
stdout = ""
84-
for poll_result in self.connection.poll("DISABLE PROMPT", disable_prompt_cmd):
90+
for poll_result in self._poller("DISABLE PROMPT", disable_prompt_cmd):
8591
if poll_result:
86-
stdout += to_text(self.connection._stdout.read(1024))
87-
self.connection.verbosity_display(4, f"DISABLE PROMPT stdout line: \n{to_bytes(stdout)}")
92+
stdout += to_text(self._stdout.read(1024))
93+
self.verbosity_display(4, f"DISABLE PROMPT stdout line: \n{to_bytes(stdout)}")
8894
if disable_prompt_reply.search(stdout):
8995
break
9096

@@ -93,11 +99,11 @@ def ensure_ssm_session_has_started(self) -> None:
9399
until we match the following string 'Starting session with SessionId'
94100
"""
95101
stdout = ""
96-
for poll_result in self.connection.poll("START SSM SESSION", "start_session"):
102+
for poll_result in self._poller("START SSM SESSION", "start_session"):
97103
if poll_result:
98-
stdout += to_text(self.connection._stdout.read(1024))
99-
self.connection.verbosity_display(4, f"START SSM SESSION stdout line: \n{to_bytes(stdout)}")
104+
stdout += to_text(self._stdout.read(1024))
105+
self.verbosity_display(4, f"START SSM SESSION stdout line: \n{to_bytes(stdout)}")
100106
match = str(stdout).find("Starting session with SessionId")
101107
if match != -1:
102-
self.connection.verbosity_display(4, "START SSM SESSION startup output received")
108+
self.verbosity_display(4, "START SSM SESSION startup output received")
103109
break

tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,26 @@ def poll_mock(x, y):
4545
def test_ensure_ssm_session_has_started(m_to_text, m_to_bytes, connection_aws_ssm, stdout_lines, timeout_failure):
4646
m_to_text.side_effect = str
4747
m_to_bytes.side_effect = str
48-
connection_aws_ssm._stdout.read = MagicMock()
48+
connection_aws_ssm.get_option = MagicMock(return_value=10)
49+
50+
mock_session = MagicMock()
51+
mock_session.poll.return_value = None
52+
connection_aws_ssm._session = mock_session
4953

54+
connection_aws_ssm._stdout = MagicMock()
5055
connection_aws_ssm._stdout.read.side_effect = stdout_lines
5156

57+
poll_mock.results = [True for i in range(len(stdout_lines))]
58+
connection_aws_ssm.poll = MagicMock()
59+
connection_aws_ssm.poll.side_effect = poll_mock
60+
5261
if not hasattr(connection_aws_ssm, "terminal_manager"):
53-
connection_aws_ssm.terminal_manager = TerminalManager(connection_aws_ssm)
62+
connection_aws_ssm.terminal_manager = TerminalManager(
63+
session=connection_aws_ssm._session,
64+
stdout=connection_aws_ssm._stdout,
65+
poller=connection_aws_ssm.poll,
66+
verbosity_display=lambda lvl, msg: None,
67+
)
5468

5569
poll_mock.results = [True for i in range(len(stdout_lines))]
5670
connection_aws_ssm.poll = MagicMock()
@@ -86,7 +100,12 @@ def test_disable_echo_command(m_to_text, m_to_bytes, connection_aws_ssm, stdout_
86100
connection_aws_ssm.poll.side_effect = poll_mock
87101

88102
if not hasattr(connection_aws_ssm, "terminal_manager"):
89-
connection_aws_ssm.terminal_manager = TerminalManager(connection_aws_ssm)
103+
connection_aws_ssm.terminal_manager = TerminalManager(
104+
session=connection_aws_ssm._session,
105+
stdout=connection_aws_ssm._stdout,
106+
poller=connection_aws_ssm.poll,
107+
verbosity_display=lambda lvl, msg: None,
108+
)
90109

91110
if timeout_failure:
92111
with pytest.raises(TimeoutError):
@@ -110,7 +129,12 @@ def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_
110129
connection_aws_ssm.poll.side_effect = poll_mock
111130

112131
if not hasattr(connection_aws_ssm, "terminal_manager"):
113-
connection_aws_ssm.terminal_manager = TerminalManager(connection_aws_ssm)
132+
connection_aws_ssm.terminal_manager = TerminalManager(
133+
session=connection_aws_ssm._session,
134+
stdout=connection_aws_ssm._stdout,
135+
poller=connection_aws_ssm.poll,
136+
verbosity_display=lambda lvl, msg: None,
137+
)
114138

115139
m_random.choice = MagicMock()
116140
m_random.choice.side_effect = lambda x: "a"

0 commit comments

Comments
 (0)