diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index df5477d0d643..f4b4e42c8b19 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -133,7 +133,7 @@ def save_model_card( diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download from safetensors.torch import load_file """ - diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model") + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model") state_dict = load_file(embedding_path) pipeline.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) pipeline.load_textual_inversion(state_dict["clip_g"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) @@ -145,8 +145,7 @@ def save_model_card( to trigger concept `{key}` → use `{tokens}` in your prompt \n """ - yaml = f""" ---- + yaml = f"""--- tags: - stable-diffusion-xl - stable-diffusion-xl-diffusers @@ -159,7 +158,7 @@ def save_model_card( instance_prompt: {instance_prompt} license: openrail++ --- - """ +""" model_card = f""" # SDXL LoRA DreamBooth - {repo_id} @@ -170,14 +169,6 @@ def save_model_card( ### These are {repo_id} LoRA adaption weights for {base_model}. -The weights were trained using [DreamBooth](https://dreambooth.github.io/). - -LoRA for the text encoder was enabled: {train_text_encoder}. - -Pivotal tuning was enabled: {train_text_encoder_ti}. - -Special VAE used for training: {vae_path}. - ## Trigger words {trigger_str} @@ -196,11 +187,24 @@ def save_model_card( For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) -## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke) +## Download model + +### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke + +- Download the LoRA *.safetensors [here](/{repo_id}/blob/main/pytorch_lora_weights.safetensors). Rename it and place it on your Lora folder. +- Download the text embeddings *.safetensors [here](/{repo_id}/blob/main/embeddings.safetensors). Rename it and place it on it on your embeddings folder. + +All [Files & versions](/{repo_id}/tree/main). -Weights for this model are available in Safetensors format. +## Details -[Download]({repo_id}/tree/main) them in the Files & versions tab. +The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py). + +LoRA for the text encoder was enabled. {train_text_encoder}. + +Pivotal tuning was enabled: {train_text_encoder_ti}. + +Special VAE used for training: {vae_path}. """ with open(os.path.join(repo_folder, "README.md"), "w") as f: @@ -667,6 +671,12 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -1170,6 +1180,7 @@ def main(args): revision=args.revision, variant=args.variant, ) + vae_scaling_factor = vae.config.scaling_factor unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) @@ -1600,6 +1611,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) print("validation prompt:", args.validation_prompt) + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=torch.float32 + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + del vae + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1715,9 +1740,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): unet.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] - # print(prompts) # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: if freeze_text_encoder: @@ -1729,9 +1752,13 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens) tokens_two = tokenize_prompt(tokenizer_two, prompts, add_special_tokens) - # Convert images to latent space - model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = model_input * vae.config.scaling_factor + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + + model_input = model_input * vae_scaling_factor if args.pretrained_vae_model_name_or_path is None: model_input = model_input.to(weight_dtype)