Skip to content

Commit dae6e61

Browse files
authored
Merge pull request #25 from emlearn/cnn-dequant
CNN: Fix wrong output values due to missing dequant
2 parents 3d771b7 + b48370a commit dae6e61

File tree

8 files changed

+301
-256
lines changed

8 files changed

+301
-256
lines changed

examples/mnist_cnn/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ mpremote mip install https://emlearn.github.io/emlearn-micropython/builds/master
5555

5656
```console
5757
mpremote cp mnist_cnn.tmdl :
58-
mpremote cp -r data/ :
58+
mpremote cp -r test_data/ :
5959
mpremote run mnist_cnn_run.py
6060
```
6161

examples/mnist_cnn/mnist_cnn.h

Lines changed: 226 additions & 226 deletions
Large diffs are not rendered by default.

examples/mnist_cnn/mnist_cnn.h5

0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

examples/mnist_cnn/mnist_cnn.tmdl

0 Bytes
Binary file not shown.

examples/mnist_cnn/mnist_cnn_run.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11

2+
import os
23
import array
3-
import emlearn_cnn
44
import time
55
import gc
66

7+
import emlearn_cnn
8+
79
MODEL = 'mnist_cnn.tmdl'
8-
TEST_DATA_DIR = 'data/'
10+
TEST_DATA_DIR = 'test_data'
911

1012
def argmax(arr):
1113
idx_max = 0
@@ -30,6 +32,27 @@ def print_2d_buffer(arr, rowstride):
3032
gc.collect()
3133
print('\n')
3234

35+
def load_images_from_directory(path):
36+
sep = '/'
37+
38+
for filename in os.listdir(path):
39+
# TODO: support standard image formats, like .bmp/.png/.jpeg
40+
if not filename.endswith('.bin'):
41+
continue
42+
43+
# Find the label (if any). The last part, X_label.format
44+
label = None
45+
basename = filename.split('.')[0]
46+
tok = basename.split('_')
47+
if len(tok) > 2:
48+
label = tok[-1]
49+
50+
data_path = path + sep + filename
51+
with open(data_path, 'rb') as f:
52+
img = array.array('B', f.read())
53+
54+
yield img, label
55+
3356
def test_cnn_mnist():
3457

3558
# load model
@@ -42,22 +65,28 @@ def test_cnn_mnist():
4265
probabilities = array.array('f', (-1 for _ in range(out_length)))
4366

4467
# run on some test data
45-
for class_no in range(0, 10):
46-
data_path = TEST_DATA_DIR + 'mnist_example_{0:d}.bin'.format(class_no)
47-
#print('open', data_path)
48-
with open(data_path, 'rb') as f:
49-
img = array.array('B', f.read())
50-
51-
print_2d_buffer(img, 28)
52-
53-
run_start = time.ticks_us()
54-
model.run(img, probabilities)
55-
out = argmax(probabilities)
56-
run_duration = time.ticks_diff(time.ticks_us(), run_start) / 1000.0 # ms
57-
58-
print('mnist-example-check', class_no, out, class_no == out, run_duration)
68+
n_correct = 0
69+
n_total = 0
70+
for img, label in load_images_from_directory(TEST_DATA_DIR):
71+
class_no = int(label) # mnist class labels are digits
72+
73+
#print_2d_buffer(img, 28)
74+
75+
run_start = time.ticks_us()
76+
model.run(img, probabilities)
77+
out = argmax(probabilities)
78+
run_duration = time.ticks_diff(time.ticks_us(), run_start) / 1000.0 # ms
79+
correct = class_no == out
80+
n_total += 1
81+
if correct:
82+
n_correct += 1
83+
84+
print('mnist-example-check', class_no, '=', out, correct, round(run_duration, 3))
5985

6086
gc.collect()
6187

88+
accuracy = n_correct / n_total
89+
print('mnist-example-done', n_correct, '/', n_total, round(accuracy*100, ), '%')
90+
6291
if __name__ == '__main__':
6392
test_cnn_mnist()

examples/mnist_cnn/mnist_train.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def train_mnist(h5_file, epochs=10):
4040
(x_orig_train, y_orig_train), (x_orig_test, y_orig_test) = mnist.load_data()
4141
num_classes = 10
4242

43-
generate_test_files('test_data', x_orig_test, y_orig_test)
43+
TEST_DATA_DIR = 'test_data'
44+
generate_test_files(TEST_DATA_DIR, x_orig_test, y_orig_test)
45+
print('Wrote test data to', TEST_DATA_DIR)
4446

4547
x_train = x_orig_train
4648
x_test = x_orig_test
@@ -58,7 +60,7 @@ def train_mnist(h5_file, epochs=10):
5860

5961
model.save(h5_file)
6062

61-
def generate_test_files(out_dir, x, y):
63+
def generate_test_files(out_dir, x, y, samples_per_class=5):
6264

6365
if not os.path.exists(out_dir):
6466
os.makedirs(out_dir)
@@ -71,15 +73,12 @@ def generate_test_files(out_dir, x, y):
7173
# select one per class
7274
for class_no in classes:
7375
matches = (Y_classes == class_no)
74-
print('mm', matches.shape)
7576
x_matches = X_series[matches]
7677

77-
selected = x_matches.sample(n=1, random_state=1)
78-
for s in selected:
79-
print('ss', s.shape, s.dtype)
80-
print(s)
81-
out = os.path.join(out_dir, f'mnist_example_{class_no}.bin')
82-
data = s.tobytes(order='C')
78+
selected = x_matches.sample(n=samples_per_class, random_state=1)
79+
for i, sample in enumerate(selected):
80+
out = os.path.join(out_dir, f'mnist_example_{i}_{class_no}.bin')
81+
data = sample.tobytes(order='C')
8382

8483
assert len(data) == expect_bytes, (len(data), expect_bytes)
8584
with open(out, 'wb') as f:
@@ -94,9 +93,10 @@ def generate_tinymaix_model(h5_file,
9493
precision='fp32',
9594
quantize_data=None,
9695
quantize_type='0to1',
97-
output_dequantize=False,
9896
):
9997

98+
output_dequantize = quantize_data is not None
99+
100100
# Convert .h5 to .tflite file
101101
assert h5_file.endswith('.h5'), 'Keras model HDF5 file must end with .h5'
102102
tflite_file = h5_file.replace('.h5', '.tflite')

src/tinymaix_cnn/mod_cnn.c

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
#include <tinymaix.h>
88

99
#include "tm_layers.c"
10+
//#include "tm_layers_O1.c"
1011
#include "tm_model.c"
1112
//#include "tm_stat.c"
1213

1314
#include <string.h>
1415

16+
#define DEBUG (1)
17+
1518

1619
// memset is used by some standard C constructs
1720
#if !defined(__linux__)
@@ -58,10 +61,13 @@ int TM_WEAK tm_get_outputs(tm_mdl_t* mdl, tm_mat_t* out, int out_length)
5861

5962
static tm_err_t layer_cb(tm_mdl_t* mdl, tml_head_t* lh)
6063
{
64+
#if DEBUG
65+
mp_printf(&mp_plat_print, "cnn-layer-cb type=%d \n", lh->type);
66+
#endif
67+
6168
return TM_OK;
6269
}
6370

64-
#define DEBUG (1)
6571

6672
// MicroPython type
6773
typedef struct _mp_obj_mod_cnn_t {
@@ -209,12 +215,22 @@ static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj, mp_obj_t outp
209215
mp_raise_ValueError(MP_ERROR_TEXT("run error"));
210216
}
211217

212-
// Copy output into
218+
// Copy output
213219
tm_mat_t out = outs[0];
220+
221+
#if DEBUG
222+
mp_printf(&mp_plat_print, "cnn-run out.dims=(%d,%d,%d,%d) out.length=%d expect_length=%d \n",
223+
out.dims, out.h, out.w, out.c, expect_out_length
224+
);
225+
#endif
226+
227+
if (!((out.dims == 1) && (out.h == 1) && (out.w == 1) && out.c == expect_out_length)) {
228+
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("unexpected output dims"));
229+
}
230+
214231
for(int i=0; i<expect_out_length; i++){
215232
output_buffer[i] = out.dataf[i];
216233
}
217-
218234
return mp_const_none;
219235
}
220236
static MP_DEFINE_CONST_FUN_OBJ_3(mod_cnn_run_obj, mod_cnn_run);

0 commit comments

Comments
 (0)