From 4d703cd8743cff6cb38d70fc923e88286772b335 Mon Sep 17 00:00:00 2001 From: vishal Date: Wed, 18 Sep 2019 22:22:15 -0400 Subject: [PATCH] Pass onnx model output directly to post_inference request handler --- examples/iris-classifier/handlers/pytorch.py | 2 +- pkg/workloads/cortex/onnx_serve/api.py | 14 ++++---------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/examples/iris-classifier/handlers/pytorch.py b/examples/iris-classifier/handlers/pytorch.py index f999028e19..ef7ee5f9bc 100644 --- a/examples/iris-classifier/handlers/pytorch.py +++ b/examples/iris-classifier/handlers/pytorch.py @@ -15,5 +15,5 @@ def pre_inference(sample, metadata): def post_inference(prediction, metadata): - predicted_class_id = int(np.argmax(prediction[0][0])) + predicted_class_id = int(np.argmax(prediction[0].squeeze())) return labels[predicted_class_id] diff --git a/pkg/workloads/cortex/onnx_serve/api.py b/pkg/workloads/cortex/onnx_serve/api.py index 3736fc2b9d..fbfd940650 100644 --- a/pkg/workloads/cortex/onnx_serve/api.py +++ b/pkg/workloads/cortex/onnx_serve/api.py @@ -192,19 +192,13 @@ def predict(app_name, api_name): ) from e inference_input = convert_to_onnx_input(prepared_sample, input_metadata) - model_outputs = sess.run([], inference_input) - result = [] - for model_output in model_outputs: - if type(model_output) is np.ndarray: - result.append(model_output.tolist()) - else: - result.append(model_output) - - debug_obj("inference", result, debug) + model_output = sess.run([], inference_input) + debug_obj("inference", model_output, debug) + result = model_output if request_handler is not None and util.has_function(request_handler, "post_inference"): try: - result = request_handler.post_inference(result, output_metadata) + result = request_handler.post_inference(model_output, output_metadata) except Exception as e: raise UserRuntimeException( api["request_handler"], "post_inference request handler", str(e)