diff --git a/software/gemmini-rocc-tests b/software/gemmini-rocc-tests index 06123c46..5c83efc6 160000 --- a/software/gemmini-rocc-tests +++ b/software/gemmini-rocc-tests @@ -1 +1 @@ -Subproject commit 06123c464559f063ec90f575405d850fc43eb41f +Subproject commit 5c83efc62b6b0013d1291f3037749b8c925d1aab diff --git a/src/main/scala/gemmini/GemminiISA.scala b/src/main/scala/gemmini/GemminiISA.scala index c3d71ef7..1a4b302f 100644 --- a/src/main/scala/gemmini/GemminiISA.scala +++ b/src/main/scala/gemmini/GemminiISA.scala @@ -27,7 +27,7 @@ object GemminiISA { val LOOP_CONV_WS_CONFIG_1 = 16.U // batch_size, in_dim, in_channels, out_channels | out_dim, pool_out_dim, stride, padding val LOOP_CONV_WS_CONFIG_2 = 17.U // kernel_dim, pool_size, pool_stride, pool_padding | batches, porows, pocols, pochs val LOOP_CONV_WS_CONFIG_3 = 18.U // krows, kcols, kchs, lpad | rpad, upad, dpad, plpad - val LOOP_CONV_WS_CONFIG_4 = 19.U // prad, pupad, pdpad, orows | ocols + val LOOP_CONV_WS_CONFIG_4 = 19.U // prad, pupad, pdpad, orows | ocols, out_channels_stride val LOOP_CONV_WS_CONFIG_5 = 20.U // *weights | *output val LOOP_CONV_WS_CONFIG_6 = 21.U // *bias, *input diff --git a/src/main/scala/gemmini/LoopConv.scala b/src/main/scala/gemmini/LoopConv.scala index 4efd815c..f04cd739 100644 --- a/src/main/scala/gemmini/LoopConv.scala +++ b/src/main/scala/gemmini/LoopConv.scala @@ -13,6 +13,7 @@ class LoopConvOuterBounds(val large_iterator_bitwidth: Int, val small_iterator_b val in_dim = UInt(small_iterator_bitwidth.W) val in_channels = UInt(large_iterator_bitwidth.W) val out_channels = UInt(large_iterator_bitwidth.W) + val out_channels_stride = UInt(large_iterator_bitwidth.W) val out_dim = UInt(small_iterator_bitwidth.W) val pool_out_dim = UInt(small_iterator_bitwidth.W) val stride = UInt(tiny_iterator_bitwidth.W) @@ -693,10 +694,10 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: val och = Reg(UInt(large_iterator_bitwidth.W)) // Addresses - val dram_addr = req.dram_addr + ((b*out_dim*out_dim + orow*out_dim + ocol) * out_channels + och) * (input_w/8).U + val dram_addr = req.dram_addr + ((b*out_dim*out_dim + orow*out_dim + ocol) * out_channels_stride + och) * (input_w/8).U val spad_addr = acc_addr_start +& (och / block_size.U) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol - val pool_dram_addr = req.dram_addr + ((b * pool_out_dim * pool_out_dim) * out_channels + och) * (input_w/8).U + val pool_dram_addr = req.dram_addr + ((b * pool_out_dim * pool_out_dim) * out_channels_stride + och) * (input_w/8).U val pool_spad_addr = acc_addr_start +& (och / block_size.U) * batches * orows * ocols +& b * orows * ocols // Sizes @@ -730,13 +731,13 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: pre_pool_config_cmd.rs1 := (ocols << 56) | (orows << 48) | (pocols << 40) | (porows << 32) | (pool_out_dim << 24) | (plpad << 10) | (pupad << 8) | (pool_size << 6) | (pool_stride << 4) | // TODO magic numbers CONFIG_STORE - pre_pool_config_cmd.rs2 := out_channels * (input_w / 8).U + pre_pool_config_cmd.rs2 := out_channels_stride * (input_w / 8).U val post_pool_config_cmd = Wire(new RoCCCommand) post_pool_config_cmd := DontCare post_pool_config_cmd.inst.funct := CONFIG_CMD post_pool_config_cmd.rs1 := CONFIG_STORE - post_pool_config_cmd.rs2 := out_channels * (input_w / 8).U + post_pool_config_cmd.rs2 := out_channels_stride * (input_w / 8).U val pool_cmd = Wire(new RoCCCommand) pool_cmd := DontCare @@ -1035,6 +1036,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: I loop_being_configured.inner_bounds.pupad := cmd.bits.rs1(31, 16) loop_being_configured.inner_bounds.pdpad := cmd.bits.rs1(15, 0) + loop_being_configured.outer_bounds.out_channels_stride := cmd.bits.rs2(31, 16) loop_being_configured.inner_bounds.ocols := cmd.bits.rs2(15, 0) }