diff --git a/docs/deployments/apis.md b/docs/deployments/apis.md index 285cabcc6d..033ca384c8 100644 --- a/docs/deployments/apis.md +++ b/docs/deployments/apis.md @@ -11,7 +11,7 @@ Serve models at scale. model_format: # model format, must be "tensorflow" or "onnx" (default: "onnx" if model path ends with .onnx, "tensorflow" if model path ends with .zip) request_handler: # path to the request handler implementation file, relative to the cortex root tracker: - key: # json key to track in the response payload + key: # json key to track if the response payload is a dictionary model_type: # model type, must be "classification" or "regression" compute: min_replicas: # minimum number of replicas (default: 1) diff --git a/pkg/operator/api/userconfig/apis.go b/pkg/operator/api/userconfig/apis.go index f5cfc1c369..0a47d34472 100644 --- a/pkg/operator/api/userconfig/apis.go +++ b/pkg/operator/api/userconfig/apis.go @@ -42,7 +42,7 @@ type API struct { } type Tracker struct { - Key string `json:"key" yaml:"key"` + Key *string `json:"key" yaml:"key"` ModelType ModelType `json:"model_type" yaml:"model_type"` } @@ -72,10 +72,8 @@ var apiValidation = &cr.StructValidation{ DefaultNil: true, StructFieldValidations: []*cr.StructFieldValidation{ { - StructField: "Key", - StringValidation: &cr.StringValidation{ - Required: true, - }, + StructField: "Key", + StringPtrValidation: &cr.StringPtrValidation{}, }, { StructField: "ModelType", diff --git a/pkg/workloads/cortex/lib/api_utils.py b/pkg/workloads/cortex/lib/api_utils.py index cb0c3dea81..4d711e7a9b 100644 --- a/pkg/workloads/cortex/lib/api_utils.py +++ b/pkg/workloads/cortex/lib/api_utils.py @@ -73,23 +73,34 @@ def extract_predicted_values(api, predictions): tracker = api.get("tracker") for prediction in predictions: - predicted_value = prediction.get(tracker["key"]) - if predicted_value is None: - raise ValueError( - "failed to track key '{}': not found in response payload".format(tracker["key"]) - ) + if tracker.get("key") is not None: + key = tracker["key"] + if type(prediction) != dict: + raise ValueError( + "failed to track key '{}': expected prediction to be of type dict but found '{}'".format( + key, type(prediction) + ) + ) + if prediction.get(key) is None: + raise ValueError( + "failed to track key '{}': not found in prediction".format(tracker["key"]) + ) + predicted_value = prediction[key] + else: + predicted_value = prediction + if tracker["model_type"] == "classification": if type(predicted_value) != str and type(predicted_value) != int: raise ValueError( - "failed to track key '{}': expected type 'str' or 'int' but encountered '{}'".format( - tracker["key"], type(predicted_value) + "failed to track classification prediction: expected type 'str' or 'int' but encountered '{}'".format( + type(predicted_value) ) ) else: if type(predicted_value) != float and type(predicted_value) != int: # allow ints raise ValueError( - "failed to track key '{}': expected type 'float' or 'int' but encountered '{}'".format( - tracker["key"], type(predicted_value) + "failed to track regression prediction: expected type 'float' or 'int' but encountered '{}'".format( + type(predicted_value) ) ) predicted_values.append(predicted_value)