Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ service Backend {
rpc PredictStream(PredictOptions) returns (stream Reply) {}
rpc Embedding(PredictOptions) returns (EmbeddingResult) {}
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
rpc GenerateVideo(GenerateVideoRequest) returns (Result) {}
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
rpc TTS(TTSRequest) returns (Result) {}
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
Expand Down Expand Up @@ -301,6 +302,19 @@ message GenerateImageRequest {
int32 CLIPSkip = 11;
}

message GenerateVideoRequest {
string prompt = 1;
string start_image = 2; // Path or base64 encoded image for the start frame
string end_image = 3; // Path or base64 encoded image for the end frame
int32 width = 4;
int32 height = 5;
int32 num_frames = 6; // Number of frames to generate
int32 fps = 7; // Frames per second
int32 seed = 8;
float cfg_scale = 9; // Classifier-free guidance scale
string dst = 10; // Output path for the generated video
}

message TTSRequest {
string text = 1;
string model = 2;
Expand Down
10 changes: 2 additions & 8 deletions core/application/startup.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,12 @@ func New(opts ...config.AppOption) (*Application, error) {
if err != nil {
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
}
if options.ImageDir != "" {
err := os.MkdirAll(options.ImageDir, 0750)
if options.GeneratedContentDir != "" {
err := os.MkdirAll(options.GeneratedContentDir, 0750)
if err != nil {
return nil, fmt.Errorf("unable to create ImageDir: %q", err)
}
}
if options.AudioDir != "" {
err := os.MkdirAll(options.AudioDir, 0750)
if err != nil {
return nil, fmt.Errorf("unable to create AudioDir: %q", err)
}
}
if options.UploadDir != "" {
err := os.MkdirAll(options.UploadDir, 0750)
if err != nil {
Expand Down
11 changes: 8 additions & 3 deletions core/backend/soundgeneration.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ func SoundGeneration(
return "", nil, fmt.Errorf("could not load sound generation model")
}

if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
if err := os.MkdirAll(appConfig.GeneratedContentDir, 0750); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}

fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "sound_generation", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)
audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio")
if err := os.MkdirAll(audioDir, 0750); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}

fileName := utils.GenerateUniqueFileName(audioDir, "sound_generation", ".wav")
filePath := filepath.Join(audioDir, fileName)

res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
Text: text,
Expand Down
7 changes: 4 additions & 3 deletions core/backend/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ func ModelTTS(
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
}

if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio")
if err := os.MkdirAll(audioDir, 0750); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}

fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)
fileName := utils.GenerateUniqueFileName(audioDir, "tts", ".wav")
filePath := filepath.Join(audioDir, fileName)

// We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
// This should be addressed in a follow up PR soon.
Expand Down
36 changes: 36 additions & 0 deletions core/backend/video.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package backend

import (
"github.com/mudler/LocalAI/core/config"

"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
)

func VideoGeneration(height, width int32, prompt, startImage, endImage, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) {

opts := ModelOptions(backendConfig, appConfig)
inferenceModel, err := loader.Load(
opts...,
)
if err != nil {
return nil, err
}
defer loader.Close()

fn := func() error {
_, err := inferenceModel.GenerateVideo(
appConfig.Context,
&proto.GenerateVideoRequest{
Height: height,
Width: width,
Prompt: prompt,
StartImage: startImage,
EndImage: endImage,
Dst: dst,
})
return err
}

return fn, nil
}
6 changes: 2 additions & 4 deletions core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ type RunCMD struct {

ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
ImagePath string `env:"LOCALAI_IMAGE_PATH,IMAGE_PATH" type:"path" default:"/tmp/generated/images" help:"Location for images generated by backends (e.g. stablediffusion)" group:"storage"`
AudioPath string `env:"LOCALAI_AUDIO_PATH,AUDIO_PATH" type:"path" default:"/tmp/generated/audio" help:"Location for audio generated by backends (e.g. piper)" group:"storage"`
GeneratedContentPath string `env:"LOCALAI_GENERATED_CONTENT_PATH,GENERATED_CONTENT_PATH" type:"path" default:"/tmp/generated/content" help:"Location for generated content (e.g. images, audio, videos)" group:"storage"`
UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"`
ConfigPath string `env:"LOCALAI_CONFIG_PATH,CONFIG_PATH" default:"/tmp/localai/config" group:"storage"`
LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"`
Expand Down Expand Up @@ -81,8 +80,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithModelPath(r.ModelsPath),
config.WithContextSize(r.ContextSize),
config.WithDebug(zerolog.GlobalLevel() <= zerolog.DebugLevel),
config.WithImageDir(r.ImagePath),
config.WithAudioDir(r.AudioPath),
config.WithGeneratedContentDir(r.GeneratedContentPath),
config.WithUploadDir(r.UploadPath),
config.WithConfigsDir(r.ConfigPath),
config.WithDynamicConfigDir(r.LocalaiConfigDir),
Expand Down
2 changes: 1 addition & 1 deletion core/cli/soundgeneration.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
opts := &config.ApplicationConfig{
ModelPath: t.ModelsPath,
Context: context.Background(),
AudioDir: outputDir,
GeneratedContentDir: outputDir,
AssetsDestination: t.BackendAssetsPath,
ExternalGRPCBackends: externalBackends,
}
Expand Down
8 changes: 4 additions & 4 deletions core/cli/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
text := strings.Join(t.Text, " ")

opts := &config.ApplicationConfig{
ModelPath: t.ModelsPath,
Context: context.Background(),
AudioDir: outputDir,
AssetsDestination: t.BackendAssetsPath,
ModelPath: t.ModelsPath,
Context: context.Background(),
GeneratedContentDir: outputDir,
AssetsDestination: t.BackendAssetsPath,
}
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)

Expand Down
39 changes: 17 additions & 22 deletions core/config/application_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,21 @@ type ApplicationConfig struct {
UploadLimitMB, Threads, ContextSize int
F16 bool
Debug bool
ImageDir string
AudioDir string
UploadDir string
ConfigsDir string
DynamicConfigsDir string
DynamicConfigsDirPollInterval time.Duration
CORS bool
CSRF bool
PreloadJSONModels string
PreloadModelsFromPath string
CORSAllowOrigins string
ApiKeys []string
P2PToken string
P2PNetworkID string
GeneratedContentDir string

ConfigsDir string
UploadDir string

DynamicConfigsDir string
DynamicConfigsDirPollInterval time.Duration
CORS bool
CSRF bool
PreloadJSONModels string
PreloadModelsFromPath string
CORSAllowOrigins string
ApiKeys []string
P2PToken string
P2PNetworkID string

DisableWebUI bool
EnforcePredownloadScans bool
Expand Down Expand Up @@ -279,15 +280,9 @@ func WithDebug(debug bool) AppOption {
}
}

func WithAudioDir(audioDir string) AppOption {
func WithGeneratedContentDir(generatedContentDir string) AppOption {
return func(o *ApplicationConfig) {
o.AudioDir = audioDir
}
}

func WithImageDir(imageDir string) AppOption {
return func(o *ApplicationConfig) {
o.ImageDir = imageDir
o.GeneratedContentDir = generatedContentDir
}
}

Expand Down
37 changes: 25 additions & 12 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,18 +436,19 @@ func (c *BackendConfig) HasTemplate() bool {
type BackendConfigUsecases int

const (
FLAG_ANY BackendConfigUsecases = 0b00000000000
FLAG_CHAT BackendConfigUsecases = 0b00000000001
FLAG_COMPLETION BackendConfigUsecases = 0b00000000010
FLAG_EDIT BackendConfigUsecases = 0b00000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b00000001000
FLAG_RERANK BackendConfigUsecases = 0b00000010000
FLAG_IMAGE BackendConfigUsecases = 0b00000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b00001000000
FLAG_TTS BackendConfigUsecases = 0b00010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b00100000000
FLAG_TOKENIZE BackendConfigUsecases = 0b01000000000
FLAG_VAD BackendConfigUsecases = 0b10000000000
FLAG_ANY BackendConfigUsecases = 0b000000000000
FLAG_CHAT BackendConfigUsecases = 0b000000000001
FLAG_COMPLETION BackendConfigUsecases = 0b000000000010
FLAG_EDIT BackendConfigUsecases = 0b000000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000000001000
FLAG_RERANK BackendConfigUsecases = 0b000000010000
FLAG_IMAGE BackendConfigUsecases = 0b000000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b000001000000
FLAG_TTS BackendConfigUsecases = 0b000010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b000100000000
FLAG_TOKENIZE BackendConfigUsecases = 0b001000000000
FLAG_VAD BackendConfigUsecases = 0b010000000000
FLAG_VIDEO BackendConfigUsecases = 0b100000000000

// Common Subsets
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
Expand All @@ -468,6 +469,7 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
"FLAG_TOKENIZE": FLAG_TOKENIZE,
"FLAG_VAD": FLAG_VAD,
"FLAG_LLM": FLAG_LLM,
"FLAG_VIDEO": FLAG_VIDEO,
}
}

Expand Down Expand Up @@ -532,6 +534,17 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
return false
}

}
if (u & FLAG_VIDEO) == FLAG_VIDEO {
videoBackends := []string{"diffusers", "stablediffusion"}
if !slices.Contains(videoBackends, c.Backend) {
return false
}

if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" {
return false
}

}
if (u & FLAG_RERANK) == FLAG_RERANK {
if c.Backend != "rerankers" {
Expand Down
19 changes: 14 additions & 5 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"errors"
"fmt"
"net/http"
"os"
"path/filepath"

"github.com/dave-gray101/v2keyauth"
"github.com/mudler/LocalAI/pkg/utils"
Expand Down Expand Up @@ -153,12 +155,19 @@
Browse: true,
}))

if application.ApplicationConfig().ImageDir != "" {
router.Static("/generated-images", application.ApplicationConfig().ImageDir)
}
if application.ApplicationConfig().GeneratedContentDir != "" {
os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)

Check warning

Code scanning / gosec

Errors unhandled Warning

Errors unhandled
audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
imagePath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "images")
videoPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "videos")

os.MkdirAll(audioPath, 0750)

Check warning

Code scanning / gosec

Errors unhandled Warning

Errors unhandled
os.MkdirAll(imagePath, 0750)

Check warning

Code scanning / gosec

Errors unhandled Warning

Errors unhandled
os.MkdirAll(videoPath, 0750)

Check warning

Code scanning / gosec

Errors unhandled Warning

Errors unhandled

if application.ApplicationConfig().AudioDir != "" {
router.Static("/generated-audio", application.ApplicationConfig().AudioDir)
router.Static("/generated-audio", audioPath)
router.Static("/generated-images", imagePath)
router.Static("/generated-videos", videoPath)
}

// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
Expand Down
3 changes: 1 addition & 2 deletions core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,7 @@ var _ = Describe("API test", func() {
application, err := application.New(
append(commonOpts,
config.WithContext(c),
config.WithAudioDir(tmpdir),
config.WithImageDir(tmpdir),
config.WithGeneratedContentDir(tmpdir),
config.WithGalleries(galleries),
config.WithModelPath(modelDir),
config.WithBackendAssets(backendAssets),
Expand Down
Loading
Loading