Skip to content

Commit e22ee1e

Browse files
authored
[Kernel] GGUF MoE kernel (#14613)
Signed-off-by: SzymonOzog <[email protected]>
1 parent e392d85 commit e22ee1e

File tree

8 files changed

+1070
-25
lines changed

8 files changed

+1070
-25
lines changed

csrc/ops.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,14 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
151151
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
152152
int64_t row);
153153

154+
torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
155+
torch::Tensor sorted_token_ids,
156+
torch::Tensor expert_ids,
157+
torch::Tensor num_tokens_post_padded, int64_t type,
158+
int64_t row, int64_t top_k, int64_t tokens);
159+
160+
int64_t ggml_moe_get_block_size(int64_t type);
161+
154162
#ifndef USE_ROCM
155163
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
156164
torch::Tensor const& B, torch::Tensor const& A_sf,

csrc/quantization/gguf/gguf_kernel.cu

Lines changed: 138 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "dequantize.cuh"
1313
#include "mmvq.cuh"
1414
#include "mmq.cuh"
15+
#include "moe.cuh"
1516

1617
// Q8 gemv
1718
template <typename scalar_t>
@@ -59,10 +60,14 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
5960
const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
6061
const int block_num_x =
6162
(kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
62-
const dim3 num_blocks(block_num_x, ky, 1);
63-
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
64-
quantize_q8_1<scalar_t>
65-
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
63+
constexpr int MAX_BLOCK_SIZE = 65535;
64+
for (int off = 0; off < ky; off += MAX_BLOCK_SIZE) {
65+
const int num_blocks_y = std::min(ky, off + MAX_BLOCK_SIZE) - off;
66+
const dim3 num_blocks(block_num_x, num_blocks_y, 1);
67+
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
68+
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(
69+
&x[off * kx], (int32_t*)vy + off * (kx_padded / 32 * 9), kx, kx_padded);
70+
}
6671
}
6772

6873
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
@@ -263,3 +268,132 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
263268
});
264269
return Y;
265270
}
271+
272+
torch::Tensor ggml_moe_a8(torch::Tensor X, // input
273+
torch::Tensor W, // expert weights
274+
torch::Tensor sorted_token_ids,
275+
torch::Tensor expert_ids,
276+
torch::Tensor num_tokens_post_padded, int64_t type,
277+
int64_t row, int64_t top_k, int64_t tokens) {
278+
int col = X.sizes()[1];
279+
int padded = (col + 512 - 1) / 512 * 512;
280+
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
281+
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
282+
at::Tensor Y = torch::empty({tokens * top_k, row}, options);
283+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
284+
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
285+
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
286+
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] {
287+
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(),
288+
col, tokens, stream);
289+
switch (type) {
290+
case 2:
291+
ggml_moe_q4_0_q8_1_cuda(
292+
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
293+
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
294+
(int*)expert_ids.data_ptr(),
295+
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
296+
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
297+
break;
298+
case 3:
299+
ggml_moe_q4_1_q8_1_cuda(
300+
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
301+
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
302+
(int*)expert_ids.data_ptr(),
303+
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
304+
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
305+
break;
306+
case 6:
307+
ggml_moe_q5_0_q8_1_cuda(
308+
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
309+
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
310+
(int*)expert_ids.data_ptr(),
311+
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
312+
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
313+
break;
314+
case 7:
315+
ggml_moe_q5_1_q8_1_cuda(
316+
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
317+
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
318+
(int*)expert_ids.data_ptr(),
319+
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
320+
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
321+
break;
322+
case 8:
323+
ggml_moe_q8_0_q8_1_cuda(
324+
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
325+
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
326+
(int*)expert_ids.data_ptr(),
327+
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
328+
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
329+
break;
330+
case 10:
331+
ggml_moe_q2_K_q8_1_cuda(
332+
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
333+
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
334+
(int*)expert_ids.data_ptr(),
335+
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
336+
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
337+
break;
338+
case 11:
339+
ggml_moe_q3_K_q8_1_cuda(
340+
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
341+
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
342+
(int*)expert_ids.data_ptr(),
343+
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
344+
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
345+
break;
346+
case 12:
347+
ggml_moe_q4_K_q8_1_cuda(
348+
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
349+
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
350+
(int*)expert_ids.data_ptr(),
351+
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
352+
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
353+
break;
354+
case 13:
355+
ggml_moe_q5_K_q8_1_cuda(
356+
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
357+
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
358+
(int*)expert_ids.data_ptr(),
359+
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
360+
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
361+
break;
362+
case 14:
363+
ggml_moe_q6_K_q8_1_cuda(
364+
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
365+
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
366+
(int*)expert_ids.data_ptr(),
367+
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
368+
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
369+
break;
370+
}
371+
});
372+
return Y;
373+
}
374+
375+
int64_t ggml_moe_get_block_size(int64_t type) {
376+
switch (type) {
377+
case 2:
378+
return MMQ_X_Q4_0;
379+
case 3:
380+
return MMQ_X_Q4_1;
381+
case 6:
382+
return MMQ_X_Q5_0;
383+
case 7:
384+
return MMQ_X_Q5_1;
385+
case 8:
386+
return MMQ_X_Q8_0;
387+
case 10:
388+
return MMQ_X_Q2_K;
389+
case 11:
390+
return MMQ_X_Q3_K;
391+
case 12:
392+
return MMQ_X_Q4_K;
393+
case 13:
394+
return MMQ_X_Q5_K;
395+
case 14:
396+
return MMQ_X_Q6_K;
397+
}
398+
return 0;
399+
}

0 commit comments

Comments
 (0)