Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
minor_changes:
- aws_ssm - Refactor S3 operations methods for improved clarity (https://github.com/ansible-collections/community.aws/pull/2275).
243 changes: 53 additions & 190 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,20 +334,12 @@
import time
from functools import wraps
from typing import Any
from typing import Dict
from typing import Iterator
from typing import List
from typing import NoReturn
from typing import Optional
from typing import Tuple
from typing import TypedDict

try:
import boto3
from botocore.client import Config
except ImportError:
pass

from ansible.errors import AnsibleConnectionFailure
from ansible.errors import AnsibleError
from ansible.errors import AnsibleFileNotFound
Expand All @@ -359,7 +351,7 @@
from ansible.utils.display import Display

from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3

from ansible_collections.community.aws.plugins.plugin_utils.base import AwsConnectionPluginBase
from ansible_collections.community.aws.plugins.plugin_utils.s3clientmanager import S3ClientManager
from ansible_collections.community.aws.plugins.plugin_utils.terminalmanager import TerminalManager

Expand Down Expand Up @@ -455,7 +447,7 @@ class CommandResult(TypedDict):
stderr_combined: str


class Connection(ConnectionBase):
class Connection(ConnectionBase, AwsConnectionPluginBase):
"""AWS SSM based connections"""

transport = "community.aws.aws_ssm"
Expand All @@ -467,7 +459,6 @@ class Connection(ConnectionBase):
is_windows = False

_client = None
_s3_client = None
_session = None
_stdout = None
_session_id = ""
Expand Down Expand Up @@ -514,40 +505,36 @@ def _init_clients(self) -> None:
"""

self.verbosity_display(4, "INITIALIZE BOTO3 CLIENTS")
profile_name = self.get_option("profile") or ""
region_name = self.get_option("region")

# Initialize S3ClientManager
self.s3_manager = S3ClientManager(self)

# Initialize S3 client
s3_endpoint_url, s3_region_name = self.s3_manager.get_bucket_endpoint()
self.verbosity_display(4, f"SETUP BOTO3 CLIENTS: S3 {s3_endpoint_url}")
self.s3_manager.initialize_client(
region_name=s3_region_name, endpoint_url=s3_endpoint_url, profile_name=profile_name
# Create S3 and SSM clients
config = {"signature_version": "s3v4", "s3": {"addressing_style": self.get_option("s3_addressing_style")}}

bucket_endpoint_url = self.get_option("bucket_endpoint_url")
s3_endpoint_url, s3_region_name = S3ClientManager.get_bucket_endpoint(
bucket_name=self.get_option("bucket_name"),
bucket_endpoint_url=bucket_endpoint_url,
access_key_id=self.get_option("access_key_id"),
secret_key_id=self.get_option("secret_access_key"),
session_token=self.get_option("session_token"),
region_name=self.get_option("region"),
profile_name=self.get_option("profile"),
)
self._s3_client = self.s3_manager._s3_client

# Initialize SSM client
self._initialize_ssm_client(region_name, profile_name)
self.verbosity_display(4, f"BUCKET Information - Endpoint: {s3_endpoint_url} / Region: {s3_region_name}")

def _initialize_ssm_client(self, region_name: Optional[str], profile_name: str) -> None:
"""
Initializes the SSM client used to manage sessions.
Args:
region_name (Optional[str]): AWS region for the SSM client.
profile_name (str): AWS profile name for authentication.
# Initialize S3ClientManager
s3_client = self._get_boto_client("s3", endpoint_url=s3_endpoint_url, region_name=s3_region_name, config=config)
self.s3_manager = S3ClientManager(s3_client)

Returns:
None
"""
# Initialize SSM client
self._client = self._get_boto_client("ssm", region_name=self.get_option("region"), config=config)

self.verbosity_display(4, "SETUP BOTO3 CLIENTS: SSM")
self._client = self._get_boto_client(
"ssm",
region_name=region_name,
profile_name=profile_name,
)
@property
def s3_client(self) -> None:
client = None
if self.s3_manager is not None:
client = self.s3_manager.client
return client

def verbosity_display(self, level: int, message: str) -> None:
"""
Expand Down Expand Up @@ -616,6 +603,7 @@ def start_session(self):
if document_name is not None:
start_session_args["DocumentName"] = document_name
response = self._client.start_session(**start_session_args)
self.verbosity_display(4, f"START SESSION RESPONSE: {response}")
self._session_id = response["SessionId"]

region_name = self.get_option("region")
Expand Down Expand Up @@ -803,153 +791,26 @@ def _flush_stderr(self, session_process) -> str:

return stderr

def _get_boto_client(self, service, region_name=None, profile_name=None, endpoint_url=None):
"""Gets a boto3 client based on the STS token"""

aws_access_key_id = self.get_option("access_key_id")
aws_secret_access_key = self.get_option("secret_access_key")
aws_session_token = self.get_option("session_token")

session_args = dict(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=region_name,
)
if profile_name:
session_args["profile_name"] = profile_name
session = boto3.session.Session(**session_args)

client = session.client(
service,
endpoint_url=endpoint_url,
config=Config(
signature_version="s3v4",
s3={"addressing_style": self.get_option("s3_addressing_style")},
),
)
return client

def _escape_path(self, path: str) -> str:
return path.replace("\\", "/")

def _generate_commands(
self,
bucket_name: str,
s3_path: str,
in_path: str,
out_path: str,
) -> Tuple[List[Dict], dict]:
"""
Generate commands for the specified bucket, S3 path, input path, and output path.

:param bucket_name: The name of the S3 bucket used for file transfers.
:param s3_path: The S3 path to the file to be sent.
:param in_path: Input path
:param out_path: Output path
:param method: The request method to use for the command (can be "get" or "put").

:returns: A tuple containing a list of command dictionaries along with any ``put_args`` dictionaries.
"""

put_args, put_headers = self.s3_manager.generate_encryption_settings()
commands = []

put_url = self.s3_manager.get_url("put_object", bucket_name, s3_path, "PUT", extra_args=put_args)
get_url = self.s3_manager.get_url("get_object", bucket_name, s3_path, "GET")

if self.is_windows:
put_command_headers = "; ".join([f"'{h}' = '{v}'" for h, v in put_headers.items()])
commands.append({
"command":
(
"Invoke-WebRequest "
f"'{get_url}' "
f"-OutFile '{out_path}'"
),
# The "method" key indicates to _file_transport_command which commands are get_commands
"method": "get",
"headers": {},
}) # fmt: skip
commands.append({
"command":
(
"Invoke-WebRequest -Method PUT "
# @{'key' = 'value'; 'key2' = 'value2'}
f"-Headers @{{{put_command_headers}}} "
f"-InFile '{in_path}' "
f"-Uri '{put_url}' "
f"-UseBasicParsing"
),
# The "method" key indicates to _file_transport_command which commands are put_commands
"method": "put",
"headers": put_headers,
}) # fmt: skip
else:
put_command_headers = " ".join([f"-H '{h}: {v}'" for h, v in put_headers.items()])
commands.append({
"command":
(
"curl "
f"-o '{out_path}' "
f"'{get_url}'"
),
# The "method" key indicates to _file_transport_command which commands are get_commands
"method": "get",
"headers": {},
}) # fmt: skip
# Due to https://github.com/curl/curl/issues/183 earlier
# versions of curl did not create the output file, when the
# response was empty. Although this issue was fixed in 2015,
# some actively maintained operating systems still use older
# versions of it (e.g. CentOS 7)
commands.append({
"command":
(
"touch "
f"'{out_path}'"
),
"method": "get",
"headers": {},
}) # fmt: skip
commands.append({
"command":
(
"curl --request PUT "
f"{put_command_headers} "
f"--upload-file '{in_path}' "
f"'{put_url}'"
),
# The "method" key indicates to _file_transport_command which commands are put_commands
"method": "put",
"headers": put_headers,
}) # fmt: skip

return commands, put_args

def _exec_transport_commands(self, in_path: str, out_path: str, commands: List[dict]) -> CommandResult:
def _exec_transport_commands(self, in_path: str, out_path: str, command: dict) -> CommandResult:
"""
Execute the provided transport commands.
Execute the provided transport command.

:param in_path: The input path.
:param out_path: The output path.
:param commands: A list of command dictionaries containing the command string and metadata.
:param command: A command to execute on the host.

:returns: A tuple containing the return code, stdout, and stderr.
"""

stdout_combined, stderr_combined = "", ""
for command in commands:
(returncode, stdout, stderr) = self.exec_command(command["command"], in_data=None, sudoable=False)

# Check the return code
if returncode != 0:
raise AnsibleError(f"failed to transfer file to {in_path} {out_path}:\n{stdout}\n{stderr}")
returncode, stdout, stderr = self.exec_command(command, in_data=None, sudoable=False)
# Check the return code
if returncode != 0:
raise AnsibleError(f"failed to transfer file to {in_path} {out_path}:\n{stdout}\n{stderr}")

stdout_combined += stdout
stderr_combined += stderr

return (returncode, stdout_combined, stderr_combined)
return returncode, stdout, stderr

@_ssm_retry
def _file_transport_command(
Expand All @@ -971,30 +832,30 @@ def _file_transport_command(
bucket_name = self.get_option("bucket_name")
s3_path = self._escape_path(f"{self.instance_id}/{out_path}")

client = self._s3_client

commands, put_args = self._generate_commands(
command, put_args = self.s3_manager.generate_host_commands(
bucket_name,
self.get_option("bucket_sse_mode"),
self.get_option("bucket_sse_kms_key_id"),
s3_path,
in_path,
out_path,
self.is_windows,
ssm_action,
)

try:
if ssm_action == "get":
put_commands = [cmd for cmd in commands if cmd.get("method") == "put"]
result = self._exec_transport_commands(in_path, out_path, put_commands)
result = self._exec_transport_commands(in_path, out_path, command)
with open(to_bytes(out_path, errors="surrogate_or_strict"), "wb") as data:
client.download_fileobj(bucket_name, s3_path, data)
self.s3_client.download_fileobj(bucket_name, s3_path, data)
else:
get_commands = [cmd for cmd in commands if cmd.get("method") == "get"]
with open(to_bytes(in_path, errors="surrogate_or_strict"), "rb") as data:
client.upload_fileobj(data, bucket_name, s3_path, ExtraArgs=put_args)
result = self._exec_transport_commands(in_path, out_path, get_commands)
self.s3_client.upload_fileobj(data, bucket_name, s3_path, ExtraArgs=put_args)
result = self._exec_transport_commands(in_path, out_path, command)
return result
finally:
# Remove the files from the bucket after they've been transferred
client.delete_object(Bucket=bucket_name, Key=s3_path)
self.s3_client.delete_object(Bucket=bucket_name, Key=s3_path)

def put_file(self, in_path: str, out_path: str) -> Tuple[int, str, str]:
"""transfer a file from local to remote"""
Expand All @@ -1019,12 +880,14 @@ def close(self) -> None:
"""terminate the connection"""
if self._session_id:
self.verbosity_display(3, f"CLOSING SSM CONNECTION TO: {self.instance_id}")
if self._has_timeout:
self._session.terminate()
else:
cmd = b"\nexit\n"
self._session.communicate(cmd)
if self._session is not None:
if self._has_timeout:
self._session.terminate()
else:
cmd = b"\nexit\n"
self._session.communicate(cmd)

self.verbosity_display(4, f"TERMINATE SSM SESSION: {self._session_id}")
self._client.terminate_session(SessionId=self._session_id)
if self._client:
self._client.terminate_session(SessionId=self._session_id)
self._session_id = ""
50 changes: 50 additions & 0 deletions plugins/plugin_utils/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-

# Copyright: Contributors to the Ansible project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

from typing import Any
from typing import Dict
from typing import Optional

try:
from boto3.session import Session
from botocore.client import Config
except ImportError:
pass


class AwsConnectionPluginBase:
def __init__(self) -> None:
pass

def _get_boto_client(
self,
service: str,
region_name: Optional[str] = None,
endpoint_url: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
) -> Any:
"""Gets a boto3 client based on the STS token"""

aws_access_key_id = self.get_option("access_key_id")
aws_secret_access_key = self.get_option("secret_access_key")
aws_session_token = self.get_option("session_token")

session_args = dict(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=region_name,
)
profile_name = self.get_option("profile")
if profile_name:
session_args["profile_name"] = profile_name
session = Session(**session_args)
params = {}
if endpoint_url:
params["endpoint_url"] = endpoint_url
if config:
params["config"] = Config(**config)

return session.client(service, **params)
Loading
Loading