Skip to content

Commit 7a1d6c9

Browse files
committed
param upcasting
1 parent b7c0f95 commit 7a1d6c9

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

examples/consistency_distillation/train_lcm_distill_lora_sdxl.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,13 @@ def main(args):
785785
)
786786
unet.add_adapter(lora_config)
787787

788+
# Make sure the trainable params are in float32.
789+
if args.mixed_precision == "fp16":
790+
for param in unet.parameters():
791+
# only upcast trainable parameters (LoRA) into fp32
792+
if param.requires_grad:
793+
param.data = param.to(torch.float32)
794+
788795
# Also move the alpha and sigma noise schedules to accelerator.device.
789796
alpha_schedule = alpha_schedule.to(accelerator.device)
790797
sigma_schedule = sigma_schedule.to(accelerator.device)
@@ -855,11 +862,7 @@ def load_model_hook(models, input_dir):
855862
optimizer_class = torch.optim.AdamW
856863

857864
# 12. Optimizer creation
858-
params_to_optimize = []
859-
for param in unet.parameters():
860-
if param.requires_grad:
861-
param.data = param.to(torch.float32)
862-
params_to_optimize.append(param)
865+
params_to_optimize = filter(lambda p: p.requires_grad, unet.parameters())
863866
optimizer = optimizer_class(
864867
params_to_optimize,
865868
lr=args.learning_rate,

0 commit comments

Comments
 (0)