Skip to content

Commit a9ef49c

Browse files
authored
Detokenize incrementally when streaming (#653)
1 parent 21ba3a8 commit a9ef49c

File tree

5 files changed

+101
-33
lines changed

5 files changed

+101
-33
lines changed

python/sglang/srt/layers/radix_attention.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,33 @@ def forward(self, q, k, v, input_metadata: InputMetadata):
136136
return self.decode_forward(q, k, v, input_metadata)
137137

138138
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
139-
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
140-
key_buffer[input_metadata.out_cache_loc] = cache_k
141-
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
142-
value_buffer[input_metadata.out_cache_loc] = cache_v
139+
kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
140+
_store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
141+
142+
143+
try:
144+
145+
@torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
146+
def _store_kv_cache(
147+
k: torch.Tensor,
148+
v: torch.Tensor,
149+
kv_cache: torch.Tensor,
150+
cache_loc: torch.Tensor,
151+
) -> None:
152+
kv_cache[cache_loc, 0] = k
153+
kv_cache[cache_loc, 1] = v
154+
155+
@_store_kv_cache.register_fake
156+
def _(k, v, kv_cache, cache_loc):
157+
pass
158+
159+
except:
160+
161+
def _store_kv_cache(
162+
k: torch.Tensor,
163+
v: torch.Tensor,
164+
kv_cache: torch.Tensor,
165+
cache_loc: torch.Tensor,
166+
) -> None:
167+
kv_cache[cache_loc, 0] = k
168+
kv_cache[cache_loc, 1] = v

python/sglang/srt/managers/controller/infer_batch.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ def __init__(self, rid, origin_input_text, origin_input_ids):
8282
self.input_ids = None # input_ids = origin_input_ids + output_ids
8383

8484
# For incremental decoding
85+
# ----- | --------- read_ids -------|
86+
# ----- | surr_ids |
87+
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
88+
# ----- ^ ----------- ^ ----------- ^
89+
# ----- 1 ----------- 2 ----------- 3
90+
# 1: surr_offset
91+
# 2: read_offset
92+
# 3: last token
8593
self.decoded_text = ""
8694
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
8795
self.read_offset = None
@@ -132,7 +140,7 @@ def finished(self) -> bool:
132140
return self.finished_reason is not None
133141

134142
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
135-
def init_detokenize_incrementally(self):
143+
def init_incremental_detokenize(self):
136144
first_iter = self.surr_offset is None or self.read_offset is None
137145

138146
if first_iter:
@@ -142,13 +150,11 @@ def init_detokenize_incrementally(self):
142150
)
143151

144152
all_ids = self.origin_input_ids_unpadded + self.output_ids
145-
surr_ids = all_ids[self.surr_offset : self.read_offset]
146-
read_ids = all_ids[self.surr_offset :]
153+
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
147154

148-
return surr_ids, read_ids, len(all_ids)
149-
150-
def detokenize_incrementally(self, inplace: bool = True):
151-
surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
155+
def get_next_inc_detokenization(self):
156+
read_ids, read_offset = self.init_incremental_detokenize()
157+
surr_ids = read_ids[:read_offset]
152158

153159
surr_text = self.tokenizer.decode(
154160
surr_ids,
@@ -162,13 +168,7 @@ def detokenize_incrementally(self, inplace: bool = True):
162168
)
163169

164170
if len(new_text) > len(surr_text) and not new_text.endswith("�"):
165-
new_text = new_text[len(surr_text) :]
166-
if inplace:
167-
self.decoded_text += new_text
168-
self.surr_offset = self.read_offset
169-
self.read_offset = num_all_tokens
170-
171-
return True, new_text
171+
return True, new_text[len(surr_text) :]
172172

173173
return False, ""
174174

@@ -501,7 +501,7 @@ def check_for_jump_forward(self, model_runner):
501501
cur_output_ids = req.output_ids
502502

503503
req.output_ids.extend(suffix_ids)
504-
decode_res, new_text = req.detokenize_incrementally(inplace=False)
504+
decode_res, new_text = req.get_next_inc_detokenization()
505505
if not decode_res:
506506
req.output_ids = cur_output_ids
507507
continue

python/sglang/srt/managers/controller/tp_worker.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -590,8 +590,8 @@ def forward_decode_batch(self, batch: Batch):
590590
def handle_finished_requests(self, batch: Batch):
591591
output_rids = []
592592
decoded_texts = []
593-
surr_output_ids = []
594-
read_output_ids = []
593+
output_read_ids = []
594+
output_read_offsets = []
595595
output_skip_special_tokens = []
596596
output_spaces_between_special_tokens = []
597597
output_meta_info = []
@@ -615,9 +615,9 @@ def handle_finished_requests(self, batch: Batch):
615615
):
616616
output_rids.append(req.rid)
617617
decoded_texts.append(req.decoded_text)
618-
surr_ids, read_ids, _ = req.init_detokenize_incrementally()
619-
surr_output_ids.append(surr_ids)
620-
read_output_ids.append(read_ids)
618+
read_ids, read_offset = req.init_incremental_detokenize()
619+
output_read_ids.append(read_ids)
620+
output_read_offsets.append(read_offset)
621621
output_skip_special_tokens.append(
622622
req.sampling_params.skip_special_tokens
623623
)
@@ -654,8 +654,8 @@ def handle_finished_requests(self, batch: Batch):
654654
BatchTokenIDOut(
655655
output_rids,
656656
decoded_texts,
657-
surr_output_ids,
658-
read_output_ids,
657+
output_read_ids,
658+
output_read_offsets,
659659
output_skip_special_tokens,
660660
output_spaces_between_special_tokens,
661661
output_meta_info,

python/sglang/srt/managers/detokenizer_manager.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""DetokenizerManager is a process that detokenizes the token ids."""
22

33
import asyncio
4+
import dataclasses
45
import inspect
6+
from typing import List
57

68
import uvloop
79
import zmq
@@ -16,6 +18,14 @@
1618
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
1719

1820

21+
@dataclasses.dataclass
22+
class DecodeStatus:
23+
decoded_text: str
24+
decode_ids: List[int]
25+
surr_offset: int
26+
read_offset: int
27+
28+
1929
class DetokenizerManager:
2030
def __init__(
2131
self,
@@ -35,31 +45,63 @@ def __init__(
3545
trust_remote_code=server_args.trust_remote_code,
3646
)
3747

48+
self.decode_status = {}
49+
3850
async def handle_loop(self):
3951
while True:
4052
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
4153
assert isinstance(recv_obj, BatchTokenIDOut)
54+
bs = len(recv_obj.rids)
55+
56+
# FIXME: incremental detokenize is not compatible with jump forward
57+
# Initialize decode status
58+
read_ids, surr_ids = [], []
59+
for i in range(bs):
60+
rid = recv_obj.rids[i]
61+
if rid not in self.decode_status:
62+
s = DecodeStatus(
63+
decoded_text=recv_obj.decoded_texts[i],
64+
decode_ids=recv_obj.decode_ids[i],
65+
surr_offset=0,
66+
read_offset=recv_obj.read_offsets[i],
67+
)
68+
self.decode_status[rid] = s
69+
else:
70+
s = self.decode_status[rid]
71+
s.decode_ids = recv_obj.decode_ids[i]
72+
73+
read_ids.append(s.decode_ids[s.surr_offset :])
74+
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
4275

4376
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
4477
surr_texts = self.tokenizer.batch_decode(
45-
recv_obj.surr_output_ids,
78+
surr_ids,
4679
skip_special_tokens=recv_obj.skip_special_tokens[0],
4780
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
4881
)
4982
read_texts = self.tokenizer.batch_decode(
50-
recv_obj.read_output_ids,
83+
read_ids,
5184
skip_special_tokens=recv_obj.skip_special_tokens[0],
5285
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
5386
)
5487

5588
# Trim stop str
5689
# TODO(lmzheng): handle the case where multiple stop strs are hit
5790
output_strs = []
58-
for i in range(len(recv_obj.rids)):
91+
for i in range(bs):
92+
s = self.decode_status[recv_obj.rids[i]]
5993
new_text = read_texts[i][len(surr_texts[i]) :]
6094
if recv_obj.finished_reason[i] is None:
61-
new_text = find_printable_text(new_text)
62-
output_strs.append(recv_obj.decoded_texts[i] + new_text)
95+
# Streaming chunk: update the decode status
96+
if len(new_text) > 0 and not new_text.endswith("�"):
97+
s.decoded_text = s.decoded_text + new_text
98+
s.surr_offset = s.read_offset
99+
s.read_offset = len(s.decode_ids)
100+
new_text = ""
101+
else:
102+
new_text = find_printable_text(new_text)
103+
104+
output_strs.append(s.decoded_text + new_text)
63105

64106
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
65107
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)

python/sglang/srt/managers/io_struct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ class TokenizedGenerateReqInput:
111111
class BatchTokenIDOut:
112112
rids: List[str]
113113
decoded_texts: List[str]
114-
surr_output_ids: List[List[int]]
115-
read_output_ids: List[List[int]]
114+
decode_ids: List[int]
115+
read_offsets: List[int]
116116
skip_special_tokens: List[bool]
117117
spaces_between_special_tokens: List[bool]
118118
meta_info: List[Dict]

0 commit comments

Comments
 (0)