- 
                Notifications
    You must be signed in to change notification settings 
- Fork 87
feat(transformers): Transformers 4.54 base #1387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
            SamitHuang
  merged 99 commits into
  mindspore-lab:master
from
wcrzlh:transformers_4.54_base
  
      
      
   
  Oct 27, 2025 
      
    
  
     Merged
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            99 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      a8c5c54
              
                upgrade activation_func to transformers v4.54
              
              
                wcrzlh c05a707
              
                feat(transformers): upgrade attn_mask/rope to 4.54
              
              
                wcrzlh 84b3ece
              
                feat(transformers): upgrade modeling_layers to 4.54
              
              
                wcrzlh 6b94c2d
              
                feat(transformers): upgrade cache_utils to 4.54
              
              
                wcrzlh 2f70121
              
                feat(transformers): upgrade modeling_utils to v4.54
              
              
                wcrzlh 44ad424
              
                feat(transformers): upgrade generation/utils to v4.54
              
              
                wcrzlh 17da5c7
              
                feat(transformers): add ernie4.5 for validation
              
              
                wcrzlh fd769b1
              
                fix get_type_hints problem
              
              
                wcrzlh 37fe594
              
                fix get_type_hints problem
              
              
                wcrzlh 0874e77
              
                fix get_type_hints problem
              
              
                wcrzlh 94fb78b
              
                fix metadata.get keyerror
              
              
                wcrzlh bf69ef9
              
                fix masking_utils alignment
              
              
                wcrzlh a1be89c
              
                fix generation/utils logic
              
              
                wcrzlh b3334ac
              
                fix get_output_embedding override bug
              
              
                wcrzlh 833419c
              
                fix __init_subclass__ bug
              
              
                wcrzlh c532a9a
              
                suplement checkpoint_conversion_mapping
              
              
                wcrzlh 1ac2f72
              
                feat(transformers): upgrade beam search to v4.54
              
              
                wcrzlh 375e6ab
              
                feat(transformers): upgrade candidate_generator to v4.54
              
              
                wcrzlh 25033c1
              
                feat(transformers): upgrade logits_process/stopping_criteria to v4.54
              
              
                wcrzlh 252f4aa
              
                pre-commit
              
              
                wcrzlh 02834b0
              
                pre-commit
              
              
                wcrzlh 65e8256
              
                update backbone_utils
              
              
                wtomin 913cd3c
              
                update generic
              
              
                wtomin a0c9dc9
              
                remove add_model_info_to_auto_map & update feature_extraction_utils.py
              
              
                wtomin 2a517d8
              
                remove add_model_info_to_auto_map & update image_processing_base.py
              
              
                wtomin 8d29722
              
                remove add_model_info_to_auto_map & update processing_utils.py
              
              
                wtomin 8a90ca6
              
                remove add_model_info_to_auto_map & update video_utils.py
              
              
                wtomin dcb98ac
              
                tokenization_utils.py update
              
              
                wtomin 5120977
              
                add_model_info_to_custom_pipelines
              
              
                wtomin 9a18655
              
                update tokenization_utils_base.py
              
              
                wtomin 509a308
              
                update image_transforms.py
              
              
                wtomin 33ed2be
              
                update video_utils.py and image_utils.py
              
              
                wtomin 0d8142c
              
                update image_utils.py & image_processing_utils_fast.py
              
              
                wtomin 991a783
              
                update integration sdpa_attention.py
              
              
                wtomin ccb0897
              
                update mask_utils.py
              
              
                wtomin 00f2ba3
              
                update modeling_flash_attention_utils.py
              
              
                wtomin d0b34fb
              
                update modeling_outputs.py
              
              
                wtomin f75b06a
              
                fix pre-commit errors
              
              
                wtomin f9ea8ce
              
                fix pre-commit errors
              
              
                wtomin 03635c5
              
                rebase
              
              
                wcrzlh 7bed7a1
              
                add modeling_layers.py from cui yushi
              
              
                wtomin 61b4f5c
              
                fix import in transformers
              
              
                wtomin 9205ca2
              
                Merge branch 'transformers_4.54_base' into transformer-v4.54.1
              
              
                wtomin 3e3f452
              
                rm tokenization_utils.py and  tokenization_utils_base.py
              
              
                wtomin 91609b9
              
                resize stacked images one by one
              
              
                wtomin ffd3377
              
                remove torchvision decoders
              
              
                wtomin b38bf63
              
                fix get_default_dtype bug
              
              
                wcrzlh f32b7cb
              
                load module dynamically from mindone/transformers
              
              
                wtomin 2cb578b
              
                not support FA
              
              
                wtomin 7ad706d
              
                Merge pull request #2 from wtomin/transformer-v4.54.1
              
              
                wcrzlh 9457ebc
              
                add video_processing_utils
              
              
                wcrzlh 32031d0
              
                fix import error/add audio_utils/fix processor bug/attn_implementatio…
              
              
                wtomin 294d153
              
                fix attn_implementation configuration bug
              
              
                wcrzlh a44b0f6
              
                Fix attn_implementation
              
              
                wtomin ba674bc
              
                fix fa bug/key_renaming_mapping bug
              
              
                wcrzlh 3ab17b0
              
                pre-commit
              
              
                wcrzlh ee91d87
              
                upgrade modeling_utils/save_pretrained to transformersv4.54
              
              
                wcrzlh ff82ffb
              
                refactor fa part
              
              
                wcrzlh 58e07d6
              
                Fix some model's UT
              
              
                wtomin ab125b4
              
                revert _support_dynamic_input to _support_jit
              
              
                wcrzlh 226bd0e
              
                fix class name mismatch in generation/utils
              
              
                wcrzlh d156ca6
              
                fix pa error/delete unused fa part
              
              
                wcrzlh fe3304b
              
                remove unused part
              
              
                wcrzlh 934520f
              
                generation/utils ops-->mint
              
              
                wcrzlh 4aab9fa
              
                copyright/pre-commit
              
              
                wcrzlh d104c56
              
                fix bugs
              
              
                wcrzlh ba0a8eb
              
                supplement activation api
              
              
                wcrzlh 9e36ba8
              
                reformat
              
              
                wcrzlh 738d9bb
              
                remove losskwargs
              
              
                wtomin c80e2fd
              
                fix disable_grouping bug in image processing
              
              
                wcrzlh 10ec00b
              
                fix attn_implementation setting in modeling_utils/from_pretrained
              
              
                wcrzlh a813cf9
              
                fix attn_implementation setting in modeling_utils/from_pretrained
              
              
                wcrzlh cdebac0
              
                fix modeling_utils/from_config mindspore_dtype setting, generation/ut…
              
              
                wcrzlh 7a20fe1
              
                feat(transformers): add qwen3_vl/qwen3_vl_moe model
              
              
                wcrzlh 4079e6f
              
                fix moe precision bug
              
              
                wcrzlh c1cde3a
              
                fix qwen3_vl moe memory bugs
              
              
                wcrzlh 721d0a3
              
                supplement zero3 model weight shard for moe part
              
              
                wcrzlh e43f3dd
              
                fix qwen3_vl_moe precision bug
              
              
                wcrzlh 51515b9
              
                fix qwen3_vl_moe precision bug
              
              
                wcrzlh 25c8110
              
                fix moe part shard bug
              
              
                wcrzlh 9650f4f
              
                pre-commit
              
              
                wcrzlh 3771434
              
                reformat
              
              
                wcrzlh 2d5f9e7
              
                Merge pull request #1310 from wcrzlh/qwen3_vl
              
              
                vigo999 fed7ffc
              
                fix(transformers): fix typos in qwen3_vl docs
              
              
                wcrzlh f2b56bf
              
                Merge pull request #1311 from wcrzlh/qwen3_vl
              
              
                vigo999 3c81df8
              
                feat(transformers): add processor for qwen3_vl (#1326)
              
              
                wcrzlh 6e6361a
              
                fix(transformers): supplement condition of taking model as processor
              
              
                wcrzlh e72e032
              
                fix(transformers): reformat generation/utils
              
              
                wcrzlh dcdec6c
              
                fix(transformers): supplement candidate generator
              
              
                wcrzlh 47ca032
              
                fix(transformers): supplement logits processor
              
              
                wcrzlh c6df7fd
              
                feat(transformers): add assisted_generation/dola_generation/contrasiv…
              
              
                wcrzlh 11ba44c
              
                rebase
              
              
                wcrzlh 18d35f6
              
                reformat
              
              
                wcrzlh 8b75291
              
                fix import bug
              
              
                wcrzlh 9fd221f
              
                fix ut bug
              
              
                wcrzlh 46639e0
              
                update pyproject.toml
              
              
                wcrzlh aa0e7b0
              
                pre-commit
              
              
                wcrzlh b11c421
              
                reformat
              
              
                wcrzlh bd76d4c
              
                update loss_type
              
              
                wcrzlh File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # Qwen3-VL series | ||
|  | ||
| ## Introduction | ||
| [Qwen3-VL](https://huggingface.co/papers/2502.13923) is a multimodal vision-language model series, encompassing both dense and MoE variants, as well as Instruct and Thinking versions. Building upon its predecessors, Qwen3-VL delivers significant improvements in visual understanding while maintaining strong pure text capabilities. Key architectural advancements include: enhanced MRope with interleaved layout for better spatial-temporal modeling, DeepStack integration to effectively leverage multi-level features from the Vision Transformer (ViT), and improved video understanding through text-based time alignment—evolving from T-RoPE to text timestamp alignment for more precise temporal grounding. These innovations collectively enable Qwen3-VL to achieve superior performance in complex multimodal tasks. | ||
|  | ||
| # Get Started | ||
|  | ||
| ## Requirements: | ||
| | mindspore | ascend driver | firmware | cann tookit/kernel | | ||
| |-----------|----------------|----------------|--------------------| | ||
| | 2.6.0 | 24.1.RC3.b080 | 7.5.T11.0.B088 | 8.1.RC1 | | ||
|  | ||
| ### Installation: | ||
| ``` | ||
| git clone https://github.com/mindspore-lab/mindone.git -b hf-transformers-4.54 | ||
| cd mindone | ||
| pip install -e . | ||
| cd .. | ||
|  | ||
| # compile newest transformers whl because qwen3-vl(transformers v4.57.dev.0) haven't released | ||
| git clone https://github.com/huggingface/transformers.git | ||
| cd transformers | ||
| git reset --hard d0af4269ec260b9c4aeeda24c346a469e44799e1 | ||
| pip install -e . | ||
| cd .. | ||
|  | ||
| cd mindone/examples/transformers/qwen3_vl | ||
| ``` | ||
|  | ||
| ## Quick Start | ||
|  | ||
| Here is a usage example of Qwen3-VL-4B-Instruct. you can use the following command: | ||
|  | ||
| ```bash | ||
| # for Qwen3-VL-4B-Instruct inference | ||
| python generate_qwen3_vl.py | ||
| --model_name "Qwen/Qwen3-VL-4B-Instruct" | ||
| --image "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" | ||
| --prompt "Describe this image." | ||
| ``` | ||
|  | ||
| ```bash | ||
| # for Qwen3-VL-30B-A3B-Instruct inference | ||
| msrun --worker_num=2 --local_worker_num=2 --master_port=8118 \ | ||
| --log_dir=msrun_log --join=True --cluster_time_out=300 \ | ||
| generate_qwen3_vl_moe.py \ | ||
| --model_name "Qwen/Qwen3-VL-30B-A3B-Instruct" \ | ||
| --image "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" \ | ||
| --prompt "Describe this image." \ | ||
| ``` | ||
|  | ||
| Image: | ||
|  | ||
|  | ||
| Prompt: Describe this image. | ||
|  | ||
| Qwen3-VL-4B Outputs: | ||
| ``` | ||
| ['Of course, here is detailed description of the image provided.\n\n | ||
| This is a close-up photograph of a Pallas\'s cat ($Felis$, $manul$), | ||
| an endangered wild feline species native to Central Aisa. | ||
| ... | ||
| **Appearance:** It has a stocky and robust build with short legs | ||
| and a large head relative to its body size. Its fur is thick and dense, | ||
| appearing somewhat fluffy or "matted,", which is characteristic'] | ||
| ``` | ||
|  | ||
| Qwen3-VL-30B Outputs: | ||
| ``` | ||
| ['Of course, here is detailed description of the image provided.\n\n | ||
| This is a dynamic and charming photograph of a Palla's cat (also known as a manul) in a snowy enviroment. | ||
| ... | ||
| "Appearance:" The cat has a very distinctive apperance, characterized by its stocky, low-slung body and exceptionally | ||
| thick, dense fur. This coat is a mix of brownish"] | ||
| ``` | ||
|  | ||
| `model_name` and `image` could be replaced with your local path. Give it a try with various images and prompts🤗🤗. | ||
|  | ||
| ## Inference Speed | ||
| | model name | mindspore version | precision* | cards | attention type | tokens/s | | ||
| |:------------------------------:|:-----------------:|:----------:|:-----:|:--------------:|:----------:| | ||
| | Qwen/Qwen3-VL-4B-Instruct | 2.6.0 | bf16 | 1 | flash_attn | 1.35 | | ||
| | Qwen/Qwen3-VL-30B-A3B-Instruct | 2.6.0 | bf16 | 2 | flash_attn | 0.5 | | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| import argparse | ||
|  | ||
| import numpy as np | ||
|  | ||
| import mindspore as ms | ||
|  | ||
| from mindone.transformers import AutoProcessor, Qwen3VLForConditionalGeneration | ||
|  | ||
|  | ||
| def generate(args): | ||
| model = Qwen3VLForConditionalGeneration.from_pretrained( | ||
| args.model_name, | ||
| mindspore_dtype=ms.bfloat16, | ||
| attn_implementation=args.attn_implementation, | ||
| ) | ||
|  | ||
| processor = AutoProcessor.from_pretrained( | ||
| args.model_name, | ||
| use_fast=False, | ||
| ) | ||
|  | ||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| { | ||
| "type": "image", | ||
| "url": args.image, | ||
| }, | ||
| { | ||
| "type": "text", | ||
| "text": args.prompt, | ||
| }, | ||
| ], | ||
| } | ||
| ] | ||
|  | ||
| inputs = processor.apply_chat_template( | ||
| messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np" | ||
| ) | ||
|  | ||
| # convert input to Tensor | ||
| for key, value in inputs.items(): | ||
| if isinstance(value, np.ndarray): | ||
| inputs[key] = ms.tensor(value) | ||
| elif isinstance(value, list): | ||
| inputs[key] = ms.Tensor(value) | ||
|  | ||
| generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False) | ||
| generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] | ||
| output_text = processor.batch_decode( | ||
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
| ) | ||
| print(output_text) | ||
|  | ||
|  | ||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Qwen3VL demo.") | ||
|  | ||
| parser.add_argument("--prompt", type=str, default="Describe this image.") | ||
| parser.add_argument( | ||
| "--image", | ||
| type=str, | ||
| default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", | ||
| ) | ||
| parser.add_argument( | ||
| "--model_name", type=str, default="Qwen/Qwen3-VL-4B-Instruct", help="Path to the pre-trained model." | ||
| ) | ||
| parser.add_argument( | ||
| "--attn_implementation", | ||
| type=str, | ||
| default="flash_attention_2", | ||
| choices=["flash_attention_2", "eager"], | ||
| ) | ||
|  | ||
| # Parse the arguments | ||
| args = parser.parse_args() | ||
|  | ||
| generate(args) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| import argparse | ||
| from functools import partial | ||
|  | ||
| import numpy as np | ||
|  | ||
| import mindspore as ms | ||
| import mindspore.mint.distributed as dist | ||
| from mindspore.communication import GlobalComm | ||
|  | ||
| from mindone.trainers.zero import prepare_network | ||
| from mindone.transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration | ||
|  | ||
|  | ||
| def generate(args): | ||
| model = Qwen3VLMoeForConditionalGeneration.from_pretrained( | ||
| args.model_name, | ||
| mindspore_dtype=ms.bfloat16, | ||
| attn_implementation=args.attn_implementation, | ||
| ) | ||
|  | ||
| # use zero3 parallel | ||
| shard_fn = partial(prepare_network, zero_stage=3, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP) | ||
| model = shard_fn(model) | ||
|  | ||
| processor = AutoProcessor.from_pretrained( | ||
| args.model_name, | ||
| use_fast=False, | ||
| ) | ||
|  | ||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| { | ||
| "type": "image", | ||
| "url": args.image, | ||
| }, | ||
| { | ||
| "type": "text", | ||
| "text": args.prompt, | ||
| }, | ||
| ], | ||
| } | ||
| ] | ||
|  | ||
| inputs = processor.apply_chat_template( | ||
| messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np" | ||
| ) | ||
|  | ||
| # convert input to Tensor | ||
| for key, value in inputs.items(): | ||
| if isinstance(value, np.ndarray): | ||
| inputs[key] = ms.tensor(value) | ||
| elif isinstance(value, list): | ||
| inputs[key] = ms.Tensor(value) | ||
|  | ||
| generated_ids = model.generate(**inputs, max_new_tokens=128) | ||
| generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] | ||
| output_text = processor.batch_decode( | ||
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
| ) | ||
| print(output_text) | ||
|  | ||
|  | ||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Qwen3VLMoE demo.") | ||
|  | ||
| parser.add_argument("--prompt", type=str, default="Describe this image.") | ||
| parser.add_argument( | ||
| "--image", | ||
| type=str, | ||
| default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", | ||
| ) | ||
| parser.add_argument( | ||
| "--model_name", type=str, default="Qwen/Qwen3-VL-30B-A3B-Instruct", help="Path to the pre-trained model." | ||
| ) | ||
| parser.add_argument( | ||
| "--attn_implementation", | ||
| type=str, | ||
| default="flash_attention_2", | ||
| choices=["flash_attention_2", "eager"], | ||
| ) | ||
|  | ||
| # Parse the arguments | ||
| args = parser.parse_args() | ||
|  | ||
| # set up card communication | ||
| dist.init_process_group(backend="hccl") | ||
| ms.set_auto_parallel_context(parallel_mode="data_parallel") | ||
|  | ||
| generate(args) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| from typing import Literal, Optional | ||
|  | ||
| from mindspore import Tensor | ||
| from mindspore import dtype as mstype | ||
| from mindspore import mint, nn | ||
| from mindspore.communication import get_group_size, get_rank | ||
| from mindspore.communication.management import GlobalComm | ||
| from mindspore.context import ParallelMode | ||
| from mindspore.parallel._utils import _get_parallel_mode | ||
|  | ||
| from .param_wrapper import ZeroParamWrapper | ||
|  | ||
|  | ||
| class MoeTextExperts(nn.Cell): | ||
| def __init__( | ||
| self, | ||
| net: nn.Cell, | ||
| zero_stage: Literal[0, 1, 2, 3] = 0, | ||
| optimizer_parallel_group: str = GlobalComm.WORLD_COMM_GROUP, | ||
| cell_type: Optional[mstype.Type] = None, | ||
| ): | ||
| super().__init__(auto_prefix=False) | ||
| self.net = net | ||
| self.set_param_wrapper(zero_stage, optimizer_parallel_group, cell_type) | ||
|  | ||
| def set_param_wrapper(self, zero_stage, optimizer_parallel_group, cell_type=None): | ||
| self.param_wrapper_gate_up_proj = nn.Identity() | ||
| self.param_wrapper_down_proj = nn.Identity() | ||
| if zero_stage == 3: | ||
| # Init parallel settings | ||
| is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL | ||
| op_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1 | ||
| op_rank_id = get_rank(optimizer_parallel_group) if is_parallel else 0 | ||
| self.op_group_size = op_group_size | ||
| self.op_rank_id = op_rank_id | ||
| self.param_wrapper_gate_up_proj = ZeroParamWrapper( | ||
| self.net.gate_up_proj, zero_stage, optimizer_parallel_group, cell_type | ||
| ) | ||
| if self.param_wrapper_gate_up_proj.need_rewrite: | ||
| self.net.gate_up_proj.assign_value( | ||
| Tensor.from_numpy( | ||
| self.net.gate_up_proj.numpy().reshape(op_group_size, -1, *self.net.gate_up_proj.shape[1:])[ | ||
| op_rank_id | ||
| ] | ||
| ) | ||
| ) | ||
| self.param_wrapper_down_proj = ZeroParamWrapper( | ||
| self.net.down_proj, zero_stage, optimizer_parallel_group, cell_type | ||
| ) | ||
| if self.param_wrapper_down_proj.need_rewrite: | ||
| self.net.down_proj.assign_value( | ||
| Tensor.from_numpy( | ||
| self.net.down_proj.numpy().reshape(op_group_size, -1, *self.net.down_proj.shape[1:])[op_rank_id] | ||
| ) | ||
| ) | ||
|  | ||
| def construct(self, hidden_states, routing_weights, router_indices): | ||
| batch_size = hidden_states.shape[0] | ||
| hidden_states = hidden_states.reshape(-1, self.net.hidden_size) # (num_tokens, hidden_size) | ||
|  | ||
| hidden_states = hidden_states.repeat(self.net.num_experts, 1) | ||
| hidden_states = hidden_states.view(self.net.num_experts, -1, self.net.hidden_size) | ||
|  | ||
| gate_up = mint.bmm(hidden_states, self.param_wrapper_gate_up_proj(self.net.gate_up_proj)) | ||
| gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors | ||
| next_states = mint.bmm((up * self.net.act_fn(gate)), self.param_wrapper_down_proj(self.net.down_proj)) | ||
| next_states = next_states.reshape(self.net.num_experts, batch_size, -1, self.net.hidden_size) | ||
| next_states = next_states * routing_weights.swapaxes(0, 1).view(self.net.num_experts, batch_size, -1)[..., None] | ||
| next_states = next_states.sum(dim=0) | ||
| return next_states | 
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.