Skip to content

Commit 87f87a7

Browse files
sayakpauldg845
andcommitted
propagate comments from #6145
Co-authored-by: dg845 <[email protected]>
1 parent 7a1d6c9 commit 87f87a7

File tree

1 file changed

+88
-44
lines changed

1 file changed

+88
-44
lines changed

examples/consistency_distillation/train_lcm_distill_lora_sdxl.py

Lines changed: 88 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
199199

200200

201201
# Compare LCMScheduler.step, Step 4
202-
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
202+
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
203+
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
204+
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
203205
if prediction_type == "epsilon":
204-
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
205-
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
206206
pred_x_0 = (sample - sigmas * model_output) / alphas
207+
elif prediction_type == "sample":
208+
pred_x_0 = model_output
207209
elif prediction_type == "v_prediction":
208-
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
210+
pred_x_0 = alphas * sample - sigmas * model_output
209211
else:
210-
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
212+
raise ValueError(
213+
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
214+
f" are supported."
215+
)
211216

212217
return pred_x_0
213218

214219

220+
# Based on step 4 in DDIMScheduler.step
221+
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
222+
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
223+
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
224+
if prediction_type == "epsilon":
225+
pred_epsilon = model_output
226+
elif prediction_type == "sample":
227+
pred_epsilon = (sample - alphas * model_output) / sigmas
228+
elif prediction_type == "v_prediction":
229+
pred_epsilon = alphas * model_output + sigmas * sample
230+
else:
231+
raise ValueError(
232+
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
233+
f" are supported."
234+
)
235+
236+
return pred_epsilon
237+
238+
215239
def extract_into_tensor(a, t, x_shape):
216240
b, *_ = t.shape
217241
out = a.gather(-1, t)
@@ -676,24 +700,25 @@ def main(args):
676700
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
677701
)
678702

679-
# The scheduler calculates the alpha and sigma schedule for us
703+
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
680704
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
681705
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
706+
# Initialize the DDIM ODE solver for distillation.
682707
solver = DDIMSolver(
683708
noise_scheduler.alphas_cumprod.numpy(),
684709
timesteps=noise_scheduler.config.num_train_timesteps,
685710
ddim_timesteps=args.num_ddim_timesteps,
686711
)
687712

688-
# 2. Load tokenizers from SD-XL checkpoint.
713+
# 2. Load tokenizers from SDXL checkpoint.
689714
tokenizer_one = AutoTokenizer.from_pretrained(
690715
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
691716
)
692717
tokenizer_two = AutoTokenizer.from_pretrained(
693718
args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False
694719
)
695720

696-
# 3. Load text encoders from SD-XL checkpoint.
721+
# 3. Load text encoders from SDXL checkpoint.
697722
# import correct text encoder classes
698723
text_encoder_cls_one = import_model_class_from_model_name_or_path(
699724
args.pretrained_teacher_model, args.teacher_revision
@@ -709,7 +734,7 @@ def main(args):
709734
args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision
710735
)
711736

712-
# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
737+
# 4. Load VAE from SDXL checkpoint (or more stable VAE)
713738
vae_path = (
714739
args.pretrained_teacher_model
715740
if args.pretrained_vae_model_name_or_path is None
@@ -726,7 +751,7 @@ def main(args):
726751
text_encoder_one.requires_grad_(False)
727752
text_encoder_two.requires_grad_(False)
728753

729-
# 7. Create online (`unet`) student U-Net.
754+
# 7. Create online student U-Net.
730755
unet = UNet2DConditionModel.from_pretrained(
731756
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
732757
)
@@ -743,7 +768,7 @@ def main(args):
743768
f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
744769
)
745770

746-
# 9. Handle mixed precision and device placement
771+
# 8. Handle mixed precision and device placement
747772
# For mixed precision training we cast all non-trainable weigths to half-precision
748773
# as these weights are only used for inference, keeping weights in full precision is not required.
749774
weight_dtype = torch.float32
@@ -762,7 +787,7 @@ def main(args):
762787
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
763788
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
764789

765-
# 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
790+
# 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
766791
lora_config = LoraConfig(
767792
r=args.lora_rank,
768793
lora_alpha=args.lora_rank,
@@ -1007,7 +1032,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10071032

10081033
compute_embeddings_fn = functools.partial(compute_embeddings, text_encoders=text_encoders, tokenizers=tokenizers)
10091034

1010-
# 14. LR Scheduler creation
1035+
# 15. LR Scheduler creation
10111036
# Scheduler and math around the number of training steps.
10121037
overrode_max_train_steps = False
10131038
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1027,7 +1052,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10271052
num_training_steps=args.max_train_steps * accelerator.num_processes,
10281053
)
10291054

1030-
# 15. Prepare for training
1055+
# 16. Prepare for training
10311056
# Prepare everything with our `accelerator`.
10321057
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
10331058
unet, optimizer, train_dataloader, lr_scheduler
@@ -1046,7 +1071,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10461071
tracker_config = dict(vars(args))
10471072
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
10481073

1049-
# 16. Train!
1074+
# 17. Train!
10501075
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
10511076

10521077
logger.info("***** Running training *****")
@@ -1098,6 +1123,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10981123
for epoch in range(first_epoch, args.num_train_epochs):
10991124
for step, batch in enumerate(train_dataloader):
11001125
with accelerator.accumulate(unet):
1126+
# 1. Load and process the image and text conditioning
11011127
pixel_values, text, orig_size, crop_coords = (
11021128
batch["pixel_values"],
11031129
batch["captions"],
@@ -1118,44 +1144,43 @@ def compute_time_ids(original_size, crops_coords_top_left):
11181144
if args.pretrained_vae_model_name_or_path is None:
11191145
latents = latents.to(weight_dtype)
11201146

1121-
# Sample noise that we'll add to the latents
1122-
noise = torch.randn_like(latents)
1147+
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
1148+
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
11231149
bsz = latents.shape[0]
1124-
1125-
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
11261150
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
11271151
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
11281152
start_timesteps = solver.ddim_timesteps[index]
11291153
timesteps = start_timesteps - topk
11301154
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
11311155

1132-
# Get boundary scalings for start_timesteps and (end) timesteps.
1156+
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
11331157
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
11341158
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
11351159
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
11361160
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
11371161

1138-
# Add noise to the latents according to the noise magnitude at each timestep
1139-
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
1162+
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
1163+
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
1164+
noise = torch.randn_like(latents)
11401165
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
11411166

1142-
# Sample a random guidance scale w from U[w_min, w_max] and embed it
1167+
# 5. Sample a random guidance scale w from U[w_min, w_max]
1168+
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
11431169
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
11441170
w = w.reshape(bsz, 1, 1, 1)
11451171
w = w.to(device=latents.device, dtype=latents.dtype)
11461172

1147-
# Prepare prompt embeds and unet_added_conditions
1173+
# 6. Prepare prompt embeds and unet_added_conditions
11481174
prompt_embeds = encoded_text.pop("prompt_embeds")
11491175

1150-
# Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
1176+
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
11511177
noise_pred = unet(
11521178
noisy_model_input,
11531179
start_timesteps,
11541180
encoder_hidden_states=prompt_embeds,
11551181
added_cond_kwargs=encoded_text,
11561182
).sample
1157-
1158-
pred_x_0 = predicted_origin(
1183+
pred_x_0 = get_predicted_original_sample(
11591184
noise_pred,
11601185
start_timesteps,
11611186
noisy_model_input,
@@ -1165,20 +1190,28 @@ def compute_time_ids(original_size, crops_coords_top_left):
11651190
)
11661191
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
11671192

1168-
# Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
1169-
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
1170-
# Get teacher model prediction on noisy_latents and conditional embedding
1171-
# Notice that we're disabling the adapter layers within the `unet` and then it becomes a
1172-
# regular teacher. This way, we don't have to separately initialize a teacher UNet.
1193+
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
1194+
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
1195+
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
1196+
# solver timestep.
11731197
unet.disable_adapters()
11741198
with torch.no_grad():
1199+
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
11751200
cond_teacher_output = unet(
11761201
noisy_model_input,
11771202
start_timesteps,
11781203
encoder_hidden_states=prompt_embeds,
11791204
added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
11801205
).sample
1181-
cond_pred_x0 = predicted_origin(
1206+
cond_pred_x0 = get_predicted_original_sample(
1207+
cond_teacher_output,
1208+
start_timesteps,
1209+
noisy_model_input,
1210+
noise_scheduler.config.prediction_type,
1211+
alpha_schedule,
1212+
sigma_schedule,
1213+
)
1214+
cond_pred_noise = get_predicted_noise(
11821215
cond_teacher_output,
11831216
start_timesteps,
11841217
noisy_model_input,
@@ -1187,19 +1220,26 @@ def compute_time_ids(original_size, crops_coords_top_left):
11871220
sigma_schedule,
11881221
)
11891222

1190-
# Create uncond embeds for classifier free guidance
1223+
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
11911224
uncond_prompt_embeds = torch.zeros_like(prompt_embeds)
11921225
uncond_pooled_prompt_embeds = torch.zeros_like(encoded_text["text_embeds"])
11931226
uncond_added_conditions = copy.deepcopy(encoded_text)
1194-
# Get teacher model prediction on noisy_latents and unconditional embedding
11951227
uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
11961228
uncond_teacher_output = unet(
11971229
noisy_model_input,
11981230
start_timesteps,
11991231
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
12001232
added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
12011233
).sample
1202-
uncond_pred_x0 = predicted_origin(
1234+
uncond_pred_x0 = get_predicted_original_sample(
1235+
uncond_teacher_output,
1236+
start_timesteps,
1237+
noisy_model_input,
1238+
noise_scheduler.config.prediction_type,
1239+
alpha_schedule,
1240+
sigma_schedule,
1241+
)
1242+
uncond_pred_noise = get_predicted_noise(
12031243
uncond_teacher_output,
12041244
start_timesteps,
12051245
noisy_model_input,
@@ -1208,24 +1248,28 @@ def compute_time_ids(original_size, crops_coords_top_left):
12081248
sigma_schedule,
12091249
)
12101250

1211-
# Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
1251+
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
1252+
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
12121253
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
1213-
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
1214-
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
1215-
x_prev = x_prev.to(unet.dtype)
1254+
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
1255+
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
1256+
# augmented PF-ODE trajectory (solving backward in time)
1257+
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
1258+
x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype)
12161259

12171260
# re-enable unet adapters
12181261
unet.enable_adapters()
12191262

1220-
# Get target LCM prediction on x_prev, w, c, t_n
1263+
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
1264+
# Note that we do not use a separate target network for LCM-LoRA distillation.
12211265
with torch.no_grad():
12221266
target_noise_pred = unet(
12231267
x_prev,
12241268
timesteps,
12251269
encoder_hidden_states=prompt_embeds,
12261270
added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
12271271
).sample
1228-
pred_x_0 = predicted_origin(
1272+
pred_x_0 = get_predicted_original_sample(
12291273
target_noise_pred,
12301274
timesteps,
12311275
x_prev,
@@ -1235,15 +1279,15 @@ def compute_time_ids(original_size, crops_coords_top_left):
12351279
)
12361280
target = c_skip * x_prev + c_out * pred_x_0
12371281

1238-
# Calculate loss
1282+
# 10. Calculate loss
12391283
if args.loss_type == "l2":
12401284
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
12411285
elif args.loss_type == "huber":
12421286
loss = torch.mean(
12431287
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
12441288
)
12451289

1246-
# Backpropagate on the online student model (`unet`)
1290+
# 11. Backpropagate on the online student model (`unet`) (only LoRA)
12471291
accelerator.backward(loss)
12481292
if accelerator.sync_gradients:
12491293
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)

0 commit comments

Comments
 (0)