Skip to content

Commit c663f31

Browse files
committed
Fix critical NWOR bugs causing 0% acceptance rate
Bug #1 (CRITICAL): Add missing begin() and stage() methods to KVWriteRouter - Flash attention backend calls router.begin() and router.stage() - KVWriteRouter only had write() and commit() methods - Added begin() to store slot_mapping and initialize shadow buffer - Added stage() to extract per-timestep slot and stage KV pairs - Without these, no tokens were being staged → 0% acceptance rate Bug #2 (MODERATE): Fix bonus token counting in accepted_lens - valid_sampled_token_ids includes [accepted_draft_tokens..., bonus_token] - Previous: len([bonus]) = 1, incorrectly counted as 1 accepted draft token - Fixed: Use max(0, len(seq) - 1) to exclude bonus token from count - Now correctly reports 0 accepted when only bonus token is present Files modified: - vllm/v1/kv_cache/write_router.py: Added begin() and stage() methods - vllm/v1/worker/gpu_model_runner.py: Fixed accepted_lens calculation
1 parent a3c136b commit c663f31

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

vllm/v1/kv_cache/write_router.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(self, persistent_writer: PersistentKVWriter):
100100
self._persistent = persistent_writer
101101
self._shadow = None
102102
self._mode = "immediate" # or "defer"
103+
self._slot_mapping = None # Stored during begin() for use in stage()
103104

104105
def immediate(self):
105106
"""Switch to immediate write mode (normal operation)."""
@@ -116,6 +117,53 @@ def defer(self, shadow):
116117
self._mode = "defer"
117118
self._shadow = shadow
118119

120+
@torch.no_grad()
121+
def begin(self, length_hint: int, slot_mapping: torch.Tensor, seg_lens: Optional[torch.Tensor] = None):
122+
"""
123+
Begin staging for a verification window.
124+
Called by flash_attn backend before staging tokens.
125+
126+
Args:
127+
length_hint: Expected number of tokens to stage
128+
slot_mapping: Slot mapping tensor for all tokens in this window
129+
seg_lens: Segment lengths (optional, for context)
130+
"""
131+
if self._mode == "defer" and self._shadow is not None:
132+
# Store slot_mapping for use in stage() calls
133+
self._slot_mapping = slot_mapping
134+
# Initialize shadow buffer for this verification window
135+
self._shadow.begin(length_hint)
136+
137+
@torch.no_grad()
138+
def stage(self, layer_idx: int, t: int, k_slice: torch.Tensor, v_slice: torch.Tensor):
139+
"""
140+
Stage a single timestep's KV during verification.
141+
Called by flash_attn backend for each token being verified.
142+
143+
Args:
144+
layer_idx: Transformer layer index
145+
t: Position in the staging buffer (0-indexed)
146+
k_slice: Key tensor [1, H, D]
147+
v_slice: Value tensor [1, H, D]
148+
"""
149+
if self._mode == "defer" and self._shadow is not None:
150+
# Extract slot mapping for this specific timestep
151+
if self._slot_mapping is not None:
152+
slot_t = self._slot_mapping[t:t+1]
153+
else:
154+
# Fallback: create a dummy slot mapping
155+
slot_t = torch.tensor([t], dtype=torch.int64, device=k_slice.device)
156+
157+
# Stage in shadow buffer
158+
self._shadow.stage(layer_idx, t, k_slice, v_slice, slot_t)
159+
elif self._mode == "immediate":
160+
# In immediate mode, write directly to persistent cache
161+
if self._slot_mapping is not None:
162+
slot_t = self._slot_mapping[t:t+1]
163+
else:
164+
slot_t = torch.tensor([t], dtype=torch.int64, device=k_slice.device)
165+
self._persistent.append_slice(layer_idx, k_slice, v_slice, slot_t)
166+
119167
@torch.no_grad()
120168
def write(self,
121169
layer_idx: int,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2353,10 +2353,14 @@ def propose_draft_token_ids(sampled_token_ids):
23532353
# NWOR: Commit accepted tokens and disarm router
23542354
if self._router_token is not None:
23552355
if isinstance(valid_sampled_token_ids, list):
2356-
accepted_lens = torch.tensor([len(seq) for seq in valid_sampled_token_ids],
2356+
# Compute actual draft tokens accepted (exclude bonus token)
2357+
# valid_sampled_token_ids includes: [accepted_draft_tokens..., bonus_token]
2358+
# So len(seq) - 1 gives the number of accepted draft tokens
2359+
accepted_lens = torch.tensor([max(0, len(seq) - 1) for seq in valid_sampled_token_ids],
23572360
dtype=torch.int32, device=self.device)
23582361
else:
2359-
accepted_lens = torch.tensor([1] * len(valid_sampled_token_ids),
2362+
# Non-list case: assume no draft tokens accepted (only bonus)
2363+
accepted_lens = torch.tensor([0] * len(valid_sampled_token_ids),
23602364
dtype=torch.int32, device=self.device)
23612365

23622366
self.drafter.kv_router.commit(accepted_lens)

0 commit comments

Comments
 (0)