Skip to content

Commit 3d26b3f

Browse files
committed
resolve merge conflict
2 parents a67570e + 73c53dc commit 3d26b3f

File tree

1 file changed

+66
-34
lines changed

1 file changed

+66
-34
lines changed

allgather.cu

Lines changed: 66 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,20 @@
88
#include <stdio.h>
99
#include <stdlib.h>
1010
#include <mpi.h>
11-
1211
#ifdef USE_CUDA
13-
#include <cuda_runtime.h>
1412
#include <cuda_bf16.h>
13+
#define bfloat16 nv_bfloat16
14+
#elif USE_ROCM
15+
#include <hip/hip_bfloat16.h>
16+
#include <hip/hip_runtime.h>
17+
#include <hip/hip_runtime_api.h>
18+
#define bfloat16 hip_bfloat16
1519
#endif
1620

1721
#ifdef USE_NCCL
1822
#include "nccl.h"
19-
#elif defined(USE_RCCL)
20-
#include "rccl.h"
23+
#elif USE_RCCL
24+
#include <rccl/rccl.h>
2125
#endif
2226

2327
#define NUM_WARMUP_ITERATIONS 5
@@ -40,6 +44,16 @@
4044
} \
4145
} while(0)
4246

47+
#define HIP_CHECK(cmd) do { \
48+
hipError_t e = cmd; \
49+
if(e != hipSuccess) { \
50+
printf("HIP error %s:%d: %s\n", \
51+
__FILE__, __LINE__, hipGetErrorString(e)); \
52+
exit(EXIT_FAILURE); \
53+
} \
54+
} while(0)
55+
56+
// NCCL_CHECK is used to validate RCCL functions as well
4357
#define NCCL_CHECK(cmd) do { \
4458
ncclResult_t e = cmd; \
4559
if (e != ncclSuccess) { \
@@ -49,9 +63,14 @@
4963
} \
5064
} while(0)
5165

52-
void initializeData(nv_bfloat16 *data, int size) {
53-
for (int i = 0; i < (size / sizeof(nv_bfloat16)); ++i) {
66+
void initializeData(bfloat16 *data, int size) {
67+
for (int i = 0; i < (size / sizeof(bfloat16)); ++i) {
68+
#ifdef USE_CUDA
5469
data[i] = __float2bfloat16((float)i);
70+
#elif USE_ROCM
71+
// ROCm doesn't have a float2bfloat16 method
72+
data[i] = (bfloat16) ((float) i);
73+
#endif
5574
}
5675
}
5776

@@ -86,33 +105,44 @@ int main(int argc, char *argv[]) {
86105
}
87106

88107
// Initialize GPU context
108+
#if USE_CUDA
89109
cudaGetDeviceCount(&num_gpus_per_node);
90110
cudaSetDevice((my_rank % num_gpus_per_node));
111+
#elif USE_ROCM
112+
hipGetDeviceCount(&num_gpus_per_node);
113+
hipSetDevice((my_rank % num_gpus_per_node));
114+
#endif
91115

92116
int local_data_size = max_msg_size; // Size of local data
93117
int global_data_size = local_data_size * num_gpus; // Size of global data
94118

95-
nv_bfloat16 *local_data = (nv_bfloat16*)malloc(local_data_size);
96-
nv_bfloat16 *global_data = (nv_bfloat16*)malloc(global_data_size);
119+
bfloat16 *local_data = (bfloat16*)malloc(local_data_size);
120+
bfloat16 *global_data = (bfloat16*)malloc(global_data_size);
97121

98122
// Initialize local data
99123
initializeData(local_data, local_data_size);
100124

101125
// Allocate memory on GPU
102-
nv_bfloat16 *d_local_data, *d_global_data;
126+
bfloat16 *d_local_data, *d_global_data;
127+
#ifdef USE_CUDA
103128
CUDA_CHECK(cudaMalloc(&d_local_data, local_data_size));
104129
CUDA_CHECK(cudaMalloc(&d_global_data, global_data_size));
105-
106130
// Copy local data to GPU
107131
CUDA_CHECK(cudaMemcpy(d_local_data, local_data, local_data_size, cudaMemcpyHostToDevice));
108132

133+
#elif USE_ROCM
134+
HIP_CHECK(hipMalloc(&d_local_data, local_data_size));
135+
HIP_CHECK(hipMalloc(&d_global_data, global_data_size));
136+
HIP_CHECK(hipMemcpy(d_local_data, local_data, local_data_size, hipMemcpyHostToDevice));
137+
#endif
138+
109139
#ifdef USE_MPI
110140
// create 2-byte datatype (send raw, un-interpreted bytes)
111141
MPI_Datatype mpi_type_bfloat16;
112142
MPI_Type_contiguous(2, MPI_BYTE, &mpi_type_bfloat16);
113143
MPI_Type_commit(&mpi_type_bfloat16);
114144

115-
#elif USE_NCCL
145+
#elif defined(USE_NCCL) || defined(USE_RCCL)
116146
ncclUniqueId nccl_comm_id;
117147
ncclComm_t nccl_comm;
118148

@@ -125,13 +155,8 @@ int main(int argc, char *argv[]) {
125155
MPI_CHECK(MPI_Bcast((void *)&nccl_comm_id, sizeof(nccl_comm_id), MPI_BYTE,
126156
0, MPI_COMM_WORLD));
127157

128-
/* Create a new NCCL communicator */
158+
/* Create a new NCCL/RCCL communicator */
129159
NCCL_CHECK(ncclCommInitRank(&nccl_comm, num_pes, nccl_comm_id, my_rank));
130-
131-
#elif defined(USE_RCCL)
132-
// TODO: fix later
133-
rcclComm_t rccl_comm;
134-
rcclCommInitRank(&comm, num_gpus, 0, rccl_root);
135160
#endif
136161

137162
// Perform MPI_Iallgather, NCCL allgather, or RCCL allgather
@@ -148,20 +173,22 @@ int main(int argc, char *argv[]) {
148173
fflush(NULL);
149174

150175
for (int msg_size = min_msg_size; msg_size <= max_msg_size; msg_size *= 2) {
151-
msg_count = msg_size / sizeof(nv_bfloat16);
176+
msg_count = msg_size / sizeof(bfloat16);
152177
// warmup iterations
153178
for (int i = 0; i < NUM_WARMUP_ITERATIONS; ++i) {
154179
#ifdef USE_MPI
155180
MPI_CHECK(MPI_Iallgather(d_local_data, msg_count, mpi_type_bfloat16,
156181
d_global_data, msg_count, mpi_type_bfloat16, MPI_COMM_WORLD, &request));
157182

158183
MPI_CHECK(MPI_Wait(&request, &status));
159-
#elif defined(USE_NCCL)
184+
#elif defined(USE_NCCL) || defined(USE_RCCL)
160185
NCCL_CHECK(ncclAllGather((const void*)d_local_data, (void*)d_global_data, msg_count, ncclBfloat16, nccl_comm, NULL));
161-
cudaDeviceSynchronize();
162-
#elif defined(USE_RCCL)
163-
// TODO: fix later
164-
rcclAllReduce((const void*)d_local_data, (void*)d_global_data, global_data_size, rcclInt, rcclSum, comm, NULL);
186+
#endif
187+
188+
#ifdef USE_CUDA
189+
cudaDeviceSynchronize();
190+
#elif USE_ROCM
191+
hipDeviceSynchronize();
165192
#endif
166193
}
167194

@@ -172,16 +199,18 @@ int main(int argc, char *argv[]) {
172199
start_time = MPI_Wtime();
173200
for (int i = 0; i < iterations; ++i) {
174201
#ifdef USE_MPI
175-
MPI_CHECK(MPI_Iallgather(d_local_data, msg_count, mpi_type_bfloat16,
176-
d_global_data, msg_count, mpi_type_bfloat16, MPI_COMM_WORLD, &request));
177-
202+
MPI_CHECK(MPI_Iallgather(d_local_data, msg_count, mpi_type_bfloat16,
203+
d_global_data, msg_count, mpi_type_bfloat16, MPI_COMM_WORLD, &request));
204+
178205
MPI_CHECK(MPI_Wait(&request, &status));
179-
#elif defined(USE_NCCL)
206+
#elif defined(USE_NCCL) || defined(USE_RCCL)
180207
NCCL_CHECK(ncclAllGather((const void*)d_local_data, (void*)d_global_data, msg_count, ncclBfloat16, nccl_comm, NULL));
181-
cudaDeviceSynchronize();
182-
#elif defined(USE_RCCL)
183-
// TODO: fix later
184-
rcclAllReduce((const void*)d_local_data, (void*)d_global_data, global_data_size, rcclInt, rcclSum, comm, NULL);
208+
#endif
209+
210+
#ifdef USE_CUDA
211+
cudaDeviceSynchronize();
212+
#elif USE_ROCM
213+
hipDeviceSynchronize();
185214
#endif
186215
}
187216
MPI_Barrier(MPI_COMM_WORLD);
@@ -193,13 +222,16 @@ int main(int argc, char *argv[]) {
193222
// Cleanup
194223
free(local_data);
195224
free(global_data);
225+
#ifdef USE_CUDA
196226
CUDA_CHECK(cudaFree(d_local_data));
197227
CUDA_CHECK(cudaFree(d_global_data));
228+
#elif USE_ROCM
229+
HIP_CHECK(hipFree(d_local_data));
230+
HIP_CHECK(hipFree(d_global_data));
231+
#endif
198232

199-
#ifdef USE_NCCL
233+
#ifdef defined(USE_NCCL) || defined(USE_RCCL)
200234
ncclCommDestroy(nccl_comm);
201-
#elif defined(USE_RCCL)
202-
rcclCommDestroy(rccl_comm);
203235
#endif
204236

205237
MPI_Finalize();

0 commit comments

Comments
 (0)