Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion examples/sam/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ struct sam_state {
struct ggml_tensor * embd_img;
struct ggml_tensor * embd_prompt_sparse;
struct ggml_tensor * embd_prompt_dense;
struct ggml_tensor * pe_img_dense;

struct ggml_context * ctx;

Expand Down Expand Up @@ -532,7 +533,7 @@ bool sam_model_load(const std::string & fname, sam_model & model) {

// key + value memory
{
const auto & hparams = model.hparams;
// const auto & hparams = model.hparams;

// TODO
}
Expand Down Expand Up @@ -630,6 +631,88 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
return true;
}

bool sam_fill_dense_pe(
const sam_model & model,
sam_state & state,
int n_threads) {
const auto & hparams = model.hparams;
const auto & enc = model.enc_prompt;

const int32_t n_img_embd = hparams.n_img_embd();
const float n_img_embd_inv = 1.0f / n_img_embd;

static size_t buf_size = 256u*1024*1024;
static void * buf = malloc(buf_size);

struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
/*.no_alloc =*/ false,
};

struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};

struct ggml_tensor * xy_embed_stacked = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 2, n_img_embd, n_img_embd);

{
float * data = (float *) ggml_get_data(xy_embed_stacked);
for (int i = 0; i < n_img_embd; ++i) {
const int row = 2*i*n_img_embd;
const float y_val = 2 * (i + 0.5f) * n_img_embd_inv - 1;
for (int j = 0; j < n_img_embd; ++j) {
const float x_val = 2 * (j + 0.5f) * n_img_embd_inv - 1;
data[row + 2*j + 0] = x_val;
data[row + 2*j + 1] = y_val;
}
}
}

struct ggml_tensor * cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, enc.pe)), xy_embed_stacked);

cur = ggml_scale(ctx0, cur, ggml_new_f32(ctx0, 2.0f*M_PI));

// concat
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192
{
struct ggml_tensor * t_sin = ggml_map_unary_f32(ctx0, cur, ggml_sam_sin);
struct ggml_tensor * t_cos = ggml_map_unary_f32(ctx0, cur, ggml_sam_cos);

cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1], cur->ne[2]);

ggml_build_forward_expand(&gf, ggml_cpy(ctx0, t_sin, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], 0)));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, t_cos, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], t_sin->nb[1])));
}

cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));

// TODO: avoid copy
cur = ggml_cpy(ctx0, cur, state.pe_img_dense);

// run the computation
ggml_set_name(cur, "check");
ggml_build_forward_expand(&gf, cur);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);

// auto * t = ggml_get_tensor(ctx0, "check");
// auto print_t_f32 = [&](struct ggml_tensor * t) {
// float * data = (float *)t->data;
// printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
// for (int i = 0; i < 256; i++) {
// printf("%f ", data[256*64*63 + 63*256 + i]);
// }
// printf("\n");
// double sum = 0.0;
// for (int i = 0; i < ggml_nelements(t); i++) {
// sum += data[i];
// }
// printf("sum: %f\n", sum);
// };
// print_t_f32(t);

return true;
}

bool sam_encode_image(
const sam_model & model,
sam_state & state,
Expand Down Expand Up @@ -1254,6 +1337,14 @@ int main(int argc, char ** argv) {

state.embd_prompt_dense = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32,
model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans);

state.pe_img_dense = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32,
model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans);
}

if (!sam_fill_dense_pe(model, state, params.n_threads)) {
fprintf(stderr, "%s: failed to get dense positional encoding\n", __func__);
return 1;
}

if (!sam_encode_image(model, state, img1, params.n_threads)) {
Expand Down