@@ -3607,7 +3607,7 @@ struct test_flash_attn_ext : public test_case {
36073607
36083608 ggml_tensor * m = nullptr ;
36093609 if (mask) {
3610- m = ggml_new_tensor_4d (ctx, GGML_TYPE_F16, kv, GGML_PAD (nb, GGML_KQ_MASK_PAD), nr23[1 ], 1 );
3610+ m = ggml_new_tensor_4d (ctx, GGML_TYPE_F16, kv, GGML_PAD (nb, GGML_KQ_MASK_PAD), nr23[0 ], nr23[ 1 ] );
36113611 ggml_set_name (m, " m" );
36123612 }
36133613
@@ -4720,7 +4720,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
47204720 test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, {1 , 1 }, scale, max_bias));
47214721
47224722 if (ne0 <= 32 && ne1 <= 32 ) {
4723- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, m_prec, {3 , 1 }, scale, max_bias));
4723+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 3 }, mask, m_prec, {3 , 1 }, scale, max_bias));
47244724 test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, {2 , 3 }, scale, max_bias));
47254725 }
47264726 }
0 commit comments