diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 66e2e3312337..80a5a610c8ac 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -222,6 +222,7 @@ steps: - pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_utils.py - pytest -v -s v1/test_oracle.py + - pytest -v -s v1/test_metrics_reader.py # TODO: accuracy does not match, whether setting # VLLM_USE_FLASHINFER_SAMPLER or not on H100. - pytest -v -s v1/e2e diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 3dd9e5464641..606ce7799a88 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -6,6 +6,7 @@ from transformers import AutoTokenizer from vllm import LLM, SamplingParams +from vllm.v1.metrics.reader import Counter, Vector def load_prompts(dataset_path, num_prompts): @@ -105,30 +106,33 @@ def main(): print(f"generated text: {output.outputs[0].text}") print("-" * 50) - if not hasattr(outputs, "metrics") or outputs.metrics is None: + try: + metrics = llm.get_metrics() + except AssertionError: + print("Metrics are not supported in the V0 engine.") return - # calculate the average number of accepted tokens per forward pass, +1 is - # to account for the token from the target model that's always going to be - # accepted - acceptance_counts = [0] * (args.num_spec_tokens + 1) - for output in outputs: - for step, count in enumerate(output.metrics.spec_token_acceptance_counts): - acceptance_counts[step] += count + num_drafts = num_accepted = 0 + acceptance_counts = [0] * args.num_spec_tokens + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens": + assert isinstance(metric, Counter) + num_accepted += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(len(metric.values)): + acceptance_counts[pos] += metric.values[pos] print("-" * 50) - print( - f"mean acceptance length (including bonus tokens): \ - {1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}" - ) + print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}") print("-" * 50) # print acceptance at each token position for i in range(len(acceptance_counts)): - print( - f"acceptance at token {i}:" - f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}" - ) + print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}") if __name__ == "__main__": diff --git a/examples/offline_inference/metrics.py b/examples/offline_inference/metrics.py new file mode 100644 index 000000000000..7927f758cb57 --- /dev/null +++ b/examples/offline_inference/metrics.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams +from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Vector + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + +def main(): + # Create an LLM. + llm = LLM(model="facebook/opt-125m", disable_log_stats=False) + + # Generate texts from the prompts. + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + print("-" * 50) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + # Dump all metrics + for metric in llm.get_metrics(): + if isinstance(metric, Gauge): + print(f"{metric.name} (gauge) = {metric.value}") + elif isinstance(metric, Counter): + print(f"{metric.name} (counter) = {metric.value}") + elif isinstance(metric, Vector): + print(f"{metric.name} (vector) = {metric.values}") + elif isinstance(metric, Histogram): + print(f"{metric.name} (histogram)") + print(f" sum = {metric.sum}") + print(f" count = {metric.count}") + for bucket_le, value in metric.buckets.items(): + print(f" {bucket_le} = {value}") + + +if __name__ == "__main__": + main() diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index cefb89eb652b..e77916f95823 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -6,6 +6,7 @@ import pytest from vllm import LLM, SamplingParams +from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector MODEL = "facebook/opt-125m" DTYPE = "half" @@ -97,3 +98,67 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None: raise AssertionError( f"{len(completion_counts)} unique completions; expected" f" {n}. Repeats: {repeats}") + + +def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): + max_tokens = 100 + # Use spec decoding to test num_accepted_tokens_per_pos + speculative_config = { + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 5, + } + monkeypatch.setenv("VLLM_USE_V1", "1") + with vllm_runner( + MODEL, + speculative_config=speculative_config, + disable_log_stats=False, + ) as vllm_model: + model: LLM = vllm_model.model + sampling_params = SamplingParams(temperature=0.0, + max_tokens=max_tokens) + outputs = model.generate(example_prompts, sampling_params) + + n_prompts = len(example_prompts) + assert len(outputs) == n_prompts + + total_tokens = 0 + for out in outputs: + assert len(out.outputs) == 1 + total_tokens += len(out.outputs[0].token_ids) + assert total_tokens == max_tokens * n_prompts + + metrics = model.get_metrics() + + def find_metric(name) -> list[Metric]: + found = [] + for metric in metrics: + if metric.name == name: + found.append(metric) + return found + + num_requests_running = find_metric("vllm:num_requests_running") + assert len(num_requests_running) == 1 + assert isinstance(num_requests_running[0], Gauge) + assert num_requests_running[0].value == .0 + + generation_tokens = find_metric("vllm:generation_tokens") + assert len(generation_tokens) == 1 + assert isinstance(generation_tokens[0], Counter) + assert generation_tokens[0].value == total_tokens + + request_generation_tokens = find_metric( + "vllm:request_generation_tokens") + assert len(request_generation_tokens) == 1 + assert isinstance(request_generation_tokens[0], Histogram) + assert "+Inf" in request_generation_tokens[0].buckets + assert request_generation_tokens[0].buckets["+Inf"] == n_prompts + assert request_generation_tokens[0].count == n_prompts + assert request_generation_tokens[0].sum == total_tokens + + num_accepted_tokens_per_pos = find_metric( + "vllm:spec_decode_num_accepted_tokens_per_pos") + assert len(num_accepted_tokens_per_pos) == 1 + assert isinstance(num_accepted_tokens_per_pos[0], Vector) + assert len(num_accepted_tokens_per_pos[0].values) == 5 diff --git a/tests/v1/test_metrics_reader.py b/tests/v1/test_metrics_reader.py new file mode 100644 index 000000000000..68539c80b59c --- /dev/null +++ b/tests/v1/test_metrics_reader.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 + +import prometheus_client +import pytest + +from vllm.v1.metrics.reader import (Counter, Gauge, Histogram, Vector, + get_metrics_snapshot) + + +@pytest.fixture(autouse=True) +def test_registry(monkeypatch): + # Use a custom registry for tests + test_registry = prometheus_client.CollectorRegistry(auto_describe=True) + monkeypatch.setattr("vllm.v1.metrics.reader.REGISTRY", test_registry) + return test_registry + + +@pytest.mark.parametrize("num_engines", [1, 4]) +def test_gauge_metric(test_registry, num_engines): + g = prometheus_client.Gauge("vllm:test_gauge", + "Test gauge metric", + labelnames=["model", "engine_index"], + registry=test_registry) + for i in range(num_engines): + g.labels(model="foo", engine_index=str(i)).set(98.5) + + metrics = get_metrics_snapshot() + assert len(metrics) == num_engines + engine_labels = [str(i) for i in range(num_engines)] + for m in metrics: + assert isinstance(m, Gauge) + assert m.name == "vllm:test_gauge" + assert m.value == 98.5 + assert m.labels["model"] == "foo" + assert m.labels["engine_index"] in engine_labels + engine_labels.remove(m.labels["engine_index"]) + + +@pytest.mark.parametrize("num_engines", [1, 4]) +def test_counter_metric(test_registry, num_engines): + c = prometheus_client.Counter("vllm:test_counter", + "Test counter metric", + labelnames=["model", "engine_index"], + registry=test_registry) + for i in range(num_engines): + c.labels(model="bar", engine_index=str(i)).inc(19) + + metrics = get_metrics_snapshot() + assert len(metrics) == num_engines + engine_labels = [str(i) for i in range(num_engines)] + for m in metrics: + assert isinstance(m, Counter) + assert m.name == "vllm:test_counter" + assert m.value == 19 + assert m.labels["model"] == "bar" + assert m.labels["engine_index"] in engine_labels + engine_labels.remove(m.labels["engine_index"]) + + +@pytest.mark.parametrize("num_engines", [1, 4]) +def test_histogram_metric(test_registry, num_engines): + h = prometheus_client.Histogram("vllm:test_histogram", + "Test histogram metric", + labelnames=["model", "engine_index"], + buckets=[10, 20, 30, 40, 50], + registry=test_registry) + for i in range(num_engines): + hist = h.labels(model="blaa", engine_index=str(i)) + hist.observe(42) + hist.observe(21) + hist.observe(7) + + metrics = get_metrics_snapshot() + assert len(metrics) == num_engines + engine_labels = [str(i) for i in range(num_engines)] + for m in metrics: + assert isinstance(m, Histogram) + assert m.name == "vllm:test_histogram" + assert m.count == 3 + assert m.sum == 70 + assert m.buckets["10.0"] == 1 + assert m.buckets["20.0"] == 1 + assert m.buckets["30.0"] == 2 + assert m.buckets["40.0"] == 2 + assert m.buckets["50.0"] == 3 + assert m.labels["model"] == "blaa" + assert m.labels["engine_index"] in engine_labels + engine_labels.remove(m.labels["engine_index"]) + + +@pytest.mark.parametrize("num_engines", [1, 4]) +def test_vector_metric(test_registry, num_engines): + c = prometheus_client.Counter( + "vllm:spec_decode_num_accepted_tokens_per_pos", + "Vector-like counter metric", + labelnames=["position", "model", "engine_index"], + registry=test_registry) + for i in range(num_engines): + c.labels(position="0", model="llama", engine_index=str(i)).inc(10) + c.labels(position="1", model="llama", engine_index=str(i)).inc(5) + c.labels(position="2", model="llama", engine_index=str(i)).inc(1) + + metrics = get_metrics_snapshot() + assert len(metrics) == num_engines + engine_labels = [str(i) for i in range(num_engines)] + for m in metrics: + assert isinstance(m, Vector) + assert m.name == "vllm:spec_decode_num_accepted_tokens_per_pos" + assert m.values == [10, 5, 1] + assert m.labels["model"] == "llama" + assert m.labels["engine_index"] in engine_labels + engine_labels.remove(m.labels["engine_index"]) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f818e1737975..5e7d442a74ca 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -4,7 +4,8 @@ import warnings from collections.abc import Sequence from contextlib import contextmanager -from typing import Any, Callable, ClassVar, Optional, Union, cast, overload +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, + cast, overload) import cloudpickle import torch.nn as nn @@ -47,6 +48,9 @@ from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, is_list_of) +if TYPE_CHECKING: + from vllm.v1.metrics.reader import Metric + logger = init_logger(__name__) _R = TypeVar("_R", default=Any) @@ -1295,6 +1299,20 @@ def wake_up(self, tags: Optional[list[str]] = None): """ self.llm_engine.wake_up(tags) + def get_metrics(self) -> list["Metric"]: + """Return a snapshot of aggregated metrics from Prometheus. + + Returns: + A ``MetricSnapshot`` instance capturing the current state + of all aggregated metrics from Prometheus. + + Note: + This method is only available with the V1 LLM engine. + """ + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + assert isinstance(self.llm_engine, V1LLMEngine) + return self.llm_engine.get_metrics() + # LEGACY def _convert_v1_inputs( self, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 112896d6c767..c856e2645a2c 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -27,7 +27,10 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import StatLoggerFactory +from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase, + StatLoggerFactory) +from vllm.v1.metrics.reader import Metric, get_metrics_snapshot +from vllm.v1.metrics.stats import IterationStats logger = init_logger(__name__) @@ -64,6 +67,11 @@ def __init__( self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config + self.log_stats = log_stats + self.stat_logger: Optional[StatLoggerBase] = None + if self.log_stats: + self.stat_logger = PrometheusStatLogger(vllm_config) + # important: init dp group before init the engine_core # In the decoupled engine case this is handled in EngineCoreProc. parallel_config = vllm_config.parallel_config @@ -86,7 +94,7 @@ def __init__( # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). self.output_processor = OutputProcessor(self.tokenizer, - log_stats=False) + log_stats=self.log_stats) # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) self.engine_core = EngineCoreClient.make_client( @@ -94,7 +102,7 @@ def __init__( asyncio_mode=False, vllm_config=vllm_config, executor_class=executor_class, - log_stats=False, # FIXME: implement + log_stats=self.log_stats, ) if not multiprocess_mode: @@ -223,12 +231,21 @@ def step(self) -> list[RequestOutput]: outputs = self.engine_core.get_output() # 2) Process EngineCoreOutputs. + iteration_stats = IterationStats() if self.log_stats else None processed_outputs = self.output_processor.process_outputs( - outputs.outputs) + outputs.outputs, + engine_core_timestamp=outputs.timestamp, + iteration_stats=iteration_stats) # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) + # 4) Record stats + if self.stat_logger is not None: + assert outputs.scheduler_stats is not None + self.stat_logger.record(scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats) + return processed_outputs.request_outputs def get_vllm_config(self): @@ -260,6 +277,10 @@ def wake_up(self, tags: Optional[list[str]] = None): def is_sleeping(self) -> bool: return self.engine_core.is_sleeping() + def get_metrics(self) -> list[Metric]: + assert self.log_stats, "Stat logging disabled" + return get_metrics_snapshot() + def get_tokenizer_group(self) -> TokenizerGroup: if self.tokenizer is None: raise ValueError("Unable to get tokenizer because " diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 2b75a3a2ecbd..3dc2f77444f6 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -200,24 +200,24 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): # Counters # self.counter_num_preempted_reqs = self._counter_cls( - name="vllm:num_preemptions_total", + name="vllm:num_preemptions", documentation="Cumulative number of preemption from the engine.", labelnames=labelnames).labels(*labelvalues) self.counter_prompt_tokens = self._counter_cls( - name="vllm:prompt_tokens_total", + name="vllm:prompt_tokens", documentation="Number of prefill tokens processed.", labelnames=labelnames).labels(*labelvalues) self.counter_generation_tokens = self._counter_cls( - name="vllm:generation_tokens_total", + name="vllm:generation_tokens", documentation="Number of generation tokens processed.", labelnames=labelnames).labels(*labelvalues) self.counter_request_success: dict[FinishReason, prometheus_client.Counter] = {} counter_request_success_base = self._counter_cls( - name="vllm:request_success_total", + name="vllm:request_success", documentation="Count of successfully processed requests.", labelnames=labelnames + ["finished_reason"]) for reason in FinishReason: diff --git a/vllm/v1/metrics/reader.py b/vllm/v1/metrics/reader.py new file mode 100644 index 000000000000..5ab78129a009 --- /dev/null +++ b/vllm/v1/metrics/reader.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Optional + +from prometheus_client import REGISTRY +from prometheus_client import Metric as PromMetric +from prometheus_client.samples import Sample + + +@dataclass +class Metric: + """A base class for prometheus metrics. + + Each metric may be associated with key=value labels, and + in some cases a single vLLM instance may have multiple + metrics with the same name but different sets of labels. + """ + name: str + labels: dict[str, str] + + +@dataclass +class Counter(Metric): + """A monotonically increasing integer counter.""" + value: int + + +@dataclass +class Vector(Metric): + """An ordered array of integer counters. + + This type - which doesn't exist in Prometheus - models one very + specific metric, vllm:spec_decode_num_accepted_tokens_per_pos. + """ + values: list[int] + + +@dataclass +class Gauge(Metric): + """A numerical value that can go up or down.""" + value: float + + +@dataclass +class Histogram(Metric): + """Observations recorded in configurable buckets. + + Buckets are represented by a dictionary. The key is + the upper limit of the bucket, and the value is the + observed count in that bucket. A '+Inf' key always + exists. + + The count property is the total count across all + buckets, identical to the count of the '+Inf' bucket. + + The sum property is the total sum of all observed + values. + """ + count: int + sum: float + buckets: dict[str, int] + + +def get_metrics_snapshot() -> list[Metric]: + """An API for accessing in-memory Prometheus metrics. + + Example: + >>> for metric in llm.get_metrics(): + ... if isinstance(metric, Counter): + ... print(f"{metric} = {metric.value}") + ... elif isinstance(metric, Gauge): + ... print(f"{metric} = {metric.value}") + ... elif isinstance(metric, Histogram): + ... print(f"{metric}") + ... print(f" sum = {metric.sum}") + ... print(f" count = {metric.count}") + ... for bucket_le, value in metrics.buckets.items(): + ... print(f" {bucket_le} = {value}") + """ + collected: list[Metric] = [] + for metric in REGISTRY.collect(): + if not metric.name.startswith("vllm:"): + continue + if metric.type == "gauge": + samples = _get_samples(metric) + for s in samples: + collected.append( + Gauge(name=metric.name, labels=s.labels, value=s.value)) + elif metric.type == "counter": + samples = _get_samples(metric, "_total") + if metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + # + # Ugly vllm:num_accepted_tokens_per_pos special case. + # + # This metric is a vector of counters - for each spec + # decoding token position, we observe the number of + # accepted tokens using a Counter labeled with 'position'. + # We convert these into a vector of integer values. + # + for labels, values in _digest_num_accepted_by_pos_samples( + samples): + collected.append( + Vector(name=metric.name, labels=labels, values=values)) + else: + for s in samples: + collected.append( + Counter(name=metric.name, + labels=s.labels, + value=int(s.value))) + + elif metric.type == "histogram": + # + # A histogram has a number of '_bucket' samples where + # the 'le' label represents the upper limit of the bucket. + # We convert these bucketized values into a dict of values + # indexed by the value of the 'le' label. The 'le=+Inf' + # label is a special case, catching all values observed. + # + bucket_samples = _get_samples(metric, "_bucket") + count_samples = _get_samples(metric, "_count") + sum_samples = _get_samples(metric, "_sum") + for labels, buckets, count_value, sum_value in _digest_histogram( + bucket_samples, count_samples, sum_samples): + collected.append( + Histogram(name=metric.name, + labels=labels, + buckets=buckets, + count=count_value, + sum=sum_value)) + else: + raise AssertionError(f"Unknown metric type {metric.type}") + + return collected + + +def _get_samples(metric: PromMetric, + suffix: Optional[str] = None) -> list[Sample]: + name = (metric.name + suffix) if suffix is not None else metric.name + return [s for s in metric.samples if s.name == name] + + +def _strip_label(labels: dict[str, str], key_to_remove: str) -> dict[str, str]: + labels_copy = labels.copy() + labels_copy.pop(key_to_remove) + return labels_copy + + +def _digest_histogram( + bucket_samples: list[Sample], count_samples: list[Sample], + sum_samples: list[Sample] +) -> list[tuple[dict[str, str], dict[str, int], int, float]]: + # + # In the case of DP, we have an indigestable + # per-bucket-per-engine count as a list of labelled + # samples, along with total and sum samples + # + # bucket_samples (in): + # labels = {bucket: 100, idx: 0}, value = 2 + # labels = {bucket: 200, idx: 0}, value = 4 + # labels = {bucket: Inf, idx: 0}, value = 10 + # labels = {bucket: 100, idx: 1}, value = 1 + # labels = {bucket: 200, idx: 2}, value = 5 + # labels = {bucket: Inf, idx: 3}, value = 7 + # count_samples (in): + # labels = {idx: 0}, value = 10 + # labels = {idx: 1}, value = 7 + # sum_samples (in): + # labels = {idx: 0}, value = 2000 + # labels = {idx: 1}, value = 1200 + # + # output: [ + # {idx: 0}, {"100": 2, "200": 4, "Inf": 10}, 10, 2000 + # {idx: 1}, {"100": 1, "200": 5, "Inf": 7}, 7, 1200 + # ] + buckets_by_labels: dict[frozenset[tuple[str, str]], dict[str, int]] = {} + for s in bucket_samples: + bucket = s.labels["le"] + labels_key = frozenset(_strip_label(s.labels, "le").items()) + if labels_key not in buckets_by_labels: + buckets_by_labels[labels_key] = {} + buckets_by_labels[labels_key][bucket] = int(s.value) + + counts_by_labels: dict[frozenset[tuple[str, str]], int] = {} + for s in count_samples: + labels_key = frozenset(s.labels.items()) + counts_by_labels[labels_key] = int(s.value) + + sums_by_labels: dict[frozenset[tuple[str, str]], float] = {} + for s in sum_samples: + labels_key = frozenset(s.labels.items()) + sums_by_labels[labels_key] = s.value + + assert set(buckets_by_labels.keys()) == set( + counts_by_labels.keys()) == set(sums_by_labels.keys()) + + output = [] + label_keys = list(buckets_by_labels.keys()) + for k in label_keys: + labels = dict(k) + output.append((labels, buckets_by_labels[k], counts_by_labels[k], + sums_by_labels[k])) + return output + + +def _digest_num_accepted_by_pos_samples( + samples: list[Sample]) -> list[tuple[dict[str, str], list[int]]]: + # + # In the case of DP, we have an indigestable + # per-position-per-engine count as a list of + # labelled samples + # + # samples (in): + # labels = {pos: 0, idx: 0}, value = 10 + # labels = {pos: 1, idx: 0}, value = 7 + # labels = {pos: 2, idx: 0}, value = 2 + # labels = {pos: 0, idx: 1}, value = 5 + # labels = {pos: 1, idx: 1}, value = 3 + # labels = {pos: 2, idx: 1}, value = 1 + # + # output: [ + # {idx: 0}, [10, 7, 2] + # {idx: 1}, [5, 3, 1] + # ] + # + max_pos = 0 + values_by_labels: dict[frozenset[tuple[str, str]], dict[int, int]] = {} + + for s in samples: + position = int(s.labels["position"]) + max_pos = max(max_pos, position) + + labels_key = frozenset(_strip_label(s.labels, "position").items()) + if labels_key not in values_by_labels: + values_by_labels[labels_key] = {} + values_by_labels[labels_key][position] = int(s.value) + + output = [] + for labels_key, values_by_position in values_by_labels.items(): + labels = dict(labels_key) + values = [0] * (max_pos + 1) + for pos, val in values_by_position.items(): + values[pos] = val + output.append((labels, values)) + return output diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 899aa9200e85..36091bef2895 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -134,17 +134,17 @@ def __init__( self.counter_spec_decode_num_drafts = \ self._counter_cls( - name="vllm:spec_decode_num_drafts_total", + name="vllm:spec_decode_num_drafts", documentation="Number of spec decoding drafts.", labelnames=labelnames).labels(*labelvalues) self.counter_spec_decode_num_draft_tokens = \ self._counter_cls( - name="vllm:spec_decode_num_draft_tokens_total", + name="vllm:spec_decode_num_draft_tokens", documentation="Number of draft tokens.", labelnames=labelnames,).labels(*labelvalues) self.counter_spec_decode_num_accepted_tokens = \ self._counter_cls( - name="vllm:spec_decode_num_accepted_tokens_total", + name="vllm:spec_decode_num_accepted_tokens", documentation="Number of accepted tokens.", labelnames=labelnames).labels(*labelvalues)