Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .utils import (
is_accelerate_available,
is_flax_available,
is_inflect_available,
is_onnx_available,
Expand All @@ -17,13 +16,6 @@
from .utils import logging


# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py"
# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available
if is_torch_available() and not is_accelerate_available():
error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501
raise ImportError(error_msg)


if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
Expand Down
34 changes: 30 additions & 4 deletions src/diffusers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,20 @@
import torch
from torch import Tensor, device

import accelerate
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError

from . import __version__
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
is_accelerate_available,
is_torch_version,
logging,
)


logger = logging.get_logger(__name__)
Expand All @@ -41,6 +46,12 @@
_LOW_CPU_MEM_USAGE_DEFAULT = False


if is_accelerate_available():
import accelerate
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version


def get_parameter_device(parameter: torch.nn.Module):
try:
return next(parameter.parameters()).device
Expand Down Expand Up @@ -319,6 +330,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)

if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warn(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)

if device_map is not None and not is_accelerate_available():
raise NotImplementedError(
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
" `device_map=None`. You can install accelerate with `pip install accelerate`."
)

# Check if we can handle device_map and dispatching the weights
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import diffusers
import PIL
from accelerate.utils.versions import is_torch_version
from huggingface_hub import snapshot_download
from packaging import version
from PIL import Image
Expand All @@ -43,6 +42,8 @@
WEIGHTS_NAME,
BaseOutput,
deprecate,
is_accelerate_available,
is_torch_version,
is_transformers_available,
logging,
)
Expand Down Expand Up @@ -397,6 +398,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)

if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warn(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)

if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
is_scipy_available,
is_tf_available,
is_torch_available,
is_torch_version,
is_transformers_available,
is_unidecode_available,
requires_backends,
Expand Down
15 changes: 0 additions & 15 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,21 +272,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class VQDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
Loading