Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion software/libgemmini
Submodule libgemmini updated 2 files
+34 −21 gemmini.cc
+1 −0 gemmini.h
53 changes: 34 additions & 19 deletions src/main/scala/gemmini/LoopConv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class LoopConvOuterBounds(val large_iterator_bitwidth: Int, val small_iterator_b
val in_channels = UInt(large_iterator_bitwidth.W)
val out_channels = UInt(large_iterator_bitwidth.W)
val out_dim = UInt(large_iterator_bitwidth.W)
val out_stride = UInt(large_iterator_bitwidth.W) //stride for output activation
val in_stride = UInt(large_iterator_bitwidth.W) //stride for input activation
val weight_stride = UInt(large_iterator_bitwidth.W) //stride for weight
val pool_out_dim = UInt(small_iterator_bitwidth.W)
val stride = UInt(tiny_iterator_bitwidth.W)
val padding = UInt(tiny_iterator_bitwidth.W)
Expand Down Expand Up @@ -272,11 +275,11 @@ class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitw
val icol_padded = icol +& undilated(lpad).zext
val is_zeros = irow < 0.S || irow >= irows_unpadded.zext || icol < 0.S || icol >= icols_unpadded.zext

val dram_stride = Mux(req.trans_input_3120, batch_size * (input_w/8).U, in_channels * (input_w/8).U)
val dram_stride = Mux(req.trans_input_3120, batch_size * (input_w/8).U, in_stride * (input_w/8).U)

// Addresses
val dram_offset = Mux(req.trans_input_3120, (((ich * in_dim * in_dim +& irow*in_dim +& icol) * batches +& b) * (input_w/8).U).asUInt,
(((b * in_dim * in_dim +& irow*in_dim +& icol) * in_channels +& ich) * (input_w/8).U).asUInt)
(((b * in_dim * in_dim +& irow*in_dim +& icol) * in_stride +& ich) * (input_w/8).U).asUInt)
val dram_addr = Mux(is_zeros, 0.U, req.dram_addr + LoopConv.castDramOffset(dram_offset))
val spad_addr = Mux(req.trans_input_3120,
// To prevent Verilator errors, we replace some "/ block_size.U" calls here with ">> log2Up(block_size)"
Expand Down Expand Up @@ -333,7 +336,7 @@ class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitw
io.idle := state === idle && !command_p.io.busy
io.loop_id := req.loop_id

command_p.io.in.valid := state =/= idle && !io.wait_for_prev_loop
command_p.io.in.valid := state =/= idle && !io.wait_for_prev_loop && (req.dram_addr =/= 0.U)
command_p.io.in.bits.cmd := Mux(state === config, config_cmd, mvin_cmd)
command_p.io.in.bits.dram_addr := dram_addr
command_p.io.in.bits.spad_addr := spad_addr
Expand All @@ -355,7 +358,9 @@ class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitw
}

// Sending outputs
when(command_p.io.in.fire) {
when(req.dram_addr === 0.U){
state := idle
}.elsewhen(command_p.io.in.fire) {
when (state === config) {
state := ld
}.otherwise {
Expand Down Expand Up @@ -442,7 +447,7 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit
out_channels_per_bank * kcols * krows * kchs)
val addr_start = req.addr_end - B_rows

val dram_stride = MuxCase(out_channels, Seq(
val dram_stride = MuxCase(weight_stride, Seq(
req.dw -> 1.U,
req.trans_weight_1203 -> (kernel_dim * kernel_dim * out_channels),
req.trans_weight_0132 -> in_channels
Expand All @@ -455,7 +460,7 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit
val kch = Reg(UInt(large_iterator_bitwidth.W))

// Addresses
val dram_offset = MuxCase(((krow*kernel_dim*in_channels +& kcol*in_channels +& kch) * out_channels +& och) * (input_w/8).U, Seq(
val dram_offset = MuxCase(((krow*kernel_dim*in_channels +& kcol*in_channels +& kch) * weight_stride +& och) * (input_w/8).U, Seq(
req.dw -> (krow * kernel_dim +& kcol) * (input_w/8).U,
req.trans_weight_1203 -> (((kch*kernel_dim*kernel_dim +& krow*kernel_dim +& kcol) * out_channels +& och) * (input_w/8).U),
req.trans_weight_0132 -> (((krow*kernel_dim*out_channels +& kcol*out_channels +& och) * in_channels +& kch) * (input_w/8).U)
Expand Down Expand Up @@ -512,7 +517,7 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit
io.idle := state === idle && !command_p.io.busy
io.loop_id := req.loop_id

command_p.io.in.valid := state =/= idle && !io.wait_for_prev_loop
command_p.io.in.valid := state =/= idle && !io.wait_for_prev_loop && (req.dram_addr =/= 0.U)
command_p.io.in.bits.cmd := Mux(state === config, config_cmd, mvin_cmd)
command_p.io.in.bits.dram_addr := dram_addr
command_p.io.in.bits.spad_addr := spad_addr
Expand All @@ -534,7 +539,9 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit
}

// Sending outputs
when(command_p.io.in.fire) {
when(req.dram_addr === 0.U){
state := idle
}.elsewhen(command_p.io.in.fire) {
when (state === config) {
state := ld
}.otherwise {
Expand Down Expand Up @@ -880,11 +887,11 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth:
// Addresses
val dram_offset = Mux(req.trans_output_1203,
((orow*out_dim*batch_size +& ocol*batch_size +& b) * out_channels +& och) * (input_w/8).U,
((b*out_dim*out_dim +& orow*out_dim +& ocol) * out_channels +& och) * (input_w/8).U)
((b*out_dim*out_dim +& orow*out_dim +& ocol) * out_stride +& och) * (input_w/8).U)
val dram_addr = req.dram_addr + LoopConv.castDramOffset(dram_offset)
val spad_addr = acc_addr_start +& (och / block_size.U(och.getWidth.W)) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol

val pool_dram_addr = req.dram_addr + ((b * pool_out_dim * pool_out_dim) * out_channels + och) * (input_w/8).U
val pool_dram_addr = req.dram_addr + ((b * pool_out_dim * pool_out_dim) * out_stride + och) * (input_w/8).U
val pool_spad_addr = acc_addr_start +& (och / block_size.U(och.getWidth.W)) * batches * orows * ocols +& b * orows * ocols

// Sizes
Expand Down Expand Up @@ -933,7 +940,7 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth:
val pre_pool_config_cmd_rs2 = Wire(config_mvout_rs2_t.cloneType)
pre_pool_config_cmd_rs2 := DontCare
pre_pool_config_cmd_rs2.acc_scale := ACC_SCALE_NO_CHANGE
pre_pool_config_cmd_rs2.stride := out_channels * (input_w / 8).U
pre_pool_config_cmd_rs2.stride := out_stride * (input_w / 8).U
pre_pool_config_cmd.rs2 := pre_pool_config_cmd_rs2.asUInt

val post_pool_config_cmd = Wire(new RoCCCommand)
Expand All @@ -949,7 +956,7 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth:
val post_pool_config_cmd_rs2 = Wire(config_mvout_rs2_t.cloneType)
post_pool_config_cmd_rs2 := DontCare
post_pool_config_cmd_rs2.acc_scale := ACC_SCALE_NO_CHANGE
post_pool_config_cmd_rs2.stride := out_channels * (input_w / 8).U
post_pool_config_cmd_rs2.stride := out_stride * (input_w / 8).U
post_pool_config_cmd.rs2 := post_pool_config_cmd_rs2.asUInt

val pool_cmd = Wire(new RoCCCommand)
Expand Down Expand Up @@ -1070,6 +1077,8 @@ class LoopConvState(val block_size: Int, val large_iterator_bitwidth: Int, val s
val dw = Bool()

val max_pixels_per_row = UInt(small_iterator_bitwidth.W)
val a_ex_spad_id = UInt(2.W)
val b_ex_spad_id = UInt(2.W)

val configured = Bool()

Expand Down Expand Up @@ -1306,11 +1315,14 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size:
is (LOOP_CONV_WS_CONFIG_4) {
loop_being_configured.inner_bounds.orows := cmd.bits.cmd.rs1(63, 48)
loop_being_configured.inner_bounds.prad := cmd.bits.cmd.rs1(47, 32)
loop_being_configured.inner_bounds.pupad := cmd.bits.cmd.rs1(31, 16)
loop_being_configured.inner_bounds.pdpad := cmd.bits.cmd.rs1(15, 0)
loop_being_configured.inner_bounds.pupad := cmd.bits.cmd.rs1(31, 21)
loop_being_configured.inner_bounds.pdpad := cmd.bits.cmd.rs1(20, 10)
loop_being_configured.outer_bounds.kernel_dilation := cmd.bits.cmd.rs1(9, 0)

loop_being_configured.inner_bounds.ocols := cmd.bits.cmd.rs2(15, 0)
loop_being_configured.outer_bounds.kernel_dilation := cmd.bits.cmd.rs2(31, 16)
loop_being_configured.outer_bounds.in_stride := cmd.bits.cmd.rs2(63, 48)
loop_being_configured.outer_bounds.weight_stride := cmd.bits.cmd.rs2(47, 32)
loop_being_configured.outer_bounds.out_stride := cmd.bits.cmd.rs2(31, 16)
}

is (LOOP_CONV_WS_CONFIG_5) {
Expand All @@ -1334,6 +1346,9 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size:
!has_first_layer_optimizations.B || config_max_pixels_per_row === 0.U,
1.U, config_max_pixels_per_row)

loop_being_configured.a_ex_spad_id := cmd.bits.cmd.rs1(19, 18)
loop_being_configured.b_ex_spad_id := cmd.bits.cmd.rs1(17, 16)

loop_being_configured.wrot180 := has_training_convs.B && cmd.bits.cmd.rs1(1)
loop_being_configured.input_dilated := has_training_convs.B && cmd.bits.cmd.rs2(2)
loop_being_configured.trans_output_1203 := has_training_convs.B && cmd.bits.cmd.rs1(2)
Expand Down Expand Up @@ -1387,7 +1402,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size:
ld_input.io.req.bits.outer_bounds := loop_requesting_ld_input.outer_bounds
ld_input.io.req.bits.inner_bounds := loop_requesting_ld_input.inner_bounds
ld_input.io.req.bits.derived_params := loop_requesting_ld_input.derived_params()
ld_input.io.req.bits.addr_start := loop_requesting_ld_input.a_addr_start
ld_input.io.req.bits.addr_start := Mux(loop_requesting_ld_input.a_ex_spad_id === 0.U, loop_requesting_ld_input.a_addr_start, (loop_requesting_ld_input.a_ex_spad_id - 1.U) * (max_addr / concurrent_loops).U)
ld_input.io.req.bits.dram_addr := loop_requesting_ld_input.input_dram_addr
ld_input.io.req.bits.downsample := loop_requesting_ld_input.downsample
ld_input.io.req.bits.max_pixels_per_row := loop_requesting_ld_input.max_pixels_per_row
Expand All @@ -1407,7 +1422,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size:
ld_weights.io.req.bits.outer_bounds := loop_requesting_ld_weights.outer_bounds
ld_weights.io.req.bits.inner_bounds := loop_requesting_ld_weights.inner_bounds
ld_weights.io.req.bits.derived_params := loop_requesting_ld_weights.derived_params()
ld_weights.io.req.bits.addr_end := loop_requesting_ld_weights.b_addr_end
ld_weights.io.req.bits.addr_end := Mux(loop_requesting_ld_weights.b_ex_spad_id === 0.U, loop_requesting_ld_weights.b_addr_end, (loop_requesting_ld_weights.b_ex_spad_id) * (max_addr / concurrent_loops).U)
ld_weights.io.req.bits.dram_addr := loop_requesting_ld_weights.weights_dram_addr
ld_weights.io.req.bits.trans_weight_1203 := loop_requesting_ld_weights.trans_weight_1203
ld_weights.io.req.bits.trans_weight_0132 := loop_requesting_ld_weights.trans_weight_0132
Expand All @@ -1426,8 +1441,8 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size:
ex.io.req.bits.outer_bounds := loop_requesting_ex.outer_bounds
ex.io.req.bits.inner_bounds := loop_requesting_ex.inner_bounds
ex.io.req.bits.derived_params := loop_requesting_ex.derived_params()
ex.io.req.bits.a_addr_start := loop_requesting_ex.a_addr_start
ex.io.req.bits.b_addr_end := loop_requesting_ex.b_addr_end
ex.io.req.bits.a_addr_start := Mux(loop_requesting_ex.a_ex_spad_id === 0.U, loop_requesting_ex.a_addr_start, (loop_requesting_ex.a_ex_spad_id - 1.U) * (max_addr / concurrent_loops).U)
ex.io.req.bits.b_addr_end := Mux(loop_requesting_ex.b_ex_spad_id === 0.U, loop_requesting_ex.b_addr_end, (loop_requesting_ex.b_ex_spad_id) * (max_addr / concurrent_loops).U)
ex.io.req.bits.c_addr_start := ex_c_addr_start
ex.io.req.bits.wrot180 := loop_requesting_ex.wrot180
ex.io.req.bits.downsample := loop_requesting_ex.downsample
Expand Down
24 changes: 16 additions & 8 deletions src/main/scala/gemmini/LoopMatmul.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
io.k := k
io.idle := state === idle

io.cmd.valid := state =/= idle && !io.rob_overloaded
io.cmd.valid := state =/= idle && !io.rob_overloaded && req.dram_addr =/= 0.U
io.cmd.bits := mvin_cmd

io.loop_id := req.loop_id

when (io.cmd.fire) {
when(req.dram_addr === 0.U){
state := idle
}.elsewhen(io.cmd.fire) {
// The order here is k, j, i
val i_blocks = Mux(req.transpose, max_blocks, 1.U)
val k_blocks = Mux(req.transpose, 1.U, max_blocks)
Expand Down Expand Up @@ -194,12 +196,14 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
io.j := j
io.idle := state === idle

io.cmd.valid := state =/= idle && !io.rob_overloaded
io.cmd.valid := state =/= idle && !io.rob_overloaded && req.dram_addr =/= 0.U
io.cmd.bits := mvin_cmd

io.loop_id := req.loop_id

when (io.cmd.fire) {
when(req.dram_addr === 0.U){
state := idle
}.elsewhen(io.cmd.fire) {
// The order here is k, j, i
val j_blocks = Mux(req.transpose, 1.U, max_blocks)
val k_blocks = Mux(req.transpose, max_blocks, 1.U)
Expand Down Expand Up @@ -698,6 +702,8 @@ class LoopMatmulState(val iterator_bitwidth: Int, val coreMaxAddrBits: Int, val
val full_c = Bool()
val ex_accumulate = Bool()

val a_ex_spad_id = UInt(2.W)
val b_ex_spad_id = UInt(2.W)
val configured = Bool()

val running = Bool()
Expand Down Expand Up @@ -896,6 +902,8 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
loop_being_configured.low_d := cmd.bits.cmd.rs1(2)
loop_being_configured.act := cmd.bits.cmd.rs1(8+Activation.bitwidth-1, 8) // TODO magic numbers

loop_being_configured.a_ex_spad_id := cmd.bits.cmd.rs1(19, 18)
loop_being_configured.b_ex_spad_id := cmd.bits.cmd.rs1(17, 16)
loop_being_configured.a_transpose := cmd.bits.cmd.rs2(0)
loop_being_configured.b_transpose := cmd.bits.cmd.rs2(1)

Expand All @@ -920,7 +928,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
ldA.io.req.bits.dram_addr := loop_requesting_ldA.a_dram_addr
ldA.io.req.bits.dram_stride := loop_requesting_ldA.a_dram_stride
ldA.io.req.bits.transpose := loop_requesting_ldA.a_transpose
ldA.io.req.bits.addr_start := loop_requesting_ldA.a_addr_start
ldA.io.req.bits.addr_start := Mux(loop_requesting_ldA.a_ex_spad_id === 0.U, loop_requesting_ldA.a_addr_start, (loop_requesting_ldA.a_ex_spad_id - 1.U) * (max_addr / concurrent_loops).U)
ldA.io.req.bits.loop_id := loop_requesting_ldA_id

ldA.io.req.valid := !loop_requesting_ldA.lda_started && loop_requesting_ldA.configured
Expand All @@ -939,7 +947,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
ldB.io.req.bits.dram_addr := loop_requesting_ldB.b_dram_addr
ldB.io.req.bits.dram_stride := loop_requesting_ldB.b_dram_stride
ldB.io.req.bits.transpose := loop_requesting_ldB.b_transpose
ldB.io.req.bits.addr_end := loop_requesting_ldB.b_addr_end
ldB.io.req.bits.addr_end := Mux(loop_requesting_ldB.b_ex_spad_id === 0.U, loop_requesting_ldB.b_addr_end, (loop_requesting_ldB.b_ex_spad_id) * (max_addr / concurrent_loops).U)
ldB.io.req.bits.loop_id := loop_requesting_ldB_id

ldB.io.req.valid := !loop_requesting_ldB.ldb_started && loop_requesting_ldB.configured
Expand All @@ -958,8 +966,8 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
ex.io.req.bits.pad_k := loop_requesting_ex.pad_k
ex.io.req.bits.pad_i := loop_requesting_ex.pad_i
ex.io.req.bits.accumulate := loop_requesting_ex.ex_accumulate
ex.io.req.bits.a_addr_start := loop_requesting_ex.a_addr_start
ex.io.req.bits.b_addr_end := loop_requesting_ex.b_addr_end
ex.io.req.bits.a_addr_start := Mux(loop_requesting_ex.a_ex_spad_id === 0.U, loop_requesting_ex.a_addr_start, (loop_requesting_ex.a_ex_spad_id - 1.U) * (max_addr / concurrent_loops).U)
ex.io.req.bits.b_addr_end := Mux(loop_requesting_ex.b_ex_spad_id === 0.U, loop_requesting_ex.b_addr_end, (loop_requesting_ex.b_ex_spad_id) * (max_addr / concurrent_loops).U)
ex.io.req.bits.a_tranpose := loop_requesting_ex.a_transpose
ex.io.req.bits.b_tranpose := loop_requesting_ex.b_transpose
ex.io.req.bits.c_addr_start := ex_c_addr_start
Expand Down