Skip to content

Commit 0bd133d

Browse files
Fixed f16 inference, x2 t/s
1 parent 0d5b895 commit 0bd133d

File tree

2 files changed

+23
-148
lines changed

2 files changed

+23
-148
lines changed

ggml-cuda.cu

Lines changed: 23 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,30 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
2323
} \
2424
} while (0)
2525

26+
#if CUDART_VERSION >= 12
2627
#define CUBLAS_CHECK(err) \
2728
do { \
2829
cublasStatus_t err_ = (err); \
2930
if (err_ != CUBLAS_STATUS_SUCCESS) { \
30-
fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
31+
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
32+
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
3133
exit(1); \
3234
} \
3335
} while (0)
36+
#else
37+
#define CUBLAS_CHECK(err) \
38+
do { \
39+
cublasStatus_t err_ = (err); \
40+
if (err_ != CUBLAS_STATUS_SUCCESS) { \
41+
fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
42+
exit(1); \
43+
} \
44+
} while (0)
45+
#endif // CUDART_VERSION >= 11
3446

3547
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
3648
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
37-
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize);
49+
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
3850
typedef void (*ggml_cuda_op_t)(
3951
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
4052
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, uint64_t i0_low, uint64_t i0_high, int i1, cudaStream_t & cudaStream_main);
@@ -190,8 +202,8 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int
190202
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
191203
const half * x = (const half *) vx;
192204

193-
v0 = __half2float(x[ib + 0]);
194-
v1 = __half2float(x[ib + 1]);
205+
v0 = __half2float(x[ib + iqs + 0]);
206+
v1 = __half2float(x[ib + iqs + 1]);
195207
}
196208

197209
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -555,114 +567,6 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(
555567
}
556568
}
557569

558-
static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
559-
GGML_ASSERT(g_device_count == 1);
560-
cudaSetDevice(0);
561-
const int64_t ne00 = src0->ne[0];
562-
const int64_t ne01 = src0->ne[1];
563-
const int64_t ne02 = src0->ne[2];
564-
const int64_t ne03 = src0->ne[3];
565-
566-
const int64_t ne10 = src1->ne[0];
567-
const int64_t ne11 = src1->ne[1];
568-
569-
const int nb10 = src1->nb[0];
570-
const int nb11 = src1->nb[1];
571-
const int nb12 = src1->nb[2];
572-
const int nb13 = src1->nb[3];
573-
574-
const int nb2 = dst->nb[2];
575-
const int nb3 = dst->nb[3];
576-
577-
const float alpha = 1.0f;
578-
const float beta = 0.0f;
579-
const int x_ne = ne01 * ne00;
580-
const int y_ne = ne11 * ne10;
581-
const int d_ne = ne11 * ne01;
582-
const int n_mm = ne03 * ne02;
583-
584-
bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
585-
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
586-
587-
size_t x_size, y_size, d_size;
588-
half * d_X;
589-
if (src0_on_device) {
590-
d_X = (half *) src0_extra->data_device[0];
591-
} else {
592-
d_X = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
593-
}
594-
595-
half * d_Y = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
596-
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
597-
598-
bool src1_cont_rows = nb10 == sizeof(float);
599-
bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
600-
601-
for (int64_t i03 = 0; i03 < ne03; i03++) {
602-
for (int64_t i02 = 0; i02 < ne02; i02++) {
603-
int i = i03*ne02 + i02;
604-
cudaStream_t cudaStream = g_cudaStreams_main[0][i % GGML_CUDA_MAX_STREAMS];
605-
606-
half * c_X = d_X + i * x_ne;
607-
half * c_Y = d_Y + i * y_ne;
608-
float * c_D = d_D + i * d_ne;
609-
610-
// copy src0 to device if necessary
611-
if (!src0_on_device) {
612-
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, 0, ne01, cudaStream));
613-
}
614-
615-
// convert src1 to fp16
616-
// TODO: use multiple threads
617-
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
618-
char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
619-
if (src1_cont_rows) {
620-
if (src1_cont_cols) {
621-
ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
622-
}
623-
else {
624-
for (int64_t i01 = 0; i01 < ne11; i01++) {
625-
ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
626-
}
627-
}
628-
}
629-
else {
630-
for (int64_t i01 = 0; i01 < ne11; i01++) {
631-
for (int64_t i00 = 0; i00 < ne10; i00++) {
632-
// very slow due to no inlining
633-
tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
634-
}
635-
}
636-
}
637-
638-
// copy src1 to device
639-
CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
640-
641-
// compute
642-
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
643-
CUBLAS_CHECK(
644-
cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
645-
ne01, ne11, ne10,
646-
&alpha, c_X, CUDA_R_16F, ne00,
647-
c_Y, CUDA_R_16F, ne10,
648-
&beta, c_D, CUDA_R_32F, ne01,
649-
CUBLAS_COMPUTE_32F_FAST_16F,
650-
CUBLAS_GEMM_DEFAULT));
651-
652-
// copy dst to host
653-
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
654-
CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
655-
}
656-
}
657-
658-
CUDA_CHECK(cudaDeviceSynchronize());
659-
if (!src0_on_device) {
660-
ggml_cuda_pool_free(d_X, x_size);
661-
}
662-
ggml_cuda_pool_free(d_Y, y_size);
663-
ggml_cuda_pool_free(d_D, d_size);
664-
}
665-
666570
inline void ggml_cuda_op_mul(
667571
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
668572
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, uint64_t i0_low, uint64_t i0_high, int i1,
@@ -917,7 +821,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
917821
float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
918822

919823
if (i0 - i0_offset_low > 0) {
920-
src0_ddq_i -= (row_low % ne01)*ne00*src0_ts / src0_bs;
824+
src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs;
921825
src0_ddf_i -= (row_low % ne01)*ne00;
922826
}
923827
if (i0 - i0_offset_low > 0) {
@@ -996,11 +900,9 @@ bool ggml_cuda_can_mul(const struct ggml_tensor * src0, const struct ggml_tensor
996900
return src1->backend == GGML_BACKEND_GPU;
997901
}
998902

999-
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
903+
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1000904
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
1001905
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true);
1002-
(void) wdata;
1003-
(void) wsize;
1004906
}
1005907

1006908
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
@@ -1021,48 +923,22 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
1021923
return false;
1022924
}
1023925

1024-
bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
1025-
size_t src0_sz = ggml_nbytes(src0);
1026-
size_t src1_sz = ggml_nbytes(src1);
1027-
1028-
// mul_mat_q: src0 is converted to fp32 on device
1029-
size_t mul_mat_q_transfer = src0_sz + src1_sz;
1030-
1031-
// mul_mat_f16: src1 is converted to fp16 on cpu
1032-
size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
1033-
1034-
// choose the smaller one to transfer to the device
1035-
// TODO: this is not always the best choice due to the overhead of converting to fp16
1036-
return mul_mat_f16_transfer < mul_mat_q_transfer;
1037-
}
1038-
1039-
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
926+
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1040927
GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
1041928

1042929
if (src0->type == GGML_TYPE_F32) {
1043930
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
1044-
}
1045-
else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
1046-
if (src1->ne[1] == 1 && src0->type != GGML_TYPE_F16) { // FIXME fp16 mul mat vec
931+
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
932+
if (src1->ne[1] == 1) {
1047933
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
1048934
} else {
1049935
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
1050936
}
1051-
}
1052-
else {
937+
} else {
1053938
GGML_ASSERT(false);
1054939
}
1055940
}
1056941

1057-
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
1058-
if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
1059-
return ggml_nelements(src1) * sizeof(ggml_fp16_t);
1060-
}
1061-
else {
1062-
return 0;
1063-
}
1064-
}
1065-
1066942
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset, int n_layer) {
1067943
FILE * fp = fopen(fname, "rb");
1068944
int nrows = ggml_nrows(tensor);
@@ -1179,6 +1055,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
11791055
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11801056
return true;
11811057
}
1182-
func(tensor->src0, tensor->src1, tensor, params->wdata, params->wsize);
1058+
func(tensor->src0, tensor->src1, tensor);
11831059
return true;
11841060
}

ggml.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14130,7 +14130,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1413014130
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
1413114131
node->n_tasks = 1; // TODO: this actually is doing nothing
1413214132
// the threads are still spinning
14133-
cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
1413414133
}
1413514134
else
1413614135
#elif defined(GGML_USE_CLBLAST)

0 commit comments

Comments
 (0)