Skip to content

Commit c52d97c

Browse files
authored
Upgrade to aws-sdk-go-v2, v1 is unsupported (#548)
* Upgrade to aws-sdk-go-v2, v1 is unsupported * Upgrade tests * Dependencies
1 parent bb1ddfb commit c52d97c

File tree

4 files changed

+188
-1562
lines changed

4 files changed

+188
-1562
lines changed

get_s3.go

Lines changed: 89 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ import (
1212
"strings"
1313
"time"
1414

15-
"github.com/aws/aws-sdk-go/aws"
16-
"github.com/aws/aws-sdk-go/aws/credentials"
17-
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
18-
"github.com/aws/aws-sdk-go/aws/ec2metadata"
19-
"github.com/aws/aws-sdk-go/aws/session"
20-
"github.com/aws/aws-sdk-go/service/s3"
15+
"github.com/aws/aws-sdk-go-v2/aws"
16+
"github.com/aws/aws-sdk-go-v2/config"
17+
"github.com/aws/aws-sdk-go-v2/credentials"
18+
"github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds"
19+
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
20+
"github.com/aws/aws-sdk-go-v2/service/s3"
21+
"github.com/hashicorp/aws-sdk-go-base/v2/endpoints"
2122
)
2223

2324
// S3Getter is a Getter implementation that will download a module from
@@ -54,24 +55,27 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) {
5455
}
5556

5657
// List the object(s) at the given prefix
57-
req := &s3.ListObjectsInput{
58+
req := &s3.ListObjectsV2Input{
5859
Bucket: aws.String(bucket),
5960
Prefix: aws.String(path),
6061
}
61-
resp, err := client.ListObjectsWithContext(ctx, req)
62-
if err != nil {
63-
return 0, err
64-
}
65-
66-
for _, o := range resp.Contents {
67-
// Use file mode on exact match.
68-
if *o.Key == path {
69-
return ClientModeFile, nil
62+
paginator := s3.NewListObjectsV2Paginator(client, req)
63+
for paginator.HasMorePages() {
64+
output, err := paginator.NextPage(ctx)
65+
if err != nil {
66+
return 0, err
7067
}
7168

72-
// Use dir mode if child keys are found.
73-
if strings.HasPrefix(*o.Key, path+"/") {
74-
return ClientModeDir, nil
69+
for _, o := range output.Contents {
70+
// Use file mode on exact match.
71+
if aws.ToString(o.Key) == path {
72+
return ClientModeFile, nil
73+
}
74+
75+
// Use dir mode if child keys are found.
76+
if strings.HasPrefix(aws.ToString(o.Key), path+"/") {
77+
return ClientModeDir, nil
78+
}
7579
}
7680
}
7781

@@ -119,28 +123,19 @@ func (g *S3Getter) Get(dst string, u *url.URL) error {
119123
}
120124

121125
// List files in path, keep listing until no more objects are found
122-
lastMarker := ""
123-
hasMore := true
124-
for hasMore {
125-
req := &s3.ListObjectsInput{
126-
Bucket: aws.String(bucket),
127-
Prefix: aws.String(path),
128-
}
129-
if lastMarker != "" {
130-
req.Marker = aws.String(lastMarker)
131-
}
132-
133-
resp, err := client.ListObjectsWithContext(ctx, req)
126+
req := &s3.ListObjectsV2Input{
127+
Bucket: aws.String(bucket),
128+
Prefix: aws.String(path),
129+
}
130+
paginator := s3.NewListObjectsV2Paginator(client, req)
131+
for paginator.HasMorePages() {
132+
output, err := paginator.NextPage(ctx)
134133
if err != nil {
135134
return err
136135
}
137136

138-
hasMore = aws.BoolValue(resp.IsTruncated)
139-
140-
// Get each object storing each file relative to the destination path
141-
for _, object := range resp.Contents {
142-
lastMarker = aws.StringValue(object.Key)
143-
objPath := aws.StringValue(object.Key)
137+
for _, object := range output.Contents {
138+
objPath := aws.ToString(object.Key)
144139

145140
// If the key ends with a backslash assume it is a directory and ignore
146141
if strings.HasSuffix(objPath, "/") {
@@ -185,7 +180,7 @@ func (g *S3Getter) GetFile(dst string, u *url.URL) error {
185180
return g.getObject(ctx, client, dst, bucket, path, version)
186181
}
187182

188-
func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, key, version string) error {
183+
func (g *S3Getter) getObject(ctx context.Context, client *s3.Client, dst, bucket, key, version string) error {
189184
req := &s3.GetObjectInput{
190185
Bucket: aws.String(bucket),
191186
Key: aws.String(key),
@@ -194,7 +189,7 @@ func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, ke
194189
req.VersionId = aws.String(version)
195190
}
196191

197-
resp, err := client.GetObjectWithContext(ctx, req)
192+
resp, err := client.GetObject(ctx, req)
198193
if err != nil {
199194
return err
200195
}
@@ -208,57 +203,62 @@ func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, ke
208203

209204
if g.client != nil && g.client.ProgressListener != nil {
210205
fn := filepath.Base(key)
211-
body = g.client.ProgressListener.TrackProgress(fn, 0, *resp.ContentLength, resp.Body)
206+
body = g.client.ProgressListener.TrackProgress(fn, 0, aws.ToInt64(resp.ContentLength), resp.Body)
212207
}
213208
defer func() { _ = body.Close() }()
214209

215210
// There is no limit set for the size of an object from S3
216211
return copyReader(dst, body, 0666, g.client.umask(), 0)
217212
}
218213

219-
func (g *S3Getter) getAWSConfig(region string, url *url.URL, creds *credentials.Credentials) (*aws.Config, error) {
220-
conf := &aws.Config{}
221-
metadataURLOverride := os.Getenv("AWS_METADATA_URL")
222-
if creds == nil && metadataURLOverride != "" {
223-
s, err := session.NewSession(&aws.Config{
224-
Endpoint: aws.String(metadataURLOverride),
225-
})
226-
if err != nil {
227-
return nil, err
228-
}
214+
func (g *S3Getter) getAWSConfig(region string, url *url.URL, staticCreds *credentials.StaticCredentialsProvider) (conf aws.Config, err error) {
215+
var loadOptions []func(*config.LoadOptions) error
216+
var creds aws.CredentialsProvider
229217

230-
creds = credentials.NewChainCredentials(
231-
[]credentials.Provider{
232-
&credentials.EnvProvider{},
233-
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
234-
&ec2rolecreds.EC2RoleProvider{
235-
Client: ec2metadata.New(s),
236-
},
218+
metadataURLOverride := os.Getenv("AWS_METADATA_URL")
219+
if staticCreds == nil && metadataURLOverride != "" {
220+
creds = ec2rolecreds.New(func(o *ec2rolecreds.Options) {
221+
o.Client = imds.New(imds.Options{
222+
Endpoint: metadataURLOverride,
223+
ClientEnableState: imds.ClientEnabled,
237224
})
225+
})
226+
} else if staticCreds != nil {
227+
creds = staticCreds
238228
}
239229

240230
if creds != nil {
241-
conf.Endpoint = &url.Host
242-
conf.S3ForcePathStyle = aws.Bool(true)
243-
if url.Scheme == "http" {
244-
conf.DisableSSL = aws.Bool(true)
245-
}
231+
loadOptions = append(loadOptions,
232+
config.WithEC2IMDSClientEnableState(imds.ClientEnabled),
233+
config.WithCredentialsProvider(creds),
234+
config.WithEndpointResolverWithOptions(aws.EndpointResolverWithOptionsFunc(
235+
func(service, region string, options ...interface{}) (aws.Endpoint, error) {
236+
return aws.Endpoint{URL: url.Host}, nil
237+
},
238+
)))
246239
}
247240

248241
conf.Credentials = creds
249242
if region != "" {
250-
conf.Region = aws.String(region)
243+
loadOptions = append(loadOptions, config.WithRegion(region))
251244
}
252245

253-
conf = conf.WithCredentialsChainVerboseErrors(true)
254-
return conf, nil
246+
return config.LoadDefaultConfig(g.Context(), loadOptions...)
255247
}
256248

257-
func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, creds *credentials.Credentials, err error) {
249+
func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, creds *credentials.StaticCredentialsProvider, err error) {
258250
// This just check whether we are dealing with S3 or
259251
// any other S3 compliant service. S3 has a predictable
260252
// url as others do not
261-
if strings.HasSuffix(u.Host, ".amazonaws.com") {
253+
var awsDomain *string
254+
for _, partition := range endpoints.DefaultPartitions() {
255+
if strings.HasSuffix(u.Host, partition.DNSSuffix()) {
256+
awsDomain = aws.String(partition.DNSSuffix())
257+
break
258+
}
259+
}
260+
261+
if awsDomain != nil {
262262
// Amazon S3 supports both virtual-hosted–style and path-style URLs to access a bucket, although path-style is deprecated
263263
// In both cases few older regions supports dash-style region indication (s3-Region) even if AWS discourages their use.
264264
// The same bucket could be reached with:
@@ -267,10 +267,10 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
267267
// s3.amazonaws.com/bucket/path
268268
// s3-region.amazonaws.com/bucket/path
269269

270-
hostParts := strings.Split(u.Host, ".")
270+
hostParts := strings.Split(strings.TrimSuffix(u.Host, *awsDomain), ".")
271271
switch len(hostParts) {
272272
// path-style
273-
case 3:
273+
case 2:
274274
// Parse the region out of the first part of the host
275275
region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3")
276276
if region == "" {
@@ -284,7 +284,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
284284
bucket = pathParts[1]
285285
path = pathParts[2]
286286
// vhost-style, dash region indication
287-
case 4:
287+
case 3:
288288
// Parse the region out of the second part of the host
289289
region = strings.TrimPrefix(strings.TrimPrefix(hostParts[1], "s3-"), "s3")
290290
if region == "" {
@@ -299,7 +299,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
299299
bucket = hostParts[0]
300300
path = pathParts[1]
301301
//vhost-style, dot region indication
302-
case 5:
302+
case 4:
303303
region = hostParts[2]
304304
pathParts := strings.SplitN(u.Path, "/", 2)
305305
if len(pathParts) < 2 {
@@ -310,7 +310,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
310310
path = pathParts[1]
311311

312312
}
313-
if len(hostParts) < 3 || len(hostParts) > 5 {
313+
if len(hostParts) < 2 || len(hostParts) > 4 {
314314
err = fmt.Errorf("URL is not a valid S3 URL")
315315
return
316316
}
@@ -335,40 +335,36 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
335335
_, hasAwsSecret := u.Query()["aws_access_key_secret"]
336336
_, hasAwsToken := u.Query()["aws_access_token"]
337337
if hasAwsId || hasAwsSecret || hasAwsToken {
338-
creds = credentials.NewStaticCredentials(
338+
provider := credentials.NewStaticCredentialsProvider(
339339
u.Query().Get("aws_access_key_id"),
340340
u.Query().Get("aws_access_key_secret"),
341341
u.Query().Get("aws_access_token"),
342342
)
343+
creds = &provider
343344
}
344345

345346
return
346347
}
347348

348349
func (g *S3Getter) newS3Client(
349-
region string, url *url.URL, creds *credentials.Credentials,
350-
) (*s3.S3, error) {
351-
var sess *session.Session
350+
region string, url *url.URL, creds *credentials.StaticCredentialsProvider,
351+
) (*s3.Client, error) {
352+
var err error
353+
var cfg aws.Config
352354

353355
if profile := url.Query().Get("aws_profile"); profile != "" {
354-
var err error
355-
sess, err = session.NewSessionWithOptions(session.Options{
356-
Profile: profile,
357-
SharedConfigState: session.SharedConfigEnable,
358-
})
359-
if err != nil {
360-
return nil, err
361-
}
356+
cfg, err = config.LoadDefaultConfig(g.Context(),
357+
config.WithSharedConfigProfile(profile),
358+
)
362359
} else {
363-
config, err := g.getAWSConfig(region, url, creds)
364-
if err != nil {
365-
return nil, err
366-
}
367-
sess, err = session.NewSession(config)
368-
if err != nil {
369-
return nil, err
370-
}
360+
cfg, err = g.getAWSConfig(region, url, creds)
361+
}
362+
363+
if err != nil {
364+
return nil, err
371365
}
372366

373-
return s3.New(sess), nil
367+
return s3.NewFromConfig(cfg, func(opts *s3.Options) {
368+
opts.UsePathStyle = true
369+
}), nil
374370
}

get_s3_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
package getter
55

66
import (
7+
"context"
8+
"errors"
79
"net/url"
810
"os"
911
"path/filepath"
1012
"testing"
1113

12-
"github.com/aws/aws-sdk-go/aws/awserr"
14+
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
1315
)
1416

1517
// Note for external contributors: In order to run the S3 test suite, you will only be able to be run
@@ -87,7 +89,8 @@ func TestS3Getter_GetFile_badParams(t *testing.T) {
8789
t.Fatalf("expected error, got none")
8890
}
8991

90-
if reqerr, ok := err.(awserr.RequestFailure); !ok || reqerr.StatusCode() != 403 {
92+
var respErr *awshttp.ResponseError
93+
if errors.As(err, &respErr) && respErr.HTTPStatusCode() != 403 {
9194
t.Fatalf("expected InvalidAccessKeyId error")
9295
}
9396
}
@@ -285,7 +288,7 @@ func TestS3Getter_Url(t *testing.T) {
285288
return
286289
}
287290

288-
credV, err := creds.Get()
291+
credV, err := creds.Retrieve(context.Background())
289292
if err != nil {
290293
t.Fatalf("failed to get credentials: %s", err)
291294
}

0 commit comments

Comments
 (0)