@@ -199,19 +199,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
199199
200200
201201# Compare LCMScheduler.step, Step 4
202- def predicted_origin (model_output , timesteps , sample , prediction_type , alphas , sigmas ):
202+ def get_predicted_original_sample (model_output , timesteps , sample , prediction_type , alphas , sigmas ):
203+ alphas = extract_into_tensor (alphas , timesteps , sample .shape )
204+ sigmas = extract_into_tensor (sigmas , timesteps , sample .shape )
203205 if prediction_type == "epsilon" :
204- sigmas = extract_into_tensor (sigmas , timesteps , sample .shape )
205- alphas = extract_into_tensor (alphas , timesteps , sample .shape )
206206 pred_x_0 = (sample - sigmas * model_output ) / alphas
207+ elif prediction_type == "sample" :
208+ pred_x_0 = model_output
207209 elif prediction_type == "v_prediction" :
208- pred_x_0 = alphas [ timesteps ] * sample - sigmas [ timesteps ] * model_output
210+ pred_x_0 = alphas * sample - sigmas * model_output
209211 else :
210- raise ValueError (f"Prediction type { prediction_type } currently not supported." )
212+ raise ValueError (
213+ f"Prediction type { prediction_type } is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
214+ f" are supported."
215+ )
211216
212217 return pred_x_0
213218
214219
220+ # Based on step 4 in DDIMScheduler.step
221+ def get_predicted_noise (model_output , timesteps , sample , prediction_type , alphas , sigmas ):
222+ alphas = extract_into_tensor (alphas , timesteps , sample .shape )
223+ sigmas = extract_into_tensor (sigmas , timesteps , sample .shape )
224+ if prediction_type == "epsilon" :
225+ pred_epsilon = model_output
226+ elif prediction_type == "sample" :
227+ pred_epsilon = (sample - alphas * model_output ) / sigmas
228+ elif prediction_type == "v_prediction" :
229+ pred_epsilon = alphas * model_output + sigmas * sample
230+ else :
231+ raise ValueError (
232+ f"Prediction type { prediction_type } is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
233+ f" are supported."
234+ )
235+
236+ return pred_epsilon
237+
238+
215239def extract_into_tensor (a , t , x_shape ):
216240 b , * _ = t .shape
217241 out = a .gather (- 1 , t )
@@ -676,24 +700,25 @@ def main(args):
676700 args .pretrained_teacher_model , subfolder = "scheduler" , revision = args .teacher_revision
677701 )
678702
679- # The scheduler calculates the alpha and sigma schedule for us
703+ # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
680704 alpha_schedule = torch .sqrt (noise_scheduler .alphas_cumprod )
681705 sigma_schedule = torch .sqrt (1 - noise_scheduler .alphas_cumprod )
706+ # Initialize the DDIM ODE solver for distillation.
682707 solver = DDIMSolver (
683708 noise_scheduler .alphas_cumprod .numpy (),
684709 timesteps = noise_scheduler .config .num_train_timesteps ,
685710 ddim_timesteps = args .num_ddim_timesteps ,
686711 )
687712
688- # 2. Load tokenizers from SD-XL checkpoint.
713+ # 2. Load tokenizers from SDXL checkpoint.
689714 tokenizer_one = AutoTokenizer .from_pretrained (
690715 args .pretrained_teacher_model , subfolder = "tokenizer" , revision = args .teacher_revision , use_fast = False
691716 )
692717 tokenizer_two = AutoTokenizer .from_pretrained (
693718 args .pretrained_teacher_model , subfolder = "tokenizer_2" , revision = args .teacher_revision , use_fast = False
694719 )
695720
696- # 3. Load text encoders from SD-XL checkpoint.
721+ # 3. Load text encoders from SDXL checkpoint.
697722 # import correct text encoder classes
698723 text_encoder_cls_one = import_model_class_from_model_name_or_path (
699724 args .pretrained_teacher_model , args .teacher_revision
@@ -709,7 +734,7 @@ def main(args):
709734 args .pretrained_teacher_model , subfolder = "text_encoder_2" , revision = args .teacher_revision
710735 )
711736
712- # 4. Load VAE from SD-XL checkpoint (or more stable VAE)
737+ # 4. Load VAE from SDXL checkpoint (or more stable VAE)
713738 vae_path = (
714739 args .pretrained_teacher_model
715740 if args .pretrained_vae_model_name_or_path is None
@@ -726,7 +751,7 @@ def main(args):
726751 text_encoder_one .requires_grad_ (False )
727752 text_encoder_two .requires_grad_ (False )
728753
729- # 7. Create online (`unet`) student U-Net.
754+ # 7. Create online student U-Net.
730755 unet = UNet2DConditionModel .from_pretrained (
731756 args .pretrained_teacher_model , subfolder = "unet" , revision = args .teacher_revision
732757 )
@@ -743,7 +768,7 @@ def main(args):
743768 f"Controlnet loaded as datatype { accelerator .unwrap_model (unet ).dtype } . { low_precision_error_string } "
744769 )
745770
746- # 9 . Handle mixed precision and device placement
771+ # 8 . Handle mixed precision and device placement
747772 # For mixed precision training we cast all non-trainable weigths to half-precision
748773 # as these weights are only used for inference, keeping weights in full precision is not required.
749774 weight_dtype = torch .float32
@@ -762,7 +787,7 @@ def main(args):
762787 text_encoder_one .to (accelerator .device , dtype = weight_dtype )
763788 text_encoder_two .to (accelerator .device , dtype = weight_dtype )
764789
765- # 8 . Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
790+ # 9 . Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
766791 lora_config = LoraConfig (
767792 r = args .lora_rank ,
768793 lora_alpha = args .lora_rank ,
@@ -1007,7 +1032,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10071032
10081033 compute_embeddings_fn = functools .partial (compute_embeddings , text_encoders = text_encoders , tokenizers = tokenizers )
10091034
1010- # 14 . LR Scheduler creation
1035+ # 15 . LR Scheduler creation
10111036 # Scheduler and math around the number of training steps.
10121037 overrode_max_train_steps = False
10131038 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -1027,7 +1052,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10271052 num_training_steps = args .max_train_steps * accelerator .num_processes ,
10281053 )
10291054
1030- # 15 . Prepare for training
1055+ # 16 . Prepare for training
10311056 # Prepare everything with our `accelerator`.
10321057 unet , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
10331058 unet , optimizer , train_dataloader , lr_scheduler
@@ -1046,7 +1071,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10461071 tracker_config = dict (vars (args ))
10471072 accelerator .init_trackers (args .tracker_project_name , config = tracker_config )
10481073
1049- # 16 . Train!
1074+ # 17 . Train!
10501075 total_batch_size = args .train_batch_size * accelerator .num_processes * args .gradient_accumulation_steps
10511076
10521077 logger .info ("***** Running training *****" )
@@ -1098,6 +1123,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10981123 for epoch in range (first_epoch , args .num_train_epochs ):
10991124 for step , batch in enumerate (train_dataloader ):
11001125 with accelerator .accumulate (unet ):
1126+ # 1. Load and process the image and text conditioning
11011127 pixel_values , text , orig_size , crop_coords = (
11021128 batch ["pixel_values" ],
11031129 batch ["captions" ],
@@ -1118,44 +1144,43 @@ def compute_time_ids(original_size, crops_coords_top_left):
11181144 if args .pretrained_vae_model_name_or_path is None :
11191145 latents = latents .to (weight_dtype )
11201146
1121- # Sample noise that we'll add to the latents
1122- noise = torch . randn_like ( latents )
1147+ # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
1148+ # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
11231149 bsz = latents .shape [0 ]
1124-
1125- # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
11261150 topk = noise_scheduler .config .num_train_timesteps // args .num_ddim_timesteps
11271151 index = torch .randint (0 , args .num_ddim_timesteps , (bsz ,), device = latents .device ).long ()
11281152 start_timesteps = solver .ddim_timesteps [index ]
11291153 timesteps = start_timesteps - topk
11301154 timesteps = torch .where (timesteps < 0 , torch .zeros_like (timesteps ), timesteps )
11311155
1132- # Get boundary scalings for start_timesteps and (end) timesteps.
1156+ # 3. Get boundary scalings for start_timesteps and (end) timesteps.
11331157 c_skip_start , c_out_start = scalings_for_boundary_conditions (start_timesteps )
11341158 c_skip_start , c_out_start = [append_dims (x , latents .ndim ) for x in [c_skip_start , c_out_start ]]
11351159 c_skip , c_out = scalings_for_boundary_conditions (timesteps )
11361160 c_skip , c_out = [append_dims (x , latents .ndim ) for x in [c_skip , c_out ]]
11371161
1138- # Add noise to the latents according to the noise magnitude at each timestep
1139- # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
1162+ # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
1163+ # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
1164+ noise = torch .randn_like (latents )
11401165 noisy_model_input = noise_scheduler .add_noise (latents , noise , start_timesteps )
11411166
1142- # Sample a random guidance scale w from U[w_min, w_max] and embed it
1167+ # 5. Sample a random guidance scale w from U[w_min, w_max]
1168+ # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
11431169 w = (args .w_max - args .w_min ) * torch .rand ((bsz ,)) + args .w_min
11441170 w = w .reshape (bsz , 1 , 1 , 1 )
11451171 w = w .to (device = latents .device , dtype = latents .dtype )
11461172
1147- # Prepare prompt embeds and unet_added_conditions
1173+ # 6. Prepare prompt embeds and unet_added_conditions
11481174 prompt_embeds = encoded_text .pop ("prompt_embeds" )
11491175
1150- # Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
1176+ # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input) , w, c, t_{n + k} (start_timesteps)
11511177 noise_pred = unet (
11521178 noisy_model_input ,
11531179 start_timesteps ,
11541180 encoder_hidden_states = prompt_embeds ,
11551181 added_cond_kwargs = encoded_text ,
11561182 ).sample
1157-
1158- pred_x_0 = predicted_origin (
1183+ pred_x_0 = get_predicted_original_sample (
11591184 noise_pred ,
11601185 start_timesteps ,
11611186 noisy_model_input ,
@@ -1165,20 +1190,28 @@ def compute_time_ids(original_size, crops_coords_top_left):
11651190 )
11661191 model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
11671192
1168- # Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
1169- # noisy_latents with both the conditioning embedding c and unconditional embedding 0
1170- # Get teacher model prediction on noisy_latents and conditional embedding
1171- # Notice that we're disabling the adapter layers within the `unet` and then it becomes a
1172- # regular teacher. This way, we don't have to separately initialize a teacher UNet.
1193+ # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
1194+ # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
1195+ # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
1196+ # solver timestep.
11731197 unet .disable_adapters ()
11741198 with torch .no_grad ():
1199+ # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
11751200 cond_teacher_output = unet (
11761201 noisy_model_input ,
11771202 start_timesteps ,
11781203 encoder_hidden_states = prompt_embeds ,
11791204 added_cond_kwargs = {k : v .to (weight_dtype ) for k , v in encoded_text .items ()},
11801205 ).sample
1181- cond_pred_x0 = predicted_origin (
1206+ cond_pred_x0 = get_predicted_original_sample (
1207+ cond_teacher_output ,
1208+ start_timesteps ,
1209+ noisy_model_input ,
1210+ noise_scheduler .config .prediction_type ,
1211+ alpha_schedule ,
1212+ sigma_schedule ,
1213+ )
1214+ cond_pred_noise = get_predicted_noise (
11821215 cond_teacher_output ,
11831216 start_timesteps ,
11841217 noisy_model_input ,
@@ -1187,19 +1220,26 @@ def compute_time_ids(original_size, crops_coords_top_left):
11871220 sigma_schedule ,
11881221 )
11891222
1190- # Create uncond embeds for classifier free guidance
1223+ # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
11911224 uncond_prompt_embeds = torch .zeros_like (prompt_embeds )
11921225 uncond_pooled_prompt_embeds = torch .zeros_like (encoded_text ["text_embeds" ])
11931226 uncond_added_conditions = copy .deepcopy (encoded_text )
1194- # Get teacher model prediction on noisy_latents and unconditional embedding
11951227 uncond_added_conditions ["text_embeds" ] = uncond_pooled_prompt_embeds
11961228 uncond_teacher_output = unet (
11971229 noisy_model_input ,
11981230 start_timesteps ,
11991231 encoder_hidden_states = uncond_prompt_embeds .to (weight_dtype ),
12001232 added_cond_kwargs = {k : v .to (weight_dtype ) for k , v in uncond_added_conditions .items ()},
12011233 ).sample
1202- uncond_pred_x0 = predicted_origin (
1234+ uncond_pred_x0 = get_predicted_original_sample (
1235+ uncond_teacher_output ,
1236+ start_timesteps ,
1237+ noisy_model_input ,
1238+ noise_scheduler .config .prediction_type ,
1239+ alpha_schedule ,
1240+ sigma_schedule ,
1241+ )
1242+ uncond_pred_noise = get_predicted_noise (
12031243 uncond_teacher_output ,
12041244 start_timesteps ,
12051245 noisy_model_input ,
@@ -1208,24 +1248,28 @@ def compute_time_ids(original_size, crops_coords_top_left):
12081248 sigma_schedule ,
12091249 )
12101250
1211- # Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
1251+ # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
1252+ # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
12121253 pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0 )
1213- pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output )
1214- x_prev = solver .ddim_step (pred_x0 , pred_noise , index )
1215- x_prev = x_prev .to (unet .dtype )
1254+ pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise )
1255+ # 4. Run one step of the ODE solver to estimate the next point x_prev on the
1256+ # augmented PF-ODE trajectory (solving backward in time)
1257+ # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
1258+ x_prev = solver .ddim_step (pred_x0 , pred_noise , index ).to (unet .dtype )
12161259
12171260 # re-enable unet adapters
12181261 unet .enable_adapters ()
12191262
1220- # Get target LCM prediction on x_prev, w, c, t_n
1263+ # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
1264+ # Note that we do not use a separate target network for LCM-LoRA distillation.
12211265 with torch .no_grad ():
12221266 target_noise_pred = unet (
12231267 x_prev ,
12241268 timesteps ,
12251269 encoder_hidden_states = prompt_embeds ,
12261270 added_cond_kwargs = {k : v .to (weight_dtype ) for k , v in encoded_text .items ()},
12271271 ).sample
1228- pred_x_0 = predicted_origin (
1272+ pred_x_0 = get_predicted_original_sample (
12291273 target_noise_pred ,
12301274 timesteps ,
12311275 x_prev ,
@@ -1235,15 +1279,15 @@ def compute_time_ids(original_size, crops_coords_top_left):
12351279 )
12361280 target = c_skip * x_prev + c_out * pred_x_0
12371281
1238- # Calculate loss
1282+ # 10. Calculate loss
12391283 if args .loss_type == "l2" :
12401284 loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
12411285 elif args .loss_type == "huber" :
12421286 loss = torch .mean (
12431287 torch .sqrt ((model_pred .float () - target .float ()) ** 2 + args .huber_c ** 2 ) - args .huber_c
12441288 )
12451289
1246- # Backpropagate on the online student model (`unet`)
1290+ # 11. Backpropagate on the online student model (`unet`) (only LoRA )
12471291 accelerator .backward (loss )
12481292 if accelerator .sync_gradients :
12491293 accelerator .clip_grad_norm_ (params_to_optimize , args .max_grad_norm )
0 commit comments