@@ -21,7 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
2121}
2222} // namespace
2323
24- template <typename scalar_t >
24+ template <typename scalar_t , typename token_cnts_t >
2525__global__ void moe_align_block_size_kernel (scalar_t * __restrict__ topk_ids,
2626 int32_t * sorted_token_ids,
2727 int32_t * expert_ids,
@@ -32,12 +32,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
3232 const size_t start_idx = threadIdx .x * tokens_per_thread;
3333
3434 extern __shared__ int32_t shared_mem[];
35-
36- int32_t * tokens_cnts =
37- shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
38- int32_t * cumsum =
39- shared_mem +
40- (blockDim .x + 1 ) * num_experts; // 1d tensor with shape (num_experts + 1)
35+ int32_t * cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
36+ token_cnts_t * tokens_cnts = (token_cnts_t *)(shared_mem + blockDim .x + 1 );
4137
4238 for (int i = 0 ; i < num_experts; ++i) {
4339 tokens_cnts[index (num_experts, threadIdx .x + 1 , i)] = 0 ;
@@ -74,7 +70,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
7470 block_size) *
7571 block_size;
7672 }
77- *total_tokens_post_pad = cumsum[num_experts];
73+ *total_tokens_post_pad = static_cast < int32_t >( cumsum[num_experts]) ;
7874 }
7975
8076 __syncthreads ();
@@ -224,26 +220,44 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
224220 torch::Tensor num_tokens_post_pad) {
225221 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
226222
227- // If we have very large number of experts, we can no longer use shared
228- // memory.
229- // TODO(simon): the right solution should be calculating the exact right
230- // amount of shared memory and use that. The num_experts >= 256 is just a
231- // temporary solution to unblock Deepseek V3.
232- if (num_experts >= 256 ) {
223+ int device_max_shared_mem;
224+ auto dev = topk_ids.get_device ();
225+ cudaDeviceGetAttribute (&device_max_shared_mem,
226+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
227+
228+ const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
229+ const int32_t shared_mem_i32 =
230+ ((num_thread + 1 ) * num_experts + (num_experts + 1 )) * sizeof (int32_t );
231+ const int32_t shared_mem_i16 =
232+ ((num_thread + 1 ) * num_experts) * sizeof (uint16_t ) +
233+ (num_experts + 1 ) * sizeof (int32_t );
234+
235+ bool use_global_memory = false ;
236+ bool use_i16 = false ; // Use uint16_t for shared memory token counts
237+ if (shared_mem_i16 > device_max_shared_mem) {
238+ use_global_memory = true ;
239+ } else if (shared_mem_i32 > device_max_shared_mem &&
240+ topk_ids.numel () <= 65535 ) {
241+ // when nelements of topk_ids is smaller than 65535 (max value of uint16),
242+ // element value of token_cnts would also smaller than 65535,
243+ // so we can use uint16 as dtype of token_cnts
244+ use_i16 = true ;
245+ }
246+
247+ if (use_global_memory) {
233248 VLLM_DISPATCH_INTEGRAL_TYPES (
234249 topk_ids.scalar_type (), " moe_align_block_size_global_mem_kernel" , [&] {
235250 // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
236251 // tensors
237252 const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
238253
239- const int32_t mem_tokens_cnts =
240- ((num_experts + 1 ) * num_experts) * sizeof (int32_t );
241- const int32_t mem_cumsum = (num_experts + 1 ) * sizeof (int32_t );
242- // allocate global memory
243- int32_t * tokens_cnts;
244- int32_t * cumsum;
245- cudaMalloc (&tokens_cnts, mem_tokens_cnts);
246- cudaMalloc (&cumsum, mem_cumsum);
254+ auto options_int = torch::TensorOptions ()
255+ .dtype (torch::kInt )
256+ .device (topk_ids.device ());
257+ torch::Tensor token_cnts_buffer =
258+ torch::empty ({(num_experts + 1 ) * num_experts}, options_int);
259+ torch::Tensor cumsum_buffer =
260+ torch::empty ({num_experts + 1 }, options_int);
247261
248262 auto kernel =
249263 vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t >;
@@ -252,25 +266,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
252266 sorted_token_ids.data_ptr <int32_t >(),
253267 experts_ids.data_ptr <int32_t >(),
254268 num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
255- topk_ids.numel (), tokens_cnts, cumsum);
256- cudaFree (tokens_cnts);
257- cudaFree (cumsum);
269+ topk_ids.numel (), token_cnts_buffer.data_ptr <int32_t >(),
270+ cumsum_buffer.data_ptr <int32_t >());
258271 });
259- } else {
272+ } else if (use_i16) {
260273 VLLM_DISPATCH_INTEGRAL_TYPES (
261274 topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
262- // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
263- // tensors
264- const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
265- const int32_t shared_mem =
266- ((num_thread + 1 ) * num_experts + (num_experts + 1 )) *
267- sizeof (int32_t );
268-
269275 // set dynamic shared mem
270- auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
276+ auto kernel =
277+ vllm::moe::moe_align_block_size_kernel<scalar_t , uint16_t >;
278+ AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
279+ (void *)kernel, shared_mem_i16));
280+ kernel<<<1 , num_thread, shared_mem_i16, stream>>> (
281+ topk_ids.data_ptr <scalar_t >(),
282+ sorted_token_ids.data_ptr <int32_t >(),
283+ experts_ids.data_ptr <int32_t >(),
284+ num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
285+ topk_ids.numel ());
286+ });
287+ } else {
288+ VLLM_DISPATCH_INTEGRAL_TYPES (
289+ topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
290+ auto kernel =
291+ vllm::moe::moe_align_block_size_kernel<scalar_t , int32_t >;
271292 AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
272- (void *)kernel, shared_mem ));
273- kernel<<<1 , num_thread, shared_mem , stream>>> (
293+ (void *)kernel, shared_mem_i32 ));
294+ kernel<<<1 , num_thread, shared_mem_i32 , stream>>> (
274295 topk_ids.data_ptr <scalar_t >(),
275296 sorted_token_ids.data_ptr <int32_t >(),
276297 experts_ids.data_ptr <int32_t >(),
0 commit comments