Skip to content

Commit 7b1170b

Browse files
authored
Add WSD scheduler (#30231)
* Added WSD scheduler. * Added tests. * Fixed errors. * Fix formatting. * CI fixes.
1 parent 90cb55b commit 7b1170b

File tree

6 files changed

+82
-0
lines changed

6 files changed

+82
-0
lines changed

docs/source/en/main_classes/optimizer_schedules.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ The `.optimization` module provides:
6666

6767
[[autodoc]] get_inverse_sqrt_schedule
6868

69+
[[autodoc]] get_wsd_schedule
70+
6971
### Warmup (TensorFlow)
7072

7173
[[autodoc]] WarmUp

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3911,6 +3911,7 @@
39113911
"get_linear_schedule_with_warmup",
39123912
"get_polynomial_decay_schedule_with_warmup",
39133913
"get_scheduler",
3914+
"get_wsd_schedule",
39143915
]
39153916
_import_structure["pytorch_utils"] = [
39163917
"Conv1D",
@@ -8414,6 +8415,7 @@
84148415
get_linear_schedule_with_warmup,
84158416
get_polynomial_decay_schedule_with_warmup,
84168417
get_scheduler,
8418+
get_wsd_schedule,
84178419
)
84188420
from .pytorch_utils import Conv1D, apply_chunking_to_forward, prune_layer
84198421

src/transformers/optimization.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,73 @@ def get_cosine_with_min_lr_schedule_with_warmup(
387387
return LambdaLR(optimizer, lr_lambda, last_epoch)
388388

389389

390+
def _get_wsd_scheduler_lambda(
391+
current_step: int,
392+
*,
393+
num_warmup_steps: int,
394+
num_stable_steps: int,
395+
num_decay_steps: int,
396+
num_cycles: float,
397+
min_lr_ratio: float,
398+
):
399+
if current_step < num_warmup_steps:
400+
return float(current_step) / float(max(1, num_warmup_steps))
401+
if current_step < num_warmup_steps + num_stable_steps:
402+
return 1.0
403+
if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
404+
progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
405+
value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
406+
return (1.0 - min_lr_ratio) * value + min_lr_ratio
407+
return min_lr_ratio
408+
409+
410+
def get_wsd_schedule(
411+
optimizer: Optimizer,
412+
num_warmup_steps: int,
413+
num_stable_steps: int,
414+
num_decay_steps: int,
415+
min_lr_ratio: float = 0,
416+
num_cycles: float = 0.5,
417+
last_epoch: int = -1,
418+
):
419+
"""
420+
Create a schedule with a learning rate that has three stages:
421+
1. linear increase from 0 to initial lr.
422+
2. constant lr (equal to initial lr).
423+
3. decrease following the values of the cosine function between the initial lr set in the optimizer to
424+
a fraction of initial lr.
425+
426+
Args:
427+
optimizer ([`~torch.optim.Optimizer`]):
428+
The optimizer for which to schedule the learning rate.
429+
num_warmup_steps (`int`):
430+
The number of steps for the warmup phase.
431+
num_stable_steps (`int`):
432+
The number of steps for the stable phase.
433+
num_decay_steps (`int`):
434+
The number of steps for the cosine annealing phase.
435+
min_lr_ratio (`float`, *optional*, defaults to 0):
436+
The minimum learning rate as a ratio of the initial learning rate.
437+
num_cycles (`float`, *optional*, defaults to 0.5):
438+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
439+
following a half-cosine).
440+
last_epoch (`int`, *optional*, defaults to -1):
441+
The index of the last epoch when resuming training.
442+
443+
Return:
444+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
445+
"""
446+
lr_lambda = partial(
447+
_get_wsd_scheduler_lambda,
448+
num_warmup_steps=num_warmup_steps,
449+
num_stable_steps=num_stable_steps,
450+
num_decay_steps=num_decay_steps,
451+
min_lr_ratio=min_lr_ratio,
452+
num_cycles=num_cycles,
453+
)
454+
return LambdaLR(optimizer, lr_lambda, last_epoch)
455+
456+
390457
TYPE_TO_SCHEDULER_FUNCTION = {
391458
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
392459
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
@@ -397,6 +464,7 @@ def get_cosine_with_min_lr_schedule_with_warmup(
397464
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
398465
SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
399466
SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup,
467+
SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule,
400468
}
401469

402470

src/transformers/trainer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ class SchedulerType(ExplicitEnum):
412412
INVERSE_SQRT = "inverse_sqrt"
413413
REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
414414
COSINE_WITH_MIN_LR = "cosine_with_min_lr"
415+
WARMUP_STABLE_DECAY = "warmup_stable_decay"
415416

416417

417418
class TrainerMemoryTracker:

src/transformers/utils/dummy_pt_objects.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10023,6 +10023,10 @@ def get_scheduler(*args, **kwargs):
1002310023
requires_backends(get_scheduler, ["torch"])
1002410024

1002510025

10026+
def get_wsd_schedule(*args, **kwargs):
10027+
requires_backends(get_wsd_schedule, ["torch"])
10028+
10029+
1002610030
class Conv1D(metaclass=DummyObject):
1002710031
_backends = ["torch"]
1002810032

tests/optimization/test_optimization.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
get_inverse_sqrt_schedule,
3737
get_linear_schedule_with_warmup,
3838
get_polynomial_decay_schedule_with_warmup,
39+
get_wsd_schedule,
3940
)
4041

4142

@@ -150,6 +151,10 @@ def test_schedulers(self):
150151
{"num_warmup_steps": 2},
151152
[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
152153
),
154+
get_wsd_schedule: (
155+
{"num_warmup_steps": 2, "num_stable_steps": 2, "num_decay_steps": 3, "min_lr_ratio": 0.1},
156+
[0.0, 5.0, 10.0, 10.0, 10.0, 7.75, 3.25, 1.0, 1.0, 1.0],
157+
),
153158
}
154159

155160
for scheduler_func, data in scheds.items():

0 commit comments

Comments
 (0)