Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions pkg/utils/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ func ConvertURL(s string) string {
return s
}

func removePartialFile(tmpFilePath string) error {
_, err := os.Stat(tmpFilePath)
if err == nil {
log.Debug().Msgf("Removing temporary file %s", tmpFilePath)
err = os.Remove(tmpFilePath)
if err != nil {
err1 := fmt.Errorf("failed to remove temporary download file %s: %v", tmpFilePath, err)
log.Warn().Msg(err1.Error())
return err1
}
}
return nil
}

func DownloadFile(url string, filePath, sha string, downloadStatus func(string, string, string, float64)) error {
url = ConvertURL(url)
// Check if the file already exists
Expand Down Expand Up @@ -143,15 +157,24 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
return fmt.Errorf("failed to create parent directory for file %q: %v", filePath, err)
}

// save partial download to dedicated file
tmpFilePath := filePath + ".partial"

// remove tmp file
err = removePartialFile(tmpFilePath)
if err != nil {
return err
}

// Create and write file content
outFile, err := os.Create(filePath)
outFile, err := os.Create(tmpFilePath)
if err != nil {
return fmt.Errorf("failed to create file %q: %v", filePath, err)
return fmt.Errorf("failed to create file %q: %v", tmpFilePath, err)
}
defer outFile.Close()

progress := &progressWriter{
fileName: filePath,
fileName: tmpFilePath,
total: resp.ContentLength,
hash: sha256.New(),
downloadStatus: downloadStatus,
Expand All @@ -161,6 +184,11 @@ func DownloadFile(url string, filePath, sha string, downloadStatus func(string,
return fmt.Errorf("failed to write file %q: %v", filePath, err)
}

err = os.Rename(tmpFilePath, filePath)
if err != nil {
return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err)
}

if sha != "" {
// Verify SHA
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
Expand Down