diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index fc3951cfd3..5aa122618e 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -691,6 +691,7 @@ from .utils import logging if version.parse(transformers.__version__) >= version.parse("4.51.0"): + from .models.deepseek_v3 import DeepseekV3ForCausalLM, DeepseekV3Model, DeepseekV3PreTrainedModel from .models.qwen3 import Qwen3ForCausalLM, Qwen3Model, Qwen3PreTrainedModel if version.parse(transformers.__version__) >= version.parse("4.51.3"): diff --git a/mindone/transformers/modeling_layers.py b/mindone/transformers/modeling_layers.py new file mode 100644 index 0000000000..af67b088a9 --- /dev/null +++ b/mindone/transformers/modeling_layers.py @@ -0,0 +1,285 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC +from functools import partial +from typing import Optional + +from transformers.utils import logging + +import mindspore as ms +import mindspore.nn as nn +from mindspore import mint + +from .cache_utils import Cache +from .modeling_outputs import ( + BaseModelOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from .models.auto import AutoModel +from .processing_utils import Unpack +from .utils import TransformersKwargs + +logger = logging.get_logger(__name__) + + +class GradientCheckpointingLayer(nn.Cell): + """Base class for layers with gradient checkpointing. + + This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled + (`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is + enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`. + + Important: + + When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states) + must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients. + + Example: + + ```python + >>> # Correct - hidden_states passed as positional arg + >>> out = self.layer(hidden_states, attention_mask=attention_mask) + + >>> # Incorrect - hidden_states passed as keyword arg + >>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask) + ``` + """ + + gradient_checkpointing = False + + def __call__(self, *args, **kwargs): + if self.gradient_checkpointing and self.training: + do_warn = False + layer_name = self.__class__.__name__ + message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting" + + if "use_cache" in kwargs and kwargs["use_cache"]: + kwargs["use_cache"] = False + message += " `use_cache=False`," + do_warn = True + + # different names for the same thing in different layers + if "past_key_value" in kwargs and kwargs["past_key_value"] is not None: + kwargs["past_key_value"] = None + message += " `past_key_value=None`," + do_warn = True + + if "past_key_values" in kwargs and kwargs["past_key_values"] is not None: + kwargs["past_key_values"] = None + message += " `past_key_values=None`," + do_warn = True + + if "layer_past" in kwargs and kwargs["layer_past"] is not None: + kwargs["layer_past"] = None + message += " `layer_past=None`," + do_warn = True + + # warn if anything was changed + if do_warn: + message = message.rstrip(",") + "." + logger.warning(message) + + return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args) + return super().__call__(*args, **kwargs) + + +class GenericForSequenceClassification(ABC): + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class + setattr(self, self.base_model_prefix, AutoModel.from_config(config)) + self.score = mint.nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutputWithPast: + transformer_outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, ms.int32) + token_indices = mint.arange(input_ids.shape[-1], dtype=ms.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[mint.arange(batch_size), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +class GenericForQuestionAnswering(ABC): + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class + setattr(self, self.base_model_prefix, AutoModel.from_config(config)) + self.qa_outputs = mint.nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return getattr(self, self.base_model_prefix).embed_tokens + + def set_input_embeddings(self, value): + getattr(self, self.base_model_prefix).embed_tokens = value + + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + start_positions: Optional[ms.Tensor] = None, + end_positions: Optional[ms.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> QuestionAnsweringModelOutput: + outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class GenericForTokenClassification(ABC): + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class + setattr(self, self.base_model_prefix, AutoModel.from_config(config)) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = mint.nn.Dropout(classifier_dropout) + self.score = mint.nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs, + ) -> TokenClassifierOutput: + outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py index e9db4e4185..3091545440 100644 --- a/mindone/transformers/models/__init__.py +++ b/mindone/transformers/models/__init__.py @@ -103,7 +103,7 @@ ) if version.parse(transformers.__version__) >= version.parse("4.51.0"): - from . import qwen3 + from . import deepseek_v3, qwen3 if version.parse(transformers.__version__) >= version.parse("4.51.3"): from . import glm4 diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index a7c3983d43..c60eff2dcd 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -296,7 +296,9 @@ if version.parse(transformers.__version__) >= version.parse("4.51.0"): CONFIG_MAPPING_NAMES.update({"qwen3": "Qwen3Config"}) + CONFIG_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3Config"}) MODEL_NAMES_MAPPING.update({"qwen3": "Qwen3Model"}) + MODEL_NAMES_MAPPING.update({"deepseek_v3": "DeepSeek-V3"}) if version.parse(transformers.__version__) >= version.parse("4.51.3"): CONFIG_MAPPING_NAMES.update({"glm4": "Glm4Config"}) diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py index e9f0b9833f..c8be950ce8 100644 --- a/mindone/transformers/models/auto/modeling_auto.py +++ b/mindone/transformers/models/auto/modeling_auto.py @@ -674,6 +674,8 @@ if version.parse(transformers.__version__) >= version.parse("4.51.0"): + MODEL_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3Model"}), + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"deepseek_v3": "DeepseekV3ForCausalLM"}), MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"qwen3": "Qwen3Model"}) MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"qwen3": "Qwen3ForCausalLM"}) MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.update({"qwen3": "Qwen3ForSequenceClassification"}) diff --git a/mindone/transformers/models/deepseek_v3/__init__.py b/mindone/transformers/models/deepseek_v3/__init__.py new file mode 100644 index 0000000000..976aae484d --- /dev/null +++ b/mindone/transformers/models/deepseek_v3/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_deepseek_v3 import * diff --git a/mindone/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/mindone/transformers/models/deepseek_v3/modeling_deepseek_v3.py new file mode 100644 index 0000000000..e097f4cc0d --- /dev/null +++ b/mindone/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -0,0 +1,674 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deepseek_v3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# with modifications to run transformers on mindspore. +# +import math +from typing import Callable, Optional, Union + +from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config + +import mindspore + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs + + +class DeepseekV3RMSNorm(mindspore.nn.Cell): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = mindspore.Parameter(mindspore.mint.ones(hidden_size)) + self.variance_epsilon = eps + + def construct(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(mindspore.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * mindspore.mint.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class DeepseekV3RotaryEmbedding(mindspore.nn.Cell): + def __init__(self, config: DeepseekV3Config): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def construct(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand((position_ids.shape[0], -1, 1)) + position_ids_expanded = position_ids[:, None, :].float() + + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = mindspore.mint.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class DeepseekV3MLP(mindspore.nn.Cell): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + + self.gate_proj = mindspore.mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = mindspore.mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = mindspore.mint.nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def construct(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class DeepseekV3TopkRouter(mindspore.nn.Cell): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + + self.weight = mindspore.Parameter(mindspore.mint.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer("e_score_correction_bias", mindspore.mint.zeros(self.n_routed_experts)) + + def get_topk_indices(self, scores): + scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = mindspore.mint.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = mindspore.mint.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand((-1, self.n_group, self.n_routed_experts // self.n_group)) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = mindspore.mint.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + return topk_indices + + def construct(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = mindspore.mint.nn.functional.linear( + hidden_states.type(mindspore.float32), self.weight.type(mindspore.float32) + ) + scores = router_logits.sigmoid() + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + +class DeepseekV3MoE(mindspore.nn.Cell): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.experts = mindspore.nn.CellList( + [ + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(config.n_routed_experts) + ] + ) + self.gate = DeepseekV3TopkRouter(config) + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) + + def moe(self, hidden_states: mindspore.Tensor, topk_indices: mindspore.Tensor, topk_weights: mindspore.Tensor): + r""" + CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused + to not have to do a loop here (deepseek has 256 experts soooo yeah). + """ + final_hidden_states = mindspore.mint.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = mindspore.mint.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) + + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = mindspore.mint.where(mask) + + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, weighted_output) + + # in original deepseek, the output of the experts are gathered once we leave this module + # thus the moe module is itelsf an IsolatedParallel module + # and all expert are "local" meaning we shard but we don't gather + return final_hidden_states.type(hidden_states.dtype) + + def construct(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return mindspore.mint.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`mindspore.Tensor`): The query tensor. + k (`mindspore.Tensor`): The key tensor. + cos (`mindspore.Tensor`): The cosine part of the rotary embedding. + sin (`mindspore.Tensor`): The sine part of the rotary embedding. + position_ids (`mindspore.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(mindspore.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: mindspore.Tensor, n_rep: int) -> mindspore.Tensor: + """ + This is the equivalent of mint.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand((batch, num_key_value_heads, n_rep, slen, head_dim)) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: mindspore.nn.Cell, + query: mindspore.Tensor, + key: mindspore.Tensor, + value: mindspore.Tensor, + attention_mask: Optional[mindspore.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = mindspore.mint.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = mindspore.mint.nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(query.dtype) + attn_weights = mindspore.mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mindspore.mint.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + r""" + TODO let's just use the original freqcis computation to not have the view + transpose + reshape! This is not optimized! + Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`mindspore.Tensor`): The query tensor. + k (`mindspore.Tensor`): The key tensor. + cos (`mindspore.Tensor`): The cosine part of the rotary embedding. + sin (`mindspore.Tensor`): The sine part of the rotary embedding. + position_ids (`mindspore.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(mindspore.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekV3Attention(mindspore.nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if self.q_lora_rank is None: + self.q_proj = mindspore.mint.nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = mindspore.mint.nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = mindspore.mint.nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = mindspore.mint.nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) + self.kv_b_proj = mindspore.mint.nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = mindspore.mint.nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale + + def construct( + self, + hidden_states: mindspore.Tensor, + position_embeddings: tuple[mindspore.Tensor, mindspore.Tensor], + attention_mask: Optional[mindspore.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[mindspore.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[mindspore.Tensor, Optional[mindspore.Tensor], Optional[tuple[mindspore.Tensor]]]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + else: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = mindspore.mint.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = mindspore.mint.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = mindspore.mint.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + cos, sin = position_embeddings + if self.config.rope_interleave: # support using interleaved weights for efficiency + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand((*k_pass.shape[:-1], -1)) + + query_states = mindspore.mint.cat((q_pass, q_rot), dim=-1) + key_states = mindspore.mint.cat((k_pass, k_rot), dim=-1) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + value_states = mindspore.mint.nn.functional.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class DeepseekV3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DeepseekV3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) + + if layer_idx >= config.first_k_dense_replace: + self.mlp = DeepseekV3MoE(config) + else: + self.mlp = DeepseekV3MLP(config) + + self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def construct( + self, + hidden_states: mindspore.Tensor, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[mindspore.Tensor] = None, + position_embeddings: Optional[ + tuple[mindspore.Tensor, mindspore.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[mindspore.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class DeepseekV3PreTrainedModel(PreTrainedModel): + config_class = DeepseekV3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_dynamic_input = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": DeepseekV3DecoderLayer, + "attentions": DeepseekV3Attention, + } + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, DeepseekV3TopkRouter): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + +class DeepseekV3Model(DeepseekV3PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"] + + def __init__(self, config: DeepseekV3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = mindspore.mint.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = mindspore.nn.CellList( + [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def construct( + self, + input_ids: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[mindspore.Tensor] = None, + cache_position: Optional[mindspore.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: mindspore.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: mindspore.Tensor = mindspore.mint.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = mindspore.mint.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def construct( + self, + input_ids: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[mindspore.Tensor] = None, + labels: Optional[mindspore.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[mindspore.Tensor] = None, + logits_to_keep: Union[int, mindspore.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import DeepseekV3ForCausalLM + >>> import mindspore as ms + + >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="np") + + >>> # Generate + >>> generate_ids = model.generate(ms.tensor(inputs.input_ids), max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["DeepseekV3PreTrainedModel", "DeepseekV3Model", "DeepseekV3ForCausalLM"] diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py index be560c7988..c7c8020b7f 100644 --- a/mindone/transformers/utils/generic.py +++ b/mindone/transformers/utils/generic.py @@ -577,3 +577,37 @@ class LossKwargs(TypedDict, total=False): """ num_items_in_batch: Optional[Tensor] + + +class TransformersKwargs(TypedDict, total=False): + """ + Keyword arguments to be passed to the loss function + + Attributes: + num_items_in_batch (`Optional[mindspore.Tensor]`, *optional*): + Number of items in the batch. It is recommended to pass it when + you are doing gradient accumulation. + output_hidden_states (`Optional[bool]`, *optional*): + Most of the models support outputing all hidden states computed during the forward pass. + output_attentions (`Optional[bool]`, *optional*): + Turn this on to return the intermediary attention scores. + output_router_logits (`Optional[bool]`, *optional*): + For MoE models, this allows returning the router logits to compute the loss. + cumulative_seqlens_q (`mindspore.Tensor`, *optional*) + Gets cumulative sequence length for query state. + cumulative_seqlens_k (`mindspore.Tensor`, *optional*) + Gets cumulative sequence length for key state. + max_length_q (`int`, *optional*): + Maximum sequence length for query state. + max_length_k (`int`, *optional*): + Maximum sequence length for key state. + """ + + num_items_in_batch: Optional["Tensor"] + output_hidden_states: Optional[bool] + output_attentions: Optional[bool] + output_router_logits: Optional[bool] + cumulative_seqlens_q: Optional["Tensor"] + cumulative_seqlens_k: Optional["Tensor"] + max_length_q: Optional[int] + max_length_k: Optional[int] diff --git a/tests/transformers_tests/models/deepseek_v3/__init__.py b/tests/transformers_tests/models/deepseek_v3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/deepseek_v3/test_modeling_deepseek_v3.py b/tests/transformers_tests/models/deepseek_v3/test_modeling_deepseek_v3.py new file mode 100644 index 0000000000..f299bd05cb --- /dev/null +++ b/tests/transformers_tests/models/deepseek_v3/test_modeling_deepseek_v3.py @@ -0,0 +1,284 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the Mindspore DeepseekV3 model.""" + +import inspect + +import numpy as np +import pytest +import torch +from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import ids_numpy + +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} +MODES = [1] + + +class DeepseekV3ModelTester: + def __init__( + self, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + vocab_size=99, + hidden_size=32, + intermediate_size=37, + moe_intermediate_size=12, + num_hidden_layers=5, + num_attention_heads=4, + num_key_value_heads=4, + n_shared_experts=1, + n_routed_experts=8, + routed_scaling_factor=2.5, + kv_lora_rank=16, + q_lora_rank=32, + qk_rope_head_dim=16, + v_head_dim=32, + qk_nope_head_dim=32, + n_group=2, + topk_group=1, + num_experts_per_tok=8, + first_k_dense_replace=2, + norm_topk_prob=True, + aux_loss_alpha=0.001, + hidden_act="silu", + max_position_embeddings=512, + initializer_range=0.02, + attention_probs_dropout_prob=0.1, + type_vocab_size=16, + type_sequence_label_size=2, + num_labels=3, + num_choices=4, + pad_token_id=0, + scope=None, + ): + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.aux_loss_alpha = aux_loss_alpha + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_numpy([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = np.tril(np.ones_like(input_ids)) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_numpy([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_numpy([self.batch_size], self.type_sequence_label_size) + token_labels = ids_numpy([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_numpy([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return DeepseekV3Config( + attn_implementation="eager", + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + moe_intermediate_size=self.moe_intermediate_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + n_shared_experts=self.n_shared_experts, + n_routed_experts=self.n_routed_experts, + routed_scaling_factor=self.routed_scaling_factor, + kv_lora_rank=self.kv_lora_rank, + q_lora_rank=self.q_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + qk_nope_head_dim=self.qk_nope_head_dim, + n_group=self.n_group, + topk_group=self.topk_group, + num_experts_per_tok=self.num_experts_per_tok, + first_k_dense_replace=self.first_k_dense_replace, + norm_topk_prob=self.norm_topk_prob, + aux_loss_alpha=self.aux_loss_alpha, + hidden_act=self.hidden_act, + max_position_embeddings=self.max_position_embeddings, + initializer_range=self.initializer_range, + use_cache=True, + pad_token_id=self.pad_token_id, + attention_dropout=self.attention_probs_dropout_prob, + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +model_tester = DeepseekV3ModelTester() +( + config, + inputs_dict, +) = model_tester.prepare_config_and_inputs_for_common() + + +DEEPSEEKV3_CASES = [ + [ + "DeepseekV3Model", + "transformers.DeepseekV3Model", + "mindone.transformers.DeepseekV3Model", + (config,), + {}, + (), + {"input_ids": inputs_dict["input_ids"], "attention_mask": inputs_dict["attention_mask"]}, + { + "last_hidden_state": 0, + }, + ], +] + + +# transformers need >= 4.41.2 +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in DEEPSEEKV3_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + # print("ms:", ms_outputs) + # print("pt:", pt_outputs) + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + # print("===map", pt_key, ms_idx) + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + )