Skip to content

Commit 824612f

Browse files
authored
feat: initial watchdog implementation (#1341)
* feat: initial watchdog implementation Signed-off-by: Ettore Di Giacinto <[email protected]> * fiuxups * Add more output * wip: idletime checker * wire idle watchdog checks * enlarge watchdog time window * small fixes * Use stopmodel * Always delete process Signed-off-by: Ettore Di Giacinto <[email protected]> --------- Signed-off-by: Ettore Di Giacinto <[email protected]> Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 9482acf commit 824612f

File tree

10 files changed

+341
-13
lines changed

10 files changed

+341
-13
lines changed

.env

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,18 @@ MODELS_PATH=/models
7272
# LLAMACPP_PARALLEL=1
7373

7474
### Enable to run parallel requests
75-
# PARALLEL_REQUESTS=true
75+
# PARALLEL_REQUESTS=true
76+
77+
### Watchdog settings
78+
###
79+
# Enables watchdog to kill backends that are inactive for too much time
80+
# WATCHDOG_IDLE=true
81+
#
82+
# Enables watchdog to kill backends that are busy for too much time
83+
# WATCHDOG_BUSY=true
84+
#
85+
# Time in duration format (e.g. 1h30m) after which a backend is considered idle
86+
# WATCHDOG_IDLE_TIMEOUT=5m
87+
#
88+
# Time in duration format (e.g. 1h30m) after which a backend is considered busy
89+
# WATCHDOG_BUSY_TIMEOUT=5m

api/api.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/go-skynet/LocalAI/internal"
1414
"github.com/go-skynet/LocalAI/metrics"
1515
"github.com/go-skynet/LocalAI/pkg/assets"
16+
"github.com/go-skynet/LocalAI/pkg/model"
1617

1718
"github.com/gofiber/fiber/v2"
1819
"github.com/gofiber/fiber/v2/middleware/cors"
@@ -79,6 +80,22 @@ func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader,
7980
options.Loader.StopAllGRPC()
8081
}()
8182

83+
if options.WatchDog {
84+
wd := model.NewWatchDog(
85+
options.Loader,
86+
options.WatchDogBusyTimeout,
87+
options.WatchDogIdleTimeout,
88+
options.WatchDogBusy,
89+
options.WatchDogIdle)
90+
options.Loader.SetWatchDog(wd)
91+
go wd.Run()
92+
go func() {
93+
<-options.Context.Done()
94+
log.Debug().Msgf("Context canceled, shutting down")
95+
wd.Shutdown()
96+
}()
97+
}
98+
8299
return options, cl, nil
83100
}
84101

api/localai/backend_monitor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ func BackendMonitorEndpoint(bm BackendMonitor) func(c *fiber.Ctx) error {
128128
return fmt.Errorf("backend %s is not currently loaded", backendId)
129129
}
130130

131-
status, rpcErr := model.GRPC(false).Status(context.TODO())
131+
status, rpcErr := model.GRPC(false, nil).Status(context.TODO())
132132
if rpcErr != nil {
133133
log.Warn().Msgf("backend %s experienced an error retrieving status info: %s", backendId, rpcErr.Error())
134134
val, slbErr := bm.SampleLocalBackendProcess(backendId)

api/options/options.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"embed"
66
"encoding/json"
7+
"time"
78

89
"github.com/go-skynet/LocalAI/metrics"
910
"github.com/go-skynet/LocalAI/pkg/gallery"
@@ -38,6 +39,11 @@ type Option struct {
3839

3940
SingleBackend bool
4041
ParallelBackendRequests bool
42+
43+
WatchDogIdle bool
44+
WatchDogBusy bool
45+
WatchDog bool
46+
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
4147
}
4248

4349
type AppOption func(*Option)
@@ -63,6 +69,32 @@ func WithCors(b bool) AppOption {
6369
}
6470
}
6571

72+
var EnableWatchDog = func(o *Option) {
73+
o.WatchDog = true
74+
}
75+
76+
var EnableWatchDogIdleCheck = func(o *Option) {
77+
o.WatchDog = true
78+
o.WatchDogIdle = true
79+
}
80+
81+
var EnableWatchDogBusyCheck = func(o *Option) {
82+
o.WatchDog = true
83+
o.WatchDogBusy = true
84+
}
85+
86+
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
87+
return func(o *Option) {
88+
o.WatchDogBusyTimeout = t
89+
}
90+
}
91+
92+
func SetWatchDogIdleTimeout(t time.Duration) AppOption {
93+
return func(o *Option) {
94+
o.WatchDogIdleTimeout = t
95+
}
96+
}
97+
6698
var EnableSingleBackend = func(o *Option) {
6799
o.SingleBackend = true
68100
}

main.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"path/filepath"
1111
"strings"
1212
"syscall"
13+
"time"
1314

1415
api "github.com/go-skynet/LocalAI/api"
1516
"github.com/go-skynet/LocalAI/api/backend"
@@ -154,6 +155,30 @@ func main() {
154155
Usage: "List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys.",
155156
EnvVars: []string{"API_KEY"},
156157
},
158+
&cli.BoolFlag{
159+
Name: "enable-watchdog-idle",
160+
Usage: "Enable watchdog for stopping idle backends. This will stop the backends if are in idle state for too long.",
161+
EnvVars: []string{"WATCHDOG_IDLE"},
162+
Value: false,
163+
},
164+
&cli.BoolFlag{
165+
Name: "enable-watchdog-busy",
166+
Usage: "Enable watchdog for stopping busy backends that exceed a defined threshold.",
167+
EnvVars: []string{"WATCHDOG_BUSY"},
168+
Value: false,
169+
},
170+
&cli.StringFlag{
171+
Name: "watchdog-busy-timeout",
172+
Usage: "Watchdog timeout. This will restart the backend if it crashes.",
173+
EnvVars: []string{"WATCHDOG_BUSY_TIMEOUT"},
174+
Value: "5m",
175+
},
176+
&cli.StringFlag{
177+
Name: "watchdog-idle-timeout",
178+
Usage: "Watchdog idle timeout. This will restart the backend if it crashes.",
179+
EnvVars: []string{"WATCHDOG_IDLE_TIMEOUT"},
180+
Value: "15m",
181+
},
157182
&cli.BoolFlag{
158183
Name: "preload-backend-only",
159184
Usage: "If set, the api is NOT launched, and only the preloaded models / backends are started. This is intended for multi-node setups.",
@@ -198,6 +223,28 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
198223
options.WithUploadLimitMB(ctx.Int("upload-limit")),
199224
options.WithApiKeys(ctx.StringSlice("api-keys")),
200225
}
226+
227+
idleWatchDog := ctx.Bool("enable-watchdog-idle")
228+
busyWatchDog := ctx.Bool("enable-watchdog-busy")
229+
if idleWatchDog || busyWatchDog {
230+
opts = append(opts, options.EnableWatchDog)
231+
if idleWatchDog {
232+
opts = append(opts, options.EnableWatchDogIdleCheck)
233+
dur, err := time.ParseDuration(ctx.String("watchdog-idle-timeout"))
234+
if err != nil {
235+
return err
236+
}
237+
opts = append(opts, options.SetWatchDogIdleTimeout(dur))
238+
}
239+
if busyWatchDog {
240+
opts = append(opts, options.EnableWatchDogBusyCheck)
241+
dur, err := time.ParseDuration(ctx.String("watchdog-busy-timeout"))
242+
if err != nil {
243+
return err
244+
}
245+
opts = append(opts, options.SetWatchDogBusyTimeout(dur))
246+
}
247+
}
201248
if ctx.Bool("parallel-requests") {
202249
opts = append(opts, options.EnableParallelBackendRequests)
203250
}

pkg/grpc/client.go

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,22 @@ type Client struct {
1919
parallel bool
2020
sync.Mutex
2121
opMutex sync.Mutex
22+
wd WatchDog
2223
}
2324

24-
func NewClient(address string, parallel bool) *Client {
25+
type WatchDog interface {
26+
Mark(address string)
27+
UnMark(address string)
28+
}
29+
30+
func NewClient(address string, parallel bool, wd WatchDog, enableWatchDog bool) *Client {
31+
if !enableWatchDog {
32+
wd = nil
33+
}
2534
return &Client{
2635
address: address,
2736
parallel: parallel,
37+
wd: wd,
2838
}
2939
}
3040

@@ -79,6 +89,10 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...
7989
}
8090
c.setBusy(true)
8191
defer c.setBusy(false)
92+
if c.wd != nil {
93+
c.wd.Mark(c.address)
94+
defer c.wd.UnMark(c.address)
95+
}
8296
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
8397
if err != nil {
8498
return nil, err
@@ -96,6 +110,10 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp
96110
}
97111
c.setBusy(true)
98112
defer c.setBusy(false)
113+
if c.wd != nil {
114+
c.wd.Mark(c.address)
115+
defer c.wd.UnMark(c.address)
116+
}
99117
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
100118
if err != nil {
101119
return nil, err
@@ -113,6 +131,10 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
113131
}
114132
c.setBusy(true)
115133
defer c.setBusy(false)
134+
if c.wd != nil {
135+
c.wd.Mark(c.address)
136+
defer c.wd.UnMark(c.address)
137+
}
116138
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
117139
if err != nil {
118140
return nil, err
@@ -129,6 +151,10 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
129151
}
130152
c.setBusy(true)
131153
defer c.setBusy(false)
154+
if c.wd != nil {
155+
c.wd.Mark(c.address)
156+
defer c.wd.UnMark(c.address)
157+
}
132158
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
133159
if err != nil {
134160
return err
@@ -164,6 +190,10 @@ func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest,
164190
}
165191
c.setBusy(true)
166192
defer c.setBusy(false)
193+
if c.wd != nil {
194+
c.wd.Mark(c.address)
195+
defer c.wd.UnMark(c.address)
196+
}
167197
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
168198
if err != nil {
169199
return nil, err
@@ -180,6 +210,10 @@ func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOp
180210
}
181211
c.setBusy(true)
182212
defer c.setBusy(false)
213+
if c.wd != nil {
214+
c.wd.Mark(c.address)
215+
defer c.wd.UnMark(c.address)
216+
}
183217
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
184218
if err != nil {
185219
return nil, err
@@ -196,6 +230,10 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
196230
}
197231
c.setBusy(true)
198232
defer c.setBusy(false)
233+
if c.wd != nil {
234+
c.wd.Mark(c.address)
235+
defer c.wd.UnMark(c.address)
236+
}
199237
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
200238
if err != nil {
201239
return nil, err
@@ -232,6 +270,10 @@ func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts
232270
}
233271
c.setBusy(true)
234272
defer c.setBusy(false)
273+
if c.wd != nil {
274+
c.wd.Mark(c.address)
275+
defer c.wd.UnMark(c.address)
276+
}
235277
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
236278
if err != nil {
237279
return nil, err

pkg/model/initializers.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
121121
// Wait for the service to start up
122122
ready := false
123123
for i := 0; i < o.grpcAttempts; i++ {
124-
if client.GRPC(o.parallelRequests).HealthCheck(context.Background()) {
124+
if client.GRPC(o.parallelRequests, ml.wd).HealthCheck(context.Background()) {
125125
log.Debug().Msgf("GRPC Service Ready")
126126
ready = true
127127
break
@@ -140,7 +140,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
140140

141141
log.Debug().Msgf("GRPC: Loading model with options: %+v", options)
142142

143-
res, err := client.GRPC(o.parallelRequests).LoadModel(o.context, &options)
143+
res, err := client.GRPC(o.parallelRequests, ml.wd).LoadModel(o.context, &options)
144144
if err != nil {
145145
return "", fmt.Errorf("could not load model: %w", err)
146146
}
@@ -154,11 +154,11 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
154154

155155
func (ml *ModelLoader) resolveAddress(addr ModelAddress, parallel bool) (*grpc.Client, error) {
156156
if parallel {
157-
return addr.GRPC(parallel), nil
157+
return addr.GRPC(parallel, ml.wd), nil
158158
}
159159

160160
if _, ok := ml.grpcClients[string(addr)]; !ok {
161-
ml.grpcClients[string(addr)] = addr.GRPC(parallel)
161+
ml.grpcClients[string(addr)] = addr.GRPC(parallel, ml.wd)
162162
}
163163
return ml.grpcClients[string(addr)], nil
164164
}

pkg/model/loader.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,17 @@ type ModelLoader struct {
6363
models map[string]ModelAddress
6464
grpcProcesses map[string]*process.Process
6565
templates map[TemplateType]map[string]*template.Template
66+
wd *WatchDog
6667
}
6768

6869
type ModelAddress string
6970

70-
func (m ModelAddress) GRPC(parallel bool) *grpc.Client {
71-
return grpc.NewClient(string(m), parallel)
71+
func (m ModelAddress) GRPC(parallel bool, wd *WatchDog) *grpc.Client {
72+
enableWD := false
73+
if wd != nil {
74+
enableWD = true
75+
}
76+
return grpc.NewClient(string(m), parallel, wd, enableWD)
7277
}
7378

7479
func NewModelLoader(modelPath string) *ModelLoader {
@@ -79,10 +84,15 @@ func NewModelLoader(modelPath string) *ModelLoader {
7984
templates: make(map[TemplateType]map[string]*template.Template),
8085
grpcProcesses: make(map[string]*process.Process),
8186
}
87+
8288
nml.initializeTemplateMap()
8389
return nml
8490
}
8591

92+
func (ml *ModelLoader) SetWatchDog(wd *WatchDog) {
93+
ml.wd = wd
94+
}
95+
8696
func (ml *ModelLoader) ExistsInModelPath(s string) bool {
8797
return existsInPath(ml.ModelPath, s)
8898
}
@@ -139,11 +149,17 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
139149
func (ml *ModelLoader) ShutdownModel(modelName string) error {
140150
ml.mu.Lock()
141151
defer ml.mu.Unlock()
152+
153+
return ml.StopModel(modelName)
154+
}
155+
156+
func (ml *ModelLoader) StopModel(modelName string) error {
157+
defer ml.deleteProcess(modelName)
142158
if _, ok := ml.models[modelName]; !ok {
143159
return fmt.Errorf("model %s not found", modelName)
144160
}
145-
146-
return ml.deleteProcess(modelName)
161+
return nil
162+
//return ml.deleteProcess(modelName)
147163
}
148164

149165
func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress {
@@ -153,7 +169,7 @@ func (ml *ModelLoader) CheckIsLoaded(s string) ModelAddress {
153169
if c, ok := ml.grpcClients[s]; ok {
154170
client = c
155171
} else {
156-
client = m.GRPC(false)
172+
client = m.GRPC(false, ml.wd)
157173
}
158174

159175
if !client.HealthCheck(context.Background()) {

0 commit comments

Comments
 (0)