Skip to content

Commit b6b8ab6

Browse files
authored
feat(models): pull models from urls (#2750)
* feat(models): pull models from urls When using `run` now we can point directly to hf models via URL, for instance: ```bash local-ai run huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf ``` Will pull the gguf model and place it in the models folder - of course this depends on the fact that the gguf file should be automatically detected by our guesser mechanism in order to this to make effective. Similarly now galleries can refer to single files in the API requests. This also changes the download code and `yaml` files now are treated in the same way, so now config files are saved with the appropriate name (and not hashed anymore). Signed-off-by: Ettore Di Giacinto <[email protected]> * Adapt tests Signed-off-by: Ettore Di Giacinto <[email protected]> --------- Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent b60acab commit b6b8ab6

File tree

2 files changed

+57
-9
lines changed

2 files changed

+57
-9
lines changed

pkg/startup/model_preload.go

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package startup
33
import (
44
"errors"
55
"fmt"
6+
"net/url"
67
"os"
78
"path/filepath"
89
"strings"
@@ -77,19 +78,35 @@ func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath
7778

7879
log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName)
7980
case downloader.LooksLikeURL(url):
80-
log.Debug().Msgf("[startup] resolved model to download: %s", url)
81+
log.Debug().Msgf("[startup] downloading %s", url)
82+
83+
// Extract filename from URL
84+
fileName, e := filenameFromUrl(url)
85+
if e != nil || fileName == "" {
86+
fileName = utils.MD5(url)
87+
if strings.HasSuffix(url, ".yaml") || strings.HasSuffix(url, ".yml") {
88+
fileName = fileName + ".yaml"
89+
}
90+
log.Warn().Err(e).Str("url", url).Msg("error extracting filename from URL")
91+
//err = errors.Join(err, e)
92+
//continue
93+
}
8194

82-
// md5 of model name
83-
md5Name := utils.MD5(url)
95+
modelPath := filepath.Join(modelPath, fileName)
96+
97+
if e := utils.VerifyPath(fileName, modelPath); e != nil {
98+
log.Error().Err(e).Str("filepath", modelPath).Msg("error verifying path")
99+
err = errors.Join(err, e)
100+
continue
101+
}
84102

85103
// check if file exists
86-
if _, e := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(e, os.ErrNotExist) {
87-
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml"
88-
e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) {
104+
if _, e := os.Stat(modelPath); errors.Is(e, os.ErrNotExist) {
105+
e := downloader.DownloadFile(url, modelPath, "", 0, 0, func(fileName, current, total string, percent float64) {
89106
utils.DisplayDownloadFunction(fileName, current, total, percent)
90107
})
91108
if e != nil {
92-
log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model")
109+
log.Error().Err(e).Str("url", url).Str("filepath", modelPath).Msg("error downloading model")
93110
err = errors.Join(err, e)
94111
}
95112
}
@@ -150,3 +167,20 @@ func installModel(galleries []config.Gallery, modelName, modelPath string, downl
150167

151168
return nil, true
152169
}
170+
171+
func filenameFromUrl(urlstr string) (string, error) {
172+
// strip anything after @
173+
if strings.Contains(urlstr, "@") {
174+
urlstr = strings.Split(urlstr, "@")[0]
175+
}
176+
177+
u, err := url.Parse(urlstr)
178+
if err != nil {
179+
return "", fmt.Errorf("error due to parsing url: %w", err)
180+
}
181+
x, err := url.QueryUnescape(u.EscapedPath())
182+
if err != nil {
183+
return "", fmt.Errorf("error due to escaping: %w", err)
184+
}
185+
return filepath.Base(x), nil
186+
}

pkg/startup/model_preload_test.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ var _ = Describe("Preload test", func() {
2020
tmpdir, err := os.MkdirTemp("", "")
2121
Expect(err).ToNot(HaveOccurred())
2222
libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml"
23-
fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719")
23+
fileName := fmt.Sprintf("%s.yaml", "phi-2")
2424

2525
InstallModels([]config.Gallery{}, libraryURL, tmpdir, true, nil, "phi-2")
2626

@@ -36,7 +36,7 @@ var _ = Describe("Preload test", func() {
3636
tmpdir, err := os.MkdirTemp("", "")
3737
Expect(err).ToNot(HaveOccurred())
3838
url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml"
39-
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))
39+
fileName := fmt.Sprintf("%s.yaml", "phi-2")
4040

4141
InstallModels([]config.Gallery{}, "", tmpdir, true, nil, url)
4242

@@ -79,5 +79,19 @@ var _ = Describe("Preload test", func() {
7979

8080
Expect(string(content)).To(ContainSubstring("name: mistral-openorca"))
8181
})
82+
It("downloads from urls", func() {
83+
tmpdir, err := os.MkdirTemp("", "")
84+
Expect(err).ToNot(HaveOccurred())
85+
url := "huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
86+
fileName := fmt.Sprintf("%s.gguf", "tinyllama-1.1b-chat-v0.3.Q2_K")
87+
88+
err = InstallModels([]config.Gallery{}, "", tmpdir, false, nil, url)
89+
Expect(err).ToNot(HaveOccurred())
90+
91+
resultFile := filepath.Join(tmpdir, fileName)
92+
93+
_, err = os.Stat(resultFile)
94+
Expect(err).ToNot(HaveOccurred())
95+
})
8296
})
8397
})

0 commit comments

Comments
 (0)