Skip to content

Commit 92f8940

Browse files
authored
Revert "[Refactor]: Core/API Split (#1506)"
This reverts commit ab7b4d5.
1 parent ab7b4d5 commit 92f8940

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+3101
-3425
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ LocalAI
1919
local-ai
2020
# prevent above rules from omitting the helm chart
2121
!charts/*
22-
# prevent above rules from omitting the core/**/localai folder
23-
!core/**/localai
22+
# prevent above rules from omitting the api/localai folder
23+
!api/localai
2424

2525
# Ignore models
2626
models/*

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ ENV NVIDIA_VISIBLE_DEVICES=all
8888
WORKDIR /build
8989

9090
COPY . .
91-
COPY .git/ .git/
91+
COPY .git .
9292
RUN make prepare
9393

9494
# stablediffusion does not tolerate a newer version of abseil, build it first

api/api.go

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
package api
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
"fmt"
7+
"os"
8+
"path/filepath"
9+
"strings"
10+
11+
config "github.com/go-skynet/LocalAI/api/config"
12+
"github.com/go-skynet/LocalAI/api/localai"
13+
"github.com/go-skynet/LocalAI/api/openai"
14+
"github.com/go-skynet/LocalAI/api/options"
15+
"github.com/go-skynet/LocalAI/api/schema"
16+
"github.com/go-skynet/LocalAI/internal"
17+
"github.com/go-skynet/LocalAI/metrics"
18+
"github.com/go-skynet/LocalAI/pkg/assets"
19+
"github.com/go-skynet/LocalAI/pkg/model"
20+
"github.com/go-skynet/LocalAI/pkg/utils"
21+
22+
"github.com/gofiber/fiber/v2"
23+
"github.com/gofiber/fiber/v2/middleware/cors"
24+
"github.com/gofiber/fiber/v2/middleware/logger"
25+
"github.com/gofiber/fiber/v2/middleware/recover"
26+
"github.com/rs/zerolog"
27+
"github.com/rs/zerolog/log"
28+
)
29+
30+
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) {
31+
options := options.NewOptions(opts...)
32+
33+
zerolog.SetGlobalLevel(zerolog.InfoLevel)
34+
if options.Debug {
35+
zerolog.SetGlobalLevel(zerolog.DebugLevel)
36+
}
37+
38+
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
39+
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
40+
41+
modelPath := options.Loader.ModelPath
42+
if len(options.ModelsURL) > 0 {
43+
for _, url := range options.ModelsURL {
44+
if utils.LooksLikeURL(url) {
45+
// md5 of model name
46+
md5Name := utils.MD5(url)
47+
48+
// check if file exists
49+
if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) {
50+
err := utils.DownloadFile(url, filepath.Join(modelPath, md5Name)+".yaml", "", func(fileName, current, total string, percent float64) {
51+
utils.DisplayDownloadFunction(fileName, current, total, percent)
52+
})
53+
if err != nil {
54+
log.Error().Msgf("error loading model: %s", err.Error())
55+
}
56+
}
57+
}
58+
}
59+
}
60+
61+
cl := config.NewConfigLoader()
62+
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
63+
log.Error().Msgf("error loading config files: %s", err.Error())
64+
}
65+
66+
if options.ConfigFile != "" {
67+
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
68+
log.Error().Msgf("error loading config file: %s", err.Error())
69+
}
70+
}
71+
72+
if err := cl.Preload(options.Loader.ModelPath); err != nil {
73+
log.Error().Msgf("error downloading models: %s", err.Error())
74+
}
75+
76+
if options.PreloadJSONModels != "" {
77+
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
78+
return nil, nil, err
79+
}
80+
}
81+
82+
if options.PreloadModelsFromPath != "" {
83+
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
84+
return nil, nil, err
85+
}
86+
}
87+
88+
if options.Debug {
89+
for _, v := range cl.ListConfigs() {
90+
cfg, _ := cl.GetConfig(v)
91+
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
92+
}
93+
}
94+
95+
if options.AssetsDestination != "" {
96+
// Extract files from the embedded FS
97+
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
98+
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
99+
if err != nil {
100+
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
101+
}
102+
}
103+
104+
// turn off any process that was started by GRPC if the context is canceled
105+
go func() {
106+
<-options.Context.Done()
107+
log.Debug().Msgf("Context canceled, shutting down")
108+
options.Loader.StopAllGRPC()
109+
}()
110+
111+
if options.WatchDog {
112+
wd := model.NewWatchDog(
113+
options.Loader,
114+
options.WatchDogBusyTimeout,
115+
options.WatchDogIdleTimeout,
116+
options.WatchDogBusy,
117+
options.WatchDogIdle)
118+
options.Loader.SetWatchDog(wd)
119+
go wd.Run()
120+
go func() {
121+
<-options.Context.Done()
122+
log.Debug().Msgf("Context canceled, shutting down")
123+
wd.Shutdown()
124+
}()
125+
}
126+
127+
return options, cl, nil
128+
}
129+
130+
func App(opts ...options.AppOption) (*fiber.App, error) {
131+
132+
options, cl, err := Startup(opts...)
133+
if err != nil {
134+
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
135+
}
136+
137+
// Return errors as JSON responses
138+
app := fiber.New(fiber.Config{
139+
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
140+
DisableStartupMessage: options.DisableMessage,
141+
// Override default error handler
142+
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
143+
// Status code defaults to 500
144+
code := fiber.StatusInternalServerError
145+
146+
// Retrieve the custom status code if it's a *fiber.Error
147+
var e *fiber.Error
148+
if errors.As(err, &e) {
149+
code = e.Code
150+
}
151+
152+
// Send custom error page
153+
return ctx.Status(code).JSON(
154+
schema.ErrorResponse{
155+
Error: &schema.APIError{Message: err.Error(), Code: code},
156+
},
157+
)
158+
},
159+
})
160+
161+
if options.Debug {
162+
app.Use(logger.New(logger.Config{
163+
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
164+
}))
165+
}
166+
167+
// Default middleware config
168+
app.Use(recover.New())
169+
if options.Metrics != nil {
170+
app.Use(metrics.APIMiddleware(options.Metrics))
171+
}
172+
173+
// Auth middleware checking if API key is valid. If no API key is set, no auth is required.
174+
auth := func(c *fiber.Ctx) error {
175+
if len(options.ApiKeys) == 0 {
176+
return c.Next()
177+
}
178+
179+
// Check for api_keys.json file
180+
fileContent, err := os.ReadFile("api_keys.json")
181+
if err == nil {
182+
// Parse JSON content from the file
183+
var fileKeys []string
184+
err := json.Unmarshal(fileContent, &fileKeys)
185+
if err != nil {
186+
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"})
187+
}
188+
189+
// Add file keys to options.ApiKeys
190+
options.ApiKeys = append(options.ApiKeys, fileKeys...)
191+
}
192+
193+
if len(options.ApiKeys) == 0 {
194+
return c.Next()
195+
}
196+
197+
authHeader := c.Get("Authorization")
198+
if authHeader == "" {
199+
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"})
200+
}
201+
authHeaderParts := strings.Split(authHeader, " ")
202+
if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" {
203+
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"})
204+
}
205+
206+
apiKey := authHeaderParts[1]
207+
for _, key := range options.ApiKeys {
208+
if apiKey == key {
209+
return c.Next()
210+
}
211+
}
212+
213+
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"})
214+
215+
}
216+
217+
if options.CORS {
218+
var c func(ctx *fiber.Ctx) error
219+
if options.CORSAllowOrigins == "" {
220+
c = cors.New()
221+
} else {
222+
c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins})
223+
}
224+
225+
app.Use(c)
226+
}
227+
228+
// LocalAI API endpoints
229+
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
230+
galleryService.Start(options.Context, cl)
231+
232+
app.Get("/version", auth, func(c *fiber.Ctx) error {
233+
return c.JSON(struct {
234+
Version string `json:"version"`
235+
}{Version: internal.PrintableVersion()})
236+
})
237+
238+
modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService)
239+
app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint())
240+
app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint())
241+
app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint())
242+
app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint())
243+
app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint())
244+
app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint())
245+
app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint())
246+
247+
// openAI compatible API endpoint
248+
249+
// chat
250+
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options))
251+
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options))
252+
253+
// edit
254+
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
255+
app.Post("/edits", auth, openai.EditEndpoint(cl, options))
256+
257+
// completion
258+
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
259+
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
260+
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options))
261+
262+
// embeddings
263+
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
264+
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
265+
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
266+
267+
// audio
268+
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options))
269+
app.Post("/tts", auth, localai.TTSEndpoint(cl, options))
270+
271+
// images
272+
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options))
273+
274+
if options.ImageDir != "" {
275+
app.Static("/generated-images", options.ImageDir)
276+
}
277+
278+
if options.AudioDir != "" {
279+
app.Static("/generated-audio", options.AudioDir)
280+
}
281+
282+
ok := func(c *fiber.Ctx) error {
283+
return c.SendStatus(200)
284+
}
285+
286+
// Kubernetes health checks
287+
app.Get("/healthz", ok)
288+
app.Get("/readyz", ok)
289+
290+
// Experimental Backend Statistics Module
291+
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now
292+
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))
293+
app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor))
294+
295+
// models
296+
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
297+
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
298+
299+
app.Get("/metrics", metrics.MetricsHandler())
300+
301+
return app, nil
302+
}

0 commit comments

Comments
 (0)