@@ -444,7 +444,7 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
444444static int g_device_count = -1 ;
445445static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0 };
446446
447- static cublasHandle_t g_cublasH = nullptr ;
447+ static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = { nullptr } ;
448448
449449static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
450450
@@ -482,11 +482,11 @@ void ggml_init_cublas() {
482482 for (int i = 0 ; i < GGML_CUDA_MAX_EVENTS; ++i) {
483483 CUDA_CHECK (cudaEventCreateWithFlags (&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming));
484484 }
485- }
486485
487- // create cublas handle
488- CUBLAS_CHECK (cublasCreate (&g_cublasH));
489- CUBLAS_CHECK (cublasSetMathMode (g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
486+ // create cublas handle
487+ CUBLAS_CHECK (cublasCreate (&g_cublas_handles[id]));
488+ CUBLAS_CHECK (cublasSetMathMode (g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
489+ }
490490
491491 // configure logging to stdout
492492 // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
@@ -659,9 +659,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
659659
660660 const uint64_t i0_diff = i0_high - i0_low;
661661
662- CUBLAS_CHECK (cublasSetStream (g_cublasH, cudaStream_main));
662+ int id;
663+ CUDA_CHECK (cudaGetDevice (&id));
664+
665+ CUBLAS_CHECK (cublasSetStream (g_cublas_handles[id], cudaStream_main));
663666 CUBLAS_CHECK (
664- cublasSgemm (g_cublasH , CUBLAS_OP_T, CUBLAS_OP_N,
667+ cublasSgemm (g_cublas_handles[id] , CUBLAS_OP_T, CUBLAS_OP_N,
665668 i0_diff, ne11, ne10,
666669 &alpha, src0_ddf_i, ne00,
667670 src1_ddf_i, ne10,
0 commit comments