|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import torch |
| 4 | + |
| 5 | +import vllm.envs as envs |
| 6 | +from vllm.config import get_current_vllm_config |
| 7 | +from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size |
| 8 | +from vllm.model_executor.custom_op import CustomOp |
| 9 | +from vllm.model_executor.layers.fused_moe.config import ( |
| 10 | + FusedMoEConfig, |
| 11 | + FusedMoEParallelConfig, |
| 12 | +) |
| 13 | +from vllm.model_executor.layers.fused_moe.layer import ( |
| 14 | + FusedMoE, |
| 15 | + UnquantizedFusedMoEMethod, |
| 16 | +) |
| 17 | +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig |
| 18 | +from vllm.platforms import current_platform |
| 19 | +from vllm.utils.import_utils import has_fbgemm_gpu_gen_ai |
| 20 | + |
| 21 | +from .dispatch_combine import MetaShufflingDispatchAndCombine, RouteInfo |
| 22 | +from .routed_experts import MetaShufflingMoERoutedExperts |
| 23 | + |
| 24 | +if current_platform.is_cuda_alike() and has_fbgemm_gpu_gen_ai(): |
| 25 | + from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling |
| 26 | +from vllm.logger import init_logger |
| 27 | + |
| 28 | +logger = init_logger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +# We only need the weight loader from unquantized fused moe method. |
| 32 | +class MetaShufflingMoEMethod(UnquantizedFusedMoEMethod): |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + moe: FusedMoEConfig, |
| 36 | + quant_config: QuantizationConfig | None = None, |
| 37 | + ): |
| 38 | + super().__init__(moe) |
| 39 | + self.quant_config = quant_config |
| 40 | + |
| 41 | + # Override to no ops. |
| 42 | + def init_prepare_finalize(self, layer: torch.nn.Module): |
| 43 | + assert self.moe is not None |
| 44 | + |
| 45 | + |
| 46 | +class MetaShufflingMoE(FusedMoE): |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + num_experts: int, |
| 50 | + top_k: int, |
| 51 | + hidden_size: int, |
| 52 | + intermediate_size: int, |
| 53 | + prefix: str, |
| 54 | + quant_config: QuantizationConfig | None = None, |
| 55 | + shared_experts: torch.nn.Module | None = None, |
| 56 | + scoring_func: str = "softmax", |
| 57 | + apply_router_weight_on_input: bool = False, |
| 58 | + activation: str = "silu", |
| 59 | + is_sequence_parallel: bool = False, |
| 60 | + **kwargs, |
| 61 | + ): |
| 62 | + CustomOp.__init__(self) |
| 63 | + |
| 64 | + logger.info_once("Initialized with MetaShufflingMoE") |
| 65 | + |
| 66 | + assert current_platform.is_cuda_alike(), ( |
| 67 | + "MetaShufflingMoE only supports CUDA and AMD for now." |
| 68 | + ) |
| 69 | + assert has_fbgemm_gpu_gen_ai(), ( |
| 70 | + "MetaShufflingMoE requires fbgemm_gpu_gen_ai. \ |
| 71 | + Run pip install fbgemm-gpu-genai" |
| 72 | + ) |
| 73 | + |
| 74 | + params_dtype = kwargs.get("params_dtype", torch.get_default_dtype()) |
| 75 | + tp_size_ = kwargs.get("tp_size", get_tensor_model_parallel_world_size()) |
| 76 | + dp_size_ = kwargs.get("dp_size", get_dp_group().world_size) |
| 77 | + assert not is_sequence_parallel, "Sequence parallel is not supported yet." |
| 78 | + # Parallelism |
| 79 | + vllm_config = get_current_vllm_config() |
| 80 | + self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( |
| 81 | + tp_size_=tp_size_, |
| 82 | + dp_size_=dp_size_, |
| 83 | + vllm_parallel_config=vllm_config.parallel_config, |
| 84 | + ) |
| 85 | + etp_size_ = 1 if self.use_ep else tp_size_ |
| 86 | + assert not self.use_ep, "Ep is not supported yet." |
| 87 | + self.tp2ep_size = tp_size_ // etp_size_ |
| 88 | + self.dp2ep = self.ep_size // self.tp2ep_size |
| 89 | + assert self.dp2ep == dp_size_, "Doesn't support dp > dp2ep yet" |
| 90 | + |
| 91 | + # Determine expert maps |
| 92 | + assert num_experts % self.ep_size == 0, ( |
| 93 | + "Does not support duplicate experts for now." |
| 94 | + ) |
| 95 | + self.global_num_experts = num_experts |
| 96 | + self.local_num_experts = self.global_num_experts |
| 97 | + self.group_expert_start = 0 |
| 98 | + self.group_expert_end = self.global_num_experts |
| 99 | + self.experts_mask = torch.arange( |
| 100 | + self.group_expert_start, self.group_expert_end, device="cuda" |
| 101 | + ).view(-1, 1, 1) |
| 102 | + self.local_num_experts, self.expert_map, self.expert_mask = ( |
| 103 | + self.global_num_experts, |
| 104 | + None, |
| 105 | + None, |
| 106 | + ) |
| 107 | + |
| 108 | + # Layer setup |
| 109 | + # TODO: Most of the weights loading logic is |
| 110 | + # similar to base fused_moe. We should probably refactor |
| 111 | + # the code so that common shared logic can be shared. |
| 112 | + compilation_config = vllm_config.compilation_config |
| 113 | + if prefix in compilation_config.static_forward_context: |
| 114 | + raise ValueError("Duplicate layer name: {}".format(prefix)) |
| 115 | + compilation_config.static_forward_context[prefix] = self |
| 116 | + self.layer_name = prefix |
| 117 | + |
| 118 | + assert intermediate_size % self.tp_size == 0 |
| 119 | + self.hidden_size = hidden_size |
| 120 | + self.intermediate_size_per_partition = intermediate_size // self.tp_size |
| 121 | + self.scoring_func = scoring_func |
| 122 | + self.apply_router_weight_on_input = apply_router_weight_on_input |
| 123 | + assert self.apply_router_weight_on_input, ( |
| 124 | + "Only support apply_router_weight_on_input=True for now." |
| 125 | + ) |
| 126 | + self.activation = activation |
| 127 | + self.top_k = top_k |
| 128 | + |
| 129 | + if vllm_config.model_config is not None: |
| 130 | + model_dtype = vllm_config.model_config.dtype |
| 131 | + else: |
| 132 | + # TODO (bnell): This is a hack to get test_mixtral_moe to work |
| 133 | + # since model_config is not set in the pytest test. |
| 134 | + model_dtype = params_dtype |
| 135 | + |
| 136 | + moe = FusedMoEConfig( |
| 137 | + num_experts=self.global_num_experts, |
| 138 | + experts_per_token=top_k, |
| 139 | + hidden_dim=hidden_size, |
| 140 | + num_local_experts=self.local_num_experts, |
| 141 | + moe_parallel_config=self.moe_parallel_config, |
| 142 | + in_dtype=model_dtype, |
| 143 | + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, |
| 144 | + has_bias=False, |
| 145 | + ) |
| 146 | + self.moe_config = moe |
| 147 | + |
| 148 | + self.is_routed_fp8_rowwise: bool = False |
| 149 | + assert quant_config is None, "Quantization is not supported yet." |
| 150 | + self.quant_config = quant_config |
| 151 | + |
| 152 | + # Note: get_quant_method will look at the layer's local_num_experts |
| 153 | + # for heuristic purposes, so it must be initialized first. |
| 154 | + self.quant_method = MetaShufflingMoEMethod(moe, quant_config=quant_config) |
| 155 | + |
| 156 | + moe_quant_params = { |
| 157 | + "num_experts": self.local_num_experts, |
| 158 | + "hidden_size": hidden_size, |
| 159 | + "intermediate_size_per_partition": self.intermediate_size_per_partition, |
| 160 | + "params_dtype": params_dtype, |
| 161 | + "weight_loader": self.weight_loader, |
| 162 | + } |
| 163 | + # need full intermediate size pre-sharding for WNA16 act order |
| 164 | + if self.quant_method.__class__.__name__ in ( |
| 165 | + "GPTQMarlinMoEMethod", |
| 166 | + "CompressedTensorsWNA16MarlinMoEMethod", |
| 167 | + "CompressedTensorsWNA16MoEMethod", |
| 168 | + ): |
| 169 | + moe_quant_params["intermediate_size_full"] = intermediate_size |
| 170 | + |
| 171 | + self.quant_method.create_weights(layer=self, **moe_quant_params) |
| 172 | + |
| 173 | + self._shared_experts = shared_experts |
| 174 | + self.dispatch_and_combine = MetaShufflingDispatchAndCombine() |
| 175 | + self.routed_experts = MetaShufflingMoERoutedExperts( |
| 176 | + quant_config=self.quant_config |
| 177 | + ) |
| 178 | + |
| 179 | + @property |
| 180 | + def shared_experts(self) -> torch.nn.Module | None: |
| 181 | + return self._shared_experts |
| 182 | + |
| 183 | + def route( |
| 184 | + self, |
| 185 | + hidden_states: torch.Tensor, |
| 186 | + router_logits: torch.Tensor, |
| 187 | + ) -> tuple[torch.Tensor, RouteInfo]: |
| 188 | + assert self.scoring_func == "sigmoid", ( |
| 189 | + "only support sigmoid scoring function for now " |
| 190 | + ) |
| 191 | + if self.scoring_func == "sigmoid": |
| 192 | + scores = torch.sigmoid(router_logits.to(torch.float32)) |
| 193 | + top_k = self.moe_config.experts_per_token |
| 194 | + if top_k in {1, 2, 4} and self.global_num_experts in {16, 128}: |
| 195 | + token_counts, expert_indices, token_indices = index_shuffling( |
| 196 | + scores, # num_tokens |
| 197 | + self.group_expert_start, |
| 198 | + self.group_expert_end, |
| 199 | + top_k=top_k, |
| 200 | + ) |
| 201 | + num_routed_tokens = token_counts[-1] |
| 202 | + token_counts = token_counts[self.group_expert_start : self.group_expert_end] |
| 203 | + else: |
| 204 | + # Slow route using torch topk. |
| 205 | + _, global_selected_indices = torch.topk(scores, top_k, dim=1) |
| 206 | + expert_indices, token_indices = torch.sort( |
| 207 | + global_selected_indices.flatten(), dim=0, stable=True |
| 208 | + ) |
| 209 | + token_indices = token_indices // top_k |
| 210 | + mask = self.experts_mask == expert_indices |
| 211 | + token_counts = (mask).sum(dim=2, dtype=torch.int32).flatten() |
| 212 | + num_routed_tokens = token_counts.sum().view( |
| 213 | + -1, |
| 214 | + ) |
| 215 | + return scores, RouteInfo( |
| 216 | + expert_indices=expert_indices, |
| 217 | + token_indices=token_indices, |
| 218 | + token_counts=token_counts, |
| 219 | + num_routed_tokens=num_routed_tokens, |
| 220 | + ) |
| 221 | + |
| 222 | + def forward_impl( |
| 223 | + self, hidden_states: torch.Tensor, router_logits: torch.Tensor |
| 224 | + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: |
| 225 | + scores, route_info = self.route( |
| 226 | + hidden_states=hidden_states, |
| 227 | + router_logits=router_logits, |
| 228 | + ) |
| 229 | + shuffled_recv_tokens, recv_token_counts = self.dispatch_and_combine.dispatch( |
| 230 | + tokens=hidden_states, |
| 231 | + scores=scores, |
| 232 | + route_info=route_info, |
| 233 | + apply_router_weight_on_input=self.apply_router_weight_on_input, |
| 234 | + ) |
| 235 | + # TODO: add using separate streams for shared experts when there's comms. |
| 236 | + if self._shared_experts is not None: |
| 237 | + shared_out = self._shared_experts(hidden_states) |
| 238 | + else: |
| 239 | + # This is so that we can call scatter_add_dense_tokens |
| 240 | + # without shared_experts. |
| 241 | + shared_out = torch.zeros_like(hidden_states) |
| 242 | + |
| 243 | + routed_out = self.routed_experts.run( |
| 244 | + x=shuffled_recv_tokens, |
| 245 | + token_counts=recv_token_counts, |
| 246 | + w1=self.w13_weight.data, |
| 247 | + w2=self.w2_weight.data, |
| 248 | + activation=self.activation, |
| 249 | + scores=scores, |
| 250 | + apply_router_weight_on_input=self.apply_router_weight_on_input, |
| 251 | + num_valid_tokens=route_info.num_recv_tokens, |
| 252 | + shared_out=shared_out if not self.use_ep else None, |
| 253 | + token_indices=route_info.token_indices if not self.use_ep else None, |
| 254 | + ) |
| 255 | + |
| 256 | + output = self.dispatch_and_combine.combine( |
| 257 | + routed_out=routed_out, |
| 258 | + shared_out=shared_out, |
| 259 | + route_info=route_info, |
| 260 | + scores=scores, |
| 261 | + ) |
| 262 | + output = output.view(hidden_states.shape) |
| 263 | + if shared_out is None: |
| 264 | + return output |
| 265 | + else: |
| 266 | + # create a fake shared_output as moe_forward_shared expect to return a tuple |
| 267 | + return torch.empty_like(output), output |
0 commit comments