Skip to content

Conversation

@Abatom
Copy link
Contributor

@Abatom Abatom commented Feb 22, 2025

When we attempted to deploy DeepSeek R1 671B on two 8-card H20 machines, vLLM crashed and reported illegal memory access whenever the prompt length exceeded 32K. This PR fixes the bug.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Abatom <[email protected]>

Co-authored-by: Apache9 <[email protected]>
Co-authored-by: kirbyzhou <[email protected]>
Co-authored-by: ch-tiger <[email protected]>
@Apache9
Copy link

Apache9 commented Feb 22, 2025

Since the problem is that we slice with a smaller cache size but actually the latter operators may write beyond the cache limit, it does not always crash the program.

We also tested on two H800 machines, there was no problem. But I guess it may effect the quality of output tokens.

And we also tested sglang which has the same problem (see sgl-project/sglang#3779) on a single 8 GPUs machine where each GPU has 140GB+ memory, the only problem was that it would be extremely slow when the input prompt was longer than 32K...

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the bug report and patch. I could not reproduce on H200, but the current oversight makes sense

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 24, 2025
@mgoin mgoin enabled auto-merge (squash) February 24, 2025 14:34
@benchislett
Copy link
Collaborator

To reproduce on H200, try sending more than one concurrent request of length 50k+. I have seen this issue with DeepSeek-R1 on 8xH200, and it seems to be resolved with this diff.

Great catch!

@simon-mo simon-mo merged commit ccc0051 into vllm-project:main Feb 24, 2025
40 of 47 checks passed
russellb added a commit to russellb/vllm that referenced this pull request Feb 24, 2025
@DefTruth
Copy link
Contributor

same error for 32K+ context len, seems this pr can fix my problem.

@Louis-Zhu
Copy link

After applying this PR, we still meet this error with 80k characters. We deploy a DeepSeek R1 on 3 8*H20 machine, with 128k max_model_len

@Apache9
Copy link

Apache9 commented Mar 3, 2025

After applying this PR, we still meet this error with 80k characters. We deploy a DeepSeek R1 on 3 8*H20 machine, with 128k max_model_len

Could you please provide more detailed information about the crash?

We still saw illegal memory access after applying this PR too, and it happened in cublas, finally it turned out that our CUDA version and linux toolkit driver was not compatibile, after upgrading the toolkit driver the error disappeared.

Thanks.

@cheferrari
Copy link

hi, we need this PR to fix the same issue that deepseek-r1 on H20,When will the version containing this PR be released? Is there a timeline?

@Louis-Zhu
Copy link

Louis-Zhu commented Mar 3, 2025

After applying this PR, we still meet this error with 80k characters. We deploy a DeepSeek R1 on 3 8*H20 machine, with 128k max_model_len

Could you please provide more detailed information about the crash?

We still saw illegal memory access after applying this PR too, and it happened in cublas, finally it turned out that our CUDA version and linux toolkit driver was not compatibile, after upgrading the toolkit driver the error disappeared.

Thanks.  

we build a docker image on ray 2.40.0 base image, with main branch of vllm installed, and start a ray cluster with 3 8*H20 nodes, cuda version V12.4.131 and nvidia driver version 560.35.03.

@Apache9
Copy link

Apache9 commented Mar 3, 2025

After applying this PR, we still meet this error with 80k characters. We deploy a DeepSeek R1 on 3 8*H20 machine, with 128k max_model_len

Could you please provide more detailed information about the crash?
We still saw illegal memory access after applying this PR too, and it happened in cublas, finally it turned out that our CUDA version and linux toolkit driver was not compatibile, after upgrading the toolkit driver the error disappeared.
Thanks.

we build a docker image on ray 2.40.0 base image, with main branch of vllm installed, and start a ray cluster with 3 8*H20 nodes, cuda version V12.4.131 and nvidia driver version 560.35.03.

The driver version seems fine...

Then maybe there are still other bugs on your setup way...

@joydchh
Copy link

joydchh commented Mar 9, 2025

Since the problem is that we slice with a smaller cache size but actually the latter operators may write beyond the cache limit, it does not always crash the program.

We also tested on two H800 machines, there was no problem. But I guess it may effect the quality of output tokens.

And we also tested sglang which has the same problem (see sgl-project/sglang#3779) on a single 8 GPUs machine where each GPU has 140GB+ memory, the only problem was that it would be extremely slow when the input prompt was longer than 32K...

Do you have some insights on the slow decoding? We tested on H200, when prompt is around 8k, the decoding speed is just 4.x tokens/s.

@Abatom Abatom deleted the moe branch June 25, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants