Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions detect.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func init() {
new(S3Detector),
new(GCSDetector),
new(FileDetector),
new(AzureBlobDetector),
}
}

Expand Down
67 changes: 67 additions & 0 deletions detect_azure_blob.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package getter

import (
"fmt"
"net/url"
"strings"
)

type AzureBlobDetector struct{}

func (d *AzureBlobDetector) Detect(src, pwd string) (string, bool, error) {
if len(src) == 0 {
return "", false, nil
}

if strings.Contains(src, ".blob.core.windows.net") {
return d.detectURL(src)
}

return "", false, nil
}

func (d *AzureBlobDetector) detectURL(src string) (string, bool, error) {
u, err := url.Parse(src)
if err != nil {
return "", false, err
}

if err := validateScheme(u.Scheme); err != nil {
return "", false, err
}

if err := validateAzureBlobHost(u.Host); err != nil {
return "", false, err
}

if err := validateBlobPath(u.Path); err != nil {
return "", false, err
}

u.Scheme = "https"

return fmt.Sprintf("azureblob::%s", u.String()), true, nil
}

func validateScheme(scheme string) error {
if scheme != "http" && scheme != "https" {
return fmt.Errorf("invalid scheme: %s, must be http or https", scheme)
}
return nil
}

func validateAzureBlobHost(host string) error {
hostParts := strings.Split(host, ".")
if len(hostParts) != 4 || hostParts[1] != "blob" || hostParts[2] != "core" || hostParts[3] != "windows.net" {
return fmt.Errorf("invalid Azure Blob Storage hostname: %s", host)
}
return nil
}

func validateBlobPath(path string) error {
parts := strings.Split(strings.TrimPrefix(path, "/"), "/")
if len(parts) < 1 || parts[0] == "" {
return fmt.Errorf("path to blob must contain at least a container name")
}
return nil
}
86 changes: 86 additions & 0 deletions detect_azure_blob_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package getter

import (
"testing"
)

func TestDetectURL(t *testing.T) {
d := &AzureBlobDetector{} // Assuming this struct exists

tests := []struct {
name string
input string
want string
wantOk bool
wantErr bool
}{
// Valid Cases
{
name: "Valid HTTPS URL",
input: "https://myaccount.blob.core.windows.net/mycontainer",
want: "azureblob::https://myaccount.blob.core.windows.net/mycontainer",
wantOk: true,
},
{
name: "Valid HTTP URL",
input: "http://myaccount.blob.core.windows.net/mycontainer",
want: "azureblob::https://myaccount.blob.core.windows.net/mycontainer",
wantOk: true,
},
{
name: "Valid URL with blob path",
input: "https://myaccount.blob.core.windows.net/mycontainer/mypath/file.txt",
want: "azureblob::https://myaccount.blob.core.windows.net/mycontainer/mypath/file.txt",
wantOk: true,
},

// Invalid Cases
{
name: "Invalid Scheme",
input: "ftp://myaccount.blob.core.windows.net/mycontainer",
wantErr: true,
},
{
name: "Invalid Hostname",
input: "https://myaccount.blob.azure.com/mycontainer",
wantErr: true,
},
{
name: "Missing Container Name",
input: "https://myaccount.blob.core.windows.net/",
wantErr: true,
},
{
name: "Completely Invalid URL",
input: "not_a_url",
wantErr: true,
},
{
name: "Empty Input",
input: "",
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok, err := d.detectURL(tt.input)

if tt.wantErr {
if err == nil {
t.Errorf("Expected error, got nil")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if got != tt.want {
t.Errorf("Expected %q, got %q", tt.want, got)
}
if ok != tt.wantOk {
t.Errorf("Expected ok = %v, got %v", tt.wantOk, ok)
}
}
})
}
}
15 changes: 8 additions & 7 deletions get.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,14 @@ func init() {
}

Getters = map[string]Getter{
"file": new(FileGetter),
"git": new(GitGetter),
"gcs": new(GCSGetter),
"hg": new(HgGetter),
"s3": new(S3Getter),
"http": httpGetter,
"https": httpGetter,
"file": new(FileGetter),
"git": new(GitGetter),
"gcs": new(GCSGetter),
"hg": new(HgGetter),
"s3": new(S3Getter),
"azureblob": new(AzureBlobGetter),
"http": httpGetter,
"https": httpGetter,
}
}

Expand Down
Loading