@@ -387,6 +387,73 @@ def get_cosine_with_min_lr_schedule_with_warmup(
387387 return LambdaLR (optimizer , lr_lambda , last_epoch )
388388
389389
390+ def _get_wsd_scheduler_lambda (
391+ current_step : int ,
392+ * ,
393+ num_warmup_steps : int ,
394+ num_stable_steps : int ,
395+ num_decay_steps : int ,
396+ num_cycles : float ,
397+ min_lr_ratio : float ,
398+ ):
399+ if current_step < num_warmup_steps :
400+ return float (current_step ) / float (max (1 , num_warmup_steps ))
401+ if current_step < num_warmup_steps + num_stable_steps :
402+ return 1.0
403+ if current_step < num_warmup_steps + num_stable_steps + num_decay_steps :
404+ progress = float (current_step - num_warmup_steps - num_stable_steps ) / float (max (1 , num_decay_steps ))
405+ value = max (0.0 , 0.5 * (1.0 + math .cos (math .pi * float (num_cycles ) * 2.0 * progress )))
406+ return (1.0 - min_lr_ratio ) * value + min_lr_ratio
407+ return min_lr_ratio
408+
409+
410+ def get_wsd_schedule (
411+ optimizer : Optimizer ,
412+ num_warmup_steps : int ,
413+ num_stable_steps : int ,
414+ num_decay_steps : int ,
415+ min_lr_ratio : float = 0 ,
416+ num_cycles : float = 0.5 ,
417+ last_epoch : int = - 1 ,
418+ ):
419+ """
420+ Create a schedule with a learning rate that has three stages:
421+ 1. linear increase from 0 to initial lr.
422+ 2. constant lr (equal to initial lr).
423+ 3. decrease following the values of the cosine function between the initial lr set in the optimizer to
424+ a fraction of initial lr.
425+
426+ Args:
427+ optimizer ([`~torch.optim.Optimizer`]):
428+ The optimizer for which to schedule the learning rate.
429+ num_warmup_steps (`int`):
430+ The number of steps for the warmup phase.
431+ num_stable_steps (`int`):
432+ The number of steps for the stable phase.
433+ num_decay_steps (`int`):
434+ The number of steps for the cosine annealing phase.
435+ min_lr_ratio (`float`, *optional*, defaults to 0):
436+ The minimum learning rate as a ratio of the initial learning rate.
437+ num_cycles (`float`, *optional*, defaults to 0.5):
438+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
439+ following a half-cosine).
440+ last_epoch (`int`, *optional*, defaults to -1):
441+ The index of the last epoch when resuming training.
442+
443+ Return:
444+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
445+ """
446+ lr_lambda = partial (
447+ _get_wsd_scheduler_lambda ,
448+ num_warmup_steps = num_warmup_steps ,
449+ num_stable_steps = num_stable_steps ,
450+ num_decay_steps = num_decay_steps ,
451+ min_lr_ratio = min_lr_ratio ,
452+ num_cycles = num_cycles ,
453+ )
454+ return LambdaLR (optimizer , lr_lambda , last_epoch )
455+
456+
390457TYPE_TO_SCHEDULER_FUNCTION = {
391458 SchedulerType .LINEAR : get_linear_schedule_with_warmup ,
392459 SchedulerType .COSINE : get_cosine_schedule_with_warmup ,
@@ -397,6 +464,7 @@ def get_cosine_with_min_lr_schedule_with_warmup(
397464 SchedulerType .INVERSE_SQRT : get_inverse_sqrt_schedule ,
398465 SchedulerType .REDUCE_ON_PLATEAU : get_reduce_on_plateau_schedule ,
399466 SchedulerType .COSINE_WITH_MIN_LR : get_cosine_with_min_lr_schedule_with_warmup ,
467+ SchedulerType .WARMUP_STABLE_DECAY : get_wsd_schedule ,
400468}
401469
402470
0 commit comments