Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,19 @@ import (
"context"
"encoding/base64"
"fmt"
"net/http"
"strconv"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go-v2/service/eks"
ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
"github.com/aws/aws-sdk-go-v2/service/iam"
iamtypes "github.com/aws/aws-sdk-go-v2/service/iam/types"
stsrequest "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/sts"
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/smithy-go"
"github.com/golang/mock/gomock"
. "github.com/onsi/gomega"
Expand All @@ -54,8 +53,8 @@ import (
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/iamauth/mock_iamauth"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/mock_services"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/network"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/securitygroup"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface"
"sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks"
clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
"sigs.k8s.io/cluster-api/util"
Expand All @@ -76,7 +75,7 @@ func TestAWSManagedControlPlaneReconcilerIntegrationTests(t *testing.T) {
ec2Mock *mocks.MockEC2API
eksMock *mock_eksiface.MockEKSAPI
iamMock *mock_iamauth.MockIAMAPI
stsMock *mock_stsiface.MockSTSAPI
stsMock *mock_stsiface.MockSTSClient
awsNodeMock *mock_services.MockAWSNodeInterface
iamAuthenticatorMock *mock_services.MockIAMAuthenticatorInterface
kubeProxyMock *mock_services.MockKubeProxyInterface
Expand All @@ -96,7 +95,7 @@ func TestAWSManagedControlPlaneReconcilerIntegrationTests(t *testing.T) {
ec2Mock = mocks.NewMockEC2API(mockCtrl)
eksMock = mock_eksiface.NewMockEKSAPI(mockCtrl)
iamMock = mock_iamauth.NewMockIAMAPI(mockCtrl)
stsMock = mock_stsiface.NewMockSTSAPI(mockCtrl)
stsMock = mock_stsiface.NewMockSTSClient(mockCtrl)

// Mocking these as well, since the actual implementation requires a remote client to an actual cluster
awsNodeMock = mock_services.NewMockAWSNodeInterface(mockCtrl)
Expand Down Expand Up @@ -854,7 +853,7 @@ func mockedEKSControlPlaneIAMRole(g *WithT, iamRec *mock_iamauth.MockIAMAPIMockR
}).After(getPolicyCall).Return(&iam.AttachRolePolicyOutput{}, nil)
}

func mockedEKSCluster(ctx context.Context, g *WithT, eksRec *mock_eksiface.MockEKSAPIMockRecorder, iamRec *mock_iamauth.MockIAMAPIMockRecorder, ec2Rec *mocks.MockEC2APIMockRecorder, stsRec *mock_stsiface.MockSTSAPIMockRecorder, awsNodeRec *mock_services.MockAWSNodeInterfaceMockRecorder, kubeProxyRec *mock_services.MockKubeProxyInterfaceMockRecorder, iamAuthenticatorRec *mock_services.MockIAMAuthenticatorInterfaceMockRecorder) {
func mockedEKSCluster(ctx context.Context, g *WithT, eksRec *mock_eksiface.MockEKSAPIMockRecorder, iamRec *mock_iamauth.MockIAMAPIMockRecorder, ec2Rec *mocks.MockEC2APIMockRecorder, stsRec *mock_stsiface.MockSTSClientMockRecorder, awsNodeRec *mock_services.MockAWSNodeInterfaceMockRecorder, kubeProxyRec *mock_services.MockKubeProxyInterfaceMockRecorder, iamAuthenticatorRec *mock_services.MockIAMAuthenticatorInterfaceMockRecorder) {
describeClusterCall := eksRec.DescribeCluster(ctx, &eks.DescribeClusterInput{
Name: aws.String("test-cluster"),
}).Return(nil, &ekstypes.ResourceNotFoundException{
Expand Down Expand Up @@ -948,12 +947,14 @@ func mockedEKSCluster(ctx context.Context, g *WithT, eksRec *mock_eksiface.MockE
})).Return(
clusterSgDesc, nil)

req, err := http.NewRequest(http.MethodGet, "foobar", http.NoBody)
g.Expect(err).To(BeNil())
stsRec.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}).Return(&stsrequest.Request{
HTTPRequest: req,
Operation: &stsrequest.Operation{},
}, &sts.GetCallerIdentityOutput{})
stsRec.PresignGetCallerIdentity(gomock.Any(), gomock.Any(), gomock.Any()).Return(&signerv4.PresignedHTTPRequest{
URL: "https://example.com",
}, nil)
stsRec.GetCallerIdentity(gomock.Any(), gomock.Any()).Return(&stsv2.GetCallerIdentityOutput{
Account: aws.String("123456789012"),
Arn: aws.String("arn:aws:iam::123456789012:user/test-user"),
UserId: aws.String("AIDACKCEVSQ6C2EXAMPLE"),
}, nil)

eksRec.TagResource(ctx, &eks.TagResourceInput{
ResourceArn: clusterARN,
Expand Down
26 changes: 12 additions & 14 deletions controlplane/rosa/controllers/rosacontrolplane_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ import (
"time"

stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
sts "github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/google/go-cmp/cmp"
idputils "github.com/openshift-online/ocm-common/pkg/idp/utils"
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
Expand Down Expand Up @@ -62,6 +60,7 @@ import (
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/annotations"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope"
stsiface "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/utils"
Expand Down Expand Up @@ -92,7 +91,7 @@ type ROSAControlPlaneReconciler struct {
WatchFilterValue string
WaitInfraPeriod time.Duration
Endpoints []scope.ServiceEndpoint
NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI
NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient
NewOCMClient func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error)
// Exposing the restClientConfig for integration test. No need to initialize.
restClientConfig *restclient.Config
Expand Down Expand Up @@ -221,7 +220,11 @@ func (r *ROSAControlPlaneReconciler) reconcileNormal(ctx context.Context, rosaSc
return ctrl.Result{}, fmt.Errorf("failed to create OCM client: %w", err)
}

creator, err := rosaaws.CreatorForCallerIdentity(convertStsV2(rosaScope.Identity))
creator, err := rosaaws.CreatorForCallerIdentity(&stsv2.GetCallerIdentityOutput{
Account: rosaScope.Identity.Account,
Arn: rosaScope.Identity.Arn,
UserId: rosaScope.Identity.UserId,
})
if err != nil {
return ctrl.Result{}, fmt.Errorf("failed to transform caller identity to creator: %w", err)
}
Expand Down Expand Up @@ -354,7 +357,11 @@ func (r *ROSAControlPlaneReconciler) reconcileDelete(ctx context.Context, rosaSc
return ctrl.Result{}, fmt.Errorf("failed to create OCM client: %w", err)
}

creator, err := rosaaws.CreatorForCallerIdentity(convertStsV2(rosaScope.Identity))
creator, err := rosaaws.CreatorForCallerIdentity(&stsv2.GetCallerIdentityOutput{
Account: rosaScope.Identity.Account,
Arn: rosaScope.Identity.Arn,
UserId: rosaScope.Identity.UserId,
})
if err != nil {
return ctrl.Result{}, fmt.Errorf("failed to transform caller identity to creator: %w", err)
}
Expand Down Expand Up @@ -1130,12 +1137,3 @@ func buildAPIEndpoint(cluster *cmv1.Cluster) (*clusterv1.APIEndpoint, error) {
Port: int32(port), //#nosec G109 G115
}, nil
}

// TODO: Remove this and update the aws-sdk lib to v2.
func convertStsV2(identity *sts.GetCallerIdentityOutput) *stsv2.GetCallerIdentityOutput {
return &stsv2.GetCallerIdentityOutput{
Account: identity.Account,
Arn: identity.Arn,
UserId: identity.UserId,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ import (
"testing"
"time"

stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go/aws"
sts "github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/golang/mock/gomock"
. "github.com/onsi/gomega"
v1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
Expand All @@ -48,7 +47,8 @@ import (
rosacontrolplanev1 "sigs.k8s.io/cluster-api-provider-aws/v2/controlplane/rosa/api/v1beta2"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface"
stsiface "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa"
"sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks"
Expand Down Expand Up @@ -292,10 +292,10 @@ func TestRosaControlPlaneReconcileStatusVersion(t *testing.T) {
mockCtrl := gomock.NewController(t)
ctx := context.TODO()
ocmMock := mocks.NewMockOCMClient(mockCtrl)
stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl)
stsMock := mock_stsiface.NewMockSTSClient(mockCtrl)

getCallerIdentityResult := &sts.GetCallerIdentityOutput{Account: aws.String("foo"), Arn: aws.String("arn:aws:iam::123456789012:rosa/foo")}
stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Return(getCallerIdentityResult, nil).Times(1)
getCallerIdentityResult := &stsv2.GetCallerIdentityOutput{Account: aws.String("foo"), Arn: aws.String("arn:aws:iam::123456789012:rosa/foo")}
stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Return(getCallerIdentityResult, nil).Times(1)

expect := func(m *mocks.MockOCMClientMockRecorder) {
m.ValidateHypershiftVersion(gomock.Any(), gomock.Any()).DoAndReturn(func(clusterId string, nodePoolID string) (bool, error) {
Expand Down Expand Up @@ -396,7 +396,9 @@ func TestRosaControlPlaneReconcileStatusVersion(t *testing.T) {
Endpoints: []scope.ServiceEndpoint{},
Client: testEnv,
restClientConfig: cfg,
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI { return stsMock },
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient {
return stsMock
},
NewOCMClient: func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) {
return ocmMock, nil
},
Expand Down
6 changes: 3 additions & 3 deletions exp/controllers/awsmachinepool_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import (
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/mock_services"
s3svc "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_s3iface"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/userdata"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
Expand All @@ -71,7 +71,7 @@ func TestAWSMachinePoolReconciler(t *testing.T) {
asgSvc *mock_services.MockASGInterface
reconSvc *mock_services.MockMachinePoolReconcileInterface
s3Mock *mock_s3iface.MockS3API
stsMock *mock_stsiface.MockSTSAPI
stsMock *mock_stsiface.MockSTSClient
recorder *record.FakeRecorder
awsMachinePool *expinfrav1.AWSMachinePool
secret *corev1.Secret
Expand Down Expand Up @@ -182,7 +182,7 @@ func TestAWSMachinePoolReconciler(t *testing.T) {
asgSvc = mock_services.NewMockASGInterface(mockCtrl)
reconSvc = mock_services.NewMockMachinePoolReconcileInterface(mockCtrl)
s3Mock = mock_s3iface.NewMockS3API(mockCtrl)
stsMock = mock_stsiface.NewMockSTSAPI(mockCtrl)
stsMock = mock_stsiface.NewMockSTSClient(mockCtrl)

// If the test hangs for 9 minutes, increase the value here to the number of events during a reconciliation loop
recorder = record.NewFakeRecorder(2)
Expand Down
4 changes: 2 additions & 2 deletions exp/controllers/rosamachinepool_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/blang/semver"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
Expand Down Expand Up @@ -35,6 +34,7 @@ import (
expinfrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/exp/api/v1beta2"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope"
stsservice "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa"
"sigs.k8s.io/cluster-api-provider-aws/v2/util/paused"
Expand All @@ -52,7 +52,7 @@ type ROSAMachinePoolReconciler struct {
Recorder record.EventRecorder
WatchFilterValue string
Endpoints []scope.ServiceEndpoint
NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI
NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsservice.STSClient
NewOCMClient func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error)
}

Expand Down
20 changes: 12 additions & 8 deletions exp/controllers/rosamachinepool_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/golang/mock/gomock"
. "github.com/onsi/gomega"
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
Expand All @@ -26,7 +25,8 @@ import (
expinfrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/exp/api/v1beta2"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface"
stsiface "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa"
"sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks"
Expand Down Expand Up @@ -546,15 +546,17 @@ func TestRosaMachinePoolReconcile(t *testing.T) {
ocmMock := mocks.NewMockOCMClient(mockCtrl)
test.expect(ocmMock.EXPECT())

stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl)
stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Times(1)
stsMock := mock_stsiface.NewMockSTSClient(mockCtrl)
stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Times(1)

r := ROSAMachinePoolReconciler{
Recorder: recorder,
WatchFilterValue: "",
Endpoints: []scope.ServiceEndpoint{},
Client: testEnv,
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI { return stsMock },
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient {
return stsMock
},
NewOCMClient: func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) {
return ocmMock, nil
},
Expand Down Expand Up @@ -641,15 +643,17 @@ func TestRosaMachinePoolReconcile(t *testing.T) {
}
expect(ocmMock.EXPECT())

stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl)
stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Times(1)
stsMock := mock_stsiface.NewMockSTSClient(mockCtrl)
stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Times(1)

r := ROSAMachinePoolReconciler{
Recorder: recorder,
WatchFilterValue: "",
Endpoints: []scope.ServiceEndpoint{},
Client: testEnv,
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI { return stsMock },
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient {
return stsMock
},
NewOCMClient: func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) {
return ocmMock, nil
},
Expand Down
23 changes: 23 additions & 0 deletions pkg/cloud/endpointsv2/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/aws/aws-sdk-go-v2/service/ssm"
"github.com/aws/aws-sdk-go-v2/service/sts"
smithyendpoints "github.com/aws/smithy-go/endpoints"

"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
Expand Down Expand Up @@ -303,3 +304,25 @@ func (s *SSMEndpointResolver) ResolveEndpoint(ctx context.Context, params ssm.En
params.Region = &endpoint.SigningRegion
return ssm.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params)
}

// STSEndpointResolver implements EndpointResolverV2 interface for STS.
type STSEndpointResolver struct {
*MultiServiceEndpointResolver
}

// ResolveEndpoint for STS.
func (s *STSEndpointResolver) ResolveEndpoint(ctx context.Context, params sts.EndpointParameters) (smithyendpoints.Endpoint, error) {
// If custom endpoint not found, return default endpoint for the service
log := logger.FromContext(ctx)
endpoint, ok := s.endpoints[sts.ServiceID]

if !ok {
log.Debug("Custom endpoint not found, using default endpoint")
return sts.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params)
}

log.Debug("Custom endpoint found, using custom endpoint", "endpoint", endpoint.URL)
params.Endpoint = &endpoint.URL
params.Region = &endpoint.SigningRegion
return sts.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params)
}
29 changes: 21 additions & 8 deletions pkg/cloud/scope/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,20 @@ import (
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/aws/aws-sdk-go-v2/service/ssm"
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"k8s.io/apimachinery/pkg/runtime"

"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/endpointsv2"
awslogs "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/logs"
awsmetrics "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/metrics"
awsmetricsv2 "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/metricsv2"
stsservice "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/throttle"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/record"
Expand Down Expand Up @@ -270,13 +270,26 @@ func NewIAMClient(scopeUser cloud.ScopeUsage, session cloud.Session, logger logg
}

// NewSTSClient creates a new STS API client for a given session.
func NewSTSClient(scopeUser cloud.ScopeUsage, session cloud.Session, logger logger.Wrapper, target runtime.Object) stsiface.STSAPI {
stsClient := sts.New(session.Session(), aws.NewConfig().WithLogLevel(awslogs.GetAWSLogLevel(logger.GetLogger())).WithLogger(awslogs.NewWrapLogr(logger.GetLogger())))
stsClient.Handlers.Build.PushFrontNamed(getUserAgentHandler())
stsClient.Handlers.CompleteAttempt.PushFront(awsmetrics.CaptureRequestMetrics(scopeUser.ControllerName()))
stsClient.Handlers.Complete.PushBack(recordAWSPermissionsIssue(target))
func NewSTSClient(scopeUser cloud.ScopeUsage, session cloud.Session, logger logger.Wrapper, target runtime.Object) stsservice.STSClient {
cfg := session.SessionV2()
multiSvcEndpointResolver := endpointsv2.NewMultiServiceEndpointResolver()
stsEndpointResolver := &endpointsv2.STSEndpointResolver{
MultiServiceEndpointResolver: multiSvcEndpointResolver,
}

stsOpts := []func(*stsv2.Options){
func(o *stsv2.Options) {
o.Logger = logger.GetAWSLogger()
o.ClientLogMode = awslogs.GetAWSLogLevelV2(logger.GetLogger())
o.EndpointResolverV2 = stsEndpointResolver
},
stsv2.WithAPIOptions(
awsmetricsv2.WithMiddlewares(scopeUser.ControllerName(), target),
awsmetricsv2.WithCAPAUserAgentMiddleware(),
),
}

return stsClient
return stsservice.NewClientWrapper(stsv2.NewFromConfig(cfg, stsOpts...))
}

// NewSSMClient creates a new Secrets API client for a given session.
Expand Down
Loading