|  | 
| 18 | 18 | from huggingface_hub.utils import validate_hf_hub_args | 
| 19 | 19 | 
 | 
| 20 | 20 | from ..configuration_utils import ConfigMixin | 
|  | 21 | +from ..models.controlnets import ControlNetUnionModel | 
| 21 | 22 | from ..utils import is_sentencepiece_available | 
| 22 | 23 | from .aura_flow import AuraFlowPipeline | 
| 23 | 24 | from .cogview3 import CogView3PlusPipeline | 
|  | 
| 28 | 29 |     StableDiffusionXLControlNetImg2ImgPipeline, | 
| 29 | 30 |     StableDiffusionXLControlNetInpaintPipeline, | 
| 30 | 31 |     StableDiffusionXLControlNetPipeline, | 
|  | 32 | +    StableDiffusionXLControlNetUnionImg2ImgPipeline, | 
|  | 33 | +    StableDiffusionXLControlNetUnionInpaintPipeline, | 
|  | 34 | +    StableDiffusionXLControlNetUnionPipeline, | 
| 31 | 35 | ) | 
| 32 | 36 | from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline | 
| 33 | 37 | from .flux import ( | 
|  | 
| 108 | 112 |         ("kandinsky3", Kandinsky3Pipeline), | 
| 109 | 113 |         ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), | 
| 110 | 114 |         ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), | 
|  | 115 | +        ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline), | 
| 111 | 116 |         ("wuerstchen", WuerstchenCombinedPipeline), | 
| 112 | 117 |         ("cascade", StableCascadeCombinedPipeline), | 
| 113 | 118 |         ("lcm", LatentConsistencyModelPipeline), | 
|  | 
| 139 | 144 |         ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline), | 
| 140 | 145 |         ("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline), | 
| 141 | 146 |         ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline), | 
|  | 147 | +        ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline), | 
| 142 | 148 |         ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), | 
| 143 | 149 |         ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), | 
| 144 | 150 |         ("lcm", LatentConsistencyModelImg2ImgPipeline), | 
|  | 
| 158 | 164 |         ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), | 
| 159 | 165 |         ("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline), | 
| 160 | 166 |         ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), | 
|  | 167 | +        ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline), | 
| 161 | 168 |         ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), | 
| 162 | 169 |         ("flux", FluxInpaintPipeline), | 
| 163 | 170 |         ("flux-controlnet", FluxControlNetInpaintPipeline), | 
| @@ -396,7 +403,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): | 
| 396 | 403 |         orig_class_name = config["_class_name"] | 
| 397 | 404 | 
 | 
| 398 | 405 |         if "controlnet" in kwargs: | 
| 399 |  | -            orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") | 
|  | 406 | +            if isinstance(kwargs["controlnet"], ControlNetUnionModel): | 
|  | 407 | +                orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline") | 
|  | 408 | +            else: | 
|  | 409 | +                orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") | 
| 400 | 410 |         if "enable_pag" in kwargs: | 
| 401 | 411 |             enable_pag = kwargs.pop("enable_pag") | 
| 402 | 412 |             if enable_pag: | 
| @@ -688,7 +698,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): | 
| 688 | 698 |         to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" | 
| 689 | 699 | 
 | 
| 690 | 700 |         if "controlnet" in kwargs: | 
| 691 |  | -            orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) | 
|  | 701 | +            if isinstance(kwargs["controlnet"], ControlNetUnionModel): | 
|  | 702 | +                orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace) | 
|  | 703 | +            else: | 
|  | 704 | +                orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) | 
| 692 | 705 |         if "enable_pag" in kwargs: | 
| 693 | 706 |             enable_pag = kwargs.pop("enable_pag") | 
| 694 | 707 |             if enable_pag: | 
| @@ -985,7 +998,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): | 
| 985 | 998 |         to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" | 
| 986 | 999 | 
 | 
| 987 | 1000 |         if "controlnet" in kwargs: | 
| 988 |  | -            orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) | 
|  | 1001 | +            if isinstance(kwargs["controlnet"], ControlNetUnionModel): | 
|  | 1002 | +                orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace) | 
|  | 1003 | +            else: | 
|  | 1004 | +                orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) | 
| 989 | 1005 |         if "enable_pag" in kwargs: | 
| 990 | 1006 |             enable_pag = kwargs.pop("enable_pag") | 
| 991 | 1007 |             if enable_pag: | 
|  | 
0 commit comments