blob: 1f55b9148d1d3c101f72bec5ad2d9e022dcf10e8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** Fetch.
*
* The fetch unit reads instructions (tasks) from memory (i.e. DRAM), using the
* VTA Memory Engine (VME), and push them into an instruction queue called
* inst_q. Once the instruction queue is full, instructions are dispatched to
* the Load, Compute and Store module queues based on the instruction opcode.
* After draining the queue, the fetch unit checks if there are more instructions
* via the ins_count register which is written by the host.
*
* Additionally, instructions are read into two chunks (see sReadLSB and sReadMSB)
* because we are using a DRAM payload of 8-bytes or half of a VTA instruction.
* This should be configurable for larger payloads, i.e. 64-bytes, which can load
* more than one instruction at the time. Finally, the instruction queue is
* sized (entries_q), depending on the maximum burst allowed in the memory.
*/
class FetchWideVME(debug: Boolean = false)(implicit p: Parameters) extends Module {
val vp = p(ShellKey).vcrParams
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val launch = Input(Bool())
val ins_baddr = Input(UInt(mp.addrBits.W))
val ins_count = Input(UInt(vp.regBits.W))
val vme_rd = new VMEReadMaster
val inst = new Bundle {
val ld = Decoupled(UInt(INST_BITS.W))
val co = Decoupled(UInt(INST_BITS.W))
val st = Decoupled(UInt(INST_BITS.W))
}
})
val tp = new TensorParams("fetch")
val tensorsInClNb = tp.clSizeRatio
val tensorsInClNbWidth = log2Ceil(tensorsInClNb)
val inst_q = Seq.fill(tensorsInClNb) {
require((tp.memDepth/tensorsInClNb) * tensorsInClNb == tp.memDepth,
"-F- Unexpected queue depth to instructions in cacheline ratio")
SyncReadMem(tp.memDepth/tensorsInClNb, UInt(tp.tensorSizeBits.W))
}
// sample start
val s1_launch = RegNext(io.launch, init = false.B)
val start = io.launch & ~s1_launch
val xrem = Reg(chiselTypeOf(io.ins_count))
// fit instruction into 64bit chunks
val elemsInInstr = INST_BITS/64
val xsize = io.ins_count << log2Ceil(elemsInInstr)
// max size of transfer is limited by a buffer size
val xmax = (((1 << mp.lenBits) << log2Ceil(tp.clSizeRatio)).min(tp.memDepth)).U
val elemNb = Reg(xsize.cloneType)
val sIdle :: sRead :: sDrain :: Nil = Enum(3)
val state = RegInit(sIdle)
val isBusy = state === sRead
val vmeStart = start || (state === sRead && RegNext(state, init = sIdle) === sDrain)
val dramOffset = RegInit(UInt(mp.addrBits.W), init = 0.U)
val vmeCmd = Module (new GenVMECmdWideFetch(debug))
vmeCmd.io.start := vmeStart
vmeCmd.io.isBusy := isBusy & ~vmeStart
vmeCmd.io.ins_baddr := Mux(start, io.ins_baddr, io.ins_baddr + (dramOffset << log2Ceil(tp.tensorSizeBits / 8)))
vmeCmd.io.vmeCmd <> io.vme_rd.cmd
val readLen = vmeCmd.io.readLen
val vmeCmdDone = vmeCmd.io.done & ~vmeStart
vmeCmd.io.xsize := elemNb
vmeCmd.io.sram_offset := 0.U // this is a queue we reload
io.vme_rd.data.ready := true.B
val pipeDelayQueueDeqV = RegNext(io.vme_rd.data.valid, init = false.B)
val pipeDelayQueueDeqF = pipeDelayQueueDeqV // fire()
val pipeDelayQueueDeqB = RegNext(io.vme_rd.data.bits)
// Nb of CLs requestd, not received.
val clCntIdxWdth = log2Ceil(tp.memDepth/tensorsInClNb) + 1
val clInFlight = Reg(UInt(clCntIdxWdth.W))
when(start) {
clInFlight := 0.U
}.elsewhen(isBusy && io.vme_rd.cmd.fire() && !pipeDelayQueueDeqF) {
clInFlight := clInFlight + readLen
}.elsewhen(isBusy && io.vme_rd.cmd.fire() && pipeDelayQueueDeqF) {
clInFlight := clInFlight + readLen - 1.U
}.elsewhen(isBusy && !io.vme_rd.cmd.fire() && pipeDelayQueueDeqF) {
assert(clInFlight > 0.U)
clInFlight := clInFlight - 1.U
}.otherwise {
clInFlight := clInFlight
}
// number of entries in a queue
val queueCount = Reg(UInt((tp.memAddrBits + 1).W))
val queueHead = Wire(UInt(tp.memAddrBits.W))
val queueHeadNext = Reg(UInt(tp.memAddrBits.W))
val forceRead = Wire(Bool())
forceRead := false.B
// control
switch(state) {
is(sIdle) {
when(start) {
state := sRead
dramOffset := 0.U
when(xsize < xmax) {
elemNb := xsize
xrem := 0.U
}.otherwise {
elemNb := xmax
xrem := xsize - xmax
}
}
}
is(sRead) {
when(vmeCmdDone && clInFlight === 0.U) {
forceRead := true.B
state := sDrain
}
}
is(sDrain) {
when(queueCount === 0.U) {
dramOffset := dramOffset + elemNb
when(xrem === 0.U) {
state := sIdle
}.elsewhen(xrem < xmax) {
state := sRead
elemNb := xrem
xrem := 0.U
}.otherwise {
state := sRead
elemNb := xmax
xrem := xrem - xmax
}
}
}
}
//---------------------
//--- Read VME data ---
//---------------------
val readData = Module(new ReadVMEDataWide("fetch", debug))
readData.io.start := vmeStart
readData.io.vmeData.valid := pipeDelayQueueDeqV
readData.io.vmeData.bits := pipeDelayQueueDeqB
assert(readData.io.vmeData.ready === true.B)
//--------------------
//--- Write memory ---
//--------------------
val wmask = readData.io.destMask
val wdata = readData.io.destData
val widx = readData.io.destIdx
for (i <- 0 until tensorsInClNb) {
when(wmask(i) && pipeDelayQueueDeqF) {
inst_q(i).write(widx(i), wdata(i))
}
}
if (debug) {
when (io.vme_rd.data.fire()) {
printf(s"[TensorLoad] fetch data rdDataDestIdx:%x rdDataDestMask:%b\n",
widx.asUInt,
wmask.asUInt)
}
}
// read-from-sram
// queue head points to the first elem of instruction
val rIdx = queueHead >> tensorsInClNbWidth // SP idx
// rMask selects the first elem of instruction
val rMask = if (tensorsInClNbWidth > 0) {
UIntToOH(queueHead(tensorsInClNbWidth - 1, 0))
} else {
1.U
}
val deqElem = Wire(Bool())
val rdataVec = for (i <- 0 until tensorsInClNb) yield {
// expand mask to select all elems of instruction
val maskShift = i%elemsInInstr
inst_q(i).read(rIdx, VecInit((rMask<<maskShift).asTypeOf(rMask).asBools)(i) && (deqElem || forceRead))
}
// instruction is a elemsInInstr number of elements
// combine them into one instruction
val rdata = Wire(Vec(elemsInInstr, UInt((tp.tensorSizeBits).W)))
for (i <- 0 until elemsInInstr) {
// expand mask to select all elems of instruction
rdata(i) := Mux1H(RegNext((rMask << i).asTypeOf(rMask)), rdataVec)
}
val canRead = queueCount >= elemsInInstr.U && state === sDrain
// instruction queues
// use 2-enty queue to create one pipe stage for valid-ready interface
val readInstrPipe = Module(new Queue(UInt(INST_BITS.W), 2))
// decode
val dec = Module(new FetchDecode)
dec.io.inst := readInstrPipe.io.deq.bits
readInstrPipe.io.enq.valid := canRead
readInstrPipe.io.enq.bits := rdata.asTypeOf(UInt(INST_BITS.W))
deqElem := readInstrPipe.io.enq.fire()
readInstrPipe.io.deq.ready := (
(dec.io.isLoad & io.inst.ld.ready) ||
(dec.io.isCompute & io.inst.co.ready) ||
(dec.io.isStore & io.inst.st.ready))
io.inst.ld.valid := dec.io.isLoad & readInstrPipe.io.deq.valid
io.inst.co.valid := dec.io.isCompute & readInstrPipe.io.deq.valid
io.inst.st.valid := dec.io.isStore & readInstrPipe.io.deq.valid
io.inst.ld.bits := readInstrPipe.io.deq.bits
io.inst.co.bits := readInstrPipe.io.deq.bits
io.inst.st.bits := readInstrPipe.io.deq.bits
when(start) {
queueCount := 0.U
}.elsewhen(deqElem && pipeDelayQueueDeqF) {
assert(queueCount > 0.U, "-F- Decrement zero counter")
val readCount = PopCount(wmask)
assert(readCount > 0.U, "-F- Must push something")
queueCount := queueCount + readCount - elemsInInstr.U
}.elsewhen(deqElem) {
assert(queueCount > 0.U, "-F- Decrement zero counter")
queueCount := queueCount - elemsInInstr.U
}.elsewhen (pipeDelayQueueDeqF) {
val numLoaded = PopCount(wmask)
assert(tp.memDepth.U - numLoaded >= queueCount, "-F- Counter overflow")
queueCount := queueCount + PopCount(wmask)
}.otherwise {
queueCount := queueCount
}
when(start) {
queueHead := 0.U
queueHeadNext := 0.U
}.elsewhen(deqElem) {
queueHead := queueHeadNext + elemsInInstr.U // read ahead
when (queueCount - elemsInInstr.U === 0.U) {
queueHeadNext := 0.U
}.otherwise {
queueHeadNext := queueHeadNext + elemsInInstr.U
}
}.otherwise {
assert(reset.asBool || state === sIdle || queueCount =/= 0.U ||
(queueCount === 0.U && queueHeadNext === 0.U))
queueHead := queueHeadNext
queueHeadNext := queueHeadNext
}
// check if selected queue is ready
val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt
val deq_ready =
MuxLookup(deq_sel,
false.B, // default
Array(
"h_01".U -> io.inst.ld.ready,
"h_02".U -> io.inst.st.ready,
"h_04".U -> io.inst.co.ready
))
// debug
if (debug) {
when(start) {
printf("[Fetch] Launch\n")
}
when(io.inst.ld.fire()) {
printf("[Fetch] [instruction decode] [L] %x\n", dec.io.inst)
}
when(io.inst.co.fire()) {
printf("[Fetch] [instruction decode] [C] %x\n", dec.io.inst)
}
when(io.inst.st.fire()) {
printf("[Fetch] [instruction decode] [S] %x\n", dec.io.inst)
}
}
}
class GenVMECmdWideFetch(debug: Boolean = false)(
implicit p: Parameters)
extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
val isBusy = Input(Bool())
val ins_baddr = Input(UInt(mp.addrBits.W))
val vmeCmd = Decoupled(new VMECmd)
val readLen = Output(UInt((mp.lenBits + 1).W))
val done = Output(Bool())
val xsize = Input(UInt(M_SIZE_BITS.W))
val sram_offset = Input(UInt(M_SRAM_OFFSET_BITS.W))
})
val cmdGen = Module (new GenVMECmdWide(tensorType = "fetch", debug))
cmdGen.io.start := io.start
cmdGen.io.isBusy := io.isBusy
cmdGen.io.baddr := io.ins_baddr
io.vmeCmd <> cmdGen.io.vmeCmd
io.readLen := cmdGen.io.readLen
io.done := cmdGen.io.done
cmdGen.io.ysize := 1.U
cmdGen.io.xsize := io.xsize
cmdGen.io.xstride := io.xsize
cmdGen.io.dram_offset := 0.U
cmdGen.io.sram_offset := io.sram_offset
cmdGen.io.xpad_0 := 0.U
cmdGen.io.xpad_1 := 0.U
cmdGen.io.ypad_0 := 0.U
cmdGen.io.updateState := io.vmeCmd.fire()
cmdGen.io.canSendCmd := true.B
when(io.start) {
val tp = new TensorParams("fetch")
assert(io.ins_baddr%(tp.tensorSizeBits/8).U === 0.U, "-F- Fetch DRAM address expected to be tensor size aligned.")
}
}