Skip to content

Commit 47fb1a6

Browse files
committed
add snr huber scheduler
1 parent 19a834c commit 47fb1a6

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

library/train_util.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3098,7 +3098,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
30983098
"--huber_schedule",
30993099
type=str,
31003100
default="exponential",
3101-
choices=["constant", "exponential", "snr"], #TODO: add snr
3101+
choices=["constant", "exponential", "snr"],
31023102
help="The type of loss to use and whether it's scheduled based on the timestep"
31033103
)
31043104
parser.add_argument(
@@ -4611,7 +4611,7 @@ def save_sd_model_on_train_end_common(
46114611
if args.huggingface_repo_id is not None:
46124612
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
46134613

4614-
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timesteps, b_size, device):
4614+
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
46154615

46164616
#TODO: if a huber loss is selected, it will use constant timesteps for each batch
46174617
# as. In the future there may be a smarter way
@@ -4623,12 +4623,12 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timest
46234623
timestep = timesteps.item()
46244624

46254625
if args.huber_schedule == "exponential":
4626-
alpha = - math.log(args.huber_c) / num_train_timesteps
4626+
alpha = - math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
46274627
huber_c = math.exp(-alpha * timestep)
46284628
elif args.huber_schedule == "snr":
4629-
# TODO
4630-
huber_c = args.huber_c # Placeholder
4631-
pass
4629+
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
4630+
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
4631+
huber_c = (1 - args.huber_c) / (1 + sigmas)**2 + args.huber_c
46324632
elif args.huber_schedule == "constant":
46334633
huber_c = args.huber_c
46344634
else:
@@ -4659,7 +4659,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
46594659
min_timestep = 0 if args.min_timestep is None else args.min_timestep
46604660
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
46614661

4662-
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler.config.num_train_timesteps, b_size, latents.device)
4662+
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
46634663

46644664
# Add noise to the latents according to the noise magnitude at each timestep
46654665
# (this is the forward diffusion process)

train_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def remove_model(old_ckpt_name):
419419
)
420420

421421
# Sample a random timestep for each image
422-
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler.config.num_train_timesteps, b_size, latents.device)
422+
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)
423423

424424
# Add noise to the latents according to the noise magnitude at each timestep
425425
# (this is the forward diffusion process)

0 commit comments

Comments
 (0)