Skip to content

Commit bda86fa

Browse files
One cuBLAS handle per device
1 parent 0bd133d commit bda86fa

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

ggml-cuda.cu

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
444444
static int g_device_count = -1;
445445
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
446446

447-
static cublasHandle_t g_cublasH = nullptr;
447+
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
448448

449449
static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
450450

@@ -482,11 +482,11 @@ void ggml_init_cublas() {
482482
for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
483483
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming));
484484
}
485-
}
486485

487-
// create cublas handle
488-
CUBLAS_CHECK(cublasCreate(&g_cublasH));
489-
CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
486+
// create cublas handle
487+
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
488+
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
489+
}
490490

491491
// configure logging to stdout
492492
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
@@ -659,9 +659,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
659659

660660
const uint64_t i0_diff = i0_high - i0_low;
661661

662-
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream_main));
662+
int id;
663+
CUDA_CHECK(cudaGetDevice(&id));
664+
665+
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], cudaStream_main));
663666
CUBLAS_CHECK(
664-
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
667+
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
665668
i0_diff, ne11, ne10,
666669
&alpha, src0_ddf_i, ne00,
667670
src1_ddf_i, ne10,

0 commit comments

Comments
 (0)