8989 GGML_METAL_DECL_KERNEL (get_rows_q6_K);
9090 GGML_METAL_DECL_KERNEL (get_rows_i32);
9191 GGML_METAL_DECL_KERNEL (get_rows_iq2_xxs);
92+ GGML_METAL_DECL_KERNEL (get_rows_iq2_xs);
9293 GGML_METAL_DECL_KERNEL (rms_norm);
9394 GGML_METAL_DECL_KERNEL (group_norm);
9495 GGML_METAL_DECL_KERNEL (norm);
108109 GGML_METAL_DECL_KERNEL (mul_mv_q5_K_f32);
109110 GGML_METAL_DECL_KERNEL (mul_mv_q6_K_f32);
110111 GGML_METAL_DECL_KERNEL (mul_mv_iq2_xxs_f32);
112+ GGML_METAL_DECL_KERNEL (mul_mv_iq2_xs_f32);
111113 GGML_METAL_DECL_KERNEL (mul_mv_id_f32_f32);
112114 // GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
113115 GGML_METAL_DECL_KERNEL (mul_mv_id_f16_f32);
124126 GGML_METAL_DECL_KERNEL (mul_mv_id_q5_K_f32);
125127 GGML_METAL_DECL_KERNEL (mul_mv_id_q6_K_f32);
126128 GGML_METAL_DECL_KERNEL (mul_mv_id_iq2_xxs_f32);
129+ GGML_METAL_DECL_KERNEL (mul_mv_id_iq2_xs_f32);
127130 GGML_METAL_DECL_KERNEL (mul_mm_f32_f32);
128131 GGML_METAL_DECL_KERNEL (mul_mm_f16_f32);
129132 GGML_METAL_DECL_KERNEL (mul_mm_q4_0_f32);
137140 GGML_METAL_DECL_KERNEL (mul_mm_q5_K_f32);
138141 GGML_METAL_DECL_KERNEL (mul_mm_q6_K_f32);
139142 GGML_METAL_DECL_KERNEL (mul_mm_iq2_xxs_f32);
143+ GGML_METAL_DECL_KERNEL (mul_mm_iq2_xs_f32);
140144 GGML_METAL_DECL_KERNEL (mul_mm_id_f32_f32);
141145 GGML_METAL_DECL_KERNEL (mul_mm_id_f16_f32);
142146 GGML_METAL_DECL_KERNEL (mul_mm_id_q4_0_f32);
150154 GGML_METAL_DECL_KERNEL (mul_mm_id_q5_K_f32);
151155 GGML_METAL_DECL_KERNEL (mul_mm_id_q6_K_f32);
152156 GGML_METAL_DECL_KERNEL (mul_mm_id_iq2_xxs_f32);
157+ GGML_METAL_DECL_KERNEL (mul_mm_id_iq2_xs_f32);
153158 GGML_METAL_DECL_KERNEL (rope_f32);
154159 GGML_METAL_DECL_KERNEL (rope_f16);
155160 GGML_METAL_DECL_KERNEL (alibi_f32);
@@ -385,6 +390,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
385390 GGML_METAL_ADD_KERNEL (get_rows_q6_K);
386391 GGML_METAL_ADD_KERNEL (get_rows_i32);
387392 GGML_METAL_ADD_KERNEL (get_rows_iq2_xxs);
393+ GGML_METAL_ADD_KERNEL (get_rows_iq2_xs);
388394 GGML_METAL_ADD_KERNEL (rms_norm);
389395 GGML_METAL_ADD_KERNEL (group_norm);
390396 GGML_METAL_ADD_KERNEL (norm);
@@ -404,6 +410,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
404410 GGML_METAL_ADD_KERNEL (mul_mv_q5_K_f32);
405411 GGML_METAL_ADD_KERNEL (mul_mv_q6_K_f32);
406412 GGML_METAL_ADD_KERNEL (mul_mv_iq2_xxs_f32);
413+ GGML_METAL_ADD_KERNEL (mul_mv_iq2_xs_f32);
407414 GGML_METAL_ADD_KERNEL (mul_mv_id_f32_f32);
408415 // GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
409416 GGML_METAL_ADD_KERNEL (mul_mv_id_f16_f32);
@@ -420,6 +427,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
420427 GGML_METAL_ADD_KERNEL (mul_mv_id_q5_K_f32);
421428 GGML_METAL_ADD_KERNEL (mul_mv_id_q6_K_f32);
422429 GGML_METAL_ADD_KERNEL (mul_mv_id_iq2_xxs_f32);
430+ GGML_METAL_ADD_KERNEL (mul_mv_id_iq2_xs_f32);
423431 if ([ctx->device supportsFamily: MTLGPUFamilyApple7]) {
424432 GGML_METAL_ADD_KERNEL (mul_mm_f32_f32);
425433 GGML_METAL_ADD_KERNEL (mul_mm_f16_f32);
@@ -434,6 +442,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
434442 GGML_METAL_ADD_KERNEL (mul_mm_q5_K_f32);
435443 GGML_METAL_ADD_KERNEL (mul_mm_q6_K_f32);
436444 GGML_METAL_ADD_KERNEL (mul_mm_iq2_xxs_f32);
445+ GGML_METAL_ADD_KERNEL (mul_mm_iq2_xs_f32);
437446 GGML_METAL_ADD_KERNEL (mul_mm_id_f32_f32);
438447 GGML_METAL_ADD_KERNEL (mul_mm_id_f16_f32);
439448 GGML_METAL_ADD_KERNEL (mul_mm_id_q4_0_f32);
@@ -447,6 +456,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
447456 GGML_METAL_ADD_KERNEL (mul_mm_id_q5_K_f32);
448457 GGML_METAL_ADD_KERNEL (mul_mm_id_q6_K_f32);
449458 GGML_METAL_ADD_KERNEL (mul_mm_id_iq2_xxs_f32);
459+ GGML_METAL_ADD_KERNEL (mul_mm_id_iq2_xs_f32);
450460 }
451461 GGML_METAL_ADD_KERNEL (rope_f32);
452462 GGML_METAL_ADD_KERNEL (rope_f16);
@@ -513,6 +523,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
513523 GGML_METAL_DEL_KERNEL (get_rows_q6_K);
514524 GGML_METAL_DEL_KERNEL (get_rows_i32);
515525 GGML_METAL_DEL_KERNEL (get_rows_iq2_xxs);
526+ GGML_METAL_DEL_KERNEL (get_rows_iq2_xs);
516527 GGML_METAL_DEL_KERNEL (rms_norm);
517528 GGML_METAL_DEL_KERNEL (group_norm);
518529 GGML_METAL_DEL_KERNEL (norm);
@@ -532,6 +543,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
532543 GGML_METAL_DEL_KERNEL (mul_mv_q5_K_f32);
533544 GGML_METAL_DEL_KERNEL (mul_mv_q6_K_f32);
534545 GGML_METAL_DEL_KERNEL (mul_mv_iq2_xxs_f32);
546+ GGML_METAL_DEL_KERNEL (mul_mv_iq2_xs_f32);
535547 GGML_METAL_DEL_KERNEL (mul_mv_id_f32_f32);
536548 // GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
537549 GGML_METAL_DEL_KERNEL (mul_mv_id_f16_f32);
@@ -548,6 +560,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
548560 GGML_METAL_DEL_KERNEL (mul_mv_id_q5_K_f32);
549561 GGML_METAL_DEL_KERNEL (mul_mv_id_q6_K_f32);
550562 GGML_METAL_DEL_KERNEL (mul_mv_id_iq2_xxs_f32);
563+ GGML_METAL_DEL_KERNEL (mul_mv_id_iq2_xs_f32);
551564 if ([ctx->device supportsFamily: MTLGPUFamilyApple7]) {
552565 GGML_METAL_DEL_KERNEL (mul_mm_f32_f32);
553566 GGML_METAL_DEL_KERNEL (mul_mm_f16_f32);
@@ -562,6 +575,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
562575 GGML_METAL_DEL_KERNEL (mul_mm_q5_K_f32);
563576 GGML_METAL_DEL_KERNEL (mul_mm_q6_K_f32);
564577 GGML_METAL_DEL_KERNEL (mul_mm_iq2_xxs_f32);
578+ GGML_METAL_DEL_KERNEL (mul_mm_iq2_xs_f32);
565579 GGML_METAL_DEL_KERNEL (mul_mm_id_f32_f32);
566580 GGML_METAL_DEL_KERNEL (mul_mm_id_f16_f32);
567581 GGML_METAL_DEL_KERNEL (mul_mm_id_q4_0_f32);
@@ -575,6 +589,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
575589 GGML_METAL_DEL_KERNEL (mul_mm_id_q5_K_f32);
576590 GGML_METAL_DEL_KERNEL (mul_mm_id_q6_K_f32);
577591 GGML_METAL_DEL_KERNEL (mul_mm_id_iq2_xxs_f32);
592+ GGML_METAL_DEL_KERNEL (mul_mm_id_iq2_xs_f32);
578593 }
579594 GGML_METAL_DEL_KERNEL (rope_f32);
580595 GGML_METAL_DEL_KERNEL (rope_f16);
@@ -1561,6 +1576,7 @@ bool ggml_metal_graph_compute(
15611576 case GGML_TYPE_Q5_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q5_K_f32]; break ;
15621577 case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q6_K_f32]; break ;
15631578 case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState: ctx->pipeline_mul_mm_iq2_xxs_f32]; break ;
1579+ case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState: ctx->pipeline_mul_mm_iq2_xs_f32]; break ;
15641580 default : GGML_ASSERT (false && " MUL MAT-MAT not implemented" );
15651581 }
15661582 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
@@ -1679,6 +1695,12 @@ bool ggml_metal_graph_compute(
16791695 nth1 = 16 ;
16801696 [encoder setComputePipelineState: ctx->pipeline_mul_mv_iq2_xxs_f32];
16811697 } break ;
1698+ case GGML_TYPE_IQ2_XS:
1699+ {
1700+ nth0 = 4 ;
1701+ nth1 = 16 ;
1702+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_iq2_xs_f32];
1703+ } break ;
16821704 default :
16831705 {
16841706 GGML_METAL_LOG_ERROR (" Asserting on type %d \n " , (int )src0t);
@@ -1712,12 +1734,12 @@ bool ggml_metal_graph_compute(
17121734
17131735 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
17141736 src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1715- // src0t == GGML_TYPE_IQ2_XXS ||
17161737 src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
17171738 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
17181739 }
1719- else if (src0t == GGML_TYPE_IQ2_XXS) {
1720- [encoder setThreadgroupMemoryLength: (256 *8 +128 ) atIndex: 0 ];
1740+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1741+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256 *8 +128 : 512 *8 +128 ;
1742+ [encoder setThreadgroupMemoryLength: mem_size atIndex: 0 ];
17211743 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
17221744 }
17231745 else if (src0t == GGML_TYPE_Q4_K) {
@@ -1810,6 +1832,7 @@ bool ggml_metal_graph_compute(
18101832 case GGML_TYPE_Q5_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_q5_K_f32]; break ;
18111833 case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_q6_K_f32]; break ;
18121834 case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break ;
1835+ case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_iq2_xs_f32]; break ;
18131836 default : GGML_ASSERT (false && " MUL_MAT_ID not implemented" );
18141837 }
18151838 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
@@ -1931,6 +1954,12 @@ bool ggml_metal_graph_compute(
19311954 nth1 = 16 ;
19321955 [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_iq2_xxs_f32];
19331956 } break ;
1957+ case GGML_TYPE_IQ2_XS:
1958+ {
1959+ nth0 = 4 ;
1960+ nth1 = 16 ;
1961+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_iq2_xs_f32];
1962+ } break ;
19341963 default :
19351964 {
19361965 GGML_METAL_LOG_ERROR (" Asserting on type %d \n " , (int )src2t);
@@ -1980,12 +2009,12 @@ bool ggml_metal_graph_compute(
19802009
19812010 if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
19822011 src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1983- // src2t == GGML_TYPE_IQ2_XXS ||
19842012 src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
19852013 [encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 7 )/8 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
19862014 }
1987- else if (src2t == GGML_TYPE_IQ2_XXS) {
1988- [encoder setThreadgroupMemoryLength: (256 *8 +128 ) atIndex: 0 ];
2015+ else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
2016+ const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256 *8 +128 : 512 *8 +128 ;
2017+ [encoder setThreadgroupMemoryLength: mem_size atIndex: 0 ];
19892018 [encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 7 )/8 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
19902019 }
19912020 else if (src2t == GGML_TYPE_Q4_K) {
@@ -2026,6 +2055,7 @@ bool ggml_metal_graph_compute(
20262055 case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_get_rows_q6_K]; break ;
20272056 case GGML_TYPE_I32: [encoder setComputePipelineState: ctx->pipeline_get_rows_i32]; break ;
20282057 case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState: ctx->pipeline_get_rows_iq2_xxs]; break ;
2058+ case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState: ctx->pipeline_get_rows_iq2_xs]; break ;
20292059 default : GGML_ASSERT (false && " not implemented" );
20302060 }
20312061
0 commit comments