@@ -29,24 +29,22 @@ def is_diffusers_available():
2929if is_diffusers_available ():
3030 import torch
3131
32- from diffusers import AutoPipelineForText2Image , DPMSolverMultistepScheduler , StableDiffusionPipeline
32+ from diffusers import DiffusionPipeline
3333
3434
35- class SMAutoPipelineForText2Image :
35+ class SMDiffusionPipelineForText2Image :
36+
3637 def __init__ (self , model_dir : str , device : str = None ): # needs "cuda" for GPU
38+ self .pipeline = None
3739 dtype = torch .float32
3840 if device == "cuda" :
3941 dtype = torch .bfloat16 if is_torch_bf16_gpu_available () else torch .float16
40- device_map = "auto" if device == "cuda" else None
42+ if torch .cuda .device_count () > 1 :
43+ device_map = "balanced"
44+ self .pipeline = DiffusionPipeline .from_pretrained (model_dir , torch_dtype = dtype , device_map = device_map )
4145
42- self .pipeline = AutoPipelineForText2Image .from_pretrained (model_dir , torch_dtype = dtype , device_map = device_map )
43- # try to use DPMSolverMultistepScheduler
44- if isinstance (self .pipeline , StableDiffusionPipeline ):
45- try :
46- self .pipeline .scheduler = DPMSolverMultistepScheduler .from_config (self .pipeline .scheduler .config )
47- except Exception :
48- pass
49- self .pipeline .to (device )
46+ if not self .pipeline :
47+ self .pipeline = DiffusionPipeline .from_pretrained (model_dir , torch_dtype = dtype ).to (device )
5048
5149 def __call__ (
5250 self ,
@@ -64,7 +62,7 @@ def __call__(
6462
6563
6664DIFFUSERS_TASKS = {
67- "text-to-image" : SMAutoPipelineForText2Image ,
65+ "text-to-image" : SMDiffusionPipelineForText2Image ,
6866}
6967
7068
0 commit comments