|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 | import numpy as np |
| 4 | +import torch |
| 5 | +from torch import distributed as dist |
4 | 6 |
|
5 | 7 | from vllm.distributed.parallel_state import (get_context_parallel_rank, |
6 | | - get_context_parallel_world_size) |
| 8 | + get_context_parallel_world_size, |
| 9 | + get_cp_group) |
7 | 10 | from vllm.logger import init_logger |
8 | 11 | from vllm.v1.worker.block_table import MultiGroupBlockTable |
9 | 12 | from vllm.v1.worker.gpu_input_batch import CachedRequestState |
@@ -241,3 +244,45 @@ def prepare_inputs_for_cp( |
241 | 244 | total_num_local_scheduled_tokens += num_scheduled_tokens_local[idx] |
242 | 245 |
|
243 | 246 | return num_scheduled_tokens_local, num_computed_tokens_local, q_seqlens_sharded |
| 247 | + |
| 248 | + |
| 249 | +def cp_get_neighbor_ranks() -> tuple[int, int]: |
| 250 | + return (get_cp_group().prev_rank, get_cp_group().next_rank) |
| 251 | + |
| 252 | + |
| 253 | +def cp_pass_around( |
| 254 | + tensors: list[torch.Tensor], to_rank: int, from_rank: int |
| 255 | +) -> tuple[list[torch.Tensor], list[torch.distributed.Work]]: |
| 256 | + """ |
| 257 | + Passes a list of tensors to designated to_rank, and receives the same number of tensors |
| 258 | + with the same sizes from designated from_rank. Note to_rank and from_rank are the ranks |
| 259 | + in default PG rather than context parallel pg. All ranks in a CP group must call this |
| 260 | + function together which results in passing the same tensors in a circular way across all ranks |
| 261 | + in a CP group. |
| 262 | +
|
| 263 | + Args: |
| 264 | + tensors: list of tensors to be passed around in the CP group |
| 265 | + to_rank: rank to pass my tensors to |
| 266 | + from_rank: rank to receive tensors from |
| 267 | +
|
| 268 | + Returns: |
| 269 | + dests: list of tensors received from from_rank |
| 270 | + reqs: list of P2POp requests to wait for to complete receiving from from_rank |
| 271 | + """ |
| 272 | + dests = [] |
| 273 | + p2p_ops = [] |
| 274 | + for x in tensors: |
| 275 | + x = x.contiguous() |
| 276 | + dest = torch.empty_like(x) |
| 277 | + dests.append(dest) |
| 278 | + p2p_ops += [ |
| 279 | + dist.P2POp(dist.isend, |
| 280 | + x, |
| 281 | + to_rank, |
| 282 | + group=get_cp_group().device_group), |
| 283 | + dist.P2POp(dist.irecv, |
| 284 | + dest, |
| 285 | + from_rank, |
| 286 | + group=get_cp_group().device_group), |
| 287 | + ] |
| 288 | + return dests, dist.batch_isend_irecv(p2p_ops) |
0 commit comments