-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
There appears to be a set of issues related to unet_time_cond_proj_dim
argument when running the LCM training scripts in examples/consistency_distillation/
Running the script as-is
First: the argument is not defined in parse_args
and raises an error when you run the script with example arguments
Traceback (most recent call last):
File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1309, in <module>
main(args)
File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1148, in main
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
AttributeError: 'Namespace' object has no attribute 'unet_time_cond_proj_dim'
Adding unet_time_cond_proj
with default value None
Second: if you add the argument with a default value of None, the script crashes
parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=None,
help="unet_time_cond_proj_dim",
)
Traceback (most recent call last):
File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1309, in <module>
main(args)
File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1148, in main
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 312, in guidance_scale_embedding
half_dim = embedding_dim // 2
TypeError: unsupported operand type(s) for //: 'NoneType' and 'int'
Setting unet_time_cond_proj
to 32
If you add the argument with an integer value, the script also crashes
parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=32,
help="unet_time_cond_proj_dim",
)
Traceback (most recent call last):
File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1309, in <module>
main(args)
File "diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py", line 1158, in main
noise_pred = unet(
File ".env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File ".env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File ".env/lib/python3.10/site-packages/accelerate/utils/operations.py", line 659, in forward
return model_forward(*args, **kwargs)
File ".env/lib/python3.10/site-packages/accelerate/utils/operations.py", line 647, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File ".env/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "diffusers/src/diffusers/models/unet_2d_condition.py", line 932, in forward
emb = self.time_embedding(t_emb, timestep_cond)
File ".env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File ".env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "diffusers/src/diffusers/models/embeddings.py", line 225, in forward
sample = sample + self.cond_proj(condition)
TypeError: 'NoneType' object is not callable
What is the intended value for this argument, and what fixes are needed to run the script without error?
Reproduction
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/saved/model"
accelerate launch train_lcm_distill_sd_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \
--resolution=512 \
--learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url='/home/myuser/laionart/chunk_00000/{00000..00153}.tar' \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \
--gradient_checkpointing --enable_xformers_memory_efficient_attention \
--gradient_accumulation_steps=1 \
--use_8bit_adam \
--resume_from_checkpoint=latest \
--report_to=wandb \
--seed=453645634 \
--push_to_hub
Logs
No response
System Info
diffusers
version: 0.24.0.dev0- Platform: Linux-5.15.0-1044-gcp-x86_64-with-glibc2.31
- Python version: 3.10.12
- PyTorch version (GPU?): 2.1.1+cu121 (True)
- Huggingface_hub version: 0.19.3
- Transformers version: 4.35.2
- Accelerate version: 0.24.1
- xFormers version: not installed
- Using GPU in script?: A100 80GB
- Using distributed or parallel set-up in script?: No
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working