Skip to content

Commit a8eeda5

Browse files
committed
feat: support 'huggingface://' format
1 parent 99c352d commit a8eeda5

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

api/config/config.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,18 +265,43 @@ func (cm *ConfigLoader) ListConfigs() []string {
265265
return res
266266
}
267267

268+
func convertURL(s string) string {
269+
if strings.HasPrefix(s, "huggingface://") {
270+
repository := strings.Replace(s, "huggingface://", "", 1)
271+
// convert repository to a full URL.
272+
// e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf
273+
owner := strings.Split(repository, "/")[0]
274+
repo := strings.Split(repository, "/")[1]
275+
branch := "main"
276+
if strings.Contains(repo, "@") {
277+
branch = strings.Split(repository, "@")[1]
278+
}
279+
filepath := strings.Split(repository, "/")[2]
280+
if strings.Contains(filepath, "@") {
281+
filepath = strings.Split(filepath, "@")[0]
282+
}
283+
284+
return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath)
285+
286+
}
287+
288+
return s
289+
}
290+
268291
func (cm *ConfigLoader) Preload(modelPath string) error {
269292
cm.Lock()
270293
defer cm.Unlock()
271294

272295
for i, config := range cm.configs {
273-
if strings.HasPrefix(config.PredictionOptions.Model, "http://") || strings.HasPrefix(config.PredictionOptions.Model, "https://") {
296+
modelURL := config.PredictionOptions.Model
297+
modelURL = convertURL(modelURL)
298+
if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") {
274299
// md5 of model name
275-
md5Name := utils.MD5(config.PredictionOptions.Model)
300+
md5Name := utils.MD5(modelURL)
276301

277302
// check if file exists
278303
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); err == os.ErrNotExist {
279-
err := utils.DownloadFile(config.PredictionOptions.Model, filepath.Join(modelPath, md5Name))
304+
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name))
280305
if err != nil {
281306
return err
282307
}
@@ -287,6 +312,7 @@ func (cm *ConfigLoader) Preload(modelPath string) error {
287312
c.PredictionOptions.Model = md5Name
288313
cm.configs[i] = *c
289314
}
315+
290316
}
291317
return nil
292318
}

0 commit comments

Comments
 (0)