Skip to content

Commit 29b5359

Browse files
committed
feat: enhance NPU kernel parameters and streamline conditional compilation.
1 parent 67a2090 commit 29b5359

File tree

15 files changed

+72
-97
lines changed

15 files changed

+72
-97
lines changed

xllm/core/framework/parallel_state/collective_communicator.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ limitations under the License.
1818
#include "mapping_npu.h"
1919

2020
#if defined(USE_NPU)
21-
#include <torch_npu/csrc/distributed/ProcessGroupHCCL.hpp>
22-
2321
#include "npu_process_group.h"
2422
#include "xllm_kernels/core/include/atb_speed/base/external_comm_manager.h"
2523
#include "xllm_kernels/core/include/atb_speed/utils/singleton.h"

xllm/core/framework/parallel_state/npu_process_group.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ limitations under the License.
1818
#include "hccl/hccl.h"
1919
#include "process_group.h"
2020

21-
namespace c10d_npu {
22-
class ProcessGroupHCCL;
23-
}
24-
2521
namespace xllm {
2622

2723
class ProcessGroupHCCL : public ProcessGroup {

xllm/core/framework/parallel_state/process_group.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ class ProcessGroup {
6666

6767
protected:
6868
#if defined(USE_NPU)
69+
// Using ProcessGroupHCCL for NPU devices
70+
// Note: torch_npu uses an older torch version where c10d::Backend lacks
71+
// shutdown() method
6972
std::unique_ptr<c10d_npu::ProcessGroupHCCL> pg_{nullptr};
7073
#else
7174
std::unique_ptr<c10d::Backend> pg_{nullptr};

xllm/core/kernels/npu/active.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ limitations under the License.
2020

2121
namespace xllm::kernel::npu {
2222

23-
torch::Tensor active(const torch::Tensor& input) {
23+
torch::Tensor active(const torch::Tensor& input, const std::string& act_mode) {
24+
if (act_mode != "silu" && act_mode != "swiglu") {
25+
throw std::runtime_error(
26+
"Only swiglu activation is supported in NPU active");
27+
}
2428
return at_npu::native::custom_ops::npu_swiglu(input);
2529
}
2630
} // namespace xllm::kernel::npu

xllm/core/kernels/npu/attention.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,34 @@ void batch_prefill(const torch::Tensor& query,
3131
const torch::Tensor& mask,
3232
const torch::Tensor& seq_len,
3333
float scale,
34-
int num_heads,
35-
int num_kv_heads,
3634
torch::Tensor& output) {
35+
auto num_heads = query.size(-2);
36+
auto num_kv_heads = key.size(-2);
3737
atb::_npu_flash_attention(
3838
query, key, value, mask, seq_len, scale, num_heads, num_kv_heads, output);
3939
}
4040

4141
void batch_decode(const torch::Tensor& query,
4242
const torch::Tensor& k_cache,
4343
const torch::Tensor& v_cache,
44-
int num_kv_heads,
45-
int num_heads,
4644
float scale,
4745
const torch::Tensor& block_table,
4846
const torch::Tensor& seq_lens,
4947
torch::Tensor& output) {
50-
atb::_npu_paged_attention(query,
48+
auto head_size = query.size(-1);
49+
auto num_heads = query.size(-2);
50+
auto num_kv_heads = k_cache.size(-2);
51+
auto q = query.view({-1, num_heads, head_size});
52+
auto o = output.view({-1, num_heads, head_size});
53+
atb::_npu_paged_attention(q,
5154
k_cache,
5255
v_cache,
5356
num_kv_heads,
5457
num_heads,
5558
scale,
5659
block_table,
5760
seq_lens,
58-
output);
61+
o);
5962
}
6063

6164
} // namespace xllm::kernel::npu

xllm/core/kernels/npu/fused_layernorm.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ namespace xllm::kernel::npu {
2121

2222
torch::Tensor fused_layernorm(const torch::Tensor& input,
2323
const torch::Tensor& weight,
24-
double eps) {
24+
double eps,
25+
const std::string& mode) {
26+
if (mode != "rmsnorm") {
27+
throw std::runtime_error(
28+
"Only rmsnorm mode is supported in NPU fused_layernorm");
29+
}
2530
std::tuple<at::Tensor, at::Tensor> result =
2631
at_npu::native::custom_ops::npu_rms_norm(input, weight, eps);
2732
auto normalized_input = std::get<0>(result);

xllm/core/kernels/npu/npu_ops_api.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,11 @@ void batch_prefill(const torch::Tensor& query,
3434
const torch::Tensor& mask,
3535
const torch::Tensor& seq_len,
3636
float scale,
37-
int num_heads,
38-
int num_kv_heads,
3937
torch::Tensor& output);
4038

4139
void batch_decode(const torch::Tensor& query,
4240
const torch::Tensor& k_cache,
4341
const torch::Tensor& v_cache,
44-
int num_kv_heads,
45-
int num_heads,
4642
float scale,
4743
const torch::Tensor& block_table,
4844
const torch::Tensor& seq_lens,
@@ -52,11 +48,12 @@ torch::Tensor matmul(const torch::Tensor& a,
5248
const torch::Tensor& b,
5349
const std::optional<torch::Tensor>& bias);
5450

55-
torch::Tensor active(const torch::Tensor& input);
51+
torch::Tensor active(const torch::Tensor& input, const std::string& act_mode);
5652

5753
torch::Tensor fused_layernorm(const torch::Tensor& input,
5854
const torch::Tensor& weight,
59-
double eps);
55+
double eps,
56+
const std::string& mode);
6057

6158
void apply_rotary(torch::Tensor& q,
6259
torch::Tensor& k,

xllm/core/kernels/ops_api.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void active(ActivationParams& params) {
5454

5555
torch::Tensor active_tensor(ActivationParams& params) {
5656
#if defined(USE_NPU)
57-
return npu::active(params.input);
57+
return npu::active(params.input, params.act_mode);
5858
#else
5959
throw std::runtime_error("active not implemented");
6060
#endif
@@ -110,8 +110,6 @@ void batch_prefill(AttentionParams& params) {
110110
params.attn_mask,
111111
params.seq_lens,
112112
params.scale,
113-
params.num_heads,
114-
params.num_kv_heads,
115113
params.output);
116114
#else
117115
throw std::runtime_error("batch_prefill not implemented");
@@ -144,8 +142,6 @@ void batch_decode(AttentionParams& params) {
144142
npu::batch_decode(params.query,
145143
params.k_cache,
146144
params.v_cache,
147-
params.num_kv_heads,
148-
params.num_heads,
149145
params.scale,
150146
params.block_table.value(),
151147
params.seq_lens,
@@ -179,7 +175,8 @@ void fused_layernorm(FusedLayerNormParams& params) {
179175

180176
torch::Tensor fused_layernorm_tensor(FusedLayerNormParams& params) {
181177
#if defined(USE_NPU)
182-
return npu::fused_layernorm(params.input, params.weight, params.eps);
178+
return npu::fused_layernorm(
179+
params.input, params.weight, params.eps, params.mode);
183180
#else
184181
throw std::runtime_error("fused_layernorm not implemented");
185182
#endif

xllm/core/kernels/param.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,8 @@ struct AttentionParams {
8282
bool return_lse = false;
8383
// for npu
8484
torch::Tensor seq_lens;
85-
int num_heads;
86-
int num_kv_heads;
8785
torch::Tensor attn_mask;
86+
8887
// for flashinfer
8988
torch::Tensor paged_kv_indptr;
9089
torch::Tensor paged_kv_indices;

xllm/core/layers/common/attention.cpp

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,18 @@ DECLARE_bool(enable_chunked_prefill);
2121
namespace xllm {
2222
namespace layer {
2323

24-
#if defined(USE_NPU)
25-
AttentionMetadata AttentionMetadata::build(const ModelInputParams& params,
26-
bool is_prefill,
27-
const torch::Tensor& attn_mask) {
24+
AttentionMetadata AttentionMetadata::build(
25+
const ModelInputParams& params,
26+
bool is_prefill,
27+
const std::optional<torch::Tensor>& attn_mask) {
2828
return AttentionMetadata::build(params, "float", is_prefill, attn_mask);
2929
}
30-
#else
31-
AttentionMetadata AttentionMetadata::build(const ModelInputParams& params,
32-
bool is_prefill) {
33-
return AttentionMetadata::build(params, "float", is_prefill);
34-
}
35-
#endif
3630

37-
#if defined(USE_NPU)
38-
AttentionMetadata AttentionMetadata::build(const ModelInputParams& params,
39-
const std::string& compute_dtype,
40-
bool is_prefill,
41-
const torch::Tensor& attn_mask) {
42-
#else
43-
AttentionMetadata AttentionMetadata::build(const ModelInputParams& params,
44-
const std::string& compute_dtype,
45-
bool is_prefill) {
46-
#endif
31+
AttentionMetadata AttentionMetadata::build(
32+
const ModelInputParams& params,
33+
const std::string& compute_dtype,
34+
bool is_prefill,
35+
const std::optional<torch::Tensor>& attn_mask) {
4736
AttentionMetadata attn_metadata;
4837
attn_metadata.query_start_loc = params.q_seq_lens;
4938
attn_metadata.seq_start_loc = params.kv_seq_lens;
@@ -52,10 +41,11 @@ AttentionMetadata AttentionMetadata::build(const ModelInputParams& params,
5241
attn_metadata.slot_mapping = params.new_cache_slots;
5342
attn_metadata.compute_dtype = compute_dtype;
5443

55-
#if defined(USE_NPU)
56-
attn_metadata.attn_mask = attn_mask;
57-
attn_metadata.seq_lens = params.kv_seq_lens.to(torch::kCPU);
58-
#endif
44+
// for npu
45+
if (attn_mask.has_value()) {
46+
attn_metadata.attn_mask = attn_mask.value();
47+
attn_metadata.seq_lens = params.kv_seq_lens.to(torch::kCPU);
48+
}
5949

6050
bool is_start_loc_match = (params.q_seq_lens_vec == params.kv_seq_lens_vec);
6151
attn_metadata.is_chunked_prefill = is_prefill && !is_start_loc_match;
@@ -123,8 +113,6 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> AttentionImpl::forward(
123113
attention_params.seq_start_loc = attn_metadata.seq_start_loc;
124114
attention_params.max_query_len = attn_metadata.max_query_len;
125115
#if defined(USE_NPU)
126-
attention_params.num_heads = num_heads_;
127-
attention_params.num_kv_heads = num_kv_heads_;
128116
attention_params.attn_mask = attn_metadata.attn_mask;
129117
attention_params.seq_lens = attn_metadata.seq_lens;
130118
#endif
@@ -139,15 +127,10 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> AttentionImpl::forward(
139127

140128
xllm::kernel::batch_prefill(attention_params);
141129
} else {
142-
#if defined(USE_NPU)
143-
query = query.view({-1, num_heads_, head_size_});
144-
output = output.view({-1, num_heads_, head_size_});
145-
attention_params.num_heads = num_heads_;
146-
attention_params.num_kv_heads = num_kv_heads_;
147-
attention_params.seq_lens = attn_metadata.seq_lens;
148-
#else
149130
query = query.view({-1, 1, num_heads_, head_size_});
150131
output = output.view({-1, 1, num_heads_, head_size_});
132+
#if defined(USE_NPU)
133+
attention_params.seq_lens = attn_metadata.seq_lens;
151134
#endif
152135

153136
attention_params.query = query;

0 commit comments

Comments
 (0)