@@ -2,15 +2,12 @@ package auth_test
22
33import (
44 "context"
5- "net/url"
65 "regexp"
76 "strings"
87 "testing"
9- "time"
108
119 "github.com/aws/aws-sdk-go-v2/aws"
1210 "github.com/aws/aws-sdk-go-v2/feature/rds/auth"
13- "github.com/aws/aws-sdk-go-v2/internal/sdk"
1411)
1512
1613func TestBuildAuthToken (t * testing.T ) {
@@ -70,155 +67,14 @@ func TestBuildAuthToken(t *testing.T) {
7067 }
7168}
7269
73- type dbAuthTestCase struct {
74- endpoint string
75- region string
76- expires time.Duration
77- credsExpireIn time.Duration
78- expectedHost string
79- expectedQueryParams []string
80- expectedError string
81- }
82-
83- type tokenGenFunc func (ctx context.Context , endpoint , region string , creds aws.CredentialsProvider , optFns ... func (options * auth.BuildAuthTokenOptions )) (string , error )
84-
85- func TestGenerateDbConnectAuthToken (t * testing.T ) {
86- cases := map [string ]dbAuthTestCase {
87- "no region" : {
88- endpoint : "https://prod-instance.us-east-1.rds.amazonaws.com:3306" ,
89- expectedError : "no region" ,
90- },
91- "no endpoint" : {
92- region : "us-west-2" ,
93- expectedError : "port" ,
94- },
95- "endpoint with scheme" : {
96- endpoint : "https://prod-instance.us-east-1.rds.amazonaws.com:3306" ,
97- region : "us-east-1" ,
98- expectedHost : "prod-instance.us-east-1.rds.amazonaws.com:3306" ,
99- expectedQueryParams : []string {"Action=DbConnect" },
100- },
101- "endpoint without scheme" : {
102- endpoint : "prod-instance.us-east-1.rds.amazonaws.com:3306" ,
103- region : "us-east-1" ,
104- expectedHost : "prod-instance.us-east-1.rds.amazonaws.com:3306" ,
105- expectedQueryParams : []string {"Action=DbConnect" },
106- },
107- "endpoint without port" : {
108- endpoint : "prod-instance.us-east-1.rds.amazonaws.com" ,
109- region : "us-east-1" ,
110- expectedHost : "prod-instance.us-east-1.rds.amazonaws.com" ,
111- expectedQueryParams : []string {"Action=DbConnect" },
112- },
113- "endpoint with region and expires" : {
114- endpoint : "peccy.dsql.us-east-1.on.aws" ,
115- region : "us-east-1" ,
116- expires : time .Second * 450 ,
117- expectedHost : "peccy.dsql.us-east-1.on.aws" ,
118- expectedQueryParams : []string {
119- "Action=DbConnect" ,
120- "X-Amz-Algorithm=AWS4-HMAC-SHA256" ,
121- "X-Amz-Credential=akid/20240827/us-east-1/dsql/aws4_request" ,
122- "X-Amz-Date=20240827T000000Z" ,
123- "X-Amz-Expires=450" },
124- },
125- "pick credential expires when less than expires" : {
126- endpoint : "peccy.dsql.us-east-1.on.aws" ,
127- region : "us-east-1" ,
128- credsExpireIn : time .Second * 100 ,
129- expires : time .Second * 450 ,
130- expectedHost : "peccy.dsql.us-east-1.on.aws" ,
131- expectedQueryParams : []string {
132- "Action=DbConnect" ,
133- "X-Amz-Algorithm=AWS4-HMAC-SHA256" ,
134- "X-Amz-Credential=akid/20240827/us-east-1/dsql/aws4_request" ,
135- "X-Amz-Date=20240827T000000Z" ,
136- "X-Amz-Expires=100" },
137- },
138- }
139-
140- for _ , c := range cases {
141- creds := & staticCredentials {AccessKey : "akid" , SecretKey : "secret" , expiresIn : c .credsExpireIn }
142- defer withTempGlobalTime (time .Date (2024 , time .August , 27 , 0 , 0 , 0 , 0 , time .UTC ))()
143- optFns := func (options * auth.BuildAuthTokenOptions ) {}
144- if c .expires != 0 {
145- optFns = func (options * auth.BuildAuthTokenOptions ) {
146- options .ExpiresIn = c .expires
147- }
148- }
149- verifyTestCase (auth .GenerateDbConnectAuthToken , c , creds , optFns , t )
150-
151- // Update the test case to use Superuser variant
152- updated := []string {}
153- for _ , part := range c .expectedQueryParams {
154- if part == "Action=DbConnect" {
155- part = "Action=DbConnectAdmin"
156- }
157- updated = append (updated , part )
158- }
159- c .expectedQueryParams = updated
160-
161- verifyTestCase (auth .GenerateDBConnectSuperUserAuthToken , c , creds , optFns , t )
162- }
163- }
164-
165- func verifyTestCase (f tokenGenFunc , c dbAuthTestCase , creds aws.CredentialsProvider , optFns func (options * auth.BuildAuthTokenOptions ), t * testing.T ) {
166- token , err := f (context .Background (), c .endpoint , c .region , creds , optFns )
167- isErrorExpected := len (c .expectedError ) > 0
168- if err != nil && ! isErrorExpected {
169- t .Fatalf ("expect no err, got: %v" , err )
170- } else if err == nil && isErrorExpected {
171- t .Fatalf ("Expected error %v got none" , c .expectedError )
172- }
173- // adding a scheme so we can parse it back as a URL. This is because comparing
174- // just direct string comparison was failing since "Action=DbConnect" is a substring or
175- // "Action=DBConnectSuperuser"
176- parsed , err := url .Parse ("http://" + token )
177- if err != nil {
178- t .Fatalf ("Couldn't parse the token %v to URL after adding a scheme, got: %v" , token , err )
179- }
180- if parsed .Host != c .expectedHost {
181- t .Errorf ("expect host %v, got %v" , c .expectedHost , parsed .Host )
182- }
183-
184- q := parsed .Query ()
185- queryValuePair := map [string ]any {}
186- for k , v := range q {
187- pair := k + "=" + v [0 ]
188- queryValuePair [pair ] = struct {}{}
189- }
190-
191- for _ , part := range c .expectedQueryParams {
192- if _ , ok := queryValuePair [part ]; ! ok {
193- t .Errorf ("expect part %s to be present at token %s" , part , token )
194- }
195- }
196- if token != "" && c .expires == 0 {
197- if ! strings .Contains (token , "X-Amz-Expires=900" ) {
198- t .Errorf ("expect token to contain default X-Amz-Expires value of 900, got %v" , token )
199- }
200- }
201- }
202-
20370type staticCredentials struct {
20471 AccessKey , SecretKey , Session string
205- expiresIn time.Duration
20672}
20773
20874func (s * staticCredentials ) Retrieve (ctx context.Context ) (aws.Credentials , error ) {
209- c := aws.Credentials {
75+ return aws.Credentials {
21076 AccessKeyID : s .AccessKey ,
21177 SecretAccessKey : s .SecretKey ,
21278 SessionToken : s .Session ,
213- }
214- if s .expiresIn != 0 {
215- c .CanExpire = true
216- c .Expires = sdk .NowTime ().Add (s .expiresIn )
217- }
218- return c , nil
219- }
220-
221- func withTempGlobalTime (t time.Time ) func () {
222- sdk .NowTime = func () time.Time { return t }
223- return func () { sdk .NowTime = time .Now }
79+ }, nil
22480}
0 commit comments