Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/scala/gemmini/GemminiISA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object GemminiISA {
val LOOP_CONV_WS_CONFIG_1 = 16.U // batch_size, in_dim, in_channels, out_channels | out_dim, pool_out_dim, stride, padding
val LOOP_CONV_WS_CONFIG_2 = 17.U // kernel_dim, pool_size, pool_stride, pool_padding | batches, porows, pocols, pochs
val LOOP_CONV_WS_CONFIG_3 = 18.U // krows, kcols, kchs, lpad | rpad, upad, dpad, plpad
val LOOP_CONV_WS_CONFIG_4 = 19.U // prad, pupad, pdpad, orows | ocols
val LOOP_CONV_WS_CONFIG_4 = 19.U // prad, pupad, pdpad, orows | ocols, out_channels_stride
val LOOP_CONV_WS_CONFIG_5 = 20.U // *weights | *output
val LOOP_CONV_WS_CONFIG_6 = 21.U // *bias, *input

Expand Down
10 changes: 6 additions & 4 deletions src/main/scala/gemmini/LoopConv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class LoopConvOuterBounds(val large_iterator_bitwidth: Int, val small_iterator_b
val in_dim = UInt(small_iterator_bitwidth.W)
val in_channels = UInt(large_iterator_bitwidth.W)
val out_channels = UInt(large_iterator_bitwidth.W)
val out_channels_stride = UInt(large_iterator_bitwidth.W)
val out_dim = UInt(small_iterator_bitwidth.W)
val pool_out_dim = UInt(small_iterator_bitwidth.W)
val stride = UInt(tiny_iterator_bitwidth.W)
Expand Down Expand Up @@ -693,10 +694,10 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth:
val och = Reg(UInt(large_iterator_bitwidth.W))

// Addresses
val dram_addr = req.dram_addr + ((b*out_dim*out_dim + orow*out_dim + ocol) * out_channels + och) * (input_w/8).U
val dram_addr = req.dram_addr + ((b*out_dim*out_dim + orow*out_dim + ocol) * out_channels_stride + och) * (input_w/8).U
val spad_addr = acc_addr_start +& (och / block_size.U) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol

val pool_dram_addr = req.dram_addr + ((b * pool_out_dim * pool_out_dim) * out_channels + och) * (input_w/8).U
val pool_dram_addr = req.dram_addr + ((b * pool_out_dim * pool_out_dim) * out_channels_stride + och) * (input_w/8).U
val pool_spad_addr = acc_addr_start +& (och / block_size.U) * batches * orows * ocols +& b * orows * ocols

// Sizes
Expand Down Expand Up @@ -730,13 +731,13 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth:
pre_pool_config_cmd.rs1 := (ocols << 56) | (orows << 48) | (pocols << 40) | (porows << 32) | (pool_out_dim << 24) |
(plpad << 10) | (pupad << 8) | (pool_size << 6) | (pool_stride << 4) | // TODO magic numbers
CONFIG_STORE
pre_pool_config_cmd.rs2 := out_channels * (input_w / 8).U
pre_pool_config_cmd.rs2 := out_channels_stride * (input_w / 8).U

val post_pool_config_cmd = Wire(new RoCCCommand)
post_pool_config_cmd := DontCare
post_pool_config_cmd.inst.funct := CONFIG_CMD
post_pool_config_cmd.rs1 := CONFIG_STORE
post_pool_config_cmd.rs2 := out_channels * (input_w / 8).U
post_pool_config_cmd.rs2 := out_channels_stride * (input_w / 8).U

val pool_cmd = Wire(new RoCCCommand)
pool_cmd := DontCare
Expand Down Expand Up @@ -1035,6 +1036,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: I
loop_being_configured.inner_bounds.pupad := cmd.bits.rs1(31, 16)
loop_being_configured.inner_bounds.pdpad := cmd.bits.rs1(15, 0)

loop_being_configured.outer_bounds.out_channels_stride := cmd.bits.rs2(31, 16)
loop_being_configured.inner_bounds.ocols := cmd.bits.rs2(15, 0)
}

Expand Down