@@ -505,13 +505,13 @@ def __init__(
505505 bias = False ,
506506 quant_config = quant_config ,
507507 prefix = f"{ prefix } .q_proj" )
508+
508509 self .kv_a_proj_with_mqa = ReplicatedLinear (
509510 self .hidden_size ,
510511 self .kv_lora_rank + self .qk_rope_head_dim ,
511512 bias = False ,
512513 quant_config = quant_config ,
513514 prefix = f"{ prefix } .kv_a_proj_with_mqa" )
514-
515515 self .kv_a_layernorm = RMSNorm (self .kv_lora_rank ,
516516 eps = config .rms_norm_eps )
517517 self .kv_b_proj = ColumnParallelLinear (
@@ -576,21 +576,27 @@ def __init__(
576576 quant_config = quant_config ,
577577 prefix = f"{ prefix } .attn" ,
578578 use_mla = True ,
579- # MLA Args
579+ use_sparse = False ,
580+ indexer = None ,
581+ # SFA Args
580582 q_lora_rank = self .q_lora_rank ,
581583 kv_lora_rank = self .kv_lora_rank ,
582584 qk_nope_head_dim = self .qk_nope_head_dim ,
583585 qk_rope_head_dim = self .qk_rope_head_dim ,
584586 qk_head_dim = self .qk_head_dim ,
585587 v_head_dim = self .v_head_dim ,
586588 rotary_emb = self .rotary_emb ,
587- q_proj = self .q_proj if self .q_lora_rank is None else None ,
588- q_b_proj = self .q_b_proj
589+ q_a_proj = self .q_a_proj
589590 if self .q_lora_rank is not None else None ,
591+ q_a_layernorm = self .q_a_layernorm
592+ if self .q_lora_rank is not None else None ,
593+ q_proj = self .q_proj
594+ if self .q_lora_rank is None else self .q_b_proj ,
590595 kv_a_proj_with_mqa = self .kv_a_proj_with_mqa ,
591596 kv_a_layernorm = self .kv_a_layernorm ,
592597 kv_b_proj = self .kv_b_proj ,
593598 o_proj = self .o_proj ,
599+ decoder_layer = decoder_layer ,
594600 )
595601 else :
596602 self .mla_attn = MLAAttention (
@@ -608,9 +614,12 @@ def __init__(
608614 indexer = None ,
609615 # MLA Args
610616 rotary_emb = self .rotary_emb ,
611- q_proj = self .q_proj if self .q_lora_rank is None else None ,
612- q_b_proj = self .q_b_proj
617+ q_a_proj = self .q_a_proj
618+ if self .q_lora_rank is not None else None ,
619+ q_a_layernorm = self .q_a_layernorm
613620 if self .q_lora_rank is not None else None ,
621+ q_proj = self .q_proj
622+ if self .q_lora_rank is None else self .q_b_proj ,
614623 kv_a_proj_with_mqa = self .kv_a_proj_with_mqa ,
615624 kv_a_layernorm = self .kv_a_layernorm ,
616625 kv_b_proj = self .kv_b_proj ,
@@ -630,22 +639,14 @@ def forward(
630639 and attn_metadata .num_decodes > 0 )
631640 forward_kwargs = {"enable_multistream_mla" : enable_multistream_mla }
632641 if self .q_lora_rank is not None :
633- maybe_npu_prefetch (self .fused_qkv_a_proj .weight ,
642+ maybe_npu_prefetch (self .q_a_proj .weight ,
634643 hidden_states ,
635644 enabled = enable_multistream_mla )
636- # ckq = self.fused_qkv_a_proj(hidden_states)[0]
637- # hidden_states_or_q_c = self.q_a_layernorm(ckq)
638- # forward_kwargs['ckq'] = ckq'
639- qkv_lora = self .fused_qkv_a_proj (hidden_states )[0 ]
640- q_c , kv_no_split = qkv_lora .split (
641- [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ],
642- dim = - 1 ,
643- )
644- hidden_states_or_q_c = self .q_a_layernorm (q_c )
645- forward_kwargs ['ckq' ] = q_c
645+ ckq = self .q_a_proj (hidden_states )[0 ]
646+ hidden_states_or_q_c = self .q_a_layernorm (ckq )
647+ forward_kwargs ['ckq' ] = ckq
646648 else :
647649 hidden_states_or_q_c = hidden_states
648- kv_no_split = self .kv_a_proj_with_mqa (hidden_states )[0 ]
649650 if self .torchair_graph_enabled :
650651 output_shape = hidden_states .shape
651652 output = torch .empty (output_shape ,
@@ -660,6 +661,7 @@ def forward(
660661 output = output .view (- 1 , output_shape [- 1 ])
661662 return output
662663 else :
664+ kv_no_split = self .kv_a_proj_with_mqa (hidden_states )[0 ]
663665 if self .enable_shared_expert_dp and self .debug_layer_idx > self .first_k_dense_replace and self .debug_layer_idx < self .layers :
664666 hidden_states_or_q_c = get_tp_group ().all_gather (
665667 hidden_states_or_q_c , 0 )
@@ -731,6 +733,14 @@ def __init__(
731733 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
732734
733735 if self .q_lora_rank is not None :
736+ self .q_a_proj = ReplicatedLinear (
737+ self .hidden_size ,
738+ self .q_lora_rank ,
739+ bias = False ,
740+ quant_config = quant_config ,
741+ prefix = f"{ prefix } .q_a_proj" ,
742+ return_bias = False ,
743+ )
734744 self .q_a_layernorm = RMSNorm (self .q_lora_rank ,
735745 eps = config .rms_norm_eps )
736746 self .q_b_proj = ColumnParallelLinear (
@@ -751,6 +761,14 @@ def __init__(
751761 return_bias = False ,
752762 )
753763
764+ self .kv_a_proj_with_mqa = ReplicatedLinear (
765+ self .hidden_size ,
766+ self .kv_lora_rank + self .qk_rope_head_dim ,
767+ bias = False ,
768+ quant_config = quant_config ,
769+ prefix = f"{ prefix } .kv_a_proj_with_mqa" ,
770+ return_bias = False ,
771+ )
754772 self .kv_a_layernorm = RMSNorm (self .kv_lora_rank ,
755773 eps = config .rms_norm_eps )
756774 self .kv_b_proj = ColumnParallelLinear (
@@ -784,23 +802,6 @@ def __init__(
784802 return_bias = False ,
785803 )
786804
787- if self .q_lora_rank is not None :
788- self .fused_qkv_a_proj = MergedColumnParallelLinear (
789- self .hidden_size ,
790- [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ],
791- bias = False ,
792- quant_config = quant_config ,
793- prefix = f"{ prefix } .fused_qkv_a_proj" ,
794- disable_tp = True )
795- self .kv_a_proj_with_mqa = None
796- else :
797- self .kv_a_proj_with_mqa = ReplicatedLinear (
798- self .hidden_size ,
799- self .kv_lora_rank + self .qk_rope_head_dim ,
800- bias = False ,
801- quant_config = quant_config ,
802- prefix = f"{ prefix } .kv_a_proj_with_mqa" )
803-
804805 if rope_scaling :
805806 rope_scaling ["rope_type" ] = 'deepseek_yarn'
806807 self .rotary_emb = get_rope (qk_rope_head_dim ,
@@ -840,21 +841,27 @@ def __init__(
840841 quant_config = quant_config ,
841842 prefix = f"{ prefix } .attn" ,
842843 use_mla = True ,
843- # MLA Args
844+ use_sparse = True ,
845+ indexer = self .indexer ,
846+ # SFA Args
844847 q_lora_rank = self .q_lora_rank ,
845848 kv_lora_rank = self .kv_lora_rank ,
846849 qk_nope_head_dim = self .qk_nope_head_dim ,
847850 qk_rope_head_dim = self .qk_rope_head_dim ,
848851 qk_head_dim = self .qk_head_dim ,
849852 v_head_dim = self .v_head_dim ,
850853 rotary_emb = self .rotary_emb ,
851- q_proj = self .q_proj if self .q_lora_rank is None else None ,
852- q_b_proj = self .q_b_proj
854+ q_a_proj = self .q_a_proj
853855 if self .q_lora_rank is not None else None ,
856+ q_a_layernorm = self .q_a_layernorm
857+ if self .q_lora_rank is not None else None ,
858+ q_proj = self .q_proj
859+ if self .q_lora_rank is None else self .q_b_proj ,
854860 kv_a_proj_with_mqa = self .kv_a_proj_with_mqa ,
855861 kv_a_layernorm = self .kv_a_layernorm ,
856862 kv_b_proj = self .kv_b_proj ,
857863 o_proj = self .o_proj ,
864+ decoder_layer = decoder_layer ,
858865 )
859866 else :
860867 self .sfa_attn = MLAAttention (
@@ -872,9 +879,12 @@ def __init__(
872879 indexer = self .indexer ,
873880 # MLA Args
874881 rotary_emb = self .rotary_emb ,
875- q_proj = self .q_proj if self .q_lora_rank is None else None ,
876- q_b_proj = self .q_b_proj
882+ q_a_proj = self .q_a_proj
883+ if self .q_lora_rank is not None else None ,
884+ q_a_layernorm = self .q_a_layernorm
877885 if self .q_lora_rank is not None else None ,
886+ q_proj = self .q_proj
887+ if self .q_lora_rank is None else self .q_b_proj ,
878888 kv_a_proj_with_mqa = self .kv_a_proj_with_mqa ,
879889 kv_a_layernorm = self .kv_a_layernorm ,
880890 kv_b_proj = self .kv_b_proj ,
@@ -1257,20 +1267,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
12571267 def load_weights (self , weights : Iterable [tuple [str ,
12581268 torch .Tensor ]]) -> set [str ]:
12591269 """"""
1260- if self .config .q_lora_rank is not None :
1261- stacked_params_mapping = [
1262- # (param_name, shard_name, shard_id)
1263- ("gate_up_proj" , "gate_proj" , 0 ),
1264- ("gate_up_proj" , "up_proj" , 1 ),
1265- ("fused_qkv_a_proj" , "q_a_proj" , 0 ),
1266- ("fused_qkv_a_proj" , "kv_a_proj_with_mqa" , 1 ),
1267- ]
1268- else :
1269- stacked_params_mapping = [
1270- # (param_name, shard_name, shard_id)
1271- ("gate_up_proj" , "gate_proj" , 0 ),
1272- ("gate_up_proj" , "up_proj" , 1 ),
1273- ]
1270+ stacked_params_mapping = [
1271+ # (param_name, shard_name, shard_id)
1272+ ("gate_up_proj" , "gate_proj" , 0 ),
1273+ ("gate_up_proj" , "up_proj" , 1 ),
1274+ ]
1275+
12741276 # Params for weights, fp8 weight scales, fp8 activation scales
12751277 # (param_name, weight_name, expert_id, shard_id)
12761278 expert_params_mapping = TorchairAscendFusedMoE .make_expert_params_mapping (
0 commit comments