@@ -1545,14 +1545,26 @@ namespace chatllm
15451545 ggml::tensor *Embedding::forward (ComputeContext *ctx, ggml::tensor *input)
15461546 {
15471547 ggml::tensor *output;
1548- CHATLLM_CHECK (ggml::type::GGML_TYPE_I32 == ggml::type_of (input));
1549- CHATLLM_CHECK (ggml::n_dims (input) <= 3 );
15501548
1551- ggml::tensor *flattend = ggml::flatten (ctx, input);
1552- output = ggml::get_rows (ctx, weight, flattend);
1553- output = ggml::reshape (ctx, output, ggml::get_dim (output, 0 ),
1554- ggml::get_dim (input, 0 ), ggml::get_dim (input, 1 ), ggml::get_dim (input, 2 ));
1549+ if (ggml::type::GGML_TYPE_I32 == ggml::type_of (input))
1550+ {
1551+ CHATLLM_CHECK (ggml::n_dims (input) <= 3 );
15551552
1553+ ggml::tensor *flattend = ggml::flatten (ctx, input);
1554+ output = ggml::get_rows (ctx, weight, flattend);
1555+ output = ggml::reshape (ctx, output, ggml::get_dim (output, 0 ),
1556+ ggml::get_dim (input, 0 ), ggml::get_dim (input, 1 ), ggml::get_dim (input, 2 ));
1557+ }
1558+ else
1559+ {
1560+ // fot tie lm head
1561+ ggml::tensor *w = weight;
1562+ if (num_padded_embeddings > 0 )
1563+ {
1564+ w = ggml::view_2d (ctx, weight, ggml::get_dim (weight, 0 ), num_embeddings, ggml::row_size (ggml::type_of (weight), ggml::get_dim (weight, 0 )), 0 );
1565+ }
1566+ output = ggml::mul_mat (ctx, w, input);
1567+ }
15561568 return output;
15571569 }
15581570
0 commit comments