|
6 | 6 | from transformers import AutoTokenizer |
7 | 7 |
|
8 | 8 | from vllm import LLM, SamplingParams |
| 9 | +from vllm.v1.metrics.reader import Counter, Vector |
9 | 10 |
|
10 | 11 |
|
11 | 12 | def load_prompts(dataset_path, num_prompts): |
@@ -105,30 +106,33 @@ def main(): |
105 | 106 | print(f"generated text: {output.outputs[0].text}") |
106 | 107 | print("-" * 50) |
107 | 108 |
|
108 | | - if not hasattr(outputs, "metrics") or outputs.metrics is None: |
| 109 | + try: |
| 110 | + metrics = llm.get_metrics() |
| 111 | + except AssertionError: |
| 112 | + print("Metrics are not supported in the V0 engine.") |
109 | 113 | return |
110 | 114 |
|
111 | | - # calculate the average number of accepted tokens per forward pass, +1 is |
112 | | - # to account for the token from the target model that's always going to be |
113 | | - # accepted |
114 | | - acceptance_counts = [0] * (args.num_spec_tokens + 1) |
115 | | - for output in outputs: |
116 | | - for step, count in enumerate(output.metrics.spec_token_acceptance_counts): |
117 | | - acceptance_counts[step] += count |
| 115 | + num_drafts = num_accepted = 0 |
| 116 | + acceptance_counts = [0] * args.num_spec_tokens |
| 117 | + for metric in metrics: |
| 118 | + if metric.name == "vllm:spec_decode_num_drafts": |
| 119 | + assert isinstance(metric, Counter) |
| 120 | + num_drafts += metric.value |
| 121 | + elif metric.name == "vllm:spec_decode_num_accepted_tokens": |
| 122 | + assert isinstance(metric, Counter) |
| 123 | + num_accepted += metric.value |
| 124 | + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": |
| 125 | + assert isinstance(metric, Vector) |
| 126 | + for pos in range(len(metric.values)): |
| 127 | + acceptance_counts[pos] += metric.values[pos] |
118 | 128 |
|
119 | 129 | print("-" * 50) |
120 | | - print( |
121 | | - f"mean acceptance length (including bonus tokens): \ |
122 | | - {1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}" |
123 | | - ) |
| 130 | + print(f"mean acceptance length: {1 + (num_accepted / num_drafts):.2f}") |
124 | 131 | print("-" * 50) |
125 | 132 |
|
126 | 133 | # print acceptance at each token position |
127 | 134 | for i in range(len(acceptance_counts)): |
128 | | - print( |
129 | | - f"acceptance at token {i}:" |
130 | | - f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}" |
131 | | - ) |
| 135 | + print(f"acceptance at token {i}:{acceptance_counts[i] / num_drafts:.2f}") |
132 | 136 |
|
133 | 137 |
|
134 | 138 | if __name__ == "__main__": |
|
0 commit comments