4141 UNet2DConditionModel ,
4242)
4343from diffusers .pipelines .latent_diffusion .pipeline_latent_diffusion import LDMBertConfig , LDMBertModel
44+ from diffusers .pipelines .paint_by_example import PaintByExampleImageEncoder , PaintByExamplePipeline
4445from diffusers .pipelines .stable_diffusion import StableDiffusionSafetyChecker
45- from transformers import AutoFeatureExtractor , BertTokenizerFast , CLIPTextModel , CLIPTokenizer
46+ from transformers import AutoFeatureExtractor , BertTokenizerFast , CLIPTextModel , CLIPTokenizer , CLIPVisionConfig
4647
4748
4849def shave_segments (path , n_shave_prefix_segments = 1 ):
@@ -647,6 +648,73 @@ def convert_ldm_clip_checkpoint(checkpoint):
647648 return text_model
648649
649650
651+ def convert_paint_by_example_checkpoint (checkpoint ):
652+ config = CLIPVisionConfig .from_pretrained ("openai/clip-vit-large-patch14" )
653+ model = PaintByExampleImageEncoder (config )
654+
655+ keys = list (checkpoint .keys ())
656+
657+ text_model_dict = {}
658+
659+ for key in keys :
660+ if key .startswith ("cond_stage_model.transformer" ):
661+ text_model_dict [key [len ("cond_stage_model.transformer." ) :]] = checkpoint [key ]
662+
663+ # load clip vision
664+ model .model .load_state_dict (text_model_dict )
665+
666+ # load mapper
667+ keys_mapper = {
668+ k [len ("cond_stage_model.mapper.res" ) :]: v
669+ for k , v in checkpoint .items ()
670+ if k .startswith ("cond_stage_model.mapper" )
671+ }
672+
673+ MAPPING = {
674+ "attn.c_qkv" : ["attn1.to_q" , "attn1.to_k" , "attn1.to_v" ],
675+ "attn.c_proj" : ["attn1.to_out.0" ],
676+ "ln_1" : ["norm1" ],
677+ "ln_2" : ["norm3" ],
678+ "mlp.c_fc" : ["ff.net.0.proj" ],
679+ "mlp.c_proj" : ["ff.net.2" ],
680+ }
681+
682+ mapped_weights = {}
683+ for key , value in keys_mapper .items ():
684+ prefix = key [: len ("blocks.i" )]
685+ suffix = key .split (prefix )[- 1 ].split ("." )[- 1 ]
686+ name = key .split (prefix )[- 1 ].split (suffix )[0 ][1 :- 1 ]
687+ mapped_names = MAPPING [name ]
688+
689+ num_splits = len (mapped_names )
690+ for i , mapped_name in enumerate (mapped_names ):
691+ new_name = "." .join ([prefix , mapped_name , suffix ])
692+ shape = value .shape [0 ] // num_splits
693+ mapped_weights [new_name ] = value [i * shape : (i + 1 ) * shape ]
694+
695+ model .mapper .load_state_dict (mapped_weights )
696+
697+ # load final layer norm
698+ model .final_layer_norm .load_state_dict (
699+ {
700+ "bias" : checkpoint ["cond_stage_model.final_ln.bias" ],
701+ "weight" : checkpoint ["cond_stage_model.final_ln.weight" ],
702+ }
703+ )
704+
705+ # load final proj
706+ model .proj_out .load_state_dict (
707+ {
708+ "bias" : checkpoint ["proj_out.bias" ],
709+ "weight" : checkpoint ["proj_out.weight" ],
710+ }
711+ )
712+
713+ # load uncond vector
714+ model .uncond_vector .data = torch .nn .Parameter (checkpoint ["learnable_vector" ])
715+ return model
716+
717+
650718def convert_open_clip_checkpoint (checkpoint ):
651719 text_model = CLIPTextModel .from_pretrained ("stabilityai/stable-diffusion-2" , subfolder = "text_encoder" )
652720
@@ -676,12 +744,24 @@ def convert_open_clip_checkpoint(checkpoint):
676744 type = str ,
677745 help = "The YAML config file corresponding to the original architecture." ,
678746 )
747+ parser .add_argument (
748+ "--num_in_channels" ,
749+ default = None ,
750+ type = int ,
751+ help = "The number of input channels. If `None` number of input channels will be automatically inferred." ,
752+ )
679753 parser .add_argument (
680754 "--scheduler_type" ,
681755 default = "pndm" ,
682756 type = str ,
683757 help = "Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']" ,
684758 )
759+ parser .add_argument (
760+ "--pipeline_type" ,
761+ default = None ,
762+ type = str ,
763+ help = "The pipeline type. If `None` pipeline will be automatically inferred." ,
764+ )
685765 parser .add_argument (
686766 "--image_size" ,
687767 default = None ,
@@ -737,6 +817,9 @@ def convert_open_clip_checkpoint(checkpoint):
737817
738818 original_config = OmegaConf .load (args .original_config_file )
739819
820+ if args .num_in_channels is not None :
821+ original_config ["model" ]["params" ]["unet_config" ]["params" ]["in_channels" ] = args .num_in_channels
822+
740823 if (
741824 "parameterization" in original_config ["model" ]["params" ]
742825 and original_config ["model" ]["params" ]["parameterization" ] == "v"
@@ -806,8 +889,11 @@ def convert_open_clip_checkpoint(checkpoint):
806889 vae .load_state_dict (converted_vae_checkpoint )
807890
808891 # Convert the text model.
809- text_model_type = original_config .model .params .cond_stage_config .target .split ("." )[- 1 ]
810- if text_model_type == "FrozenOpenCLIPEmbedder" :
892+ model_type = args .pipeline_type
893+ if model_type is None :
894+ model_type = original_config .model .params .cond_stage_config .target .split ("." )[- 1 ]
895+
896+ if model_type == "FrozenOpenCLIPEmbedder" :
811897 text_model = convert_open_clip_checkpoint (checkpoint )
812898 tokenizer = CLIPTokenizer .from_pretrained ("stabilityai/stable-diffusion-2" , subfolder = "tokenizer" )
813899 pipe = StableDiffusionPipeline (
@@ -820,7 +906,19 @@ def convert_open_clip_checkpoint(checkpoint):
820906 feature_extractor = None ,
821907 requires_safety_checker = False ,
822908 )
823- elif text_model_type == "FrozenCLIPEmbedder" :
909+ elif model_type == "PaintByExample" :
910+ vision_model = convert_paint_by_example_checkpoint (checkpoint )
911+ tokenizer = CLIPTokenizer .from_pretrained ("openai/clip-vit-large-patch14" )
912+ feature_extractor = AutoFeatureExtractor .from_pretrained ("CompVis/stable-diffusion-safety-checker" )
913+ pipe = PaintByExamplePipeline (
914+ vae = vae ,
915+ image_encoder = vision_model ,
916+ unet = unet ,
917+ scheduler = scheduler ,
918+ safety_checker = None ,
919+ feature_extractor = feature_extractor ,
920+ )
921+ elif model_type == "FrozenCLIPEmbedder" :
824922 text_model = convert_ldm_clip_checkpoint (checkpoint )
825923 tokenizer = CLIPTokenizer .from_pretrained ("openai/clip-vit-large-patch14" )
826924 safety_checker = StableDiffusionSafetyChecker .from_pretrained ("CompVis/stable-diffusion-safety-checker" )
0 commit comments