Skip to content

LCM train scripts crash due to missing unet_time_cond_proj_dim argument #5829

@justindujardin

Description

@justindujardin

Describe the bug

There appears to be a set of issues related to unet_time_cond_proj_dim argument when running the LCM training scripts in examples/consistency_distillation/

Running the script as-is
First: the argument is not defined in parse_args and raises an error when you run the script with example arguments

Traceback (most recent call last):
  File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1309, in <module>
    main(args)
  File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1148, in main
    w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
AttributeError: 'Namespace' object has no attribute 'unet_time_cond_proj_dim'

Adding unet_time_cond_proj with default value None

Second: if you add the argument with a default value of None, the script crashes

    parser.add_argument(
        "--unet_time_cond_proj_dim",
        type=int,
        default=None,
        help="unet_time_cond_proj_dim",
    )
Traceback (most recent call last):
  File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1309, in <module>
    main(args)
  File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1148, in main
    w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
  File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 312, in guidance_scale_embedding
    half_dim = embedding_dim // 2
TypeError: unsupported operand type(s) for //: 'NoneType' and 'int'

Setting unet_time_cond_proj to 32

If you add the argument with an integer value, the script also crashes

    parser.add_argument(
        "--unet_time_cond_proj_dim",
        type=int,
        default=32,
        help="unet_time_cond_proj_dim",
    )
Traceback (most recent call last):
  File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1309, in <module>
    main(args)
  File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1158, in main
    noise_pred = unet(
  File ".env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File ".env/lib/python3.10/site-packages/accelerate/utils/operations.py", line 659, in forward
    return model_forward(*args, **kwargs)
  File ".env/lib/python3.10/site-packages/accelerate/utils/operations.py", line 647, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File ".env/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "diffusers/src/diffusers/models/unet_2d_condition.py", line 932, in forward
    emb = self.time_embedding(t_emb, timestep_cond)
  File ".env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "diffusers/src/diffusers/models/embeddings.py", line 225, in forward
    sample = sample + self.cond_proj(condition)
TypeError: 'NoneType' object is not callable

What is the intended value for this argument, and what fixes are needed to run the script without error?

Reproduction

export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/saved/model"

accelerate launch train_lcm_distill_sd_wds.py \
    --pretrained_teacher_model=$MODEL_DIR \
    --output_dir=$OUTPUT_DIR \
    --mixed_precision=fp16 \
    --resolution=512 \
    --learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
    --max_train_steps=1000 \
    --max_train_samples=4000000 \
    --dataloader_num_workers=8 \
    --train_shards_path_or_url='/home/myuser/laionart/chunk_00000/{00000..00153}.tar' \
    --validation_steps=200 \
    --checkpointing_steps=200 --checkpoints_total_limit=10 \
    --train_batch_size=12 \
    --gradient_checkpointing --enable_xformers_memory_efficient_attention \
    --gradient_accumulation_steps=1 \
    --use_8bit_adam \
    --resume_from_checkpoint=latest \
    --report_to=wandb \
    --seed=453645634 \
    --push_to_hub

Logs

No response

System Info

  • diffusers version: 0.24.0.dev0
  • Platform: Linux-5.15.0-1044-gcp-x86_64-with-glibc2.31
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.1.1+cu121 (True)
  • Huggingface_hub version: 0.19.3
  • Transformers version: 4.35.2
  • Accelerate version: 0.24.1
  • xFormers version: not installed
  • Using GPU in script?: A100 80GB
  • Using distributed or parallel set-up in script?: No

Who can help?

@patrickvonplaten @sayakpaul

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions