diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 1f6319d3f7..40b9b07a01 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -15,6 +15,8 @@ struct BaseArgument BaseArgument& operator=(const BaseArgument&) = default; virtual ~BaseArgument() {} + + void* p_workspace_ = nullptr; }; struct BaseInvoker @@ -42,7 +44,11 @@ struct BaseOperator virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } - virtual void SetWorkSpacePointer(BaseArgument*, void*) const {} + virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const final + { + assert(p_arg); + p_arg->p_workspace_ = p_workspace; + } virtual ~BaseOperator() {} }; diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp index 0617b4fcb7..6dfa448fa8 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp @@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdl { grid_size_ = 0; - gemm_descs_args_workspace_ = nullptr; + p_workspace_ = nullptr; group_count_ = ck::type_convert(gemm_shapes.size()); @@ -437,8 +437,6 @@ struct DeviceGroupedGemmXdl std::vector gemm_desc_kernel_arg_; - void* gemm_descs_args_workspace_; - index_t grid_size_; }; @@ -488,7 +486,7 @@ struct DeviceGroupedGemmXdl } hipGetErrorString( - hipMemcpy(arg.gemm_descs_args_workspace_, + hipMemcpy(arg.p_workspace_, arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg), hipMemcpyHostToDevice)); @@ -507,17 +505,17 @@ struct DeviceGroupedGemmXdl CElementwiseOperation, true>; - ave_time = launch_and_time_kernel( - stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_), - arg.gemm_desc_kernel_arg_.size(), - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); } else { @@ -531,17 +529,17 @@ struct DeviceGroupedGemmXdl CElementwiseOperation, false>; - ave_time = launch_and_time_kernel( - stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_), - arg.gemm_desc_kernel_arg_.size(), - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); } return ave_time; @@ -635,11 +633,6 @@ struct DeviceGroupedGemmXdl { return dynamic_cast(p_arg)->group_count_ * sizeof(GemmDescKernelArg); } - - void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override - { - dynamic_cast(p_arg)->gemm_descs_args_workspace_ = workspace_ptr; - } }; } // namespace device