@@ -894,14 +894,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
894894}
895895
896896/* *
897- * @brief Get or expand a cached float32 tensor filled with a scalar value.
897+ * @brief Get or expand a cached tensor filled with a scalar value.
898898 *
899- * This function manages cached device memory for float32 tensors. If the current
899+ * This function manages cached device memory for tensors. If the current
900900 * cache size is insufficient for the requested tensor shape, the old memory will
901- * be released and new memory will be allocated. The allocated buffer is then
902- * initialized either with zeros (when @p value == 0.0f) or with the given scalar
903- * value using CANN operations. Finally, an aclTensor object is created from the
904- * cached memory and returned.
901+ * be released and new memory will be allocated. The allocated buffer is
902+ * initialized with the given scalar value using CANN operations.
903+ * Finally, an aclTensor object is created from the cached memory and returned.
905904 *
906905 * @param ctx The CANN backend context that manages device memory.
907906 * @param buffer A pointer to the cached device buffer (will be allocated
@@ -910,25 +909,27 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
910909 * updated when the cache is expanded.
911910 * @param ne The tensor shape array (number of elements in each dimension).
912911 * @param nb The stride size for each dimension.
912+ * @param dtype Data type of cached tensor.
913913 * @param dims The number of tensor dimensions.
914914 * @param value The scalar value used to fill the tensor (supports zero
915915 * initialization via memset or arbitrary values via fill_scalar).
916916 * @return An aclTensor pointer created from the cached buffer.
917917 */
918- static aclTensor* get_f32_cache_acl_tensor (
918+ static aclTensor* get_cache_acl_tensor (
919919 ggml_backend_cann_context& ctx,
920920 void ** buffer,
921921 int64_t &cache_element,
922922 int64_t * ne,
923923 size_t * nb,
924+ ggml_type dtype,
924925 int64_t dims,
925926 float value) {
926927 // Calculate total number of elements
927928 int64_t n_element = 1 ;
928929 for (int i = 0 ; i < dims; i++) {
929930 n_element *= ne[i];
930931 }
931- size_t size = n_element * sizeof ( float );
932+ size_t size = n_element * ggml_type_size (dtype );
932933
933934 // Allocate or expand cache if needed
934935 if (cache_element < n_element) {
@@ -941,19 +942,17 @@ static aclTensor* get_f32_cache_acl_tensor(
941942 cache_element = n_element;
942943
943944 // Initialize cache
944- if (value == 0 .0f ) {
945- ACL_CHECK (aclrtMemsetAsync (*buffer, size, 0 , size, ctx.stream ()));
946- } else {
947- int64_t pool_ne[1 ] = { n_element };
948- size_t pool_nb[1 ] = { sizeof (float ) };
949- aclTensor* acl_value = ggml_cann_create_tensor (
950- *buffer, ACL_FLOAT, sizeof (float ), pool_ne, pool_nb, 1 );
951- aclnn_fill_scalar (ctx, 1 , acl_value);
952- ggml_cann_release_resources (ctx, acl_value);
953- }
945+ int64_t pool_ne[1 ] = { n_element };
946+ size_t pool_nb[1 ] = { ggml_type_size (dtype) };
947+ aclTensor* acl_value = ggml_cann_create_tensor (
948+ *buffer, ggml_cann_type_mapping (dtype), ggml_type_size (dtype),
949+ pool_ne, pool_nb, 1 );
950+ aclnn_fill_scalar (ctx, value, acl_value);
951+ ggml_cann_release_resources (ctx, acl_value);
954952 }
955953
956- return ggml_cann_create_tensor (*buffer, ACL_FLOAT, sizeof (float ), ne, nb, dims);
954+ return ggml_cann_create_tensor (*buffer, ggml_cann_type_mapping (dtype),
955+ ggml_type_size (dtype), ne, nb, dims);
957956}
958957
959958void ggml_cann_rms_norm (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -965,35 +964,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
965964 float eps;
966965 memcpy (&eps, dst->op_params , sizeof (float ));
967966
968- // build gamma, one.. .
967+ // build gamma.
969968 size_t acl_gamma_nb[GGML_MAX_DIMS];
970- acl_gamma_nb[0 ] = sizeof (float );
969+ // gamma's type is the same with dst.
970+ acl_gamma_nb[0 ] = ggml_type_size (dst->type );
971971 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
972972 acl_gamma_nb[i] = acl_gamma_nb[i - 1 ] * src->ne [i - 1 ];
973973 }
974- aclTensor* acl_gamma = get_f32_cache_acl_tensor (
974+ aclTensor* acl_gamma = get_cache_acl_tensor (
975975 ctx,
976976 &ctx.rms_norm_one_tensor_cache .cache ,
977977 ctx.rms_norm_one_tensor_cache .size ,
978978 src->ne ,
979979 acl_gamma_nb,
980+ dst->type ,
980981 1 , // dims
981982 1 .0f // value
982983 );
983984
984- // build rstd, zero.. .
985+ // build rstd.
985986 int64_t acl_rstd_ne[] = {src->ne [1 ], src->ne [2 ], src->ne [3 ]};
986987 size_t acl_rstd_nb[GGML_MAX_DIMS - 1 ];
988+ // rstd will always be F32.
987989 acl_rstd_nb[0 ] = sizeof (float );
988990 for (int i = 1 ; i < GGML_MAX_DIMS - 1 ; i++) {
989991 acl_rstd_nb[i] = acl_rstd_nb[i - 1 ] * acl_rstd_ne[i - 1 ];
990992 }
991- aclTensor* acl_rstd = get_f32_cache_acl_tensor (
993+ aclTensor* acl_rstd = get_cache_acl_tensor (
992994 ctx,
993995 &ctx.rms_norm_zero_tensor_cache .cache ,
994996 ctx.rms_norm_zero_tensor_cache .size ,
995997 acl_rstd_ne,
996998 acl_rstd_nb,
999+ GGML_TYPE_F32,
9971000 GGML_MAX_DIMS - 1 ,
9981001 0 .0f // value
9991002 );
@@ -1765,41 +1768,43 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
17651768 ggml_tensor* src0 = dst->src [0 ]; // src
17661769 ggml_tensor* src1 = dst->src [1 ]; // index
17671770
1771+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1772+
1773+ // src0 type must be F16 or quantized format
17681774 switch (src0->type ) {
1769- case GGML_TYPE_F32: {
1770- aclnn_index_select_4d (ctx, src0->data , src0->ne , src0->nb ,
1771- dst->data , dst->ne , dst->nb ,
1772- src1, dst->type );
1773- break ;
1774- }
1775- case GGML_TYPE_F16: {
1776- aclTensor* acl_src0 = ggml_cann_create_tensor (src0);
1777- ggml_cann_pool_alloc src_buffer_allocator (
1778- ctx.pool (), ggml_nelements (src0) * sizeof (float ));
1779- void * src_trans_buffer = src_buffer_allocator.get ();
1780- size_t src_trans_nb[GGML_MAX_DIMS];
1781- src_trans_nb[0 ] = sizeof (float );
1782- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
1783- src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
1775+ case GGML_TYPE_F16:
1776+ case GGML_TYPE_F32:
1777+ if (src0->type == dst->type ) {
1778+ aclnn_index_select_4d (ctx, src0->data , src0->ne , src0->nb ,
1779+ dst->data , dst->ne , dst->nb ,
1780+ src1, dst->type );
1781+ } else {
1782+ aclTensor* acl_src0 = ggml_cann_create_tensor (src0);
1783+ ggml_cann_pool_alloc src_buffer_allocator (
1784+ ctx.pool (), ggml_nelements (src0) * ggml_element_size (dst));
1785+ void * src_trans_buffer = src_buffer_allocator.get ();
1786+ size_t src_trans_nb[GGML_MAX_DIMS];
1787+ src_trans_nb[0 ] = dst->nb [0 ];
1788+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
1789+ src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
1790+ }
1791+ aclTensor* src_trans_tensor = ggml_cann_create_tensor (
1792+ src_trans_buffer, ggml_cann_type_mapping (dst->type ), ggml_type_size (dst->type ),
1793+ src0->ne , src_trans_nb, GGML_MAX_DIMS);
1794+ aclnn_cast (ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
1795+ aclnn_index_select_4d (ctx, src_trans_buffer, src0->ne , src_trans_nb,
1796+ dst->data , dst->ne , dst->nb ,
1797+ src1, dst->type );
1798+ ggml_cann_release_resources (ctx, acl_src0, src_trans_tensor);
17841799 }
1785- aclTensor* src_trans_tensor = ggml_cann_create_tensor (
1786- src_trans_buffer, ACL_FLOAT, ggml_type_size (dst->type ),
1787- src0->ne , src_trans_nb, GGML_MAX_DIMS);
1788- aclnn_cast (ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
1789- aclnn_index_select_4d (ctx, src_trans_buffer, src0->ne , src_trans_nb,
1790- dst->data , dst->ne , dst->nb ,
1791- src1, dst->type );
1792- ggml_cann_release_resources (ctx, acl_src0, src_trans_tensor);
17931800 break ;
1794- }
17951801 case GGML_TYPE_Q8_0: {
17961802 // add 1 dim for bcast mul.
17971803 size_t weight_nb[GGML_MAX_DIMS + 1 ], scale_nb[GGML_MAX_DIMS + 1 ],
17981804 dequant_nb[GGML_MAX_DIMS + 1 ];
17991805 int64_t weight_ne[GGML_MAX_DIMS + 1 ], scale_ne[GGML_MAX_DIMS + 1 ],
18001806 *dequant_ne;
18011807 int64_t scale_offset = 0 ;
1802-
18031808 // [3,4,5,64] -> [3,4,5,2,32]
18041809 weight_ne[0 ] = QK8_0;
18051810 weight_ne[1 ] = src0->ne [0 ] / QK8_0;
@@ -1809,7 +1814,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
18091814 weight_ne[i] = src0->ne [i - 1 ];
18101815 weight_nb[i] = weight_nb[i - 1 ] * weight_ne[i - 1 ];
18111816 }
1812-
18131817 // [3,4,5,64] -> [3,4,5,2,1]
18141818 scale_ne[0 ] = 1 ;
18151819 scale_ne[1 ] = src0->ne [0 ] / QK8_0;
@@ -1819,35 +1823,30 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
18191823 scale_ne[i] = src0->ne [i - 1 ];
18201824 scale_nb[i] = scale_nb[i - 1 ] * scale_ne[i - 1 ];
18211825 }
1822-
18231826 // [3,4,5,64] -> [3,4,5,2,32]
18241827 dequant_ne = weight_ne;
1825- dequant_nb[0 ] = sizeof ( float );
1828+ dequant_nb[0 ] = ggml_type_size (dst-> type );
18261829 for (int i = 1 ; i < GGML_MAX_DIMS + 1 ; i++) {
18271830 dequant_nb[i] = dequant_nb[i - 1 ] * dequant_ne[i - 1 ];
18281831 }
1829-
18301832 scale_offset = ggml_nelements (src0) * sizeof (int8_t );
18311833 ggml_cann_pool_alloc dequant_buffer_allocator (
1832- ctx.pool (), ggml_nelements (src0) * sizeof (float ));
1833-
1834+ ctx.pool (), ggml_nelements (src0) * ggml_type_size (dst->type ));
18341835 aclTensor* acl_weight_tensor = ggml_cann_create_tensor (
18351836 src0->data , ACL_INT8, sizeof (int8_t ), weight_ne, weight_nb,
18361837 GGML_MAX_DIMS + 1 );
18371838 aclTensor* acl_scale_tensor = ggml_cann_create_tensor (
18381839 src0->data , ACL_FLOAT16, sizeof (uint16_t ), scale_ne, scale_nb,
18391840 GGML_MAX_DIMS + 1 , ACL_FORMAT_ND, scale_offset);
18401841 aclTensor* dequant_tensor = ggml_cann_create_tensor (
1841- dequant_buffer_allocator.get (), ACL_FLOAT, sizeof ( float ),
1842+ dequant_buffer_allocator.get (), ggml_cann_type_mapping (dst-> type ), ggml_type_size (dst-> type ),
18421843 dequant_ne, dequant_nb, GGML_MAX_DIMS + 1 );
1843-
18441844 aclnn_mul (ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
1845- dequant_nb[0 ] = sizeof ( float );
1845+ dequant_nb[0 ] = ggml_type_size (dst-> type );
18461846 dequant_ne = src0->ne ;
18471847 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
18481848 dequant_nb[i] = dequant_nb[i - 1 ] * src0->ne [i - 1 ];
18491849 }
1850-
18511850 aclnn_index_select_4d (ctx, dequant_buffer_allocator.get (),
18521851 dequant_ne, dequant_nb,
18531852 dst->data , dst->ne , dst->nb ,
@@ -1859,6 +1858,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
18591858 default :
18601859 GGML_ABORT (" Unsupported tensor type for GGML_OP_GET_ROWS" );
18611860 break ;
1861+
18621862 }
18631863}
18641864
@@ -1965,16 +1965,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
19651965 // Only check env once.
19661966 static bool weight_to_nz = parse_bool (get_env (" GGML_CANN_WEIGHT_NZ" ).value_or (" on" ));
19671967 if (weight_to_nz && is_matmul_weight (weight)) {
1968- int64_t acl_stride[2 ] = {1 , transpose_ne[1 ]};
1969-
1970- // Reverse ne.
1971- std::reverse (transpose_ne, transpose_ne + n_dims);
1972-
1973- std::vector<int64_t > storageDims = {transpose_ne[0 ], transpose_ne[1 ]};
1974-
1975- acl_weight_tensor = aclCreateTensor (
1976- transpose_ne, n_dims, ggml_cann_type_mapping (weight->type ), acl_stride,
1977- 0 , ACL_FORMAT_FRACTAL_NZ, storageDims.data (), 2 , weight->data );
1968+ acl_weight_tensor =
1969+ ggml_cann_create_tensor (weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
19781970 } else {
19791971 acl_weight_tensor =
19801972 ggml_cann_create_tensor (weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
@@ -3178,7 +3170,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
31783170 aclTensor* acl_src0_f16_tensor = nullptr ;
31793171 aclTensor* acl_src1_f16_tensor = nullptr ;
31803172 aclTensor* acl_src2_f16_tensor = nullptr ;
3181- aclTensor* acl_dst_f16_tensor = nullptr ;
31823173
31833174 // Step 1: cast the src0 (Query) to fp16 if needed
31843175 ggml_cann_pool_alloc src0_f16_allocator (ctx.pool ());
@@ -3216,22 +3207,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32163207 acl_src2_f16_tensor = ggml_cann_create_tensor (src2, src2_bsnd_ne,
32173208 src2_bsnd_nb, GGML_MAX_DIMS);
32183209
3219- ggml_cann_pool_alloc out_f16_allocator (ctx.pool ());
3220- void * out_f16_buffer = out_f16_allocator.alloc (
3221- ggml_nelements (dst) * faElemSize);
3222-
3223- int64_t * out_f16_ne = src0_bsnd_ne;
3224- size_t out_f16_nb[GGML_MAX_DIMS];
3225- out_f16_nb[0 ] = faElemSize;
3226- for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
3227- out_f16_nb[i] = out_f16_nb[i - 1 ] * out_f16_ne[i - 1 ];
3228- }
3229-
3230- acl_dst_f16_tensor = ggml_cann_create_tensor (
3231- out_f16_buffer, faDataType, faElemSize,
3232- out_f16_ne, out_f16_nb, GGML_MAX_DIMS
3233- );
3234-
32353210 // Step 3: create the PSEShift tensor if needed
32363211 // this tensor is considered as mask (f16) in the llama.cpp
32373212 aclTensor* bcast_pse_tensor = nullptr ;
@@ -3334,8 +3309,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33343309 int64_t keyAntiquantMode = 0 ;
33353310 int64_t valueAntiquantMode = 0 ;
33363311
3337- // Step 5: launch the FusedInferAttentionScoreV2 kernel.
3338- // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
3312+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
3313+ aclTensor * fa_dst_tensor = nullptr ;
3314+ aclTensor * acl_dst_tensor = nullptr ;
3315+ ggml_cann_pool_alloc out_f16_allocator (ctx.pool ());
3316+ if (dst->type == GGML_TYPE_F32) {
3317+ void * out_f16_buffer = out_f16_allocator.alloc (
3318+ ggml_nelements (dst) * faElemSize);
3319+
3320+ int64_t * out_f16_ne = src0_bsnd_ne;
3321+ size_t out_f16_nb[GGML_MAX_DIMS];
3322+ out_f16_nb[0 ] = faElemSize;
3323+ for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
3324+ out_f16_nb[i] = out_f16_nb[i - 1 ] * out_f16_ne[i - 1 ];
3325+ }
3326+
3327+ fa_dst_tensor = ggml_cann_create_tensor (
3328+ out_f16_buffer, faDataType, faElemSize,
3329+ out_f16_ne, out_f16_nb, GGML_MAX_DIMS
3330+ );
3331+ }
3332+ else {
3333+ fa_dst_tensor = ggml_cann_create_tensor (dst);
3334+ }
33393335
33403336 GGML_CANN_CALL_ACLNN_OP (ctx, FusedInferAttentionScoreV2,
33413337 acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
@@ -3357,23 +3353,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33573353 blockSize, antiquantMode, // blockSize, antiquantMode
33583354 softmaxLseFlag, // softmaxLseFlag
33593355 keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
3360- acl_dst_f16_tensor , // attentionOut
3356+ fa_dst_tensor , // attentionOut
33613357 nullptr // softmaxLse
33623358 );
33633359
3364- // Step 6: post-processing, permute and cast to f32
3365- aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
3366- // TODO: when dst is fp16, don't need cast
3367- aclnn_cast (ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
3368- ggml_cann_release_resources (ctx, acl_src0_f16_tensor,
3369- acl_src1_f16_tensor,
3370- acl_src2_f16_tensor,
3371- acl_dst_f16_tensor,
3372- acl_dst_tensor);
3373- if (src3 != nullptr ){
3374- ggml_cann_release_resources (ctx, bcast_pse_tensor);
3360+ if (dst->type == GGML_TYPE_F32) {
3361+ // Step 6: post-processing, permute and cast to f32
3362+ aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
3363+ aclnn_cast (ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
33753364 }
3376- }else {
3365+
3366+ ggml_cann_release_resources (ctx, acl_src0_f16_tensor,
3367+ acl_src1_f16_tensor,
3368+ acl_src2_f16_tensor,
3369+ fa_dst_tensor,
3370+ acl_dst_tensor,
3371+ bcast_pse_tensor);
3372+
3373+ } else {
33773374 GGML_ABORT (" Function is not implemented." );
33783375 }
33793376}
0 commit comments