-
Notifications
You must be signed in to change notification settings - Fork 570
qwen3-vl Vit module enable sp and mrope fusion op #4165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,110 @@ | ||||||||||||||
| import torch | ||||||||||||||
| import torch.distributed as dist | ||||||||||||||
| import torch_npu # noqa | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def all_to_all_4d(input_tensor: torch.tensor, | ||||||||||||||
| is_seq_to_head: bool, | ||||||||||||||
| group=None, | ||||||||||||||
| use_sync: bool = False) -> torch.tensor: | ||||||||||||||
| seq_world_size = dist.get_world_size(group) | ||||||||||||||
| if is_seq_to_head: | ||||||||||||||
| # Transfer shape (bs, seqlen/sp, hc, hs) to (bs, seqlen, hc/sp, hs) | ||||||||||||||
| bs, shard_seqlen, hc, hs = input_tensor.shape | ||||||||||||||
| seqlen = shard_seqlen * seq_world_size | ||||||||||||||
| shard_hc = hc // seq_world_size | ||||||||||||||
|
|
||||||||||||||
| input_t = (input_tensor.reshape(bs, shard_seqlen, seq_world_size, | ||||||||||||||
| shard_hc, | ||||||||||||||
| hs).transpose(0, 2).contiguous()) | ||||||||||||||
|
|
||||||||||||||
| output = torch.empty_like(input_t) | ||||||||||||||
| if seq_world_size > 1: | ||||||||||||||
| dist.all_to_all_single(output, input_t, group=group) | ||||||||||||||
| if use_sync: | ||||||||||||||
| torch.npu.synchronize() | ||||||||||||||
| else: | ||||||||||||||
| output = input_t | ||||||||||||||
|
|
||||||||||||||
| output = output.reshape(seqlen, bs, shard_hc, | ||||||||||||||
| hs).transpose(0, 1).contiguous() | ||||||||||||||
| return output | ||||||||||||||
| else: | ||||||||||||||
| bs, seqlen, shard_hc, hs = input_tensor.shape | ||||||||||||||
| hc = shard_hc * seq_world_size | ||||||||||||||
| shard_seqlen = seqlen // seq_world_size | ||||||||||||||
|
|
||||||||||||||
| input_t = (input_tensor.reshape( | ||||||||||||||
| bs, seq_world_size, shard_seqlen, shard_hc, | ||||||||||||||
| hs).transpose(0, 3).transpose(0, 1).contiguous().reshape( | ||||||||||||||
| seq_world_size, shard_hc, shard_seqlen, bs, hs)) | ||||||||||||||
|
|
||||||||||||||
| output = torch.empty_like(input_t) | ||||||||||||||
| if seq_world_size > 1: | ||||||||||||||
| dist.all_to_all_single(output, input_t, group=group) | ||||||||||||||
| if use_sync: | ||||||||||||||
| torch.npu.synchronize() | ||||||||||||||
| else: | ||||||||||||||
| output = input_t | ||||||||||||||
|
|
||||||||||||||
| output = output.reshape(hc, shard_seqlen, bs, | ||||||||||||||
| hs).transpose(0, 2).contiguous() | ||||||||||||||
| return output.reshape(bs, shard_seqlen, hc, hs) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def all_to_all_3d(input_tensor: torch.tensor, | ||||||||||||||
| is_seq_to_head: bool, | ||||||||||||||
| group=None, | ||||||||||||||
| use_sync: bool = False) -> torch.tensor: | ||||||||||||||
| seq_world_size = dist.get_world_size(group) | ||||||||||||||
|
|
||||||||||||||
| if is_seq_to_head: | ||||||||||||||
| shard_seqlen, hc, hs = input_tensor.shape | ||||||||||||||
| seqlen = shard_seqlen * seq_world_size | ||||||||||||||
| shard_hc = hc // seq_world_size | ||||||||||||||
|
|
||||||||||||||
| input_t = (input_tensor.reshape(shard_seqlen, seq_world_size, shard_hc, | ||||||||||||||
| hs).transpose(0, 1).contiguous()) | ||||||||||||||
|
|
||||||||||||||
| output = torch.empty_like(input_t) | ||||||||||||||
| if seq_world_size > 1: | ||||||||||||||
| dist.all_to_all_single(output, input_t, group=group) | ||||||||||||||
| if use_sync: | ||||||||||||||
| torch.npu.synchronize() | ||||||||||||||
| else: | ||||||||||||||
| output = input_t | ||||||||||||||
| output = output.reshape(seqlen, shard_hc, hs) | ||||||||||||||
| return output | ||||||||||||||
| else: | ||||||||||||||
| # Transfer shape (seqlen, hc/sp, hs) to (seqlen/sp, hc, hs) | ||||||||||||||
| seqlen, shard_hc, hs = input_tensor.shape | ||||||||||||||
| hc = shard_hc * seq_world_size | ||||||||||||||
| shard_seqlen = seqlen // seq_world_size | ||||||||||||||
|
|
||||||||||||||
| input_t = (input_tensor.reshape(seq_world_size, shard_seqlen, shard_hc, | ||||||||||||||
| hs).transpose(1, 2).contiguous()) | ||||||||||||||
|
|
||||||||||||||
| output = torch.empty_like(input_t) | ||||||||||||||
| if seq_world_size > 1: | ||||||||||||||
| dist.all_to_all_single(output, input_t, group=group) | ||||||||||||||
| if use_sync: | ||||||||||||||
| torch.npu.synchronize() | ||||||||||||||
| else: | ||||||||||||||
| output = input_t | ||||||||||||||
|
|
||||||||||||||
| output = output.reshape(hc, shard_seqlen, | ||||||||||||||
| hs).transpose(0, 1).contiguous() | ||||||||||||||
| return output | ||||||||||||||
|
Comment on lines
+95
to
+97
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the issue in
Suggested change
|
||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def all_gather_2d(input_tensor: torch.tensor, | ||||||||||||||
| world_size: int, | ||||||||||||||
| group=None) -> torch.tensor: | ||||||||||||||
| s, d = input_tensor.shape | ||||||||||||||
| input_gather = torch.zeros(world_size * s, | ||||||||||||||
| d, | ||||||||||||||
| dtype=input_tensor.dtype, | ||||||||||||||
| device=input_tensor.device) | ||||||||||||||
| dist.all_gather_into_tensor(input_gather, input_tensor, group=group) | ||||||||||||||
|
|
||||||||||||||
| return input_gather | ||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
reshapeoperation on theoutputtensor is incorrect. The tensor has a shape of(seq_world_size, shard_hc, shard_seqlen, bs, hs), and thereshapeattempts to merge the first two dimensions (seq_world_sizeandshard_hc). However, these dimensions are not contiguous in memory after the preceding transpose operations. Atranspose(0, 1)is required to make them adjacent before reshaping. Failure to do so will result in a tensor with scrambled data.Additionally, the
reshapein thereturnstatement is redundant as the tensor already has the correct shape after the preceding operations.