@@ -12,12 +12,13 @@ import (
1212 "strings"
1313 "time"
1414
15- "github.com/aws/aws-sdk-go/aws"
16- "github.com/aws/aws-sdk-go/aws/credentials"
17- "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
18- "github.com/aws/aws-sdk-go/aws/ec2metadata"
19- "github.com/aws/aws-sdk-go/aws/session"
20- "github.com/aws/aws-sdk-go/service/s3"
15+ "github.com/aws/aws-sdk-go-v2/aws"
16+ "github.com/aws/aws-sdk-go-v2/config"
17+ "github.com/aws/aws-sdk-go-v2/credentials"
18+ "github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds"
19+ "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
20+ "github.com/aws/aws-sdk-go-v2/service/s3"
21+ "github.com/hashicorp/aws-sdk-go-base/v2/endpoints"
2122)
2223
2324// S3Getter is a Getter implementation that will download a module from
@@ -54,24 +55,27 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) {
5455 }
5556
5657 // List the object(s) at the given prefix
57- req := & s3.ListObjectsInput {
58+ req := & s3.ListObjectsV2Input {
5859 Bucket : aws .String (bucket ),
5960 Prefix : aws .String (path ),
6061 }
61- resp , err := client .ListObjectsWithContext (ctx , req )
62- if err != nil {
63- return 0 , err
64- }
65-
66- for _ , o := range resp .Contents {
67- // Use file mode on exact match.
68- if * o .Key == path {
69- return ClientModeFile , nil
62+ paginator := s3 .NewListObjectsV2Paginator (client , req )
63+ for paginator .HasMorePages () {
64+ output , err := paginator .NextPage (ctx )
65+ if err != nil {
66+ return 0 , err
7067 }
7168
72- // Use dir mode if child keys are found.
73- if strings .HasPrefix (* o .Key , path + "/" ) {
74- return ClientModeDir , nil
69+ for _ , o := range output .Contents {
70+ // Use file mode on exact match.
71+ if aws .ToString (o .Key ) == path {
72+ return ClientModeFile , nil
73+ }
74+
75+ // Use dir mode if child keys are found.
76+ if strings .HasPrefix (aws .ToString (o .Key ), path + "/" ) {
77+ return ClientModeDir , nil
78+ }
7579 }
7680 }
7781
@@ -119,28 +123,19 @@ func (g *S3Getter) Get(dst string, u *url.URL) error {
119123 }
120124
121125 // List files in path, keep listing until no more objects are found
122- lastMarker := ""
123- hasMore := true
124- for hasMore {
125- req := & s3.ListObjectsInput {
126- Bucket : aws .String (bucket ),
127- Prefix : aws .String (path ),
128- }
129- if lastMarker != "" {
130- req .Marker = aws .String (lastMarker )
131- }
132-
133- resp , err := client .ListObjectsWithContext (ctx , req )
126+ req := & s3.ListObjectsV2Input {
127+ Bucket : aws .String (bucket ),
128+ Prefix : aws .String (path ),
129+ }
130+ paginator := s3 .NewListObjectsV2Paginator (client , req )
131+ for paginator .HasMorePages () {
132+ output , err := paginator .NextPage (ctx )
134133 if err != nil {
135134 return err
136135 }
137136
138- hasMore = aws .BoolValue (resp .IsTruncated )
139-
140- // Get each object storing each file relative to the destination path
141- for _ , object := range resp .Contents {
142- lastMarker = aws .StringValue (object .Key )
143- objPath := aws .StringValue (object .Key )
137+ for _ , object := range output .Contents {
138+ objPath := aws .ToString (object .Key )
144139
145140 // If the key ends with a backslash assume it is a directory and ignore
146141 if strings .HasSuffix (objPath , "/" ) {
@@ -185,7 +180,7 @@ func (g *S3Getter) GetFile(dst string, u *url.URL) error {
185180 return g .getObject (ctx , client , dst , bucket , path , version )
186181}
187182
188- func (g * S3Getter ) getObject (ctx context.Context , client * s3.S3 , dst , bucket , key , version string ) error {
183+ func (g * S3Getter ) getObject (ctx context.Context , client * s3.Client , dst , bucket , key , version string ) error {
189184 req := & s3.GetObjectInput {
190185 Bucket : aws .String (bucket ),
191186 Key : aws .String (key ),
@@ -194,7 +189,7 @@ func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, ke
194189 req .VersionId = aws .String (version )
195190 }
196191
197- resp , err := client .GetObjectWithContext (ctx , req )
192+ resp , err := client .GetObject (ctx , req )
198193 if err != nil {
199194 return err
200195 }
@@ -208,57 +203,62 @@ func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, ke
208203
209204 if g .client != nil && g .client .ProgressListener != nil {
210205 fn := filepath .Base (key )
211- body = g .client .ProgressListener .TrackProgress (fn , 0 , * resp .ContentLength , resp .Body )
206+ body = g .client .ProgressListener .TrackProgress (fn , 0 , aws . ToInt64 ( resp .ContentLength ) , resp .Body )
212207 }
213208 defer func () { _ = body .Close () }()
214209
215210 // There is no limit set for the size of an object from S3
216211 return copyReader (dst , body , 0666 , g .client .umask (), 0 )
217212}
218213
219- func (g * S3Getter ) getAWSConfig (region string , url * url.URL , creds * credentials.Credentials ) (* aws.Config , error ) {
220- conf := & aws.Config {}
221- metadataURLOverride := os .Getenv ("AWS_METADATA_URL" )
222- if creds == nil && metadataURLOverride != "" {
223- s , err := session .NewSession (& aws.Config {
224- Endpoint : aws .String (metadataURLOverride ),
225- })
226- if err != nil {
227- return nil , err
228- }
214+ func (g * S3Getter ) getAWSConfig (region string , url * url.URL , staticCreds * credentials.StaticCredentialsProvider ) (conf aws.Config , err error ) {
215+ var loadOptions []func (* config.LoadOptions ) error
216+ var creds aws.CredentialsProvider
229217
230- creds = credentials .NewChainCredentials (
231- []credentials.Provider {
232- & credentials.EnvProvider {},
233- & credentials.SharedCredentialsProvider {Filename : "" , Profile : "" },
234- & ec2rolecreds.EC2RoleProvider {
235- Client : ec2metadata .New (s ),
236- },
218+ metadataURLOverride := os .Getenv ("AWS_METADATA_URL" )
219+ if staticCreds == nil && metadataURLOverride != "" {
220+ creds = ec2rolecreds .New (func (o * ec2rolecreds.Options ) {
221+ o .Client = imds .New (imds.Options {
222+ Endpoint : metadataURLOverride ,
223+ ClientEnableState : imds .ClientEnabled ,
237224 })
225+ })
226+ } else if staticCreds != nil {
227+ creds = staticCreds
238228 }
239229
240230 if creds != nil {
241- conf .Endpoint = & url .Host
242- conf .S3ForcePathStyle = aws .Bool (true )
243- if url .Scheme == "http" {
244- conf .DisableSSL = aws .Bool (true )
245- }
231+ loadOptions = append (loadOptions ,
232+ config .WithEC2IMDSClientEnableState (imds .ClientEnabled ),
233+ config .WithCredentialsProvider (creds ),
234+ config .WithEndpointResolverWithOptions (aws .EndpointResolverWithOptionsFunc (
235+ func (service , region string , options ... interface {}) (aws.Endpoint , error ) {
236+ return aws.Endpoint {URL : url .Host }, nil
237+ },
238+ )))
246239 }
247240
248241 conf .Credentials = creds
249242 if region != "" {
250- conf . Region = aws . String (region )
243+ loadOptions = append ( loadOptions , config . WithRegion (region ) )
251244 }
252245
253- conf = conf .WithCredentialsChainVerboseErrors (true )
254- return conf , nil
246+ return config .LoadDefaultConfig (g .Context (), loadOptions ... )
255247}
256248
257- func (g * S3Getter ) parseUrl (u * url.URL ) (region , bucket , path , version string , creds * credentials.Credentials , err error ) {
249+ func (g * S3Getter ) parseUrl (u * url.URL ) (region , bucket , path , version string , creds * credentials.StaticCredentialsProvider , err error ) {
258250 // This just check whether we are dealing with S3 or
259251 // any other S3 compliant service. S3 has a predictable
260252 // url as others do not
261- if strings .HasSuffix (u .Host , ".amazonaws.com" ) {
253+ var awsDomain * string
254+ for _ , partition := range endpoints .DefaultPartitions () {
255+ if strings .HasSuffix (u .Host , partition .DNSSuffix ()) {
256+ awsDomain = aws .String (partition .DNSSuffix ())
257+ break
258+ }
259+ }
260+
261+ if awsDomain != nil {
262262 // Amazon S3 supports both virtual-hosted–style and path-style URLs to access a bucket, although path-style is deprecated
263263 // In both cases few older regions supports dash-style region indication (s3-Region) even if AWS discourages their use.
264264 // The same bucket could be reached with:
@@ -267,10 +267,10 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
267267 // s3.amazonaws.com/bucket/path
268268 // s3-region.amazonaws.com/bucket/path
269269
270- hostParts := strings .Split (u .Host , "." )
270+ hostParts := strings .Split (strings . TrimSuffix ( u .Host , * awsDomain ) , "." )
271271 switch len (hostParts ) {
272272 // path-style
273- case 3 :
273+ case 2 :
274274 // Parse the region out of the first part of the host
275275 region = strings .TrimPrefix (strings .TrimPrefix (hostParts [0 ], "s3-" ), "s3" )
276276 if region == "" {
@@ -284,7 +284,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
284284 bucket = pathParts [1 ]
285285 path = pathParts [2 ]
286286 // vhost-style, dash region indication
287- case 4 :
287+ case 3 :
288288 // Parse the region out of the second part of the host
289289 region = strings .TrimPrefix (strings .TrimPrefix (hostParts [1 ], "s3-" ), "s3" )
290290 if region == "" {
@@ -299,7 +299,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
299299 bucket = hostParts [0 ]
300300 path = pathParts [1 ]
301301 //vhost-style, dot region indication
302- case 5 :
302+ case 4 :
303303 region = hostParts [2 ]
304304 pathParts := strings .SplitN (u .Path , "/" , 2 )
305305 if len (pathParts ) < 2 {
@@ -310,7 +310,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
310310 path = pathParts [1 ]
311311
312312 }
313- if len (hostParts ) < 3 || len (hostParts ) > 5 {
313+ if len (hostParts ) < 2 || len (hostParts ) > 4 {
314314 err = fmt .Errorf ("URL is not a valid S3 URL" )
315315 return
316316 }
@@ -335,40 +335,36 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
335335 _ , hasAwsSecret := u .Query ()["aws_access_key_secret" ]
336336 _ , hasAwsToken := u .Query ()["aws_access_token" ]
337337 if hasAwsId || hasAwsSecret || hasAwsToken {
338- creds = credentials .NewStaticCredentials (
338+ provider : = credentials .NewStaticCredentialsProvider (
339339 u .Query ().Get ("aws_access_key_id" ),
340340 u .Query ().Get ("aws_access_key_secret" ),
341341 u .Query ().Get ("aws_access_token" ),
342342 )
343+ creds = & provider
343344 }
344345
345346 return
346347}
347348
348349func (g * S3Getter ) newS3Client (
349- region string , url * url.URL , creds * credentials.Credentials ,
350- ) (* s3.S3 , error ) {
351- var sess * session.Session
350+ region string , url * url.URL , creds * credentials.StaticCredentialsProvider ,
351+ ) (* s3.Client , error ) {
352+ var err error
353+ var cfg aws.Config
352354
353355 if profile := url .Query ().Get ("aws_profile" ); profile != "" {
354- var err error
355- sess , err = session .NewSessionWithOptions (session.Options {
356- Profile : profile ,
357- SharedConfigState : session .SharedConfigEnable ,
358- })
359- if err != nil {
360- return nil , err
361- }
356+ cfg , err = config .LoadDefaultConfig (g .Context (),
357+ config .WithSharedConfigProfile (profile ),
358+ )
362359 } else {
363- config , err := g .getAWSConfig (region , url , creds )
364- if err != nil {
365- return nil , err
366- }
367- sess , err = session .NewSession (config )
368- if err != nil {
369- return nil , err
370- }
360+ cfg , err = g .getAWSConfig (region , url , creds )
361+ }
362+
363+ if err != nil {
364+ return nil , err
371365 }
372366
373- return s3 .New (sess ), nil
367+ return s3 .NewFromConfig (cfg , func (opts * s3.Options ) {
368+ opts .UsePathStyle = true
369+ }), nil
374370}
0 commit comments