@@ -202,6 +202,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
202202 GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
203203 GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
204204 GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207+ GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213+ GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
205214 GGML_METAL_KERNEL_TYPE_RMS_NORM,
206215 GGML_METAL_KERNEL_TYPE_L2_NORM,
207216 GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1169,6 +1178,15 @@ @implementation GGMLMetalClass
11691178 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true );
11701179 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true );
11711180 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true );
1181+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true );
1182+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true );
1183+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
1184+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true );
1185+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true );
1186+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true );
1187+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true );
1188+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true );
1189+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true );
11721190 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
11731191 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
11741192 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1635,6 +1653,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16351653 const bool use_bfloat = ctx_dev->use_bfloat ;
16361654
16371655 if (!use_bfloat) {
1656+ if (op->type == GGML_TYPE_BF16) {
1657+ return false ;
1658+ }
1659+
16381660 for (size_t i = 0 , n = 3 ; i < n; ++i) {
16391661 if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16) {
16401662 return false ;
@@ -1804,6 +1826,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18041826 {
18051827 return op->ne [3 ] == 1 ;
18061828 }
1829+ case GGML_OP_SET_ROWS:
1830+ {
1831+ if (op->src [0 ]->type != GGML_TYPE_F32) {
1832+ return false ;
1833+ }
1834+
1835+ switch (op->type ) {
1836+ case GGML_TYPE_F32:
1837+ case GGML_TYPE_F16:
1838+ case GGML_TYPE_BF16:
1839+ case GGML_TYPE_Q8_0:
1840+ case GGML_TYPE_Q4_0:
1841+ case GGML_TYPE_Q4_1:
1842+ case GGML_TYPE_Q5_0:
1843+ case GGML_TYPE_Q5_1:
1844+ case GGML_TYPE_IQ4_NL:
1845+ return true ;
1846+ default :
1847+ return false ;
1848+ };
1849+ }
18071850 default :
18081851 return false ;
18091852 }
@@ -3777,13 +3820,74 @@ static bool ggml_metal_encode_node(
37773820 };
37783821
37793822 [encoder setComputePipelineState: pipeline];
3780- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3781- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
3782- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
3783- [encoder setBytes: &args length: sizeof (args) atIndex: 3 ];
3823+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3824+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3825+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
3826+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
37843827
37853828 [encoder dispatchThreadgroups: MTLSizeMake (ne10, ne11, 1 ) threadsPerThreadgroup: MTLSizeMake (32 , 1 , 1 )];
37863829 } break ;
3830+ case GGML_OP_SET_ROWS:
3831+ {
3832+ id <MTLComputePipelineState > pipeline = nil ;
3833+
3834+ switch (dst->type ) {
3835+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline ; break ;
3836+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline ; break ;
3837+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline ; break ;
3838+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline ; break ;
3839+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline ; break ;
3840+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline ; break ;
3841+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline ; break ;
3842+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline ; break ;
3843+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline ; break ;
3844+ default : GGML_ABORT (" not implemented" );
3845+ }
3846+
3847+ const int32_t nk0 = ne0/ggml_blck_size (dst->type );
3848+
3849+ int nth = 32 ; // SIMD width
3850+
3851+ while (nth < nk0 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3852+ nth *= 2 ;
3853+ }
3854+
3855+ int nrptg = 1 ;
3856+ if (nth > nk0) {
3857+ nrptg = (nth + nk0 - 1 )/nk0;
3858+ nth = nk0;
3859+
3860+ if (nrptg*nth > (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3861+ nrptg--;
3862+ }
3863+ }
3864+
3865+ nth = MIN (nth, nk0);
3866+
3867+ ggml_metal_kargs_set_rows args = {
3868+ /* .nk0 =*/ nk0,
3869+ /* .ne01 =*/ ne01,
3870+ /* .nb01 =*/ nb01,
3871+ /* .nb02 =*/ nb02,
3872+ /* .nb03 =*/ nb03,
3873+ /* .ne11 =*/ ne11,
3874+ /* .ne12 =*/ ne12,
3875+ /* .nb10 =*/ nb10,
3876+ /* .nb11 =*/ nb11,
3877+ /* .nb12 =*/ nb12,
3878+ /* .nb1 =*/ nb1,
3879+ /* .nb2 =*/ nb2,
3880+ /* .nb3 =*/ nb3,
3881+ };
3882+
3883+ [encoder setComputePipelineState: pipeline];
3884+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3885+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3886+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
3887+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
3888+
3889+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nrptg - 1 )/nrptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, nrptg, 1 )];
3890+ } break ;
37873891 case GGML_OP_RMS_NORM:
37883892 {
37893893 GGML_ASSERT (ne00 % 4 == 0 );
0 commit comments