Skip to content

feat: add stop string matching #969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 20, 2025
21 changes: 18 additions & 3 deletions lightllm/server/api_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,9 @@ async def process_single_prompt(prompt: Union[str, List[int]], prompt_index: int
prompt, individual_sampling_params, multimodal_params, request=raw_request
)

return await _collect_generation_results(generator, request, prompt_str, prompt_index)
return await _collect_generation_results(
generator, request, prompt_str, prompt_index, individual_sampling_params
)

tasks = [asyncio.create_task(process_single_prompt(prompt, i)) for i, prompt in enumerate(prompts)]

Expand Down Expand Up @@ -485,7 +487,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)


async def _collect_generation_results(generator, request: CompletionRequest, prompt: str, prompt_index: int):
async def _collect_generation_results(
generator, request: CompletionRequest, prompt: str, prompt_index: int, sampling_params: SamplingParams
):
final_output = []
count_output_tokens = 0
finish_reason = None
Expand Down Expand Up @@ -516,9 +520,20 @@ async def _collect_generation_results(generator, request: CompletionRequest, pro
finish_reason = finish_status.get_finish_reason()
prompt_tokens = metadata["prompt_tokens"]

# 处理停止序列剔除
final_text = "".join(final_output)
if finish_reason == "stop" and sampling_params.stop_sequences.size > 0:
valid_stop_strings = sampling_params.stop_sequences.to_strings()
for stop_str in valid_stop_strings:
stop_index = final_text.rfind(stop_str, max(0, len(final_text) - len(stop_str) - 20), len(final_text))
if stop_index != -1:
logger.debug(f"removed stop sequence in tail: '{final_text[stop_index:]}'")
final_text = final_text[:stop_index]
break

return {
"index": prompt_index,
"text": "".join(final_output),
"text": final_text,
"finish_reason": finish_reason,
"prompt_tokens": prompt_tokens,
"completion_tokens": count_output_tokens,
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/core/objs/io_objs/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd
from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, StopStrMatchedReqCmd
5 changes: 5 additions & 0 deletions lightllm/server/core/objs/io_objs/group_req.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ def to_group_req_index(self):
@dataclass
class AbortedReqCmd:
req_id: int


@dataclass
class StopStrMatchedReqCmd:
req_id: int
20 changes: 15 additions & 5 deletions lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def get_status(self):
def is_finished(self):
return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH

def is_stopped(self):
return self.status == self.FINISHED_STOP

def get_finish_reason(self):
if self.status == self.FINISHED_STOP:
return "stop"
Expand Down Expand Up @@ -74,10 +77,8 @@ class Req(ctypes.Structure):
("prompt_cache_len", ctypes.c_int), # 用于记录prompt cache 的命中长度,用于统计
("is_paused", ctypes.c_bool), # 标记一个Req因为显存资源管理的原因被临时暂停了。
("finish_status", FinishStatus),
# 这个标记变量是http_server 写入,其他进程读取,用于标记该请求是否因为断网被aborted。
("is_aborted", ctypes.c_bool),
# 这个标记变量是router进程读取到is_aborted信息后,router 进程标记该请求已经被abort处理
# 等待推理进程处理,防止router进程反复给推理进程发送abort指令。
("router_aborted", ctypes.c_bool),
# 当FinishStatus 是正常结束状态时,finish_token_index 用于标识结束的
# token 的index位置
("finish_token_index", ctypes.c_int),
Expand All @@ -97,6 +98,12 @@ class Req(ctypes.Structure):
("mtp_accepted_token_num", ctypes.c_int),
# mtp_step 保存一个mtp使用的常量参数,用于快速访问,不会被外部输入初始化
("_mtp_step", ctypes.c_int),
# stop_str_matched 用于判断停止字符串是否匹配成功, detokenization 进程写入,router 进程读取
# 然后router发停止命令给推理进程,推理进程停止输出
("stop_str_matched", ctypes.c_bool),
# 当 stop_str_matched 条件满足的时候,对应的最后一个生成 token 所在的index位置。
# 该变量为 detokenization 进程写入,http_server 读取
("stop_str_matched_token_index", ctypes.c_int),
]

def get_str(self):
Expand Down Expand Up @@ -124,7 +131,6 @@ def init(
self.is_paused = False
self.finish_status = FinishStatus()
self.is_aborted = False
self.router_aborted = False
self.shm_infer_released = False
self.shm_cur_kv_len = 0
self.shm_cur_output_len = 0
Expand All @@ -150,6 +156,8 @@ def init(
self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids
self.mtp_accepted_token_num = 0
self._mtp_step = get_env_start_args().mtp_step
self.stop_str_matched = False
self.stop_str_matched_token_index = -1

self.post_init()

Expand Down Expand Up @@ -210,7 +218,9 @@ def can_release(self):
if self.is_aborted and can_released_mark and ref_count_ok:
return True

if self.finish_status.is_finished() and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty():
ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched

if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty():
return True

return False
Expand Down
82 changes: 54 additions & 28 deletions lightllm/server/core/objs/sampling_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import ctypes
from typing import List, Tuple, Union
from typing import Optional, List, Tuple, Union
from transformers import GenerationConfig
from lightllm.server.req_id_generator import MAX_BEST_OF

Expand All @@ -10,6 +10,7 @@

# 从环境变量获取最大长度限制
STOP_SEQUENCE_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_MAX_LENGTH", 256))
STOP_SEQUENCE_STR_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_STR_MAX_LENGTH", 256))
ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256))
MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10))
REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
Expand All @@ -22,17 +23,30 @@ class StopSequence(ctypes.Structure):
_fields_ = [
("sequence", ctypes.c_int * STOP_SEQUENCE_MAX_LENGTH),
("size", ctypes.c_int),
("sequence_str", ctypes.c_char * STOP_SEQUENCE_STR_MAX_LENGTH),
("sequence_str_len", ctypes.c_int),
]

def initialize(self, sequence: List[int]):
def initialize(self, sequence: List[int], sequence_str: Optional[str] = None):
self.size = len(sequence)
assert self.size <= STOP_SEQUENCE_MAX_LENGTH, "stop token length too long."
assert all(isinstance(e, int) for e in sequence), "all must be int"
self.sequence[: self.size] = sequence[:]

def to_list(self):
if sequence_str is not None:
sequence_str_bytes = sequence_str.encode("utf-8")
assert len(sequence_str_bytes) < STOP_SEQUENCE_STR_MAX_LENGTH, "stop sequence string too long."
self.sequence_str = sequence_str_bytes
self.sequence_str_len = len(sequence_str_bytes)
else:
self.sequence_str_len = 0

def to_list(self) -> List[int]:
return list(self.sequence[0 : self.size])

def to_string(self) -> str:
return bytes(self.sequence_str[0 : self.sequence_str_len]).decode("utf-8")


class StopSequenceGroups(ctypes.Structure):
_pack_ = 4
Expand All @@ -41,40 +55,52 @@ class StopSequenceGroups(ctypes.Structure):
("size", ctypes.c_int),
]

def initialize(self, stop_sequences: Union[str, List], tokenizer):
def initialize(self, stop_sequences: Union[str, List[Union[List[int], str]]], tokenizer):
if stop_sequences is None:
stop_sequences = []
elif isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]

groups: List[List[int]] = self.stop_sentences_to_token_ids(stop_sequences, tokenizer)
self.size = len(groups)
assert self.size <= MAX_STOP_SEQUENCES, "Too many stop sequence groups."
for group_idx in range(self.size):
self.groups[group_idx].initialize(groups[group_idx])

def stop_sentences_to_token_ids(self, stop_sequences, tokenizer):
if stop_sequences is None:
stop_sequences = []
else:
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]

new_stop_sequences = []
for stop_info in stop_sequences:
if isinstance(stop_info, str):
stop_str_ids = self._stop_str_to_token_ids(stop_info, tokenizer)
if stop_str_ids is not None and len(stop_str_ids) > 0:
new_stop_sequences.append(stop_str_ids)
if isinstance(stop_info, list):
if all(isinstance(x, int) for x in stop_info):
if len(stop_info) > 0:
new_stop_sequences.append(stop_info)
stop_sequences = new_stop_sequences
return stop_sequences

def _stop_str_to_token_ids(self, stop_str: str, tokenizer):
for group_idx in range(self.size):
if isinstance(stop_sequences[group_idx], str):
self.groups[group_idx].initialize(groups[group_idx], sequence_str=stop_sequences[group_idx])
else:
self.groups[group_idx].initialize(groups[group_idx])

def stop_sentences_to_token_ids(self, stop_sequences: List[Union[List[int], str]], tokenizer) -> List[List[int]]:
new_stop_sequences = []
for stop_info in stop_sequences:
if isinstance(stop_info, str):
stop_str_ids = self._stop_str_to_token_ids(stop_info, tokenizer)
if stop_str_ids is not None and len(stop_str_ids) > 0:
new_stop_sequences.append(stop_str_ids)
if isinstance(stop_info, list):
if all(isinstance(x, int) for x in stop_info):
if len(stop_info) > 0:
new_stop_sequences.append(stop_info)
else:
assert False, "stop_sequences item must be type List[int] when it is a list."
return new_stop_sequences

def _stop_str_to_token_ids(self, stop_str: str, tokenizer) -> List[int]:
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
return stop_str_ids

def to_list(self):
def to_list(self) -> List[List[int]]:
return [self.groups[i].to_list() for i in range(self.size)]

def to_strings(self) -> List[str]:
# 降序匹配,在出现"\n\n"和"\n"情况时,优先匹配“\n\n”
return sorted(
[self.groups[i].to_string() for i in range(self.size) if self.groups[i].sequence_str_len > 0],
key=len,
reverse=True,
)


class RegularConstraint(ctypes.Structure):
_pack_ = 4
Expand Down
34 changes: 33 additions & 1 deletion lightllm/server/detokenization/decode_req.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
from typing import List, Dict
from lightllm.server.core.objs import Req
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


LIGHTLLM_DECODE_PREFIX_LENGTH = int(os.getenv("LIGHTLLM_DECODE_PREFIX_LENGTH", 5))

Expand All @@ -15,6 +19,7 @@ def __init__(
self.group_req_id = req.group_req_id
self.prompt_ids = req.shm_prompt_ids.arr[0 : req.input_len].tolist()
self.output_ids = []
self.output_strs = []
self.prefix_offset = max(len(self.prompt_ids) - LIGHTLLM_DECODE_PREFIX_LENGTH, 0)

if is_pd_decode_mode:
Expand All @@ -26,6 +31,9 @@ def __init__(
self.req = req
self.input_len = self.req.input_len
self.prefix_str = ""
self.stop_strs: List[str] = self.req.sample_params.stop_sequences.to_strings()
# to_strings()已经做了倒序排列,第一个元素就是最长字符串
self.stop_str_max_len = len(self.stop_strs[0]) if self.stop_strs else 0

def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], tokenizer):
tokens = [token_id_to_token[token_id] for token_id in self.req.prefix_token_ids.get_token_ids()]
Expand All @@ -35,8 +43,30 @@ def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], token
self.prefix_str = ""
return

def stop_sequences_str_match(self) -> bool:
stop_strs = self.stop_strs
if not stop_strs or self.stop_str_max_len == 0:
return False

tail_token_len = self.stop_str_max_len + 10 # 10 for safety
tail_token_strs = self.output_strs[-tail_token_len:]
tail_str = "".join(tail_token_strs)

for stop_str in stop_strs:
if stop_str in tail_str:
logger.debug(
f"req_id {self.request_id} Found stop sequence in tail: stop_str='{stop_str}', "
f"tail_str='{tail_str}'"
)
return True
return False

def need_detoken(self):
if (not self.req.is_aborted) and len(self.output_ids) < self.req.candetoken_out_len:
if (
(not self.req.is_aborted)
and (not self.req.stop_str_matched)
and len(self.output_ids) < self.req.candetoken_out_len
):
return True
return False

Expand All @@ -55,6 +85,8 @@ def get_decode_tokens(self):
def can_set_release_mark(self):
if self.req.is_aborted:
return True
if self.req.stop_str_matched:
return True
if (
self.req.finish_status.is_finished()
and self.req.candetoken_out_len == len(self.output_ids)
Expand Down
12 changes: 12 additions & 0 deletions lightllm/server/detokenization/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def gen_token_out(self):
exist_need_detoken = False
exist_decode = False
for decode_req in self.req_id_to_out.values():
# 已经满足停止字符串停止条件,则不再处理后续生成 token
if decode_req.req.stop_str_matched:
continue

if decode_req.need_detoken() and not decode_req.out_queue_is_full():
new_token_id, src_index = decode_req.get_next_token_id_and_index()
decode_req.output_ids.append(new_token_id)
Expand All @@ -131,6 +135,14 @@ def gen_token_out(self):
logger.error(
f"error token healing state, prefix_str {decode_req.prefix_str} new_text {new_text}"
)

decode_req.output_strs.append(new_text)

# 停止字符串匹配
if not decode_req.req.finish_status.is_stopped() and decode_req.stop_sequences_str_match():
decode_req.req.stop_str_matched_token_index = src_index
decode_req.req.stop_str_matched = True

decode_req.req.out_tokens_queue.push(new_text, src_index, special, count_output_tokens)

if decode_req.need_detoken():
Expand Down
12 changes: 10 additions & 2 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,10 +679,18 @@ async def handle_loop(self):

req.out_tokens_queue.pop_no_ret()

if req.finish_token_index != src_index:
finished_token_index = (
req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index
)

if finished_token_index != src_index:
token_list.append((req_id, text, metadata, FinishStatus()))
else:
finish_status = FinishStatus(req.finish_status.status)
if req.stop_str_matched:
finish_status = FinishStatus(FinishStatus.FINISHED_STOP)
else:
finish_status = FinishStatus(req.finish_status.status)

token_list.append((req_id, text, metadata, finish_status))
else:
break
Expand Down
Loading