Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ ADD_LIBRARY(redisai_obj OBJECT
util/string_utils.c
redisai.c
execution/command_parser.c
execution/deprecated.c
execution/run_info.c
execution/background_workers.c
config/config.c
Expand Down
1 change: 1 addition & 0 deletions src/execution/DAG/dag_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "dag.h"
#include "dag_parser.h"
#include "dag_execute.h"
#include "execution/deprecated.h"

/**
* DAGRUN Building Block to parse [LOAD <nkeys> key1 key2... ]
Expand Down
170 changes: 128 additions & 42 deletions src/execution/command_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,82 +6,156 @@
#include "DAG/dag_parser.h"
#include "util/string_utils.h"
#include "execution/modelRun_ctx.h"
#include "deprecated.h"

extern int rlecMajorVersion;

static inline int IsEnterprise() { return rlecMajorVersion != -1; }

// Use this to check if a command is given a key whose hash slot is not on the current
// shard, when using enterprise cluster.
static void _AnalyzeKey(RedisModuleString *key_str, RAI_Error *err) {

if (IsEnterprise() && RAI_GetErrorCode(err) == RAI_EKEYEMPTY) {
int first_slot, last_slot;
RedisModule_ShardingGetSlotRange(&first_slot, &last_slot);
int key_slot = RedisModule_ShardingGetKeySlot(key_str);
if (key_slot < first_slot || key_slot > last_slot) {
RAI_ClearError(err);
RAI_SetError(err, RAI_EKEYEMPTY,
"ERR CROSSSLOT Keys in request don't hash to the same slot");
}
}
}

static int _parseTimeout(RedisModuleString *timeout_arg, RAI_Error *error, long long *timeout) {

const int retval = RedisModule_StringToLongLong(timeout_arg, timeout);
if (retval != REDISMODULE_OK || timeout <= 0) {
if (retval != REDISMODULE_OK || *timeout <= 0) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid value for TIMEOUT");
return REDISMODULE_ERR;
}
return REDISMODULE_OK;
}

static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModuleString **argv,
RAI_Model **model, RAI_Error *error,
RedisModuleString ***inkeys, RedisModuleString ***outkeys,
RedisModuleString **runkey, long long *timeout) {
static int _ModelExecuteCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModuleString **argv,
RAI_Model **model, RAI_Error *error,
RedisModuleString ***inkeys, RedisModuleString ***outkeys,
RedisModuleString **runkey, long long *timeout) {

if (argc < 6) {
if (argc < 8) {
RAI_SetError(error, RAI_EMODELRUN,
"ERR wrong number of arguments for 'AI.MODELRUN' command");
"ERR wrong number of arguments for 'AI.MODELEXECUTE' command");
return REDISMODULE_ERR;
}
size_t argpos = 1;
const int status = RAI_GetModelFromKeyspace(ctx, argv[argpos], model, REDISMODULE_READ, error);
size_t arg_pos = 1;
const int status = RAI_GetModelFromKeyspace(ctx, argv[arg_pos], model, REDISMODULE_READ, error);
if (status == REDISMODULE_ERR) {
// IFDEF LITE, call _AnalyzeKey() because model key should be located at this shard.
return REDISMODULE_ERR;
}
RAI_HoldString(NULL, argv[argpos]);
*runkey = argv[argpos];
const char *arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
*runkey = RAI_HoldString(NULL, argv[arg_pos++]);
const char *arg_string = RedisModule_StringPtrLen(argv[arg_pos++], NULL);

// Parse timeout arg if given and store it in timeout
if (!strcasecmp(arg_string, "TIMEOUT")) {
if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR)
return REDISMODULE_ERR;
arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
}
if (strcasecmp(arg_string, "INPUTS") != 0) {
RAI_SetError(error, RAI_EMODELRUN, "ERR INPUTS not specified");
return REDISMODULE_ERR;
}

bool is_input = true, is_output = false;
size_t ninputs = 0, noutputs = 0;

while (++argpos < argc) {
arg_string = RedisModule_StringPtrLen(argv[argpos], NULL);
if (!strcasecmp(arg_string, "OUTPUTS") && !is_output) {
is_input = false;
is_output = true;
} else {
RAI_HoldString(NULL, argv[argpos]);
if (is_input) {
ninputs++;
*inkeys = array_append(*inkeys, argv[argpos]);
} else {
noutputs++;
*outkeys = array_append(*outkeys, argv[argpos]);
}
}
long long ninputs = 0, noutputs = 0;
if (RedisModule_StringToLongLong(argv[arg_pos++], &ninputs) != REDISMODULE_OK) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid argument for input_count");
return REDISMODULE_ERR;
}
if (ninputs <= 0) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Input count must be a positive integer");
return REDISMODULE_ERR;
}
if ((*model)->ninputs != ninputs) {
RAI_SetError(error, RAI_EMODELRUN,
"Number of keys given as INPUTS here does not match model definition");
return REDISMODULE_ERR;
}
for (; arg_pos < ninputs + 4 && arg_pos < argc; arg_pos++) {
*inkeys = array_append(*inkeys, RAI_HoldString(NULL, argv[arg_pos]));
}
if (arg_pos != ninputs + 4) {
RAI_SetError(
error, RAI_EMODELRUN,
"ERR number of input keys to AI.MODELEXECUTE command does not match the number of "
"given arguments");
return REDISMODULE_ERR;
}

// After inputs args, there must be at least 3 more args ("OUTPUT" "output_count"
// <first_output>)
if (argc < arg_pos + 2) {
RAI_SetError(error, RAI_EMODELRUN, "ERR OUTPUTS not specified");
return REDISMODULE_ERR;
}
if (strcasecmp(RedisModule_StringPtrLen(argv[arg_pos++], NULL), "OUTPUTS") != 0) {
RAI_SetError(error, RAI_EMODELRUN, "ERR OUTPUTS not specified");
return REDISMODULE_ERR;
}
if (RedisModule_StringToLongLong(argv[arg_pos++], &noutputs) != REDISMODULE_OK) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid argument for output_count");
}
if (noutputs <= 0) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Input count must be a positive integer");
return REDISMODULE_ERR;
}
if ((*model)->noutputs != noutputs) {
RAI_SetError(error, RAI_EMODELRUN,
"Number of keys given as OUTPUTS here does not match model definition");
return REDISMODULE_ERR;
}
for (; arg_pos < noutputs + ninputs + 6 && arg_pos < argc; arg_pos++) {
*outkeys = array_append(*outkeys, RAI_HoldString(NULL, argv[arg_pos]));
}
if (arg_pos != noutputs + ninputs + 6) {
RAI_SetError(
error, RAI_EMODELRUN,
"ERR number of output keys to AI.MODELEXECUTE command does not match the number of "
"given arguments");
return REDISMODULE_ERR;
}

if (arg_pos == argc) {
return REDISMODULE_OK;
}

// Parse timeout arg if given and store it in timeout.
char *error_str;
arg_string = RedisModule_StringPtrLen(argv[arg_pos++], NULL);
if (!strcasecmp(arg_string, "TIMEOUT")) {
if (arg_pos == argc) {
RAI_SetError(error, RAI_EMODELRUN, "ERR No value provided for TIMEOUT");
return REDISMODULE_ERR;
}
if (_parseTimeout(argv[arg_pos++], error, timeout) == REDISMODULE_ERR)
return REDISMODULE_ERR;
} else {
error_str = RedisModule_Alloc(strlen("Invalid argument: ") + strlen(arg_string) + 1);
sprintf(error_str, "Invalid argument: %s", arg_string);
RAI_SetError(error, RAI_EMODELRUN, error_str);
RedisModule_Free(error_str);
return REDISMODULE_ERR;
}

// There are no more valid args to be processed.
if (arg_pos != argc) {
arg_string = RedisModule_StringPtrLen(argv[arg_pos], NULL);
error_str = RedisModule_Alloc(strlen("Invalid argument: ") + strlen(arg_string) + 1);
sprintf(error_str, "Invalid argument: %s", arg_string);
RAI_SetError(error, RAI_EMODELRUN, error_str);
RedisModule_Free(error_str);
return REDISMODULE_ERR;
}
return REDISMODULE_OK;
}

/**
* Extract the params for the ModelCtxRun object from AI.MODELRUN arguments.
* Extract the params for the ModelCtxRun object from AI.MODELEXECUTE arguments.
*
* @param ctx Context in which Redis modules operate
* @param inkeys Model input tensors keys, as an array of strings
Expand All @@ -105,6 +179,7 @@ static int _ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkey
if (status == REDISMODULE_ERR) {
RedisModule_Log(ctx, "warning", "could not load input tensor %s from keyspace",
RedisModule_StringPtrLen(inkeys[i], NULL));
_AnalyzeKey(inkeys[i], err); // Relevant for enterprise cluster.
return REDISMODULE_ERR;
}
if (model->inputs)
Expand All @@ -115,22 +190,26 @@ static int _ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkey
for (size_t i = 0; i < noutputs; i++) {
if (model->outputs)
opname = model->outputs[i];
_AnalyzeKey(outkeys[i], err); // Relevant for enterprise cluster.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do it on parse time

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't because the parsing function is used from dag command as well (and in DAG we do not take the keys from key space...)
This is called only from "pure" AI.MODELEXECUTE.

if (RAI_GetErrorCode(err) != RAI_OK) {
return REDISMODULE_ERR;
}
RAI_ModelRunCtxAddOutput(mctx, opname);
}
return REDISMODULE_OK;
}

int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv,
int argc) {
int ParseModelExecuteCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv,
int argc) {

int res = REDISMODULE_ERR;
// Build a ModelRunCtx from command.
RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL);
RAI_Model *model;
long long timeout = 0;
if (_ModelRunCommand_ParseArgs(ctx, argc, argv, &model, rinfo->err, &currentOp->inkeys,
&currentOp->outkeys, &currentOp->runkey,
&timeout) == REDISMODULE_ERR) {
if (_ModelExecuteCommand_ParseArgs(ctx, argc, argv, &model, rinfo->err, &currentOp->inkeys,
&currentOp->outkeys, &currentOp->runkey,
&timeout) == REDISMODULE_ERR) {
goto cleanup;
}

Expand Down Expand Up @@ -372,6 +451,13 @@ int RedisAI_ExecuteCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
case CMD_DAGRUN:
status = ParseDAGRunCommand(rinfo, ctx, argv, argc, ro_dag);
break;
case CMD_MODELEXECUTE:
rinfo->single_op_dag = 1;
RAI_DagOp *modelExecuteOp;
RAI_InitDagOp(&modelExecuteOp);
rinfo->dagOps = array_append(rinfo->dagOps, modelExecuteOp);
status = ParseModelExecuteCommand(rinfo, modelExecuteOp, argv, argc);
break;
default:
break;
}
Expand Down
13 changes: 9 additions & 4 deletions src/execution/command_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
#include "redismodule.h"
#include "run_info.h"

typedef enum RunCommand { CMD_MODELRUN = 0, CMD_SCRIPTRUN, CMD_DAGRUN } RunCommand;
typedef enum RunCommand {
CMD_MODELRUN = 0,
CMD_SCRIPTRUN,
CMD_DAGRUN,
CMD_MODELEXECUTE
} RunCommand;

/**
* @brief Parse and validate MODELRUN command: create a modelRunCtx based on the model obtained
* @brief Parse and validate MODELEXECUTE command: create a modelRunCtx based on the model obtained
* from the key space and save it in the op. The keys of the input and output tensors are stored in
* the op's inkeys and outkeys arrays, the model key is saved in op's runkey, and the given timeout
* is saved as well (if given, otherwise it is zero).
* @return Returns REDISMODULE_OK if the command is valid, REDISMODULE_ERR otherwise.
*/
int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv,
int argc);
int ParseModelExecuteCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv,
int argc);

/**
* @brief Parse and validate SCRIPTRUN command: create a scriptRunCtx based on the script obtained
Expand Down
Loading