diff --git a/examples/dreambooth/README_sdxl.md b/examples/dreambooth/README_sdxl.md index 70b9821553de..aa1aa927c5e9 100644 --- a/examples/dreambooth/README_sdxl.md +++ b/examples/dreambooth/README_sdxl.md @@ -206,3 +206,40 @@ You can explore the results from a couple of our internal experiments by checkin ## Running on a free-tier Colab Notebook Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb). + +## Conducting EDM-style training + +It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364). + +For the SDXL model, simple set: + +```diff ++ --do_edm_style_training \ +``` + +Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command: + +```bash +accelerate launch train_dreambooth_lora_sdxl.py \ + --pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \ + --instance_data_dir="dog" \ + --output_dir="dog-playground-lora" \ + --mixed_precision="fp16" \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --learning_rate=1e-4 \ + --use_8bit_adam \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +> [!CAUTION] +> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant". diff --git a/examples/dreambooth/test_dreambooth_lora_edm.py b/examples/dreambooth/test_dreambooth_lora_edm.py new file mode 100644 index 000000000000..0f6b3674b8b3 --- /dev/null +++ b/examples/dreambooth/test_dreambooth_lora_edm.py @@ -0,0 +1,99 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + +import safetensors + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRASDXLWithEDM(ExamplesTestsAccelerate): + def test_dreambooth_lora_sdxl_with_edm(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --do_edm_style_training + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` in their names. + starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_unet) + + def test_dreambooth_lora_playground(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-playground-v2-5-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` in their names. + starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_unet) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 849becae4a61..f4d0efc411fd 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -14,8 +14,10 @@ # See the License for the specific language governing permissions and import argparse +import contextlib import gc import itertools +import json import logging import math import os @@ -32,7 +34,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed -from huggingface_hub import create_repo, upload_folder +from huggingface_hub import create_repo, hf_hub_download, upload_folder from huggingface_hub.utils import insecure_hashlib from packaging import version from peft import LoraConfig, set_peft_model_state_dict @@ -50,6 +52,8 @@ AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, + EDMEulerScheduler, + EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, ) @@ -76,6 +80,20 @@ logger = get_logger(__name__) +def determine_scheduler_type(pretrained_model_name_or_path, revision): + model_index_filename = "model_index.json" + if os.path.isdir(pretrained_model_name_or_path): + model_index = os.path.join(pretrained_model_name_or_path, model_index_filename) + else: + model_index = hf_hub_download( + repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision + ) + + with open(model_index, "r") as f: + scheduler_type = json.load(f)["scheduler"][1] + return scheduler_type + + def save_model_card( repo_id: str, images=None, @@ -95,7 +113,7 @@ def save_model_card( ) model_description = f""" -# SDXL LoRA DreamBooth - {repo_id} +# {'SDXL' if 'playgroundai' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id} @@ -119,11 +137,17 @@ def save_model_card( [Download]({repo_id}/tree/main) them in the Files & versions tab. +""" + if "playgroundai" in args.pretrained_model_name_or_path: + model_description += """\n +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md). """ model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="openrail++", + license="openrail++" if "playgroundai" not in base_model else "playground-v2dot5-community", base_model=base_model, prompt=instance_prompt, model_description=model_description, @@ -131,15 +155,17 @@ def save_model_card( ) tags = [ "text-to-image", - "stable-diffusion-xl", - "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "lora", "template:sd-lora", ] - model_card = populate_model_card(model_card, tags=tags) + if "playgroundai" in base_model: + tags.extend(["playground", "playground-diffusers"]) + else: + tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"]) + model_card = populate_model_card(model_card, tags=tags) model_card.save(os.path.join(repo_folder, "README.md")) @@ -159,23 +185,29 @@ def log_validation( # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it scheduler_args = {} - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type + if not args.do_edm_style_training: + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" - scheduler_args["variance_type"] = variance_type + scheduler_args["variance_type"] = variance_type - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better + # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 + inference_ctx = ( + contextlib.nullcontext() if "playgroundai" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast() + ) - with torch.cuda.amp.autocast(): + with inference_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: @@ -334,6 +366,12 @@ def parse_args(input_args=None): " `args.validation_prompt` multiple times: `args.num_validation_images`." ), ) + parser.add_argument( + "--do_edm_style_training", + default=False, + action="store_true", + help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.", + ) parser.add_argument( "--with_prior_preservation", default=False, @@ -905,6 +943,9 @@ def main(args): " Please use `huggingface-cli login` to authenticate with the Hub." ) + if args.do_edm_style_training and args.snr_gamma is not None: + raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.") + logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) @@ -1018,7 +1059,19 @@ def main(args): ) # Load scheduler and models - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision) + if "EDM" in scheduler_type: + args.do_edm_style_training = True + noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + logger.info("Performing EDM-style training!") + elif args.do_edm_style_training: + noise_scheduler = EulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + logger.info("Performing EDM-style training!") + else: + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) @@ -1036,6 +1089,12 @@ def main(args): revision=args.revision, variant=args.variant, ) + latents_mean = latents_std = None + if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: + latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) @@ -1433,7 +1492,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args)) + tracker_name = ( + "dreambooth-lora-sd-xl" + if "playgroundai" not in args.pretrained_model_name_or_path + else "dreambooth-lora-playground" + ) + accelerator.init_trackers(tracker_name, config=vars(args)) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1485,6 +1549,18 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): disable=not accelerator.is_local_main_process, ) + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + for epoch in range(first_epoch, args.num_train_epochs): unet.train() if args.train_text_encoder: @@ -1512,22 +1588,46 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = model_input * vae.config.scaling_factor - if args.pretrained_vae_model_name_or_path is None: - model_input = model_input.to(weight_dtype) + + if latents_mean is None and latents_std is None: + model_input = model_input * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) + else: + latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype) + latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype) + model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std + model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) bsz = model_input.shape[0] + # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device - ) - timesteps = timesteps.long() + if not args.do_edm_style_training: + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + else: + # in EDM formulation, the model is conditioned on the pre-conditioned noise levels + # instead of discrete timesteps, so here we sample indices to get the noise levels + # from `scheduler.timesteps` + indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,)) + timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device) # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + # For EDM-style training, we first obtain the sigmas based on the continuous timesteps. + # We then precondition the final model inputs based on these sigmas instead of the timesteps. + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + if args.do_edm_style_training: + sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype) + if "EDM" in scheduler_type: + inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas) + else: + inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5) # time ids add_time_ids = torch.cat( @@ -1551,7 +1651,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): } prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( - noisy_model_input, + inp_noisy_latents if args.do_edm_style_training else noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, @@ -1570,18 +1670,43 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( - noisy_model_input, + inp_noisy_latents if args.do_edm_style_training else noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, return_dict=False, )[0] + weighting = None + if args.do_edm_style_training: + # Similar to the input preconditioning, the model predictions are also preconditioned + # on noised model inputs (before preconditioning) and the sigmas. + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + if "EDM" in scheduler_type: + model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas) + else: + if noise_scheduler.config.prediction_type == "epsilon": + model_pred = model_pred * (-sigmas) + noisy_model_input + elif noise_scheduler.config.prediction_type == "v_prediction": + model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + ( + noisy_model_input / (sigmas**2 + 1) + ) + # We are not doing weighting here because it tends result in numerical problems. + # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 + # There might be other alternatives for weighting as well: + # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686 + if "EDM" not in scheduler_type: + weighting = (sigmas**-2.0).float() + # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": - target = noise + target = model_input if args.do_edm_style_training else noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(model_input, noise, timesteps) + target = ( + model_input + if args.do_edm_style_training + else noise_scheduler.get_velocity(model_input, noise, timesteps) + ) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") @@ -1591,10 +1716,28 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): target, target_prior = torch.chunk(target, 2, dim=0) # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + if weighting is not None: + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + else: + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") if args.snr_gamma is None: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + if weighting is not None: + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape( + target.shape[0], -1 + ), + 1, + ) + loss = loss.mean() + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. @@ -1696,7 +1839,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): variant=args.variant, torch_dtype=weight_dtype, ) - pipeline_args = {"prompt": args.validation_prompt} images = log_validation(