Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,7 +2107,7 @@ def test_hotswapping_compiled_model_linear(self, rank0, rank1):
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
return
pytest.skip("Test only applies to UNet.")

# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
Expand All @@ -2117,7 +2117,7 @@ def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
return
pytest.skip("Test only applies to UNet.")

# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
Expand Down