Skip to content

Commit f23817d

Browse files
committed
revert torchair sfa changes
Signed-off-by: MengqingCao <[email protected]>
1 parent a1cc794 commit f23817d

File tree

3 files changed

+58
-58
lines changed

3 files changed

+58
-58
lines changed

vllm_ascend/torchair/models/torchair_deepseek_v2.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -505,13 +505,13 @@ def __init__(
505505
bias=False,
506506
quant_config=quant_config,
507507
prefix=f"{prefix}.q_proj")
508+
508509
self.kv_a_proj_with_mqa = ReplicatedLinear(
509510
self.hidden_size,
510511
self.kv_lora_rank + self.qk_rope_head_dim,
511512
bias=False,
512513
quant_config=quant_config,
513514
prefix=f"{prefix}.kv_a_proj_with_mqa")
514-
515515
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
516516
eps=config.rms_norm_eps)
517517
self.kv_b_proj = ColumnParallelLinear(
@@ -576,21 +576,27 @@ def __init__(
576576
quant_config=quant_config,
577577
prefix=f"{prefix}.attn",
578578
use_mla=True,
579-
# MLA Args
579+
use_sparse=False,
580+
indexer=None,
581+
# SFA Args
580582
q_lora_rank=self.q_lora_rank,
581583
kv_lora_rank=self.kv_lora_rank,
582584
qk_nope_head_dim=self.qk_nope_head_dim,
583585
qk_rope_head_dim=self.qk_rope_head_dim,
584586
qk_head_dim=self.qk_head_dim,
585587
v_head_dim=self.v_head_dim,
586588
rotary_emb=self.rotary_emb,
587-
q_proj=self.q_proj if self.q_lora_rank is None else None,
588-
q_b_proj=self.q_b_proj
589+
q_a_proj=self.q_a_proj
589590
if self.q_lora_rank is not None else None,
591+
q_a_layernorm=self.q_a_layernorm
592+
if self.q_lora_rank is not None else None,
593+
q_proj=self.q_proj
594+
if self.q_lora_rank is None else self.q_b_proj,
590595
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
591596
kv_a_layernorm=self.kv_a_layernorm,
592597
kv_b_proj=self.kv_b_proj,
593598
o_proj=self.o_proj,
599+
decoder_layer=decoder_layer,
594600
)
595601
else:
596602
self.mla_attn = MLAAttention(
@@ -608,9 +614,12 @@ def __init__(
608614
indexer=None,
609615
# MLA Args
610616
rotary_emb=self.rotary_emb,
611-
q_proj=self.q_proj if self.q_lora_rank is None else None,
612-
q_b_proj=self.q_b_proj
617+
q_a_proj=self.q_a_proj
618+
if self.q_lora_rank is not None else None,
619+
q_a_layernorm=self.q_a_layernorm
613620
if self.q_lora_rank is not None else None,
621+
q_proj=self.q_proj
622+
if self.q_lora_rank is None else self.q_b_proj,
614623
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
615624
kv_a_layernorm=self.kv_a_layernorm,
616625
kv_b_proj=self.kv_b_proj,
@@ -630,22 +639,14 @@ def forward(
630639
and attn_metadata.num_decodes > 0)
631640
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
632641
if self.q_lora_rank is not None:
633-
maybe_npu_prefetch(self.fused_qkv_a_proj.weight,
642+
maybe_npu_prefetch(self.q_a_proj.weight,
634643
hidden_states,
635644
enabled=enable_multistream_mla)
636-
# ckq = self.fused_qkv_a_proj(hidden_states)[0]
637-
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
638-
# forward_kwargs['ckq'] = ckq'
639-
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
640-
q_c, kv_no_split = qkv_lora.split(
641-
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
642-
dim=-1,
643-
)
644-
hidden_states_or_q_c = self.q_a_layernorm(q_c)
645-
forward_kwargs['ckq'] = q_c
645+
ckq = self.q_a_proj(hidden_states)[0]
646+
hidden_states_or_q_c = self.q_a_layernorm(ckq)
647+
forward_kwargs['ckq'] = ckq
646648
else:
647649
hidden_states_or_q_c = hidden_states
648-
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
649650
if self.torchair_graph_enabled:
650651
output_shape = hidden_states.shape
651652
output = torch.empty(output_shape,
@@ -660,6 +661,7 @@ def forward(
660661
output = output.view(-1, output_shape[-1])
661662
return output
662663
else:
664+
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
663665
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
664666
hidden_states_or_q_c = get_tp_group().all_gather(
665667
hidden_states_or_q_c, 0)
@@ -731,6 +733,14 @@ def __init__(
731733
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
732734

733735
if self.q_lora_rank is not None:
736+
self.q_a_proj = ReplicatedLinear(
737+
self.hidden_size,
738+
self.q_lora_rank,
739+
bias=False,
740+
quant_config=quant_config,
741+
prefix=f"{prefix}.q_a_proj",
742+
return_bias=False,
743+
)
734744
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
735745
eps=config.rms_norm_eps)
736746
self.q_b_proj = ColumnParallelLinear(
@@ -751,6 +761,14 @@ def __init__(
751761
return_bias=False,
752762
)
753763

764+
self.kv_a_proj_with_mqa = ReplicatedLinear(
765+
self.hidden_size,
766+
self.kv_lora_rank + self.qk_rope_head_dim,
767+
bias=False,
768+
quant_config=quant_config,
769+
prefix=f"{prefix}.kv_a_proj_with_mqa",
770+
return_bias=False,
771+
)
754772
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
755773
eps=config.rms_norm_eps)
756774
self.kv_b_proj = ColumnParallelLinear(
@@ -784,23 +802,6 @@ def __init__(
784802
return_bias=False,
785803
)
786804

787-
if self.q_lora_rank is not None:
788-
self.fused_qkv_a_proj = MergedColumnParallelLinear(
789-
self.hidden_size,
790-
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
791-
bias=False,
792-
quant_config=quant_config,
793-
prefix=f"{prefix}.fused_qkv_a_proj",
794-
disable_tp=True)
795-
self.kv_a_proj_with_mqa = None
796-
else:
797-
self.kv_a_proj_with_mqa = ReplicatedLinear(
798-
self.hidden_size,
799-
self.kv_lora_rank + self.qk_rope_head_dim,
800-
bias=False,
801-
quant_config=quant_config,
802-
prefix=f"{prefix}.kv_a_proj_with_mqa")
803-
804805
if rope_scaling:
805806
rope_scaling["rope_type"] = 'deepseek_yarn'
806807
self.rotary_emb = get_rope(qk_rope_head_dim,
@@ -840,21 +841,27 @@ def __init__(
840841
quant_config=quant_config,
841842
prefix=f"{prefix}.attn",
842843
use_mla=True,
843-
# MLA Args
844+
use_sparse=True,
845+
indexer=self.indexer,
846+
# SFA Args
844847
q_lora_rank=self.q_lora_rank,
845848
kv_lora_rank=self.kv_lora_rank,
846849
qk_nope_head_dim=self.qk_nope_head_dim,
847850
qk_rope_head_dim=self.qk_rope_head_dim,
848851
qk_head_dim=self.qk_head_dim,
849852
v_head_dim=self.v_head_dim,
850853
rotary_emb=self.rotary_emb,
851-
q_proj=self.q_proj if self.q_lora_rank is None else None,
852-
q_b_proj=self.q_b_proj
854+
q_a_proj=self.q_a_proj
853855
if self.q_lora_rank is not None else None,
856+
q_a_layernorm=self.q_a_layernorm
857+
if self.q_lora_rank is not None else None,
858+
q_proj=self.q_proj
859+
if self.q_lora_rank is None else self.q_b_proj,
854860
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
855861
kv_a_layernorm=self.kv_a_layernorm,
856862
kv_b_proj=self.kv_b_proj,
857863
o_proj=self.o_proj,
864+
decoder_layer=decoder_layer,
858865
)
859866
else:
860867
self.sfa_attn = MLAAttention(
@@ -872,9 +879,12 @@ def __init__(
872879
indexer=self.indexer,
873880
# MLA Args
874881
rotary_emb=self.rotary_emb,
875-
q_proj=self.q_proj if self.q_lora_rank is None else None,
876-
q_b_proj=self.q_b_proj
882+
q_a_proj=self.q_a_proj
883+
if self.q_lora_rank is not None else None,
884+
q_a_layernorm=self.q_a_layernorm
877885
if self.q_lora_rank is not None else None,
886+
q_proj=self.q_proj
887+
if self.q_lora_rank is None else self.q_b_proj,
878888
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
879889
kv_a_layernorm=self.kv_a_layernorm,
880890
kv_b_proj=self.kv_b_proj,
@@ -1257,20 +1267,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
12571267
def load_weights(self, weights: Iterable[tuple[str,
12581268
torch.Tensor]]) -> set[str]:
12591269
""""""
1260-
if self.config.q_lora_rank is not None:
1261-
stacked_params_mapping = [
1262-
# (param_name, shard_name, shard_id)
1263-
("gate_up_proj", "gate_proj", 0),
1264-
("gate_up_proj", "up_proj", 1),
1265-
("fused_qkv_a_proj", "q_a_proj", 0),
1266-
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
1267-
]
1268-
else:
1269-
stacked_params_mapping = [
1270-
# (param_name, shard_name, shard_id)
1271-
("gate_up_proj", "gate_proj", 0),
1272-
("gate_up_proj", "up_proj", 1),
1273-
]
1270+
stacked_params_mapping = [
1271+
# (param_name, shard_name, shard_id)
1272+
("gate_up_proj", "gate_proj", 0),
1273+
("gate_up_proj", "up_proj", 1),
1274+
]
1275+
12741276
# Params for weights, fp8 weight scales, fp8 activation scales
12751277
# (param_name, weight_name, expert_id, shard_id)
12761278
expert_params_mapping = TorchairAscendFusedMoE.make_expert_params_mapping(

vllm_ascend/torchair/torchair_mla.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -656,8 +656,7 @@ def __init__(
656656
self.qk_head_dim = kwargs['qk_head_dim']
657657
self.v_head_dim = kwargs['v_head_dim']
658658
self.rotary_emb = kwargs['rotary_emb']
659-
self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[
660-
'q_b_proj']
659+
self.q_proj = kwargs['q_proj']
661660
self.kv_b_proj = kwargs['kv_b_proj']
662661
self.o_proj = kwargs['o_proj']
663662
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)

vllm_ascend/torchair/torchair_sfa.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -719,8 +719,7 @@ def __init__(
719719
self.qk_head_dim = kwargs['qk_head_dim']
720720
self.v_head_dim = kwargs['v_head_dim']
721721
self.rotary_emb = kwargs['rotary_emb']
722-
self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[
723-
'q_b_proj']
722+
self.q_proj = kwargs['q_proj']
724723
self.kv_b_proj = kwargs['kv_b_proj']
725724
self.o_proj = kwargs['o_proj']
726725
self.indexer = kwargs['indexer']

0 commit comments

Comments
 (0)