@@ -604,8 +604,9 @@ def __init__(
604604
605605 rope_theta = getattr (config , "rope_theta" , 10000 )
606606
607- head_dim = getattr (config , "head_dim" ,
608- config .hidden_size // config .num_attention_heads )
607+ head_dim = getattr (config , "head_dim" , None )
608+ if head_dim is None :
609+ head_dim = config .hidden_size // config .num_attention_heads
609610 if hasattr (config , "max_model_len" ) and isinstance (
610611 config .max_model_len , int ):
611612 max_position_embeddings = min (config .max_position_embeddings ,
@@ -861,8 +862,9 @@ def layer_fn(prefix):
861862 cache_shape = self .cache_shape )
862863
863864 rope_theta = getattr (config , "rope_theta" , 10000 )
864- head_dim = getattr (config , "head_dim" ,
865- config .hidden_size // config .num_attention_heads )
865+ head_dim = getattr (config , "head_dim" , None )
866+ if head_dim is None :
867+ head_dim = config .hidden_size // config .num_attention_heads
866868 if hasattr (config , "max_model_len" ) and isinstance (
867869 config .max_model_len , int ):
868870 max_position_embeddings = min (config .max_position_embeddings ,
0 commit comments