Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 67 additions & 65 deletions src/dag.c

Large diffs are not rendered by default.

20 changes: 14 additions & 6 deletions src/model.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ static void *RAI_Model_RdbLoad(struct RedisModuleIO *io, int encver) {
RAI_Backend backend = RedisModule_LoadUnsigned(io);
const char *devicestr = RedisModule_LoadStringBuffer(io, NULL);

const char *tag = RedisModule_LoadStringBuffer(io, NULL);
RedisModuleString *tag = RedisModule_LoadString(io);

const size_t batchsize = RedisModule_LoadUnsigned(io);
const size_t minbatchsize = RedisModule_LoadUnsigned(io);
Expand Down Expand Up @@ -113,7 +113,10 @@ static void *RAI_Model_RdbLoad(struct RedisModuleIO *io, int encver) {
RedisModuleString *stats_keystr =
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));
const char *stats_devicestr = RedisModule_Strdup(devicestr);
const char *stats_tag = RedisModule_Strdup(tag);
if (tag) {
RedisModule_RetainString(NULL, tag);
}
RedisModuleString *stats_tag = tag;

model->infokey =
RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_MODEL, backend, stats_devicestr, stats_tag);
Expand Down Expand Up @@ -143,7 +146,7 @@ static void RAI_Model_RdbSave(RedisModuleIO *io, void *value) {

RedisModule_SaveUnsigned(io, model->backend);
RedisModule_SaveStringBuffer(io, model->devicestr, strlen(model->devicestr) + 1);
RedisModule_SaveStringBuffer(io, model->tag, strlen(model->tag) + 1);
RedisModule_SaveString(io, model->tag);
RedisModule_SaveUnsigned(io, model->opts.batchsize);
RedisModule_SaveUnsigned(io, model->opts.minbatchsize);
RedisModule_SaveUnsigned(io, model->ninputs);
Expand Down Expand Up @@ -221,7 +224,7 @@ static void RAI_Model_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, voi

const char *backendstr = RAI_BackendName(model->backend);

RedisModule_EmitAOF(aof, "AI.MODELSET", "slccclclcvcvcv", key, backendstr, model->devicestr,
RedisModule_EmitAOF(aof, "AI.MODELSET", "sccsclclcvcvcv", key, backendstr, model->devicestr,
model->tag, "BATCHSIZE", model->opts.batchsize, "MINBATCHSIZE",
model->opts.minbatchsize, "INPUTS", inputs_, model->ninputs, "OUTPUTS",
outputs_, model->noutputs, "BLOB", buffers_, n_chunks);
Expand Down Expand Up @@ -285,7 +288,7 @@ int RAI_ModelInit(RedisModuleCtx *ctx) {
return RedisAI_ModelType != NULL;
}

RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, const char *tag,
RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModuleString *tag,
RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs,
const char **outputs, const char *modeldef, size_t modellen,
RAI_Error *err) {
Expand Down Expand Up @@ -321,7 +324,12 @@ RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, const cha
}

if (model) {
model->tag = RedisModule_Strdup(tag);
if (tag) {
RedisModule_RetainString(NULL, tag);
model->tag = tag;
} else {
model->tag = RedisModule_CreateString(NULL, "", 0);
}
}

return model;
Expand Down
2 changes: 1 addition & 1 deletion src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ int RAI_ModelInit(RedisModuleCtx *ctx);
* failures
* @return RAI_Model model structure on success, or NULL if failed
*/
RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, const char *tag,
RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModuleString *tag,
RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs,
const char **outputs, const char *modeldef, size_t modellen,
RAI_Error *err);
Expand Down
2 changes: 1 addition & 1 deletion src/model_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ typedef struct RAI_Model {
void *session;
RAI_Backend backend;
char *devicestr;
char *tag;
RedisModuleString *tag;
RAI_ModelOpts opts;
char **inputs;
size_t ninputs;
Expand Down
32 changes: 19 additions & 13 deletions src/redisai.c
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
return RedisModule_ReplyWithError(ctx, "ERR Invalid DEVICE");
}

const char *tag = "";
RedisModuleString *tag = NULL;
if (AC_AdvanceIfMatch(&ac, "TAG")) {
AC_GetString(&ac, &tag, NULL, 0);
AC_GetRString(&ac, &tag, 0);
}

unsigned long long batchsize = 0;
Expand Down Expand Up @@ -470,7 +470,8 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
RedisModule_ReplyWithCString(ctx, mto->devicestr);

RedisModule_ReplyWithCString(ctx, "tag");
RedisModule_ReplyWithCString(ctx, mto->tag ? mto->tag : "");
RedisModuleString *empty_tag = RedisModule_CreateString(ctx, "", 0);
RedisModule_ReplyWithString(ctx, mto->tag ? mto->tag : empty_tag);

RedisModule_ReplyWithCString(ctx, "batchsize");
RedisModule_ReplyWithLongLong(ctx, (long)mto->opts.batchsize);
Expand Down Expand Up @@ -539,15 +540,15 @@ int RedisAI_ModelScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv

long long nkeys;
RedisModuleString **keys;
const char **tags;
RedisModuleString **tags;
RAI_ListStatsEntries(RAI_MODEL, &nkeys, &keys, &tags);

RedisModule_ReplyWithArray(ctx, nkeys);

for (long long i = 0; i < nkeys; i++) {
RedisModule_ReplyWithArray(ctx, 2);
RedisModule_ReplyWithString(ctx, keys[i]);
RedisModule_ReplyWithCString(ctx, tags[i]);
RedisModule_ReplyWithString(ctx, tags[i]);
}

RedisModule_Free(keys);
Expand Down Expand Up @@ -633,7 +634,7 @@ int RedisAI_ScriptGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
RedisModule_ReplyWithCString(ctx, "device");
RedisModule_ReplyWithCString(ctx, sto->devicestr);
RedisModule_ReplyWithCString(ctx, "tag");
RedisModule_ReplyWithCString(ctx, sto->tag);
RedisModule_ReplyWithString(ctx, sto->tag);
if (source) {
RedisModule_ReplyWithCString(ctx, "source");
RedisModule_ReplyWithCString(ctx, sto->scriptdef);
Expand Down Expand Up @@ -682,9 +683,9 @@ int RedisAI_ScriptSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
const char *devicestr;
AC_GetString(&ac, &devicestr, NULL, 0);

const char *tag = "";
RedisModuleString *tag = NULL;
if (AC_AdvanceIfMatch(&ac, "TAG")) {
AC_GetString(&ac, &tag, NULL, 0);
AC_GetRString(&ac, &tag, 0);
}

if (AC_IsAtEnd(&ac)) {
Expand Down Expand Up @@ -780,15 +781,15 @@ int RedisAI_ScriptScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg

long long nkeys;
RedisModuleString **keys;
const char **tags;
RedisModuleString **tags;
RAI_ListStatsEntries(RAI_SCRIPT, &nkeys, &keys, &tags);

RedisModule_ReplyWithArray(ctx, nkeys);

for (long long i = 0; i < nkeys; i++) {
RedisModule_ReplyWithArray(ctx, 2);
RedisModule_ReplyWithString(ctx, keys[i]);
RedisModule_ReplyWithCString(ctx, tags[i]);
RedisModule_ReplyWithString(ctx, tags[i]);
}

RedisModule_Free(keys);
Expand All @@ -803,7 +804,7 @@ int RedisAI_ScriptScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg
int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
if (argc != 2 && argc != 3)
return RedisModule_WrongArity(ctx);
const char *runkey = RedisModule_StringPtrLen(argv[1], NULL);
RedisModuleString *runkey = argv[1];
struct RedisAI_RunStats *rstats = NULL;
if (RAI_GetRunStats(runkey, &rstats) == REDISMODULE_ERR) {
return RedisModule_ReplyWithError(ctx, "ERR cannot find run info for key");
Expand Down Expand Up @@ -833,7 +834,11 @@ int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
RedisModule_ReplyWithCString(ctx, "device");
RedisModule_ReplyWithCString(ctx, rstats->devicestr);
RedisModule_ReplyWithCString(ctx, "tag");
RedisModule_ReplyWithCString(ctx, rstats->tag);
if (rstats->tag) {
RedisModule_ReplyWithString(ctx, rstats->tag);
} else {
RedisModule_ReplyWithCString(ctx, "");
}
RedisModule_ReplyWithCString(ctx, "duration");
RedisModule_ReplyWithLongLong(ctx, rstats->duration_us);
RedisModule_ReplyWithCString(ctx, "samples");
Expand Down Expand Up @@ -1209,9 +1214,10 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
return REDISMODULE_ERR;
}

run_stats = AI_dictCreate(&AI_dictTypeHeapStrings, NULL);
run_stats = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);

return REDISMODULE_OK;
}

extern AI_dictType AI_dictTypeHeapStrings;
extern AI_dictType AI_dictTypeHeapRStrings;
21 changes: 13 additions & 8 deletions src/run_info.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,26 @@
#include "tensor.h"
#include "util/arr_rm_alloc.h"
#include "util/dict.h"
#include <pthread.h>

static uint64_t RAI_TensorDictKeyHashFunction(const void *key) {
return AI_dictGenHashFunction(key, strlen((char *)key));
size_t len;
const char *buffer = RedisModule_StringPtrLen((RedisModuleString *)key, &len);
return AI_dictGenHashFunction(buffer, len);
}

static int RAI_TensorDictKeyStrcmp(void *privdata, const void *key1, const void *key2) {
const char *strKey1 = key1;
const char *strKey2 = key2;
return strcmp(strKey1, strKey2) == 0;
RedisModuleString *strKey1 = (RedisModuleString *)key1;
RedisModuleString *strKey2 = (RedisModuleString *)key2;
return RedisModule_StringCompare(strKey1, strKey2) == 0;
}

static void RAI_TensorDictKeyFree(void *privdata, void *key) { RedisModule_Free(key); }
static void RAI_TensorDictKeyFree(void *privdata, void *key) {
RedisModule_FreeString(NULL, (RedisModuleString *)key);
}

static void *RAI_TensorDictKeyDup(void *privdata, const void *key) {
return RedisModule_Strdup((char *)key);
return RedisModule_CreateStringFromString(NULL, (RedisModuleString *)key);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RAI_HoldString ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's now the same code as above, so same consideration.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is duplicated logic of the RSring dictionaries. Any chance to consolidate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

static void RAI_TensorDictValFree(void *privdata, void *obj) {
Expand Down Expand Up @@ -105,11 +110,11 @@ int RAI_InitRunInfo(RedisAI_RunInfo **result) {
if (!(rinfo->dagTensorsContext)) {
return REDISMODULE_ERR;
}
rinfo->dagTensorsLoadedContext = AI_dictCreate(&AI_dictTypeHeapStrings, NULL);
rinfo->dagTensorsLoadedContext = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
if (!(rinfo->dagTensorsLoadedContext)) {
return REDISMODULE_ERR;
}
rinfo->dagTensorsPersistedContext = AI_dictCreate(&AI_dictTypeHeapStrings, NULL);
rinfo->dagTensorsPersistedContext = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
if (!(rinfo->dagTensorsPersistedContext)) {
return REDISMODULE_ERR;
}
Expand Down
26 changes: 17 additions & 9 deletions src/script.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static void *RAI_Script_RdbLoad(struct RedisModuleIO *io, int encver) {
RAI_Error err = {0};

const char *devicestr = RedisModule_LoadStringBuffer(io, NULL);
const char *tag = RedisModule_LoadStringBuffer(io, NULL);
RedisModuleString *tag = RedisModule_LoadString(io);

size_t len;
char *scriptdef = RedisModule_LoadStringBuffer(io, &len);
Expand Down Expand Up @@ -58,12 +58,15 @@ static void *RAI_Script_RdbLoad(struct RedisModuleIO *io, int encver) {
RedisModuleString *stats_keystr =
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));
const char *stats_devicestr = RedisModule_Strdup(devicestr);
const char *stats_tag = RedisModule_Strdup(tag);

if (tag) {
RedisModule_RetainString(NULL, tag);
}

script->infokey = RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_SCRIPT, RAI_BACKEND_TORCH,
stats_devicestr, stats_tag);
stats_devicestr, tag);

RedisModule_Free(stats_keystr);
RedisModule_FreeString(NULL, stats_keystr);

return script;
}
Expand All @@ -74,14 +77,14 @@ static void RAI_Script_RdbSave(RedisModuleIO *io, void *value) {
size_t len = strlen(script->scriptdef) + 1;

RedisModule_SaveStringBuffer(io, script->devicestr, strlen(script->devicestr) + 1);
RedisModule_SaveStringBuffer(io, script->tag, strlen(script->tag) + 1);
RedisModule_SaveString(io, script->tag);
RedisModule_SaveStringBuffer(io, script->scriptdef, len);
}

static void RAI_Script_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, void *value) {
RAI_Script *script = (RAI_Script *)value;

RedisModule_EmitAOF(aof, "AI.SCRIPTSET", "scccc", key, script->devicestr, script->tag, "SOURCE",
RedisModule_EmitAOF(aof, "AI.SCRIPTSET", "scscc", key, script->devicestr, script->tag, "SOURCE",
script->scriptdef);
}

Expand All @@ -107,7 +110,7 @@ int RAI_ScriptInit(RedisModuleCtx *ctx) {
return RedisAI_ScriptType != NULL;
}

RAI_Script *RAI_ScriptCreate(const char *devicestr, const char *tag, const char *scriptdef,
RAI_Script *RAI_ScriptCreate(const char *devicestr, RedisModuleString *tag, const char *scriptdef,
RAI_Error *err) {
if (!RAI_backends.torch.script_create) {
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH");
Expand All @@ -116,7 +119,12 @@ RAI_Script *RAI_ScriptCreate(const char *devicestr, const char *tag, const char
RAI_Script *script = RAI_backends.torch.script_create(devicestr, scriptdef, err);

if (script) {
script->tag = RedisModule_Strdup(tag);
if (tag) {
RedisModule_RetainString(NULL, tag);
script->tag = tag;
} else {
script->tag = RedisModule_CreateString(NULL, "", 0);
}
}

return script;
Expand All @@ -132,7 +140,7 @@ void RAI_ScriptFree(RAI_Script *script, RAI_Error *err) {
return;
}

RedisModule_Free(script->tag);
RedisModule_FreeString(NULL, script->tag);

RAI_RemoveStatsEntry(script->infokey);

Expand Down
2 changes: 1 addition & 1 deletion src/script.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int RAI_ScriptInit(RedisModuleCtx *ctx);
* failures
* @return RAI_Script script structure on success, or NULL if failed
*/
RAI_Script *RAI_ScriptCreate(const char *devicestr, const char *tag, const char *scriptdef,
RAI_Script *RAI_ScriptCreate(const char *devicestr, RedisModuleString *tag, const char *scriptdef,
RAI_Error *err);

/**
Expand Down
2 changes: 1 addition & 1 deletion src/script_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ typedef struct RAI_Script {
// We keep it here at the moment, until we have a
// CUDA allocator for dlpack
char *devicestr;
char *tag;
RedisModuleString *tag;
long long refCount;
void *infokey;
} RAI_Script;
Expand Down
21 changes: 11 additions & 10 deletions src/stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,32 @@ long long ustime(void) {
mstime_t mstime(void) { return ustime() / 1000; }

void *RAI_AddStatsEntry(RedisModuleCtx *ctx, RedisModuleString *key, RAI_RunType runtype,
RAI_Backend backend, const char *devicestr, const char *tag) {
const char *infokey = RedisModule_StringPtrLen(key, NULL);

RAI_Backend backend, const char *devicestr, RedisModuleString *tag) {
struct RedisAI_RunStats *rstats = NULL;
rstats = RedisModule_Calloc(1, sizeof(struct RedisAI_RunStats));
RedisModule_RetainString(ctx, key);
rstats->key = key;
rstats->type = runtype;
rstats->backend = backend;
rstats->devicestr = RedisModule_Strdup(devicestr);
rstats->tag = RedisModule_Strdup(tag);
if (tag) {
RedisModule_RetainString(ctx, tag);
}
rstats->tag = tag;

AI_dictAdd(run_stats, (void *)infokey, (void *)rstats);
AI_dictAdd(run_stats, (void *)key, (void *)rstats);

return (void *)infokey;
return (void *)key;
}

void RAI_ListStatsEntries(RAI_RunType type, long long *nkeys, RedisModuleString ***keys,
const char ***tags) {
RedisModuleString ***tags) {
AI_dictIterator *stats_iter = AI_dictGetSafeIterator(run_stats);

long long stats_size = AI_dictSize(run_stats);

*keys = RedisModule_Calloc(stats_size, sizeof(RedisModuleString *));
*tags = RedisModule_Calloc(stats_size, sizeof(const char *));
*tags = RedisModule_Calloc(stats_size, sizeof(RedisModuleString *));

*nkeys = 0;

Expand Down Expand Up @@ -109,13 +110,13 @@ void RAI_FreeRunStats(struct RedisAI_RunStats *rstats) {
RedisModule_Free(rstats->devicestr);
}
if (rstats->tag) {
RedisModule_Free(rstats->tag);
RedisModule_FreeString(NULL, rstats->tag);
}
RedisModule_Free(rstats);
}
}

int RAI_GetRunStats(const char *runkey, struct RedisAI_RunStats **rstats) {
int RAI_GetRunStats(RedisModuleString *runkey, struct RedisAI_RunStats **rstats) {
int result = 1;
if (run_stats == NULL) {
return result;
Expand Down
Loading