From bf2b13ec0d560ac582b81cc77fccd7f43ac38a4e Mon Sep 17 00:00:00 2001 From: alonre24 Date: Sun, 27 Dec 2020 10:30:18 +0200 Subject: [PATCH 1/3] Refactor script run command (parsing, dag run step and reply). --- src/DAG/dag.c | 106 +++++++++++++++------------- src/DAG/dag_parser.c | 44 ++---------- src/command_parser.c | 162 ++++++++++++++++++++++++++++++++++++++++++- src/model.c | 1 + src/modelRun_ctx.c | 7 +- src/redisai.c | 14 +--- src/run_info.c | 2 +- src/script.c | 29 ++++---- src/script.h | 11 +-- 9 files changed, 246 insertions(+), 130 deletions(-) diff --git a/src/DAG/dag.c b/src/DAG/dag.c index ad7b8380f..aa24b8e50 100644 --- a/src/DAG/dag.c +++ b/src/DAG/dag.c @@ -237,53 +237,52 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur uint n_inkeys = array_len(currentOp->inkeys); uint n_outkeys = array_len(currentOp->outkeys); - RAI_ContextReadLock(rinfo); + if (!rinfo->single_op_dag) { - RAI_Tensor *inputTensors[n_inkeys]; - for (uint i = 0; i < n_inkeys; i++) { - RAI_Tensor *inputTensor; - const int get_result = RAI_getTensorFromLocalContext( - NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err); - if (get_result == REDISMODULE_ERR) { - // We check for this outside the function - // this check cannot be covered by tests - currentOp->result = REDISMODULE_ERR; - RAI_ContextUnlock(rinfo); - return; + RAI_ContextReadLock(rinfo); + RAI_Tensor *inputTensors[n_inkeys]; + for (uint i = 0; i < n_inkeys; i++) { + RAI_Tensor *inputTensor; + const int get_result = RAI_getTensorFromLocalContext( + NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err); + if (get_result == REDISMODULE_ERR) { + // We check for this outside the function + // this check cannot be covered by tests + currentOp->result = REDISMODULE_ERR; + RAI_ContextUnlock(rinfo); + return; + } + inputTensors[i] = inputTensor; } - inputTensors[i] = inputTensor; - } - - RAI_ContextUnlock(rinfo); - - for (uint i = 0; i < n_inkeys; i++) { - RAI_ScriptRunCtxAddInput(currentOp->sctx, inputTensors[i], currentOp->err); - } + RAI_ContextUnlock(rinfo); - for (uint i = 0; i < n_outkeys; i++) { - RAI_ScriptRunCtxAddOutput(currentOp->sctx); + for (uint i = 0; i < n_inkeys; i++) { + RAI_ScriptRunCtxAddInput(currentOp->sctx, inputTensors[i]); + } + for (uint i = 0; i < n_outkeys; i++) { + RAI_ScriptRunCtxAddOutput(currentOp->sctx); + } } const long long start = ustime(); int result = RAI_ScriptRun(currentOp->sctx, currentOp->err); const long long end = ustime(); - RAI_ContextWriteLock(rinfo); - - const size_t noutputs = RAI_ScriptRunCtxNumOutputs(currentOp->sctx); - for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) { - RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(currentOp->sctx, outputNumber); - RedisModuleString *key_string = currentOp->outkeys[outputNumber]; - tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL; - AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, tensor); - } - currentOp->result = result; currentOp->duration_us = end - start; - RAI_ContextUnlock(rinfo); + if (!rinfo->single_op_dag) { - return; + RAI_ContextWriteLock(rinfo); + const size_t noutputs = RAI_ScriptRunCtxNumOutputs(currentOp->sctx); + for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) { + RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(currentOp->sctx, outputNumber); + RedisModuleString *key_string = currentOp->outkeys[outputNumber]; + tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL; + AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, tensor); + } + RAI_ContextUnlock(rinfo); + } } size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) { @@ -572,17 +571,16 @@ static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor, return ret; } -static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) { +static void _PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) { + AI_dictIterator *persist_iter = AI_dictGetSafeIterator(rinfo->dagTensorsPersistedContext); AI_dictEntry *persist_entry = AI_dictNext(persist_iter); + while (persist_entry) { RedisModuleString *persist_key_name = AI_dictGetKey(persist_entry); - AI_dictEntry *tensor_entry = AI_dictFind(rinfo->dagTensorsContext, persist_key_name); - if (tensor_entry) { RAI_Tensor *tensor = AI_dictGetVal(tensor_entry); - if (tensor == NULL) { persist_entry = AI_dictNext(persist_iter); continue; @@ -594,13 +592,13 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) { RedisModule_ReplyWithError(ctx, "ERR specified persistent key that was not used in DAG"); rinfo->dagReplyLength++; - RedisModule_Log(ctx, "warning", "on DAGRUN's PERSIST pecified persistent key (%s) that " "was not used on DAG. Logging all local context keys", persist_key_name); AI_dictIterator *local_iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext); AI_dictEntry *local_entry = AI_dictNext(local_iter); + while (local_entry) { RedisModuleString *localcontext_key_name = AI_dictGetKey(local_entry); RedisModule_Log(ctx, "warning", "DAG's local context key (%s)", @@ -619,7 +617,7 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) { AI_dictReleaseIterator(persist_iter); } -static void ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) { +static void _ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) { const size_t noutputs = RAI_ModelRunCtxNumOutputs(op->mctx); for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) { RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(op->mctx, outputNumber); @@ -629,6 +627,16 @@ static void ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) { } } +static void _ScriptSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) { + const size_t noutputs = RAI_ScriptRunCtxNumOutputs(op->sctx); + for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) { + RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(op->sctx, outputNumber); + tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL; + if (tensor) + _StoreTensorInKeySpace(ctx, tensor, op->outkeys[outputNumber], false); + } +} + int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { REDISMODULE_NOT_USED(argv); REDISMODULE_NOT_USED(argc); @@ -650,7 +658,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc return REDISMODULE_OK; } - if (rinfo->single_op_dag == 0) { + if (!rinfo->single_op_dag) { RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_ARRAY_LEN); } @@ -745,18 +753,20 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc return REDISMODULE_ERR; } - // TODO: Take care of script single op - if (rinfo->single_op_dag == 0 || rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_SCRIPTRUN) { + if (!rinfo->single_op_dag) { // Save the required tensors in redis key space. - PersistTensors(ctx, rinfo); - if (rinfo->single_op_dag == 0) - RedisModule_ReplySetArrayLength(ctx, rinfo->dagReplyLength); + _PersistTensors(ctx, rinfo); + RedisModule_ReplySetArrayLength(ctx, rinfo->dagReplyLength); } else { - ModelSingleOp_PersistTensors(ctx, rinfo->dagOps[0]); + if (rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_MODELRUN) { + _ModelSingleOp_PersistTensors(ctx, rinfo->dagOps[0]); + } else { + RedisModule_Assert(rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_SCRIPTRUN); + _ScriptSingleOp_PersistTensors(ctx, rinfo->dagOps[0]); + } } RAI_FreeRunInfo(rinfo); - return REDISMODULE_OK; } diff --git a/src/DAG/dag_parser.c b/src/DAG/dag_parser.c index f81f59ce6..13a0f3412 100644 --- a/src/DAG/dag_parser.c +++ b/src/DAG/dag_parser.c @@ -136,18 +136,9 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b int chainingOpCount = 0; bool load_complete = false; bool persist_complete = false; - int arg_pos = 1; - - // If we're parsing a AI.SCRIPTRUN command, we don't expect there to be a chaining |> operator - if (!strcasecmp(RedisModule_StringPtrLen(argv[0], NULL), "AI.SCRIPTRUN")) { - arg_pos = 0; - chainingOpCount++; - rinfo->single_op_dag = 1; - rinfo->single_device_dag = 1; - } // The first arg is "AI.DAGRUN", so we go over from the next arg. - for (; arg_pos < argc; arg_pos++) { + for (int arg_pos = 1; arg_pos < argc; arg_pos++) { const char *arg_string = RedisModule_StringPtrLen(argv[arg_pos], NULL); if (!strcasecmp(arg_string, "LOAD") && !load_complete) { @@ -232,6 +223,7 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b RedisModuleKey *modelKey; const int status = RAI_GetModelFromKeyspace(ctx, argv[arg_pos + 1], &modelKey, &mto, REDISMODULE_READ); + RedisModule_OpenKey(ctx, argv[arg_pos + 1], REDISMODULE_READ); if (status == REDISMODULE_ERR) { RAI_FreeRunInfo(rinfo); RedisModule_ReplyWithError(ctx, "ERR Model not found"); @@ -252,6 +244,7 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b RedisModuleKey *scriptKey; const int status = RAI_GetScriptFromKeyspace(ctx, argv[arg_pos + 1], &scriptKey, &sto, REDISMODULE_READ); + RedisModule_OpenKey(ctx, argv[arg_pos + 1], REDISMODULE_READ); if (status == REDISMODULE_ERR) { RAI_FreeRunInfo(rinfo); return REDISMODULE_ERR; @@ -303,35 +296,6 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b } } - if (rinfo->single_op_dag && rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_SCRIPTRUN) { - RAI_DagOp *op = rinfo->dagOps[0]; - RAI_Tensor *t; - RedisModuleKey *key; - for (size_t i = 0; i < array_len(op->inkeys); i++) { - RedisModuleString *inkey = op->inkeys[i]; - const int status = RAI_GetTensorFromKeyspace(ctx, inkey, &key, &t, REDISMODULE_READ); - if (status == REDISMODULE_ERR) { - RedisModule_Log(ctx, "warning", - "on DAGRUN's LOAD could not load tensor %s from keyspace", - RedisModule_StringPtrLen(inkey, NULL)); - return REDISMODULE_ERR; - } - char buf[16]; - sprintf(buf, "%04d", 1); - RedisModuleString *dictKey = RedisModule_CreateStringFromString(NULL, inkey); - RedisModule_StringAppendBuffer(NULL, dictKey, buf, strlen(buf)); - AI_dictAdd(rinfo->dagTensorsContext, (void *)dictKey, - (void *)RAI_TensorGetShallowCopy(t)); - AI_dictAdd(rinfo->dagTensorsLoadedContext, (void *)dictKey, (void *)1); - RedisModule_Free(dictKey); - } - - for (size_t i = 0; i < array_len(op->outkeys); i++) { - RedisModuleString *outkey = op->outkeys[i]; - AI_dictAdd(rinfo->dagTensorsPersistedContext, (void *)outkey, (void *)1); - } - } - // At this point, we have built a sequence of DAG operations, each with its own // input and output keys. The names of the keys will be used to look whether the // inputs to a DAG operation have all been realized by previous operations (or if @@ -462,4 +426,4 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b } } return REDISMODULE_OK; -} \ No newline at end of file +} diff --git a/src/command_parser.c b/src/command_parser.c index c3aadf2f8..22e2d5a18 100644 --- a/src/command_parser.c +++ b/src/command_parser.c @@ -175,6 +175,166 @@ int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModul return REDISMODULE_ERR; } +static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, + RAI_Script **script, RAI_Error *error, + RedisModuleString ***inkeys, RedisModuleString ***outkeys, + RedisModuleString **runkey, char const **func_name, + long long *timeout, int *variadic) { + + if (argc < 5) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR wrong number of arguments for 'AI.SCRIPTRUN' command"); + return REDISMODULE_ERR; + } + size_t argpos = 1; + RedisModuleKey *scriptKey; + const int status = + RAI_GetScriptFromKeyspace(ctx, argv[argpos], &scriptKey, script, REDISMODULE_READ); + if (status == REDISMODULE_ERR) { + RAI_SetError(error, RAI_ESCRIPTRUN, "ERR Script not found"); + return REDISMODULE_ERR; + } + RAI_HoldString(NULL, argv[argpos]); + *runkey = argv[argpos]; + + const char *arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL); + if (!strcasecmp(arg_string, "TIMEOUT") || !strcasecmp(arg_string, "INPUTS")) { + RAI_SetError(error, RAI_ESCRIPTRUN, "ERR function name not specified"); + return REDISMODULE_ERR; + } + *func_name = arg_string; + arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL); + + // Parse timeout arg if given and store it in timeout + if (!strcasecmp(arg_string, "TIMEOUT")) { + if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR) + return REDISMODULE_ERR; + arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL); + } + if (strcasecmp(arg_string, "INPUTS") != 0) { + RAI_SetError(error, RAI_ESCRIPTRUN, "ERR INPUTS not specified"); + return REDISMODULE_ERR; + } + + bool is_input = true, is_output = false; + size_t ninputs = 0, noutputs = 0; + int varidic_start_pos = -1; + + while (++argpos < argc) { + arg_string = RedisModule_StringPtrLen(argv[argpos], NULL); + if (!strcasecmp(arg_string, "OUTPUTS") && !is_output) { + is_input = false; + is_output = true; + } else if (!strcasecmp(arg_string, "$")) { + if (varidic_start_pos > -1) { + RAI_SetError(error, RAI_ESCRIPTRUN, + "ERR Already encountered a variable size list of tensors"); + return REDISMODULE_ERR; + } + varidic_start_pos = ninputs; + } else { + RAI_HoldString(NULL, argv[argpos]); + if (is_input) { + ninputs++; + *inkeys = array_append(*inkeys, argv[argpos]); + } else { + noutputs++; + *outkeys = array_append(*outkeys, argv[argpos]); + } + } + } + *variadic = varidic_start_pos; + + return REDISMODULE_OK; +} + +/** + * Extract the params for the ScriptCtxRun object from AI.SCRIPTRUN arguments. + * + * @param ctx Context in which Redis modules operate. + * @param inkeys Script input tensors keys, as an array of strings. + * @param outkeys Script output tensors keys, as an array of strings. + * @param sctx Destination Script context to store the parsed data. + * @return REDISMODULE_OK in case of success, REDISMODULE_ERR otherwise. + */ + +static int _ScriptRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkeys, + RedisModuleString **outkeys, RAI_ScriptRunCtx *sctx) { + + RAI_Tensor *t; + RedisModuleKey *key; + size_t ninputs = array_len(inkeys), noutputs = array_len(outkeys); + for (size_t i = 0; i < ninputs; i++) { + const int status = RAI_GetTensorFromKeyspace(ctx, inkeys[i], &key, &t, REDISMODULE_READ); + if (status == REDISMODULE_ERR) { + RedisModule_Log(ctx, "warning", "could not load tensor %s from keyspace", + RedisModule_StringPtrLen(inkeys[i], NULL)); + return REDISMODULE_ERR; + } + RAI_ScriptRunCtxAddInput(sctx, t); + } + for (size_t i = 0; i < noutputs; i++) { + RAI_ScriptRunCtxAddOutput(sctx); + } + return REDISMODULE_OK; +} + +int ParseScriptRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleString **argv, + int argc) { + + // Build a ScriptRunCtx from command. + RAI_Error error = {0}; + RAI_Script *script; + RedisModuleString **inkeys = array_new(RedisModuleString *, 1); + RedisModuleString **outkeys = array_new(RedisModuleString *, 1); + RedisModuleString *runkey = NULL; + const char *func_name = NULL; + RAI_ScriptRunCtx *sctx = NULL; + RAI_DagOp *currentOp; + + long long timeout = 0; + int variadic = -1; + if (_ScriptRunCommand_ParseArgs(ctx, argv, argc, &script, &error, &inkeys, &outkeys, &runkey, + &func_name, &timeout, &variadic) == REDISMODULE_ERR) { + RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&error)); + goto cleanup; + } + sctx = RAI_ScriptRunCtxCreate(script, func_name); + sctx->variadic = variadic; + + if (rinfo->single_op_dag) { + rinfo->timeout = timeout; + // Set params in ScriptRunCtx, bring inputs from key space. + if (_ScriptRunCtx_SetParams(ctx, inkeys, outkeys, sctx) == REDISMODULE_ERR) + goto cleanup; + } + if (RAI_InitDagOp(¤tOp) == REDISMODULE_ERR) { + RedisModule_ReplyWithError( + ctx, "ERR Unable to allocate the memory and initialise the RAI_dagOp structure"); + goto cleanup; + } + currentOp->commandType = REDISAI_DAG_CMD_SCRIPTRUN; + Dag_PopulateOp(currentOp, sctx, inkeys, outkeys, runkey); + rinfo->dagOps = array_append(rinfo->dagOps, currentOp); + return REDISMODULE_OK; + +cleanup: + for (size_t i = 0; i < array_len(inkeys); i++) { + RedisModule_FreeString(NULL, inkeys[i]); + } + array_free(inkeys); + for (size_t i = 0; i < array_len(outkeys); i++) { + RedisModule_FreeString(NULL, outkeys[i]); + } + array_free(outkeys); + if (runkey) + RedisModule_FreeString(NULL, runkey); + if (sctx) + RAI_ScriptRunCtxFree(sctx); + RAI_FreeRunInfo(rinfo); + return REDISMODULE_ERR; +} + int RedisAI_ExecuteCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, RunCommand command, bool ro_dag) { @@ -199,7 +359,7 @@ int RedisAI_ExecuteCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar break; case CMD_SCRIPTRUN: rinfo->single_op_dag = 1; - status = DAG_CommandParser(ctx, argv, argc, ro_dag, &rinfo); + status = ParseScriptRunCommand(rinfo, ctx, argv, argc); break; case CMD_DAGRUN: status = DAG_CommandParser(ctx, argv, argc, ro_dag, &rinfo); diff --git a/src/model.c b/src/model.c index 481fb59c2..7236679b5 100644 --- a/src/model.c +++ b/src/model.c @@ -261,6 +261,7 @@ int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, Re return REDISMODULE_ERR; } *model = RedisModule_ModuleTypeGetValue(*key); + RedisModule_CloseKey(*key); return REDISMODULE_OK; } diff --git a/src/modelRun_ctx.c b/src/modelRun_ctx.c index 37d8e697f..3b94b259c 100644 --- a/src/modelRun_ctx.c +++ b/src/modelRun_ctx.c @@ -1,5 +1,6 @@ #include "modelRun_ctx.h" +#include "util/string_utils.h" static int _Model_RunCtxAddParam(RAI_ModelCtxParam **paramArr, const char *name, RAI_Tensor *tensor) { @@ -96,11 +97,7 @@ int RedisAI_Parse_ModelRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString * is_input = 1; outputs_flag_count = 1; } else { - if (RMAPI_FUNC_SUPPORTED(RedisModule_HoldString)) { - RedisModule_HoldString(NULL, argv[argpos]); - } else { - RedisModule_RetainString(NULL, argv[argpos]); - } + RAI_HoldString(NULL, argv[argpos]); if (is_input == 0) { *inkeys = array_append(*inkeys, argv[argpos]); ninputs++; diff --git a/src/redisai.c b/src/redisai.c index f24999173..295323038 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -427,18 +427,15 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, } if (!meta && !blob) { - RedisModule_CloseKey(key); return RedisModule_ReplyWithError(ctx, "ERR no META or BLOB specified"); } RAI_Error err = {0}; - char *buffer = NULL; size_t len = 0; if (blob) { RAI_ModelSerialize(mto, &buffer, &len, &err); - if (err.code != RAI_OK) { #ifdef RAI_PRINT_BACKEND_ERRORS printf("ERR: %s\n", err.detail); @@ -455,12 +452,10 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, if (!meta && blob) { RAI_ReplyWithChunks(ctx, buffer, len); RedisModule_Free(buffer); - RedisModule_CloseKey(key); return REDISMODULE_OK; } const int outentries = blob ? 16 : 14; - RedisModule_ReplyWithArray(ctx, outentries); RedisModule_ReplyWithCString(ctx, "backend"); @@ -503,8 +498,6 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, RedisModule_Free(buffer); } - RedisModule_CloseKey(key); - return REDISMODULE_OK; } @@ -523,6 +516,7 @@ int RedisAI_ModelDel_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, return REDISMODULE_ERR; } + key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_WRITE); RedisModule_DeleteKey(key); RedisModule_CloseKey(key); RedisModule_ReplicateVerbatim(ctx); @@ -617,13 +611,11 @@ int RedisAI_ScriptGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv } if (!meta && !source) { - RedisModule_CloseKey(key); return RedisModule_ReplyWithError(ctx, "ERR no META or SOURCE specified"); } if (!meta && source) { RedisModule_ReplyWithCString(ctx, sto->scriptdef); - RedisModule_CloseKey(key); return REDISMODULE_OK; } @@ -638,7 +630,6 @@ int RedisAI_ScriptGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv RedisModule_ReplyWithCString(ctx, "source"); RedisModule_ReplyWithCString(ctx, sto->scriptdef); } - RedisModule_CloseKey(key); return REDISMODULE_OK; } @@ -655,12 +646,11 @@ int RedisAI_ScriptDel_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv if (status == REDISMODULE_ERR) { return REDISMODULE_ERR; } - + key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_WRITE); RedisModule_DeleteKey(key); RedisModule_CloseKey(key); RedisModule_ReplicateVerbatim(ctx); - return RedisModule_ReplyWithSimpleString(ctx, "OK"); } diff --git a/src/run_info.c b/src/run_info.c index 579d956cf..5c6217f58 100644 --- a/src/run_info.c +++ b/src/run_info.c @@ -167,7 +167,7 @@ void RAI_FreeDagOp(RAI_DagOp *dagOp) { RAI_ModelRunCtxFree(dagOp->mctx); } if (dagOp->sctx) { - RAI_ScriptRunCtxFree(dagOp->sctx, true); + RAI_ScriptRunCtxFree(dagOp->sctx); } if (dagOp->inkeys) { diff --git a/src/script.c b/src/script.c index ec75bdfdd..6cb595618 100644 --- a/src/script.c +++ b/src/script.c @@ -156,8 +156,7 @@ RAI_ScriptRunCtx *RAI_ScriptRunCtxCreate(RAI_Script *script, const char *fnname) return sctx; } -static int Script_RunCtxAddParam(RAI_ScriptRunCtx *sctx, RAI_ScriptCtxParam **paramArr, - RAI_Tensor *tensor) { +static int _Script_RunCtxAddParam(RAI_ScriptCtxParam **paramArr, RAI_Tensor *tensor) { RAI_ScriptCtxParam param = { .tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL, }; @@ -165,9 +164,9 @@ static int Script_RunCtxAddParam(RAI_ScriptRunCtx *sctx, RAI_ScriptCtxParam **pa return 1; } -int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor, RAI_Error *err) { +int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor) { // Even if variadic is set, we still allow to add inputs in the LLAPI - return Script_RunCtxAddParam(sctx, &sctx->inputs, inputTensor); + return _Script_RunCtxAddParam(&sctx->inputs, inputTensor); } int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx *sctx, RAI_Tensor **inputTensors, size_t len, @@ -182,7 +181,7 @@ int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx *sctx, RAI_Tensor **inputTenso int res; for (size_t i = 0; i < len; i++) { - res = Script_RunCtxAddParam(sctx, &sctx->inputs, inputTensors[i]); + res = _Script_RunCtxAddParam(&sctx->inputs, inputTensors[i]); if (res != 1) return res; } @@ -190,7 +189,7 @@ int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx *sctx, RAI_Tensor **inputTenso } int RAI_ScriptRunCtxAddOutput(RAI_ScriptRunCtx *sctx) { - return Script_RunCtxAddParam(sctx, &sctx->outputs, NULL); + return _Script_RunCtxAddParam(&sctx->outputs, NULL); } size_t RAI_ScriptRunCtxNumOutputs(RAI_ScriptRunCtx *sctx) { return array_len(sctx->outputs); } @@ -200,16 +199,15 @@ RAI_Tensor *RAI_ScriptRunCtxOutputTensor(RAI_ScriptRunCtx *sctx, size_t index) { return sctx->outputs[index].tensor; } -void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx *sctx, int freeTensors) { - if (freeTensors) { - for (size_t i = 0; i < array_len(sctx->inputs); ++i) { - RAI_TensorFree(sctx->inputs[i].tensor); - } +void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx *sctx) { - for (size_t i = 0; i < array_len(sctx->outputs); ++i) { - if (sctx->outputs[i].tensor) { - RAI_TensorFree(sctx->outputs[i].tensor); - } + for (size_t i = 0; i < array_len(sctx->inputs); ++i) { + RAI_TensorFree(sctx->inputs[i].tensor); + } + + for (size_t i = 0; i < array_len(sctx->outputs); ++i) { + if (sctx->outputs[i].tensor) { + RAI_TensorFree(sctx->outputs[i].tensor); } } @@ -261,6 +259,7 @@ int RAI_GetScriptFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, R return REDISMODULE_ERR; } *script = RedisModule_ModuleTypeGetValue(*key); + RedisModule_CloseKey(*key); return REDISMODULE_OK; } diff --git a/src/script.h b/src/script.h index ec43e42a4..de9241db6 100644 --- a/src/script.h +++ b/src/script.h @@ -66,11 +66,9 @@ RAI_ScriptRunCtx *RAI_ScriptRunCtxCreate(RAI_Script *script, const char *fnname) * * @param sctx input RAI_ScriptRunCtx to add the input tensor * @param inputTensor input tensor structure - * @param err error data structure to store error message in the case of - * failures * @return returns 1 on success, 0 in case of error. */ -int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor, RAI_Error *err); +int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor); /** * For each Allocates a RAI_ScriptCtxParam data structure, and enforces a @@ -80,12 +78,10 @@ int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor, RA * @param sctx input RAI_ScriptRunCtx to add the input tensor * @param inputTensors input tensors array * @param len input tensors array len - * @param err error data structure to store error message in the case of - * failures * @return returns 1 on success, 0 in case of error. */ int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx *sctx, RAI_Tensor **inputTensors, size_t len, - RAI_Error *err); + RAI_Error *error); /** * Allocates a RAI_ScriptCtxParam data structure, and sets the tensor reference @@ -119,9 +115,8 @@ RAI_Tensor *RAI_ScriptRunCtxOutputTensor(RAI_ScriptRunCtx *sctx, size_t index); * work * * @param sctx - * @param freeTensors free input and output tensors or leave them allocated */ -void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx *sctx, int freeTensors); +void RAI_ScriptRunCtxFree(RAI_ScriptRunCtx *sctx); /** * Given the input script context, run associated script From 34d153bace9e6f20f0f5b9260ee9d4100952061f Mon Sep 17 00:00:00 2001 From: alonre24 Date: Mon, 28 Dec 2020 16:26:06 +0200 Subject: [PATCH 2/3] Support Async script run via LLAPI --- src/DAG/dag.c | 8 +-- src/command_parser.c | 3 +- src/redisai.c | 4 +- src/redisai.h | 5 ++ src/run_info.c | 15 ++++++ src/run_info.h | 8 +++ src/script.c | 28 +++++++++- src/script.h | 15 +++++- tests/flow/tests_llapi.py | 51 +++++++++++++++--- tests/module/LLAPI.c | 109 +++++++++++++++++++++++++++++++++++++- 10 files changed, 228 insertions(+), 18 deletions(-) diff --git a/src/DAG/dag.c b/src/DAG/dag.c index aa24b8e50..962ce3bc9 100644 --- a/src/DAG/dag.c +++ b/src/DAG/dag.c @@ -257,7 +257,7 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur RAI_ContextUnlock(rinfo); for (uint i = 0; i < n_inkeys; i++) { - RAI_ScriptRunCtxAddInput(currentOp->sctx, inputTensors[i]); + RAI_ScriptRunCtxAddInput(currentOp->sctx, inputTensors[i], currentOp->err); } for (uint i = 0; i < n_outkeys; i++) { RAI_ScriptRunCtxAddOutput(currentOp->sctx); @@ -593,16 +593,16 @@ static void _PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) { "ERR specified persistent key that was not used in DAG"); rinfo->dagReplyLength++; RedisModule_Log(ctx, "warning", - "on DAGRUN's PERSIST pecified persistent key (%s) that " + "on DAGRUN's PERSIST specified persistent key (%s) that " "was not used on DAG. Logging all local context keys", - persist_key_name); + RedisModule_StringPtrLen(persist_key_name, NULL)); AI_dictIterator *local_iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext); AI_dictEntry *local_entry = AI_dictNext(local_iter); while (local_entry) { RedisModuleString *localcontext_key_name = AI_dictGetKey(local_entry); RedisModule_Log(ctx, "warning", "DAG's local context key (%s)", - localcontext_key_name); + RedisModule_StringPtrLen(localcontext_key_name, NULL)); local_entry = AI_dictNext(local_iter); } AI_dictReleaseIterator(local_iter); diff --git a/src/command_parser.c b/src/command_parser.c index 22e2d5a18..e35880269 100644 --- a/src/command_parser.c +++ b/src/command_parser.c @@ -263,6 +263,7 @@ static int _ScriptRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inke RAI_Tensor *t; RedisModuleKey *key; + RAI_Error *err; size_t ninputs = array_len(inkeys), noutputs = array_len(outkeys); for (size_t i = 0; i < ninputs; i++) { const int status = RAI_GetTensorFromKeyspace(ctx, inkeys[i], &key, &t, REDISMODULE_READ); @@ -271,7 +272,7 @@ static int _ScriptRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inke RedisModule_StringPtrLen(inkeys[i], NULL)); return REDISMODULE_ERR; } - RAI_ScriptRunCtxAddInput(sctx, t); + RAI_ScriptRunCtxAddInput(sctx, t, err); } for (size_t i = 0; i < noutputs; i++) { RAI_ScriptRunCtxAddOutput(sctx); diff --git a/src/redisai.c b/src/redisai.c index 295323038..3c4b9cd19 100644 --- a/src/redisai.c +++ b/src/redisai.c @@ -969,7 +969,7 @@ static int RedisAI_RegisterApi(RedisModuleCtx *ctx) { REGISTER_API(ModelGetShallowCopy, ctx); REGISTER_API(ModelRedisType, ctx); REGISTER_API(ModelRunAsync, ctx); - REGISTER_API(GetAsModelRunCtx, ctx) + REGISTER_API(GetAsModelRunCtx, ctx); REGISTER_API(ScriptCreate, ctx); REGISTER_API(ScriptFree, ctx); @@ -983,6 +983,8 @@ static int RedisAI_RegisterApi(RedisModuleCtx *ctx) { REGISTER_API(ScriptRun, ctx); REGISTER_API(ScriptGetShallowCopy, ctx); REGISTER_API(ScriptRedisType, ctx); + REGISTER_API(ScriptRunAsync, ctx); + REGISTER_API(GetAsScriptRunCtx, ctx); return REDISMODULE_OK; } diff --git a/src/redisai.h b/src/redisai.h index 44c6614d8..2a59bc331 100644 --- a/src/redisai.h +++ b/src/redisai.h @@ -121,6 +121,9 @@ void MODULE_API_FUNC(RedisAI_ScriptRunCtxFree)(RAI_ScriptRunCtx *sctx); int MODULE_API_FUNC(RedisAI_ScriptRun)(RAI_ScriptRunCtx *sctx, RAI_Error *err); RAI_Script *MODULE_API_FUNC(RedisAI_ScriptGetShallowCopy)(RAI_Script *script); RedisModuleType *MODULE_API_FUNC(RedisAI_ScriptRedisType)(void); +int MODULE_API_FUNC(RedisAI_ScriptRunAsync)(RAI_ScriptRunCtx *sctx, RAI_OnFinishCB DAGAsyncFinish, + void *private_data); +RAI_ScriptRunCtx *MODULE_API_FUNC(RedisAI_GetAsScriptRunCtx)(RAI_OnFinishCtx *ctx, RAI_Error *err); int MODULE_API_FUNC(RedisAI_GetLLAPIVersion)(); @@ -204,6 +207,8 @@ static int RedisAI_Initialize(RedisModuleCtx *ctx) { REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRun); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptGetShallowCopy); REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRedisType); + REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunAsync); + REDISAI_MODULE_INIT_FUNCTION(ctx, GetAsScriptRunCtx); if (RedisAI_GetLLAPIVersion() < REDISAI_LLAPI_VERSION) { return REDISMODULE_ERR; diff --git a/src/run_info.c b/src/run_info.c index 5c6217f58..df939e30b 100644 --- a/src/run_info.c +++ b/src/run_info.c @@ -329,6 +329,7 @@ int RAI_RunInfoBatchable(struct RAI_DagOp *op1, struct RAI_DagOp *op2) { return 1; } + RAI_ModelRunCtx *RAI_GetAsModelRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err) { RAI_DagOp *op = rinfo->dagOps[0]; @@ -342,3 +343,17 @@ RAI_ModelRunCtx *RAI_GetAsModelRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err) { RAI_FreeRunInfo(rinfo); return mctx; } + +RAI_ScriptRunCtx *RAI_GetAsScriptRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err) { + + RAI_DagOp *op = rinfo->dagOps[0]; + if (!rinfo->single_op_dag || !op->sctx) { + RAI_SetError(err, RedisAI_ErrorCode_EFINISHCTX, "Finish ctx is not a script run ctx"); + return NULL; + } + RAI_SetError(err, RAI_GetErrorCode(op->err), RAI_GetError(op->err)); + RAI_ScriptRunCtx *sctx = op->sctx; + rinfo->dagOps[0]->sctx = NULL; + RAI_FreeRunInfo(rinfo); + return sctx; +} diff --git a/src/run_info.h b/src/run_info.h index a53a40b1a..e692a51e2 100644 --- a/src/run_info.h +++ b/src/run_info.h @@ -188,6 +188,14 @@ int RAI_RunInfoBatchable(struct RAI_DagOp *op1, struct RAI_DagOp *op2); */ RAI_ModelRunCtx *RAI_GetAsModelRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err); +/** + * Retreive the ScriptRunCtx of a DAG runInfo that contains a single op of type + * SCRIPTRUN. + * @param DAG runInfo. + * @return Pointer to the ScriptRunCtx in DAG's single op. + */ +RAI_ScriptRunCtx *RAI_GetAsScriptRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err); + #ifdef __cplusplus } // extern "C" #endif diff --git a/src/script.c b/src/script.c index 6cb595618..6532769a4 100644 --- a/src/script.c +++ b/src/script.c @@ -7,7 +7,8 @@ */ #include "script.h" - +#include "run_info.h" +#include "DAG/dag.h" #include "backends.h" #include "rmutil/alloc.h" #include "script_struct.h" @@ -164,7 +165,7 @@ static int _Script_RunCtxAddParam(RAI_ScriptCtxParam **paramArr, RAI_Tensor *ten return 1; } -int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor) { +int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor, RAI_Error *error) { // Even if variadic is set, we still allow to add inputs in the LLAPI return _Script_RunCtxAddParam(&sctx->inputs, inputTensor); } @@ -352,3 +353,26 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString } RedisModuleType *RAI_ScriptRedisType(void) { return RedisAI_ScriptType; } + +int RAI_ScriptRunAsync(RAI_ScriptRunCtx *sctx, RAI_OnFinishCB ScriptAsyncFinish, + void *private_data) { + + RedisAI_RunInfo *rinfo = NULL; + if (RAI_InitRunInfo(&rinfo) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + rinfo->single_op_dag = 1; + rinfo->OnFinish = (RedisAI_OnFinishCB)ScriptAsyncFinish; + rinfo->private_data = private_data; + + RAI_DagOp *op; + if (RAI_InitDagOp(&op) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + op->commandType = REDISAI_DAG_CMD_SCRIPTRUN; + Dag_PopulateOp(op, sctx, NULL, NULL, NULL); + + rinfo->dagOps = array_append(rinfo->dagOps, op); + rinfo->dagOpCount = 1; + return DAG_InsertDAGToQueue(rinfo); +} diff --git a/src/script.h b/src/script.h index de9241db6..33da8ef87 100644 --- a/src/script.h +++ b/src/script.h @@ -14,6 +14,7 @@ #include "redismodule.h" #include "script_struct.h" #include "tensor.h" +#include "run_info.h" extern RedisModuleType *RedisAI_ScriptType; @@ -68,7 +69,7 @@ RAI_ScriptRunCtx *RAI_ScriptRunCtxCreate(RAI_Script *script, const char *fnname) * @param inputTensor input tensor structure * @return returns 1 on success, 0 in case of error. */ -int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor); +int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx *sctx, RAI_Tensor *inputTensor, RAI_Error *error); /** * For each Allocates a RAI_ScriptCtxParam data structure, and enforces a @@ -212,4 +213,16 @@ void RedisAI_ReplyOrSetError(RedisModuleCtx *ctx, RAI_Error *error, RAI_ErrorCod */ RedisModuleType *RAI_ScriptRedisType(void); +/** + * Insert the ScriptRunCtx to the run queues so it will run asynchronously. + * + * @param sctx SodelRunCtx to execute + * @param ScriptAsyncFinish A callback that will be called when the execution is finished. + * @param private_data This is going to be sent to to the ScriptAsyncFinish. + * @return REDISMODULE_OK if the sctx was insert to the queues successfully, REDISMODULE_ERR + * otherwise. + */ +int RAI_ScriptRunAsync(RAI_ScriptRunCtx *sctx, RAI_OnFinishCB ScriptAsyncFinish, + void *private_data); + #endif /* SRC_SCRIPT_H_ */ diff --git a/tests/flow/tests_llapi.py b/tests/flow/tests_llapi.py index fe443e8e3..60319bff5 100644 --- a/tests/flow/tests_llapi.py +++ b/tests/flow/tests_llapi.py @@ -2,30 +2,45 @@ from includes import * import os +from functools import wraps ''' python -m RLTest --test tests_llapi.py --module path/to/redisai.so ''' -goal_dir = os.path.join(os.getcwd(), "../module/LLAPI.so") -TEST_MODULE_PATH = os.path.abspath(goal_dir) +def ensure_test_module_loaded(f): + @wraps(f) + def wrapper(env, *args, **kwargs): + goal_dir = os.path.join(os.getcwd(), "../module/LLAPI.so") + TEST_MODULE_PATH = os.path.abspath(goal_dir) + con = env.getConnection() + modules = con.execute_command("MODULE", "LIST") + if b'RAI_llapi' in [module[1] for module in modules]: + return f(env, *args, **kwargs) + try: + ret = con.execute_command('MODULE', 'LOAD', TEST_MODULE_PATH) + env.assertEqual(ret, b'OK') + return f(env, *args, **kwargs) + except Exception as e: + env.assertFalse(True) + env.debugPrint(str(e), force=True) + return + return wrapper + +@ensure_test_module_loaded def test_basic_check(env): con = env.getConnection() - ret = con.execute_command("MODULE", "LOAD", TEST_MODULE_PATH) - env.assertEqual(ret, b'OK') ret = con.execute_command("RAI_llapi.basic_check") env.assertEqual(ret, b'OK') +@ensure_test_module_loaded def test_model_run_async(env): con = env.getConnection() - ret = con.execute_command("MODULE", "LOAD", TEST_MODULE_PATH) - env.assertEqual(ret, b'OK') - test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') model_filename = os.path.join(test_data_path, 'graph.pb') @@ -39,3 +54,25 @@ def test_model_run_async(env): con.execute_command('AI.TENSORSET', 'b{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) ret = con.execute_command("RAI_llapi.modelRun") env.assertEqual(ret, b'Async run success') + + +@ensure_test_module_loaded +def test_script_run_async(env): + + con = env.getConnection() + test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') + script_filename = os.path.join(test_data_path, 'script.txt') + + with open(script_filename, 'rb') as f: + script = f.read() + + ret = con.execute_command('AI.SCRIPTSET', 'myscript{1}', DEVICE, 'TAG', 'version1', 'SOURCE', script) + env.assertEqual(ret, b'OK') + + ret = con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + ret = con.execute_command('AI.TENSORSET', 'b{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) + env.assertEqual(ret, b'OK') + + ret = con.execute_command("RAI_llapi.scriptRun") + env.assertEqual(ret, b'Async run success') diff --git a/tests/module/LLAPI.c b/tests/module/LLAPI.c index 7e38f9e0e..b4c179b39 100644 --- a/tests/module/LLAPI.c +++ b/tests/module/LLAPI.c @@ -32,7 +32,40 @@ int RAI_llapi_basic_check(RedisModuleCtx *ctx, RedisModuleString **argv, int arg return RedisModule_ReplyWithError(ctx, "ERROR"); } -static void ModelFinishFunc(RAI_OnFinishCtx *onFinishCtx, void *private_data) { +static void _ScriptFinishFunc(RAI_OnFinishCtx *onFinishCtx, void *private_data) { + + RAI_Error *err; + if (RedisAI_InitError(&err) != REDISMODULE_OK) goto finish; + RAI_ScriptRunCtx* sctx = RedisAI_GetAsScriptRunCtx(onFinishCtx, err); + if(RedisAI_GetErrorCode(err) != RedisAI_ErrorCode_OK) { + *(int *) private_data = LLAPI_RUN_ERROR; + goto finish; + } + if(RedisAI_ScriptRunCtxNumOutputs(sctx) != 1) { + *(int *) private_data = LLAPI_NUM_OUTPUTS_ERROR; + goto finish; + } + RAI_Tensor *tensor = RedisAI_ScriptRunCtxOutputTensor(sctx, 0); + double expceted[4] = {4, 6, 4, 6}; + double val[4]; + + // Verify that we received the expected tensor at the end of the run. + for (long long i = 0; i < 4; i++) { + if(RedisAI_TensorGetValueAsDouble(tensor, i, &val[i]) != 0) { + goto finish; + } + if (expceted[i] != val[i]) { + goto finish; + } + } + *(int *)private_data = LLAPI_RUN_SUCCESS; + + finish: + RedisAI_FreeError(err); + pthread_cond_signal(&global_cond); +} + +static void _ModelFinishFunc(RAI_OnFinishCtx *onFinishCtx, void *private_data) { RAI_Error *err; if (RedisAI_InitError(&err) != REDISMODULE_OK) goto finish; @@ -68,7 +101,7 @@ static void ModelFinishFunc(RAI_OnFinishCtx *onFinishCtx, void *private_data) { static int _ExecuteModelRunAsync(RedisModuleCtx *ctx, RAI_ModelRunCtx* mctx) { LLAPI_status status = LLAPI_RUN_NONE; pthread_mutex_lock(&global_lock); - if (RedisAI_ModelRunAsync(mctx, ModelFinishFunc, &status) != REDISMODULE_OK) { + if (RedisAI_ModelRunAsync(mctx, _ModelFinishFunc, &status) != REDISMODULE_OK) { pthread_mutex_unlock(&global_lock); RedisAI_ModelRunCtxFree(mctx); RedisModule_ReplyWithError(ctx, "Async run could not start"); @@ -82,6 +115,23 @@ static int _ExecuteModelRunAsync(RedisModuleCtx *ctx, RAI_ModelRunCtx* mctx) { return status; } +static int _ExecuteScriptRunAsync(RedisModuleCtx *ctx, RAI_ScriptRunCtx* sctx) { + LLAPI_status status = LLAPI_RUN_NONE; + pthread_mutex_lock(&global_lock); + if (RedisAI_ScriptRunAsync(sctx, _ScriptFinishFunc, &status) != REDISMODULE_OK) { + pthread_mutex_unlock(&global_lock); + RedisAI_ScriptRunCtxFree(sctx); + RedisModule_ReplyWithError(ctx, "Async run could not start"); + return LLAPI_RUN_NONE; + } + + // Wait until the onFinish callback returns. + pthread_cond_wait(&global_cond, &global_lock); + pthread_mutex_unlock(&global_lock); + RedisAI_ScriptRunCtxFree(sctx); + return status; +} + int RAI_llapi_modelRun(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { REDISMODULE_NOT_USED(argv); @@ -132,6 +182,57 @@ int RAI_llapi_modelRun(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return RedisModule_ReplyWithSimpleString(ctx, "Async run success"); } +int RAI_llapi_scriptRun(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + REDISMODULE_NOT_USED(argv); + + if (argc>1) { + RedisModule_WrongArity(ctx); + return REDISMODULE_OK; + } + // The script 'myscript{1}' should exist in key space. + const char *keyNameStr = "myscript{1}"; + RedisModuleString *keyRedisStr = RedisModule_CreateString(ctx, keyNameStr, strlen(keyNameStr)); + RedisModuleKey *key = RedisModule_OpenKey(ctx, keyRedisStr, REDISMODULE_READ); + RAI_Script *script = RedisModule_ModuleTypeGetValue(key); + RAI_ScriptRunCtx* sctx = RedisAI_ScriptRunCtxCreate(script, "bad_func"); + RedisModule_FreeString(ctx, keyRedisStr); + RedisModule_CloseKey(key); + + // Test the case of a failure in the script run execution (func name does not exist in script). + if(_ExecuteScriptRunAsync(ctx, sctx) != LLAPI_RUN_ERROR) { + return RedisModule_ReplyWithSimpleString(ctx, "Async run should end with an error"); + } + + sctx = RedisAI_ScriptRunCtxCreate(script, "bar"); + RAI_Error *err; + // The tensors a{1} and b{1} should exist in key space. + // Load the tensors a{1} and b{1} and add them as inputs for the script. + keyNameStr = "a{1}"; + keyRedisStr = RedisModule_CreateString(ctx, keyNameStr, + strlen(keyNameStr)); + key = RedisModule_OpenKey(ctx, keyRedisStr, REDISMODULE_READ); + RAI_Tensor *input1 = RedisModule_ModuleTypeGetValue(key); + RedisAI_ScriptRunCtxAddInput(sctx, input1, err); + RedisModule_FreeString(ctx, keyRedisStr); + RedisModule_CloseKey(key); + + keyNameStr = "b{1}"; + keyRedisStr = RedisModule_CreateString(ctx, keyNameStr, + strlen(keyNameStr)); + key = RedisModule_OpenKey(ctx, keyRedisStr, REDISMODULE_READ); + RAI_Tensor *input2 = RedisModule_ModuleTypeGetValue(key); + RedisAI_ScriptRunCtxAddInput(sctx, input2, err); + RedisModule_FreeString(ctx, keyRedisStr); + RedisModule_CloseKey(key); + + // Add the expected output tensor. + RedisAI_ScriptRunCtxAddOutput(sctx); + + if (_ExecuteScriptRunAsync(ctx, sctx) != LLAPI_RUN_SUCCESS) + return RedisModule_ReplyWithSimpleString(ctx, "Async run failed"); + return RedisModule_ReplyWithSimpleString(ctx, "Async run success"); +} + int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { REDISMODULE_NOT_USED(argv); REDISMODULE_NOT_USED(argc); @@ -151,5 +252,9 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) if(RedisModule_CreateCommand(ctx, "RAI_llapi.modelRun", RAI_llapi_modelRun, "", 0, 0, 0) == REDISMODULE_ERR) return REDISMODULE_ERR; + + if(RedisModule_CreateCommand(ctx, "RAI_llapi.scriptRun", RAI_llapi_scriptRun, "", + 0, 0, 0) == REDISMODULE_ERR) + return REDISMODULE_ERR; return REDISMODULE_OK; } From 9c83cbd0424374d4fa96edc9de6a85600c602f2a Mon Sep 17 00:00:00 2001 From: alonre24 Date: Mon, 28 Dec 2020 18:19:31 +0200 Subject: [PATCH 3/3] Fix the RedisModule_CloseKey problem - remove the call for freeing the runkey in DagOp free function. Closing the model key no longer leeds to a crash. --- src/DAG/dag_parser.c | 2 -- src/run_info.c | 3 --- 2 files changed, 5 deletions(-) diff --git a/src/DAG/dag_parser.c b/src/DAG/dag_parser.c index 13a0f3412..e0ff9113e 100644 --- a/src/DAG/dag_parser.c +++ b/src/DAG/dag_parser.c @@ -223,7 +223,6 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b RedisModuleKey *modelKey; const int status = RAI_GetModelFromKeyspace(ctx, argv[arg_pos + 1], &modelKey, &mto, REDISMODULE_READ); - RedisModule_OpenKey(ctx, argv[arg_pos + 1], REDISMODULE_READ); if (status == REDISMODULE_ERR) { RAI_FreeRunInfo(rinfo); RedisModule_ReplyWithError(ctx, "ERR Model not found"); @@ -244,7 +243,6 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b RedisModuleKey *scriptKey; const int status = RAI_GetScriptFromKeyspace(ctx, argv[arg_pos + 1], &scriptKey, &sto, REDISMODULE_READ); - RedisModule_OpenKey(ctx, argv[arg_pos + 1], REDISMODULE_READ); if (status == REDISMODULE_ERR) { RAI_FreeRunInfo(rinfo); return REDISMODULE_ERR; diff --git a/src/run_info.c b/src/run_info.c index df939e30b..b338f29ee 100644 --- a/src/run_info.c +++ b/src/run_info.c @@ -145,9 +145,6 @@ int RAI_ShallowCopyDagRunInfo(RedisAI_RunInfo **result, RedisAI_RunInfo *src) { void RAI_FreeDagOp(RAI_DagOp *dagOp) { if (dagOp) { RAI_FreeError(dagOp->err); - if (dagOp->runkey) { - RedisModule_FreeString(NULL, dagOp->runkey); - } if (dagOp->argv) { for (size_t i = 0; i < array_len(dagOp->argv); i++) { RedisModule_FreeString(NULL, dagOp->argv[i]);