Skip to content

vulkan: Optimize argsort #15354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ enum vk_conv_shapes {
CONV_SHAPE_COUNT,
};

static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);

struct vk_device_struct {
std::recursive_mutex mutex;

Expand Down Expand Up @@ -499,7 +502,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
Expand Down Expand Up @@ -856,7 +859,6 @@ struct vk_op_soft_max_push_constants {

struct vk_op_argsort_push_constants {
uint32_t ncols;
uint32_t ncols_pad;
int32_t order;
};

Expand Down Expand Up @@ -3103,7 +3105,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
}

ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
}

ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);

Expand Down Expand Up @@ -7145,7 +7149,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
case GGML_OP_ARGSORT:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
return ctx->device->pipeline_argsort_f32;
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
return ctx->device->pipeline_argsort_f32[idx];
}
return nullptr;
case GGML_OP_SUM:
Expand Down Expand Up @@ -8369,16 +8374,8 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c

uint32_t ncols = src0->ne[0];

uint32_t ncols_pad = 1;
while (ncols_pad < ncols) {
ncols_pad *= 2;
}

GGML_ASSERT(ncols_pad <= 1024);

ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
ncols,
ncols_pad,
op_params[0],
}, dryrun);
}
Expand Down Expand Up @@ -11189,6 +11186,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ARGSORT:
return op->ne[0] <= max_argsort_cols;
case GGML_OP_UPSCALE:
case GGML_OP_ACC:
case GGML_OP_CONCAT:
Expand All @@ -11198,7 +11197,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGMAX:
Expand Down
68 changes: 39 additions & 29 deletions ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
Original file line number Diff line number Diff line change
@@ -1,69 +1,79 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable

#include "types.comp"

#define BLOCK_SIZE 1024
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
#define ASC 0

layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) buffer D {int data_d[];};

layout (push_constant) uniform parameter {
uint ncols;
uint ncols_pad;
uint order;
} p;

shared int dst_row[BLOCK_SIZE];
shared A_TYPE a_sh[BLOCK_SIZE];

void swap(uint idx0, uint idx1) {
int tmp = dst_row[idx0];
dst_row[idx0] = dst_row[idx1];
dst_row[idx1] = tmp;
}

void main() {
void argsort(bool needs_bounds_check) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y;

const uint row_offset = row * p.ncols;

// initialize indices
if (col < p.ncols_pad) {
dst_row[col] = col;
}
dst_row[col] = col;
a_sh[col] = data_a[row_offset + col];
barrier();

for (uint k = 2; k <= p.ncols_pad; k *= 2) {
for (uint j = k / 2; j > 0; j /= 2) {
const uint ixj = col ^ j;
if (col < p.ncols_pad && ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= p.ncols ||
(dst_row[ixj] < p.ncols && (p.order == ASC ?
data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] :
data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]]))
) {
swap(col, ixj);
}
} else {
if (dst_row[ixj] >= p.ncols ||
(dst_row[col] < p.ncols && (p.order == ASC ?
data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] :
data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]]))
) {
swap(col, ixj);
}
}
uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
uint num_inner_loop_iters = outer_idx + 1;
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
const int ixj = int(col ^ j);

int idx_0 = (col & k) == 0 ? col : ixj;
int idx_1 = (col & k) == 0 ? ixj : col;

int sh_idx_0 = dst_row[idx_0];
int sh_idx_1 = dst_row[idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;

if ((idx_0_oob ||
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
swap(idx_0, idx_1);
}

barrier();
}
}

if (col < p.ncols) {
data_d[row_offset + col] = dst_row[col];
if (p.order == ASC) {
data_d[row_offset + col] = dst_row[col];
} else {
data_d[row_offset + p.ncols - col - 1] = dst_row[col];
}
}
}

void main() {
if (p.ncols == BLOCK_SIZE) {
argsort(false);
} else {
argsort(true);
}
}
1 change: 1 addition & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6024,6 +6024,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
}

for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
Expand Down
Loading