@@ -252,6 +252,8 @@ def _test_reshape_and_cache_func(
252
252
num_blocks : int ,
253
253
dtype : torch .dtype ,
254
254
seed : int ,
255
+ key_is_contiguous : bool ,
256
+ value_is_contiguous : bool ,
255
257
) -> None :
256
258
random .seed (seed )
257
259
torch .random .manual_seed (seed )
@@ -264,6 +266,13 @@ def _test_reshape_and_cache_func(
264
266
265
267
qkv = torch .randn (num_token , 3 , num_head , head_size , dtype = dtype , device = "cpu" )
266
268
_ , key , value = qkv .unbind (dim = 1 )
269
+ if key .shape [0 ] != 1 :
270
+ if not key_is_contiguous :
271
+ key = key .transpose (0 , 1 ).contiguous ()
272
+ key = key .transpose (0 , 1 )
273
+ if not value_is_contiguous :
274
+ value = value .transpose (0 , 1 ).contiguous ()
275
+ value = value .transpose (0 , 1 )
267
276
# Create the KV caches.
268
277
key_caches , value_caches = self .create_kv_caches (
269
278
num_blocks , block_size , 1 , num_head , head_size , dtype , seed
@@ -300,6 +309,8 @@ def test_reshape_and_cache(self):
300
309
head_sizes = [64 , 80 , 128 , 96 , 112 , 128 , 256 ]
301
310
block_sizes = [16 , 32 ]
302
311
dtypes = [torch .bfloat16 , torch .float ]
312
+ key_modes = [True , False ]
313
+ value_modes = [True , False ]
303
314
if core .onednn_has_fp16_support ():
304
315
dtypes .append (torch .float16 )
305
316
seeds = [0 ]
@@ -310,16 +321,28 @@ def test_reshape_and_cache(self):
310
321
block_size ,
311
322
dtype ,
312
323
seed ,
324
+ key_is_contiguous ,
325
+ value_is_contiguous ,
313
326
) in product (
314
327
num_tokens ,
315
328
num_kv_heads ,
316
329
head_sizes ,
317
330
block_sizes ,
318
331
dtypes ,
319
332
seeds ,
333
+ key_modes ,
334
+ value_modes ,
320
335
):
321
336
self ._test_reshape_and_cache_func (
322
- num_token , num_kv_head , head_size , block_size , num_blocks , dtype , seed
337
+ num_token ,
338
+ num_kv_head ,
339
+ head_size ,
340
+ block_size ,
341
+ num_blocks ,
342
+ dtype ,
343
+ seed ,
344
+ key_is_contiguous ,
345
+ value_is_contiguous ,
323
346
)
324
347
325
348
0 commit comments