@@ -729,7 +729,22 @@ def sample(
729729 def load_weights (self , weights : Iterable [Tuple [str ,
730730 torch .Tensor ]]) -> Set [str ]:
731731 loader = AutoWeightsLoader (self , skip_prefixes = ["proj_out." ])
732- loaded_weights = [(name , loaded_weight )
733- for name , loaded_weight in weights ]
734732 mapper = WeightsMapper ({".fc1." : ".mlp.fc1." , ".fc2." : ".mlp.fc2." })
735- return loader .load_weights (loaded_weights , mapper = mapper )
733+ # add fake zeros bias for k_proj to state_dict
734+ weights = _create_fake_bias_for_k_proj (weights )
735+ return loader .load_weights (weights , mapper = mapper )
736+
737+
738+ def _create_fake_bias_for_k_proj (
739+ weights : Iterable [Tuple [str , torch .Tensor ]]
740+ ) -> Iterable [Tuple [str , torch .Tensor ]]:
741+ """
742+ Create full zeros bias for k_proj weight in self-attention layers.
743+ So that the bias for k_proj in qkv_proj can be initialized with zeros.
744+ """
745+ for name , weight in weights :
746+ if ".self_attn.k_proj.weight" in name :
747+ bias = torch .zeros (weight .size (0 ))
748+ bias_name = name .replace ("weight" , "bias" )
749+ yield from [(name , weight ), (bias_name , bias )]
750+ yield name , weight
0 commit comments