Skip to content

Commit 2c8f697

Browse files
authored
breakfix: revert bad feature/rds/auth api release (#2925)
1 parent ffbc1df commit 2c8f697

File tree

5 files changed

+23
-219
lines changed

5 files changed

+23
-219
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "658fc1c5-0afd-443f-a803-916695e3583e",
3+
"type": "bugfix",
4+
"description": "**BREAKFIX**: Revert bad API release.",
5+
"modules": [
6+
"feature/rds/auth"
7+
]
8+
}

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
## Module Highlights
3737
* `github.com/aws/aws-sdk-go-v2/feature/rds/auth`: [v1.5.0](feature/rds/auth/CHANGELOG.md#v150-2024-12-032)
38-
* **Feature**: feat: Add Xanadu Auth Token Generator
38+
* No change notes available for this release.
3939
* `github.com/aws/aws-sdk-go-v2/service/athena`: [v1.49.0](service/athena/CHANGELOG.md#v1490-2024-12-032)
4040
* **Feature**: Add FEDERATED type to CreateDataCatalog. This creates Athena Data Catalog, AWS Lambda connector, and AWS Glue connection. Create/DeleteDataCatalog returns DataCatalog. Add Status, ConnectionType, and Error to DataCatalog and DataCatalogSummary. Add DeleteCatalogOnly to delete Athena Catalog only.
4141
* `github.com/aws/aws-sdk-go-v2/service/bedrock`: [v1.24.0](service/bedrock/CHANGELOG.md#v1240-2024-12-032)

feature/rds/auth/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# v1.5.0 (2024-12-03.2)
22

3-
* **Feature**: feat: Add Xanadu Auth Token Generator
3+
* No change notes available for this release.
44

55
# v1.4.25 (2024-12-02)
66

feature/rds/auth/connect.go

Lines changed: 11 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,26 @@ import (
44
"context"
55
"fmt"
66
"net/http"
7-
"net/url"
87
"strconv"
98
"strings"
109
"time"
1110

1211
"github.com/aws/aws-sdk-go-v2/aws"
1312
"github.com/aws/aws-sdk-go-v2/aws/signer/v4"
14-
"github.com/aws/aws-sdk-go-v2/internal/sdk"
1513
)
1614

1715
const (
18-
rdsAuthTokenID = "rds-db"
19-
rdsClusterTokenID = "dsql"
20-
emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
21-
userAction = "DbConnect"
22-
adminUserAction = "DbConnectAdmin"
16+
signingID = "rds-db"
17+
emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
2318
)
2419

2520
// BuildAuthTokenOptions is the optional set of configuration properties for BuildAuthToken
26-
type BuildAuthTokenOptions struct {
27-
ExpiresIn time.Duration
28-
}
21+
type BuildAuthTokenOptions struct{}
2922

3023
// BuildAuthToken will return an authorization token used as the password for a DB
3124
// connection.
3225
//
33-
// * endpoint - Endpoint consists of the hostname and port needed to connect to the DB. <host>:<port>
26+
// * endpoint - Endpoint consists of the port needed to connect to the DB. <host>:<port>
3427
// * region - Region is the location of where the DB is
3528
// * dbUser - User account within the database to sign in with
3629
// * creds - Credentials to be signed with
@@ -57,64 +50,12 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds
5750
return "", fmt.Errorf("the provided endpoint is missing a port, or the provided port is invalid")
5851
}
5952

60-
values := url.Values{
61-
"Action": []string{"connect"},
62-
"DBUser": []string{dbUser},
63-
}
64-
65-
return generateAuthToken(ctx, endpoint, region, values, rdsAuthTokenID, creds, optFns...)
66-
}
67-
68-
// GenerateDbConnectAuthToken will return an authorization token as the password for a
69-
// DB connection.
70-
//
71-
// This is the regular user variant, see [GenerateDBConnectSuperUserAuthToken] for the superuser variant
72-
//
73-
// * endpoint - Endpoint is the hostname and optional port to connect to the DB
74-
// * region - Region is the location of where the DB is
75-
// * creds - Credentials to be signed with
76-
func GenerateDbConnectAuthToken(ctx context.Context, endpoint, region string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
77-
values := url.Values{
78-
"Action": []string{userAction},
79-
}
80-
return generateAuthToken(ctx, endpoint, region, values, rdsClusterTokenID, creds, optFns...)
81-
}
82-
83-
// GenerateDBConnectSuperUserAuthToken will return an authorization token as the password for a
84-
// DB connection.
85-
//
86-
// This is the superuser user variant, see [GenerateDBConnectSuperUserAuthToken] for the regular user variant
87-
//
88-
// * endpoint - Endpoint is the hostname and optional port to connect to the DB
89-
// * region - Region is the location of where the DB is
90-
// * creds - Credentials to be signed with
91-
func GenerateDBConnectSuperUserAuthToken(ctx context.Context, endpoint, region string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
92-
values := url.Values{
93-
"Action": []string{adminUserAction},
94-
}
95-
return generateAuthToken(ctx, endpoint, region, values, rdsClusterTokenID, creds, optFns...)
96-
}
97-
98-
// All generate token functions are presigned URLs behind the scenes with the scheme stripped.
99-
// This function abstracts generating this for all use cases
100-
func generateAuthToken(ctx context.Context, endpoint, region string, values url.Values, signingID string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
101-
if len(region) == 0 {
102-
return "", fmt.Errorf("region is required")
103-
}
104-
if len(endpoint) == 0 {
105-
return "", fmt.Errorf("endpoint is required")
106-
}
107-
10853
o := BuildAuthTokenOptions{}
10954

11055
for _, fn := range optFns {
11156
fn(&o)
11257
}
11358

114-
if o.ExpiresIn == 0 {
115-
o.ExpiresIn = 15 * time.Minute
116-
}
117-
11859
if creds == nil {
11960
return "", fmt.Errorf("credetials provider must not ne nil")
12061
}
@@ -128,25 +69,24 @@ func generateAuthToken(ctx context.Context, endpoint, region string, values url.
12869
if err != nil {
12970
return "", err
13071
}
72+
values := req.URL.Query()
73+
values.Set("Action", "connect")
74+
values.Set("DBUser", dbUser)
13175
req.URL.RawQuery = values.Encode()
76+
13277
signer := v4.NewSigner()
13378

13479
credentials, err := creds.Retrieve(ctx)
13580
if err != nil {
13681
return "", err
13782
}
13883

139-
expires := o.ExpiresIn
140-
// if creds expire before expiresIn, set that as the expiration time
141-
if credentials.CanExpire && !credentials.Expires.IsZero() {
142-
credsExpireIn := credentials.Expires.Sub(sdk.NowTime())
143-
expires = min(o.ExpiresIn, credsExpireIn)
144-
}
84+
// Expire Time: 15 minute
14585
query := req.URL.Query()
146-
query.Set("X-Amz-Expires", strconv.Itoa(int(expires.Seconds())))
86+
query.Set("X-Amz-Expires", "900")
14787
req.URL.RawQuery = query.Encode()
14888

149-
signedURI, _, err := signer.PresignHTTP(ctx, credentials, req, emptyPayloadHash, signingID, region, sdk.NowTime().UTC())
89+
signedURI, _, err := signer.PresignHTTP(ctx, credentials, req, emptyPayloadHash, signingID, region, time.Now().UTC())
15090
if err != nil {
15191
return "", err
15292
}

feature/rds/auth/connect_test.go

Lines changed: 2 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@ package auth_test
22

33
import (
44
"context"
5-
"net/url"
65
"regexp"
76
"strings"
87
"testing"
9-
"time"
108

119
"github.com/aws/aws-sdk-go-v2/aws"
1210
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
13-
"github.com/aws/aws-sdk-go-v2/internal/sdk"
1411
)
1512

1613
func TestBuildAuthToken(t *testing.T) {
@@ -70,155 +67,14 @@ func TestBuildAuthToken(t *testing.T) {
7067
}
7168
}
7269

73-
type dbAuthTestCase struct {
74-
endpoint string
75-
region string
76-
expires time.Duration
77-
credsExpireIn time.Duration
78-
expectedHost string
79-
expectedQueryParams []string
80-
expectedError string
81-
}
82-
83-
type tokenGenFunc func(ctx context.Context, endpoint, region string, creds aws.CredentialsProvider, optFns ...func(options *auth.BuildAuthTokenOptions)) (string, error)
84-
85-
func TestGenerateDbConnectAuthToken(t *testing.T) {
86-
cases := map[string]dbAuthTestCase{
87-
"no region": {
88-
endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306",
89-
expectedError: "no region",
90-
},
91-
"no endpoint": {
92-
region: "us-west-2",
93-
expectedError: "port",
94-
},
95-
"endpoint with scheme": {
96-
endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306",
97-
region: "us-east-1",
98-
expectedHost: "prod-instance.us-east-1.rds.amazonaws.com:3306",
99-
expectedQueryParams: []string{"Action=DbConnect"},
100-
},
101-
"endpoint without scheme": {
102-
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:3306",
103-
region: "us-east-1",
104-
expectedHost: "prod-instance.us-east-1.rds.amazonaws.com:3306",
105-
expectedQueryParams: []string{"Action=DbConnect"},
106-
},
107-
"endpoint without port": {
108-
endpoint: "prod-instance.us-east-1.rds.amazonaws.com",
109-
region: "us-east-1",
110-
expectedHost: "prod-instance.us-east-1.rds.amazonaws.com",
111-
expectedQueryParams: []string{"Action=DbConnect"},
112-
},
113-
"endpoint with region and expires": {
114-
endpoint: "peccy.dsql.us-east-1.on.aws",
115-
region: "us-east-1",
116-
expires: time.Second * 450,
117-
expectedHost: "peccy.dsql.us-east-1.on.aws",
118-
expectedQueryParams: []string{
119-
"Action=DbConnect",
120-
"X-Amz-Algorithm=AWS4-HMAC-SHA256",
121-
"X-Amz-Credential=akid/20240827/us-east-1/dsql/aws4_request",
122-
"X-Amz-Date=20240827T000000Z",
123-
"X-Amz-Expires=450"},
124-
},
125-
"pick credential expires when less than expires": {
126-
endpoint: "peccy.dsql.us-east-1.on.aws",
127-
region: "us-east-1",
128-
credsExpireIn: time.Second * 100,
129-
expires: time.Second * 450,
130-
expectedHost: "peccy.dsql.us-east-1.on.aws",
131-
expectedQueryParams: []string{
132-
"Action=DbConnect",
133-
"X-Amz-Algorithm=AWS4-HMAC-SHA256",
134-
"X-Amz-Credential=akid/20240827/us-east-1/dsql/aws4_request",
135-
"X-Amz-Date=20240827T000000Z",
136-
"X-Amz-Expires=100"},
137-
},
138-
}
139-
140-
for _, c := range cases {
141-
creds := &staticCredentials{AccessKey: "akid", SecretKey: "secret", expiresIn: c.credsExpireIn}
142-
defer withTempGlobalTime(time.Date(2024, time.August, 27, 0, 0, 0, 0, time.UTC))()
143-
optFns := func(options *auth.BuildAuthTokenOptions) {}
144-
if c.expires != 0 {
145-
optFns = func(options *auth.BuildAuthTokenOptions) {
146-
options.ExpiresIn = c.expires
147-
}
148-
}
149-
verifyTestCase(auth.GenerateDbConnectAuthToken, c, creds, optFns, t)
150-
151-
// Update the test case to use Superuser variant
152-
updated := []string{}
153-
for _, part := range c.expectedQueryParams {
154-
if part == "Action=DbConnect" {
155-
part = "Action=DbConnectAdmin"
156-
}
157-
updated = append(updated, part)
158-
}
159-
c.expectedQueryParams = updated
160-
161-
verifyTestCase(auth.GenerateDBConnectSuperUserAuthToken, c, creds, optFns, t)
162-
}
163-
}
164-
165-
func verifyTestCase(f tokenGenFunc, c dbAuthTestCase, creds aws.CredentialsProvider, optFns func(options *auth.BuildAuthTokenOptions), t *testing.T) {
166-
token, err := f(context.Background(), c.endpoint, c.region, creds, optFns)
167-
isErrorExpected := len(c.expectedError) > 0
168-
if err != nil && !isErrorExpected {
169-
t.Fatalf("expect no err, got: %v", err)
170-
} else if err == nil && isErrorExpected {
171-
t.Fatalf("Expected error %v got none", c.expectedError)
172-
}
173-
// adding a scheme so we can parse it back as a URL. This is because comparing
174-
// just direct string comparison was failing since "Action=DbConnect" is a substring or
175-
// "Action=DBConnectSuperuser"
176-
parsed, err := url.Parse("http://" + token)
177-
if err != nil {
178-
t.Fatalf("Couldn't parse the token %v to URL after adding a scheme, got: %v", token, err)
179-
}
180-
if parsed.Host != c.expectedHost {
181-
t.Errorf("expect host %v, got %v", c.expectedHost, parsed.Host)
182-
}
183-
184-
q := parsed.Query()
185-
queryValuePair := map[string]any{}
186-
for k, v := range q {
187-
pair := k + "=" + v[0]
188-
queryValuePair[pair] = struct{}{}
189-
}
190-
191-
for _, part := range c.expectedQueryParams {
192-
if _, ok := queryValuePair[part]; !ok {
193-
t.Errorf("expect part %s to be present at token %s", part, token)
194-
}
195-
}
196-
if token != "" && c.expires == 0 {
197-
if !strings.Contains(token, "X-Amz-Expires=900") {
198-
t.Errorf("expect token to contain default X-Amz-Expires value of 900, got %v", token)
199-
}
200-
}
201-
}
202-
20370
type staticCredentials struct {
20471
AccessKey, SecretKey, Session string
205-
expiresIn time.Duration
20672
}
20773

20874
func (s *staticCredentials) Retrieve(ctx context.Context) (aws.Credentials, error) {
209-
c := aws.Credentials{
75+
return aws.Credentials{
21076
AccessKeyID: s.AccessKey,
21177
SecretAccessKey: s.SecretKey,
21278
SessionToken: s.Session,
213-
}
214-
if s.expiresIn != 0 {
215-
c.CanExpire = true
216-
c.Expires = sdk.NowTime().Add(s.expiresIn)
217-
}
218-
return c, nil
219-
}
220-
221-
func withTempGlobalTime(t time.Time) func() {
222-
sdk.NowTime = func() time.Time { return t }
223-
return func() { sdk.NowTime = time.Now }
79+
}, nil
22480
}

0 commit comments

Comments
 (0)