Skip to content

Commit ae82799

Browse files
committed
amend
1 parent 0797f43 commit ae82799

File tree

4 files changed

+30
-17
lines changed

4 files changed

+30
-17
lines changed

include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -486,14 +486,16 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
486486
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
487487
c_grid_desc_m_n_container_.push_back(descs[I2]);
488488

489-
block_2_ctile_map_container_.push_back(
490-
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01));
489+
auto block_2_ctile_map =
490+
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01);
491491

492492
if(GridwiseGemm::CheckValidity(
493-
descs[I0], descs[I1], descs[I2], block_2_ctile_map_container_.back()))
493+
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
494494
{
495495
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
496496
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
497+
498+
block_2_ctile_map_container_.push_back(block_2_ctile_map);
497499
}
498500
}
499501
}

include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,14 +1073,15 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
10731073
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
10741074
c_grid_desc_m_n_container_.push_back(descs[I2]);
10751075

1076-
block_2_ctile_map_container_.push_back(
1077-
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_));
1076+
auto block_2_ctile_map =
1077+
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
10781078

1079-
if(GridwiseGemm::CheckValidity(
1080-
descs[I0], descs[I1], descs[I2], block_2_ctile_map_container_.back()))
1079+
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map))
10811080
{
10821081
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
10831082
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
1083+
1084+
block_2_ctile_map_container_.push_back(block_2_ctile_map);
10841085
}
10851086
}
10861087
}
@@ -1130,14 +1131,16 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
11301131
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
11311132
c_grid_desc_m_n_container_.push_back(descs[I2]);
11321133

1133-
block_2_ctile_map_container_.push_back(
1134-
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_));
1134+
auto block_2_ctile_map =
1135+
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
11351136

11361137
if(GridwiseGemm::CheckValidity(
1137-
descs[I0], descs[I1], descs[I2], block_2_ctile_map_container_.back()))
1138+
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
11381139
{
11391140
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
11401141
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
1142+
1143+
block_2_ctile_map_container_.push_back(block_2_ctile_map);
11411144
}
11421145
}
11431146
}
@@ -1196,17 +1199,17 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
11961199
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
11971200
c_grid_desc_m_n_container_.push_back(descs[I2]);
11981201

1199-
block_2_ctile_map_container_.push_back(
1200-
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_));
1202+
auto block_2_ctile_map =
1203+
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_);
12011204

1202-
if(GridwiseGemm::CheckValidity(descs[I0],
1203-
descs[I1],
1204-
descs[I2],
1205-
block_2_ctile_map_container_.back()))
1205+
if(GridwiseGemm::CheckValidity(
1206+
descs[I0], descs[I1], descs[I2], block_2_ctile_map))
12061207
{
12071208
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
12081209
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
12091210
descs[I2]));
1211+
1212+
block_2_ctile_map_container_.push_back(block_2_ctile_map);
12101213
}
12111214
}
12121215
}

include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
300300
const auto block_work_idx =
301301
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
302302

303+
// if(!block_2_ctile_map.ValidCTileIndex(
304+
// make_tuple(block_work_idx[I1], block_work_idx[I2]),
305+
// make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
306+
// c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1))))
307+
// {
308+
// return;
309+
// }
310+
303311
const index_t k_batch_id = block_work_idx[I0];
304312

305313
// HACK: this force m/n_block_data_idx_on_grid into SGPR

include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
290290
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
291291

292292
if(!c_block_cluster_adaptor.ValidCTileIndex(
293-
block_work_idx,
293+
make_tuple(block_work_idx[I1], block_work_idx[I2]),
294294
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
295295
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
296296
{

0 commit comments

Comments
 (0)