Skip to content

Conversation

@liuzijing2014
Copy link
Collaborator

@liuzijing2014 liuzijing2014 commented Aug 12, 2025

Current Issue

Mitigation to avoid blocking copy operations across different CUDA streams. Details could be found in: #22754

Change

When we copy the sampled valid token ids from device to host, avoid using tolist which would trigger a CUDA wise stream sync if the source is on device. We change it to use non-blocking copy followed by an explicit CUDA event sync.

Test for Non Disagg

Bring up vLLM server

VLLM_USE_V1=1 vllm serve Qwen/Qwen2.5-14B-Instruct --disable-l
og-requests -tp 8 --max-num-seqs 64 --no-enable-prefix-caching --max_num_batched_tokens=8000

Load test for non-disagg setup and observe no perf difference:

Before

Maximum request concurrency: 64
============ Serving Benchmark Result ============
Successful requests:                     640       
Maximum request concurrency:             64        
Benchmark duration (s):                  29.60     
Total input tokens:                      1276994   
Total generated tokens:                  93700     
Request throughput (req/s):              21.62     
Output token throughput (tok/s):         3165.99   
Total Token throughput (tok/s):          46313.80  
---------------Time to First Token----------------
Mean TTFT (ms):                          327.88    
Median TTFT (ms):                        302.90    
P99 TTFT (ms):                           1566.85   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          18.16     
Median TPOT (ms):                        18.01     
P99 TPOT (ms):                           20.10     
---------------Inter-token Latency----------------
Mean ITL (ms):                           17.82     
Median ITL (ms):                         9.49      
P99 ITL (ms):                            102.45    
==================================================

After

Maximum request concurrency: 64
============ Serving Benchmark Result ============
Successful requests:                     640       
Maximum request concurrency:             64        
Benchmark duration (s):                  29.58     
Total input tokens:                      1276994   
Total generated tokens:                  93693     
Request throughput (req/s):              21.64     
Output token throughput (tok/s):         3167.38   
Total Token throughput (tok/s):          46337.38  
---------------Time to First Token----------------
Mean TTFT (ms):                          320.65    
Median TTFT (ms):                        302.65    
P99 TTFT (ms):                           1568.02   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          18.21     
Median TPOT (ms):                        18.17     
P99 TPOT (ms):                           20.09     
---------------Inter-token Latency----------------
Mean ITL (ms):                           17.86     
Median ITL (ms):                         9.50      
P99 ITL (ms):                            102.48    
==================================================

Test for Disagg

Runs with a vendor internal disagg implementation (kv connector interface based), and observe TTIT wins.

Before

Ran 1029/1029 requests in 133.00s
Success rate:        100.00%
QPS:                 7.74
Avg latency:         4.841s
Avg TTFT (client):   137.37ms
P50 TTFT (client):   130.12ms
P99 TTFT (client):   364.42ms
Avg TTIT (client):   31.35ms
P50 TTIT (client):   31.12ms
P99 TTIT (client):   38.27ms
Avg TTFT (server):   134.39ms
Avg TTIT (server):   31.29ms
Avg prefill len:     2199.72 tokens
P50 prefill len:     2199.00 tokens
P99 prefill len:     2231.00 tokens
Avg decode len:      150.00 tokens
P50 decode len:      150.00 tokens
P99 decode len:      150.00 tokens

After

Ran 1026/1026 requests in 132.34s
Success rate:        100.00%
QPS:                 7.75
Avg latency:         4.550s
Avg TTFT (client):   128.50ms
P50 TTFT (client):   125.91ms
P99 TTFT (client):   205.68ms
Avg TTIT (client):   29.47ms
P50 TTIT (client):   29.59ms
P99 TTIT (client):   30.00ms
Avg TTFT (server):   124.88ms
Avg TTIT (server):   28.88ms
Avg prefill len:     2199.10 tokens
P50 prefill len:     2199.00 tokens
P99 prefill len:     2233.00 tokens
Avg decode len:      150.00 tokens
P50 decode len:      150.00 tokens
P99 decode len:      150.00 tokens

Also confirmed from GPU trace that there is no blocking behavior between vLLM model forward CUDA stream with other CUDA streams.
Screenshot 2025-08-12 at 11 28 46 AM

Accuracy

gsm8K 8shot (disagg)

[2025-07-30 23:06:36,254] [rank 0] [INFO] Per prompt detailed info dumped to /tmp/eval_dump.gsm8k.8_shot.1_gen.20250730_230636.json
[2025-07-30 23:06:36,254] [rank 0] [INFO] Evaluation results on task gsm8k.8_shot.1_gen: em: 0.970000 | f1: 0.970000 | em_maj1@1: 0.970000 | f1_maj1@1: 0.970000
[2025-07-30 23:06:36,254] [rank 0] [INFO] Task gsm8k.8_shot.1_gen took 54.20 seconds

mmlu_pro (non-disagg)

vllm (pretrained=/data/users/zijingliu/cp/Llama-4-Maverick-17B-128E-Instruct-FP8,dtype=auto,max_model_len=8196,tensor_parallel_size=8,gpu_memory_utilization=0.9,max_gen_toks=2048,seed=0), gen_kwargs: (None), limit: 500.0, num_fewshot: None, batch_size: auto
|       Tasks       |Version|    Filter    |n-shot|  Metric   |   |Value |   |Stderr|
|-------------------|------:|--------------|-----:|-----------|---|-----:|---|-----:|
|mmlu_pro           |    2.0|custom-extract|      |exact_match|↑  |0.7994|±  |0.0048|
| - biology         |    2.1|custom-extract|     5|exact_match|↑  |0.8960|±  |0.0137|
| - business        |    2.1|custom-extract|     5|exact_match|↑  |0.8460|±  |0.0162|
| - chemistry       |    2.1|custom-extract|     5|exact_match|↑  |0.8540|±  |0.0158|
| - computer_science|    2.1|custom-extract|     5|exact_match|↑  |0.8390|±  |0.0182|
| - economics       |    2.1|custom-extract|     5|exact_match|↑  |0.8520|±  |0.0159|
| - engineering     |    2.1|custom-extract|     5|exact_match|↑  |0.7300|±  |0.0199|
| - health          |    2.1|custom-extract|     5|exact_match|↑  |0.7640|±  |0.0190|
| - history         |    2.1|custom-extract|     5|exact_match|↑  |0.7060|±  |0.0234|
| - law             |    2.1|custom-extract|     5|exact_match|↑  |0.5960|±  |0.0220|
| - math            |    2.1|custom-extract|     5|exact_match|↑  |0.9020|±  |0.0133|
| - other           |    2.1|custom-extract|     5|exact_match|↑  |0.7600|±  |0.0191|
| - philosophy      |    2.1|custom-extract|     5|exact_match|↑  |0.7375|±  |0.0197|
| - physics         |    2.1|custom-extract|     5|exact_match|↑  |0.8680|±  |0.0152|
| - psychology      |    2.1|custom-extract|     5|exact_match|↑  |0.8260|±  |0.0170|

| Groups |Version|    Filter    |n-shot|  Metric   |   |Value |   |Stderr|
|--------|------:|--------------|------|-----------|---|-----:|---|-----:|
|mmlu_pro|      2|custom-extract|      |exact_match|↑  |0.7994|±  |0.0048|

cc @WoosukKwon @lucas-tucker @simon-mo @houseroad

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request improves performance by replacing a blocking tolist() call on a CUDA tensor with a non-blocking copy and an explicit CUDA event synchronization. This avoids unnecessary device-wide stream synchronization. My review focuses on further optimizing this change by suggesting the reuse of allocated resources to minimize overhead in this critical performance path.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @liuzijing2014!

It might be cleaner to put this in a separate to_list(self, tensor) method?

Comment on lines 1734 to 1736
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work instead of using an event?

Suggested change
pinned.copy_(sampled_token_ids, non_blocking=True)
self.transfer_event.record()
self.transfer_event.synchronize()
pinned.copy_(sampled_token_ids, non_blocking=True)
torch.cuda.current_stream().synchronize()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the observation, i think the CUDA stream synchronization is what has caused us the issue. I added a few more details in the issue ticket: #22754

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I was just curious whether the problem was that all streams were being synchronized.

If my suggestion doesn't work or is inferior in some way then the event synchronization looks good to me!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I was just curious whether the problem was that all streams were being synchronized.

Yes, I think this is what we have observed.

@mergify
Copy link

mergify bot commented Aug 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @liuzijing2014.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 13, 2025
@liuzijing2014
Copy link
Collaborator Author

@njhill @WoosukKwon do you mind taking a quick review and see if this looks to you? Would appreciate this PR gets merged sooner to unblock a potential production launch.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @liuzijing2014! LGTM, but I would like to get blessing from @WoosukKwon too on this one.

Please also merge in latest main and resolve the conflicts.

@njhill
Copy link
Member

njhill commented Aug 18, 2025

We may also want to consider doing the same for other tensors that are transferred in the non default case such as logprobs and spec decode tokens.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ok to me. Sorry for the delay!

@njhill
Copy link
Member

njhill commented Aug 18, 2025

@liuzijing2014 but wait until #23118 is merged before rebasing!

@liuzijing2014
Copy link
Collaborator Author

Thanks for the help! Will rebase and merge soon.

@njhill
Copy link
Member

njhill commented Aug 20, 2025

@liuzijing2014 let me know if you'd like me to rebase to get this over the line

@liuzijing2014
Copy link
Collaborator Author

@liuzijing2014 let me know if you'd like me to rebase to get this over the line

@njhill sorry for the delay merging today!

Signed-off-by: Zijing Liu <[email protected]>
@njhill
Copy link
Member

njhill commented Aug 22, 2025

@liuzijing2014 sorry, I didn't get to it in time and there are more conflicts that need resolving now :-/

@njhill njhill enabled auto-merge (squash) August 25, 2025 17:16
@liuzijing2014
Copy link
Collaborator Author

@njhill i don't think the failing CIs is related to this PR, do you think we can still merge this PR directly?

@vllm-bot vllm-bot merged commit b395b3b into vllm-project:main Aug 26, 2025
33 of 36 checks passed
@liuzijing2014 liuzijing2014 deleted the event-sync branch August 26, 2025 21:15
tc-mb pushed a commit to tc-mb/vllm that referenced this pull request Aug 27, 2025
…oid unintentional copy ops blocking across different CUDA streams, improving disagg TTIT/TTFT (vllm-project#22760)

Signed-off-by: Zijing Liu <[email protected]>
Signed-off-by: Zijing Liu <[email protected]>
Signed-off-by: tc-mb <[email protected]>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
…oid unintentional copy ops blocking across different CUDA streams, improving disagg TTIT/TTFT (vllm-project#22760)

Signed-off-by: Zijing Liu <[email protected]>
Signed-off-by: Zijing Liu <[email protected]>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…oid unintentional copy ops blocking across different CUDA streams, improving disagg TTIT/TTFT (vllm-project#22760)

Signed-off-by: Zijing Liu <[email protected]>
Signed-off-by: Zijing Liu <[email protected]>
Signed-off-by: Xiao Yu <[email protected]>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
…oid unintentional copy ops blocking across different CUDA streams, improving disagg TTIT/TTFT (vllm-project#22760)

Signed-off-by: Zijing Liu <[email protected]>
Signed-off-by: Zijing Liu <[email protected]>
@tomasruizt
Copy link
Contributor

@liuzijing2014 @njhill
Is this the way to transfer tensors from GPU to CPU in asynchronously in vLLM? I'm curious for the spec decoding code path.
I'm a bit confused about the fact that there is still a .tolist() call. How come this doesn't block, but it did before?
And what about the event object? Perhaps each tolist() call should create its own CUDA event object?

@njhill
Copy link
Member

njhill commented Aug 29, 2025

@liuzijing2014 @njhill Is this the way to transfer tensors from GPU to CPU in asynchronously in vLLM? I'm curious for the spec decoding code path. I'm a bit confused about the fact that there is still a .tolist() call. How come this doesn't block, but it did before? And what about the event object? Perhaps each tolist() call should create its own CUDA event object?

@tomasruizt it still blocks, it's just that it can now happen concurrently with transfers that kv connectors might make in different cuda streams.

We are working on eliminating this sync point altogether, including for the spec decoding case: #23569

@tomasruizt
Copy link
Contributor

Great! Thanks for the explanation @njhill I'll keep an eye on #23569

zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
…oid unintentional copy ops blocking across different CUDA streams, improving disagg TTIT/TTFT (vllm-project#22760)

Signed-off-by: Zijing Liu <[email protected]>
Signed-off-by: Zijing Liu <[email protected]>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…oid unintentional copy ops blocking across different CUDA streams, improving disagg TTIT/TTFT (vllm-project#22760)

Signed-off-by: Zijing Liu <[email protected]>
Signed-off-by: Zijing Liu <[email protected]>
@dr75
Copy link
Contributor

dr75 commented Nov 12, 2025

I am seeing crashes in _to_list() in v0.11.0 with H100 and Llama 3.3, 70k context window:

Instance 1:

2025-11-11 17:43:17.875ERROR�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:43:17 [core.py:710]     ) = self._bookkeeping_sync(scheduler_output, sampler_output,
2025-11-11 17:43:17.875ERROR�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:43:17 [core.py:710]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-11-11 17:43:17.875ERROR�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:43:17 [core.py:710]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 2144, in _bookkeeping_sync
2025-11-11 17:43:17.875ERROR�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:43:17 [core.py:710]     valid_sampled_token_ids = self._to_list(sampled_token_ids)
2025-11-11 17:43:17.875ERROR�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:43:17 [core.py:710]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-11-11 17:43:17.875error�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:43:17 [core.py:710]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 4157, in _to_list
2025-11-11 17:43:17.875error�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:43:17 [core.py:710]     pinned.copy_(sampled_token_ids, non_blocking=True)
2025-11-11 17:43:17.875error�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:43:17 [core.py:710] torch.AcceleratorError: CUDA error: an illegal memory access was encountered

Instance 2:

2025-11-11 17:55:32.301error�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:55:32 [core.py:710]     ) = self._bookkeeping_sync(scheduler_output, sampler_output,
2025-11-11 17:55:32.301error�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:55:32 [core.py:710]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-11-11 17:55:32.301error�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:55:32 [core.py:710]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 2144, in _bookkeeping_sync
2025-11-11 17:55:32.301error�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:55:32 [core.py:710]     valid_sampled_token_ids = self._to_list(sampled_token_ids)
2025-11-11 17:55:32.301error�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:55:32 [core.py:710]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-11-11 17:55:32.301error�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:55:32 [core.py:710]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 4157, in _to_list
2025-11-11 17:55:32.301error�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:55:32 [core.py:710]     pinned.copy_(sampled_token_ids, non_blocking=True)
2025-11-11 17:55:32.301error�[1;36m(EngineCore_DP0 pid=14)�[0;0m ERROR 11-11 08:55:32 [core.py:710] torch.AcceleratorError: CUDA error: device-side assert triggered

Couldn't find a way to reproduce them yet, but likely caused by this change?!

There is also this related fix: #28025

But that wrong buffer seems not the cause of the crash as max_model_length of 70k is quite large.

@liuzijing2014, @njhill wdyt?

@dr75
Copy link
Contributor

dr75 commented Nov 13, 2025

Is the issue that the copy operation is happening while the next iteration runs? I am not familiar with the details here but to me the error seems like there are concurrent copy operations and this change could make them overlap. Can iteration N-1 still be copying the output while operation N also gets to that point? Both then access the same pinned buffer?

@njhill
Copy link
Member

njhill commented Nov 14, 2025

@dr75 I expect the root cause of the issue you are seeing is unrelated to this change. CUDA operates asynchronously and errors can manifest in arbitrary places where there are sync points.

You can try with CUDA_LAUNCH_BLOCKING=1 which may show more clearly where the error is happening (it can still be misleading though given multiple cuda streams).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants