Skip to content

Commit cfc2603

Browse files
junchao-loongsontaronaeo
authored andcommitted
ggml : fix LoongArch compile error with 128-bit SIMD (ggml-org#11701)
1 parent b4b2214 commit cfc2603

File tree

1 file changed

+91
-78
lines changed

1 file changed

+91
-78
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 91 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,90 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
297297
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
298298
#endif
299299

300+
#if defined(__loongarch_sx)
301+
302+
static __m128i lsx_packs_w(__m128i a, __m128i b) {
303+
__m128i tmp, tmp1;
304+
tmp = __lsx_vsat_w(a, 15);
305+
tmp1 = __lsx_vsat_w(b, 15);
306+
return __lsx_vpickev_h(tmp1, tmp);
307+
}
308+
309+
static __m128i lsx_packs_h(__m128i a, __m128i b) {
310+
__m128i tmp, tmp1;
311+
tmp = __lsx_vsat_h(a, 7);
312+
tmp1 = __lsx_vsat_h(b, 7);
313+
return __lsx_vpickev_b(tmp1, tmp);
314+
}
315+
316+
static __m128i lsx_packus_h(__m128i a, __m128i b) {
317+
__m128i tmp, tmp1;
318+
tmp = __lsx_vsat_hu(a, 7);
319+
tmp1 = __lsx_vsat_hu(b, 7);
320+
return __lsx_vpickev_b(tmp1, tmp);
321+
}
322+
323+
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
324+
__m128i tmp1, tmp2;
325+
tmp1 = __lsx_vmulwev_h_b(a, b);
326+
tmp2 = __lsx_vmulwod_h_b(a, b);
327+
return __lsx_vsadd_h(tmp1, tmp2);
328+
}
329+
330+
static __m128i lsx_madd_h(__m128i a, __m128i b) {
331+
__m128i tmp1, tmp2;
332+
tmp1 = __lsx_vmulwev_w_h(a, b);
333+
tmp2 = __lsx_vmulwod_w_h(a, b);
334+
return __lsx_vadd_w(tmp1, tmp2);
335+
}
336+
337+
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
338+
v4i32 __ret = {d, c, b, a};
339+
return (__m128i)__ret;
340+
}
341+
342+
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
343+
__m128i mask_f, zero, tmp0, tmp2, mask;
344+
int f = 0x8f;
345+
mask_f = __lsx_vreplgr2vr_b(f);
346+
zero = __lsx_vldi(0);
347+
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
348+
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
349+
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
350+
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
351+
return __lsx_vshuf_b(a, zero, tmp2);
352+
}
353+
354+
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
355+
__m128i tmp1 = __lsx_vpickev_h(b, a);
356+
__m128i tmp2 = __lsx_vpickod_h(b, a);
357+
return __lsx_vadd_h(tmp1, tmp2);
358+
}
359+
360+
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
361+
__m128i tmp1 = __lsx_vpickev_w(b, a);
362+
__m128i tmp2 = __lsx_vpickod_w(b, a);
363+
return __lsx_vadd_w(tmp1, tmp2);
364+
}
365+
366+
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
367+
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
368+
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
369+
370+
return __lsx_vfadd_s(tmp1, tmp2);
371+
}
372+
373+
static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
374+
__m128 res_0 =lsx_hadd_s(a, b);
375+
__m128 res_1 =lsx_hadd_s(c, d);
376+
__m128 res =lsx_hadd_s(res_0, res_1);
377+
res =lsx_hadd_s(res, res);
378+
res =lsx_hadd_s(res, res);
379+
380+
return ((v4f32)res)[0];
381+
}
382+
#endif
383+
300384
#if defined(__loongarch_asx)
301385

302386
#ifdef __clang__
@@ -395,11 +479,6 @@ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1
395479
return (__m256i)__ret;
396480
}
397481

398-
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
399-
v4i32 __ret = {d, c, b, a};
400-
return (__m128i)__ret;
401-
}
402-
403482
static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
404483
v4i64 __ret = {d, c, b, a};
405484
return (__m256i)__ret;
@@ -409,18 +488,6 @@ static __m256i lasx_insertf128( __m128i x, __m128i y) {
409488
return lasx_set_q(x, y);
410489
}
411490

412-
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
413-
__m128i mask_f, zero, tmp0, tmp2, mask;
414-
int f = 0x8f;
415-
mask_f = __lsx_vreplgr2vr_b(f);
416-
zero = __lsx_vldi(0);
417-
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
418-
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
419-
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
420-
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
421-
return __lsx_vshuf_b(a, zero, tmp2);
422-
}
423-
424491
static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
425492
__m256i mask_f, zero, tmp0, tmp2, mask;
426493
int f = 0x8f;
@@ -467,25 +534,6 @@ static __m128 lasx_extractf128( __m256 a, int pos) {
467534
return ret;
468535
}
469536

470-
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
471-
__m128i tmp1 = __lsx_vpickev_h(b, a);
472-
__m128i tmp2 = __lsx_vpickod_h(b, a);
473-
return __lsx_vadd_h(tmp1, tmp2);
474-
}
475-
476-
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
477-
__m128i tmp1 = __lsx_vpickev_w(b, a);
478-
__m128i tmp2 = __lsx_vpickod_w(b, a);
479-
return __lsx_vadd_w(tmp1, tmp2);
480-
}
481-
482-
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
483-
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
484-
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
485-
486-
return __lsx_vfadd_s(tmp1, tmp2);
487-
}
488-
489537
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
490538
__m256i tmp1, tmp2;
491539
tmp1 = __lasx_xvmulwev_h_b(a, b);
@@ -514,42 +562,6 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
514562
return __lasx_xvpickev_b(tmp1, tmp);
515563
}
516564

517-
static __m128i lsx_packs_w(__m128i a, __m128i b) {
518-
__m128i tmp, tmp1;
519-
tmp = __lsx_vsat_w(a, 15);
520-
tmp1 = __lsx_vsat_w(b, 15);
521-
return __lsx_vpickev_h(tmp1, tmp);
522-
}
523-
524-
static __m128i lsx_packs_h(__m128i a, __m128i b) {
525-
__m128i tmp, tmp1;
526-
tmp = __lsx_vsat_h(a, 7);
527-
tmp1 = __lsx_vsat_h(b, 7);
528-
return __lsx_vpickev_b(tmp1, tmp);
529-
}
530-
531-
static __m128i lsx_packus_h(__m128i a, __m128i b) {
532-
__m128i tmp, tmp1;
533-
tmp = __lsx_vsat_hu(a, 7);
534-
tmp1 = __lsx_vsat_hu(b, 7);
535-
return __lsx_vpickev_b(tmp1, tmp);
536-
}
537-
538-
539-
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
540-
__m128i tmp1, tmp2;
541-
tmp1 = __lsx_vmulwev_h_b(a, b);
542-
tmp2 = __lsx_vmulwod_h_b(a, b);
543-
return __lsx_vsadd_h(tmp1, tmp2);
544-
}
545-
546-
static __m128i lsx_madd_h(__m128i a, __m128i b) {
547-
__m128i tmp1, tmp2;
548-
tmp1 = __lsx_vmulwev_w_h(a, b);
549-
tmp2 = __lsx_vmulwod_w_h(a, b);
550-
return __lsx_vadd_w(tmp1, tmp2);
551-
}
552-
553565
// multiply int8_t, add results pairwise twice
554566
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
555567
// Get absolute values of x vectors
@@ -2281,21 +2293,22 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
22812293
}
22822294

22832295
sumf = hsum_float_8(acc);
2296+
22842297
#elif defined(__loongarch_sx)
22852298
// set constants
22862299
const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
22872300
const __m128i off = __lsx_vreplgr2vr_b(8);
22882301

22892302
// Initialize accumulator with zeros
2290-
__m128 acc_0 = __lsx_vldi(0);
2291-
__m128 acc_1 = __lsx_vldi(0);
2292-
__m128 acc_2 = __lsx_vldi(0);
2293-
__m128 acc_3 = __lsx_vldi(0);
2303+
__m128 acc_0 = (__m128)__lsx_vldi(0);
2304+
__m128 acc_1 = (__m128)__lsx_vldi(0);
2305+
__m128 acc_2 = (__m128)__lsx_vldi(0);
2306+
__m128 acc_3 = (__m128)__lsx_vldi(0);
22942307

22952308
for (; ib + 1 < nb; ib += 2) {
22962309

22972310
// Compute combined scale for the block 0 and 1
2298-
const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
2311+
const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
22992312

23002313
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
23012314

@@ -2313,7 +2326,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
23132326
//_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
23142327

23152328
// Compute combined scale for the block 2 and 3
2316-
const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
2329+
const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
23172330

23182331
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
23192332

0 commit comments

Comments
 (0)