File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed
vllm/model_executor/models Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -111,6 +111,7 @@ class MixtralAttention(nn.Module):
111111
112112 def __init__ (
113113 self ,
114+ config : MixtralConfig ,
114115 hidden_size : int ,
115116 num_heads : int ,
116117 num_kv_heads : int ,
@@ -136,7 +137,9 @@ def __init__(
136137 # the KV heads across multiple tensor parallel GPUs.
137138 assert tp_size % self .total_num_kv_heads == 0
138139 self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
139- self .head_dim = hidden_size // self .total_num_heads
140+ # MixtralConfig has an optional head_dim argument
141+ self .head_dim = getattr (config , "head_dim" ,
142+ self .hidden_size // self .total_num_heads )
140143 self .q_size = self .num_heads * self .head_dim
141144 self .kv_size = self .num_kv_heads * self .head_dim
142145 self .scaling = self .head_dim ** - 0.5
@@ -200,6 +203,7 @@ def __init__(
200203 # Requires transformers > 4.32.0
201204 rope_theta = getattr (config , "rope_theta" , 10000 )
202205 self .self_attn = MixtralAttention (
206+ config = config ,
203207 hidden_size = self .hidden_size ,
204208 num_heads = config .num_attention_heads ,
205209 max_position = config .max_position_embeddings ,
Original file line number Diff line number Diff line change @@ -165,6 +165,7 @@ class MixtralAttention(nn.Module):
165165
166166 def __init__ (
167167 self ,
168+ config : MixtralConfig ,
168169 hidden_size : int ,
169170 num_heads : int ,
170171 num_kv_heads : int ,
@@ -190,7 +191,9 @@ def __init__(
190191 # the KV heads across multiple tensor parallel GPUs.
191192 assert tp_size % self .total_num_kv_heads == 0
192193 self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
193- self .head_dim = hidden_size // self .total_num_heads
194+ # MixtralConfig has an optional head_dim argument
195+ self .head_dim = getattr (config , "head_dim" ,
196+ self .hidden_size // self .total_num_heads )
194197 self .q_size = self .num_heads * self .head_dim
195198 self .kv_size = self .num_kv_heads * self .head_dim
196199 self .scaling = self .head_dim ** - 0.5
@@ -252,6 +255,7 @@ def __init__(
252255 # Requires transformers > 4.32.0
253256 rope_theta = getattr (config , "rope_theta" , 10000 )
254257 self .self_attn = MixtralAttention (
258+ config = config ,
255259 hidden_size = self .hidden_size ,
256260 num_heads = config .num_attention_heads ,
257261 max_position = config .max_position_embeddings ,
You can’t perform that action at this time.
0 commit comments