-
Notifications
You must be signed in to change notification settings - Fork 106
Create model execute command #680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
51b4c2d
b646b02
01cd46d
28f54bc
606bb89
d6673e8
7df134a
11649ea
c57e3ab
5cebe5a
d3efe52
928f7cc
a4cceeb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,82 +6,156 @@ | |
| #include "DAG/dag_parser.h" | ||
| #include "util/string_utils.h" | ||
| #include "execution/modelRun_ctx.h" | ||
| #include "deprecated.h" | ||
|
|
||
| extern int rlecMajorVersion; | ||
|
|
||
| static inline int IsEnterprise() { return rlecMajorVersion != -1; } | ||
|
|
||
| // Use this to check if a command is given a key whose hash slot is not on the current | ||
| // shard, when using enterprise cluster. | ||
| static void _AnalyzeKey(RedisModuleString *key_str, RAI_Error *err) { | ||
alonre24 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if (IsEnterprise() && RAI_GetErrorCode(err) == RAI_EKEYEMPTY) { | ||
| int first_slot, last_slot; | ||
| RedisModule_ShardingGetSlotRange(&first_slot, &last_slot); | ||
| int key_slot = RedisModule_ShardingGetKeySlot(key_str); | ||
| if (key_slot < first_slot || key_slot > last_slot) { | ||
| RAI_ClearError(err); | ||
| RAI_SetError(err, RAI_EKEYEMPTY, | ||
| "ERR CROSSSLOT Keys in request don't hash to the same slot"); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| static int _parseTimeout(RedisModuleString *timeout_arg, RAI_Error *error, long long *timeout) { | ||
|
|
||
| const int retval = RedisModule_StringToLongLong(timeout_arg, timeout); | ||
| if (retval != REDISMODULE_OK || timeout <= 0) { | ||
| if (retval != REDISMODULE_OK || *timeout <= 0) { | ||
| RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid value for TIMEOUT"); | ||
| return REDISMODULE_ERR; | ||
| } | ||
| return REDISMODULE_OK; | ||
| } | ||
|
|
||
| static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModuleString **argv, | ||
| RAI_Model **model, RAI_Error *error, | ||
| RedisModuleString ***inkeys, RedisModuleString ***outkeys, | ||
| RedisModuleString **runkey, long long *timeout) { | ||
| static int _ModelExecuteCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModuleString **argv, | ||
| RAI_Model **model, RAI_Error *error, | ||
| RedisModuleString ***inkeys, RedisModuleString ***outkeys, | ||
| RedisModuleString **runkey, long long *timeout) { | ||
|
|
||
| if (argc < 6) { | ||
| if (argc < 8) { | ||
| RAI_SetError(error, RAI_EMODELRUN, | ||
| "ERR wrong number of arguments for 'AI.MODELRUN' command"); | ||
| "ERR wrong number of arguments for 'AI.MODELEXECUTE' command"); | ||
| return REDISMODULE_ERR; | ||
| } | ||
| size_t argpos = 1; | ||
| const int status = RAI_GetModelFromKeyspace(ctx, argv[argpos], model, REDISMODULE_READ, error); | ||
| size_t arg_pos = 1; | ||
| const int status = RAI_GetModelFromKeyspace(ctx, argv[arg_pos], model, REDISMODULE_READ, error); | ||
| if (status == REDISMODULE_ERR) { | ||
| // IFDEF LITE, call _AnalyzeKey() because model key should be located at this shard. | ||
alonre24 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return REDISMODULE_ERR; | ||
| } | ||
| RAI_HoldString(NULL, argv[argpos]); | ||
| *runkey = argv[argpos]; | ||
| const char *arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL); | ||
| *runkey = RAI_HoldString(NULL, argv[arg_pos++]); | ||
| const char *arg_string = RedisModule_StringPtrLen(argv[arg_pos++], NULL); | ||
|
|
||
| // Parse timeout arg if given and store it in timeout | ||
| if (!strcasecmp(arg_string, "TIMEOUT")) { | ||
| if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR) | ||
| return REDISMODULE_ERR; | ||
| arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL); | ||
| } | ||
| if (strcasecmp(arg_string, "INPUTS") != 0) { | ||
| RAI_SetError(error, RAI_EMODELRUN, "ERR INPUTS not specified"); | ||
| return REDISMODULE_ERR; | ||
| } | ||
|
|
||
| bool is_input = true, is_output = false; | ||
| size_t ninputs = 0, noutputs = 0; | ||
|
|
||
| while (++argpos < argc) { | ||
| arg_string = RedisModule_StringPtrLen(argv[argpos], NULL); | ||
| if (!strcasecmp(arg_string, "OUTPUTS") && !is_output) { | ||
| is_input = false; | ||
| is_output = true; | ||
| } else { | ||
| RAI_HoldString(NULL, argv[argpos]); | ||
| if (is_input) { | ||
| ninputs++; | ||
| *inkeys = array_append(*inkeys, argv[argpos]); | ||
| } else { | ||
| noutputs++; | ||
| *outkeys = array_append(*outkeys, argv[argpos]); | ||
| } | ||
| } | ||
| long long ninputs = 0, noutputs = 0; | ||
| if (RedisModule_StringToLongLong(argv[arg_pos++], &ninputs) != REDISMODULE_OK) { | ||
| RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid argument for input_count"); | ||
| return REDISMODULE_ERR; | ||
| } | ||
| if (ninputs <= 0) { | ||
| RAI_SetError(error, RAI_EMODELRUN, "ERR Input count must be a positive integer"); | ||
| return REDISMODULE_ERR; | ||
| } | ||
| if ((*model)->ninputs != ninputs) { | ||
| RAI_SetError(error, RAI_EMODELRUN, | ||
| "Number of keys given as INPUTS here does not match model definition"); | ||
| return REDISMODULE_ERR; | ||
| } | ||
| for (; arg_pos < ninputs + 4 && arg_pos < argc; arg_pos++) { | ||
alonre24 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
alonre24 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| *inkeys = array_append(*inkeys, RAI_HoldString(NULL, argv[arg_pos])); | ||
| } | ||
| if (arg_pos != ninputs + 4) { | ||
| RAI_SetError( | ||
| error, RAI_EMODELRUN, | ||
| "ERR number of input keys to AI.MODELEXECUTE command does not match the number of " | ||
| "given arguments"); | ||
| return REDISMODULE_ERR; | ||
| } | ||
|
|
||
| // After inputs args, there must be at least 3 more args ("OUTPUT" "output_count" | ||
| // <first_output>) | ||
| if (argc < arg_pos + 2) { | ||
| RAI_SetError(error, RAI_EMODELRUN, "ERR OUTPUTS not specified"); | ||
| return REDISMODULE_ERR; | ||
alonre24 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| if (strcasecmp(RedisModule_StringPtrLen(argv[arg_pos++], NULL), "OUTPUTS") != 0) { | ||
| RAI_SetError(error, RAI_EMODELRUN, "ERR OUTPUTS not specified"); | ||
| return REDISMODULE_ERR; | ||
| } | ||
| if (RedisModule_StringToLongLong(argv[arg_pos++], &noutputs) != REDISMODULE_OK) { | ||
| RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid argument for output_count"); | ||
| } | ||
| if (noutputs <= 0) { | ||
| RAI_SetError(error, RAI_EMODELRUN, "ERR Input count must be a positive integer"); | ||
| return REDISMODULE_ERR; | ||
| } | ||
| if ((*model)->noutputs != noutputs) { | ||
| RAI_SetError(error, RAI_EMODELRUN, | ||
| "Number of keys given as OUTPUTS here does not match model definition"); | ||
| return REDISMODULE_ERR; | ||
| } | ||
| for (; arg_pos < noutputs + ninputs + 6 && arg_pos < argc; arg_pos++) { | ||
alonre24 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
alonre24 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| *outkeys = array_append(*outkeys, RAI_HoldString(NULL, argv[arg_pos])); | ||
| } | ||
| if (arg_pos != noutputs + ninputs + 6) { | ||
| RAI_SetError( | ||
| error, RAI_EMODELRUN, | ||
| "ERR number of output keys to AI.MODELEXECUTE command does not match the number of " | ||
| "given arguments"); | ||
| return REDISMODULE_ERR; | ||
| } | ||
|
|
||
| if (arg_pos == argc) { | ||
| return REDISMODULE_OK; | ||
| } | ||
|
|
||
| // Parse timeout arg if given and store it in timeout. | ||
| char *error_str; | ||
| arg_string = RedisModule_StringPtrLen(argv[arg_pos++], NULL); | ||
| if (!strcasecmp(arg_string, "TIMEOUT")) { | ||
| if (arg_pos == argc) { | ||
| RAI_SetError(error, RAI_EMODELRUN, "ERR No value provided for TIMEOUT"); | ||
| return REDISMODULE_ERR; | ||
| } | ||
| if (_parseTimeout(argv[arg_pos++], error, timeout) == REDISMODULE_ERR) | ||
| return REDISMODULE_ERR; | ||
| } else { | ||
| error_str = RedisModule_Alloc(strlen("Invalid argument: ") + strlen(arg_string) + 1); | ||
| sprintf(error_str, "Invalid argument: %s", arg_string); | ||
| RAI_SetError(error, RAI_EMODELRUN, error_str); | ||
| RedisModule_Free(error_str); | ||
| return REDISMODULE_ERR; | ||
| } | ||
|
|
||
| // There are no more valid args to be processed. | ||
| if (arg_pos != argc) { | ||
| arg_string = RedisModule_StringPtrLen(argv[arg_pos], NULL); | ||
| error_str = RedisModule_Alloc(strlen("Invalid argument: ") + strlen(arg_string) + 1); | ||
| sprintf(error_str, "Invalid argument: %s", arg_string); | ||
| RAI_SetError(error, RAI_EMODELRUN, error_str); | ||
| RedisModule_Free(error_str); | ||
| return REDISMODULE_ERR; | ||
| } | ||
| return REDISMODULE_OK; | ||
| } | ||
|
|
||
| /** | ||
| * Extract the params for the ModelCtxRun object from AI.MODELRUN arguments. | ||
| * Extract the params for the ModelCtxRun object from AI.MODELEXECUTE arguments. | ||
| * | ||
| * @param ctx Context in which Redis modules operate | ||
| * @param inkeys Model input tensors keys, as an array of strings | ||
|
|
@@ -105,6 +179,7 @@ static int _ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkey | |
| if (status == REDISMODULE_ERR) { | ||
| RedisModule_Log(ctx, "warning", "could not load input tensor %s from keyspace", | ||
| RedisModule_StringPtrLen(inkeys[i], NULL)); | ||
| _AnalyzeKey(inkeys[i], err); // Relevant for enterprise cluster. | ||
| return REDISMODULE_ERR; | ||
| } | ||
| if (model->inputs) | ||
|
|
@@ -115,22 +190,26 @@ static int _ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkey | |
| for (size_t i = 0; i < noutputs; i++) { | ||
| if (model->outputs) | ||
| opname = model->outputs[i]; | ||
| _AnalyzeKey(outkeys[i], err); // Relevant for enterprise cluster. | ||
|
||
| if (RAI_GetErrorCode(err) != RAI_OK) { | ||
| return REDISMODULE_ERR; | ||
| } | ||
| RAI_ModelRunCtxAddOutput(mctx, opname); | ||
| } | ||
| return REDISMODULE_OK; | ||
| } | ||
|
|
||
| int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv, | ||
| int argc) { | ||
| int ParseModelExecuteCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv, | ||
| int argc) { | ||
|
|
||
| int res = REDISMODULE_ERR; | ||
| // Build a ModelRunCtx from command. | ||
| RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL); | ||
| RAI_Model *model; | ||
| long long timeout = 0; | ||
| if (_ModelRunCommand_ParseArgs(ctx, argc, argv, &model, rinfo->err, ¤tOp->inkeys, | ||
| ¤tOp->outkeys, ¤tOp->runkey, | ||
| &timeout) == REDISMODULE_ERR) { | ||
| if (_ModelExecuteCommand_ParseArgs(ctx, argc, argv, &model, rinfo->err, ¤tOp->inkeys, | ||
| ¤tOp->outkeys, ¤tOp->runkey, | ||
| &timeout) == REDISMODULE_ERR) { | ||
| goto cleanup; | ||
| } | ||
|
|
||
|
|
@@ -372,6 +451,13 @@ int RedisAI_ExecuteCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar | |
| case CMD_DAGRUN: | ||
| status = ParseDAGRunCommand(rinfo, ctx, argv, argc, ro_dag); | ||
| break; | ||
| case CMD_MODELEXECUTE: | ||
| rinfo->single_op_dag = 1; | ||
| RAI_DagOp *modelExecuteOp; | ||
| RAI_InitDagOp(&modelExecuteOp); | ||
| rinfo->dagOps = array_append(rinfo->dagOps, modelExecuteOp); | ||
| status = ParseModelExecuteCommand(rinfo, modelExecuteOp, argv, argc); | ||
| break; | ||
| default: | ||
| break; | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.