|
| 1 | +# pytorch_all2all.py |
| 2 | +import os |
| 3 | +import torch |
| 4 | +import torch.distributed as dist |
| 5 | +import dataclasses |
| 6 | +from task import input_t, output_t |
| 7 | + |
| 8 | + |
| 9 | +# ---------------- MoE config ---------------- |
| 10 | +@dataclasses.dataclass |
| 11 | +class MoEConfig: |
| 12 | + num_experts: int |
| 13 | + experts_per_token: int |
| 14 | + hidden_dim: int |
| 15 | + max_num_tokens: int |
| 16 | + in_dtype: torch.dtype = torch.float16 |
| 17 | + out_dtype: torch.dtype = torch.float16 |
| 18 | + |
| 19 | + |
| 20 | +# ---------------- data per dp rank ---------------- |
| 21 | +class RankTestData: |
| 22 | + def __init__(self, cfg: MoEConfig, rng: torch.Generator, rank: int): |
| 23 | + device = torch.device(f"cuda:{rank}") |
| 24 | + self.num_tokens = int( |
| 25 | + torch.randint( |
| 26 | + 1, cfg.max_num_tokens, [1], generator=rng, device=device |
| 27 | + ).item() |
| 28 | + ) |
| 29 | + # token expert map |
| 30 | + self.indices = torch.empty( |
| 31 | + self.num_tokens, cfg.experts_per_token, dtype=torch.int32, device=device |
| 32 | + ) |
| 33 | + for i in range(self.num_tokens): |
| 34 | + perm = torch.randperm(cfg.num_experts, generator=rng, device=device) |
| 35 | + self.indices[i] = perm[: cfg.experts_per_token] |
| 36 | + # topk weights |
| 37 | + self.weights = torch.rand( |
| 38 | + self.num_tokens, |
| 39 | + cfg.experts_per_token, |
| 40 | + dtype=torch.float32, |
| 41 | + generator=rng, |
| 42 | + device=device, |
| 43 | + ) |
| 44 | + # dp tokens, input of dispatch |
| 45 | + self.x = torch.randn( |
| 46 | + self.num_tokens, |
| 47 | + cfg.hidden_dim, |
| 48 | + dtype=cfg.in_dtype, |
| 49 | + generator=rng, |
| 50 | + device=device, |
| 51 | + ) |
| 52 | + |
| 53 | + |
| 54 | +# ---------------- All2All pytorch impl ---------------- |
| 55 | +class PyTorchAllToAll: |
| 56 | + META_DIM = 5 # global_exp, src_rank, src_token, src_k, pad |
| 57 | + |
| 58 | + def __init__(self, cfg: MoEConfig, rank: int, world_size: int): |
| 59 | + self.cfg = cfg |
| 60 | + self.rank = rank |
| 61 | + self.world_size = world_size |
| 62 | + # num experts per rank |
| 63 | + self.num_local_experts = cfg.num_experts // world_size |
| 64 | + # max recv tokens per rank |
| 65 | + self.max_recv = cfg.max_num_tokens * cfg.experts_per_token |
| 66 | + |
| 67 | + # ---------- dispatch ---------- |
| 68 | + def dispatch(self, dp_x: torch.Tensor, indices: torch.Tensor): |
| 69 | + device = dp_x.device |
| 70 | + cfg = self.cfg |
| 71 | + |
| 72 | + # ---------1. get counts of send and recv for each rank ----------- |
| 73 | + # 1.1 token nums to send to each rank |
| 74 | + send_counts = [0] * self.world_size |
| 75 | + # 1.2 token id to send to each rank |
| 76 | + token_map = [[] for _ in range(self.world_size)] |
| 77 | + # 1.3 token meta data, need update for combine |
| 78 | + meta_map = [[] for _ in range(self.world_size)] |
| 79 | + for t, expert_list in enumerate(indices.tolist()): |
| 80 | + for k, e in enumerate(expert_list): |
| 81 | + dst_rank = e // self.num_local_experts |
| 82 | + send_counts[dst_rank] += 1 |
| 83 | + token_map[dst_rank].append(t) |
| 84 | + meta_map[dst_rank].extend( |
| 85 | + [e, self.rank, t, k, 0] |
| 86 | + ) # srcGobalExpert, srcRank, srcIndex, expert index |
| 87 | + |
| 88 | + send_counts_t = torch.tensor(send_counts, dtype=torch.long, device=device) |
| 89 | + # 1.3 token nums to recv from each rank |
| 90 | + recv_counts_t = torch.empty(self.world_size, dtype=torch.long, device=device) |
| 91 | + dist.all_to_all_single(recv_counts_t, send_counts_t) |
| 92 | + # ---------2. send and recv buffer, order by tokens on each rank ---------- |
| 93 | + send_buf = torch.cat([dp_x[idx_list] for idx_list in token_map], dim=0) |
| 94 | + total_recv = int(recv_counts_t.sum().item()) |
| 95 | + recv_buf = torch.empty( |
| 96 | + total_recv, cfg.hidden_dim, dtype=cfg.in_dtype, device=device |
| 97 | + ) |
| 98 | + |
| 99 | + # 2.1 meta buf for send and recv |
| 100 | + send_meta = torch.tensor( |
| 101 | + [v for sub in meta_map for v in sub], dtype=torch.int32, device=device |
| 102 | + ).view(-1, self.META_DIM) |
| 103 | + recv_meta = torch.empty( |
| 104 | + total_recv, self.META_DIM, dtype=torch.int32, device=device |
| 105 | + ) |
| 106 | + # ---------3. dispatch send_buf to recv_buf by recv and send counts-------------- |
| 107 | + dist.all_to_all_single( |
| 108 | + recv_buf, |
| 109 | + send_buf, |
| 110 | + output_split_sizes=recv_counts_t.tolist(), |
| 111 | + input_split_sizes=send_counts_t.tolist(), |
| 112 | + ) |
| 113 | + |
| 114 | + dist.all_to_all_single( |
| 115 | + recv_meta.view(-1), |
| 116 | + send_meta.view(-1), |
| 117 | + output_split_sizes=[c * self.META_DIM for c in recv_counts_t.tolist()], |
| 118 | + input_split_sizes=[c * self.META_DIM for c in send_counts_t.tolist()], |
| 119 | + ) |
| 120 | + recv_meta = recv_meta.view(-1, self.META_DIM) |
| 121 | + # ---------4. define output tensor of dispatch ------------ |
| 122 | + # 4.1 num tokens per expert |
| 123 | + expert_num_tokens = torch.zeros( |
| 124 | + self.num_local_experts, dtype=torch.int32, device=device |
| 125 | + ) |
| 126 | + # 4.2 token tensor on each expert |
| 127 | + expert_x = torch.empty( |
| 128 | + (self.num_local_experts, self.max_recv, cfg.hidden_dim), |
| 129 | + dtype=cfg.in_dtype, |
| 130 | + device=device, |
| 131 | + ) |
| 132 | + expert_meta = torch.empty( |
| 133 | + (self.num_local_experts, self.max_recv, self.META_DIM), |
| 134 | + dtype=torch.int32, |
| 135 | + device=device, |
| 136 | + ) |
| 137 | + # ---------5. dispatch send_meta to recv_meta by recv and send counts------ |
| 138 | + # ---------6. write tokens to each expert on each rank ------ |
| 139 | + # 6.1 fetch the local expert id of corresponding token i |
| 140 | + for i in range(total_recv): |
| 141 | + global_eid = int(recv_meta[i, 0].item()) |
| 142 | + local_eid = global_eid % self.num_local_experts |
| 143 | + # output, store token buf and token meta and token nums of each expert |
| 144 | + expert_x[local_eid, expert_num_tokens[local_eid]] = recv_buf[i] |
| 145 | + expert_meta[local_eid, expert_num_tokens[local_eid]] = recv_meta[i] |
| 146 | + expert_num_tokens[local_eid] += 1 |
| 147 | + # 6.2 after dispatch, token nums and token and meta of token on expert |
| 148 | + return expert_num_tokens, expert_x, expert_meta |
| 149 | + |
| 150 | + # ---------- combine ---------- |
| 151 | + def combine( |
| 152 | + self, |
| 153 | + out_tokens: torch.Tensor, # output, (max num tokens, token dim) |
| 154 | + weights: torch.Tensor, # topk weight |
| 155 | + expert_meta: torch.Tensor, # input |
| 156 | + expert_y: torch.Tensor, # input, (num_local_experts, max_num_tokens * num_dp, token_dim) |
| 157 | + expert_num_tokens: torch.Tensor, |
| 158 | + ): # input |
| 159 | + device = out_tokens.device |
| 160 | + cfg = self.cfg |
| 161 | + |
| 162 | + # 1. count send-back tokens in cur rank |
| 163 | + send_counts = [0] * self.world_size |
| 164 | + # 1.1 token that will send back |
| 165 | + y_map = [[] for _ in range(self.world_size)] |
| 166 | + # 1.2 meta info of each token that send back to its src rank |
| 167 | + meta_map = [[] for _ in range(self.world_size)] |
| 168 | + |
| 169 | + # 2. traverse each token of each local expert of each rank, fill into send_counts and y_map and meta_map |
| 170 | + for local_eid in range(self.num_local_experts): |
| 171 | + cnt = int(expert_num_tokens[local_eid].item()) |
| 172 | + for j in range(cnt): |
| 173 | + # meta info token j of local eid |
| 174 | + meta = expert_meta[local_eid, j] |
| 175 | + dst_rank = int(meta[1].item()) |
| 176 | + send_counts[dst_rank] += 1 |
| 177 | + # token j and its meta that send back to dst rank/local eid |
| 178 | + y_map[dst_rank].append(expert_y[local_eid, j].unsqueeze(0)) |
| 179 | + meta_map[dst_rank].extend(meta.tolist()) |
| 180 | + # token nums that cur rank plan to send to other ranks |
| 181 | + send_counts_t = torch.tensor(send_counts, dtype=torch.long, device=device) |
| 182 | + # token nums that will recv from other ranks |
| 183 | + recv_counts_t = torch.empty(self.world_size, dtype=torch.long, device=device) |
| 184 | + # call all2all to send send counts and recv recv_counts_t at each rank by all2all |
| 185 | + dist.all_to_all_single(recv_counts_t, send_counts_t) |
| 186 | + # 3.send buffers of each rank, that is, the tokens at its experts |
| 187 | + y_map_tensors = [] |
| 188 | + for sub_list in y_map: |
| 189 | + if sub_list: |
| 190 | + y_map_tensors.append(torch.cat(sub_list, dim=0)) |
| 191 | + else: |
| 192 | + y_map_tensors.append( |
| 193 | + torch.empty((0, cfg.hidden_dim), dtype=cfg.out_dtype, device=device) |
| 194 | + ) |
| 195 | + send_buf = torch.cat(y_map_tensors, dim=0) |
| 196 | + # 4. flatten send meta by tokens |
| 197 | + send_meta = torch.tensor( |
| 198 | + [v for sub in meta_map for v in sub], dtype=torch.int32, device=device |
| 199 | + ).view(-1, self.META_DIM) |
| 200 | + # 5. total recv tokens of cur rank |
| 201 | + total_recv = int(recv_counts_t.sum().item()) |
| 202 | + # 6. recv buffer of cur rank |
| 203 | + recv_buf = torch.empty( |
| 204 | + total_recv, cfg.hidden_dim, dtype=cfg.out_dtype, device=device |
| 205 | + ) |
| 206 | + recv_meta = torch.empty( |
| 207 | + total_recv, self.META_DIM, dtype=torch.int32, device=device |
| 208 | + ) |
| 209 | + # 7. call all2all to send and recv for each rank |
| 210 | + dist.all_to_all_single( |
| 211 | + recv_buf, |
| 212 | + send_buf, |
| 213 | + output_split_sizes=recv_counts_t.tolist(), |
| 214 | + input_split_sizes=send_counts_t.tolist(), |
| 215 | + ) |
| 216 | + # 8. call all2all to send meta and recv meta for each rank |
| 217 | + dist.all_to_all_single( |
| 218 | + recv_meta.view(-1), |
| 219 | + send_meta.view(-1), |
| 220 | + output_split_sizes=[c * self.META_DIM for c in recv_counts_t.tolist()], |
| 221 | + input_split_sizes=[c * self.META_DIM for c in send_counts_t.tolist()], |
| 222 | + ) |
| 223 | + # 9. restore recv meta |
| 224 | + recv_meta = recv_meta.view(-1, self.META_DIM) |
| 225 | + |
| 226 | + # 10. write back tokens from recv buf, per meta info, and do weighted sum |
| 227 | + for i in range(total_recv): |
| 228 | + src_token = int(recv_meta[i, 2].item()) |
| 229 | + src_k = int(recv_meta[i, 3].item()) |
| 230 | + src_rank = int(recv_meta[i, 1].item()) |
| 231 | + w = weights[src_token, src_k].to(torch.float32) |
| 232 | + out_tokens[src_token] += recv_buf[i].to(torch.float32) * w |
| 233 | + |
| 234 | + return out_tokens |
| 235 | + |
| 236 | + |
| 237 | +def generate_input( |
| 238 | + num_experts, experts_per_token, hidden_dim, max_num_tokens, seed, rank, world_size |
| 239 | +): |
| 240 | + device = torch.device(f"cuda:{rank}") |
| 241 | + gen = torch.Generator(device=device) |
| 242 | + gen.manual_seed(seed) |
| 243 | + |
| 244 | + cfg = MoEConfig( |
| 245 | + num_experts=num_experts, |
| 246 | + experts_per_token=experts_per_token, |
| 247 | + hidden_dim=hidden_dim, |
| 248 | + max_num_tokens=max_num_tokens, |
| 249 | + in_dtype=torch.float16, |
| 250 | + out_dtype=torch.float16, |
| 251 | + ) |
| 252 | + rank_data = RankTestData(cfg, gen, rank) |
| 253 | + return cfg, rank_data, rank, world_size |
| 254 | + |
| 255 | + |
| 256 | +def ref_kernel(data: input_t) -> output_t: |
| 257 | + cfg, rank_data, rank, world_size = data |
| 258 | + |
| 259 | + ata = PyTorchAllToAll(cfg, rank, world_size) |
| 260 | + |
| 261 | + expert_num, expert_x, expert_meta = ata.dispatch(rank_data.x, rank_data.indices) |
| 262 | + expert_y = expert_x.to(cfg.out_dtype) * 2 |
| 263 | + y = torch.zeros( |
| 264 | + cfg.max_num_tokens, |
| 265 | + cfg.hidden_dim, |
| 266 | + dtype=cfg.out_dtype, |
| 267 | + device=rank_data.x.device, |
| 268 | + ) |
| 269 | + |
| 270 | + ata.combine(y, rank_data.weights, expert_meta, expert_y, expert_num) |
| 271 | + |
| 272 | + return y[: rank_data.num_tokens] |
| 273 | + |
| 274 | + |
| 275 | +def check_implementation(data: input_t, output: output_t): |
| 276 | + expected = ref_kernel(data) |
| 277 | + if output.device != expected.device: |
| 278 | + return False, f"Output device mismatch: {output.device} != {expected.device}" |
| 279 | + return torch.allclose(output, expected), f"Output mismatch: {output} != {expected}" |
0 commit comments