1-
21#include " ggml.h"
32#include " common.cuh"
43#include " convert.cuh"
@@ -8,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
87static __global__ void mul_mat_vec_f (
98 const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
109 const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
11- const uint3 channel_ratio_fd , const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
12- const uint3 sample_ratio_fd , const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
10+ const uint3 channel_ratio , const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
11+ const uint3 sample_ratio , const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
1312 const int row = blockIdx .x ;
1413 const int channel_dst = blockIdx .y ;
15- const int channel_x = ids ? ids[channel_dst] : fastdiv ((uint32_t ) channel_dst, channel_ratio_fd );
14+ const int channel_x = ids ? ids[channel_dst] : fastdiv ((uint32_t ) channel_dst, channel_ratio );
1615 const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
1716 const int sample_dst = blockIdx .z ;
18- const int sample_x = fastdiv ((uint32_t ) sample_dst, sample_ratio_fd );
17+ const int sample_x = fastdiv ((uint32_t ) sample_dst, sample_ratio );
1918 const int sample_y = sample_dst;
2019 const int tid = threadIdx .x ;
2120
@@ -89,16 +88,14 @@ static __global__ void mul_mat_vec_f(
8988#endif // FP16_AVAILABLE
9089 }
9190 } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
92- const int * x2 = (const int *) x;
91+ const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
9392 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
94- const int tmpx = x2[col2];
93+ const nv_bfloat162 tmpx = x2[col2];
9594#pragma unroll
9695 for (int j = 0 ; j < ncols_dst; ++j) {
9796 const float2 tmpy = y2[j*stride_col_y2 + col2];
98- const float tmpx0 = ggml_cuda_cast<float >(reinterpret_cast <const nv_bfloat16 *>(&tmpx)[0 ]);
99- const float tmpx1 = ggml_cuda_cast<float >(reinterpret_cast <const nv_bfloat16 *>(&tmpx)[1 ]);
100- ggml_cuda_mad (sumf[j], tmpx0, tmpy.x );
101- ggml_cuda_mad (sumf[j], tmpx1, tmpy.y );
97+ ggml_cuda_mad (sumf[j], tmpx.x , tmpy.x );
98+ ggml_cuda_mad (sumf[j], tmpx.y , tmpy.y );
10299 }
103100 }
104101 } else {
@@ -143,7 +140,7 @@ static void launch_mul_mat_vec_f_cuda(
143140 GGML_ASSERT (stride_col_y % 2 == 0 );
144141 GGML_ASSERT (ids || nchannels_dst % nchannels_x == 0 );
145142 GGML_ASSERT ( nsamples_dst % nsamples_x == 0 );
146- const uint3 channel_ratio_fd = ids ? make_uint3 (0 , 0 , 0 ) : init_fastdiv_values (nchannels_dst / nchannels_x);
143+ const uint3 channel_ratio_fd = ids ? make_uint3 (0 , 0 , 0 ) : init_fastdiv_values (nchannels_dst / nchannels_x);
147144 const uint3 sample_ratio_fd = init_fastdiv_values (nsamples_dst / nsamples_x);
148145
149146 const int device = ggml_cuda_get_device ();
0 commit comments