|
23 | 23 | from vllm.sampling_params import SamplingType |
24 | 24 | from vllm.sequence import IntermediateTensors |
25 | 25 | from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available |
26 | | -from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, |
27 | | - NUM_QUERIES_PER_BLOCK, |
28 | | - PallasAttentionBackend, |
| 26 | +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, |
29 | 27 | PallasMetadata) |
30 | 28 | from vllm.v1.core.encoder_cache_manager import compute_encoder_budget |
31 | 29 | from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, |
@@ -78,10 +76,8 @@ def __init__( |
78 | 76 | self.block_size = cache_config.block_size |
79 | 77 | self.max_model_len = model_config.max_model_len |
80 | 78 | self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) |
81 | | - self.max_num_tokens = _get_padded_number( |
82 | | - scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK) |
83 | | - self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs, |
84 | | - NUM_QUERIES_PER_BLOCK) |
| 79 | + self.max_num_tokens = scheduler_config.max_num_batched_tokens |
| 80 | + self.max_num_reqs = scheduler_config.max_num_seqs |
85 | 81 |
|
86 | 82 | # Model-related. |
87 | 83 | self.num_attn_layers = model_config.get_num_layers_by_block_type( |
@@ -142,16 +138,8 @@ def __init__( |
142 | 138 | device="cpu") |
143 | 139 | self.slot_mapping_np = self.slot_mapping_cpu.numpy() |
144 | 140 |
|
145 | | - # self.input_batch.block_table has a shape of [max_num_reqs, |
146 | | - # max_num_blocks_per_req]. To reduce the number of recompilation, |
147 | | - # we want the block_table.shape[0] to be num_tokens. |
148 | | - # To make the block_table to be compatible with the paged attention |
149 | | - # kernel, we want the block_table[1] to be multiple of |
150 | | - # NUM_KV_PAGES_PER_BLOCK. |
151 | | - padded_max_num_blocks_per_req = _get_padded_number( |
152 | | - self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) |
153 | 141 | self.block_table_cpu = torch.zeros( |
154 | | - (self.max_num_tokens, padded_max_num_blocks_per_req), |
| 142 | + (self.max_num_tokens, self.max_num_blocks_per_req), |
155 | 143 | dtype=self.input_batch.block_table.get_cpu_tensor().dtype, |
156 | 144 | device="cpu") |
157 | 145 |
|
|
0 commit comments