Skip to content

Commit 4609478

Browse files
author
Zoey Sun
committed
add signoff and fix shared experts error
Signed-off-by: Zoey Sun <[email protected]>
1 parent 69f0640 commit 4609478

File tree

7 files changed

+510
-21
lines changed

7 files changed

+510
-21
lines changed

vllm/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@
218218
VLLM_USE_FBGEMM: bool = False
219219
VLLM_GC_DEBUG: str = ""
220220
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
221+
VLLM_USE_META_SHUFFLING_MOE: bool = False
222+
VLLM_META_SHUFFLING_GEMM_BACKEND: Literal["cutlass", "triton"] = "cutlass"
221223

222224

223225
def get_default_cache_root():
@@ -1408,6 +1410,12 @@ def get_vllm_port() -> int | None:
14081410
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv(
14091411
"VLLM_DISABLE_SHARED_EXPERTS_STREAM", False
14101412
),
1413+
"VLLM_USE_META_SHUFFLING_MOE": lambda: bool(
1414+
int(os.getenv("VLLM_USE_META_SHUFFLING_MOE", "0"))
1415+
),
1416+
"VLLM_META_SHUFFLING_GEMM_BACKEND": env_with_choices(
1417+
"VLLM_META_SHUFFLING_GEMM_BACKEND", "cutlass", ["cutlass", "triton"]
1418+
),
14111419
}
14121420

14131421
# --8<-- [end:env-vars-definition]
@@ -1534,6 +1542,7 @@ def compute_hash() -> str:
15341542
"VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE",
15351543
"VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK",
15361544
"VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL",
1545+
"VLLM_META_SHUFFLING_GEMM_BACKEND",
15371546
]
15381547
for key in environment_variables_to_hash:
15391548
# if this goes out of sync with environment_variables,
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from vllm.model_executor.layers.meta_shuffling_moe.meta_shuffling_moe import (
4+
MetaShufflingMoE,
5+
)
6+
7+
__all__ = ["MetaShufflingMoE"]
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from dataclasses import dataclass
4+
5+
import torch
6+
7+
import vllm.envs as envs
8+
from vllm.platforms import current_platform
9+
from vllm.utils.import_utils import has_fbgemm_gpu_gen_ai
10+
11+
if current_platform.is_cuda_alike() and has_fbgemm_gpu_gen_ai():
12+
from fbgemm_gpu.experimental.gen_ai.moe import (
13+
gather_scale_dense_tokens,
14+
scatter_add_dense_tokens,
15+
)
16+
17+
18+
@dataclass
19+
class RouteInfo:
20+
expert_indices: torch.Tensor
21+
token_counts: torch.Tensor
22+
token_indices: torch.Tensor
23+
num_routed_tokens: torch.Tensor
24+
num_recv_tokens: torch.Tensor | None = None
25+
recv_sizes_across_ranks: torch.Tensor | None = None
26+
recv_sizes_across_ranks_cpu: torch.Tensor | None = None
27+
send_sizes_across_ranks: torch.Tensor | None = None
28+
send_sizes_across_ranks_cpu: torch.Tensor | None = None
29+
30+
31+
# Skeleton code to prepare for enabling EP.
32+
# In TP only case, dispatch/combine are almost no-ops.
33+
class MetaShufflingDispatchAndCombine:
34+
"""
35+
Dispatch/Combine using Meta Shuffling kernels.
36+
"""
37+
38+
def __new__(cls, *args, **kwargs):
39+
if not hasattr(cls, "instance"):
40+
cls.instance = super().__new__(cls)
41+
cls.instance._initialized = False
42+
return cls.instance
43+
44+
def __init__(self) -> None:
45+
if self._initialized:
46+
return
47+
self.world_size = 1
48+
assert current_platform.is_cuda_alike() and has_fbgemm_gpu_gen_ai()
49+
self._initialized: bool = True
50+
51+
def dispatch(
52+
self,
53+
tokens: torch.Tensor, # tokens
54+
route_info: RouteInfo,
55+
scores: torch.Tensor, # scores,
56+
apply_router_weight_on_input: bool,
57+
) -> tuple[torch.Tensor, torch.Tensor]:
58+
if apply_router_weight_on_input:
59+
tokens = gather_scale_dense_tokens(
60+
tokens,
61+
route_info.token_indices.flatten(),
62+
route_info.expert_indices.flatten(),
63+
scores,
64+
valid_token_count=route_info.num_routed_tokens,
65+
)
66+
assert self.world_size == 1
67+
return tokens, route_info.token_counts
68+
69+
def combine(
70+
self,
71+
routed_out: torch.Tensor,
72+
route_info: RouteInfo,
73+
scores: torch.Tensor,
74+
shared_out: torch.Tensor | None = None,
75+
) -> torch.Tensor:
76+
assert self.world_size == 1
77+
if envs.VLLM_META_SHUFFLING_GEMM_BACKEND == "cutlass":
78+
scatter_add_dense_tokens(
79+
out_tokens=shared_out,
80+
in_tokens=routed_out,
81+
token_indices=route_info.token_indices,
82+
valid_token_count=route_info.num_routed_tokens,
83+
)
84+
return shared_out
85+
# Assume in TP only case, we have already produced
86+
# fused output from routed and shared by calling
87+
# grouped_gemm with shared output when using triton grouped_gemm.
88+
else:
89+
return routed_out
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
5+
import vllm.envs as envs
6+
from vllm.config import get_current_vllm_config
7+
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
8+
from vllm.model_executor.custom_op import CustomOp
9+
from vllm.model_executor.layers.fused_moe.config import (
10+
FusedMoEConfig,
11+
FusedMoEParallelConfig,
12+
)
13+
from vllm.model_executor.layers.fused_moe.layer import (
14+
FusedMoE,
15+
UnquantizedFusedMoEMethod,
16+
)
17+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
18+
from vllm.platforms import current_platform
19+
from vllm.utils.import_utils import has_fbgemm_gpu_gen_ai
20+
21+
from .dispatch_combine import MetaShufflingDispatchAndCombine, RouteInfo
22+
from .routed_experts import MetaShufflingMoERoutedExperts
23+
24+
if current_platform.is_cuda_alike() and has_fbgemm_gpu_gen_ai():
25+
from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling
26+
from vllm.logger import init_logger
27+
28+
logger = init_logger(__name__)
29+
30+
31+
# We only need the weight loader from unquantized fused moe method.
32+
class MetaShufflingMoEMethod(UnquantizedFusedMoEMethod):
33+
def __init__(
34+
self,
35+
moe: FusedMoEConfig,
36+
quant_config: QuantizationConfig | None = None,
37+
):
38+
super().__init__(moe)
39+
self.quant_config = quant_config
40+
41+
# Override to no ops.
42+
def init_prepare_finalize(self, layer: torch.nn.Module):
43+
assert self.moe is not None
44+
45+
46+
class MetaShufflingMoE(FusedMoE):
47+
def __init__(
48+
self,
49+
num_experts: int,
50+
top_k: int,
51+
hidden_size: int,
52+
intermediate_size: int,
53+
prefix: str,
54+
quant_config: QuantizationConfig | None = None,
55+
shared_experts: torch.nn.Module | None = None,
56+
scoring_func: str = "softmax",
57+
apply_router_weight_on_input: bool = False,
58+
activation: str = "silu",
59+
is_sequence_parallel: bool = False,
60+
**kwargs,
61+
):
62+
CustomOp.__init__(self)
63+
64+
logger.info_once("Initialized with MetaShufflingMoE")
65+
66+
assert current_platform.is_cuda_alike(), (
67+
"MetaShufflingMoE only supports CUDA and AMD for now."
68+
)
69+
assert has_fbgemm_gpu_gen_ai(), (
70+
"MetaShufflingMoE requires fbgemm_gpu_gen_ai. \
71+
Run pip install fbgemm-gpu-genai"
72+
)
73+
74+
params_dtype = kwargs.get("params_dtype", torch.get_default_dtype())
75+
tp_size_ = kwargs.get("tp_size", get_tensor_model_parallel_world_size())
76+
dp_size_ = kwargs.get("dp_size", get_dp_group().world_size)
77+
assert not is_sequence_parallel, "Sequence parallel is not supported yet."
78+
# Parallelism
79+
vllm_config = get_current_vllm_config()
80+
self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
81+
tp_size_=tp_size_,
82+
dp_size_=dp_size_,
83+
vllm_parallel_config=vllm_config.parallel_config,
84+
)
85+
etp_size_ = 1 if self.use_ep else tp_size_
86+
assert not self.use_ep, "Ep is not supported yet."
87+
self.tp2ep_size = tp_size_ // etp_size_
88+
self.dp2ep = self.ep_size // self.tp2ep_size
89+
assert self.dp2ep == dp_size_, "Doesn't support dp > dp2ep yet"
90+
91+
# Determine expert maps
92+
assert num_experts % self.ep_size == 0, (
93+
"Does not support duplicate experts for now."
94+
)
95+
self.global_num_experts = num_experts
96+
self.local_num_experts = self.global_num_experts
97+
self.group_expert_start = 0
98+
self.group_expert_end = self.global_num_experts
99+
self.experts_mask = torch.arange(
100+
self.group_expert_start, self.group_expert_end, device="cuda"
101+
).view(-1, 1, 1)
102+
self.local_num_experts, self.expert_map, self.expert_mask = (
103+
self.global_num_experts,
104+
None,
105+
None,
106+
)
107+
108+
# Layer setup
109+
# TODO: Most of the weights loading logic is
110+
# similar to base fused_moe. We should probably refactor
111+
# the code so that common shared logic can be shared.
112+
compilation_config = vllm_config.compilation_config
113+
if prefix in compilation_config.static_forward_context:
114+
raise ValueError("Duplicate layer name: {}".format(prefix))
115+
compilation_config.static_forward_context[prefix] = self
116+
self.layer_name = prefix
117+
118+
assert intermediate_size % self.tp_size == 0
119+
self.hidden_size = hidden_size
120+
self.intermediate_size_per_partition = intermediate_size // self.tp_size
121+
self.scoring_func = scoring_func
122+
self.apply_router_weight_on_input = apply_router_weight_on_input
123+
assert self.apply_router_weight_on_input, (
124+
"Only support apply_router_weight_on_input=True for now."
125+
)
126+
self.activation = activation
127+
self.top_k = top_k
128+
129+
if vllm_config.model_config is not None:
130+
model_dtype = vllm_config.model_config.dtype
131+
else:
132+
# TODO (bnell): This is a hack to get test_mixtral_moe to work
133+
# since model_config is not set in the pytest test.
134+
model_dtype = params_dtype
135+
136+
moe = FusedMoEConfig(
137+
num_experts=self.global_num_experts,
138+
experts_per_token=top_k,
139+
hidden_dim=hidden_size,
140+
num_local_experts=self.local_num_experts,
141+
moe_parallel_config=self.moe_parallel_config,
142+
in_dtype=model_dtype,
143+
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
144+
has_bias=False,
145+
)
146+
self.moe_config = moe
147+
148+
self.is_routed_fp8_rowwise: bool = False
149+
assert quant_config is None, "Quantization is not supported yet."
150+
self.quant_config = quant_config
151+
152+
# Note: get_quant_method will look at the layer's local_num_experts
153+
# for heuristic purposes, so it must be initialized first.
154+
self.quant_method = MetaShufflingMoEMethod(moe, quant_config=quant_config)
155+
156+
moe_quant_params = {
157+
"num_experts": self.local_num_experts,
158+
"hidden_size": hidden_size,
159+
"intermediate_size_per_partition": self.intermediate_size_per_partition,
160+
"params_dtype": params_dtype,
161+
"weight_loader": self.weight_loader,
162+
}
163+
# need full intermediate size pre-sharding for WNA16 act order
164+
if self.quant_method.__class__.__name__ in (
165+
"GPTQMarlinMoEMethod",
166+
"CompressedTensorsWNA16MarlinMoEMethod",
167+
"CompressedTensorsWNA16MoEMethod",
168+
):
169+
moe_quant_params["intermediate_size_full"] = intermediate_size
170+
171+
self.quant_method.create_weights(layer=self, **moe_quant_params)
172+
173+
self._shared_experts = shared_experts
174+
self.dispatch_and_combine = MetaShufflingDispatchAndCombine()
175+
self.routed_experts = MetaShufflingMoERoutedExperts(
176+
quant_config=self.quant_config
177+
)
178+
179+
@property
180+
def shared_experts(self) -> torch.nn.Module | None:
181+
return self._shared_experts
182+
183+
def route(
184+
self,
185+
hidden_states: torch.Tensor,
186+
router_logits: torch.Tensor,
187+
) -> tuple[torch.Tensor, RouteInfo]:
188+
assert self.scoring_func == "sigmoid", (
189+
"only support sigmoid scoring function for now "
190+
)
191+
if self.scoring_func == "sigmoid":
192+
scores = torch.sigmoid(router_logits.to(torch.float32))
193+
top_k = self.moe_config.experts_per_token
194+
if top_k in {1, 2, 4} and self.global_num_experts in {16, 128}:
195+
token_counts, expert_indices, token_indices = index_shuffling(
196+
scores, # num_tokens
197+
self.group_expert_start,
198+
self.group_expert_end,
199+
top_k=top_k,
200+
)
201+
num_routed_tokens = token_counts[-1]
202+
token_counts = token_counts[self.group_expert_start : self.group_expert_end]
203+
else:
204+
# Slow route using torch topk.
205+
_, global_selected_indices = torch.topk(scores, top_k, dim=1)
206+
expert_indices, token_indices = torch.sort(
207+
global_selected_indices.flatten(), dim=0, stable=True
208+
)
209+
token_indices = token_indices // top_k
210+
mask = self.experts_mask == expert_indices
211+
token_counts = (mask).sum(dim=2, dtype=torch.int32).flatten()
212+
num_routed_tokens = token_counts.sum().view(
213+
-1,
214+
)
215+
return scores, RouteInfo(
216+
expert_indices=expert_indices,
217+
token_indices=token_indices,
218+
token_counts=token_counts,
219+
num_routed_tokens=num_routed_tokens,
220+
)
221+
222+
def forward_impl(
223+
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
224+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
225+
scores, route_info = self.route(
226+
hidden_states=hidden_states,
227+
router_logits=router_logits,
228+
)
229+
shuffled_recv_tokens, recv_token_counts = self.dispatch_and_combine.dispatch(
230+
tokens=hidden_states,
231+
scores=scores,
232+
route_info=route_info,
233+
apply_router_weight_on_input=self.apply_router_weight_on_input,
234+
)
235+
# TODO: add using separate streams for shared experts when there's comms.
236+
if self._shared_experts is not None:
237+
shared_out = self._shared_experts(hidden_states)
238+
else:
239+
# This is so that we can call scatter_add_dense_tokens
240+
# without shared_experts.
241+
shared_out = torch.zeros_like(hidden_states)
242+
243+
routed_out = self.routed_experts.run(
244+
x=shuffled_recv_tokens,
245+
token_counts=recv_token_counts,
246+
w1=self.w13_weight.data,
247+
w2=self.w2_weight.data,
248+
activation=self.activation,
249+
scores=scores,
250+
apply_router_weight_on_input=self.apply_router_weight_on_input,
251+
num_valid_tokens=route_info.num_recv_tokens,
252+
shared_out=shared_out if not self.use_ep else None,
253+
token_indices=route_info.token_indices if not self.use_ep else None,
254+
)
255+
256+
output = self.dispatch_and_combine.combine(
257+
routed_out=routed_out,
258+
shared_out=shared_out,
259+
route_info=route_info,
260+
scores=scores,
261+
)
262+
output = output.view(hidden_states.shape)
263+
if shared_out is None:
264+
return output
265+
else:
266+
# create a fake shared_output as moe_forward_shared expect to return a tuple
267+
return torch.empty_like(output), output

0 commit comments

Comments
 (0)