@@ -61,11 +61,11 @@ var AutoLoadBackends []string = []string{
6161
6262// starts the grpcModelProcess for the backend, and returns a grpc client
6363// It also loads the model
64- func (ml * ModelLoader ) grpcModel (backend string , o * Options ) func (string , string ) (* grpc. Client , error ) {
65- return func (modelName , modelFile string ) (* grpc. Client , error ) {
64+ func (ml * ModelLoader ) grpcModel (backend string , o * Options ) func (string , string ) (ModelAddress , error ) {
65+ return func (modelName , modelFile string ) (ModelAddress , error ) {
6666 log .Debug ().Msgf ("Loading Model %s with gRPC (file: %s) (backend: %s): %+v" , modelName , modelFile , backend , * o )
6767
68- var client * grpc. Client
68+ var client ModelAddress
6969
7070 getFreeAddress := func () (string , error ) {
7171 port , err := freeport .GetFreePort ()
@@ -82,46 +82,46 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
8282 if _ , err := os .Stat (uri ); err == nil {
8383 serverAddress , err := getFreeAddress ()
8484 if err != nil {
85- return nil , fmt .Errorf ("failed allocating free ports: %s" , err .Error ())
85+ return "" , fmt .Errorf ("failed allocating free ports: %s" , err .Error ())
8686 }
8787 // Make sure the process is executable
8888 if err := ml .startProcess (uri , o .model , serverAddress ); err != nil {
89- return nil , err
89+ return "" , err
9090 }
9191
9292 log .Debug ().Msgf ("GRPC Service Started" )
9393
94- client = grpc . NewClient (serverAddress )
94+ client = ModelAddress (serverAddress )
9595 } else {
9696 // address
97- client = grpc . NewClient (uri )
97+ client = ModelAddress (uri )
9898 }
9999 } else {
100100 grpcProcess := filepath .Join (o .assetDir , "backend-assets" , "grpc" , backend )
101101 // Check if the file exists
102102 if _ , err := os .Stat (grpcProcess ); os .IsNotExist (err ) {
103- return nil , fmt .Errorf ("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS" , grpcProcess )
103+ return "" , fmt .Errorf ("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS" , grpcProcess )
104104 }
105105
106106 serverAddress , err := getFreeAddress ()
107107 if err != nil {
108- return nil , fmt .Errorf ("failed allocating free ports: %s" , err .Error ())
108+ return "" , fmt .Errorf ("failed allocating free ports: %s" , err .Error ())
109109 }
110110
111111 // Make sure the process is executable
112112 if err := ml .startProcess (grpcProcess , o .model , serverAddress ); err != nil {
113- return nil , err
113+ return "" , err
114114 }
115115
116116 log .Debug ().Msgf ("GRPC Service Started" )
117117
118- client = grpc . NewClient (serverAddress )
118+ client = ModelAddress (serverAddress )
119119 }
120120
121121 // Wait for the service to start up
122122 ready := false
123123 for i := 0 ; i < o .grpcAttempts ; i ++ {
124- if client .HealthCheck (context .Background ()) {
124+ if client .GRPC (). HealthCheck (context .Background ()) {
125125 log .Debug ().Msgf ("GRPC Service Ready" )
126126 ready = true
127127 break
@@ -131,7 +131,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
131131
132132 if ! ready {
133133 log .Debug ().Msgf ("GRPC Service NOT ready" )
134- return nil , fmt .Errorf ("grpc service not ready" )
134+ return "" , fmt .Errorf ("grpc service not ready" )
135135 }
136136
137137 options := * o .gRPCOptions
@@ -140,19 +140,30 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
140140
141141 log .Debug ().Msgf ("GRPC: Loading model with options: %+v" , options )
142142
143- res , err := client .LoadModel (o .context , & options )
143+ res , err := client .GRPC (). LoadModel (o .context , & options )
144144 if err != nil {
145- return nil , fmt .Errorf ("could not load model: %w" , err )
145+ return "" , fmt .Errorf ("could not load model: %w" , err )
146146 }
147147 if ! res .Success {
148- return nil , fmt .Errorf ("could not load model (no success): %s" , res .Message )
148+ return "" , fmt .Errorf ("could not load model (no success): %s" , res .Message )
149149 }
150150
151151 return client , nil
152152 }
153153}
154154
155- func (ml * ModelLoader ) BackendLoader (opts ... Option ) (model * grpc.Client , err error ) {
155+ func (ml * ModelLoader ) resolveAddress (addr ModelAddress , parallel bool ) (* grpc.Client , error ) {
156+ if parallel {
157+ return addr .GRPC (), nil
158+ }
159+
160+ if _ , ok := ml .grpcClients [string (addr )]; ! ok {
161+ ml .grpcClients [string (addr )] = addr .GRPC ()
162+ }
163+ return ml .grpcClients [string (addr )], nil
164+ }
165+
166+ func (ml * ModelLoader ) BackendLoader (opts ... Option ) (client * grpc.Client , err error ) {
156167 o := NewOptions (opts ... )
157168
158169 log .Debug ().Msgf ("Loading model %s from %s" , o .backendString , o .model )
@@ -166,22 +177,25 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
166177 ml .mu .Unlock ()
167178 }
168179
169- // if an external backend is provided, use it
170- _ , externalBackendExists := o .externalBackends [backend ]
171- if externalBackendExists {
172- return ml .LoadModel (o .model , ml .grpcModel (backend , o ))
173- }
180+ var backendToConsume string
174181
175182 switch backend {
176183 case Gpt4AllLlamaBackend , Gpt4AllMptBackend , Gpt4AllJBackend , Gpt4All :
177184 o .gRPCOptions .LibrarySearchPath = filepath .Join (o .assetDir , "backend-assets" , "gpt4all" )
178- return ml . LoadModel ( o . model , ml . grpcModel ( Gpt4All , o ))
185+ backendToConsume = Gpt4All
179186 case PiperBackend :
180187 o .gRPCOptions .LibrarySearchPath = filepath .Join (o .assetDir , "backend-assets" , "espeak-ng-data" )
181- return ml . LoadModel ( o . model , ml . grpcModel ( PiperBackend , o ))
188+ backendToConsume = PiperBackend
182189 default :
183- return ml .LoadModel (o .model , ml .grpcModel (backend , o ))
190+ backendToConsume = backend
191+ }
192+
193+ addr , err := ml .LoadModel (o .model , ml .grpcModel (backendToConsume , o ))
194+ if err != nil {
195+ return nil , err
184196 }
197+
198+ return ml .resolveAddress (addr , o .parallelRequests )
185199}
186200
187201func (ml * ModelLoader ) GreedyLoader (opts ... Option ) (* grpc.Client , error ) {
@@ -190,10 +204,11 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
190204 ml .mu .Lock ()
191205 // Return earlier if we have a model already loaded
192206 // (avoid looping through all the backends)
193- if m := ml .CheckIsLoaded (o .model ); m != nil {
207+ if m := ml .CheckIsLoaded (o .model ); m != "" {
194208 log .Debug ().Msgf ("Model '%s' already loaded" , o .model )
195209 ml .mu .Unlock ()
196- return m , nil
210+
211+ return ml .resolveAddress (m , o .parallelRequests )
197212 }
198213 // If we can have only one backend active, kill all the others (except external backends)
199214 if o .singleActiveBackend {
0 commit comments