@@ -21,29 +21,18 @@ DECLARE_bool(enable_chunked_prefill);
2121namespace xllm {
2222namespace layer {
2323
24- # if defined(USE_NPU)
25- AttentionMetadata AttentionMetadata::build ( const ModelInputParams& params,
26- bool is_prefill,
27- const torch::Tensor& attn_mask) {
24+ AttentionMetadata AttentionMetadata::build (
25+ const ModelInputParams& params,
26+ bool is_prefill,
27+ const std::optional< torch::Tensor> & attn_mask) {
2828 return AttentionMetadata::build (params, " float" , is_prefill, attn_mask);
2929}
30- #else
31- AttentionMetadata AttentionMetadata::build (const ModelInputParams& params,
32- bool is_prefill) {
33- return AttentionMetadata::build (params, " float" , is_prefill);
34- }
35- #endif
3630
37- #if defined(USE_NPU)
38- AttentionMetadata AttentionMetadata::build (const ModelInputParams& params,
39- const std::string& compute_dtype,
40- bool is_prefill,
41- const torch::Tensor& attn_mask) {
42- #else
43- AttentionMetadata AttentionMetadata::build (const ModelInputParams& params,
44- const std::string& compute_dtype,
45- bool is_prefill) {
46- #endif
31+ AttentionMetadata AttentionMetadata::build (
32+ const ModelInputParams& params,
33+ const std::string& compute_dtype,
34+ bool is_prefill,
35+ const std::optional<torch::Tensor>& attn_mask) {
4736 AttentionMetadata attn_metadata;
4837 attn_metadata.query_start_loc = params.q_seq_lens ;
4938 attn_metadata.seq_start_loc = params.kv_seq_lens ;
@@ -52,10 +41,11 @@ AttentionMetadata AttentionMetadata::build(const ModelInputParams& params,
5241 attn_metadata.slot_mapping = params.new_cache_slots ;
5342 attn_metadata.compute_dtype = compute_dtype;
5443
55- #if defined(USE_NPU)
56- attn_metadata.attn_mask = attn_mask;
57- attn_metadata.seq_lens = params.kv_seq_lens .to (torch::kCPU );
58- #endif
44+ // for npu
45+ if (attn_mask.has_value ()) {
46+ attn_metadata.attn_mask = attn_mask.value ();
47+ attn_metadata.seq_lens = params.kv_seq_lens .to (torch::kCPU );
48+ }
5949
6050 bool is_start_loc_match = (params.q_seq_lens_vec == params.kv_seq_lens_vec );
6151 attn_metadata.is_chunked_prefill = is_prefill && !is_start_loc_match;
@@ -123,8 +113,6 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> AttentionImpl::forward(
123113 attention_params.seq_start_loc = attn_metadata.seq_start_loc ;
124114 attention_params.max_query_len = attn_metadata.max_query_len ;
125115#if defined(USE_NPU)
126- attention_params.num_heads = num_heads_;
127- attention_params.num_kv_heads = num_kv_heads_;
128116 attention_params.attn_mask = attn_metadata.attn_mask ;
129117 attention_params.seq_lens = attn_metadata.seq_lens ;
130118#endif
@@ -139,15 +127,10 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> AttentionImpl::forward(
139127
140128 xllm::kernel::batch_prefill (attention_params);
141129 } else {
142- #if defined(USE_NPU)
143- query = query.view ({-1 , num_heads_, head_size_});
144- output = output.view ({-1 , num_heads_, head_size_});
145- attention_params.num_heads = num_heads_;
146- attention_params.num_kv_heads = num_kv_heads_;
147- attention_params.seq_lens = attn_metadata.seq_lens ;
148- #else
149130 query = query.view ({-1 , 1 , num_heads_, head_size_});
150131 output = output.view ({-1 , 1 , num_heads_, head_size_});
132+ #if defined(USE_NPU)
133+ attention_params.seq_lens = attn_metadata.seq_lens ;
151134#endif
152135
153136 attention_params.query = query;
0 commit comments