@@ -114,6 +114,7 @@ struct sam_state {
114114 struct ggml_tensor * embd_img;
115115 struct ggml_tensor * embd_prompt_sparse;
116116 struct ggml_tensor * embd_prompt_dense;
117+ struct ggml_tensor * pe_img_dense;
117118
118119 struct ggml_context * ctx;
119120
@@ -532,7 +533,7 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
532533
533534 // key + value memory
534535 {
535- const auto & hparams = model.hparams ;
536+ // const auto & hparams = model.hparams;
536537
537538 // TODO
538539 }
@@ -630,6 +631,88 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
630631 return true ;
631632}
632633
634+ bool sam_fill_dense_pe (
635+ const sam_model & model,
636+ sam_state & state,
637+ int n_threads) {
638+ const auto & hparams = model.hparams ;
639+ const auto & enc = model.enc_prompt ;
640+
641+ const int32_t n_img_embd = hparams.n_img_embd ();
642+ const float n_img_embd_inv = 1 .0f / n_img_embd;
643+
644+ static size_t buf_size = 256u *1024 *1024 ;
645+ static void * buf = malloc (buf_size);
646+
647+ struct ggml_init_params params = {
648+ /* .mem_size =*/ buf_size,
649+ /* .mem_buffer =*/ buf,
650+ /* .no_alloc =*/ false ,
651+ };
652+
653+ struct ggml_context * ctx0 = ggml_init (params);
654+ struct ggml_cgraph gf = {};
655+
656+ struct ggml_tensor * xy_embed_stacked = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, 2 , n_img_embd, n_img_embd);
657+
658+ {
659+ float * data = (float *) ggml_get_data (xy_embed_stacked);
660+ for (int i = 0 ; i < n_img_embd; ++i) {
661+ const int row = 2 *i*n_img_embd;
662+ const float y_val = 2 * (i + 0 .5f ) * n_img_embd_inv - 1 ;
663+ for (int j = 0 ; j < n_img_embd; ++j) {
664+ const float x_val = 2 * (j + 0 .5f ) * n_img_embd_inv - 1 ;
665+ data[row + 2 *j + 0 ] = x_val;
666+ data[row + 2 *j + 1 ] = y_val;
667+ }
668+ }
669+ }
670+
671+ struct ggml_tensor * cur = ggml_mul_mat (ctx0, ggml_cont (ctx0, ggml_transpose (ctx0, enc.pe )), xy_embed_stacked);
672+
673+ cur = ggml_scale (ctx0, cur, ggml_new_f32 (ctx0, 2 .0f *M_PI));
674+
675+ // concat
676+ // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192
677+ {
678+ struct ggml_tensor * t_sin = ggml_map_unary_f32 (ctx0, cur, ggml_sam_sin);
679+ struct ggml_tensor * t_cos = ggml_map_unary_f32 (ctx0, cur, ggml_sam_cos);
680+
681+ cur = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, t_sin->ne [0 ] + t_cos->ne [0 ], cur->ne [1 ], cur->ne [2 ]);
682+
683+ ggml_build_forward_expand (&gf, ggml_cpy (ctx0, t_sin, ggml_view_3d (ctx0, cur, t_sin->ne [0 ], t_sin->ne [1 ], t_sin->ne [2 ], cur->nb [1 ], cur->nb [2 ], 0 )));
684+ ggml_build_forward_expand (&gf, ggml_cpy (ctx0, t_cos, ggml_view_3d (ctx0, cur, t_sin->ne [0 ], t_sin->ne [1 ], t_sin->ne [2 ], cur->nb [1 ], cur->nb [2 ], t_sin->nb [1 ])));
685+ }
686+
687+ cur = ggml_cont (ctx0, ggml_permute (ctx0, cur, 2 , 0 , 1 , 3 ));
688+
689+ // TODO: avoid copy
690+ cur = ggml_cpy (ctx0, cur, state.pe_img_dense );
691+
692+ // run the computation
693+ ggml_set_name (cur, " check" );
694+ ggml_build_forward_expand (&gf, cur);
695+ ggml_graph_compute_with_ctx (ctx0, &gf, n_threads);
696+
697+ // auto * t = ggml_get_tensor(ctx0, "check");
698+ // auto print_t_f32 = [&](struct ggml_tensor * t) {
699+ // float * data = (float *)t->data;
700+ // printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
701+ // for (int i = 0; i < 256; i++) {
702+ // printf("%f ", data[256*64*63 + 63*256 + i]);
703+ // }
704+ // printf("\n");
705+ // double sum = 0.0;
706+ // for (int i = 0; i < ggml_nelements(t); i++) {
707+ // sum += data[i];
708+ // }
709+ // printf("sum: %f\n", sum);
710+ // };
711+ // print_t_f32(t);
712+
713+ return true ;
714+ }
715+
633716bool sam_encode_image (
634717 const sam_model & model,
635718 sam_state & state,
@@ -1254,6 +1337,14 @@ int main(int argc, char ** argv) {
12541337
12551338 state.embd_prompt_dense = ggml_new_tensor_3d (state.ctx , GGML_TYPE_F32,
12561339 model.hparams .n_img_embd (), model.hparams .n_img_embd (), model.hparams .n_enc_out_chans );
1340+
1341+ state.pe_img_dense = ggml_new_tensor_3d (state.ctx , GGML_TYPE_F32,
1342+ model.hparams .n_img_embd (), model.hparams .n_img_embd (), model.hparams .n_enc_out_chans );
1343+ }
1344+
1345+ if (!sam_fill_dense_pe (model, state, params.n_threads )) {
1346+ fprintf (stderr, " %s: failed to get dense positional encoding\n " , __func__);
1347+ return 1 ;
12571348 }
12581349
12591350 if (!sam_encode_image (model, state, img1, params.n_threads )) {
0 commit comments