Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions docs/commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,49 @@ OK
!!! note "The `AI.MODELDEL` vis a vis the `DEL` command"
The `AI.MODELDEL` is equivalent to the [Redis `DEL` command](https://redis.io/commands/del) and should be used in its stead. This ensures compatibility with all deployment options (i.e., stand-alone vs. cluster, OSS vs. Enterprise).


## AI.MODELEXECUTE
The **`AI.MODELRUN`** command runs a model stored as a key's value using its specified backend and device. It accepts one or more input tensors and store output tensors.

The run request is put in a queue and is executed asynchronously by a worker thread. The client that had issued the run request is blocked until the model run is completed. When needed, tensors data is automatically copied to the device prior to execution.

A `TIMEOUT t` argument can be specified to cause a request to be removed from the queue after it sits there `t` milliseconds, meaning that the client won't be interested in the result being computed after that time (`TIMEDOUT` is returned in that case).

!!! warning "Intermediate memory overhead"
The execution of models will generate intermediate tensors that are not allocated by the Redis allocator, but by whatever allocator is used in the backends (which may act on main memory or GPU memory, depending on the device), thus not being limited by `maxmemory` configuration settings of Redis.

**Redis API**

```
AI.MODELEXECUTE <key> INPUTS <input_count> <input> [input ...] OUTPUTS <output_count> <output> [output ...] [TIMEOUT t]
```

_Arguments_

* **key**: the model's key name
* **INPUTS**: denotes the beginning of the input tensors keys' list, followed by the number of inputs and one or more key names
* **input_count**: A positive number that indicates the number of following input keys.
* **OUTPUTS**: denotes the beginning of the output tensors keys' list, followed by the number of outputs one or more key names
* **output_count**: A positive number that indicates the number of output keys to follow.
* **TIMEOUT**: the time (in ms) after which the client is unblocked and a `TIMEDOUT` string is returned

_Return_

A simple 'OK' string, a simple `TIMEDOUT` string, or an error.

**Examples**

Assuming that running the model that's stored at 'mymodel' with the tensor 'mytensor' as input outputs two tensors - 'classes' and 'predictions', the following command does that:

```
redis> AI.MODELEXECUTE mymodel INPUTS 1 mytensor OUTPUTS 2 classes predictions
OK
```

## AI.MODELRUN

_This command is deprecated and will not be available in future versions. consider using AI.MODELEXECUTE command instead._

The **`AI.MODELRUN`** command runs a model stored as a key's value using its specified backend and device. It accepts one or more input tensors and store output tensors.

The run request is put in a queue and is executed asynchronously by a worker thread. The client that had issued the run request is blocked until the model run is completed. When needed, tensors data is automatically copied to the device prior to execution.
Expand Down
4 changes: 3 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ file (GLOB BACKEND_COMMON_SRC
util/dict.c
redis_ai_objects/tensor.c
util/string_utils.c
execution/utils.c
serialization/ai_datatypes.c)

ADD_LIBRARY(redisai_obj OBJECT
Expand All @@ -25,8 +26,10 @@ ADD_LIBRARY(redisai_obj OBJECT
util/string_utils.c
redisai.c
execution/command_parser.c
execution/deprecated.c
execution/run_info.c
execution/background_workers.c
execution/utils.c
config/config.c
execution/DAG/dag.c
execution/DAG/dag_parser.c
Expand All @@ -43,7 +46,6 @@ ADD_LIBRARY(redisai_obj OBJECT
rmutil/alloc.c
rmutil/sds.c
rmutil/args.c
execution/run_info.c
redis_ai_types/model_type.c
redis_ai_types/tensor_type.c
redis_ai_types/script_type.c
Expand Down
2 changes: 2 additions & 0 deletions src/execution/DAG/dag_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "dag.h"
#include "dag_parser.h"
#include "dag_execute.h"
#include "execution/deprecated.h"
#include "execution/utils.h"

/**
* DAGRUN Building Block to parse [LOAD <nkeys> key1 key2... ]
Expand Down
190 changes: 121 additions & 69 deletions src/execution/command_parser.c
Original file line number Diff line number Diff line change
@@ -1,98 +1,129 @@

#include "redismodule.h"
#include "run_info.h"
#include "command_parser.h"
#include "DAG/dag.h"
#include "DAG/dag_parser.h"
#include "util/string_utils.h"
#include "execution/modelRun_ctx.h"
#include "deprecated.h"
#include "utils.h"

static int _parseTimeout(RedisModuleString *timeout_arg, RAI_Error *error, long long *timeout) {

const int retval = RedisModule_StringToLongLong(timeout_arg, timeout);
if (retval != REDISMODULE_OK || timeout <= 0) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid value for TIMEOUT");
return REDISMODULE_ERR;
}
return REDISMODULE_OK;
}

static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModuleString **argv,
RAI_Model **model, RAI_Error *error,
RedisModuleString ***inkeys, RedisModuleString ***outkeys,
RedisModuleString **runkey, long long *timeout) {
static int _ModelExecuteCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModuleString **argv,
RAI_Model **model, RAI_Error *error,
RedisModuleString ***inkeys, RedisModuleString ***outkeys,
RedisModuleString **runkey, long long *timeout) {

if (argc < 6) {
if (argc < 8) {
RAI_SetError(error, RAI_EMODELRUN,
"ERR wrong number of arguments for 'AI.MODELRUN' command");
"ERR wrong number of arguments for 'AI.MODELEXECUTE' command");
return REDISMODULE_ERR;
}
size_t argpos = 1;
const int status = RAI_GetModelFromKeyspace(ctx, argv[argpos], model, REDISMODULE_READ, error);
size_t arg_pos = 1;
const int status = RAI_GetModelFromKeyspace(ctx, argv[arg_pos], model, REDISMODULE_READ, error);
if (status == REDISMODULE_ERR) {
return REDISMODULE_ERR;
}
RAI_HoldString(NULL, argv[argpos]);
*runkey = argv[argpos];
const char *arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
*runkey = RAI_HoldString(NULL, argv[arg_pos++]);
const char *arg_string = RedisModule_StringPtrLen(argv[arg_pos++], NULL);

// Parse timeout arg if given and store it in timeout
if (!strcasecmp(arg_string, "TIMEOUT")) {
if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR)
return REDISMODULE_ERR;
arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
}
if (strcasecmp(arg_string, "INPUTS") != 0) {
RAI_SetError(error, RAI_EMODELRUN, "ERR INPUTS not specified");
return REDISMODULE_ERR;
}

bool is_input = true, is_output = false;
size_t ninputs = 0, noutputs = 0;

while (++argpos < argc) {
arg_string = RedisModule_StringPtrLen(argv[argpos], NULL);
if (!strcasecmp(arg_string, "OUTPUTS") && !is_output) {
is_input = false;
is_output = true;
} else {
RAI_HoldString(NULL, argv[argpos]);
if (is_input) {
ninputs++;
*inkeys = array_append(*inkeys, argv[argpos]);
} else {
noutputs++;
*outkeys = array_append(*outkeys, argv[argpos]);
}
}
long long ninputs = 0, noutputs = 0;
if (RedisModule_StringToLongLong(argv[arg_pos++], &ninputs) != REDISMODULE_OK) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid argument for input_count");
return REDISMODULE_ERR;
}
if (ninputs <= 0) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Input count must be a positive integer");
return REDISMODULE_ERR;
}
if ((*model)->ninputs != ninputs) {
RAI_SetError(error, RAI_EMODELRUN,
"Number of keys given as INPUTS here does not match model definition");
return REDISMODULE_ERR;
}
// arg_pos = 4 at this point, as we always have 4 args before the input keys.
for (; arg_pos < ninputs + 4 && arg_pos < argc; arg_pos++) {
*inkeys = array_append(*inkeys, RAI_HoldString(NULL, argv[arg_pos]));
}
if (arg_pos != ninputs + 4) {
RAI_SetError(
error, RAI_EMODELRUN,
"ERR number of input keys to AI.MODELEXECUTE command does not match the number of "
"given arguments");
return REDISMODULE_ERR;
}

if (argc == arg_pos ||
strcasecmp(RedisModule_StringPtrLen(argv[arg_pos++], NULL), "OUTPUTS") != 0) {
RAI_SetError(error, RAI_EMODELRUN, "ERR OUTPUTS not specified");
return REDISMODULE_ERR;
}
if (argc == arg_pos ||
RedisModule_StringToLongLong(argv[arg_pos++], &noutputs) != REDISMODULE_OK) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid argument for output_count");
}
if (noutputs <= 0) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Input count must be a positive integer");
return REDISMODULE_ERR;
}
if ((*model)->noutputs != noutputs) {
RAI_SetError(error, RAI_EMODELRUN,
"Number of keys given as OUTPUTS here does not match model definition");
return REDISMODULE_ERR;
}
// arg_pos = ninputs+6, the argument that we already parsed are:
// AI.MODELEXECUTE <model_key> INPUTS <input_count> <input> ... OUTPUTS <output_count>
for (; arg_pos < noutputs + ninputs + 6 && arg_pos < argc; arg_pos++) {
*outkeys = array_append(*outkeys, RAI_HoldString(NULL, argv[arg_pos]));
}
if (arg_pos != noutputs + ninputs + 6) {
RAI_SetError(
error, RAI_EMODELRUN,
"ERR number of output keys to AI.MODELEXECUTE command does not match the number of "
"given arguments");
return REDISMODULE_ERR;
}

if (arg_pos == argc) {
return REDISMODULE_OK;
}

// Parse timeout arg if given and store it in timeout.
char *error_str;
arg_string = RedisModule_StringPtrLen(argv[arg_pos++], NULL);
if (!strcasecmp(arg_string, "TIMEOUT")) {
if (arg_pos == argc) {
RAI_SetError(error, RAI_EMODELRUN, "ERR No value provided for TIMEOUT");
return REDISMODULE_ERR;
}
if (ParseTimeout(argv[arg_pos++], error, timeout) == REDISMODULE_ERR)
return REDISMODULE_ERR;
} else {
error_str = RedisModule_Alloc(strlen("Invalid argument: ") + strlen(arg_string) + 1);
sprintf(error_str, "Invalid argument: %s", arg_string);
RAI_SetError(error, RAI_EMODELRUN, error_str);
RedisModule_Free(error_str);
return REDISMODULE_ERR;
}

// There are no more valid args to be processed.
if (arg_pos != argc) {
arg_string = RedisModule_StringPtrLen(argv[arg_pos], NULL);
error_str = RedisModule_Alloc(strlen("Invalid argument: ") + strlen(arg_string) + 1);
sprintf(error_str, "Invalid argument: %s", arg_string);
RAI_SetError(error, RAI_EMODELRUN, error_str);
RedisModule_Free(error_str);
return REDISMODULE_ERR;
}
return REDISMODULE_OK;
}

/**
* Extract the params for the ModelCtxRun object from AI.MODELRUN arguments.
*
* @param ctx Context in which Redis modules operate
* @param inkeys Model input tensors keys, as an array of strings
* @param outkeys Model output tensors keys, as an array of strings
* @param mctx Destination Model context to store the parsed data
* @return REDISMODULE_OK in case of success, REDISMODULE_ERR otherwise
*/

static int _ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkeys,
RedisModuleString **outkeys, RAI_ModelRunCtx *mctx,
RAI_Error *err) {
int ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkeys,
RedisModuleString **outkeys, RAI_ModelRunCtx *mctx, RAI_Error *err) {

RAI_Model *model = mctx->model;
RAI_Tensor *t;
Expand All @@ -103,8 +134,6 @@ static int _ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkey
const int status =
RAI_GetTensorFromKeyspace(ctx, inkeys[i], &key, &t, REDISMODULE_READ, err);
if (status == REDISMODULE_ERR) {
RedisModule_Log(ctx, "warning", "could not load input tensor %s from keyspace",
RedisModule_StringPtrLen(inkeys[i], NULL));
return REDISMODULE_ERR;
}
if (model->inputs)
Expand All @@ -113,24 +142,30 @@ static int _ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkey
}

for (size_t i = 0; i < noutputs; i++) {
if (model->outputs)
if (model->outputs) {
opname = model->outputs[i];
}
if (!VerifyKeyInThisShard(ctx, outkeys[i])) { // Relevant for enterprise cluster.
RAI_SetError(err, RAI_EMODELRUN,
"ERR CROSSSLOT Keys in request don't hash to the same slot");
return REDISMODULE_ERR;
Comment on lines +149 to +152

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to the actual parsing phase. fail as soon as possible when you can

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't because the parsing function is used from dag command as well (and in DAG we do not take the keys from key space...)
But ModelRunCtx_SetParams is called only from "pure" AI.MODELEXECUTE.

}
RAI_ModelRunCtxAddOutput(mctx, opname);
}
return REDISMODULE_OK;
}

int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv,
int argc) {
int ParseModelExecuteCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv,
int argc) {

int res = REDISMODULE_ERR;
// Build a ModelRunCtx from command.
RedisModuleCtx *ctx = RedisModule_GetThreadSafeContext(NULL);
RAI_Model *model;
long long timeout = 0;
if (_ModelRunCommand_ParseArgs(ctx, argc, argv, &model, rinfo->err, &currentOp->inkeys,
&currentOp->outkeys, &currentOp->runkey,
&timeout) == REDISMODULE_ERR) {
if (_ModelExecuteCommand_ParseArgs(ctx, argc, argv, &model, rinfo->err, &currentOp->inkeys,
&currentOp->outkeys, &currentOp->runkey,
&timeout) == REDISMODULE_ERR) {
goto cleanup;
}

Expand All @@ -147,7 +182,7 @@ int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModu
if (rinfo->single_op_dag) {
rinfo->timeout = timeout;
// Set params in ModelRunCtx, bring inputs from key space.
if (_ModelRunCtx_SetParams(ctx, currentOp->inkeys, currentOp->outkeys, mctx, rinfo->err) ==
if (ModelRunCtx_SetParams(ctx, currentOp->inkeys, currentOp->outkeys, mctx, rinfo->err) ==
REDISMODULE_ERR)
goto cleanup;
}
Expand Down Expand Up @@ -198,7 +233,7 @@ static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **

// Parse timeout arg if given and store it in timeout
if (!strcasecmp(arg_string, "TIMEOUT") && !timeout_set) {
if (_parseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR)
if (ParseTimeout(argv[++argpos], error, timeout) == REDISMODULE_ERR)
return REDISMODULE_ERR;
timeout_set = true;
continue;
Expand Down Expand Up @@ -296,6 +331,16 @@ static int _ScriptRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inke
return REDISMODULE_OK;
}

int ParseTimeout(RedisModuleString *timeout_arg, RAI_Error *error, long long *timeout) {

const int retval = RedisModule_StringToLongLong(timeout_arg, timeout);
if (retval != REDISMODULE_OK || *timeout <= 0) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid value for TIMEOUT");
return REDISMODULE_ERR;
}
return REDISMODULE_OK;
}

int ParseScriptRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisModuleString **argv,
int argc) {

Expand Down Expand Up @@ -372,6 +417,13 @@ int RedisAI_ExecuteCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
case CMD_DAGRUN:
status = ParseDAGRunCommand(rinfo, ctx, argv, argc, ro_dag);
break;
case CMD_MODELEXECUTE:
rinfo->single_op_dag = 1;
RAI_DagOp *modelExecuteOp;
RAI_InitDagOp(&modelExecuteOp);
rinfo->dagOps = array_append(rinfo->dagOps, modelExecuteOp);
status = ParseModelExecuteCommand(rinfo, modelExecuteOp, argv, argc);
break;
default:
break;
}
Expand Down
Loading