Skip to content

Commit 9ad4f19

Browse files
iliailmerggerganov
andauthored
metal : add CONV_TRANSPOSE_2D (ggml-org#16542)
* initial: headers and metal-device.cpp updates * adding conv_transpose_2d * fix type * fix type: int32->int64 * Update ggml/src/ggml-metal/ggml-metal.metal Co-authored-by: Georgi Gerganov <[email protected]> * Update ggml/src/ggml-metal/ggml-metal.metal Co-authored-by: Georgi Gerganov <[email protected]> * Update ggml/src/ggml-metal/ggml-metal.metal Co-authored-by: Georgi Gerganov <[email protected]> * add checks for src[0] and src[1]; add type checks * Update ggml-metal.metal Co-authored-by: Georgi Gerganov <[email protected]> * add more tests, add optimization to threading * add dynamic memory allocation in metal --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 79967ec commit 9ad4f19

File tree

8 files changed

+198
-0
lines changed

8 files changed

+198
-0
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,6 +1406,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
14061406
return res;
14071407
}
14081408

1409+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
1410+
assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
1411+
1412+
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1413+
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
1414+
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1415+
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1416+
GGML_ASSERT(op->type == GGML_TYPE_F32);
1417+
1418+
char base[256];
1419+
char name[256];
1420+
1421+
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1422+
snprintf(name, 256, "%s", base);
1423+
1424+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1425+
if (res) {
1426+
return res;
1427+
}
1428+
1429+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1430+
1431+
return res;
1432+
}
1433+
14091434
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
14101435
assert(op->op == GGML_OP_UPSCALE);
14111436

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_me
130130
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
131131
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
132132
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
133+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
133134
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
134135
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
135136
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
653653
case GGML_OP_SCALE:
654654
case GGML_OP_CONV_TRANSPOSE_1D:
655655
return true;
656+
case GGML_OP_CONV_TRANSPOSE_2D:
657+
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) &&
658+
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
659+
op->src[1]->type == GGML_TYPE_F32 &&
660+
op->type == GGML_TYPE_F32;
656661
case GGML_OP_CLAMP:
657662
return op->src[0]->type == GGML_TYPE_F32;
658663
case GGML_OP_SQR:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,19 @@ typedef struct {
514514
uint64_t nb1;
515515
} ggml_metal_kargs_conv_transpose_1d;
516516

517+
typedef struct {
518+
int32_t IC;
519+
int32_t IH;
520+
int32_t IW;
521+
int32_t KH;
522+
int32_t KW;
523+
int32_t OC;
524+
int32_t s0;
525+
uint64_t nb0;
526+
uint64_t nb1;
527+
uint64_t nb2;
528+
} ggml_metal_kargs_conv_transpose_2d;
529+
517530
typedef struct {
518531
uint64_t ofs0;
519532
uint64_t ofs1;

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
368368
{
369369
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
370370
} break;
371+
case GGML_OP_CONV_TRANSPOSE_2D:
372+
{
373+
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
374+
} break;
371375
case GGML_OP_UPSCALE:
372376
{
373377
n_fuse = ggml_metal_op_upscale(ctx, idx);
@@ -3118,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
31183122
return 1;
31193123
}
31203124

3125+
int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
3126+
ggml_tensor * op = ctx->node(idx);
3127+
3128+
ggml_metal_library_t lib = ctx->lib;
3129+
ggml_metal_encoder_t enc = ctx->enc;
3130+
3131+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3132+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3133+
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3134+
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3135+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3136+
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3137+
3138+
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3139+
3140+
const int32_t IC = op->src[1]->ne[2];
3141+
const int32_t IH = op->src[1]->ne[1];
3142+
const int32_t IW = op->src[1]->ne[0];
3143+
3144+
const int32_t KH = op->src[0]->ne[1];
3145+
const int32_t KW = op->src[0]->ne[0];
3146+
3147+
const int32_t OW = op->ne[0];
3148+
const int32_t OH = op->ne[1];
3149+
const int32_t OC = op->ne[2];
3150+
3151+
ggml_metal_kargs_conv_transpose_2d args = {
3152+
/*.IC =*/ IC,
3153+
/*.IH =*/ IH,
3154+
/*.IW =*/ IW,
3155+
/*.KH =*/ KH,
3156+
/*.KW =*/ KW,
3157+
/*.OC =*/ OC,
3158+
/*.s0 =*/ s0,
3159+
/*.nb0 =*/ nb0,
3160+
/*.nb1 =*/ nb1,
3161+
/*.nb2 =*/ nb2,
3162+
};
3163+
3164+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
3165+
3166+
ggml_metal_encoder_set_pipeline(enc, pipeline);
3167+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3168+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3169+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3170+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3171+
3172+
// Metal requires buffer size to be multiple of 16 bytes
3173+
const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
3174+
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3175+
3176+
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
3177+
3178+
return 1;
3179+
}
3180+
31213181
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
31223182
ggml_tensor * op = ctx->node(idx);
31233183

ggml/src/ggml-metal/ggml-metal-ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
7171
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
7272
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
7373
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
74+
int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
7475
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
7576
int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
7677
int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4179,6 +4179,97 @@ kernel void kernel_conv_transpose_1d<half>(
41794179
uint3 tgpig[[threadgroup_position_in_grid]],
41804180
uint3 tgpg[[threadgroups_per_grid]]);
41814181

4182+
4183+
typedef void (conv_transpose_2d_t)(
4184+
constant ggml_metal_kargs_conv_transpose_2d & args,
4185+
device const float * src0,
4186+
device const float * src1,
4187+
device char * dst,
4188+
uint3 tgpig[[threadgroup_position_in_grid]],
4189+
uint3 tgpg[[threadgroups_per_grid]]);
4190+
4191+
template <typename T>
4192+
kernel void kernel_conv_transpose_2d(
4193+
constant ggml_metal_kargs_conv_transpose_2d & args,
4194+
device const T * src0,
4195+
device const float * src1,
4196+
device char * dst,
4197+
threadgroup float * shared_sum [[threadgroup(0)]],
4198+
uint3 tgpig[[threadgroup_position_in_grid]],
4199+
uint3 tpitg[[thread_position_in_threadgroup]],
4200+
uint3 ntg[[threads_per_threadgroup]]) {
4201+
4202+
const int64_t out_x = tgpig[0];
4203+
const int64_t out_y = tgpig[1];
4204+
const int64_t out_c = tgpig[2];
4205+
4206+
const int64_t kw = tpitg[0];
4207+
const int64_t kh = tpitg[1];
4208+
4209+
float v = 0.0f;
4210+
4211+
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
4212+
int64_t in_y = out_y - kh;
4213+
4214+
if (in_y < 0 || in_y % args.s0) continue;
4215+
4216+
in_y /= args.s0;
4217+
4218+
if (in_y >= args.IH) continue;
4219+
4220+
int64_t in_x = out_x - kw;
4221+
4222+
if (in_x < 0 || in_x % args.s0) continue;
4223+
4224+
in_x /= args.s0;
4225+
4226+
if (in_x >= args.IW) continue;
4227+
4228+
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
4229+
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
4230+
4231+
v += (float)src0[kernel_idx] * src1[input_idx];
4232+
}
4233+
4234+
const uint tid = tpitg.y * ntg.x + tpitg.x;
4235+
shared_sum[tid] = v;
4236+
4237+
threadgroup_barrier(mem_flags::mem_threadgroup);
4238+
4239+
if (tid == 0) {
4240+
float total = 0.0f;
4241+
const uint num_threads = ntg.x * ntg.y;
4242+
for (uint i = 0; i < num_threads; i++) {
4243+
total += shared_sum[i];
4244+
}
4245+
4246+
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
4247+
dst_ptr[0] = total;
4248+
}
4249+
}
4250+
4251+
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
4252+
kernel void kernel_conv_transpose_2d<float>(
4253+
constant ggml_metal_kargs_conv_transpose_2d & args,
4254+
device const float * src0,
4255+
device const float * src1,
4256+
device char * dst,
4257+
threadgroup float * shared_sum [[threadgroup(0)]],
4258+
uint3 tgpig[[threadgroup_position_in_grid]],
4259+
uint3 tpitg[[thread_position_in_threadgroup]],
4260+
uint3 ntg[[threads_per_threadgroup]]);
4261+
4262+
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
4263+
kernel void kernel_conv_transpose_2d<half>(
4264+
constant ggml_metal_kargs_conv_transpose_2d & args,
4265+
device const half * src0,
4266+
device const float * src1,
4267+
device char * dst,
4268+
threadgroup float * shared_sum [[threadgroup(0)]],
4269+
uint3 tgpig[[threadgroup_position_in_grid]],
4270+
uint3 tpitg[[thread_position_in_threadgroup]],
4271+
uint3 ntg[[threads_per_threadgroup]]);
4272+
41824273
kernel void kernel_upscale_f32(
41834274
constant ggml_metal_kargs_upscale & args,
41844275
device const char * src0,

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6989,6 +6989,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
69896989
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
69906990

69916991
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
6992+
test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1));
6993+
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
69926994

69936995
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
69946996

0 commit comments

Comments
 (0)