diff --git a/examples/har_trees/har_live.py b/examples/har_trees/har_live.py index 69dea35..d8fb326 100644 --- a/examples/har_trees/har_live.py +++ b/examples/har_trees/har_live.py @@ -15,6 +15,16 @@ def mean(arr): m = sum(arr) / float(len(arr)) return m +def argmax(arr): + idx_max = 0 + value_max = arr[0] + for i in range(1, len(arr)): + if arr[i] > value_max: + value_max = arr[i] + idx_max = i + + return idx_max + def copy_array_into(source, target): assert len(source) == len(target) for i in range(len(target)): @@ -63,6 +73,7 @@ def main(): features_typecode = timebased.DATA_TYPECODE n_features = timebased.N_FEATURES features = array.array(features_typecode, (0 for _ in range(n_features))) + out = array.array('f', range(model.outputs())) while True: @@ -87,7 +98,8 @@ def main(): # Cun classifier #print(features) - result = model.predict(features) + model.predict(features, out) + result = argmax(out) activity = class_index_to_name[result] d = time.ticks_diff(time.ticks_ms(), start) diff --git a/examples/har_trees/har_run.py b/examples/har_trees/har_run.py index 8bf7f62..efdaa70 100644 --- a/examples/har_trees/har_run.py +++ b/examples/har_trees/har_run.py @@ -6,6 +6,16 @@ import emlearn_trees import timebased +def argmax(arr): + idx_max = 0 + value_max = arr[0] + for i in range(1, len(arr)): + if arr[i] > value_max: + value_max = arr[i] + idx_max = i + + return idx_max + def har_load_test_data(path, skip_samples=0, limit_samples=None): @@ -63,6 +73,7 @@ def main(): with open(model_path, 'r') as f: emlearn_trees.load_model(model, f) + out = array.array('f', range(model.outputs())) errors = 0 total = 0 @@ -72,7 +83,8 @@ def main(): assert len(labels) == 1 label = labels[0] - result = model.predict(features) + model.predict(features, out) + result = argmax(out) if result != label: errors += 1 total += 1 diff --git a/src/emlearn_trees/trees.c b/src/emlearn_trees/trees.c index 6c7460b..882440f 100644 --- a/src/emlearn_trees/trees.c +++ b/src/emlearn_trees/trees.c @@ -197,8 +197,25 @@ static mp_obj_t builder_addleaf(mp_obj_t self_obj, mp_obj_t leaf_obj) { static MP_DEFINE_CONST_FUN_OBJ_2(builder_addleaf_obj, builder_addleaf); +// Return the shape of the output +static mp_obj_t builder_get_outputs(mp_obj_t self_obj) { + + mp_obj_trees_builder_t *o = MP_OBJ_TO_PTR(self_obj); + EmlTreesBuilder *self = &o->builder; + + const int n_classes = self->trees.n_classes; + if (n_classes == 0) { + mp_raise_ValueError(MP_ERROR_TEXT("model not loaded")); + } + + return mp_obj_new_int(n_classes); +} +static MP_DEFINE_CONST_FUN_OBJ_1(builder_get_outputs_obj, builder_get_outputs); + + + // Takes a array of input data -static mp_obj_t builder_predict(mp_obj_t self_obj, mp_obj_t features_obj) { +static mp_obj_t builder_predict(mp_obj_t self_obj, mp_obj_t features_obj, mp_obj_t output_obj) { mp_obj_trees_builder_t *o = MP_OBJ_TO_PTR(self_obj); EmlTreesBuilder *self = &o->builder; @@ -212,28 +229,45 @@ static mp_obj_t builder_predict(mp_obj_t self_obj, mp_obj_t features_obj) { const int16_t *features = bufinfo.buf; const int n_features = bufinfo.len / sizeof(*features); + const int n_outputs = self->trees.n_classes; #if EMLEARN_MICROPYTHON_DEBUG mp_printf(&mp_plat_print, - "emltrees-predict n_features=%d n_classes=%d leaves=%d nodes=%d trees=%d length=%d \n", + "emltrees-predict n_features=%d n_classes=%d leaves=%d nodes=%d trees=%d length=%d outputs=%d \n", self->trees.n_features, self->trees.n_classes, self->trees.n_leaves, self->trees.n_nodes, self->trees.n_trees, n_features ); #endif + if (n_features == 0 || n_outputs == 0) { + mp_raise_ValueError(MP_ERROR_TEXT("model not loaded")); + } + + // Extract output + mp_get_buffer_raise(output_obj, &bufinfo, MP_BUFFER_RW); + if (bufinfo.typecode != 'f') { + mp_raise_ValueError(MP_ERROR_TEXT("expecting float output array")); + } + float *output_buffer = bufinfo.buf; + const int output_length = bufinfo.len / sizeof(*output_buffer); + + // call model - const int result = eml_trees_predict(&self->trees, features, n_features); - if (result < 0) { - mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("eml_trees_predict error")); + // NOTE: also handles checking of input and output lengths + const EmlError err = \ + eml_trees_predict_proba(&self->trees, features, n_features, output_buffer, output_length); + + if (err != EmlOk) { + mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("eml_trees_predict_proba error")); } - return mp_obj_new_int(result); + return mp_const_none; } -static MP_DEFINE_CONST_FUN_OBJ_2(builder_predict_obj, builder_predict); +static MP_DEFINE_CONST_FUN_OBJ_3(builder_predict_obj, builder_predict); -mp_map_elem_t trees_locals_dict_table[6]; +mp_map_elem_t trees_locals_dict_table[7]; static MP_DEFINE_CONST_DICT(trees_locals_dict, trees_locals_dict_table); // This is the entry point and is called when the module is imported @@ -253,8 +287,9 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a trees_locals_dict_table[3] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_addleaf), MP_OBJ_FROM_PTR(&builder_addleaf_obj) }; trees_locals_dict_table[4] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&builder_del_obj) }; trees_locals_dict_table[5] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_setdata), MP_OBJ_FROM_PTR(&builder_setdata_obj) }; + trees_locals_dict_table[6] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_outputs), MP_OBJ_FROM_PTR(&builder_get_outputs_obj) }; - MP_OBJ_TYPE_SET_SLOT(&trees_builder_type, locals_dict, (void*)&trees_locals_dict, 6); + MP_OBJ_TYPE_SET_SLOT(&trees_builder_type, locals_dict, (void*)&trees_locals_dict, 7); // This must be last, it restores the globals dict MP_DYNRUNTIME_INIT_EXIT diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 8cfb36b..b1b3aa0 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -73,7 +73,7 @@ def test_cnn_mnist(): if out == class_no: correct += 1 - assert correct >= 6, correct + assert correct >= 9, correct test_cnn_create() diff --git a/tests/test_trees.py b/tests/test_trees.py index cca418a..a968f64 100644 --- a/tests/test_trees.py +++ b/tests/test_trees.py @@ -4,6 +4,16 @@ import array import gc +def argmax(arr): + idx_max = 0 + value_max = arr[0] + for i in range(1, len(arr)): + if arr[i] > value_max: + value_max = arr[i] + idx_max = i + + return idx_max + def test_trees_del(): """ Deleting the model should free all the memory @@ -45,9 +55,12 @@ def test_trees_xor(): ( [1*s, 0], 1 ), ] + out = array.array('f', range(model.outputs())) + for (ex, expect) in examples: f = array.array('h', ex) - result = model.predict(f) + model.predict(f, out) + result = argmax(out) assert result == expect, (ex, expect, result) test_trees_del()