@@ -747,17 +747,22 @@ def collate_fn(examples):
747747    )
748748
749749    # Scheduler and math around the number of training steps. 
750-     overrode_max_train_steps   =   False 
751-     num_update_steps_per_epoch  =  math . ceil ( len ( train_dataloader )  /   args . gradient_accumulation_steps ) 
750+     # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. 
751+     num_warmup_steps_for_scheduler  =  args . lr_warmup_steps   *   accelerator . num_processes 
752752    if  args .max_train_steps  is  None :
753-         args .max_train_steps  =  args .num_train_epochs  *  num_update_steps_per_epoch 
754-         overrode_max_train_steps  =  True 
753+         len_train_dataloader_after_sharding  =  math .ceil (len (train_dataloader ) /  accelerator .num_processes )
754+         num_update_steps_per_epoch  =  math .ceil (len_train_dataloader_after_sharding  /  args .gradient_accumulation_steps )
755+         num_training_steps_for_scheduler  =  (
756+             args .num_train_epochs  *  num_update_steps_per_epoch  *  accelerator .num_processes 
757+         )
758+     else :
759+         num_training_steps_for_scheduler  =  args .max_train_steps  *  accelerator .num_processes 
755760
756761    lr_scheduler  =  get_scheduler (
757762        args .lr_scheduler ,
758763        optimizer = optimizer ,
759-         num_warmup_steps = args . lr_warmup_steps   *   accelerator . num_processes ,
760-         num_training_steps = args . max_train_steps   *   accelerator . num_processes ,
764+         num_warmup_steps = num_warmup_steps_for_scheduler ,
765+         num_training_steps = num_training_steps_for_scheduler ,
761766    )
762767
763768    # Prepare everything with our `accelerator`. 
@@ -782,8 +787,14 @@ def collate_fn(examples):
782787
783788    # We need to recalculate our total training steps as the size of the training dataloader may have changed. 
784789    num_update_steps_per_epoch  =  math .ceil (len (train_dataloader ) /  args .gradient_accumulation_steps )
785-     if  overrode_max_train_steps :
790+     if  args . max_train_steps   is   None :
786791        args .max_train_steps  =  args .num_train_epochs  *  num_update_steps_per_epoch 
792+         if  num_training_steps_for_scheduler  !=  args .max_train_steps  *  accelerator .num_processes :
793+             logger .warning (
794+                 f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )}  
795+                 f"the expected length ({ len_train_dataloader_after_sharding }  
796+                 f"This inconsistency may result in the learning rate scheduler not functioning properly." 
797+             )
787798    # Afterwards we recalculate our number of training epochs 
788799    args .num_train_epochs  =  math .ceil (args .max_train_steps  /  num_update_steps_per_epoch )
789800
0 commit comments