Skip to content

Commit 4966a97

Browse files
committed
refactor: share single Lmhead class across NPU and other hardware.
1 parent 1a5e2f0 commit 4966a97

17 files changed

+109
-126
lines changed

xllm/core/framework/model/causal_lm.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,13 @@ class CausalLM : public torch::nn::Module {
6666

6767
virtual const torch::TensorOptions& options() const = 0;
6868

69-
virtual layer::LmHead get_lm_head() = 0;
70-
virtual void set_lm_head(layer::LmHead& head) = 0;
69+
#if defined(USE_NPU)
70+
virtual layer::NpuLmHead get_lm_head() = 0;
71+
virtual void set_lm_head(layer::NpuLmHead& head) = 0;
7172
virtual std::vector<layer::WordEmbedding> get_word_embedding() = 0;
7273
virtual void set_word_embedding(
7374
std::vector<layer::WordEmbedding>& embedding) = 0;
75+
#endif
7476
};
7577

7678
template <typename Model>
@@ -104,10 +106,12 @@ class CausalLMImpl : public CausalLM {
104106
virtual void update_expert_weight(int32_t layer_id) {
105107
return model_->update_expert_weight(layer_id);
106108
}
109+
#if defined(USE_NPU)
110+
layer::NpuLmHead get_lm_head() override { return model_->get_lm_head(); };
107111

108-
layer::LmHead get_lm_head() override { return model_->get_lm_head(); };
109-
110-
void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); };
112+
void set_lm_head(layer::NpuLmHead& head) override {
113+
model_->set_lm_head(head);
114+
};
111115

112116
std::vector<layer::WordEmbedding> get_word_embedding() override {
113117
return model_->get_word_embedding();
@@ -117,7 +121,7 @@ class CausalLMImpl : public CausalLM {
117121
std::vector<layer::WordEmbedding>& embedding) override {
118122
model_->set_word_embedding(embedding);
119123
};
120-
124+
#endif
121125
torch::Device device() const override { return options_.device(); }
122126

123127
const torch::TensorOptions& options() const override { return options_; }

xllm/core/framework/model/causal_vlm.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ class CausalVLMImpl : public CausalVLM {
6363
}
6464

6565
virtual void update_expert_weight(int32_t layer_id) { return; }
66+
#if defined(USE_NPU)
67+
layer::NpuLmHead get_lm_head() override { return model_->get_lm_head(); };
6668

67-
layer::LmHead get_lm_head() override { return model_->get_lm_head(); };
68-
69-
void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); };
69+
void set_lm_head(layer::NpuLmHead& head) override {
70+
model_->set_lm_head(head);
71+
};
7072

7173
std::vector<layer::WordEmbedding> get_word_embedding() override {
7274
return model_->get_word_embedding();
@@ -76,7 +78,7 @@ class CausalVLMImpl : public CausalVLM {
7678
std::vector<layer::WordEmbedding>& embedding) override {
7779
model_->set_word_embedding(embedding);
7880
};
79-
81+
#endif
8082
torch::Device device() const override { return options_.device(); }
8183

8284
const torch::TensorOptions& options() const override { return options_; }

xllm/core/layers/lm_head.h

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,42 +24,16 @@ namespace xllm {
2424
namespace layer {
2525

2626
#if defined(USE_NPU)
27-
class LmHead : public torch::nn::ModuleHolder<NpuLmHeadImpl> {
27+
class NpuLmHead : public torch::nn::ModuleHolder<NpuLmHeadImpl> {
2828
public:
2929
using torch::nn::ModuleHolder<NpuLmHeadImpl>::ModuleHolder;
3030
using Impl __attribute__((__unused__)) = NpuLmHeadImpl;
3131

32-
LmHead(const ModelContext& context)
32+
NpuLmHead(const ModelContext& context)
3333
: ModuleHolder(std::make_shared<NpuLmHeadImpl>(context)) {}
3434
};
3535

36-
/**
37-
* TODO: Rename the original LmHead definition to NpuLmHead,
38-
* and define the current one as LmHead to unify NPU's LmHead
39-
* related code with MLU and GPU
40-
*/
41-
class LmHeadNative : public torch::nn::ModuleHolder<ColumnParallelLinearImpl> {
42-
public:
43-
using torch::nn::ModuleHolder<ColumnParallelLinearImpl>::ModuleHolder;
44-
using Impl __attribute__((__unused__)) = ColumnParallelLinearImpl;
45-
46-
LmHeadNative(int64_t in_features,
47-
int64_t out_features,
48-
bool bias,
49-
bool gather_output,
50-
const QuantArgs& quant_args,
51-
const ParallelArgs& parallel_args,
52-
const torch::TensorOptions& options)
53-
: ModuleHolder(std::make_shared<ColumnParallelLinearImpl>(in_features,
54-
out_features,
55-
bias,
56-
gather_output,
57-
quant_args,
58-
parallel_args,
59-
options)) {}
60-
};
61-
62-
#else
36+
#endif
6337
class LmHead : public torch::nn::ModuleHolder<ColumnParallelLinearImpl> {
6438
public:
6539
using torch::nn::ModuleHolder<ColumnParallelLinearImpl>::ModuleHolder;
@@ -80,7 +54,6 @@ class LmHead : public torch::nn::ModuleHolder<ColumnParallelLinearImpl> {
8054
parallel_args,
8155
options)) {}
8256
};
83-
#endif
8457

8558
} // namespace layer
8659
} // namespace xllm

xllm/core/runtime/acl_graph_executor_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,12 @@ class SimpleCausalLM : public CausalLM {
234234
// Simple implementation for testing
235235
}
236236

237-
layer::LmHead get_lm_head() override {
237+
layer::NpuLmHead get_lm_head() override {
238238
// Simple implementation for testing
239-
return layer::LmHead(nullptr);
239+
return layer::NpuLmHead(nullptr);
240240
}
241241

242-
void set_lm_head(layer::LmHead& head) override {
242+
void set_lm_head(layer::NpuLmHead& head) override {
243243
// Simple implementation for testing
244244
}
245245

xllm/core/runtime/llm_worker_impl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ class LLMWorkerImpl : public WorkerImpl {
4444

4545
std::optional<ForwardOutput> step(
4646
const BatchedForwardInputs& inputs) override;
47+
#if defined(USE_NPU)
48+
layer::NpuLmHead get_lm_head() { return model_->get_lm_head(); };
4749

48-
layer::LmHead get_lm_head() { return model_->get_lm_head(); };
49-
50-
void set_lm_head(layer::LmHead& head) { model_->set_lm_head(head); };
50+
void set_lm_head(layer::NpuLmHead& head) { model_->set_lm_head(head); };
5151

5252
std::vector<layer::WordEmbedding> get_word_embedding() {
5353
return model_->get_word_embedding();
@@ -56,7 +56,7 @@ class LLMWorkerImpl : public WorkerImpl {
5656
void set_word_embedding(std::vector<layer::WordEmbedding>& embedding) {
5757
model_->set_word_embedding(embedding);
5858
};
59-
59+
#endif
6060
private:
6161
std::unique_ptr<BeamSearcher> beam_searcher_;
6262
};

xllm/core/runtime/speculative_worker_impl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,15 @@ bool SpeculativeWorkerImpl::init_model(const std::string& model_weights_path) {
108108
CHECK_EQ(draft_impl_->get_status(), WorkerImpl::Status::UNINITIALIZED);
109109
result = draft_impl_->WorkerImpl::init_model(model_weights_path);
110110
}
111-
111+
#if defined(USE_NPU)
112112
if (draft_impl_->get_status() == WorkerImpl::Status::LOADED) {
113113
// Deepseek MTP
114114
auto head = impl_->get_lm_head();
115115
draft_impl_->set_lm_head(head);
116116
auto word_embedding = impl_->get_word_embedding();
117117
draft_impl_->set_word_embedding(word_embedding);
118118
}
119+
#endif
119120
return result;
120121
}
121122

xllm/models/llm/deepseek_v2.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ class DeepseekV2ForCausalLMImpl : public torch::nn::Module {
297297
public:
298298
DeepseekV2ForCausalLMImpl(const ModelContext& context) {
299299
model_ = register_module("model", DeepseekV2Model(context));
300-
lm_head_ = register_module("lm_head", layer::LmHead(context));
300+
lm_head_ = register_module("lm_head", layer::NpuLmHead(context));
301301
first_k_dense_replace_ = context.get_model_args().first_k_dense_replace();
302302
}
303303

@@ -342,10 +342,10 @@ class DeepseekV2ForCausalLMImpl : public torch::nn::Module {
342342
void update_expert_weight(int32_t layer_id) {
343343
model_->update_expert_weight(layer_id + first_k_dense_replace_);
344344
}
345+
#if defined(USE_NPU)
346+
layer::NpuLmHead get_lm_head() { return lm_head_; }
345347

346-
layer::LmHead get_lm_head() { return lm_head_; }
347-
348-
void set_lm_head(layer::LmHead& head) { lm_head_ = head; }
348+
void set_lm_head(layer::NpuLmHead& head) { lm_head_ = head; }
349349

350350
std::vector<layer::WordEmbedding> get_word_embedding() {
351351
return model_->get_word_embedding();
@@ -355,9 +355,11 @@ class DeepseekV2ForCausalLMImpl : public torch::nn::Module {
355355
model_->set_word_embedding(word_embedding);
356356
}
357357

358+
private:
359+
layer::NpuLmHead lm_head_{nullptr};
360+
#endif
358361
private:
359362
DeepseekV2Model model_{nullptr};
360-
layer::LmHead lm_head_{nullptr};
361363
int32_t first_k_dense_replace_;
362364
};
363365
TORCH_MODULE(DeepseekV2ForCausalLM);

xllm/models/llm/deepseek_v2_mtp.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,10 @@ class DeepseekV2MtpForCausalLMImpl : public torch::nn::Module {
295295
return;
296296
}
297297
void update_expert_weight(int32_t layer_id) { return; }
298-
layer::LmHead get_lm_head() { return lm_head_; }
298+
#if defined(USE_NPU)
299+
layer::NpuLmHead get_lm_head() { return lm_head_; }
299300

300-
void set_lm_head(layer::LmHead& head) { lm_head_ = head; }
301+
void set_lm_head(layer::NpuLmHead& head) { lm_head_ = head; }
301302

302303
std::vector<layer::WordEmbedding> get_word_embedding() {
303304
return model_->get_word_embedding();
@@ -307,9 +308,11 @@ class DeepseekV2MtpForCausalLMImpl : public torch::nn::Module {
307308
model_->set_word_embedding(word_embedding);
308309
}
309310

311+
private:
312+
layer::NpuLmHead lm_head_{nullptr};
313+
#endif
310314
private:
311315
DeepseekV2MtpModel model_{nullptr};
312-
layer::LmHead lm_head_{nullptr};
313316
};
314317
TORCH_MODULE(DeepseekV2MtpForCausalLM);
315318

xllm/models/llm/embedding_model_base.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ class LlmForEmbeddingImplBase : public torch::nn::Module {
7373
return;
7474
}
7575
virtual void update_expert_weight(int32_t layer_id) { return; }
76+
#if defined(USE_NPU)
77+
virtual layer::NpuLmHead get_lm_head() { return lm_head_; }
7678

77-
virtual layer::LmHead get_lm_head() { return lm_head_; }
78-
79-
virtual void set_lm_head(layer::LmHead& head) { lm_head_ = head; }
79+
virtual void set_lm_head(layer::NpuLmHead& head) { lm_head_ = head; }
8080

8181
virtual std::vector<layer::WordEmbedding> get_word_embedding() {
8282
return model_->get_word_embedding();
@@ -87,13 +87,15 @@ class LlmForEmbeddingImplBase : public torch::nn::Module {
8787
model_->set_word_embedding(word_embedding);
8888
}
8989

90+
protected:
91+
layer::NpuLmHead lm_head_{nullptr};
92+
#endif
9093
protected:
9194
// parameter members, must be registered
9295
LlmModelType model_{nullptr};
9396
int device_id = 0;
9497
bool tie_word_embeddings{false};
9598
// test
96-
layer::LmHead lm_head_{nullptr};
9799
};
98100

99101
} // namespace xllm

xllm/models/llm/glm4_moe.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ class Glm4MoeForCausalLMImpl : public torch::nn::Module {
254254
public:
255255
Glm4MoeForCausalLMImpl(const ModelContext& context) {
256256
model_ = register_module("model", Glm4MoeModel(context));
257-
lm_head_ = register_module("lm_head", layer::LmHead(context));
257+
lm_head_ = register_module("lm_head", layer::NpuLmHead(context));
258258
}
259259

260260
// tokens: [num_tokens]
@@ -296,10 +296,10 @@ class Glm4MoeForCausalLMImpl : public torch::nn::Module {
296296
return;
297297
}
298298
virtual void update_expert_weight(int32_t layer_id) { return; }
299+
#if defined(USE_NPU)
300+
layer::NpuLmHead get_lm_head() { return lm_head_; }
299301

300-
layer::LmHead get_lm_head() { return lm_head_; }
301-
302-
void set_lm_head(layer::LmHead& head) { lm_head_ = head; }
302+
void set_lm_head(layer::NpuLmHead& head) { lm_head_ = head; }
303303

304304
std::vector<layer::WordEmbedding> get_word_embedding() {
305305
return model_->get_word_embedding();
@@ -309,9 +309,11 @@ class Glm4MoeForCausalLMImpl : public torch::nn::Module {
309309
model_->set_word_embedding(word_embedding);
310310
}
311311

312+
private:
313+
layer::NpuLmHead lm_head_{nullptr};
314+
#endif
312315
private:
313316
Glm4MoeModel model_{nullptr};
314-
layer::LmHead lm_head_{nullptr};
315317
};
316318
TORCH_MODULE(Glm4MoeForCausalLM);
317319

0 commit comments

Comments
 (0)