Skip to content

Commit 7a311c6

Browse files
patrickvonplatenPrathik Rao
authored andcommitted
[DeviceMap] Make sure stable diffusion can be loaded from older trans… (huggingface#860)
[DeviceMap] Make sure stable diffusion can be loaded from older transformers versiosn
1 parent 0671472 commit 7a311c6

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import diffusers
2727
import PIL
2828
from huggingface_hub import snapshot_download
29+
from packaging import version
2930
from PIL import Image
3031
from tqdm.auto import tqdm
3132

@@ -45,6 +46,7 @@
4546

4647

4748
if is_transformers_available():
49+
import transformers
4850
from transformers import PreTrainedModel
4951

5052

@@ -508,11 +510,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
508510
loading_kwargs["provider"] = provider
509511
loading_kwargs["sess_options"] = sess_options
510512

511-
if (
512-
issubclass(class_obj, diffusers.ModelMixin)
513-
or is_transformers_available()
513+
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
514+
is_transformers_model = (
515+
is_transformers_available()
514516
and issubclass(class_obj, PreTrainedModel)
515-
):
517+
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
518+
)
519+
520+
if is_diffusers_model or is_transformers_model:
516521
loading_kwargs["device_map"] = device_map
517522

518523
# check if the module is in a subdirectory

0 commit comments

Comments
 (0)