@@ -164,29 +164,6 @@ def trtllm_prefill_attn_kvfp8_dequant(
164164 return mock_kv_cache , mock_block_table
165165
166166
167- @dataclass
168- class BatchDCPPrefillPlanConfig :
169- """Parameters for BatchDCPPrefillWrapper.plan() method."""
170-
171- qo_indptr_cpu : torch .Tensor
172- paged_kv_indptr_cpu : torch .Tensor
173- paged_kv_indices : torch .Tensor
174- paged_kv_last_page_len_cpu : torch .Tensor
175- prefill_start : int
176- page_size : int
177- num_qo_heads : int
178- dcp_world_size : int
179- num_kv_heads : int
180- head_dim : int
181- sm_scale : float
182- window_left : int
183- logits_soft_cap : float | None
184- q_data_type : torch .dtype
185- kv_cache_dtype : torch .dtype
186- prefill_fixed_split_size : int
187- disable_split_kv : bool
188-
189-
190167class BatchDCPPrefillWrapper :
191168 def __init__ (
192169 self ,
@@ -199,38 +176,57 @@ def __init__(
199176 workspace_buffer , get_kv_cache_layout ()
200177 )
201178
202- def plan (self , cfg : BatchDCPPrefillPlanConfig ):
179+ def plan (
180+ self ,
181+ qo_indptr_cpu : torch .Tensor ,
182+ paged_kv_indptr_cpu : torch .Tensor ,
183+ paged_kv_indices : torch .Tensor ,
184+ paged_kv_last_page_len_cpu : torch .Tensor ,
185+ prefill_start : int ,
186+ page_size : int ,
187+ num_qo_heads : int ,
188+ dcp_world_size : int ,
189+ num_kv_heads : int ,
190+ head_dim : int ,
191+ sm_scale : float ,
192+ window_left : int ,
193+ logits_soft_cap : float | None ,
194+ q_data_type : torch .dtype ,
195+ kv_cache_dtype : torch .dtype ,
196+ prefill_fixed_split_size : int ,
197+ disable_split_kv : bool ,
198+ ):
203199 """Plan the prefill operation with given parameters."""
204200 self ._context .plan (
205- cfg . qo_indptr_cpu ,
206- cfg . paged_kv_indptr_cpu ,
207- cfg . paged_kv_indices ,
208- cfg . paged_kv_last_page_len_cpu [cfg . prefill_start :],
209- cfg . num_qo_heads * cfg . dcp_world_size ,
210- cfg . num_kv_heads ,
211- cfg . head_dim ,
212- cfg . page_size ,
201+ qo_indptr_cpu ,
202+ paged_kv_indptr_cpu ,
203+ paged_kv_indices ,
204+ paged_kv_last_page_len_cpu [prefill_start :],
205+ num_qo_heads * dcp_world_size ,
206+ num_kv_heads ,
207+ head_dim ,
208+ page_size ,
213209 causal = False , # This is context run
214- sm_scale = cfg . sm_scale ,
215- window_left = cfg . window_left ,
216- logits_soft_cap = cfg . logits_soft_cap ,
217- q_data_type = cfg . q_data_type ,
218- kv_data_type = cfg . kv_cache_dtype ,
219- fixed_split_size = cfg . prefill_fixed_split_size ,
220- disable_split_kv = cfg . disable_split_kv ,
210+ sm_scale = sm_scale ,
211+ window_left = window_left ,
212+ logits_soft_cap = logits_soft_cap ,
213+ q_data_type = q_data_type ,
214+ kv_data_type = kv_cache_dtype ,
215+ fixed_split_size = prefill_fixed_split_size ,
216+ disable_split_kv = disable_split_kv ,
221217 )
222218 self ._new_tokens .plan (
223- qo_indptr = cfg . qo_indptr_cpu ,
224- kv_indptr = cfg . qo_indptr_cpu ,
225- num_qo_heads = cfg . num_qo_heads ,
226- num_kv_heads = cfg . num_kv_heads ,
227- head_dim_qk = cfg . head_dim ,
228- head_dim_vo = cfg . head_dim ,
219+ qo_indptr = qo_indptr_cpu ,
220+ kv_indptr = qo_indptr_cpu ,
221+ num_qo_heads = num_qo_heads ,
222+ num_kv_heads = num_kv_heads ,
223+ head_dim_qk = head_dim ,
224+ head_dim_vo = head_dim ,
229225 causal = True , # This is newtokens run
230- sm_scale = cfg . sm_scale ,
231- window_left = cfg . window_left ,
232- logits_soft_cap = cfg . logits_soft_cap ,
233- q_data_type = cfg . q_data_type ,
226+ sm_scale = sm_scale ,
227+ window_left = window_left ,
228+ logits_soft_cap = logits_soft_cap ,
229+ q_data_type = q_data_type ,
234230 )
235231
236232 def run (
@@ -240,6 +236,7 @@ def run(
240236 kv_cache_permute : torch .Tensor ,
241237 key : torch .Tensor ,
242238 value : torch .Tensor ,
239+ out : torch .Tensor ,
243240 ):
244241 prefill_query_across_dcp = get_dcp_group ().all_gather (
245242 prefill_query .contiguous (), dim = 1
@@ -264,15 +261,14 @@ def run(
264261 )
265262 lse_query = lse_query .transpose (0 , 1 ).contiguous ()
266263
267- output = torch .empty_like (prefill_query )
268264 merge_attn_states (
269- output ,
265+ out ,
270266 output_context ,
271267 lse_context ,
272268 output_query ,
273269 lse_query ,
274270 )
275- return output
271+ return out
276272
277273
278274class FlashInferBackend (AttentionBackend ):
@@ -847,7 +843,7 @@ def build(
847843 assert isinstance (
848844 attn_metadata .prefill_wrapper , BatchDCPPrefillWrapper
849845 )
850- plan_cfgs = BatchDCPPrefillPlanConfig (
846+ attn_metadata . prefill_wrapper . plan (
851847 qo_indptr_cpu = qo_indptr_cpu ,
852848 paged_kv_indptr_cpu = paged_kv_indptr_cpu ,
853849 paged_kv_indices = paged_kv_indices ,
@@ -866,7 +862,6 @@ def build(
866862 prefill_fixed_split_size = self .prefill_fixed_split_size ,
867863 disable_split_kv = self .disable_split_kv ,
868864 )
869- attn_metadata .prefill_wrapper .plan (plan_cfgs )
870865 else :
871866 assert isinstance (
872867 attn_metadata .prefill_wrapper ,
@@ -1203,12 +1198,13 @@ def forward(
12031198 assert prefill_wrapper ._new_tokens ._sm_scale == self .scale
12041199 assert prefill_wrapper ._new_tokens .causal
12051200
1206- output [ num_decode_tokens :] = prefill_wrapper .run (
1201+ prefill_wrapper .run (
12071202 layer ,
12081203 prefill_query ,
12091204 kv_cache_permute ,
12101205 key [num_decode_tokens :],
12111206 value [num_decode_tokens :],
1207+ out = output [num_decode_tokens :],
12121208 )
12131209 else :
12141210 assert isinstance (
0 commit comments