blob: c2129ea7a9443ed3ccc4d7e0696ec04ad2f8931f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package unittest
import chisel3._
import chisel3.util._
import chiseltest._
import chiseltest.iotesters._
import scala.math.pow
import unittest.util._
import vta.core._
class TestMatrixVectorMultiplication(c: MatrixVectorMultiplication) extends PeekPokeTester(c) {
/* mvm_ref
*
* This is a software function that computes dot product with a programmable shift
* This is used as a reference for the hardware
*/
def mvmRef(inp: Array[Int], wgt: Array[Array[Int]], shift: Int) : Array[Int] = {
val size = inp.length
val res = Array.fill(size) {0}
for (i <- 0 until size) {
var dot = 0
for (j <- 0 until size) {
dot += wgt(i)(j) * inp(j)
}
res(i) = dot * pow(2, shift).toInt
}
res
}
val cycles = 5
for (i <- 0 until cycles) {
// generate data based on bits
val inpGen = new RandomArray(c.size, c.inpBits)
val wgtGen = new RandomArray(c.size, c.wgtBits)
val in_a = inpGen.any
val in_b = Array.fill(c.size) { wgtGen.any }
val res = mvmRef(in_a, in_b, 0)
val inpMask = Helper.getMask(c.inpBits)
val wgtMask = Helper.getMask(c.wgtBits)
val accMask = Helper.getMask(c.accBits)
for (i <- 0 until c.size) {
poke(c.io.inp.data.bits(0)(i), in_a(i) & inpMask)
poke(c.io.acc_i.data.bits(0)(i), 0)
for (j <- 0 until c.size) {
poke(c.io.wgt.data.bits(i)(j), in_b(i)(j) & wgtMask)
}
}
poke(c.io.reset, 0)
poke(c.io.inp.data.valid, 1)
poke(c.io.wgt.data.valid, 1)
poke(c.io.acc_i.data.valid, 1)
step(1)
poke(c.io.inp.data.valid, 0)
poke(c.io.wgt.data.valid, 0)
poke(c.io.acc_i.data.valid, 0)
// wait for valid signal
while (peek(c.io.acc_o.data.valid) == BigInt(0)) {
step(1) // advance clock
}
if (peek(c.io.acc_o.data.valid) == BigInt(1)) {
for (i <- 0 until c.size) {
expect(c.io.acc_o.data.bits(0)(i), res(i) & accMask)
}
}
}
}