Skip to content

Dp balancer #991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
help="tool call parser type",
)
parser.add_argument(
"--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time"
"--running_max_req_size", type=int, default=2048, help="the max size for forward requests in the same time"
)
parser.add_argument("--nnodes", type=int, default=1, help="the number of nodes")
parser.add_argument("--node_rank", type=int, default=0, help="the rank of the current node")
Expand All @@ -137,6 +137,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
using the deepseekv2 model, set dp to be equal to the tp parameter. In other cases, please
do not set it and keep the default value as 1.""",
)
parser.add_argument(
"--dp_balancer",
type=str,
default="bs_balancer",
choices=["round_robin", "bs_balancer"],
help="the dp balancer type, default is bs_balancer",
)
parser.add_argument(
"--max_req_total_len", type=int, default=16384, help="the max value for req_input_len + req_output_len"
)
Expand Down
9 changes: 9 additions & 0 deletions lightllm/server/router/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ def get_req_list_for_dp(self, dp_index: int):
req_list.append(req)
return req_list

def get_all_dp_req_num(self) -> List[int]:
if self.dp_size_in_node == 1:
return [len(self.reqs)]

all_dp_req_num = [0 for _ in range(self.dp_size_in_node)]
for req in self.reqs:
all_dp_req_num[req.sample_params.suggested_dp_index] += 1
return all_dp_req_num

def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
unfinished_req_ids = []
for req in self.reqs:
Expand Down
10 changes: 2 additions & 8 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,6 @@ async def wait_to_model_ready(self):
return

def _get_schedule_time_interval(self):
if self.running_batch is None:
# 没有运行中的 batch 时,每 10ms 触发一次请求调度
return 0.01

# dp 模式,为了更好的配平,需要更长的调度间隔,以便于能收到更多的请求
return self.schedule_time_interval

Expand Down Expand Up @@ -370,9 +366,7 @@ def _add_req(self, group_req_indexes: GroupReqIndexes):

def _generate_new_batch(self):
# 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
new_batch = self.req_queue.generate_new_batch(
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch)
)
new_batch = self.req_queue.generate_new_batch(self.schedule_new_batch)
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
return

Expand Down Expand Up @@ -469,7 +463,7 @@ async def _recv_new_reqs_and_schedule(self):
if self.is_multinode_tp:
self._multinode_tp_generate_new_batch()
else:
if self._get_paused_req_num() == 0:
if self._get_paused_req_num() == 0 and self.shm_reqs_io_buffer.is_empty():
self._generate_new_batch()
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def init_model(self, kvargs):
self.chunked_prefill_size = self.args.chunked_prefill_size
self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs
self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache
self.batch_max_tokens = self.args.batch_max_tokens
self.eos_id: List[int] = kvargs.get("eos_id", [2])
self.disable_cudagraph = self.args.disable_cudagraph
self.is_multinode_tp = self.args.nnodes > 1 and self.args.dp == 1
Expand Down Expand Up @@ -391,6 +392,7 @@ def _get_classed_reqs(
# 请求,其逻辑是不适合的。
pause_max_req_num = 2
wait_pause_count = 0
prefill_tokens = 0

# 因为会使用到 radix cache 和 mem_manager 的计数信息
# 所以需要加锁保护。
Expand Down Expand Up @@ -439,6 +441,11 @@ def _get_classed_reqs(
wait_pause_count += 1
else:
token_num = req_obj.prefill_need_token_num(is_chuncked_prefill=not self.disable_chunked_prefill)
if prefill_tokens + token_num > self.batch_max_tokens:
# 跳过等下次prefill,避免oom
prefill_tokens = 0
break
prefill_tokens += token_num
if token_num <= can_alloc_token_num:
prefill_reqs.append(req_obj)
can_alloc_token_num -= token_num
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/router/req_queue/chunked_prefill/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def generate_new_batch(self, current_batch: Batch):
new_batch_first_router_need_tokens = (
0 if current_batch is None else current_batch.get_batch_decode_need_tokens()[self.dp_index]
)
print(f"new_batch_first_router_need_tokens: {new_batch_first_router_need_tokens}")

self._init_cache_list(current_batch, is_busy)
can_run_list = []
Expand Down
13 changes: 13 additions & 0 deletions lightllm/server/router/req_queue/dp_balancer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .dp_base_balancer import RoundRobinDpBalancer
from typing import List
from lightllm.server.router.req_queue.base_queue import BaseQueue
from .dp_bs_balancer import DpBsBalancer


def get_dp_balancer(args, dp_size_in_node: int, inner_queues: List[BaseQueue]):
if args.dp_balancer == "round_robin":
return RoundRobinDpBalancer(dp_size_in_node, inner_queues)
elif args.dp_balancer == "bs_balancer":
return DpBsBalancer(dp_size_in_node, inner_queues)
else:
raise ValueError(f"Invalid dp balancer: {args.dp_balancer}")
65 changes: 65 additions & 0 deletions lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import random
from abc import ABC, abstractmethod
from typing import List, Union
from lightllm.server.router.req_queue.base_queue import BaseQueue
from lightllm.server.router.batch import Batch, Req
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class DpBalancer(ABC):
"""
DP负载均衡器基类
定义了负载均衡策略的接口,子类可以实现不同的负载均衡算法
"""

def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]):
self.dp_size_in_node = dp_size_in_node
self.inner_queues = inner_queues
self.pre_select_dp_index = self.dp_size_in_node - 1

@abstractmethod
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None:
pass


class RoundRobinDpBalancer(DpBalancer):
"""
轮询负载均衡器
在队列长度最小的DP中进行轮询选择
"""

def get_suggest_dp_index(
self,
) -> int:
min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues)
select_dp_indexes = [
i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length
]

# 如果没有可选择的索引,随机选择一个
if not select_dp_indexes:
self.pre_select_dp_index = random.randint(0, self.dp_size_in_node - 1)
return self.pre_select_dp_index

# 轮询选择
for i in range(self.dp_size_in_node):
next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node
if next_dp_index in select_dp_indexes:
self.pre_select_dp_index = next_dp_index
return self.pre_select_dp_index

self.pre_select_dp_index = random.choice(select_dp_indexes)
return self.pre_select_dp_index

def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None:
for req_group in reqs_waiting_for_dp_index:
suggested_dp_index = self.get_suggest_dp_index()
if not isinstance(req_group, list):
req_group = [req_group]
for req in req_group:
req.sample_params.suggested_dp_index = suggested_dp_index
self.inner_queues[suggested_dp_index].append(req)
reqs_waiting_for_dp_index.clear()
return
63 changes: 63 additions & 0 deletions lightllm/server/router/req_queue/dp_balancer/dp_bs_balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import List, Union
from lightllm.server.router.req_queue.base_queue import BaseQueue
from lightllm.server.router.batch import Batch, Req
from lightllm.utils.log_utils import init_logger
from .dp_base_balancer import DpBalancer

logger = init_logger(__name__)


class DpBsBalancer(DpBalancer):
"""
This balancer is main to balance the batch size of each dp rank.
Because, for dp mode, if it exists a dp rank without any request, it will
padding a request and cause the waste of GPU compute resource.
"""

def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]):
super().__init__(dp_size_in_node, inner_queues)

def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None:
if len(reqs_waiting_for_dp_index) == 0:
return
# calculate the total load of each dp rank
if current_batch is not None:
all_dp_req_num = current_batch.get_all_dp_req_num()
total_load_per_dp = [
all_dp_req_num[i] + len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node)
]
else:
total_load_per_dp = [len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node)]
for req_group in reqs_waiting_for_dp_index:
# calculate the length of this request group
if isinstance(req_group, list):
req_length = len(req_group)
else:
req_length = 1

# find the dp rank with minimum load
min_load = min(total_load_per_dp)
select_dp_indexes = [i for i in range(self.dp_size_in_node) if total_load_per_dp[i] == min_load]

# select the dp rank with the minimum load
if len(select_dp_indexes) == 1:
suggested_dp_index = select_dp_indexes[0]
else:
# if multiple dp ranks have the same minimum load, randomly select one
import random

suggested_dp_index = random.choice(select_dp_indexes)

# assign the request to the dp rank and update the load count
if not isinstance(req_group, list):
req_group = [req_group]

for req in req_group:
req.sample_params.suggested_dp_index = suggested_dp_index
self.inner_queues[suggested_dp_index].append(req)

# update the load count for this dp rank
total_load_per_dp[suggested_dp_index] += req_length

reqs_waiting_for_dp_index.clear()
return
61 changes: 21 additions & 40 deletions lightllm/server/router/req_queue/dp_base_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List
from ..batch import Batch, Req
from lightllm.server.router.req_queue.base_queue import BaseQueue
from lightllm.server.router.req_queue.dp_balancer import get_dp_balancer
from lightllm.common.basemodel.infer_lock import g_router_lock
from lightllm.utils.log_utils import init_logger

Expand All @@ -12,14 +13,14 @@ class DpQueue:
def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None:
self.dp_size_in_node = dp_size_in_node
self.base_queue_class = base_queue_class
self.pre_select_dp_index = self.dp_size_in_node - 1
from lightllm.server.router.manager import RouterManager

self.router: RouterManager = router
self.inner_queues: List[BaseQueue] = [
base_queue_class(args, router, dp_index, dp_size_in_node) for dp_index in range(self.dp_size_in_node)
]

self.dp_balancer = get_dp_balancer(args, dp_size_in_node, self.inner_queues)
self.reqs_waiting_for_dp_index = []
return

def get_dp_queue(self, dp_index: int):
Expand All @@ -31,10 +32,16 @@ def get_wait_req_num(self):

# @calculate_time(show=True, min_cost_ms=10)
def generate_new_batch(self, current_batch: Batch):
batches = [
self.inner_queues[dp_index].generate_new_batch(current_batch) for dp_index in range(self.dp_size_in_node)
]
return self._merge_batch(batches)
try:
self.dp_balancer.assign_reqs_to_dp(current_batch, self.reqs_waiting_for_dp_index)
batches = [
self.inner_queues[dp_index].generate_new_batch(current_batch)
for dp_index in range(self.dp_size_in_node)
]
return self._merge_batch(batches)
except Exception as e:
logger.error(f"generate new batch failed: {e}")
raise e
Comment on lines +43 to +44

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When an exception occurs during batch generation, using raise e can obscure the original stack trace. Using a bare raise will preserve the original traceback, making debugging easier.

Suggested change
logger.error(f"generate new batch failed: {e}")
raise e
except Exception as e:
logger.error(f"generate new batch failed: {e}")
raise


def _merge_batch(self, dp_batches: List[Batch]):
merged_batch: Batch = None
Expand All @@ -48,28 +55,20 @@ def _merge_batch(self, dp_batches: List[Batch]):
def append(self, req: Req):
suggested_dp_index = req.sample_params.suggested_dp_index
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid")
suggested_dp_index = self._get_suggest_dp_index()
self.pre_select_dp_index = suggested_dp_index
req.sample_params.suggested_dp_index = suggested_dp_index
self.inner_queues[suggested_dp_index].append(req)
# 在调度时,统一分配请求id
self.reqs_waiting_for_dp_index.append(req)
else:
self.inner_queues[suggested_dp_index].append(req)
return

def extend(self, req_group: List[Req]):
# 同一个组的,要分配在同一个 dp 上,效率最高
index = self._get_suggest_dp_index()
for req in req_group:
suggested_dp_index = req.sample_params.suggested_dp_index
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid")
self.pre_select_dp_index = index
req.sample_params.suggested_dp_index = index
self.inner_queues[index].append(req)
else:
suggested_dp_index = req_group[0].sample_params.suggested_dp_index
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
# 同一个组的,要分配在同一个 dp 上
self.reqs_waiting_for_dp_index.append(req_group)
else:
for req in req_group:
self.inner_queues[suggested_dp_index].append(req)

return

def is_busy(self):
Expand All @@ -87,21 +86,3 @@ def update_token_load(self, current_batch: Batch, force_update=False):
self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, dp_index)
self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, dp_index)
return

def _get_suggest_dp_index(self):
min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues)
select_dp_indexes = [
i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length
]

# multi thread safe keep
if not select_dp_indexes:
return random.randint(0, self.dp_size_in_node - 1)

# round_robin select.
for i in range(self.dp_size_in_node):
next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node
if next_dp_index in select_dp_indexes:
return next_dp_index

return random.choice(select_dp_indexes)
Loading