Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 61 additions & 51 deletions src/DAG/dag.c
Original file line number Diff line number Diff line change
Expand Up @@ -237,53 +237,52 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
uint n_inkeys = array_len(currentOp->inkeys);
uint n_outkeys = array_len(currentOp->outkeys);

RAI_ContextReadLock(rinfo);
if (!rinfo->single_op_dag) {

RAI_Tensor *inputTensors[n_inkeys];
for (uint i = 0; i < n_inkeys; i++) {
RAI_Tensor *inputTensor;
const int get_result = RAI_getTensorFromLocalContext(
NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err);
if (get_result == REDISMODULE_ERR) {
// We check for this outside the function
// this check cannot be covered by tests
currentOp->result = REDISMODULE_ERR;
RAI_ContextUnlock(rinfo);
return;
RAI_ContextReadLock(rinfo);
RAI_Tensor *inputTensors[n_inkeys];
for (uint i = 0; i < n_inkeys; i++) {
RAI_Tensor *inputTensor;
const int get_result = RAI_getTensorFromLocalContext(
NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err);
if (get_result == REDISMODULE_ERR) {
// We check for this outside the function
// this check cannot be covered by tests
currentOp->result = REDISMODULE_ERR;
RAI_ContextUnlock(rinfo);
return;
}
inputTensors[i] = inputTensor;
}
inputTensors[i] = inputTensor;
}

RAI_ContextUnlock(rinfo);

for (uint i = 0; i < n_inkeys; i++) {
RAI_ScriptRunCtxAddInput(currentOp->sctx, inputTensors[i], currentOp->err);
}
RAI_ContextUnlock(rinfo);

for (uint i = 0; i < n_outkeys; i++) {
RAI_ScriptRunCtxAddOutput(currentOp->sctx);
for (uint i = 0; i < n_inkeys; i++) {
RAI_ScriptRunCtxAddInput(currentOp->sctx, inputTensors[i], currentOp->err);
}
for (uint i = 0; i < n_outkeys; i++) {
RAI_ScriptRunCtxAddOutput(currentOp->sctx);
}
}

const long long start = ustime();
int result = RAI_ScriptRun(currentOp->sctx, currentOp->err);
const long long end = ustime();

RAI_ContextWriteLock(rinfo);

const size_t noutputs = RAI_ScriptRunCtxNumOutputs(currentOp->sctx);
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(currentOp->sctx, outputNumber);
RedisModuleString *key_string = currentOp->outkeys[outputNumber];
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, tensor);
}

currentOp->result = result;
currentOp->duration_us = end - start;

RAI_ContextUnlock(rinfo);
if (!rinfo->single_op_dag) {

return;
RAI_ContextWriteLock(rinfo);
const size_t noutputs = RAI_ScriptRunCtxNumOutputs(currentOp->sctx);
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(currentOp->sctx, outputNumber);
RedisModuleString *key_string = currentOp->outkeys[outputNumber];
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, tensor);
}
RAI_ContextUnlock(rinfo);
}
}

size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) {
Expand Down Expand Up @@ -572,17 +571,16 @@ static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor,
return ret;
}

static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
static void _PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {

AI_dictIterator *persist_iter = AI_dictGetSafeIterator(rinfo->dagTensorsPersistedContext);
AI_dictEntry *persist_entry = AI_dictNext(persist_iter);

while (persist_entry) {
RedisModuleString *persist_key_name = AI_dictGetKey(persist_entry);

AI_dictEntry *tensor_entry = AI_dictFind(rinfo->dagTensorsContext, persist_key_name);

if (tensor_entry) {
RAI_Tensor *tensor = AI_dictGetVal(tensor_entry);

if (tensor == NULL) {
persist_entry = AI_dictNext(persist_iter);
continue;
Expand All @@ -594,17 +592,17 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
RedisModule_ReplyWithError(ctx,
"ERR specified persistent key that was not used in DAG");
rinfo->dagReplyLength++;

RedisModule_Log(ctx, "warning",
"on DAGRUN's PERSIST pecified persistent key (%s) that "
"on DAGRUN's PERSIST specified persistent key (%s) that "
"was not used on DAG. Logging all local context keys",
persist_key_name);
RedisModule_StringPtrLen(persist_key_name, NULL));
AI_dictIterator *local_iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext);
AI_dictEntry *local_entry = AI_dictNext(local_iter);

while (local_entry) {
RedisModuleString *localcontext_key_name = AI_dictGetKey(local_entry);
RedisModule_Log(ctx, "warning", "DAG's local context key (%s)",
localcontext_key_name);
RedisModule_StringPtrLen(localcontext_key_name, NULL));
local_entry = AI_dictNext(local_iter);
}
AI_dictReleaseIterator(local_iter);
Expand All @@ -619,7 +617,7 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
AI_dictReleaseIterator(persist_iter);
}

static void ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
static void _ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
const size_t noutputs = RAI_ModelRunCtxNumOutputs(op->mctx);
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(op->mctx, outputNumber);
Expand All @@ -629,6 +627,16 @@ static void ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
}
}

static void _ScriptSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
const size_t noutputs = RAI_ScriptRunCtxNumOutputs(op->sctx);
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(op->sctx, outputNumber);
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
if (tensor)
_StoreTensorInKeySpace(ctx, tensor, op->outkeys[outputNumber], false);
}
}

int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
REDISMODULE_NOT_USED(argv);
REDISMODULE_NOT_USED(argc);
Expand All @@ -650,7 +658,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
return REDISMODULE_OK;
}

if (rinfo->single_op_dag == 0) {
if (!rinfo->single_op_dag) {
RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_ARRAY_LEN);
}

Expand Down Expand Up @@ -745,18 +753,20 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
return REDISMODULE_ERR;
}

// TODO: Take care of script single op
if (rinfo->single_op_dag == 0 || rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_SCRIPTRUN) {
if (!rinfo->single_op_dag) {
// Save the required tensors in redis key space.
PersistTensors(ctx, rinfo);
if (rinfo->single_op_dag == 0)
RedisModule_ReplySetArrayLength(ctx, rinfo->dagReplyLength);
_PersistTensors(ctx, rinfo);
RedisModule_ReplySetArrayLength(ctx, rinfo->dagReplyLength);
} else {
ModelSingleOp_PersistTensors(ctx, rinfo->dagOps[0]);
if (rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_MODELRUN) {
_ModelSingleOp_PersistTensors(ctx, rinfo->dagOps[0]);
} else {
RedisModule_Assert(rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_SCRIPTRUN);
_ScriptSingleOp_PersistTensors(ctx, rinfo->dagOps[0]);
}
}

RAI_FreeRunInfo(rinfo);

return REDISMODULE_OK;
}

Expand Down
42 changes: 2 additions & 40 deletions src/DAG/dag_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,9 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b
int chainingOpCount = 0;
bool load_complete = false;
bool persist_complete = false;
int arg_pos = 1;

// If we're parsing a AI.SCRIPTRUN command, we don't expect there to be a chaining |> operator
if (!strcasecmp(RedisModule_StringPtrLen(argv[0], NULL), "AI.SCRIPTRUN")) {
arg_pos = 0;
chainingOpCount++;
rinfo->single_op_dag = 1;
rinfo->single_device_dag = 1;
}

// The first arg is "AI.DAGRUN", so we go over from the next arg.
for (; arg_pos < argc; arg_pos++) {
for (int arg_pos = 1; arg_pos < argc; arg_pos++) {
const char *arg_string = RedisModule_StringPtrLen(argv[arg_pos], NULL);

if (!strcasecmp(arg_string, "LOAD") && !load_complete) {
Expand Down Expand Up @@ -303,35 +294,6 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b
}
}

if (rinfo->single_op_dag && rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_SCRIPTRUN) {
RAI_DagOp *op = rinfo->dagOps[0];
RAI_Tensor *t;
RedisModuleKey *key;
for (size_t i = 0; i < array_len(op->inkeys); i++) {
RedisModuleString *inkey = op->inkeys[i];
const int status = RAI_GetTensorFromKeyspace(ctx, inkey, &key, &t, REDISMODULE_READ);
if (status == REDISMODULE_ERR) {
RedisModule_Log(ctx, "warning",
"on DAGRUN's LOAD could not load tensor %s from keyspace",
RedisModule_StringPtrLen(inkey, NULL));
return REDISMODULE_ERR;
}
char buf[16];
sprintf(buf, "%04d", 1);
RedisModuleString *dictKey = RedisModule_CreateStringFromString(NULL, inkey);
RedisModule_StringAppendBuffer(NULL, dictKey, buf, strlen(buf));
AI_dictAdd(rinfo->dagTensorsContext, (void *)dictKey,
(void *)RAI_TensorGetShallowCopy(t));
AI_dictAdd(rinfo->dagTensorsLoadedContext, (void *)dictKey, (void *)1);
RedisModule_Free(dictKey);
}

for (size_t i = 0; i < array_len(op->outkeys); i++) {
RedisModuleString *outkey = op->outkeys[i];
AI_dictAdd(rinfo->dagTensorsPersistedContext, (void *)outkey, (void *)1);
}
}

// At this point, we have built a sequence of DAG operations, each with its own
// input and output keys. The names of the keys will be used to look whether the
// inputs to a DAG operation have all been realized by previous operations (or if
Expand Down Expand Up @@ -462,4 +424,4 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b
}
}
return REDISMODULE_OK;
}
}
Loading