Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions benchmarks/benchmark_bookkeep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
from random import randint, shuffle

import numpy as np
from benchmark_utils import TimeCollector, throughput_change
from tabulate import tabulate
from tqdm import trange

from vllm.utils import FlexibleArgumentParser


def update_one_by_one(
num_tokens_no_spec_np: np.ndarray[np.int32, np.dtype[np.int32]],
num_tokens_np: np.ndarray[np.int32, np.dtype[np.int32]],
update_tags: list[bool],
update_values: list[int],
) -> None:
for i in range(len(update_values)):
if update_tags[i]:
# Access a single value from numpy array
start_idx = num_tokens_no_spec_np[i]
end_idx = start_idx + update_values[i]
# Update a single value to 2 numpy arrays
num_tokens_no_spec_np[i] = end_idx
num_tokens_np[i] = end_idx


def update_by_batch(
num_tokens_no_spec_np: np.ndarray[np.int32, np.dtype[np.int32]],
num_tokens_np: np.ndarray[np.int32, np.dtype[np.int32]],
update_tags: list[bool],
update_values: list[int],
) -> None:
# Convert numpy array to list once before for loop
num_tokens_no_spec = num_tokens_no_spec_np.tolist()
num_tokens_indices_to_update: list[int] = []
num_tokens_values_to_update: list[int] = []
for i in range(len(update_values)):
if update_tags[i]:
start_idx = num_tokens_no_spec[i]
end_idx = start_idx + update_values[i]
num_tokens_indices_to_update.append(i)
num_tokens_values_to_update.append(end_idx)
# Batch update numpy arrays after for loop
if num_tokens_indices_to_update:
num_tokens_no_spec_np[num_tokens_indices_to_update] = (
num_tokens_values_to_update
)
num_tokens_np[num_tokens_indices_to_update] = num_tokens_values_to_update


def main(args) -> None:
testsets = []
for num_element in args.num_elements:
update_element = num_element
while update_element > 0:
testsets.append((num_element, update_element))
update_element = update_element // 10
testsets.append((num_element, 0))
testsets.sort()

data_rows = []
TIME_SCALE = TimeCollector.US

for i in trange(len(testsets), desc="Testsets"):
num_element, update_element = testsets[i]
num_tokens_no_spec_np = np.empty((num_element,), dtype=np.int32)
num_tokens_np = np.empty((num_element,), dtype=np.int32)
one_by_one_times = TimeCollector(TIME_SCALE)
batch_times = TimeCollector(TIME_SCALE)
# Only update update_element per iterations
update_tags = [True] * update_element + [False] * (num_element - update_element)
gc.collect()
for _ in trange(args.num_iteration, desc="Iterations per testset"):
shuffle(update_tags)
update_values = [randint(0, 100) for _ in range(num_element)]
with one_by_one_times:
update_one_by_one(
num_tokens_no_spec_np, num_tokens_np, update_tags, update_values
)
with batch_times:
update_by_batch(
num_tokens_no_spec_np, num_tokens_np, update_tags, update_values
)

one_by_one_avg = one_by_one_times.avg_v()
batch_metric_avg = batch_times.avg_v()
data_rows.append(
[
num_element,
update_element,
one_by_one_times.avg(),
batch_times.avg(),
throughput_change(batch_metric_avg, one_by_one_avg),
]
)

print(
tabulate(
data_rows,
headers=[
"Total\nElements",
"Update\nElements",
"One by One\nAvg (us)",
"Batch\nAvg (us)",
"Throughput\nChange",
],
tablefmt="pipe",
floatfmt=".3f",
colalign=["right"] * len(data_rows[0]),
)
)


def invoke_main() -> None:
parser = FlexibleArgumentParser(
description="Benchmark the performance of Bookkeeping "
"(i.e. GPUModelRunner._bookkeeping_sync)"
)
parser.add_argument(
"--num-iteration",
type=int,
default=1000,
help="Number of iterations to run to stabilize final data readings",
)
parser.add_argument(
"--num_elements",
type=int,
nargs="+",
default=[10, 100, 1000, 10000],
)
main(parser.parse_args())


if __name__ == "__main__":
invoke_main() # pragma: no cover
16 changes: 15 additions & 1 deletion benchmarks/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from types import TracebackType
from typing import Any

from multi_turn.bench_utils import Color


def convert_to_pytorch_benchmark_format(
args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
Expand Down Expand Up @@ -75,6 +77,15 @@ def write_to_json(filename: str, records: list) -> None:
)


def throughput_change(test_elpased_time: float, baseline_elpased_time: float) -> str:
"""
Generates throughput change between test and baseline elapsed time
with proper colors.
"""
color = Color.GREEN if test_elpased_time < baseline_elpased_time else Color.RED
return f"{color}{baseline_elpased_time / test_elpased_time * 100 - 100:.2f}%\033[0m"


# Collect time and generate time metrics
#
# Example Usage:
Expand Down Expand Up @@ -104,8 +115,11 @@ def collect(self, v: int) -> None:
else:
self._max = max(self._max, v)

def avg_v(self) -> float:
return self._sum * 1.0 / self.cnt / self.scale

def avg(self) -> float | str:
return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A"
return self.avg_v() if self.cnt > 0 else "N/A"

def max(self) -> float | str:
return self._max / self.scale if self._max else "N/A"
Expand Down
29 changes: 26 additions & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,6 +2318,13 @@ def _bookkeeping_sync(
if i not in invalid_req_indices_set
}

# Collect updates in the for loop and apply a batch update at the end
# to vectorize updates to tensors and numpy arrays.
start_indices = self.input_batch.num_tokens_no_spec.tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to convert to list for start_indices?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a microbench internally, and found that if we need to read all indices (tolist then read again python list is more efficient than direct read against numpy array).

Screenshot 2025-10-14 at 11 41 10 AM

# Indices and values to update for num_tokens and num_tokens_no_spec
num_tokens_indices_to_update: list[int] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use np_array here? It may be lighter here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't know the total elements to update ahead of time. Even if it does, updating the np array elements within a for loop one at a time might not be efficient.

Please let me know if I missed anything.

num_tokens_values_to_update: list[int] = []

# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
Expand All @@ -2332,23 +2339,39 @@ def _bookkeeping_sync(
if not sampled_ids:
continue

start_idx = self.input_batch.num_tokens_no_spec[req_idx]
start_idx = start_indices[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}"
)

# TODO(Jialin): batchify the update to token_ids_cpu and is_token_ids
self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx

# Collect updates to num_tokens and num_tokens_no_spec,
# which is equivilent to
# - self.input_batch.num_tokens_no_spec[req_idx] = end_idx
# - self.input_batch.num_tokens[req_idx] = end_idx
num_tokens_indices_to_update.append(req_idx)
num_tokens_values_to_update.append(end_idx)

req_id = req_ids[req_idx]
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)

# Apply tensor / numpy array updates in batch
if num_tokens_indices_to_update:
# Batch update num_tokens arrays
self.input_batch.num_tokens[num_tokens_indices_to_update] = (
num_tokens_values_to_update
)
self.input_batch.num_tokens_no_spec[num_tokens_indices_to_update] = (
num_tokens_values_to_update
)

return (
num_nans_in_logits,
logprobs_lists,
Expand Down