diff --git a/async-gateway/endpoint.go b/async-gateway/endpoint.go new file mode 100644 index 0000000000..1e5ea2ece3 --- /dev/null +++ b/async-gateway/endpoint.go @@ -0,0 +1,110 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/gorilla/mux" + "go.uber.org/zap" +) + +// Endpoint wraps an async-gateway Service with HTTP logic +type Endpoint struct { + service Service + logger *zap.Logger +} + +// NewEndpoint creates and initializes a new Endpoint struct +func NewEndpoint(svc Service, logger *zap.Logger) *Endpoint { + return &Endpoint{ + service: svc, + logger: logger, + } +} + +// CreateWorkload is a handler for the async-gateway service workload creation route +func (e *Endpoint) CreateWorkload(w http.ResponseWriter, r *http.Request) { + requestID := r.Header.Get("x-request-id") + if requestID == "" { + respondPlainText(w, http.StatusBadRequest, "error: missing x-request-id key in request header") + return + } + + contentType := r.Header.Get("Content-Type") + if contentType == "" { + respondPlainText(w, http.StatusBadRequest, "error: missing Content-Type key in request header") + return + } + + body := r.Body + defer func() { + _ = r.Body.Close() + }() + + log := e.logger.With(zap.String("id", requestID), zap.String("contentType", contentType)) + + id, err := e.service.CreateWorkload(requestID, body, contentType) + if err != nil { + log.Error("failed to create workload", zap.Error(err)) + respondPlainText(w, http.StatusInternalServerError, fmt.Sprintf("error: %v", err)) + return + } + + if err = respondJSON(w, http.StatusOK, CreateWorkloadResponse{ID: id}); err != nil { + log.Error("failed to encode json response", zap.Error(err)) + return + } +} + +// GetWorkload is a handler for the async-gateway service workload retrieval route +func (e *Endpoint) GetWorkload(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + id, ok := vars["id"] + if !ok { + respondPlainText(w, http.StatusBadRequest, "error: missing request id in url path") + return + } + + log := e.logger.With(zap.String("id", id)) + + res, err := e.service.GetWorkload(id) + if err != nil { + log.Error("failed to get workload", zap.Error(err)) + respondPlainText(w, http.StatusInternalServerError, fmt.Sprintf("error: %v", err)) + return + } + + if err = respondJSON(w, http.StatusOK, res); err != nil { + log.Error("failed to encode json response", zap.Error(err)) + return + } +} + +func respondPlainText(w http.ResponseWriter, statusCode int, message string) { + w.WriteHeader(statusCode) + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte(message)) +} + +func respondJSON(w http.ResponseWriter, statusCode int, s interface{}) error { + w.WriteHeader(statusCode) + w.Header().Set("Content-Type", "application/json") + return json.NewEncoder(w).Encode(s) +} diff --git a/async-gateway/go.mod b/async-gateway/go.mod new file mode 100644 index 0000000000..610b4c9734 --- /dev/null +++ b/async-gateway/go.mod @@ -0,0 +1,9 @@ +module github.com/cortexlabs/async-gateway + +go 1.15 + +require ( + github.com/aws/aws-sdk-go v1.37.23 + github.com/gorilla/mux v1.8.0 + go.uber.org/zap v1.16.0 +) diff --git a/async-gateway/go.sum b/async-gateway/go.sum new file mode 100644 index 0000000000..dcc39f4523 --- /dev/null +++ b/async-gateway/go.sum @@ -0,0 +1,72 @@ +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/aws/aws-sdk-go v1.37.23 h1:bO80NcSmRv52w+GFpBegoLdlP/Z0OwUqQ9bbeCLCy/0= +github.com/aws/aws-sdk-go v1.37.23/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.16.0 h1:uFRZXykJGK9lLY4HtgSw44DnIcAM+kRBP7x5m+NpAOM= +go.uber.org/zap v1.16.0/go.mod h1:MA8QOfq0BHJwdXa996Y4dYkAqRKB8/1K1QMMZVaNZjQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 h1:hKsoRgsbwY1NafxrwTs+k64bikrLBkAgPir1TNCj3Zs= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/async-gateway/main.go b/async-gateway/main.go new file mode 100644 index 0000000000..b37a51661f --- /dev/null +++ b/async-gateway/main.go @@ -0,0 +1,135 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "flag" + "net/http" + "os" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/gorilla/mux" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +const ( + _defaultPort = "8080" +) + +func createLogger() (*zap.Logger, error) { + logLevelEnv := os.Getenv("CORTEX_LOG_LEVEL") + disableJSONLogging := os.Getenv("CORTEX_DISABLE_JSON_LOGGING") + + var logLevelZap zapcore.Level + switch logLevelEnv { + case "DEBUG": + logLevelZap = zapcore.DebugLevel + case "WARNING": + logLevelZap = zapcore.WarnLevel + case "ERROR": + logLevelZap = zapcore.ErrorLevel + default: + logLevelZap = zapcore.InfoLevel + } + + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.MessageKey = "message" + + encoding := "json" + if strings.ToLower(disableJSONLogging) == "true" { + encoding = "console" + } + + return zap.Config{ + Level: zap.NewAtomicLevelAt(logLevelZap), + Encoding: encoding, + EncoderConfig: encoderConfig, + OutputPaths: []string{"stdout"}, + ErrorOutputPaths: []string{"stderr"}, + }.Build() +} + +// usage: ./gateway -bucket -region -port -queue queue +func main() { + log, err := createLogger() + if err != nil { + panic(err) + } + defer func() { + _ = log.Sync() + }() + + var ( + port = flag.String("port", _defaultPort, "port on which the gateway server runs on") + queueURL = flag.String("queue", "", "SQS queue URL") + region = flag.String("region", "", "AWS region") + bucket = flag.String("bucket", "", "AWS bucket") + clusterName = flag.String("cluster", "", "cluster name") + ) + flag.Parse() + + switch { + case *queueURL == "": + log.Fatal("missing required option: -queue") + case *region == "": + log.Fatal("missing required option: -region") + case *bucket == "": + log.Fatal("missing required option: -bucket") + case *clusterName == "": + log.Fatal("missing required option: -cluster") + } + + apiName := flag.Arg(0) + if apiName == "" { + log.Fatal("apiName argument was not provided") + } + + sess, err := session.NewSessionWithOptions(session.Options{ + Config: aws.Config{ + Region: region, + }, + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + log.Fatal("failed to create AWS session: %s", zap.Error(err)) + } + + s3Storage := NewS3(sess, *bucket) + + sqsQueue := NewSQS(*queueURL, sess) + + svc := NewService(*clusterName, apiName, sqsQueue, s3Storage, log) + ep := NewEndpoint(svc, log) + + router := mux.NewRouter() + router.HandleFunc("/", ep.CreateWorkload).Methods("POST") + router.HandleFunc( + "/healthz", + func(w http.ResponseWriter, r *http.Request) { + respondPlainText(w, http.StatusOK, "ok") + }, + ) + router.HandleFunc("/{id}", ep.GetWorkload).Methods("GET") + + log.Info("Running on port " + *port) + if err = http.ListenAndServe(":"+*port, router); err != nil { + log.Fatal("failed to start server", zap.Error(err)) + } +} diff --git a/async-gateway/queue.go b/async-gateway/queue.go new file mode 100644 index 0000000000..23f9d7916d --- /dev/null +++ b/async-gateway/queue.go @@ -0,0 +1,51 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + awssqs "github.com/aws/aws-sdk-go/service/sqs" +) + +// Queue is an interface to abstract communication with event queues +type Queue interface { + SendMessage(message string, uniqueID string) error +} + +type sqs struct { + queueURL string + client *awssqs.SQS +} + +// NewSQS creates a new SQS client that satisfies the Queue interface +func NewSQS(queueURL string, sess *session.Session) Queue { + client := awssqs.New(sess) + + return &sqs{queueURL: queueURL, client: client} +} + +// SendMessage sends a string +func (q *sqs) SendMessage(message string, uniqueID string) error { + _, err := q.client.SendMessage(&awssqs.SendMessageInput{ + MessageBody: aws.String(message), + MessageDeduplicationId: aws.String(uniqueID), + MessageGroupId: aws.String(uniqueID), + QueueUrl: aws.String(q.queueURL), + }) + return err +} diff --git a/async-gateway/service.go b/async-gateway/service.go new file mode 100644 index 0000000000..e73630b50a --- /dev/null +++ b/async-gateway/service.go @@ -0,0 +1,132 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "encoding/json" + "fmt" + "io" + "strings" + + "go.uber.org/zap" +) + +// Service provides an interface to the async-gateway business logic +type Service interface { + CreateWorkload(id string, payload io.Reader, contentType string) (string, error) + GetWorkload(id string) (GetWorkloadResponse, error) +} + +type service struct { + logger *zap.Logger + queue Queue + storage Storage + clusterName string + apiName string +} + +// NewService creates a new async-gateway service +func NewService(clusterName, apiName string, queue Queue, storage Storage, logger *zap.Logger) Service { + return &service{ + logger: logger, + queue: queue, + storage: storage, + clusterName: clusterName, + apiName: apiName, + } +} + +// CreateWorkload enqueues an async workload request and uploads the request payload to cloud storage +func (s *service) CreateWorkload(id string, payload io.Reader, contentType string) (string, error) { + prefix := s.workloadStoragePrefix() + log := s.logger.With(zap.String("id", "id"), zap.String("contentType", contentType)) + + payloadPath := fmt.Sprintf("%s/%s/payload", prefix, id) + log.Debug("uploading payload", zap.String("path", payloadPath)) + if err := s.storage.Upload(payloadPath, payload, contentType); err != nil { + return "", err + } + + log.Debug("sending message to queue") + if err := s.queue.SendMessage(id, id); err != nil { + return "", err + } + + statusPath := fmt.Sprintf("%s/%s/status", prefix, id) + log.Debug(fmt.Sprintf("setting status to %s", StatusInQueue)) + if err := s.storage.Upload(statusPath, strings.NewReader(string(StatusInQueue)), "text/plain"); err != nil { + return "", err + } + + return id, nil +} + +// GetWorkload retrieves the status and result, if available, of a given workload +func (s *service) GetWorkload(id string) (GetWorkloadResponse, error) { + prefix := s.workloadStoragePrefix() + log := s.logger.With(zap.String("id", id)) + + // download workload status + statusPath := fmt.Sprintf("%s/%s/status", prefix, id) + log.Debug("downloading status file", zap.String("path", statusPath)) + statusBuf, err := s.storage.Download(statusPath) + if err != nil { + return GetWorkloadResponse{}, err + } + + status := Status(statusBuf[:]) + switch status { + case StatusFailed, StatusInProgress, StatusInQueue: + return GetWorkloadResponse{ + ID: id, + Status: status, + }, nil + case StatusCompleted: // continues execution after switch/case, below + default: + return GetWorkloadResponse{}, fmt.Errorf("invalid workload status: %s", status) + } + + // attempt to download user result + resultPath := fmt.Sprintf("%s/%s/result.json", prefix, id) + log.Debug("donwloading user result", zap.String("path", resultPath)) + resultBuf, err := s.storage.Download(resultPath) + if err != nil { + return GetWorkloadResponse{}, err + } + + var userResponse UserResponse + if err = json.Unmarshal(resultBuf, &userResponse); err != nil { + return GetWorkloadResponse{}, err + } + + log.Debug("getting workload timestamp") + timestamp, err := s.storage.GetLastModified(resultPath) + if err != nil { + return GetWorkloadResponse{}, err + } + + return GetWorkloadResponse{ + ID: id, + Status: status, + Result: &userResponse, + Timestamp: ×tamp, + }, nil +} + +func (s *service) workloadStoragePrefix() string { + return fmt.Sprintf("%s/apis/%s/workloads", s.clusterName, s.apiName) +} diff --git a/async-gateway/storage.go b/async-gateway/storage.go new file mode 100644 index 0000000000..f78c2c997b --- /dev/null +++ b/async-gateway/storage.go @@ -0,0 +1,95 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "io" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + awss3 "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" +) + +// Storage is an interface that abstracts cloud storage uploading +type Storage interface { + Upload(key string, payload io.Reader, contentType string) error + Download(key string) ([]byte, error) + GetLastModified(key string) (time.Time, error) +} + +type s3 struct { + uploader *s3manager.Uploader + downloader *s3manager.Downloader + client *awss3.S3 + bucket string +} + +// NewS3 creates a new S3 client that satisfies the Storage interface +func NewS3(sess *session.Session, bucket string) Storage { + uploader := s3manager.NewUploader(sess) + downloader := s3manager.NewDownloader(sess) + client := awss3.New(sess) + return &s3{ + uploader: uploader, + bucket: bucket, + downloader: downloader, + client: client, + } +} + +// Upload uploads binary data to S3 +func (s *s3) Upload(key string, payload io.Reader, contentType string) error { + _, err := s.uploader.Upload(&s3manager.UploadInput{ + Key: aws.String(key), + Bucket: aws.String(s.bucket), + ContentType: aws.String(contentType), + Body: payload, + }) + return err +} + +// Download downloads a file from S3 into memory +func (s *s3) Download(key string) ([]byte, error) { + buff := &aws.WriteAtBuffer{} + input := awss3.GetObjectInput{ + Key: aws.String(key), + Bucket: aws.String(s.bucket), + } + + _, err := s.downloader.Download(buff, &input) + if err != nil { + return nil, err + } + + return buff.Bytes(), nil +} + +// GetLastModified retrieves the last modified timestamp of an S3 object +func (s *s3) GetLastModified(key string) (time.Time, error) { + input := awss3.GetObjectInput{ + Key: aws.String(key), + Bucket: aws.String(s.bucket), + } + obj, err := s.client.GetObject(&input) + if err != nil { + return time.Time{}, err + } + + return *obj.LastModified, nil +} diff --git a/async-gateway/types.go b/async-gateway/types.go new file mode 100644 index 0000000000..ccf26795e2 --- /dev/null +++ b/async-gateway/types.go @@ -0,0 +1,46 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import "time" + +// UserResponse represents the user's API response, which has to be JSON serializable +type UserResponse = map[string]interface{} + +// Status is an enum type for workload status +type Status string + +// Different possible workload status +const ( + StatusFailed Status = "failed" + StatusInProgress Status = "in_progress" + StatusInQueue Status = "in_queue" + StatusCompleted Status = "completed" +) + +//CreateWorkloadResponse represents the response returned to the user on workload creation +type CreateWorkloadResponse struct { + ID string `json:"id"` +} + +// GetWorkloadResponse represents the workload response that is returned to the user +type GetWorkloadResponse struct { + ID string `json:"id"` + Status Status `json:"status"` + Result *UserResponse `json:"result,omitempty"` + Timestamp *time.Time `json:"timestamp,omitempty"` +} diff --git a/build/images.sh b/build/images.sh index 6b3742e711..0f2bc6f4d6 100644 --- a/build/images.sh +++ b/build/images.sh @@ -38,6 +38,7 @@ dev_images_cluster=( "downloader" "manager" "request-monitor" + "async-gateway" ) dev_images_aws=( # includes dev_images_cluster diff --git a/cli/cmd/get.go b/cli/cmd/get.go index 391ba841fc..50f65a7177 100644 --- a/cli/cmd/get.go +++ b/cli/cmd/get.go @@ -197,6 +197,8 @@ func getAPIsInAllEnvironments() (string, error) { var allRealtimeAPIs []schema.APIResponse var allRealtimeAPIEnvs []string + var allAsyncAPIs []schema.APIResponse + var allAsyncAPIEnvs []string var allBatchAPIs []schema.APIResponse var allBatchAPIEnvs []string var allTaskAPIs []schema.APIResponse @@ -231,6 +233,9 @@ func getAPIsInAllEnvironments() (string, error) { case userconfig.RealtimeAPIKind: allRealtimeAPIEnvs = append(allRealtimeAPIEnvs, env.Name) allRealtimeAPIs = append(allRealtimeAPIs, api) + case userconfig.AsyncAPIKind: + allAsyncAPIEnvs = append(allAsyncAPIEnvs, env.Name) + allAsyncAPIs = append(allAsyncAPIs, api) case userconfig.TaskAPIKind: allTaskAPIEnvs = append(allTaskAPIEnvs, env.Name) allTaskAPIs = append(allTaskAPIs, api) @@ -258,7 +263,7 @@ func getAPIsInAllEnvironments() (string, error) { out := "" - if len(allRealtimeAPIs) == 0 && len(allBatchAPIs) == 0 && len(allTrafficSplitters) == 0 && len(allTaskAPIs) == 0 { + if len(allRealtimeAPIs) == 0 && len(allAsyncAPIs) == 0 && len(allBatchAPIs) == 0 && len(allTrafficSplitters) == 0 && len(allTaskAPIs) == 0 { // check if any environments errorred if len(errorsMap) != len(cliConfig.Environments) { if len(errorsMap) == 0 { @@ -301,11 +306,18 @@ func getAPIsInAllEnvironments() (string, error) { } out += t.MustFormat() } + if len(allAsyncAPIs) > 0 { + t := asyncAPIsTable(allAsyncAPIs, allAsyncAPIEnvs) + if len(allBatchAPIs) > 0 || len(allTaskAPIs) > 0 || len(allRealtimeAPIs) > 0 { + out += "\n" + } + out += t.MustFormat() + } if len(allTrafficSplitters) > 0 { t := trafficSplitterListTable(allTrafficSplitters, allTrafficSplitterEnvs) - if len(allRealtimeAPIs) > 0 || len(allBatchAPIs) > 0 || len(allTaskAPIs) > 0 { + if len(allBatchAPIs) > 0 || len(allTaskAPIs) > 0 || len(allRealtimeAPIs) > 0 || len(allAsyncAPIs) > 0 { out += "\n" } @@ -339,6 +351,7 @@ func getAPIsByEnv(env cliconfig.Environment, printEnv bool) (string, error) { } var allRealtimeAPIs []schema.APIResponse + var allAsyncAPIs []schema.APIResponse var allBatchAPIs []schema.APIResponse var allTaskAPIs []schema.APIResponse var allTrafficSplitters []schema.APIResponse @@ -351,6 +364,8 @@ func getAPIsByEnv(env cliconfig.Environment, printEnv bool) (string, error) { allTaskAPIs = append(allTaskAPIs, api) case userconfig.RealtimeAPIKind: allRealtimeAPIs = append(allRealtimeAPIs, api) + case userconfig.AsyncAPIKind: + allAsyncAPIs = append(allAsyncAPIs, api) case userconfig.TrafficSplitterKind: allTrafficSplitters = append(allTrafficSplitters, api) } @@ -406,6 +421,22 @@ func getAPIsByEnv(env cliconfig.Environment, printEnv bool) (string, error) { out += t.MustFormat() } + if len(allAsyncAPIs) > 0 { + envNames := []string{} + for range allRealtimeAPIs { + envNames = append(envNames, env.Name) + } + + t := asyncAPIsTable(allAsyncAPIs, envNames) + t.FindHeaderByTitle(_titleEnvironment).Hidden = true + + if len(allBatchAPIs) > 0 || len(allTaskAPIs) > 0 || len(allRealtimeAPIs) > 0 { + out += "\n" + } + + out += t.MustFormat() + } + if len(allTrafficSplitters) > 0 { envNames := []string{} for range allTrafficSplitters { @@ -415,7 +446,7 @@ func getAPIsByEnv(env cliconfig.Environment, printEnv bool) (string, error) { t := trafficSplitterListTable(allTrafficSplitters, envNames) t.FindHeaderByTitle(_titleEnvironment).Hidden = true - if len(allBatchAPIs) > 0 || len(allTaskAPIs) > 0 || len(allRealtimeAPIs) > 0 { + if len(allBatchAPIs) > 0 || len(allTaskAPIs) > 0 || len(allRealtimeAPIs) > 0 || len(allAsyncAPIs) > 0 { out += "\n" } @@ -448,6 +479,8 @@ func getAPI(env cliconfig.Environment, apiName string) (string, error) { switch apiRes.Spec.Kind { case userconfig.RealtimeAPIKind: return realtimeAPITable(apiRes, env) + case userconfig.AsyncAPIKind: + return asyncAPITable(apiRes, env) case userconfig.TrafficSplitterKind: return trafficSplitterTable(apiRes, env) case userconfig.BatchAPIKind: diff --git a/cli/cmd/lib_async_apis.go b/cli/cmd/lib_async_apis.go new file mode 100644 index 0000000000..ab80f1684e --- /dev/null +++ b/cli/cmd/lib_async_apis.go @@ -0,0 +1,101 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cmd + +import ( + "strings" + "time" + + "github.com/cortexlabs/cortex/cli/types/cliconfig" + "github.com/cortexlabs/cortex/pkg/lib/console" + "github.com/cortexlabs/cortex/pkg/lib/table" + libtime "github.com/cortexlabs/cortex/pkg/lib/time" + "github.com/cortexlabs/cortex/pkg/operator/schema" + "github.com/cortexlabs/cortex/pkg/types/userconfig" +) + +const ( + _titleAsyncAPI = "async api" +) + +func asyncAPITable(asyncAPI schema.APIResponse, env cliconfig.Environment) (string, error) { + var out string + + t := asyncAPIsTable([]schema.APIResponse{asyncAPI}, []string{env.Name}) + t.FindHeaderByTitle(_titleEnvironment).Hidden = true + t.FindHeaderByTitle(_titleAsyncAPI).Hidden = true + + out += t.MustFormat() + + if asyncAPI.DashboardURL != nil && *asyncAPI.DashboardURL != "" { + out += "\n" + console.Bold("metrics dashboard: ") + *asyncAPI.DashboardURL + "\n" + } + + out += "\n" + console.Bold("endpoint: ") + asyncAPI.Endpoint + "\n" + + if !(asyncAPI.Spec.Predictor.Type == userconfig.PythonPredictorType && asyncAPI.Spec.Predictor.MultiModelReloading == nil) { + out += "\n" + describeModelInput(asyncAPI.Status, asyncAPI.Spec.Predictor, asyncAPI.Endpoint) + } + + out += "\n" + apiHistoryTable(asyncAPI.APIVersions) + + if !_flagVerbose { + return out, nil + } + + out += titleStr("configuration") + strings.TrimSpace(asyncAPI.Spec.UserStr(env.Provider)) + + return out, nil +} + +func asyncAPIsTable(asyncAPIs []schema.APIResponse, envNames []string) table.Table { + rows := make([][]interface{}, 0, len(asyncAPIs)) + + var totalFailed int32 + var totalStale int32 + + for i, asyncAPI := range asyncAPIs { + lastUpdated := time.Unix(asyncAPI.Spec.LastUpdated, 0) + rows = append(rows, []interface{}{ + envNames[i], + asyncAPI.Spec.Name, + asyncAPI.Status.Message(), + asyncAPI.Status.Updated.Ready, + asyncAPI.Status.Stale.Ready, + asyncAPI.Status.Requested, + asyncAPI.Status.Updated.TotalFailed(), + libtime.SinceStr(&lastUpdated), + }) + + totalFailed += asyncAPI.Status.Updated.TotalFailed() + totalStale += asyncAPI.Status.Stale.Ready + } + + return table.Table{ + Headers: []table.Header{ + {Title: _titleEnvironment}, + {Title: _titleAsyncAPI}, + {Title: _titleStatus}, + {Title: _titleUpToDate}, + {Title: _titleStale, Hidden: totalStale == 0}, + {Title: _titleRequested}, + {Title: _titleFailed, Hidden: totalFailed == 0}, + {Title: _titleLastupdated}, + }, + Rows: rows, + } +} diff --git a/dev/operator_local.sh b/dev/operator_local.sh index b14743097b..693fbab504 100755 --- a/dev/operator_local.sh +++ b/dev/operator_local.sh @@ -95,6 +95,7 @@ fi export CORTEX_OPERATOR_IN_CLUSTER=false export CORTEX_CLUSTER_CONFIG_PATH=~/.cortex/cluster-dev.yaml export CORTEX_DISABLE_JSON_LOGGING=true +export CORTEX_OPERATOR_LOG_LEVEL=debug export CORTEX_PROMETHEUS_URL="http://localhost:9090" portForwardCMD="kubectl port-forward -n default prometheus-prometheus-0 9090" diff --git a/docs/clusters/aws/auth.md b/docs/clusters/aws/auth.md index e4fd96ce69..5090928a69 100644 --- a/docs/clusters/aws/auth.md +++ b/docs/clusters/aws/auth.md @@ -77,7 +77,7 @@ _NOTE: The policy created during `cortex cluster up` will automatically be delet { "Effect": "Allow", "Action": "sqs:*", - "Resource": "arn:aws:sqs:{{ .Region }}:{{ .AccountID }}:cortex-*" + "Resource": "arn:aws:sqs:{{ .Region }}:{{ .AccountID }}:cx-*" }, { "Effect": "Allow", diff --git a/go.sum b/go.sum index f047d6b1af..faaec8218e 100644 --- a/go.sum +++ b/go.sum @@ -81,6 +81,7 @@ github.com/aws/aws-sdk-go v1.36.2/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zK github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= @@ -431,6 +432,7 @@ github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7q github.com/prometheus/common v0.4.0 h1:7etb9YClo3a6HjLzfl6rIQaU+FDfi0VSX39io3aQ+DM= github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084 h1:sofwID9zm4tzrgykg80hfFph1mryUeLRsUfoocVVmRY= github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/tsdb v0.7.1 h1:YZcsG11NqnK4czYLrWd9mpEuAJIHVQLwdrleYfszMAA= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= diff --git a/images/async-gateway/Dockerfile b/images/async-gateway/Dockerfile new file mode 100644 index 0000000000..10db56ad4e --- /dev/null +++ b/images/async-gateway/Dockerfile @@ -0,0 +1,16 @@ +FROM golang:1.15 as builder + +COPY async-gateway/go.mod async-gateway/go.sum /workspace/async-gateway/ +WORKDIR /workspace/async-gateway/ +RUN go mod download + +COPY async-gateway/*.go /workspace/async-gateway/ +RUN GO111MODULE=on CGO_ENABLED=0 GOOS=linux go build -installsuffix cgo -o async-gateway . + +FROM alpine:3.12 + +RUN apk update && apk add ca-certificates + +COPY --from=builder /workspace/async-gateway/async-gateway /root/ + +ENTRYPOINT ["/root/async-gateway"] diff --git a/manager/manifests/operator.yaml.j2 b/manager/manifests/operator.yaml.j2 index 42ee9cb1e4..949a48af18 100644 --- a/manager/manifests/operator.yaml.j2 +++ b/manager/manifests/operator.yaml.j2 @@ -54,7 +54,7 @@ spec: serviceAccountName: operator containers: - name: operator - image: {{ env['CORTEX_IMAGE_OPERATOR'] }} + image: {{ config['image_operator'] }} imagePullPolicy: Always resources: requests: @@ -73,7 +73,7 @@ spec: mountPath: /configs/cluster - name: docker-client mountPath: /var/run/docker.sock - {% if env['CORTEX_PROVIDER'] == "gcp" %} + {% if config['provider'] == "gcp" %} - name: gcp-credentials mountPath: /var/secrets/google {% endif %} @@ -85,7 +85,7 @@ spec: hostPath: path: /var/run/docker.sock type: Socket - {% if env['CORTEX_PROVIDER'] == "gcp" %} + {% if config['provider'] == "gcp" %} - name: gcp-credentials secret: secretName: gcp-credentials @@ -97,6 +97,8 @@ kind: Service metadata: namespace: default name: operator + labels: + cortex.dev/name: operator spec: selector: workloadID: operator diff --git a/manager/manifests/prometheus-monitoring.yaml.j2 b/manager/manifests/prometheus-monitoring.yaml.j2 index d10bf75cd9..77f388d8dc 100644 --- a/manager/manifests/prometheus-monitoring.yaml.j2 +++ b/manager/manifests/prometheus-monitoring.yaml.j2 @@ -45,7 +45,7 @@ spec: matchExpressions: - key: "monitoring.cortex.dev" operator: "In" - values: [ "kubelet-exporter", "node-exporter" ] + values: [ "kubelet-exporter", "node-exporter", "operator" ] ruleSelector: matchLabels: prometheus: k8s @@ -236,3 +236,25 @@ spec: selector: matchLabels: name: prometheus-statsd-exporter + +--- + +apiVersion: monitoring.coreos.com/v1 +kind: ServiceMonitor +metadata: + name: operator + labels: + name: operator + monitoring.cortex.dev: "operator" +spec: + jobLabel: "operator" + endpoints: + - port: http + scheme: http + path: /metrics + interval: 10s + namespaceSelector: + any: true + selector: + matchLabels: + cortex.dev/name: operator diff --git a/pkg/cortex/serve/cortex_internal/lib/api/async.py b/pkg/cortex/serve/cortex_internal/lib/api/async.py new file mode 100644 index 0000000000..7e662c7e4c --- /dev/null +++ b/pkg/cortex/serve/cortex_internal/lib/api/async.py @@ -0,0 +1,188 @@ +# Copyright 2021 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import imp +import inspect +import json +import os +from copy import deepcopy +from http import HTTPStatus +from typing import Any, Dict, Union + +import datadog +import dill + +from cortex_internal.lib.api.validations import validate_class_impl +from cortex_internal.lib.exceptions import CortexException, UserException, UserRuntimeException +from cortex_internal.lib.metrics import MetricsClient +from cortex_internal.lib.storage import S3 + +ASYNC_PYTHON_PREDICTOR_VALIDATION = { + "required": [ + { + "name": "__init__", + "required_args": ["self", "config"], + "optional_args": ["metrics_client"], + }, + { + "name": "predict", + "required_args": ["self"], + "optional_args": ["payload", "request_id"], + }, + ], +} + + +class AsyncAPI: + def __init__( + self, + api_spec: Dict[str, Any], + storage: S3, + storage_path: str, + statsd_host: str, + statsd_port: int, + ): + self.api_spec = api_spec + self.storage = storage + self.storage_path = storage_path + self.path = api_spec["predictor"]["path"] + self.config = api_spec["predictor"].get("config", {}) + + datadog.initialize(statsd_host=statsd_host, statsd_port=statsd_port) + self.__statsd = datadog.statsd + + @property + def statsd(self): + return self.__statsd + + def update_status(self, request_id: str, status: str): + self.storage.put_str(status, f"{self.storage_path}/{request_id}/status") + + def upload_result(self, request_id: str, result: Dict[str, Any]): + if not isinstance(result, dict): + raise UserRuntimeException( + f"user response must be json serializable dictionary, got {type(result)} instead" + ) + + try: + result_json = json.dumps(result) + except Exception: + raise UserRuntimeException("user response is not json serializable") + + self.storage.put_str(result_json, f"{self.storage_path}/{request_id}/result.json") + + def get_payload(self, request_id: str) -> Union[Dict, str, bytes]: + key = f"{self.storage_path}/{request_id}/payload" + obj = self.storage.get_object(key) + status_code = obj["ResponseMetadata"]["HTTPStatusCode"] + if status_code != HTTPStatus.OK: + raise CortexException( + f"failed to retrieve async payload (request_id: {request_id}, status_code: {status_code})" + ) + + content_type: str = obj["ResponseMetadata"]["HTTPHeaders"]["content-type"] + payload_bytes: bytes = obj["Body"].read() + + # decode payload + if content_type.startswith("application/json"): + try: + return json.loads(payload_bytes) + except Exception as err: + raise UserRuntimeException( + f"the uploaded payload, with content-type {content_type}, could not be decoded to JSON" + ) from err + elif content_type.startswith("text/plain"): + try: + return payload_bytes.decode("utf-8") + except Exception as err: + raise UserRuntimeException( + f"the uploaded payload, with content-type {content_type}, could not be decoded to a utf-8 string" + ) from err + else: + return payload_bytes + + def delete_payload(self, request_id: str): + key = f"{self.storage_path}/{request_id}/payload" + self.storage.delete(key) + + def initialize_impl(self, project_dir: str, metrics_client: MetricsClient): + predictor_impl = self._get_impl(project_dir) + constructor_args = inspect.getfullargspec(predictor_impl.__init__).args + config = deepcopy(self.config) + + args = {} + if "config" in constructor_args: + args["config"] = config + if "metrics_client" in constructor_args: + args["metrics_client"] = metrics_client + + try: + predictor = predictor_impl(**args) + except Exception as e: + raise UserRuntimeException(self.path, "__init__", str(e)) from e + + return predictor + + def _get_impl(self, project_dir: str): + try: + impl = self._read_impl( + "cortex_async_predictor", os.path.join(project_dir, self.path), "PythonPredictor" + ) + except CortexException as e: + e.wrap("error in " + self.path) + raise + + try: + self._validate_impl(impl) + except CortexException as e: + e.wrap("error in " + self.path) + raise + + return impl + + @staticmethod + def _read_impl(module_name: str, impl_path: str, target_class_name): + if impl_path.endswith(".pickle"): + try: + with open(impl_path, "rb") as pickle_file: + return dill.load(pickle_file) + except Exception as e: + raise UserException("unable to load pickle", str(e)) from e + + try: + impl = imp.load_source(module_name, impl_path) + except Exception as e: + raise UserException(str(e)) from e + + classes = inspect.getmembers(impl, inspect.isclass) + + predictor_class = None + for class_df in classes: + if class_df[0] == target_class_name: + if predictor_class is not None: + raise UserException( + f"multiple definitions for {target_class_name} class found; please check " + f"your imports and class definitions and ensure that there is only one " + f"predictor class definition" + ) + predictor_class = class_df[1] + + if predictor_class is None: + raise UserException(f"{target_class_name} class is not defined") + + return predictor_class + + @staticmethod + def _validate_impl(impl): + return validate_class_impl(impl, ASYNC_PYTHON_PREDICTOR_VALIDATION) diff --git a/pkg/cortex/serve/cortex_internal/lib/queue/__init__.py b/pkg/cortex/serve/cortex_internal/lib/queue/__init__.py new file mode 100644 index 0000000000..dcd1d9ae2f --- /dev/null +++ b/pkg/cortex/serve/cortex_internal/lib/queue/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pkg/cortex/serve/cortex_internal/lib/queue/sqs.py b/pkg/cortex/serve/cortex_internal/lib/queue/sqs.py new file mode 100644 index 0000000000..eb7ff600d7 --- /dev/null +++ b/pkg/cortex/serve/cortex_internal/lib/queue/sqs.py @@ -0,0 +1,181 @@ +# Copyright 2021 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +from typing import Callable, Dict, Any, Optional + +import botocore.exceptions + +from cortex_internal.lib.exceptions import UserRuntimeException +from cortex_internal.lib.log import logger as log +from cortex_internal.lib.signals import SignalHandler +from cortex_internal.lib.telemetry import capture_exception + + +def is_on_job_complete(message) -> bool: + return "MessageAttributes" in message and "job_complete" in message["MessageAttributes"] + + +class SQSHandler: + def __init__( + self, + sqs_client, + queue_url: str, + dead_letter_queue_url: str = None, + message_wait_time: int = 10, + visibility_timeout: int = 30, + not_found_sleep_time: int = 10, + renewal_period: int = 15, + stop_if_no_messages: bool = False, + ): + self.sqs_client = sqs_client + self.queue_url = queue_url + self.dead_letter_queue_url = dead_letter_queue_url + self.message_wait_time = message_wait_time + self.visibility_timeout = visibility_timeout + self.not_found_sleep_time = not_found_sleep_time + self.renewal_period = renewal_period + self.stop_if_no_messages = stop_if_no_messages + + self.receipt_handle_mutex = threading.Lock() + self.stop_renewal = set() + + def start( + self, + message_fn: Callable[[Dict[str, Any]], None], + message_failure_fn: Callable[[Dict[str, Any]], None], + on_job_complete_fn: Optional[Callable[[Dict[str, Any]], None]] = None, + ): + no_messages_found_in_previous_iteration = False + signal_handler = SignalHandler() + + while not signal_handler.received_signal(): + response = self.sqs_client.receive_message( + QueueUrl=self.queue_url, + MaxNumberOfMessages=1, + WaitTimeSeconds=self.message_wait_time, + VisibilityTimeout=self.visibility_timeout, + MessageAttributeNames=["All"], + ) + + if response.get("Messages") is None or len(response["Messages"]) == 0: + visible_messages, invisible_messages = self._get_total_messages_in_queue() + if visible_messages + invisible_messages == 0: + if no_messages_found_in_previous_iteration and self.stop_if_no_messages: + log.info("no messages left in queue, exiting...") + return + no_messages_found_in_previous_iteration = True + + time.sleep(self.not_found_sleep_time) + continue + + no_messages_found_in_previous_iteration = False + message = response["Messages"][0] + receipt_handle = message["ReceiptHandle"] + + renewer = threading.Thread( + target=self._renew_message_visibility, + args=(receipt_handle,), + daemon=True, + ) + renewer.start() + + if is_on_job_complete(message): + self._handle_on_job_complete(message, on_job_complete_fn) + else: + self._handle_message(message, message_fn, message_failure_fn) + + def _renew_message_visibility(self, receipt_handle: str): + interval = self.renewal_period + new_timeout = self.visibility_timeout + + cur_time = time.time() + while True: + time.sleep((cur_time + interval) - time.time()) + cur_time += interval + new_timeout += interval + + with self.receipt_handle_mutex: + if receipt_handle in self.stop_renewal: + self.stop_renewal.remove(receipt_handle) + break + + try: + self.sqs_client.change_message_visibility( + QueueUrl=self.queue_url, + ReceiptHandle=receipt_handle, + VisibilityTimeout=new_timeout, + ) + except botocore.exceptions.ClientError as err: + if err.response["Error"]["Code"] == "InvalidParameterValue": + # unexpected; this error is thrown when attempting to renew a message that has been deleted + continue + elif err.response["Error"]["Code"] == "AWS.SimpleQueueService.NonExistentQueue": + # there may be a delay between the cron may deleting the queue and this worker stopping + log.info( + "failed to renew message visibility because the queue was not found" + ) + else: + self.stop_renewal.remove(receipt_handle) + raise err + + def _get_total_messages_in_queue(self): + attributes = self.sqs_client.get_queue_attributes( + QueueUrl=self.queue_url, AttributeNames=["All"] + )["Attributes"] + visible_count = int(attributes.get("ApproximateNumberOfMessages", 0)) + not_visible_count = int(attributes.get("ApproximateNumberOfMessagesNotVisible", 0)) + return visible_count, not_visible_count + + def _handle_message(self, message, callback_fn, failure_callback_fn): + receipt_handle = message["ReceiptHandle"] + + try: + callback_fn(message) + except Exception as err: + if not isinstance(err, UserRuntimeException): + capture_exception(err) + + failure_callback_fn(message) + + with self.receipt_handle_mutex: + self.stop_renewal.add(receipt_handle) + if self.dead_letter_queue_url is not None: + self.sqs_client.change_message_visibility( # return message + QueueUrl=self.queue_url, ReceiptHandle=receipt_handle, VisibilityTimeout=0 + ) + else: + self.sqs_client.delete_message( + QueueUrl=self.queue_url, ReceiptHandle=receipt_handle + ) + else: + with self.receipt_handle_mutex: + self.stop_renewal.add(receipt_handle) + self.sqs_client.delete_message( + QueueUrl=self.queue_url, ReceiptHandle=receipt_handle + ) + + def _handle_on_job_complete(self, message, callback_fn): + receipt_handle = message["ReceiptHandle"] + try: + callback_fn(message) + except Exception as err: + raise type(err)("failed to handle on_job_complete") from err + finally: + with self.receipt_handle_mutex: + self.stop_renewal.add(receipt_handle) + self.sqs_client.delete_message( + QueueUrl=self.queue_url, ReceiptHandle=receipt_handle + ) diff --git a/pkg/cortex/serve/cortex_internal/lib/signals.py b/pkg/cortex/serve/cortex_internal/lib/signals.py new file mode 100644 index 0000000000..f3ea829d41 --- /dev/null +++ b/pkg/cortex/serve/cortex_internal/lib/signals.py @@ -0,0 +1,31 @@ +# Copyright 2021 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from signal import signal, SIGINT, SIGTERM + +from cortex_internal.lib.log import logger as log + + +class SignalHandler: + def __init__(self): + self.__received_signal = False + signal(SIGINT, self._signal_handler) + signal(SIGTERM, self._signal_handler) + + def _signal_handler(self, sys_signal, _): + log.info(f"handling signal {sys_signal}, exiting gracefully") + self.__received_signal = True + + def received_signal(self): + return self.__received_signal diff --git a/pkg/cortex/serve/cortex_internal/lib/storage/s3.py b/pkg/cortex/serve/cortex_internal/lib/storage/s3.py index dceba1d051..f9c642cfc3 100644 --- a/pkg/cortex/serve/cortex_internal/lib/storage/s3.py +++ b/pkg/cortex/serve/cortex_internal/lib/storage/s3.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime +import json import os +import time +from typing import List, Tuple + import boto3 import botocore -import pickle -import json import msgpack -import time -import datetime -from typing import Dict, List, Tuple from cortex_internal.lib import util from cortex_internal.lib.exceptions import CortexException @@ -120,6 +120,9 @@ def _get_matching_s3_keys_generator(self, prefix="", suffix="", include_dir_obje def put_object(self, body, key): self.s3.put_object(Bucket=self.bucket, Key=key, Body=body) + def get_object(self, key): + return self.s3.get_object(Bucket=self.bucket, Key=key) + def _read_bytes_from_s3( self, key, allow_missing=False, ext_bucket=None, num_retries=0, retry_delay_sec=2 ): @@ -242,3 +245,6 @@ def download(self, prefix, local_dir): self.download_dir(prefix, local_dir) else: self.download_file_to_dir(prefix, local_dir) + + def delete(self, key): + self.s3.delete_object(Bucket=self.bucket, Key=key) diff --git a/pkg/cortex/serve/init/bootloader.sh b/pkg/cortex/serve/init/bootloader.sh index 74cc214c1c..4dcbd059b8 100755 --- a/pkg/cortex/serve/init/bootloader.sh +++ b/pkg/cortex/serve/init/bootloader.sh @@ -159,6 +159,9 @@ if [ "$CORTEX_KIND" = "RealtimeAPI" ]; then elif [ "$CORTEX_KIND" = "BatchAPI" ]; then create_s6_service "py_init" "cd /mnt/project && exec /opt/conda/envs/env/bin/python /src/cortex/serve/init/script.py" create_s6_service "batch" "cd /mnt/project && $source_env_file_cmd && PYTHONUNBUFFERED=TRUE PYTHONPATH=$PYTHONPATH:$CORTEX_PYTHON_PATH exec /opt/conda/envs/env/bin/python /src/cortex/serve/start/batch.py" +elif [ "$CORTEX_KIND" = "AsyncAPI" ]; then + create_s6_service "py_init" "cd /mnt/project && exec /opt/conda/envs/env/bin/python /src/cortex/serve/init/script.py" + create_s6_service "async" "cd /mnt/project && $source_env_file_cmd && PYTHONUNBUFFERED=TRUE PYTHONPATH=$PYTHONPATH:$CORTEX_PYTHON_PATH exec /opt/conda/envs/env/bin/python /src/cortex/serve/start/async.py" elif [ "$CORTEX_KIND" = "TaskAPI" ]; then create_s6_service "task" "cd /mnt/project && $source_env_file_cmd && PYTHONUNBUFFERED=TRUE PYTHONPATH=$PYTHONPATH:$CORTEX_PYTHON_PATH exec /opt/conda/envs/env/bin/python /src/cortex/serve/start/task.py" fi diff --git a/pkg/cortex/serve/start/async.py b/pkg/cortex/serve/start/async.py new file mode 100644 index 0000000000..ec73c233bc --- /dev/null +++ b/pkg/cortex/serve/start/async.py @@ -0,0 +1,160 @@ +# Copyright 2021 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import os +import sys +from typing import Dict, Any + +import boto3 + +from cortex_internal.lib.api import get_spec +from cortex_internal.lib.api.async import AsyncAPI +from cortex_internal.lib.exceptions import UserRuntimeException +from cortex_internal.lib.log import configure_logger +from cortex_internal.lib.metrics import MetricsClient +from cortex_internal.lib.queue.sqs import SQSHandler +from cortex_internal.lib.telemetry import init_sentry, get_default_tags, capture_exception + +init_sentry(tags=get_default_tags()) +log = configure_logger("cortex", os.environ["CORTEX_LOG_CONFIG_FILE"]) + +SQS_POLL_WAIT_TIME = 10 # seconds +MESSAGE_NOT_FOUND_SLEEP = 10 # seconds +INITIAL_MESSAGE_VISIBILITY = 30 # seconds +MESSAGE_RENEWAL_PERIOD = 15 # seconds +JOB_COMPLETE_MESSAGE_RENEWAL = 10 # seconds + +local_cache: Dict[str, Any] = { + "api": None, + "provider": None, + "predictor_impl": None, + "predict_fn_args": None, + "sqs_client": None, + "storage_client": None, +} + + +def handle_workload(message): + api: AsyncAPI = local_cache["api"] + predictor_impl = local_cache["predictor_impl"] + + request_id = message["Body"] + log.info(f"processing workload...", extra={"id": request_id}) + + api.update_status(request_id, "in_progress") + payload = api.get_payload(request_id) + + try: + result = predictor_impl.predict(**build_predict_args(payload, request_id)) + except Exception as err: + raise UserRuntimeException from err + + log.debug("uploading result", extra={"id": request_id}) + api.upload_result(request_id, result) + + log.debug("updating status to completed", extra={"id": request_id}) + api.update_status(request_id, "completed") + + log.debug("deleting payload from s3") + api.delete_payload(request_id=request_id) + + log.info("workload processing complete", extra={"id": request_id}) + + +def handle_workload_failure(message): + api: AsyncAPI = local_cache["api"] + request_id = message["Body"] + + log.error("failed to process workload", exc_info=True, extra={"id": request_id}) + api.update_status(request_id, "failed") + + log.debug("deleting payload from s3") + api.delete_payload(request_id=request_id) + + +def build_predict_args(payload, request_id): + args = {} + if "payload" in local_cache["predict_fn_args"]: + args["payload"] = payload + if "request_id" in local_cache["predict_fn_args"]: + args["request_id"] = request_id + return args + + +def main(): + cache_dir = os.environ["CORTEX_CACHE_DIR"] + provider = os.environ["CORTEX_PROVIDER"] + api_spec_path = os.environ["CORTEX_API_SPEC"] + workload_path = os.environ["CORTEX_ASYNC_WORKLOAD_PATH"] + project_dir = os.environ["CORTEX_PROJECT_DIR"] + readiness_file = os.getenv("CORTEX_READINESS_FILE", "/mnt/workspace/api_readiness.txt") + region = os.getenv("AWS_REGION") + queue_url = os.environ["CORTEX_QUEUE_URL"] + statsd_host = os.getenv("HOST_IP") + statsd_port = os.getenv("CORTEX_STATSD_PORT", "9125") + + storage, api_spec = get_spec(provider, api_spec_path, cache_dir, region) + sqs_client = boto3.client("sqs", region_name=region) + api = AsyncAPI( + api_spec=api_spec, + storage=storage, + storage_path=workload_path, + statsd_host=statsd_host, + statsd_port=int(statsd_port), + ) + + try: + log.info("loading the predictor from {}".format(api.path)) + metrics_client = MetricsClient(api.statsd) + predictor_impl = api.initialize_impl(project_dir, metrics_client) + except UserRuntimeException as err: + err.wrap(f"failed to initialize predictor implementation") + log.error(str(err), exc_info=True) + sys.exit(1) + except Exception as err: + capture_exception(err) + log.error(f"failed to initialize predictor implementation", exc_info=True) + sys.exit(1) + + local_cache["api"] = api + local_cache["provider"] = provider + local_cache["predictor_impl"] = predictor_impl + local_cache["sqs_client"] = sqs_client + local_cache["storage_client"] = storage + local_cache["predict_fn_args"] = inspect.getfullargspec(predictor_impl.predict).args + + open(readiness_file, "a").close() + + log.info("polling for workloads...") + try: + sqs_handler = SQSHandler( + sqs_client=sqs_client, + queue_url=queue_url, + renewal_period=MESSAGE_RENEWAL_PERIOD, + visibility_timeout=INITIAL_MESSAGE_VISIBILITY, + not_found_sleep_time=MESSAGE_NOT_FOUND_SLEEP, + message_wait_time=SQS_POLL_WAIT_TIME, + ) + sqs_handler.start(message_fn=handle_workload, message_failure_fn=handle_workload_failure) + except UserRuntimeException as err: + log.error(str(err), exc_info=True) + sys.exit(1) + except Exception as err: + capture_exception(err) + log.error(str(err), exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/pkg/cortex/serve/start/batch.py b/pkg/cortex/serve/start/batch.py index 69327e447e..2dd3e9ed03 100644 --- a/pkg/cortex/serve/start/batch.py +++ b/pkg/cortex/serve/start/batch.py @@ -25,10 +25,11 @@ import boto3 import botocore +import botocore.exceptions from cortex_internal.lib.api import get_api, get_spec from cortex_internal.lib.concurrency import LockedFile -from cortex_internal.lib.exceptions import UserException, UserRuntimeException +from cortex_internal.lib.exceptions import UserRuntimeException from cortex_internal.lib.log import configure_logger from cortex_internal.lib.metrics import MetricsClient from cortex_internal.lib.storage import S3 diff --git a/pkg/lib/k8s/virtual_service.go b/pkg/lib/k8s/virtual_service.go index 40f1cc6137..4665f00e73 100644 --- a/pkg/lib/k8s/virtual_service.go +++ b/pkg/lib/k8s/virtual_service.go @@ -124,7 +124,7 @@ func VirtualService(spec *VirtualServiceSpec) *istioclientnetworking.VirtualServ { Uri: &istionetworking.StringMatch{ MatchType: &istionetworking.StringMatch_Prefix{ - Prefix: urls.CanonicalizeEndpoint(*spec.PrefixPath) + "/", + Prefix: urls.CanonicalizeEndpointWithTrailingSlash(*spec.PrefixPath), }, }, }, @@ -140,7 +140,7 @@ func VirtualService(spec *VirtualServiceSpec) *istioclientnetworking.VirtualServ } prefixMatch.Rewrite = &istionetworking.HTTPRewrite{ - Uri: urls.CanonicalizeEndpoint(*spec.Rewrite) + "/", + Uri: urls.CanonicalizeEndpointWithTrailingSlash(*spec.Rewrite), } } diff --git a/pkg/lib/urls/urls.go b/pkg/lib/urls/urls.go index 4797b9d0b9..eeb387aac6 100644 --- a/pkg/lib/urls/urls.go +++ b/pkg/lib/urls/urls.go @@ -87,6 +87,13 @@ func CanonicalizeEndpoint(str string) string { return strings.TrimSuffix(s.EnsurePrefix(str, "/"), "/") } +func CanonicalizeEndpointWithTrailingSlash(str string) string { + if str == "" || str == "/" { + return "/" + } + return s.EnsureSuffix(s.EnsurePrefix(str, "/"), "/") +} + func TrimQueryParamsURL(u url.URL) string { u.RawQuery = "" return u.String() diff --git a/pkg/operator/endpoints/logs.go b/pkg/operator/endpoints/logs.go index ac1111fc65..4592d26f2f 100644 --- a/pkg/operator/endpoints/logs.go +++ b/pkg/operator/endpoints/logs.go @@ -44,7 +44,7 @@ func ReadLogs(w http.ResponseWriter, r *http.Request) { if deployedResource.Kind == userconfig.BatchAPIKind || deployedResource.Kind == userconfig.TaskAPIKind { respondError(w, r, ErrorLogsJobIDRequired(*deployedResource)) return - } else if deployedResource.Kind != userconfig.RealtimeAPIKind { + } else if deployedResource.Kind != userconfig.RealtimeAPIKind && deployedResource.Kind != userconfig.AsyncAPIKind { respondError(w, r, resources.ErrorOperationIsOnlySupportedForKind(*deployedResource, userconfig.RealtimeAPIKind)) return } @@ -60,5 +60,10 @@ func ReadLogs(w http.ResponseWriter, r *http.Request) { } defer socket.Close() - operator.StreamLogsFromRandomPod(map[string]string{"apiName": apiName, "deploymentID": deploymentID, "predictorID": predictorID}, socket) + labels := map[string]string{"apiName": apiName, "deploymentID": deploymentID, "predictorID": predictorID} + + if deployedResource.Kind == userconfig.AsyncAPIKind { + labels["cortex.dev/async"] = "api" + } + operator.StreamLogsFromRandomPod(labels, socket) } diff --git a/pkg/operator/lib/autoscaler/autoscaler.go b/pkg/operator/lib/autoscaler/autoscaler.go new file mode 100644 index 0000000000..1cb68189f8 --- /dev/null +++ b/pkg/operator/lib/autoscaler/autoscaler.go @@ -0,0 +1,234 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package autoscaler + +import ( + "fmt" + "math" + "time" + + "github.com/cortexlabs/cortex/pkg/lib/errors" + math2 "github.com/cortexlabs/cortex/pkg/lib/math" + "github.com/cortexlabs/cortex/pkg/lib/strings" + time2 "github.com/cortexlabs/cortex/pkg/lib/time" + "github.com/cortexlabs/cortex/pkg/operator/config" + "github.com/cortexlabs/cortex/pkg/operator/operator" + "github.com/cortexlabs/cortex/pkg/types/spec" + "github.com/cortexlabs/cortex/pkg/types/userconfig" + kapps "k8s.io/api/apps/v1" +) + +// GetInFlightFunc is the function signature used by the autoscaler to retrieve +// the number of in-flight requests / messages +type GetInFlightFunc func(apiName string, window time.Duration) (*float64, error) + +type recommendations map[time.Time]int32 + +func (recs recommendations) add(rec int32) { + recs[time.Now()] = rec +} + +func (recs recommendations) deleteOlderThan(period time.Duration) { + for t := range recs { + if time.Since(t) > period { + delete(recs, t) + } + } +} + +// Returns nil if no recommendations in the period +func (recs recommendations) maxSince(period time.Duration) *int32 { + max := int32(math.MinInt32) + foundRecommendation := false + + for t, rec := range recs { + if time.Since(t) <= period && rec > max { + max = rec + foundRecommendation = true + } + } + + if !foundRecommendation { + return nil + } + + return &max +} + +// Returns nil if no recommendations in the period +func (recs recommendations) minSince(period time.Duration) *int32 { + min := int32(math.MaxInt32) + foundRecommendation := false + + for t, rec := range recs { + if time.Since(t) <= period && rec < min { + min = rec + foundRecommendation = true + } + } + + if !foundRecommendation { + return nil + } + + return &min +} + +// AutoscaleFn returns the autoscaler function +func AutoscaleFn(initialDeployment *kapps.Deployment, apiSpec *spec.API, getInFlightFn GetInFlightFunc) (func() error, error) { + autoscalingSpec, err := userconfig.AutoscalingFromAnnotations(initialDeployment) + if err != nil { + return nil, err + } + + apiName := apiSpec.Name + currentReplicas := *initialDeployment.Spec.Replicas + + apiLogger, err := operator.GetRealtimeAPILoggerFromSpec(apiSpec) + if err != nil { + return nil, err + } + + apiLogger.Infof("%s autoscaler init", apiName) + + var startTime time.Time + recs := make(recommendations) + + return func() error { + if startTime.IsZero() { + startTime = time.Now() + } + + avgInFlight, err := getInFlightFn(apiName, autoscalingSpec.Window) + if err != nil { + return err + } + if avgInFlight == nil { + apiLogger.Debugf("%s autoscaler tick: metrics not available yet", apiName) + return nil + } + + rawRecommendation := *avgInFlight / *autoscalingSpec.TargetReplicaConcurrency + recommendation := int32(math.Ceil(rawRecommendation)) + + if rawRecommendation < float64(currentReplicas) && rawRecommendation > float64(currentReplicas)*(1-autoscalingSpec.DownscaleTolerance) { + recommendation = currentReplicas + } + + if rawRecommendation > float64(currentReplicas) && rawRecommendation < float64(currentReplicas)*(1+autoscalingSpec.UpscaleTolerance) { + recommendation = currentReplicas + } + + // always allow subtraction of 1 + downscaleFactorFloor := math2.MinInt32(currentReplicas-1, int32(math.Ceil(float64(currentReplicas)*autoscalingSpec.MaxDownscaleFactor))) + if recommendation < downscaleFactorFloor { + recommendation = downscaleFactorFloor + } + + // always allow addition of 1 + upscaleFactorCeil := math2.MaxInt32(currentReplicas+1, int32(math.Ceil(float64(currentReplicas)*autoscalingSpec.MaxUpscaleFactor))) + if recommendation > upscaleFactorCeil { + recommendation = upscaleFactorCeil + } + + if recommendation < 1 { + recommendation = 1 + } + + if recommendation < autoscalingSpec.MinReplicas { + recommendation = autoscalingSpec.MinReplicas + } + + if recommendation > autoscalingSpec.MaxReplicas { + recommendation = autoscalingSpec.MaxReplicas + } + + // Rule of thumb: any modifications that don't consider historical recommendations should be performed before + // recording the recommendation, any modifications that use historical recommendations should be performed after + recs.add(recommendation) + + // This is just for garbage collection + recs.deleteOlderThan(time2.MaxDuration(autoscalingSpec.DownscaleStabilizationPeriod, autoscalingSpec.UpscaleStabilizationPeriod)) + + request := recommendation + var downscaleStabilizationFloor *int32 + var upscaleStabilizationCeil *int32 + + if request < currentReplicas { + downscaleStabilizationFloor = recs.maxSince(autoscalingSpec.DownscaleStabilizationPeriod) + if time.Since(startTime) < autoscalingSpec.DownscaleStabilizationPeriod { + request = currentReplicas + } else if downscaleStabilizationFloor != nil && request < *downscaleStabilizationFloor { + request = *downscaleStabilizationFloor + } + } + if request > currentReplicas { + upscaleStabilizationCeil = recs.minSince(autoscalingSpec.UpscaleStabilizationPeriod) + if time.Since(startTime) < autoscalingSpec.UpscaleStabilizationPeriod { + request = currentReplicas + } else if upscaleStabilizationCeil != nil && request > *upscaleStabilizationCeil { + request = *upscaleStabilizationCeil + } + } + + apiLogger.Debugw(fmt.Sprintf("%s autoscaler tick", apiName), + "autoscaling", map[string]interface{}{ + "avg_in_flight": strings.Round(*avgInFlight, 2, 0), + "target_replica_concurrency": strings.Float64(*autoscalingSpec.TargetReplicaConcurrency), + "raw_recommendation": strings.Round(rawRecommendation, 2, 0), + "current_replicas": currentReplicas, + "downscale_tolerance": strings.Float64(autoscalingSpec.DownscaleTolerance), + "upscale_tolerance": strings.Float64(autoscalingSpec.UpscaleTolerance), + "max_downscale_factor": strings.Float64(autoscalingSpec.MaxDownscaleFactor), + "downscale_factor_floor": downscaleFactorFloor, + "max_upscale_factor": strings.Float64(autoscalingSpec.MaxUpscaleFactor), + "upscale_factor_ceil": upscaleFactorCeil, + "min_replicas": autoscalingSpec.MinReplicas, + "max_replicas": autoscalingSpec.MaxReplicas, + "recommendation": recommendation, + "downscale_stabilization_period": autoscalingSpec.DownscaleStabilizationPeriod, + "downscale_stabilization_floor": strings.ObjFlatNoQuotes(downscaleStabilizationFloor), + "upscale_stabilization_period": autoscalingSpec.UpscaleStabilizationPeriod, + "upscale_stabilization_ceil": strings.ObjFlatNoQuotes(upscaleStabilizationCeil), + "request": request, + }, + ) + + if currentReplicas != request { + apiLogger.Infof("%s autoscaling event: %d -> %d", apiName, currentReplicas, request) + + deployment, err := config.K8s.GetDeployment(initialDeployment.Name) + if err != nil { + return err + } + + if deployment == nil { + return errors.ErrorUnexpected("unable to find k8s deployment", apiName) + } + + deployment.Spec.Replicas = &request + + if _, err := config.K8s.UpdateDeployment(deployment); err != nil { + return err + } + + currentReplicas = request + } + + return nil + }, nil +} diff --git a/pkg/operator/main.go b/pkg/operator/main.go index 5697c0d6d3..00ce9a8a12 100644 --- a/pkg/operator/main.go +++ b/pkg/operator/main.go @@ -28,12 +28,14 @@ import ( "github.com/cortexlabs/cortex/pkg/operator/lib/exit" "github.com/cortexlabs/cortex/pkg/operator/lib/logging" "github.com/cortexlabs/cortex/pkg/operator/operator" + "github.com/cortexlabs/cortex/pkg/operator/resources/asyncapi" "github.com/cortexlabs/cortex/pkg/operator/resources/job/batchapi" "github.com/cortexlabs/cortex/pkg/operator/resources/job/taskapi" "github.com/cortexlabs/cortex/pkg/operator/resources/realtimeapi" "github.com/cortexlabs/cortex/pkg/types" "github.com/cortexlabs/cortex/pkg/types/userconfig" "github.com/gorilla/mux" + "github.com/prometheus/client_golang/prometheus/promhttp" ) var operatorLogger = logging.GetOperatorLogger() @@ -70,15 +72,29 @@ func main() { } for _, deployment := range deployments { - if userconfig.KindFromString(deployment.Labels["apiKind"]) == userconfig.RealtimeAPIKind { + apiKind := deployment.Labels["apiKind"] + if userconfig.KindFromString(apiKind) == userconfig.RealtimeAPIKind || + userconfig.KindFromString(apiKind) == userconfig.AsyncAPIKind { apiID := deployment.Labels["apiID"] apiName := deployment.Labels["apiName"] api, err := operator.DownloadAPISpec(apiName, apiID) if err != nil { exit.Error(errors.Wrap(err, "init")) } - if err := realtimeapi.UpdateAutoscalerCron(&deployment, api); err != nil { - operatorLogger.Fatal(errors.Wrap(err, "init")) + + switch apiKind { + case userconfig.RealtimeAPIKind.String(): + if err := realtimeapi.UpdateAutoscalerCron(&deployment, api); err != nil { + operatorLogger.Fatal(errors.Wrap(err, "init")) + } + case userconfig.AsyncAPIKind.String(): + if err := asyncapi.UpdateMetricsCron(&deployment); err != nil { + operatorLogger.Fatal(errors.Wrap(err, "init")) + } + + if err := asyncapi.UpdateAutoscalerCron(&deployment, *api); err != nil { + operatorLogger.Fatal(errors.Wrap(err, "init")) + } } } } @@ -102,6 +118,9 @@ func main() { routerWithoutAuth.HandleFunc("/tasks/{apiName}", endpoints.GetTaskJob).Methods("GET") routerWithoutAuth.HandleFunc("/tasks/{apiName}", endpoints.StopTaskJob).Methods("DELETE") + // prometheus metrics + routerWithoutAuth.Handle("/metrics", promhttp.Handler()).Methods("GET") + routerWithAuth := router.NewRoute().Subrouter() routerWithAuth.Use(endpoints.PanicMiddleware) diff --git a/pkg/operator/operator/k8s.go b/pkg/operator/operator/k8s.go index 57be71b540..44ed935033 100644 --- a/pkg/operator/operator/k8s.go +++ b/pkg/operator/operator/k8s.go @@ -53,6 +53,7 @@ const ( _emptyDirVolumeName = "mnt" _tfServingContainerName = "serve" _requestMonitorContainerName = "request-monitor" + _gatewayContainerName = "gateway" _downloaderInitContainerName = "downloader" _downloaderLastLog = "downloading the %s serving image" _neuronRTDContainerName = "neuron-rtd" @@ -71,6 +72,9 @@ var ( _requestMonitorCPURequest = kresource.MustParse("10m") _requestMonitorMemRequest = kresource.MustParse("10Mi") + _asyncGatewayCPURequest = kresource.MustParse("100m") + _asyncGatewayMemRequest = kresource.MustParse("100Mi") + // each Inferentia chip requires 128 HugePages with each HugePage having a size of 2Mi _hugePagesMemPerInf = int64(128 * 2 * 1024 * 1024) // bytes ) @@ -197,7 +201,67 @@ func TaskContainers(api *spec.API) ([]kcore.Container, []kcore.Volume) { return containers, volumes } +func AsyncPythonPredictorContainers(api spec.API, queueURL string) ([]kcore.Container, []kcore.Volume) { + return pythonPredictorContainers(&api, getAsyncEnvVars(api, APIContainerName, queueURL)) +} + +func AsyncGatewayContainers(api spec.API, queueURL string) kcore.Container { + image := config.CoreConfig.ImageAsyncGateway + region := config.CoreConfig.Region + bucket := config.Bucket() + clusterName := config.ClusterName() + + return kcore.Container{ + Name: _gatewayContainerName, + Image: image, + ImagePullPolicy: kcore.PullAlways, + Args: []string{ + "-queue", queueURL, + "-region", region, + "-bucket", bucket, + "-cluster", clusterName, + "-port", s.Int32(DefaultPortInt32), + api.Name, + }, + Ports: []kcore.ContainerPort{ + {ContainerPort: DefaultPortInt32}, + }, + Env: []kcore.EnvVar{ + { + Name: "CORTEX_LOG_LEVEL", + Value: strings.ToUpper(api.Predictor.LogLevel.String()), + }, + }, + Resources: kcore.ResourceRequirements{ + Requests: kcore.ResourceList{ + kcore.ResourceCPU: _asyncGatewayCPURequest, + kcore.ResourceMemory: _asyncGatewayMemRequest, + }, + }, + LivenessProbe: &kcore.Probe{ + Handler: kcore.Handler{ + HTTPGet: &kcore.HTTPGetAction{ + Path: "/healthz", + Port: intstr.FromInt(8888), + }, + }, + }, + ReadinessProbe: &kcore.Probe{ + Handler: kcore.Handler{ + HTTPGet: &kcore.HTTPGetAction{ + Path: "/healthz", + Port: intstr.FromInt(8888), + }, + }, + }, + } +} + func PythonPredictorContainers(api *spec.API) ([]kcore.Container, []kcore.Volume) { + return pythonPredictorContainers(api, getEnvVars(api, APIContainerName)) +} + +func pythonPredictorContainers(api *spec.API, envVars []kcore.EnvVar) ([]kcore.Container, []kcore.Volume) { apiPodResourceList := kcore.ResourceList{} apiPodResourceLimitsList := kcore.ResourceList{} apiPodVolumeMounts := defaultVolumeMounts() @@ -281,7 +345,7 @@ func PythonPredictorContainers(api *spec.API) ([]kcore.Container, []kcore.Volume Name: APIContainerName, Image: api.Predictor.Image, ImagePullPolicy: kcore.PullAlways, - Env: getEnvVars(api, APIContainerName), + Env: envVars, EnvFrom: baseEnvVars(), VolumeMounts: apiPodVolumeMounts, ReadinessProbe: FileExistsProbe(_apiReadinessFile), @@ -568,6 +632,129 @@ func getTaskEnvVars(api *spec.API, container string) []kcore.EnvVar { return envVars } +func getAsyncEnvVars(api spec.API, container string, queueURL string) []kcore.EnvVar { + if container == _downloaderInitContainerName { + return []kcore.EnvVar{ + { + Name: "CORTEX_LOG_LEVEL", + Value: strings.ToUpper(api.Predictor.LogLevel.String()), + }, + } + } + + envVars := []kcore.EnvVar{ + { + Name: "CORTEX_KIND", + Value: api.Kind.String(), + }, + { + Name: "CORTEX_TELEMETRY_SENTRY_USER_ID", + Value: config.OperatorMetadata.OperatorID, + }, + { + Name: "CORTEX_TELEMETRY_SENTRY_ENVIRONMENT", + Value: "api", + }, + } + + for name, val := range api.Predictor.Env { + envVars = append(envVars, kcore.EnvVar{ + Name: name, + Value: val, + }) + } + + if container == APIContainerName { + envVars = append(envVars, + kcore.EnvVar{ + Name: "CORTEX_LOG_LEVEL", + Value: strings.ToUpper(api.Predictor.LogLevel.String()), + }, + kcore.EnvVar{ + Name: "HOST_IP", + ValueFrom: &kcore.EnvVarSource{ + FieldRef: &kcore.ObjectFieldSelector{ + FieldPath: "status.hostIP", + }, + }, + }, + kcore.EnvVar{ + Name: "CORTEX_SERVING_PORT", + Value: DefaultPortStr, + }, + kcore.EnvVar{ + Name: "CORTEX_CACHE_DIR", + Value: _specCacheDir, + }, + kcore.EnvVar{ + Name: "CORTEX_PROJECT_DIR", + Value: path.Join(_emptyDirMountPath, "project"), + }, + kcore.EnvVar{ + Name: "CORTEX_DEPENDENCIES_PIP", + Value: api.Predictor.Dependencies.Pip, + }, + kcore.EnvVar{ + Name: "CORTEX_DEPENDENCIES_CONDA", + Value: api.Predictor.Dependencies.Conda, + }, + kcore.EnvVar{ + Name: "CORTEX_DEPENDENCIES_SHELL", + Value: api.Predictor.Dependencies.Shell, + }, + kcore.EnvVar{ + Name: "CORTEX_QUEUE_URL", + Value: queueURL, + }, + kcore.EnvVar{ + Name: "CORTEX_ASYNC_WORKLOAD_PATH", + Value: fmt.Sprintf("%s/apis/%s/workloads", config.ClusterName(), api.Name), + }, + kcore.EnvVar{ + Name: "CORTEX_API_SPEC", + Value: config.BucketPath(api.PredictorKey), + }, + kcore.EnvVar{ + Name: "CORTEX_PROCESSES_PER_REPLICA", + Value: s.Int32(api.Predictor.ProcessesPerReplica), + }, + ) + + if api.Autoscaling != nil { + envVars = append(envVars, + kcore.EnvVar{ + Name: "CORTEX_MAX_REPLICA_CONCURRENCY", + Value: s.Int64(api.Autoscaling.MaxReplicaConcurrency), + }, + ) + } + + cortexPythonPath := path.Join(_emptyDirMountPath, "project") + if api.Predictor.PythonPath != nil { + cortexPythonPath = path.Join(_emptyDirMountPath, "project", *api.Predictor.PythonPath) + } + envVars = append(envVars, kcore.EnvVar{ + Name: "CORTEX_PYTHON_PATH", + Value: cortexPythonPath, + }) + } + + if api.Compute.Inf > 0 && container == APIContainerName { + envVars = append(envVars, + kcore.EnvVar{ + Name: "NEURONCORE_GROUP_SIZES", + Value: s.Int64(api.Compute.Inf * consts.NeuronCoresPerInf / int64(api.Predictor.ProcessesPerReplica)), + }, + kcore.EnvVar{ + Name: "NEURON_RTD_ADDRESS", + Value: fmt.Sprintf("unix:%s", _neuronRTDSocket), + }, + ) + } + + return envVars +} + func getEnvVars(api *spec.API, container string) []kcore.EnvVar { if container == _requestMonitorContainerName || container == _downloaderInitContainerName { return []kcore.EnvVar{ diff --git a/pkg/operator/resources/asyncapi/api.go b/pkg/operator/resources/asyncapi/api.go new file mode 100644 index 0000000000..e354f3d779 --- /dev/null +++ b/pkg/operator/resources/asyncapi/api.go @@ -0,0 +1,462 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package asyncapi + +import ( + "fmt" + "path/filepath" + "time" + + "github.com/cortexlabs/cortex/pkg/lib/cron" + "github.com/cortexlabs/cortex/pkg/lib/errors" + "github.com/cortexlabs/cortex/pkg/lib/k8s" + "github.com/cortexlabs/cortex/pkg/lib/parallel" + "github.com/cortexlabs/cortex/pkg/operator/config" + autoscalerlib "github.com/cortexlabs/cortex/pkg/operator/lib/autoscaler" + "github.com/cortexlabs/cortex/pkg/operator/lib/routines" + "github.com/cortexlabs/cortex/pkg/operator/operator" + "github.com/cortexlabs/cortex/pkg/operator/schema" + "github.com/cortexlabs/cortex/pkg/types/spec" + "github.com/cortexlabs/cortex/pkg/types/userconfig" + istioclientnetworking "istio.io/client-go/pkg/apis/networking/v1beta1" + kapps "k8s.io/api/apps/v1" + kcore "k8s.io/api/core/v1" +) + +const ( + _stalledPodTimeout = 10 * time.Minute + _tickPeriodMetrics = 10 * time.Second +) + +var ( + _autoscalerCrons = make(map[string]cron.Cron) + _metricsCrons = make(map[string]cron.Cron) +) + +type resources struct { + apiDeployment *kapps.Deployment + gatewayDeployment *kapps.Deployment + gatewayService *kcore.Service + gatewayVirtualService *istioclientnetworking.VirtualService +} + +func getGatewayK8sName(apiName string) string { + return "gateway-" + apiName +} + +func deploymentID() string { + return k8s.RandomName()[:10] +} + +func UpdateAPI(apiConfig userconfig.API, projectID string, force bool) (*spec.API, string, error) { + prevK8sResources, err := getK8sResources(apiConfig) + if err != nil { + return nil, "", err + } + + deployID := deploymentID() + if prevK8sResources.apiDeployment != nil && prevK8sResources.apiDeployment.Labels["deploymentID"] != "" { + deployID = prevK8sResources.apiDeployment.Labels["deploymentID"] + } + + api := spec.GetAPISpec(&apiConfig, projectID, deployID, config.ClusterName()) + + // resource creation + if prevK8sResources.apiDeployment == nil { + if err = uploadAPItoS3(*api); err != nil { + return nil, "", err + } + + tags := map[string]string{ + "apiName": apiConfig.Name, + } + + queueURL, err := createFIFOQueue(apiConfig.Name, deployID, tags) + if err != nil { + return nil, "", err + } + + if err = applyK8sResources(*api, prevK8sResources, queueURL); err != nil { + routines.RunWithPanicHandler(func() { + _ = parallel.RunFirstErr( + func() error { + return deleteQueueByURL(queueURL) + }, + func() error { + return deleteK8sResources(api.Name) + }, + ) + }) + return nil, "", err + } + + return api, fmt.Sprintf("creating %s", api.Resource.UserString()), nil + } + + // resource update + if prevK8sResources.gatewayVirtualService.Labels["specID"] != api.SpecID { + isUpdating, err := isAPIUpdating(prevK8sResources.apiDeployment) + if err != nil { + return nil, "", err + } + if isUpdating && !force { + return nil, "", ErrorAPIUpdating(api.Name) + } + + if err = uploadAPItoS3(*api); err != nil { + return nil, "", err + } + + queueURL, err := getQueueURL(api.Name, prevK8sResources.gatewayVirtualService.Labels["deploymentID"]) + if err != nil { + return nil, "", err + } + + if err = applyK8sResources(*api, prevK8sResources, queueURL); err != nil { + return nil, "", err + } + + return api, fmt.Sprintf("updating %s", api.Resource.UserString()), nil + } + + // nothing changed + isUpdating, err := isAPIUpdating(prevK8sResources.apiDeployment) + if err != nil { + return nil, "", err + } + if isUpdating { + return api, fmt.Sprintf("%s is already updating", api.Resource.UserString()), nil + } + return api, fmt.Sprintf("%s is up to date", api.Resource.UserString()), nil +} + +func DeleteAPI(apiName string, keepCache bool) error { + err := parallel.RunFirstErr( + func() error { + deployment, err := config.K8s.GetVirtualService(operator.K8sName(apiName)) + if err != nil { + return err + } + queueURL, err := getQueueURL(apiName, deployment.Labels["deploymentID"]) + if err != nil { + return err + } + // best effort deletion + _ = deleteQueueByURL(queueURL) + return nil + }, + func() error { + return deleteK8sResources(apiName) + }, + func() error { + if keepCache { + return nil + } + // best effort deletion, swallow errors because there could be weird error messages + _ = deleteBucketResources(apiName) + return nil + }, + ) + + if err != nil { + return err + } + + return nil +} + +func GetAPIByName(deployedResource *operator.DeployedResource) ([]schema.APIResponse, error) { + status, err := GetStatus(deployedResource.Name) + if err != nil { + return nil, err + } + + api, err := operator.DownloadAPISpec(status.APIName, status.APIID) + if err != nil { + return nil, err + } + + apiEndpoint, err := operator.APIEndpoint(api) + if err != nil { + return nil, err + } + + return []schema.APIResponse{ + { + Spec: *api, + Status: status, + Endpoint: apiEndpoint, + }, + }, nil +} + +func GetAllAPIs(pods []kcore.Pod, deployments []kapps.Deployment) ([]schema.APIResponse, error) { + statuses, err := GetAllStatuses(deployments, pods) + if err != nil { + return nil, err + } + + apiNames, apiIDs := namesAndIDsFromStatuses(statuses) + apis, err := operator.DownloadAPISpecs(apiNames, apiIDs) + if err != nil { + return nil, err + } + + //allMetrics, err := GetMultipleMetrics(apis) + //if err != nil { + // return nil, err + //} + + realtimeAPIs := make([]schema.APIResponse, len(apis)) + + for i, api := range apis { + endpoint, err := operator.APIEndpoint(&api) + if err != nil { + return nil, err + } + + realtimeAPIs[i] = schema.APIResponse{ + Spec: api, + Status: &statuses[i], + //Metrics: &allMetrics[i], + Endpoint: endpoint, + } + } + + return realtimeAPIs, nil +} + +func UpdateMetricsCron(deployment *kapps.Deployment) error { + // skip gateway deployments + if deployment.Labels["cortex.dev/async"] != "api" { + return nil + } + + apiName := deployment.Labels["apiName"] + deployID := deployment.Labels["deploymentID"] + + if prevMetricsCron, ok := _metricsCrons[apiName]; ok { + prevMetricsCron.Cancel() + } + + queueURL, err := getQueueURL(apiName, deployID) + if err != nil { + return err + } + + metricsCron := updateQueueLengthMetricsFn(apiName, queueURL) + + _metricsCrons[apiName] = cron.Run(metricsCron, operator.ErrorHandler(apiName+" metrics"), _tickPeriodMetrics) + + return nil +} + +func UpdateAutoscalerCron(deployment *kapps.Deployment, apiSpec spec.API) error { + // skip gateway deployments + if deployment.Labels["cortex.dev/async"] != "api" { + return nil + } + + apiName := deployment.Labels["apiName"] + if prevAutoscalerCron, ok := _autoscalerCrons[apiName]; ok { + prevAutoscalerCron.Cancel() + } + + autoscaler, err := autoscalerlib.AutoscaleFn(deployment, &apiSpec, getMessagesInQueue) + if err != nil { + return err + } + + _autoscalerCrons[apiName] = cron.Run(autoscaler, operator.ErrorHandler(apiName+" autoscaler"), spec.AutoscalingTickInterval) + + return nil +} + +func getK8sResources(apiConfig userconfig.API) (resources, error) { + var deployment *kapps.Deployment + var gatewayDeployment *kapps.Deployment + var gatewayService *kcore.Service + var gatewayVirtualService *istioclientnetworking.VirtualService + + gatewayK8sName := getGatewayK8sName(apiConfig.Name) + apiK8sName := operator.K8sName(apiConfig.Name) + + err := parallel.RunFirstErr( + func() error { + var err error + deployment, err = config.K8s.GetDeployment(apiK8sName) + return err + }, + func() error { + var err error + gatewayDeployment, err = config.K8s.GetDeployment(gatewayK8sName) + return err + }, + func() error { + var err error + gatewayService, err = config.K8s.GetService(apiK8sName) + return err + }, + func() error { + var err error + gatewayVirtualService, err = config.K8s.GetVirtualService(apiK8sName) + return err + }, + ) + + return resources{ + apiDeployment: deployment, + gatewayDeployment: gatewayDeployment, + gatewayService: gatewayService, + gatewayVirtualService: gatewayVirtualService, + }, err +} + +func applyK8sResources(api spec.API, prevK8sResources resources, queueURL string) error { + gatewayDeployment := apiDeploymentSpec(api, prevK8sResources.apiDeployment, queueURL) + apiDeployment := gatewayDeploymentSpec(api, prevK8sResources.gatewayDeployment, queueURL) + gatewayService := gatewayServiceSpec(api) + gatewayVirtualService := gatewayVirtualServiceSpec(api) + + return parallel.RunFirstErr( + func() error { + return applyK8sDeployment(api, prevK8sResources.apiDeployment, &apiDeployment) + }, + func() error { + return applyK8sDeployment(api, prevK8sResources.gatewayDeployment, &gatewayDeployment) + }, + func() error { + return applyK8sService(prevK8sResources.gatewayService, &gatewayService) + }, + func() error { + return applyK8sVirtualService(prevK8sResources.gatewayVirtualService, &gatewayVirtualService) + }, + ) +} + +func applyK8sDeployment(api spec.API, prevDeployment *kapps.Deployment, newDeployment *kapps.Deployment) error { + if prevDeployment == nil { + _, err := config.K8s.CreateDeployment(newDeployment) + if err != nil { + return err + } + } else if prevDeployment.Status.ReadyReplicas == 0 { + // Delete deployment if it never became ready + _, _ = config.K8s.DeleteDeployment(operator.K8sName(api.Name)) + _, err := config.K8s.CreateDeployment(newDeployment) + if err != nil { + return err + } + } else { + _, err := config.K8s.UpdateDeployment(newDeployment) + if err != nil { + return err + } + } + + if err := UpdateMetricsCron(newDeployment); err != nil { + return err + } + + if err := UpdateAutoscalerCron(newDeployment, api); err != nil { + return err + } + + return nil +} + +func applyK8sService(prevService *kcore.Service, newService *kcore.Service) error { + if prevService == nil { + _, err := config.K8s.CreateService(newService) + return err + } + + _, err := config.K8s.UpdateService(prevService, newService) + return err +} + +func applyK8sVirtualService(prevVirtualService *istioclientnetworking.VirtualService, newVirtualService *istioclientnetworking.VirtualService) error { + if prevVirtualService == nil { + _, err := config.K8s.CreateVirtualService(newVirtualService) + return err + } + + _, err := config.K8s.UpdateVirtualService(prevVirtualService, newVirtualService) + return err +} + +func deleteBucketResources(apiName string) error { + prefix := filepath.Join(config.ClusterName(), "apis", apiName) + return config.DeleteBucketDir(prefix, true) +} + +func deleteK8sResources(apiName string) error { + apiK8sName := operator.K8sName(apiName) + + err := parallel.RunFirstErr( + func() error { + if metricsCron, ok := _metricsCrons[apiName]; ok { + metricsCron.Cancel() + delete(_metricsCrons, apiName) + } + + if autoscalerCron, ok := _autoscalerCrons[apiName]; ok { + autoscalerCron.Cancel() + delete(_autoscalerCrons, apiName) + } + _, err := config.K8s.DeleteDeployment(apiK8sName) + return err + }, + func() error { + gatewayK8sName := getGatewayK8sName(apiName) + _, err := config.K8s.DeleteDeployment(gatewayK8sName) + return err + }, + func() error { + _, err := config.K8s.DeleteService(apiK8sName) + return err + }, + func() error { + _, err := config.K8s.DeleteVirtualService(apiK8sName) + return err + }, + ) + + return err +} + +func uploadAPItoS3(api spec.API) error { + return parallel.RunFirstErr( + func() error { + var err error + err = config.UploadJSONToBucket(api, api.Key) + if err != nil { + err = errors.Wrap(err, "upload api spec") + } + return err + }, + func() error { + var err error + // Use api spec indexed by PredictorID for replicas to prevent rolling updates when SpecID changes without PredictorID changing + err = config.UploadJSONToBucket(api, api.PredictorKey) + if err != nil { + err = errors.Wrap(err, "upload predictor spec") + } + return err + }, + ) +} diff --git a/pkg/operator/resources/asyncapi/autoscaler.go b/pkg/operator/resources/asyncapi/autoscaler.go new file mode 100644 index 0000000000..44eeb2b94b --- /dev/null +++ b/pkg/operator/resources/asyncapi/autoscaler.go @@ -0,0 +1,122 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package asyncapi + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/cortexlabs/cortex/pkg/lib/errors" + "github.com/cortexlabs/cortex/pkg/operator/config" + "github.com/cortexlabs/cortex/pkg/types/userconfig" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/common/model" +) + +const ( + _sqsQueryTimeoutSeconds = 10 + _prometheusQueryTimeoutSeconds = 10 +) + +var queueLengthGauge = promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "cortex_async_queue_length", + Help: "The number of in-queue messages for a cortex AsyncAPI", + ConstLabels: map[string]string{"api_kind": userconfig.AsyncAPIKind.String()}, + }, []string{"api_name"}, +) + +func updateQueueLengthMetricsFn(apiName, queueURL string) func() error { + return func() error { + sqsClient := config.AWS.SQS() + + ctx, cancel := context.WithTimeout(context.Background(), _sqsQueryTimeoutSeconds*time.Second) + defer cancel() + + input := &sqs.GetQueueAttributesInput{ + AttributeNames: []*string{ + aws.String("ApproximateNumberOfMessages"), + aws.String("ApproximateNumberOfMessagesNotVisible"), + }, + QueueUrl: aws.String(queueURL), + } + + output, err := sqsClient.GetQueueAttributesWithContext(ctx, input) + if err != nil { + return err + } + + visibleMessagesStr := output.Attributes["ApproximateNumberOfMessages"] + invisibleMessagesStr := output.Attributes["ApproximateNumberOfMessagesNotVisible"] + + visibleMessages, err := strconv.ParseFloat(*visibleMessagesStr, 64) + if err != nil { + return err + } + + invisibleMessages, err := strconv.ParseFloat(*invisibleMessagesStr, 64) + if err != nil { + return err + } + + queueLength := visibleMessages + invisibleMessages + queueLengthGauge.WithLabelValues(apiName).Set(queueLength) + + return nil + } +} + +func getMessagesInQueue(apiName string, window time.Duration) (*float64, error) { + windowSeconds := int64(window.Seconds()) + + // PromQL query: + // sum(sum_over_time(cortex_async_queue_length{api_name=""}[60s])) / + // sum(count_over_time(cortex_async_queue_length{api_name=""}[60s])) + query := fmt.Sprintf( + "sum(sum_over_time(cortex_async_queue_length{api_name=\"%s\"}[%ds])) / "+ + "max(count_over_time(cortex_async_queue_length{api_name=\"%s\"}[%ds]))", + apiName, windowSeconds, + apiName, windowSeconds, + ) + + ctx, cancel := context.WithTimeout(context.Background(), _prometheusQueryTimeoutSeconds*time.Second) + defer cancel() + + valuesQuery, err := config.Prometheus.Query(ctx, query, time.Now()) + if err != nil { + return nil, err + } + + values, ok := valuesQuery.(model.Vector) + if !ok { + return nil, errors.ErrorUnexpected("failed to convert prometheus metric to vector") + } + + // no values available + if values.Len() == 0 { + return nil, nil + } + + avgMessagesInQueue := float64(values[0].Value) + + return &avgMessagesInQueue, nil +} diff --git a/pkg/operator/resources/asyncapi/errors.go b/pkg/operator/resources/asyncapi/errors.go new file mode 100644 index 0000000000..aa9560a433 --- /dev/null +++ b/pkg/operator/resources/asyncapi/errors.go @@ -0,0 +1,34 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package asyncapi + +import ( + "fmt" + + "github.com/cortexlabs/cortex/pkg/lib/errors" +) + +const ( + ErrAPIUpdating = "ayncapi.api_updating" +) + +func ErrorAPIUpdating(apiName string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrAPIUpdating, + Message: fmt.Sprintf("%s is updating (override with --force)", apiName), + }) +} diff --git a/pkg/operator/resources/asyncapi/k8s_specs.go b/pkg/operator/resources/asyncapi/k8s_specs.go new file mode 100644 index 0000000000..73d7252fbd --- /dev/null +++ b/pkg/operator/resources/asyncapi/k8s_specs.go @@ -0,0 +1,193 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package asyncapi + +import ( + "github.com/cortexlabs/cortex/pkg/lib/k8s" + "github.com/cortexlabs/cortex/pkg/lib/pointer" + "github.com/cortexlabs/cortex/pkg/operator/operator" + "github.com/cortexlabs/cortex/pkg/types/spec" + "istio.io/client-go/pkg/apis/networking/v1beta1" + kapps "k8s.io/api/apps/v1" + kcore "k8s.io/api/core/v1" +) + +var _terminationGracePeriodSeconds int64 = 60 // seconds + +func gatewayDeploymentSpec(api spec.API, prevDeployment *kapps.Deployment, queueURL string) kapps.Deployment { + container := operator.AsyncGatewayContainers(api, queueURL) + return *k8s.Deployment(&k8s.DeploymentSpec{ + Name: getGatewayK8sName(api.Name), + Replicas: getRequestedReplicasFromDeployment(api, prevDeployment), + MaxSurge: pointer.String(api.UpdateStrategy.MaxSurge), + MaxUnavailable: pointer.String(api.UpdateStrategy.MaxUnavailable), + Selector: map[string]string{ + "apiName": api.Name, + "apiKind": api.Kind.String(), + "cortex.dev/async": "gateway", + }, + Labels: map[string]string{ + "apiName": api.Name, + "apiKind": api.Kind.String(), + "apiID": api.ID, + "specID": api.SpecID, + "deploymentID": api.DeploymentID, + "predictorID": api.PredictorID, + "cortex.dev/api": "true", + "cortex.dev/async": "gateway", + }, + PodSpec: k8s.PodSpec{ + Labels: map[string]string{ + "apiName": api.Name, + "apiKind": api.Kind.String(), + "deploymentID": api.DeploymentID, + "predictorID": api.PredictorID, + "cortex.dev/api": "true", + "cortex.dev/async": "gateway", + }, + K8sPodSpec: kcore.PodSpec{ + RestartPolicy: "Always", + TerminationGracePeriodSeconds: pointer.Int64(_terminationGracePeriodSeconds), + Containers: []kcore.Container{container}, + NodeSelector: operator.NodeSelectors(), + Tolerations: operator.GenerateResourceTolerations(), + Affinity: &kcore.Affinity{ + NodeAffinity: &kcore.NodeAffinity{ + PreferredDuringSchedulingIgnoredDuringExecution: operator.GeneratePreferredNodeAffinities(), + }, + }, ServiceAccountName: operator.ServiceAccountName, + }, + }, + }) +} + +func gatewayServiceSpec(api spec.API) kcore.Service { + return *k8s.Service(&k8s.ServiceSpec{ + Name: operator.K8sName(api.Name), + Port: operator.DefaultPortInt32, + TargetPort: operator.DefaultPortInt32, + Annotations: api.ToK8sAnnotations(), + Labels: map[string]string{ + "apiName": api.Name, + "apiKind": api.Kind.String(), + "cortex.dev/api": "true", + "cortex.dev/async": "gateway", + }, + Selector: map[string]string{ + "apiName": api.Name, + "apiKind": api.Kind.String(), + "cortex.dev/async": "gateway", + }, + }) +} + +func gatewayVirtualServiceSpec(api spec.API) v1beta1.VirtualService { + return *k8s.VirtualService(&k8s.VirtualServiceSpec{ + Name: operator.K8sName(api.Name), + Gateways: []string{"apis-gateway"}, + Destinations: []k8s.Destination{{ + ServiceName: operator.K8sName(api.Name), + Weight: 100, + Port: uint32(operator.DefaultPortInt32), + }}, + PrefixPath: api.Networking.Endpoint, + Rewrite: pointer.String("/"), + Annotations: api.ToK8sAnnotations(), + Labels: map[string]string{ + "apiName": api.Name, + "apiKind": api.Kind.String(), + "apiID": api.ID, + "specID": api.SpecID, + "deploymentID": api.DeploymentID, + "predictorID": api.PredictorID, + "cortex.dev/api": "true", + "cortex.dev/async": "gateway", + }, + }) +} + +func apiDeploymentSpec(api spec.API, prevDeployment *kapps.Deployment, queueURL string) kapps.Deployment { + containers, volumes := operator.AsyncPythonPredictorContainers(api, queueURL) + + return *k8s.Deployment(&k8s.DeploymentSpec{ + Name: operator.K8sName(api.Name), + Replicas: getRequestedReplicasFromDeployment(api, prevDeployment), + MaxSurge: pointer.String(api.UpdateStrategy.MaxSurge), + MaxUnavailable: pointer.String(api.UpdateStrategy.MaxUnavailable), + Labels: map[string]string{ + "apiName": api.Name, + "apiKind": api.Kind.String(), + "apiID": api.ID, + "specID": api.SpecID, + "deploymentID": api.DeploymentID, + "predictorID": api.PredictorID, + "cortex.dev/api": "true", + "cortex.dev/async": "api", + }, + Annotations: api.ToK8sAnnotations(), + Selector: map[string]string{ + "apiName": api.Name, + "apiKind": api.Kind.String(), + "cortex.dev/async": "api", + }, + PodSpec: k8s.PodSpec{ + Labels: map[string]string{ + "apiName": api.Name, + "apiKind": api.Kind.String(), + "deploymentID": api.DeploymentID, + "predictorID": api.PredictorID, + "cortex.dev/api": "true", + "cortex.dev/async": "api", + }, + K8sPodSpec: kcore.PodSpec{ + RestartPolicy: "Always", + TerminationGracePeriodSeconds: pointer.Int64(_terminationGracePeriodSeconds), + InitContainers: []kcore.Container{ + operator.InitContainer(&api), + }, + Containers: containers, + NodeSelector: operator.NodeSelectors(), + Tolerations: operator.GenerateResourceTolerations(), + Affinity: &kcore.Affinity{ + NodeAffinity: &kcore.NodeAffinity{ + PreferredDuringSchedulingIgnoredDuringExecution: operator.GeneratePreferredNodeAffinities(), + }, + }, + Volumes: volumes, + ServiceAccountName: operator.ServiceAccountName, + }, + }, + }) +} + +func getRequestedReplicasFromDeployment(api spec.API, deployment *kapps.Deployment) int32 { + requestedReplicas := api.Autoscaling.InitReplicas + + if deployment != nil && deployment.Spec.Replicas != nil && *deployment.Spec.Replicas > 0 { + requestedReplicas = *deployment.Spec.Replicas + } + + if requestedReplicas < api.Autoscaling.MinReplicas { + requestedReplicas = api.Autoscaling.MinReplicas + } + + if requestedReplicas > api.Autoscaling.MaxReplicas { + requestedReplicas = api.Autoscaling.MaxReplicas + } + + return requestedReplicas +} diff --git a/pkg/operator/resources/asyncapi/queue.go b/pkg/operator/resources/asyncapi/queue.go new file mode 100644 index 0000000000..b21123207f --- /dev/null +++ b/pkg/operator/resources/asyncapi/queue.go @@ -0,0 +1,84 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package asyncapi + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/cortexlabs/cortex/pkg/lib/errors" + "github.com/cortexlabs/cortex/pkg/operator/config" +) + +func createFIFOQueue(apiName string, deploymentID string, tags map[string]string) (string, error) { + if config.CoreConfig.IsManaged { + managedConfig := config.ManagedConfigOrNil() + if managedConfig != nil { + for key, value := range managedConfig.Tags { + tags[key] = value + } + } + } + + queueName := apiQueueName(apiName, deploymentID) + + attributes := map[string]string{ + sqs.QueueAttributeNameFifoQueue: "true", + sqs.QueueAttributeNameVisibilityTimeout: "60", + } + + output, err := config.AWS.SQS().CreateQueue( + &sqs.CreateQueueInput{ + Attributes: aws.StringMap(attributes), + QueueName: aws.String(queueName), + Tags: aws.StringMap(tags), + }, + ) + if err != nil { + return "", errors.Wrap(err, "failed to create sqs queue", queueName) + } + + return *output.QueueUrl, nil +} + +func apiQueueName(apiName string, deploymentID string) string { + return config.CoreConfig.SQSNamePrefix() + apiName + "-" + deploymentID + ".fifo" +} + +func deleteQueueByURL(queueURL string) error { + _, err := config.AWS.SQS().DeleteQueue(&sqs.DeleteQueueInput{ + QueueUrl: aws.String(queueURL), + }) + if err != nil { + return errors.Wrap(err, "failed to delete queue", queueURL) + } + + return err +} + +func getQueueURL(apiName string, deploymentID string) (string, error) { + operatorAccountID, _, err := config.AWS.GetCachedAccountID() + if err != nil { + return "", errors.Wrap(err, "failed to construct queue url", "unable to get account id") + } + + return fmt.Sprintf( + "https://sqs.%s.amazonaws.com/%s/%s", + config.AWS.Region, operatorAccountID, apiQueueName(apiName, deploymentID), + ), nil +} diff --git a/pkg/operator/resources/asyncapi/status.go b/pkg/operator/resources/asyncapi/status.go new file mode 100644 index 0000000000..834a9fa1f4 --- /dev/null +++ b/pkg/operator/resources/asyncapi/status.go @@ -0,0 +1,292 @@ +/* +Copyright 2021 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package asyncapi + +import ( + "sort" + "time" + + "github.com/cortexlabs/cortex/pkg/lib/errors" + "github.com/cortexlabs/cortex/pkg/lib/k8s" + "github.com/cortexlabs/cortex/pkg/lib/parallel" + "github.com/cortexlabs/cortex/pkg/operator/config" + "github.com/cortexlabs/cortex/pkg/operator/operator" + "github.com/cortexlabs/cortex/pkg/types/status" + "github.com/cortexlabs/cortex/pkg/types/userconfig" + kapps "k8s.io/api/apps/v1" + kcore "k8s.io/api/core/v1" +) + +type asyncResourceGroup struct { + APIDeployment *kapps.Deployment + APIPods []kcore.Pod + GatewayDeployment *kapps.Deployment + GatewayPods []kcore.Pod +} + +func GetStatus(apiName string) (*status.Status, error) { + var apiDeployment *kapps.Deployment + var gatewayDeployment *kapps.Deployment + var gatewayPods []kcore.Pod + var apiPods []kcore.Pod + + err := parallel.RunFirstErr( + func() error { + var err error + apiDeployment, err = config.K8s.GetDeployment(operator.K8sName(apiName)) + return err + }, + func() error { + var err error + gatewayDeployment, err = config.K8s.GetDeployment(getGatewayK8sName(apiName)) + return err + }, + func() error { + var err error + gatewayPods, err = config.K8s.ListPodsByLabels( + map[string]string{ + "apiName": apiName, + "cortex.dev/async": "gateway", + }, + ) + return err + }, + func() error { + var err error + apiPods, err = config.K8s.ListPodsByLabels( + map[string]string{ + "apiName": apiName, + "cortex.dev/async": "api", + }, + ) + return err + }, + ) + if err != nil { + return nil, err + } + + if apiDeployment == nil { + return nil, errors.ErrorUnexpected("unable to find api deployment", apiName) + } + + if gatewayDeployment == nil { + return nil, errors.ErrorUnexpected("unable to find gateway deployment", apiName) + } + + return apiStatus(apiDeployment, apiPods, gatewayDeployment, gatewayPods) +} + +func GetAllStatuses(deployments []kapps.Deployment, pods []kcore.Pod) ([]status.Status, error) { + resourcesByAPI := groupResourcesByAPI(deployments, pods) + statuses := make([]status.Status, len(resourcesByAPI)) + + var i int + for _, k8sResources := range resourcesByAPI { + st, err := apiStatus(k8sResources.APIDeployment, k8sResources.APIPods, k8sResources.GatewayDeployment, k8sResources.GatewayPods) + if err != nil { + return nil, err + } + statuses[i] = *st + i++ + } + + sort.Slice(statuses, func(i, j int) bool { + return statuses[i].APIName < statuses[j].APIName + }) + + return statuses, nil +} + +func namesAndIDsFromStatuses(statuses []status.Status) ([]string, []string) { + apiNames := make([]string, len(statuses)) + apiIDs := make([]string, len(statuses)) + + for i, st := range statuses { + apiNames[i] = st.APIName + apiIDs[i] = st.APIID + } + + return apiNames, apiIDs +} + +// let's do CRDs instead, to avoid this +func groupResourcesByAPI(deployments []kapps.Deployment, pods []kcore.Pod) map[string]*asyncResourceGroup { + resourcesByAPI := map[string]*asyncResourceGroup{} + for i := range deployments { + deployment := deployments[i] + apiName := deployment.Labels["apiName"] + asyncType := deployment.Labels["cortex.dev/async"] + apiResources, exists := resourcesByAPI[apiName] + if exists { + if asyncType == "api" { + apiResources.APIDeployment = &deployment + } else { + apiResources.GatewayDeployment = &deployment + } + } else { + if asyncType == "api" { + resourcesByAPI[apiName] = &asyncResourceGroup{APIDeployment: &deployment} + } else { + resourcesByAPI[apiName] = &asyncResourceGroup{GatewayDeployment: &deployment} + } + } + } + + for _, pod := range pods { + apiName := pod.Labels["apiName"] + asyncType := pod.Labels["cortex.dev/async"] + apiResources, exists := resourcesByAPI[apiName] + if !exists { + // ignore pods that might still be waiting to be deleted while the deployment has already been deleted + continue + } + + if asyncType == "api" { + apiResources.APIPods = append(resourcesByAPI[apiName].APIPods, pod) + } else { + apiResources.GatewayPods = append(resourcesByAPI[apiName].GatewayPods, pod) + } + } + return resourcesByAPI +} + +func apiStatus(apiDeployment *kapps.Deployment, apiPods []kcore.Pod, gatewayDeployment *kapps.Deployment, gatewayPods []kcore.Pod) (*status.Status, error) { + autoscalingSpec, err := userconfig.AutoscalingFromAnnotations(apiDeployment) + if err != nil { + return nil, err + } + + apiReplicaCounts := getReplicaCounts(apiDeployment, apiPods) + gatewayReplicaCounts := getReplicaCounts(gatewayDeployment, gatewayPods) + + st := &status.Status{} + st.APIName = apiDeployment.Labels["apiName"] + st.APIID = apiDeployment.Labels["apiID"] + st.ReplicaCounts = apiReplicaCounts + st.Code = getStatusCode(apiReplicaCounts, gatewayReplicaCounts, autoscalingSpec.MinReplicas) + + return st, nil +} + +func getStatusCode(apiCounts status.ReplicaCounts, gatewayCounts status.ReplicaCounts, apiMinReplicas int32) status.Code { + if apiCounts.Updated.Ready >= apiCounts.Requested && gatewayCounts.Updated.Ready >= 1 { + return status.Live + } + + if apiCounts.Updated.ErrImagePull > 0 || gatewayCounts.Updated.ErrImagePull > 0 { + return status.ErrorImagePull + } + + if apiCounts.Updated.Failed > 0 || apiCounts.Updated.Killed > 0 || + gatewayCounts.Updated.Failed > 0 || gatewayCounts.Updated.Killed > 0 { + return status.Error + } + + if apiCounts.Updated.KilledOOM > 0 || gatewayCounts.Updated.KilledOOM > 0 { + return status.OOM + } + + if apiCounts.Updated.Stalled > 0 || gatewayCounts.Updated.Stalled > 0 { + return status.Stalled + } + + if apiCounts.Updated.Ready >= apiMinReplicas && gatewayCounts.Updated.Ready >= 1 { + return status.Live + } + + return status.Updating +} + +// returns true if min_replicas are not ready and no updated replicas have errored +func isAPIUpdating(deployment *kapps.Deployment) (bool, error) { + pods, err := config.K8s.ListPodsByLabel("apiName", deployment.Labels["apiName"]) + if err != nil { + return false, err + } + + replicaCounts := getReplicaCounts(deployment, pods) + + autoscalingSpec, err := userconfig.AutoscalingFromAnnotations(deployment) + if err != nil { + return false, err + } + + if replicaCounts.Updated.Ready < autoscalingSpec.MinReplicas && replicaCounts.Updated.TotalFailed() == 0 { + return true, nil + } + + return false, nil +} + +func getReplicaCounts(deployment *kapps.Deployment, pods []kcore.Pod) status.ReplicaCounts { + counts := status.ReplicaCounts{} + counts.Requested = *deployment.Spec.Replicas + + for _, pod := range pods { + if pod.Labels["apiName"] != deployment.Labels["apiName"] { + continue + } + addPodToReplicaCounts(&pod, deployment, &counts) + } + + return counts +} + +func addPodToReplicaCounts(pod *kcore.Pod, deployment *kapps.Deployment, counts *status.ReplicaCounts) { + var subCounts *status.SubReplicaCounts + if isPodSpecLatest(deployment, pod) { + subCounts = &counts.Updated + } else { + subCounts = &counts.Stale + } + + if k8s.IsPodReady(pod) { + subCounts.Ready++ + return + } + + switch k8s.GetPodStatus(pod) { + case k8s.PodStatusPending: + if time.Since(pod.CreationTimestamp.Time) > _stalledPodTimeout { + subCounts.Stalled++ + } else { + subCounts.Pending++ + } + case k8s.PodStatusInitializing: + subCounts.Initializing++ + case k8s.PodStatusRunning: + subCounts.Initializing++ + case k8s.PodStatusErrImagePull: + subCounts.ErrImagePull++ + case k8s.PodStatusTerminating: + subCounts.Terminating++ + case k8s.PodStatusFailed: + subCounts.Failed++ + case k8s.PodStatusKilled: + subCounts.Killed++ + case k8s.PodStatusKilledOOM: + subCounts.KilledOOM++ + default: + subCounts.Unknown++ + } +} + +func isPodSpecLatest(deployment *kapps.Deployment, pod *kcore.Pod) bool { + return deployment.Spec.Template.Labels["predictorID"] == pod.Labels["predictorID"] && + deployment.Spec.Template.Labels["deploymentID"] == pod.Labels["deploymentID"] +} diff --git a/pkg/operator/resources/job/batchapi/queue.go b/pkg/operator/resources/job/batchapi/queue.go index 2c8b4eab98..74da31df33 100644 --- a/pkg/operator/resources/job/batchapi/queue.go +++ b/pkg/operator/resources/job/batchapi/queue.go @@ -39,10 +39,10 @@ const ( ) func apiQueueNamePrefix(apiName string) string { - return config.CoreConfig.SQSNamePrefix() + apiName + "-" + return fmt.Sprintf("%sb-%s-", config.CoreConfig.SQSNamePrefix(), apiName) } -// QueueName is cortex---.fifo +// QueueName is cx--b--.fifo func getJobQueueName(jobKey spec.JobKey) string { return apiQueueNamePrefix(jobKey.APIName) + jobKey.ID + ".fifo" } @@ -64,7 +64,7 @@ func jobKeyFromQueueURL(queueURL string) spec.JobKey { jobID := strings.TrimSuffix(dashSplit[len(dashSplit)-1], ".fifo") - apiNameSplit := dashSplit[2 : len(dashSplit)-1] + apiNameSplit := dashSplit[3 : len(dashSplit)-1] apiName := strings.Join(apiNameSplit, "-") return spec.JobKey{APIName: apiName, ID: jobID} @@ -120,7 +120,7 @@ func doesQueueExist(jobKey spec.JobKey) (bool, error) { } func listQueueURLsForAllAPIs() ([]string, error) { - queueURLs, err := config.AWS.ListQueuesByQueueNamePrefix(config.CoreConfig.SQSNamePrefix()) + queueURLs, err := config.AWS.ListQueuesByQueueNamePrefix(config.CoreConfig.SQSNamePrefix() + "b-") if err != nil { return nil, err } diff --git a/pkg/operator/resources/realtimeapi/api.go b/pkg/operator/resources/realtimeapi/api.go index eeeae90f0d..d973efef9f 100644 --- a/pkg/operator/resources/realtimeapi/api.go +++ b/pkg/operator/resources/realtimeapi/api.go @@ -26,6 +26,7 @@ import ( "github.com/cortexlabs/cortex/pkg/lib/parallel" "github.com/cortexlabs/cortex/pkg/lib/pointer" "github.com/cortexlabs/cortex/pkg/operator/config" + autoscalerlib "github.com/cortexlabs/cortex/pkg/operator/lib/autoscaler" "github.com/cortexlabs/cortex/pkg/operator/lib/routines" "github.com/cortexlabs/cortex/pkg/operator/operator" "github.com/cortexlabs/cortex/pkg/operator/schema" @@ -338,7 +339,7 @@ func UpdateAutoscalerCron(deployment *kapps.Deployment, apiSpec *spec.API) error prevAutoscalerCron.Cancel() } - autoscaler, err := autoscaleFn(deployment, apiSpec) + autoscaler, err := autoscalerlib.AutoscaleFn(deployment, apiSpec, getInflightRequests) if err != nil { return err } diff --git a/pkg/operator/resources/realtimeapi/autoscaler.go b/pkg/operator/resources/realtimeapi/autoscaler.go index 5a41b3d904..b1c5c173df 100644 --- a/pkg/operator/resources/realtimeapi/autoscaler.go +++ b/pkg/operator/resources/realtimeapi/autoscaler.go @@ -19,221 +19,17 @@ package realtimeapi import ( "context" "fmt" - "math" "time" "github.com/cortexlabs/cortex/pkg/lib/errors" - libmath "github.com/cortexlabs/cortex/pkg/lib/math" - s "github.com/cortexlabs/cortex/pkg/lib/strings" - libtime "github.com/cortexlabs/cortex/pkg/lib/time" "github.com/cortexlabs/cortex/pkg/operator/config" - "github.com/cortexlabs/cortex/pkg/operator/operator" - "github.com/cortexlabs/cortex/pkg/types/spec" - "github.com/cortexlabs/cortex/pkg/types/userconfig" "github.com/prometheus/common/model" - kapps "k8s.io/api/apps/v1" ) const ( _prometheusQueryTimeoutSeconds = 10 ) -type recommendations map[time.Time]int32 - -func (recs recommendations) add(rec int32) { - recs[time.Now()] = rec -} - -func (recs recommendations) deleteOlderThan(period time.Duration) { - for t := range recs { - if time.Since(t) > period { - delete(recs, t) - } - } -} - -// Returns nil if no recommendations in the period -func (recs recommendations) maxSince(period time.Duration) *int32 { - max := int32(math.MinInt32) - foundRecommendation := false - - for t, rec := range recs { - if time.Since(t) <= period && rec > max { - max = rec - foundRecommendation = true - } - } - - if !foundRecommendation { - return nil - } - - return &max -} - -// Returns nil if no recommendations in the period -func (recs recommendations) minSince(period time.Duration) *int32 { - min := int32(math.MaxInt32) - foundRecommendation := false - - for t, rec := range recs { - if time.Since(t) <= period && rec < min { - min = rec - foundRecommendation = true - } - } - - if !foundRecommendation { - return nil - } - - return &min -} - -func autoscaleFn(initialDeployment *kapps.Deployment, apiSpec *spec.API) (func() error, error) { - autoscalingSpec, err := userconfig.AutoscalingFromAnnotations(initialDeployment) - if err != nil { - return nil, err - } - - apiName := apiSpec.Name - currentReplicas := *initialDeployment.Spec.Replicas - - apiLogger, err := operator.GetRealtimeAPILoggerFromSpec(apiSpec) - if err != nil { - return nil, err - } - - apiLogger.Infof("%s autoscaler init", apiName) - - var startTime time.Time - recs := make(recommendations) - - return func() error { - if startTime.IsZero() { - startTime = time.Now() - } - - avgInFlight, err := getInflightRequests(apiName, autoscalingSpec.Window) - if err != nil { - return err - } - if avgInFlight == nil { - apiLogger.Debugf("%s autoscaler tick: metrics not available yet", apiName) - return nil - } - - rawRecommendation := *avgInFlight / *autoscalingSpec.TargetReplicaConcurrency - recommendation := int32(math.Ceil(rawRecommendation)) - - if rawRecommendation < float64(currentReplicas) && rawRecommendation > float64(currentReplicas)*(1-autoscalingSpec.DownscaleTolerance) { - recommendation = currentReplicas - } - - if rawRecommendation > float64(currentReplicas) && rawRecommendation < float64(currentReplicas)*(1+autoscalingSpec.UpscaleTolerance) { - recommendation = currentReplicas - } - - // always allow subtraction of 1 - downscaleFactorFloor := libmath.MinInt32(currentReplicas-1, int32(math.Ceil(float64(currentReplicas)*autoscalingSpec.MaxDownscaleFactor))) - if recommendation < downscaleFactorFloor { - recommendation = downscaleFactorFloor - } - - // always allow addition of 1 - upscaleFactorCeil := libmath.MaxInt32(currentReplicas+1, int32(math.Ceil(float64(currentReplicas)*autoscalingSpec.MaxUpscaleFactor))) - if recommendation > upscaleFactorCeil { - recommendation = upscaleFactorCeil - } - - if recommendation < 1 { - recommendation = 1 - } - - if recommendation < autoscalingSpec.MinReplicas { - recommendation = autoscalingSpec.MinReplicas - } - - if recommendation > autoscalingSpec.MaxReplicas { - recommendation = autoscalingSpec.MaxReplicas - } - - // Rule of thumb: any modifications that don't consider historical recommendations should be performed before - // recording the recommendation, any modifications that use historical recommendations should be performed after - recs.add(recommendation) - - // This is just for garbage collection - recs.deleteOlderThan(libtime.MaxDuration(autoscalingSpec.DownscaleStabilizationPeriod, autoscalingSpec.UpscaleStabilizationPeriod)) - - request := recommendation - var downscaleStabilizationFloor *int32 - var upscaleStabilizationCeil *int32 - - if request < currentReplicas { - downscaleStabilizationFloor = recs.maxSince(autoscalingSpec.DownscaleStabilizationPeriod) - if time.Since(startTime) < autoscalingSpec.DownscaleStabilizationPeriod { - request = currentReplicas - } else if downscaleStabilizationFloor != nil && request < *downscaleStabilizationFloor { - request = *downscaleStabilizationFloor - } - } - if request > currentReplicas { - upscaleStabilizationCeil = recs.minSince(autoscalingSpec.UpscaleStabilizationPeriod) - if time.Since(startTime) < autoscalingSpec.UpscaleStabilizationPeriod { - request = currentReplicas - } else if upscaleStabilizationCeil != nil && request > *upscaleStabilizationCeil { - request = *upscaleStabilizationCeil - } - } - - apiLogger.Debugw(fmt.Sprintf("%s autoscaler tick", apiName), - "autoscaling", map[string]interface{}{ - "avg_in_flight": s.Round(*avgInFlight, 2, 0), - "target_replica_concurrency": s.Float64(*autoscalingSpec.TargetReplicaConcurrency), - "raw_recommendation": s.Round(rawRecommendation, 2, 0), - "current_replicas": currentReplicas, - "downscale_tolerance": s.Float64(autoscalingSpec.DownscaleTolerance), - "upscale_tolerance": s.Float64(autoscalingSpec.UpscaleTolerance), - "max_downscale_factor": s.Float64(autoscalingSpec.MaxDownscaleFactor), - "downscale_factor_floor": downscaleFactorFloor, - "max_upscale_factor": s.Float64(autoscalingSpec.MaxUpscaleFactor), - "upscale_factor_ceil": upscaleFactorCeil, - "min_replicas": autoscalingSpec.MinReplicas, - "max_replicas": autoscalingSpec.MaxReplicas, - "recommendation": recommendation, - "downscale_stabilization_period": autoscalingSpec.DownscaleStabilizationPeriod, - "downscale_stabilization_floor": s.ObjFlatNoQuotes(downscaleStabilizationFloor), - "upscale_stabilization_period": autoscalingSpec.UpscaleStabilizationPeriod, - "upscale_stabilization_ceil": s.ObjFlatNoQuotes(upscaleStabilizationCeil), - "request": request, - }, - ) - - if currentReplicas != request { - apiLogger.Infof("%s autoscaling event: %d -> %d", apiName, currentReplicas, request) - - deployment, err := config.K8s.GetDeployment(initialDeployment.Name) - if err != nil { - return err - } - - if deployment == nil { - return errors.ErrorUnexpected("unable to find k8s deployment", apiName) - } - - deployment.Spec.Replicas = &request - - if _, err := config.K8s.UpdateDeployment(deployment); err != nil { - return err - } - - currentReplicas = request - } - - return nil - }, nil -} - func getInflightRequests(apiName string, window time.Duration) (*float64, error) { windowSeconds := int64(window.Seconds()) diff --git a/pkg/operator/resources/resources.go b/pkg/operator/resources/resources.go index 2e4fdffd4e..b6286df365 100644 --- a/pkg/operator/resources/resources.go +++ b/pkg/operator/resources/resources.go @@ -30,6 +30,7 @@ import ( "github.com/cortexlabs/cortex/pkg/operator/config" "github.com/cortexlabs/cortex/pkg/operator/lib/routines" "github.com/cortexlabs/cortex/pkg/operator/operator" + "github.com/cortexlabs/cortex/pkg/operator/resources/asyncapi" "github.com/cortexlabs/cortex/pkg/operator/resources/job/batchapi" "github.com/cortexlabs/cortex/pkg/operator/resources/job/taskapi" "github.com/cortexlabs/cortex/pkg/operator/resources/realtimeapi" @@ -164,11 +165,14 @@ func UpdateAPI(apiConfig *userconfig.API, projectID string, force bool) (*schema api, msg, err = batchapi.UpdateAPI(apiConfig, projectID) case userconfig.TaskAPIKind: api, msg, err = taskapi.UpdateAPI(apiConfig, projectID) + case userconfig.AsyncAPIKind: + api, msg, err = asyncapi.UpdateAPI(*apiConfig, projectID, force) case userconfig.TrafficSplitterKind: api, msg, err = trafficsplitter.UpdateAPI(apiConfig) default: return nil, "", ErrorOperationIsOnlySupportedForKind( *deployedResource, userconfig.RealtimeAPIKind, + userconfig.AsyncAPIKind, userconfig.BatchAPIKind, userconfig.TrafficSplitterKind, userconfig.TaskAPIKind, @@ -235,7 +239,7 @@ func patchAPI(apiConfig *userconfig.API, force bool) (*spec.API, string, error) } if deployedResource.Kind == userconfig.UnknownKind { - return nil, "", ErrorOperationIsOnlySupportedForKind(*deployedResource, userconfig.RealtimeAPIKind, userconfig.BatchAPIKind, userconfig.TrafficSplitterKind) // unexpected + return nil, "", ErrorOperationIsOnlySupportedForKind(*deployedResource, userconfig.RealtimeAPIKind, userconfig.AsyncAPIKind, userconfig.BatchAPIKind, userconfig.TaskAPIKind, userconfig.TrafficSplitterKind) // unexpected } var projectFiles ProjectFiles @@ -274,6 +278,8 @@ func patchAPI(apiConfig *userconfig.API, force bool) (*spec.API, string, error) return batchapi.UpdateAPI(apiConfig, prevAPISpec.ProjectID) case userconfig.TaskAPIKind: return taskapi.UpdateAPI(apiConfig, prevAPISpec.ProjectID) + case userconfig.AsyncAPIKind: + return asyncapi.UpdateAPI(*apiConfig, prevAPISpec.ProjectID, force) default: return trafficsplitter.UpdateAPI(apiConfig) } @@ -320,6 +326,12 @@ func DeleteAPI(apiName string, keepCache bool) (*schema.DeleteResponse, error) { func() error { return taskapi.DeleteAPI(apiName, keepCache) }, + func() error { + if config.Provider == types.AWSProviderType { + return asyncapi.DeleteAPI(apiName, keepCache) + } + return nil + }, ) if err != nil { telemetry.Error(err) @@ -353,8 +365,13 @@ func DeleteAPI(apiName string, keepCache bool) (*schema.DeleteResponse, error) { if err != nil { return nil, err } + case userconfig.AsyncAPIKind: + err = asyncapi.DeleteAPI(apiName, keepCache) + if err != nil { + return nil, err + } default: - return nil, ErrorOperationIsOnlySupportedForKind(*deployedResource, userconfig.RealtimeAPIKind, userconfig.BatchAPIKind, userconfig.TrafficSplitterKind) // unexpected + return nil, ErrorOperationIsOnlySupportedForKind(*deployedResource, userconfig.RealtimeAPIKind, userconfig.AsyncAPIKind, userconfig.BatchAPIKind, userconfig.TrafficSplitterKind) // unexpected } return &schema.DeleteResponse{ @@ -416,9 +433,21 @@ func GetAPIs() ([]schema.APIResponse, error) { return nil, err } - realtimeAPIPods := []kcore.Pod{} - batchAPIPods := []kcore.Pod{} - taskAPIPods := []kcore.Pod{} + var realtimeAPIDeployments []kapps.Deployment + var asyncAPIDeployments []kapps.Deployment + for _, deployment := range deployments { + switch deployment.Labels["apiKind"] { + case userconfig.RealtimeAPIKind.String(): + realtimeAPIDeployments = append(realtimeAPIDeployments, deployment) + case userconfig.AsyncAPIKind.String(): + asyncAPIDeployments = append(asyncAPIDeployments, deployment) + } + } + + var realtimeAPIPods []kcore.Pod + var batchAPIPods []kcore.Pod + var taskAPIPods []kcore.Pod + var asyncAPIPods []kcore.Pod for _, pod := range pods { switch pod.Labels["apiKind"] { case userconfig.RealtimeAPIKind.String(): @@ -427,6 +456,8 @@ func GetAPIs() ([]schema.APIResponse, error) { batchAPIPods = append(batchAPIPods, pod) case userconfig.TaskAPIKind.String(): taskAPIPods = append(taskAPIPods, pod) + case userconfig.AsyncAPIKind.String(): + asyncAPIPods = append(asyncAPIPods, pod) } } @@ -445,7 +476,7 @@ func GetAPIs() ([]schema.APIResponse, error) { } } - realtimeAPIList, err := realtimeapi.GetAllAPIs(realtimeAPIPods, deployments) + realtimeAPIList, err := realtimeapi.GetAllAPIs(realtimeAPIPods, realtimeAPIDeployments) if err != nil { return nil, err } @@ -464,6 +495,11 @@ func GetAPIs() ([]schema.APIResponse, error) { } } + asyncAPIList, err := asyncapi.GetAllAPIs(asyncAPIPods, asyncAPIDeployments) + if err != nil { + return nil, err + } + trafficSplitterList, err := trafficsplitter.GetAllAPIs(trafficSplitterVirtualServices) if err != nil { return nil, err @@ -474,6 +510,7 @@ func GetAPIs() ([]schema.APIResponse, error) { response = append(response, realtimeAPIList...) response = append(response, batchAPIList...) response = append(response, taskAPIList...) + response = append(response, asyncAPIList...) response = append(response, trafficSplitterList...) return response, nil @@ -503,13 +540,23 @@ func GetAPI(apiName string) ([]schema.APIResponse, error) { if err != nil { return nil, err } + case userconfig.AsyncAPIKind: + apiResponse, err = asyncapi.GetAPIByName(deployedResource) + if err != nil { + return nil, err + } case userconfig.TrafficSplitterKind: apiResponse, err = trafficsplitter.GetAPIByName(deployedResource) if err != nil { return nil, err } default: - return nil, ErrorOperationIsOnlySupportedForKind(*deployedResource, userconfig.RealtimeAPIKind, userconfig.BatchAPIKind) // unexpected + return nil, ErrorOperationIsOnlySupportedForKind( + *deployedResource, + userconfig.RealtimeAPIKind, userconfig.BatchAPIKind, + userconfig.TaskAPIKind, userconfig.TrafficSplitterKind, + userconfig.AsyncAPIKind, + ) // unexpected } // Get past API deploy times diff --git a/pkg/operator/resources/validations.go b/pkg/operator/resources/validations.go index c479373160..7e62417c2a 100644 --- a/pkg/operator/resources/validations.go +++ b/pkg/operator/resources/validations.go @@ -91,7 +91,9 @@ func ValidateClusterAPIs(apis []userconfig.API, projectFiles spec.ProjectFiles) for i := range apis { api := &apis[i] - if api.Kind == userconfig.RealtimeAPIKind || api.Kind == userconfig.BatchAPIKind || api.Kind == userconfig.TaskAPIKind { + if api.Kind == userconfig.RealtimeAPIKind || api.Kind == userconfig.BatchAPIKind || + api.Kind == userconfig.TaskAPIKind || api.Kind == userconfig.AsyncAPIKind { + if err := spec.ValidateAPI(api, nil, projectFiles, config.Provider, config.AWS, config.GCP, config.K8s); err != nil { return errors.Wrap(err, api.Identify()) } @@ -122,7 +124,7 @@ func ValidateClusterAPIs(apis []userconfig.API, projectFiles spec.ProjectFiles) for i := range apis { api := &apis[i] - if api.Kind == userconfig.RealtimeAPIKind || api.Kind == userconfig.BatchAPIKind || api.Kind == userconfig.TaskAPIKind { + if api.Kind == userconfig.RealtimeAPIKind || api.Kind == userconfig.AsyncAPIKind || api.Kind == userconfig.BatchAPIKind || api.Kind == userconfig.TaskAPIKind { if err := awsManagedValidateK8sCompute(api.Compute, maxMemMap); err != nil { return err } diff --git a/pkg/types/clusterconfig/aws_policy.go b/pkg/types/clusterconfig/aws_policy.go index 8dba939aef..670b4744de 100644 --- a/pkg/types/clusterconfig/aws_policy.go +++ b/pkg/types/clusterconfig/aws_policy.go @@ -53,7 +53,7 @@ var _cortexPolicy = ` { "Effect": "Allow", "Action": "sqs:*", - "Resource": "arn:aws:sqs:{{ .Region }}:{{ .AccountID }}:cortex-*" + "Resource": "arn:aws:sqs:{{ .Region }}:{{ .AccountID }}:cx-*" }, { "Effect": "Allow", diff --git a/pkg/types/clusterconfig/cluster_config_aws.go b/pkg/types/clusterconfig/cluster_config_aws.go index 6a4473318a..6806e7dbda 100644 --- a/pkg/types/clusterconfig/cluster_config_aws.go +++ b/pkg/types/clusterconfig/cluster_config_aws.go @@ -70,6 +70,7 @@ type CoreConfig struct { ImageManager string `json:"image_manager" yaml:"image_manager"` ImageDownloader string `json:"image_downloader" yaml:"image_downloader"` ImageRequestMonitor string `json:"image_request_monitor" yaml:"image_request_monitor"` + ImageAsyncGateway string `json:"image_async_gateway" yaml:"image_async_gateway"` ImageClusterAutoscaler string `json:"image_cluster_autoscaler" yaml:"image_cluster_autoscaler"` ImageMetricsServer string `json:"image_metrics_server" yaml:"image_metrics_server"` ImageInferentia string `json:"image_inferentia" yaml:"image_inferentia"` @@ -286,6 +287,13 @@ var CoreConfigStructFieldValidations = []*cr.StructFieldValidation{ Validator: validateImageVersion, }, }, + { + StructField: "ImageAsyncGateway", + StringValidation: &cr.StringValidation{ + Default: "quay.io/cortexlabs/async-gateway:" + consts.CortexVersion, + Validator: validateImageVersion, + }, + }, { StructField: "ImageClusterAutoscaler", StringValidation: &cr.StringValidation{ @@ -729,15 +737,10 @@ func (cc *Config) ToAccessConfig() AccessConfig { func SQSNamePrefix(clusterName string) string { // 8 was chosen to make sure that other identifiers can be added to the full queue name before reaching the 80 char SQS name limit - return "cortex-" + hash.String(clusterName)[:8] + "-" -} - -// returns hash of cluster name and adds trailing "-" -func (cc *Config) SQSNamePrefix() string { - return SQSNamePrefix(cc.ClusterName) + return "cx-" + hash.String(clusterName)[:8] + "-" } -// returns hash of cluster name and adds trailing "-" +// returns hash of cluster name and adds trailing "-" e.g. cx-abcd1234- func (cc *CoreConfig) SQSNamePrefix() string { return SQSNamePrefix(cc.ClusterName) } diff --git a/pkg/types/spec/errors.go b/pkg/types/spec/errors.go index 52522476eb..75cd47b6c5 100644 --- a/pkg/types/spec/errors.go +++ b/pkg/types/spec/errors.go @@ -76,6 +76,7 @@ const ( ErrFieldMustBeDefinedForPredictorType = "spec.field_must_be_defined_for_predictor_type" ErrFieldNotSupportedByPredictorType = "spec.field_not_supported_by_predictor_type" + ErrPredictorTypeNotSupportedForKind = "spec.predictor_type_not_supported_by_kind" ErrNoAvailableNodeComputeLimit = "spec.no_available_node_compute_limit" ErrCortexPrefixedEnvVarNotAllowed = "spec.cortex_prefixed_env_var_not_allowed" ErrUnsupportedComputeResourceForProvider = "spec.unsupported_compute_resource_for_provider" @@ -537,6 +538,13 @@ func ErrorKeyIsNotSupportedForKind(key string, kind userconfig.Kind) error { }) } +func ErrorPredictorTypeNotSupportedForKind(predictorType userconfig.PredictorType, kind userconfig.Kind) error { + return errors.WithStack(&errors.Error{ + Kind: ErrPredictorTypeNotSupportedForKind, + Message: fmt.Sprintf("%s predictor type is not supported for %s kind", predictorType.String(), kind.String()), + }) +} + func ErrorComputeResourceConflict(resourceA, resourceB string) error { return errors.WithStack(&errors.Error{ Kind: ErrComputeResourceConflict, diff --git a/pkg/types/spec/validations.go b/pkg/types/spec/validations.go index 3c476ffdb3..0aa4860a74 100644 --- a/pkg/types/spec/validations.go +++ b/pkg/types/spec/validations.go @@ -64,6 +64,14 @@ func apiValidation( autoscalingValidation(), updateStrategyValidation(), ) + case userconfig.AsyncAPIKind: + structFieldValidations = append(resourceStructValidations, + predictorValidation(), + networkingValidation(), + computeValidation(provider), + autoscalingValidation(), + updateStrategyValidation(), + ) case userconfig.BatchAPIKind: structFieldValidations = append(resourceStructValidations, predictorValidation(), @@ -91,9 +99,10 @@ var resourceStructValidations = []*cr.StructFieldValidation{ { StructField: "Name", StringValidation: &cr.StringValidation{ - Required: true, - DNS1035: true, - MaxLength: 42, // k8s adds 21 characters to the pod name, and 63 is the max before it starts to truncate + Required: true, + DNS1035: true, + InvalidPrefixes: []string{"b-"}, // collides with our sqs names + MaxLength: 42, // k8s adds 21 characters to the pod name, and 63 is the max before it starts to truncate }, }, { @@ -710,7 +719,7 @@ func ExtractAPIConfigs( return nil, errors.Append(err, fmt.Sprintf("\n\napi configuration schema can be found at https://docs.cortex.dev/v/%s/", consts.CortexVersionMinor)) } - if resourceStruct.Kind == userconfig.BatchAPIKind { + if resourceStruct.Kind == userconfig.BatchAPIKind || resourceStruct.Kind == userconfig.AsyncAPIKind { if provider == types.GCPProviderType { return nil, errors.Wrap( ErrorKindIsNotSupportedByProvider(resourceStruct.Kind, provider), @@ -739,7 +748,8 @@ func ExtractAPIConfigs( if resourceStruct.Kind == userconfig.RealtimeAPIKind || resourceStruct.Kind == userconfig.BatchAPIKind || - resourceStruct.Kind == userconfig.TaskAPIKind { + resourceStruct.Kind == userconfig.TaskAPIKind || + resourceStruct.Kind == userconfig.AsyncAPIKind { api.ApplyDefaultDockerPaths() } @@ -873,6 +883,9 @@ func validatePredictor( k8sClient *k8s.Client, ) error { predictor := api.Predictor + if api.Kind == userconfig.AsyncAPIKind && predictor.Type != userconfig.PythonPredictorType { + return ErrorPredictorTypeNotSupportedForKind(predictor.Type, api.Kind) + } if err := validateMultiModelsFields(api); err != nil { return err @@ -896,21 +909,21 @@ func validatePredictor( } } - if api.Kind == userconfig.BatchAPIKind { + if api.Kind == userconfig.BatchAPIKind || api.Kind == userconfig.AsyncAPIKind { if predictor.MultiModelReloading != nil { - return ErrorKeyIsNotSupportedForKind(userconfig.MultiModelReloadingKey, userconfig.BatchAPIKind) + return ErrorKeyIsNotSupportedForKind(userconfig.MultiModelReloadingKey, api.Kind) } if predictor.ServerSideBatching != nil { - return ErrorKeyIsNotSupportedForKind(userconfig.ServerSideBatchingKey, userconfig.BatchAPIKind) + return ErrorKeyIsNotSupportedForKind(userconfig.ServerSideBatchingKey, api.Kind) } if predictor.ProcessesPerReplica > 1 { - return ErrorKeyIsNotSupportedForKind(userconfig.ProcessesPerReplicaKey, userconfig.BatchAPIKind) + return ErrorKeyIsNotSupportedForKind(userconfig.ProcessesPerReplicaKey, api.Kind) } if predictor.ThreadsPerProcess > 1 { - return ErrorKeyIsNotSupportedForKind(userconfig.ThreadsPerProcessKey, userconfig.BatchAPIKind) + return ErrorKeyIsNotSupportedForKind(userconfig.ThreadsPerProcessKey, api.Kind) } } diff --git a/pkg/types/userconfig/api.go b/pkg/types/userconfig/api.go index 238a85bd79..2a0a82fe6c 100644 --- a/pkg/types/userconfig/api.go +++ b/pkg/types/userconfig/api.go @@ -154,7 +154,7 @@ func (api *API) ApplyDefaultDockerPaths() { usesInf := api.Compute.Inf > 0 switch api.Kind { - case RealtimeAPIKind, BatchAPIKind: + case RealtimeAPIKind, BatchAPIKind, AsyncAPIKind: api.applyPredictorDefaultDockerPaths(usesGPU, usesInf) case TaskAPIKind: api.applyTaskDefaultDockerPaths(usesGPU, usesInf) diff --git a/pkg/types/userconfig/kind.go b/pkg/types/userconfig/kind.go index e2d9ded10c..46a2e566be 100644 --- a/pkg/types/userconfig/kind.go +++ b/pkg/types/userconfig/kind.go @@ -24,6 +24,7 @@ const ( BatchAPIKind TrafficSplitterKind TaskAPIKind + AsyncAPIKind ) var _kinds = []string{ @@ -32,6 +33,7 @@ var _kinds = []string{ "BatchAPI", "TrafficSplitter", "TaskAPI", + "AsyncAPI", } func KindFromString(s string) Kind { diff --git a/test/apis/async/cortex.yaml b/test/apis/async/cortex.yaml new file mode 100644 index 0000000000..249acd3a0c --- /dev/null +++ b/test/apis/async/cortex.yaml @@ -0,0 +1,11 @@ +- name: iris-classifier + kind: AsyncAPI + predictor: + type: python + path: predictor.py + config: + bucket: cortex-examples + key: sklearn/iris-classifier/model.pkl + compute: + cpu: 0.2 + mem: 200M diff --git a/test/apis/async/predictor.py b/test/apis/async/predictor.py new file mode 100644 index 0000000000..50fb50e1a6 --- /dev/null +++ b/test/apis/async/predictor.py @@ -0,0 +1,30 @@ +import os +import pickle + +import boto3 +from botocore import UNSIGNED +from botocore.client import Config + +labels = ["setosa", "versicolor", "virginica"] + + +class PythonPredictor: + def __init__(self, config): + if os.environ.get("AWS_ACCESS_KEY_ID"): + s3 = boto3.client("s3") # client will use your credentials if available + else: + s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) # anonymous client + + s3.download_file(config["bucket"], config["key"], "/tmp/model.pkl") + self.model = pickle.load(open("/tmp/model.pkl", "rb")) + + def predict(self, payload): + measurements = [ + payload["sepal_length"], + payload["sepal_width"], + payload["petal_length"], + payload["petal_width"], + ] + + label_id = self.model.predict([measurements])[0] + return {"result": labels[label_id]} diff --git a/test/apis/async/requirements.txt b/test/apis/async/requirements.txt new file mode 100644 index 0000000000..bbc213cf3e --- /dev/null +++ b/test/apis/async/requirements.txt @@ -0,0 +1,2 @@ +boto3 +scikit-learn==0.21.3 diff --git a/test/apis/async/sample.json b/test/apis/async/sample.json new file mode 100644 index 0000000000..1e1eda2251 --- /dev/null +++ b/test/apis/async/sample.json @@ -0,0 +1,6 @@ +{ + "sepal_length": 5.2, + "sepal_width": 3.6, + "petal_length": 1.5, + "petal_width": 0.3 +}