Skip to content

Commit 8f1dc00

Browse files
author
EC2 Default User
committed
Created more generic pipeline for text-to-image task
1 parent 9923001 commit 8f1dc00

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,22 @@ def is_diffusers_available():
2929
if is_diffusers_available():
3030
import torch
3131

32-
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionPipeline
32+
from diffusers import DiffusionPipeline
3333

3434

35-
class SMAutoPipelineForText2Image:
35+
class DiffusionPipelineForText2Image:
36+
3637
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
38+
self.pipeline = None
3739
dtype = torch.float32
3840
if device == "cuda":
3941
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
40-
device_map = "auto" if device == "cuda" else None
42+
if torch.cuda.device_count() > 1:
43+
device_map = "balanced"
44+
self.pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map)
4145

42-
self.pipeline = AutoPipelineForText2Image.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map)
43-
# try to use DPMSolverMultistepScheduler
44-
if isinstance(self.pipeline, StableDiffusionPipeline):
45-
try:
46-
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
47-
except Exception:
48-
pass
49-
self.pipeline.to(device)
46+
if not self.pipeline:
47+
self.pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=dtype).to(device)
5048

5149
def __call__(
5250
self,
@@ -64,7 +62,7 @@ def __call__(
6462

6563

6664
DIFFUSERS_TASKS = {
67-
"text-to-image": SMAutoPipelineForText2Image,
65+
"text-to-image": DiffusionPipelineForText2Image,
6866
}
6967

7068

0 commit comments

Comments
 (0)