blob: 3ad825977093c32bb73ff8927019861f0507146e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import scala.math.pow
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** TensorLoad.
*
* Load Cachelines from main memory (DRAM) into SRAM
* Mux Cachelines to tensor size memory blocks in
* scratchpads (SRAM). Also, there is support for zero padding, while
* doing the load. Zero-padding works on the y and x axis, and it is
* managed by ZeroPadding.
* Read tensors from SRAM.
* banks number (BN) = CachLineSize (CS) / Tensor bit size (TS)
* the number of banks is pow of 2
* Scratchpad: Seq(BN) {Mem(TensorsNb/BN, TS)}
* Cacheline: Vec(BN,CS/BN)
* Load:
* Scratchpad
* bank1 bank2
* | |
* --- ---
* wmask-/ \ -/ \
* ----- -----
* | | | |
* c | | | |
* a -----|-------- |
* c | |
* h | |
* e | |
* l | |
* i ------------------
* n
* e
*/
class TensorLoadWideVME(tensorType: String = "none", debug: Boolean = false)(
implicit p: Parameters)
extends Module {
val tp = new TensorParams(tensorType)
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
val inst = Input(UInt(INST_BITS.W))
val baddr = Input(UInt(mp.addrBits.W))
val vme_rd = new VMEReadMaster
val tensor = new TensorClient(tensorType)
})
// the delay cycles of write pipe. Needed to deliver singal over physical distance
val writePipeLatency = tp.writePipeLatency
val sIdle :: sBusy :: Nil =
Enum(2)
val state = RegInit(sIdle)
val isBusy = state === sBusy
val localDone = Wire(Bool())
when(io.start) {
state := sBusy
}.elsewhen(localDone) {
state := sIdle
}
val dec = io.inst.asTypeOf(new MemDecode)
val readVMEDataLatency = tp.readVMEDataLatency
val vmeDataBitsPipe = ShiftRegister(io.vme_rd.data.bits, readVMEDataLatency, en = true.B)
val vmeDataValidPipe = ShiftRegister(io.vme_rd.data.valid, readVMEDataLatency, resetData = false.B, en = true.B)
val vmeDataReadyPipe = ShiftRegister(io.vme_rd.data.ready, readVMEDataLatency, resetData = true.B, en = true.B)
val vmeDataFirePipe = vmeDataValidPipe & vmeDataReadyPipe
//--------------------------------------
//--- Generate data load VME command ---
//--------------------------------------
val vmeCmd = Module (new GenVMECmdWideTL(tensorType, debug))
vmeCmd.io.start := io.start
vmeCmd.io.isBusy := isBusy
vmeCmd.io.inst := io.inst
vmeCmd.io.baddr := io.baddr
vmeCmd.io.vmeCmd <> io.vme_rd.cmd
val readLen = vmeCmd.io.readLen
val commandsDone = vmeCmd.io.done
require (mp.dataBits >= tp.tensorSizeBits,
"-F- Chacheline width must be larger than tensor bit size")
require(pow(2, log2Ceil(mp.dataBits)) == mp.dataBits,
"-F- Chacheline width must be pow of 2")
require(pow(2, log2Ceil(tp.tensorSizeBits)) == tp.tensorSizeBits,
"-F- Tensor size bits must be pow of 2")
// me mux puts tensors in a single memory line of Cacheline (CL) bits
val tensorsInClNb = tp.clSizeRatio
val tensorsInClNbWidth = log2Ceil(tensorsInClNb)
//--------------------------------------
//--- count how many CLs not receved ---
//--------------------------------------
// the address size of scratchpad memory
val clCntIdxWdth = log2Ceil(tp.memDepth/tensorsInClNb) + 1
// Nb of CLs requestd, not received.
val clInFlight = Reg(UInt(clCntIdxWdth.W))
when(io.start) {
clInFlight := 0.U
}.elsewhen(isBusy && io.vme_rd.cmd.fire && !vmeDataFirePipe) {
clInFlight := clInFlight + readLen
}.elsewhen(isBusy && io.vme_rd.cmd.fire && vmeDataFirePipe) {
clInFlight := clInFlight + readLen - 1.U
}.elsewhen(isBusy && !io.vme_rd.cmd.fire && vmeDataFirePipe) {
assert(clInFlight > 0.U)
clInFlight := clInFlight - 1.U
}.otherwise {
clInFlight := clInFlight
}
//---------------------
//--- Read VME data ---
//---------------------
val readData = Module(new ReadVMEDataWide(tensorType, debug))
readData.io.start := io.start
readData.io.vmeData.valid := vmeDataValidPipe
readData.io.vmeData.bits := vmeDataBitsPipe
assert(!readData.io.vmeData.valid || readData.io.vmeData.ready,
"-F- Expecting const ready. Fix ReadVMEData to receive data piped after ready")
io.vme_rd.data.ready := readData.io.vmeData.ready
// write mask defined number of elems strating with offset in SRAM line
val rdDataWrIdx = readData.io.destIdx // SP index vector
val rdDataWrData = readData.io.destData // SP data vector
val rdDataWrEn = readData.io.destMask // write enable vector
//-------------------------
//--- Fill zero padding ---
//-------------------------
val fillPadding = Module(new ZeroPadding(tensorType, debug))
fillPadding.io.canWriteMem := !vmeDataFirePipe
fillPadding.io.inst := io.inst
fillPadding.io.start := io.start
val isZeroPadWrite = fillPadding.io.tensorIdx.valid // Store zero filled tensor, zpDestIdx is valid
val zpDestIdx = fillPadding.io.tensorIdx.bits >> tensorsInClNbWidth // SP idx
val zpDestMask =
if (tensorsInClNb == 1) 1.U
else UIntToOH(fillPadding.io.tensorIdx.bits (tensorsInClNbWidth - 1, 0)) // tensor in a memory line
val paddingDone = fillPadding.io.done
//--------------------
//--- Write memory ---
//--------------------
// depth is reduced by dataBlock/tensorSize ratio
// width is dataBlock bits split into tensor bits
// each tensor is split into group bits
// group bits can be read/written independently
val splitDataFactor = tp.splitWidth * tp.splitLength
val splitMemFactor = tp.splitMemsFactor
val groupSizeBits = tp.tensorSizeBits/splitDataFactor
val memSizeBits = groupSizeBits/splitMemFactor
val tensorFile = Seq.fill(tensorsInClNb * splitDataFactor*splitMemFactor) {
SyncReadMem(tp.memDepth/tensorsInClNb, UInt(memSizeBits.W))
}
// direct write
val directWrIdx = for (grpIdx <- 0 until splitDataFactor) yield {
io.tensor.wr(grpIdx).bits.idx >> tensorsInClNbWidth // SP idx
}
val directWrMask = for (grpIdx <- 0 until splitDataFactor) yield {
Mux(
io.tensor.wr(grpIdx).valid,
if(tensorsInClNb == 1) 1.U
else UIntToOH(io.tensor.wr(grpIdx).bits.idx(tensorsInClNbWidth - 1, 0)),// tensor in a memory line
0.U)
}
// THIS directWrData writes continous scratchpad data space
// It is WRONG for ACC is batch is > 1
// maps group data bits to continous sequence of mem blocks
// but wr(x).bits.data is a window in a tensor
val directWrData = VecInit(for (grpIdx <- 0 until splitDataFactor) yield {
io.tensor.wr(grpIdx).bits.data
}).asTypeOf(UInt(tp.tensorSizeBits.W))
val wmask = Wire(Vec(tensorsInClNb*splitDataFactor*splitMemFactor, Bool()))
for (i <- 0 until tensorsInClNb) {
for (grpIdx <- 0 until splitDataFactor) {
for (memIdx <- 0 until splitMemFactor) { // duplicate control
wmask(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + memIdx) :=
Mux(
ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, en = true.B),
directWrMask(grpIdx)(i),
Mux(
ShiftRegister(isZeroPadWrite, writePipeLatency, resetData = false.B, en = true.B),
ShiftRegister(zpDestMask(i), writePipeLatency),
Mux(
ShiftRegister(vmeDataFirePipe, writePipeLatency, resetData = false.B, en = true.B),
ShiftRegister(rdDataWrEn(i), writePipeLatency),
false.B)))
}
}
}
val wdata = Wire(Vec(tensorsInClNb*splitDataFactor, UInt(groupSizeBits.W)))
for (i <- 0 until tensorsInClNb){
for (grpIdx <- 0 until splitDataFactor) {
val zpDestData = 0.U
wdata(i*splitDataFactor + grpIdx) := Mux(
ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, en = true.B),
io.tensor.wr(grpIdx).bits.data.asTypeOf(UInt(groupSizeBits.W)),
Mux(
ShiftRegister(isZeroPadWrite, writePipeLatency, resetData = false.B, en = true.B),
ShiftRegister(zpDestData /* group size zero */, writePipeLatency),
ShiftRegister(
(rdDataWrData(i).asTypeOf(Vec(splitDataFactor, UInt(groupSizeBits.W))))(grpIdx), writePipeLatency)))
}
}
val widx = Wire(Vec(tensorsInClNb*splitDataFactor*splitMemFactor, UInt(tp.memAddrBits.W)))
for (i <- 0 until tensorsInClNb) {
for (grpIdx <- 0 until splitDataFactor) {
for (memIdx <- 0 until splitMemFactor) { // duplicate control
widx(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + memIdx) :=
Mux(
ShiftRegister(state === sIdle, writePipeLatency, resetData = true.B, en = true.B),
directWrIdx(grpIdx),
Mux(
ShiftRegister(isZeroPadWrite, writePipeLatency, resetData = false.B, en = true.B),
ShiftRegister(zpDestIdx, writePipeLatency),
ShiftRegister(rdDataWrIdx(i), writePipeLatency)))
}
}
}
for (i <- 0 until tensorsInClNb) {
for (grpIdx <- 0 until splitDataFactor) {
for (memIdx <- 0 until splitMemFactor) { // duplicate control
when(wmask(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + memIdx)) {
tensorFile(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + memIdx).write(
widx(i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + memIdx),
wdata(i*splitDataFactor + grpIdx).asTypeOf(
Vec(splitMemFactor, UInt(memSizeBits.W)))(memIdx))
}
}
}
}
if (debug) {
when(isZeroPadWrite) {
printf(s"[TensorLoad] $tensorType isZeroPadWrite data zpDestIdx:%d\n",
zpDestIdx)
}
}
// read-from-sram
for (grpIdx <- 0 until splitDataFactor) {
val rIdx = io.tensor.rd(grpIdx).idx.bits >> tensorsInClNbWidth // SP idx
val rMask =
Mux(
io.tensor.rd(grpIdx).idx.valid,
if(tensorsInClNb == 1) 1.U
else UIntToOH(io.tensor.rd(grpIdx).idx.bits(tensorsInClNbWidth - 1, 0)),// tensor in a memory line
0.U)
val rdataVec = for (i <- 0 until tensorsInClNb) yield {
VecInit(for (memIdx <- 0 until splitMemFactor) yield {
tensorFile(
i*splitDataFactor*splitMemFactor + grpIdx * splitMemFactor + memIdx).read(
ShiftRegister(rIdx, tp.readTensorLatency),
ShiftRegister(VecInit(rMask.asBools)(i), tp.readTensorLatency, resetData = false.B, en = true.B))
}).asUInt
}
val rdata = Wire(UInt(tp.tensorSizeBits.W))
rdata := Mux1H(ShiftRegister(rMask, tp.readTensorLatency + 1), rdataVec)
io.tensor.rd(grpIdx).data.bits := rdata.asTypeOf(io.tensor.rd(grpIdx).data.bits.cloneType)
val rvalid = ShiftRegister(
io.tensor.rd(grpIdx).idx.valid, tp.readTensorLatency + 1, resetData = false.B, en = true.B)
io.tensor.rd(grpIdx).data.valid := rvalid
}
// done
val loadDone = clInFlight === 0.U && commandsDone && state === sBusy
localDone := loadDone && paddingDone
io.done := ShiftRegister(localDone, writePipeLatency, resetData = false.B, en = true.B)
}
//---------------------
//--- Read VME data ---
//---------------------
//----------------------------------------------------------------------------
// Read VME data. Generate Memory index and data
// transaction TAG is a data block offset in scratchpad
// Different transactions are identified by atag change
// SAME DESTINATION SUBSEQUENT REQUESTS IN ONE INSTRUCTION LEADS TO UNDEFINED BEHAVIOR
//----------------------------------------------------------------------------
class ReadVMEDataWide(tensorType: String = "none", debug: Boolean = false)(
implicit p: Parameters)
extends Module {
val tp = new TensorParams(tensorType)
val mp = p(ShellKey).memParams
val wmaskWidth = mp.dataBits/tp.tensorSizeBits
val io = IO(new Bundle {
val start = Input(Bool())
val vmeData = Flipped(Decoupled(new VMEData))
val destIdx = Output(Vec(tp.clSizeRatio, UInt(tp.memAddrBits.W)))
val destData = Output(Vec(tp.clSizeRatio, UInt(tp.tensorSizeBits.W)))
val destMask = Output(Vec(tp.clSizeRatio, Bool()))
})
io.vmeData.ready := true.B // always ready to read VME data
require(pow(2, log2Ceil(tp.tensorLength)) == tp.tensorLength,
"-F- Tensor length must be 2^. Using shift and bits to divide.")
val blkIdxWdth = log2Ceil(tp.memDepth) // the size of scratchpad in cache lines
//decode data destination
val vmeTagDecode = io.vmeData.bits.tag
val vmeTagDecodeLast = Reg(vmeTagDecode.cloneType) // store tag to identify a new burst
val clBytes = mp.dataBits / 8 // cacheline bytes
val elemBytes = tp.tensorLength * tp.tensorWidth * tp.tensorElemBits / 8 // bytes in tensor
val rdDataMaskDecodeWidth = if (wmaskWidth == 1) 1 else (log2Ceil(wmaskWidth) + 1)
val rdDataElemIdx = vmeTagDecode(vmeTagDecode.getWidth - 1, 2 * rdDataMaskDecodeWidth)
val rdFstOffsetNb = if (rdDataMaskDecodeWidth == 0) {
0.U
} else {
val readOffset = vmeTagDecode(2 * rdDataMaskDecodeWidth - 1, rdDataMaskDecodeWidth)
readOffset
}
val rdLstNb = if (rdDataMaskDecodeWidth == 0) {
1.U
} else {
val readNb = vmeTagDecode(rdDataMaskDecodeWidth - 1, 0)
assert(!io.vmeData.valid || readNb > 0.U,"-F- Expecting some elements to read")
readNb
}
val wrMask1st = if (rdDataMaskDecodeWidth == 0) {
1.U
} else {
Reverse(VecInit(for(idx <- 0 until wmaskWidth) yield {
idx.U < tp.clSizeRatio.U - rdFstOffsetNb
}).asUInt)
}
val wrMaskLast = if (rdDataMaskDecodeWidth == 0) {
1.U
} else {
VecInit(for(idx <- 0 until wmaskWidth) yield {
idx.U < rdLstNb
}).asUInt
}
val rdDataElemDestIdx = Wire(UInt(tp.memAddrBits.W)) // this is an idx of a tensor
val rdDataElemDestIdxNext = Reg(UInt(tp.memAddrBits.W))
val rdDataClDestIdx = rdDataElemDestIdx >> log2Ceil(tp.clSizeRatio)
val rdDataDestElemOffset = rdDataElemDestIdx % tp.clSizeRatio.U
val vmeTagDecodeLastValid = Wire(Bool())
val vmeTagDecodeLastValidNext = RegNext(
next = vmeTagDecodeLastValid,
init = false.B)
when(io.start) {
vmeTagDecodeLastValid :=false.B // reset tag valid
}.elsewhen(io.vmeData.fire) {
vmeTagDecodeLastValid := true.B // set tag valid on a new read
}.otherwise {
vmeTagDecodeLastValid := vmeTagDecodeLastValidNext // keep value
}
val isFirstPulse = Wire(Bool())
val isLastPulse = io.vmeData.bits.last
val wmaskSel =
Mux(
isFirstPulse && isLastPulse,
wrMask1st & wrMaskLast,
Mux(
isFirstPulse,
wrMask1st,
Mux(
isLastPulse,
wrMaskLast,
((1 << wmaskWidth) - 1).U)))
val wmask = Mux(io.vmeData.fire, wmaskSel, 0.U)
rdDataElemDestIdx := DontCare
isFirstPulse := false.B
when(io.vmeData.fire) {
when (
!vmeTagDecodeLastValidNext ||
(vmeTagDecodeLastValidNext &&
vmeTagDecode.asUInt =/= vmeTagDecodeLast.asUInt)) {
vmeTagDecodeLast := vmeTagDecode // a new burst
isFirstPulse := true.B
rdDataElemDestIdx := rdDataElemIdx
// dont incrememt first partial read pulse
rdDataElemDestIdxNext := rdDataElemIdx + PopCount(wmask)
}.otherwise {
rdDataElemDestIdxNext := rdDataElemDestIdxNext + PopCount(wmask)
rdDataElemDestIdx := rdDataElemDestIdxNext
}
}
val srcData = io.vmeData.bits.data.asTypeOf(Vec(tp.clSizeRatio, UInt(tp.tensorSizeBits.W)))
val srcOffset = Wire(Vec(tp.clSizeRatio, UInt((log2Ceil(tp.clSizeRatio) + 1).W)))
val srcIdx = Wire(Vec(tp.clSizeRatio, UInt(log2Ceil(tp.clSizeRatio).W)))
// D(j+d) = S(j+s) replace i=j+d --> D(i) = S(i-d+s)
for (i <- 0 until tp.clSizeRatio) {
srcOffset(i) := i.U + Mux(isFirstPulse, rdFstOffsetNb, 0.U)
srcIdx(i) := srcOffset(i) -% rdDataDestElemOffset
val srcIdxOH = UIntToOH(srcIdx(i))
io.destData(i) := Mux1H(srcIdxOH,srcData)
io.destMask(i) := Mux1H(srcIdxOH, wmask)
//if dest offset overflow, incr that dest idx
val incrIdx = if (tp.clSizeRatio == 1 ) {
0.U
} else {
Mux(srcOffset(i) >= rdDataDestElemOffset, 0.U, 1.U)
}
io.destIdx(i) := rdDataClDestIdx + incrIdx
}
}
// transaction TAG is a data block offset in scratchpad
// Different transactions are identified by atag change
// SAME DESTINATION SUBSEQUENT REQUESTS IN ONE INSTRUCTION LEADS TO UNDEFINED BEHAVIOR
class GenVMECmdWide(tensorType: String = "none", debug: Boolean = false)(
implicit p: Parameters)
extends Module {
val tp = new TensorParams(tensorType)
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
val isBusy = Input(Bool())
val updateState = Input(Bool())
val canSendCmd = Input(Bool())
val baddr = Input(UInt(mp.addrBits.W))
val vmeCmd = Decoupled(new VMECmd)
val readLen = Output(UInt((mp.lenBits + 1).W))
val done = Output(Bool())
val fstPulseDataStart = Output(UInt((log2Ceil(tp.clSizeRatio) + 1).W))
val lstPulseDataEnd = Output(UInt((log2Ceil(tp.clSizeRatio) + 1).W))
val spElemIdx = Output(UInt(tp.memAddrBits.W))
val ysize = Input(UInt(M_SIZE_BITS.W))
val xsize = Input(UInt(M_SIZE_BITS.W))
val xstride = Input(UInt(M_STRIDE_BITS.W))
val dram_offset = Input(UInt(M_DRAM_OFFSET_BITS.W))
val sram_offset = Input(UInt(M_SRAM_OFFSET_BITS.W))
val xpad_0 = Input(UInt(M_PAD_BITS.W))
val xpad_1 = Input(UInt(M_PAD_BITS.W))
val ypad_0 = Input(UInt(M_PAD_BITS.W))
})
val clBytes = mp.dataBits / 8 // cacheline bytes
val elemBytes = tp.tensorLength * tp.tensorWidth * tp.tensorElemBits / 8 // bytes in tensor
val stride = Wire(Bool()) // flags change to the next row to read
//----------------------------------------
//--- Count lines of DRAM memory lines ---
//----------------------------------------
// set which source row of data to read. io.ysize defines the number of rows
val dramLineIdx = Reg(UInt(io.ysize.getWidth.W)) // current row of stride read
when (io.start) {
dramLineIdx := 0.U // 1st row
}.elsewhen (stride) {
dramLineIdx := dramLineIdx + 1.U // increment row
}.otherwise {
dramLineIdx := dramLineIdx // stay in the row
}
// calculate address of DRAM memory line begin (initial/stride)
val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
val dramInitialAddr = (io.dram_offset << log2Ceil(elemBytes)).asTypeOf(UInt(mp.addrBits.W))
val xferElemInitAddr = io.baddr | dramInitialAddr // SHOULD have + here?
//aling address to CL size
// lower bits - elem offset in a cachline
val dramClAddrAlignNotMask = ((BigInt(1) << log2Ceil(clBytes)) - 1).U.asTypeOf(xferElemInitAddr)
// upper bits - cacheline alinement
val dramClAddrAlignMask = ~dramClAddrAlignNotMask
val xferClInitAddr = xferElemInitAddr & dramClAddrAlignMask
val rdLineElemBeginAddr = Reg(UInt(mp.addrBits.W)) // DRAM address of xsize tensors memory line
val rdLineClBeginAddr = rdLineElemBeginAddr & dramClAddrAlignMask
// begin of the next DRAM memory line
val nextLineBeginElemAddr = rdLineElemBeginAddr + (io.xstride << log2Ceil(elemBytes))
val nextLineBeginClAddr = nextLineBeginElemAddr & dramClAddrAlignMask
when (io.start) {
rdLineElemBeginAddr := xferElemInitAddr
}.elsewhen (stride) {
rdLineElemBeginAddr := nextLineBeginElemAddr
}.otherwise {
rdLineElemBeginAddr := rdLineElemBeginAddr
}
//-----------------------------------------------------
//--- Calculate current DRAM address of transaction ---
//-----------------------------------------------------
val rdLen = Wire(UInt((mp.lenBits + 1).W)) // read cmd transaction length. It is <= maxTransfer
val rdLineAddr = Reg(UInt(mp.addrBits.W)) // current DRAM address of command
when (io.start) {
rdLineAddr := xferClInitAddr
}.elsewhen (io.updateState) {
when(stride) {
rdLineAddr := nextLineBeginClAddr
}.otherwise {
rdLineAddr := rdLineAddr + (rdLen << log2Ceil(clBytes))
}
}.otherwise {
rdLineAddr := rdLineAddr
}
//total load length in cachelines
val rdLineBytes = io.xsize << log2Ceil(elemBytes)
//First transaction in a line length (1st or stride)
val maxTransfer = (1 << mp.lenBits).U // max number of pulses in transfer
val maxTrBytes = maxTransfer << log2Ceil(clBytes)
val rdLen1stMaxTransBytes = maxTrBytes - rdLineClBeginAddr % maxTrBytes
// get the number of cachelines till maxTrBytes aligned address
val rdLen1stMaxTransClNb = rdLen1stMaxTransBytes >> log2Ceil(clBytes)
//Transaction begin mask. Number of tensors to read from right
val rd1stPulseOffsetBytes = rdLineElemBeginAddr % clBytes.U
assert(rd1stPulseOffsetBytes >> log2Ceil(elemBytes) <= tp.clSizeRatio.U,
"-F- Expecting the number of tensors to skip in CL")
val rd1stPulseOffsetTensNb = Wire(UInt((log2Ceil(tp.clSizeRatio) + 1).W))
rd1stPulseOffsetTensNb := rd1stPulseOffsetBytes >> log2Ceil(elemBytes)
val rdLineClNbTmp = (rdLineBytes + rd1stPulseOffsetBytes) >> log2Ceil(clBytes)
val rdLineClNb =
Mux((rdLineBytes + rd1stPulseOffsetBytes) % clBytes.U === 0.U, rdLineClNbTmp, rdLineClNbTmp + 1.U)
//Transaction end mask. Number of tensors to read from left
val rdLastPulseBytes = (rdLineElemBeginAddr + rdLineBytes) % clBytes.U
assert(rdLastPulseBytes >> log2Ceil(elemBytes) <= (clBytes/elemBytes).U,
"-F- Expecting the number of active tensors in CL")
val rdLastPulseTensNb = Wire(UInt((log2Ceil(clBytes/elemBytes) + 1).W))
val rdLastPulseTensNbTmp = rdLastPulseBytes >> log2Ceil(elemBytes)
rdLastPulseTensNb := Mux(rdLastPulseTensNbTmp === 0.U, (clBytes/elemBytes).U, rdLastPulseTensNbTmp)
//--------------------------------------
//--- Generate data load VME command ---
//--------------------------------------
val rdCmdStartIdxValid = Wire(Bool()) // Command is valid
val startIssueCmdRead = Wire(Bool()) // First transaction in io.xsize transfer
val rdCmdStartIdx = Reg(UInt(log2Ceil(tp.memDepth).W)) // Scratchpad data block index for the first transaction
val commandsDone = RegInit(true.B) // Done generating VME commands
// counts the number of CLs read in a xsize line
val clReadIdx = Reg(UInt((io.xsize.getWidth + log2Ceil(elemBytes) - log2Ceil(clBytes)).W))
val newReadRow = clReadIdx === 0.U // flags the first read of io.xsize
// set how many blocks of data being loaded
commandsDone := commandsDone
when (io.start || stride) {
clReadIdx := 0.U
commandsDone := false.B
}.elsewhen (io.updateState) {
val nextClIdx = clReadIdx + rdLen
clReadIdx := nextClIdx // THIS IS WHEN A NEW VME CMD HAPPENS
when (nextClIdx === rdLineClNb && dramLineIdx === io.ysize - 1.U) {
commandsDone := true.B
}
}.otherwise {
clReadIdx := clReadIdx
}
//when the whole xsize row read commands are sent, go for the next src row
when((clReadIdx === rdLineClNb - rdLen) && (dramLineIdx =/= io.ysize - 1.U) && io.updateState) {
stride := true.B
}.otherwise {
stride := false.B
}
// current transaction tensors to read nb in 1st and last pulses
val rdCmd1stPluseOffsetTensNb = Wire(rd1stPulseOffsetTensNb.cloneType)
val rdCmdLastPluseTensNb = Wire(rdLastPulseTensNb.cloneType)
when(newReadRow) {
// first read in line
rdCmd1stPluseOffsetTensNb := rd1stPulseOffsetTensNb
}.otherwise {
// any other read
rdCmd1stPluseOffsetTensNb := 0.U
}
when (clReadIdx === rdLineClNb - rdLen) {
// last read in line
rdCmdLastPluseTensNb := rdLastPulseTensNb
}.otherwise {
// any other read
rdCmdLastPluseTensNb := (clBytes/elemBytes).U
}
//when the whole xsize row read commands are sent, go for the next src row
when((clReadIdx === rdLineClNb - rdLen) && (dramLineIdx =/= io.ysize - 1.U) && io.updateState) {
stride := true.B
}.otherwise {
stride := false.B
}
assert(!io.isBusy || rdLineClNb >= clReadIdx)// define how many cachelines to read at this cycle
val clRemained = rdLineClNb - clReadIdx
when (newReadRow) {
when(clRemained < rdLen1stMaxTransClNb) {
rdLen := clRemained
}.otherwise {
rdLen := rdLen1stMaxTransClNb
}
}.otherwise {
when(clRemained < maxTransfer) {
rdLen := clRemained
}.otherwise {
rdLen := maxTransfer
}
}
// block index of the read data row (xsize). Modified by zero padding
val totalWidth = io.xsize + io.xpad_0 + io.xpad_1 // width of scratchpad matrix in tensors
// instead of multiplying total width by ypad_0 do incremental addition.
//Should cost ypad_0 cycles to issue 1st read cmd
// counts src matrix with y padding rows of tensors
val currentRowIdx = Reg(UInt((io.ysize.getWidth + io.ypad_0.getWidth).W))
// start to issue read cmd
rdCmdStartIdxValid := currentRowIdx >= io.ypad_0 &&
currentRowIdx < (io.ysize + io.ypad_0) &&
io.isBusy &&
!commandsDone
when (io.start) {
currentRowIdx := 0.U
rdCmdStartIdx := io.sram_offset + io.xpad_0 // this index is in tensors
}.elsewhen (io.isBusy && (currentRowIdx < io.ypad_0 || stride)) {
rdCmdStartIdx := rdCmdStartIdx + totalWidth
currentRowIdx := currentRowIdx + 1.U
}
startIssueCmdRead := false.B
when(newReadRow && rdCmdStartIdxValid) {
startIssueCmdRead := true.B
}
//-------------------------------------
//--- execute VME data load command ---
//-------------------------------------
require(pow(2, log2Ceil(tp.tensorLength)) == tp.tensorLength,
"-F- Tensor length must be 2^. Using shift and bits to divide.")
val blkIdxWdth = log2Ceil(tp.memDepth) // the size of scratchpad
val rdCmdDestElemIdx = Wire(UInt(tp.memAddrBits.W)) // element(tensor) size block index in a scratchpad
val rdCmdDestElemIdxNext = Reg(rdCmdDestElemIdx.cloneType)
rdCmdDestElemIdxNext := rdCmdDestElemIdxNext
rdCmdDestElemIdx := rdCmdDestElemIdxNext
val rdCmdValid = Wire(Bool())
// the number of tensors read in transaction
val rdCmdTransactionTensNb = (rdLen << log2Ceil(clBytes/elemBytes)) - rdCmd1stPluseOffsetTensNb
//increment scratch pad destination index
when(rdCmdStartIdxValid) {
rdCmdValid := true.B
when(startIssueCmdRead) {
rdCmdDestElemIdx := rdCmdStartIdx
rdCmdDestElemIdxNext:= rdCmdStartIdx + rdCmdTransactionTensNb
}.elsewhen (io.updateState) {
// increment block position by transaction length
rdCmdDestElemIdxNext:= rdCmdDestElemIdxNext + rdCmdTransactionTensNb
}
}.otherwise {
rdCmdValid := false.B
}
// read-from-dram
require(io.vmeCmd.bits.tag.getWidth >= rdCmdDestElemIdx.getWidth +
rdCmdLastPluseTensNb.getWidth + rdCmd1stPluseOffsetTensNb.getWidth,
s"-F- Tensor ${tensorType} Not enough VME tag bits to store transaction" +
s" tag. need:${rdCmdDestElemIdx.getWidth + rdCmdLastPluseTensNb.getWidth + rdCmd1stPluseOffsetTensNb.getWidth}")
io.vmeCmd.valid := rdCmdValid && io.canSendCmd
io.vmeCmd.bits.addr := rdLineAddr
io.vmeCmd.bits.len := rdLen - 1.U
assert(!io.vmeCmd.valid || ((rdLen << log2Ceil(clBytes)) <= maxTrBytes - rdLineAddr % maxTrBytes),
s"-F- ${tensorType} DRAM page alignment failure. DRAM " +
s"address + len overlaps mp.lenBits*memBlockSize alignment %x %x",
rdLineAddr, rdLen)
io.vmeCmd.bits.tag := Cat(rdCmdDestElemIdx, Cat(rdCmd1stPluseOffsetTensNb, rdCmdLastPluseTensNb))
io.readLen := rdLen
io.spElemIdx := rdCmdDestElemIdx // scratchpad tensor idx
io.fstPulseDataStart := rdCmd1stPluseOffsetTensNb // first pulse data start
io.lstPulseDataEnd := rdCmdLastPluseTensNb // last pulse data end
io.done := commandsDone
}
class GenVMECmdWideTL(tensorType: String = "none", debug: Boolean = false)(
implicit p: Parameters)
extends Module {
val tp = new TensorParams(tensorType)
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
val isBusy = Input(Bool())
val inst = Input(UInt(INST_BITS.W))
val baddr = Input(UInt(mp.addrBits.W))
val vmeCmd = Decoupled(new VMECmd)
val readLen = Output(UInt((mp.lenBits + 1).W))
val done = Output(Bool())
})
val dec = io.inst.asTypeOf(new MemDecode)
val cmdGen = Module (new GenVMECmdWide(tensorType, debug))
cmdGen.io.start := io.start
cmdGen.io.isBusy := io.isBusy
cmdGen.io.baddr := io.baddr
io.vmeCmd <> cmdGen.io.vmeCmd
io.readLen := cmdGen.io.readLen
io.done := cmdGen.io.done
cmdGen.io.ysize := dec.ysize
cmdGen.io.xsize := dec.xsize
cmdGen.io.xstride := dec.xstride
cmdGen.io.dram_offset := dec.dram_offset
cmdGen.io.sram_offset := dec.sram_offset
cmdGen.io.xpad_0 := dec.xpad_0
cmdGen.io.xpad_1 := dec.xpad_1
cmdGen.io.ypad_0 := dec.ypad_0
cmdGen.io.updateState := io.vmeCmd.fire
cmdGen.io.canSendCmd := true.B
}