diff --git a/cli/cmd/lib_realtime_apis.go b/cli/cmd/lib_realtime_apis.go index 3e2f5716db..9ff3f1a883 100644 --- a/cli/cmd/lib_realtime_apis.go +++ b/cli/cmd/lib_realtime_apis.go @@ -52,9 +52,14 @@ func realtimeAPITable(realtimeAPI schema.APIResponse, env cliconfig.Environment) out += "\n" + console.Bold("metrics dashboard: ") + *realtimeAPI.DashboardURL + "\n" } - out += "\n" + console.Bold("endpoint: ") + realtimeAPI.Endpoint + "\n" + if realtimeAPI.Spec.Predictor.IsGRPC() { + out += "\n" + console.Bold("insecure endpoint: ") + fmt.Sprintf("%s:%d", realtimeAPI.Endpoint, realtimeAPI.GRPCPorts["insecure"]) + out += "\n" + console.Bold("secure endpoint: ") + fmt.Sprintf("%s:%d", realtimeAPI.Endpoint, realtimeAPI.GRPCPorts["secure"]) + "\n" + } else { + out += "\n" + console.Bold("endpoint: ") + realtimeAPI.Endpoint + "\n" + } - if !(realtimeAPI.Spec.Predictor.Type == userconfig.PythonPredictorType && realtimeAPI.Spec.Predictor.MultiModelReloading == nil) { + if !(realtimeAPI.Spec.Predictor.Type == userconfig.PythonPredictorType && realtimeAPI.Spec.Predictor.MultiModelReloading == nil) && realtimeAPI.Spec.Predictor.ProtobufPath == nil { out += "\n" + describeModelInput(realtimeAPI.Status, realtimeAPI.Spec.Predictor, realtimeAPI.Endpoint) } diff --git a/go.mod b/go.mod index 1a7277481f..7b7579e480 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/docker/docker v0.0.0-00010101000000-000000000000 github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.4.0 // indirect + github.com/emicklei/proto v1.9.0 github.com/fatih/color v1.10.0 github.com/getsentry/sentry-go v0.8.0 github.com/go-ole/go-ole v1.2.4 // indirect diff --git a/go.sum b/go.sum index 86ac05a8f3..3ea03a6ff5 100644 --- a/go.sum +++ b/go.sum @@ -135,6 +135,8 @@ github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385/go.mod h1:0vRUJqYpeSZi github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153 h1:yUdfgN0XgIJw7foRItutHYUIhlcKzcSf5vDpdhQAKTc= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= +github.com/emicklei/proto v1.9.0 h1:l0QiNT6Qs7Yj0Mb4X6dnWBQer4ebei2BFcgQLbGqUDc= +github.com/emicklei/proto v1.9.0/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= diff --git a/manager/install.sh b/manager/install.sh index e8f9bf97a3..2206d4a88c 100755 --- a/manager/install.sh +++ b/manager/install.sh @@ -63,12 +63,12 @@ function cluster_up() { setup_grafana echo "✓" - echo -n "○ configuring gpu support (for the nodegroups that may require it)" + echo -n "○ configuring gpu support (for the nodegroups that may require it) " envsubst < manifests/nvidia.yaml | kubectl apply -f - >/dev/null NVIDIA_COM_GPU_VALUE=true envsubst < manifests/prometheus-dcgm-exporter.yaml | kubectl apply -f - >/dev/null echo "✓" - echo -n "○ configuring inf support (for the nodegroups that may require it)" + echo -n "○ configuring inf support (for the nodegroups that may require it) " envsubst < manifests/inferentia.yaml | kubectl apply -f - >/dev/null echo "✓" diff --git a/pkg/cortex/serve/cortex_internal.requirements.txt b/pkg/cortex/serve/cortex_internal.requirements.txt index 52c1e6467a..2567901317 100644 --- a/pkg/cortex/serve/cortex_internal.requirements.txt +++ b/pkg/cortex/serve/cortex_internal.requirements.txt @@ -1,4 +1,4 @@ -grpcio==1.32.0 +grpcio==1.36.0 boto3==1.14.53 datadog==0.39.0 dill>=0.3.1.1 diff --git a/pkg/cortex/serve/cortex_internal/lib/api/api.py b/pkg/cortex/serve/cortex_internal/lib/api/api.py index 05f5d8b90b..90e79aa59c 100644 --- a/pkg/cortex/serve/cortex_internal/lib/api/api.py +++ b/pkg/cortex/serve/cortex_internal/lib/api/api.py @@ -83,6 +83,21 @@ def post_request_metrics(self, status_code, total_time): ] self.post_metrics(metrics) + def post_status_code_request_metrics(self, status_code): + metrics = [ + self.status_code_metric(self.metric_dimensions(), status_code), + self.status_code_metric(self.metric_dimensions_with_id(), status_code), + ] + self.post_metrics(metrics) + + def post_latency_request_metrics(self, total_time): + total_time_ms = total_time * 1000 + metrics = [ + self.latency_metric(self.metric_dimensions(), total_time_ms), + self.latency_metric(self.metric_dimensions_with_id(), total_time_ms), + ] + self.post_metrics(metrics) + def post_metrics(self, metrics): try: if self.statsd is None: diff --git a/pkg/cortex/serve/cortex_internal/lib/api/predictor.py b/pkg/cortex/serve/cortex_internal/lib/api/predictor.py index 9ff57f0caa..48d540f92a 100644 --- a/pkg/cortex/serve/cortex_internal/lib/api/predictor.py +++ b/pkg/cortex/serve/cortex_internal/lib/api/predictor.py @@ -28,6 +28,7 @@ from cortex_internal.lib.api.validations import ( validate_class_impl, validate_python_predictor_with_models, + validate_predictor_with_grpc, are_models_specified, ) from cortex_internal.lib.client.onnx import ONNXClient @@ -61,12 +62,12 @@ { "name": "__init__", "required_args": ["self", "config"], - "optional_args": ["job_spec", "python_client", "metrics_client"], + "optional_args": ["job_spec", "python_client", "metrics_client", "proto_module_pb2"], }, { "name": "predict", "required_args": ["self"], - "optional_args": ["payload", "query_params", "headers", "batch_id"], + "optional_args": ["payload", "query_params", "headers", "batch_id", "context"], }, ], "optional": [ @@ -88,12 +89,12 @@ { "name": "__init__", "required_args": ["self", "tensorflow_client", "config"], - "optional_args": ["job_spec", "metrics_client"], + "optional_args": ["job_spec", "metrics_client", "proto_module_pb2"], }, { "name": "predict", "required_args": ["self"], - "optional_args": ["payload", "query_params", "headers", "batch_id"], + "optional_args": ["payload", "query_params", "headers", "batch_id", "context"], }, ], "optional": [ @@ -111,12 +112,12 @@ { "name": "__init__", "required_args": ["self", "onnx_client", "config"], - "optional_args": ["job_spec", "metrics_client"], + "optional_args": ["job_spec", "metrics_client", "proto_module_pb2"], }, { "name": "predict", "required_args": ["self"], - "optional_args": ["payload", "query_params", "headers", "batch_id"], + "optional_args": ["payload", "query_params", "headers", "batch_id", "context"], }, ], "optional": [ @@ -146,6 +147,7 @@ def __init__(self, api_spec: dict, model_dir: str): self.type = predictor_type_from_api_spec(api_spec) self.path = api_spec["predictor"]["path"] self.config = api_spec["predictor"].get("config", {}) + self.protobuf_path = api_spec["predictor"].get("protobuf_path") self.api_spec = api_spec @@ -234,12 +236,14 @@ def initialize_impl( project_dir: str, client: Union[PythonClient, TensorFlowClient, ONNXClient], metrics_client: DogStatsd, - job_spec: Dict[str, Any] = None, + job_spec: Optional[Dict[str, Any]] = None, + proto_module_pb2: Optional[Any] = None, ): """ Initialize predictor class as provided by the user. job_spec is a dictionary when the "kind" of the API is set to "BatchAPI". Otherwise, it's None. + proto_module_pb2 is a module of the compiled proto when grpc is enabled for the "RealtimeAPI" kind. Otherwise, it's None. Can raise UserRuntimeException/UserException/CortexException. """ @@ -257,6 +261,8 @@ def initialize_impl( args["job_spec"] = job_spec if "metrics_client" in constructor_args: args["metrics_client"] = metrics_client + if "proto_module_pb2" in constructor_args: + args["proto_module_pb2"] = proto_module_pb2 # initialize predictor class try: @@ -328,6 +334,7 @@ def class_impl(self, project_dir): try: validate_class_impl(predictor_class, validations) + validate_predictor_with_grpc(predictor_class, self.api_spec) if self.type == PythonPredictorType: validate_python_predictor_with_models(predictor_class, self.api_spec) except Exception as e: diff --git a/pkg/cortex/serve/cortex_internal/lib/api/validations.py b/pkg/cortex/serve/cortex_internal/lib/api/validations.py index e4b10209ac..e1f84dc748 100644 --- a/pkg/cortex/serve/cortex_internal/lib/api/validations.py +++ b/pkg/cortex/serve/cortex_internal/lib/api/validations.py @@ -15,6 +15,7 @@ import inspect from typing import Dict +from cortex_internal.lib import util from cortex_internal.lib.exceptions import UserException from cortex_internal.lib.type import predictor_type_from_api_spec, PythonPredictorType @@ -90,27 +91,28 @@ def validate_required_method_args(impl, func_signature): def validate_python_predictor_with_models(impl, api_spec): - target_class_name = impl.__name__ + if not are_models_specified(api_spec): + return - if are_models_specified(api_spec): - constructor = getattr(impl, "__init__") - constructor_arg_spec = inspect.getfullargspec(constructor) - if "python_client" not in constructor_arg_spec.args: - raise UserException( - f"class {target_class_name}", - f'invalid signature for method "__init__"', - f'"python_client" is a required argument, but was not provided', - f"when the python predictor type is used and models are specified in the api spec, " - f'adding the "python_client" argument is required', - ) + target_class_name = impl.__name__ + constructor = getattr(impl, "__init__") + constructor_arg_spec = inspect.getfullargspec(constructor) + if "python_client" not in constructor_arg_spec.args: + raise UserException( + f"class {target_class_name}", + f'invalid signature for method "__init__"', + f'"python_client" is a required argument, but was not provided', + f"when the python predictor type is used and models are specified in the api spec, " + f'adding the "python_client" argument is required', + ) - if getattr(impl, "load_model", None) is None: - raise UserException( - f"class {target_class_name}", - f'required method "load_model" is not defined', - f"when the python predictor type is used and models are specified in the api spec, " - f'adding the "load_model" method is required', - ) + if getattr(impl, "load_model", None) is None: + raise UserException( + f"class {target_class_name}", + f'required method "load_model" is not defined', + f"when the python predictor type is used and models are specified in the api spec, " + f'adding the "load_model" method is required', + ) def are_models_specified(api_spec: Dict) -> bool: @@ -130,3 +132,42 @@ def are_models_specified(api_spec: Dict) -> bool: return False return models is not None + + +def is_grpc_enabled(api_spec: Dict) -> bool: + """ + Checks if the API has the grpc protocol enabled (cortex.yaml). + + Args: + api_spec: API configuration. + """ + return api_spec["predictor"]["protobuf_path"] is not None + + +def validate_predictor_with_grpc(impl, api_spec): + if not is_grpc_enabled(api_spec): + return + + target_class_name = impl.__name__ + constructor = getattr(impl, "__init__") + constructor_arg_spec = inspect.getfullargspec(constructor) + if "proto_module_pb2" not in constructor_arg_spec.args: + raise UserException( + f"class {target_class_name}", + f'invalid signature for method "__init__"', + f'"proto_module_pb2" is a required argument, but was not provided', + f"when a protobuf is specified in the api spec, then that means the grpc protocol is enabled, " + f'which means that adding the "proto_module_pb2" argument is required', + ) + + predictor = getattr(impl, "predict") + predictor_arg_spec = inspect.getfullargspec(predictor) + disallowed_params = list( + set(["query_params", "headers", "batch_id"]).intersection(predictor_arg_spec.args) + ) + if len(disallowed_params) > 0: + raise UserException( + f"class {target_class_name}", + f'invalid signature for method "predict"', + f'{util.string_plural_with_s("argument", len(disallowed_params))} {util.and_list_with_quotes(disallowed_params)} cannot be used when the grpc protocol is enabled', + ) diff --git a/pkg/cortex/serve/cortex_internal/lib/util.py b/pkg/cortex/serve/cortex_internal/lib/util.py index 7f4e1503ff..ec6c09baea 100644 --- a/pkg/cortex/serve/cortex_internal/lib/util.py +++ b/pkg/cortex/serve/cortex_internal/lib/util.py @@ -318,6 +318,47 @@ def is_float_or_int_list(var): return True +def and_list_with_quotes(values: List) -> str: + """ + Converts a list like ["a", "b", "c"] to '"a", "b" and "c"'". + """ + string = "" + + if len(values) == 1: + string = '"' + values[0] + '"' + elif len(values) > 1: + for val in values[:-2]: + string += '"' + val + '", ' + string += '"' + values[-2] + '" and "' + values[-1] + '"' + + return string + + +def or_list_with_quotes(values: List) -> str: + """ + Converts a list like ["a", "b", "c"] to '"a", "b" or "c"'. + """ + string = "" + + if len(values) == 1: + string = '"' + values[0] + '"' + elif len(values) > 1: + for val in values[:-2]: + string += '"' + val + '", ' + string += '"' + values[-2] + '" or "' + values[-1] + '"' + + return string + + +def string_plural_with_s(string: str, count: int) -> str: + """ + Pluralize the word with an "s" character if the count is greater than 1. + """ + if count > 1: + string += "s" + return string + + def render_jinja_template(jinja_template_file: str, context: dict) -> str: from jinja2 import Environment, FileSystemLoader diff --git a/pkg/cortex/serve/init/bootloader.sh b/pkg/cortex/serve/init/bootloader.sh index 4dcbd059b8..f3a2d9d416 100755 --- a/pkg/cortex/serve/init/bootloader.sh +++ b/pkg/cortex/serve/init/bootloader.sh @@ -142,11 +142,25 @@ create_s6_service_from_file() { # prepare webserver if [ "$CORTEX_KIND" = "RealtimeAPI" ]; then + if [ $CORTEX_SERVING_PROTOCOL = "http" ]; then + mkdir /run/servers + fi + + if [ $CORTEX_SERVING_PROTOCOL = "grpc" ]; then + /opt/conda/envs/env/bin/python -m grpc_tools.protoc --proto_path=$CORTEX_PROJECT_DIR --python_out=$CORTEX_PYTHON_PATH --grpc_python_out=$CORTEX_PYTHON_PATH $CORTEX_PROTOBUF_FILE + fi - # prepare uvicorn workers - mkdir /run/uvicorn + # prepare servers for i in $(seq 1 $CORTEX_PROCESSES_PER_REPLICA); do - create_s6_service "uvicorn-$((i-1))" "cd /mnt/project && $source_env_file_cmd && PYTHONUNBUFFERED=TRUE PYTHONPATH=$PYTHONPATH:$CORTEX_PYTHON_PATH exec /opt/conda/envs/env/bin/python /src/cortex/serve/start/server.py /run/uvicorn/proc-$((i-1)).sock" + # prepare uvicorn workers + if [ $CORTEX_SERVING_PROTOCOL = "http" ]; then + create_s6_service "uvicorn-$((i-1))" "cd /mnt/project && $source_env_file_cmd && PYTHONUNBUFFERED=TRUE PYTHONPATH=$PYTHONPATH:$CORTEX_PYTHON_PATH exec /opt/conda/envs/env/bin/python /src/cortex/serve/start/server.py /run/servers/proc-$((i-1)).sock" + fi + + # prepare grpc workers + if [ $CORTEX_SERVING_PROTOCOL = "grpc" ]; then + create_s6_service "grpc-$((i-1))" "cd /mnt/project && $source_env_file_cmd && PYTHONUNBUFFERED=TRUE PYTHONPATH=$PYTHONPATH:$CORTEX_PYTHON_PATH exec /opt/conda/envs/env/bin/python /src/cortex/serve/start/server_grpc.py localhost:$((i-1+20000))" + fi done # generate nginx conf diff --git a/pkg/cortex/serve/nginx.conf.j2 b/pkg/cortex/serve/nginx.conf.j2 index 1f379fceaa..b9c58f9581 100644 --- a/pkg/cortex/serve/nginx.conf.j2 +++ b/pkg/cortex/serve/nginx.conf.j2 @@ -56,22 +56,57 @@ http { # to distribute load aio threads=pool; - # how much time an inference can take - proxy_read_timeout 3600s; - - upstream uvicorn { + upstream servers { # load balancing policy least_conn; {% for i in range(CORTEX_PROCESSES_PER_REPLICA | int) %} - server unix:/run/uvicorn/proc-{{ i }}.sock; + {% if CORTEX_SERVING_PROTOCOL == 'http' %} + server unix:/run/servers/proc-{{ i }}.sock; + {% endif %} + {% if CORTEX_SERVING_PROTOCOL == 'grpc' %} + server localhost:{{ i + 20000 }}; + {% endif %} {% endfor %} } + {% if CORTEX_SERVING_PROTOCOL == 'grpc' %} + server { + listen {{ CORTEX_SERVING_PORT | int }} http2; + default_type application/grpc; + underscores_in_headers on; + + grpc_read_timeout 3600s; + + location /nginx_status { + stub_status on; + allow 127.0.0.1; + deny all; + } + + location / { + limit_conn inflights {{ CORTEX_MAX_REPLICA_CONCURRENCY | int }}; + + grpc_set_header Upgrade $http_upgrade; + grpc_set_header Connection "Upgrade"; + grpc_set_header Connection keep-alive; + grpc_set_header Host $host:$server_port; + grpc_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + grpc_set_header X-Forwarded-Proto $scheme; + + grpc_pass grpc://servers; + } + } + {% endif %} + + {% if CORTEX_SERVING_PROTOCOL == 'http' %} server { listen {{ CORTEX_SERVING_PORT | int }}; underscores_in_headers on; + # how much time an inference can take + proxy_read_timeout 3600s; + location /nginx_status { stub_status on; allow 127.0.0.1; @@ -119,7 +154,8 @@ http { proxy_redirect off; proxy_buffering off; - proxy_pass http://uvicorn; + proxy_pass http://servers; } } + {% endif %} } diff --git a/pkg/cortex/serve/poll/readiness.sh b/pkg/cortex/serve/poll/readiness.sh index 248d5239a1..5f41121210 100644 --- a/pkg/cortex/serve/poll/readiness.sh +++ b/pkg/cortex/serve/poll/readiness.sh @@ -16,7 +16,7 @@ while true; do procs_ready="$(ls /mnt/workspace/proc-*-ready.txt 2>/dev/null | wc -l)" - if [ "$CORTEX_PROCESSES_PER_REPLICA" = "$procs_ready" ] && curl --silent "localhost:${CORTEX_SERVING_PORT}/nginx_status" >/dev/null; then + if [ "$CORTEX_PROCESSES_PER_REPLICA" = "$procs_ready" ] && curl --silent "localhost:$CORTEX_SERVING_PORT/nginx_status" --output /dev/null; then touch /mnt/workspace/api_readiness.txt break fi diff --git a/pkg/cortex/serve/serve.requirements.txt b/pkg/cortex/serve/serve.requirements.txt index 0dca76ba5b..eb80cbd963 100644 --- a/pkg/cortex/serve/serve.requirements.txt +++ b/pkg/cortex/serve/serve.requirements.txt @@ -1,4 +1,6 @@ -grpcio==1.32.0 +grpcio==1.36.0 +grpcio-tools==1.36.0 +grpcio-reflection==1.36.0 python-multipart==0.0.5 requests==2.24.0 uvicorn==0.11.8 diff --git a/pkg/cortex/serve/start/server_grpc.py b/pkg/cortex/serve/start/server_grpc.py new file mode 100644 index 0000000000..0d5a7bd3c2 --- /dev/null +++ b/pkg/cortex/serve/start/server_grpc.py @@ -0,0 +1,249 @@ +# Copyright 2021 Cortex Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import json +import time +import uuid +import signal +import threading +import traceback +import pathlib +import importlib +import inspect +from typing import Callable, Dict, Any +from concurrent import futures + +import grpc +from grpc_reflection.v1alpha import reflection + +from cortex_internal.lib.api import get_api +from cortex_internal.lib.concurrency import FileLock, LockedFile +from cortex_internal.lib.exceptions import UserRuntimeException +from cortex_internal.lib.log import configure_logger +from cortex_internal.lib.metrics import MetricsClient +from cortex_internal.lib.telemetry import capture_exception, get_default_tags, init_sentry + +NANOSECONDS_IN_SECOND = 1e9 + + +class ThreadPoolExecutorWithRequestMonitor: + def __init__(self, post_latency_metrics_fn: Callable[[int, float], None], *args, **kwargs): + self._post_latency_metrics_fn = post_latency_metrics_fn + self._thread_pool_executor = futures.ThreadPoolExecutor(*args, **kwargs) + + def submit(self, fn, *args, **kwargs): + request_id = uuid.uuid1() + file_id = f"/mnt/requests/{request_id}" + open(file_id, "a").close() + + start_time = time.time() + + def wrapper_fn(*args, **kwargs): + try: + result = fn(*args, **kwargs) + except: + raise + finally: + try: + os.remove(file_id) + except FileNotFoundError: + pass + self._post_latency_metrics_fn(time.time() - start_time) + + return result + + self._thread_pool_executor.submit(wrapper_fn, *args, **kwargs) + + def map(self, *args, **kwargs): + return self._thread_pool_executor.map(*args, **kwargs) + + def shutdown(self, *args, **kwargs): + return self._thread_pool_executor.shutdown(*args, **kwargs) + + +def get_service_name_from_module(module_proto_pb2_grpc) -> Any: + classes = inspect.getmembers(module_proto_pb2_grpc, inspect.isclass) + for class_name, _ in classes: + if class_name.endswith("Servicer"): + return class_name[: -len("Servicer")] + # this line will never be reached because we're guaranteed to have one servicer class in the module + + +def get_servicer_from_module(module_proto_pb2_grpc) -> Any: + classes = inspect.getmembers(module_proto_pb2_grpc, inspect.isclass) + for class_name, module_class in classes: + if class_name.endswith("Servicer"): + return module_class + # this line will never be reached because we're guaranteed to have one servicer class in the module + + +def get_servicer_to_server_from_module(module_proto_pb2_grpc) -> Any: + functions = inspect.getmembers(module_proto_pb2_grpc, inspect.isfunction) + for function_name, function in functions: + if function_name.endswith("_to_server"): + return function + # this line will never be reached because we're guaranteed to have one servicer adder in the module + + +def build_predict_kwargs(predict_fn_args, payload, context) -> Dict[str, Any]: + predict_kwargs = {} + if "payload" in predict_fn_args: + predict_kwargs["payload"] = payload + if "context" in predict_fn_args: + predict_kwargs["context"] = context + return predict_kwargs + + +def init(): + project_dir = os.environ["CORTEX_PROJECT_DIR"] + spec_path = os.environ["CORTEX_API_SPEC"] + + model_dir = os.getenv("CORTEX_MODEL_DIR") + cache_dir = os.getenv("CORTEX_CACHE_DIR") + region = os.getenv("AWS_REGION") + + tf_serving_port = os.getenv("CORTEX_TF_BASE_SERVING_PORT", "9000") + tf_serving_host = os.getenv("CORTEX_TF_SERVING_HOST", "localhost") + + has_multiple_servers = os.getenv("CORTEX_MULTIPLE_TF_SERVERS") + if has_multiple_servers: + with LockedFile("/run/used_ports.json", "r+") as f: + used_ports = json.load(f) + for port in used_ports.keys(): + if not used_ports[port]: + tf_serving_port = port + used_ports[port] = True + break + f.seek(0) + json.dump(used_ports, f) + f.truncate() + + api = get_api(spec_path, model_dir, cache_dir, region) + + config: Dict[str, Any] = { + "api": None, + "client": None, + "predictor_impl": None, + "module_proto_pb2_grpc": None, + } + + proto_without_ext = pathlib.Path(api.predictor.protobuf_path).stem + module_proto_pb2 = importlib.import_module(proto_without_ext + "_pb2") + module_proto_pb2_grpc = importlib.import_module(proto_without_ext + "_pb2_grpc") + + client = api.predictor.initialize_client( + tf_serving_host=tf_serving_host, tf_serving_port=tf_serving_port + ) + + with FileLock("/run/init_stagger.lock"): + logger.info("loading the predictor from {}".format(api.predictor.path)) + metrics_client = MetricsClient(api.statsd) + predictor_impl = api.predictor.initialize_impl( + project_dir=project_dir, + client=client, + metrics_client=metrics_client, + proto_module_pb2=module_proto_pb2, + ) + + # crons only stop if an unhandled exception occurs + def check_if_crons_have_failed(): + while True: + for cron in api.predictor.crons: + if not cron.is_alive(): + os.kill(os.getpid(), signal.SIGQUIT) + time.sleep(1) + + threading.Thread(target=check_if_crons_have_failed, daemon=True).start() + + ServicerClass = get_servicer_from_module(module_proto_pb2_grpc) + + class PredictorServicer(ServicerClass): + def __init__(self, predict_fn_args, predictor_impl, api): + self.predict_fn_args = predict_fn_args + self.predictor_impl = predictor_impl + self.api = api + + def Predict(self, payload, context): + try: + kwargs = build_predict_kwargs(self.predict_fn_args, payload, context) + response = self.predictor_impl.predict(**kwargs) + self.api.post_status_code_request_metrics(200) + except Exception: + logger.error(traceback.format_exc()) + self.api.post_status_code_request_metrics(500) + context.abort(grpc.StatusCode.INTERNAL, "internal server error") + return response + + config["api"] = api + config["client"] = client + config["predictor_impl"] = predictor_impl + config["predict_fn_args"] = inspect.getfullargspec(predictor_impl.predict).args + config["module_proto_pb2"] = module_proto_pb2 + config["module_proto_pb2_grpc"] = module_proto_pb2_grpc + config["predictor_servicer"] = PredictorServicer + + return config + + +def main(): + address = sys.argv[1] + threads_per_process = int(os.environ["CORTEX_THREADS_PER_PROCESS"]) + + try: + config = init() + except Exception as err: + if not isinstance(err, UserRuntimeException): + capture_exception(err) + logger.exception("failed to start api") + sys.exit(1) + + module_proto_pb2 = config["module_proto_pb2"] + module_proto_pb2_grpc = config["module_proto_pb2_grpc"] + PredictorServicer = config["predictor_servicer"] + + api = config["api"] + predictor_impl = config["predictor_impl"] + predict_fn_args = config["predict_fn_args"] + + server = grpc.server( + ThreadPoolExecutorWithRequestMonitor( + post_latency_metrics_fn=api.post_latency_request_metrics, + max_workers=threads_per_process, + ) + ) + + add_PredictorServicer_to_server = get_servicer_to_server_from_module(module_proto_pb2_grpc) + add_PredictorServicer_to_server(PredictorServicer(predict_fn_args, predictor_impl, api), server) + + service_name = get_service_name_from_module(module_proto_pb2_grpc) + SERVICE_NAMES = ( + module_proto_pb2.DESCRIPTOR.services_by_name[service_name].full_name, + reflection.SERVICE_NAME, + ) + reflection.enable_server_reflection(SERVICE_NAMES, server) + + server.add_insecure_port(address) + server.start() + + time.sleep(5.0) + open(f"/mnt/workspace/proc-{os.getpid()}-ready.txt", "a").close() + server.wait_for_termination() + + +if __name__ == "__main__": + init_sentry(tags=get_default_tags()) + logger = configure_logger("cortex", os.environ["CORTEX_LOG_CONFIG_FILE"]) + main() diff --git a/pkg/lib/configreader/errors.go b/pkg/lib/configreader/errors.go index b06abda4b6..557dc44028 100644 --- a/pkg/lib/configreader/errors.go +++ b/pkg/lib/configreader/errors.go @@ -35,10 +35,12 @@ const ( ErrLeadingWhitespace = "configreader.leading_whitespace" ErrTrailingWhitespace = "configreader.trailing_whitespace" ErrAlphaNumericDashUnderscore = "configreader.alpha_numeric_dash_underscore" + ErrAlphaNumericDotUnderscore = "configreader.alpha_numeric_dot_underscore" ErrAlphaNumericDashDotUnderscore = "configreader.alpha_numeric_dash_dot_underscore" ErrInvalidAWSTag = "configreader.invalid_aws_tag" ErrInvalidDockerImage = "configreader.invalid_docker_image" ErrMustHavePrefix = "configreader.must_have_prefix" + ErrMustHaveSuffix = "configreader.must_have_suffix" ErrCantHavePrefix = "configreader.cant_have_prefix" ErrInvalidInterface = "configreader.invalid_interface" ErrInvalidFloat64 = "configreader.invalid_float64" @@ -138,6 +140,13 @@ func ErrorAlphaNumericDashUnderscore(provided string) error { }) } +func ErrorAlphaNumericDotUnderscore(provided string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrAlphaNumericDotUnderscore, + Message: fmt.Sprintf("%s must contain only letters, numbers, underscores and periods", s.UserStr(provided)), + }) +} + func ErrorAlphaNumericDashDotUnderscore(provided string) error { return errors.WithStack(&errors.Error{ Kind: ErrAlphaNumericDashDotUnderscore, @@ -166,6 +175,13 @@ func ErrorMustHavePrefix(provided string, prefix string) error { }) } +func ErrorMustHaveSuffix(provided string, suffix string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrMustHaveSuffix, + Message: fmt.Sprintf("%s must end with %s", s.UserStr(provided), s.UserStr(suffix)), + }) +} + func ErrorCantHavePrefix(provided string, prefix string) error { return errors.WithStack(&errors.Error{ Kind: ErrCantHavePrefix, diff --git a/pkg/lib/configreader/string.go b/pkg/lib/configreader/string.go index f5de1d07c5..85140037bd 100644 --- a/pkg/lib/configreader/string.go +++ b/pkg/lib/configreader/string.go @@ -40,6 +40,7 @@ type StringValidation struct { DisallowedValues []string CantBeSpecifiedErrStr *string Prefix string + Suffix string InvalidPrefixes []string MaxLength int MinLength int @@ -49,6 +50,7 @@ type StringValidation struct { AlphaNumericDashDotUnderscore bool AlphaNumericDashUnderscoreOrEmpty bool AlphaNumericDashUnderscore bool + AlphaNumericDotUnderscore bool AWSTag bool DNS1035 bool DNS1123 bool @@ -262,6 +264,12 @@ func ValidateStringVal(val string, v *StringValidation) error { } } + if v.Suffix != "" { + if !strings.HasSuffix(val, v.Suffix) { + return ErrorMustHaveSuffix(val, v.Suffix) + } + } + for _, invalidPrefix := range v.InvalidPrefixes { if strings.HasPrefix(val, invalidPrefix) { return ErrorCantHavePrefix(val, invalidPrefix) @@ -292,6 +300,12 @@ func ValidateStringVal(val string, v *StringValidation) error { } } + if v.AlphaNumericDotUnderscore { + if !regex.IsAlphaNumericDotUnderscore(val) { + return ErrorAlphaNumericDotUnderscore(val) + } + } + if v.AlphaNumericDashUnderscoreOrEmpty { if !regex.IsAlphaNumericDashUnderscore(val) && val != "" { return ErrorAlphaNumericDashUnderscore(val) diff --git a/pkg/lib/configreader/string_ptr.go b/pkg/lib/configreader/string_ptr.go index 823720b4b8..9dfb0b31bd 100644 --- a/pkg/lib/configreader/string_ptr.go +++ b/pkg/lib/configreader/string_ptr.go @@ -33,6 +33,7 @@ type StringPtrValidation struct { DisallowedValues []string CantBeSpecifiedErrStr *string Prefix string + Suffix string InvalidPrefixes []string MaxLength int MinLength int @@ -41,6 +42,7 @@ type StringPtrValidation struct { AlphaNumericDashDotUnderscoreOrEmpty bool AlphaNumericDashDotUnderscore bool AlphaNumericDashUnderscore bool + AlphaNumericDotUnderscore bool AWSTag bool DNS1035 bool DNS1123 bool @@ -59,6 +61,7 @@ func makeStringValValidation(v *StringPtrValidation) *StringValidation { AllowedValues: v.AllowedValues, DisallowedValues: v.DisallowedValues, Prefix: v.Prefix, + Suffix: v.Suffix, InvalidPrefixes: v.InvalidPrefixes, MaxLength: v.MaxLength, MinLength: v.MinLength, @@ -67,6 +70,7 @@ func makeStringValValidation(v *StringPtrValidation) *StringValidation { AlphaNumericDashDotUnderscoreOrEmpty: v.AlphaNumericDashDotUnderscoreOrEmpty, AlphaNumericDashDotUnderscore: v.AlphaNumericDashDotUnderscore, AlphaNumericDashUnderscore: v.AlphaNumericDashUnderscore, + AlphaNumericDotUnderscore: v.AlphaNumericDotUnderscore, AWSTag: v.AWSTag, DNS1035: v.DNS1035, DNS1123: v.DNS1123, diff --git a/pkg/lib/k8s/service.go b/pkg/lib/k8s/service.go index 598ec4c9c8..66d86a8027 100644 --- a/pkg/lib/k8s/service.go +++ b/pkg/lib/k8s/service.go @@ -34,6 +34,7 @@ var _serviceTypeMeta = kmeta.TypeMeta{ type ServiceSpec struct { Name string + PortName string Port int32 TargetPort int32 ServiceType kcore.ServiceType @@ -56,7 +57,7 @@ func Service(spec *ServiceSpec) *kcore.Service { Ports: []kcore.ServicePort{ { Protocol: kcore.ProtocolTCP, - Name: "http", + Name: spec.PortName, Port: spec.Port, TargetPort: intstr.IntOrString{ IntVal: spec.TargetPort, diff --git a/pkg/lib/regex/regex.go b/pkg/lib/regex/regex.go index f7572b545e..7ee75f8677 100644 --- a/pkg/lib/regex/regex.go +++ b/pkg/lib/regex/regex.go @@ -61,6 +61,12 @@ func IsAlphaNumericDashUnderscore(s string) bool { return _alphaNumericDashUnderscoreRegex.MatchString(s) } +var _alphaNumericDotUnderscoreRegex = regexp.MustCompile(`^[a-zA-Z0-9_\.]+$`) + +func IsAlphaNumericDotUnderscore(s string) bool { + return _alphaNumericDotUnderscoreRegex.MatchString(s) +} + // used the evaluated form of // https://github.com/docker/distribution/blob/3150937b9f2b1b5b096b2634d0e7c44d4a0f89fb/reference/regexp.go#L68-L70 var _dockerValidImage = regexp.MustCompile( diff --git a/pkg/operator/operator/k8s.go b/pkg/operator/operator/k8s.go index 1c352c9804..b0720f06c1 100644 --- a/pkg/operator/operator/k8s.go +++ b/pkg/operator/operator/k8s.go @@ -885,12 +885,26 @@ func getEnvVars(api *spec.API, container string) []kcore.EnvVar { ) if api.Kind == userconfig.RealtimeAPIKind { + servingProtocol := "http" + if api.Predictor != nil && api.Predictor.IsGRPC() { + servingProtocol = "grpc" + } envVars = append(envVars, kcore.EnvVar{ Name: "CORTEX_API_SPEC", Value: aws.S3Path(config.CoreConfig.Bucket, api.PredictorKey), }, + kcore.EnvVar{ + Name: "CORTEX_SERVING_PROTOCOL", + Value: servingProtocol, + }, ) + if servingProtocol == "grpc" { + envVars = append(envVars, kcore.EnvVar{ + Name: "CORTEX_PROTOBUF_FILE", + Value: path.Join(_emptyDirMountPath, "project", *api.Predictor.ProtobufPath), + }) + } } else { envVars = append(envVars, kcore.EnvVar{ @@ -1445,5 +1459,10 @@ func APIEndpoint(api *spec.API) (string, error) { } baseAPIEndpoint = strings.Replace(baseAPIEndpoint, "https://", "http://", 1) + if api.Predictor != nil && api.Predictor.IsGRPC() { + baseAPIEndpoint = strings.Replace(baseAPIEndpoint, "http://", "", 1) + return baseAPIEndpoint, nil + } + return urls.Join(baseAPIEndpoint, *api.Networking.Endpoint), nil } diff --git a/pkg/operator/resources/asyncapi/k8s_specs.go b/pkg/operator/resources/asyncapi/k8s_specs.go index cfa69ac274..9400082aae 100644 --- a/pkg/operator/resources/asyncapi/k8s_specs.go +++ b/pkg/operator/resources/asyncapi/k8s_specs.go @@ -81,6 +81,7 @@ func gatewayDeploymentSpec(api spec.API, prevDeployment *kapps.Deployment, queue func gatewayServiceSpec(api spec.API) kcore.Service { return *k8s.Service(&k8s.ServiceSpec{ Name: operator.K8sName(api.Name), + PortName: "http", Port: operator.DefaultPortInt32, TargetPort: operator.DefaultPortInt32, Annotations: api.ToK8sAnnotations(), diff --git a/pkg/operator/resources/errors.go b/pkg/operator/resources/errors.go index 3df7a2143d..6a829d3145 100644 --- a/pkg/operator/resources/errors.go +++ b/pkg/operator/resources/errors.go @@ -27,14 +27,16 @@ import ( ) const ( - ErrOperationIsOnlySupportedForKind = "resources.operation_is_only_supported_for_kind" - ErrAPINotDeployed = "resources.api_not_deployed" - ErrAPIIDNotFound = "resources.api_id_not_found" - ErrCannotChangeTypeOfDeployedAPI = "resources.cannot_change_kind_of_deployed_api" - ErrNoAvailableNodeComputeLimit = "resources.no_available_node_compute_limit" - ErrJobIDRequired = "resources.job_id_required" - ErrRealtimeAPIUsedByTrafficSplitter = "resources.realtime_api_used_by_traffic_splitter" - ErrAPIsNotDeployed = "resources.apis_not_deployed" + ErrOperationIsOnlySupportedForKind = "resources.operation_is_only_supported_for_kind" + ErrAPINotDeployed = "resources.api_not_deployed" + ErrAPIIDNotFound = "resources.api_id_not_found" + ErrCannotChangeTypeOfDeployedAPI = "resources.cannot_change_kind_of_deployed_api" + ErrCannotChangeProtocolWhenUsedByTrafficSplitter = "resources.cannot_change_protocol_when_used_by_traffic_splitter" + ErrNoAvailableNodeComputeLimit = "resources.no_available_node_compute_limit" + ErrJobIDRequired = "resources.job_id_required" + ErrRealtimeAPIUsedByTrafficSplitter = "resources.realtime_api_used_by_traffic_splitter" + ErrAPIsNotDeployed = "resources.apis_not_deployed" + ErrGRPCNotSupportedForTrafficSplitter = "resources.grpc_not_supported_for_traffic_splitter" ) func ErrorOperationIsOnlySupportedForKind(resource operator.DeployedResource, supportedKind userconfig.Kind, supportedKinds ...userconfig.Kind) error { @@ -72,6 +74,13 @@ func ErrorCannotChangeKindOfDeployedAPI(name string, newKind, prevKind userconfi }) } +func ErrorCannotChangeProtocolWhenUsedByTrafficSplitter(protocolChangingAPIName string, trafficSplitters []string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrCannotChangeProtocolWhenUsedByTrafficSplitter, + Message: fmt.Sprintf("cannot change the serving protocol (http -> grpc) of api %s because it is used by the following %s: %s", protocolChangingAPIName, strings.PluralS("TrafficSplitter", len(trafficSplitters)), strings.StrsSentence(trafficSplitters, "")), + }) +} + func ErrorNoAvailableNodeComputeLimit(resource string, reqStr string, maxStr string) error { message := fmt.Sprintf("no instances can satisfy the requested %s quantity - requested %s %s but instances only have %s %s available", resource, reqStr, resource, maxStr, resource) if maxStr == "0" { @@ -100,3 +109,10 @@ func ErrorAPIsNotDeployed(notDeployedAPIs []string) error { Message: message, }) } + +func ErrorGRPCNotSupportedForTrafficSplitter(grpcAPIName string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrGRPCNotSupportedForTrafficSplitter, + Message: fmt.Sprintf("api %s (of kind %s) is served using the grpc protocol and therefore, it cannot be used for the %s kind", grpcAPIName, userconfig.RealtimeAPIKind, userconfig.TrafficSplitterKind), + }) +} diff --git a/pkg/operator/resources/realtimeapi/api.go b/pkg/operator/resources/realtimeapi/api.go index 9d6f0fd65b..4a9b678c1a 100644 --- a/pkg/operator/resources/realtimeapi/api.go +++ b/pkg/operator/resources/realtimeapi/api.go @@ -253,12 +253,19 @@ func GetAPIByName(deployedResource *operator.DeployedResource) ([]schema.APIResp dashboardURL := pointer.String(getDashboardURL(api.Name)) + grpcPorts := map[string]int64{} + if api.Predictor != nil && api.Predictor.IsGRPC() { + grpcPorts["insecure"] = 80 + grpcPorts["secure"] = 443 + } + return []schema.APIResponse{ { Spec: *api, Status: status, Metrics: metrics, Endpoint: apiEndpoint, + GRPCPorts: grpcPorts, DashboardURL: dashboardURL, }, }, nil diff --git a/pkg/operator/resources/realtimeapi/k8s_specs.go b/pkg/operator/resources/realtimeapi/k8s_specs.go index ac64ed03ae..3afd6ee9e6 100644 --- a/pkg/operator/resources/realtimeapi/k8s_specs.go +++ b/pkg/operator/resources/realtimeapi/k8s_specs.go @@ -46,19 +46,25 @@ func tensorflowAPISpec(api *spec.API, prevDeployment *kapps.Deployment) *kapps.D containers, volumes := operator.TensorFlowPredictorContainers(api) containers = append(containers, operator.RequestMonitorContainer(api)) + servingProtocol := "http" + if api.Predictor != nil && api.Predictor.IsGRPC() { + servingProtocol = "grpc" + } + return k8s.Deployment(&k8s.DeploymentSpec{ Name: operator.K8sName(api.Name), Replicas: getRequestedReplicasFromDeployment(api, prevDeployment), MaxSurge: pointer.String(api.UpdateStrategy.MaxSurge), MaxUnavailable: pointer.String(api.UpdateStrategy.MaxUnavailable), Labels: map[string]string{ - "apiName": api.Name, - "apiKind": api.Kind.String(), - "apiID": api.ID, - "specID": api.SpecID, - "deploymentID": api.DeploymentID, - "predictorID": api.PredictorID, - "cortex.dev/api": "true", + "apiName": api.Name, + "apiKind": api.Kind.String(), + "apiID": api.ID, + "specID": api.SpecID, + "deploymentID": api.DeploymentID, + "predictorID": api.PredictorID, + "servingProtocol": servingProtocol, + "cortex.dev/api": "true", }, Annotations: api.ToK8sAnnotations(), Selector: map[string]string{ @@ -67,11 +73,12 @@ func tensorflowAPISpec(api *spec.API, prevDeployment *kapps.Deployment) *kapps.D }, PodSpec: k8s.PodSpec{ Labels: map[string]string{ - "apiName": api.Name, - "apiKind": api.Kind.String(), - "deploymentID": api.DeploymentID, - "predictorID": api.PredictorID, - "cortex.dev/api": "true", + "apiName": api.Name, + "apiKind": api.Kind.String(), + "deploymentID": api.DeploymentID, + "predictorID": api.PredictorID, + "servingProtocol": servingProtocol, + "cortex.dev/api": "true", }, Annotations: map[string]string{ "traffic.sidecar.istio.io/excludeOutboundIPRanges": "0.0.0.0/0", @@ -101,19 +108,25 @@ func pythonAPISpec(api *spec.API, prevDeployment *kapps.Deployment) *kapps.Deplo containers, volumes := operator.PythonPredictorContainers(api) containers = append(containers, operator.RequestMonitorContainer(api)) + servingProtocol := "http" + if api.Predictor != nil && api.Predictor.IsGRPC() { + servingProtocol = "grpc" + } + return k8s.Deployment(&k8s.DeploymentSpec{ Name: operator.K8sName(api.Name), Replicas: getRequestedReplicasFromDeployment(api, prevDeployment), MaxSurge: pointer.String(api.UpdateStrategy.MaxSurge), MaxUnavailable: pointer.String(api.UpdateStrategy.MaxUnavailable), Labels: map[string]string{ - "apiName": api.Name, - "apiKind": api.Kind.String(), - "apiID": api.ID, - "specID": api.SpecID, - "deploymentID": api.DeploymentID, - "predictorID": api.PredictorID, - "cortex.dev/api": "true", + "apiName": api.Name, + "apiKind": api.Kind.String(), + "apiID": api.ID, + "specID": api.SpecID, + "deploymentID": api.DeploymentID, + "predictorID": api.PredictorID, + "servingProtocol": servingProtocol, + "cortex.dev/api": "true", }, Annotations: api.ToK8sAnnotations(), Selector: map[string]string{ @@ -122,11 +135,12 @@ func pythonAPISpec(api *spec.API, prevDeployment *kapps.Deployment) *kapps.Deplo }, PodSpec: k8s.PodSpec{ Labels: map[string]string{ - "apiName": api.Name, - "apiKind": api.Kind.String(), - "deploymentID": api.DeploymentID, - "predictorID": api.PredictorID, - "cortex.dev/api": "true", + "apiName": api.Name, + "apiKind": api.Kind.String(), + "deploymentID": api.DeploymentID, + "predictorID": api.PredictorID, + "servingProtocol": servingProtocol, + "cortex.dev/api": "true", }, Annotations: map[string]string{ "traffic.sidecar.istio.io/excludeOutboundIPRanges": "0.0.0.0/0", @@ -156,19 +170,25 @@ func onnxAPISpec(api *spec.API, prevDeployment *kapps.Deployment) *kapps.Deploym containers, volumes := operator.ONNXPredictorContainers(api) containers = append(containers, operator.RequestMonitorContainer(api)) + servingProtocol := "http" + if api.Predictor != nil && api.Predictor.IsGRPC() { + servingProtocol = "grpc" + } + return k8s.Deployment(&k8s.DeploymentSpec{ Name: operator.K8sName(api.Name), Replicas: getRequestedReplicasFromDeployment(api, prevDeployment), MaxSurge: pointer.String(api.UpdateStrategy.MaxSurge), MaxUnavailable: pointer.String(api.UpdateStrategy.MaxUnavailable), Labels: map[string]string{ - "apiName": api.Name, - "apiKind": api.Kind.String(), - "apiID": api.ID, - "specID": api.SpecID, - "deploymentID": api.DeploymentID, - "predictorID": api.PredictorID, - "cortex.dev/api": "true", + "apiName": api.Name, + "apiKind": api.Kind.String(), + "apiID": api.ID, + "specID": api.SpecID, + "deploymentID": api.DeploymentID, + "predictorID": api.PredictorID, + "servingProtocol": servingProtocol, + "cortex.dev/api": "true", }, Annotations: api.ToK8sAnnotations(), Selector: map[string]string{ @@ -177,11 +197,12 @@ func onnxAPISpec(api *spec.API, prevDeployment *kapps.Deployment) *kapps.Deploym }, PodSpec: k8s.PodSpec{ Labels: map[string]string{ - "apiName": api.Name, - "apiKind": api.Kind.String(), - "deploymentID": api.DeploymentID, - "predictorID": api.PredictorID, - "cortex.dev/api": "true", + "apiName": api.Name, + "apiKind": api.Kind.String(), + "deploymentID": api.DeploymentID, + "predictorID": api.PredictorID, + "servingProtocol": servingProtocol, + "cortex.dev/api": "true", }, Annotations: map[string]string{ "traffic.sidecar.istio.io/excludeOutboundIPRanges": "0.0.0.0/0", @@ -207,15 +228,21 @@ func onnxAPISpec(api *spec.API, prevDeployment *kapps.Deployment) *kapps.Deploym } func serviceSpec(api *spec.API) *kcore.Service { + servingProtocol := "http" + if api.Predictor != nil && api.Predictor.IsGRPC() { + servingProtocol = "grpc" + } return k8s.Service(&k8s.ServiceSpec{ Name: operator.K8sName(api.Name), + PortName: servingProtocol, Port: operator.DefaultPortInt32, TargetPort: operator.DefaultPortInt32, Annotations: api.ToK8sAnnotations(), Labels: map[string]string{ - "apiName": api.Name, - "apiKind": api.Kind.String(), - "cortex.dev/api": "true", + "apiName": api.Name, + "apiKind": api.Kind.String(), + "servingProtocol": servingProtocol, + "cortex.dev/api": "true", }, Selector: map[string]string{ "apiName": api.Name, @@ -225,6 +252,14 @@ func serviceSpec(api *spec.API) *kcore.Service { } func virtualServiceSpec(api *spec.API) *istioclientnetworking.VirtualService { + servingProtocol := "http" + rewritePath := pointer.String("predict") + + if api.Predictor != nil && api.Predictor.IsGRPC() { + servingProtocol = "grpc" + rewritePath = nil + } + return k8s.VirtualService(&k8s.VirtualServiceSpec{ Name: operator.K8sName(api.Name), Gateways: []string{"apis-gateway"}, @@ -234,16 +269,17 @@ func virtualServiceSpec(api *spec.API) *istioclientnetworking.VirtualService { Port: uint32(operator.DefaultPortInt32), }}, ExactPath: api.Networking.Endpoint, - Rewrite: pointer.String("predict"), + Rewrite: rewritePath, Annotations: api.ToK8sAnnotations(), Labels: map[string]string{ - "apiName": api.Name, - "apiKind": api.Kind.String(), - "apiID": api.ID, - "specID": api.SpecID, - "deploymentID": api.DeploymentID, - "predictorID": api.PredictorID, - "cortex.dev/api": "true", + "apiName": api.Name, + "apiKind": api.Kind.String(), + "servingProtocol": servingProtocol, + "apiID": api.ID, + "specID": api.SpecID, + "deploymentID": api.DeploymentID, + "predictorID": api.PredictorID, + "cortex.dev/api": "true", }, }) } diff --git a/pkg/operator/resources/resources.go b/pkg/operator/resources/resources.go index fd13afaf4f..2d462d9170 100644 --- a/pkg/operator/resources/resources.go +++ b/pkg/operator/resources/resources.go @@ -118,6 +118,7 @@ func Deploy(projectBytes []byte, configFileName string, configBytes []byte, forc results := make([]schema.DeployResult, 0, len(apiConfigs)) for i := range apiConfigs { apiConfig := apiConfigs[i] + api, msg, err := UpdateAPI(&apiConfig, projectID, force) result := schema.DeployResult{ @@ -145,6 +146,40 @@ func UpdateAPI(apiConfig *userconfig.API, projectID string, force bool) (*schema return nil, "", ErrorCannotChangeKindOfDeployedAPI(apiConfig.Name, apiConfig.Kind, deployedResource.Kind) } + if deployedResource != nil { + prevAPISpec, err := operator.DownloadAPISpec(deployedResource.Name, deployedResource.ID()) + if err != nil { + return nil, "", err + } + + if deployedResource.Kind == userconfig.RealtimeAPIKind && prevAPISpec != nil && !prevAPISpec.Predictor.IsGRPC() && apiConfig.Predictor.IsGRPC() { + realtimeAPIName := deployedResource.Name + + virtualServices, err := config.K8s.ListVirtualServicesByLabel("apiKind", userconfig.TrafficSplitterKind.String()) + if err != nil { + return nil, "", err + } + + trafficSplitterList, err := trafficsplitter.GetAllAPIs(virtualServices) + if err != nil { + return nil, "", err + } + + dependentTrafficSplitters := []string{} + for _, trafficSplitter := range trafficSplitterList { + for _, api := range trafficSplitter.Spec.APIs { + if realtimeAPIName == api.Name { + dependentTrafficSplitters = append(dependentTrafficSplitters, api.Name) + } + } + } + + if len(dependentTrafficSplitters) > 0 { + return nil, "", ErrorCannotChangeProtocolWhenUsedByTrafficSplitter(realtimeAPIName, dependentTrafficSplitters) + } + } + } + telemetry.Event("operator.deploy", apiConfig.TelemetryEvent()) var api *spec.API @@ -252,6 +287,33 @@ func patchAPI(apiConfig *userconfig.API, force bool) (*spec.API, string, error) return nil, "", err } + if deployedResource.Kind == userconfig.RealtimeAPIKind && !prevAPISpec.Predictor.IsGRPC() && apiConfig.Predictor.IsGRPC() { + realtimeAPIName := deployedResource.Name + + virtualServices, err := config.K8s.ListVirtualServicesByLabel("apiKind", userconfig.TrafficSplitterKind.String()) + if err != nil { + return nil, "", err + } + + trafficSplitterList, err := trafficsplitter.GetAllAPIs(virtualServices) + if err != nil { + return nil, "", err + } + + dependentTrafficSplitters := []string{} + for _, trafficSplitter := range trafficSplitterList { + for _, api := range trafficSplitter.Spec.APIs { + if realtimeAPIName == api.Name { + dependentTrafficSplitters = append(dependentTrafficSplitters, api.Name) + } + } + } + + if len(dependentTrafficSplitters) > 0 { + return nil, "", ErrorCannotChangeProtocolWhenUsedByTrafficSplitter(realtimeAPIName, dependentTrafficSplitters) + } + } + switch deployedResource.Kind { case userconfig.RealtimeAPIKind: return realtimeapi.UpdateAPI(apiConfig, prevAPISpec.ProjectID, force) diff --git a/pkg/operator/resources/validations.go b/pkg/operator/resources/validations.go index 658a8462f1..3caaad20d6 100644 --- a/pkg/operator/resources/validations.go +++ b/pkg/operator/resources/validations.go @@ -31,6 +31,8 @@ import ( "github.com/cortexlabs/cortex/pkg/types/userconfig" istioclientnetworking "istio.io/client-go/pkg/apis/networking/v1beta1" kresource "k8s.io/apimachinery/pkg/api/resource" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + klabels "k8s.io/apimachinery/pkg/labels" ) type ProjectFiles struct { @@ -77,12 +79,26 @@ func ValidateClusterAPIs(apis []userconfig.API, projectFiles spec.ProjectFiles) if err != nil { return err } - - deployedRealtimeAPIs := strset.New() - + httpDeployedRealtimeAPIs := strset.New() for _, virtualService := range virtualServices { if virtualService.Labels["apiKind"] == userconfig.RealtimeAPIKind.String() { - deployedRealtimeAPIs.Add(virtualService.Labels["apiName"]) + httpDeployedRealtimeAPIs.Add(virtualService.Labels["apiName"]) + } + } + + virtualServicesForGrpc, err := config.K8s.ListVirtualServices(&v1.ListOptions{ + LabelSelector: klabels.SelectorFromSet( + map[string]string{ + "servingProtocol": "grpc", + }).String(), + }) + if err != nil { + return err + } + grpcDeployedRealtimeAPIs := strset.New() + for _, virtualService := range virtualServicesForGrpc { + if virtualService.Labels["apiKind"] == userconfig.RealtimeAPIKind.String() { + grpcDeployedRealtimeAPIs.Add(virtualService.Labels["apiName"]) } } @@ -106,7 +122,7 @@ func ValidateClusterAPIs(apis []userconfig.API, projectFiles spec.ProjectFiles) if err := spec.ValidateTrafficSplitter(api); err != nil { return errors.Wrap(err, api.Identify()) } - if err := checkIfAPIExists(api.APIs, realtimeAPIs, deployedRealtimeAPIs); err != nil { + if err := checkIfAPIExists(api.APIs, realtimeAPIs, httpDeployedRealtimeAPIs, grpcDeployedRealtimeAPIs); err != nil { return errors.Wrap(err, api.Identify()) } if err := validateEndpointCollisions(api, virtualServices); err != nil { @@ -290,13 +306,19 @@ func ExclusiveFilterAPIsByKind(apis []userconfig.API, kindsToExclude ...userconf return fileredAPIs } -// checkIfAPIExists checks if referenced apis in trafficsplitter are either defined in yaml or already deployed -func checkIfAPIExists(trafficSplitterAPIs []*userconfig.TrafficSplit, apis []userconfig.API, deployedRealtimeAPIs strset.Set) error { +// checkIfAPIExists checks if referenced apis in trafficsplitter are either defined in yaml or already deployed. +// Also prevents traffic splitting apis that use grpc. +func checkIfAPIExists(trafficSplitterAPIs []*userconfig.TrafficSplit, apis []userconfig.API, httpDeployedRealtimeAPIs strset.Set, grpcDeployedRealtimeAPIs strset.Set) error { var missingAPIs []string // check if apis named in trafficsplitter are either defined in same yaml or already deployed for _, trafficSplitAPI := range trafficSplitterAPIs { - //check if already deployed - deployed := deployedRealtimeAPIs.Has(trafficSplitAPI.Name) + // don't allow apis that use grpc + if grpcDeployedRealtimeAPIs.Has(trafficSplitAPI.Name) { + return ErrorGRPCNotSupportedForTrafficSplitter(trafficSplitAPI.Name) + } + + // check if already deployed + deployed := httpDeployedRealtimeAPIs.Has(trafficSplitAPI.Name) // check defined apis for _, definedAPI := range apis { diff --git a/pkg/operator/schema/schema.go b/pkg/operator/schema/schema.go index 4f5682d55e..15a87d46d0 100644 --- a/pkg/operator/schema/schema.go +++ b/pkg/operator/schema/schema.go @@ -54,6 +54,7 @@ type APIResponse struct { Status *status.Status `json:"status,omitempty"` Metrics *metrics.Metrics `json:"metrics,omitempty"` Endpoint string `json:"endpoint"` + GRPCPorts map[string]int64 `json:"grpc_ports,omitempty"` DashboardURL *string `json:"dashboard_url,omitempty"` BatchJobStatuses []status.BatchJobStatus `json:"batch_job_statuses,omitempty"` TaskJobStatuses []status.TaskJobStatus `json:"task_job_statuses,omitempty"` diff --git a/pkg/types/spec/errors.go b/pkg/types/spec/errors.go index 9458f4c138..af43f6dbda 100644 --- a/pkg/types/spec/errors.go +++ b/pkg/types/spec/errors.go @@ -73,6 +73,13 @@ const ( ErrDuplicateModelNames = "spec.duplicate_model_names" ErrReservedModelName = "spec.reserved_model_name" + ErrProtoNumServicesExceeded = "spec.proto_num_services_exceeded" + ErrProtoNumServiceMethodsExceeded = "spec.proto_num_service_methods_exceeded" + ErrProtoInvalidServiceMethod = "spec.proto_invalid_service_method" + ErrProtoMissingPackageName = "spec.proto_missing_package_name" + ErrProtoInvalidPackageName = "spec.proto_invalid_package_name" + ErrProtoInvalidNetworkingEndpoint = "spec.proto_invalid_networking_endpoint" + ErrFieldMustBeDefinedForPredictorType = "spec.field_must_be_defined_for_predictor_type" ErrFieldNotSupportedByPredictorType = "spec.field_not_supported_by_predictor_type" ErrPredictorTypeNotSupportedForKind = "spec.predictor_type_not_supported_by_kind" @@ -479,6 +486,48 @@ func ErrorReservedModelName(reservedModel string) error { }) } +func ErrorProtoNumServicesExceeded(requested int) error { + return errors.WithStack(&errors.Error{ + Kind: ErrProtoNumServicesExceeded, + Message: fmt.Sprintf("cannot have more than one service defined; there are currently %d services defined", requested), + }) +} + +func ErrorProtoNumServiceMethodsExceeded(requested int, serviceName string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrProtoNumServiceMethodsExceeded, + Message: fmt.Sprintf("cannot have more than one service method for service %s; there are currently %d service methods defined", serviceName, requested), + }) +} + +func ErrorProtoInvalidServiceMethod(requestedServiceMethodName, allowedServiceMethodName, serviceName string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrProtoInvalidServiceMethod, + Message: fmt.Sprintf("found %s service method in service %s; service %s can only have a single service method defined that must be called %s", requestedServiceMethodName, serviceName, serviceName, allowedServiceMethodName), + }) +} + +func ErrorProtoMissingPackageName(allowedPackageName string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrProtoMissingPackageName, + Message: fmt.Sprintf("your protobuf definition must have the %s package defined", allowedPackageName), + }) +} + +func ErrorProtoInvalidPackageName(requestedPackageName, allowedPackageName string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrProtoInvalidPackageName, + Message: fmt.Sprintf("found invalid package %s; your package must be named %s", requestedPackageName, allowedPackageName), + }) +} + +func ErrorProtoInvalidNetworkingEndpoint(allowedValue string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrProtoInvalidNetworkingEndpoint, + Message: fmt.Sprintf("because of the protobuf definition from section %s and field %s, the only permitted value is %s", userconfig.PredictorKey, userconfig.ProtobufPathKey, allowedValue), + }) +} + func ErrorFieldMustBeDefinedForPredictorType(fieldKey string, predictorType userconfig.PredictorType) error { return errors.WithStack(&errors.Error{ Kind: ErrFieldMustBeDefinedForPredictorType, diff --git a/pkg/types/spec/validations.go b/pkg/types/spec/validations.go index 064e9fe52f..41f6d0aa7d 100644 --- a/pkg/types/spec/validations.go +++ b/pkg/types/spec/validations.go @@ -17,6 +17,7 @@ limitations under the License. package spec import ( + "bytes" "context" "fmt" "strings" @@ -40,6 +41,7 @@ import ( "github.com/cortexlabs/cortex/pkg/lib/urls" "github.com/cortexlabs/cortex/pkg/types/userconfig" dockertypes "github.com/docker/docker/api/types" + pbparser "github.com/emicklei/proto" kresource "k8s.io/apimachinery/pkg/api/resource" ) @@ -165,6 +167,16 @@ func predictorValidation() *cr.StructFieldValidation { StructField: "Path", StringValidation: &cr.StringValidation{ Required: true, + Suffix: ".py", + }, + }, + { + StructField: "ProtobufPath", + StringPtrValidation: &cr.StringPtrValidation{ + Default: nil, + AllowExplicitNull: true, + AlphaNumericDotUnderscore: true, + Suffix: ".proto", }, }, { @@ -737,7 +749,7 @@ func ValidateAPI( models = &[]CuratedModelResource{} } - if api.Networking.Endpoint == nil { + if api.Networking.Endpoint == nil && (api.Predictor == nil || (api.Predictor != nil && api.Predictor.ProtobufPath == nil)) { api.Networking.Endpoint = pointer.String("/" + api.Name) } @@ -748,6 +760,9 @@ func ValidateAPI( } default: if err := validatePredictor(api, models, projectFiles, awsClient, k8sClient); err != nil { + if errors.GetKind(err) == ErrProtoInvalidNetworkingEndpoint { + return errors.Wrap(err, userconfig.NetworkingKey, userconfig.EndpointKey) + } return errors.Wrap(err, userconfig.PredictorKey) } } @@ -843,6 +858,27 @@ func validatePredictor( k8sClient *k8s.Client, ) error { predictor := api.Predictor + + if !projectFiles.HasFile(predictor.Path) { + return errors.Wrap(files.ErrorFileDoesNotExist(predictor.Path), userconfig.PathKey) + } + + if predictor.PythonPath != nil { + if err := validatePythonPath(predictor, projectFiles); err != nil { + return errors.Wrap(err, userconfig.PythonPathKey) + } + } + + if predictor.IsGRPC() { + if api.Kind != userconfig.RealtimeAPIKind { + return ErrorKeyIsNotSupportedForKind(userconfig.ProtobufPathKey, api.Kind) + } + + if err := validateProtobufPath(api, projectFiles); err != nil { + return err + } + } + if err := validateMultiModelsFields(api); err != nil { return err } @@ -893,16 +929,6 @@ func validatePredictor( } } - if !projectFiles.HasFile(predictor.Path) { - return errors.Wrap(files.ErrorFileDoesNotExist(predictor.Path), userconfig.PathKey) - } - - if predictor.PythonPath != nil { - if err := validatePythonPath(predictor, projectFiles); err != nil { - return errors.Wrap(err, userconfig.PythonPathKey) - } - } - return nil } @@ -1300,6 +1326,85 @@ func validateONNXModelFilePath(modelPath string, awsClient *aws.Client) error { return nil } +func validateProtobufPath(api *userconfig.API, projectFiles ProjectFiles) error { + apiName := api.Name + protobufPath := *api.Predictor.ProtobufPath + + if !projectFiles.HasFile(protobufPath) { + return errors.Wrap(files.ErrorFileDoesNotExist(protobufPath), userconfig.ProtobufPathKey) + } + protoBytes, err := projectFiles.GetFile(protobufPath) + if err != nil { + return errors.Wrap(err, userconfig.ProtobufPathKey, *api.Predictor.ProtobufPath) + } + + protoReader := bytes.NewReader(protoBytes) + parser := pbparser.NewParser(protoReader) + proto, err := parser.Parse() + if err != nil { + return errors.Wrap(errors.WithStack(err), userconfig.ProtobufPathKey, *api.Predictor.ProtobufPath) + } + + var packageName string + var serviceName string + var serviceMethodName string = "Predict" + var detectedMethodName string + + numServices := 0 + numRPCs := 0 + pbparser.Walk(proto, + pbparser.WithPackage(func(pkg *pbparser.Package) { + packageName = pkg.Name + }), + pbparser.WithService(func(service *pbparser.Service) { + numServices++ + serviceName = service.Name + for _, elem := range service.Elements { + if s, ok := elem.(*pbparser.RPC); ok { + numRPCs++ + detectedMethodName = s.Name + } + } + }), + ) + + if numServices > 1 { + return errors.Wrap(ErrorProtoNumServicesExceeded(numServices), userconfig.ProtobufPathKey, *api.Predictor.ProtobufPath) + } + + if numRPCs > 1 { + return errors.Wrap(ErrorProtoNumServiceMethodsExceeded(numRPCs, serviceName), userconfig.ProtobufPathKey, *api.Predictor.ProtobufPath) + } + + if serviceMethodName != detectedMethodName { + return errors.Wrap(ErrorProtoInvalidServiceMethod(detectedMethodName, serviceMethodName, serviceName), userconfig.ProtobufPathKey, *api.Predictor.ProtobufPath) + } + + var requiredPackageName string + requiredPackageName = strings.ReplaceAll(apiName, "-", "_") + + if api.Predictor.ServerSideBatching != nil { + return ErrorConflictingFields(userconfig.ProtobufPathKey, userconfig.ServerSideBatchingKey) + } + + if packageName == "" { + return errors.Wrap(ErrorProtoMissingPackageName(requiredPackageName), userconfig.ProtobufPathKey, *api.Predictor.ProtobufPath) + } + if packageName != requiredPackageName { + return errors.Wrap(ErrorProtoInvalidPackageName(packageName, requiredPackageName), userconfig.ProtobufPathKey, *api.Predictor.ProtobufPath) + } + + requiredEndpoint := "/" + requiredPackageName + "." + serviceName + "/" + serviceMethodName + if api.Networking.Endpoint == nil { + api.Networking.Endpoint = pointer.String(requiredEndpoint) + } + if *api.Networking.Endpoint != requiredEndpoint { + return ErrorProtoInvalidNetworkingEndpoint(requiredEndpoint) + } + + return nil +} + func validatePythonPath(predictor *userconfig.Predictor, projectFiles ProjectFiles) error { if !projectFiles.HasDir(*predictor.PythonPath) { return ErrorPythonPathNotFound(*predictor.PythonPath) diff --git a/pkg/types/userconfig/api.go b/pkg/types/userconfig/api.go index 6c83536b86..74f0ab01ae 100644 --- a/pkg/types/userconfig/api.go +++ b/pkg/types/userconfig/api.go @@ -44,8 +44,9 @@ type API struct { } type Predictor struct { - Type PredictorType `json:"type" yaml:"type"` - Path string `json:"path" yaml:"path"` + Type PredictorType `json:"type" yaml:"type"` + Path string `json:"path" yaml:"path"` + ProtobufPath *string `json:"protobuf_path" yaml:"protobuf_path"` MultiModelReloading *MultiModels `json:"multi_model_reloading" yaml:"multi_model_reloading"` Models *MultiModels `json:"models" yaml:"models"` @@ -210,6 +211,10 @@ func (api *API) applyTaskDefaultDockerPaths(usesGPU, usesInf bool) { } } +func (predictor *Predictor) IsGRPC() bool { + return predictor.ProtobufPath != nil +} + func IdentifyAPI(filePath string, name string, kind Kind, index int) string { str := "" @@ -421,6 +426,10 @@ func (predictor *Predictor) UserStr() string { sb.WriteString(fmt.Sprintf("%s: %s\n", TypeKey, predictor.Type)) sb.WriteString(fmt.Sprintf("%s: %s\n", PathKey, predictor.Path)) + if predictor.ProtobufPath != nil { + sb.WriteString(fmt.Sprintf("%s: %s\n", ProtobufPathKey, *predictor.ProtobufPath)) + } + if predictor.Models != nil { sb.WriteString(fmt.Sprintf("%s:\n", ModelsKey)) sb.WriteString(s.Indent(predictor.Models.UserStr(), " ")) @@ -668,6 +677,9 @@ func (api *API) TelemetryEvent() map[string]interface{} { event["predictor.log_level"] = api.Predictor.LogLevel + if api.Predictor.ProtobufPath != nil { + event["predictor.protobuf_path._is_defined"] = true + } if api.Predictor.PythonPath != nil { event["predictor.python_path._is_defined"] = true } diff --git a/pkg/types/userconfig/config_key.go b/pkg/types/userconfig/config_key.go index 7ba49e631d..8a6d2a7912 100644 --- a/pkg/types/userconfig/config_key.go +++ b/pkg/types/userconfig/config_key.go @@ -35,6 +35,7 @@ const ( // Predictor TypeKey = "type" PathKey = "path" + ProtobufPathKey = "protobuf_path" ServerSideBatchingKey = "server_side_batching" PythonPathKey = "python_path" ImageKey = "image" diff --git a/test/apis/grpc/iris-classifier-sklearn/README.md b/test/apis/grpc/iris-classifier-sklearn/README.md new file mode 100644 index 0000000000..de87b54b09 --- /dev/null +++ b/test/apis/grpc/iris-classifier-sklearn/README.md @@ -0,0 +1,32 @@ +## gRPC client + +#### Step 1 + +```bash +pip install grpc +python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. iris_classifier.proto +``` + +#### Step 2 + +```python +import cortex +import iris_classifier_pb2 +import iris_classifier_pb2_grpc + +sample = iris_classifier_pb2.Sample( + sepal_length=5.2, + sepal_width=3.6, + petal_length=1.4, + petal_width=0.3 +) + +cx = cortex.client("aws") +api = cx.get_api("iris-classifier") +grpc_endpoint = api["endpoint"] + ":" + str(api["grpc_ports"]["insecure"]) +channel = grpc.insecure_channel(grpc_endpoint) +stub = iris_classifier_pb2_grpc.PredictorStub(channel) + +response = stub.Predict(sample) +print("prediction:", response.classification) +``` diff --git a/test/apis/grpc/iris-classifier-sklearn/cortex.yaml b/test/apis/grpc/iris-classifier-sklearn/cortex.yaml new file mode 100644 index 0000000000..dd3e956ea2 --- /dev/null +++ b/test/apis/grpc/iris-classifier-sklearn/cortex.yaml @@ -0,0 +1,12 @@ +- name: iris-classifier + kind: RealtimeAPI + predictor: + type: python + path: predictor.py + protobuf_path: iris_classifier.proto + config: + bucket: cortex-examples + key: sklearn/iris-classifier/model.pkl + compute: + cpu: 0.2 + mem: 200M diff --git a/test/apis/grpc/iris-classifier-sklearn/expectations.yaml b/test/apis/grpc/iris-classifier-sklearn/expectations.yaml new file mode 100644 index 0000000000..d7a1ec85e1 --- /dev/null +++ b/test/apis/grpc/iris-classifier-sklearn/expectations.yaml @@ -0,0 +1,16 @@ +grpc: + proto_module_pb2: "test_proto/iris_classifier_pb2.py" + proto_module_pb2_grpc: "test_proto/iris_classifier_pb2_grpc.py" + stub_service_name: "Predictor" + input_spec: + class_name: "Sample" + input: + sepal_length: 5.2 + sepal_width: 3.6 + petal_length: 1.4 + petal_width: 0.3 + output_spec: + class_name: "Response" + stream: false + output: + classification: "setosa" diff --git a/test/apis/grpc/iris-classifier-sklearn/iris_classifier.proto b/test/apis/grpc/iris-classifier-sklearn/iris_classifier.proto new file mode 100644 index 0000000000..d38660512a --- /dev/null +++ b/test/apis/grpc/iris-classifier-sklearn/iris_classifier.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package iris_classifier; + +service Predictor { + rpc Predict (Sample) returns (Response); +} + +message Sample { + float sepal_length = 1; + float sepal_width = 2; + float petal_length = 3; + float petal_width = 4; +} + +message Response { + string classification = 1; +} diff --git a/test/apis/grpc/iris-classifier-sklearn/predictor.py b/test/apis/grpc/iris-classifier-sklearn/predictor.py new file mode 100644 index 0000000000..c15cc7470c --- /dev/null +++ b/test/apis/grpc/iris-classifier-sklearn/predictor.py @@ -0,0 +1,30 @@ +import os +import boto3 +from botocore import UNSIGNED +from botocore.client import Config +import pickle + +labels = ["setosa", "versicolor", "virginica"] + + +class PythonPredictor: + def __init__(self, config, proto_module_pb2): + if os.environ.get("AWS_ACCESS_KEY_ID"): + s3 = boto3.client("s3") # client will use your credentials if available + else: + s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) # anonymous client + + s3.download_file(config["bucket"], config["key"], "/tmp/model.pkl") + self.model = pickle.load(open("/tmp/model.pkl", "rb")) + self.proto_module_pb2 = proto_module_pb2 + + def predict(self, payload): + measurements = [ + payload.sepal_length, + payload.sepal_width, + payload.petal_length, + payload.petal_width, + ] + + label_id = self.model.predict([measurements])[0] + return self.proto_module_pb2.Response(classification=labels[label_id]) diff --git a/test/apis/grpc/iris-classifier-sklearn/requirements.txt b/test/apis/grpc/iris-classifier-sklearn/requirements.txt new file mode 100644 index 0000000000..bbc213cf3e --- /dev/null +++ b/test/apis/grpc/iris-classifier-sklearn/requirements.txt @@ -0,0 +1,2 @@ +boto3 +scikit-learn==0.21.3 diff --git a/test/apis/grpc/iris-classifier-sklearn/test_proto/iris_classifier_pb2.py b/test/apis/grpc/iris-classifier-sklearn/test_proto/iris_classifier_pb2.py new file mode 100644 index 0000000000..ddd9fdd1e6 --- /dev/null +++ b/test/apis/grpc/iris-classifier-sklearn/test_proto/iris_classifier_pb2.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: iris_classifier.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor.FileDescriptor( + name="iris_classifier.proto", + package="iris_classifier", + syntax="proto3", + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\x15iris_classifier.proto\x12\x0firis_classifier"^\n\x06Sample\x12\x14\n\x0csepal_length\x18\x01 \x01(\x02\x12\x13\n\x0bsepal_width\x18\x02 \x01(\x02\x12\x14\n\x0cpetal_length\x18\x03 \x01(\x02\x12\x13\n\x0bpetal_width\x18\x04 \x01(\x02""\n\x08Response\x12\x16\n\x0e\x63lassification\x18\x01 \x01(\t2J\n\tPredictor\x12=\n\x07Predict\x12\x17.iris_classifier.Sample\x1a\x19.iris_classifier.Responseb\x06proto3', +) + + +_SAMPLE = _descriptor.Descriptor( + name="Sample", + full_name="iris_classifier.Sample", + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name="sepal_length", + full_name="iris_classifier.Sample.sepal_length", + index=0, + number=1, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name="sepal_width", + full_name="iris_classifier.Sample.sepal_width", + index=1, + number=2, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name="petal_length", + full_name="iris_classifier.Sample.petal_length", + index=2, + number=3, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name="petal_width", + full_name="iris_classifier.Sample.petal_width", + index=3, + number=4, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=42, + serialized_end=136, +) + + +_RESPONSE = _descriptor.Descriptor( + name="Response", + full_name="iris_classifier.Response", + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name="classification", + full_name="iris_classifier.Response.classification", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=138, + serialized_end=172, +) + +DESCRIPTOR.message_types_by_name["Sample"] = _SAMPLE +DESCRIPTOR.message_types_by_name["Response"] = _RESPONSE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Sample = _reflection.GeneratedProtocolMessageType( + "Sample", + (_message.Message,), + { + "DESCRIPTOR": _SAMPLE, + "__module__": "iris_classifier_pb2" + # @@protoc_insertion_point(class_scope:iris_classifier.Sample) + }, +) +_sym_db.RegisterMessage(Sample) + +Response = _reflection.GeneratedProtocolMessageType( + "Response", + (_message.Message,), + { + "DESCRIPTOR": _RESPONSE, + "__module__": "iris_classifier_pb2" + # @@protoc_insertion_point(class_scope:iris_classifier.Response) + }, +) +_sym_db.RegisterMessage(Response) + + +_PREDICTOR = _descriptor.ServiceDescriptor( + name="Predictor", + full_name="iris_classifier.Predictor", + file=DESCRIPTOR, + index=0, + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_start=174, + serialized_end=248, + methods=[ + _descriptor.MethodDescriptor( + name="Predict", + full_name="iris_classifier.Predictor.Predict", + index=0, + containing_service=None, + input_type=_SAMPLE, + output_type=_RESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), + ], +) +_sym_db.RegisterServiceDescriptor(_PREDICTOR) + +DESCRIPTOR.services_by_name["Predictor"] = _PREDICTOR + +# @@protoc_insertion_point(module_scope) diff --git a/test/apis/grpc/iris-classifier-sklearn/test_proto/iris_classifier_pb2_grpc.py b/test/apis/grpc/iris-classifier-sklearn/test_proto/iris_classifier_pb2_grpc.py new file mode 100644 index 0000000000..a3c3a550a6 --- /dev/null +++ b/test/apis/grpc/iris-classifier-sklearn/test_proto/iris_classifier_pb2_grpc.py @@ -0,0 +1,79 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import iris_classifier_pb2 as iris__classifier__pb2 + + +class PredictorStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Predict = channel.unary_unary( + "/iris_classifier.Predictor/Predict", + request_serializer=iris__classifier__pb2.Sample.SerializeToString, + response_deserializer=iris__classifier__pb2.Response.FromString, + ) + + +class PredictorServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Predict(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_PredictorServicer_to_server(servicer, server): + rpc_method_handlers = { + "Predict": grpc.unary_unary_rpc_method_handler( + servicer.Predict, + request_deserializer=iris__classifier__pb2.Sample.FromString, + response_serializer=iris__classifier__pb2.Response.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "iris_classifier.Predictor", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class Predictor(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Predict( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/iris_classifier.Predictor/Predict", + iris__classifier__pb2.Sample.SerializeToString, + iris__classifier__pb2.Response.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/test/apis/grpc/prime-number-generator/README.md b/test/apis/grpc/prime-number-generator/README.md new file mode 100644 index 0000000000..e98c11d49b --- /dev/null +++ b/test/apis/grpc/prime-number-generator/README.md @@ -0,0 +1,26 @@ +## Prime number generator + +#### Step 1 + +```bash +pip install grpc +python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. generator.proto +``` + +#### Step 2 + +```python +import cortex +import generator_pb2 +import generator_pb2_grpc + +cx = cortex.client("aws") +api = cx.get_api("prime-generator") +grpc_endpoint = api["endpoint"] + ":" + str(api["grpc_ports"]["insecure"]) + +prime_numbers_to_generate = 5 + +channel = grpc.insecure_channel(grpc_endpoint) +for r in stub.Predict(generator_pb2.Input(prime_numbers_to_generate=prime_numbers_to_generate)): + print(r) +``` diff --git a/test/apis/grpc/prime-number-generator/cortex.yaml b/test/apis/grpc/prime-number-generator/cortex.yaml new file mode 100644 index 0000000000..640172759c --- /dev/null +++ b/test/apis/grpc/prime-number-generator/cortex.yaml @@ -0,0 +1,9 @@ +- name: prime-generator + kind: RealtimeAPI + predictor: + type: python + path: generator.py + protobuf_path: generator.proto + compute: + cpu: 200m + mem: 200Mi diff --git a/test/apis/grpc/prime-number-generator/expectations.yaml b/test/apis/grpc/prime-number-generator/expectations.yaml new file mode 100644 index 0000000000..15265c06a2 --- /dev/null +++ b/test/apis/grpc/prime-number-generator/expectations.yaml @@ -0,0 +1,17 @@ +grpc: + proto_module_pb2: "test_proto/generator_pb2.py" + proto_module_pb2_grpc: "test_proto/generator_pb2_grpc.py" + stub_service_name: "Generator" + input_spec: + class_name: "Input" + input: + prime_numbers_to_generate: 5 + output_spec: + class_name: "Output" + stream: true + output: + - prime_number: 2 + - prime_number: 3 + - prime_number: 5 + - prime_number: 7 + - prime_number: 11 diff --git a/test/apis/grpc/prime-number-generator/generator.proto b/test/apis/grpc/prime-number-generator/generator.proto new file mode 100644 index 0000000000..6be4f808ac --- /dev/null +++ b/test/apis/grpc/prime-number-generator/generator.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package prime_generator; + +service Generator { + rpc Predict (Input) returns (stream Output); +} + +message Input { + int64 prime_numbers_to_generate = 1; +} + +message Output { + int64 prime_number = 1; +} diff --git a/test/apis/grpc/prime-number-generator/generator.py b/test/apis/grpc/prime-number-generator/generator.py new file mode 100644 index 0000000000..d8a803f3ea --- /dev/null +++ b/test/apis/grpc/prime-number-generator/generator.py @@ -0,0 +1,28 @@ +from collections import defaultdict + + +class PythonPredictor: + def __init__(self, config, proto_module_pb2): + self.proto_module_pb2 = proto_module_pb2 + + def predict(self, payload): + prime_numbers_to_generate: int = payload.prime_numbers_to_generate + for prime_number in self.gen_primes(): + if prime_numbers_to_generate == 0: + break + prime_numbers_to_generate -= 1 + yield self.proto_module_pb2.Output(prime_number=prime_number) + + def gen_primes(self, limit=None): + """Sieve of Eratosthenes""" + not_prime = defaultdict(list) + num = 2 + while limit is None or num <= limit: + if num in not_prime: + for prime in not_prime[num]: + not_prime[prime + num].append(prime) + del not_prime[num] + else: + yield num + not_prime[num * num] = [num] + num += 1 diff --git a/test/apis/grpc/prime-number-generator/test_proto/generator_pb2.py b/test/apis/grpc/prime-number-generator/test_proto/generator_pb2.py new file mode 100644 index 0000000000..bf458f4c80 --- /dev/null +++ b/test/apis/grpc/prime-number-generator/test_proto/generator_pb2.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: generator.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor.FileDescriptor( + name="generator.proto", + package="prime_generator", + syntax="proto3", + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\x0fgenerator.proto\x12\x0fprime_generator"*\n\x05Input\x12!\n\x19prime_numbers_to_generate\x18\x01 \x01(\x03"\x1e\n\x06Output\x12\x14\n\x0cprime_number\x18\x01 \x01(\x03\x32I\n\tGenerator\x12<\n\x07Predict\x12\x16.prime_generator.Input\x1a\x17.prime_generator.Output0\x01\x62\x06proto3', +) + + +_INPUT = _descriptor.Descriptor( + name="Input", + full_name="prime_generator.Input", + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name="prime_numbers_to_generate", + full_name="prime_generator.Input.prime_numbers_to_generate", + index=0, + number=1, + type=3, + cpp_type=2, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=36, + serialized_end=78, +) + + +_OUTPUT = _descriptor.Descriptor( + name="Output", + full_name="prime_generator.Output", + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name="prime_number", + full_name="prime_generator.Output.prime_number", + index=0, + number=1, + type=3, + cpp_type=2, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=80, + serialized_end=110, +) + +DESCRIPTOR.message_types_by_name["Input"] = _INPUT +DESCRIPTOR.message_types_by_name["Output"] = _OUTPUT +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Input = _reflection.GeneratedProtocolMessageType( + "Input", + (_message.Message,), + { + "DESCRIPTOR": _INPUT, + "__module__": "generator_pb2" + # @@protoc_insertion_point(class_scope:prime_generator.Input) + }, +) +_sym_db.RegisterMessage(Input) + +Output = _reflection.GeneratedProtocolMessageType( + "Output", + (_message.Message,), + { + "DESCRIPTOR": _OUTPUT, + "__module__": "generator_pb2" + # @@protoc_insertion_point(class_scope:prime_generator.Output) + }, +) +_sym_db.RegisterMessage(Output) + + +_GENERATOR = _descriptor.ServiceDescriptor( + name="Generator", + full_name="prime_generator.Generator", + file=DESCRIPTOR, + index=0, + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_start=112, + serialized_end=185, + methods=[ + _descriptor.MethodDescriptor( + name="Predict", + full_name="prime_generator.Generator.Predict", + index=0, + containing_service=None, + input_type=_INPUT, + output_type=_OUTPUT, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), + ], +) +_sym_db.RegisterServiceDescriptor(_GENERATOR) + +DESCRIPTOR.services_by_name["Generator"] = _GENERATOR + +# @@protoc_insertion_point(module_scope) diff --git a/test/apis/grpc/prime-number-generator/test_proto/generator_pb2_grpc.py b/test/apis/grpc/prime-number-generator/test_proto/generator_pb2_grpc.py new file mode 100644 index 0000000000..1eb11b6083 --- /dev/null +++ b/test/apis/grpc/prime-number-generator/test_proto/generator_pb2_grpc.py @@ -0,0 +1,79 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import generator_pb2 as generator__pb2 + + +class GeneratorStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Predict = channel.unary_stream( + "/prime_generator.Generator/Predict", + request_serializer=generator__pb2.Input.SerializeToString, + response_deserializer=generator__pb2.Output.FromString, + ) + + +class GeneratorServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Predict(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_GeneratorServicer_to_server(servicer, server): + rpc_method_handlers = { + "Predict": grpc.unary_stream_rpc_method_handler( + servicer.Predict, + request_deserializer=generator__pb2.Input.FromString, + response_serializer=generator__pb2.Output.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "prime_generator.Generator", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class Generator(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Predict( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_stream( + request, + target, + "/prime_generator.Generator/Predict", + generator__pb2.Input.SerializeToString, + generator__pb2.Output.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/test/e2e/e2e/expectations.py b/test/e2e/e2e/expectations.py index 62477b4540..bf334f4ce4 100644 --- a/test/e2e/e2e/expectations.py +++ b/test/e2e/e2e/expectations.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pathlib import types from typing import Dict, Any @@ -79,6 +80,38 @@ def validate_response_expectations(expectations: Dict[str, Any]): except Exception as e: raise ExpectationsValidationException("json_schema is invalid") from e + if "grpc" in expectations: + grpc = expectations["grpc"] + required_fields = [ + "proto_module_pb2", + "proto_module_pb2_grpc", + "stub_service_name", + "input_spec", + "output_spec", + ] + for required_field in required_fields: + if required_field not in grpc: + raise ExpectationsValidationException(f"missing grpc.{required_field} field") + + p1 = str(pathlib.Path(grpc["proto_module_pb2"]).parent) + p2 = str(pathlib.Path(grpc["proto_module_pb2_grpc"]).parent) + if p1 != p2: + raise ExpectationsValidationException( + "the parent directories of proto_module_pb2 and proto_module_pb2_grpc don't match" + ) + + input_spec = grpc["input_spec"] + if "class_name" not in input_spec: + raise ExpectationsValidationException("missing grpc.input_spec.class_name field") + if "input" not in input_spec: + raise ExpectationsValidationException("missing grpc.input_spec.input field") + + output_spec = grpc["output_spec"] + if "class_name" not in output_spec: + raise ExpectationsValidationException("missing grpc.output_spec.class_name field") + if "stream" not in output_spec: + raise ExpectationsValidationException("missing grpc.output_spec.stream field") + def _get_response_content(response: requests.Response, content_type: str) -> str: attr = CONTENT_TO_ATTR.get(content_type, "content") diff --git a/test/e2e/e2e/tests.py b/test/e2e/e2e/tests.py index b1c6582c23..de8373202d 100644 --- a/test/e2e/e2e/tests.py +++ b/test/e2e/e2e/tests.py @@ -30,6 +30,7 @@ apis_ready, endpoint_ready, request_prediction, + generate_grpc, job_done, request_batch_prediction, request_task, @@ -65,17 +66,36 @@ def test_realtime_api( client=client, api_names=[api_name], timeout=timeout ), f"apis {api_name} not ready" - with open(str(api_dir / "sample.json")) as f: - payload = json.load(f) + if not expectations or "grpc" not in expectations: + with open(str(api_dir / "sample.json")) as f: + payload = json.load(f) + response = request_prediction(client, api_name, payload) - response = request_prediction(client, api_name, payload) + assert ( + response.status_code == HTTPStatus.OK + ), f"status code: got {response.status_code}, expected {HTTPStatus.OK}" - assert ( - response.status_code == HTTPStatus.OK - ), f"status code: got {response.status_code}, expected {HTTPStatus.OK}" + if expectations and "response" in expectations: + assert_response_expectations(response, expectations["response"]) - if expectations and "response" in expectations: - assert_response_expectations(response, expectations["response"]) + if expectations and "grpc" in expectations: + stub, input_sample, output_values, output_type, is_output_stream = generate_grpc( + client, api_name, api_dir, expectations["grpc"] + ) + if is_output_stream: + for response, output_val in zip(stub.Predict(input_sample), output_values): + assert ( + type(response) == output_type + ), f"didn't receive response of type {str(output_type)}, but received {str(type(response))}" + assert response == output_val, f"received {response} instead of {output_val}" + else: + response = stub.Predict(input_sample) + assert ( + type(stub.Predict(input_sample)) == output_type + ), f"didn't receive response of type {str(output_type)}, but received {str(type(response))}" + assert ( + response == output_values[0] + ), f"received {response} instead of {output_values[0]}" finally: delete_apis(client, [api_name]) @@ -107,7 +127,7 @@ def test_batch_api( payload = json.load(f) response = None - for i in range(retry_attempts + 1): + for _ in range(retry_attempts + 1): response = request_batch_prediction( client, api_name, diff --git a/test/e2e/e2e/utils.py b/test/e2e/e2e/utils.py index 5c07925b92..1eb7e69374 100644 --- a/test/e2e/e2e/utils.py +++ b/test/e2e/e2e/utils.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import time +import importlib +import pathlib from http import HTTPStatus -from typing import List, Optional, Dict, Union, Callable +from typing import Any, List, Optional, Tuple, Union, Dict, Callable +import grpc import cortex as cx import requests import yaml @@ -60,6 +64,45 @@ def _is_ready(): return wait_for(_is_ready, timeout=timeout) +def generate_grpc( + client: cx.Client, api_name: str, api_dir: pathlib.Path, config: Dict[str, Any] +) -> Tuple[Any, Any, List, Any, bool]: + api_info = client.get_api(api_name) + + test_proto_dir = pathlib.Path(config["proto_module_pb2"]).parent + sys.path.append(str(api_dir / test_proto_dir)) + proto_module_pb2 = importlib.import_module(str(pathlib.Path(config["proto_module_pb2"]).stem)) + proto_module_pb2_grpc = importlib.import_module( + str(pathlib.Path(config["proto_module_pb2_grpc"]).stem) + ) + sys.path.pop() + + endpoint = api_info["endpoint"] + ":" + str(api_info["grpc_ports"]["insecure"]) + channel = grpc.insecure_channel(endpoint) + stub = getattr(proto_module_pb2_grpc, config["stub_service_name"] + "Stub")(channel) + + input_sample = getattr(proto_module_pb2, config["input_spec"]["class_name"])() + for k, v in config["input_spec"]["input"].items(): + setattr(input_sample, k, v) + + SampleClass = getattr(proto_module_pb2, config["output_spec"]["class_name"]) + output_values = [] + is_output_stream = config["output_spec"]["stream"] + if is_output_stream: + for entry in config["output_spec"]["output"]: + output_val = SampleClass() + for k, v in entry.items(): + setattr(output_val, k, v) + output_values.append(output_val) + else: + output_val = SampleClass() + for k, v in config["output_spec"]["output"].items(): + setattr(output_val, k, v) + output_values.append(output_val) + + return stub, input_sample, output_values, SampleClass, is_output_stream + + def request_prediction( client: cx.Client, api_name: str, payload: Union[List, Dict] ) -> requests.Response: diff --git a/test/e2e/setup.py b/test/e2e/setup.py index a2a03350c3..ec296c0a97 100644 --- a/test/e2e/setup.py +++ b/test/e2e/setup.py @@ -30,6 +30,7 @@ license="Apache License 2.0", python_requires=">=3.6", install_requires=[ + "grpcio==1.36.0", "requests==2.24.0", "jsonschema==3.2.0", "pytest==6.1.*", diff --git a/test/e2e/tests/aws/test_realtime.py b/test/e2e/tests/aws/test_realtime.py index 1b4bb7e225..506df3de21 100644 --- a/test/e2e/tests/aws/test_realtime.py +++ b/test/e2e/tests/aws/test_realtime.py @@ -18,41 +18,52 @@ import e2e.tests -TEST_APIS = ["pytorch/iris-classifier", "onnx/iris-classifier", "tensorflow/iris-classifier"] -TEST_APIS_GPU = ["pytorch/text-generator", "tensorflow/text-generator"] -TEST_APIS_INF = ["pytorch/image-classifier-resnet50"] +# TEST_APIS = ["pytorch/iris-classifier", "onnx/iris-classifier", "tensorflow/iris-classifier"] +TEST_APIS_GRPC = ["grpc/iris-classifier-sklearn", "grpc/prime-number-generator"] +# TEST_APIS_GPU = ["pytorch/text-generator", "tensorflow/text-generator"] +# TEST_APIS_INF = ["pytorch/image-classifier-resnet50"] + + +# @pytest.mark.usefixtures("client") +# @pytest.mark.parametrize("api", TEST_APIS) +# def test_realtime_api(config: Dict, client: cx.Client, api: str): +# e2e.tests.test_realtime_api( +# client=client, api=api, timeout=config["global"]["realtime_deploy_timeout"] +# ) @pytest.mark.usefixtures("client") -@pytest.mark.parametrize("api", TEST_APIS) -def test_realtime_api(config: Dict, client: cx.Client, api: str): +@pytest.mark.parametrize("api", TEST_APIS_GRPC) +def test_realtime_api_grpc(config: Dict, client: cx.Client, api: str): e2e.tests.test_realtime_api( - client=client, api=api, timeout=config["global"]["realtime_deploy_timeout"] + client=client, + api=api, + timeout=config["global"]["realtime_deploy_timeout"], ) -@pytest.mark.usefixtures("client") -@pytest.mark.parametrize("api", TEST_APIS_GPU) -def test_realtime_api_gpu(config: Dict, client: cx.Client, api: str): - skip_gpus = config["global"].get("skip_gpus", False) - if skip_gpus: - pytest.skip("--skip-gpus flag detected, skipping GPU tests") +# @pytest.mark.usefixtures("client") +# @pytest.mark.parametrize("api", TEST_APIS_GPU) +# def test_realtime_api_gpu(config: Dict, client: cx.Client, api: str): +# skip_gpus = config["global"].get("skip_gpus", False) +# if skip_gpus: +# pytest.skip("--skip-gpus flag detected, skipping GPU tests") - e2e.tests.test_realtime_api( - client=client, api=api, timeout=config["global"]["realtime_deploy_timeout"] - ) +# e2e.tests.test_realtime_api( +# client=client, api=api, timeout=config["global"]["realtime_deploy_timeout"] +# ) -@pytest.mark.usefixtures("client") -@pytest.mark.parametrize("api", TEST_APIS_INF) -def test_realtime_api_inf(config: Dict, client: cx.Client, api: str): - skip_infs = config["global"].get("skip_infs", False) - if skip_infs: - pytest.skip("--skip-infs flag detected, skipping Inferentia tests") +# @pytest.mark.usefixtures("client") +# @pytest.mark.parametrize("api", TEST_APIS_INF) +# def test_realtime_api_inf(config: Dict, client: cx.Client, api: str): +# skip_infs = config["global"].get("skip_infs", False) +# if skip_infs: +# pytest.skip("--skip-infs flag detected, skipping Inferentia tests") - e2e.tests.test_realtime_api( - client=client, - api=api, - timeout=config["global"]["realtime_deploy_timeout"], - api_config_name="cortex_inf.yaml", - ) +# e2e.tests.test_realtime_api( +# client=client, +# api=api, +# timeout=config["global"]["realtime_deploy_timeout"], +# api_config_name="cortex_inf.yaml", +# )