diff --git a/pkg/async-gateway/endpoint.go b/pkg/async-gateway/endpoint.go index 26e628c1cb..76c1ee22ec 100644 --- a/pkg/async-gateway/endpoint.go +++ b/pkg/async-gateway/endpoint.go @@ -50,20 +50,14 @@ func (e *Endpoint) CreateWorkload(w http.ResponseWriter, r *http.Request) { return } - contentType := r.Header.Get("Content-Type") - if contentType == "" { - respondPlainText(w, http.StatusBadRequest, "error: missing Content-Type key in request header") - return - } - body := r.Body defer func() { _ = r.Body.Close() }() - log := e.logger.With(zap.String("id", requestID), zap.String("contentType", contentType)) + log := e.logger.With(zap.String("id", requestID)) - id, err := e.service.CreateWorkload(requestID, body, contentType) + id, err := e.service.CreateWorkload(requestID, body, r.Header) if err != nil { respondPlainText(w, http.StatusInternalServerError, fmt.Sprintf("error: %v", err)) logErrorWithTelemetry(log, errors.Wrap(err, "failed to create workload")) diff --git a/pkg/async-gateway/service.go b/pkg/async-gateway/service.go index bba7a611e9..344d252ad6 100644 --- a/pkg/async-gateway/service.go +++ b/pkg/async-gateway/service.go @@ -17,18 +17,21 @@ limitations under the License. package gateway import ( + "bytes" "encoding/json" "fmt" "io" + "net/http" "strings" + "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/types/async" "go.uber.org/zap" ) // Service provides an interface to the async-gateway business logic type Service interface { - CreateWorkload(id string, payload io.Reader, contentType string) (string, error) + CreateWorkload(id string, payload io.Reader, headers http.Header) (string, error) GetWorkload(id string) (GetWorkloadResponse, error) } @@ -52,25 +55,37 @@ func NewService(clusterUID, apiName string, queue Queue, storage Storage, logger } // CreateWorkload enqueues an async workload request and uploads the request payload to S3 -func (s *service) CreateWorkload(id string, payload io.Reader, contentType string) (string, error) { +func (s *service) CreateWorkload(id string, payload io.Reader, headers http.Header) (string, error) { prefix := async.StoragePath(s.clusterUID, s.apiName) - log := s.logger.With(zap.String("id", id), zap.String("contentType", contentType)) + log := s.logger.With(zap.String("id", id)) + + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(headers); err != nil { + return "", errors.Wrap(err, "failed to dump headers") + } + + headersPath := async.HeadersPath(prefix, id) + log.Debugw("uploading headers", zap.String("path", headersPath)) + if err := s.storage.Upload(headersPath, buf, "application/json"); err != nil { + return "", errors.Wrap(err, "failed to upload headers") + } + contentType := headers.Get("Content-Type") payloadPath := async.PayloadPath(prefix, id) - log.Debug("uploading payload", zap.String("path", payloadPath)) + log.Debugw("uploading payload", zap.String("path", payloadPath)) if err := s.storage.Upload(payloadPath, payload, contentType); err != nil { - return "", err + return "", errors.Wrap(err, "failed to upload payload") } log.Debug("sending message to queue") if err := s.queue.SendMessage(id, id); err != nil { - return "", err + return "", errors.Wrap(err, "failed to send message to queue") } statusPath := fmt.Sprintf("%s/%s/status/%s", prefix, id, async.StatusInQueue) log.Debug(fmt.Sprintf("setting status to %s", async.StatusInQueue)) if err := s.storage.Upload(statusPath, strings.NewReader(""), "text/plain"); err != nil { - return "", err + return "", errors.Wrap(err, "failed to upload workload status") } return id, nil diff --git a/pkg/dequeuer/async_handler.go b/pkg/dequeuer/async_handler.go index 063ef5289a..c2065cdb9b 100644 --- a/pkg/dequeuer/async_handler.go +++ b/pkg/dequeuer/async_handler.go @@ -55,11 +55,6 @@ type AsyncMessageHandlerConfig struct { TargetURL string } -type userPayload struct { - Body io.ReadCloser - ContentType string -} - func NewAsyncMessageHandler(config AsyncMessageHandlerConfig, awsClient *awslib.Client, eventHandler RequestEventHandler, logger *zap.SugaredLogger) *AsyncMessageHandler { return &AsyncMessageHandler{ config: config, @@ -104,9 +99,21 @@ func (h *AsyncMessageHandler) handleMessage(requestID string) error { } return errors.Wrap(err, "failed to get payload") } - defer h.deletePayload(requestID) + defer func() { + h.deletePayload(requestID) + _ = payload.Close() + }() - result, err := h.submitRequest(payload, requestID) + headers, err := h.getHeaders(requestID) + if err != nil { + updateStatusErr := h.updateStatus(requestID, async.StatusFailed) + if updateStatusErr != nil { + h.log.Errorw("failed to update status after failure to get headers", "id", requestID, "error", updateStatusErr) + } + return errors.Wrap(err, "failed to get payload") + } + + result, err := h.submitRequest(payload, headers, requestID) if err != nil { h.log.Errorw("failed to submit request to user container", "id", requestID, "error", err) updateStatusErr := h.updateStatus(requestID, async.StatusFailed) @@ -138,7 +145,7 @@ func (h *AsyncMessageHandler) updateStatus(requestID string, status async.Status return h.aws.UploadStringToS3("", h.config.Bucket, key) } -func (h *AsyncMessageHandler) getPayload(requestID string) (*userPayload, error) { +func (h *AsyncMessageHandler) getPayload(requestID string) (io.ReadCloser, error) { key := async.PayloadPath(h.storagePath, requestID) output, err := h.aws.S3().GetObject( &s3.GetObjectInput{ @@ -149,16 +156,7 @@ func (h *AsyncMessageHandler) getPayload(requestID string) (*userPayload, error) if err != nil { return nil, errors.WithStack(err) } - - contentType := "application/octet-stream" - if output.ContentType != nil { - contentType = *output.ContentType - } - - return &userPayload{ - Body: output.Body, - ContentType: contentType, - }, nil + return output.Body, nil } func (h *AsyncMessageHandler) deletePayload(requestID string) { @@ -170,13 +168,13 @@ func (h *AsyncMessageHandler) deletePayload(requestID string) { } } -func (h *AsyncMessageHandler) submitRequest(payload *userPayload, requestID string) (interface{}, error) { - req, err := http.NewRequest(http.MethodPost, h.config.TargetURL, payload.Body) +func (h *AsyncMessageHandler) submitRequest(payload io.Reader, headers http.Header, requestID string) (interface{}, error) { + req, err := http.NewRequest(http.MethodPost, h.config.TargetURL, payload) if err != nil { return nil, errors.WithStack(err) } - req.Header.Set("Content-Type", payload.ContentType) + req.Header = headers req.Header.Set(CortexRequestIDHeader, requestID) startTime := time.Now() @@ -216,3 +214,14 @@ func (h *AsyncMessageHandler) uploadResult(requestID string, result interface{}) key := async.ResultPath(h.storagePath, requestID) return h.aws.UploadJSONToS3(result, h.config.Bucket, key) } + +func (h *AsyncMessageHandler) getHeaders(requestID string) (http.Header, error) { + key := async.HeadersPath(h.storagePath, requestID) + + var headers http.Header + if err := h.aws.ReadJSONFromS3(&headers, h.config.Bucket, key); err != nil { + return nil, err + } + + return headers, nil +} diff --git a/pkg/dequeuer/async_handler_test.go b/pkg/dequeuer/async_handler_test.go index e59c25791d..02401985b2 100644 --- a/pkg/dequeuer/async_handler_test.go +++ b/pkg/dequeuer/async_handler_test.go @@ -68,7 +68,10 @@ func TestAsyncMessageHandler_Handle(t *testing.T) { }) require.NoError(t, err) - err = awsClient.UploadStringToS3("{}", asyncHandler.config.Bucket, fmt.Sprintf("%s/%s/payload", asyncHandler.storagePath, requestID)) + err = awsClient.UploadStringToS3("{}", asyncHandler.config.Bucket, async.PayloadPath(asyncHandler.storagePath, requestID)) + require.NoError(t, err) + + err = awsClient.UploadStringToS3("{}", asyncHandler.config.Bucket, async.HeadersPath(asyncHandler.storagePath, requestID)) require.NoError(t, err) err = asyncHandler.Handle(&sqs.Message{ diff --git a/pkg/types/async/s3_paths.go b/pkg/types/async/s3_paths.go index dab5c7fa45..b61983dc64 100644 --- a/pkg/types/async/s3_paths.go +++ b/pkg/types/async/s3_paths.go @@ -28,6 +28,10 @@ func PayloadPath(storagePath string, requestID string) string { return fmt.Sprintf("%s/%s/payload", storagePath, requestID) } +func HeadersPath(storagePath string, requestID string) string { + return fmt.Sprintf("%s/%s/headers.json", storagePath, requestID) +} + func ResultPath(storagePath string, requestID string) string { return fmt.Sprintf("%s/%s/result.json", storagePath, requestID) } diff --git a/test/e2e/tests/aws/test_autoscaling.py b/test/e2e/tests/aws/test_autoscaling.py index b58164070b..6274c5b286 100644 --- a/test/e2e/tests/aws/test_autoscaling.py +++ b/test/e2e/tests/aws/test_autoscaling.py @@ -31,7 +31,7 @@ @pytest.mark.usefixtures("client") -@pytest.mark.parametrize("apis", TEST_APIS) +@pytest.mark.parametrize("apis", TEST_APIS, ids=[api["primary"] for api in TEST_APIS]) def test_autoscaling(printer: Callable, config: Dict, client: cx.Client, apis: Dict[str, Any]): skip_autoscaling_test = config["global"].get("skip_autoscaling", False) if skip_autoscaling_test: diff --git a/test/e2e/tests/aws/test_realtime.py b/test/e2e/tests/aws/test_realtime.py index 6b45e399b0..92c702b982 100644 --- a/test/e2e/tests/aws/test_realtime.py +++ b/test/e2e/tests/aws/test_realtime.py @@ -64,10 +64,8 @@ def test_realtime_api(printer: Callable, config: Dict, client: cx.Client, api: D @pytest.mark.usefixtures("client") -@pytest.mark.parametrize("api", TEST_APIS_ARM) +@pytest.mark.parametrize("api", TEST_APIS_ARM, ids=[api["name"] for api in TEST_APIS_ARM]) def test_realtime_api_arm(printer: Callable, config: Dict, client: cx.Client, api: Dict[str, str]): - - printer(f"testing {api['name']}") e2e.tests.test_realtime_api( printer=printer, client=client,