4949 RowParallelLinear ,
5050 UnquantizedLinearMethod )
5151from vllm .model_executor .layers .logits_processor import LogitsProcessor
52- from vllm .model_executor .layers .mla import MLAModules
5352from vllm .model_executor .layers .quantization import QuantizationConfig
5453from vllm .model_executor .layers .rotary_embedding import get_rope
5554from vllm .model_executor .layers .vocab_parallel_embedding import (
7069
7170from vllm_ascend import envs
7271from vllm_ascend .ascend_config import get_ascend_config
73- from vllm_ascend .models .layers .sfa import AscendSFAModules , Indexer
72+ from vllm_ascend .models .layers .sfa import Indexer
7473from vllm_ascend .ops .weight_prefetch import maybe_npu_prefetch
7574from vllm_ascend .quantization .quant_config import AscendLinearMethod
7675from vllm_ascend .torchair .ops .torchair_fused_moe import TorchairAscendFusedMoE
7978from vllm_ascend .utils import dispose_tensor , oproj_tp_enable , vllm_version_is
8079
8180if vllm_version_is ("0.11.0" ):
82- from vllm .model_executor . layers . mla import MultiHeadLatentAttention
81+ from vllm .attention import Attention
8382else :
84- from vllm .model_executor . layers . mla import MultiHeadLatentAttentionWrapper
83+ from vllm .attention . layer import MLAAttention
8584
8685
8786class TorchairDeepseekV2SiluAndMul (SiluAndMul ):
@@ -486,6 +485,11 @@ def __init__(
486485 self .enable_shared_expert_dp = ascend_config .enable_shared_expert_dp
487486
488487 if self .q_lora_rank is not None :
488+ self .q_a_proj = ReplicatedLinear (self .hidden_size ,
489+ self .q_lora_rank ,
490+ bias = False ,
491+ quant_config = quant_config ,
492+ prefix = f"{ prefix } .q_a_proj" )
489493 self .q_a_layernorm = RMSNorm (self .q_lora_rank ,
490494 eps = config .rms_norm_eps )
491495 self .q_b_proj = ColumnParallelLinear (q_lora_rank ,
@@ -501,6 +505,12 @@ def __init__(
501505 bias = False ,
502506 quant_config = quant_config ,
503507 prefix = f"{ prefix } .q_proj" )
508+ self .kv_a_proj_with_mqa = ReplicatedLinear (
509+ self .hidden_size ,
510+ self .kv_lora_rank + self .qk_rope_head_dim ,
511+ bias = False ,
512+ quant_config = quant_config ,
513+ prefix = f"{ prefix } .kv_a_proj_with_mqa" )
504514
505515 self .kv_a_layernorm = RMSNorm (self .kv_lora_rank ,
506516 eps = config .rms_norm_eps )
@@ -536,24 +546,6 @@ def __init__(
536546 quant_config = quant_config ,
537547 prefix = f"{ prefix } .o_proj" )
538548
539- print (30 * "=" , f"q_lora_rank: { q_lora_rank } " )
540- if self .q_lora_rank is not None :
541- self .fused_qkv_a_proj = MergedColumnParallelLinear (
542- self .hidden_size ,
543- [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ],
544- bias = False ,
545- quant_config = quant_config ,
546- prefix = f"{ prefix } .fused_qkv_a_proj" ,
547- disable_tp = True )
548- self .kv_a_proj_with_mqa = None
549- else :
550- self .kv_a_proj_with_mqa = ReplicatedLinear (
551- self .hidden_size ,
552- self .kv_lora_rank + self .qk_rope_head_dim ,
553- bias = False ,
554- quant_config = quant_config ,
555- prefix = f"{ prefix } .kv_a_proj_with_mqa" )
556-
557549 if rope_scaling :
558550 rope_scaling ["rope_type" ] = 'deepseek_yarn'
559551 self .rotary_emb = get_rope (qk_rope_head_dim ,
@@ -568,59 +560,61 @@ def __init__(
568560 mscale = yarn_get_mscale (scaling_factor , float (mscale_all_dim ))
569561 self .scaling = self .scaling * mscale * mscale
570562
571- mla_modules = MLAModules (
572- kv_a_layernorm = self .kv_a_layernorm ,
573- kv_b_proj = self .kv_b_proj ,
574- rotary_emb = self .rotary_emb ,
575- o_proj = self .o_proj ,
576- fused_qkv_a_proj = self .fused_qkv_a_proj
577- if self .q_lora_rank is not None else None ,
578- kv_a_proj_with_mqa = self .kv_a_proj_with_mqa
579- if self .q_lora_rank is None else None ,
580- q_a_layernorm = self .q_a_layernorm
581- if self .q_lora_rank is not None else None ,
582- q_b_proj = self .q_b_proj if self .q_lora_rank is not None else None ,
583- q_proj = self .q_proj if self .q_lora_rank is None else None ,
584- indexer = None ,
585- is_sparse = hasattr (config , "index_topk" ),
586- topk_indices_buffer = None ,
587- )
588563 # In the MLA backend, kv_cache includes both k_c and
589564 # pe (i.e. decoupled position embeddings). In particular,
590565 # the concat_and_cache_mla op requires
591566 # k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
592567 # i.e.
593568 # kv_lora_rank + qk_rope_head_dim == head_size
594-
595569 if vllm_version_is ("0.11.0" ):
596- self .mla_attn = MultiHeadLatentAttention (
597- hidden_size = self .hidden_size ,
570+ self .mla_attn = Attention (
598571 num_heads = self .num_local_heads ,
572+ head_size = self .kv_lora_rank + self .qk_rope_head_dim ,
599573 scale = self .scaling ,
574+ num_kv_heads = 1 ,
575+ cache_config = cache_config ,
576+ quant_config = quant_config ,
577+ prefix = f"{ prefix } .attn" ,
578+ use_mla = True ,
579+ # MLA Args
580+ q_lora_rank = self .q_lora_rank ,
581+ kv_lora_rank = self .kv_lora_rank ,
600582 qk_nope_head_dim = self .qk_nope_head_dim ,
601583 qk_rope_head_dim = self .qk_rope_head_dim ,
584+ qk_head_dim = self .qk_head_dim ,
602585 v_head_dim = self .v_head_dim ,
603- q_lora_rank = self .q_lora_rank ,
604- kv_lora_rank = self .kv_lora_rank ,
605- mla_modules = mla_modules ,
606- cache_config = cache_config ,
607- quant_config = quant_config ,
608- prefix = prefix ,
586+ rotary_emb = self .rotary_emb ,
587+ q_proj = self .q_proj if self .q_lora_rank is None else None ,
588+ q_b_proj = self .q_b_proj
589+ if self .q_lora_rank is not None else None ,
590+ kv_a_proj_with_mqa = self .kv_a_proj_with_mqa ,
591+ kv_a_layernorm = self .kv_a_layernorm ,
592+ kv_b_proj = self .kv_b_proj ,
593+ o_proj = self .o_proj ,
609594 )
610595 else :
611- self .mla_attn = MultiHeadLatentAttentionWrapper (
612- hidden_size = self .kv_lora_rank + self .qk_rope_head_dim ,
596+ self .mla_attn = MLAAttention (
613597 num_heads = self .num_local_heads ,
614598 scale = self .scaling ,
615599 qk_nope_head_dim = self .qk_nope_head_dim ,
616600 qk_rope_head_dim = self .qk_rope_head_dim ,
617601 v_head_dim = self .v_head_dim ,
618602 q_lora_rank = self .q_lora_rank ,
619603 kv_lora_rank = self .kv_lora_rank ,
620- mla_modules = mla_modules ,
621604 cache_config = cache_config ,
622605 quant_config = quant_config ,
623- prefix = prefix ,
606+ prefix = f"{ prefix } .attn" ,
607+ use_sparse = False ,
608+ indexer = None ,
609+ # MLA Args
610+ rotary_emb = self .rotary_emb ,
611+ q_proj = self .q_proj if self .q_lora_rank is None else None ,
612+ q_b_proj = self .q_b_proj
613+ if self .q_lora_rank is not None else None ,
614+ kv_a_proj_with_mqa = self .kv_a_proj_with_mqa ,
615+ kv_a_layernorm = self .kv_a_layernorm ,
616+ kv_b_proj = self .kv_b_proj ,
617+ o_proj = self .o_proj ,
624618 )
625619
626620 def forward (
@@ -658,9 +652,11 @@ def forward(
658652 dtype = hidden_states_or_q_c .dtype ,
659653 device = hidden_states_or_q_c .device )
660654 forward_kwargs ['output' ] = output
661- output = self .mla_attn .mla_attn .impl .forward (
662- self .mla_attn , hidden_states_or_q_c , hidden_states , None ,
663- kv_cache , attn_metadata , ** forward_kwargs )
655+ output = self .mla_attn .impl .forward (self .mla_attn ,
656+ hidden_states_or_q_c ,
657+ hidden_states , None , kv_cache ,
658+ attn_metadata ,
659+ ** forward_kwargs )
664660 output = output .view (- 1 , output_shape [- 1 ])
665661 return output
666662 else :
@@ -834,51 +830,55 @@ def __init__(
834830 prefix = f"{ prefix } .indexer" ,
835831 )
836832
837- sfa_modules = AscendSFAModules (
838- q_a_layernorm = self .q_a_layernorm
839- if self .q_lora_rank is not None else None ,
840- q_proj = self .q_proj if self .q_lora_rank is None else self .q_b_proj ,
841- q_b_proj = self .q_b_proj if self .q_lora_rank is not None else None ,
842- kv_a_proj_with_mqa = self .kv_a_proj_with_mqa ,
843- fused_qkv_a_proj = self .fused_qkv_a_proj
844- if self .q_lora_rank is not None else None ,
845- kv_a_layernorm = self .kv_a_layernorm ,
846- kv_b_proj = self .kv_b_proj ,
847- o_proj = self .o_proj ,
848- rotary_emb = self .rotary_emb ,
849- indexer = self .indexer ,
850- is_sparse = hasattr (config , "index_topk" ),
851- topk_indices_buffer = None )
852-
853833 if vllm_version_is ("0.11.0" ):
854- self .sfa_attn = MultiHeadLatentAttention (
855- hidden_size = self .hidden_size ,
834+ self .sfa_attn = Attention (
856835 num_heads = self .num_local_heads ,
836+ head_size = self .kv_lora_rank + self .qk_rope_head_dim ,
857837 scale = self .scaling ,
838+ num_kv_heads = 1 ,
839+ cache_config = cache_config ,
840+ quant_config = quant_config ,
841+ prefix = f"{ prefix } .attn" ,
842+ use_mla = True ,
843+ # MLA Args
844+ q_lora_rank = self .q_lora_rank ,
845+ kv_lora_rank = self .kv_lora_rank ,
858846 qk_nope_head_dim = self .qk_nope_head_dim ,
859847 qk_rope_head_dim = self .qk_rope_head_dim ,
848+ qk_head_dim = self .qk_head_dim ,
860849 v_head_dim = self .v_head_dim ,
861- q_lora_rank = self .q_lora_rank ,
862- kv_lora_rank = self .kv_lora_rank ,
863- mla_modules = sfa_modules ,
864- cache_config = cache_config ,
865- quant_config = quant_config ,
866- prefix = prefix ,
850+ rotary_emb = self .rotary_emb ,
851+ q_proj = self .q_proj if self .q_lora_rank is None else None ,
852+ q_b_proj = self .q_b_proj
853+ if self .q_lora_rank is not None else None ,
854+ kv_a_proj_with_mqa = self .kv_a_proj_with_mqa ,
855+ kv_a_layernorm = self .kv_a_layernorm ,
856+ kv_b_proj = self .kv_b_proj ,
857+ o_proj = self .o_proj ,
867858 )
868859 else :
869- self .sfa_attn = MultiHeadLatentAttentionWrapper (
870- hidden_size = self .hidden_size ,
860+ self .sfa_attn = MLAAttention (
871861 num_heads = self .num_local_heads ,
872862 scale = self .scaling ,
873863 qk_nope_head_dim = self .qk_nope_head_dim ,
874864 qk_rope_head_dim = self .qk_rope_head_dim ,
875865 v_head_dim = self .v_head_dim ,
876866 q_lora_rank = self .q_lora_rank ,
877867 kv_lora_rank = self .kv_lora_rank ,
878- mla_modules = sfa_modules ,
879868 cache_config = cache_config ,
880869 quant_config = quant_config ,
881- prefix = prefix ,
870+ prefix = f"{ prefix } .attn" ,
871+ use_sparse = True ,
872+ indexer = self .indexer ,
873+ # MLA Args
874+ rotary_emb = self .rotary_emb ,
875+ q_proj = self .q_proj if self .q_lora_rank is None else None ,
876+ q_b_proj = self .q_b_proj
877+ if self .q_lora_rank is not None else None ,
878+ kv_a_proj_with_mqa = self .kv_a_proj_with_mqa ,
879+ kv_a_layernorm = self .kv_a_layernorm ,
880+ kv_b_proj = self .kv_b_proj ,
881+ o_proj = self .o_proj ,
882882 )
883883
884884 def forward (
@@ -917,9 +917,8 @@ def forward(
917917 output = torch .empty (output_shape ,
918918 dtype = hidden_states .dtype ,
919919 device = hidden_states .device )
920- self .sfa_attn .sfa_attn .impl .forward (hidden_states , kv_cache ,
921- attn_metadata , need_gather_q_kv ,
922- output )
920+ self .sfa_attn .impl .forward (hidden_states , kv_cache , attn_metadata ,
921+ need_gather_q_kv , output )
923922 output = output .view (- 1 , output_shape [- 1 ])
924923 return output
925924
0 commit comments