-
-
Notifications
You must be signed in to change notification settings - Fork 9.3k
Support Deepseek-V2 #4650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Deepseek-V2 #4650
Changes from all commits
5688e58
2609d43
2bcfba8
36425b0
28199d8
434d757
ce3a80a
59b6353
1ce0c2a
bf98862
ca9c0ee
4cf44a5
0746b4f
2443f27
44f087c
df65a69
1d90229
e06d0d2
703e6a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
from vllm.model_executor.layers.fused_moe.fused_moe import ( | ||
fused_experts, fused_moe, fused_topk, get_config_file_name) | ||
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) | ||
|
||
__all__ = [ | ||
"fused_moe", | ||
"fused_topk", | ||
"fused_experts", | ||
"get_config_file_name", | ||
"grouped_topk", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -610,6 +610,119 @@ def forward( | |
return query.flatten(-2), key.flatten(-2) | ||
|
||
|
||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: | ||
if scale <= 1: | ||
return 1.0 | ||
return 0.1 * mscale * math.log(scale) + 1.0 | ||
|
||
|
||
class DeepseekScalingRotaryEmbedding(RotaryEmbedding): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is extremely similar to |
||
"""RotaryEmbedding extended with YaRN method. | ||
|
||
Credits to Peng et al. github.com/jquesnelle/yarn | ||
""" | ||
|
||
def __init__( | ||
self, | ||
head_size: int, | ||
rotary_dim: int, | ||
max_position_embeddings: int, | ||
base: int, | ||
is_neox_style: bool, | ||
scaling_factor: float, | ||
dtype: torch.dtype, | ||
*, | ||
extrapolation_factor: float = 1, | ||
attn_factor: float = 1, | ||
beta_fast: int = 32, | ||
beta_slow: int = 1, | ||
mscale: float = 1, | ||
mscale_all_dim: float = 0, | ||
) -> None: | ||
self.scaling_factor = scaling_factor | ||
self.extrapolation_factor = extrapolation_factor | ||
self.attn_factor = attn_factor | ||
self.beta_fast = beta_fast | ||
self.beta_slow = beta_slow | ||
# Get n-d magnitude scaling corrected for interpolation. | ||
self.mscale = float( | ||
yarn_get_mscale(self.scaling_factor, float(mscale)) / | ||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * | ||
attn_factor) | ||
super().__init__(head_size, rotary_dim, max_position_embeddings, base, | ||
is_neox_style, dtype) | ||
|
||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: | ||
pos_freqs = self.base**(torch.arange( | ||
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / | ||
self.rotary_dim) | ||
inv_freq_extrapolation = 1.0 / pos_freqs | ||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) | ||
|
||
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, | ||
self.rotary_dim, self.base, | ||
self.max_position_embeddings) | ||
# Get n-d rotational scaling corrected for extrapolation | ||
inv_freq_mask = (1 - _yarn_linear_ramp_mask( | ||
low, high, self.rotary_dim // 2, | ||
dtype=torch.float)) * self.extrapolation_factor | ||
inv_freq = inv_freq_interpolation * ( | ||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask | ||
return inv_freq | ||
|
||
def _compute_cos_sin_cache(self) -> torch.Tensor: | ||
inv_freq = self._compute_inv_freq(self.scaling_factor) | ||
t = torch.arange(self.max_position_embeddings * self.scaling_factor, | ||
device="cuda", | ||
dtype=torch.float32) | ||
freqs = torch.einsum("i,j -> ij", t, inv_freq) | ||
cos = (freqs.cos() * self.mscale) | ||
sin = (freqs.sin() * self.mscale) | ||
cache = torch.cat((cos, sin), dim=-1) | ||
print("Cache shape", cache.shape) | ||
return cache | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
offsets: Optional[torch.Tensor] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""PyTorch-native implementation equivalent to forward().""" | ||
query_rot = query[..., :self.rotary_dim] | ||
key_rot = key[..., :self.rotary_dim] | ||
if self.rotary_dim < self.head_size: | ||
query_pass = query[..., self.rotary_dim:] | ||
key_pass = key[..., self.rotary_dim:] | ||
|
||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( | ||
positions.device) | ||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets) | ||
if offsets is not None else positions] | ||
cos, sin = cos_sin.chunk(2, dim=-1) | ||
if self.is_neox_style: | ||
# NOTE(woosuk): Here we assume that the positions tensor has the | ||
# shape [batch_size, seq_len]. | ||
cos = cos.repeat(1, 1, 2).unsqueeze(-2) | ||
sin = sin.repeat(1, 1, 2).unsqueeze(-2) | ||
else: | ||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
|
||
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj | ||
query_rot = query_rot * cos + rotate_fn(query_rot) * sin | ||
key_rot = key_rot * cos + rotate_fn(key_rot) * sin | ||
|
||
if self.rotary_dim < self.head_size: | ||
query = torch.cat((query_rot, query_pass), dim=-1) | ||
key = torch.cat((key_rot, key_pass), dim=-1) | ||
else: | ||
query = query_rot | ||
key = key_rot | ||
return query, key | ||
|
||
|
||
class GemmaRotaryEmbedding(RotaryEmbedding): | ||
|
||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: | ||
|
@@ -679,6 +792,19 @@ def get_rope( | |
base, is_neox_style, | ||
scaling_factor, dtype, | ||
**extra_kwargs) | ||
elif scaling_type == "deepseek_yarn": | ||
original_max_position = rope_scaling[ | ||
"original_max_position_embeddings"] | ||
# assert max_position == original_max_position * scaling_factor | ||
extra_kwargs = { | ||
k: v | ||
for k, v in rope_scaling.items() | ||
if k in ("extrapolation_factor", "attn_factor", "beta_fast", | ||
"beta_slow", "mscale", "mscale_all_dim") | ||
} | ||
rotary_emb = DeepseekScalingRotaryEmbedding( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L369 , why is_neox_style=False? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dsv3 use half mode insteal of interleave There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
head_size, rotary_dim, original_max_position, base, | ||
is_neox_style, scaling_factor, dtype, **extra_kwargs) | ||
# The correct one should be "longrope" but keep "su" here | ||
# for backward compatible | ||
elif scaling_type == "su" or scaling_type == "longrope": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this true? According to
vllm/vllm/attention/backends/flash_attn.py
Line 16 in 99eff67
deepseek_v2.py
-- that should make it quite a bit simpler :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, i will test it later with the latest flash attn