Skip to content

Commit a38f4e8

Browse files
authored
Merge pull request #51 from danielhua23/a2a
Add all2all initial impl
2 parents a353809 + e5db664 commit a38f4e8

File tree

6 files changed

+1162
-0
lines changed

6 files changed

+1162
-0
lines changed

problems/amd_distributed.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
name: AMD Developer Challenge 2025 - Distributed Edition
2+
# when does this end (individual problems might close earlier)
3+
deadline: "2025-10-14"
4+
# A description for this particular competition
5+
description: "AMD Developer Challenge 2025: Distributed Edition"
6+
# the list of problems
7+
problems:
8+
- directory: amd_distributed/all2all
9+
name: all2all
10+
deadline: "2025-10-14"
11+
gpus:
12+
- MI300x8
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# pytorch_all2all.py
2+
import os
3+
import torch
4+
import torch.distributed as dist
5+
import dataclasses
6+
from task import input_t, output_t
7+
8+
9+
# ---------------- MoE config ----------------
10+
@dataclasses.dataclass
11+
class MoEConfig:
12+
num_experts: int
13+
experts_per_token: int
14+
hidden_dim: int
15+
max_num_tokens: int
16+
in_dtype: torch.dtype = torch.float16
17+
out_dtype: torch.dtype = torch.float16
18+
19+
20+
# ---------------- data per dp rank ----------------
21+
class RankTestData:
22+
def __init__(self, cfg: MoEConfig, rng: torch.Generator, rank: int):
23+
device = torch.device(f"cuda:{rank}")
24+
self.num_tokens = int(
25+
torch.randint(
26+
1, cfg.max_num_tokens, [1], generator=rng, device=device
27+
).item()
28+
)
29+
# token expert map
30+
self.indices = torch.empty(
31+
self.num_tokens, cfg.experts_per_token, dtype=torch.int32, device=device
32+
)
33+
for i in range(self.num_tokens):
34+
perm = torch.randperm(cfg.num_experts, generator=rng, device=device)
35+
self.indices[i] = perm[: cfg.experts_per_token]
36+
# topk weights
37+
self.weights = torch.rand(
38+
self.num_tokens,
39+
cfg.experts_per_token,
40+
dtype=torch.float32,
41+
generator=rng,
42+
device=device,
43+
)
44+
# dp tokens, input of dispatch
45+
self.x = torch.randn(
46+
self.num_tokens,
47+
cfg.hidden_dim,
48+
dtype=cfg.in_dtype,
49+
generator=rng,
50+
device=device,
51+
)
52+
53+
54+
# ---------------- All2All pytorch impl ----------------
55+
class PyTorchAllToAll:
56+
META_DIM = 5 # global_exp, src_rank, src_token, src_k, pad
57+
58+
def __init__(self, cfg: MoEConfig, rank: int, world_size: int):
59+
self.cfg = cfg
60+
self.rank = rank
61+
self.world_size = world_size
62+
# num experts per rank
63+
self.num_local_experts = cfg.num_experts // world_size
64+
# max recv tokens per rank
65+
self.max_recv = cfg.max_num_tokens * cfg.experts_per_token
66+
67+
# ---------- dispatch ----------
68+
def dispatch(self, dp_x: torch.Tensor, indices: torch.Tensor):
69+
device = dp_x.device
70+
cfg = self.cfg
71+
72+
# ---------1. get counts of send and recv for each rank -----------
73+
# 1.1 token nums to send to each rank
74+
send_counts = [0] * self.world_size
75+
# 1.2 token id to send to each rank
76+
token_map = [[] for _ in range(self.world_size)]
77+
# 1.3 token meta data, need update for combine
78+
meta_map = [[] for _ in range(self.world_size)]
79+
for t, expert_list in enumerate(indices.tolist()):
80+
for k, e in enumerate(expert_list):
81+
dst_rank = e // self.num_local_experts
82+
send_counts[dst_rank] += 1
83+
token_map[dst_rank].append(t)
84+
meta_map[dst_rank].extend(
85+
[e, self.rank, t, k, 0]
86+
) # srcGobalExpert, srcRank, srcIndex, expert index
87+
88+
send_counts_t = torch.tensor(send_counts, dtype=torch.long, device=device)
89+
# 1.3 token nums to recv from each rank
90+
recv_counts_t = torch.empty(self.world_size, dtype=torch.long, device=device)
91+
dist.all_to_all_single(recv_counts_t, send_counts_t)
92+
# ---------2. send and recv buffer, order by tokens on each rank ----------
93+
send_buf = torch.cat([dp_x[idx_list] for idx_list in token_map], dim=0)
94+
total_recv = int(recv_counts_t.sum().item())
95+
recv_buf = torch.empty(
96+
total_recv, cfg.hidden_dim, dtype=cfg.in_dtype, device=device
97+
)
98+
99+
# 2.1 meta buf for send and recv
100+
send_meta = torch.tensor(
101+
[v for sub in meta_map for v in sub], dtype=torch.int32, device=device
102+
).view(-1, self.META_DIM)
103+
recv_meta = torch.empty(
104+
total_recv, self.META_DIM, dtype=torch.int32, device=device
105+
)
106+
# ---------3. dispatch send_buf to recv_buf by recv and send counts--------------
107+
dist.all_to_all_single(
108+
recv_buf,
109+
send_buf,
110+
output_split_sizes=recv_counts_t.tolist(),
111+
input_split_sizes=send_counts_t.tolist(),
112+
)
113+
114+
dist.all_to_all_single(
115+
recv_meta.view(-1),
116+
send_meta.view(-1),
117+
output_split_sizes=[c * self.META_DIM for c in recv_counts_t.tolist()],
118+
input_split_sizes=[c * self.META_DIM for c in send_counts_t.tolist()],
119+
)
120+
recv_meta = recv_meta.view(-1, self.META_DIM)
121+
# ---------4. define output tensor of dispatch ------------
122+
# 4.1 num tokens per expert
123+
expert_num_tokens = torch.zeros(
124+
self.num_local_experts, dtype=torch.int32, device=device
125+
)
126+
# 4.2 token tensor on each expert
127+
expert_x = torch.empty(
128+
(self.num_local_experts, self.max_recv, cfg.hidden_dim),
129+
dtype=cfg.in_dtype,
130+
device=device,
131+
)
132+
expert_meta = torch.empty(
133+
(self.num_local_experts, self.max_recv, self.META_DIM),
134+
dtype=torch.int32,
135+
device=device,
136+
)
137+
# ---------5. dispatch send_meta to recv_meta by recv and send counts------
138+
# ---------6. write tokens to each expert on each rank ------
139+
# 6.1 fetch the local expert id of corresponding token i
140+
for i in range(total_recv):
141+
global_eid = int(recv_meta[i, 0].item())
142+
local_eid = global_eid % self.num_local_experts
143+
# output, store token buf and token meta and token nums of each expert
144+
expert_x[local_eid, expert_num_tokens[local_eid]] = recv_buf[i]
145+
expert_meta[local_eid, expert_num_tokens[local_eid]] = recv_meta[i]
146+
expert_num_tokens[local_eid] += 1
147+
# 6.2 after dispatch, token nums and token and meta of token on expert
148+
return expert_num_tokens, expert_x, expert_meta
149+
150+
# ---------- combine ----------
151+
def combine(
152+
self,
153+
out_tokens: torch.Tensor, # output, (max num tokens, token dim)
154+
weights: torch.Tensor, # topk weight
155+
expert_meta: torch.Tensor, # input
156+
expert_y: torch.Tensor, # input, (num_local_experts, max_num_tokens * num_dp, token_dim)
157+
expert_num_tokens: torch.Tensor,
158+
): # input
159+
device = out_tokens.device
160+
cfg = self.cfg
161+
162+
# 1. count send-back tokens in cur rank
163+
send_counts = [0] * self.world_size
164+
# 1.1 token that will send back
165+
y_map = [[] for _ in range(self.world_size)]
166+
# 1.2 meta info of each token that send back to its src rank
167+
meta_map = [[] for _ in range(self.world_size)]
168+
169+
# 2. traverse each token of each local expert of each rank, fill into send_counts and y_map and meta_map
170+
for local_eid in range(self.num_local_experts):
171+
cnt = int(expert_num_tokens[local_eid].item())
172+
for j in range(cnt):
173+
# meta info token j of local eid
174+
meta = expert_meta[local_eid, j]
175+
dst_rank = int(meta[1].item())
176+
send_counts[dst_rank] += 1
177+
# token j and its meta that send back to dst rank/local eid
178+
y_map[dst_rank].append(expert_y[local_eid, j].unsqueeze(0))
179+
meta_map[dst_rank].extend(meta.tolist())
180+
# token nums that cur rank plan to send to other ranks
181+
send_counts_t = torch.tensor(send_counts, dtype=torch.long, device=device)
182+
# token nums that will recv from other ranks
183+
recv_counts_t = torch.empty(self.world_size, dtype=torch.long, device=device)
184+
# call all2all to send send counts and recv recv_counts_t at each rank by all2all
185+
dist.all_to_all_single(recv_counts_t, send_counts_t)
186+
# 3.send buffers of each rank, that is, the tokens at its experts
187+
y_map_tensors = []
188+
for sub_list in y_map:
189+
if sub_list:
190+
y_map_tensors.append(torch.cat(sub_list, dim=0))
191+
else:
192+
y_map_tensors.append(
193+
torch.empty((0, cfg.hidden_dim), dtype=cfg.out_dtype, device=device)
194+
)
195+
send_buf = torch.cat(y_map_tensors, dim=0)
196+
# 4. flatten send meta by tokens
197+
send_meta = torch.tensor(
198+
[v for sub in meta_map for v in sub], dtype=torch.int32, device=device
199+
).view(-1, self.META_DIM)
200+
# 5. total recv tokens of cur rank
201+
total_recv = int(recv_counts_t.sum().item())
202+
# 6. recv buffer of cur rank
203+
recv_buf = torch.empty(
204+
total_recv, cfg.hidden_dim, dtype=cfg.out_dtype, device=device
205+
)
206+
recv_meta = torch.empty(
207+
total_recv, self.META_DIM, dtype=torch.int32, device=device
208+
)
209+
# 7. call all2all to send and recv for each rank
210+
dist.all_to_all_single(
211+
recv_buf,
212+
send_buf,
213+
output_split_sizes=recv_counts_t.tolist(),
214+
input_split_sizes=send_counts_t.tolist(),
215+
)
216+
# 8. call all2all to send meta and recv meta for each rank
217+
dist.all_to_all_single(
218+
recv_meta.view(-1),
219+
send_meta.view(-1),
220+
output_split_sizes=[c * self.META_DIM for c in recv_counts_t.tolist()],
221+
input_split_sizes=[c * self.META_DIM for c in send_counts_t.tolist()],
222+
)
223+
# 9. restore recv meta
224+
recv_meta = recv_meta.view(-1, self.META_DIM)
225+
226+
# 10. write back tokens from recv buf, per meta info, and do weighted sum
227+
for i in range(total_recv):
228+
src_token = int(recv_meta[i, 2].item())
229+
src_k = int(recv_meta[i, 3].item())
230+
src_rank = int(recv_meta[i, 1].item())
231+
w = weights[src_token, src_k].to(torch.float32)
232+
out_tokens[src_token] += recv_buf[i].to(torch.float32) * w
233+
234+
return out_tokens
235+
236+
237+
def generate_input(
238+
num_experts, experts_per_token, hidden_dim, max_num_tokens, seed, rank, world_size
239+
):
240+
device = torch.device(f"cuda:{rank}")
241+
gen = torch.Generator(device=device)
242+
gen.manual_seed(seed)
243+
244+
cfg = MoEConfig(
245+
num_experts=num_experts,
246+
experts_per_token=experts_per_token,
247+
hidden_dim=hidden_dim,
248+
max_num_tokens=max_num_tokens,
249+
in_dtype=torch.float16,
250+
out_dtype=torch.float16,
251+
)
252+
rank_data = RankTestData(cfg, gen, rank)
253+
return cfg, rank_data, rank, world_size
254+
255+
256+
def ref_kernel(data: input_t) -> output_t:
257+
cfg, rank_data, rank, world_size = data
258+
259+
ata = PyTorchAllToAll(cfg, rank, world_size)
260+
261+
expert_num, expert_x, expert_meta = ata.dispatch(rank_data.x, rank_data.indices)
262+
expert_y = expert_x.to(cfg.out_dtype) * 2
263+
y = torch.zeros(
264+
cfg.max_num_tokens,
265+
cfg.hidden_dim,
266+
dtype=cfg.out_dtype,
267+
device=rank_data.x.device,
268+
)
269+
270+
ata.combine(y, rank_data.weights, expert_meta, expert_y, expert_num)
271+
272+
return y[: rank_data.num_tokens]
273+
274+
275+
def check_implementation(data: input_t, output: output_t):
276+
expected = ref_kernel(data)
277+
if output.device != expected.device:
278+
return False, f"Output device mismatch: {output.device} != {expected.device}"
279+
return torch.allclose(output, expected), f"Output mismatch: {output} != {expected}"

0 commit comments

Comments
 (0)