@@ -124,7 +124,16 @@ def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
124124
125125
126126class BitLinear (nn .Module ):
127- def __init__ (self , in_features : int , out_features : int , bias : bool , device = None , dtype = None ):
127+ def __init__ (
128+ self ,
129+ in_features : int ,
130+ out_features : int ,
131+ bias : bool ,
132+ device = None ,
133+ dtype = None ,
134+ use_rms_norm : bool = False ,
135+ rms_norm_eps : float = 1e-6 ,
136+ ):
128137 super ().__init__ ()
129138 self .dtype = dtype
130139 self .in_features = in_features
@@ -150,6 +159,13 @@ def __init__(self, in_features: int, out_features: int, bias: bool, device=None,
150159 else :
151160 self .bias = None
152161
162+ # Optional RMSNorm (applied on the activations before quantization).
163+ self .rms_norm = None
164+ if use_rms_norm :
165+ from ..models .llama .modeling_llama import LlamaRMSNorm
166+
167+ self .rms_norm = LlamaRMSNorm (in_features , eps = rms_norm_eps )
168+
153169 @torch .compile
154170 def activation_quant (self , input , num_bits = 8 ):
155171 """
@@ -180,6 +196,10 @@ def post_quant_process(self, input, input_scale, weight_scale):
180196 return out
181197
182198 def forward (self , input ):
199+ # Apply RMSNorm on the input if requested.
200+ if self .rms_norm is not None :
201+ input = self .rms_norm (input )
202+
183203 w = self .weight
184204 w_quant = unpack_weights (w , dtype = self .dtype )
185205 input_quant , input_scale = self .activation_quant (input )
@@ -245,9 +265,17 @@ def __init__(
245265 device = None ,
246266 dtype = None ,
247267 online_quant : bool = False ,
268+ use_rms_norm : bool = False ,
269+ rms_norm_eps : float = 1e-6 ,
248270 ):
249271 super ().__init__ (in_features , out_features , bias )
250272 self .online_quant = online_quant
273+ # Optional RMSNorm
274+ self .rms_norm = None
275+ if use_rms_norm :
276+ from ..models .llama .modeling_llama import LlamaRMSNorm
277+
278+ self .rms_norm = LlamaRMSNorm (in_features , eps = rms_norm_eps )
251279 if not online_quant :
252280 self .register_buffer (
253281 "weight_scale" ,
@@ -271,6 +299,10 @@ def load_hook(
271299 return state_dict
272300
273301 def forward (self , input ):
302+ # Optional RMSNorm on activations prior to quantization.
303+ if self .rms_norm is not None :
304+ input = self .rms_norm (input )
305+
274306 if self .online_quant :
275307 weight = WeightQuant .apply (self .weight )
276308 else :
@@ -318,6 +350,8 @@ def _replace_with_bitnet_linear(
318350 device = module .weight .device ,
319351 dtype = module .weight .dtype ,
320352 online_quant = (quantization_config .quantization_mode == "online" ),
353+ use_rms_norm = quantization_config .use_rms_norm ,
354+ rms_norm_eps = quantization_config .rms_norm_eps ,
321355 )
322356 if quantization_config .quantization_mode == "offline" :
323357 model ._modules [name ].requires_grad_ (False )
@@ -328,6 +362,8 @@ def _replace_with_bitnet_linear(
328362 bias = module .bias is not None ,
329363 device = module .weight .device ,
330364 dtype = module .weight .dtype ,
365+ use_rms_norm = quantization_config .use_rms_norm ,
366+ rms_norm_eps = quantization_config .rms_norm_eps ,
331367 )
332368 model ._modules [name ].requires_grad_ (False )
333369 has_been_replaced = True
@@ -363,7 +399,7 @@ def replace_with_bitnet_linear(
363399 model (`torch.nn.Module`):
364400 Input model or `torch.nn.Module` as the function is run recursively.
365401 modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
366- Names of the modules to not convert in `EetqLinear `. In practice we keep the `lm_head` in full precision
402+ Names of the modules to not convert in `BitLinear `. In practice we keep the `lm_head` in full precision
367403 for numerical stability reasons.
368404 current_key_name (`List[`str`]`, *optional*):
369405 An array to track the current key of the recursion. This is used to check whether the current key (part of
0 commit comments