Skip to content

Commit 6c56f05

Browse files
authored
v-prediction training support (#1455)
* add get_velocity * add v prediction for training * fix saving * add revision arg * fix saving * save checkpoints dreambooth * fix saving embeds * add instruction in readme * quality * noise_pred -> model_pred
1 parent 77fc197 commit 6c56f05

File tree

8 files changed

+157
-38
lines changed

8 files changed

+157
-38
lines changed

examples/dreambooth/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ Now let's get our dataset. Download images from [here](https://drive.google.com/
3939

4040
And launch the training using
4141

42+
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
43+
4244
```bash
4345
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
4446
export INSTANCE_DIR="path-to-instance-images"

examples/dreambooth/train_dreambooth.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def parse_args(input_args=None):
124124
default=None,
125125
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
126126
)
127+
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
127128
parser.add_argument(
128129
"--gradient_accumulation_steps",
129130
type=int,
@@ -603,23 +604,31 @@ def collate_fn(examples):
603604
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
604605

605606
# Predict the noise residual
606-
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
607+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
608+
609+
# Get the target for loss depending on the prediction type
610+
if noise_scheduler.config.prediction_type == "epsilon":
611+
target = noise
612+
elif noise_scheduler.config.prediction_type == "v_prediction":
613+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
614+
else:
615+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
607616

608617
if args.with_prior_preservation:
609-
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
610-
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
611-
noise, noise_prior = torch.chunk(noise, 2, dim=0)
618+
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
619+
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
620+
target, target_prior = torch.chunk(target, 2, dim=0)
612621

613622
# Compute instance loss
614-
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
623+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
615624

616625
# Compute prior loss
617-
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
626+
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
618627

619628
# Add the prior loss to the instance loss.
620629
loss = loss + args.prior_loss_weight * prior_loss
621630
else:
622-
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
631+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
623632

624633
accelerator.backward(loss)
625634
if accelerator.sync_gradients:
@@ -638,6 +647,17 @@ def collate_fn(examples):
638647
progress_bar.update(1)
639648
global_step += 1
640649

650+
if global_step % args.save_steps == 0:
651+
if accelerator.is_main_process:
652+
pipeline = StableDiffusionPipeline.from_pretrained(
653+
args.pretrained_model_name_or_path,
654+
unet=accelerator.unwrap_model(unet),
655+
text_encoder=accelerator.unwrap_model(text_encoder),
656+
revision=args.revision,
657+
)
658+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
659+
pipeline.save_pretrained(save_path)
660+
641661
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
642662
progress_bar.set_postfix(**logs)
643663
accelerator.log(logs, step=global_step)

examples/text_to_image/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ If you have already cloned the repo, then you won't need to go through these ste
4242
#### Hardware
4343
With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
4444

45+
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
46+
4547
```bash
4648
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
4749
export dataset_name="lambdalabs/pokemon-blip-captions"

examples/text_to_image/train_text_to_image.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
from accelerate.logging import get_logger
1616
from accelerate.utils import set_seed
1717
from datasets import load_dataset
18-
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
18+
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
1919
from diffusers.optimization import get_scheduler
20-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
2120
from huggingface_hub import HfFolder, Repository, whoami
2221
from torchvision import transforms
2322
from tqdm.auto import tqdm
24-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
23+
from transformers import CLIPTextModel, CLIPTokenizer
2524

2625

2726
logger = get_logger(__name__)
@@ -36,6 +35,13 @@ def parse_args():
3635
required=True,
3736
help="Path to pretrained model or model identifier from huggingface.co/models.",
3837
)
38+
parser.add_argument(
39+
"--revision",
40+
type=str,
41+
default=None,
42+
required=False,
43+
help="Revision of pretrained model identifier from huggingface.co/models.",
44+
)
3945
parser.add_argument(
4046
"--dataset_name",
4147
type=str,
@@ -335,10 +341,24 @@ def main():
335341
os.makedirs(args.output_dir, exist_ok=True)
336342

337343
# Load models and create wrapper for stable diffusion
338-
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
339-
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
340-
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
341-
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
344+
tokenizer = CLIPTokenizer.from_pretrained(
345+
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
346+
)
347+
text_encoder = CLIPTextModel.from_pretrained(
348+
args.pretrained_model_name_or_path,
349+
subfolder="text_encoder",
350+
revision=args.revision,
351+
)
352+
vae = AutoencoderKL.from_pretrained(
353+
args.pretrained_model_name_or_path,
354+
subfolder="vae",
355+
revision=args.revision,
356+
)
357+
unet = UNet2DConditionModel.from_pretrained(
358+
args.pretrained_model_name_or_path,
359+
subfolder="unet",
360+
revision=args.revision,
361+
)
342362

343363
# Freeze vae and text_encoder
344364
vae.requires_grad_(False)
@@ -562,9 +582,17 @@ def collate_fn(examples):
562582
# Get the text embedding for conditioning
563583
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
564584

585+
# Get the target for loss depending on the prediction type
586+
if noise_scheduler.config.prediction_type == "epsilon":
587+
target = noise
588+
elif noise_scheduler.config.prediction_type == "v_prediction":
589+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
590+
else:
591+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
592+
565593
# Predict the noise residual and compute loss
566-
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
567-
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
594+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
595+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
568596

569597
# Gather the losses across all processes for logging (if we use distributed training).
570598
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
@@ -600,14 +628,12 @@ def collate_fn(examples):
600628
if args.use_ema:
601629
ema_unet.copy_to(unet.parameters())
602630

603-
pipeline = StableDiffusionPipeline(
631+
pipeline = StableDiffusionPipeline.from_pretrained(
632+
args.pretrained_model_name_or_path,
604633
text_encoder=text_encoder,
605634
vae=vae,
606635
unet=unet,
607-
tokenizer=tokenizer,
608-
scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
609-
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
610-
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
636+
revision=args.revision,
611637
)
612638
pipeline.save_pretrained(args.output_dir)
613639

examples/textual_inversion/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c
4747

4848
And launch the training using
4949

50+
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
51+
5052
```bash
5153
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
5254
export DATA_DIR="path-to-dir-containing-images"

examples/textual_inversion/textual_inversion.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,16 @@
1616
from accelerate import Accelerator
1717
from accelerate.logging import get_logger
1818
from accelerate.utils import set_seed
19-
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
19+
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
2020
from diffusers.optimization import get_scheduler
21-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
2221
from huggingface_hub import HfFolder, Repository, whoami
2322

2423
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
2524
from packaging import version
2625
from PIL import Image
2726
from torchvision import transforms
2827
from tqdm.auto import tqdm
29-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
28+
from transformers import CLIPTextModel, CLIPTokenizer
3029

3130

3231
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
@@ -51,11 +50,11 @@
5150
logger = get_logger(__name__)
5251

5352

54-
def save_progress(text_encoder, placeholder_token_id, accelerator, args):
53+
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
5554
logger.info("Saving embeddings")
5655
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
5756
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
58-
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
57+
torch.save(learned_embeds_dict, save_path)
5958

6059

6160
def parse_args():
@@ -73,6 +72,13 @@ def parse_args():
7372
required=True,
7473
help="Path to pretrained model or model identifier from huggingface.co/models.",
7574
)
75+
parser.add_argument(
76+
"--revision",
77+
type=str,
78+
default=None,
79+
required=False,
80+
help="Revision of pretrained model identifier from huggingface.co/models.",
81+
)
7682
parser.add_argument(
7783
"--tokenizer_name",
7884
type=str,
@@ -405,9 +411,21 @@ def main():
405411
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
406412

407413
# Load models and create wrapper for stable diffusion
408-
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
409-
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
410-
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
414+
text_encoder = CLIPTextModel.from_pretrained(
415+
args.pretrained_model_name_or_path,
416+
subfolder="text_encoder",
417+
revision=args.revision,
418+
)
419+
vae = AutoencoderKL.from_pretrained(
420+
args.pretrained_model_name_or_path,
421+
subfolder="vae",
422+
revision=args.revision,
423+
)
424+
unet = UNet2DConditionModel.from_pretrained(
425+
args.pretrained_model_name_or_path,
426+
subfolder="unet",
427+
revision=args.revision,
428+
)
411429

412430
# Resize the token embeddings as we are adding new special tokens to the tokenizer
413431
text_encoder.resize_token_embeddings(len(tokenizer))
@@ -532,9 +550,17 @@ def main():
532550
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
533551

534552
# Predict the noise residual
535-
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
553+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
554+
555+
# Get the target for loss depending on the prediction type
556+
if noise_scheduler.config.prediction_type == "epsilon":
557+
target = noise
558+
elif noise_scheduler.config.prediction_type == "v_prediction":
559+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
560+
else:
561+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
536562

537-
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
563+
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
538564
accelerator.backward(loss)
539565

540566
# Zero out the gradients for all token embeddings except the newly added
@@ -556,7 +582,8 @@ def main():
556582
progress_bar.update(1)
557583
global_step += 1
558584
if global_step % args.save_steps == 0:
559-
save_progress(text_encoder, placeholder_token_id, accelerator, args)
585+
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
586+
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
560587

561588
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
562589
progress_bar.set_postfix(**logs)
@@ -569,18 +596,18 @@ def main():
569596

570597
# Create the pipeline using using the trained modules and save it.
571598
if accelerator.is_main_process:
572-
pipeline = StableDiffusionPipeline(
599+
pipeline = StableDiffusionPipeline.from_pretrained(
600+
args.pretrained_model_name_or_path,
573601
text_encoder=accelerator.unwrap_model(text_encoder),
602+
tokenizer=tokenizer,
574603
vae=vae,
575604
unet=unet,
576-
tokenizer=tokenizer,
577-
scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
578-
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
579-
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
605+
revision=args.revision,
580606
)
581607
pipeline.save_pretrained(args.output_dir)
582608
# Also save the newly trained embeddings
583-
save_progress(text_encoder, placeholder_token_id, accelerator, args)
609+
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
610+
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
584611

585612
if args.push_to_hub:
586613
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,5 +355,25 @@ def add_noise(
355355
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
356356
return noisy_samples
357357

358+
def get_velocity(
359+
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
360+
) -> torch.FloatTensor:
361+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
362+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
363+
timesteps = timesteps.to(sample.device)
364+
365+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
366+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
367+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
368+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
369+
370+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
371+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
372+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
373+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
374+
375+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
376+
return velocity
377+
358378
def __len__(self):
359379
return self.config.num_train_timesteps

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,5 +345,25 @@ def add_noise(
345345
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
346346
return noisy_samples
347347

348+
def get_velocity(
349+
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
350+
) -> torch.FloatTensor:
351+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
352+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
353+
timesteps = timesteps.to(sample.device)
354+
355+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
356+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
357+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
358+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
359+
360+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
361+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
362+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
363+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
364+
365+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
366+
return velocity
367+
348368
def __len__(self):
349369
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)