-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Feature] Add MetaShufflingMoE as Optional MoE backend to Llama4 models #27891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sunfish2010
wants to merge
1
commit into
vllm-project:main
Choose a base branch
from
sunfish2010:meta_shuffling_integration
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from vllm.model_executor.layers.meta_shuffling_moe.meta_shuffling_moe import ( | ||
| MetaShufflingMoE, | ||
| ) | ||
|
|
||
| __all__ = ["MetaShufflingMoE"] |
89 changes: 89 additions & 0 deletions
89
vllm/model_executor/layers/meta_shuffling_moe/dispatch_combine.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from dataclasses import dataclass | ||
|
|
||
| import torch | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils.import_utils import has_fbgemm_gpu_gen_ai | ||
|
|
||
| if current_platform.is_cuda_alike() and has_fbgemm_gpu_gen_ai(): | ||
| from fbgemm_gpu.experimental.gen_ai.moe import ( | ||
| gather_scale_dense_tokens, | ||
| scatter_add_dense_tokens, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class RouteInfo: | ||
| expert_indices: torch.Tensor | ||
| token_counts: torch.Tensor | ||
| token_indices: torch.Tensor | ||
| num_routed_tokens: torch.Tensor | ||
| num_recv_tokens: torch.Tensor | None = None | ||
| recv_sizes_across_ranks: torch.Tensor | None = None | ||
| recv_sizes_across_ranks_cpu: torch.Tensor | None = None | ||
| send_sizes_across_ranks: torch.Tensor | None = None | ||
| send_sizes_across_ranks_cpu: torch.Tensor | None = None | ||
|
|
||
|
|
||
| # Skeleton code to prepare for enabling EP. | ||
| # In TP only case, dispatch/combine are almost no-ops. | ||
| class MetaShufflingDispatchAndCombine: | ||
| """ | ||
| Dispatch/Combine using Meta Shuffling kernels. | ||
| """ | ||
|
|
||
| def __new__(cls, *args, **kwargs): | ||
| if not hasattr(cls, "instance"): | ||
| cls.instance = super().__new__(cls) | ||
| cls.instance._initialized = False | ||
| return cls.instance | ||
|
|
||
| def __init__(self) -> None: | ||
| if self._initialized: | ||
| return | ||
| self.world_size = 1 | ||
| assert current_platform.is_cuda_alike() and has_fbgemm_gpu_gen_ai() | ||
| self._initialized: bool = True | ||
|
|
||
| def dispatch( | ||
| self, | ||
| tokens: torch.Tensor, # tokens | ||
| route_info: RouteInfo, | ||
| scores: torch.Tensor, # scores, | ||
| apply_router_weight_on_input: bool, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| if apply_router_weight_on_input: | ||
| tokens = gather_scale_dense_tokens( | ||
| tokens, | ||
| route_info.token_indices.flatten(), | ||
| route_info.expert_indices.flatten(), | ||
| scores, | ||
| valid_token_count=route_info.num_routed_tokens, | ||
| ) | ||
| assert self.world_size == 1 | ||
| return tokens, route_info.token_counts | ||
|
|
||
| def combine( | ||
| self, | ||
| routed_out: torch.Tensor, | ||
| route_info: RouteInfo, | ||
| scores: torch.Tensor, | ||
| shared_out: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| assert self.world_size == 1 | ||
| if envs.VLLM_META_SHUFFLING_GEMM_BACKEND == "cutlass": | ||
| scatter_add_dense_tokens( | ||
| out_tokens=shared_out, | ||
| in_tokens=routed_out, | ||
| token_indices=route_info.token_indices, | ||
| valid_token_count=route_info.num_routed_tokens, | ||
| ) | ||
| return shared_out | ||
| # Assume in TP only case, we have already produced | ||
| # fused output from routed and shared by calling | ||
| # grouped_gemm with shared output when using triton grouped_gemm. | ||
| else: | ||
| return routed_out |
267 changes: 267 additions & 0 deletions
267
vllm/model_executor/layers/meta_shuffling_moe/meta_shuffling_moe.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,267 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import torch | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.config import get_current_vllm_config | ||
| from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size | ||
| from vllm.model_executor.custom_op import CustomOp | ||
| from vllm.model_executor.layers.fused_moe.config import ( | ||
| FusedMoEConfig, | ||
| FusedMoEParallelConfig, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.layer import ( | ||
| FusedMoE, | ||
| UnquantizedFusedMoEMethod, | ||
| ) | ||
| from vllm.model_executor.layers.quantization.base_config import QuantizationConfig | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils.import_utils import has_fbgemm_gpu_gen_ai | ||
|
|
||
| from .dispatch_combine import MetaShufflingDispatchAndCombine, RouteInfo | ||
| from .routed_experts import MetaShufflingMoERoutedExperts | ||
|
|
||
| if current_platform.is_cuda_alike() and has_fbgemm_gpu_gen_ai(): | ||
| from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling | ||
| from vllm.logger import init_logger | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| # We only need the weight loader from unquantized fused moe method. | ||
| class MetaShufflingMoEMethod(UnquantizedFusedMoEMethod): | ||
| def __init__( | ||
| self, | ||
| moe: FusedMoEConfig, | ||
| quant_config: QuantizationConfig | None = None, | ||
| ): | ||
| super().__init__(moe) | ||
| self.quant_config = quant_config | ||
|
|
||
| # Override to no ops. | ||
| def init_prepare_finalize(self, layer: torch.nn.Module): | ||
| assert self.moe is not None | ||
|
|
||
|
|
||
| class MetaShufflingMoE(FusedMoE): | ||
| def __init__( | ||
| self, | ||
| num_experts: int, | ||
| top_k: int, | ||
| hidden_size: int, | ||
| intermediate_size: int, | ||
| prefix: str, | ||
| quant_config: QuantizationConfig | None = None, | ||
| shared_experts: torch.nn.Module | None = None, | ||
| scoring_func: str = "softmax", | ||
| apply_router_weight_on_input: bool = False, | ||
| activation: str = "silu", | ||
| is_sequence_parallel: bool = False, | ||
| **kwargs, | ||
| ): | ||
| CustomOp.__init__(self) | ||
|
|
||
| logger.info_once("Initialized with MetaShufflingMoE") | ||
|
|
||
| assert current_platform.is_cuda_alike(), ( | ||
| "MetaShufflingMoE only supports CUDA and AMD for now." | ||
| ) | ||
| assert has_fbgemm_gpu_gen_ai(), ( | ||
| "MetaShufflingMoE requires fbgemm_gpu_gen_ai. \ | ||
| Run pip install fbgemm-gpu-genai" | ||
| ) | ||
|
|
||
| params_dtype = kwargs.get("params_dtype", torch.get_default_dtype()) | ||
| tp_size_ = kwargs.get("tp_size", get_tensor_model_parallel_world_size()) | ||
| dp_size_ = kwargs.get("dp_size", get_dp_group().world_size) | ||
| assert not is_sequence_parallel, "Sequence parallel is not supported yet." | ||
| # Parallelism | ||
| vllm_config = get_current_vllm_config() | ||
| self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( | ||
| tp_size_=tp_size_, | ||
| dp_size_=dp_size_, | ||
| vllm_parallel_config=vllm_config.parallel_config, | ||
| ) | ||
| etp_size_ = 1 if self.use_ep else tp_size_ | ||
| assert not self.use_ep, "Ep is not supported yet." | ||
| self.tp2ep_size = tp_size_ // etp_size_ | ||
| self.dp2ep = self.ep_size // self.tp2ep_size | ||
| assert self.dp2ep == dp_size_, "Doesn't support dp > dp2ep yet" | ||
|
|
||
| # Determine expert maps | ||
| assert num_experts % self.ep_size == 0, ( | ||
| "Does not support duplicate experts for now." | ||
| ) | ||
| self.global_num_experts = num_experts | ||
| self.local_num_experts = self.global_num_experts | ||
| self.group_expert_start = 0 | ||
| self.group_expert_end = self.global_num_experts | ||
| self.experts_mask = torch.arange( | ||
| self.group_expert_start, self.group_expert_end, device="cuda" | ||
| ).view(-1, 1, 1) | ||
| self.local_num_experts, self.expert_map, self.expert_mask = ( | ||
| self.global_num_experts, | ||
| None, | ||
| None, | ||
| ) | ||
|
|
||
| # Layer setup | ||
| # TODO: Most of the weights loading logic is | ||
| # similar to base fused_moe. We should probably refactor | ||
| # the code so that common shared logic can be shared. | ||
| compilation_config = vllm_config.compilation_config | ||
| if prefix in compilation_config.static_forward_context: | ||
| raise ValueError("Duplicate layer name: {}".format(prefix)) | ||
| compilation_config.static_forward_context[prefix] = self | ||
| self.layer_name = prefix | ||
|
|
||
| assert intermediate_size % self.tp_size == 0 | ||
| self.hidden_size = hidden_size | ||
| self.intermediate_size_per_partition = intermediate_size // self.tp_size | ||
| self.scoring_func = scoring_func | ||
| self.apply_router_weight_on_input = apply_router_weight_on_input | ||
| assert self.apply_router_weight_on_input, ( | ||
| "Only support apply_router_weight_on_input=True for now." | ||
| ) | ||
| self.activation = activation | ||
| self.top_k = top_k | ||
|
|
||
| if vllm_config.model_config is not None: | ||
| model_dtype = vllm_config.model_config.dtype | ||
| else: | ||
| # TODO (bnell): This is a hack to get test_mixtral_moe to work | ||
| # since model_config is not set in the pytest test. | ||
| model_dtype = params_dtype | ||
|
|
||
| moe = FusedMoEConfig( | ||
| num_experts=self.global_num_experts, | ||
| experts_per_token=top_k, | ||
| hidden_dim=hidden_size, | ||
| num_local_experts=self.local_num_experts, | ||
| moe_parallel_config=self.moe_parallel_config, | ||
| in_dtype=model_dtype, | ||
| max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, | ||
| has_bias=False, | ||
| ) | ||
| self.moe_config = moe | ||
|
|
||
| self.is_routed_fp8_rowwise: bool = False | ||
| assert quant_config is None, "Quantization is not supported yet." | ||
| self.quant_config = quant_config | ||
|
|
||
| # Note: get_quant_method will look at the layer's local_num_experts | ||
| # for heuristic purposes, so it must be initialized first. | ||
| self.quant_method = MetaShufflingMoEMethod(moe, quant_config=quant_config) | ||
|
|
||
| moe_quant_params = { | ||
| "num_experts": self.local_num_experts, | ||
| "hidden_size": hidden_size, | ||
| "intermediate_size_per_partition": self.intermediate_size_per_partition, | ||
| "params_dtype": params_dtype, | ||
| "weight_loader": self.weight_loader, | ||
| } | ||
| # need full intermediate size pre-sharding for WNA16 act order | ||
| if self.quant_method.__class__.__name__ in ( | ||
| "GPTQMarlinMoEMethod", | ||
| "CompressedTensorsWNA16MarlinMoEMethod", | ||
| "CompressedTensorsWNA16MoEMethod", | ||
| ): | ||
| moe_quant_params["intermediate_size_full"] = intermediate_size | ||
|
|
||
| self.quant_method.create_weights(layer=self, **moe_quant_params) | ||
|
|
||
| self._shared_experts = shared_experts | ||
| self.dispatch_and_combine = MetaShufflingDispatchAndCombine() | ||
| self.routed_experts = MetaShufflingMoERoutedExperts( | ||
| quant_config=self.quant_config | ||
| ) | ||
|
|
||
| @property | ||
| def shared_experts(self) -> torch.nn.Module | None: | ||
| return self._shared_experts | ||
|
|
||
| def route( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| router_logits: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, RouteInfo]: | ||
| assert self.scoring_func == "sigmoid", ( | ||
| "only support sigmoid scoring function for now " | ||
| ) | ||
| if self.scoring_func == "sigmoid": | ||
| scores = torch.sigmoid(router_logits.to(torch.float32)) | ||
| top_k = self.moe_config.experts_per_token | ||
| if top_k in {1, 2, 4} and self.global_num_experts in {16, 128}: | ||
| token_counts, expert_indices, token_indices = index_shuffling( | ||
| scores, # num_tokens | ||
| self.group_expert_start, | ||
| self.group_expert_end, | ||
| top_k=top_k, | ||
| ) | ||
| num_routed_tokens = token_counts[-1] | ||
| token_counts = token_counts[self.group_expert_start : self.group_expert_end] | ||
| else: | ||
| # Slow route using torch topk. | ||
| _, global_selected_indices = torch.topk(scores, top_k, dim=1) | ||
| expert_indices, token_indices = torch.sort( | ||
| global_selected_indices.flatten(), dim=0, stable=True | ||
| ) | ||
| token_indices = token_indices // top_k | ||
| mask = self.experts_mask == expert_indices | ||
| token_counts = (mask).sum(dim=2, dtype=torch.int32).flatten() | ||
| num_routed_tokens = token_counts.sum().view( | ||
| -1, | ||
| ) | ||
| return scores, RouteInfo( | ||
| expert_indices=expert_indices, | ||
| token_indices=token_indices, | ||
| token_counts=token_counts, | ||
| num_routed_tokens=num_routed_tokens, | ||
| ) | ||
|
|
||
| def forward_impl( | ||
| self, hidden_states: torch.Tensor, router_logits: torch.Tensor | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| scores, route_info = self.route( | ||
| hidden_states=hidden_states, | ||
| router_logits=router_logits, | ||
| ) | ||
| shuffled_recv_tokens, recv_token_counts = self.dispatch_and_combine.dispatch( | ||
| tokens=hidden_states, | ||
| scores=scores, | ||
| route_info=route_info, | ||
| apply_router_weight_on_input=self.apply_router_weight_on_input, | ||
| ) | ||
| # TODO: add using separate streams for shared experts when there's comms. | ||
| if self._shared_experts is not None: | ||
| shared_out = self._shared_experts(hidden_states) | ||
| else: | ||
| # This is so that we can call scatter_add_dense_tokens | ||
| # without shared_experts. | ||
| shared_out = torch.zeros_like(hidden_states) | ||
|
|
||
| routed_out = self.routed_experts.run( | ||
| x=shuffled_recv_tokens, | ||
| token_counts=recv_token_counts, | ||
| w1=self.w13_weight.data, | ||
| w2=self.w2_weight.data, | ||
| activation=self.activation, | ||
| scores=scores, | ||
| apply_router_weight_on_input=self.apply_router_weight_on_input, | ||
| num_valid_tokens=route_info.num_recv_tokens, | ||
| shared_out=shared_out if not self.use_ep else None, | ||
| token_indices=route_info.token_indices if not self.use_ep else None, | ||
| ) | ||
|
|
||
| output = self.dispatch_and_combine.combine( | ||
| routed_out=routed_out, | ||
| shared_out=shared_out, | ||
| route_info=route_info, | ||
| scores=scores, | ||
| ) | ||
| output = output.view(hidden_states.shape) | ||
| if shared_out is None: | ||
| return output | ||
| else: | ||
| # create a fake shared_output as moe_forward_shared expect to return a tuple | ||
| return torch.empty_like(output), output | ||
sunfish2010 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.