Skip to content

Commit 1e98a21

Browse files
author
Qirui Yang
committed
Add context parallel support for xformers attention backend
1 parent 3716b20 commit 1e98a21

File tree

2 files changed

+392
-2
lines changed

2 files changed

+392
-2
lines changed

vllm/v1/attention/backends/cp_utils.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import numpy as np
4+
import torch
5+
from torch import distributed as dist
46

57
from vllm.distributed.parallel_state import (get_context_parallel_rank,
6-
get_context_parallel_world_size)
8+
get_context_parallel_world_size,
9+
get_cp_group)
710
from vllm.logger import init_logger
811
from vllm.v1.worker.block_table import MultiGroupBlockTable
912
from vllm.v1.worker.gpu_input_batch import CachedRequestState
@@ -241,3 +244,45 @@ def prepare_inputs_for_cp(
241244
total_num_local_scheduled_tokens += num_scheduled_tokens_local[idx]
242245

243246
return num_scheduled_tokens_local, num_computed_tokens_local, q_seqlens_sharded
247+
248+
249+
def cp_get_neighbor_ranks() -> tuple[int, int]:
250+
return (get_cp_group().prev_rank, get_cp_group().next_rank)
251+
252+
253+
def cp_pass_around(
254+
tensors: list[torch.Tensor], to_rank: int, from_rank: int
255+
) -> tuple[list[torch.Tensor], list[torch.distributed.Work]]:
256+
"""
257+
Passes a list of tensors to designated to_rank, and receives the same number of tensors
258+
with the same sizes from designated from_rank. Note to_rank and from_rank are the ranks
259+
in default PG rather than context parallel pg. All ranks in a CP group must call this
260+
function together which results in passing the same tensors in a circular way across all ranks
261+
in a CP group.
262+
263+
Args:
264+
tensors: list of tensors to be passed around in the CP group
265+
to_rank: rank to pass my tensors to
266+
from_rank: rank to receive tensors from
267+
268+
Returns:
269+
dests: list of tensors received from from_rank
270+
reqs: list of P2POp requests to wait for to complete receiving from from_rank
271+
"""
272+
dests = []
273+
p2p_ops = []
274+
for x in tensors:
275+
x = x.contiguous()
276+
dest = torch.empty_like(x)
277+
dests.append(dest)
278+
p2p_ops += [
279+
dist.P2POp(dist.isend,
280+
x,
281+
to_rank,
282+
group=get_cp_group().device_group),
283+
dist.P2POp(dist.irecv,
284+
dest,
285+
from_rank,
286+
group=get_cp_group().device_group),
287+
]
288+
return dests, dist.batch_isend_irecv(p2p_ops)

0 commit comments

Comments
 (0)