diff --git a/get_deps.sh b/get_deps.sh index 89ea86369..727f3a5c9 100755 --- a/get_deps.sh +++ b/get_deps.sh @@ -253,7 +253,7 @@ if [[ $WITH_PT != 0 ]]; then echo "Done." else - echo "librotch is in place." + echo "libtorch is in place." fi else echo "SKipping libtorch." diff --git a/src/dag.c b/src/dag.c index c818560d1..0f80a184e 100644 --- a/src/dag.c +++ b/src/dag.c @@ -447,7 +447,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, RedisModule_ReplyWithError(ctx, "ERR could not save tensor"); rinfo->dagReplyLength++; } else { - if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, tensor) != + if (RedisModule_ModuleTypeSetValue(key, RedisAI_TensorType, RAI_TensorGetShallowCopy(tensor)) != REDISMODULE_OK) { RedisModule_ReplyWithError(ctx, "ERR could not save tensor"); rinfo->dagReplyLength++; @@ -473,6 +473,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, localcontext_key_name); local_entry = AI_dictNext(local_iter); } + AI_dictReleaseIterator(local_iter); for (size_t opN = 0; opN < array_len(rinfo->dagOps); opN++) { RedisModule_Log( @@ -532,7 +533,7 @@ int RAI_parseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, RedisModule_CloseKey(key); char *dictKey = (char*) RedisModule_Alloc((strlen(arg_string) + 5)*sizeof(char)); sprintf(dictKey, "%s%04d", arg_string, 1); - AI_dictAdd(*localContextDict, (void*)dictKey, (void *)t); + AI_dictAdd(*localContextDict, (void*)dictKey, (void *)RAI_TensorGetShallowCopy(t)); AI_dictAdd(*loadedContextDict, (void*)dictKey, (void *)1); RedisModule_Free(dictKey); number_loaded_keys++; @@ -796,6 +797,7 @@ int RedisAI_DagRunSyntaxParser(RedisModuleCtx *ctx, RedisModuleString **argv, const char* key = RedisModule_StringPtrLen(currentOp->inkeys[j], NULL); AI_dictEntry *entry = AI_dictFind(mangled_tensors, key); if (!entry) { + AI_dictRelease(mangled_tensors); return RedisModule_ReplyWithError(ctx, "ERR INPUT key cannot be found in DAG"); } @@ -837,6 +839,8 @@ int RedisAI_DagRunSyntaxParser(RedisModuleCtx *ctx, RedisModuleString **argv, char *key = (char *)AI_dictGetKey(entry); AI_dictEntry *mangled_entry = AI_dictFind(mangled_tensors, key); if (!mangled_entry) { + AI_dictRelease(mangled_tensors); + AI_dictRelease(mangled_persisted); return RedisModule_ReplyWithError(ctx, "ERR PERSIST key cannot be found in DAG"); } @@ -849,6 +853,7 @@ int RedisAI_DagRunSyntaxParser(RedisModuleCtx *ctx, RedisModuleString **argv, AI_dictReleaseIterator(iter); } + AI_dictRelease(rinfo->dagTensorsPersistedContext); rinfo->dagTensorsPersistedContext = mangled_persisted; { diff --git a/src/run_info.c b/src/run_info.c index b1d44d1c4..e1d89b30b 100644 --- a/src/run_info.c +++ b/src/run_info.c @@ -246,36 +246,9 @@ void RAI_FreeRunInfo(RedisModuleCtx *ctx, struct RedisAI_RunInfo *rinfo) { } if (rinfo->dagTensorsContext) { - AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext); - AI_dictEntry *entry = AI_dictNext(iter); - RAI_Tensor *tensor = NULL; - - while (entry) { - tensor = AI_dictGetVal(entry); - char *key = (char *)AI_dictGetKey(entry); - - if (tensor && key != NULL) { - // if the key is persisted then we should not delete it - AI_dictEntry *persisted_entry = - AI_dictFind(rinfo->dagTensorsPersistedContext, key); - // if the key was loaded from the keyspace then we should not delete it - AI_dictEntry *loaded_entry = - AI_dictFind(rinfo->dagTensorsLoadedContext, key); - - if (persisted_entry == NULL && loaded_entry == NULL) { - AI_dictDelete(rinfo->dagTensorsContext, key); - } - - if (persisted_entry) { - AI_dictDelete(rinfo->dagTensorsPersistedContext, key); - } - if (loaded_entry) { - AI_dictDelete(rinfo->dagTensorsLoadedContext, key); - } - } - entry = AI_dictNext(iter); - } - AI_dictReleaseIterator(iter); + AI_dictRelease(rinfo->dagTensorsContext); + AI_dictRelease(rinfo->dagTensorsLoadedContext); + AI_dictRelease(rinfo->dagTensorsPersistedContext); } if (rinfo->dagOps) { diff --git a/src/tensor.c b/src/tensor.c index a73083ac9..e76ebdc77 100644 --- a/src/tensor.c +++ b/src/tensor.c @@ -980,6 +980,7 @@ int RAI_parseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar int meta = 0; int blob = 0; int values = 0; + int fmt_error = 0; for (int i=2; i