Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion get_deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ if [[ $WITH_PT != 0 ]]; then

echo "Done."
else
echo "librotch is in place."
echo "libtorch is in place."
fi
else
echo "SKipping libtorch."
Expand Down
9 changes: 7 additions & 2 deletions src/dag.c
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv,
RedisModule_ReplyWithError(ctx, "ERR could not save tensor");
rinfo->dagReplyLength++;
} else {
if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, tensor) !=
if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, RAI_TensorGetShallowCopy(tensor)) !=
REDISMODULE_OK) {
RedisModule_ReplyWithError(ctx, "ERR could not save tensor");
rinfo->dagReplyLength++;
Expand All @@ -473,6 +473,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv,
localcontext_key_name);
local_entry = AI_dictNext(local_iter);
}
AI_dictReleaseIterator(local_iter);

for (size_t opN = 0; opN < array_len(rinfo->dagOps); opN++) {
RedisModule_Log(
Expand Down Expand Up @@ -532,7 +533,7 @@ int RAI_parseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv,
RedisModule_CloseKey(key);
char *dictKey = (char*) RedisModule_Alloc((strlen(arg_string) + 5)*sizeof(char));
sprintf(dictKey, "%s%04d", arg_string, 1);
AI_dictAdd(*localContextDict, (void*)dictKey, (void *)t);
AI_dictAdd(*localContextDict, (void*)dictKey, (void *)RAI_TensorGetShallowCopy(t));
AI_dictAdd(*loadedContextDict, (void*)dictKey, (void *)1);
RedisModule_Free(dictKey);
number_loaded_keys++;
Expand Down Expand Up @@ -796,6 +797,7 @@ int RedisAI_DagRunSyntaxParser(RedisModuleCtx *ctx, RedisModuleString **argv,
const char* key = RedisModule_StringPtrLen(currentOp->inkeys[j], NULL);
AI_dictEntry *entry = AI_dictFind(mangled_tensors, key);
if (!entry) {
AI_dictRelease(mangled_tensors);
return RedisModule_ReplyWithError(ctx,
"ERR INPUT key cannot be found in DAG");
}
Expand Down Expand Up @@ -837,6 +839,8 @@ int RedisAI_DagRunSyntaxParser(RedisModuleCtx *ctx, RedisModuleString **argv,
char *key = (char *)AI_dictGetKey(entry);
AI_dictEntry *mangled_entry = AI_dictFind(mangled_tensors, key);
if (!mangled_entry) {
AI_dictRelease(mangled_tensors);
AI_dictRelease(mangled_persisted);
return RedisModule_ReplyWithError(ctx,
"ERR PERSIST key cannot be found in DAG");
}
Expand All @@ -849,6 +853,7 @@ int RedisAI_DagRunSyntaxParser(RedisModuleCtx *ctx, RedisModuleString **argv,
AI_dictReleaseIterator(iter);
}

AI_dictRelease(rinfo->dagTensorsPersistedContext);
rinfo->dagTensorsPersistedContext = mangled_persisted;

{
Expand Down
33 changes: 3 additions & 30 deletions src/run_info.c
Original file line number Diff line number Diff line change
Expand Up @@ -246,36 +246,9 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) {
}

if (rinfo->dagTensorsContext) {
AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext);
AI_dictEntry *entry = AI_dictNext(iter);
RAI_Tensor *tensor = NULL;

while (entry) {
tensor = AI_dictGetVal(entry);
char *key = (char *)AI_dictGetKey(entry);

if (tensor && key != NULL) {
// if the key is persisted then we should not delete it
AI_dictEntry *persisted_entry =
AI_dictFind(rinfo->dagTensorsPersistedContext, key);
// if the key was loaded from the keyspace then we should not delete it
AI_dictEntry *loaded_entry =
AI_dictFind(rinfo->dagTensorsLoadedContext, key);

if (persisted_entry == NULL && loaded_entry == NULL) {
AI_dictDelete(rinfo->dagTensorsContext, key);
}

if (persisted_entry) {
AI_dictDelete(rinfo->dagTensorsPersistedContext, key);
}
if (loaded_entry) {
AI_dictDelete(rinfo->dagTensorsLoadedContext, key);
}
}
entry = AI_dictNext(iter);
}
AI_dictReleaseIterator(iter);
AI_dictRelease(rinfo->dagTensorsContext);
AI_dictRelease(rinfo->dagTensorsLoadedContext);
AI_dictRelease(rinfo->dagTensorsPersistedContext);
}

if (rinfo->dagOps) {
Expand Down
14 changes: 10 additions & 4 deletions src/tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,7 @@ int RAI_parseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
int meta = 0;
int blob = 0;
int values = 0;
int fmt_error = 0;
for (int i=2; i<argc; i++) {
const char *fmtstr = RedisModule_StringPtrLen(argv[i], NULL);
if (!strcasecmp(fmtstr, "BLOB")) {
Expand All @@ -992,11 +993,15 @@ int RAI_parseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
meta = 1;
datafmt = REDISAI_DATA_NONE;
} else {
RedisModule_ReplyWithError(ctx, "ERR unsupported data format");
return -1;
fmt_error = 1;
}
}

if (fmt_error) {
RedisModule_ReplyWithError(ctx, "ERR unsupported data format");
return -1;
}

if (blob && values) {
RedisModule_ReplyWithError(ctx, "ERR both BLOB and VALUES specified");
return -1;
Expand Down Expand Up @@ -1033,14 +1038,15 @@ int RAI_parseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar

const long long ndims = RAI_TensorNumDims(t);

RedisModule_ReplyWithArray(ctx, resplen);

char *dtypestr = NULL;
const int dtypestr_result = Tensor_DataTypeStr(RAI_TensorDataType(t), &dtypestr);
if(dtypestr_result==REDISMODULE_ERR){
RedisModule_ReplyWithError(ctx, "ERR unsupported dtype");
return -1;
}

RedisModule_ReplyWithArray(ctx, resplen);

RedisModule_ReplyWithCString(ctx, "dtype");
RedisModule_ReplyWithCString(ctx, dtypestr);

Expand Down