/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package vta.shell

import chisel3._
import chisel3.util._
import vta.util.config._
import vta.util.genericbundle._
import vta.interface.axi._

/** VCR parameters.
 *
 * These parameters are used on VCR interfaces and modules.
 */
case class VCRParams() {
  val nCtrl = 1
  val nECnt = 1
  val nVals = 1
  val nPtrs = 6
  val nUCnt = 1
  val regBits = 32
}

/** VCRBase. Parametrize base class. */
abstract class VCRBase(implicit p: Parameters) extends GenericParameterizedBundle(p)

/** VCRMaster.
 *
 * This is the master interface used by VCR in the VTAShell to control
 * the Core unit.
 */
class VCRMaster(implicit p: Parameters) extends VCRBase {
  val vp = p(ShellKey).vcrParams
  val mp = p(ShellKey).memParams
  val launch = Output(Bool())
  val finish = Input(Bool())
  val ecnt = Vec(vp.nECnt, Flipped(ValidIO(UInt(vp.regBits.W))))
  val vals = Output(Vec(vp.nVals, UInt(vp.regBits.W)))
  val ptrs = Output(Vec(vp.nPtrs, UInt(mp.addrBits.W)))
  val ucnt = Vec(vp.nUCnt, Flipped(ValidIO(UInt(vp.regBits.W))))
}

/** VCRClient.
 *
 * This is the client interface used by the Core module to communicate
 * to the VCR in the VTAShell.
 */
class VCRClient(implicit p: Parameters) extends VCRBase {
  val vp = p(ShellKey).vcrParams
  val mp = p(ShellKey).memParams
  val launch = Input(Bool())
  val finish = Output(Bool())
  val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W)))
  val vals = Input(Vec(vp.nVals, UInt(vp.regBits.W)))
  val ptrs = Input(Vec(vp.nPtrs, UInt(mp.addrBits.W)))
  val ucnt = Vec(vp.nUCnt, ValidIO(UInt(vp.regBits.W)))
}

/** VTA Control Registers (VCR).
 *
 * This unit provides control registers (32 and 64 bits) to be used by a control'
 * unit, typically a host processor. These registers are read-only by the core
 * at the moment but this will likely change once we add support to general purpose
 * registers that could be used as event counters by the Core unit.
 */
class VCR(implicit p: Parameters) extends Module {
  val io = IO(new Bundle {
    val host = new AXILiteClient(p(ShellKey).hostParams)
    val vcr = new VCRMaster
  })

  val vp = p(ShellKey).vcrParams
  val mp = p(ShellKey).memParams
  val hp = p(ShellKey).hostParams

  // Write control (AW, W, B)
  val waddr = RegInit("h_ffff".U(hp.addrBits.W)) // init with invalid address
  val wdata = io.host.w.bits.data
  val sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(3)
  val wstate = RegInit(sWriteAddress)

  // read control (AR, R)
  val sReadAddress :: sReadData :: Nil = Enum(2)
  val rstate = RegInit(sReadAddress)
  val rdata = RegInit(0.U(vp.regBits.W))

  // registers
  val nPtrs = if (mp.addrBits == 32) vp.nPtrs else 2 * vp.nPtrs
  val nTotal = vp.nCtrl + vp.nECnt + vp.nVals + nPtrs + vp.nUCnt

  val reg = Seq.fill(nTotal)(RegInit(0.U(vp.regBits.W)))
  val addr = Seq.tabulate(nTotal)(_ * 4)
  val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
  val eo = vp.nCtrl
  val vo = eo + vp.nECnt
  val po = vo + vp.nVals
  val uo = po + nPtrs

  switch(wstate) {
    is(sWriteAddress) {
      when(io.host.aw.valid) {
        wstate := sWriteData
      }
    }
    is(sWriteData) {
      when(io.host.w.valid) {
        wstate := sWriteResponse
      }
    }
    is(sWriteResponse) {
      when(io.host.b.ready) {
        wstate := sWriteAddress
      }
    }
  }

  when(io.host.aw.fire) { waddr := io.host.aw.bits.addr }

  io.host.aw.ready := wstate === sWriteAddress
  io.host.w.ready := wstate === sWriteData
  io.host.b.valid := wstate === sWriteResponse
  io.host.b.bits.resp := 0.U

  switch(rstate) {
    is(sReadAddress) {
      when(io.host.ar.valid) {
        rstate := sReadData
      }
    }
    is(sReadData) {
      when(io.host.r.ready) {
        rstate := sReadAddress
      }
    }
  }

  io.host.ar.ready := rstate === sReadAddress
  io.host.r.valid := rstate === sReadData
  io.host.r.bits.data := rdata
  io.host.r.bits.resp := 0.U

  when(io.vcr.finish) {
    reg(0) := "b_10".U
  }.elsewhen(io.host.w.fire && addr(0).U === waddr) {
    reg(0) := wdata
  }

  for (i <- 0 until vp.nECnt) {
    when(io.vcr.ecnt(i).valid) {
      reg(eo + i) := io.vcr.ecnt(i).bits
    }.elsewhen(io.host.w.fire && addr(eo + i).U === waddr) {
      reg(eo + i) := wdata
    }
  }

  for (i <- 0 until (vp.nVals + nPtrs)) {
    when(io.host.w.fire && addr(vo + i).U === waddr) {
      reg(vo + i) := wdata
    }
  }

  when(io.host.ar.fire) {
    rdata := MuxLookup(io.host.ar.bits.addr, 0.U, reg_map)
  }

  io.vcr.launch := reg(0)(0)

  for (i <- 0 until vp.nVals) {
    io.vcr.vals(i) := reg(vo + i)
  }

  if (mp.addrBits == 32) { // 32-bit pointers
    for (i <- 0 until nPtrs) {
      io.vcr.ptrs(i) := reg(po + i)
    }
  } else { // 64-bits pointers
    for (i <- 0 until (nPtrs / 2)) {
      io.vcr.ptrs(i) := Cat(reg(po + 2 * i + 1), reg(po + 2 * i))
    }
  }

  for (i <- 0 until vp.nUCnt) {
    when(io.vcr.ucnt(i).valid) {
      reg(uo + i) := io.vcr.ucnt(i).bits
    }.elsewhen(io.host.w.fire && addr(uo + i).U === waddr) {
      reg(uo + i) := wdata
    }
  }
}
