blob: f2d295f66220dae255609918758aeffa0d45fa90 [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
package vta.core
import chisel3._
import chisel3.util._
import chisel3.experimental._
import vta.util.config._
import scala.math.pow
/** Pipelined multiply and accumulate */
class MAC(aBits: Int = 8, bBits: Int = 8, cBits: Int = 16) extends Module {
val outBits = Math.max(aBits + bBits, cBits) + 1
val io = IO(new Bundle {
val a = Input(SInt(aBits.W))
val b = Input(SInt(bBits.W))
val c = Input(SInt(cBits.W))
val y = Output(SInt(outBits.W))
val mult = Wire(SInt((aBits + bBits).W))
val add = Wire(SInt(outBits.W))
val rA = RegNext(io.a)
val rB = RegNext(io.b)
val rC = RegNext(io.c)
mult := rA * rB
add := rC +& mult
io.y := add
/** PipeAdder
* This unit loads input bits into register and performs addition in the next cycle
class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module {
val outBits = Math.max(aBits, bBits) + 1
val io = IO(new Bundle {
val a = Input(SInt(aBits.W))
val b = Input(SInt(bBits.W))
val y = Output(SInt(outBits.W))
val add = Wire(SInt(outBits.W))
val rA = RegNext(io.a)
val rB = RegNext(io.b)
add := rA +& rB
io.y := add
/** Adder
* This unit wires input bits to an adder directly.
* The output comes out of combinational logic without waiting for another cycle.
class Adder(aBits: Int = 8, bBits: Int = 8) extends Module {
val outBits = Math.max(aBits, bBits) + 1
val io = IO(new Bundle {
val a = Input(SInt(aBits.W))
val b = Input(SInt(bBits.W))
val y = Output(SInt(outBits.W))
val add = Wire(SInt(outBits.W))
val rA = Wire(SInt(aBits.W))
val rB = Wire(SInt(bBits.W))
rA := io.a
rB := io.b
add := rA +& rB
io.y := add
/** Pipelined DotProduct based on MAC and PipeAdder */
class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module {
val errorMsg =
s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
require(size >= 2 && isPow2(size), errorMsg)
val b = aBits + bBits
val outBits = b + log2Ceil(size) + 1
val io = IO(new Bundle {
val a = Input(Vec(size, SInt(aBits.W)))
val b = Input(Vec(size, SInt(bBits.W)))
val y = Output(SInt(outBits.W))
val s = Seq.tabulate(log2Ceil(size + 1))(i =>
pow(2, log2Ceil(size) - i).toInt) // # of total layers
val p = log2Ceil(size / 2) + 1 // # of adder layers
val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs
val a = Seq.tabulate(p)(
i =>
Seq.fill(s(i + 1))(
if (i == 0)
Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1)))
Module(new Adder(aBits = (b + i + 1), bBits = (b + i + 1))))) // # adders within each layer
// Vector MACs
for (i <- 0 until s(0)) {
m(i).io.a := io.a(i)
m(i).io.b := io.b(i)
m(i).io.c := 0.S
// PipeAdder Reduction
for (i <- 0 until p) {
for (j <- 0 until s(i + 1)) {
if (i == 0) {
// First layer of PipeAdders
a(i)(j).io.a := m(2 * j).io.y
a(i)(j).io.b := m(2 * j + 1).io.y
} else {
a(i)(j).io.a := a(i - 1)(2 * j).io.y
a(i)(j).io.b := a(i - 1)(2 * j + 1).io.y
// last adder
io.y := a(p - 1)(0).io.y
/** Perform matrix-vector-multiplication based on DotProduct */
class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
val accBits = p(CoreKey).accBits
val size = p(CoreKey).blockOut
val inpBits = p(CoreKey).inpBits
val wgtBits = p(CoreKey).wgtBits
val outBits = p(CoreKey).outBits
val io = IO(new Bundle {
val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr
val inp = new TensorMasterData(tensorType = "inp")
val wgt = new TensorMasterData(tensorType = "wgt")
val acc_i = new TensorMasterData(tensorType = "acc")
val acc_o = new TensorClientData(tensorType = "acc")
val out = new TensorClientData(tensorType = "out")
val dot = Seq.fill(size)(
Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size)))
// Latency is defined as two in the following, because there is one cycle in the MAC module,
// and another cycle in the pipelined adders as the first layer of the accumulator
val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = 2)))
val add = Seq.fill(size)(Wire(SInt(accBits.W)))
val vld = Wire(Vec(size, Bool()))
for (i <- 0 until size) {
acc(i).io.enq.valid := & & & ~io.reset
acc(i).io.enq.bits :=
for (j <- 0 until size) {
dot(i).io.a(j) :=
dot(i).io.b(j) :=
add(i) := acc(i).io.deq.bits.asSInt + dot(i).io.y := Mux(io.reset, 0.U, add(i).asUInt) := add(i).asUInt
vld(i) := acc(i).io.deq.valid
} := vld.asUInt.andR | io.reset := vld.asUInt.andR
/** TensorGemm.
* This unit instantiate the MatrixVectorMultiplication and go over the
* micro-ops (uops) which are used to read inputs, weights and biases,
* and writes results back to the acc and out scratchpads.
* Also, the TensorGemm uses the reset field in the Gemm instruction to
* clear or zero-out the acc-scratchpad locations based on the micro-ops.
class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
val inst = Input(UInt(INST_BITS.W))
val uop = new UopMaster
val inp = new TensorMaster(tensorType = "inp")
val wgt = new TensorMaster(tensorType = "wgt")
val acc = new TensorMaster(tensorType = "acc")
val out = new TensorMaster(tensorType = "out")
val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil =
val state = RegInit(sIdle)
val mvc = Module(new MatrixVectorMultiplication)
val dec = io.inst.asTypeOf(new GemmDecode)
val uop_idx = Reg(chiselTypeOf(dec.uop_end))
val uop_end = dec.uop_end
val uop_acc = Reg(chiselTypeOf(dec.uop_end))
val uop_inp = Reg(chiselTypeOf(dec.uop_end))
val uop_wgt = Reg(chiselTypeOf(dec.uop_end))
val cnt_o = Reg(chiselTypeOf(dec.lp_0))
val acc_o = Reg(chiselTypeOf(dec.uop_end))
val inp_o = Reg(chiselTypeOf(dec.uop_end))
val wgt_o = Reg(chiselTypeOf(dec.uop_end))
val cnt_i = Reg(chiselTypeOf(dec.lp_1))
val acc_i = Reg(chiselTypeOf(dec.uop_end))
val inp_i = Reg(chiselTypeOf(dec.uop_end))
val wgt_i = Reg(chiselTypeOf(dec.uop_end))
val pBits = log2Ceil(p(CoreKey).blockOut) + 1
val inflight = Reg(UInt(pBits.W))
// Latency is defined as two in the following, because there is one cycle in the MAC module,
// and another cycle in the pipelined adders as the first layer of the accumulator
val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = 2))
val done = inflight === 0.U &
((state === sExe &
cnt_o === dec.lp_0 - 1.U &
cnt_i === dec.lp_1 - 1.U &
uop_idx === uop_end - 1.U &
inflight === 0.U) |
state === sWait)
switch(state) {
is(sIdle) {
when(io.start) {
state := sReadUop
is(sReadUop) {
state := sComputeIdx
is(sComputeIdx) {
state := sReadTensor
is(sReadTensor) {
state := sExe
is(sExe) {
(cnt_o === dec.lp_0 - 1.U) &&
(cnt_i === dec.lp_1 - 1.U) &&
(uop_idx === uop_end - 1.U)) {
when(inflight =/= 0.U) {
state := sWait
}.otherwise {
state := sIdle
}.otherwise {
state := sReadUop
is(sWait) {
when(inflight === 0.U) {
state := sIdle
when(state === sIdle) {
inflight := 0.U
}.elsewhen(!dec.reset) {
when((state === sReadTensor) && { // issue & commit
inflight := inflight
}.elsewhen(state === sReadTensor) { // issue a tensor
inflight := inflight + 1.U
}.elsewhen( { // commit a tensor
inflight := inflight - 1.U
state === sIdle ||
(state === sExe &&
uop_idx === uop_end - 1.U)) {
uop_idx := dec.uop_begin
}.elsewhen(state === sExe && dec.uop_begin =/= uop_end) {
uop_idx := uop_idx + 1.U
when(state === sIdle) {
cnt_o := 0.U
acc_o := 0.U
inp_o := 0.U
wgt_o := 0.U
state === sExe &&
uop_idx === uop_end - 1.U &&
cnt_i === dec.lp_1 - 1.U) {
cnt_o := cnt_o + 1.U
acc_o := acc_o + dec.acc_0
inp_o := inp_o + dec.inp_0
wgt_o := wgt_o + dec.wgt_0
when(state === sIdle) {
cnt_i := 0.U
acc_i := 0.U
inp_i := 0.U
wgt_i := 0.U
}.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
cnt_i := 0.U
acc_i := acc_o
inp_i := inp_o
wgt_i := wgt_o
}.elsewhen(state === sExe && uop_idx === uop_end - 1.U) {
cnt_i := cnt_i + 1.U
acc_i := acc_i + dec.acc_1
inp_i := inp_i + dec.inp_1
wgt_i := wgt_i + dec.wgt_1
when(state === sComputeIdx && {
uop_acc := + acc_i
uop_inp := + inp_i
uop_wgt := + wgt_i
} := state === sExe & ~dec.reset := uop_acc
// uop
io.uop.idx.valid := state === sReadUop
io.uop.idx.bits := uop_idx
// inp
io.inp.rd.idx.valid := state === sReadTensor
io.inp.rd.idx.bits := uop_inp
io.inp.tieoffWrite() // read-only
// wgt
io.wgt.rd.idx.valid := state === sReadTensor
io.wgt.rd.idx.bits := uop_wgt
io.wgt.tieoffWrite() // read-only
// acc_i
io.acc.rd.idx.valid := state === sReadTensor
io.acc.rd.idx.bits := uop_acc
// mvc := dec.reset & state === sExe <> <> <>
// acc_o
io.acc.wr.valid := &
Mux(dec.reset, true.B,
io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, <>
// out
io.out.wr.valid := &
io.out.wr.bits.idx := <>
io.out.tieoffRead() // write-only
io.done := done
if (debug) {
when(state === sReadUop && ~dec.reset) {
printf("[TensorGemm] [uop] idx:%x\n", uop_idx)
when(state === sReadTensor && ~dec.reset) {
printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt)
} {
case (r, i) =>
when( && ~dec.reset) {
printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
} {
case (r, i) =>
when( && ~dec.reset) {
printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
} { tensor =>
tensor.zipWithIndex.foreach {
case (elem, i) =>
when( && ~dec.reset) {
printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
} { tensor =>
tensor.zipWithIndex.foreach {
case (elem, i) =>
when( && ~dec.reset) {
printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
} { tensor =>
tensor.zipWithIndex.foreach {
case (elem, i) =>
when( && ~dec.reset) {
printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)