blob: 8a3cea32f50f5514eb2ee0e424fdc54324ca0601 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import scala.math.pow
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** TensorStore.
*
* Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
*/
class TensorStoreWideVME(tensorType: String = "none", debug: Boolean = false)(
implicit p: Parameters)
extends Module {
val tp = new TensorParams(tensorType)
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
val inst = Input(UInt(INST_BITS.W))
val baddr = Input(UInt(mp.addrBits.W))
val vme_wr = new VMEWriteMaster
val tensor = new TensorClient(tensorType)
})
val writePipeLatency = tp.writePipeLatency
// Store write is delayed by writePipeLatency
// postpone start by the same number of cycles
// expects instr and baddr are valid from start till done
val localStart = ShiftRegister(io.start, writePipeLatency, resetData = false.B, en = true.B)
val dec = io.inst.asTypeOf(new MemDecode)
val sIdle :: sWriteCmd :: sWriteData :: sWriteAck :: Nil = Enum(4)
val state = RegInit(sIdle)
val cmdGen = Module (new GenVMECmdWide(tensorType, debug))
cmdGen.io.ysize := dec.ysize
cmdGen.io.xsize := dec.xsize
cmdGen.io.xstride := dec.xstride
cmdGen.io.dram_offset := dec.dram_offset
cmdGen.io.sram_offset := dec.sram_offset
cmdGen.io.xpad_0 := 0.U
cmdGen.io.xpad_1 := 0.U
cmdGen.io.ypad_0 := 0.U
cmdGen.io.start := localStart
cmdGen.io.isBusy := state =/= sIdle
cmdGen.io.baddr := io.baddr
cmdGen.io.updateState := state === sWriteCmd
cmdGen.io.canSendCmd := cmdGen.io.updateState
io.vme_wr.cmd <> cmdGen.io.vmeCmd
val commandsDone = cmdGen.io.done
// latch cmd parameters
val readLenReg = Reg(cmdGen.io.readLen.cloneType)
val readLen = Wire(readLenReg.cloneType)
val fstPulseDataStartReg = Reg(cmdGen.io.fstPulseDataStart.cloneType)
val fstPulseDataStart = Wire(fstPulseDataStartReg.cloneType)
val lstPulseDataEndReg = Reg(cmdGen.io.lstPulseDataEnd.cloneType)
val lstPulseDataEnd = Wire(lstPulseDataEndReg.cloneType)
val spElemIdxReg = Reg(cmdGen.io.spElemIdx.cloneType)
val spElemIdx = Wire(spElemIdxReg.cloneType)
when (cmdGen.io.updateState) {
readLen := cmdGen.io.readLen
readLenReg := readLen
fstPulseDataStart := cmdGen.io.fstPulseDataStart
fstPulseDataStartReg := fstPulseDataStart
lstPulseDataEnd := cmdGen.io.lstPulseDataEnd
lstPulseDataEndReg := lstPulseDataEnd
spElemIdx := cmdGen.io.spElemIdx
spElemIdxReg := spElemIdx
}.otherwise {
readLenReg := readLenReg
readLen := readLenReg
fstPulseDataStartReg := fstPulseDataStartReg
fstPulseDataStart := fstPulseDataStartReg
lstPulseDataEndReg := lstPulseDataEndReg
lstPulseDataEnd := lstPulseDataEndReg
spElemIdxReg := spElemIdxReg
spElemIdx := spElemIdxReg
}
val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
xcnt := xcnt
// control
val updateState = Wire(Bool())
updateState := false.B
switch(state) {
is(sIdle) {
when (localStart) {
state := sWriteCmd
}
}
is(sWriteCmd) {
when(io.vme_wr.cmd.fire()) {
state := sWriteData
updateState := true.B
xcnt := 0.U
}
}
is(sWriteData) {
when(io.vme_wr.data.fire()) {
when(xcnt === readLen - 1.U) {
state := sWriteAck
}.otherwise {
xcnt := xcnt + 1.U
}
}
}
is(sWriteAck) {
when(io.vme_wr.ack) {
when(commandsDone) {
state := sIdle
}.otherwise { // stride
state := sWriteCmd
}
}
}
}
//--------------------
//--- Write memory ---
//--------------------
val splitDataFactor = tp.splitWidth * tp.splitLength
val groupSizeBits = tp.tensorSizeBits/splitDataFactor
val tensorFile = Seq.fill(tp.clSizeRatio * splitDataFactor) {
SyncReadMem(tp.memDepth/tp.clSizeRatio, UInt(groupSizeBits.W))
}
// direct write
for (grpIdx <- 0 until splitDataFactor) {
val directWrIdx = io.tensor.wr(grpIdx).bits.idx >> log2Ceil(tp.clSizeRatio) // SP idx
val directWrTensorIdx =
if(tp.clSizeRatio == 1) 0.U
else io.tensor.wr(grpIdx).bits.idx(log2Ceil(tp.clSizeRatio) - 1, 0)
for (i <- 0 until tp.clSizeRatio) {
when(ShiftRegister(io.tensor.wr(grpIdx).valid && directWrTensorIdx === i.U, writePipeLatency,
resetData = false.B, en = true.B)) {
tensorFile(i*splitDataFactor + grpIdx).write(ShiftRegister(directWrIdx, writePipeLatency),
ShiftRegister(io.tensor.wr(grpIdx).bits.data.asUInt, writePipeLatency))
}
}
}
//--------------------
//--- Read memory ---
//--------------------
// first pulse doesnt reead whole data size, it is bounded by DRAM data alignment
// ! - data pulse boundary
// . - tenzor boundary
// tz - not used
// TZ - tensor to store
// =TZ= - first pulse tensor
// DRAM !-tz-.-tz-.=TZ=!-TZ-.-TZ-.-TZ-!
//
// SRAM !-tz-.=TZ=.-TZ-!-TZ-.-TZ-.-tz-!
val isFirstPulse = io.vme_wr.data.fire() && xcnt === 0.U
assert(state =/= sWriteData || readLen > 0.U)
val firstPulseTenzorsNb = tp.clSizeRatio.U - fstPulseDataStart
val isLastPulse = io.vme_wr.data.fire() && xcnt === readLen - 1.U
val spReadAddrReg = Reg(UInt(M_SRAM_OFFSET_BITS.W))
val spReadAddr = Wire(spReadAddrReg.cloneType)
val srcElemOffsetReg = Reg(UInt(log2Ceil(tp.clSizeRatio).W))
val srcElemOffset = Wire(srcElemOffsetReg.cloneType)
val incrFstIdx = Mux((spElemIdx % tp.clSizeRatio.U) < fstPulseDataStart.asTypeOf(UInt(width = 8.W)), 0.U , 1.U)
spReadAddr := DontCare
when(state === sWriteCmd) {
// init by data block index
spReadAddr := spElemIdx >> log2Ceil(tp.clSizeRatio)
spReadAddrReg := spReadAddr + incrFstIdx
srcElemOffset := spElemIdx % tp.clSizeRatio.U
srcElemOffsetReg := (spElemIdx + firstPulseTenzorsNb) % tp.clSizeRatio.U
}.elsewhen(io.vme_wr.data.fire()) {
spReadAddrReg := spReadAddrReg + 1.U
spReadAddr := spReadAddrReg
srcElemOffset := (spElemIdx + firstPulseTenzorsNb) % tp.clSizeRatio.U
srcElemOffsetReg := srcElemOffset
}.otherwise {
spReadAddrReg := spReadAddrReg
srcElemOffsetReg := srcElemOffsetReg
srcElemOffset := srcElemOffsetReg
}
val dstData = Wire(Vec(tp.clSizeRatio, UInt(tp.tensorSizeBits.W)))
val srcData = Wire(Vec(tp.clSizeRatio, UInt(tp.tensorSizeBits.W)))
val srcMemIdx = Wire(Vec(tp.clSizeRatio, spReadAddr.cloneType))
val dstOffset = Wire(Vec(tp.clSizeRatio, UInt((log2Ceil(tp.clSizeRatio) + 1).W)))
val dstIdx = Wire(Vec(tp.clSizeRatio, UInt(log2Ceil(tp.clSizeRatio).W)))
// D(j+d) = S(j+s) replace i=j+d --> D(i) = S(i-d+s)
for (i <- 0 until tp.clSizeRatio) {
//if src offset overflow, incr that dest idx, read next memory row
val incrIdx = if (tp.clSizeRatio == 1 ) {
0.U
} else {
Mux(i.U >= srcElemOffset, 0.U, 1.U)
}
srcMemIdx(i) := spReadAddr + incrIdx
//read memory
srcData(i) := VecInit(for (grpIdx <- 0 until splitDataFactor) yield {
tensorFile(i*splitDataFactor + grpIdx).read(
srcMemIdx(i),
state === sWriteCmd | (state === sWriteData && io.vme_wr.data.fire()))
}).asTypeOf(UInt(tp.tensorSizeBits.W))
// crossbar src to dst
dstOffset(i) := i.U + spElemIdx % tp.clSizeRatio.U
dstIdx(i) := dstOffset(i) -% fstPulseDataStart
dstData(i) := Mux1H(UIntToOH(dstIdx(i)), srcData)
}
// build valid bytes strb
val tensorSizeBytes = tp.tensorSizeBits/8
val validBytes = Wire(Vec(tp.clSizeRatio, UInt(tensorSizeBytes.W)))
val tensorBytesOnes = (BigInt(1) << tensorSizeBytes) - 1
when(isFirstPulse && !isLastPulse) {
for (i <- 0 until tp.clSizeRatio) {
validBytes(i) := Mux(i.U < fstPulseDataStart && fstPulseDataStart =/= 0.U, 0.U, tensorBytesOnes.U)
}
}.elsewhen (!isFirstPulse && isLastPulse) {
for (i <- 0 until tp.clSizeRatio) {
validBytes(i) := Mux(i.U >= lstPulseDataEnd && lstPulseDataEnd =/= 0.U, 0.U, tensorBytesOnes.U)
}
}.elsewhen (isFirstPulse && isLastPulse) {
for (i <- 0 until tp.clSizeRatio) {
validBytes(i) := Mux((i.U < fstPulseDataStart && fstPulseDataStart =/= 0.U)
|| (i.U >= lstPulseDataEnd && lstPulseDataEnd =/= 0.U), 0.U, tensorBytesOnes.U)
}
}.otherwise {
for (i <- 0 until tp.clSizeRatio) {
validBytes(i) := tensorBytesOnes.U
}
}
io.vme_wr.data.valid := state === sWriteData
io.vme_wr.data.bits.data := dstData.asUInt
io.vme_wr.data.bits.strb := validBytes.asUInt
// disable external read-from-sram requests
io.tensor.tieoffRead()
// done
io.done := state === sWriteAck & commandsDone & io.vme_wr.ack
// debug
if (debug) {
when(io.vme_wr.data.fire()) {
printf("[TensorStore] data:%x\n", io.vme_wr.data.bits.data)
printf("[TensorStore] strb:%x\n", io.vme_wr.data.bits.strb)
}
when(io.vme_wr.ack) {
printf("[TensorStore] ack\n")
}
}
}