Skip to content

Conversation

etasnadi
Copy link
Contributor

@etasnadi etasnadi commented Sep 18, 2025

I am adding this, because the current conv2d alg #15635 seems to underutilize the GPU -- the Vulkan version #14316 & #14933 is 8-10 times faster on my device. Additionally, the Tensor Cores extension #15813 of the previous alg also seems to be slower than this.

There is another CUDA conv2d proposal that could be related #15805.

Furthermore, this version introduces bank conflict reduction that is not added to Vulkan yet. It seems to be effective on large problems. I expect that this version will be even more efficient than the Vulkan backend.

I do not support f16 yet, a future contribution might do that. Currently this alg will be used when for f32 inputs, otherwise it falls back to the previous implementation. GGML_CUDA_USE_LEGACY_CONV forces to use the previous (probably slower) implementation.

Perf of previous on RTX 2060:

$ GGML_CUDA_USE_LEGACY_CONV=1 ./bin/test-backend-ops -o CONV_2D -b CUDA0 perf

CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      4 runs - 327359.25 us/run - 137.42 GFLOP/run - 419.79 GFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                2992 runs -   357.38 us/run - 133.69 MFLOP/run - 374.09 GFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                2948 runs -   362.97 us/run - 135.78 MFLOP/run - 374.08 GFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                139264 runs -     7.53 us/run - 642.82 kFLOP/run -  85.42 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                14358 runs -    92.52 us/run -  20.90 MFLOP/run - 225.87 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                24576 runs -    49.88 us/run -   2.78 MFLOP/run -  55.83 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 4489 runs -   386.69 us/run -  22.28 MFLOP/run -  57.61 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3468 runs -   328.26 us/run - 115.40 MFLOP/run - 351.56 GFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 436 runs -  2535.99 us/run - 923.24 MFLOP/run - 364.05 GFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      220 runs -  4808.02 us/run -   1.85 GFLOP/run - 384.54 GFLOPS

Perf of proposed:

 
$ ./bin/test-backend-ops -o CONV_2D -b CUDA0 perf
CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     41 runs - 24946.49 us/run - 137.42 GFLOP/run -   5.51 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               20944 runs -    49.06 us/run - 133.69 MFLOP/run -   2.73 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               14740 runs -    69.12 us/run - 135.78 MFLOP/run -   1.96 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                204800 runs -     5.04 us/run - 642.82 kFLOP/run - 127.47 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                52646 runs -    20.26 us/run -  20.90 MFLOP/run -   1.03 TFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                81920 runs -    12.37 us/run -   2.78 MFLOP/run - 225.14 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                13467 runs -    85.40 us/run -  22.28 MFLOP/run - 260.88 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               24276 runs -    42.68 us/run - 115.40 MFLOP/run -   2.70 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3270 runs -   311.18 us/run - 923.24 MFLOP/run -   2.97 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     2200 runs -   462.72 us/run -   1.85 GFLOP/run -   4.00 TFLOPS

* Extra: reduces bank conflicts
@etasnadi etasnadi changed the title Vulkan direct conv ported to CUDA ggml-cude: Vulkan direct conv ported to CUDA Sep 18, 2025
@etasnadi etasnadi changed the title ggml-cude: Vulkan direct conv ported to CUDA ggml-cuda: Vulkan direct conv 2D ported to CUDA Sep 18, 2025
@etasnadi
Copy link
Contributor Author

@Green-Sky Can you check it?

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Sep 18, 2025
@Green-Sky
Copy link
Collaborator

@Green-Sky Can you check it?

sd.cp relies exclusively on f16 kernels.

@bssrdf
Copy link
Contributor

bssrdf commented Sep 18, 2025

May I suggest using ggml_cuda_cast<float> to add support for fp16? It won't be faster, but at least @Green-Sky can test in sd.cpp.

@Green-Sky
Copy link
Collaborator

May I suggest using ggml_cuda_cast<float> to add support for fp16? It won't be faster, but at least @Green-Sky can test in sd.cpp.

This would make things easy for me, yes.

BTW, forgot to thank you @etasnadi for working on this :)

... even though we now have 3 competing prs, more or less.

@bssrdf
Copy link
Contributor

bssrdf commented Sep 18, 2025

May I suggest using ggml_cuda_cast<float> to add support for fp16? It won't be faster, but at least @Green-Sky can test in sd.cpp.

This would make things easy for me, yes.

BTW, forgot to thank you @etasnadi for working on this :)

... even though we now have 3 competing prs, more or less.

I'll close my PR. This one is way better:)

@mnehete32
Copy link
Contributor

May I suggest using ggml_cuda_cast<float> to add support for fp16? It won't be faster, but at least @Green-Sky can test in sd.cpp.

This would make things easy for me, yes.
BTW, forgot to thank you @etasnadi for working on this :)
... even though we now have 3 competing prs, more or less.

I'll close my PR. This one is way better:)

Same, closing mine too.

@etasnadi
Copy link
Contributor Author

etasnadi commented Sep 18, 2025 via email

@mnehete32
Copy link
Contributor

mnehete32 commented Sep 18, 2025

I’m new to CUDA but I’d love to give this a shot @Green-Sky @etasnadi if the fp16 isn’t super urgent?, I can take a crack at it in the next week or two.

Maybe you can add a parallel pr based on this for f16? https://etasnadi.com Proton Mail Android alkalmazásból küldve

-------- Eredeti üzenet -------- 2025. 09. 18. 19:32-kor, Erik Scholz ezt írta:
Green-Sky left a comment [(ggml-org/llama.cpp#16088)](#16088 (comment)) > May I suggest using ggml_cuda_cast to add support for fp16? It won't be faster, but at least @.(https://github.com/Green-Sky) can test in sd.cpp. This would make things easy for me, yes. BTW, forgot to thank you @.(https://github.com/etasnadi) for working on this :) ... even though we now have 3 competing prs, more or less. — Reply to this email directly, [view it on GitHub](#16088 (comment)), or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give me a list of what parts of the code you changed relative to the Vulkan version, if any? Some things like fastdiv and how to retrieve the SM count have equivalents in the CUDA backend. But if this is just a copy-paste of the Vulkan code I would preferably change as little as possible.

@etasnadi
Copy link
Contributor Author

Can you give me a list of what parts of the code you changed relative to the Vulkan version, if any? Some things like fastdiv and how to retrieve the SM count have equivalents in the CUDA backend. But if this is just a copy-paste of the Vulkan code I would preferably change as little as possible.

Can you give any reference to doing fastdiv/sm_count() in proper ggml-cuda way? I will refactor then.

Only the necessary things are changed compared to Vulkan, but they are significant

  • Vulkan has specialization constants what is missing in CUDA so the kernel selection/initialization is different.
  • The unrolls in the kernel are also different and
  • The coopmat2 api we use in Vulkan is much different than the mma APIs found in CUDA, so that part is completely removed before porting.
  • Shmem size check is not needed, etc.
  • The core algorithm is mostly the same, but it is augmented with different shmem indexing to minimize bank conflicts, this is not present in the Vulkan kernel yet.

IMO it already changes as little as possible compared to Vulkan.

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Sep 18, 2025

For a fastdiv example, look at e.g. binbcast.cu, get the SM count via ggml_cuda_info().devices[ggml_cuda_get_device()].nsm.

@etasnadi
Copy link
Contributor Author

etasnadi commented Sep 19, 2025

@bssrdf Do you want to contribute the ggml-cuda conformant fastdiv as a patch to my branch or in a separate PR so everyone gets the authorship for conv2d for their effort?

@bssrdf
Copy link
Contributor

bssrdf commented Sep 19, 2025

@bssrdf Do you want to contribute the ggml-cuda conformant fastdiv as a patch to my branch or in a separate PR so everyone gets the authorship for conv2d for their effort?

@etasnadi, I can give a try. Will do a patch on your branch.
@Green-Sky, you can try etasnadi#1 in sd.cpp.

@Green-Sky
Copy link
Collaborator

In the f16 pr by @bssrdf , we found that sd.cpp crashes with this pr. I double checked by forcing sd.cpp to use f32 for the kernel without @bssrdf 's pr.

[INFO ] stable-diffusion.cpp:2166 - generating 2 latent images completed, taking 84.68s
[INFO ] stable-diffusion.cpp:2169 - decoding 2 latents
[INFO ] ggml_extend.hpp:1648 - vae offload params ( 94.47 MB, 140 tensors) to runtime backend (CUDA0), taking 0.01s
[DEBUG] ggml_extend.hpp:1550 - vae compute buffer size: 1928.64 MB(VRAM)
[ERROR] ggml_extend.hpp:71   - CUDA error: an illegal memory access was encountered
[ERROR] ggml_extend.hpp:71   -   current device: 0, in function ggml_backend_cuda_synchronize at /build/pqlxhx4zgf1dr2wyx5qdm2gb2b6c73sf-source/ggml/src/ggml-cuda/ggml-cuda.cu:2628
[ERROR] ggml_extend.hpp:71   -   cudaStreamSynchronize(cuda_ctx->stream())
/build/pqlxhx4zgf1dr2wyx5qdm2gb2b6c73sf-source/ggml/src/ggml-cuda/ggml-cuda.cu:88: CUDA error
#4  0x000000000057b955 in ggml_backend_cuda_synchronize (backend=<optimized out>) at ggml/src/ggml-cuda/ggml-cuda.cu:2628
#5  0x0000000000a40714 in ggml_backend_synchronize (backend=backend@entry=0x22775310) at ggml/src/ggml-backend.cpp:327
#6  0x0000000000a40a6b in ggml_backend_graph_compute (backend=0x22775310, cgraph=<optimized out>) at ggml/src/ggml-backend.cpp:353

$ result/bin/sd -m models/CyberRealistic_V9_FP16.safetensors --sampling-method dpm++2m --scheduler karras --cfg-scale 5 -W 768 -H 1024 --diffusion-fa --steps 20 -b 2 -v -n "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" -p "a lovely cat" --vae-conv-direct --offload-to-cpu

@bssrdf
Copy link
Contributor

bssrdf commented Sep 19, 2025

In the f16 pr by @bssrdf , we found that sd.cpp crashes with this pr. I double checked by forcing sd.cpp to use f32 for the kernel without @bssrdf 's pr.

[INFO ] stable-diffusion.cpp:2166 - generating 2 latent images completed, taking 84.68s
[INFO ] stable-diffusion.cpp:2169 - decoding 2 latents
[INFO ] ggml_extend.hpp:1648 - vae offload params ( 94.47 MB, 140 tensors) to runtime backend (CUDA0), taking 0.01s
[DEBUG] ggml_extend.hpp:1550 - vae compute buffer size: 1928.64 MB(VRAM)
[ERROR] ggml_extend.hpp:71   - CUDA error: an illegal memory access was encountered
[ERROR] ggml_extend.hpp:71   -   current device: 0, in function ggml_backend_cuda_synchronize at /build/pqlxhx4zgf1dr2wyx5qdm2gb2b6c73sf-source/ggml/src/ggml-cuda/ggml-cuda.cu:2628
[ERROR] ggml_extend.hpp:71   -   cudaStreamSynchronize(cuda_ctx->stream())
/build/pqlxhx4zgf1dr2wyx5qdm2gb2b6c73sf-source/ggml/src/ggml-cuda/ggml-cuda.cu:88: CUDA error
#4  0x000000000057b955 in ggml_backend_cuda_synchronize (backend=<optimized out>) at ggml/src/ggml-cuda/ggml-cuda.cu:2628
#5  0x0000000000a40714 in ggml_backend_synchronize (backend=backend@entry=0x22775310) at ggml/src/ggml-backend.cpp:327
#6  0x0000000000a40a6b in ggml_backend_graph_compute (backend=0x22775310, cgraph=<optimized out>) at ggml/src/ggml-backend.cpp:353

$ result/bin/sd -m models/CyberRealistic_V9_FP16.safetensors --sampling-method dpm++2m --scheduler karras --cfg-scale 5 -W 768 -H 1024 --diffusion-fa --steps 20 -b 2 -v -n "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" -p "a lovely cat" --vae-conv-direct --offload-to-cpu

@Green-Sky, it may be due to my changes. I'll investigate.

@Green-Sky
Copy link
Collaborator

In the f16 pr by @bssrdf , we found that sd.cpp crashes with this pr. I double checked by forcing sd.cpp to use f32 for the kernel without @bssrdf 's pr.

@Green-Sky, it may be due to my changes. I'll investigate.

I redid the test without your changes, and the issue was the same, as I state right there.

@bssrdf
Copy link
Contributor

bssrdf commented Sep 19, 2025

In the f16 pr by @bssrdf , we found that sd.cpp crashes with this pr. I double checked by forcing sd.cpp to use f32 for the kernel without @bssrdf 's pr.

@Green-Sky, it may be due to my changes. I'll investigate.

I redid the test without your changes, and the issue was the same, as I state right there.

@Green-Sky, without my change, it will fall back to using the slow direct version. Did it even fail there?

@Green-Sky
Copy link
Collaborator

Green-Sky commented Sep 19, 2025

In the f16 pr by @bssrdf , we found that sd.cpp crashes with this pr. I double checked by forcing sd.cpp to use f32 for the kernel without @bssrdf 's pr.

@Green-Sky, it may be due to my changes. I'll investigate.

I redid the test without your changes, and the issue was the same, as I state right there.

@Green-Sky, without my change, it will fall back to using the slow direct version. Did it even fail there?

I patched sd.cpp to cast the kernel to f32, so it would fall back.
(ggml_cast w at the direct callsite)

-    x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
+    x = ggml_conv_2d_direct(ctx, ggml_cast(ctx, w, GGML_TYPE_F32), x, s0, s1, p0, p1, d0, d1);

I guess I should have made my report more wordy (:

edit: fun side note, it seems like with the current naive fallback and f32 kernel, sd.cpp vae decode is ever so slightly faster (~35s vs ~33s)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants