@@ -66,11 +66,13 @@ class CausalLM : public torch::nn::Module {
6666
6767 virtual const torch::TensorOptions& options () const = 0;
6868
69- virtual layer::LmHead get_lm_head () = 0;
70- virtual void set_lm_head (layer::LmHead& head) = 0;
69+ #if defined(USE_NPU)
70+ virtual layer::NpuLmHead get_lm_head () = 0;
71+ virtual void set_lm_head (layer::NpuLmHead& head) = 0;
7172 virtual std::vector<layer::WordEmbedding> get_word_embedding () = 0;
7273 virtual void set_word_embedding (
7374 std::vector<layer::WordEmbedding>& embedding) = 0;
75+ #endif
7476};
7577
7678template <typename Model>
@@ -104,10 +106,12 @@ class CausalLMImpl : public CausalLM {
104106 virtual void update_expert_weight (int32_t layer_id) {
105107 return model_->update_expert_weight (layer_id);
106108 }
109+ #if defined(USE_NPU)
110+ layer::NpuLmHead get_lm_head () override { return model_->get_lm_head (); };
107111
108- layer::LmHead get_lm_head ( ) override { return model_-> get_lm_head (); };
109-
110- void set_lm_head (layer::LmHead& head) override { model_-> set_lm_head (head); };
112+ void set_lm_head ( layer::NpuLmHead& head ) override {
113+ model_-> set_lm_head (head);
114+ };
111115
112116 std::vector<layer::WordEmbedding> get_word_embedding () override {
113117 return model_->get_word_embedding ();
@@ -117,7 +121,7 @@ class CausalLMImpl : public CausalLM {
117121 std::vector<layer::WordEmbedding>& embedding) override {
118122 model_->set_word_embedding (embedding);
119123 };
120-
124+ # endif
121125 torch::Device device () const override { return options_.device (); }
122126
123127 const torch::TensorOptions& options () const override { return options_; }
0 commit comments