Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def save_model_card(
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
"""
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model")
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
Expand All @@ -145,8 +145,7 @@ def save_model_card(
to trigger concept `{key}` → use `{tokens}` in your prompt \n
"""

yaml = f"""
---
yaml = f"""---
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
Expand All @@ -159,7 +158,7 @@ def save_model_card(
instance_prompt: {instance_prompt}
license: openrail++
---
"""
"""

model_card = f"""
# SDXL LoRA DreamBooth - {repo_id}
Expand All @@ -170,14 +169,6 @@ def save_model_card(

### These are {repo_id} LoRA adaption weights for {base_model}.

The weights were trained using [DreamBooth](https://dreambooth.github.io/).

LoRA for the text encoder was enabled: {train_text_encoder}.

Pivotal tuning was enabled: {train_text_encoder_ti}.

Special VAE used for training: {vae_path}.

## Trigger words

{trigger_str}
Expand All @@ -196,11 +187,24 @@ def save_model_card(

For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)

## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke)
## Download model

### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke

- Download the LoRA *.safetensors [here](/{repo_id}/blob/main/pytorch_lora_weights.safetensors). Rename it and place it on your Lora folder.
- Download the text embeddings *.safetensors [here](/{repo_id}/blob/main/embeddings.safetensors). Rename it and place it on it on your embeddings folder.

All [Files & versions](/{repo_id}/tree/main).

Weights for this model are available in Safetensors format.
## Details

[Download]({repo_id}/tree/main) them in the Files & versions tab.
The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py).

LoRA for the text encoder was enabled. {train_text_encoder}.

Pivotal tuning was enabled: {train_text_encoder_ti}.

Special VAE used for training: {vae_path}.

"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
Expand Down Expand Up @@ -667,6 +671,12 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1170,6 +1180,7 @@ def main(args):
revision=args.revision,
variant=args.variant,
)
vae_scaling_factor = vae.config.scaling_factor
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
Expand Down Expand Up @@ -1600,6 +1611,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
print("validation prompt:", args.validation_prompt)

if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=torch.float32
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)

if args.validation_prompt is None:
del vae
if torch.cuda.is_available():
torch.cuda.empty_cache()

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -1715,9 +1740,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
unet.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]
# print(prompts)
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if freeze_text_encoder:
Expand All @@ -1729,9 +1752,13 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens)
tokens_two = tokenize_prompt(tokenizer_two, prompts, add_special_tokens)

# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()

model_input = model_input * vae_scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)

Expand Down