Skip to content

Commit a3034c7

Browse files
authored
Add inverse sqrt learning rate scheduler (#21495)
* added inverse sqrt lr scheduler * Updated get_scheduler in src/transformers/optimization.py * Updated src/transformers/__init__.py * Added inverse sqrt lr scheduler test * Updated docs/source/en/main_classes/optimizer_schedules.mdx * Ran style and quality scripts * Fix get_inverse_sqrt_schedule docstring * Comment implementation URL
1 parent b9af152 commit a3034c7

File tree

6 files changed

+54
-0
lines changed

6 files changed

+54
-0
lines changed

docs/source/en/main_classes/optimizer_schedules.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ The `.optimization` module provides:
6060

6161
[[autodoc]] get_polynomial_decay_schedule_with_warmup
6262

63+
[[autodoc]] get_inverse_sqrt_schedule
64+
6365
### Warmup (TensorFlow)
6466

6567
[[autodoc]] WarmUp

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2588,6 +2588,7 @@
25882588
"get_constant_schedule_with_warmup",
25892589
"get_cosine_schedule_with_warmup",
25902590
"get_cosine_with_hard_restarts_schedule_with_warmup",
2591+
"get_inverse_sqrt_schedule",
25912592
"get_linear_schedule_with_warmup",
25922593
"get_polynomial_decay_schedule_with_warmup",
25932594
"get_scheduler",
@@ -5659,6 +5660,7 @@
56595660
get_constant_schedule_with_warmup,
56605661
get_cosine_schedule_with_warmup,
56615662
get_cosine_with_hard_restarts_schedule_with_warmup,
5663+
get_inverse_sqrt_schedule,
56625664
get_linear_schedule_with_warmup,
56635665
get_polynomial_decay_schedule_with_warmup,
56645666
get_scheduler,

src/transformers/optimization.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,13 +220,50 @@ def lr_lambda(current_step: int):
220220
return LambdaLR(optimizer, lr_lambda, last_epoch)
221221

222222

223+
def get_inverse_sqrt_schedule(
224+
optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1
225+
):
226+
"""
227+
Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
228+
warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
229+
230+
Args:
231+
optimizer ([`~torch.optim.Optimizer`]):
232+
The optimizer for which to schedule the learning rate.
233+
num_warmup_steps (`int`):
234+
The number of steps for the warmup phase.
235+
timescale (`int`, *optional*, defaults to `num_warmup_steps`):
236+
Time scale.
237+
last_epoch (`int`, *optional*, defaults to -1):
238+
The index of the last epoch when resuming training.
239+
240+
Return:
241+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
242+
"""
243+
# Note: this implementation is adapted from
244+
# https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
245+
246+
if timescale is None:
247+
timescale = num_warmup_steps
248+
249+
def lr_lambda(current_step: int):
250+
if current_step < num_warmup_steps:
251+
return float(current_step) / float(max(1, num_warmup_steps))
252+
shift = timescale - num_warmup_steps
253+
decay = 1.0 / math.sqrt((current_step + shift) / timescale)
254+
return decay
255+
256+
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
257+
258+
223259
TYPE_TO_SCHEDULER_FUNCTION = {
224260
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
225261
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
226262
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
227263
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
228264
SchedulerType.CONSTANT: get_constant_schedule,
229265
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
266+
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
230267
}
231268

232269

@@ -263,6 +300,9 @@ def get_scheduler(
263300
if name == SchedulerType.CONSTANT_WITH_WARMUP:
264301
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
265302

303+
if name == SchedulerType.INVERSE_SQRT:
304+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
305+
266306
# All other schedulers require `num_training_steps`
267307
if num_training_steps is None:
268308
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")

src/transformers/trainer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ class SchedulerType(ExplicitEnum):
363363
POLYNOMIAL = "polynomial"
364364
CONSTANT = "constant"
365365
CONSTANT_WITH_WARMUP = "constant_with_warmup"
366+
INVERSE_SQRT = "inverse_sqrt"
366367

367368

368369
class TrainerMemoryTracker:

src/transformers/utils/dummy_pt_objects.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7019,6 +7019,10 @@ def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
70197019
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"])
70207020

70217021

7022+
def get_inverse_sqrt_schedule(*args, **kwargs):
7023+
requires_backends(get_inverse_sqrt_schedule, ["torch"])
7024+
7025+
70227026
def get_linear_schedule_with_warmup(*args, **kwargs):
70237027
requires_backends(get_linear_schedule_with_warmup, ["torch"])
70247028

tests/optimization/test_optimization.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
get_constant_schedule_with_warmup,
3434
get_cosine_schedule_with_warmup,
3535
get_cosine_with_hard_restarts_schedule_with_warmup,
36+
get_inverse_sqrt_schedule,
3637
get_linear_schedule_with_warmup,
3738
get_polynomial_decay_schedule_with_warmup,
3839
)
@@ -145,6 +146,10 @@ def test_schedulers(self):
145146
{**common_kwargs, "power": 2.0, "lr_end": 1e-7},
146147
[0.0, 5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156],
147148
),
149+
get_inverse_sqrt_schedule: (
150+
{"num_warmup_steps": 2},
151+
[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
152+
),
148153
}
149154

150155
for scheduler_func, data in scheds.items():

0 commit comments

Comments
 (0)