Skip to content

Commit a1cc794

Browse files
committed
revert torchair refactor
Signed-off-by: MengqingCao <[email protected]>
1 parent d754a18 commit a1cc794

File tree

5 files changed

+129
-176
lines changed

5 files changed

+129
-176
lines changed

tests/ut/test_platform.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
import importlib
2-
import unittest
3-
from datetime import timedelta
42
from unittest.mock import MagicMock, patch
53

64
import pytest
75
import torch
8-
from torch.distributed import ProcessGroup
9-
from torch.distributed.distributed_c10d import PrefixStore
106
from vllm.config.compilation import CUDAGraphMode
117
from vllm.platforms import PlatformEnum
128

vllm_ascend/compilation/acl_graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ACLGraphWrapper:
4040
4141
The workflow of this wrapper in the aclgraph dispatching is as follows:
4242
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
43-
VLLM_COMPILE).
43+
PIECEWISE).
4444
2. At runtime, the wrapper receives a runtime_mode and a
4545
batch_descriptor(key) from the forward context and blindly trust them
4646
for aclgraph dispatching.
@@ -126,7 +126,7 @@ def __call__(self, *args, **kwargs):
126126
# Since we capture aclgraph for many different shapes and
127127
# capturing is fast, we don't need to log it for every
128128
# shape. E.g. we only log it for the first subgraph in
129-
# VLLM_COMPILE mode.
129+
# piecewise mode.
130130
logger.debug("Capturing a aclgraph on (%s,%s)",
131131
self.runtime_mode.name, entry.batch_descriptor)
132132
# validate that aclgraph capturing is legal at this point.
@@ -140,7 +140,7 @@ def __call__(self, *args, **kwargs):
140140

141141
with ExitStack() as stack:
142142
if self.aclgraph_options.gc_disable:
143-
# during every model forward for VLLM_COMPILE aclgraph
143+
# during every model forward for piecewise aclgraph
144144
# mode, we will capture many pieces of aclgraphs
145145
# (roughly one per layer). running gc again and again
146146
# across layers will make the aclgraph capture very slow.
@@ -159,7 +159,7 @@ def __call__(self, *args, **kwargs):
159159
# by converting it to weak ref,
160160
# the original `output` will immediately be released
161161
# to save memory. It is only safe to do this for
162-
# the last graph in VLLM_COMPILE aclgraph mode, because
162+
# the last graph in piecewise aclgraph mode, because
163163
# the output of the last graph will not be used by
164164
# any other acl graph.
165165
output = weak_ref_tensors(output)

vllm_ascend/torchair/models/torchair_deepseek_v2.py

Lines changed: 85 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
RowParallelLinear,
5050
UnquantizedLinearMethod)
5151
from vllm.model_executor.layers.logits_processor import LogitsProcessor
52-
from vllm.model_executor.layers.mla import MLAModules
5352
from vllm.model_executor.layers.quantization import QuantizationConfig
5453
from vllm.model_executor.layers.rotary_embedding import get_rope
5554
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -70,7 +69,7 @@
7069

7170
from vllm_ascend import envs
7271
from vllm_ascend.ascend_config import get_ascend_config
73-
from vllm_ascend.models.layers.sfa import AscendSFAModules, Indexer
72+
from vllm_ascend.models.layers.sfa import Indexer
7473
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
7574
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7675
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
@@ -79,9 +78,9 @@
7978
from vllm_ascend.utils import dispose_tensor, oproj_tp_enable, vllm_version_is
8079

8180
if vllm_version_is("0.11.0"):
82-
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
81+
from vllm.attention import Attention
8382
else:
84-
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
83+
from vllm.attention.layer import MLAAttention
8584

8685

8786
class TorchairDeepseekV2SiluAndMul(SiluAndMul):
@@ -486,6 +485,11 @@ def __init__(
486485
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
487486

488487
if self.q_lora_rank is not None:
488+
self.q_a_proj = ReplicatedLinear(self.hidden_size,
489+
self.q_lora_rank,
490+
bias=False,
491+
quant_config=quant_config,
492+
prefix=f"{prefix}.q_a_proj")
489493
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
490494
eps=config.rms_norm_eps)
491495
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
@@ -501,6 +505,12 @@ def __init__(
501505
bias=False,
502506
quant_config=quant_config,
503507
prefix=f"{prefix}.q_proj")
508+
self.kv_a_proj_with_mqa = ReplicatedLinear(
509+
self.hidden_size,
510+
self.kv_lora_rank + self.qk_rope_head_dim,
511+
bias=False,
512+
quant_config=quant_config,
513+
prefix=f"{prefix}.kv_a_proj_with_mqa")
504514

505515
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
506516
eps=config.rms_norm_eps)
@@ -536,24 +546,6 @@ def __init__(
536546
quant_config=quant_config,
537547
prefix=f"{prefix}.o_proj")
538548

539-
print(30 * "=", f"q_lora_rank: {q_lora_rank}")
540-
if self.q_lora_rank is not None:
541-
self.fused_qkv_a_proj = MergedColumnParallelLinear(
542-
self.hidden_size,
543-
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
544-
bias=False,
545-
quant_config=quant_config,
546-
prefix=f"{prefix}.fused_qkv_a_proj",
547-
disable_tp=True)
548-
self.kv_a_proj_with_mqa = None
549-
else:
550-
self.kv_a_proj_with_mqa = ReplicatedLinear(
551-
self.hidden_size,
552-
self.kv_lora_rank + self.qk_rope_head_dim,
553-
bias=False,
554-
quant_config=quant_config,
555-
prefix=f"{prefix}.kv_a_proj_with_mqa")
556-
557549
if rope_scaling:
558550
rope_scaling["rope_type"] = 'deepseek_yarn'
559551
self.rotary_emb = get_rope(qk_rope_head_dim,
@@ -568,59 +560,61 @@ def __init__(
568560
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
569561
self.scaling = self.scaling * mscale * mscale
570562

571-
mla_modules = MLAModules(
572-
kv_a_layernorm=self.kv_a_layernorm,
573-
kv_b_proj=self.kv_b_proj,
574-
rotary_emb=self.rotary_emb,
575-
o_proj=self.o_proj,
576-
fused_qkv_a_proj=self.fused_qkv_a_proj
577-
if self.q_lora_rank is not None else None,
578-
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
579-
if self.q_lora_rank is None else None,
580-
q_a_layernorm=self.q_a_layernorm
581-
if self.q_lora_rank is not None else None,
582-
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
583-
q_proj=self.q_proj if self.q_lora_rank is None else None,
584-
indexer=None,
585-
is_sparse=hasattr(config, "index_topk"),
586-
topk_indices_buffer=None,
587-
)
588563
# In the MLA backend, kv_cache includes both k_c and
589564
# pe (i.e. decoupled position embeddings). In particular,
590565
# the concat_and_cache_mla op requires
591566
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
592567
# i.e.
593568
# kv_lora_rank + qk_rope_head_dim == head_size
594-
595569
if vllm_version_is("0.11.0"):
596-
self.mla_attn = MultiHeadLatentAttention(
597-
hidden_size=self.hidden_size,
570+
self.mla_attn = Attention(
598571
num_heads=self.num_local_heads,
572+
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
599573
scale=self.scaling,
574+
num_kv_heads=1,
575+
cache_config=cache_config,
576+
quant_config=quant_config,
577+
prefix=f"{prefix}.attn",
578+
use_mla=True,
579+
# MLA Args
580+
q_lora_rank=self.q_lora_rank,
581+
kv_lora_rank=self.kv_lora_rank,
600582
qk_nope_head_dim=self.qk_nope_head_dim,
601583
qk_rope_head_dim=self.qk_rope_head_dim,
584+
qk_head_dim=self.qk_head_dim,
602585
v_head_dim=self.v_head_dim,
603-
q_lora_rank=self.q_lora_rank,
604-
kv_lora_rank=self.kv_lora_rank,
605-
mla_modules=mla_modules,
606-
cache_config=cache_config,
607-
quant_config=quant_config,
608-
prefix=prefix,
586+
rotary_emb=self.rotary_emb,
587+
q_proj=self.q_proj if self.q_lora_rank is None else None,
588+
q_b_proj=self.q_b_proj
589+
if self.q_lora_rank is not None else None,
590+
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
591+
kv_a_layernorm=self.kv_a_layernorm,
592+
kv_b_proj=self.kv_b_proj,
593+
o_proj=self.o_proj,
609594
)
610595
else:
611-
self.mla_attn = MultiHeadLatentAttentionWrapper(
612-
hidden_size=self.kv_lora_rank + self.qk_rope_head_dim,
596+
self.mla_attn = MLAAttention(
613597
num_heads=self.num_local_heads,
614598
scale=self.scaling,
615599
qk_nope_head_dim=self.qk_nope_head_dim,
616600
qk_rope_head_dim=self.qk_rope_head_dim,
617601
v_head_dim=self.v_head_dim,
618602
q_lora_rank=self.q_lora_rank,
619603
kv_lora_rank=self.kv_lora_rank,
620-
mla_modules=mla_modules,
621604
cache_config=cache_config,
622605
quant_config=quant_config,
623-
prefix=prefix,
606+
prefix=f"{prefix}.attn",
607+
use_sparse=False,
608+
indexer=None,
609+
# MLA Args
610+
rotary_emb=self.rotary_emb,
611+
q_proj=self.q_proj if self.q_lora_rank is None else None,
612+
q_b_proj=self.q_b_proj
613+
if self.q_lora_rank is not None else None,
614+
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
615+
kv_a_layernorm=self.kv_a_layernorm,
616+
kv_b_proj=self.kv_b_proj,
617+
o_proj=self.o_proj,
624618
)
625619

626620
def forward(
@@ -658,9 +652,11 @@ def forward(
658652
dtype=hidden_states_or_q_c.dtype,
659653
device=hidden_states_or_q_c.device)
660654
forward_kwargs['output'] = output
661-
output = self.mla_attn.mla_attn.impl.forward(
662-
self.mla_attn, hidden_states_or_q_c, hidden_states, None,
663-
kv_cache, attn_metadata, **forward_kwargs)
655+
output = self.mla_attn.impl.forward(self.mla_attn,
656+
hidden_states_or_q_c,
657+
hidden_states, None, kv_cache,
658+
attn_metadata,
659+
**forward_kwargs)
664660
output = output.view(-1, output_shape[-1])
665661
return output
666662
else:
@@ -834,51 +830,55 @@ def __init__(
834830
prefix=f"{prefix}.indexer",
835831
)
836832

837-
sfa_modules = AscendSFAModules(
838-
q_a_layernorm=self.q_a_layernorm
839-
if self.q_lora_rank is not None else None,
840-
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
841-
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
842-
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
843-
fused_qkv_a_proj=self.fused_qkv_a_proj
844-
if self.q_lora_rank is not None else None,
845-
kv_a_layernorm=self.kv_a_layernorm,
846-
kv_b_proj=self.kv_b_proj,
847-
o_proj=self.o_proj,
848-
rotary_emb=self.rotary_emb,
849-
indexer=self.indexer,
850-
is_sparse=hasattr(config, "index_topk"),
851-
topk_indices_buffer=None)
852-
853833
if vllm_version_is("0.11.0"):
854-
self.sfa_attn = MultiHeadLatentAttention(
855-
hidden_size=self.hidden_size,
834+
self.sfa_attn = Attention(
856835
num_heads=self.num_local_heads,
836+
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
857837
scale=self.scaling,
838+
num_kv_heads=1,
839+
cache_config=cache_config,
840+
quant_config=quant_config,
841+
prefix=f"{prefix}.attn",
842+
use_mla=True,
843+
# MLA Args
844+
q_lora_rank=self.q_lora_rank,
845+
kv_lora_rank=self.kv_lora_rank,
858846
qk_nope_head_dim=self.qk_nope_head_dim,
859847
qk_rope_head_dim=self.qk_rope_head_dim,
848+
qk_head_dim=self.qk_head_dim,
860849
v_head_dim=self.v_head_dim,
861-
q_lora_rank=self.q_lora_rank,
862-
kv_lora_rank=self.kv_lora_rank,
863-
mla_modules=sfa_modules,
864-
cache_config=cache_config,
865-
quant_config=quant_config,
866-
prefix=prefix,
850+
rotary_emb=self.rotary_emb,
851+
q_proj=self.q_proj if self.q_lora_rank is None else None,
852+
q_b_proj=self.q_b_proj
853+
if self.q_lora_rank is not None else None,
854+
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
855+
kv_a_layernorm=self.kv_a_layernorm,
856+
kv_b_proj=self.kv_b_proj,
857+
o_proj=self.o_proj,
867858
)
868859
else:
869-
self.sfa_attn = MultiHeadLatentAttentionWrapper(
870-
hidden_size=self.hidden_size,
860+
self.sfa_attn = MLAAttention(
871861
num_heads=self.num_local_heads,
872862
scale=self.scaling,
873863
qk_nope_head_dim=self.qk_nope_head_dim,
874864
qk_rope_head_dim=self.qk_rope_head_dim,
875865
v_head_dim=self.v_head_dim,
876866
q_lora_rank=self.q_lora_rank,
877867
kv_lora_rank=self.kv_lora_rank,
878-
mla_modules=sfa_modules,
879868
cache_config=cache_config,
880869
quant_config=quant_config,
881-
prefix=prefix,
870+
prefix=f"{prefix}.attn",
871+
use_sparse=True,
872+
indexer=self.indexer,
873+
# MLA Args
874+
rotary_emb=self.rotary_emb,
875+
q_proj=self.q_proj if self.q_lora_rank is None else None,
876+
q_b_proj=self.q_b_proj
877+
if self.q_lora_rank is not None else None,
878+
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
879+
kv_a_layernorm=self.kv_a_layernorm,
880+
kv_b_proj=self.kv_b_proj,
881+
o_proj=self.o_proj,
882882
)
883883

884884
def forward(
@@ -917,9 +917,8 @@ def forward(
917917
output = torch.empty(output_shape,
918918
dtype=hidden_states.dtype,
919919
device=hidden_states.device)
920-
self.sfa_attn.sfa_attn.impl.forward(hidden_states, kv_cache,
921-
attn_metadata, need_gather_q_kv,
922-
output)
920+
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
921+
need_gather_q_kv, output)
923922
output = output.view(-1, output_shape[-1])
924923
return output
925924

vllm_ascend/torchair/torchair_mla.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,6 @@ def __init__(
664664
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
665665
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
666666
self.tp_size = get_tensor_model_parallel_world_size()
667-
self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None)
668667

669668
ascend_config = get_ascend_config()
670669
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -914,14 +913,7 @@ def exec_kv(
914913
B = hidden_states.shape[0]
915914
N = self.num_kv_heads
916915
S = 1
917-
if self.fused_qkv_a_proj is not None:
918-
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
919-
_, kv = qkv_lora.split(
920-
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
921-
dim=-1,
922-
)
923-
else:
924-
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
916+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
925917
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
926918
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
927919
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
@@ -950,14 +942,7 @@ def exec_kv_prefill(
950942
B = hidden_states.shape[0]
951943
N = self.num_kv_heads
952944
S = 1
953-
if self.fused_qkv_a_proj is not None:
954-
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
955-
_, kv = qkv_lora.split(
956-
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
957-
dim=-1,
958-
)
959-
else:
960-
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
945+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
961946
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
962947
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
963948
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
@@ -1120,23 +1105,9 @@ def forward(
11201105
self.running_chunkprefilll_with_torchair = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill
11211106
num_actual_toks = attn_metadata.num_actual_tokens
11221107
if k_pe is None and not self.running_in_graph:
1123-
if self.fused_qkv_a_proj is not None:
1124-
qkv_lora = self.fused_qkv_a_proj(
1125-
hidden_states_or_kv_c_normed)[0]
1126-
_, kv = qkv_lora.split(
1127-
[
1128-
self.q_lora_rank,
1129-
self.kv_lora_rank + self.qk_rope_head_dim
1130-
],
1131-
dim=-1,
1132-
)
1133-
kv_c, k_pe = kv.split(
1108+
kv_c, k_pe = self.kv_a_proj_with_mqa(
1109+
hidden_states_or_kv_c_normed)[0].split(
11341110
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1135-
else:
1136-
kv = self.kv_a_proj_with_mqa(hidden_states_or_kv_c_normed)[0]
1137-
kv_c, k_pe = self.kv_a_proj_with_mqa(
1138-
hidden_states_or_kv_c_normed)[0].split(
1139-
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
11401111
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
11411112
else:
11421113
kv_c_normed = hidden_states_or_kv_c_normed

0 commit comments

Comments
 (0)