Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader,
}
}

if err := cl.Preload(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error downloading models: %s", err.Error())
}

if options.Debug {
for _, v := range cl.ListConfigs() {
cfg, _ := cl.GetConfig(v)
Expand Down
5 changes: 3 additions & 2 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ var _ = Describe("API test", func() {
Expect(content["backend"]).To(Equal("bert-embeddings"))
})

It("runs openllama", Label("llama"), func() {
It("runs openllama(llama-ggml backend)", Label("llama"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
Expand Down Expand Up @@ -362,9 +362,10 @@ var _ = Describe("API test", func() {
Expect(res["location"]).To(Equal("San Francisco, California, United States"), fmt.Sprint(res))
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))

})

It("runs openllama gguf", Label("llama-gguf"), func() {
It("runs openllama gguf(llama-cpp)", Label("llama-gguf"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
Expand Down
32 changes: 32 additions & 0 deletions api/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"strings"
"sync"

"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
)

Expand Down Expand Up @@ -264,6 +266,36 @@ func (cm *ConfigLoader) ListConfigs() []string {
return res
}

func (cm *ConfigLoader) Preload(modelPath string) error {
cm.Lock()
defer cm.Unlock()

for i, config := range cm.configs {
modelURL := config.PredictionOptions.Model
modelURL = utils.ConvertURL(modelURL)
if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") {
// md5 of model name
md5Name := utils.MD5(modelURL)

// check if file exists
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); err == os.ErrNotExist {
err := utils.DownloadFile(modelURL, filepath.Join(modelPath, md5Name), "", func(fileName, current, total string, percent float64) {
log.Info().Msgf("Downloading %s: %s/%s (%.2f%%)", fileName, current, total, percent)
})
if err != nil {
return err
}
}

cc := cm.configs[i]
c := &cc
c.PredictionOptions.Model = md5Name
cm.configs[i] = *c
}
}
return nil
}

func (cm *ConfigLoader) LoadConfigs(path string) error {
cm.Lock()
defer cm.Unlock()
Expand Down
32 changes: 19 additions & 13 deletions api/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,12 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
c.Set("Transfer-Encoding", "chunked")
}

templateFile := config.Model
templateFile := ""

// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}

if config.TemplateConfig.Chat != "" && !processFunctions {
templateFile = config.TemplateConfig.Chat
Expand All @@ -229,18 +234,19 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
templateFile = config.TemplateConfig.Functions
}

// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.ChatPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
SuppressSystemPrompt: suppressConfigSystemPrompt,
Input: predInput,
Functions: funcs,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
} else {
log.Debug().Msgf("Template failed loading: %s", err.Error())
}
}

log.Debug().Msgf("Prompt (after templating): %s", predInput)
Expand Down
40 changes: 24 additions & 16 deletions api/openai/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
c.Set("Transfer-Encoding", "chunked")
}

templateFile := config.Model
templateFile := ""

// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}

if config.TemplateConfig.Completion != "" {
templateFile = config.TemplateConfig.Completion
Expand All @@ -94,13 +99,14 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe

predInput := config.PromptStrings[0]

// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
Input: predInput,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
Input: predInput,
})
if err == nil {
predInput = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}
}

responses := make(chan schema.OpenAIResponse)
Expand Down Expand Up @@ -145,14 +151,16 @@ func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
totalTokenUsage := backend.TokenUsage{}

for k, i := range config.PromptStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
Input: i,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
if templateFile != "" {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.CompletionPromptTemplate, templateFile, model.PromptTemplateData{
SystemPrompt: config.SystemPrompt,
Input: i,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
}

r, tokenUsage, err := ComputeChoices(
Expand Down
26 changes: 16 additions & 10 deletions api/openai/edit.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)

log.Debug().Msgf("Parameter Config: %+v", config)

templateFile := config.Model
templateFile := ""

// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
if o.Loader.ExistsInModelPath(fmt.Sprintf("%s.tmpl", config.Model)) {
templateFile = config.Model
}

if config.TemplateConfig.Edit != "" {
templateFile = config.TemplateConfig.Edit
Expand All @@ -40,15 +45,16 @@ func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
totalTokenUsage := backend.TokenUsage{}

for _, i := range config.InputStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
if templateFile != "" {
templatedInput, err := o.Loader.EvaluateTemplateForPrompt(model.EditPromptTemplate, templateFile, model.PromptTemplateData{
Input: i,
Instruction: input.Instruction,
SystemPrompt: config.SystemPrompt,
})
if err == nil {
i = templatedInput
log.Debug().Msgf("Template found, input modified to: %s", i)
}
}

r, tokenUsage, err := ComputeChoices(input, i, config, o, o.Loader, func(s string, c *[]schema.Choice) {
Expand Down
86 changes: 2 additions & 84 deletions pkg/gallery/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"hash"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
Expand Down Expand Up @@ -115,89 +114,8 @@ func InstallModel(basePath, nameOverride string, config *Config, configOverrides
// Create file path
filePath := filepath.Join(basePath, file.Filename)

// Check if the file already exists
_, err := os.Stat(filePath)
if err == nil {
// File exists, check SHA
if file.SHA256 != "" {
// Verify SHA
calculatedSHA, err := calculateSHA(filePath)
if err != nil {
return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err)
}
if calculatedSHA == file.SHA256 {
// SHA matches, skip downloading
log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename)
continue
}
// SHA doesn't match, delete the file and download again
err = os.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err)
}
log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath)

} else {
// SHA is missing, skip downloading
log.Debug().Msgf("File %q already exists. Skipping download", file.Filename)
continue
}
} else if !os.IsNotExist(err) {
// Error occurred while checking file existence
return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err)
}

log.Debug().Msgf("Downloading %q", file.URI)

// Download file
resp, err := http.Get(file.URI)
if err != nil {
return fmt.Errorf("failed to download file %q: %v", file.Filename, err)
}
defer resp.Body.Close()

// Create parent directory
err = os.MkdirAll(filepath.Dir(filePath), 0755)
if err != nil {
return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err)
}

// Create and write file content
outFile, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("failed to create file %q: %v", file.Filename, err)
}
defer outFile.Close()

progress := &progressWriter{
fileName: file.Filename,
total: resp.ContentLength,
hash: sha256.New(),
downloadStatus: downloadStatus,
}
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
if err != nil {
return fmt.Errorf("failed to write file %q: %v", file.Filename, err)
}

if file.SHA256 != "" {
// Verify SHA
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
if calculatedSHA != file.SHA256 {
log.Debug().Msgf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256)
}
} else {
log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename)
}

log.Debug().Msgf("File %q downloaded and verified", file.Filename)
if utils.IsArchive(filePath) {
log.Debug().Msgf("File %q is an archive, uncompressing to %s", file.Filename, basePath)
if err := utils.ExtractArchive(filePath, basePath); err != nil {
log.Debug().Msgf("Failed decompressing %q: %s", file.Filename, err.Error())
return err
}
if err := utils.DownloadFile(file.URI, filePath, file.SHA256, downloadStatus); err != nil {
return err
}
}

Expand Down
18 changes: 10 additions & 8 deletions pkg/model/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,19 @@ func (ml *ModelLoader) loadTemplateIfExists(templateType TemplateType, templateN
// skip any error here - we run anyway if a template does not exist
modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName)

if !ml.ExistsInModelPath(modelTemplateFile) {
return nil
}

dat, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile))
if err != nil {
return err
dat := ""
if ml.ExistsInModelPath(modelTemplateFile) {
d, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile))
if err != nil {
return err
}
dat = string(d)
} else {
dat = templateName
}

// Parse the template
tmpl, err := template.New("prompt").Parse(string(dat))
tmpl, err := template.New("prompt").Parse(dat)
if err != nil {
return err
}
Expand Down
Loading