diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index b0fcbbea2..727339992 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -426,7 +426,9 @@ async def process_single_prompt(prompt: Union[str, List[int]], prompt_index: int prompt, individual_sampling_params, multimodal_params, request=raw_request ) - return await _collect_generation_results(generator, request, prompt_str, prompt_index) + return await _collect_generation_results( + generator, request, prompt_str, prompt_index, individual_sampling_params + ) tasks = [asyncio.create_task(process_single_prompt(prompt, i)) for i, prompt in enumerate(prompts)] @@ -485,7 +487,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) -async def _collect_generation_results(generator, request: CompletionRequest, prompt: str, prompt_index: int): +async def _collect_generation_results( + generator, request: CompletionRequest, prompt: str, prompt_index: int, sampling_params: SamplingParams +): final_output = [] count_output_tokens = 0 finish_reason = None @@ -516,9 +520,20 @@ async def _collect_generation_results(generator, request: CompletionRequest, pro finish_reason = finish_status.get_finish_reason() prompt_tokens = metadata["prompt_tokens"] + # 处理停止序列剔除 + final_text = "".join(final_output) + if finish_reason == "stop" and sampling_params.stop_sequences.size > 0: + valid_stop_strings = sampling_params.stop_sequences.to_strings() + for stop_str in valid_stop_strings: + stop_index = final_text.rfind(stop_str, max(0, len(final_text) - len(stop_str) - 20), len(final_text)) + if stop_index != -1: + logger.debug(f"removed stop sequence in tail: '{final_text[stop_index:]}'") + final_text = final_text[:stop_index] + break + return { "index": prompt_index, - "text": "".join(final_output), + "text": final_text, "finish_reason": finish_reason, "prompt_tokens": prompt_tokens, "completion_tokens": count_output_tokens, diff --git a/lightllm/server/core/objs/io_objs/__init__.py b/lightllm/server/core/objs/io_objs/__init__.py index 80f4f0772..c9b806c47 100644 --- a/lightllm/server/core/objs/io_objs/__init__.py +++ b/lightllm/server/core/objs/io_objs/__init__.py @@ -1 +1 @@ -from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd +from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, StopStrMatchedReqCmd diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index d16dc4d06..dfcbdd256 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -31,3 +31,8 @@ def to_group_req_index(self): @dataclass class AbortedReqCmd: req_id: int + + +@dataclass +class StopStrMatchedReqCmd: + req_id: int diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 06a728925..195de4148 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -32,6 +32,9 @@ def get_status(self): def is_finished(self): return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH + def is_stopped(self): + return self.status == self.FINISHED_STOP + def get_finish_reason(self): if self.status == self.FINISHED_STOP: return "stop" @@ -74,10 +77,8 @@ class Req(ctypes.Structure): ("prompt_cache_len", ctypes.c_int), # 用于记录prompt cache 的命中长度,用于统计 ("is_paused", ctypes.c_bool), # 标记一个Req因为显存资源管理的原因被临时暂停了。 ("finish_status", FinishStatus), + # 这个标记变量是http_server 写入,其他进程读取,用于标记该请求是否因为断网被aborted。 ("is_aborted", ctypes.c_bool), - # 这个标记变量是router进程读取到is_aborted信息后,router 进程标记该请求已经被abort处理 - # 等待推理进程处理,防止router进程反复给推理进程发送abort指令。 - ("router_aborted", ctypes.c_bool), # 当FinishStatus 是正常结束状态时,finish_token_index 用于标识结束的 # token 的index位置 ("finish_token_index", ctypes.c_int), @@ -97,6 +98,12 @@ class Req(ctypes.Structure): ("mtp_accepted_token_num", ctypes.c_int), # mtp_step 保存一个mtp使用的常量参数,用于快速访问,不会被外部输入初始化 ("_mtp_step", ctypes.c_int), + # stop_str_matched 用于判断停止字符串是否匹配成功, detokenization 进程写入,router 进程读取 + # 然后router发停止命令给推理进程,推理进程停止输出 + ("stop_str_matched", ctypes.c_bool), + # 当 stop_str_matched 条件满足的时候,对应的最后一个生成 token 所在的index位置。 + # 该变量为 detokenization 进程写入,http_server 读取 + ("stop_str_matched_token_index", ctypes.c_int), ] def get_str(self): @@ -124,7 +131,6 @@ def init( self.is_paused = False self.finish_status = FinishStatus() self.is_aborted = False - self.router_aborted = False self.shm_infer_released = False self.shm_cur_kv_len = 0 self.shm_cur_output_len = 0 @@ -150,6 +156,8 @@ def init( self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids self.mtp_accepted_token_num = 0 self._mtp_step = get_env_start_args().mtp_step + self.stop_str_matched = False + self.stop_str_matched_token_index = -1 self.post_init() @@ -210,7 +218,9 @@ def can_release(self): if self.is_aborted and can_released_mark and ref_count_ok: return True - if self.finish_status.is_finished() and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty(): + ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched + + if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty(): return True return False diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 634dd208d..d1ee13cb6 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -1,6 +1,6 @@ import os import ctypes -from typing import List, Tuple, Union +from typing import Optional, List, Tuple, Union from transformers import GenerationConfig from lightllm.server.req_id_generator import MAX_BEST_OF @@ -10,6 +10,7 @@ # 从环境变量获取最大长度限制 STOP_SEQUENCE_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_MAX_LENGTH", 256)) +STOP_SEQUENCE_STR_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_STR_MAX_LENGTH", 256)) ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256)) MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10)) REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048)) @@ -22,17 +23,30 @@ class StopSequence(ctypes.Structure): _fields_ = [ ("sequence", ctypes.c_int * STOP_SEQUENCE_MAX_LENGTH), ("size", ctypes.c_int), + ("sequence_str", ctypes.c_char * STOP_SEQUENCE_STR_MAX_LENGTH), + ("sequence_str_len", ctypes.c_int), ] - def initialize(self, sequence: List[int]): + def initialize(self, sequence: List[int], sequence_str: Optional[str] = None): self.size = len(sequence) assert self.size <= STOP_SEQUENCE_MAX_LENGTH, "stop token length too long." assert all(isinstance(e, int) for e in sequence), "all must be int" self.sequence[: self.size] = sequence[:] - def to_list(self): + if sequence_str is not None: + sequence_str_bytes = sequence_str.encode("utf-8") + assert len(sequence_str_bytes) < STOP_SEQUENCE_STR_MAX_LENGTH, "stop sequence string too long." + self.sequence_str = sequence_str_bytes + self.sequence_str_len = len(sequence_str_bytes) + else: + self.sequence_str_len = 0 + + def to_list(self) -> List[int]: return list(self.sequence[0 : self.size]) + def to_string(self) -> str: + return bytes(self.sequence_str[0 : self.sequence_str_len]).decode("utf-8") + class StopSequenceGroups(ctypes.Structure): _pack_ = 4 @@ -41,40 +55,52 @@ class StopSequenceGroups(ctypes.Structure): ("size", ctypes.c_int), ] - def initialize(self, stop_sequences: Union[str, List], tokenizer): + def initialize(self, stop_sequences: Union[str, List[Union[List[int], str]]], tokenizer): + if stop_sequences is None: + stop_sequences = [] + elif isinstance(stop_sequences, str): + stop_sequences = [stop_sequences] + groups: List[List[int]] = self.stop_sentences_to_token_ids(stop_sequences, tokenizer) self.size = len(groups) assert self.size <= MAX_STOP_SEQUENCES, "Too many stop sequence groups." - for group_idx in range(self.size): - self.groups[group_idx].initialize(groups[group_idx]) - def stop_sentences_to_token_ids(self, stop_sequences, tokenizer): - if stop_sequences is None: - stop_sequences = [] - else: - if isinstance(stop_sequences, str): - stop_sequences = [stop_sequences] - - new_stop_sequences = [] - for stop_info in stop_sequences: - if isinstance(stop_info, str): - stop_str_ids = self._stop_str_to_token_ids(stop_info, tokenizer) - if stop_str_ids is not None and len(stop_str_ids) > 0: - new_stop_sequences.append(stop_str_ids) - if isinstance(stop_info, list): - if all(isinstance(x, int) for x in stop_info): - if len(stop_info) > 0: - new_stop_sequences.append(stop_info) - stop_sequences = new_stop_sequences - return stop_sequences - - def _stop_str_to_token_ids(self, stop_str: str, tokenizer): + for group_idx in range(self.size): + if isinstance(stop_sequences[group_idx], str): + self.groups[group_idx].initialize(groups[group_idx], sequence_str=stop_sequences[group_idx]) + else: + self.groups[group_idx].initialize(groups[group_idx]) + + def stop_sentences_to_token_ids(self, stop_sequences: List[Union[List[int], str]], tokenizer) -> List[List[int]]: + new_stop_sequences = [] + for stop_info in stop_sequences: + if isinstance(stop_info, str): + stop_str_ids = self._stop_str_to_token_ids(stop_info, tokenizer) + if stop_str_ids is not None and len(stop_str_ids) > 0: + new_stop_sequences.append(stop_str_ids) + if isinstance(stop_info, list): + if all(isinstance(x, int) for x in stop_info): + if len(stop_info) > 0: + new_stop_sequences.append(stop_info) + else: + assert False, "stop_sequences item must be type List[int] when it is a list." + return new_stop_sequences + + def _stop_str_to_token_ids(self, stop_str: str, tokenizer) -> List[int]: stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False) return stop_str_ids - def to_list(self): + def to_list(self) -> List[List[int]]: return [self.groups[i].to_list() for i in range(self.size)] + def to_strings(self) -> List[str]: + # 降序匹配,在出现"\n\n"和"\n"情况时,优先匹配“\n\n” + return sorted( + [self.groups[i].to_string() for i in range(self.size) if self.groups[i].sequence_str_len > 0], + key=len, + reverse=True, + ) + class RegularConstraint(ctypes.Structure): _pack_ = 4 diff --git a/lightllm/server/detokenization/decode_req.py b/lightllm/server/detokenization/decode_req.py index 9a85ea089..9aa3a8eff 100644 --- a/lightllm/server/detokenization/decode_req.py +++ b/lightllm/server/detokenization/decode_req.py @@ -1,6 +1,10 @@ import os from typing import List, Dict from lightllm.server.core.objs import Req +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + LIGHTLLM_DECODE_PREFIX_LENGTH = int(os.getenv("LIGHTLLM_DECODE_PREFIX_LENGTH", 5)) @@ -15,6 +19,7 @@ def __init__( self.group_req_id = req.group_req_id self.prompt_ids = req.shm_prompt_ids.arr[0 : req.input_len].tolist() self.output_ids = [] + self.output_strs = [] self.prefix_offset = max(len(self.prompt_ids) - LIGHTLLM_DECODE_PREFIX_LENGTH, 0) if is_pd_decode_mode: @@ -26,6 +31,9 @@ def __init__( self.req = req self.input_len = self.req.input_len self.prefix_str = "" + self.stop_strs: List[str] = self.req.sample_params.stop_sequences.to_strings() + # to_strings()已经做了倒序排列,第一个元素就是最长字符串 + self.stop_str_max_len = len(self.stop_strs[0]) if self.stop_strs else 0 def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], tokenizer): tokens = [token_id_to_token[token_id] for token_id in self.req.prefix_token_ids.get_token_ids()] @@ -35,8 +43,30 @@ def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], token self.prefix_str = "" return + def stop_sequences_str_match(self) -> bool: + stop_strs = self.stop_strs + if not stop_strs or self.stop_str_max_len == 0: + return False + + tail_token_len = self.stop_str_max_len + 10 # 10 for safety + tail_token_strs = self.output_strs[-tail_token_len:] + tail_str = "".join(tail_token_strs) + + for stop_str in stop_strs: + if stop_str in tail_str: + logger.debug( + f"req_id {self.request_id} Found stop sequence in tail: stop_str='{stop_str}', " + f"tail_str='{tail_str}'" + ) + return True + return False + def need_detoken(self): - if (not self.req.is_aborted) and len(self.output_ids) < self.req.candetoken_out_len: + if ( + (not self.req.is_aborted) + and (not self.req.stop_str_matched) + and len(self.output_ids) < self.req.candetoken_out_len + ): return True return False @@ -55,6 +85,8 @@ def get_decode_tokens(self): def can_set_release_mark(self): if self.req.is_aborted: return True + if self.req.stop_str_matched: + return True if ( self.req.finish_status.is_finished() and self.req.candetoken_out_len == len(self.output_ids) diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index f57eae333..57922ff62 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -105,6 +105,10 @@ def gen_token_out(self): exist_need_detoken = False exist_decode = False for decode_req in self.req_id_to_out.values(): + # 已经满足停止字符串停止条件,则不再处理后续生成 token + if decode_req.req.stop_str_matched: + continue + if decode_req.need_detoken() and not decode_req.out_queue_is_full(): new_token_id, src_index = decode_req.get_next_token_id_and_index() decode_req.output_ids.append(new_token_id) @@ -131,6 +135,14 @@ def gen_token_out(self): logger.error( f"error token healing state, prefix_str {decode_req.prefix_str} new_text {new_text}" ) + + decode_req.output_strs.append(new_text) + + # 停止字符串匹配 + if not decode_req.req.finish_status.is_stopped() and decode_req.stop_sequences_str_match(): + decode_req.req.stop_str_matched_token_index = src_index + decode_req.req.stop_str_matched = True + decode_req.req.out_tokens_queue.push(new_text, src_index, special, count_output_tokens) if decode_req.need_detoken(): diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index de552d80c..b99ba6cce 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -679,10 +679,18 @@ async def handle_loop(self): req.out_tokens_queue.pop_no_ret() - if req.finish_token_index != src_index: + finished_token_index = ( + req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index + ) + + if finished_token_index != src_index: token_list.append((req_id, text, metadata, FinishStatus())) else: - finish_status = FinishStatus(req.finish_status.status) + if req.stop_str_matched: + finish_status = FinishStatus(FinishStatus.FINISHED_STOP) + else: + finish_status = FinishStatus(req.finish_status.status) + token_list.append((req_id, text, metadata, finish_status)) else: break diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 24b8a9ddb..44820cbdc 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -15,7 +15,7 @@ from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue -from lightllm.server.core.objs.io_objs import GroupReqIndexes, AbortedReqCmd +from lightllm.server.core.objs.io_objs import GroupReqIndexes, AbortedReqCmd, StopStrMatchedReqCmd from lightllm.server.core.objs import ShmReqManager, StartArgs from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from .shm_reqs_io_buffer import ShmReqsIOBuffer @@ -279,6 +279,9 @@ async def _step(self): aborted_reqs = self._get_aborted_reqs_from_running_batch() if aborted_reqs: await self._aborted_reqs(aborted_reqs=aborted_reqs) + stop_str_matched_reqs = self._get_stop_str_reqs_from_running_batch() + if stop_str_matched_reqs: + await self._stop_str_matched_reqs(stop_str_matched_reqs=stop_str_matched_reqs) return async def _add_batch(self, batch: Batch): @@ -301,6 +304,15 @@ async def _aborted_reqs(self, aborted_reqs: List[Req]): self.shm_reqs_io_buffer.set_ready() return + async def _stop_str_matched_reqs(self, stop_str_matched_reqs: List[Req]): + cmds = [StopStrMatchedReqCmd(req_id=r.request_id) for r in stop_str_matched_reqs] + while not self.shm_reqs_io_buffer.is_empty(): + await asyncio.sleep(0.02) + + self.shm_reqs_io_buffer.write_obj(cmds) + self.shm_reqs_io_buffer.set_ready() + return + def _add_new_batch_to_running_batch(self, new_batch: Batch): if self.running_batch is None: self.running_batch = new_batch @@ -320,8 +332,22 @@ def _get_aborted_reqs_from_running_batch(self) -> List[Req]: if self.running_batch is None: return ans for req in self.running_batch.reqs: - if req.is_aborted and req.router_aborted is False: - req.router_aborted = True + if req.is_aborted and req._router_aborted is False: + req._router_aborted = True + ans.append(req) + return ans + + def _get_stop_str_reqs_from_running_batch(self) -> List[Req]: + # to do, 多节点tp模式,暂时不能支持 stop str 匹配退出 + if self.is_multinode_tp: + return [] + + ans = [] + if self.running_batch is None: + return ans + for req in self.running_batch.reqs: + if req.stop_str_matched and req._router_stop_str_matched is False: + req._router_stop_str_matched = True ans.append(req) return ans @@ -361,6 +387,11 @@ def _add_req(self, group_req_indexes: GroupReqIndexes): req = self.shm_req_manager.get_req_obj_by_index(req_index) req.multimodal_params = group_req_indexes.multimodal_params req.start_time = group_req_indexes.time_mark + # 附加一个私有标记变量,标记请求是否已经被router发送过abort命令给推理进程, + # 防止反复发送abort命令给推理进程 + req._router_aborted = False + # 作用同 _router_aborted 类似 + req._router_stop_str_matched = False req_group.append(req) logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s") diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 67d69aa38..01ae6c9c5 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -392,9 +392,6 @@ def update_finish_status(self, eos_ids, output_len: int): self.finish_status.set_status(FinishStatus.FINISHED_LENGTH) return - def is_finished_or_aborted(self): - return self.finish_status.is_finished() or self.shm_req.router_aborted - def _stop_sequences_matched(self, output_len: int): for stop_token_ids in self.stop_sequences: stop_len = len(stop_token_ids) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index d7341bed5..928709888 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -19,7 +19,7 @@ from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs -from lightllm.server.core.objs.io_objs import AbortedReqCmd +from lightllm.server.core.objs.io_objs import AbortedReqCmd, StopStrMatchedReqCmd from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.utils.dist_utils import get_global_rank, get_global_world_size, get_dp_size @@ -319,6 +319,12 @@ def _read_reqs_buffer_and_init_reqs(self): if obj.req_id in g_infer_context.requests_mapping: req: InferReq = g_infer_context.requests_mapping[obj.req_id] req.infer_aborted = True + elif isinstance(cmds[0], StopStrMatchedReqCmd): + for obj in cmds: + obj: StopStrMatchedReqCmd = obj + if obj.req_id in g_infer_context.requests_mapping: + req: InferReq = g_infer_context.requests_mapping[obj.req_id] + req.infer_aborted = True else: self._init_reqs(reqs=cmds) return