@@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str):
365365 # Next only keep the first 2 draft tokens same as the zero temperature
366366 # tokens. For the remaining 3 choose some other tokens. In the
367367 # response we will expect the first 2 tokens to be the same as the
368- # draft tokens and the rest as -1
368+ # draft tokens and the recovered token and rest as -1
369369 draft_token_ids_to_replace = get_draft_token_ids (
370370 batch_size , k , vocab_size , zero_temperature_token_ids )
371371 draft_token_ids = torch .cat (
@@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str):
378378 assert output_token_ids .shape [0 ] == batch_size
379379 assert output_token_ids .shape [1 ] == (k + 1 )
380380 assert torch .all (output_token_ids [:, :2 ] == draft_token_ids [:, :2 ])
381+ assert torch .all (
382+ output_token_ids [:, 2 ] == target_with_bonus_probs .argmax (- 1 )[:, 2 ])
381383 assert torch .all (output_token_ids [:, - 3 :] == - 1 )
382384
383385
@@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
443445@pytest .mark .parametrize ("seed" , list (range (10 )))
444446@pytest .mark .parametrize ("device" , CUDA_DEVICES )
445447@torch .inference_mode ()
446- def test_replacement_token_ids (seed : int , device : str ):
448+ def test_get_recovered_token_ids (seed : int , device : str ):
447449 """
448450 Test the TypicalAcceptanceSampler's method for generating
449451 replacement token IDs.
450452
451- This test verifies that the `_replacement_token_ids ` method of the
453+ This test verifies that the `_get_recovered_token_ids ` method of the
452454 TypicalAcceptanceSampler correctly identifies the token IDs to be used
453- as replacements based on the target probability distribution.
455+ as recovered token IDs based on the target probability distribution.
454456 Specifically, it ensures that the method correctly identifies the
455457 tokens with the highest probability for each sequence in the batch.
456458 """
@@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str):
462464 typical_acceptance_sampler = get_acceptance_sampler (strict_mode = True )
463465 typical_acceptance_sampler .init_gpu_tensors (device = device )
464466 target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
465- expected_replacement_tokens = - torch .ones (
466- (batch_size , k ), dtype = torch .long )
467- expected_replacement_tokens [:, 0 ] = torch .argmax (target_probs [:, 0 , :],
468- dim = 1 )
467+ expected_replacement_tokens = torch .argmax (target_probs , dim = - 1 )
469468 actual_replacement_tokens = (
470- typical_acceptance_sampler ._replacement_token_ids (target_probs ))
469+ typical_acceptance_sampler ._get_recovered_token_ids (target_probs ))
471470 assert torch .all (expected_replacement_tokens == actual_replacement_tokens )
0 commit comments