[VTA][Chisel] TSIM VTA Source Refactor (#4163)
* app init push
* fix on readme
* change name, add bit serial explanantion
* rm serialLoadMM, change doc
* syntax change for readme
* add parallel test functionality
* fix readme
* add python doc
* syntax
* init commit
* fix empty line
* fix typo
diff --git a/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala b/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala
index 325fce1..6bfe3e0 100644
--- a/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala
+++ b/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala
@@ -22,21 +22,31 @@
import chisel3._
import chisel3.util._
import vta.dpi._
+import vta.core._
+import vta.util.config._
+import vta.shell._
+class TestConfig extends Config(new CoreConfig ++ new PynqConfig)
/** Compute
*
* Bit Slice GEMM:
*
* 1. Wait for launch to be asserted
- * 2. Issue 2 read request for 8-byte value at inp1_baddr address and inp2_baddr address
+ * 2. Issue 1 read request for 8-bit value at inp1_baddr address (read matrix)
* 3. Wait for the value
* 4. Increment read-address for next value
- * 5. Wait for sliced accumulator
- * 6. Check if counter (cnt) is equal to length process,
- otherwise goto step 2
- * 7. Check if reset slice accumulator
- * 8. Wait for overall accumulator
- * 8. Issue a write request for 8-byte value at out_baddr address
+ * 5. Repeat until all inp1 data have been read
+
+ * 6. Issue 1 read request for 8-bit value at inp2_baddr address (read vector)
+ * 7. Wait for the value
+ * 8. Increment read-address for next value
+ * 9. Repeat until all inp2 data have been read
+
+ * 10. Wait for output to be calculated
+ * 11. Issue a write request for 8-byte value at out_baddr address
+ * 12. Increment write-address for next value to write
+ * 13. Check if counter (cntout) is equal to length to asser finish,
+ otherwise go to step 11
*/
class Compute(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle {
@@ -47,19 +57,24 @@
val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W)))
val mem = new VTAMemDPIMaster
})
- val sIdle :: sReadAReq :: sReadAData :: sReadBReq :: sReadBData :: sWriteReq :: sWriteData :: Nil = Enum(7)
+ implicit val p: Parameters = new TestConfig
+ val sIdle :: sReadAReq :: sReadAData :: sReadADone ::sReadBReq :: sReadBData :: sReadBDone :: sInpDone ::sWait:: sWriteReq :: sWriteData :: sWriteDone :: Nil = Enum(12)
val state = RegInit(sIdle)
val shift = io.vals(0)
val length = io.vals(1)
val rstAccum = io.vals(2)
val startDot = io.vals(3)
val cycles = RegInit(0.U(config.regBits.W))
- val reg1 = Reg(chiselTypeOf(io.mem.rd.bits))
- val reg2 = Reg(chiselTypeOf(io.mem.rd.bits))
- val cnt = Reg(UInt(config.regBits.W))
+ val mvc = Module(new MatrixVectorMultiplication)
+ val reg1 = Reg(chiselTypeOf(mvc.io.wgt.data.bits))
+ val reg2 = Reg(chiselTypeOf(mvc.io.inp.data.bits))
+ val cntwgt = Reg(UInt(config.regBits.W))
+ val cntinp = Reg(UInt(config.regBits.W))
+ val cntout = Reg(UInt(config.regBits.W))
val raddr1 = Reg(UInt(config.ptrBits.W))
val raddr2 = Reg(UInt(config.ptrBits.W))
val waddr = Reg(UInt(config.ptrBits.W))
+ val accum = Module(new Accmulator(size = p(CoreKey).blockOut, accBits = p(CoreKey).accBits))
switch (state) {
is (sIdle) {
@@ -73,7 +88,14 @@
}
is (sReadAData) {
when (io.mem.rd.valid) {
+ state := sReadADone
+ }
+ }
+ is (sReadADone) {
+ when (cntwgt === (length * length) - 1.U) {
state := sReadBReq
+ } .otherwise {
+ state := sReadAReq
}
}
is (sReadBReq) {
@@ -81,6 +103,23 @@
}
is (sReadBData) {
when (io.mem.rd.valid) {
+ state := sReadBDone
+ }
+ }
+ is (sReadBDone) {
+ when (cntinp === length-1.U) {
+ state := sInpDone
+ } .otherwise {
+ state := sReadBReq
+ }
+ }
+ // Both input is processed
+ is (sInpDone) {
+ state := sWait
+ }
+ // Wait for computation
+ is (sWait) {
+ when (accum.io.ready) {
state := sWriteReq
}
}
@@ -89,15 +128,18 @@
state := sWriteData
}
is (sWriteData) {
- when (cnt === (length - 1.U)) {
+ state := sWriteDone
+ }
+ is (sWriteDone) {
+ when (cntout === (length - 1.U)) {
state := sIdle
} .otherwise {
- state := sReadAReq
+ state := sWriteReq
}
}
}
- val last = state === sWriteData && cnt === (length - 1.U)
+ val last = state === sWriteDone && cntout === (length - 1.U)
// cycle counter
when (state === sIdle) {
@@ -114,10 +156,12 @@
raddr1 := io.ptrs(0)
raddr2 := io.ptrs(1)
waddr := io.ptrs(2)
- } .elsewhen (state === sWriteData) { // increment input array by 1-byte
+ } .elsewhen (state === sReadADone) { // increment input array by 1-byte
raddr1 := raddr1 + 1.U
+ } .elsewhen (state === sReadBDone) { // increment input array by 1-byte
raddr2 := raddr2 + 1.U
- waddr := waddr
+ } .elsewhen (state === sWriteDone) {
+ waddr := waddr + 4.U // writing 4 bytes
}
// create request
@@ -128,59 +172,70 @@
// read
when (state === sReadAData && io.mem.rd.valid) {
- reg1 := io.mem.rd.bits(7, 0)
+ reg1(cntwgt/length)(cntwgt%length) := io.mem.rd.bits(7, 0)
}
when (state === sReadBData && io.mem.rd.valid) {
- reg2 := io.mem.rd.bits(7, 0)
+ reg2(0)(cntinp) := io.mem.rd.bits(7, 0)
}
io.mem.rd.ready := state === sReadAData | state === sReadBData
+ mvc.io.inp.data.valid := state === sInpDone // 2 inputs have been processed
+ mvc.io.wgt.data.valid := state === sInpDone // 2 inputs have been processed
-
- val sliceAccum = Module(new Accumulator(63))
- val overallAccum = Module(new Accumulator(64))
+ mvc.io.wgt.data.bits <> reg1
+ mvc.io.inp.data.bits <> reg2
+ // Modify when shift operation is supported
+ mvc.io.reset := false.B
+ mvc.io.acc_i.data.valid := true.B
+ for (i <- 0 until p(CoreKey).blockOut) {
+ mvc.io.acc_i.data.bits(0)(i) := 0.U
+ }
- sliceAccum.io.valid := state === sWriteReq // 2 inputs have been processed
- sliceAccum.io.in := reg1 * reg2
- sliceAccum.io.clear := startDot
- overallAccum.io.clear := rstAccum
- overallAccum.io.valid := last // last element has been processed
- overallAccum.io.in := sliceAccum.io.sum << shift(7,0) // limit to 8 bits
+ accum.io.in := mvc.io.acc_o.data.bits
+ accum.io.shift := shift
+ accum.io.clear := rstAccum
+ accum.io.valid := mvc.io.acc_o.data.valid
// write
- io.mem.wr.valid := overallAccum.io.ready
- io.mem.wr.bits := overallAccum.io.sum
-
+ io.mem.wr.valid := state === sWriteData
+ io.mem.wr.bits := accum.io.sum(cntout)
// count read/write
when (state === sIdle) {
- cnt := 0.U
- } .elsewhen (state === sWriteData) {
- cnt := cnt + 1.U
+ cntwgt := 0.U
+ cntinp := 0.U
+ cntout := 0.U
+ } .elsewhen (state === sReadADone) {
+ cntwgt := cntwgt + 1.U
+ } .elsewhen (state === sReadBDone) {
+ cntinp := cntinp + 1.U
+ } .elsewhen (state === sWriteDone) {
+ cntout := cntout + 1.U
}
- io.finish := overallAccum.io.ready // data has been added
+ io.finish := last // data has been added
}
-
-
-class Accumulator(dataBits: Int = 8) extends Module {
+// Shift operation until supported in MVM
+class Accmulator(size: Int = 16, accBits: Int = 32) extends Module {
val io = IO(new Bundle {
val clear = Input(Bool())
val valid = Input(Bool())
val ready = Output(Bool())
- val in = Input(UInt(dataBits.W))
- val sum = Output(UInt((dataBits).W))
+ val in = Input(Vec(1, Vec(size, (UInt(accBits.W)))))
+ val shift = Input(UInt(8.W))
+ val sum = Output(Vec(size, (UInt(accBits.W))))
})
+ val reg = RegInit(VecInit(Seq.fill(size)(0.U(accBits.W))))
- val reg = RegInit(0.U((dataBits).W))
- val ready = RegNext(io.valid)
- when (io.clear) {
- reg := 0.U
- } .elsewhen (io.valid) {
- reg := reg + io.in
- }
- io.ready := ready
- io.sum := reg
+ for (i <- 0 until size) {
+ when (io.clear) {
+ reg(i) := 0.U
+ } .elsewhen(io.valid) {
+ reg(i) := reg(i) + (io.in(0)(i) << io.shift)
+ }
+ }
+ io.ready := RegNext(io.valid)
+ io.sum := reg
}
diff --git a/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala b/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala
index 6f0bdbb..10c40b5 100644
--- a/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala
+++ b/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala
@@ -35,13 +35,9 @@
* Shift value | 0x08
* Vector length | 0x0c
* Reset Accumulator | 0x10
- * Reset Dot Module | 0x14
- * Input1 pointer lsb | 0x18
- * Input1 pointer msb | 0x1c
- * Input2 pointer lsb | 0x20
- * Input2 pointer msb | 0x24
- * Output pointer lsb | 0x28
- * Output pointer msb | 0x2c
+ * Input1 pointer | 0x18
+ * Input2 pointer | 0x20
+ * Output pointer | 0x28
* -------------------------------
* ------------------------------
diff --git a/apps/gemm/src/driver.cc b/apps/gemm/src/driver.cc
index 8d380c3..24b998e 100644
--- a/apps/gemm/src/driver.cc
+++ b/apps/gemm/src/driver.cc
@@ -66,10 +66,12 @@
uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) {
uint32_t cycles;
- uint32_t length = inp1->shape[0];
- size_t size1 = (inp1->dtype.bits >> 3) * length;
+ uint32_t length = inp2->shape[0];
+ // 1 matrix 1 vector input
+ size_t size1 = (inp1->dtype.bits >> 3) * length * length;
size_t size2 = (inp2->dtype.bits >> 3) * length;
- size_t size3 = (64 >> 3);
+ // 1 vector output
+ size_t size3 = (32 >> 3) * length;
inp1_ = this->MemAlloc(size1);
inp2_ = this->MemAlloc(size2);
out_ = this->MemAlloc(size3);
@@ -115,19 +117,17 @@
void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) {
dpi_->WriteReg(0x08, shiftVal);
- dpi_->WriteReg(0x0c, length); // vector length
+ dpi_->WriteReg(0x0c, length); // tensor size
dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_));
dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_));
dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_));
dpi_->WriteReg(0x00, 0x1); // launch
- dpi_->WriteReg(0x00, 0x0); // launch
+ dpi_->WriteReg(0x00, 0x0);
if (reset == 1) {
- dpi_->WriteReg(0x10, 0x1); // reset accum
- dpi_->WriteReg(0x10, 0x0); // stop reset accum
+ dpi_->WriteReg(0x10, 0x1); // reset accumulator
+ dpi_->WriteReg(0x10, 0x0);
}
- dpi_->WriteReg(0x14, 0x1); // reset dot
- dpi_->WriteReg(0x14, 0x0); // stop reset dot
}
uint32_t WaitForCompletion() {
diff --git a/apps/gemm/tests/python/chisel_accel.py b/apps/gemm/tests/python/chisel_accel.py
index 4aed563..4666661 100644
--- a/apps/gemm/tests/python/chisel_accel.py
+++ b/apps/gemm/tests/python/chisel_accel.py
@@ -26,7 +26,7 @@
A : Vector to be sliced and packed
slice_width : slice width
-Returnsi
+Returns
---------
C: 2d matrix where each cloumn (because of bit packing) represents each bit slice of A
"""
@@ -39,7 +39,7 @@
elif dtype is np.uint16: row = 16 // slice_width
elif dtype is np.uint32: row = 32 // slice_width
elif dtype is np.uint64: row = 64 // slice_width
- else: raise ValueError("datatype " + str(dtype) + "currently not supported")
+ else: raise ValueError("datatype currently not supported")
if (row >= 8):
dtype = 'uint' + str(row)
else:
@@ -55,64 +55,88 @@
C[y][x] = (np.uint64(A[x]) >> np.uint64(slice_width * y)) & np.uint64(slice_mask)
return C
+def slice_mat(A, slice_width):
+ assert np.log2(slice_width) % 1 == 0, "only power of 2 is supported"
+ dtype = type(A[0][0])
+ row = 0
+ # currently only supports uint
+ if dtype is np.uint8: row = 8 // slice_width
+ elif dtype is np.uint16: row = 16 // slice_width
+ elif dtype is np.uint32: row = 32 // slice_width
+ elif dtype is np.uint64: row = 64 // slice_width
+ else: raise ValueError("datatype currently not supported")
+ if (row >= 8):
+ dtype = 'uint' + str(row)
+ else:
+ dtype = 'uint8'
+
+ # 3d array (bits, row, clmn)
+ C = np.zeros((row, A.shape[0], A.shape[1])).astype(dtype) # sliced and transform
+
+ # create mask
+ slice_mask = 2**(slice_width)-1
+ # slice and pack
+ for z in range(A.shape[0]):
+ C[:, z, :] = slice(A[z], slice_width)
+ return C
+
""" Matrix Multiplication Function
Parameters
----------
A : Matrix A
B: Matrix B
-w_width : weight slice width
-a_width : activation slice width
+i_width : weight slice width
+w_width : activation slice width
Returns
---------
C: result of A * B
"""
# A is a n*m matrix, B is a m*p matrix(not transposed yet)
-def matrix_multiply(A, B, w_width, a_width):
+def matrix_multiply(A, B, i_width, w_width):
assert A.shape[1] == B.shape[0], "can't perform multiplication"
BT = B.transpose()
cycles = 0
+ B_sliced = slice_mat(BT, w_width)
C = np.zeros((A.shape[0], B.shape[1])).astype('uint64')
for i in range(A.shape[0]):
- for j in range(B.shape[1]):
- # C[i, j] = A[i].dot(BT[j])
- A_sliced = slice(A[i], w_width)
- B_sliced = slice(BT[j], a_width)
+ A_sliced = slice(A[i], i_width)
+ test = test_accel(A_sliced, B_sliced, i_width, w_width)
+ C[i] = test[0]
+ cycles += test[1]
+ np.testing.assert_array_equal(C[i], compute(A_sliced, B_sliced, i_width, w_width))
+ print("PASS row " + str(i))
- C[i, j] = compute(A_sliced, B_sliced, w_width, a_width)
- test = test_accel(A_sliced, B_sliced, w_width, a_width)
- cycles += test[1]
- np.testing.assert_equal(C[i,j], A[i].astype('uint64').dot(BT[j]))
- print("PASS SW serial & parallel")
-
- np.testing.assert_equal(test[0], C[i, j])
- print("PASS SW & HW bit serial")
-
- np.testing.assert_equal(test[0], A[i].astype('uint64').dot(BT[j]))
- print("PASS SW bit parallel & HW bit parallel")
-
+ np.testing.assert_array_equal(C, np.matmul(A.astype('uint64'),B))
print("result: ")
print(C)
- print("ALL TESTS PASSED, cycles: " + str(cycles))
+ print("TEST PASSED, cycles: " + str(cycles))
return C
-""" Software Verification Function"""
-# takes 2 matrix input (sliced and packed)
-def compute(A, B, w_width, a_width):
+""" Software Verification Function
+Parameter Dimesions
+---------
+A (bits, y) and B (bits, y, x) (transposed)
+
+Takes 1 vector and 1 matrix input (sliced and packed)
+
+Returns
+---------
+Resulting vector
+"""
+def compute(A, B, i_width, w_width):
assert A.shape[1] == B.shape[1], "sliced shape not match"
# reset hardware accumulator
- accum = 0
+ accum = np.zeros(A.shape[1])
for x in range(A.shape[0]):
for y in range(B.shape[0]):
- # hardware implementation
- accum += np.uint64(A[x]).dot(np.uint64(B[y])) << np.uint64(x*w_width + y*a_width)
+ accum += np.matmul(A[x].astype('uint64'), B[y].transpose()) << np.uint64(x*i_width + y*w_width)
# get value from accumulator
return accum
-"""Testing Function for Dot Product"""
-def test_accel(A, B, w_width, a_width):
- assert A.shape[1] == B.shape[1], "sliced shape not match"
-
+"""Testing Function for Matrix Vector Multiplication"""
+def test_accel(A, B, i_width, w_width):
+ assert A.shape[1] == B.shape[2], "sliced shape not match"
dtype = A.dtype
ctx = tvm.cpu(0)
f = tsim.load_module()
@@ -126,57 +150,54 @@
a_arr.append(tvm.nd.array(list_a.astype(dtype), ctx))
for i in range(B.shape[0]):
- list_b = np.zeros(B.shape[1]).astype(dtype)
- for j in range(B.shape[1]):
- list_b[j] = B[i][j]
+ # transpose
+ list_b = np.zeros((B.shape[2], B.shape[1])).astype(dtype)
+ for j in range(B.shape[2]):
+ for k in range(B.shape[1]):
+ list_b[j][k] = B[i][j][k]
b_arr.append(tvm.nd.array(list_b.astype(dtype), ctx))
cycles = 0
-
- accum = tvm.nd.array(np.array([0]).astype("uint64"), ctx)
+ accum = tvm.nd.array(np.zeros(A.shape[1]).astype("uint32"), ctx)
for i in range(len(a_arr)):
for j in range(len(b_arr)):
- shift = np.uint8(i*w_width + j*a_width)
+ shift = np.uint8(i*i_width + j*w_width)
if i == 0 and j == 0:
- cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(1)) # reset accumulator
+ cycles += f(b_arr[j], a_arr[i], shift, accum, np.uint32(1)) # reset accumulator
else:
- cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(0)) # no reset
+ cycles += f(b_arr[j], a_arr[i], shift, accum, np.uint32(0)) # no reset
- return (accum.asnumpy()[0], cycles)
+ return (accum.asnumpy(), cycles)
""" Matrix Generator
Parameters
----------
dtype : String, datatype generated (supports only uint)
-w_width : weight bit slices(needs to be less than actual bit width)
-a_width : activation bit slices(needs to be less than actual bit width)
+i_width : weight bit slices(needs to be less than actual bit width)
+w_width : activation bit slices(needs to be less than actual bit width)
"""
-def top_test(dtype, w_width, a_width):
+def top_test(dtype, i_width, w_width):
- rmax = np.random.randint(256)
- # random matrix generation (dimension up to 8)
- rrow = np.random.randint(7) + 1
- rclmn = np.random.randint(7) + 1
- rrow2 = np.random.randint(7) + 1
- A = np.random.randint(rmax, size=(rrow,rclmn)).astype(dtype)
- B = np.random.randint(rmax, size=(rclmn,rrow2)).astype(dtype)
+ # only supports positive values (up to 2**(bits-1))
+ rmax = 127
+ # (m,16) * (16,16) GEMM
+ rrow = np.random.randint(7) + 1
+ clmn = 16
+ A = np.random.randint(rmax, size=(rrow,clmn)).astype(dtype)
+ B = np.random.randint(rmax, size=(clmn,clmn)).astype(dtype)
- print("A: ")
- print(A)
- print("\n")
- print("B: ")
- print(B)
- print("\n")
- matrix_multiply(A, B, w_width, a_width)
-
+ print("A: " + str(A))
+ print("B: " + str(B))
+ # perform GEMM
+ matrix_multiply(A, B, i_width, w_width)
if __name__ == "__main__":
tsim.init("chisel")
for i in range(1):
- # reg1 and reg2 bits in Compute.scala must be modified for slices greater than 8 bits
+ # reg1 and reg2 bits in hardware/chisel/src/main/Compute.scala must be modified for slices greater than 8 bits
if sys.argv[1] == 'serial':
- # generates a random uint8 GEMM with 2-bit(8/4) weight and 4-bit(8/2) activation
- top_test("uint8",4, 2)
+ # generates a random uint8 GEMM with 2-bit(8/4) input and 4-bit(8/2) weight
+ top_test("uint8", 4, 2)
elif sys.argv[1] == 'parallel':
- # generates a random uint8 GEMM with 8-bit weight and 8-bit activation (bit parallel)
- top_test('uint8', 1, 1)
+ # generates a random uint8 GEMM with 8-bit input and 8-bit weight (bit parallel)
+ top_test('uint8', 8, 8)