Skip to content

Commit c172a1a

Browse files
author
Jingchun Gao
committed
rename and clean code
Signed-off-by: Jingchun Gao <[email protected]>
1 parent aedf9c4 commit c172a1a

File tree

2 files changed

+53
-56
lines changed

2 files changed

+53
-56
lines changed

vllm/utils/flashinfer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,8 @@ def use_trtllm_attention(
265265
# Decode context parallel is not supported
266266
if dcp_world_size > 1:
267267
logger.warning_once(
268-
"Trtllm not support lse, please use flash attention or FlashInfer backend."
268+
"Trtllm does not support returning LSE and as a result"
269+
"does not support DCP, reverting to FlashInfer"
269270
)
270271
return False
271272

vllm/v1/attention/backends/flashinfer.py

Lines changed: 51 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -164,29 +164,6 @@ def trtllm_prefill_attn_kvfp8_dequant(
164164
return mock_kv_cache, mock_block_table
165165

166166

167-
@dataclass
168-
class BatchDCPPrefillPlanConfig:
169-
"""Parameters for BatchDCPPrefillWrapper.plan() method."""
170-
171-
qo_indptr_cpu: torch.Tensor
172-
paged_kv_indptr_cpu: torch.Tensor
173-
paged_kv_indices: torch.Tensor
174-
paged_kv_last_page_len_cpu: torch.Tensor
175-
prefill_start: int
176-
page_size: int
177-
num_qo_heads: int
178-
dcp_world_size: int
179-
num_kv_heads: int
180-
head_dim: int
181-
sm_scale: float
182-
window_left: int
183-
logits_soft_cap: float | None
184-
q_data_type: torch.dtype
185-
kv_cache_dtype: torch.dtype
186-
prefill_fixed_split_size: int
187-
disable_split_kv: bool
188-
189-
190167
class BatchDCPPrefillWrapper:
191168
def __init__(
192169
self,
@@ -199,38 +176,57 @@ def __init__(
199176
workspace_buffer, get_kv_cache_layout()
200177
)
201178

202-
def plan(self, cfg: BatchDCPPrefillPlanConfig):
179+
def plan(
180+
self,
181+
qo_indptr_cpu: torch.Tensor,
182+
paged_kv_indptr_cpu: torch.Tensor,
183+
paged_kv_indices: torch.Tensor,
184+
paged_kv_last_page_len_cpu: torch.Tensor,
185+
prefill_start: int,
186+
page_size: int,
187+
num_qo_heads: int,
188+
dcp_world_size: int,
189+
num_kv_heads: int,
190+
head_dim: int,
191+
sm_scale: float,
192+
window_left: int,
193+
logits_soft_cap: float | None,
194+
q_data_type: torch.dtype,
195+
kv_cache_dtype: torch.dtype,
196+
prefill_fixed_split_size: int,
197+
disable_split_kv: bool,
198+
):
203199
"""Plan the prefill operation with given parameters."""
204200
self._context.plan(
205-
cfg.qo_indptr_cpu,
206-
cfg.paged_kv_indptr_cpu,
207-
cfg.paged_kv_indices,
208-
cfg.paged_kv_last_page_len_cpu[cfg.prefill_start :],
209-
cfg.num_qo_heads * cfg.dcp_world_size,
210-
cfg.num_kv_heads,
211-
cfg.head_dim,
212-
cfg.page_size,
201+
qo_indptr_cpu,
202+
paged_kv_indptr_cpu,
203+
paged_kv_indices,
204+
paged_kv_last_page_len_cpu[prefill_start:],
205+
num_qo_heads * dcp_world_size,
206+
num_kv_heads,
207+
head_dim,
208+
page_size,
213209
causal=False, # This is context run
214-
sm_scale=cfg.sm_scale,
215-
window_left=cfg.window_left,
216-
logits_soft_cap=cfg.logits_soft_cap,
217-
q_data_type=cfg.q_data_type,
218-
kv_data_type=cfg.kv_cache_dtype,
219-
fixed_split_size=cfg.prefill_fixed_split_size,
220-
disable_split_kv=cfg.disable_split_kv,
210+
sm_scale=sm_scale,
211+
window_left=window_left,
212+
logits_soft_cap=logits_soft_cap,
213+
q_data_type=q_data_type,
214+
kv_data_type=kv_cache_dtype,
215+
fixed_split_size=prefill_fixed_split_size,
216+
disable_split_kv=disable_split_kv,
221217
)
222218
self._new_tokens.plan(
223-
qo_indptr=cfg.qo_indptr_cpu,
224-
kv_indptr=cfg.qo_indptr_cpu,
225-
num_qo_heads=cfg.num_qo_heads,
226-
num_kv_heads=cfg.num_kv_heads,
227-
head_dim_qk=cfg.head_dim,
228-
head_dim_vo=cfg.head_dim,
219+
qo_indptr=qo_indptr_cpu,
220+
kv_indptr=qo_indptr_cpu,
221+
num_qo_heads=num_qo_heads,
222+
num_kv_heads=num_kv_heads,
223+
head_dim_qk=head_dim,
224+
head_dim_vo=head_dim,
229225
causal=True, # This is newtokens run
230-
sm_scale=cfg.sm_scale,
231-
window_left=cfg.window_left,
232-
logits_soft_cap=cfg.logits_soft_cap,
233-
q_data_type=cfg.q_data_type,
226+
sm_scale=sm_scale,
227+
window_left=window_left,
228+
logits_soft_cap=logits_soft_cap,
229+
q_data_type=q_data_type,
234230
)
235231

236232
def run(
@@ -240,6 +236,7 @@ def run(
240236
kv_cache_permute: torch.Tensor,
241237
key: torch.Tensor,
242238
value: torch.Tensor,
239+
out: torch.Tensor,
243240
):
244241
prefill_query_across_dcp = get_dcp_group().all_gather(
245242
prefill_query.contiguous(), dim=1
@@ -264,15 +261,14 @@ def run(
264261
)
265262
lse_query = lse_query.transpose(0, 1).contiguous()
266263

267-
output = torch.empty_like(prefill_query)
268264
merge_attn_states(
269-
output,
265+
out,
270266
output_context,
271267
lse_context,
272268
output_query,
273269
lse_query,
274270
)
275-
return output
271+
return out
276272

277273

278274
class FlashInferBackend(AttentionBackend):
@@ -847,7 +843,7 @@ def build(
847843
assert isinstance(
848844
attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper
849845
)
850-
plan_cfgs = BatchDCPPrefillPlanConfig(
846+
attn_metadata.prefill_wrapper.plan(
851847
qo_indptr_cpu=qo_indptr_cpu,
852848
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
853849
paged_kv_indices=paged_kv_indices,
@@ -866,7 +862,6 @@ def build(
866862
prefill_fixed_split_size=self.prefill_fixed_split_size,
867863
disable_split_kv=self.disable_split_kv,
868864
)
869-
attn_metadata.prefill_wrapper.plan(plan_cfgs)
870865
else:
871866
assert isinstance(
872867
attn_metadata.prefill_wrapper,
@@ -1203,12 +1198,13 @@ def forward(
12031198
assert prefill_wrapper._new_tokens._sm_scale == self.scale
12041199
assert prefill_wrapper._new_tokens.causal
12051200

1206-
output[num_decode_tokens:] = prefill_wrapper.run(
1201+
prefill_wrapper.run(
12071202
layer,
12081203
prefill_query,
12091204
kv_cache_permute,
12101205
key[num_decode_tokens:],
12111206
value[num_decode_tokens:],
1207+
out=output[num_decode_tokens:],
12121208
)
12131209
else:
12141210
assert isinstance(

0 commit comments

Comments
 (0)