Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion feature/s3/manager/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,14 @@ func isRangeMismatch(expectStart, expectEnd, actualStart, actualEnd int) bool {
return false // we don't know, one of the ranges was missing or unparseable
}

return expectStart != actualStart && expectEnd != actualEnd
// for the final chunk (or the first chunk if it's smaller) we still
// request a full chunk but we get back the actual final part of the
// object, which will be smaller
if expectStart == actualStart && actualEnd < expectEnd {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice capture

return false
}

return expectStart != actualStart || expectEnd != actualEnd
}

// getTotalBytes is a thread-safe getter for retrieving the total byte status.
Expand Down
24 changes: 24 additions & 0 deletions feature/s3/manager/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,30 @@ func newDownloadNonRangeClient(data []byte) (*downloadCaptureClient, *int) {
return capture, &capture.GetObjectInvocations
}

func newDownloadBadRangeClient(data []byte) (*downloadCaptureClient, *int, *[]string) {
capture := &downloadCaptureClient{}

capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
start, fin := parseRange(aws.ToString(params.Range))
fin++

if fin >= int64(len(data)) {
fin = int64(len(data))
}

bodyBytes := data[start:fin]

return &s3.GetObjectOutput{
Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)),
// offset start by 1 to make it wrong
ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", start+1, fin-1, len(data))),
ContentLength: aws.Int64(int64(len(bodyBytes))),
}, nil
}

return capture, &capture.GetObjectInvocations, &capture.RetrievedRanges
}

func newDownloadVersionClient(data []byte) (*downloadCaptureClient, *int, *[]string, *[]string) {
capture := &downloadCaptureClient{}

Expand Down
34 changes: 34 additions & 0 deletions feature/s3/manager/validate_parts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package manager_test

import (
"context"
"strings"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
)

type invalidRangeClient struct {
}

func TestDownload_RangeMismatch(t *testing.T) {
c, _, _ := newDownloadBadRangeClient(buf12MB)

d := manager.NewDownloader(c, func(d *manager.Downloader) {
d.Concurrency = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test when we useDisableValidateParts and ensure this does not fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

})

w := manager.NewWriteAtBuffer(make([]byte, len(buf12MB)))
_, err := d.Download(context.Background(), w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
if err == nil {
t.Fatalf("expect err, got none")
}
if !strings.Contains(err.Error(), "invalid content range") {
t.Errorf("error mismatch:\n%v !=\n%v", err, "invalid content range")
}
}
Loading