[VTA][Chisel] TSIM VTA Source Refactor (#4163)

* app init push

* fix on readme

* change name, add bit serial explanantion

* rm serialLoadMM, change doc

* syntax change for readme

* add parallel test functionality

* fix readme

* add python doc

* syntax

* init commit

* fix empty line

* fix typo
diff --git a/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala b/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala
index 325fce1..6bfe3e0 100644
--- a/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala
+++ b/apps/gemm/hardware/chisel/src/main/scala/accel/Compute.scala
@@ -22,21 +22,31 @@
 import chisel3._
 import chisel3.util._
 import vta.dpi._
+import vta.core._
+import vta.util.config._
+import vta.shell._
 
+class TestConfig extends Config(new CoreConfig ++ new PynqConfig)
 /** Compute
   *
   * Bit Slice GEMM:
   *
   * 1. Wait for launch to be asserted
-  * 2. Issue 2 read request for 8-byte value at inp1_baddr address and inp2_baddr address
+  * 2. Issue 1 read request for 8-bit value at inp1_baddr address (read matrix)
   * 3. Wait for the value
   * 4. Increment read-address for next value
-  * 5. Wait for sliced accumulator
-  * 6. Check if counter (cnt) is equal to length process,
-       otherwise goto step 2
-  * 7. Check if reset slice accumulator
-  * 8. Wait for overall accumulator
-  * 8. Issue a write request for 8-byte value at out_baddr address
+  * 5. Repeat until all inp1 data have been read
+
+  * 6. Issue 1 read request for 8-bit value at inp2_baddr address (read vector)
+  * 7. Wait for the value
+  * 8. Increment read-address for next value
+  * 9. Repeat until all inp2 data have been read
+
+  * 10. Wait for output to be calculated
+  * 11. Issue a write request for 8-byte value at out_baddr address
+  * 12. Increment write-address for next value to write
+  * 13. Check if counter (cntout) is equal to length to asser finish,
+       otherwise go to step 11
   */
 class Compute(implicit config: AccelConfig) extends Module {
   val io = IO(new Bundle {
@@ -47,19 +57,24 @@
     val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W)))
     val mem = new VTAMemDPIMaster
   })
-  val sIdle :: sReadAReq :: sReadAData :: sReadBReq :: sReadBData :: sWriteReq :: sWriteData :: Nil = Enum(7)
+  implicit val p: Parameters = new TestConfig
+  val sIdle :: sReadAReq :: sReadAData :: sReadADone ::sReadBReq :: sReadBData :: sReadBDone :: sInpDone ::sWait:: sWriteReq :: sWriteData :: sWriteDone :: Nil = Enum(12)
   val state = RegInit(sIdle)
   val shift = io.vals(0)
   val length = io.vals(1)
   val rstAccum = io.vals(2)
   val startDot = io.vals(3)
   val cycles = RegInit(0.U(config.regBits.W))
-  val reg1 = Reg(chiselTypeOf(io.mem.rd.bits))
-  val reg2 = Reg(chiselTypeOf(io.mem.rd.bits))
-  val cnt = Reg(UInt(config.regBits.W))
+  val mvc = Module(new MatrixVectorMultiplication)
+  val reg1 = Reg(chiselTypeOf(mvc.io.wgt.data.bits))
+  val reg2 = Reg(chiselTypeOf(mvc.io.inp.data.bits))
+  val cntwgt = Reg(UInt(config.regBits.W))
+  val cntinp = Reg(UInt(config.regBits.W))
+  val cntout = Reg(UInt(config.regBits.W))
   val raddr1 = Reg(UInt(config.ptrBits.W))
   val raddr2 = Reg(UInt(config.ptrBits.W))
   val waddr = Reg(UInt(config.ptrBits.W))
+  val accum = Module(new Accmulator(size = p(CoreKey).blockOut, accBits = p(CoreKey).accBits))
 
   switch (state) {
     is (sIdle) {
@@ -73,7 +88,14 @@
     }
     is (sReadAData) {
       when (io.mem.rd.valid) {
+        state := sReadADone
+      }   
+    }
+    is (sReadADone) {
+      when (cntwgt === (length * length) - 1.U) {
         state := sReadBReq
+      } .otherwise {
+        state := sReadAReq
       }
     }
     is (sReadBReq) {
@@ -81,6 +103,23 @@
     }
     is (sReadBData) {
       when (io.mem.rd.valid) {
+        state := sReadBDone
+      }
+    }
+    is (sReadBDone) {
+      when (cntinp === length-1.U) {
+        state := sInpDone
+      } .otherwise {
+        state := sReadBReq
+      }
+    }
+    // Both input is processed
+    is (sInpDone) {
+      state := sWait
+    }
+    // Wait for computation
+    is (sWait) {
+      when (accum.io.ready) {
         state := sWriteReq
       }
     }
@@ -89,15 +128,18 @@
       state := sWriteData
     }
     is (sWriteData) {
-      when (cnt === (length - 1.U)) {
+        state := sWriteDone
+    }
+    is (sWriteDone) {
+      when (cntout === (length - 1.U)) {
         state := sIdle
       } .otherwise {
-        state := sReadAReq
+        state := sWriteReq
       }
     }
   }
 
-  val last = state === sWriteData && cnt === (length - 1.U)
+  val last = state === sWriteDone && cntout === (length - 1.U)
 
   // cycle counter
   when (state === sIdle) {
@@ -114,10 +156,12 @@
     raddr1 := io.ptrs(0)
     raddr2 := io.ptrs(1)
     waddr := io.ptrs(2)
-  } .elsewhen (state === sWriteData) { // increment input array by 1-byte
+  } .elsewhen (state === sReadADone) { // increment input array by 1-byte
     raddr1 := raddr1 + 1.U
+  } .elsewhen (state === sReadBDone) { // increment input array by 1-byte
     raddr2 := raddr2 + 1.U
-    waddr := waddr
+  } .elsewhen (state === sWriteDone) {
+    waddr := waddr + 4.U // writing 4 bytes
   }
 
   // create request
@@ -128,59 +172,70 @@
 
   // read
   when (state === sReadAData && io.mem.rd.valid) {
-    reg1 := io.mem.rd.bits(7, 0)
+    reg1(cntwgt/length)(cntwgt%length) := io.mem.rd.bits(7, 0)
   }
 
   when (state === sReadBData && io.mem.rd.valid) {
-    reg2 := io.mem.rd.bits(7, 0)
+    reg2(0)(cntinp) := io.mem.rd.bits(7, 0)
   }
 
   io.mem.rd.ready := state === sReadAData | state === sReadBData
+  mvc.io.inp.data.valid := state === sInpDone // 2 inputs have been processed 
+  mvc.io.wgt.data.valid := state === sInpDone // 2 inputs have been processed 
 
-  
-  val sliceAccum = Module(new Accumulator(63))
-  val overallAccum = Module(new Accumulator(64))
+  mvc.io.wgt.data.bits <> reg1
+  mvc.io.inp.data.bits <> reg2
+  // Modify when shift operation is supported
+  mvc.io.reset := false.B
+  mvc.io.acc_i.data.valid := true.B
+  for (i <- 0 until p(CoreKey).blockOut) {
+    mvc.io.acc_i.data.bits(0)(i) := 0.U
+  }
 
-  sliceAccum.io.valid := state === sWriteReq // 2 inputs have been processed 
-  sliceAccum.io.in := reg1 * reg2
-  sliceAccum.io.clear := startDot
-  overallAccum.io.clear := rstAccum
-  overallAccum.io.valid := last // last element has been processed
-  overallAccum.io.in := sliceAccum.io.sum << shift(7,0) // limit to 8 bits 
+  accum.io.in := mvc.io.acc_o.data.bits
+  accum.io.shift := shift
+  accum.io.clear := rstAccum
+  accum.io.valid := mvc.io.acc_o.data.valid
 
   // write
-  io.mem.wr.valid := overallAccum.io.ready 
-  io.mem.wr.bits := overallAccum.io.sum
-  
+  io.mem.wr.valid := state === sWriteData 
+  io.mem.wr.bits := accum.io.sum(cntout)
 
   // count read/write
   when (state === sIdle) {
-    cnt := 0.U
-  } .elsewhen (state === sWriteData) {
-    cnt := cnt + 1.U
+    cntwgt := 0.U
+    cntinp := 0.U
+    cntout := 0.U
+  } .elsewhen (state === sReadADone) {
+    cntwgt := cntwgt + 1.U
+  } .elsewhen (state === sReadBDone) {
+    cntinp := cntinp + 1.U
+  } .elsewhen (state === sWriteDone) {
+    cntout := cntout + 1.U
   }
 
-  io.finish := overallAccum.io.ready // data has been added
+  io.finish := last // data has been added
 }
-
-
-class Accumulator(dataBits: Int = 8) extends Module {
+// Shift operation until supported in MVM
+class Accmulator(size: Int = 16, accBits: Int = 32) extends Module {
   val io = IO(new Bundle {
     val clear = Input(Bool())
     val valid = Input(Bool())
     val ready = Output(Bool())
-    val in = Input(UInt(dataBits.W))
-    val sum = Output(UInt((dataBits).W))
+    val in = Input(Vec(1, Vec(size, (UInt(accBits.W)))))
+    val shift = Input(UInt(8.W))
+    val sum = Output(Vec(size, (UInt(accBits.W))))
   })
+    val reg = RegInit(VecInit(Seq.fill(size)(0.U(accBits.W))))
 
-  val reg = RegInit(0.U((dataBits).W))
-  val ready = RegNext(io.valid)
-  when (io.clear) {
-    reg := 0.U
-  } .elsewhen (io.valid) {
-    reg := reg + io.in
-  } 
-  io.ready := ready
-  io.sum := reg
+    for (i <- 0 until size) {
+      when (io.clear) {
+        reg(i) := 0.U
+      } .elsewhen(io.valid) {
+        reg(i) := reg(i) + (io.in(0)(i) << io.shift)
+      }
+    }
+    io.ready := RegNext(io.valid)
+    io.sum := reg
 }
 
diff --git a/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala b/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala
index 6f0bdbb..10c40b5 100644
--- a/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala
+++ b/apps/gemm/hardware/chisel/src/main/scala/accel/RegFile.scala
@@ -35,13 +35,9 @@
   *  Shift value             | 0x08
   *  Vector length           | 0x0c
   *  Reset Accumulator       | 0x10
-  *  Reset Dot Module        | 0x14
-  *  Input1 pointer lsb      | 0x18
-  *  Input1 pointer msb      | 0x1c
-  *  Input2 pointer lsb      | 0x20
-  *  Input2 pointer msb      | 0x24
-  *  Output pointer lsb      | 0x28
-  *  Output pointer msb      | 0x2c
+  *  Input1 pointer          | 0x18
+  *  Input2 pointer          | 0x20
+  *  Output pointer          | 0x28
   * -------------------------------
 
   * ------------------------------
diff --git a/apps/gemm/src/driver.cc b/apps/gemm/src/driver.cc
index 8d380c3..24b998e 100644
--- a/apps/gemm/src/driver.cc
+++ b/apps/gemm/src/driver.cc
@@ -66,10 +66,12 @@
 
   uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) {
     uint32_t cycles;
-    uint32_t length = inp1->shape[0];
-    size_t size1 = (inp1->dtype.bits >> 3) * length;
+    uint32_t length = inp2->shape[0];
+    // 1 matrix 1 vector input
+    size_t size1 = (inp1->dtype.bits >> 3) * length * length;
     size_t size2 = (inp2->dtype.bits >> 3) * length;
-    size_t size3 = (64 >> 3);
+    // 1 vector output
+    size_t size3 = (32 >> 3) * length;
     inp1_ = this->MemAlloc(size1);
     inp2_ = this->MemAlloc(size2);
     out_ = this->MemAlloc(size3);
@@ -115,19 +117,17 @@
 
   void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) {
     dpi_->WriteReg(0x08, shiftVal);
-    dpi_->WriteReg(0x0c, length); // vector length
+    dpi_->WriteReg(0x0c, length); // tensor size
     dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_));
     dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_));
     dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_));
     dpi_->WriteReg(0x00, 0x1); // launch
-    dpi_->WriteReg(0x00, 0x0); // launch
+    dpi_->WriteReg(0x00, 0x0); 
 
     if (reset == 1) {
-      dpi_->WriteReg(0x10, 0x1); // reset accum
-      dpi_->WriteReg(0x10, 0x0); // stop reset accum
+      dpi_->WriteReg(0x10, 0x1); // reset accumulator
+      dpi_->WriteReg(0x10, 0x0); 
     }
-    dpi_->WriteReg(0x14, 0x1); // reset dot
-    dpi_->WriteReg(0x14, 0x0); // stop reset dot
   }
 
   uint32_t WaitForCompletion() {
diff --git a/apps/gemm/tests/python/chisel_accel.py b/apps/gemm/tests/python/chisel_accel.py
index 4aed563..4666661 100644
--- a/apps/gemm/tests/python/chisel_accel.py
+++ b/apps/gemm/tests/python/chisel_accel.py
@@ -26,7 +26,7 @@
 A : Vector to be sliced and packed
 slice_width : slice width
 
-Returnsi
+Returns
 ---------
 C: 2d matrix where each cloumn (because of bit packing) represents each bit slice of A
 """
@@ -39,7 +39,7 @@
     elif dtype is np.uint16: row = 16 // slice_width
     elif dtype is np.uint32: row = 32 // slice_width
     elif dtype is np.uint64: row = 64 // slice_width
-    else: raise ValueError("datatype " + str(dtype) + "currently not supported")
+    else: raise ValueError("datatype currently not supported")
     if (row >= 8):
         dtype = 'uint' + str(row)
     else:
@@ -55,64 +55,88 @@
             C[y][x] = (np.uint64(A[x]) >> np.uint64(slice_width * y)) & np.uint64(slice_mask)
     return C
 
+def slice_mat(A, slice_width):
+    assert np.log2(slice_width) % 1 == 0, "only power of 2 is supported"
+    dtype = type(A[0][0]) 
+    row = 0
+    # currently only supports uint
+    if dtype is np.uint8: row = 8 // slice_width
+    elif dtype is np.uint16: row = 16 // slice_width
+    elif dtype is np.uint32: row = 32 // slice_width
+    elif dtype is np.uint64: row = 64 // slice_width
+    else: raise ValueError("datatype currently not supported")
+    if (row >= 8):
+        dtype = 'uint' + str(row)
+    else:
+        dtype = 'uint8'
+
+    # 3d array (bits, row, clmn)
+    C = np.zeros((row, A.shape[0], A.shape[1])).astype(dtype) # sliced and transform 
+
+    # create mask
+    slice_mask = 2**(slice_width)-1
+    # slice and pack
+    for z in range(A.shape[0]):
+        C[:, z, :] = slice(A[z], slice_width)
+    return C
+
 """ Matrix Multiplication Function
 Parameters
 ----------
 A : Matrix A
 B: Matrix B
-w_width : weight slice width
-a_width : activation slice width
+i_width : weight slice width
+w_width : activation slice width
 
 Returns
 ---------
 C: result of A * B
 """
 # A is a n*m matrix, B is a m*p matrix(not transposed yet)
-def matrix_multiply(A, B, w_width, a_width):
+def matrix_multiply(A, B, i_width, w_width):
     assert A.shape[1] == B.shape[0], "can't perform multiplication"
     BT = B.transpose()
     cycles = 0
+    B_sliced = slice_mat(BT, w_width)
     C = np.zeros((A.shape[0], B.shape[1])).astype('uint64')
     for i in range(A.shape[0]):
-        for j in range(B.shape[1]):
-            # C[i, j] = A[i].dot(BT[j])
-            A_sliced = slice(A[i], w_width)
-            B_sliced = slice(BT[j], a_width)
+        A_sliced = slice(A[i], i_width)
+        test = test_accel(A_sliced, B_sliced, i_width, w_width)
+        C[i] = test[0]
+        cycles += test[1]
+        np.testing.assert_array_equal(C[i], compute(A_sliced, B_sliced, i_width, w_width))
+        print("PASS row " + str(i))
 
-            C[i, j] = compute(A_sliced, B_sliced, w_width, a_width)
-            test = test_accel(A_sliced, B_sliced, w_width, a_width)
-            cycles += test[1]
-            np.testing.assert_equal(C[i,j], A[i].astype('uint64').dot(BT[j]))
-            print("PASS SW serial & parallel")
-
-            np.testing.assert_equal(test[0], C[i, j])
-            print("PASS SW & HW bit serial")
-
-            np.testing.assert_equal(test[0], A[i].astype('uint64').dot(BT[j]))
-            print("PASS SW bit parallel & HW bit parallel")
-
+    np.testing.assert_array_equal(C, np.matmul(A.astype('uint64'),B))
     print("result: ")
     print(C)
-    print("ALL TESTS PASSED, cycles: " + str(cycles))
+    print("TEST PASSED, cycles: " + str(cycles))
     return C
 
-""" Software Verification Function"""
-# takes 2 matrix input (sliced and packed)
-def compute(A, B, w_width, a_width):
+""" Software Verification Function
+Parameter Dimesions
+---------
+A (bits, y) and B (bits, y, x) (transposed)
+
+Takes 1 vector and 1 matrix input (sliced and packed)
+
+Returns
+---------
+Resulting vector
+"""
+def compute(A, B, i_width, w_width):
     assert A.shape[1] == B.shape[1], "sliced shape not match"
     # reset hardware accumulator
-    accum = 0
+    accum = np.zeros(A.shape[1])
     for x in range(A.shape[0]):
         for y in range(B.shape[0]):
-            # hardware implementation
-            accum += np.uint64(A[x]).dot(np.uint64(B[y])) << np.uint64(x*w_width + y*a_width)
+            accum += np.matmul(A[x].astype('uint64'), B[y].transpose()) << np.uint64(x*i_width + y*w_width)
     # get value from accumulator
     return accum
 
-"""Testing Function for Dot Product"""
-def test_accel(A, B, w_width, a_width):
-    assert A.shape[1] == B.shape[1], "sliced shape not match"
-
+"""Testing Function for Matrix Vector Multiplication"""
+def test_accel(A, B, i_width, w_width):
+    assert A.shape[1] == B.shape[2], "sliced shape not match"
     dtype = A.dtype
     ctx = tvm.cpu(0)
     f = tsim.load_module()
@@ -126,57 +150,54 @@
         a_arr.append(tvm.nd.array(list_a.astype(dtype), ctx))
 
     for i in range(B.shape[0]):
-        list_b = np.zeros(B.shape[1]).astype(dtype)
-        for j in range(B.shape[1]):
-            list_b[j] = B[i][j]
+        # transpose
+        list_b = np.zeros((B.shape[2], B.shape[1])).astype(dtype)
+        for j in range(B.shape[2]):
+            for k in range(B.shape[1]):
+                list_b[j][k] = B[i][j][k]
         b_arr.append(tvm.nd.array(list_b.astype(dtype), ctx))
 
     cycles = 0
-
-    accum = tvm.nd.array(np.array([0]).astype("uint64"), ctx)
+    accum = tvm.nd.array(np.zeros(A.shape[1]).astype("uint32"), ctx)
     for i in range(len(a_arr)):
         for j in range(len(b_arr)):
-            shift = np.uint8(i*w_width + j*a_width)
+            shift = np.uint8(i*i_width + j*w_width)
             if i == 0 and j == 0: 
-                cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(1)) # reset accumulator
+                cycles += f(b_arr[j], a_arr[i], shift, accum, np.uint32(1)) # reset accumulator
             else: 
-                cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(0)) # no reset
+                cycles += f(b_arr[j], a_arr[i], shift, accum, np.uint32(0)) # no reset
 
-    return (accum.asnumpy()[0], cycles)
+    return (accum.asnumpy(), cycles)
 
 """ Matrix Generator
 Parameters
 ----------     
 dtype : String, datatype generated (supports only uint)
-w_width : weight bit slices(needs to be less than actual bit width)
-a_width : activation bit slices(needs to be less than actual bit width)
+i_width : weight bit slices(needs to be less than actual bit width)
+w_width : activation bit slices(needs to be less than actual bit width)
 """
-def top_test(dtype, w_width, a_width):
+def top_test(dtype, i_width, w_width):
 
-    rmax = np.random.randint(256)
-    # random matrix generation (dimension up to 8)
-    rrow = np.random.randint(7) + 1
-    rclmn = np.random.randint(7) + 1
-    rrow2 = np.random.randint(7) + 1 
-    A = np.random.randint(rmax, size=(rrow,rclmn)).astype(dtype)
-    B = np.random.randint(rmax, size=(rclmn,rrow2)).astype(dtype)
+    # only supports positive values (up to 2**(bits-1))
+    rmax = 127 
+    # (m,16) * (16,16) GEMM
+    rrow = np.random.randint(7) + 1 
+    clmn = 16
+    A = np.random.randint(rmax, size=(rrow,clmn)).astype(dtype)
+    B = np.random.randint(rmax, size=(clmn,clmn)).astype(dtype)
 
-    print("A: ")
-    print(A)
-    print("\n")
-    print("B: ")
-    print(B)
-    print("\n")
-    matrix_multiply(A, B, w_width, a_width)
-
+    print("A: " + str(A))
+    print("B: " + str(B))
+    # perform GEMM
+    matrix_multiply(A, B, i_width, w_width)
 
 if __name__ == "__main__":
     tsim.init("chisel")
     for i in range(1):
-        # reg1 and reg2 bits in Compute.scala must be modified for slices greater than 8 bits
+        # reg1 and reg2 bits in hardware/chisel/src/main/Compute.scala must be modified for slices greater than 8 bits
         if sys.argv[1] == 'serial':
-          # generates a random uint8 GEMM with 2-bit(8/4) weight and 4-bit(8/2) activation 
-          top_test("uint8",4, 2)
+          # generates a random uint8 GEMM with 2-bit(8/4) input and 4-bit(8/2) weight 
+          top_test("uint8", 4, 2)
         elif sys.argv[1] == 'parallel':
-          # generates a random uint8 GEMM with 8-bit weight and 8-bit activation (bit parallel) 
-          top_test('uint8', 1, 1)
+          # generates a random uint8 GEMM with 8-bit input and 8-bit weight (bit parallel) 
+          top_test('uint8', 8, 8)