@@ -34,39 +34,31 @@ void NORETURN abort() {
3434}
3535#endif
3636
37-
38- static tm_err_t layer_cb (tm_mdl_t * mdl , tml_head_t * lh )
39- {
40- #if 0
41- //dump middle result
42- int h = lh -> out_dims [1 ];
43- int w = lh -> out_dims [2 ];
44- int ch = lh -> out_dims [3 ];
45- mtype_t * output = TML_GET_OUTPUT (mdl , lh );
46- return TM_OK ;
47- TM_PRINTF ("Layer %d callback ========\n" , mdl -> layer_i );
48- #if 1
49- for (int y = 0 ; y < h ; y ++ ){
50- TM_PRINTF ("[" );
51- for (int x = 0 ; x < w ; x ++ ){
52- TM_PRINTF ("[" );
53- for (int c = 0 ; c < ch ; c ++ ){
54- #if TM_MDL_TYPE == TM_MDL_FP32
55- TM_PRINTF ("%.3f," , output [(y * w + x )* ch + c ]);
56- #else
57- TM_PRINTF ("%.3f," , TML_DEQUANT (lh ,output [(y * w + x )* ch + c ]));
58- #endif
37+ // get model output shapes
38+ //mdl: model handle; in: input mat; out: output mat
39+ int TM_WEAK tm_get_outputs (tm_mdl_t * mdl , tm_mat_t * out , int out_length )
40+ {
41+ // NOTE: based on tm_run, but without actually executing
42+ int out_idx = 0 ;
43+ mdl -> layer_body = mdl -> b -> layers_body ;
44+ for (mdl -> layer_i = 0 ; mdl -> layer_i < mdl -> b -> layer_cnt ; mdl -> layer_i ++ ){
45+ tml_head_t * h = (tml_head_t * )(mdl -> layer_body );
46+ if (h -> is_out ) {
47+ if (out_idx < out_length ) {
48+ memcpy ((void * )(& out [out_idx ]), (void * )(& (h -> out_dims )), sizeof (uint16_t )* 4 );
49+ out_idx += 1 ;
50+ } else {
51+ return -1 ;
5952 }
60- TM_PRINTF ("]," );
6153 }
62- TM_PRINTF ( "],\n" );
54+ mdl -> layer_body += ( h -> size );
6355 }
64- TM_PRINTF ("\n" );
65- #endif
66- return TM_OK ;
67- #else
56+ return out_idx ;
57+ }
58+
59+ static tm_err_t layer_cb (tm_mdl_t * mdl , tml_head_t * lh )
60+ {
6861 return TM_OK ;
69- #endif
7062}
7163
7264#define DEBUG (1)
@@ -79,6 +71,7 @@ typedef struct _mp_obj_mod_cnn_t {
7971 tm_mat_t input ;
8072 uint8_t * model_buffer ;
8173 uint8_t * data_buffer ;
74+ uint16_t out_dims [4 ];
8275} mp_obj_mod_cnn_t ;
8376
8477mp_obj_full_type_t mod_cnn_type ;
@@ -121,6 +114,25 @@ static mp_obj_t mod_cnn_new(mp_obj_t model_data_obj) {
121114 mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("tm_load error" ));
122115 }
123116
117+ // find model output shape
118+ o -> out_dims [0 ] = 0 ;
119+ tm_mat_t outs [1 ];
120+ const int outputs = tm_get_outputs (model , outs , 1 );
121+ if (outputs != 1 ) {
122+ mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("only 1 output supported" ));
123+ }
124+ memcpy ((void * )(o -> out_dims ), (void * )(& (outs [0 ])), sizeof (uint16_t )* 4 );
125+
126+ if ((o -> out_dims [0 ] != 1 )) {
127+ mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("output must be 1d" ));
128+ }
129+ memcpy ((void * )(o -> out_dims ), (void * )(& (outs [0 ])), sizeof (uint16_t )* 4 );
130+
131+ #if DEBUG
132+ mp_printf (& mp_plat_print , "cnn-new-done outs=%d out.dims=(%d,%d,%d,%d) \n" ,
133+ outputs , o -> out_dims [0 ], o -> out_dims [1 ], o -> out_dims [2 ], o -> out_dims [3 ]);
134+ #endif
135+
124136 return MP_OBJ_FROM_PTR (o );
125137}
126138static MP_DEFINE_CONST_FUN_OBJ_1 (mod_cnn_new_obj , mod_cnn_new ) ;
@@ -141,15 +153,15 @@ static MP_DEFINE_CONST_FUN_OBJ_1(mod_cnn_del_obj, mod_cnn_del);
141153
142154
143155// Add a node to the tree
144- static mp_obj_t mod_cnn_run (mp_obj_t self_obj , mp_obj_t input_obj ) {
156+ static mp_obj_t mod_cnn_run (mp_obj_t self_obj , mp_obj_t input_obj , mp_obj_t output_obj ) {
145157
146158 mp_obj_mod_cnn_t * o = MP_OBJ_TO_PTR (self_obj );
147159
148160 // Extract input
149161 mp_buffer_info_t bufinfo ;
150162 mp_get_buffer_raise (input_obj , & bufinfo , MP_BUFFER_RW );
151163 if (bufinfo .typecode != 'B' ) {
152- mp_raise_ValueError (MP_ERROR_TEXT ("expecting float array" ));
164+ mp_raise_ValueError (MP_ERROR_TEXT ("expecting byte array" ));
153165 }
154166 uint8_t * input_buffer = bufinfo .buf ;
155167 const int input_length = bufinfo .len / sizeof (* input_buffer );
@@ -160,6 +172,21 @@ static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj) {
160172 mp_raise_ValueError (MP_ERROR_TEXT ("wrong input size" ));
161173 }
162174
175+ // Extract output
176+ mp_get_buffer_raise (output_obj , & bufinfo , MP_BUFFER_RW );
177+ if (bufinfo .typecode != 'f' ) {
178+ mp_raise_ValueError (MP_ERROR_TEXT ("expecting float array" ));
179+ }
180+ float * output_buffer = bufinfo .buf ;
181+ const int output_length = bufinfo .len / sizeof (* output_buffer );
182+
183+
184+ // check buffer size wrt input
185+ const int expect_out_length = o -> out_dims [1 ]* o -> out_dims [2 ]* o -> out_dims [3 ];
186+ if (output_length != expect_out_length ) {
187+ mp_raise_ValueError (MP_ERROR_TEXT ("wrong output size" ));
188+ }
189+
163190 // Preprocess data
164191 tm_mat_t in_uint8 = o -> input ;
165192 in_uint8 .data = (mtype_t * )input_buffer ;
@@ -181,27 +208,38 @@ static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj) {
181208 mp_raise_ValueError (MP_ERROR_TEXT ("run error" ));
182209 }
183210
211+ // Copy output into
184212 tm_mat_t out = outs [0 ];
185- float * data = out .dataf ;
186- float maxp = 0 ;
187- int maxi = -1 ;
188-
189- // TODO: pass the entire output vector out to Python
190- // FIXME: unhardcode output handling
191- for (int i = 0 ; i < 10 ; i ++ ){
192- //printf("%d: %.3f\n", i, data[i]);
193- if (data [i ] > maxp ) {
194- maxi = i ;
195- maxp = data [i ];
196- }
213+ for (int i = 0 ; i < expect_out_length ; i ++ ){
214+ output_buffer [i ] = out .dataf [i ];
215+ }
216+
217+ return mp_const_none ;
218+ }
219+ static MP_DEFINE_CONST_FUN_OBJ_3 (mod_cnn_run_obj , mod_cnn_run ) ;
220+
221+
222+ // Return the shape of the output
223+ static mp_obj_t mod_cnn_output_dimensions (mp_obj_t self_obj ) {
224+
225+ mp_obj_mod_cnn_t * o = MP_OBJ_TO_PTR (self_obj );
226+ const int dimensions = o -> out_dims [0 ];
227+ mp_obj_tuple_t * tuple = MP_OBJ_TO_PTR (mp_obj_new_tuple (dimensions , NULL ));
228+
229+ // A regular output should have C channels, and 1 for everything else
230+ // TODO: support other shapes?
231+ //dims==1, 11c
232+ if (!(o -> out_dims [0 ] == 1 && o -> out_dims [1 ] == 1 && o -> out_dims [2 ] == 1 )) {
233+ mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("wrong output shape" ));
197234 }
198235
199- return mp_obj_new_int (maxi );
236+ tuple -> items [0 ] = mp_obj_new_int (o -> out_dims [3 ]);
237+ return tuple ;
200238}
201- static MP_DEFINE_CONST_FUN_OBJ_2 ( mod_cnn_run_obj , mod_cnn_run ) ;
239+ static MP_DEFINE_CONST_FUN_OBJ_1 ( mod_cnn_output_dimensions_obj , mod_cnn_output_dimensions ) ;
202240
203241
204- mp_map_elem_t mod_locals_dict_table [2 ];
242+ mp_map_elem_t mod_locals_dict_table [3 ];
205243static MP_DEFINE_CONST_DICT (mod_locals_dict , mod_locals_dict_table ) ;
206244
207245// This is the entry point and is called when the module is imported
@@ -217,6 +255,7 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
217255 // methods
218256 mod_locals_dict_table [0 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_run ), MP_OBJ_FROM_PTR (& mod_cnn_run_obj ) };
219257 mod_locals_dict_table [1 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR___del__ ), MP_OBJ_FROM_PTR (& mod_cnn_del_obj ) };
258+ mod_locals_dict_table [2 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_output_dimensions ), MP_OBJ_FROM_PTR (& mod_cnn_output_dimensions_obj ) };
220259
221260 MP_OBJ_TYPE_SET_SLOT (& mod_cnn_type , locals_dict , (void * )& mod_locals_dict , 2 );
222261
0 commit comments