8
8
#include < stdio.h>
9
9
#include < stdlib.h>
10
10
#include < mpi.h>
11
-
12
11
#ifdef USE_CUDA
13
- #include < cuda_runtime.h>
14
12
#include < cuda_bf16.h>
13
+ #define bfloat16 nv_bfloat16
14
+ #elif USE_ROCM
15
+ #include < hip/hip_bfloat16.h>
16
+ #include < hip/hip_runtime.h>
17
+ #include < hip/hip_runtime_api.h>
18
+ #define bfloat16 hip_bfloat16
15
19
#endif
16
20
17
21
#ifdef USE_NCCL
18
22
#include " nccl.h"
19
- #elif defined( USE_RCCL)
20
- #include " rccl.h "
23
+ #elif USE_RCCL
24
+ #include < rccl/rccl.h >
21
25
#endif
22
26
23
27
#define NUM_WARMUP_ITERATIONS 5
40
44
} \
41
45
} while (0 )
42
46
47
+ #define HIP_CHECK (cmd ) do { \
48
+ hipError_t e = cmd; \
49
+ if (e != hipSuccess) { \
50
+ printf (" HIP error %s:%d: %s\n " , \
51
+ __FILE__, __LINE__, hipGetErrorString (e)); \
52
+ exit (EXIT_FAILURE); \
53
+ } \
54
+ } while (0 )
55
+
56
+ // NCCL_CHECK is used to validate RCCL functions as well
43
57
#define NCCL_CHECK (cmd ) do { \
44
58
ncclResult_t e = cmd; \
45
59
if (e != ncclSuccess) { \
49
63
} \
50
64
} while (0 )
51
65
52
- void initializeData (nv_bfloat16 *data, int size) {
53
- for (int i = 0 ; i < (size / sizeof (nv_bfloat16)); ++i) {
66
+ void initializeData (bfloat16 *data, int size) {
67
+ for (int i = 0 ; i < (size / sizeof (bfloat16)); ++i) {
68
+ #ifdef USE_CUDA
54
69
data[i] = __float2bfloat16 ((float )i);
70
+ #elif USE_ROCM
71
+ // ROCm doesn't have a float2bfloat16 method
72
+ data[i] = (bfloat16) ((float ) i);
73
+ #endif
55
74
}
56
75
}
57
76
@@ -86,33 +105,44 @@ int main(int argc, char *argv[]) {
86
105
}
87
106
88
107
// Initialize GPU context
108
+ #if USE_CUDA
89
109
cudaGetDeviceCount (&num_gpus_per_node);
90
110
cudaSetDevice ((my_rank % num_gpus_per_node));
111
+ #elif USE_ROCM
112
+ hipGetDeviceCount (&num_gpus_per_node);
113
+ hipSetDevice ((my_rank % num_gpus_per_node));
114
+ #endif
91
115
92
116
int local_data_size = max_msg_size; // Size of local data
93
117
int global_data_size = local_data_size * num_gpus; // Size of global data
94
118
95
- nv_bfloat16 *local_data = (nv_bfloat16 *)malloc (local_data_size);
96
- nv_bfloat16 *global_data = (nv_bfloat16 *)malloc (global_data_size);
119
+ bfloat16 *local_data = (bfloat16 *)malloc (local_data_size);
120
+ bfloat16 *global_data = (bfloat16 *)malloc (global_data_size);
97
121
98
122
// Initialize local data
99
123
initializeData (local_data, local_data_size);
100
124
101
125
// Allocate memory on GPU
102
- nv_bfloat16 *d_local_data, *d_global_data;
126
+ bfloat16 *d_local_data, *d_global_data;
127
+ #ifdef USE_CUDA
103
128
CUDA_CHECK (cudaMalloc (&d_local_data, local_data_size));
104
129
CUDA_CHECK (cudaMalloc (&d_global_data, global_data_size));
105
-
106
130
// Copy local data to GPU
107
131
CUDA_CHECK (cudaMemcpy (d_local_data, local_data, local_data_size, cudaMemcpyHostToDevice));
108
132
133
+ #elif USE_ROCM
134
+ HIP_CHECK (hipMalloc (&d_local_data, local_data_size));
135
+ HIP_CHECK (hipMalloc (&d_global_data, global_data_size));
136
+ HIP_CHECK (hipMemcpy (d_local_data, local_data, local_data_size, hipMemcpyHostToDevice));
137
+ #endif
138
+
109
139
#ifdef USE_MPI
110
140
// create 2-byte datatype (send raw, un-interpreted bytes)
111
141
MPI_Datatype mpi_type_bfloat16;
112
142
MPI_Type_contiguous (2 , MPI_BYTE, &mpi_type_bfloat16);
113
143
MPI_Type_commit (&mpi_type_bfloat16);
114
144
115
- #elif USE_NCCL
145
+ #elif defined( USE_NCCL) || defined(USE_RCCL)
116
146
ncclUniqueId nccl_comm_id;
117
147
ncclComm_t nccl_comm;
118
148
@@ -125,13 +155,8 @@ int main(int argc, char *argv[]) {
125
155
MPI_CHECK (MPI_Bcast ((void *)&nccl_comm_id, sizeof (nccl_comm_id), MPI_BYTE,
126
156
0 , MPI_COMM_WORLD));
127
157
128
- /* Create a new NCCL communicator */
158
+ /* Create a new NCCL/RCCL communicator */
129
159
NCCL_CHECK (ncclCommInitRank (&nccl_comm, num_pes, nccl_comm_id, my_rank));
130
-
131
- #elif defined(USE_RCCL)
132
- // TODO: fix later
133
- rcclComm_t rccl_comm;
134
- rcclCommInitRank (&comm, num_gpus, 0 , rccl_root);
135
160
#endif
136
161
137
162
// Perform MPI_Iallgather, NCCL allgather, or RCCL allgather
@@ -148,20 +173,22 @@ int main(int argc, char *argv[]) {
148
173
fflush (NULL );
149
174
150
175
for (int msg_size = min_msg_size; msg_size <= max_msg_size; msg_size *= 2 ) {
151
- msg_count = msg_size / sizeof (nv_bfloat16 );
176
+ msg_count = msg_size / sizeof (bfloat16 );
152
177
// warmup iterations
153
178
for (int i = 0 ; i < NUM_WARMUP_ITERATIONS; ++i) {
154
179
#ifdef USE_MPI
155
180
MPI_CHECK (MPI_Iallgather (d_local_data, msg_count, mpi_type_bfloat16,
156
181
d_global_data, msg_count, mpi_type_bfloat16, MPI_COMM_WORLD, &request));
157
182
158
183
MPI_CHECK (MPI_Wait (&request, &status));
159
- #elif defined(USE_NCCL)
184
+ #elif defined(USE_NCCL) || defined(USE_RCCL)
160
185
NCCL_CHECK (ncclAllGather ((const void *)d_local_data, (void *)d_global_data, msg_count, ncclBfloat16, nccl_comm, NULL ));
161
- cudaDeviceSynchronize ();
162
- #elif defined(USE_RCCL)
163
- // TODO: fix later
164
- rcclAllReduce ((const void *)d_local_data, (void *)d_global_data, global_data_size, rcclInt, rcclSum, comm, NULL );
186
+ #endif
187
+
188
+ #ifdef USE_CUDA
189
+ cudaDeviceSynchronize ();
190
+ #elif USE_ROCM
191
+ hipDeviceSynchronize ();
165
192
#endif
166
193
}
167
194
@@ -172,16 +199,18 @@ int main(int argc, char *argv[]) {
172
199
start_time = MPI_Wtime ();
173
200
for (int i = 0 ; i < iterations; ++i) {
174
201
#ifdef USE_MPI
175
- MPI_CHECK (MPI_Iallgather (d_local_data, msg_count, mpi_type_bfloat16,
176
- d_global_data, msg_count, mpi_type_bfloat16, MPI_COMM_WORLD, &request));
177
-
202
+ MPI_CHECK (MPI_Iallgather (d_local_data, msg_count, mpi_type_bfloat16,
203
+ d_global_data, msg_count, mpi_type_bfloat16, MPI_COMM_WORLD, &request));
204
+
178
205
MPI_CHECK (MPI_Wait (&request, &status));
179
- #elif defined(USE_NCCL)
206
+ #elif defined(USE_NCCL) || defined(USE_RCCL)
180
207
NCCL_CHECK (ncclAllGather ((const void *)d_local_data, (void *)d_global_data, msg_count, ncclBfloat16, nccl_comm, NULL ));
181
- cudaDeviceSynchronize ();
182
- #elif defined(USE_RCCL)
183
- // TODO: fix later
184
- rcclAllReduce ((const void *)d_local_data, (void *)d_global_data, global_data_size, rcclInt, rcclSum, comm, NULL );
208
+ #endif
209
+
210
+ #ifdef USE_CUDA
211
+ cudaDeviceSynchronize ();
212
+ #elif USE_ROCM
213
+ hipDeviceSynchronize ();
185
214
#endif
186
215
}
187
216
MPI_Barrier (MPI_COMM_WORLD);
@@ -193,13 +222,16 @@ int main(int argc, char *argv[]) {
193
222
// Cleanup
194
223
free (local_data);
195
224
free (global_data);
225
+ #ifdef USE_CUDA
196
226
CUDA_CHECK (cudaFree (d_local_data));
197
227
CUDA_CHECK (cudaFree (d_global_data));
228
+ #elif USE_ROCM
229
+ HIP_CHECK (hipFree (d_local_data));
230
+ HIP_CHECK (hipFree (d_global_data));
231
+ #endif
198
232
199
- #ifdef USE_NCCL
233
+ #ifdef defined( USE_NCCL) || defined(USE_RCCL)
200
234
ncclCommDestroy (nccl_comm);
201
- #elif defined(USE_RCCL)
202
- rcclCommDestroy (rccl_comm);
203
235
#endif
204
236
205
237
MPI_Finalize ();
0 commit comments