@@ -3098,7 +3098,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
30983098 "--huber_schedule" ,
30993099 type = str ,
31003100 default = "exponential" ,
3101- choices = ["constant" , "exponential" , "snr" ], #TODO: add snr
3101+ choices = ["constant" , "exponential" , "snr" ],
31023102 help = "The type of loss to use and whether it's scheduled based on the timestep"
31033103 )
31043104 parser .add_argument (
@@ -4611,7 +4611,7 @@ def save_sd_model_on_train_end_common(
46114611 if args .huggingface_repo_id is not None :
46124612 huggingface_util .upload (args , out_dir , "/" + model_name , force_sync_upload = True )
46134613
4614- def get_timesteps_and_huber_c (args , min_timestep , max_timestep , num_train_timesteps , b_size , device ):
4614+ def get_timesteps_and_huber_c (args , min_timestep , max_timestep , noise_scheduler , b_size , device ):
46154615
46164616 #TODO: if a huber loss is selected, it will use constant timesteps for each batch
46174617 # as. In the future there may be a smarter way
@@ -4623,12 +4623,12 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timest
46234623 timestep = timesteps .item ()
46244624
46254625 if args .huber_schedule == "exponential" :
4626- alpha = - math .log (args .huber_c ) / num_train_timesteps
4626+ alpha = - math .log (args .huber_c ) / noise_scheduler . config . num_train_timesteps
46274627 huber_c = math .exp (- alpha * timestep )
46284628 elif args .huber_schedule == "snr" :
4629- # TODO
4630- huber_c = args . huber_c # Placeholder
4631- pass
4629+ alphas_cumprod = noise_scheduler . alphas_cumprod [ timestep ]
4630+ sigmas = (( 1.0 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
4631+ huber_c = ( 1 - args . huber_c ) / ( 1 + sigmas ) ** 2 + args . huber_c
46324632 elif args .huber_schedule == "constant" :
46334633 huber_c = args .huber_c
46344634 else :
@@ -4659,7 +4659,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
46594659 min_timestep = 0 if args .min_timestep is None else args .min_timestep
46604660 max_timestep = noise_scheduler .config .num_train_timesteps if args .max_timestep is None else args .max_timestep
46614661
4662- timesteps , huber_c = get_timesteps_and_huber_c (args , min_timestep , max_timestep , noise_scheduler . config . num_train_timesteps , b_size , latents .device )
4662+ timesteps , huber_c = get_timesteps_and_huber_c (args , min_timestep , max_timestep , noise_scheduler , b_size , latents .device )
46634663
46644664 # Add noise to the latents according to the noise magnitude at each timestep
46654665 # (this is the forward diffusion process)
0 commit comments