@@ -900,9 +900,6 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
900900 b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
901901 c_grid_desc_m_n_ = descs[I2];
902902
903- // init work space
904- p_c_workspace_grid_ = nullptr ;
905-
906903 block_2_ctile_map_ =
907904 GridwiseGemm::MakeCBlockClusterAdaptor (c_grid_desc_m_n_, M01, N01, k_batch_);
908905
@@ -939,9 +936,6 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
939936 std::vector<index_t > input_left_pads_;
940937 std::vector<index_t > input_right_pads_;
941938 index_t k_batch_;
942-
943- // external work space
944- void * p_c_workspace_grid_;
945939 };
946940
947941 // Invoker
@@ -1017,7 +1011,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
10171011 // run kernel for bf16 with splitk
10181012 const auto run_bf16_splitk = [&](const auto & kernel) {
10191013 hipGetErrorString (hipMemset (
1020- arg.p_c_workspace_grid_ ,
1014+ arg.p_workspace_ ,
10211015 0 ,
10221016 arg.c_grid_desc_mblock_mperblock_nblock_nperblock_ .GetElementSpaceSize () *
10231017 sizeof (AccDataType)));
@@ -1030,7 +1024,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
10301024 0 ,
10311025 arg.p_a_grid_ ,
10321026 arg.p_b_grid_ ,
1033- static_cast <AccDataType*>(arg.p_c_workspace_grid_ ),
1027+ static_cast <AccDataType*>(arg.p_workspace_ ),
10341028 arg.a_grid_desc_kbatch_k0_m_k1_ ,
10351029 arg.b_grid_desc_kbatch_k0_n_k1_ ,
10361030 arg.c_grid_desc_mblock_mperblock_nblock_nperblock_ ,
@@ -1072,7 +1066,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
10721066 dim3 (type_convert_grid_size),
10731067 dim3 (256 ),
10741068 0 ,
1075- static_cast <AccDataType*>(arg.p_c_workspace_grid_ ),
1069+ static_cast <AccDataType*>(arg.p_workspace_ ),
10761070 p_c_grid_tmp_bf16_,
10771071 a_grid_desc_m0_,
10781072 b_grid_desc_m0_,
@@ -1448,11 +1442,6 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
14481442 {
14491443 return GetWorkSpaceSize<NumDimSpatial>(*dynamic_cast <const Argument*>(p_arg));
14501444 }
1451-
1452- void SetWorkSpacePointer (BaseArgument* p_arg, void * workspace_ptr) const override
1453- {
1454- dynamic_cast <Argument*>(p_arg)->p_c_workspace_grid_ = workspace_ptr;
1455- }
14561445};
14571446
14581447} // namespace device
0 commit comments