diff --git a/Makefile b/Makefile index 47a6f74c66..65f0c63ddb 100644 --- a/Makefile +++ b/Makefile @@ -131,6 +131,7 @@ ci-build-images: @./build/build-image.sh images/tf-serve tf-serve @./build/build-image.sh images/tf-serve-gpu tf-serve-gpu @./build/build-image.sh images/tf-api tf-api + @./build/build-image.sh images/onnx-serve onnx-serve @./build/build-image.sh images/operator operator @./build/build-image.sh images/fluentd fluentd @./build/build-image.sh images/nginx-controller nginx-controller @@ -151,6 +152,7 @@ ci-push-images: @./build/push-image.sh tf-serve @./build/push-image.sh tf-serve-gpu @./build/push-image.sh tf-api + @./build/push-image.sh images/onnx-serve onnx-serve @./build/push-image.sh operator @./build/push-image.sh fluentd @./build/push-image.sh nginx-controller diff --git a/README.md b/README.md index cf23eb4d88..e22f1fd0ae 100644 --- a/README.md +++ b/README.md @@ -35,11 +35,11 @@ Cortex is actively maintained by Cortex Labs. We're a venture-backed team of inf ```python # handler.py -def preprocess(payload): +def pre_inference(sample, metadata): # Python code -def postprocess(prediction): +def post_inference(prediction, metadata): # Python code ``` diff --git a/cli/cmd/predict.go b/cli/cmd/predict.go index 4946021533..6069741d16 100644 --- a/cli/cmd/predict.go +++ b/cli/cmd/predict.go @@ -22,6 +22,7 @@ import ( "net/http" "strings" + "github.com/cortexlabs/yaml" "github.com/spf13/cobra" "github.com/cortexlabs/cortex/pkg/lib/cast" @@ -43,11 +44,11 @@ func init() { } type PredictResponse struct { - ResourceID string `json:"resource_id"` - Predictions []Prediction `json:"predictions"` + ResourceID string `json:"resource_id"` + Predictions []interface{} `json:"predictions"` } -type Prediction struct { +type DetailedPrediction struct { Prediction interface{} `json:"prediction"` PredictionReversed interface{} `json:"prediction_reversed"` TransformedSample interface{} `json:"transformed_sample"` @@ -97,9 +98,10 @@ var predictCmd = &cobra.Command{ } apiID := predictResponse.ResourceID - api := resourcesRes.APIStatuses[apiID] + apiStatus := resourcesRes.APIStatuses[apiID] + api := resourcesRes.Context.APIs[apiName] - apiStart := libtime.LocalTimestampHuman(api.Start) + apiStart := libtime.LocalTimestampHuman(apiStatus.Start) fmt.Println("\n" + apiName + " was last updated on " + apiStart + "\n") if len(predictResponse.Predictions) == 1 { @@ -109,8 +111,8 @@ var predictCmd = &cobra.Command{ } for _, prediction := range predictResponse.Predictions { - if prediction.Prediction == nil { - prettyResp, err := json.Pretty(prediction.Response) + if !yaml.StartsWithEscapedAtSymbol(api.Model) { + prettyResp, err := json.Pretty(prediction) if err != nil { errors.Exit(err) } @@ -119,9 +121,30 @@ var predictCmd = &cobra.Command{ continue } - value := prediction.Prediction - if prediction.PredictionReversed != nil { - value = prediction.PredictionReversed + predictionBytes, err := json.Marshal(prediction) + if err != nil { + errors.Exit(err) + } + + var detailedPrediction DetailedPrediction + err = json.DecodeWithNumber(predictionBytes, &detailedPrediction) + if err != nil { + errors.Exit(err, "prediction response") + } + + if detailedPrediction.Prediction == nil { + prettyResp, err := json.Pretty(detailedPrediction.Response) + if err != nil { + errors.Exit(err) + } + + fmt.Println(prettyResp) + continue + } + + value := detailedPrediction.Prediction + if detailedPrediction.PredictionReversed != nil { + value = detailedPrediction.PredictionReversed } if cast.IsFloatType(value) { diff --git a/cortex.sh b/cortex.sh index 17603343ac..b4830c3613 100755 --- a/cortex.sh +++ b/cortex.sh @@ -131,6 +131,7 @@ export CORTEX_IMAGE_TF_API="${CORTEX_IMAGE_TF_API:-cortexlabs/tf-api:$CORTEX_VER export CORTEX_IMAGE_PYTHON_PACKAGER="${CORTEX_IMAGE_PYTHON_PACKAGER:-cortexlabs/python-packager:$CORTEX_VERSION_STABLE}" export CORTEX_IMAGE_TF_SERVE_GPU="${CORTEX_IMAGE_TF_SERVE_GPU:-cortexlabs/tf-serve-gpu:$CORTEX_VERSION_STABLE}" export CORTEX_IMAGE_TF_TRAIN_GPU="${CORTEX_IMAGE_TF_TRAIN_GPU:-cortexlabs/tf-train-gpu:$CORTEX_VERSION_STABLE}" +export CORTEX_IMAGE_ONNX_SERVE="${CORTEX_IMAGE_ONNX_SERVE:-cortexlabs/onnx-serve:$CORTEX_VERSION_STABLE}" export CORTEX_IMAGE_CLUSTER_AUTOSCALER="${CORTEX_IMAGE_CLUSTER_AUTOSCALER:-cortexlabs/cluster-autoscaler:$CORTEX_VERSION_STABLE}" export CORTEX_IMAGE_NVIDIA="${CORTEX_IMAGE_NVIDIA:-cortexlabs/nvidia:$CORTEX_VERSION_STABLE}" export CORTEX_IMAGE_METRICS_SERVER="${CORTEX_IMAGE_METRICS_SERVER:-cortexlabs/metrics-server:$CORTEX_VERSION_STABLE}" @@ -188,6 +189,7 @@ function install_cortex() { -e CORTEX_IMAGE_PYTHON_PACKAGER=$CORTEX_IMAGE_PYTHON_PACKAGER \ -e CORTEX_IMAGE_TF_SERVE_GPU=$CORTEX_IMAGE_TF_SERVE_GPU \ -e CORTEX_IMAGE_TF_TRAIN_GPU=$CORTEX_IMAGE_TF_TRAIN_GPU \ + -e CORTEX_IMAGE_ONNX_SERVE=$CORTEX_IMAGE_ONNX_SERVE \ -e CORTEX_IMAGE_CLUSTER_AUTOSCALER=$CORTEX_IMAGE_CLUSTER_AUTOSCALER \ -e CORTEX_IMAGE_NVIDIA=$CORTEX_IMAGE_NVIDIA \ -e CORTEX_IMAGE_METRICS_SERVER=$CORTEX_IMAGE_METRICS_SERVER \ diff --git a/dev/registry.sh b/dev/registry.sh index 1e38ba314a..b501e5844b 100755 --- a/dev/registry.sh +++ b/dev/registry.sh @@ -50,6 +50,7 @@ function create_registry() { aws ecr create-repository --repository-name=cortexlabs/python-packager --region=$REGISTRY_REGION || true aws ecr create-repository --repository-name=cortexlabs/tf-train-gpu --region=$REGISTRY_REGION || true aws ecr create-repository --repository-name=cortexlabs/tf-serve-gpu --region=$REGISTRY_REGION || true + aws ecr create-repository --repository-name=cortexlabs/onnx-serve --region=$REGISTRY_REGION || true aws ecr create-repository --repository-name=cortexlabs/cluster-autoscaler --region=$REGISTRY_REGION || true aws ecr create-repository --repository-name=cortexlabs/nvidia --region=$REGISTRY_REGION || true aws ecr create-repository --repository-name=cortexlabs/metrics-server --region=$REGISTRY_REGION || true @@ -130,7 +131,9 @@ elif [ "$cmd" = "update" ]; then cache_builder $ROOT/images/spark-operator spark-operator build_and_push $ROOT/images/spark-operator spark-operator latest - + build_and_push $ROOT/images/spark spark latest + build_and_push $ROOT/images/tf-train tf-train latest + build_and_push $ROOT/images/tf-train-gpu tf-train-gpu latest build_and_push $ROOT/images/nginx-controller nginx-controller latest build_and_push $ROOT/images/nginx-backend nginx-backend latest build_and_push $ROOT/images/fluentd fluentd latest @@ -144,10 +147,8 @@ elif [ "$cmd" = "update" ]; then build_and_push $ROOT/images/metrics-server metrics-server latest fi - build_and_push $ROOT/images/spark spark latest - build_and_push $ROOT/images/tf-train tf-train latest - build_and_push $ROOT/images/tf-train-gpu tf-train-gpu latest build_and_push $ROOT/images/tf-api tf-api latest + build_and_push $ROOT/images/onnx-serve onnx-serve latest cleanup fi diff --git a/docs/apis/apis.md b/docs/apis/apis.md index 4b6a27ec20..1415343142 100644 --- a/docs/apis/apis.md +++ b/docs/apis/apis.md @@ -8,6 +8,8 @@ Serve models at scale and use them to build smarter applications. - kind: api name: # API name (required) model: # path to a zipped model dir (e.g. s3://my-bucket/model.zip) + model_format: # model format, must be "tensorflow" or "onnx" + request_handler: # path to the request handler implementation file, relative to the cortex root compute: min_replicas: # minimum number of replicas (default: 1) max_replicas: # maximum number of replicas (default: 100) @@ -26,12 +28,19 @@ See [packaging models](packaging-models.md) for how to create the zipped model. - kind: api name: my-api model: s3://my-bucket/my-model.zip + request_handler: inference.py compute: min_replicas: 5 max_replicas: 20 cpu: "1" ``` +## Custom Request Handlers + +Request handlers are used to decouple the interface of an API endpoint from its model. A `pre_inference` request handler can be used to modify request payloads before they are sent to the model. A `post_inference` request handler can be used to modify model predictions in the server before they are sent to the client. + +See [request handlers](request-handlers.md) for a detailed guide. + ## Integration APIs can be integrated into other applications or services via their JSON endpoints. The endpoint for any API follows the following format: {apis_endpoint}/{deployment_name}/{api_name}. diff --git a/docs/apis/packaging-models.md b/docs/apis/packaging-models.md index 4837829420..7ade2e1fcd 100644 --- a/docs/apis/packaging-models.md +++ b/docs/apis/packaging-models.md @@ -2,7 +2,7 @@ ## TensorFlow -Zip the exported estimator output in your checkpoint directory, e.g. +Zip the exported estimator output in your checkpoint directory: ```text $ ls export/estimator @@ -11,16 +11,58 @@ saved_model.pb variables/ $ zip -r model.zip export/estimator ``` -Upload the zipped file to Amazon S3, e.g. +Upload the zipped file to Amazon S3: ```text $ aws s3 cp model.zip s3://my-bucket/model.zip ``` -Specify `model` in an API, e.g. +Reference your `model` in an API: ```yaml - kind: api name: my-api + model_format: tensorflow model: s3://my-bucket/model.zip ``` + +## ONNX + +Export your trained model to an ONNX model format. An example of an sklearn model being exported to ONNX is shown below: + +```Python +... +logreg_model = sklearn.linear_model.LogisticRegression(solver="lbfgs", multi_class="multinomial") + +# Train the model +logreg_model.fit(X_train, y_train) + +# Convert to ONNX model format +onnx_model = onnxmltools.convert_sklearn( + logreg_model, initial_types=[("input", onnxconverter_common.data_types.FloatTensorType([1, 4]))] +) +with open("model.onnx", "wb") as f: + f.write(onnx_model.SerializeToString()) +``` + +Here are examples of converting models from some of the common ML frameworks to ONNX: + +* [PyTorch](https://github.com/cortexlabs/cortex/blob/master/examples/iris/pytorch/model.py) +* [Sklearn](https://github.com/cortexlabs/cortex/blob/master/examples/iris/sklearn/model.py) +* [XGBoost](https://github.com/cortexlabs/cortex/blob/master/examples/iris/xgboost/model.py) +* [Keras](https://github.com/cortexlabs/cortex/blob/master/examples/iris/keras/model.py) + +Upload your trained model in ONNX format to Amazon S3: + +```text +$ aws s3 cp model.onnx s3://my-bucket/model.onnx +``` + +Reference your `model` in an API: + +```yaml +- kind: api + name: my-api + model_format: onnx + model: s3://my-bucket/model.onnx +``` diff --git a/docs/apis/request-handlers.md b/docs/apis/request-handlers.md new file mode 100644 index 0000000000..19b7cd6cbf --- /dev/null +++ b/docs/apis/request-handlers.md @@ -0,0 +1,85 @@ +# Request Handlers + +Request handlers are python files that can contain a `pre_inference` function and a `post_inference` function. Both functions are optional. + +## Implementation + +```python +def pre_inference(sample, metadata): + """Prepare a sample before it is passed into the model. + + Args: + sample: A sample from the request payload. + + metadata: Describes the expected shape and type of inputs to the model. + If API model_format is tensorflow: map + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/meta_graph.proto + If API model_format is onnx: list + https://microsoft.github.io/onnxruntime/api_summary.html#onnxruntime.NodeArg + + Returns: + A dictionary containing model input names as keys and python lists or numpy arrays as values. If the model only has a single input, then a python list or numpy array can be returned. + """ + pass + +def post_inference(prediction, metadata): + """Modify a prediction from the model before responding to the request. + + Args: + prediction: The output of the model. + + metadata: Describes the output shape and type of outputs from the model. + If API model_format is tensorflow: map + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/meta_graph.proto + If API model_format is onnx: list + https://microsoft.github.io/onnxruntime/api_summary.html#onnxruntime.NodeArg + + Returns: + A python dictionary or list. + """ +``` + +## Example + +```python +import numpy as np + +iris_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"] + +def pre_inference(sample, metadata): + # Convert a dictionary of features to a flattened in list in the order expected by the model + return { + metadata[0].name : [ + sample["sepal_length"], + sample["sepal_width"], + sample["petal_length"], + sample["petal_width"], + ] + } + + +def post_inference(prediction, metadata): + # Update the model prediction to include the index and the label of the predicted class + probabilites = prediction[0][0] + predicted_class_id = int(np.argmax(probabilites)) + return { + "class_label": iris_labels[predicted_class_id], + "class_index": predicted_class_id, + "probabilities": probabilites, + } + +``` + +## Pre-installed Packages + +The following packages have been pre-installed and can be used in your implementations: + +```text +boto3==1.9.78 +msgpack==0.6.1 +numpy>=1.13.3,<2 +requirements-parser==0.2.0 +packaging==19.0.0 +``` + +You can install additional PyPI packages and import your own Python packages. See [Python Packages](../piplines/python-packages.md) for more details. diff --git a/docs/cluster/config.md b/docs/cluster/config.md index 96e32251ef..7504101b32 100644 --- a/docs/cluster/config.md +++ b/docs/cluster/config.md @@ -53,6 +53,7 @@ export CORTEX_IMAGE_TF_TRAIN="cortexlabs/tf-train:master" export CORTEX_IMAGE_TF_API="cortexlabs/tf-api:master" export CORTEX_IMAGE_TF_TRAIN_GPU="cortexlabs/tf-train-gpu:master" export CORTEX_IMAGE_TF_SERVE_GPU="cortexlabs/tf-serve-gpu:master" +export CORTEX_IMAGE_ONNX_SERVE="cortexlabs/onnx-serve:master" export CORTEX_IMAGE_PYTHON_PACKAGER="cortexlabs/python-packager:master" export CORTEX_IMAGE_CLUSTER_AUTOSCALER="cortexlabs/cluster-autoscaler:master" export CORTEX_IMAGE_NVIDIA="cortexlabs/nvidia:master" diff --git a/docs/cluster/development.md b/docs/cluster/development.md index 7dffee0389..fc0af4e5fd 100644 --- a/docs/cluster/development.md +++ b/docs/cluster/development.md @@ -61,6 +61,7 @@ export CORTEX_IMAGE_ARGO_EXECUTOR="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cort export CORTEX_IMAGE_FLUENTD="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cortexlabs/fluentd:latest" export CORTEX_IMAGE_NGINX_BACKEND="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cortexlabs/nginx-backend:latest" export CORTEX_IMAGE_NGINX_CONTROLLER="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cortexlabs/nginx-controller:latest" +export CORTEX_IMAGE_ONNX_SERVE="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cortexlabs/onnx-serve:latest" export CORTEX_IMAGE_OPERATOR="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cortexlabs/operator:latest" export CORTEX_IMAGE_SPARK="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cortexlabs/spark:latest" export CORTEX_IMAGE_SPARK_OPERATOR="XXXXXXXX.dkr.ecr.us-west-2.amazonaws.com/cortexlabs/spark-operator:latest" diff --git a/examples/iris/cortex.yaml b/examples/iris/cortex.yaml index 41b9f504bc..16ba8b85fa 100644 --- a/examples/iris/cortex.yaml +++ b/examples/iris/cortex.yaml @@ -2,5 +2,30 @@ name: iris - kind: api - name: iris-type - model: s3://cortex-examples/iris-tensorflow.zip + name: tensorflow + model_format: tensorflow + model: s3://cortex-examples/iris/tensorflow.zip + +- kind: api + name: pytorch + model_format: onnx + request_handler: pytorch/request_handler.py + model: s3://cortex-examples/iris/pytorch.onnx + +- kind: api + name: xgboost + model_format: onnx + request_handler: xgboost/request_handler.py + model: s3://cortex-examples/iris/xgboost.onnx + +- kind: api + name: sklearn + model_format: onnx + request_handler: sklearn/request_handler.py + model: s3://cortex-examples/iris/sklearn.onnx + +- kind: api + name: keras + model_format: onnx + request_handler: keras/request_handler.py + model: s3://cortex-examples/iris/keras.onnx diff --git a/examples/iris/keras/irises.json b/examples/iris/keras/irises.json new file mode 100644 index 0000000000..4b660c4325 --- /dev/null +++ b/examples/iris/keras/irises.json @@ -0,0 +1,16 @@ +{ + "samples": [ + [ + 5.9, + 3.0, + 5.1, + 1.8 + ], + [ + 5.6, + 2.5, + 3.9, + 1.1 + ] + ] +} diff --git a/examples/iris/keras/model.py b/examples/iris/keras/model.py new file mode 100644 index 0000000000..84e5af784d --- /dev/null +++ b/examples/iris/keras/model.py @@ -0,0 +1,27 @@ +import numpy as np +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from keras.models import Sequential +from keras.layers import Dense +from keras.utils import np_utils +import keras2onnx + +iris = load_iris() +X, y = iris.data, np_utils.to_categorical(iris.target) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +model = Sequential(name="iris") +model.add(Dense(30, input_dim=4, activation="relu", name="input")) +model.add(Dense(3, activation="softmax", name="last")) + +model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) + +model.fit(X_train, y_train, epochs=100) + +scores = model.evaluate(X_test, y_test) +print("\n%s: %.2f%%" % (model.metrics_names[1], scores[1] * 100)) + +# Convert to ONNX model format +onnx_model = keras2onnx.convert_keras(model) +with open("keras.onnx", "wb") as f: + f.write(onnx_model.SerializeToString()) diff --git a/examples/iris/keras/request_handler.py b/examples/iris/keras/request_handler.py new file mode 100644 index 0000000000..8c7217f8ca --- /dev/null +++ b/examples/iris/keras/request_handler.py @@ -0,0 +1,13 @@ +import numpy as np + +iris_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"] + + +def post_inference(prediction, metadata): + probabilites = prediction[0][0] + predicted_class_id = int(np.argmax(probabilites)) + return { + "class_label": iris_labels[predicted_class_id], + "class_index": predicted_class_id, + "probabilities": probabilites, + } diff --git a/examples/iris/keras/requirements.txt b/examples/iris/keras/requirements.txt new file mode 100644 index 0000000000..dbc2f28b5f --- /dev/null +++ b/examples/iris/keras/requirements.txt @@ -0,0 +1,4 @@ +scikit-learn +keras +keras2onnx +tensorflow diff --git a/examples/iris/irises.json b/examples/iris/pytorch/irises.json similarity index 100% rename from examples/iris/irises.json rename to examples/iris/pytorch/irises.json diff --git a/examples/iris/pytorch/model.py b/examples/iris/pytorch/model.py new file mode 100644 index 0000000000..dc7bafe119 --- /dev/null +++ b/examples/iris/pytorch/model.py @@ -0,0 +1,63 @@ +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score, precision_score, recall_score +from sklearn.datasets import load_iris +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + + +class Net(nn.Module): + # define nn + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Linear(4, 100) + self.fc2 = nn.Linear(100, 100) + self.fc3 = nn.Linear(100, 3) + self.softmax = nn.Softmax(dim=1) + + def forward(self, X): + X = F.relu(self.fc1(X)) + X = self.fc2(X) + X = self.fc3(X) + X = self.softmax(X) + + return X + + +iris = load_iris() +X, y = iris.data, iris.target +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42) + +# wrap up with Variable in pytorch +train_X = Variable(torch.Tensor(X_train).float()) +test_X = Variable(torch.Tensor(X_test).float()) +train_y = Variable(torch.Tensor(y_train).long()) +test_y = Variable(torch.Tensor(y_test).long()) + +model = Net() + +criterion = nn.CrossEntropyLoss() + +optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + +for epoch in range(1000): + optimizer.zero_grad() + out = model(train_X) + loss = criterion(out, train_y) + loss.backward() + optimizer.step() + + if epoch % 100 == 0: + print("number of epoch {} loss {}".format(epoch, loss)) + +predict_out = model(test_X) +_, predict_y = torch.max(predict_out, 1) + +print("prediction accuracy {}".format(accuracy_score(test_y.data, predict_y.data))) + +# Convert to ONNX model format +placeholder = torch.randn(1, 4) +torch.onnx.export( + model, placeholder, "pytorch.onnx", input_names=["input"], output_names=["species"] +) diff --git a/examples/iris/pytorch/request_handler.py b/examples/iris/pytorch/request_handler.py new file mode 100644 index 0000000000..cc0e7369e4 --- /dev/null +++ b/examples/iris/pytorch/request_handler.py @@ -0,0 +1,23 @@ +import numpy as np + +iris_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"] + + +def pre_inference(sample, metadata): + return { + metadata[0].name: [ + sample["sepal_length"], + sample["sepal_width"], + sample["petal_length"], + sample["petal_width"], + ] + } + + +def post_inference(prediction, metadata): + predicted_class_id = int(np.argmax(prediction[0][0])) + return { + "class_label": iris_labels[predicted_class_id], + "class_index": predicted_class_id, + "probabilites": prediction[0][0], + } diff --git a/examples/iris/pytorch/requirements.txt b/examples/iris/pytorch/requirements.txt new file mode 100644 index 0000000000..04383c96ca --- /dev/null +++ b/examples/iris/pytorch/requirements.txt @@ -0,0 +1,2 @@ +scikit-learn +torch diff --git a/examples/iris/requirements.txt b/examples/iris/requirements.txt new file mode 100644 index 0000000000..1c122fe8fd --- /dev/null +++ b/examples/iris/requirements.txt @@ -0,0 +1 @@ +numpy==1.16.4 diff --git a/examples/iris/sklearn/irises.json b/examples/iris/sklearn/irises.json new file mode 100644 index 0000000000..33d1e6a5b5 --- /dev/null +++ b/examples/iris/sklearn/irises.json @@ -0,0 +1,10 @@ +{ + "samples": [ + { + "sepal_length": 5.2, + "sepal_width": 3.6, + "petal_length": 1.4, + "petal_width": 0.3 + } + ] +} diff --git a/examples/iris/sklearn/model.py b/examples/iris/sklearn/model.py new file mode 100644 index 0000000000..06fc5f5e28 --- /dev/null +++ b/examples/iris/sklearn/model.py @@ -0,0 +1,20 @@ +import numpy as np +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from sklearn.linear_model import LogisticRegression +from onnxmltools import convert_sklearn +from onnxconverter_common.data_types import FloatTensorType + +iris = load_iris() +X, y = iris.data, iris.target +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42) + +logreg_model = LogisticRegression(solver="lbfgs", multi_class="multinomial") +logreg_model.fit(X_train, y_train) + +print("Test data accuracy: {:.2f}".format(logreg_model.score(X_test, y_test))) + +# Convert to ONNX model format +onnx_model = convert_sklearn(logreg_model, initial_types=[("input", FloatTensorType([1, 4]))]) +with open("sklearn.onnx", "wb") as f: + f.write(onnx_model.SerializeToString()) diff --git a/examples/iris/sklearn/request_handler.py b/examples/iris/sklearn/request_handler.py new file mode 100644 index 0000000000..45d24eb818 --- /dev/null +++ b/examples/iris/sklearn/request_handler.py @@ -0,0 +1,17 @@ +import numpy as np + +iris_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"] + + +def pre_inference(sample, metadata): + return [ + sample["sepal_length"], + sample["sepal_width"], + sample["petal_length"], + sample["petal_width"], + ] + + +def post_inference(prediction, metadata): + predicted_class_id = prediction[0][0] + return {"class_label": iris_labels[predicted_class_id], "class_index": predicted_class_id} diff --git a/examples/iris/sklearn/requirements.txt b/examples/iris/sklearn/requirements.txt new file mode 100644 index 0000000000..c486fc151f --- /dev/null +++ b/examples/iris/sklearn/requirements.txt @@ -0,0 +1,4 @@ +onnxmltools +pandas +scikit-learn +skl2onnx diff --git a/examples/iris/tensorflow/irises.json b/examples/iris/tensorflow/irises.json new file mode 100644 index 0000000000..33d1e6a5b5 --- /dev/null +++ b/examples/iris/tensorflow/irises.json @@ -0,0 +1,10 @@ +{ + "samples": [ + { + "sepal_length": 5.2, + "sepal_width": 3.6, + "petal_length": 1.4, + "petal_width": 0.3 + } + ] +} diff --git a/examples/iris/xgboost/irises.json b/examples/iris/xgboost/irises.json new file mode 100644 index 0000000000..4b660c4325 --- /dev/null +++ b/examples/iris/xgboost/irises.json @@ -0,0 +1,16 @@ +{ + "samples": [ + [ + 5.9, + 3.0, + 5.1, + 1.8 + ], + [ + 5.6, + 2.5, + 3.9, + 1.1 + ] + ] +} diff --git a/examples/iris/xgboost/model.py b/examples/iris/xgboost/model.py new file mode 100644 index 0000000000..3bdef80b1d --- /dev/null +++ b/examples/iris/xgboost/model.py @@ -0,0 +1,20 @@ +import numpy as np +import xgboost as xgb +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from onnxmltools.convert import convert_xgboost +from onnxconverter_common.data_types import FloatTensorType + +iris = load_iris() +X, y = iris.data, iris.target +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42) + +xgb_model = xgb.XGBClassifier() +xgb_model = xgb_model.fit(X_train, y_train) + +print("Test data accuracy of the xgb classifier is {:.2f}".format(xgb_model.score(X_test, y_test))) + +# Convert to ONNX model format +onnx_model = convert_xgboost(xgb_model, initial_types=[("input", FloatTensorType([1, 4]))]) +with open("xgboost.onnx", "wb") as f: + f.write(onnx_model.SerializeToString()) diff --git a/examples/iris/xgboost/request_handler.py b/examples/iris/xgboost/request_handler.py new file mode 100644 index 0000000000..093a2db31e --- /dev/null +++ b/examples/iris/xgboost/request_handler.py @@ -0,0 +1,10 @@ +iris_labels = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"] + + +def post_inference(prediction, metadata): + predicted_class_id = prediction[0][0] + return { + "class_label": iris_labels[predicted_class_id], + "class_index": predicted_class_id, + "probabilities": prediction[1][0], + } diff --git a/examples/iris/xgboost/requirements.txt b/examples/iris/xgboost/requirements.txt new file mode 100644 index 0000000000..10de6ba33c --- /dev/null +++ b/examples/iris/xgboost/requirements.txt @@ -0,0 +1,3 @@ +onnxmltools +scikit-learn +xgboost diff --git a/images/onnx-serve/Dockerfile b/images/onnx-serve/Dockerfile new file mode 100644 index 0000000000..4acabf1614 --- /dev/null +++ b/images/onnx-serve/Dockerfile @@ -0,0 +1,39 @@ +FROM ubuntu:16.04 + +RUN apt-get update -qq && apt-get install -y -q \ + python3 \ + python3-dev \ + python3-pip \ + && apt-get clean -qq && rm -rf /var/lib/apt/lists/* && \ + pip3 install --upgrade \ + pip \ + setuptools \ + && rm -rf /root/.cache/pip* + +RUN apt-get update -qq && apt-get install -y -q \ + build-essential \ + curl \ + libfreetype6-dev \ + libpng-dev \ + libzmq3-dev \ + pkg-config \ + rsync \ + software-properties-common \ + unzip \ + zlib1g-dev \ + && apt-get clean -qq && rm -rf /var/lib/apt/lists/* + + +ENV PYTHONPATH="/src:${PYTHONPATH}" + +COPY pkg/workloads/lib/requirements.txt /src/lib/requirements.txt +COPY pkg/workloads/onnx_serve/requirements.txt /src/onnx_serve/requirements.txt +RUN pip3 install -r /src/lib/requirements.txt && \ + pip3 install -r /src/onnx_serve/requirements.txt && \ + rm -rf /root/.cache/pip* + +COPY pkg/workloads/consts.py /src/ +COPY pkg/workloads/lib /src/lib +COPY pkg/workloads/onnx_serve /src/onnx_serve + +ENTRYPOINT ["/usr/bin/python3", "/src/onnx_serve/api.py"] diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index 8b92580411..d8295de37d 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -47,6 +47,7 @@ var ( TransformersDir = "transformers" EstimatorsDir = "estimators" PythonPackagesDir = "python_packages" + RequestHandlersDir = "request_handlers" ModelsDir = "models" ConstantsDir = "constants" ContextsDir = "contexts" diff --git a/pkg/operator/api/context/apis.go b/pkg/operator/api/context/apis.go index 9fcf247a72..2f4d47c074 100644 --- a/pkg/operator/api/context/apis.go +++ b/pkg/operator/api/context/apis.go @@ -25,7 +25,8 @@ type APIs map[string]*API type API struct { *userconfig.API *ComputedResourceFields - Path string `json:"path"` + Path string `json:"path"` + RequestHandlerImplKey *string `json:"request_handler_impl_key"` } func APIPath(apiName string, appName string) string { diff --git a/pkg/operator/api/context/context.go b/pkg/operator/api/context/context.go index 533d888093..7dafe70b25 100644 --- a/pkg/operator/api/context/context.go +++ b/pkg/operator/api/context/context.go @@ -17,6 +17,8 @@ limitations under the License. package context import ( + "fmt" + "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/sets/strset" "github.com/cortexlabs/cortex/pkg/operator/api/resource" @@ -206,7 +208,7 @@ func (ctx *Context) PopulateWorkloadIDs(resourceWorkloadIDs map[string]string) { func (ctx *Context) CheckAllWorkloadIDsPopulated() error { for _, res := range ctx.ComputedResources() { if res.GetWorkloadID() == "" { - return errors.New(ctx.App.Name, "resource", res.GetID(), "workload ID is missing") // unexpected + return errors.New(ctx.App.Name, "workload ID missing", fmt.Sprintf("%s (ID: %s)", res.GetName(), res.GetID())) // unexpected } } return nil diff --git a/pkg/operator/api/context/dependencies.go b/pkg/operator/api/context/dependencies.go index 4ff2e80564..af10369d79 100644 --- a/pkg/operator/api/context/dependencies.go +++ b/pkg/operator/api/context/dependencies.go @@ -72,6 +72,7 @@ func (ctx *Context) DirectComputedResourceDependencies(resourceID string) strset return ctx.trainingDatasetDependencies(model) } } + for _, api := range ctx.APIs { if api.ID == resourceID { return ctx.apiDependencies(api) @@ -151,12 +152,20 @@ func (ctx *Context) modelDependencies(model *Model) strset.Set { } func (ctx *Context) apiDependencies(api *API) strset.Set { + dependencies := strset.New() + modelName, ok := yaml.ExtractAtSymbolText(api.Model) - if !ok { - return strset.New() + if ok { + model := ctx.Models[modelName] + dependencies.Add(model.ID) } - model := ctx.Models[modelName] - return strset.New(model.ID) + + if api.RequestHandler != nil { + for _, pythonPackage := range ctx.PythonPackages { + dependencies.Add(pythonPackage.GetID()) + } + } + return dependencies } func (ctx *Context) ExtractCortexResources( diff --git a/pkg/operator/api/userconfig/apis.go b/pkg/operator/api/userconfig/apis.go index e9f804f22c..51450ade41 100644 --- a/pkg/operator/api/userconfig/apis.go +++ b/pkg/operator/api/userconfig/apis.go @@ -33,9 +33,11 @@ type APIs []*API type API struct { ResourceFields - Model string `json:"model" yaml:"model"` - Compute *APICompute `json:"compute" yaml:"compute"` - Tags Tags `json:"tags" yaml:"tags"` + Model string `json:"model" yaml:"model"` + ModelFormat ModelFormat `json:"model_format" yaml:"model_format"` + RequestHandler *string `json:"request_handler" yaml:"request_handler"` + Compute *APICompute `json:"compute" yaml:"compute"` + Tags Tags `json:"tags" yaml:"tags"` } var apiValidation = &cr.StructValidation{ @@ -54,6 +56,20 @@ var apiValidation = &cr.StructValidation{ AllowCortexResources: true, }, }, + { + StructField: "RequestHandler", + StringPtrValidation: &cr.StringPtrValidation{}, + }, + { + StructField: "ModelFormat", + StringValidation: &cr.StringValidation{ + Required: true, + AllowedValues: ModelFormatStrings(), + }, + Parser: func(str string) (interface{}, error) { + return ModelFormatFromString(str), nil + }, + }, apiComputeFieldValidation, tagsFieldValidation, typeFieldValidation, @@ -64,6 +80,13 @@ func (api *API) UserConfigStr() string { var sb strings.Builder sb.WriteString(api.ResourceFields.UserConfigStr()) sb.WriteString(fmt.Sprintf("%s: %s\n", ModelKey, yaml.UnescapeAtSymbol(api.Model))) + + if api.ModelFormat != UnknownModelFormat { + sb.WriteString(fmt.Sprintf("%s: %s\n", ModelFormatKey, api.ModelFormat.String())) + } + if api.RequestHandler != nil { + sb.WriteString(fmt.Sprintf("%s: %s\n", RequestHandlerKey, *api.RequestHandler)) + } if api.Compute != nil { sb.WriteString(fmt.Sprintf("%s:\n", ComputeKey)) sb.WriteString(s.Indent(api.Compute.UserConfigStr(), " ")) diff --git a/pkg/operator/api/userconfig/config_key.go b/pkg/operator/api/userconfig/config_key.go index ff3d71ad7b..2f871a9792 100644 --- a/pkg/operator/api/userconfig/config_key.go +++ b/pkg/operator/api/userconfig/config_key.go @@ -91,7 +91,9 @@ const ( ThrottleSecsKey = "throttle_secs" // API - ModelKey = "model" + ModelKey = "model" + ModelFormatKey = "model_format" + RequestHandlerKey = "request_handler" // compute ComputeKey = "compute" diff --git a/pkg/operator/api/userconfig/model_format.go b/pkg/operator/api/userconfig/model_format.go new file mode 100644 index 0000000000..2398ae20dd --- /dev/null +++ b/pkg/operator/api/userconfig/model_format.go @@ -0,0 +1,78 @@ +/* +Copyright 2019 Cortex Labs, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package userconfig + +type ModelFormat int + +const ( + UnknownModelFormat ModelFormat = iota + TensorFlowModelFormat + ONNXModelFormat +) + +var modelFormats = []string{ + "unknown", + "tensorflow", + "onnx", +} + +func ModelFormatFromString(s string) ModelFormat { + for i := 0; i < len(modelFormats); i++ { + if s == modelFormats[i] { + return ModelFormat(i) + } + } + return UnknownModelFormat +} + +func ModelFormatStrings() []string { + return modelFormats[1:] +} + +func (t ModelFormat) String() string { + return modelFormats[t] +} + +// MarshalText satisfies TextMarshaler +func (t ModelFormat) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +// UnmarshalText satisfies TextUnmarshaler +func (t *ModelFormat) UnmarshalText(text []byte) error { + enum := string(text) + for i := 0; i < len(modelFormats); i++ { + if enum == modelFormats[i] { + *t = ModelFormat(i) + return nil + } + } + + *t = UnknownModelFormat + return nil +} + +// UnmarshalBinary satisfies BinaryUnmarshaler +// Needed for msgpack +func (t *ModelFormat) UnmarshalBinary(data []byte) error { + return t.UnmarshalText(data) +} + +// MarshalBinary satisfies BinaryMarshaler +func (t ModelFormat) MarshalBinary() ([]byte, error) { + return []byte(t.String()), nil +} diff --git a/pkg/operator/config/config.go b/pkg/operator/config/config.go index b36bd6b12d..c6bcc64a7a 100644 --- a/pkg/operator/config/config.go +++ b/pkg/operator/config/config.go @@ -53,9 +53,11 @@ type CortexConfig struct { PythonPackagerImage string `json:"python_packager_image"` TFTrainImageGPU string `json:"tf_train_image_gpu"` TFServeImageGPU string `json:"tf_serve_image_gpu"` - TelemetryURL string `json:"telemetry_url"` - EnableTelemetry bool `json:"enable_telemetry"` - OperatorInCluster bool `json:"operator_in_cluster"` + ONNXServeImage string `json:"onnx_serve_image"` + + TelemetryURL string `json:"telemetry_url"` + EnableTelemetry bool `json:"enable_telemetry"` + OperatorInCluster bool `json:"operator_in_cluster"` } func Init() error { @@ -73,9 +75,11 @@ func Init() error { PythonPackagerImage: getStr("IMAGE_PYTHON_PACKAGER"), TFTrainImageGPU: getStr("IMAGE_TF_TRAIN_GPU"), TFServeImageGPU: getStr("IMAGE_TF_SERVE_GPU"), - TelemetryURL: configreader.MustStringFromEnv("CONST_TELEMETRY_URL", &configreader.StringValidation{Required: false, Default: consts.TelemetryURL}), - EnableTelemetry: getBool("ENABLE_TELEMETRY"), - OperatorInCluster: configreader.MustBoolFromEnv("CONST_OPERATOR_IN_CLUSTER", &configreader.BoolValidation{Default: true}), + ONNXServeImage: getStr("IMAGE_ONNX_SERVE"), + + TelemetryURL: configreader.MustStringFromEnv("CONST_TELEMETRY_URL", &configreader.StringValidation{Required: false, Default: consts.TelemetryURL}), + EnableTelemetry: getBool("ENABLE_TELEMETRY"), + OperatorInCluster: configreader.MustBoolFromEnv("CONST_OPERATOR_IN_CLUSTER", &configreader.BoolValidation{Default: true}), } Cortex.ID = hash.String(Cortex.Bucket + Cortex.Region + Cortex.LogGroup) diff --git a/pkg/operator/context/apis.go b/pkg/operator/context/apis.go index 0c161b3f9b..76ea5c7671 100644 --- a/pkg/operator/context/apis.go +++ b/pkg/operator/context/apis.go @@ -18,25 +18,56 @@ package context import ( "bytes" + "path/filepath" + "github.com/cortexlabs/cortex/pkg/consts" "github.com/cortexlabs/cortex/pkg/lib/errors" "github.com/cortexlabs/cortex/pkg/lib/hash" + "github.com/cortexlabs/cortex/pkg/lib/pointer" + "github.com/cortexlabs/cortex/pkg/lib/sets/strset" "github.com/cortexlabs/cortex/pkg/operator/api/context" "github.com/cortexlabs/cortex/pkg/operator/api/resource" "github.com/cortexlabs/cortex/pkg/operator/api/userconfig" + "github.com/cortexlabs/cortex/pkg/operator/config" "github.com/cortexlabs/yaml" ) +var uploadedRequestHandlers = strset.New() + func getAPIs(config *userconfig.Config, models context.Models, datasetVersion string, + impls map[string][]byte, + pythonPackages context.PythonPackages, ) (context.APIs, error) { apis := context.APIs{} for _, apiConfig := range config.APIs { var buf bytes.Buffer + var requestHandlerImplKey *string buf.WriteString(apiConfig.Name) + buf.WriteString(apiConfig.ModelFormat.String()) + + if apiConfig.RequestHandler != nil { + for _, pythonPackage := range pythonPackages { + buf.WriteString(pythonPackage.GetID()) + } + + impl, ok := impls[*apiConfig.RequestHandler] + if !ok { + return nil, errors.Wrap(userconfig.ErrorImplDoesNotExist(*apiConfig.RequestHandler), userconfig.Identify(apiConfig), userconfig.RequestHandlerKey) + } + implID := hash.Bytes(impl) + buf.WriteString(implID) + + requestHandlerImplKey = pointer.String(filepath.Join(consts.RequestHandlersDir, implID)) + + err := uploadRequestHandler(*requestHandlerImplKey, impls[*apiConfig.RequestHandler]) + if err != nil { + return nil, errors.Wrap(err, userconfig.Identify(apiConfig)) + } + } if yaml.StartsWithEscapedAtSymbol(apiConfig.Model) { modelName, _ := yaml.ExtractAtSymbolText(apiConfig.Model) @@ -59,9 +90,27 @@ func getAPIs(config *userconfig.Config, ResourceType: resource.APIType, }, }, - API: apiConfig, - Path: context.APIPath(apiConfig.Name, config.App.Name), + API: apiConfig, + Path: context.APIPath(apiConfig.Name, config.App.Name), + RequestHandlerImplKey: requestHandlerImplKey, } } return apis, nil } + +func uploadRequestHandler(implKey string, impl []byte) error { + isUploaded, err := config.AWS.IsS3File(implKey) + if err != nil { + return errors.Wrap(err, "upload") + } + + if !isUploaded { + err = config.AWS.UploadBytesToS3(impl, implKey) + if err != nil { + return errors.Wrap(err, "upload") + } + } + + uploadedRequestHandlers.Add(implKey) + return nil +} diff --git a/pkg/operator/context/context.go b/pkg/operator/context/context.go index 9742c715da..bf931f99e3 100644 --- a/pkg/operator/context/context.go +++ b/pkg/operator/context/context.go @@ -231,7 +231,7 @@ func New( } ctx.Models = models - apis, err := getAPIs(userconf, ctx.Models, ctx.DatasetVersion) + apis, err := getAPIs(userconf, ctx.Models, ctx.DatasetVersion, files, pythonPackages) if err != nil { return nil, err } diff --git a/pkg/operator/workloads/api.go b/pkg/operator/workloads/api.go index d07adf9211..dc3f17bda9 100644 --- a/pkg/operator/workloads/api.go +++ b/pkg/operator/workloads/api.go @@ -40,13 +40,12 @@ const ( tfServingContainerName = "serve" ) -func apiSpec( +func tfAPISpec( ctx *context.Context, api *context.API, workloadID string, desiredReplicas int32, ) *appsv1b1.Deployment { - transformResourceList := corev1.ResourceList{} tfServingResourceList := corev1.ResourceList{} tfServingLimitsList := corev1.ResourceList{} @@ -166,6 +165,87 @@ func apiSpec( }) } +func onnxAPISpec( + ctx *context.Context, + api *context.API, + workloadID string, + desiredReplicas int32, +) *appsv1b1.Deployment { + resourceList := corev1.ResourceList{} + resourceList[corev1.ResourceCPU] = api.Compute.CPU.Quantity + + if api.Compute.Mem != nil { + resourceList[corev1.ResourceMemory] = api.Compute.Mem.Quantity + } + + return k8s.Deployment(&k8s.DeploymentSpec{ + Name: internalAPIName(api.Name, ctx.App.Name), + Replicas: desiredReplicas, + Labels: map[string]string{ + "appName": ctx.App.Name, + "workloadType": WorkloadTypeAPI, + "apiName": api.Name, + "resourceID": ctx.APIs[api.Name].ID, + "workloadID": workloadID, + }, + Selector: map[string]string{ + "appName": ctx.App.Name, + "workloadType": WorkloadTypeAPI, + "apiName": api.Name, + }, + PodSpec: k8s.PodSpec{ + Labels: map[string]string{ + "appName": ctx.App.Name, + "workloadType": WorkloadTypeAPI, + "apiName": api.Name, + "resourceID": ctx.APIs[api.Name].ID, + "workloadID": workloadID, + "userFacing": "true", + }, + K8sPodSpec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: apiContainerName, + Image: config.Cortex.ONNXServeImage, + ImagePullPolicy: "Always", + Args: []string{ + "--workload-id=" + workloadID, + "--port=" + defaultPortStr, + "--context=" + config.AWS.S3Path(ctx.Key), + "--api=" + ctx.APIs[api.Name].ID, + "--model-dir=" + path.Join(consts.EmptyDirMountPath, "model"), + "--cache-dir=" + consts.ContextCacheDir, + }, + Env: k8s.AWSCredentials(), + VolumeMounts: k8s.DefaultVolumeMounts(), + ReadinessProbe: &corev1.Probe{ + InitialDelaySeconds: 5, + TimeoutSeconds: 5, + PeriodSeconds: 5, + SuccessThreshold: 1, + FailureThreshold: 2, + Handler: corev1.Handler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: "/healthz", + Port: intstr.IntOrString{ + IntVal: defaultPortInt32, + }, + }, + }, + }, + Resources: corev1.ResourceRequirements{ + Requests: resourceList, + }, + }, + }, + Volumes: k8s.DefaultVolumes(), + ServiceAccountName: "default", + }, + }, + Namespace: config.Cortex.Namespace, + }) +} + func ingressSpec(ctx *context.Context, api *context.API) *k8s.IngressSpec { return &k8s.IngressSpec{ Name: internalAPIName(api.Name, ctx.App.Name), @@ -253,10 +333,21 @@ func apiWorkloadSpecs(ctx *context.Context) ([]*WorkloadSpec, error) { } } + var spec metav1.Object + + switch api.ModelFormat { + case userconfig.TensorFlowModelFormat: + spec = tfAPISpec(ctx, api, workloadID, desiredReplicas) + case userconfig.ONNXModelFormat: + spec = onnxAPISpec(ctx, api, workloadID, desiredReplicas) + default: + return nil, errors.New(api.Name, "unknown model format encountered") // unexpected + } + workloadSpecs = append(workloadSpecs, &WorkloadSpec{ WorkloadID: workloadID, ResourceIDs: strset.New(api.ID), - K8sSpecs: []metav1.Object{apiSpec(ctx, api, workloadID, desiredReplicas), hpaSpec(ctx, api)}, + K8sSpecs: []metav1.Object{spec, hpaSpec(ctx, api)}, K8sAction: "apply", WorkloadType: WorkloadTypeAPI, // SuccessCondition: k8s.DeploymentSuccessConditionAll, # Currently success conditions don't work for multi-resource config diff --git a/pkg/operator/workloads/workflow.go b/pkg/operator/workloads/workflow.go index 2ca7c7c842..526d71ee4f 100644 --- a/pkg/operator/workloads/workflow.go +++ b/pkg/operator/workloads/workflow.go @@ -67,13 +67,13 @@ func Create(ctx *context.Context) (*awfv1.Workflow, error) { var allSpecs []*WorkloadSpec - if ctx.Environment != nil { - pythonPackageJobSpecs, err := pythonPackageWorkloadSpecs(ctx) - if err != nil { - return nil, err - } - allSpecs = append(allSpecs, pythonPackageJobSpecs...) + pythonPackageJobSpecs, err := pythonPackageWorkloadSpecs(ctx) + if err != nil { + return nil, err + } + allSpecs = append(allSpecs, pythonPackageJobSpecs...) + if ctx.Environment != nil { dataJobSpecs, err := dataWorkloadSpecs(ctx) if err != nil { return nil, err diff --git a/pkg/workloads/lib/context.py b/pkg/workloads/lib/context.py index b08bef5242..f2e81a815a 100644 --- a/pkg/workloads/lib/context.py +++ b/pkg/workloads/lib/context.py @@ -270,6 +270,26 @@ def get_estimator_impl(self, model_name): self._estimator_impls[estimator_name] = (impl, impl_path) return (impl, impl_path) + def get_request_handler_impl(self, api_name): + api = self.apis[api_name] + + module_prefix = "request_handler" + + try: + impl, impl_path = self.load_module( + module_prefix, api["name"], api["request_handler_impl_key"] + ) + except CortexException as e: + e.wrap("api " + api_name, "request_handler") + raise + + try: + _validate_impl(impl, REQUEST_HANDLER_IMPL_VALIDATION) + except CortexException as e: + e.wrap("api " + api_name, "request_handler " + api["request_handler"]) + raise + return impl + # Mode must be "training" or "evaluation" def get_training_data_parts(self, model_name, mode, part_prefix="part"): training_dataset = self.models[model_name]["dataset"] @@ -662,6 +682,13 @@ def cast_compound_type(value, type_str): ] } +REQUEST_HANDLER_IMPL_VALIDATION = { + "optional": [ + {"name": "pre_inference", "args": ["sample", "metadata"]}, + {"name": "post_inference", "args": ["prediction", "metadata"]}, + ] +} + def _validate_impl(impl, impl_req): for optional_func in impl_req.get("optional", []): diff --git a/pkg/workloads/lib/util.py b/pkg/workloads/lib/util.py index 3ee7ba2266..b4b8f3094d 100644 --- a/pkg/workloads/lib/util.py +++ b/pkg/workloads/lib/util.py @@ -929,3 +929,11 @@ def extract_resource_refs(input): return resources return set() + + +def has_function(impl, fn_name): + fn = getattr(impl, fn_name, None) + if fn is None: + return False + + return callable(fn) diff --git a/pkg/workloads/onnx_serve/api.py b/pkg/workloads/onnx_serve/api.py new file mode 100644 index 0000000000..d21219421c --- /dev/null +++ b/pkg/workloads/onnx_serve/api.py @@ -0,0 +1,237 @@ +# Copyright 2019 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import json +import argparse +import traceback +import time +from flask import Flask, request, jsonify +from flask_api import status +from waitress import serve +import onnxruntime as rt +from lib.storage import S3 +import numpy as np + +import consts +from lib import util, package, Context +from lib.log import get_logger +from lib.exceptions import CortexException, UserRuntimeException, UserException + +logger = get_logger() +logger.propagate = False # prevent double logging (flask modifies root logger) + +app = Flask(__name__) + +onnx_to_np = { + "tensor(float16)": "float16", + "tensor(float)": "float32", + "tensor(double)": "float64", + "tensor(int32)": "int32", + "tensor(int8)": "int8", + "tensor(uint8)": "uint8", + "tensor(int16)": "int16", + "tensor(uint16)": "uint16", + "tensor(int64)": "int64", + "tensor(uint64)": "uint64", + "tensor(bool)": "bool", + "tensor(string)": "string", +} + +local_cache = { + "ctx": None, + "api": None, + "sess": None, + "input_metadata": None, + "output_metadata": None, + "request_handler": None, +} + + +def prediction_failed(sample, reason=None): + message = "prediction failed for sample: {}".format(json.dumps(sample)) + if reason: + message += " ({})".format(reason) + + logger.error(message) + return message, status.HTTP_406_NOT_ACCEPTABLE + + +@app.route("/healthz", methods=["GET"]) +def health(): + return jsonify({"ok": True}) + + +def transform_to_numpy(input_pyobj, input_metadata): + target_dtype = onnx_to_np[input_metadata.type] + target_shape = input_metadata.shape + + for idx, dim in enumerate(target_shape): + if dim is None: + target_shape[idx] = 1 + + if type(input_pyobj) is not np.ndarray: + np_arr = np.array(input_pyobj, dtype=target_dtype) + else: + np_arr = input_pyobj + np_arr = np_arr.reshape(target_shape) + return np_arr + + +def convert_to_onnx_input(sample, input_metadata_list): + sess = local_cache["sess"] + + input_dict = {} + if len(input_metadata_list) == 1: + input_metadata = input_metadata_list[0] + if util.is_dict(sample): + if sample.get(input_metadata.name) is None: + raise ValueError("sample should be a dict containing key: " + input_metadata.name) + input_dict[input_metadata.name] = transform_to_numpy( + sample[input_metadata.name], input_metadata + ) + else: + input_dict[input_metadata.name] = transform_to_numpy(sample, input_metadata) + else: + for input_metadata in input_metadata_list: + if not sample.is_dict(input_metadata): + expected_keys = [metadata.name for metadata in input_metadata_list] + raise ValueError( + "sample should be a dict containing keys: " + ", ".join(expected_keys) + ) + + if sample.get(input_metadata.name) is None: + raise ValueError("sample should be a dict containing key: " + input_metadata.name) + + input_dict[input_metadata.name] = transform_to_numpy(sample, input_metadata) + return input_dict + + +@app.route("//", methods=["POST"]) +def predict(app_name, api_name): + try: + payload = request.get_json() + except Exception as e: + return "Malformed JSON", status.HTTP_400_BAD_REQUEST + + sess = local_cache["sess"] + api = local_cache["api"] + request_handler = local_cache.get("request_handler") + input_metadata = local_cache["input_metadata"] + output_metadata = local_cache["output_metadata"] + + response = {} + + if not util.is_dict(payload) or "samples" not in payload: + util.log_pretty(payload, logging_func=logger.error) + return prediction_failed(payload, "top level `samples` key not found in request") + + logger.info("Predicting " + util.pluralize(len(payload["samples"]), "sample", "samples")) + + predictions = [] + samples = payload["samples"] + if not util.is_list(samples): + util.log_pretty(samples, logging_func=logger.error) + return prediction_failed( + payload, "expected the value of key `samples` to be a list of json objects" + ) + + for i, sample in enumerate(payload["samples"]): + util.log_indent("sample {}".format(i + 1), 2) + try: + util.log_indent("Raw sample:", indent=4) + util.log_pretty(sample, indent=6) + + if request_handler is not None and util.has_function(request_handler, "pre_inference"): + sample = request_handler.pre_inference(sample, input_metadata) + + inference_input = convert_to_onnx_input(sample, input_metadata) + model_outputs = sess.run([], inference_input) + result = [] + for model_output in model_outputs: + if type(model_output) is np.ndarray: + result.append(model_output.tolist()) + else: + result.append(model_output) + + if request_handler is not None and util.has_function(request_handler, "post_inference"): + result = request_handler.post_inference(result, output_metadata) + util.log_indent("Prediction:", indent=4) + util.log_pretty(result, indent=6) + prediction = {"prediction": result} + except CortexException as e: + e.wrap("error", "sample {}".format(i + 1)) + logger.error(str(e)) + logger.exception( + "An error occurred, see `cx logs -v api {}` for more details.".format(api["name"]) + ) + return prediction_failed(sample, str(e)) + except Exception as e: + logger.exception( + "An error occurred, see `cx logs -v api {}` for more details.".format(api["name"]) + ) + return prediction_failed(sample, str(e)) + + predictions.append(prediction) + + response["predictions"] = predictions + response["resource_id"] = api["id"] + + return jsonify(response) + + +def start(args): + ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) + api = ctx.apis_id_map[args.api] + + local_cache["api"] = api + local_cache["ctx"] = ctx + if api.get("request_handler_impl_key") is not None: + package.install_packages(ctx.python_packages, ctx.storage) + local_cache["request_handler"] = ctx.get_request_handler_impl(api["name"]) + + model_cache_path = os.path.join(args.model_dir, args.api) + if not os.path.exists(model_cache_path): + ctx.storage.download_file_external(api["model"], model_cache_path) + + sess = rt.InferenceSession(model_cache_path) + local_cache["sess"] = sess + local_cache["input_metadata"] = sess.get_inputs() + local_cache["output_metadata"] = sess.get_outputs() + serve(app, listen="*:{}".format(args.port)) + logger.info("Serving model") + + +def main(): + parser = argparse.ArgumentParser() + na = parser.add_argument_group("required named arguments") + na.add_argument("--workload-id", required=True, help="Workload ID") + na.add_argument("--port", type=int, required=True, help="Port (on localhost) to use") + na.add_argument( + "--context", + required=True, + help="S3 path to context (e.g. s3://bucket/path/to/context.json)", + ) + na.add_argument("--api", required=True, help="Resource id of api to serve") + na.add_argument("--model-dir", required=True, help="Directory to download the model to") + na.add_argument("--cache-dir", required=True, help="Local path for the context cache") + parser.set_defaults(func=start) + + args = parser.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/pkg/workloads/onnx_serve/requirements.txt b/pkg/workloads/onnx_serve/requirements.txt new file mode 100644 index 0000000000..4712d29649 --- /dev/null +++ b/pkg/workloads/onnx_serve/requirements.txt @@ -0,0 +1,5 @@ +flask==1.0.2 +flask-api==1.1 +waitress==1.2.1 +onnxruntime==0.4.0 +numpy>=1.15.0 diff --git a/pkg/workloads/tf_api/api.py b/pkg/workloads/tf_api/api.py index 46fb2f8d40..760779966b 100644 --- a/pkg/workloads/tf_api/api.py +++ b/pkg/workloads/tf_api/api.py @@ -244,8 +244,16 @@ def parse_response_proto_raw(response_proto): def run_predict(sample): + request_handler = local_cache.get("request_handler") + + prepared_sample = sample + if request_handler is not None and util.has_function(request_handler, "pre_inference"): + prepared_sample = request_handler.pre_inference( + sample, local_cache["metadata"]["signatureDef"] + ) + if util.is_resource_ref(local_cache["api"]["model"]): - transformed_sample = transform_sample(sample) + transformed_sample = transform_sample(prepared_sample) prediction_request = create_prediction_request(transformed_sample) response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0) result = parse_response_proto(response_proto) @@ -260,7 +268,7 @@ def run_predict(sample): result["transformed_sample"] = transformed_sample else: - prediction_request = create_raw_prediction_request(sample) + prediction_request = create_raw_prediction_request(prepared_sample) response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0) result = parse_response_proto_raw(response_proto) util.log_indent("Sample:", indent=4) @@ -268,6 +276,9 @@ def run_predict(sample): util.log_indent("Prediction:", indent=4) util.log_pretty(result, indent=6) + if request_handler is not None and util.has_function(request_handler, "post_inference"): + result = request_handler.post_inference(result, local_cache["metadata"]["signatureDef"]) + return result @@ -382,14 +393,22 @@ def predict(deployment_name, api_name): def start(args): ctx = Context(s3_path=args.context, cache_dir=args.cache_dir, workload_id=args.workload_id) - package.install_packages(ctx.python_packages, ctx.storage) api = ctx.apis_id_map[args.api] - local_cache["api"] = api local_cache["ctx"] = ctx + if api.get("request_handler_impl_key") is not None: + local_cache["request_handler"] = ctx.get_request_handler_impl(api["name"]) + + if not util.is_resource_ref(api["model"]): + if api.get("request_handler") is not None: + package.install_packages(ctx.python_packages, ctx.storage) + if not os.path.isdir(args.model_dir): + ctx.storage.download_and_unzip_external(api["model"], args.model_dir) + if util.is_resource_ref(api["model"]): + package.install_packages(ctx.python_packages, ctx.storage) model_name = util.get_resource_ref(api["model"]) model = ctx.models[model_name] estimator = ctx.estimators[model["estimator"]] @@ -427,10 +446,6 @@ def start(args): model["input"]["target_vocab"], None, False ) - else: - if not os.path.isdir(args.model_dir): - ctx.storage.download_and_unzip_external(api["model"], args.model_dir) - channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port)) local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(channel)