Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion examples/har_trees/har_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ def mean(arr):
m = sum(arr) / float(len(arr))
return m

def argmax(arr):
idx_max = 0
value_max = arr[0]
for i in range(1, len(arr)):
if arr[i] > value_max:
value_max = arr[i]
idx_max = i

return idx_max

def copy_array_into(source, target):
assert len(source) == len(target)
for i in range(len(target)):
Expand Down Expand Up @@ -63,6 +73,7 @@ def main():
features_typecode = timebased.DATA_TYPECODE
n_features = timebased.N_FEATURES
features = array.array(features_typecode, (0 for _ in range(n_features)))
out = array.array('f', range(model.outputs()))

while True:

Expand All @@ -87,7 +98,8 @@ def main():

# Cun classifier
#print(features)
result = model.predict(features)
model.predict(features, out)
result = argmax(out)
activity = class_index_to_name[result]

d = time.ticks_diff(time.ticks_ms(), start)
Expand Down
14 changes: 13 additions & 1 deletion examples/har_trees/har_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
import emlearn_trees
import timebased

def argmax(arr):
idx_max = 0
value_max = arr[0]
for i in range(1, len(arr)):
if arr[i] > value_max:
value_max = arr[i]
idx_max = i

return idx_max

def har_load_test_data(path,
skip_samples=0, limit_samples=None):

Expand Down Expand Up @@ -63,6 +73,7 @@ def main():
with open(model_path, 'r') as f:
emlearn_trees.load_model(model, f)

out = array.array('f', range(model.outputs()))

errors = 0
total = 0
Expand All @@ -72,7 +83,8 @@ def main():

assert len(labels) == 1
label = labels[0]
result = model.predict(features)
model.predict(features, out)
result = argmax(out)
if result != label:
errors += 1
total += 1
Expand Down
53 changes: 44 additions & 9 deletions src/emlearn_trees/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,25 @@ static mp_obj_t builder_addleaf(mp_obj_t self_obj, mp_obj_t leaf_obj) {
static MP_DEFINE_CONST_FUN_OBJ_2(builder_addleaf_obj, builder_addleaf);


// Return the shape of the output
static mp_obj_t builder_get_outputs(mp_obj_t self_obj) {

mp_obj_trees_builder_t *o = MP_OBJ_TO_PTR(self_obj);
EmlTreesBuilder *self = &o->builder;

const int n_classes = self->trees.n_classes;
if (n_classes == 0) {
mp_raise_ValueError(MP_ERROR_TEXT("model not loaded"));
}

return mp_obj_new_int(n_classes);
}
static MP_DEFINE_CONST_FUN_OBJ_1(builder_get_outputs_obj, builder_get_outputs);



// Takes a array of input data
static mp_obj_t builder_predict(mp_obj_t self_obj, mp_obj_t features_obj) {
static mp_obj_t builder_predict(mp_obj_t self_obj, mp_obj_t features_obj, mp_obj_t output_obj) {

mp_obj_trees_builder_t *o = MP_OBJ_TO_PTR(self_obj);
EmlTreesBuilder *self = &o->builder;
Expand All @@ -212,28 +229,45 @@ static mp_obj_t builder_predict(mp_obj_t self_obj, mp_obj_t features_obj) {

const int16_t *features = bufinfo.buf;
const int n_features = bufinfo.len / sizeof(*features);
const int n_outputs = self->trees.n_classes;

#if EMLEARN_MICROPYTHON_DEBUG
mp_printf(&mp_plat_print,
"emltrees-predict n_features=%d n_classes=%d leaves=%d nodes=%d trees=%d length=%d \n",
"emltrees-predict n_features=%d n_classes=%d leaves=%d nodes=%d trees=%d length=%d outputs=%d \n",
self->trees.n_features, self->trees.n_classes,
self->trees.n_leaves, self->trees.n_nodes, self->trees.n_trees,
n_features
);
#endif

if (n_features == 0 || n_outputs == 0) {
mp_raise_ValueError(MP_ERROR_TEXT("model not loaded"));
}

// Extract output
mp_get_buffer_raise(output_obj, &bufinfo, MP_BUFFER_RW);
if (bufinfo.typecode != 'f') {
mp_raise_ValueError(MP_ERROR_TEXT("expecting float output array"));
}
float *output_buffer = bufinfo.buf;
const int output_length = bufinfo.len / sizeof(*output_buffer);


// call model
const int result = eml_trees_predict(&self->trees, features, n_features);
if (result < 0) {
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("eml_trees_predict error"));
// NOTE: also handles checking of input and output lengths
const EmlError err = \
eml_trees_predict_proba(&self->trees, features, n_features, output_buffer, output_length);

if (err != EmlOk) {
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("eml_trees_predict_proba error"));
}

return mp_obj_new_int(result);
return mp_const_none;
}
static MP_DEFINE_CONST_FUN_OBJ_2(builder_predict_obj, builder_predict);
static MP_DEFINE_CONST_FUN_OBJ_3(builder_predict_obj, builder_predict);


mp_map_elem_t trees_locals_dict_table[6];
mp_map_elem_t trees_locals_dict_table[7];
static MP_DEFINE_CONST_DICT(trees_locals_dict, trees_locals_dict_table);

// This is the entry point and is called when the module is imported
Expand All @@ -253,8 +287,9 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
trees_locals_dict_table[3] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_addleaf), MP_OBJ_FROM_PTR(&builder_addleaf_obj) };
trees_locals_dict_table[4] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&builder_del_obj) };
trees_locals_dict_table[5] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_setdata), MP_OBJ_FROM_PTR(&builder_setdata_obj) };
trees_locals_dict_table[6] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_outputs), MP_OBJ_FROM_PTR(&builder_get_outputs_obj) };

MP_OBJ_TYPE_SET_SLOT(&trees_builder_type, locals_dict, (void*)&trees_locals_dict, 6);
MP_OBJ_TYPE_SET_SLOT(&trees_builder_type, locals_dict, (void*)&trees_locals_dict, 7);

// This must be last, it restores the globals dict
MP_DYNRUNTIME_INIT_EXIT
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_cnn_mnist():
if out == class_no:
correct += 1

assert correct >= 6, correct
assert correct >= 9, correct


test_cnn_create()
Expand Down
15 changes: 14 additions & 1 deletion tests/test_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
import array
import gc

def argmax(arr):
idx_max = 0
value_max = arr[0]
for i in range(1, len(arr)):
if arr[i] > value_max:
value_max = arr[i]
idx_max = i

return idx_max

def test_trees_del():
"""
Deleting the model should free all the memory
Expand Down Expand Up @@ -45,9 +55,12 @@ def test_trees_xor():
( [1*s, 0], 1 ),
]

out = array.array('f', range(model.outputs()))

for (ex, expect) in examples:
f = array.array('h', ex)
result = model.predict(f)
model.predict(f, out)
result = argmax(out)
assert result == expect, (ex, expect, result)

test_trees_del()
Expand Down