Skip to content

Conversation

@ssweens
Copy link
Contributor

@ssweens ssweens commented Sep 17, 2025

Following up on part of issue to resolve #15974

llama-bench was missing --devices which was recently enhanced in the main apps like server. This PR looks to bring that configuration option to benchmarking, close to how server, etc implements.

  • Provided --device option, so that one or more devices can be used from tensor splits.
  • Provided --list-devices for convenience rather than having to switch to llama-server.
  • Merged with the rpc device handling. RPC device handling was heavily intertwined with local device handling. Plus it already had a FIXME comment in the code. So the shortest path felt like going ahead and preserving/providing rpc as equivalent device peers rather than pulling them apart. Felt this expanded the PR a bit, but within reason and the other untangling options were likely worse or more risky.
  • Doc update, including a quick cleanup to ensure the new -n-cpu-moe option is captured in the doc while I was in there.

Tested on Mac and Linux

Using different devices including RPC, benchmark each:

llama-bench -ngl 999 -m ~/models/gguf/Qwen3-4B-128K-Q4_K_M.gguf --rpc 192.168.1.59:50052 -dev ROCm0,Vulkan0,'RPC[192.168.1.59:50052]'
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
load_backend: loaded CUDA backend from /base/llama.cpp/build/bin/libggml-cuda.so
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
load_backend: loaded ROCm backend from /base/llama.cpp/build/bin/libggml-hip.so
load_backend: loaded RPC backend from /base/llama.cpp/build/bin/libggml-rpc.so
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Radeon 8060S Graphics (RADV GFX1151) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: KHR_coopmat
load_backend: loaded Vulkan backend from /base/llama.cpp/build/bin/libggml-vulkan.so
load_backend: loaded CPU backend from /base/llama.cpp/build/bin/libggml-cpu.so
| model                          |       size |     params | backend    | ngl | dev          |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------ | --------------: | -------------------: |
| qwen3 4B Q4_K - Medium         |   2.32 GiB |     4.02 B | CUDA,ROCm,RPC,Vulkan | 999 | ROCm0        |           pp512 |       1608.23 ± 9.44 |
| qwen3 4B Q4_K - Medium         |   2.32 GiB |     4.02 B | CUDA,ROCm,RPC,Vulkan | 999 | ROCm0        |           tg128 |         63.41 ± 0.01 |
| qwen3 4B Q4_K - Medium         |   2.32 GiB |     4.02 B | CUDA,ROCm,RPC,Vulkan | 999 | Vulkan0      |           pp512 |       1199.15 ± 1.34 |
| qwen3 4B Q4_K - Medium         |   2.32 GiB |     4.02 B | CUDA,ROCm,RPC,Vulkan | 999 | Vulkan0      |           tg128 |         77.48 ± 0.03 |
| qwen3 4B Q4_K - Medium         |   2.32 GiB |     4.02 B | CUDA,ROCm,RPC,Vulkan | 999 | RPC[192.168.1.59:50052] |           pp512 |        619.81 ± 0.21 |
| qwen3 4B Q4_K - Medium         |   2.32 GiB |     4.02 B | CUDA,ROCm,RPC,Vulkan | 999 | RPC[192.168.1.59:50052] |           tg128 |         28.77 ± 1.18 |

Combined all devices, benchmark tensor split:

llama-bench -ngl 999 -m ~/models/gguf/Qwen3-4B-128K-Q4_K_M.gguf --rpc 192.168.1.59:50052 -dev ROCm0/Vulkan0/'RPC[192.168.1.59:50052]' -ts 1/2/1
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
load_backend: loaded CUDA backend from /base/llama.cpp/build/bin/libggml-cuda.so
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
load_backend: loaded ROCm backend from /base/llama.cpp/build/bin/libggml-hip.so
load_backend: loaded RPC backend from /base/llama.cpp/build/bin/libggml-rpc.so
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Radeon 8060S Graphics (RADV GFX1151) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: KHR_coopmat
load_backend: loaded Vulkan backend from /base/llama.cpp/build/bin/libggml-vulkan.so
load_backend: loaded CPU backend from /base/llama.cpp/build/bin/libggml-cpu.so
| model                          |       size |     params | backend    | ngl | dev          | ts           |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------ | ------------ | --------------: | -------------------: |
| qwen3 4B Q4_K - Medium         |   2.32 GiB |     4.02 B | CUDA,ROCm,RPC,Vulkan | 999 | ROCm0/Vulkan0/RPC[192.168.1.59:50052] | 1.00/2.00/1.00 |           pp512 |       984.98 ± 10.09 |
| qwen3 4B Q4_K - Medium         |   2.32 GiB |     4.02 B | CUDA,ROCm,RPC,Vulkan | 999 | ROCm0/Vulkan0/RPC[192.168.1.59:50052] | 1.00/2.00/1.00 |           tg128 |         38.44 ± 0.40 |

Hopefully this helps get llama-bench more up-to-date with the core tool capabilities.

- Support --devices same as llama-server
- Provide for benchmarking different device combinations
- Include --list-devices like llama-server for convenience
- aimed to mimic the server as much as possible
- handle dup device listing with RPC
- added the recently added n-cpu-moe option to the docs while in there
@ssweens
Copy link
Contributor Author

ssweens commented Sep 18, 2025

Fixed the examples so the cmds aren't hidden any more

Copy link
Member

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would a better way to deal with the RPC devices:

  • Change the -rpc option to just register the devices at startup
  • Remove the list of rpc servers from cmd_params_instance entirely
  • If the user wants to test different subsets of RPC devices, they can use -dev

* rpc servers unify with other devices earlier, simplifying code
* --list-devices made stateless and simpler
* various cleanup
@ssweens
Copy link
Contributor Author

ssweens commented Sep 19, 2025

I think this would a better way to deal with the RPC devices:

* Change the `-rpc` option to just register the devices at startup

* Remove the list of rpc servers from `cmd_params_instance` entirely

* If the user wants to test different subsets of RPC devices, they can use `-dev`

The RPC handling suggestion is nice and tightens things up a good deal IMO.

@ssweens ssweens requested a review from slaren September 19, 2025 04:15
@slaren slaren merged commit be79d9f into ggml-org:master Sep 19, 2025
47 of 50 checks passed
@slaren
Copy link
Member

slaren commented Sep 19, 2025

Thank you!

@ssweens
Copy link
Contributor Author

ssweens commented Sep 19, 2025

Thank you!

Thank you all for all the good work. Happy to help.

@ssweens ssweens deleted the llama-bench-add-devices branch September 19, 2025 23:54
@JohannesGaessler
Copy link
Collaborator

Thank you, this is a very useful PR. I'm noticing that on a multi GPU system the "GPU info" does not take the -devices argument into account.

Example log
./bench --model models/opt/${model_name}-${quantization}.gguf -dev CUDA1 -o json                                                                               [23:20:20]
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: Tesla P40, compute capability 6.1, VMM: yes
  Device 1: Tesla P40, compute capability 6.1, VMM: yes
load_backend: loaded CUDA backend from /home/johannesg/Projects/llama.cpp/build/bin/libggml-cuda.so
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 ROCm devices:
  Device 0: AMD Radeon RX 6800, gfx1030 (0x1030), VMM: no, Wave Size: 32
  Device 1: AMD Instinct MI60 / MI50, gfx906:sramecc+:xnack- (0x906), VMM: no, Wave Size: 64
load_backend: loaded ROCm backend from /home/johannesg/Projects/llama.cpp/build/bin/libggml-hip.so
load_backend: loaded CPU backend from /home/johannesg/Projects/llama.cpp/build/bin/libggml-cpu-haswell.so
[
  {
    "build_commit": "28baac9c9",
    "build_number": 6530,
    "cpu_info": "Intel(R) Xeon(R) CPU E5-2683 v4 @ 2.10GHz",
    "gpu_info": "Tesla P40, Tesla P40, AMD Radeon RX 6800, AMD Instinct MI60 / MI50",
    "backends": "CUDA,ROCm",
    "model_filename": "models/opt/llama_3-8b-q4_0.gguf",
    "model_type": "llama 8B Q4_0",
    "model_size": 4653375488,
    "model_n_params": 8030261248,
    "n_batch": 2048,
    "n_ubatch": 512,
    "n_threads": 16,
    "cpu_mask": "0x0",
    "cpu_strict": false,
    "poll": 50,
    "type_k": "f16",
    "type_v": "f16",
    "n_gpu_layers": 99,
    "n_cpu_moe": 0,
    "split_mode": "layer",
    "main_gpu": 0,
    "no_kv_offload": false,
    "flash_attn": false,
    "devices": "CUDA1",
    "tensor_split": "0.00",
    "tensor_buft_overrides": "none",
    "use_mmap": true,
    "embeddings": false,
    "no_op_offload": 0,
    "n_prompt": 512,
    "n_gen": 0,
    "n_depth": 0,
    "test_time": "2025-09-22T21:20:43Z",
    "avg_ns": 526051936,
    "stddev_ns": 671573,
    "avg_ts": 973.289197,
    "stddev_ts": 1.241021,
    "samples_ns": [ 526317712, 527105069, 525545811, 525509902, 525781187 ],
    "samples_ts": [ 972.796, 971.343, 974.225, 974.292, 973.789 ]
  },
  {
    "build_commit": "28baac9c9",
    "build_number": 6530,
    "cpu_info": "Intel(R) Xeon(R) CPU E5-2683 v4 @ 2.10GHz",
    "gpu_info": "Tesla P40, Tesla P40, AMD Radeon RX 6800, AMD Instinct MI60 / MI50",
    "backends": "CUDA,ROCm",
    "model_filename": "models/opt/llama_3-8b-q4_0.gguf",
    "model_type": "llama 8B Q4_0",
    "model_size": 4653375488,
    "model_n_params": 8030261248,
    "n_batch": 2048,
    "n_ubatch": 512,
    "n_threads": 16,
    "cpu_mask": "0x0",
    "cpu_strict": false,
    "poll": 50,
    "type_k": "f16",
    "type_v": "f16",
    "n_gpu_layers": 99,
    "n_cpu_moe": 0,
    "split_mode": "layer",
    "main_gpu": 0,
    "no_kv_offload": false,
    "flash_attn": false,
    "devices": "CUDA1",
    "tensor_split": "0.00",
    "tensor_buft_overrides": "none",
    "use_mmap": true,
    "embeddings": false,
    "no_op_offload": 0,
    "n_prompt": 0,
    "n_gen": 128,
    "n_depth": 0,
    "test_time": "2025-09-22T21:20:47Z",
    "avg_ns": 2378770702,
    "stddev_ns": 575513,
    "avg_ts": 53.809308,
    "stddev_ts": 0.012949,
    "samples_ns": [ 2379264250, 2378155745, 2378179546, 2379333728, 2378920244 ],
    "samples_ts": [ 53.7981, 53.8232, 53.8227, 53.7966, 53.8059 ]
  }
]

This doesn't matter if you just want to print a markdown table to console but for SQL/JSON/CSV output this means you cannot (automatically) associate the benchmark runs with the GPUs that were used. I think the correct behavior would be to populate GPU info only with those devices that were actually used.

@JohannesGaessler
Copy link
Collaborator

Or maybe it would make more sense to add a new property like device_info that explicitly lists the names of the used devices?

@slaren
Copy link
Member

slaren commented Sep 24, 2025

Yes, the backend and GPU info fields should be updated to consider the devices actually being used. This might require adding a new API to llama.cpp to obtain the list of devices used for a model, to be able to tell what devices are used by default, when --device is not used.

struct pushed a commit to struct/llama.cpp that referenced this pull request Sep 26, 2025
* * llama-bench: add --devices support
- Support --devices same as llama-server
- Provide for benchmarking different device combinations
- Include --list-devices like llama-server for convenience

* fix: field display ordering restored

* fix: integrated the rpc devices
- aimed to mimic the server as much as possible

* cleanup: defaults for list-devices
- handle dup device listing with RPC

* cleanup: remove dup device load calls

* docs: update llama-bench
- added the recently added n-cpu-moe option to the docs while in there

* llama-bench: rpc device simplification
* rpc servers unify with other devices earlier, simplifying code
* --list-devices made stateless and simpler
* various cleanup
yael-works pushed a commit to yael-works/llama.cpp that referenced this pull request Oct 15, 2025
* * llama-bench: add --devices support
- Support --devices same as llama-server
- Provide for benchmarking different device combinations
- Include --list-devices like llama-server for convenience

* fix: field display ordering restored

* fix: integrated the rpc devices
- aimed to mimic the server as much as possible

* cleanup: defaults for list-devices
- handle dup device listing with RPC

* cleanup: remove dup device load calls

* docs: update llama-bench
- added the recently added n-cpu-moe option to the docs while in there

* llama-bench: rpc device simplification
* rpc servers unify with other devices earlier, simplifying code
* --list-devices made stateless and simpler
* various cleanup
pwilkin pushed a commit to pwilkin/llama.cpp that referenced this pull request Oct 23, 2025
* * llama-bench: add --devices support
- Support --devices same as llama-server
- Provide for benchmarking different device combinations
- Include --list-devices like llama-server for convenience

* fix: field display ordering restored

* fix: integrated the rpc devices
- aimed to mimic the server as much as possible

* cleanup: defaults for list-devices
- handle dup device listing with RPC

* cleanup: remove dup device load calls

* docs: update llama-bench
- added the recently added n-cpu-moe option to the docs while in there

* llama-bench: rpc device simplification
* rpc servers unify with other devices earlier, simplifying code
* --list-devices made stateless and simpler
* various cleanup
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Eval bug: Tensor split on vulkan broken

3 participants