diff --git a/src/backends/tensorflow.c b/src/backends/tensorflow.c index 33813bbc9..1baed43e9 100644 --- a/src/backends/tensorflow.c +++ b/src/backends/tensorflow.c @@ -530,8 +530,10 @@ int RAI_ModelRunTF(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error) outputTensorsValues, noutputs, NULL /* target_opers */, 0 /* ntargets */, NULL /* run_Metadata */, status); + bool delete_output = true; if (TF_GetCode(status) != TF_OK) { RAI_SetError(error, RAI_EMODELRUN, TF_Message(status)); + delete_output = false; goto cleanup; } @@ -575,8 +577,10 @@ int RAI_ModelRunTF(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error) } TF_DeleteTensor(inputTensorsValues[i]); } - for (size_t i = 0; i < noutputs; i++) { - TF_DeleteTensor(outputTensorsValues[i]); + if (delete_output) { + for (size_t i = 0; i < noutputs; i++) { + TF_DeleteTensor(outputTensorsValues[i]); + } } return res; } diff --git a/tests/flow/test_data/frozen_bad_model.pb b/tests/flow/test_data/frozen_bad_model.pb new file mode 100644 index 000000000..62be95700 --- /dev/null +++ b/tests/flow/test_data/frozen_bad_model.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98d6bb6cdfe22525a894c67b3992dc3e88b06730661817ee79c62c20b9dd09dd +size 174203 diff --git a/tests/flow/tests_tensorflow.py b/tests/flow/tests_tensorflow.py index c590ae6f6..8ba8ef1ca 100644 --- a/tests/flow/tests_tensorflow.py +++ b/tests/flow/tests_tensorflow.py @@ -759,3 +759,14 @@ def run(): env.assertEqual(out_values, [b'this is', b'the first batch']) out_values = con.execute_command('AI.TENSORGET', 'second_batch{1}', 'VALUES') env.assertEqual(out_values, [b'that is', b'the second batch']) + +@skip_if_no_TF +def test_bad_execution_model(env): + con = get_connection(env, '{1}') + + model_pb = load_file_content('frozen_bad_model.pb') + ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'TF', DEVICE, 'INPUTS', 1, 'x', 'OUTPUTS', 1, 'Identity', 'BLOB', model_pb) + env.assertEqual(ret, b'OK') + con.execute_command('AI.TENSORSET', 'my_str_tensor{1}', 'STRING', 4, 'BLOB', "how do I extract keys from a dict into a list?\x00debug public static void main(string[] args) {...}\x00should I use def main()\x00type hinting for list?\x00") + env.assertEqual(ret, b'OK') + check_error(env, con, 'AI.MODELEXECUTE', 'm{1}', 'INPUTS', 1, 'my_str_tensor{1}', 'OUTPUTS', 1, 'foo{1}') \ No newline at end of file