@@ -4179,6 +4179,97 @@ kernel void kernel_conv_transpose_1d<half>(
41794179    uint3   tgpig[[threadgroup_position_in_grid]],
41804180    uint3    tgpg[[threadgroups_per_grid]]);
41814181
4182+ 
4183+ typedef  void  (conv_transpose_2d_t )(
4184+         constant ggml_metal_kargs_conv_transpose_2d & args,
4185+         device const  float  * src0,
4186+         device const  float  * src1,
4187+         device        char  * dst,
4188+         uint3   tgpig[[threadgroup_position_in_grid]],
4189+         uint3    tgpg[[threadgroups_per_grid]]);
4190+ 
4191+ template  <typename  T>
4192+ kernel void  kernel_conv_transpose_2d (
4193+         constant ggml_metal_kargs_conv_transpose_2d & args,
4194+         device const  T * src0,
4195+         device const  float  * src1,
4196+         device        char  * dst,
4197+         threadgroup float  * shared_sum [[threadgroup(0 )]],
4198+         uint3   tgpig[[threadgroup_position_in_grid]],
4199+         uint3   tpitg[[thread_position_in_threadgroup]],
4200+         uint3     ntg[[threads_per_threadgroup]]) {
4201+ 
4202+     const  int64_t  out_x = tgpig[0 ];
4203+     const  int64_t  out_y = tgpig[1 ];
4204+     const  int64_t  out_c = tgpig[2 ];
4205+ 
4206+     const  int64_t  kw = tpitg[0 ];
4207+     const  int64_t  kh = tpitg[1 ];
4208+ 
4209+     float  v = 0 .0f ;
4210+ 
4211+     for  (int64_t  in_c = 0 ; in_c < args.IC ; in_c++) {
4212+         int64_t  in_y = out_y - kh;
4213+ 
4214+         if  (in_y < 0  || in_y % args.s0 ) continue ;
4215+ 
4216+         in_y /= args.s0 ;
4217+ 
4218+         if  (in_y >= args.IH ) continue ;
4219+ 
4220+         int64_t  in_x = out_x - kw;
4221+ 
4222+         if  (in_x < 0  || in_x % args.s0 ) continue ;
4223+ 
4224+         in_x /= args.s0 ;
4225+ 
4226+         if  (in_x >= args.IW ) continue ;
4227+ 
4228+         const  int64_t  input_idx = (args.IW  * args.IH ) * in_c + (args.IW ) * in_y + in_x;
4229+         const  int64_t  kernel_idx = (args.KH  * args.KW  * args.OC ) * in_c + (args.KH  * args.KW ) * out_c + (args.KW ) * kh + kw;
4230+ 
4231+         v += (float )src0[kernel_idx] * src1[input_idx];
4232+     }
4233+ 
4234+     const  uint tid = tpitg.y  * ntg.x  + tpitg.x ;
4235+     shared_sum[tid] = v;
4236+ 
4237+     threadgroup_barrier (mem_flags::mem_threadgroup);
4238+ 
4239+     if  (tid == 0 ) {
4240+         float  total = 0 .0f ;
4241+         const  uint num_threads = ntg.x  * ntg.y ;
4242+         for  (uint i = 0 ; i < num_threads; i++) {
4243+             total += shared_sum[i];
4244+         }
4245+ 
4246+         device float  * dst_ptr = (device float  *) (dst + out_x*args.nb0  + out_y * args.nb1  + out_c*args.nb2 );
4247+         dst_ptr[0 ] = total;
4248+     }
4249+ }
4250+ 
4251+ template  [[host_name(" kernel_conv_transpose_2d_f32_f32" 
4252+ kernel void  kernel_conv_transpose_2d<float >(
4253+     constant ggml_metal_kargs_conv_transpose_2d & args,
4254+     device const  float  * src0,
4255+     device const  float  * src1,
4256+     device        char  * dst,
4257+     threadgroup float  * shared_sum [[threadgroup(0 )]],
4258+     uint3   tgpig[[threadgroup_position_in_grid]],
4259+     uint3   tpitg[[thread_position_in_threadgroup]],
4260+     uint3     ntg[[threads_per_threadgroup]]);
4261+ 
4262+ template  [[host_name(" kernel_conv_transpose_2d_f16_f32" 
4263+ kernel void  kernel_conv_transpose_2d<half>(
4264+     constant ggml_metal_kargs_conv_transpose_2d & args,
4265+     device const  half  * src0,
4266+     device const  float  * src1,
4267+     device        char  * dst,
4268+     threadgroup float  * shared_sum [[threadgroup(0 )]],
4269+     uint3   tgpig[[threadgroup_position_in_grid]],
4270+     uint3   tpitg[[thread_position_in_threadgroup]],
4271+     uint3     ntg[[threads_per_threadgroup]]);
4272+ 
41824273kernel void  kernel_upscale_f32 (
41834274    constant ggml_metal_kargs_upscale & args,
41844275    device  const  char  * src0,
0 commit comments