@@ -986,7 +986,51 @@ inline static __m128 ggml_v_silu(__m128 x) {
986986 return _mm_div_ps (x , one_plus_exp_neg_x );
987987}
988988
989- #endif // __ARM_NEON / __AVX2__ / __SSE2__
989+ #elif defined(__riscv_v_intrinsic )
990+
991+ // adapted from arm limited optimized routine
992+ // the maximum error is 1.45358 plus 0.5 ulps
993+ // numbers above 88.38 will flush to infinity
994+ // numbers beneath -103.97 will flush to zero
995+ inline static vfloat32m2_t ggml_v_expf_m2 (vfloat32m2_t x , int vl ) {
996+ const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2 (0x1.8p23f , vl );
997+ #ifdef __riscv_xtheadvector
998+ // workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4')
999+ vfloat32m2_t z = __riscv_vfadd_vf_f32m2 (r , 0.0f , vl );
1000+ z = __riscv_vfmacc_vf_f32m2 (z , 0x1.715476p+0f , x , vl );
1001+ #else
1002+ const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2 (r , 0x1.715476p+0f , x , vl );
1003+ #endif
1004+ const vfloat32m2_t n = __riscv_vfsub_vv_f32m2 (z , r , vl );
1005+ const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2 (__riscv_vfnmsac_vf_f32m2 (x , 0x1.62e4p-1f , n , vl ),
1006+ 0x1.7f7d1cp-20f , n , vl );
1007+ const vuint32m2_t e = __riscv_vsll_vx_u32m2 (__riscv_vreinterpret_v_f32m2_u32m2 (z ), 23 , vl );
1008+ const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2 (__riscv_vadd_vx_u32m2 (e , 0x3f800000 , vl )); // 1.0f
1009+ const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16 (__riscv_vfabs_v_f32m2 (n , vl ), 126.0f , vl );
1010+ const vfloat32m2_t u = __riscv_vfmul_vv_f32m2 (b , b , vl );
1011+ const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2 (
1012+ __riscv_vfmul_vf_f32m2 (b , 0x1.ffffecp-1f , vl ),
1013+ __riscv_vfmacc_vv_f32m2 (
1014+ __riscv_vfmacc_vf_f32m2 (__riscv_vfmv_v_f_f32m2 (0x1.fffdb6p-2f , vl ), 0x1.555e66p-3f , b , vl ),
1015+ __riscv_vfmacc_vf_f32m2 (__riscv_vfmv_v_f_f32m2 (0x1.573e2ep-5f , vl ), 0x1.0e4020p-7f , b , vl ),
1016+ u , vl ), u , vl );
1017+ if (!__riscv_vcpop_m_b16 (c , vl ))
1018+ return __riscv_vfmacc_vv_f32m2 (k , j , k , vl );
1019+ const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16 (n , 0.0f , vl );
1020+ const vuint32m2_t d = __riscv_vmerge_vxm_u32m2 (__riscv_vmv_v_x_u32m2 (0 , vl ), 0x82000000 , dm , vl );
1021+ const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2 (__riscv_vadd_vx_u32m2 (d , 0x7f000000 , vl ));
1022+ const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2 (__riscv_vsub_vv_u32m2 (e , d , vl ));
1023+ const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2 (
1024+ __riscv_vfmacc_vv_f32m2 (k , k , j , vl ),
1025+ __riscv_vfmul_vv_f32m2 (__riscv_vfmacc_vv_f32m2 (s2 , s2 , j , vl ), s1 , vl ),
1026+ c , vl );
1027+ return __riscv_vmerge_vvm_f32m2 (
1028+ r1 , __riscv_vfmul_vv_f32m2 (s1 , s1 , vl ),
1029+ __riscv_vmfgt_vf_f32m2_b16 (__riscv_vfabs_v_f32m2 (n , vl ), 192.0f , vl ),
1030+ vl );
1031+ }
1032+
1033+ #endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
9901034
9911035inline static void ggml_vec_silu_f16 (const int n , ggml_fp16_t * y , const ggml_fp16_t * x ) {
9921036 for (int i = 0 ; i < n ; ++ i ) {
0 commit comments