Skip to content

Commit 08bfc9d

Browse files
committed
fix regression in embedding block
1 parent bc991e3 commit 08bfc9d

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/layers.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,14 +1545,26 @@ namespace chatllm
15451545
ggml::tensor *Embedding::forward(ComputeContext *ctx, ggml::tensor *input)
15461546
{
15471547
ggml::tensor *output;
1548-
CHATLLM_CHECK(ggml::type::GGML_TYPE_I32 == ggml::type_of(input));
1549-
CHATLLM_CHECK(ggml::n_dims(input) <= 3);
15501548

1551-
ggml::tensor *flattend = ggml::flatten(ctx, input);
1552-
output = ggml::get_rows(ctx, weight, flattend);
1553-
output = ggml::reshape(ctx, output, ggml::get_dim(output, 0),
1554-
ggml::get_dim(input, 0), ggml::get_dim(input, 1), ggml::get_dim(input, 2));
1549+
if (ggml::type::GGML_TYPE_I32 == ggml::type_of(input))
1550+
{
1551+
CHATLLM_CHECK(ggml::n_dims(input) <= 3);
15551552

1553+
ggml::tensor *flattend = ggml::flatten(ctx, input);
1554+
output = ggml::get_rows(ctx, weight, flattend);
1555+
output = ggml::reshape(ctx, output, ggml::get_dim(output, 0),
1556+
ggml::get_dim(input, 0), ggml::get_dim(input, 1), ggml::get_dim(input, 2));
1557+
}
1558+
else
1559+
{
1560+
// fot tie lm head
1561+
ggml::tensor *w = weight;
1562+
if (num_padded_embeddings > 0)
1563+
{
1564+
w = ggml::view_2d(ctx, weight, ggml::get_dim(weight, 0), num_embeddings, ggml::row_size(ggml::type_of(weight), ggml::get_dim(weight, 0)), 0);
1565+
}
1566+
output = ggml::mul_mat(ctx, w, input);
1567+
}
15561568
return output;
15571569
}
15581570

0 commit comments

Comments
 (0)