Skip to content

Commit 9c98bab

Browse files
committed
add llm_graph_input_one
1 parent 08f7f2b commit 9c98bab

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

src/llama-graph.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,12 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
350350
}
351351
}
352352

353+
void llm_graph_input_one::set_input(const llama_ubatch *) {
354+
GGML_ASSERT(one && ggml_nelements(one) == 1);
355+
float f_one = 1.0f;
356+
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
357+
}
358+
353359
//
354360
// llm_graph_context
355361
//

src/llama-graph.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,17 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
329329
const llama_memory_hybrid_context * mctx;
330330
};
331331

332+
// TODO: remove this when ggml_scale_add is implemented
333+
class llm_graph_input_one : public llm_graph_input_i {
334+
public:
335+
llm_graph_input_one() {}
336+
virtual ~llm_graph_input_one() = default;
337+
338+
void set_input(const llama_ubatch *) override;
339+
340+
ggml_tensor * one = nullptr; // F32
341+
};
342+
332343
//
333344
// llm_graph_result
334345
//

src/llama-model.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9081,8 +9081,13 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
90819081
ggml_tensor * cur;
90829082
ggml_tensor * inpL;
90839083

9084+
// TODO: remove this when ggml_scale_add is implemented
90849085
one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
9085-
one = ggml_cos(ctx0, ggml_scale(ctx0, one, 0.0f)); // cos(0.0f) = 1.0f
9086+
{
9087+
auto inp = std::make_unique<llm_graph_input_one>();
9088+
inp->one = one;
9089+
res->add_input(std::move(inp));
9090+
}
90869091

90879092
inpL = build_inp_embd(model.tok_embd);
90889093

0 commit comments

Comments
 (0)