Skip to content

Commit e0bbb3e

Browse files
patrickvonplatensliard
authored andcommitted
[Dreambooth] Make compatible with alt diffusion (huggingface#1470)
* [Dreambooth] Make compatible with alt diffusion * make style * add example
1 parent c5e712f commit e0bbb3e

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

examples/dreambooth/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,17 @@ accelerate launch train_dreambooth.py \
195195
--max_train_steps=800
196196
```
197197

198+
### Using DreamBooth for other pipelines than Stable Diffusion
199+
200+
Altdiffusion also support dreambooth now, the runing comman is basically the same as abouve, all you need to do is replace the `MODEL_NAME` like this:
201+
One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion).
202+
203+
```
204+
export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9"
205+
or
206+
export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion"
207+
```
208+
198209
### Inference
199210

200211
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.

examples/dreambooth/train_dreambooth.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,38 @@
1414
from accelerate import Accelerator
1515
from accelerate.logging import get_logger
1616
from accelerate.utils import set_seed
17-
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
17+
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
1818
from diffusers.optimization import get_scheduler
1919
from huggingface_hub import HfFolder, Repository, whoami
2020
from PIL import Image
2121
from torchvision import transforms
2222
from tqdm.auto import tqdm
23-
from transformers import CLIPTextModel, CLIPTokenizer
23+
from transformers import AutoTokenizer, PretrainedConfig
2424

2525

2626
logger = get_logger(__name__)
2727

2828

29+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
30+
text_encoder_config = PretrainedConfig.from_pretrained(
31+
pretrained_model_name_or_path,
32+
subfolder="text_encoder",
33+
revision=args.revision,
34+
)
35+
model_class = text_encoder_config.architectures[0]
36+
37+
if model_class == "CLIPTextModel":
38+
from transformers import CLIPTextModel
39+
40+
return CLIPTextModel
41+
elif model_class == "RobertaSeriesModelWithTransformation":
42+
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
43+
44+
return RobertaSeriesModelWithTransformation
45+
else:
46+
raise ValueError(f"{model_class} is not supported.")
47+
48+
2949
def parse_args(input_args=None):
3050
parser = argparse.ArgumentParser(description="Simple example of a training script.")
3151
parser.add_argument(
@@ -357,7 +377,7 @@ def main(args):
357377

358378
if cur_class_images < args.num_class_images:
359379
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
360-
pipeline = StableDiffusionPipeline.from_pretrained(
380+
pipeline = DiffusionPipeline.from_pretrained(
361381
args.pretrained_model_name_or_path,
362382
torch_dtype=torch_dtype,
363383
safety_checker=None,
@@ -407,19 +427,24 @@ def main(args):
407427

408428
# Load the tokenizer
409429
if args.tokenizer_name:
410-
tokenizer = CLIPTokenizer.from_pretrained(
430+
tokenizer = AutoTokenizer.from_pretrained(
411431
args.tokenizer_name,
412432
revision=args.revision,
433+
use_fast=False,
413434
)
414435
elif args.pretrained_model_name_or_path:
415-
tokenizer = CLIPTokenizer.from_pretrained(
436+
tokenizer = AutoTokenizer.from_pretrained(
416437
args.pretrained_model_name_or_path,
417438
subfolder="tokenizer",
418439
revision=args.revision,
440+
use_fast=False,
419441
)
420442

443+
# import correct text encoder class
444+
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
445+
421446
# Load models and create wrapper for stable diffusion
422-
text_encoder = CLIPTextModel.from_pretrained(
447+
text_encoder = text_encoder_cls.from_pretrained(
423448
args.pretrained_model_name_or_path,
424449
subfolder="text_encoder",
425450
revision=args.revision,
@@ -649,7 +674,7 @@ def collate_fn(examples):
649674

650675
if global_step % args.save_steps == 0:
651676
if accelerator.is_main_process:
652-
pipeline = StableDiffusionPipeline.from_pretrained(
677+
pipeline = DiffusionPipeline.from_pretrained(
653678
args.pretrained_model_name_or_path,
654679
unet=accelerator.unwrap_model(unet),
655680
text_encoder=accelerator.unwrap_model(text_encoder),
@@ -669,7 +694,7 @@ def collate_fn(examples):
669694

670695
# Create the pipeline using using the trained modules and save it.
671696
if accelerator.is_main_process:
672-
pipeline = StableDiffusionPipeline.from_pretrained(
697+
pipeline = DiffusionPipeline.from_pretrained(
673698
args.pretrained_model_name_or_path,
674699
unet=accelerator.unwrap_model(unet),
675700
text_encoder=accelerator.unwrap_model(text_encoder),

0 commit comments

Comments
 (0)