1313# limitations under the License.
1414
1515import gc
16- import itertools
1716
1817import pytest
1918import torch
2019from accelerate .utils .memory import release_memory
2120from datasets import load_dataset
22- from parameterized import parameterized
2321from transformers import AutoModelForCausalLM , AutoTokenizer , BitsAndBytesConfig
2422from transformers .testing_utils import (
2523 backend_empty_cache ,
@@ -61,7 +59,8 @@ def teardown_method(self):
6159 backend_empty_cache (torch_device )
6260 gc .collect ()
6361
64- @parameterized .expand (list (itertools .product (MODELS_TO_TEST , PACKING_OPTIONS )))
62+ @pytest .mark .parametrize ("packing" , PACKING_OPTIONS )
63+ @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
6564 def test_sft_trainer_str (self , model_name , packing ):
6665 """
6766 Simply tests if passing a simple str to `SFTTrainer` loads and runs the trainer as expected.
@@ -85,7 +84,8 @@ def test_sft_trainer_str(self, model_name, packing):
8584
8685 trainer .train ()
8786
88- @parameterized .expand (list (itertools .product (MODELS_TO_TEST , PACKING_OPTIONS )))
87+ @pytest .mark .parametrize ("packing" , PACKING_OPTIONS )
88+ @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
8989 def test_sft_trainer_transformers (self , model_name , packing ):
9090 """
9191 Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer as expected.
@@ -115,7 +115,8 @@ def test_sft_trainer_transformers(self, model_name, packing):
115115
116116 release_memory (model , trainer )
117117
118- @parameterized .expand (list (itertools .product (MODELS_TO_TEST , PACKING_OPTIONS )))
118+ @pytest .mark .parametrize ("packing" , PACKING_OPTIONS )
119+ @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
119120 @require_peft
120121 def test_sft_trainer_peft (self , model_name , packing ):
121122 """
@@ -151,7 +152,8 @@ def test_sft_trainer_peft(self, model_name, packing):
151152
152153 release_memory (model , trainer )
153154
154- @parameterized .expand (list (itertools .product (MODELS_TO_TEST , PACKING_OPTIONS )))
155+ @pytest .mark .parametrize ("packing" , PACKING_OPTIONS )
156+ @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
155157 def test_sft_trainer_transformers_mp (self , model_name , packing ):
156158 """
157159 Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer as expected in mixed
@@ -183,7 +185,9 @@ def test_sft_trainer_transformers_mp(self, model_name, packing):
183185
184186 release_memory (model , trainer )
185187
186- @parameterized .expand (list (itertools .product (MODELS_TO_TEST , PACKING_OPTIONS , GRADIENT_CHECKPOINTING_KWARGS )))
188+ @pytest .mark .parametrize ("gradient_checkpointing_kwargs" , GRADIENT_CHECKPOINTING_KWARGS )
189+ @pytest .mark .parametrize ("packing" , PACKING_OPTIONS )
190+ @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
187191 def test_sft_trainer_transformers_mp_gc (self , model_name , packing , gradient_checkpointing_kwargs ):
188192 """
189193 Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer as expected in mixed
@@ -217,7 +221,9 @@ def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_chec
217221
218222 release_memory (model , trainer )
219223
220- @parameterized .expand (list (itertools .product (MODELS_TO_TEST , PACKING_OPTIONS , GRADIENT_CHECKPOINTING_KWARGS )))
224+ @pytest .mark .parametrize ("gradient_checkpointing_kwargs" , GRADIENT_CHECKPOINTING_KWARGS )
225+ @pytest .mark .parametrize ("packing" , PACKING_OPTIONS )
226+ @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
221227 @require_peft
222228 def test_sft_trainer_transformers_mp_gc_peft (self , model_name , packing , gradient_checkpointing_kwargs ):
223229 """
@@ -255,9 +261,10 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient
255261
256262 release_memory (model , trainer )
257263
258- @parameterized .expand (
259- list (itertools .product (MODELS_TO_TEST , PACKING_OPTIONS , GRADIENT_CHECKPOINTING_KWARGS , DEVICE_MAP_OPTIONS ))
260- )
264+ @pytest .mark .parametrize ("device_map" , DEVICE_MAP_OPTIONS )
265+ @pytest .mark .parametrize ("gradient_checkpointing_kwargs" , GRADIENT_CHECKPOINTING_KWARGS )
266+ @pytest .mark .parametrize ("packing" , PACKING_OPTIONS )
267+ @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
261268 @require_torch_multi_accelerator
262269 def test_sft_trainer_transformers_mp_gc_device_map (
263270 self , model_name , packing , gradient_checkpointing_kwargs , device_map
@@ -294,7 +301,9 @@ def test_sft_trainer_transformers_mp_gc_device_map(
294301
295302 release_memory (model , trainer )
296303
297- @parameterized .expand (list (itertools .product (MODELS_TO_TEST , PACKING_OPTIONS , GRADIENT_CHECKPOINTING_KWARGS )))
304+ @pytest .mark .parametrize ("gradient_checkpointing_kwargs" , GRADIENT_CHECKPOINTING_KWARGS )
305+ @pytest .mark .parametrize ("packing" , PACKING_OPTIONS )
306+ @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
298307 @require_peft
299308 @require_bitsandbytes
300309 def test_sft_trainer_transformers_mp_gc_peft_qlora (self , model_name , packing , gradient_checkpointing_kwargs ):
@@ -335,7 +344,8 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
335344
336345 release_memory (model , trainer )
337346
338- @parameterized .expand (list (itertools .product (MODELS_TO_TEST , PACKING_OPTIONS )))
347+ @pytest .mark .parametrize ("packing" , PACKING_OPTIONS )
348+ @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
339349 @require_peft
340350 @require_bitsandbytes
341351 def test_sft_trainer_with_chat_format_qlora (self , model_name , packing ):
@@ -375,7 +385,8 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
375385
376386 release_memory (model , trainer )
377387
378- @parameterized .expand (list (itertools .product (MODELS_TO_TEST , PACKING_OPTIONS )))
388+ @pytest .mark .parametrize ("packing" , PACKING_OPTIONS )
389+ @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
379390 @require_liger_kernel
380391 def test_sft_trainer_with_liger (self , model_name , packing ):
381392 """
@@ -419,7 +430,8 @@ def cleanup_liger_patches(trainer):
419430 finally :
420431 cleanup_liger_patches (trainer )
421432
422- @parameterized .expand (list (itertools .product (MODELS_TO_TEST , PACKING_OPTIONS )))
433+ @pytest .mark .parametrize ("packing" , PACKING_OPTIONS )
434+ @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
423435 @require_torch_accelerator
424436 def test_train_offloading (self , model_name , packing ):
425437 """Test that activation offloading works with SFTTrainer."""
0 commit comments