From 8ee4066cac3381a87adfaf96646f3110b40942c4 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 11 Sep 2025 22:57:13 +0800 Subject: [PATCH 01/12] CUDA: kernel for larger batch sizes for MoE --- ggml/src/ggml-cuda/mmf.cu | 48 +++++- ggml/src/ggml-cuda/mmf.cuh | 324 +++++++++++++++++++++++++++++++++---- ggml/src/ggml-cuda/mmq.cu | 60 +++---- tests/test-backend-ops.cpp | 2 +- 4 files changed, 365 insertions(+), 69 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 599e085ee91b7..de6eddcdedfa5 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -1,6 +1,10 @@ #include "ggml.h" #include "mmf.cuh" +void ggml_cuda_launch_mmq_ids_helper( + const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream); + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); @@ -37,6 +41,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + mmf_ids_data ids_info{}; + mmf_ids_data * ids_info_ptr = nullptr; + ggml_cuda_pool_alloc ids_src_compact_dev; + ggml_cuda_pool_alloc ids_dst_compact_dev; + ggml_cuda_pool_alloc expert_bounds_dev; + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t nchannels_dst = ids ? ne1 : ne2; @@ -54,6 +64,34 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr nchannels_y = ids->ne[0]; } + if (ids && ncols_dst > 16) { + const int64_t n_expert_used = ids->ne[0]; + const int64_t n_experts = ne02; + const int64_t n_tokens = ne12; + const int64_t ne_get_rows = n_tokens * n_expert_used; + + ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows); + ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows); + expert_bounds_dev.alloc(ctx.pool(), n_experts + 1); + + const int si1 = static_cast(ids_s1); + const int sis1 = static_cast(src1->nb[2] / src1->nb[1]); + + GGML_ASSERT(sis1 > 0); + + ggml_cuda_launch_mmq_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(), + static_cast(n_experts), static_cast(n_tokens), static_cast(n_expert_used), static_cast(ne11), si1, sis1, ctx.stream()); + CUDA_CHECK(cudaGetLastError()); + + ids_info.ids_src_compact = ids_src_compact_dev.get(); + ids_info.ids_dst_compact = ids_dst_compact_dev.get(); + ids_info.expert_bounds_dev = expert_bounds_dev.get(); + ids_info.n_experts = static_cast(n_experts); + ids_info.sis1 = sis1; + ids_info.cols_per_tile = 16; + ids_info_ptr = &ids_info; + } + switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; @@ -61,7 +99,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; @@ -69,7 +107,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; @@ -77,7 +115,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -98,10 +136,10 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const } if (mul_mat_id) { - if (type == GGML_TYPE_F32 && src1_ncols > 32) { + if (type == GGML_TYPE_F32 && src1_ncols > 512) { return false; } - if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) { + if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 8192) { return false; } } else { diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index a6c3adfcf1704..b3da913a0c030 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -7,6 +7,15 @@ using namespace ggml_cuda_mma; #define MMF_ROWS_PER_BLOCK 32 +struct mmf_ids_data { + const int32_t * ids_src_compact = nullptr; + const int32_t * ids_dst_compact = nullptr; + const int32_t * expert_bounds_dev = nullptr; + int n_experts = 0; + int sis1 = 0; + int cols_per_tile = 0; +}; + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id); @@ -224,6 +233,195 @@ static __global__ void mul_mat_f( #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } + +//This kernel is for larger batch sizes of mul_mat_id +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f_ids( + const T * __restrict__ x, const float * __restrict__ y, + const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 sis1_fd, const uint3 nch_fd) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + 4; + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + + const int expert_idx = blockIdx.y; + const int expert_start = expert_bounds[expert_idx]; + const int expert_end = expert_bounds[expert_idx + 1]; + const int ncols_expert = expert_end - expert_start; + + if (ncols_expert <= 0) { + return; + } + + const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block; + const int tile_idx = blockIdx.z; + if (tile_idx >= tiles_for_expert) { + return; + } + + const int col_base = tile_idx * cols_per_block; + + GGML_UNUSED(channel_ratio); + + const int channel_x = expert_idx; + const int sample_dst = 0; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row; + y += int64_t(sample_y) *stride_sample_y; + dst += int64_t(sample_dst)*stride_sample_dst; + + const int32_t * ids_src_expert = ids_src_compact + expert_start; + const int32_t * ids_dst_expert = ids_dst_compact + expert_start; + + extern __shared__ char data_mmv[]; + char * compute_base = data_mmv; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + const int global_j = col_base + j; + float val = 0.0f; + if (j < cols_per_block && global_j < ncols_expert) { + const int src_entry = ids_src_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd); + const int token = (int) qrm.x; + const int channel = (int) qrm.y; + if (token < ncols_dst_total) { + val = y[channel*stride_channel_y + token*stride_col_y + col]; + } + } + tile_xy[j0*tile_k_padded + threadIdx.x] = val; + } + } else if constexpr (std::is_same_v || std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + const int global_j = col_base + j; + float2 tmp = make_float2(0.0f, 0.0f); + if (j < cols_per_block && global_j < ncols_expert) { + const int src_entry = ids_src_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd); + const int token = (int) qrm.x; + const int channel = (int) qrm.y; + if (token < ncols_dst_total) { + tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)]; + } + } + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + } + } + + float * buf_iw = (float *) compute_base; + constexpr int kiw = nwarps*rows_per_block + 4; + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum = 0.0f; + static_assert(rows_per_block == warp_size, "need loop/check"); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { + const int i = i0 + threadIdx.x; + + sum += buf_iw[j*kiw + i]; + } + + const int global_j = col_base + j; + if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) { + const int dst_entry = ids_dst_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd); + const int token = (int) qrm.x; + if (token < ncols_dst_total) { + const int slot = (int) qrm.y; + dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum; + } + } + } +#else + GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); + NO_DEVICE_CODE; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +} + template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, @@ -232,13 +430,35 @@ static inline void mul_mat_f_switch_ids( const int64_t stride_col_id, const int64_t stride_row_id, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { - if (ids) { + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream, + const mmf_ids_data * ids_data) { + const bool has_ids_data = ids_data && ids_data->ids_src_compact; + + if (has_ids_data) { + const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block); + if (max_tiles == 0) { + return; + } + GGML_ASSERT(ids_data->cols_per_tile == 0 || ids_data->cols_per_tile == cols_per_block); + + dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles); + + const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1); + const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst); + + mul_mat_f_ids<<>> + (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst, + ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, + sis1_fd, nch_fd); + } else if (ids) { const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block; dim3 block_nums_ids = block_nums; block_nums_ids.y *= col_tiles; + mul_mat_f<<>> - (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } else { @@ -258,7 +478,7 @@ void mul_mat_f_cuda( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + cudaStream_t stream, const mmf_ids_data * ids_data) { typedef tile<16, 8, T> tile_A; typedef tile< 8, 8, T> tile_B; @@ -288,9 +508,8 @@ void mul_mat_f_cuda( const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); - const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; - const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; - const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present + const int nbytes_shared_total = nbytes_shared; + const int64_t grid_y = ids ? nchannels_x : nchannels_dst; const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); const dim3 block_dims(warp_size, nwarps_best, 1); @@ -300,49 +519,57 @@ void mul_mat_f_cuda( mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 2: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 3: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 4: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 5: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 6: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 7: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; case 8: { mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); } break; default: { GGML_ABORT("fatal error"); @@ -361,92 +588,123 @@ static void mul_mat_f_switch_cols_per_block( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + cudaStream_t stream, const mmf_ids_data * ids_data) { const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst; GGML_ASSERT(ids || ncols_dst <= 16); + mmf_ids_data ids_case; + auto prepare_ids = [&](int cols_case) -> const mmf_ids_data * { + if (!ids_data || !ids_data->ids_src_compact) { + return nullptr; + } + + if (ids_data->cols_per_tile != 0 && ids_data->cols_per_tile != cols_case) { + return nullptr; + } + + ids_case = *ids_data; + ids_case.cols_per_tile = cols_case; + return &ids_case; + }; + switch (ncols_case) { case 1: { + const mmf_ids_data * ids_case_ptr = prepare_ids(1); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 2: { + const mmf_ids_data * ids_case_ptr = prepare_ids(2); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 3: { + const mmf_ids_data * ids_case_ptr = prepare_ids(3); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 4: { + const mmf_ids_data * ids_case_ptr = prepare_ids(4); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 5: { + const mmf_ids_data * ids_case_ptr = prepare_ids(5); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 6: { + const mmf_ids_data * ids_case_ptr = prepare_ids(6); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 7: { + const mmf_ids_data * ids_case_ptr = prepare_ids(7); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 8: { + const mmf_ids_data * ids_case_ptr = prepare_ids(8); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 9: { + const mmf_ids_data * ids_case_ptr = prepare_ids(9); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 10: { + const mmf_ids_data * ids_case_ptr = prepare_ids(10); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 11: { + const mmf_ids_data * ids_case_ptr = prepare_ids(11); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 12: { + const mmf_ids_data * ids_case_ptr = prepare_ids(12); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 13: { + const mmf_ids_data * ids_case_ptr = prepare_ids(13); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 14: { + const mmf_ids_data * ids_case_ptr = prepare_ids(14); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 15: { + const mmf_ids_data * ids_case_ptr = prepare_ids(15); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 16: { + const mmf_ids_data * ids_case_ptr = prepare_ids(16); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; default: { GGML_ABORT("fatal error"); @@ -462,7 +720,7 @@ static void mul_mat_f_switch_cols_per_block( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ - cudaStream_t stream); + cudaStream_t stream, const mmf_ids_data * ids_data); #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) #define DECL_MMF_CASE_EXTERN(ncols_dst) \ diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 12bdc629bd6b2..57604b9333bda 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -137,6 +137,34 @@ static void launch_mmq_ids_helper( (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); } +void ggml_cuda_launch_mmq_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + switch (n_expert_used) { + case 2: + launch_mmq_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 4: + launch_mmq_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 6: + launch_mmq_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 8: + launch_mmq_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 16: + launch_mmq_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 32: + launch_mmq_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + default: + launch_mmq_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + } +} + static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { case GGML_TYPE_Q4_0: @@ -293,36 +321,8 @@ void ggml_cuda_mul_mat_q( const int si1 = ids->nb[1] / ggml_element_size(ids); const int sis1 = nb12 / nb11; - switch (n_expert_used) { - case 2: - launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 4: - launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 6: - launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 8: - launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 16: - launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - case 32: - launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - default: - launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), - ne02, ne12, n_expert_used, ne11, si1, sis1, stream); - break; - } + ggml_cuda_launch_mmq_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); CUDA_CHECK(cudaGetLastError()); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2fa16b497a6b7..3ed05632d3f2c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6911,7 +6911,7 @@ static std::vector> make_test_cases_perf() { } // qwen3-30b-a3b - for (int bs : {1, 4, 8, 32, 64, 128, 512}) { + for (int bs : {1, 4, 8, 32, 64, 128, 512, 1024, 2048, 4096, 8192}) { for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { for (ggml_type type_b : {GGML_TYPE_F32}) { test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1)); From 996bed07d276f31b69c847b6ac88aa19937e93c4 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 10 Oct 2025 14:20:10 +0800 Subject: [PATCH 02/12] WIP --- ggml/src/ggml-cuda/mmf.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index de6eddcdedfa5..e31b75846fbcf 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -139,7 +139,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const if (type == GGML_TYPE_F32 && src1_ncols > 512) { return false; } - if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 8192) { + if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 1024) { return false; } } else { From 4a2388ee59146baaa45fef0f1dc8b7a6a1a939ff Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 10 Oct 2025 14:37:12 +0800 Subject: [PATCH 03/12] WIP --- ggml/src/ggml-cuda/mmf.cuh | 20 ++++---------------- ggml/src/ggml-cuda/mmq.cuh | 4 ++++ 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index b3da913a0c030..60227edb4e955 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -508,7 +508,8 @@ void mul_mat_f_cuda( const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); - const int nbytes_shared_total = nbytes_shared; + const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; + const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; const int64_t grid_y = ids ? nchannels_x : nchannels_dst; const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); @@ -609,93 +610,80 @@ static void mul_mat_f_switch_cols_per_block( return &ids_case; }; + mmf_ids_data * ids_case_ptr = nullptr; + switch (ncols_case) { case 1: { - const mmf_ids_data * ids_case_ptr = prepare_ids(1); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 2: { - const mmf_ids_data * ids_case_ptr = prepare_ids(2); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 3: { - const mmf_ids_data * ids_case_ptr = prepare_ids(3); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 4: { - const mmf_ids_data * ids_case_ptr = prepare_ids(4); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 5: { - const mmf_ids_data * ids_case_ptr = prepare_ids(5); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 6: { - const mmf_ids_data * ids_case_ptr = prepare_ids(6); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 7: { - const mmf_ids_data * ids_case_ptr = prepare_ids(7); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 8: { - const mmf_ids_data * ids_case_ptr = prepare_ids(8); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 9: { - const mmf_ids_data * ids_case_ptr = prepare_ids(9); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 10: { - const mmf_ids_data * ids_case_ptr = prepare_ids(10); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 11: { - const mmf_ids_data * ids_case_ptr = prepare_ids(11); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 12: { - const mmf_ids_data * ids_case_ptr = prepare_ids(12); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 13: { - const mmf_ids_data * ids_case_ptr = prepare_ids(13); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 14: { - const mmf_ids_data * ids_case_ptr = prepare_ids(14); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); } break; case 15: { - const mmf_ids_data * ids_case_ptr = prepare_ids(15); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index c9a07e82fedf2..2cb6a0c58f451 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3756,3 +3756,7 @@ void ggml_cuda_op_mul_mat_q( const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11); + +void ggml_cuda_launch_mmq_ids_helper( + const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream); From 3120ddfa726f49fdefdab8ba32658eb9e1286b22 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 10 Oct 2025 18:06:50 +0800 Subject: [PATCH 04/12] WIP --- tests/test-backend-ops.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3ed05632d3f2c..56417290da96d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6911,7 +6911,7 @@ static std::vector> make_test_cases_perf() { } // qwen3-30b-a3b - for (int bs : {1, 4, 8, 32, 64, 128, 512, 1024, 2048, 4096, 8192}) { + for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { for (ggml_type type_b : {GGML_TYPE_F32}) { test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1)); @@ -6919,6 +6919,15 @@ static std::vector> make_test_cases_perf() { } } + // liquid 1b-8b + for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { + for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1)); + } + } + } + // gpt-oss-20b for (int bs : {1, 4, 8, 512}) { for (ggml_type type_a : {GGML_TYPE_MXFP4}) { From 237ade5d56315ef2c07baf6de464d7c3bc57eece Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 10 Oct 2025 19:36:04 +0800 Subject: [PATCH 05/12] WIP --- ggml/src/ggml-cuda/mmf.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 60227edb4e955..d529814808247 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -291,6 +291,8 @@ static __global__ void mul_mat_f_ids( extern __shared__ char data_mmv[]; char * compute_base = data_mmv; + const float2 * y2 = (const float2 *) y; + tile_C C[ntA][ntB]; T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); @@ -342,7 +344,7 @@ static __global__ void mul_mat_f_ids( const int token = (int) qrm.x; const int channel = (int) qrm.y; if (token < ncols_dst_total) { - tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)]; + tmp = y2[channel*stride_channel_y/2 + token*stride_col_y + col]; } } tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; From 43e0436b9cb4bdad31655b4d2208a11c46ebbe9c Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 10 Oct 2025 22:16:27 +0800 Subject: [PATCH 06/12] WIP --- ggml/src/ggml-cuda/mmf.cu | 2 +- ggml/src/ggml-cuda/mmf.cuh | 95 ++++++++++++++++++++++++++++++-------- 2 files changed, 78 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index e31b75846fbcf..362bf4847b276 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -139,7 +139,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const if (type == GGML_TYPE_F32 && src1_ncols > 512) { return false; } - if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 1024) { + if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 512) { return false; } } else { diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index d529814808247..e3dd49937b8cd 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -291,7 +291,7 @@ static __global__ void mul_mat_f_ids( extern __shared__ char data_mmv[]; char * compute_base = data_mmv; - const float2 * y2 = (const float2 *) y; + //const float2 * y2 = (const float2 *) y; tile_C C[ntA][ntB]; @@ -311,13 +311,12 @@ static __global__ void mul_mat_f_ids( } } -#pragma unroll - for (int itB = 0; itB < ntB; ++itB) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + float vals_buf[2][tile_B::I]; + auto gather_tile = [&](int tile_idx_local, float *vals) { #pragma unroll for (int j0 = 0; j0 < tile_B::I; ++j0) { - const int j = j0 + itB*tile_B::I; - + const int j = j0 + tile_idx_local*tile_B::I; const int global_j = col_base + j; float val = 0.0f; if (j < cols_per_block && global_j < ncols_expert) { @@ -329,13 +328,48 @@ static __global__ void mul_mat_f_ids( val = y[channel*stride_channel_y + token*stride_col_y + col]; } } - tile_xy[j0*tile_k_padded + threadIdx.x] = val; + vals[j0] = val; } - } else if constexpr (std::is_same_v || std::is_same_v) { + }; + + if (ntB > 0) { + gather_tile(0, vals_buf[0]); + } + + int curr_buf = 0; + int next_buf = 1; +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { #pragma unroll for (int j0 = 0; j0 < tile_B::I; ++j0) { - const int j = j0 + itB*tile_B::I; + tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0]; + } + + if (itB + 1 < ntB) { + gather_tile(itB + 1, vals_buf[next_buf]); + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + + if (itB + 1 < ntB) { + curr_buf ^= 1; + next_buf ^= 1; + } + } + } else if constexpr (std::is_same_v || std::is_same_v) { + float2 vals_buf[2][tile_B::I]; + auto gather_tile = [&](int tile_idx_local, float2 *vals) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + tile_idx_local*tile_B::I; const int global_j = col_base + j; float2 tmp = make_float2(0.0f, 0.0f); if (j < cols_per_block && global_j < ncols_expert) { @@ -344,23 +378,48 @@ static __global__ void mul_mat_f_ids( const int token = (int) qrm.x; const int channel = (int) qrm.y; if (token < ncols_dst_total) { - tmp = y2[channel*stride_channel_y/2 + token*stride_col_y + col]; + tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)]; } } - tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + vals[j0] = tmp; } - } else { - static_assert(std::is_same_v, "unsupported type"); + }; + + if (ntB > 0) { + gather_tile(0, vals_buf[0]); } + + int curr_buf = 0; + int next_buf = 1; #pragma unroll - for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { - tile_B B; - load_ldmatrix(B, tile_xy + k0, tile_k_padded); + for (int itB = 0; itB < ntB; ++itB) { #pragma unroll - for (int itA = 0; itA < ntA; ++itA) { - mma(C[itA][itB], A[itA][k0/tile_B::J], B); + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const float2 tmp = vals_buf[curr_buf][j0]; + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } + + if (itB + 1 < ntB) { + gather_tile(itB + 1, vals_buf[next_buf]); + } + +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + + if (itB + 1 < ntB) { + curr_buf ^= 1; + next_buf ^= 1; } } + } else { + static_assert(std::is_same_v, "unsupported type"); } } From 56e73628ffd7f74b87611b9365905ec0e7379782 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 11 Oct 2025 15:02:33 +0800 Subject: [PATCH 07/12] WIP --- ggml/src/ggml-cuda/mmf.cu | 7 ++++--- ggml/src/ggml-cuda/mmf.cuh | 4 +++- tests/test-backend-ops.cpp | 9 --------- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 362bf4847b276..9cc874b4b9f87 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -136,10 +136,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const } if (mul_mat_id) { - if (type == GGML_TYPE_F32 && src1_ncols > 512) { + if (src0_ne[1] <= 1024 && src1_ncols > 512) { return false; - } - if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 512) { + } else if(src0_ne[1] > 1024 && src1_ncols > 128) { + return false; + } else { return false; } } else { diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index e3dd49937b8cd..5e38036c51d35 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -495,7 +495,9 @@ static inline void mul_mat_f_switch_ids( const mmf_ids_data * ids_data) { const bool has_ids_data = ids_data && ids_data->ids_src_compact; - if (has_ids_data) { + // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16) + // we prefer the normal mul_mat_f path with has_ids=true. + if (has_ids_data && ncols_dst > 16) { const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block); if (max_tiles == 0) { return; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 56417290da96d..9e20238475510 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6919,15 +6919,6 @@ static std::vector> make_test_cases_perf() { } } - // liquid 1b-8b - for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { - for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { - for (ggml_type type_b : {GGML_TYPE_F32}) { - test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1)); - } - } - } - // gpt-oss-20b for (int bs : {1, 4, 8, 512}) { for (ggml_type type_a : {GGML_TYPE_MXFP4}) { From 3183a8ef8bb4fc48eb27b095f7fd7c9a9c40ef55 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 11 Oct 2025 15:35:58 +0800 Subject: [PATCH 08/12] fixup --- ggml/src/ggml-cuda/mmf.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 9cc874b4b9f87..32ac9e49d8c19 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -140,8 +140,6 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } else if(src0_ne[1] > 1024 && src1_ncols > 128) { return false; - } else { - return false; } } else { if (src1_ncols > 16) { From 0728f6b7d88034623a3cbf306b5f15b16a88d481 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 13 Oct 2025 09:52:14 +0800 Subject: [PATCH 09/12] tests --- tests/test-backend-ops.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9e20238475510..bb31280631f7f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6919,6 +6919,15 @@ static std::vector> make_test_cases_perf() { } } + for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { + for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1)); + } + } + } + + // gpt-oss-20b for (int bs : {1, 4, 8, 512}) { for (ggml_type type_a : {GGML_TYPE_MXFP4}) { From 37f8f84d0285cd07bd21d4e7a5fc57b49cf7025c Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 14 Oct 2025 12:27:33 +0800 Subject: [PATCH 10/12] Move mmq_ids_helper to mmid --- ggml/src/ggml-cuda/mmid.cu | 164 ++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/mmid.cuh | 5 ++ 2 files changed, 169 insertions(+) create mode 100644 ggml/src/ggml-cuda/mmid.cu create mode 100644 ggml/src/ggml-cuda/mmid.cuh diff --git a/ggml/src/ggml-cuda/mmid.cu b/ggml/src/ggml-cuda/mmid.cu new file mode 100644 index 0000000000000..3c61e4595a7b1 --- /dev/null +++ b/ggml/src/ggml-cuda/mmid.cu @@ -0,0 +1,164 @@ +#include "common.cuh" +#include "mmid.cuh" + +// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. +struct mm_ids_helper_store { + uint32_t data; + + __device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) { + data = (it & 0x003FFFFF) | (iex_used << 22); + } + + __device__ uint32_t it() const { + return data & 0x003FFFFF; + } + + __device__ uint32_t iex_used() const { + return data >> 22; + } +}; +static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store"); + +// Helper function for mul_mat_id, converts ids to a more convenient format. +// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. +// ids_dst describes the same mapping but for the dst tensor. +// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. +template +__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; + const int expert = blockIdx.x; + + extern __shared__ char data_mm_ids_helper[]; + mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper; + + int nex_prev = 0; // Number of columns for experts with a lower index. + int it_compact = 0; // Running index for the compact slice of this expert. + + if constexpr (n_expert_used_template == 0) { + // Generic implementation: + for (int it = 0; it < n_tokens; ++it) { + int iex_used = -1; // The index at which the expert is used, if any. + for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { + const int expert_used = ids[it*si1 + iex]; + nex_prev += expert_used < expert; + if (expert_used == expert) { + iex_used = iex; + } + } + + if (iex_used != -1) { + store[it_compact] = mm_ids_helper_store(it, iex_used); + } + + if (warp_reduce_any(iex_used != -1)) { + it_compact++; + } + } + } else { + // Implementation optimized for specific numbers of experts used: + static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); + const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. + for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { + const int it = it0 + threadIdx.x / neu_padded; + + const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. + const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? + ids[it*si1 + iex] : INT_MAX; + const int iex_used = expert_used == expert ? iex : -1; + nex_prev += expert_used < expert; + + // Whether the threads at this token position have used the expert: + const int it_compact_add_self = warp_reduce_any(iex_used != -1); + + // Do a scan over threads at lower token positions in warp to get the correct index for writing data: + int it_compact_add_lower = 0; +#pragma unroll + for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { + const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); + if (threadIdx.x >= static_cast(offset)) { + it_compact_add_lower += tmp; + } + } + + if (iex_used != -1) { + store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used); + } + + // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: + it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); + } + } + nex_prev = warp_reduce_sum(nex_prev); + + for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { + const mm_ids_helper_store store_it = store[itc]; + const int it = store_it.it(); + const int iex_used = store_it.iex_used(); + ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; + ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; + } + + if (threadIdx.x != 0) { + return; + } + + expert_bounds[expert] = nex_prev; + + if (expert < static_cast(gridDim.x) - 1) { + return; + } + + expert_bounds[gridDim.x] = nex_prev + it_compact; +} + +template +static void launch_mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store"); + GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store"); + + const int id = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper, smpbo); + + const dim3 num_blocks(n_experts, 1, 1); + const dim3 block_size(warp_size, 1, 1); + const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store); + GGML_ASSERT(nbytes_shared <= smpbo); + mm_ids_helper<<>> + (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); +} + +void ggml_cuda_launch_mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + switch (n_expert_used) { + case 2: + launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 4: + launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 6: + launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 8: + launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 16: + launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 32: + launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + default: + launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + } +} diff --git a/ggml/src/ggml-cuda/mmid.cuh b/ggml/src/ggml-cuda/mmid.cuh new file mode 100644 index 0000000000000..4ff60b2e90e8a --- /dev/null +++ b/ggml/src/ggml-cuda/mmid.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +void ggml_cuda_launch_mm_ids_helper( + const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream); \ No newline at end of file From c710560daf2a9e95c6fa42eafbeec56390375da4 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 14 Oct 2025 12:41:01 +0800 Subject: [PATCH 11/12] cleanup --- ggml/src/ggml-cuda/mmf.cu | 7 +- ggml/src/ggml-cuda/mmf.cuh | 53 ++++-------- ggml/src/ggml-cuda/mmid.cuh | 4 +- ggml/src/ggml-cuda/mmq.cu | 167 +----------------------------------- ggml/src/ggml-cuda/mmq.cuh | 4 - 5 files changed, 22 insertions(+), 213 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 32ac9e49d8c19..9e2aaf52d6cce 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -1,9 +1,7 @@ #include "ggml.h" #include "mmf.cuh" +#include "mmid.cuh" -void ggml_cuda_launch_mmq_ids_helper( - const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, - int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream); void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); @@ -79,7 +77,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr GGML_ASSERT(sis1 > 0); - ggml_cuda_launch_mmq_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(), + ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(), static_cast(n_experts), static_cast(n_tokens), static_cast(n_expert_used), static_cast(ne11), si1, sis1, ctx.stream()); CUDA_CHECK(cudaGetLastError()); @@ -88,7 +86,6 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr ids_info.expert_bounds_dev = expert_bounds_dev.get(); ids_info.n_experts = static_cast(n_experts); ids_info.sis1 = sis1; - ids_info.cols_per_tile = 16; ids_info_ptr = &ids_info; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 5e38036c51d35..31184198c331e 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -13,7 +13,6 @@ struct mmf_ids_data { const int32_t * expert_bounds_dev = nullptr; int n_experts = 0; int sis1 = 0; - int cols_per_tile = 0; }; void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); @@ -502,8 +501,6 @@ static inline void mul_mat_f_switch_ids( if (max_tiles == 0) { return; } - GGML_ASSERT(ids_data->cols_per_tile == 0 || ids_data->cols_per_tile == cols_per_block); - dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles); const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1); @@ -658,104 +655,86 @@ static void mul_mat_f_switch_cols_per_block( GGML_ASSERT(ids || ncols_dst <= 16); - mmf_ids_data ids_case; - auto prepare_ids = [&](int cols_case) -> const mmf_ids_data * { - if (!ids_data || !ids_data->ids_src_compact) { - return nullptr; - } - - if (ids_data->cols_per_tile != 0 && ids_data->cols_per_tile != cols_case) { - return nullptr; - } - - ids_case = *ids_data; - ids_case.cols_per_tile = cols_case; - return &ids_case; - }; - - mmf_ids_data * ids_case_ptr = nullptr; - switch (ncols_case) { case 1: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 2: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 3: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 4: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 5: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 6: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 7: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 8: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 9: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 10: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 11: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 12: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 13: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 14: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 15: { mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 16: { - const mmf_ids_data * ids_case_ptr = prepare_ids(16); mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_case_ptr); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cuda/mmid.cuh b/ggml/src/ggml-cuda/mmid.cuh index 4ff60b2e90e8a..ac090aea9ea1a 100644 --- a/ggml/src/ggml-cuda/mmid.cuh +++ b/ggml/src/ggml-cuda/mmid.cuh @@ -1,5 +1,5 @@ -#include "common.cuh" +#pragma once void ggml_cuda_launch_mm_ids_helper( const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, - int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream); \ No newline at end of file + int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 57604b9333bda..a2c8760abea93 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -1,169 +1,6 @@ #include "mmq.cuh" #include "quantize.cuh" - -#include - -// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. -struct mmq_ids_helper_store { - uint32_t data; - - __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) { - data = (it & 0x003FFFFF) | (iex_used << 22); - } - - __device__ uint32_t it() const { - return data & 0x003FFFFF; - } - - __device__ uint32_t iex_used() const { - return data >> 22; - } -}; -static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store"); - -// Helper function for mul_mat_id, converts ids to a more convenient format. -// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. -// ids_dst describes the same mapping but for the dst tensor. -// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. -template -__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) -static __global__ void mmq_ids_helper( - const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, - const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; - const int expert = blockIdx.x; - - extern __shared__ char data_mmq_ids_helper[]; - mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper; - - int nex_prev = 0; // Number of columns for experts with a lower index. - int it_compact = 0; // Running index for the compact slice of this expert. - - if constexpr (n_expert_used_template == 0) { - // Generic implementation: - for (int it = 0; it < n_tokens; ++it) { - int iex_used = -1; // The index at which the expert is used, if any. - for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { - const int expert_used = ids[it*si1 + iex]; - nex_prev += expert_used < expert; - if (expert_used == expert) { - iex_used = iex; - } - } - - if (iex_used != -1) { - store[it_compact] = mmq_ids_helper_store(it, iex_used); - } - - if (warp_reduce_any(iex_used != -1)) { - it_compact++; - } - } - } else { - // Implementation optimized for specific numbers of experts used: - static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); - const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. - for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { - const int it = it0 + threadIdx.x / neu_padded; - - const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. - const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? - ids[it*si1 + iex] : INT_MAX; - const int iex_used = expert_used == expert ? iex : -1; - nex_prev += expert_used < expert; - - // Whether the threads at this token position have used the expert: - const int it_compact_add_self = warp_reduce_any(iex_used != -1); - - // Do a scan over threads at lower token positions in warp to get the correct index for writing data: - int it_compact_add_lower = 0; -#pragma unroll - for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { - const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); - if (threadIdx.x >= static_cast(offset)) { - it_compact_add_lower += tmp; - } - } - - if (iex_used != -1) { - store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used); - } - - // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: - it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); - } - } - nex_prev = warp_reduce_sum(nex_prev); - - for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { - const mmq_ids_helper_store store_it = store[itc]; - const int it = store_it.it(); - const int iex_used = store_it.iex_used(); - ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; - ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; - } - - if (threadIdx.x != 0) { - return; - } - - expert_bounds[expert] = nex_prev; - - if (expert < static_cast(gridDim.x) - 1) { - return; - } - - expert_bounds[gridDim.x] = nex_prev + it_compact; -} - -template -static void launch_mmq_ids_helper( - const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, - const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { - GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store"); - GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store"); - - const int id = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[id].warp_size; - const size_t smpbo = ggml_cuda_info().devices[id].smpbo; - CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper, smpbo); - - const dim3 num_blocks(n_experts, 1, 1); - const dim3 block_size(warp_size, 1, 1); - const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store); - GGML_ASSERT(nbytes_shared <= smpbo); - mmq_ids_helper<<>> - (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); -} - -void ggml_cuda_launch_mmq_ids_helper( - const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, - const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { - switch (n_expert_used) { - case 2: - launch_mmq_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); - break; - case 4: - launch_mmq_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); - break; - case 6: - launch_mmq_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); - break; - case 8: - launch_mmq_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); - break; - case 16: - launch_mmq_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); - break; - case 32: - launch_mmq_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); - break; - default: - launch_mmq_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); - break; - } -} +#include "mmid.cuh" static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { @@ -321,7 +158,7 @@ void ggml_cuda_mul_mat_q( const int si1 = ids->nb[1] / ggml_element_size(ids); const int sis1 = nb12 / nb11; - ggml_cuda_launch_mmq_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), ne02, ne12, n_expert_used, ne11, si1, sis1, stream); CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 2cb6a0c58f451..c9a07e82fedf2 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3756,7 +3756,3 @@ void ggml_cuda_op_mul_mat_q( const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11); - -void ggml_cuda_launch_mmq_ids_helper( - const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, - int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream); From ddd3a990632fa9fd063d32daff5ac8d1aab67869 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 14 Oct 2025 17:41:38 +0800 Subject: [PATCH 12/12] Remove redundant checks --- ggml/src/ggml-cuda/mmf.cuh | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 31184198c331e..49d5295be0ea0 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -261,10 +261,6 @@ static __global__ void mul_mat_f_ids( const int expert_end = expert_bounds[expert_idx + 1]; const int ncols_expert = expert_end - expert_start; - if (ncols_expert <= 0) { - return; - } - const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block; const int tile_idx = blockIdx.z; if (tile_idx >= tiles_for_expert) { @@ -331,9 +327,7 @@ static __global__ void mul_mat_f_ids( } }; - if (ntB > 0) { - gather_tile(0, vals_buf[0]); - } + gather_tile(0, vals_buf[0]); int curr_buf = 0; int next_buf = 1;