Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions vllm_ascend/distributed/context_parallel_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
import torch.distributed as dist
import torch_npu # noqa


def all_to_all_4d(input_tensor: torch.tensor,
is_seq_to_head: bool,
group=None,
use_sync: bool = False) -> torch.tensor:
seq_world_size = dist.get_world_size(group)
if is_seq_to_head:
# Transfer shape (bs, seqlen/sp, hc, hs) to (bs, seqlen, hc/sp, hs)
bs, shard_seqlen, hc, hs = input_tensor.shape
seqlen = shard_seqlen * seq_world_size
shard_hc = hc // seq_world_size

input_t = (input_tensor.reshape(bs, shard_seqlen, seq_world_size,
shard_hc,
hs).transpose(0, 2).contiguous())

output = torch.empty_like(input_t)
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
if use_sync:
torch.npu.synchronize()
else:
output = input_t

output = output.reshape(seqlen, bs, shard_hc,
hs).transpose(0, 1).contiguous()
return output
else:
bs, seqlen, shard_hc, hs = input_tensor.shape
hc = shard_hc * seq_world_size
shard_seqlen = seqlen // seq_world_size

input_t = (input_tensor.reshape(
bs, seq_world_size, shard_seqlen, shard_hc,
hs).transpose(0, 3).transpose(0, 1).contiguous().reshape(
seq_world_size, shard_hc, shard_seqlen, bs, hs))

output = torch.empty_like(input_t)
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
if use_sync:
torch.npu.synchronize()
else:
output = input_t

output = output.reshape(hc, shard_seqlen, bs,
hs).transpose(0, 2).contiguous()
return output.reshape(bs, shard_seqlen, hc, hs)
Comment on lines +50 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The reshape operation on the output tensor is incorrect. The tensor has a shape of (seq_world_size, shard_hc, shard_seqlen, bs, hs), and the reshape attempts to merge the first two dimensions (seq_world_size and shard_hc). However, these dimensions are not contiguous in memory after the preceding transpose operations. A transpose(0, 1) is required to make them adjacent before reshaping. Failure to do so will result in a tensor with scrambled data.

Additionally, the reshape in the return statement is redundant as the tensor already has the correct shape after the preceding operations.

Suggested change
output = output.reshape(hc, shard_seqlen, bs,
hs).transpose(0, 2).contiguous()
return output.reshape(bs, shard_seqlen, hc, hs)
output = output.transpose(0, 1).contiguous().reshape(
hc, shard_seqlen, bs, hs).transpose(0, 2).contiguous()
return output



def all_to_all_3d(input_tensor: torch.tensor,
is_seq_to_head: bool,
group=None,
use_sync: bool = False) -> torch.tensor:
seq_world_size = dist.get_world_size(group)

if is_seq_to_head:
shard_seqlen, hc, hs = input_tensor.shape
seqlen = shard_seqlen * seq_world_size
shard_hc = hc // seq_world_size

input_t = (input_tensor.reshape(shard_seqlen, seq_world_size, shard_hc,
hs).transpose(0, 1).contiguous())

output = torch.empty_like(input_t)
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
if use_sync:
torch.npu.synchronize()
else:
output = input_t
output = output.reshape(seqlen, shard_hc, hs)
return output
else:
# Transfer shape (seqlen, hc/sp, hs) to (seqlen/sp, hc, hs)
seqlen, shard_hc, hs = input_tensor.shape
hc = shard_hc * seq_world_size
shard_seqlen = seqlen // seq_world_size

input_t = (input_tensor.reshape(seq_world_size, shard_seqlen, shard_hc,
hs).transpose(1, 2).contiguous())

output = torch.empty_like(input_t)
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
if use_sync:
torch.npu.synchronize()
else:
output = input_t

output = output.reshape(hc, shard_seqlen,
hs).transpose(0, 1).contiguous()
return output
Comment on lines +95 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the issue in all_to_all_4d, the reshape operation on the output tensor here is incorrect. The output tensor has a shape of (seq_world_size, shard_hc, shard_seqlen, hs), and reshape(hc, ...) incorrectly attempts to merge the non-contiguous first two dimensions. This will lead to data corruption. You need to transpose the first two dimensions to make them contiguous before reshaping.

Suggested change
output = output.reshape(hc, shard_seqlen,
hs).transpose(0, 1).contiguous()
return output
output = output.transpose(0, 1).contiguous().reshape(
hc, shard_seqlen, hs).transpose(0, 1).contiguous()
return output



def all_gather_2d(input_tensor: torch.tensor,
world_size: int,
group=None) -> torch.tensor:
s, d = input_tensor.shape
input_gather = torch.zeros(world_size * s,
d,
dtype=input_tensor.dtype,
device=input_tensor.device)
dist.all_gather_into_tensor(input_gather, input_tensor, group=group)

return input_gather
Loading
Loading