@@ -507,8 +507,8 @@ def __init__(
507507 dtype : torch .dtype ,
508508 short_factor : List [float ],
509509 long_factor : List [float ],
510- short_mscale : float = 1.1 ,
511- long_mscale : float = 1.225 ,
510+ short_mscale : float = 1.0 ,
511+ long_mscale : float = 1.0 ,
512512 ):
513513 super ().__init__ ()
514514
@@ -530,6 +530,16 @@ def __init__(
530530 self .short_mscale = short_mscale
531531 self .long_mscale = long_mscale
532532
533+ scale = (self .max_position_embeddings /
534+ self .original_max_position_embeddings )
535+
536+ if scale <= 1.0 :
537+ self .scaling_factor = 1.0
538+ else :
539+ self .scaling_factor = math .sqrt (
540+ 1 + math .log (scale ) /
541+ math .log (self .original_max_position_embeddings ))
542+
533543 short_cache = self ._compute_cos_sin_cache (
534544 original_max_position_embeddings , short_factor , short_mscale )
535545 short_cache = short_cache .to (dtype )
@@ -565,8 +575,8 @@ def _compute_cos_sin_cache(
565575 inv_freq = self ._compute_inv_freq (rescale_factors )
566576 t = torch .arange (max_position_embeddings , dtype = torch .float )
567577 freqs = torch .einsum ("i,j -> ij" , t , inv_freq )
568- cos = freqs .cos () * mscale
569- sin = freqs .sin () * mscale
578+ cos = freqs .cos () * mscale * self . scaling_factor
579+ sin = freqs .sin () * mscale * self . scaling_factor
570580 cache = torch .cat ((cos , sin ), dim = - 1 )
571581 return cache
572582
0 commit comments