@@ -4588,20 +4588,31 @@ struct test_topk_moe: public test_case {
45884588struct  test_sum  : public  test_case  {
45894589    const  ggml_type type;
45904590    const  std::array<int64_t , 4 > ne;
4591+     const  std::array<int64_t , 4 > permute;
4592+     bool  _use_permute;
45914593
45924594    std::string vars () override  {
4593-         return  VARS_TO_STR2 (type, ne);
4595+         std::string v = VARS_TO_STR2 (type, ne);
4596+         if  (_use_permute) v += " ," VAR_TO_STR (permute);
4597+         return  v;
45944598    }
45954599
45964600    test_sum (ggml_type type = GGML_TYPE_F32,
4597-             std::array<int64_t , 4 > ne = {10 , 5 , 4 , 3 })
4598-         : type(type), ne(ne) {}
4601+             std::array<int64_t , 4 > ne = {10 , 5 , 4 , 3 },
4602+             std::array<int64_t , 4 > permute = {0 , 0 , 0 , 0 })
4603+         : type(type), ne(ne), permute(permute),
4604+             _use_permute (permute[0 ] + permute[1 ] + permute[2 ] + permute[3 ] > 0 ) {}
45994605
46004606    ggml_tensor * build_graph (ggml_context * ctx) override  {
46014607        ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
46024608        ggml_set_param (a);
46034609        ggml_set_name (a, " a" 
46044610
4611+         if  (_use_permute) {
4612+             a = ggml_permute (ctx, a, permute[0 ], permute[1 ], permute[2 ], permute[3 ]);
4613+             ggml_set_name (a, " a_permuted" 
4614+         }
4615+ 
46054616        ggml_tensor * out = ggml_sum (ctx, a);
46064617        ggml_set_name (out, " out" 
46074618
@@ -6724,6 +6735,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
67246735
67256736    test_cases.emplace_back (new  test_sum ());
67266737    test_cases.emplace_back (new  test_sum_rows ());
6738+     test_cases.emplace_back (new  test_sum (GGML_TYPE_F32, {11 , 5 , 6 , 3 }, {0 , 2 , 1 , 3 }));  //  row-contiguous but non-contiguous
6739+     test_cases.emplace_back (new  test_sum (GGML_TYPE_F32, {11 , 5 , 6 , 3 }, {0 , 3 , 2 , 1 }));
6740+     test_cases.emplace_back (new  test_sum (GGML_TYPE_F32, {11 , 5 , 6 , 3 }, {0 , 1 , 3 , 2 }));
67276741    test_cases.emplace_back (new  test_sum_rows (GGML_TYPE_F32, { 11 , 5 , 6 , 3  }, true , false ));
67286742    test_cases.emplace_back (new  test_sum_rows (GGML_TYPE_F32, { 11 , 5 , 6 , 3  }, false , true ));
67296743    test_cases.emplace_back (new  test_sum_rows (GGML_TYPE_F32, { 11 , 5 , 6 , 3  }, true , true ));
@@ -6734,6 +6748,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
67346748    test_cases.emplace_back (new  test_sum (GGML_TYPE_F32, { 33 , 1024 , 1 , 1  }));
67356749    test_cases.emplace_back (new  test_sum_rows (GGML_TYPE_F32, { 33 , 1024 , 1 , 1  }));
67366750    test_cases.emplace_back (new  test_sum (GGML_TYPE_F32, { 33 , 256 , 1 , 1  }));
6751+     test_cases.emplace_back (new  test_sum (GGML_TYPE_F32, { 33 , 256 , 1 , 1  }, { 1 , 0 , 2 , 3  })); //  sum dst not-contiguous
67376752    test_cases.emplace_back (new  test_sum_rows (GGML_TYPE_F32, { 33 , 256 , 1 , 1  }));
67386753    test_cases.emplace_back (new  test_mean (GGML_TYPE_F32, { 33 , 256 , 1 , 1  }));
67396754    test_cases.emplace_back (new  test_mean (GGML_TYPE_F32, { 32769 , 1 , 1 , 1  }));
0 commit comments