Skip to content
12 changes: 12 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,18 @@ def decorator(test_case):
return decorator


def require_torch_version_greater(torch_version):
"""Decorator marking a test that requires torch with a specific version greater."""

def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
return unittest.skipUnless(
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
)(test_case)

return decorator


def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
Expand Down
27 changes: 27 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
FluxTransformer2DModel,
SD3Transformer2DModel,
)
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
Expand All @@ -44,11 +45,14 @@
require_peft_backend,
require_torch,
require_torch_accelerator,
require_torch_version_greater,
require_transformers_version_greater,
slow,
torch_device,
)

from ..test_torch_compile_utils import QuantCompileMiscTests


def get_some_linear_layer(model):
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
Expand Down Expand Up @@ -855,3 +859,26 @@ def test_fp4_double_unsafe(self):

def test_fp4_double_safe(self):
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)


@require_torch_version_greater("2.7.1")
class Bnb4BitCompileTests(QuantCompileMiscTests):
quantization_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["transformer", "text_encoder_2"],
)

def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config)

def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)

def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)
28 changes: 28 additions & 0 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
SD3Transformer2DModel,
logging,
)
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import (
CaptureLogger,
Expand All @@ -42,11 +43,14 @@
require_peft_version_greater,
require_torch,
require_torch_accelerator,
require_torch_version_greater_equal,
require_transformers_version_greater,
slow,
torch_device,
)

from ..test_torch_compile_utils import QuantCompileMiscTests


def get_some_linear_layer(model):
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
Expand Down Expand Up @@ -773,3 +777,27 @@ def test_serialization_sharded(self):
out_0 = self.model_0(**inputs)[0]
out_1 = model_1(**inputs)[0]
self.assertTrue(torch.equal(out_0, out_1))


@require_torch_version_greater_equal("2.6.0")
class Bnb8BitCompileTests(QuantCompileMiscTests):
quantization_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=["transformer", "text_encoder_2"],
)

def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)

def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(
quantization_config=self.quantization_config, torch_dtype=torch.float16
)

@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(
quantization_config=self.quantization_config, torch_dtype=torch.float16
)
87 changes: 87 additions & 0 deletions tests/quantization/test_torch_compile_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest

import torch

from diffusers import DiffusionPipeline
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device


@require_torch_gpu
@slow
class QuantCompileMiscTests(unittest.TestCase):
quantization_config = None

def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
torch.compiler.reset()

def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
torch.compiler.reset()

def _init_pipeline(self, quantization_config, torch_dtype):
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
quantization_config=quantization_config,
torch_dtype=torch_dtype,
)
return pipe

def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
# import to ensure fullgraph True
pipe.transformer.compile(fullgraph=True)

for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)

def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype)
pipe.enable_model_cpu_offload()
pipe.transformer.compile()

for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)

def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
torch._dynamo.config.cache_size_limit = 10000

pipe = self._init_pipeline(quantization_config, torch_dtype)
group_offload_kwargs = {
"onload_device": torch.device("cuda"),
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
"use_stream": True,
"non_blocking": True,
}
pipe.transformer.enable_group_offload(**group_offload_kwargs)
pipe.transformer.compile()
for name, component in pipe.components.items():
if name != "transformer" and isinstance(component, torch.nn.Module):
if torch.device(component.device).type == "cpu":
component.to("cuda")

for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)