|
12 | 12 | #include "dequantize.cuh" |
13 | 13 | #include "mmvq.cuh" |
14 | 14 | #include "mmq.cuh" |
| 15 | +#include "moe.cuh" |
15 | 16 |
|
16 | 17 | // Q8 gemv |
17 | 18 | template <typename scalar_t> |
@@ -59,10 +60,14 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx, |
59 | 60 | const int64_t kx_padded = (kx + 512 - 1) / 512 * 512; |
60 | 61 | const int block_num_x = |
61 | 62 | (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; |
62 | | - const dim3 num_blocks(block_num_x, ky, 1); |
63 | | - const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); |
64 | | - quantize_q8_1<scalar_t> |
65 | | - <<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded); |
| 63 | + constexpr int MAX_BLOCK_SIZE = 65535; |
| 64 | + for (int off = 0; off < ky; off += MAX_BLOCK_SIZE) { |
| 65 | + const int num_blocks_y = std::min(ky, off + MAX_BLOCK_SIZE) - off; |
| 66 | + const dim3 num_blocks(block_num_x, num_blocks_y, 1); |
| 67 | + const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); |
| 68 | + quantize_q8_1<<<num_blocks, block_size, 0, stream>>>( |
| 69 | + &x[off * kx], (int32_t*)vy + off * (kx_padded / 32 * 9), kx, kx_padded); |
| 70 | + } |
66 | 71 | } |
67 | 72 |
|
68 | 73 | torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight |
@@ -263,3 +268,132 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight |
263 | 268 | }); |
264 | 269 | return Y; |
265 | 270 | } |
| 271 | + |
| 272 | +torch::Tensor ggml_moe_a8(torch::Tensor X, // input |
| 273 | + torch::Tensor W, // expert weights |
| 274 | + torch::Tensor sorted_token_ids, |
| 275 | + torch::Tensor expert_ids, |
| 276 | + torch::Tensor num_tokens_post_padded, int64_t type, |
| 277 | + int64_t row, int64_t top_k, int64_t tokens) { |
| 278 | + int col = X.sizes()[1]; |
| 279 | + int padded = (col + 512 - 1) / 512 * 512; |
| 280 | + const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); |
| 281 | + auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); |
| 282 | + at::Tensor Y = torch::empty({tokens * top_k, row}, options); |
| 283 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); |
| 284 | + options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); |
| 285 | + at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options); |
| 286 | + VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] { |
| 287 | + quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), |
| 288 | + col, tokens, stream); |
| 289 | + switch (type) { |
| 290 | + case 2: |
| 291 | + ggml_moe_q4_0_q8_1_cuda( |
| 292 | + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), |
| 293 | + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), |
| 294 | + (int*)expert_ids.data_ptr(), |
| 295 | + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, |
| 296 | + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); |
| 297 | + break; |
| 298 | + case 3: |
| 299 | + ggml_moe_q4_1_q8_1_cuda( |
| 300 | + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), |
| 301 | + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), |
| 302 | + (int*)expert_ids.data_ptr(), |
| 303 | + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, |
| 304 | + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); |
| 305 | + break; |
| 306 | + case 6: |
| 307 | + ggml_moe_q5_0_q8_1_cuda( |
| 308 | + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), |
| 309 | + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), |
| 310 | + (int*)expert_ids.data_ptr(), |
| 311 | + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, |
| 312 | + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); |
| 313 | + break; |
| 314 | + case 7: |
| 315 | + ggml_moe_q5_1_q8_1_cuda( |
| 316 | + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), |
| 317 | + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), |
| 318 | + (int*)expert_ids.data_ptr(), |
| 319 | + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, |
| 320 | + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); |
| 321 | + break; |
| 322 | + case 8: |
| 323 | + ggml_moe_q8_0_q8_1_cuda( |
| 324 | + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), |
| 325 | + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), |
| 326 | + (int*)expert_ids.data_ptr(), |
| 327 | + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, |
| 328 | + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); |
| 329 | + break; |
| 330 | + case 10: |
| 331 | + ggml_moe_q2_K_q8_1_cuda( |
| 332 | + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), |
| 333 | + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), |
| 334 | + (int*)expert_ids.data_ptr(), |
| 335 | + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, |
| 336 | + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); |
| 337 | + break; |
| 338 | + case 11: |
| 339 | + ggml_moe_q3_K_q8_1_cuda( |
| 340 | + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), |
| 341 | + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), |
| 342 | + (int*)expert_ids.data_ptr(), |
| 343 | + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, |
| 344 | + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); |
| 345 | + break; |
| 346 | + case 12: |
| 347 | + ggml_moe_q4_K_q8_1_cuda( |
| 348 | + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), |
| 349 | + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), |
| 350 | + (int*)expert_ids.data_ptr(), |
| 351 | + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, |
| 352 | + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); |
| 353 | + break; |
| 354 | + case 13: |
| 355 | + ggml_moe_q5_K_q8_1_cuda( |
| 356 | + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), |
| 357 | + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), |
| 358 | + (int*)expert_ids.data_ptr(), |
| 359 | + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, |
| 360 | + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); |
| 361 | + break; |
| 362 | + case 14: |
| 363 | + ggml_moe_q6_K_q8_1_cuda( |
| 364 | + (void*)quant_X.data_ptr(), (void*)W.data_ptr(), |
| 365 | + (scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(), |
| 366 | + (int*)expert_ids.data_ptr(), |
| 367 | + (int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row, |
| 368 | + tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream); |
| 369 | + break; |
| 370 | + } |
| 371 | + }); |
| 372 | + return Y; |
| 373 | +} |
| 374 | + |
| 375 | +int64_t ggml_moe_get_block_size(int64_t type) { |
| 376 | + switch (type) { |
| 377 | + case 2: |
| 378 | + return MMQ_X_Q4_0; |
| 379 | + case 3: |
| 380 | + return MMQ_X_Q4_1; |
| 381 | + case 6: |
| 382 | + return MMQ_X_Q5_0; |
| 383 | + case 7: |
| 384 | + return MMQ_X_Q5_1; |
| 385 | + case 8: |
| 386 | + return MMQ_X_Q8_0; |
| 387 | + case 10: |
| 388 | + return MMQ_X_Q2_K; |
| 389 | + case 11: |
| 390 | + return MMQ_X_Q3_K; |
| 391 | + case 12: |
| 392 | + return MMQ_X_Q4_K; |
| 393 | + case 13: |
| 394 | + return MMQ_X_Q5_K; |
| 395 | + case 14: |
| 396 | + return MMQ_X_Q6_K; |
| 397 | + } |
| 398 | + return 0; |
| 399 | +} |
0 commit comments