From f867a41ed0030465cee2b1d3dbfdfdd70190bad2 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 1 Jun 2020 12:40:42 +0200 Subject: [PATCH 1/2] Add support for variadic arguments to SCRIPT --- docs/commands.md | 29 ++++++++++++++++++-- src/backends/torch.c | 1 + src/libtorch_c/torch_c.cpp | 22 ++++++++++++--- src/libtorch_c/torch_c.h | 2 +- src/script.c | 11 +++++--- src/script_struct.h | 1 + test/test_data/script.txt | 3 +++ test/tests_pytorch.py | 55 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 114 insertions(+), 10 deletions(-) diff --git a/docs/commands.md b/docs/commands.md index 69d264f0a..b70e04d2c 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -458,14 +458,16 @@ The **`AI.SCRIPTRUN`** command runs a script stored as a key's value on its spec **Redis API** ``` -AI.SCRIPTRUN INPUTS [input ...] OUTPUTS [output ...] +AI.SCRIPTRUN INPUTS [input ...] [$ input ...] OUTPUTS [output ...] ``` _Arguments_ * **key**: the script's key name * **function**: the name of the function to run -* **INPUTS**: denotes the beginning of the input tensors keys' list, followed by one or more key names +* **INPUTS**: denotes the beginning of the input tensors keys' list, followed by one or more key names; + variadic arguments are supported by prepending the list with `$`, in this case the + script is expected an argument of type `List[Tensor]` as its last argument * **OUTPUTS**: denotes the beginning of the output tensors keys' list, followed by one or more key names _Return_ @@ -489,6 +491,29 @@ redis> AI.TENSORGET result VALUES 3) 1) "42" ``` +If 'myscript' supports variadic arguments: +```python +def addn(a, args : List[Tensor]): + return a + torch.stack(args).sum() +``` + +then one can provide an arbitrary number of inputs after the `$` sign: + +``` +redis> AI.TENSORSET mytensor1 FLOAT 1 VALUES 40 +OK +redis> AI.TENSORSET mytensor2 FLOAT 1 VALUES 1 +OK +redis> AI.TENSORSET mytensor3 FLOAT 1 VALUES 1 +OK +redis> AI.SCRIPTRUN myscript addn INPUTS mytensor1 $ mytensor2 mytensor3 OUTPUTS result +OK +redis> AI.TENSORGET result VALUES +1) FLOAT +2) 1) (integer) 1 +3) 1) "42" +``` + !!! warning "Intermediate memory overhead" The execution of scripts may generate intermediate tensors that are not allocated by the Redis allocator, but by whatever allocator is used in the backends (which may act on main memory or GPU memory, depending on the device), thus not being limited by `maxmemory` configuration settings of Redis. diff --git a/src/backends/torch.c b/src/backends/torch.c index 61c676fc4..fa9753fd2 100644 --- a/src/backends/torch.c +++ b/src/backends/torch.c @@ -252,6 +252,7 @@ int RAI_ScriptRunTorch(RAI_ScriptRunCtx* sctx, RAI_Error* error) { char* error_descr = NULL; torchRunScript(sctx->script->script, sctx->fnname, + sctx->variadic, nInputs, inputs, nOutputs, outputs, &error_descr, RedisModule_Alloc); diff --git a/src/libtorch_c/torch_c.cpp b/src/libtorch_c/torch_c.cpp index 75551319a..a748da525 100644 --- a/src/libtorch_c/torch_c.cpp +++ b/src/libtorch_c/torch_c.cpp @@ -190,7 +190,7 @@ struct ModuleContext { int64_t device_id; }; -void torchRunModule(ModuleContext* ctx, const char* fnName, +void torchRunModule(ModuleContext* ctx, const char* fnName, int variadic, long nInputs, DLManagedTensor** inputs, long nOutputs, DLManagedTensor** outputs) { // Checks device, if GPU then move input to GPU before running @@ -214,11 +214,25 @@ void torchRunModule(ModuleContext* ctx, const char* fnName, torch::jit::Stack stack; for (int i=0; idl_tensor); torch::Tensor tensor = fromDLPack(input); stack.push_back(tensor.to(device)); } + if (variadic != -1 ) { + std::vector args; + for (int i=variadic; idl_tensor); + torch::Tensor tensor = fromDLPack(input); + tensor.to(device); + args.emplace_back(tensor); + } + stack.push_back(args); + } + if (ctx->module) { torch::NoGradGuard guard; torch::jit::script::Method method = ctx->module->get_method(fnName); @@ -351,14 +365,14 @@ extern "C" void* torchLoadModel(const char* graph, size_t graphlen, DLDeviceType return ctx; } -extern "C" void torchRunScript(void* scriptCtx, const char* fnName, +extern "C" void torchRunScript(void* scriptCtx, const char* fnName, int variadic, long nInputs, DLManagedTensor** inputs, long nOutputs, DLManagedTensor** outputs, char **error, void* (*alloc)(size_t)) { ModuleContext* ctx = (ModuleContext*)scriptCtx; try { - torchRunModule(ctx, fnName, nInputs, inputs, nOutputs, outputs); + torchRunModule(ctx, fnName, variadic, nInputs, inputs, nOutputs, outputs); } catch(std::exception& e) { const size_t len = strlen(e.what()); @@ -376,7 +390,7 @@ extern "C" void torchRunModel(void* modelCtx, { ModuleContext* ctx = (ModuleContext*)modelCtx; try { - torchRunModule(ctx, "forward", nInputs, inputs, nOutputs, outputs); + torchRunModule(ctx, "forward", -1, nInputs, inputs, nOutputs, outputs); } catch(std::exception& e) { const size_t len = strlen(e.what()); diff --git a/src/libtorch_c/torch_c.h b/src/libtorch_c/torch_c.h index 117785a73..617e82eea 100644 --- a/src/libtorch_c/torch_c.h +++ b/src/libtorch_c/torch_c.h @@ -19,7 +19,7 @@ void* torchCompileScript(const char* script, DLDeviceType device, int64_t device void* torchLoadModel(const char* model, size_t modellen, DLDeviceType device, int64_t device_id, char **error, void* (*alloc)(size_t)); -void torchRunScript(void* scriptCtx, const char* fnName, +void torchRunScript(void* scriptCtx, const char* fnName, int variadic, long nInputs, DLManagedTensor** inputs, long nOutputs, DLManagedTensor** outputs, char **error, void* (*alloc)(size_t)); diff --git a/src/script.c b/src/script.c index 194f5cb4e..a907e1955 100644 --- a/src/script.c +++ b/src/script.c @@ -150,6 +150,7 @@ RAI_ScriptRunCtx* RAI_ScriptRunCtxCreate(RAI_Script* script, sctx->inputs = array_new(RAI_ScriptCtxParam, PARAM_INITIAL_SIZE); sctx->outputs = array_new(RAI_ScriptCtxParam, PARAM_INITIAL_SIZE); sctx->fnname = RedisModule_Strdup(fnname); + sctx->variadic = -1; return sctx; } @@ -285,6 +286,10 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx, is_input = 1; outputs_flag_count = 1; } else { + if (!strcasecmp(arg_string, "$")) { + (*sctx)->variadic = argpos - 4; + continue; + } RedisModule_RetainString(ctx, argv[argpos]); if (is_input == 0) { RAI_Tensor *inputTensor; @@ -299,18 +304,18 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx, RedisModule_CloseKey(tensorKey); } else { const int get_result = RAI_getTensorFromLocalContext( - ctx, *localContextDict, arg_string, &inputTensor,error); + ctx, *localContextDict, arg_string, &inputTensor, error); if (get_result == REDISMODULE_ERR) { return -1; } } if (!RAI_ScriptRunCtxAddInput(*sctx, inputTensor)) { - RedisAI_ReplyOrSetError(ctx,error,RAI_ESCRIPTRUN, "ERR Input key not found"); + RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Input key not found"); return -1; } } else { if (!RAI_ScriptRunCtxAddOutput(*sctx)) { - RedisAI_ReplyOrSetError(ctx,error,RAI_ESCRIPTRUN, "ERR Output key not found"); + RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Output key not found"); return -1; } *outkeys=array_append(*outkeys,argv[argpos]); diff --git a/src/script_struct.h b/src/script_struct.h index 056150973..943325484 100644 --- a/src/script_struct.h +++ b/src/script_struct.h @@ -27,6 +27,7 @@ typedef struct RAI_ScriptRunCtx { char* fnname; RAI_ScriptCtxParam* inputs; RAI_ScriptCtxParam* outputs; + int variadic; } RAI_ScriptRunCtx; #endif /* SRC_SCRIPT_STRUCT_H_ */ diff --git a/test/test_data/script.txt b/test/test_data/script.txt index c3fbc1014..34e4b9317 100644 --- a/test/test_data/script.txt +++ b/test/test_data/script.txt @@ -1,2 +1,5 @@ def bar(a, b): return a + b + +def bar_variadic(a, args : List[Tensor]): + return args[0] + args[1] diff --git a/test/tests_pytorch.py b/test/tests_pytorch.py index f798a3453..912772981 100644 --- a/test/tests_pytorch.py +++ b/test/tests_pytorch.py @@ -426,6 +426,61 @@ def test_pytorch_scriptrun(env): values2 = con2.execute_command('AI.TENSORGET', 'c', 'VALUES') env.assertEqual(values2, values) + +def test_pytorch_scriptrun_variadic(env): + if not TEST_PT: + env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) + return + + con = env.getConnection() + + test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') + script_filename = os.path.join(test_data_path, 'script.txt') + + with open(script_filename, 'rb') as f: + script = f.read() + + ret = con.execute_command('AI.SCRIPTSET', 'myscript', DEVICE, 'TAG', 'version1', 'SOURCE', script) + env.assertEqual(ret, b'OK') + + ret = con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + ret = con.execute_command('AI.TENSORSET', 'b1', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + ret = con.execute_command('AI.TENSORSET', 'b2', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + + ensureSlaveSynced(con, env) + + for _ in range( 0,100): + ret = con.execute_command('AI.SCRIPTRUN', 'myscript', 'bar_variadic', 'INPUTS', 'a', '$', 'b1', 'b2', 'OUTPUTS', 'c') + env.assertEqual(ret, b'OK') + + ensureSlaveSynced(con, env) + + info = con.execute_command('AI.INFO', 'myscript') + info_dict_0 = info_to_dict(info) + + env.assertEqual(info_dict_0['key'], 'myscript') + env.assertEqual(info_dict_0['type'], 'SCRIPT') + env.assertEqual(info_dict_0['backend'], 'TORCH') + env.assertEqual(info_dict_0['tag'], 'version1') + env.assertTrue(info_dict_0['duration'] > 0) + env.assertEqual(info_dict_0['samples'], -1) + env.assertEqual(info_dict_0['calls'], 100) + env.assertEqual(info_dict_0['errors'], 0) + + values = con.execute_command('AI.TENSORGET', 'c', 'VALUES') + env.assertEqual(values, [b'4', b'6', b'4', b'6']) + + ensureSlaveSynced(con, env) + + if env.useSlaves: + con2 = env.getSlaveConnection() + values2 = con2.execute_command('AI.TENSORGET', 'c', 'VALUES') + env.assertEqual(values2, values) + + def test_pytorch_scriptrun_errors(env): if not TEST_PT: env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) From b7388a81f211e5f426d8704fd1e6cce08cf5e24c Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 1 Jun 2020 15:01:30 +0200 Subject: [PATCH 2/2] Add negative errors --- test/tests_pytorch.py | 60 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/test/tests_pytorch.py b/test/tests_pytorch.py index 912772981..ff33b2715 100644 --- a/test/tests_pytorch.py +++ b/test/tests_pytorch.py @@ -583,6 +583,66 @@ def test_pytorch_scriptrun_errors(env): env.assertEqual(type(exception), redis.exceptions.ResponseError) +def test_pytorch_scriptrun_errors(env): + if not TEST_PT: + env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True) + return + + con = env.getConnection() + + test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') + script_filename = os.path.join(test_data_path, 'script.txt') + + with open(script_filename, 'rb') as f: + script = f.read() + + ret = con.execute_command('AI.SCRIPTSET', 'ket', DEVICE, 'TAG', 'asdf', 'SOURCE', script) + env.assertEqual(ret, b'OK') + + ret = con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + ret = con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + + ensureSlaveSynced(con, env) + + # ERR Variadic input key is empty + try: + con.execute_command('DEL', 'EMPTY') + con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'a', '$', 'EMPTY', 'b', 'OUTPUTS', 'c') + except Exception as e: + exception = e + env.assertEqual(type(exception), redis.exceptions.ResponseError) + env.assertEqual("tensor key is empty", exception.__str__()) + + # ERR Variadic input key not tensor + try: + con.execute_command('SET', 'NOT_TENSOR', 'BAR') + con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'a', '$' , 'NOT_TENSOR', 'b', 'OUTPUTS', 'c') + except Exception as e: + exception = e + env.assertEqual(type(exception), redis.exceptions.ResponseError) + env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__()) + + try: + con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'b', '$', 'OUTPUTS', 'c') + except Exception as e: + exception = e + env.assertEqual(type(exception), redis.exceptions.ResponseError) + + try: + con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', 'b', '$', 'OUTPUTS') + except Exception as e: + exception = e + env.assertEqual(type(exception), redis.exceptions.ResponseError) + + try: + con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', '$', 'OUTPUTS') + except Exception as e: + exception = e + env.assertEqual(type(exception), redis.exceptions.ResponseError) + + def test_pytorch_scriptinfo(env): if not TEST_PT: env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)