11"""DetokenizerManager is a process that detokenizes the token ids."""
22
33import asyncio
4+ import dataclasses
45import inspect
6+ from typing import List
57
68import uvloop
79import zmq
1618asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
1719
1820
21+ @dataclasses .dataclass
22+ class DecodeStatus :
23+ decoded_text : str
24+ decode_ids : List [int ]
25+ surr_offset : int
26+ read_offset : int
27+
28+
1929class DetokenizerManager :
2030 def __init__ (
2131 self ,
@@ -35,31 +45,63 @@ def __init__(
3545 trust_remote_code = server_args .trust_remote_code ,
3646 )
3747
48+ self .decode_status = {}
49+
3850 async def handle_loop (self ):
3951 while True :
4052 recv_obj : BatchTokenIDOut = await self .recv_from_router .recv_pyobj ()
4153 assert isinstance (recv_obj , BatchTokenIDOut )
54+ bs = len (recv_obj .rids )
55+
56+ # FIXME: incremental detokenize is not compatible with jump forward
57+ # Initialize decode status
58+ read_ids , surr_ids = [], []
59+ for i in range (bs ):
60+ rid = recv_obj .rids [i ]
61+ if rid not in self .decode_status :
62+ s = DecodeStatus (
63+ decoded_text = recv_obj .decoded_texts [i ],
64+ decode_ids = recv_obj .decode_ids [i ],
65+ surr_offset = 0 ,
66+ read_offset = recv_obj .read_offsets [i ],
67+ )
68+ self .decode_status [rid ] = s
69+ else :
70+ s = self .decode_status [rid ]
71+ s .decode_ids = recv_obj .decode_ids [i ]
72+
73+ read_ids .append (s .decode_ids [s .surr_offset :])
74+ surr_ids .append (s .decode_ids [s .surr_offset : s .read_offset ])
4275
4376 # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
4477 surr_texts = self .tokenizer .batch_decode (
45- recv_obj . surr_output_ids ,
78+ surr_ids ,
4679 skip_special_tokens = recv_obj .skip_special_tokens [0 ],
4780 spaces_between_special_tokens = recv_obj .spaces_between_special_tokens [0 ],
4881 )
4982 read_texts = self .tokenizer .batch_decode (
50- recv_obj . read_output_ids ,
83+ read_ids ,
5184 skip_special_tokens = recv_obj .skip_special_tokens [0 ],
5285 spaces_between_special_tokens = recv_obj .spaces_between_special_tokens [0 ],
5386 )
5487
5588 # Trim stop str
5689 # TODO(lmzheng): handle the case where multiple stop strs are hit
5790 output_strs = []
58- for i in range (len (recv_obj .rids )):
91+ for i in range (bs ):
92+ s = self .decode_status [recv_obj .rids [i ]]
5993 new_text = read_texts [i ][len (surr_texts [i ]) :]
6094 if recv_obj .finished_reason [i ] is None :
61- new_text = find_printable_text (new_text )
62- output_strs .append (recv_obj .decoded_texts [i ] + new_text )
95+ # Streaming chunk: update the decode status
96+ if len (new_text ) > 0 and not new_text .endswith ("�" ):
97+ s .decoded_text = s .decoded_text + new_text
98+ s .surr_offset = s .read_offset
99+ s .read_offset = len (s .decode_ids )
100+ new_text = ""
101+ else :
102+ new_text = find_printable_text (new_text )
103+
104+ output_strs .append (s .decoded_text + new_text )
63105
64106 if isinstance (recv_obj .finished_reason [i ], FINISH_MATCHED_STR ):
65107 pos = output_strs [i ].find (recv_obj .finished_reason [i ].matched )
0 commit comments