@@ -649,7 +649,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
649649 )
650650
651651
652- def _test_processing_cache_correctness (
652+ def _test_processing_correctness (
653653 model_id : str ,
654654 modalities : dict [str , bool ],
655655 hit_rate : float ,
@@ -691,6 +691,7 @@ def _test_processing_cache_correctness(
691691 baseline_processor = factories .build_processor (ctx , cache = None )
692692 cached_processor = factories .build_processor (ctx , cache = cache )
693693 dummy_inputs = baseline_processor .dummy_inputs
694+ tokenizer = baseline_processor .info .get_tokenizer ()
694695
695696 rng = np .random .RandomState (0 )
696697
@@ -747,7 +748,25 @@ def _test_processing_cache_correctness(
747748 )
748749
749750 assert baseline_result == cached_result , (
750- f"Failed ({ batch_idx = } , { mm_data = } )" )
751+ f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" )
752+
753+ baseline_tokenized_result = baseline_processor .apply (
754+ tokenizer .encode (prompt ),
755+ mm_data = mm_data ,
756+ hf_processor_mm_kwargs = {},
757+ )
758+
759+ assert baseline_result == baseline_tokenized_result , (
760+ f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" )
761+
762+ cached_tokenized_result = cached_processor .apply (
763+ tokenizer .encode (prompt ),
764+ mm_data = mm_data ,
765+ hf_processor_mm_kwargs = {},
766+ )
767+
768+ assert cached_result == cached_tokenized_result , (
769+ f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" )
751770
752771
753772# yapf: disable
@@ -771,14 +790,14 @@ def _test_processing_cache_correctness(
771790@pytest .mark .parametrize ("num_batches" , [32 ])
772791@pytest .mark .parametrize ("simplify_rate" , [1.0 ])
773792# yapf: enable
774- def test_processing_cache_correctness (
793+ def test_processing_correctness (
775794 model_id : str ,
776795 modalities : dict [str , bool ],
777796 hit_rate : float ,
778797 num_batches : int ,
779798 simplify_rate : float ,
780799):
781- _test_processing_cache_correctness (
800+ _test_processing_correctness (
782801 model_id ,
783802 modalities ,
784803 hit_rate = hit_rate ,
@@ -795,7 +814,7 @@ def test_processing_cache_correctness(
795814@pytest .mark .parametrize ("num_batches" , [32 ])
796815@pytest .mark .parametrize ("simplify_rate" , [1.0 ])
797816# yapf: enable
798- def test_processing_cache_correctness_phi3v (
817+ def test_processing_correctness_phi3v (
799818 model_id : str ,
800819 modalities : dict [str , bool ],
801820 hit_rate : float ,
@@ -809,7 +828,7 @@ def test_processing_cache_correctness_phi3v(
809828
810829 AutoImageProcessor .from_pretrained (model_id , trust_remote_code = True )
811830
812- _test_processing_cache_correctness (
831+ _test_processing_correctness (
813832 model_id ,
814833 modalities ,
815834 hit_rate = hit_rate ,
0 commit comments