-
Notifications
You must be signed in to change notification settings - Fork 544
[BugFix][mian] Fixed a triton kernel bug of layer_norm_fwd_kernel for Qwen3-next #3549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes a Triton kernel bug in layer_norm_fwd_kernel by vendoring the kernel and modifying it to handle larger inputs. The changes correctly update the kernel launch grid and pass the necessary parameters. However, my review identified two critical bugs in the new kernel implementation related to undefined variables (mean and b) in certain code paths. These would cause the kernel to fail at runtime. I have provided a single, combined code suggestion to fix both issues and improve code clarity.
9981d67 to
585f5cb
Compare
|
LGTM! It solve the problem of a bunch of requests |
… Qwen3-next Signed-off-by: drslark <[email protected]>
… Qwen3-next (vllm-project#3549) ### What this PR does / why we need it? Fixes triton kernel **layer_norm_fwd_kernel**, descripted by vllm-project#3548 ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? The environment is same with this issue, vllm-project#3548. Starts a vllm server with: ```shell vllm serve /home/model/Qwen3-Next-80B-A3B-Instruct --port 22 --host 0.0.0.0 --served-model-name qwen3_next_mtp_0 --tensor-parallel-size 4 --max-model-len 32000 --gpu-memory-utilization 0.7 --enforce-eager ``` The, we start an aisbench clinet like: ```shell ais_bench --models vllm_api_general_chat --datasets ceval_gen_0_shot_cot_chat_prompt --dump-eval-details ``` Whose config is: ```python # a big batch_size and a large max_out_len dict( abbr='vllm-api-general-chat', attr='service', batch_size=512, generation_kwargs=dict(temperature=0.7, top_k=20, top_p=0.8), host_ip='xxx.xxx.xxx.xxx', host_port=8881, max_out_len=30000, model='qwen3_next_mtp_0', path='', pred_postprocessor=dict( type= 'ais_bench.benchmark.utils.model_postprocessors.extract_non_reasoning_content' ), request_rate=0, retry=2, trust_remote_code=False, type='ais_bench.benchmark.models.VLLMCustomAPIChat'), ``` **Results:** ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 72.1 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 98.3%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 72.1 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 100.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 71.4 tokens/s, Running: 7 reqs, Waiting: 1 reqs, GPU KV cache usage: 100.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 49.6 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 86.1%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 59.8 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:44:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 61.2 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:45:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 61.8 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 88.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 01:45:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 62.4 tokens/s, Running: 6 reqs, Waiting: 2 reqs, GPU KV cache usage: 90.8%, Prefix cache hit rate: 0.0% ``` We can see when we sent a bunch of requests and the **KV cache usage reaches 100.0%**. We won't get a **coreDim=xxx can't be greater than UINT16_MAX.** Exception. ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.6 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 98.3%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.3 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:17:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.6 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:18:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 30.9 tokens/s, Running: 3 reqs, Waiting: 5 reqs, GPU KV cache usage: 99.6%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 02:18:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 22.7 tokens/s, Running: 2 reqs, Waiting: 6 reqs, GPU KV cache usage: 81.9%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48568 - "POST /v1/chat/completions HTTP/1.1" 200 OK ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48580 - "POST /v1/chat/completions HTTP/1.1" 200 OK ``` And after a few minutes, these two requests have been done. ```text ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.3 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:18:55 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:05 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 40.8%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:15 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 6.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 41.2%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO: 141.61.39.105:48712 - "POST /v1/chat/completions HTTP/1.1" 200 OK ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.8 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0% ^[[1;36m(APIServer pid=615544)^[[0;0m INFO 10-21 03:19:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0% ``` Finally, all requests are done. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: drslark <[email protected]>
What this PR does / why we need it?
Fixes triton kernel layer_norm_fwd_kernel, descripted by #3548
Does this PR introduce any user-facing change?
N/A
How was this patch tested?
The environment is same with this issue, #3548.
Starts a vllm server with:
The, we start an aisbench clinet like:
Whose config is:
Results:
We can see when we sent a bunch of requests and the KV cache usage reaches 100.0%.
We won't get a coreDim=xxx can't be greater than UINT16_MAX. Exception.
And after a few minutes, these two requests have been done.
Finally, all requests are done.