Skip to content

Conversation

@Jialin
Copy link
Collaborator

@Jialin Jialin commented Sep 27, 2025

Purpose

Currently, GPUModelRunner._bookkeeping_sync interleaves numpy updates and python logics which is inefficient, and we could see scattered tensor and numpy array updates which consumes significant amount of times.

In this change, we simply vectorize the tensor and numpy updates

  • compute update indexes and values in for loop in Python
  • apply buck updates to tensor and numpy for vectorization

Update
We only apply batchify to 1D array in this PR, will do the 2D array/tensor updates as followup.

Screenshot 2025-09-26 at 11 40 03 AM

Test Plan & Test Result

Benchmark
We introduced a benchmark script to measure the win. It's a clear performance boost when all rows are updated (5x throughput boost), but a regression when <10% of rows are updated). But in real work scenarios, we believe most of the rows are updated, so it should still be a consistent improvement to the system.

Screenshot 2025-10-09 at 2 20 31 PM

Correctness

VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py  --method ngram  --model-dir meta-llama/Llama-3.1-8B-Instruct  --prompt_lookup_min 2  --prompt_lookup_max 5  --num_spec_tokens 5  --dataset-name hf  --dataset-path philschmid/mt-bench  --num-prompts 80  --print-output

Output is exactly the same before and after the change
--------------------------------------------------
total_num_output_tokens: 17069
num_drafts: 2548
num_draft_tokens: 12711
num_accepted_tokens: 2587
mean acceptance length: 2.02
--------------------------------------------------
acceptance at token 0: 0.43
acceptance at token 1: 0.25
acceptance at token 2: 0.15
acceptance at token 3: 0.10
acceptance at token 4: 0.07

Optimization
~3x speedup with the change per trace
Screenshot 2025-09-26 at 2 36 37 PM

Per gptoss AIME 2025 eval runs

  • bookkeeping total elapsed time reduced by 60%+
  • bookkeeping elapsed time distribution is less skewed
Screenshot 2025-09-26 at 11 36 11 PM Screenshot 2025-09-26 at 11 35 31 PM
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@Jialin Jialin marked this pull request as ready for review September 27, 2025 06:38
@mergify mergify bot added the v1 label Sep 27, 2025
@Jialin
Copy link
Collaborator Author

Jialin commented Sep 28, 2025

CC @yeqcharlotte @houseroad

@Jialin
Copy link
Collaborator Author

Jialin commented Oct 3, 2025

CC @njhill @WoosukKwon for awareness

I'm wondering when async scheduler is landed, would the bookkeeping costs also hide in a separate process.

@mergify
Copy link

mergify bot commented Oct 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Jialin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 4, 2025
# - self.input_batch.token_ids_cpu[req_idx,
# start_idx:end_idx] = sampled_ids
base_idx = req_idx * token_ids_cpu_column_cnt
token_ids_cpu_flatten_indices.extend(
Copy link
Collaborator

@houseroad houseroad Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit concerned if end_idx - start_idx is a large number.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think it's a valid concern. IIUC, we will only assign output tokens here, so most of the time is 1, and at most num_spec_tokens.

But to be very honest, I also need to confirm the prompt tokens are not appending here (if yes, this could be huge).

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can write a custom CPU op to take care of this case,

like batch_assign_2d(indices), indices is 1D tensor, consists of 3 elements, 0 dim index, 1 dim start and end.

@Jialin
Copy link
Collaborator Author

Jialin commented Oct 9, 2025

Per offline discussion, we will split the PR a bit:

  • [This PR] Batchify the 1D update using existing numpy / tensor index_select operator
  • [Followup PR 1] Batchify the 2D update to numpy array token_ids_cpu with numba JIT function
  • [Followup PR 2] Batchify the 2D update to pytorch tensor is_token_ids with custom operator

@Jialin
Copy link
Collaborator Author

Jialin commented Oct 9, 2025

Per offline discussion, we will split the PR a bit:

  • [This PR] Batchify the 1D update using existing numpy / tensor index_select operator
  • [Followup PR 1] Batchify the 2D update to numpy array token_ids_cpu with numba JIT function
  • [Followup PR 2] Batchify the 2D update to pytorch tensor is_token_ids with custom operator

We just found that with a micro batch for 1D array, individual updates is actually faster than batch updates (mostly likely due to slow list.append (i.e. batch update preparation). We will draft this PR first before having more promising numbers (potentially on 2D array).

After updated the benchmark scripts (which better reflects the actual bookkeeping usage, the change is shown to be a clear win. PTAL @houseroad

@Jialin Jialin marked this pull request as draft October 9, 2025 18:12
@mergify mergify bot added the performance Performance-related issues label Oct 9, 2025
@Jialin Jialin marked this pull request as ready for review October 9, 2025 21:12
@Jialin Jialin changed the title [Core] Bookkeeping optimization: Vectorize updates [Core] Bookkeeping optimization: Batchify updates 1D numpy arrays (e.g. num_tokens, num_tokens_no_spec) Oct 9, 2025
@mergify mergify bot removed the needs-rebase label Oct 10, 2025
@Jialin
Copy link
Collaborator Author

Jialin commented Oct 13, 2025

Gentle nudge @houseroad for the review :P


# Collect updates in the for loop and apply a batch update at the end
# to vectorize updates to tensors and numpy arrays.
start_indices = self.input_batch.num_tokens_no_spec.tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to convert to list for start_indices?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a microbench internally, and found that if we need to read all indices (tolist then read again python list is more efficient than direct read against numpy array).

Screenshot 2025-10-14 at 11 41 10 AM

@mergify
Copy link

mergify bot commented Oct 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Jialin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 13, 2025
# to vectorize updates to tensors and numpy arrays.
start_indices = self.input_batch.num_tokens_no_spec.tolist()
# Indices and values to update for num_tokens and num_tokens_no_spec
num_tokens_indices_to_update: list[int] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use np_array here? It may be lighter here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't know the total elements to update ahead of time. Even if it does, updating the np array elements within a for loop one at a time might not be efficient.

Please let me know if I missed anything.

Signed-off-by: Jialin Ouyang <[email protected]>
Signed-off-by: Jialin Ouyang <[email protected]>
Signed-off-by: Jialin Ouyang <[email protected]>
Signed-off-by: Jialin Ouyang <[email protected]>
Signed-off-by: Jialin Ouyang <[email protected]>
Signed-off-by: Jialin Ouyang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants