Skip to content

Commit 63cdd92

Browse files
authored
use universal workspace pointer in bwd-weight (#286)
1 parent c7a96ed commit 63cdd92

File tree

1 file changed

+3
-14
lines changed

1 file changed

+3
-14
lines changed

include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -900,9 +900,6 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
900900
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
901901
c_grid_desc_m_n_ = descs[I2];
902902

903-
// init work space
904-
p_c_workspace_grid_ = nullptr;
905-
906903
block_2_ctile_map_ =
907904
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
908905

@@ -939,9 +936,6 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
939936
std::vector<index_t> input_left_pads_;
940937
std::vector<index_t> input_right_pads_;
941938
index_t k_batch_;
942-
943-
// external work space
944-
void* p_c_workspace_grid_;
945939
};
946940

947941
// Invoker
@@ -1017,7 +1011,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
10171011
// run kernel for bf16 with splitk
10181012
const auto run_bf16_splitk = [&](const auto& kernel) {
10191013
hipGetErrorString(hipMemset(
1020-
arg.p_c_workspace_grid_,
1014+
arg.p_workspace_,
10211015
0,
10221016
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
10231017
sizeof(AccDataType)));
@@ -1030,7 +1024,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
10301024
0,
10311025
arg.p_a_grid_,
10321026
arg.p_b_grid_,
1033-
static_cast<AccDataType*>(arg.p_c_workspace_grid_),
1027+
static_cast<AccDataType*>(arg.p_workspace_),
10341028
arg.a_grid_desc_kbatch_k0_m_k1_,
10351029
arg.b_grid_desc_kbatch_k0_n_k1_,
10361030
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
@@ -1072,7 +1066,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
10721066
dim3(type_convert_grid_size),
10731067
dim3(256),
10741068
0,
1075-
static_cast<AccDataType*>(arg.p_c_workspace_grid_),
1069+
static_cast<AccDataType*>(arg.p_workspace_),
10761070
p_c_grid_tmp_bf16_,
10771071
a_grid_desc_m0_,
10781072
b_grid_desc_m0_,
@@ -1448,11 +1442,6 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
14481442
{
14491443
return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast<const Argument*>(p_arg));
14501444
}
1451-
1452-
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
1453-
{
1454-
dynamic_cast<Argument*>(p_arg)->p_c_workspace_grid_ = workspace_ptr;
1455-
}
14561445
};
14571446

14581447
} // namespace device

0 commit comments

Comments
 (0)