| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package unittest |
| |
| import chisel3._ |
| import chisel3.util._ |
| import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester} |
| import scala.math.pow |
| import unittest.util._ |
| import vta.core._ |
| |
| class TestMatrixVectorMultiplication(c: MatrixVectorMultiplication) extends PeekPokeTester(c) { |
| |
| /* mvm_ref |
| * |
| * This is a software function that computes dot product with a programmable shift |
| * This is used as a reference for the hardware |
| */ |
| def mvmRef(inp: Array[Int], wgt: Array[Array[Int]], shift: Int) : Array[Int] = { |
| val size = inp.length |
| val res = Array.fill(size) {0} |
| for (i <- 0 until size) { |
| var dot = 0 |
| for (j <- 0 until size) { |
| dot += wgt(i)(j) * inp(j) |
| } |
| res(i) = dot * pow(2, shift).toInt |
| } |
| res |
| } |
| |
| val cycles = 5 |
| for (i <- 0 until cycles) { |
| // generate data based on bits |
| val inpGen = new RandomArray(c.size, c.inpBits) |
| val wgtGen = new RandomArray(c.size, c.wgtBits) |
| val in_a = inpGen.any |
| val in_b = Array.fill(c.size) { wgtGen.any } |
| val res = mvmRef(in_a, in_b, 0) |
| val inpMask = Helper.getMask(c.inpBits) |
| val wgtMask = Helper.getMask(c.wgtBits) |
| val accMask = Helper.getMask(c.accBits) |
| |
| for (i <- 0 until c.size) { |
| poke(c.io.inp.data.bits(0)(i), in_a(i) & inpMask) |
| poke(c.io.acc_i.data.bits(0)(i), 0) |
| for (j <- 0 until c.size) { |
| poke(c.io.wgt.data.bits(i)(j), in_b(i)(j) & wgtMask) |
| } |
| } |
| |
| poke(c.io.reset, 0) |
| |
| poke(c.io.inp.data.valid, 1) |
| poke(c.io.wgt.data.valid, 1) |
| poke(c.io.acc_i.data.valid, 1) |
| |
| step(1) |
| |
| poke(c.io.inp.data.valid, 0) |
| poke(c.io.wgt.data.valid, 0) |
| poke(c.io.acc_i.data.valid, 0) |
| |
| // wait for valid signal |
| while (peek(c.io.acc_o.data.valid) == BigInt(0)) { |
| step(1) // advance clock |
| } |
| if (peek(c.io.acc_o.data.valid) == BigInt(1)) { |
| for (i <- 0 until c.size) { |
| expect(c.io.acc_o.data.bits(0)(i), res(i) & accMask) |
| } |
| } |
| } |
| } |