blob: 81abb8e48cde9755a5deac5e8e40141a0aa516f6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
/** ALU datapath */
class Alu(implicit p: Parameters) extends Module {
val aluBits = p(CoreKey).accBits
val io = IO(new Bundle {
val opcode = Input(UInt(C_ALU_OP_BITS.W))
val a = Input(SInt(aluBits.W))
val b = Input(SInt(aluBits.W))
val y = Output(SInt(aluBits.W))
})
// FIXME: the following three will change once we support properly SHR and SHL
val ub = io.b.asUInt
val width = log2Ceil(aluBits)
val m = ~ub(width - 1, 0) + 1.U
val n = ub(width - 1, 0)
// opcode - min:0, max:1, add:2, shr:3, shl:4
val fop = Seq(Mux(io.a < io.b, io.a, io.b), Mux(io.a < io.b, io.b, io.a),
io.a + io.b, io.a >> n, io.a << m)
val opmux = Seq.tabulate(ALU_OP_NUM)(i => ALU_OP(i) -> fop(i))
io.y := MuxLookup(io.opcode, io.a, opmux)
}
/** Pipelined ALU */
class AluReg(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val opcode = Input(UInt(C_ALU_OP_BITS.W))
val a = Flipped(ValidIO(UInt(p(CoreKey).accBits.W)))
val b = Flipped(ValidIO(UInt(p(CoreKey).accBits.W)))
val y = ValidIO(UInt(p(CoreKey).accBits.W))
})
val alu = Module(new Alu)
val rA = RegEnable(io.a.bits, io.a.valid)
val rB = RegEnable(io.b.bits, io.b.valid)
val valid = RegNext(io.b.valid)
alu.io.opcode := io.opcode
// register input
alu.io.a := rA.asSInt
alu.io.b := rB.asSInt
// output
io.y.valid := valid
io.y.bits := alu.io.y.asUInt
}
/** Vector of pipeline ALUs */
class AluVector(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val opcode = Input(UInt(C_ALU_OP_BITS.W))
val acc_a = new TensorMasterData(tensorType = "acc")
val acc_b = new TensorMasterData(tensorType = "acc")
val acc_y = new TensorClientData(tensorType = "acc")
val out = new TensorClientData(tensorType = "out")
})
val blockOut = p(CoreKey).blockOut
val f = Seq.fill(blockOut)(Module(new AluReg))
val valid = Wire(Vec(blockOut, Bool()))
for (i <- 0 until blockOut) {
f(i).io.opcode := io.opcode
f(i).io.a.valid := io.acc_a.data.valid
f(i).io.a.bits := io.acc_a.data.bits(0)(i)
f(i).io.b.valid := io.acc_b.data.valid
f(i).io.b.bits := io.acc_b.data.bits(0)(i)
valid(i) := f(i).io.y.valid
io.acc_y.data.bits(0)(i) := f(i).io.y.bits
io.out.data.bits(0)(i) := f(i).io.y.bits
}
io.acc_y.data.valid := valid.asUInt.andR
io.out.data.valid := valid.asUInt.andR
}
/** TensorAlu.
*
* This unit instantiate the ALU vector unit (AluVector) and go over the
* micro-ops (uops) which are used to read the source operands (vectors)
* from the acc-scratchpad and then they are written back the same
* acc-scratchpad.
*/
class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
val aluBits = p(CoreKey).accBits
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
val inst = Input(UInt(INST_BITS.W))
val uop = new UopMaster
val acc = new TensorMaster(tensorType = "acc")
val out = new TensorMaster(tensorType = "out")
})
val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil =
Enum(6)
val state = RegInit(sIdle)
val alu = Module(new AluVector)
val dec = io.inst.asTypeOf(new AluDecode)
val uop_idx = Reg(chiselTypeOf(dec.uop_end))
val uop_end = dec.uop_end
val uop_dst = Reg(chiselTypeOf(dec.uop_end))
val uop_src = Reg(chiselTypeOf(dec.uop_end))
val cnt_o = Reg(chiselTypeOf(dec.lp_0))
val dst_o = Reg(chiselTypeOf(dec.uop_end))
val src_o = Reg(chiselTypeOf(dec.uop_end))
val cnt_i = Reg(chiselTypeOf(dec.lp_1))
val dst_i = Reg(chiselTypeOf(dec.uop_end))
val src_i = Reg(chiselTypeOf(dec.uop_end))
val done =
state === sExe &
alu.io.out.data.valid &
(cnt_o === dec.lp_0 - 1.U) &
(cnt_i === dec.lp_1 - 1.U) &
(uop_idx === uop_end - 1.U)
switch(state) {
is(sIdle) {
when(io.start) {
state := sReadUop
}
}
is(sReadUop) {
state := sComputeIdx
}
is(sComputeIdx) {
state := sReadTensorA
}
is(sReadTensorA) {
state := sReadTensorB
}
is(sReadTensorB) {
state := sExe
}
is(sExe) {
when(alu.io.out.data.valid) {
when(
(cnt_o === dec.lp_0 - 1.U) &&
(cnt_i === dec.lp_1 - 1.U) &&
(uop_idx === uop_end - 1.U)) {
state := sIdle
}.otherwise {
state := sReadUop
}
}
}
}
when(
state === sIdle ||
(state === sExe &&
alu.io.out.data.valid &&
uop_idx === uop_end - 1.U)) {
uop_idx := dec.uop_begin
}.elsewhen(state === sExe && alu.io.out.data.valid) {
uop_idx := uop_idx + 1.U
}
when(state === sIdle) {
cnt_o := 0.U
dst_o := 0.U
src_o := 0.U
}.elsewhen(
state === sExe &&
alu.io.out.data.valid &&
uop_idx === uop_end - 1.U &&
cnt_i === dec.lp_1 - 1.U) {
cnt_o := cnt_o + 1.U
dst_o := dst_o + dec.dst_0
src_o := src_o + dec.src_0
}
when(state === sIdle) {
cnt_i := 0.U
dst_i := 0.U
src_i := 0.U
}.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
cnt_i := 0.U
dst_i := dst_o
src_i := src_o
}.elsewhen(state === sExe && alu.io.out.data.valid && uop_idx === uop_end - 1.U) {
cnt_i := cnt_i + 1.U
dst_i := dst_i + dec.dst_1
src_i := src_i + dec.src_1
}
when(state === sComputeIdx && io.uop.data.valid) {
uop_dst := io.uop.data.bits.u0 + dst_i
uop_src := io.uop.data.bits.u1 + src_i
}
// uop
io.uop.idx.valid := state === sReadUop
io.uop.idx.bits := uop_idx
// acc (input)
io.acc.rd.idx.valid := state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm)
io.acc.rd.idx.bits := Mux(state === sReadTensorA, uop_dst, uop_src)
// imm
val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
tensorImm.data.valid := state === sReadTensorB
tensorImm.data.bits.foreach { b =>
b.foreach { c =>
c := Mux(dec.alu_imm(C_ALU_IMM_BITS - 1),
Cat(-1.S((aluBits - C_ALU_IMM_BITS).W), dec.alu_imm), dec.alu_imm)
}
}
// alu
val isSHR = dec.alu_op === ALU_OP(3)
val isSHL = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
// opcode - min:0, max:1, add:2, shr:3, shl:4
val fixme_alu_op = Cat(isSHL, Mux(isSHL, 0.U, dec.alu_op(1, 0)))
alu.io.opcode := fixme_alu_op
alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
alu.io.acc_a.data.bits <> io.acc.rd.data.bits
alu.io.acc_b.data.valid := Mux(dec.alu_use_imm,
tensorImm.data.valid,
io.acc.rd.data.valid & state === sExe)
alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm,
tensorImm.data.bits,
io.acc.rd.data.bits)
// acc (output)
io.acc.wr.valid := alu.io.acc_y.data.valid
io.acc.wr.bits.idx := uop_dst
io.acc.wr.bits.data <> alu.io.acc_y.data.bits
// out
io.out.wr.valid := alu.io.out.data.valid
io.out.wr.bits.idx := uop_dst
io.out.wr.bits.data <> alu.io.out.data.bits
io.out.tieoffRead() // write-only
io.done := done
if (debug) {
when(state === sReadUop) {
printf("[TensorAlu] [uop] idx:%x\n", uop_idx)
}
when(state === sReadTensorA) {
printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src)
}
when(state === sIdle && io.start) {
printf(p"[TensorAlu] decode:$dec\n")
}
alu.io.acc_a.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach {
case (elem, i) =>
when(alu.io.acc_a.data.valid) {
printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem)
}
}
}
alu.io.acc_b.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach {
case (elem, i) =>
when(alu.io.acc_b.data.valid) {
printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem)
}
}
}
alu.io.acc_y.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach {
case (elem, i) =>
when(alu.io.acc_y.data.valid) {
printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem)
}
}
}
alu.io.out.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach {
case (elem, i) =>
when(alu.io.out.data.valid) {
printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem)
}
}
}
}
}