blob: 0055a25455e01cbfffafe7c6d4678b2e8f8e563a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import scala.math.pow
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** Compute.
*
* The compute unit is in charge of the following:
* - Loading micro-ops from memory (loadUop module)
* - Loading biases (acc) from memory (tensorAcc module)
* - Compute ALU instructions (tensorAlu module)
* - Compute GEMM instructions (tensorGemm module)
*/
class Compute(debug: Boolean = false)(implicit val p: Parameters) extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val i_post = Vec(2, Input(Bool()))
val o_post = Vec(2, Output(Bool()))
val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
val uop_baddr = Input(UInt(mp.addrBits.W))
val acc_baddr = Input(UInt(mp.addrBits.W))
val vme_rd = Vec(2, new VMEReadMaster)
val inp = new TensorMaster(tensorType = "inp")
val wgt = new TensorMaster(tensorType = "wgt")
val out = new TensorMaster(tensorType = "out")
val finish = Output(Bool())
val acc_wr_event = Output(Bool())
})
val sIdle :: sSync :: sExe :: Nil = Enum(3)
val state = RegInit(sIdle)
val s = Seq.tabulate(2)(_ =>
Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
val loadUop = Module(new LoadUop)
val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
val tensorGemm = Module(new TensorGemm)
val tensorAlu = Module(new TensorAlu)
// try to use the acc closest to top IO
val topAccGrpIdx = tensorGemm.io.acc.closestIOGrpIdx
val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
// decode
val dec = Module(new ComputeDecode)
dec.io.inst := inst_q.io.deq.bits
val inst_type =
Cat(dec.io.isFinish,
dec.io.isAlu,
dec.io.isGemm,
dec.io.isLoadAcc,
dec.io.isLoadUop).asUInt
val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B)
val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
val start = snext & sprev
val done =
MuxLookup(
inst_type,
false.B, // default
Array(
"h_01".U -> loadUop.io.done,
"h_02".U -> tensorAcc.io.done,
"h_04".U -> tensorGemm.io.done,
"h_08".U -> tensorAlu.io.done,
"h_10".U -> true.B // Finish
)
)
// control
switch(state) {
is(sIdle) {
when(start) {
when(dec.io.isSync) {
state := sSync
}.elsewhen(inst_type.orR) {
state := sExe
}
}
}
is(sSync) {
state := sIdle
}
is(sExe) {
when(done) {
state := sIdle
}
}
}
// instructions
inst_q.io.enq <> io.inst
inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
// uop
loadUop.io.start := state === sIdle & start & dec.io.isLoadUop
loadUop.io.inst := inst_q.io.deq.bits
loadUop.io.baddr := io.uop_baddr
io.vme_rd(0) <> loadUop.io.vme_rd
loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)
assert(!tensorGemm.io.uop.idx.valid || !tensorAlu.io.uop.idx.valid)
// acc
tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
tensorAcc.io.inst := inst_q.io.deq.bits
tensorAcc.io.baddr := io.acc_baddr
require(tensorAcc.io.tensor.lenSplit ==
tensorAcc.io.tensor.tensorLength, "-F- Expecting a whole batch in acc group")
// split factor of isGemm for many groups
val splitFactorL0 = pow(2,log2Ceil(tensorAcc.io.tensor.splitWidth) / 2).toInt
val splitFactorL1 = pow(2,log2Ceil(tensorAcc.io.tensor.splitWidth)
- log2Ceil(tensorAcc.io.tensor.splitWidth) / 2).toInt
require(splitFactorL0 * splitFactorL1 == tensorAcc.io.tensor.splitWidth)
val accRdSelectL0 = for (idx <- 0 until splitFactorL1) yield {
// can save 1 stage on small design
if (splitFactorL1 > 1) RegNext(dec.io.isGemm, init = false.B) else dec.io.isGemm
}
for (idx <- 0 until tensorAcc.io.tensor.splitWidth) {
tensorAcc.io.tensor.rd(idx).idx <> Mux(
RegNext(accRdSelectL0(idx/splitFactorL0), init = false.B),
tensorGemm.io.acc.rd(idx).idx,
tensorAlu.io.acc.rd(idx).idx)
tensorAcc.io.tensor.wr(idx) <> Mux(
RegNext(accRdSelectL0(idx/splitFactorL0), init = false.B),
tensorGemm.io.acc.wr(idx),
tensorAlu.io.acc.wr(idx))
}
io.vme_rd(1) <> tensorAcc.io.vme_rd
io.acc_wr_event := tensorAcc.io.tensor.wr(topAccGrpIdx).valid
// gemm
tensorGemm.io.start := RegNext(state === sIdle & start & dec.io.isGemm, init = false.B)
tensorGemm.io.dec := inst_q.io.deq.bits.asTypeOf(new GemmDecode)
tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
tensorGemm.io.inp <> io.inp
tensorGemm.io.wgt <> io.wgt
for (idx <- 0 until tensorGemm.io.acc.splitWidth) {
tensorGemm.io.acc.rd(idx).data.valid :=
tensorAcc.io.tensor.rd(idx).data.valid & RegNext(dec.io.isGemm, init = false.B)
tensorGemm.io.acc.rd(idx).data.bits <> tensorAcc.io.tensor.rd(idx).data.bits
}
for (idx <- 0 until tensorGemm.io.out.splitWidth) {
tensorGemm.io.out.rd(idx).data.valid :=
io.out.rd(idx).data.valid & RegNext(dec.io.isGemm, init = false.B)
tensorGemm.io.out.rd(idx).data.bits <> io.out.rd(idx).data.bits
}
// alu
tensorAlu.io.start := RegNext(state === sIdle & start & dec.io.isAlu, init = false.B)
tensorAlu.io.dec := inst_q.io.deq.bits.asTypeOf(new AluDecode)
tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
for (idx <- 0 until tensorAlu.io.acc.splitWidth) {
tensorAlu.io.acc.rd(idx).data.valid :=
tensorAcc.io.tensor.rd(idx).data.valid & RegNext(dec.io.isAlu, init = false.B)
tensorAlu.io.acc.rd(idx).data.bits <> tensorAcc.io.tensor.rd(idx).data.bits
}
for (idx <- 0 until tensorAlu.io.out.splitWidth) {
tensorAlu.io.out.rd(idx).data.valid :=
io.out.rd(idx).data.valid & RegNext(dec.io.isAlu, init = false.B)
tensorAlu.io.out.rd(idx).data.bits <> io.out.rd(idx).data.bits
}
// out
for (idx <- 0 until tensorGemm.io.out.splitWidth) {
io.out.rd(idx).idx <> Mux(dec.io.isGemm,
tensorGemm.io.out.rd(idx).idx,
tensorAlu.io.out.rd(idx).idx)
assert(!tensorGemm.io.out.rd(idx).idx.valid || !tensorAlu.io.out.rd(idx).idx.valid)
assert(!tensorGemm.io.out.rd(idx).data.valid || !tensorAlu.io.out.rd(idx).data.valid)
assert(!tensorGemm.io.out.wr(idx).valid || !tensorAlu.io.out.wr(idx).valid)
}
require (tensorGemm.io.out.splitWidth == 1)
require (tensorAlu.io.out.splitWidth == 1)
io.out.wr(0).valid := Mux(
RegNext(dec.io.isGemm, init = false.B), tensorGemm.io.out.wr(0).valid, tensorAlu.io.out.wr(0).valid)
io.out.wr(0).bits.idx := Mux(
RegNext(dec.io.isGemm, init = false.B), tensorGemm.io.out.wr(0).bits.idx, tensorAlu.io.out.wr(0).bits.idx)
// put mux/Reg into every gemm group to build pipe (for Mux select) tree over distance
val chunkWidth = io.out.wr(0).bits.data.getWidth / tensorGemm.io.acc.splitWidth
val outDataBits = Wire(Vec(tensorGemm.io.acc.splitWidth, UInt(chunkWidth.W)))
io.out.wr(0).bits.data := outDataBits.asTypeOf(io.out.wr(0).bits.data)
for (idx <- 0 until tensorGemm.io.acc.splitWidth) {
val lowBitIdx = idx * chunkWidth
val highBitIdx = lowBitIdx + chunkWidth - 1
val srcAluFlat = tensorAlu.io.out.wr(0).bits.data.asUInt
val srcGemFlat = tensorGemm.io.out.wr(0).bits.data.asUInt
outDataBits(idx) := Mux(
RegNext(dec.io.isGemm, init = false.B),
srcGemFlat(highBitIdx, lowBitIdx),
srcAluFlat(highBitIdx, lowBitIdx))
}
// semaphore
s(0).io.spost := io.i_post(0)
s(1).io.spost := io.i_post(1)
s(0).io.swait := dec.io.pop_prev & (state === sIdle & start)
s(1).io.swait := dec.io.pop_next & (state === sIdle & start)
io.o_post(0) := dec.io.push_prev & ((state === sExe & done) | (state === sSync))
io.o_post(1) := dec.io.push_next & ((state === sExe & done) | (state === sSync))
// finish
io.finish := state === sExe & done & dec.io.isFinish
// debug
if (debug) {
// start
when(state === sIdle && start) {
when(dec.io.isSync) {
printf("[Compute] start sync\n")
}.elsewhen(dec.io.isLoadUop) {
printf("[Compute] start load uop\n")
}.elsewhen(dec.io.isLoadAcc) {
printf("[Compute] start load acc\n")
}.elsewhen(dec.io.isGemm) {
printf("[Compute] start gemm\n")
}.elsewhen(dec.io.isAlu) {
printf("[Compute] start alu\n")
}.elsewhen(dec.io.isFinish) {
printf("[Compute] start finish\n")
}
}
// done
when(state === sSync) {
printf("[Compute] done sync\n")
}
when(state === sExe) {
when(done) {
when(dec.io.isLoadUop) {
printf("[Compute] done load uop\n")
}.elsewhen(dec.io.isLoadAcc) {
printf("[Compute] done load acc\n")
}.elsewhen(dec.io.isGemm) {
printf("[Compute] done gemm\n")
}.elsewhen(dec.io.isAlu) {
printf("[Compute] done alu\n")
}.elsewhen(dec.io.isFinish) {
printf("[Compute] done finish\n")
}
}
}
}
}