@@ -220,13 +220,50 @@ def lr_lambda(current_step: int):
220220 return LambdaLR (optimizer , lr_lambda , last_epoch )
221221
222222
223+ def get_inverse_sqrt_schedule (
224+ optimizer : Optimizer , num_warmup_steps : int , timescale : int = None , last_epoch : int = - 1
225+ ):
226+ """
227+ Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
228+ warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
229+
230+ Args:
231+ optimizer ([`~torch.optim.Optimizer`]):
232+ The optimizer for which to schedule the learning rate.
233+ num_warmup_steps (`int`):
234+ The number of steps for the warmup phase.
235+ timescale (`int`, *optional*, defaults to `num_warmup_steps`):
236+ Time scale.
237+ last_epoch (`int`, *optional*, defaults to -1):
238+ The index of the last epoch when resuming training.
239+
240+ Return:
241+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
242+ """
243+ # Note: this implementation is adapted from
244+ # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
245+
246+ if timescale is None :
247+ timescale = num_warmup_steps
248+
249+ def lr_lambda (current_step : int ):
250+ if current_step < num_warmup_steps :
251+ return float (current_step ) / float (max (1 , num_warmup_steps ))
252+ shift = timescale - num_warmup_steps
253+ decay = 1.0 / math .sqrt ((current_step + shift ) / timescale )
254+ return decay
255+
256+ return LambdaLR (optimizer , lr_lambda , last_epoch = last_epoch )
257+
258+
223259TYPE_TO_SCHEDULER_FUNCTION = {
224260 SchedulerType .LINEAR : get_linear_schedule_with_warmup ,
225261 SchedulerType .COSINE : get_cosine_schedule_with_warmup ,
226262 SchedulerType .COSINE_WITH_RESTARTS : get_cosine_with_hard_restarts_schedule_with_warmup ,
227263 SchedulerType .POLYNOMIAL : get_polynomial_decay_schedule_with_warmup ,
228264 SchedulerType .CONSTANT : get_constant_schedule ,
229265 SchedulerType .CONSTANT_WITH_WARMUP : get_constant_schedule_with_warmup ,
266+ SchedulerType .INVERSE_SQRT : get_inverse_sqrt_schedule ,
230267}
231268
232269
@@ -263,6 +300,9 @@ def get_scheduler(
263300 if name == SchedulerType .CONSTANT_WITH_WARMUP :
264301 return schedule_func (optimizer , num_warmup_steps = num_warmup_steps )
265302
303+ if name == SchedulerType .INVERSE_SQRT :
304+ return schedule_func (optimizer , num_warmup_steps = num_warmup_steps )
305+
266306 # All other schedulers require `num_training_steps`
267307 if num_training_steps is None :
268308 raise ValueError (f"{ name } requires `num_training_steps`, please provide that argument." )
0 commit comments