|
| 1 | +package api |
| 2 | + |
| 3 | +import ( |
| 4 | + "encoding/json" |
| 5 | + "errors" |
| 6 | + "fmt" |
| 7 | + "os" |
| 8 | + "path/filepath" |
| 9 | + "strings" |
| 10 | + |
| 11 | + config "github.com/go-skynet/LocalAI/api/config" |
| 12 | + "github.com/go-skynet/LocalAI/api/localai" |
| 13 | + "github.com/go-skynet/LocalAI/api/openai" |
| 14 | + "github.com/go-skynet/LocalAI/api/options" |
| 15 | + "github.com/go-skynet/LocalAI/api/schema" |
| 16 | + "github.com/go-skynet/LocalAI/internal" |
| 17 | + "github.com/go-skynet/LocalAI/metrics" |
| 18 | + "github.com/go-skynet/LocalAI/pkg/assets" |
| 19 | + "github.com/go-skynet/LocalAI/pkg/model" |
| 20 | + "github.com/go-skynet/LocalAI/pkg/utils" |
| 21 | + |
| 22 | + "github.com/gofiber/fiber/v2" |
| 23 | + "github.com/gofiber/fiber/v2/middleware/cors" |
| 24 | + "github.com/gofiber/fiber/v2/middleware/logger" |
| 25 | + "github.com/gofiber/fiber/v2/middleware/recover" |
| 26 | + "github.com/rs/zerolog" |
| 27 | + "github.com/rs/zerolog/log" |
| 28 | +) |
| 29 | + |
| 30 | +func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) { |
| 31 | + options := options.NewOptions(opts...) |
| 32 | + |
| 33 | + zerolog.SetGlobalLevel(zerolog.InfoLevel) |
| 34 | + if options.Debug { |
| 35 | + zerolog.SetGlobalLevel(zerolog.DebugLevel) |
| 36 | + } |
| 37 | + |
| 38 | + log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath) |
| 39 | + log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) |
| 40 | + |
| 41 | + modelPath := options.Loader.ModelPath |
| 42 | + if len(options.ModelsURL) > 0 { |
| 43 | + for _, url := range options.ModelsURL { |
| 44 | + if utils.LooksLikeURL(url) { |
| 45 | + // md5 of model name |
| 46 | + md5Name := utils.MD5(url) |
| 47 | + |
| 48 | + // check if file exists |
| 49 | + if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { |
| 50 | + err := utils.DownloadFile(url, filepath.Join(modelPath, md5Name)+".yaml", "", func(fileName, current, total string, percent float64) { |
| 51 | + utils.DisplayDownloadFunction(fileName, current, total, percent) |
| 52 | + }) |
| 53 | + if err != nil { |
| 54 | + log.Error().Msgf("error loading model: %s", err.Error()) |
| 55 | + } |
| 56 | + } |
| 57 | + } |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + cl := config.NewConfigLoader() |
| 62 | + if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil { |
| 63 | + log.Error().Msgf("error loading config files: %s", err.Error()) |
| 64 | + } |
| 65 | + |
| 66 | + if options.ConfigFile != "" { |
| 67 | + if err := cl.LoadConfigFile(options.ConfigFile); err != nil { |
| 68 | + log.Error().Msgf("error loading config file: %s", err.Error()) |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + if err := cl.Preload(options.Loader.ModelPath); err != nil { |
| 73 | + log.Error().Msgf("error downloading models: %s", err.Error()) |
| 74 | + } |
| 75 | + |
| 76 | + if options.PreloadJSONModels != "" { |
| 77 | + if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil { |
| 78 | + return nil, nil, err |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + if options.PreloadModelsFromPath != "" { |
| 83 | + if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil { |
| 84 | + return nil, nil, err |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + if options.Debug { |
| 89 | + for _, v := range cl.ListConfigs() { |
| 90 | + cfg, _ := cl.GetConfig(v) |
| 91 | + log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) |
| 92 | + } |
| 93 | + } |
| 94 | + |
| 95 | + if options.AssetsDestination != "" { |
| 96 | + // Extract files from the embedded FS |
| 97 | + err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) |
| 98 | + log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) |
| 99 | + if err != nil { |
| 100 | + log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) |
| 101 | + } |
| 102 | + } |
| 103 | + |
| 104 | + // turn off any process that was started by GRPC if the context is canceled |
| 105 | + go func() { |
| 106 | + <-options.Context.Done() |
| 107 | + log.Debug().Msgf("Context canceled, shutting down") |
| 108 | + options.Loader.StopAllGRPC() |
| 109 | + }() |
| 110 | + |
| 111 | + if options.WatchDog { |
| 112 | + wd := model.NewWatchDog( |
| 113 | + options.Loader, |
| 114 | + options.WatchDogBusyTimeout, |
| 115 | + options.WatchDogIdleTimeout, |
| 116 | + options.WatchDogBusy, |
| 117 | + options.WatchDogIdle) |
| 118 | + options.Loader.SetWatchDog(wd) |
| 119 | + go wd.Run() |
| 120 | + go func() { |
| 121 | + <-options.Context.Done() |
| 122 | + log.Debug().Msgf("Context canceled, shutting down") |
| 123 | + wd.Shutdown() |
| 124 | + }() |
| 125 | + } |
| 126 | + |
| 127 | + return options, cl, nil |
| 128 | +} |
| 129 | + |
| 130 | +func App(opts ...options.AppOption) (*fiber.App, error) { |
| 131 | + |
| 132 | + options, cl, err := Startup(opts...) |
| 133 | + if err != nil { |
| 134 | + return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error()) |
| 135 | + } |
| 136 | + |
| 137 | + // Return errors as JSON responses |
| 138 | + app := fiber.New(fiber.Config{ |
| 139 | + BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB |
| 140 | + DisableStartupMessage: options.DisableMessage, |
| 141 | + // Override default error handler |
| 142 | + ErrorHandler: func(ctx *fiber.Ctx, err error) error { |
| 143 | + // Status code defaults to 500 |
| 144 | + code := fiber.StatusInternalServerError |
| 145 | + |
| 146 | + // Retrieve the custom status code if it's a *fiber.Error |
| 147 | + var e *fiber.Error |
| 148 | + if errors.As(err, &e) { |
| 149 | + code = e.Code |
| 150 | + } |
| 151 | + |
| 152 | + // Send custom error page |
| 153 | + return ctx.Status(code).JSON( |
| 154 | + schema.ErrorResponse{ |
| 155 | + Error: &schema.APIError{Message: err.Error(), Code: code}, |
| 156 | + }, |
| 157 | + ) |
| 158 | + }, |
| 159 | + }) |
| 160 | + |
| 161 | + if options.Debug { |
| 162 | + app.Use(logger.New(logger.Config{ |
| 163 | + Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", |
| 164 | + })) |
| 165 | + } |
| 166 | + |
| 167 | + // Default middleware config |
| 168 | + app.Use(recover.New()) |
| 169 | + if options.Metrics != nil { |
| 170 | + app.Use(metrics.APIMiddleware(options.Metrics)) |
| 171 | + } |
| 172 | + |
| 173 | + // Auth middleware checking if API key is valid. If no API key is set, no auth is required. |
| 174 | + auth := func(c *fiber.Ctx) error { |
| 175 | + if len(options.ApiKeys) == 0 { |
| 176 | + return c.Next() |
| 177 | + } |
| 178 | + |
| 179 | + // Check for api_keys.json file |
| 180 | + fileContent, err := os.ReadFile("api_keys.json") |
| 181 | + if err == nil { |
| 182 | + // Parse JSON content from the file |
| 183 | + var fileKeys []string |
| 184 | + err := json.Unmarshal(fileContent, &fileKeys) |
| 185 | + if err != nil { |
| 186 | + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Error parsing api_keys.json"}) |
| 187 | + } |
| 188 | + |
| 189 | + // Add file keys to options.ApiKeys |
| 190 | + options.ApiKeys = append(options.ApiKeys, fileKeys...) |
| 191 | + } |
| 192 | + |
| 193 | + if len(options.ApiKeys) == 0 { |
| 194 | + return c.Next() |
| 195 | + } |
| 196 | + |
| 197 | + authHeader := c.Get("Authorization") |
| 198 | + if authHeader == "" { |
| 199 | + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) |
| 200 | + } |
| 201 | + authHeaderParts := strings.Split(authHeader, " ") |
| 202 | + if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { |
| 203 | + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) |
| 204 | + } |
| 205 | + |
| 206 | + apiKey := authHeaderParts[1] |
| 207 | + for _, key := range options.ApiKeys { |
| 208 | + if apiKey == key { |
| 209 | + return c.Next() |
| 210 | + } |
| 211 | + } |
| 212 | + |
| 213 | + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) |
| 214 | + |
| 215 | + } |
| 216 | + |
| 217 | + if options.CORS { |
| 218 | + var c func(ctx *fiber.Ctx) error |
| 219 | + if options.CORSAllowOrigins == "" { |
| 220 | + c = cors.New() |
| 221 | + } else { |
| 222 | + c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) |
| 223 | + } |
| 224 | + |
| 225 | + app.Use(c) |
| 226 | + } |
| 227 | + |
| 228 | + // LocalAI API endpoints |
| 229 | + galleryService := localai.NewGalleryService(options.Loader.ModelPath) |
| 230 | + galleryService.Start(options.Context, cl) |
| 231 | + |
| 232 | + app.Get("/version", auth, func(c *fiber.Ctx) error { |
| 233 | + return c.JSON(struct { |
| 234 | + Version string `json:"version"` |
| 235 | + }{Version: internal.PrintableVersion()}) |
| 236 | + }) |
| 237 | + |
| 238 | + modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService) |
| 239 | + app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) |
| 240 | + app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint()) |
| 241 | + app.Get("/models/galleries", auth, modelGalleryService.ListModelGalleriesEndpoint()) |
| 242 | + app.Post("/models/galleries", auth, modelGalleryService.AddModelGalleryEndpoint()) |
| 243 | + app.Delete("/models/galleries", auth, modelGalleryService.RemoveModelGalleryEndpoint()) |
| 244 | + app.Get("/models/jobs/:uuid", auth, modelGalleryService.GetOpStatusEndpoint()) |
| 245 | + app.Get("/models/jobs", auth, modelGalleryService.GetAllStatusEndpoint()) |
| 246 | + |
| 247 | + // openAI compatible API endpoint |
| 248 | + |
| 249 | + // chat |
| 250 | + app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options)) |
| 251 | + app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options)) |
| 252 | + |
| 253 | + // edit |
| 254 | + app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options)) |
| 255 | + app.Post("/edits", auth, openai.EditEndpoint(cl, options)) |
| 256 | + |
| 257 | + // completion |
| 258 | + app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options)) |
| 259 | + app.Post("/completions", auth, openai.CompletionEndpoint(cl, options)) |
| 260 | + app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options)) |
| 261 | + |
| 262 | + // embeddings |
| 263 | + app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) |
| 264 | + app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) |
| 265 | + app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options)) |
| 266 | + |
| 267 | + // audio |
| 268 | + app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options)) |
| 269 | + app.Post("/tts", auth, localai.TTSEndpoint(cl, options)) |
| 270 | + |
| 271 | + // images |
| 272 | + app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options)) |
| 273 | + |
| 274 | + if options.ImageDir != "" { |
| 275 | + app.Static("/generated-images", options.ImageDir) |
| 276 | + } |
| 277 | + |
| 278 | + if options.AudioDir != "" { |
| 279 | + app.Static("/generated-audio", options.AudioDir) |
| 280 | + } |
| 281 | + |
| 282 | + ok := func(c *fiber.Ctx) error { |
| 283 | + return c.SendStatus(200) |
| 284 | + } |
| 285 | + |
| 286 | + // Kubernetes health checks |
| 287 | + app.Get("/healthz", ok) |
| 288 | + app.Get("/readyz", ok) |
| 289 | + |
| 290 | + // Experimental Backend Statistics Module |
| 291 | + backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now |
| 292 | + app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor)) |
| 293 | + app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitor)) |
| 294 | + |
| 295 | + // models |
| 296 | + app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) |
| 297 | + app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl)) |
| 298 | + |
| 299 | + app.Get("/metrics", metrics.MetricsHandler()) |
| 300 | + |
| 301 | + return app, nil |
| 302 | +} |
0 commit comments