1414from accelerate import Accelerator
1515from accelerate .logging import get_logger
1616from accelerate .utils import set_seed
17- from diffusers import AutoencoderKL , DDPMScheduler , StableDiffusionPipeline , UNet2DConditionModel
17+ from diffusers import AutoencoderKL , DDPMScheduler , DiffusionPipeline , UNet2DConditionModel
1818from diffusers .optimization import get_scheduler
1919from huggingface_hub import HfFolder , Repository , whoami
2020from PIL import Image
2121from torchvision import transforms
2222from tqdm .auto import tqdm
23- from transformers import CLIPTextModel , CLIPTokenizer
23+ from transformers import AutoTokenizer , PretrainedConfig
2424
2525
2626logger = get_logger (__name__ )
2727
2828
29+ def import_model_class_from_model_name_or_path (pretrained_model_name_or_path : str ):
30+ text_encoder_config = PretrainedConfig .from_pretrained (
31+ pretrained_model_name_or_path ,
32+ subfolder = "text_encoder" ,
33+ revision = args .revision ,
34+ )
35+ model_class = text_encoder_config .architectures [0 ]
36+
37+ if model_class == "CLIPTextModel" :
38+ from transformers import CLIPTextModel
39+
40+ return CLIPTextModel
41+ elif model_class == "RobertaSeriesModelWithTransformation" :
42+ from diffusers .pipelines .alt_diffusion .modeling_roberta_series import RobertaSeriesModelWithTransformation
43+
44+ return RobertaSeriesModelWithTransformation
45+ else :
46+ raise ValueError (f"{ model_class } is not supported." )
47+
48+
2949def parse_args (input_args = None ):
3050 parser = argparse .ArgumentParser (description = "Simple example of a training script." )
3151 parser .add_argument (
@@ -357,7 +377,7 @@ def main(args):
357377
358378 if cur_class_images < args .num_class_images :
359379 torch_dtype = torch .float16 if accelerator .device .type == "cuda" else torch .float32
360- pipeline = StableDiffusionPipeline .from_pretrained (
380+ pipeline = DiffusionPipeline .from_pretrained (
361381 args .pretrained_model_name_or_path ,
362382 torch_dtype = torch_dtype ,
363383 safety_checker = None ,
@@ -407,19 +427,24 @@ def main(args):
407427
408428 # Load the tokenizer
409429 if args .tokenizer_name :
410- tokenizer = CLIPTokenizer .from_pretrained (
430+ tokenizer = AutoTokenizer .from_pretrained (
411431 args .tokenizer_name ,
412432 revision = args .revision ,
433+ use_fast = False ,
413434 )
414435 elif args .pretrained_model_name_or_path :
415- tokenizer = CLIPTokenizer .from_pretrained (
436+ tokenizer = AutoTokenizer .from_pretrained (
416437 args .pretrained_model_name_or_path ,
417438 subfolder = "tokenizer" ,
418439 revision = args .revision ,
440+ use_fast = False ,
419441 )
420442
443+ # import correct text encoder class
444+ text_encoder_cls = import_model_class_from_model_name_or_path (args .pretrained_model_name_or_path )
445+
421446 # Load models and create wrapper for stable diffusion
422- text_encoder = CLIPTextModel .from_pretrained (
447+ text_encoder = text_encoder_cls .from_pretrained (
423448 args .pretrained_model_name_or_path ,
424449 subfolder = "text_encoder" ,
425450 revision = args .revision ,
@@ -649,7 +674,7 @@ def collate_fn(examples):
649674
650675 if global_step % args .save_steps == 0 :
651676 if accelerator .is_main_process :
652- pipeline = StableDiffusionPipeline .from_pretrained (
677+ pipeline = DiffusionPipeline .from_pretrained (
653678 args .pretrained_model_name_or_path ,
654679 unet = accelerator .unwrap_model (unet ),
655680 text_encoder = accelerator .unwrap_model (text_encoder ),
@@ -669,7 +694,7 @@ def collate_fn(examples):
669694
670695 # Create the pipeline using using the trained modules and save it.
671696 if accelerator .is_main_process :
672- pipeline = StableDiffusionPipeline .from_pretrained (
697+ pipeline = DiffusionPipeline .from_pretrained (
673698 args .pretrained_model_name_or_path ,
674699 unet = accelerator .unwrap_model (unet ),
675700 text_encoder = accelerator .unwrap_model (text_encoder ),
0 commit comments