Skip to content

Conversation

@jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Sep 10, 2024

The real bonus token should be the first unmatched token. For example, the draft_token_ids is [1, 2, 3, 5, 7], and the target_token_ids is [1, 2, 3, 4, 6, 8]. Then, the matched token should be [1, 2, 3], and the bonus token should be [4] because the target model will output [4] based on the input [1, 2, 3] in the next round generation. So we can select the bonus token [4] in this round without any precision regression.

It will increase the performance from 89 tokens/s to 110 tokens/s in typical_acceptance_sampler in A100 single card with (num_speculative_tokens=2, max_num_seqs=1, model="meta-llama/Llama-2-7b-chat-hf", speculative_model="Felladrin/Llama-68M-Chat-v1"). The outputs are exactly the same before and after my changes.

Do you mind review this PR? @cadedaniel
cc @LiuXiaoxuanPKU

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@jiqing-feng jiqing-feng changed the title fix verify tokens with the correct bonus token Fix verify tokens with the correct bonus token Sep 10, 2024
@LiuXiaoxuanPKU
Copy link
Collaborator

Hey, thanks for the interest! I want to align some definition here:

  1. In vllm, bonus_token_ids is defined as
   The "bonus" token ids that are accepted iff all speculative tokens in a sequence are accepted.

That being said, if the draft_token_ids is [1, 2, 3, 5, 7], and the target_token_ids is [1, 2, 3, 5, 7, 8]. Then the bonus_token_ids for this record is [8]. Notice here, all proposed tokens 1, 2, 3, 5, 7 are accepted.
2. For the example you gave, we say 4 is a recovered_token. To get the recovered token, instead of getting it from the target model directly. vllm samples the recovered token from a new distribution as shown here. (Minor thing, for greedy decoding, yeah the sampled token is the same as the target token, but it might be different for standard sampling.) We implemented it to strictly follow the paper.
Screenshot 2024-09-10 at 9 44 24 AM

Please let me know if there is any confusion here! Sorry for different terms here, we might simplify it in the future.

@LiuXiaoxuanPKU
Copy link
Collaborator

LiuXiaoxuanPKU commented Sep 10, 2024

After second thoughts, I feel this is a good optimization. We can actually skip rejection sampling in the greedy decoding case, which explains the speedup you get.
I also suggestion trying the flashinfer backend and you should also see good speedup.

@yao-matrix
Copy link

After second thoughts, I feel this is a good optimization. We can actually skip rejection sampling in the greedy decoding case, which explains the speedup you get. I also suggestion trying the flashinfer backend and you should also see good speedup.

Thx Xiaoxuan, so do you think this optimization idea is qualified to merge to vllm. We treat this as a platform independent optimization(orthogonal w/ backend optimizations like flashinfer), which can benefit other device backends like CPU/XPU, and we see similar performance issue.

@jiqing-feng
Copy link
Contributor Author

This change is to align the verify token function to the transformers speculative sampling algorithm, it always selectes the next sample token from the target model.

@LiuXiaoxuanPKU
Copy link
Collaborator

This change is to align the verify token function to the transformers speculative sampling algorithm, it always selectes the next sample token from the target model.

From reading the code, it seems it adjusts the distribution and resamples (line 4195 - 4203)?

@LiuXiaoxuanPKU
Copy link
Collaborator

After second thoughts, I feel this is a good optimization. We can actually skip rejection sampling in the greedy decoding case, which explains the speedup you get. I also suggestion trying the flashinfer backend and you should also see good speedup.

Thx Xiaoxuan, so do you think this optimization idea is qualified to merge to vllm. We treat this as a platform independent optimization(orthogonal w/ backend optimizations like flashinfer), which can benefit other device backends like CPU/XPU, and we see similar performance issue.

Could you double check the correctness here? If the optimization can pass the rejection sampling tests, yeah happy to review and get it in.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Sep 11, 2024

From reading the code, it seems it adjusts the distribution and resamples (line 4195 - 4203)?

Yes, but the point is the output will always contain the first unmatched token p_n_plus_1 (which is the next sample token from the target model).

Could you double check the correctness here? If the optimization can pass the rejection sampling tests, yeah happy to review and get it in.

The error comes from sampling. We can not guarantee the output will be all matched even if the target model is the same as the draft model because sampling will introduce random factors. So I disabled sampling by setting temperature=0 so we can make sure all tokens will be matched.

@comaniac
Copy link
Collaborator

The error comes from sampling. We can not guarantee the output will be all matched even if the target model is the same as the draft model because sampling will introduce random factors. So I disabled sampling by setting temperature=0 so we can make sure all tokens will be matched.

This is the intention because you cannot force users to set temperature=0. That's why @LiuXiaoxuanPKU suggested we could bypass rejected sampling when temperature=0, but we cannot remove rejected sampling.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Sep 11, 2024

This is the intention because you cannot force users to set temperature=0. That's why @LiuXiaoxuanPKU suggested we could bypass rejected sampling when temperature=0, but we cannot remove rejected sampling.

We cannot expect the output length of speculative decoding to be a fixed number if sampling is applied, it could be in [1, num_speculative_tokens] in a single step. How about changing the test to verify the output token number in a reasonable interval instead of finishing after 2 steps?

@jiqing-feng
Copy link
Contributor Author

OK, I removed temperature=0 but changed the max_tokens and num_speculative_tokens to avoid tests failed by sampling. Please let me know your opinion about this test.

@LiuXiaoxuanPKU
Copy link
Collaborator

I'm confused. It seems this PR removes rejection sampling. Then how do you do speculative decoding with temperature != 0?

@jiqing-feng
Copy link
Contributor Author

I'm confused. It seems this PR removes rejection sampling. Then how do you do speculative decoding with temperature != 0?

I didn't remove rejection sampling, just removed recovered_token since it is not needed. The newly defined bonus_token covers the recovered_token case. Besides, I kept the temperature in default value in the test.

@LiuXiaoxuanPKU
Copy link
Collaborator

The recovered token should not be from the target token ids if temperature != 0. It should be sampled from a new distribution. Correct me if I misunderstand anything.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Sep 11, 2024

The recovered token should not be from the target token ids if temperature != 0. It should be sampled from a new distribution. Correct me if I misunderstand anything.

I see your point, the recovered_token is selected from a new distribution, but I just select it from target_token. It is more convenient and also makes sense.

I know adjusting the distribution of target_prob is from the original paper, but I didn't see any advantages compared to just selecting from target_token_ids, and it also introduces a customized multinomial function and some overheads.

Please let me know if you want me to revert it to keep the original recovered_token from the adjusted target probs, thanks.

@yao-matrix
Copy link

The recovered token should not be from the target token ids if temperature != 0. It should be sampled from a new distribution. Correct me if I misunderstand anything.

I see your point, the recovered_token is selected from a new distribution, but I just select it from target_token. It is more convenient and also makes sense.

I know adjusting the distribution of target_prob is from the original paper, but I didn't see any advantages compared to just selecting from target_token_ids, and it also introduces a customized multinomial function and some overheads.

Please let me know if you want me to revert it to keep the original recovered_token from the adjusted target probs, thanks.

I don't think change original paper algorithm is a good idea without data proving, and I don't think change behavior is the target of this PR. This PR's target is performance optimization. @jiqing-feng, pls only optimize performance for temperature == 0, and don't change the logic of others.
You can submit another PR or issue if you wanna discuss it, keeping the PR only for one thing is better for cognition burden and make things move fast forward.

@jiqing-feng
Copy link
Contributor Author

I have reverted unnecessary changes. Now, the rejection sampling exactly follows the paper and is the same as transformers integration.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Sep 13, 2024

The recovered token should not be from the target token ids if temperature != 0. It should be sampled from a new distribution. Correct me if I misunderstand anything.

Yes, you were right. I have fixed the recovered token by selecting it from the new distribution (torch.clamp(target_prob-draft_prob), min=0). Please take a review. Thx.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Sep 18, 2024

Hi @LiuXiaoxuanPKU . I checked the rejection sampler codes in detail and found there is no need to change it because you can get the correct recovered token ids. Only 1 thing:
The _multinomial function just selects the tokens with the largest probability which is different with torch.multinomial

I am okay with the little difference btw vllm and the original paper because I cannot get speed-up on rejection sampling with this PR, but it could get significant speed-up on typical acceptance sampler. So I opened another PR to only change typical acceptance sampler, see #8562 . Please take a review, thx.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants