From 63da33d68fa50c7c9a698bc645d3ba0c8cde96cc Mon Sep 17 00:00:00 2001 From: Miguel Varela Ramos Date: Fri, 9 Jul 2021 15:42:06 +0200 Subject: [PATCH 1/5] Forward headers to async api --- pkg/async-gateway/endpoint.go | 2 +- pkg/async-gateway/service.go | 32 ++++++++++++++++++++++++++------ pkg/dequeuer/async_handler.go | 25 +++++++++++++++++++++++-- pkg/types/async/s3_paths.go | 4 ++++ 4 files changed, 54 insertions(+), 9 deletions(-) diff --git a/pkg/async-gateway/endpoint.go b/pkg/async-gateway/endpoint.go index 9e52b95e48..d18af1f05e 100644 --- a/pkg/async-gateway/endpoint.go +++ b/pkg/async-gateway/endpoint.go @@ -63,7 +63,7 @@ func (e *Endpoint) CreateWorkload(w http.ResponseWriter, r *http.Request) { log := e.logger.With(zap.String("id", requestID), zap.String("contentType", contentType)) - id, err := e.service.CreateWorkload(requestID, body, contentType) + id, err := e.service.CreateWorkload(requestID, body, r.Header) if err != nil { respondPlainText(w, http.StatusInternalServerError, fmt.Sprintf("error: %v", err)) logErrorWithTelemetry(log, errors.Wrap(err, "failed to create workload")) diff --git a/pkg/async-gateway/service.go b/pkg/async-gateway/service.go index bba7a611e9..afa049fa7e 100644 --- a/pkg/async-gateway/service.go +++ b/pkg/async-gateway/service.go @@ -17,18 +17,21 @@ limitations under the License. package gateway import ( + "bytes" "encoding/json" "fmt" "io" + "net/http" "strings" + "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/types/async" "go.uber.org/zap" ) // Service provides an interface to the async-gateway business logic type Service interface { - CreateWorkload(id string, payload io.Reader, contentType string) (string, error) + CreateWorkload(id string, payload io.Reader, headers http.Header) (string, error) GetWorkload(id string) (GetWorkloadResponse, error) } @@ -52,25 +55,42 @@ func NewService(clusterUID, apiName string, queue Queue, storage Storage, logger } // CreateWorkload enqueues an async workload request and uploads the request payload to S3 -func (s *service) CreateWorkload(id string, payload io.Reader, contentType string) (string, error) { +func (s *service) CreateWorkload(id string, payload io.Reader, headers http.Header) (string, error) { + contentType := headers.Get("Content-Type") + if contentType == "" { + return "", errors.ErrorUnexpected("missing content-type in headers") + } + headers.Del("Content-Type") + prefix := async.StoragePath(s.clusterUID, s.apiName) log := s.logger.With(zap.String("id", id), zap.String("contentType", contentType)) + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(headers); err != nil { + return "", errors.Wrap(err, "failed to dump headers") + } + + headersPath := async.HeadersPath(prefix, id) + log.Debugw("uploading headers", zap.String("path", headersPath)) + if err := s.storage.Upload(headersPath, buf, "application/json"); err != nil { + return "", errors.Wrap(err, "failed to upload headers") + } + payloadPath := async.PayloadPath(prefix, id) - log.Debug("uploading payload", zap.String("path", payloadPath)) + log.Debugw("uploading payload", zap.String("path", payloadPath)) if err := s.storage.Upload(payloadPath, payload, contentType); err != nil { - return "", err + return "", errors.Wrap(err, "failed to upload payload") } log.Debug("sending message to queue") if err := s.queue.SendMessage(id, id); err != nil { - return "", err + return "", errors.Wrap(err, "failed to send message to queue") } statusPath := fmt.Sprintf("%s/%s/status/%s", prefix, id, async.StatusInQueue) log.Debug(fmt.Sprintf("setting status to %s", async.StatusInQueue)) if err := s.storage.Upload(statusPath, strings.NewReader(""), "text/plain"); err != nil { - return "", err + return "", errors.Wrap(err, "failed to upload workload status") } return id, nil diff --git a/pkg/dequeuer/async_handler.go b/pkg/dequeuer/async_handler.go index 063ef5289a..17852152bd 100644 --- a/pkg/dequeuer/async_handler.go +++ b/pkg/dequeuer/async_handler.go @@ -106,7 +106,16 @@ func (h *AsyncMessageHandler) handleMessage(requestID string) error { } defer h.deletePayload(requestID) - result, err := h.submitRequest(payload, requestID) + headers, err := h.getHeaders(requestID) + if err != nil { + updateStatusErr := h.updateStatus(requestID, async.StatusFailed) + if updateStatusErr != nil { + h.log.Errorw("failed to update status after failure to get headers", "id", requestID, "error", updateStatusErr) + } + return errors.Wrap(err, "failed to get payload") + } + + result, err := h.submitRequest(payload, headers, requestID) if err != nil { h.log.Errorw("failed to submit request to user container", "id", requestID, "error", err) updateStatusErr := h.updateStatus(requestID, async.StatusFailed) @@ -170,12 +179,13 @@ func (h *AsyncMessageHandler) deletePayload(requestID string) { } } -func (h *AsyncMessageHandler) submitRequest(payload *userPayload, requestID string) (interface{}, error) { +func (h *AsyncMessageHandler) submitRequest(payload *userPayload, headers http.Header, requestID string) (interface{}, error) { req, err := http.NewRequest(http.MethodPost, h.config.TargetURL, payload.Body) if err != nil { return nil, errors.WithStack(err) } + req.Header = headers req.Header.Set("Content-Type", payload.ContentType) req.Header.Set(CortexRequestIDHeader, requestID) @@ -216,3 +226,14 @@ func (h *AsyncMessageHandler) uploadResult(requestID string, result interface{}) key := async.ResultPath(h.storagePath, requestID) return h.aws.UploadJSONToS3(result, h.config.Bucket, key) } + +func (h *AsyncMessageHandler) getHeaders(requestID string) (http.Header, error) { + key := async.HeadersPath(h.storagePath, requestID) + + var headers http.Header + if err := h.aws.ReadJSONFromS3(&headers, h.config.Bucket, key); err != nil { + return nil, err + } + + return headers, nil +} diff --git a/pkg/types/async/s3_paths.go b/pkg/types/async/s3_paths.go index dab5c7fa45..b61983dc64 100644 --- a/pkg/types/async/s3_paths.go +++ b/pkg/types/async/s3_paths.go @@ -28,6 +28,10 @@ func PayloadPath(storagePath string, requestID string) string { return fmt.Sprintf("%s/%s/payload", storagePath, requestID) } +func HeadersPath(storagePath string, requestID string) string { + return fmt.Sprintf("%s/%s/headers.json", storagePath, requestID) +} + func ResultPath(storagePath string, requestID string) string { return fmt.Sprintf("%s/%s/result.json", storagePath, requestID) } From d153dd6a469c97e495321819e07b0d3e1242a6b0 Mon Sep 17 00:00:00 2001 From: Miguel Varela Ramos Date: Fri, 9 Jul 2021 16:02:07 +0200 Subject: [PATCH 2/5] Fix async handler test in dequeuer --- pkg/dequeuer/async_handler_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/dequeuer/async_handler_test.go b/pkg/dequeuer/async_handler_test.go index e59c25791d..02401985b2 100644 --- a/pkg/dequeuer/async_handler_test.go +++ b/pkg/dequeuer/async_handler_test.go @@ -68,7 +68,10 @@ func TestAsyncMessageHandler_Handle(t *testing.T) { }) require.NoError(t, err) - err = awsClient.UploadStringToS3("{}", asyncHandler.config.Bucket, fmt.Sprintf("%s/%s/payload", asyncHandler.storagePath, requestID)) + err = awsClient.UploadStringToS3("{}", asyncHandler.config.Bucket, async.PayloadPath(asyncHandler.storagePath, requestID)) + require.NoError(t, err) + + err = awsClient.UploadStringToS3("{}", asyncHandler.config.Bucket, async.HeadersPath(asyncHandler.storagePath, requestID)) require.NoError(t, err) err = asyncHandler.Handle(&sqs.Message{ From 6c30b72167d7863ccd618da4ad80d5c80456b61c Mon Sep 17 00:00:00 2001 From: Miguel Varela Ramos Date: Fri, 16 Jul 2021 15:32:50 +0200 Subject: [PATCH 3/5] Allow random content types in async request headers --- pkg/async-gateway/endpoint.go | 8 +------- pkg/async-gateway/service.go | 9 ++------- pkg/dequeuer/async_handler.go | 1 - 3 files changed, 3 insertions(+), 15 deletions(-) diff --git a/pkg/async-gateway/endpoint.go b/pkg/async-gateway/endpoint.go index c3c80d6a75..76c1ee22ec 100644 --- a/pkg/async-gateway/endpoint.go +++ b/pkg/async-gateway/endpoint.go @@ -50,18 +50,12 @@ func (e *Endpoint) CreateWorkload(w http.ResponseWriter, r *http.Request) { return } - contentType := r.Header.Get("Content-Type") - if contentType == "" { - respondPlainText(w, http.StatusBadRequest, "error: missing Content-Type key in request header") - return - } - body := r.Body defer func() { _ = r.Body.Close() }() - log := e.logger.With(zap.String("id", requestID), zap.String("contentType", contentType)) + log := e.logger.With(zap.String("id", requestID)) id, err := e.service.CreateWorkload(requestID, body, r.Header) if err != nil { diff --git a/pkg/async-gateway/service.go b/pkg/async-gateway/service.go index afa049fa7e..344d252ad6 100644 --- a/pkg/async-gateway/service.go +++ b/pkg/async-gateway/service.go @@ -56,14 +56,8 @@ func NewService(clusterUID, apiName string, queue Queue, storage Storage, logger // CreateWorkload enqueues an async workload request and uploads the request payload to S3 func (s *service) CreateWorkload(id string, payload io.Reader, headers http.Header) (string, error) { - contentType := headers.Get("Content-Type") - if contentType == "" { - return "", errors.ErrorUnexpected("missing content-type in headers") - } - headers.Del("Content-Type") - prefix := async.StoragePath(s.clusterUID, s.apiName) - log := s.logger.With(zap.String("id", id), zap.String("contentType", contentType)) + log := s.logger.With(zap.String("id", id)) buf := &bytes.Buffer{} if err := json.NewEncoder(buf).Encode(headers); err != nil { @@ -76,6 +70,7 @@ func (s *service) CreateWorkload(id string, payload io.Reader, headers http.Head return "", errors.Wrap(err, "failed to upload headers") } + contentType := headers.Get("Content-Type") payloadPath := async.PayloadPath(prefix, id) log.Debugw("uploading payload", zap.String("path", payloadPath)) if err := s.storage.Upload(payloadPath, payload, contentType); err != nil { diff --git a/pkg/dequeuer/async_handler.go b/pkg/dequeuer/async_handler.go index 17852152bd..da68dc0bd7 100644 --- a/pkg/dequeuer/async_handler.go +++ b/pkg/dequeuer/async_handler.go @@ -186,7 +186,6 @@ func (h *AsyncMessageHandler) submitRequest(payload *userPayload, headers http.H } req.Header = headers - req.Header.Set("Content-Type", payload.ContentType) req.Header.Set(CortexRequestIDHeader, requestID) startTime := time.Now() From 21daf8bb01775a15443649506ad93406298ac3c2 Mon Sep 17 00:00:00 2001 From: Miguel Varela Ramos Date: Mon, 19 Jul 2021 12:30:46 +0200 Subject: [PATCH 4/5] Remove userPayload struct as it is no longer necessary --- pkg/dequeuer/async_handler.go | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/pkg/dequeuer/async_handler.go b/pkg/dequeuer/async_handler.go index da68dc0bd7..c2065cdb9b 100644 --- a/pkg/dequeuer/async_handler.go +++ b/pkg/dequeuer/async_handler.go @@ -55,11 +55,6 @@ type AsyncMessageHandlerConfig struct { TargetURL string } -type userPayload struct { - Body io.ReadCloser - ContentType string -} - func NewAsyncMessageHandler(config AsyncMessageHandlerConfig, awsClient *awslib.Client, eventHandler RequestEventHandler, logger *zap.SugaredLogger) *AsyncMessageHandler { return &AsyncMessageHandler{ config: config, @@ -104,7 +99,10 @@ func (h *AsyncMessageHandler) handleMessage(requestID string) error { } return errors.Wrap(err, "failed to get payload") } - defer h.deletePayload(requestID) + defer func() { + h.deletePayload(requestID) + _ = payload.Close() + }() headers, err := h.getHeaders(requestID) if err != nil { @@ -147,7 +145,7 @@ func (h *AsyncMessageHandler) updateStatus(requestID string, status async.Status return h.aws.UploadStringToS3("", h.config.Bucket, key) } -func (h *AsyncMessageHandler) getPayload(requestID string) (*userPayload, error) { +func (h *AsyncMessageHandler) getPayload(requestID string) (io.ReadCloser, error) { key := async.PayloadPath(h.storagePath, requestID) output, err := h.aws.S3().GetObject( &s3.GetObjectInput{ @@ -158,16 +156,7 @@ func (h *AsyncMessageHandler) getPayload(requestID string) (*userPayload, error) if err != nil { return nil, errors.WithStack(err) } - - contentType := "application/octet-stream" - if output.ContentType != nil { - contentType = *output.ContentType - } - - return &userPayload{ - Body: output.Body, - ContentType: contentType, - }, nil + return output.Body, nil } func (h *AsyncMessageHandler) deletePayload(requestID string) { @@ -179,8 +168,8 @@ func (h *AsyncMessageHandler) deletePayload(requestID string) { } } -func (h *AsyncMessageHandler) submitRequest(payload *userPayload, headers http.Header, requestID string) (interface{}, error) { - req, err := http.NewRequest(http.MethodPost, h.config.TargetURL, payload.Body) +func (h *AsyncMessageHandler) submitRequest(payload io.Reader, headers http.Header, requestID string) (interface{}, error) { + req, err := http.NewRequest(http.MethodPost, h.config.TargetURL, payload) if err != nil { return nil, errors.WithStack(err) } From e7465d3580388f10eb9f26cbe409357bdc19e4ee Mon Sep 17 00:00:00 2001 From: Miguel Varela Ramos Date: Mon, 19 Jul 2021 12:32:40 +0200 Subject: [PATCH 5/5] Fix e2e tests subtest ids --- test/e2e/tests/aws/test_autoscaling.py | 2 +- test/e2e/tests/aws/test_realtime.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/e2e/tests/aws/test_autoscaling.py b/test/e2e/tests/aws/test_autoscaling.py index b58164070b..6274c5b286 100644 --- a/test/e2e/tests/aws/test_autoscaling.py +++ b/test/e2e/tests/aws/test_autoscaling.py @@ -31,7 +31,7 @@ @pytest.mark.usefixtures("client") -@pytest.mark.parametrize("apis", TEST_APIS) +@pytest.mark.parametrize("apis", TEST_APIS, ids=[api["primary"] for api in TEST_APIS]) def test_autoscaling(printer: Callable, config: Dict, client: cx.Client, apis: Dict[str, Any]): skip_autoscaling_test = config["global"].get("skip_autoscaling", False) if skip_autoscaling_test: diff --git a/test/e2e/tests/aws/test_realtime.py b/test/e2e/tests/aws/test_realtime.py index 6b45e399b0..92c702b982 100644 --- a/test/e2e/tests/aws/test_realtime.py +++ b/test/e2e/tests/aws/test_realtime.py @@ -64,10 +64,8 @@ def test_realtime_api(printer: Callable, config: Dict, client: cx.Client, api: D @pytest.mark.usefixtures("client") -@pytest.mark.parametrize("api", TEST_APIS_ARM) +@pytest.mark.parametrize("api", TEST_APIS_ARM, ids=[api["name"] for api in TEST_APIS_ARM]) def test_realtime_api_arm(printer: Callable, config: Dict, client: cx.Client, api: Dict[str, str]): - - printer(f"testing {api['name']}") e2e.tests.test_realtime_api( printer=printer, client=client,