Skip to content

vulkan: Use larger workgroups for mul_mat_vec when M is small #15355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 17, 2025

Conversation

jeffbolznv
Copy link
Collaborator

Use larger workgroups for mul_mat_vec when m is small. Also use subgroup instructions for (part of) the reduction when supported. Without this, the more expensive reductions would eat into the benefits of the larger workgroups.

Many models have some matrices with small m that don't launch enough work to fill the GPU (particularly for larger GPUs). Using larger workgroups helps. I think ggml-cuda already (mostly?) uses 128 threads per CTA.

As currently written, non-NVIDIA GPUs continue to use the same workgroup size, but I'm happy to change it based on perf testing. This change does make it so other GPUs will use subgroupAdd if it's supported, so hopefully that works everywhere and is not a slowdown, but it's easy enough to change if not.

5090 before:

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -n 128 -p 0 -r 10 --prio 1 -m c:\models\DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf -m c:\models\DeepSeek-R1-Distill-Llama-8B-Q6_K.gguf -m c:\models\DeepSeek-R1-Distill-Qwen-14B-Q4_K_M.gguf -m c:\models\Llama-3.2-1B.Q2_K.gguf -m c:\models\Llama-3.2-1B.Q3_K_S.gguf -m c:\models\llama-3.2-3b-instruct-q5_k_m.gguf -m c:\models\Qwen_Qwen3-30B-A3B-Q2_K.gguf -m c:\models\Qwen2.5-7B-Instruct-1M-Q2_K.gguf  -m c:\models\\deepseek-v2-lite-safetensors\deepseek-v2-lite-Q4_K_M.gguf -m c:\models\gpt-oss-20b-mxfp4.gguf -m c:\models\Phi-3-mini-4k-instruct-q4.gguf -m c:\models\llama-2-7b.Q4_0.gguf -m c:\models\llama-3.2-3b-instruct-q8_0.gguf -m c:\models\Mistral-22B-v0.2-Q4_K_M.gguf -m c:\models\nvidia_Llama-3_3-Nemotron-Super-49B-v1_5-Q4_K_S.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |        195.02 ± 1.76 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |        169.29 ± 2.36 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           tg128 |        109.69 ± 1.62 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        554.21 ± 1.76 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        504.32 ± 3.70 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        293.63 ± 2.23 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           tg128 |        178.06 ± 1.47 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           tg128 |        198.03 ± 3.13 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           tg128 |        231.73 ± 2.87 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        210.33 ± 5.19 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        289.75 ± 2.30 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           tg128 |        220.28 ± 7.71 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        266.37 ± 3.65 |
| llama ?B Q4_K - Medium         |  12.42 GiB |    22.24 B | Vulkan     |  99 |  1 |           tg128 |         79.17 ± 0.77 |
| deci 70B Q4_K - Small          |  26.66 GiB |    49.87 B | Vulkan     |  99 |  1 |           tg128 |         41.52 ± 0.20 |

5090 after:

ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |        197.43 ± 1.05 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |        172.04 ± 4.84 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           tg128 |        109.92 ± 2.11 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        617.50 ± 5.25 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        599.14 ± 6.84 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        305.70 ± 1.89 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           tg128 |        184.74 ± 1.52 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           tg128 |        198.91 ± 2.30 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           tg128 |        241.35 ± 6.19 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        219.92 ± 2.89 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        303.41 ± 1.79 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           tg128 |        222.24 ± 1.10 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |       268.12 ± 10.28 |
| llama ?B Q4_K - Medium         |  12.42 GiB |    22.24 B | Vulkan     |  99 |  1 |           tg128 |         80.42 ± 0.25 |
| deci 70B Q4_K - Small          |  26.66 GiB |    49.87 B | Vulkan     |  99 |  1 |           tg128 |         44.51 ± 0.15 |

4070 before:

ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         87.03 ± 0.86 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         68.46 ± 1.11 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           tg128 |         47.21 ± 0.25 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        399.50 ± 2.45 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        371.05 ± 1.20 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        155.52 ± 0.26 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           tg128 |        132.58 ± 3.77 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           tg128 |        102.32 ± 1.82 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           tg128 |       157.92 ± 12.35 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        119.43 ± 1.79 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        149.57 ± 0.97 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           tg128 |        101.02 ± 0.27 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        117.55 ± 0.31 |

4070 after:

ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         87.39 ± 0.60 |
| llama 8B Q6_K                  |   6.14 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         68.65 ± 0.71 |
| qwen2 14B Q4_K - Medium        |   8.37 GiB |    14.77 B | Vulkan     |  99 |  1 |           tg128 |         47.79 ± 0.05 |
| llama 1B Q2_K - Medium         | 546.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        400.11 ± 2.49 |
| llama 1B Q3_K - Small          | 604.50 MiB |     1.24 B | Vulkan     |  99 |  1 |           tg128 |        370.84 ± 2.06 |
| llama 3B Q5_K - Medium         |   2.16 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        156.91 ± 0.33 |
| qwen3moe 30B.A3B Q2_K - Medium |  10.15 GiB |    30.53 B | Vulkan     |  99 |  1 |           tg128 |        133.44 ± 4.52 |
| qwen2 7B Q2_K - Medium         |   2.80 GiB |     7.62 B | Vulkan     |  99 |  1 |           tg128 |        101.81 ± 1.36 |
| deepseek2 16B Q4_K - Medium    |   9.65 GiB |    15.71 B | Vulkan     |  99 |  1 |           tg128 |        160.82 ± 8.35 |
| gpt-oss ?B MXFP4 MoE           |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        120.24 ± 4.80 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        150.66 ± 1.05 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     |  99 |  1 |           tg128 |        102.45 ± 0.24 |
| llama 3B Q8_0                  |   3.18 GiB |     3.21 B | Vulkan     |  99 |  1 |           tg128 |        118.44 ± 0.40 |

Here's a comparison from GGML_VK_PERF_LOGGER running gpt-oss-20b-mxfp4.gguf, the buckets with small m are significantly cheaper.

5090 before:
MUL_MAT_ID_VEC mxfp4 m=2880 n=4 k=2880: 72 x 19.088 us (3475.67 GFLOPS/s)
MUL_MAT_VEC f32 m=32 n=1 k=2880: 24 x 8.028 us (22.9557 GFLOPS/s)
MUL_MAT_VEC q8_0 m=201088 n=1 k=2880: 1 x 416.32 us (2781.67 GFLOPS/s)
MUL_MAT_VEC q8_0 m=2880 n=1 k=4096: 24 x 10.981 us (2148.2 GFLOPS/s)
MUL_MAT_VEC q8_0 m=4096 n=1 k=2880: 24 x 12.065 us (1955.09 GFLOPS/s)
MUL_MAT_VEC q8_0 m=512 n=1 k=2880: 48 x 6.165 us (478.256 GFLOPS/s)

5090 after:
MUL_MAT_ID_VEC mxfp4 m=2880 n=4 k=2880: 72 x 19.126 us (3468.73 GFLOPS/s)
MUL_MAT_VEC f32 m=32 n=1 k=2880: 24 x 5.533 us (33.3051 GFLOPS/s)
MUL_MAT_VEC q8_0 m=201088 n=1 k=2880: 1 x 421.696 us (2746.21 GFLOPS/s)
MUL_MAT_VEC q8_0 m=2880 n=1 k=4096: 24 x 10.769 us (2190.49 GFLOPS/s)
MUL_MAT_VEC q8_0 m=4096 n=1 k=2880: 24 x 11.722 us (2012.24 GFLOPS/s)
MUL_MAT_VEC q8_0 m=512 n=1 k=2880: 48 x 4.121 us (715.45 GFLOPS/s)

@jeffbolznv jeffbolznv requested a review from 0cc4m as a code owner August 15, 2025 19:45
@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Aug 15, 2025
@characharm
Copy link
Contributor

characharm commented Aug 15, 2025

9070xt before:

model size params backend ngl fa test t/s
gpt-oss ?B MXFP4 MoE 11.27 GiB 20.91 B RPC,Vulkan 99 1 tg128 146.16 ± 2.89

9070xt after:

model size params backend ngl fa test t/s
gpt-oss ?B MXFP4 MoE 11.27 GiB 20.91 B Vulkan 99 1 tg128 150.51 ± 1.56

a770 before:

model size params backend ngl test t/s
gpt-oss ?B MXFP4 MoE 11.27 GiB 20.91 B RPC,Vulkan 99 tg128 51.30 ± 0.12

a770 after:

model size params backend ngl test t/s
gpt-oss ?B MXFP4 MoE 11.27 GiB 20.91 B Vulkan 99 tg128 51.00 ± 0.09

On Intel with FA, the performance is expectedly lower in both cases: 48.11 ± 0 (PR) and 48.32 ± 0 (master).
On AMD without FA 152.64 ± 3.84 (PR), and 151.96 ± 5.44 (master).

@jeffbolznv
Copy link
Collaborator Author

This was with the change as is, right? Not enabling larger workgroups?

@characharm
Copy link
Contributor

Yes, I copied the benchmarking flags from your results. Should I run it any other way?

@jeffbolznv
Copy link
Collaborator Author

This is a good data point, thanks. It means subgroupAdd is a speedup. The other thing to try would be to enable the code at https://github.com/ggml-org/llama.cpp/pull/15355/files#diff-35a5049d5eebe22eda1e0d661bd87639b31aafeba62deeaaaca9c13ec3e71d11R4433 unconditionally.

@0cc4m
Copy link
Collaborator

0cc4m commented Aug 16, 2025

With this diff:

diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index e255f68c2..19459c737 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -2766,7 +2766,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_16 : (subgroup_size_16 * 4);
         uint32_t wg_size_subgroup   = (w == DMMV_WG_SIZE_SUBGROUP) ? device->subgroup_size : (device->subgroup_size * 4);

-        bool s = device->subgroup_add;
+        const bool s = device->subgroup_add && device->architecture != vk_device_architecture::AMD_GCN;

         for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
             ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1),  arr_dmmv_f32_f32_f32_len[s],  arr_dmmv_f32_f32_f32_data[s],  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
@@ -4406,7 +4406,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *

     // heuristic to choose workgroup size
     uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
-    if (ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA) {
+    if (ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
         // Prefer larger workgroups when M is small, to spread the work out more
         // and keep more SMs busy.
         // q6_k seems to prefer small workgroup size even for "medium" values of M.

Nvidia RTX 3090

Master:

model size params backend ngl test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 tg128 125.96 ± 8.82
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 103.80 ± 1.19

PR:

model size params backend ngl test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 tg128 123.32 ± 8.12
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 107.39 ± 1.86

Intel A770

Master:

model size params backend ngl test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 tg128 40.07 ± 0.59
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 21.35 ± 0.21

PR:

model size params backend ngl test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 tg128 43.51 ± 0.75
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 29.24 ± 0.41

AMD RX 6800 XT

Master:

model size params backend ngl test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 tg128 98.15 ± 0.06
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 88.22 ± 0.08

PR:

model size params backend ngl test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 tg128 98.09 ± 0.09
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 88.52 ± 0.07

AMD Radeon Pro VII

Master:

model size params backend ngl test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 tg128 76.58 ± 0.11
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 67.03 ± 0.30

PR:

model size params backend ngl test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 99 tg128 76.45 ± 0.19
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 67.12 ± 0.20

@characharm
Copy link
Contributor

diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp
--- a/ggml-vulkan.cpp
+++ b/ggml-vulkan.cpp
@@ -4433,15 +4433,13 @@
-    if (ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA) {
-        // Prefer larger workgroups when M is small, to spread the work out more
-        // and keep more SMs busy.
-        // q6_k seems to prefer small workgroup size even for "medium" values of M.
-        if (a_type == GGML_TYPE_Q6_K) {
-            if (m < 4096 && k >= 1024) {
-                dmmv_wg = DMMV_WG_SIZE_LARGE;
-            }
-        } else {
-            if (m <= 8192 && k >= 1024) {
-                dmmv_wg = DMMV_WG_SIZE_LARGE;
-            }
-        }
-    }
+    // Always apply NVIDIA heuristic
+    if (a_type == GGML_TYPE_Q6_K) {
+        if (m < 4096 && k >= 1024) {
+            dmmv_wg = DMMV_WG_SIZE_LARGE;
+        }
+    } else {
+        if (m <= 8192 && k >= 1024) {
+            dmmv_wg = DMMV_WG_SIZE_LARGE;
+        }
+    }

9070xt:

model size params backend ngl fa test t/s
gpt-oss ?B MXFP4 MoE 11.27 GiB 20.91 B Vulkan 99 1 tg128 151.80 ± 0.71

a770:

model size params backend ngl fa test t/s
gpt-oss ?B MXFP4 MoE 11.27 GiB 20.91 B Vulkan 99 1 tg128 47.93 ± 0.09

master:
9070xt:

model size params backend ngl fa test t/s
gpt-oss ?B MXFP4 MoE 11.27 GiB 20.91 B RPC,Vulkan 99 1 tg128 147.16 ± 1.57

a770:

model size params backend ngl fa test t/s
gpt-oss ?B MXFP4 MoE 11.27 GiB 20.91 B RPC,Vulkan 99 1 tg128 48.28 ± 0.08

@netrunnereve
Copy link
Collaborator

On my 470 there's barely any difference with subgroup adds turned on or off, but the Nvidia small m fix makes my inference 10% slower when I forcibly enable it on my card.

@jeffbolznv
Copy link
Collaborator Author

Thanks for all the testing. I've applied @0cc4m's diff.

@0cc4m
Copy link
Collaborator

0cc4m commented Aug 17, 2025

Can you rebase this?

jeffbolznv and others added 2 commits August 17, 2025 09:45
Also use subgroup instructions for (part of) the reduction when supported.
Without this, the more expensive reductions would eat into the benefits of
the larger workgroups.
@jeffbolznv
Copy link
Collaborator Author

Rebased.

@0cc4m 0cc4m merged commit 21c17b5 into ggml-org:master Aug 17, 2025
46 of 47 checks passed
@LostRuins
Copy link
Collaborator

Hi everyone, not sure if others experience the same issue too, but this commit raises the time it takes for me to compile ggml-vulkan.cpp from 3 minutes (in commit 19f4dec) to 14 minutes (this).

@jeffbolznv
Copy link
Collaborator Author

What is your build environment? It takes about 30s to build ggml-vulkan.cpp on MSVC (which I already consider too high) but I didn't see it get worse from this change.

I've been wanting to split ggml_vk_load_shaders into its own file because heavy macro use tends to be pretty expensive to compile. But it always seems like a bad time (conflicts).

@LostRuins
Copy link
Collaborator

LostRuins commented Aug 19, 2025

Hi @jeffbolznv, I kinda figured it out, TLDR at bottom.

First, here's a simple way to repro this. Note that this is just to demonstrate the exact way on how to reproduce this issue, it's not a complete binary build.

I am using g++ from w64devkit and building on Windows 10.
g++ -v gives gcc version 12.2.0 (GCC)

First, download and install the latest Vulkan SDK for windows at https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe

Clone the llama.cpp project

git clone https://github.com/ggml-org/llama.cpp.git
cd llama.cpp

Build the Vulkan shader generator binary (no issue here)

g++ -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp -o vulkan-shaders-gen

Build the Vulkan shaders, everything enabled (no issues here)

vulkan-shaders-gen --glslc glslc --input-dir ggml/src/ggml-vulkan/vulkan-shaders --target-hpp ggml/src/ggml-vulkan-shaders.hpp --target-cpp ggml/src/ggml-vulkan-shaders.cpp

Verify we have vulkan precompiled shaders
image

and now run the compile for ggml-vulkan.cpp , measuring time taken...

g++ -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT -I ggml/include -Iggml/src -IC:\VulkanSDK\1.4.321.1\Include -O3 -fno-finite-math-only -std=c++17 -c ggml/src/ggml-vulkan/ggml-vulkan.cpp -o ggml-vulkan.o

and...

real    11m 19.76s
user    0m 0.01s
sys     0m 0.01s

Absolutely awful. Now we make just one simple change, removing -O3....

g++ -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT -I ggml/include -Iggml/src -IC:\VulkanSDK\1.4.321.1\Include -fno-finite-math-only -std=c++17 -c ggml/src/ggml-vulkan/ggml-vulkan.cpp -o ggml-vulkan.o

and we get

real    0m 10.82s
user    0m 0.00s
sys     0m 0.01s

Wow ten seconds, literally over 60 times faster build speed.

I tried again with -O2 and -O1 with the same terrible results, somehow anything after this commit is absolute destroying compile times. Before this commit, using compiler optimization flags work fine. But with -O0 I get the same blazing fast compilation that you do.

Question: Are g++ compiler optimization flags really necessary for this target? I suppose I could just exclude it for normal builds.

TLDR: This happen when using any -O1 / -O2 / -O3 compiler optimization flags after this commit, which slows build times by 60x. Doesn't matter which optimization level.

@jeffbolznv
Copy link
Collaborator Author

We need at least -O2.

Can you try removing all the +std::to_string(w)+"_"+std::to_string(i+1) in the new code? I wonder if the compiler is spending a bunch of time folding this.

@LostRuins
Copy link
Collaborator

LostRuins commented Aug 19, 2025

MUCH better. I replaced all +std::to_string(w)+"_"+std::to_string(i+1) with the empty string, and ran on -O2

D:\MainApplications\MinGW\w64devkit\bin\time.exe g++ -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT -I ggml/include -Iggml/src -IC:\VulkanSDK\1.4.321.1\Include -fno-finite-math-only -O2 -std=c++17 -c ggml/src/ggml-vulkan/ggml-vulkan.cpp -o ggml-vulkan.o

Result:

real    1m 11.16s
user    0m 0.00s
sys     0m 0.01s

So we have gone from 12 minutes to 1 minute. It's still slightly slower than the previous commit, but this is a massive improvement. I think you have cracked the case.

Edit: I also did -O3 again, and got

real    1m 59.24s
user    0m 0.00s
sys     0m 0.04s

So about 2 minutes is a more fair comparison

@jeffbolznv
Copy link
Collaborator Author

Cool, I'll make a change to simplify the strings. We used to require unique strings (names were used to lookup pipelines) but now they're just informative.

@jeffbolznv
Copy link
Collaborator Author

PR to shorten the strings: #15431

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants