[VTA] TSIM improvements and fixes (#3505)
* add tsim init function
* add sim device
* test wait and resume
* launch simulation thread from DPILoader
* add VTASimDPI module to handle all simulation related stuff
* test tsim init
* move exit to simdpi module
* update vta driver
* add chisel DPI module
* get back simshell
* update vta to support dpi sim
* update unittests
* add tsim to integration-conv2d test
* run resnet on tsim
* remove max-cycles
* match tsim counters with sim counters
* use env in simulator to switch between sim and tsim
* update unittest
* rollback conv2d test
* update resnet
* add stats to matrix multiply
* add stats
* print stats after assert
* update other tests
* add stats to gemm
* add return and remove unused libs
* add missing arg
* return lib
* update comments for linter
* add more comments to VTASimDPI module
* remove trailing spaces
* remove trailing spaces
diff --git a/apps/tsim_example/hardware/chisel/src/test/scala/dut/TestAccel.scala b/apps/tsim_example/hardware/chisel/src/test/scala/dut/TestAccel.scala
index 2c02ff3..d931620 100644
--- a/apps/tsim_example/hardware/chisel/src/test/scala/dut/TestAccel.scala
+++ b/apps/tsim_example/hardware/chisel/src/test/scala/dut/TestAccel.scala
@@ -20,6 +20,7 @@
package test
import chisel3._
+import chisel3.experimental.MultiIOModule
import vta.dpi._
import accel._
@@ -28,19 +29,23 @@
* Instantiate Host and Memory DPI modules.
*
*/
-class VTASimShell extends Module {
- val io = IO(new Bundle {
- val host = new VTAHostDPIMaster
- val mem = new VTAMemDPIClient
- })
- val host = Module(new VTAHostDPI)
- val mem = Module(new VTAMemDPI)
- mem.io.dpi <> io.mem
- mem.io.reset := reset
- mem.io.clock := clock
- io.host <> host.io.dpi
- host.io.reset := reset
- host.io.clock := clock
+class VTASimShell extends MultiIOModule {
+ val host = IO(new VTAHostDPIMaster)
+ val mem = IO(new VTAMemDPIClient)
+ val sim_clock = IO(Input(Clock()))
+ val sim_wait = IO(Output(Bool()))
+ val mod_sim = Module(new VTASimDPI)
+ val mod_host = Module(new VTAHostDPI)
+ val mod_mem = Module(new VTAMemDPI)
+ mod_mem.io.clock := clock
+ mod_mem.io.reset := reset
+ mod_mem.io.dpi <> mem
+ mod_host.io.clock := clock
+ mod_host.io.reset := reset
+ host <> mod_host.io.dpi
+ mod_sim.io.clock := sim_clock
+ mod_sim.io.reset := reset
+ sim_wait := mod_sim.io.dpi_wait
}
/** Test accelerator.
@@ -48,12 +53,15 @@
* Instantiate and connect the simulation-shell and the accelerator.
*
*/
-class TestAccel extends Module {
- val io = IO(new Bundle {})
+class TestAccel extends MultiIOModule {
+ val sim_clock = IO(Input(Clock()))
+ val sim_wait = IO(Output(Bool()))
val sim_shell = Module(new VTASimShell)
val vta_accel = Module(new Accel)
- vta_accel.io.host <> sim_shell.io.host
- sim_shell.io.mem <> vta_accel.io.mem
+ sim_shell.sim_clock := sim_clock
+ sim_wait := sim_shell.sim_wait
+ sim_shell.mem <> vta_accel.io.mem
+ vta_accel.io.host <> sim_shell.host
}
/** Generate TestAccel as top module */
diff --git a/apps/tsim_example/hardware/verilog/src/TestAccel.v b/apps/tsim_example/hardware/verilog/src/TestAccel.v
index f3bcc86..cc1ec85 100644
--- a/apps/tsim_example/hardware/verilog/src/TestAccel.v
+++ b/apps/tsim_example/hardware/verilog/src/TestAccel.v
@@ -25,7 +25,9 @@
module TestAccel
(
input clock,
- input reset
+ input reset,
+ input sim_clock,
+ output sim_wait
);
localparam HOST_ADDR_BITS = 8;
@@ -53,6 +55,14 @@
logic [MEM_DATA_BITS-1:0] mem_rd_bits;
logic mem_rd_ready;
+ VTASimDPI sim
+ (
+ .clock (sim_clock),
+ .reset (reset),
+
+ .dpi_wait (sim_wait)
+ );
+
VTAHostDPI host
(
.clock (clock),
@@ -114,4 +124,5 @@
.mem_rd_bits (mem_rd_bits),
.mem_rd_ready (mem_rd_ready)
);
+
endmodule
diff --git a/apps/tsim_example/python/tsim.py b/apps/tsim_example/python/tsim.py
index a41d904..1e8bbd9 100644
--- a/apps/tsim_example/python/tsim.py
+++ b/apps/tsim_example/python/tsim.py
@@ -20,7 +20,22 @@
import os.path as osp
from sys import platform
-def driver(hw_backend):
+def get_ext():
+ return ".dylib" if platform == "darwin" else ".so"
+
+def load_dll(dll):
+ try:
+ return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
+ except OSError:
+ return []
+
+def load_sw():
+ cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
+ sw_libname = "libsw" + get_ext()
+ sw_lib = osp.join(cur_path, "..", "build", sw_libname)
+ load_dll(sw_lib)
+
+def init(hw_backend):
"""Init hardware and software shared library for accelerator
Parameters
@@ -29,23 +44,15 @@
Hardware backend can be verilog or chisel
"""
- _ext = ".dylib" if platform == "darwin" else ".so"
- _hw_libname = "libhw" + _ext
- _sw_libname = "libsw" + _ext
- _cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
+ cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
+ hw_libname = "libhw" + get_ext()
if hw_backend in ("verilog", "chisel"):
- _hw_lib = osp.join(_cur_path, "..", "hardware", hw_backend, "build", _hw_libname)
- _sw_lib = osp.join(_cur_path, "..", "build", _sw_libname)
+ hw_lib = osp.join(cur_path, "..", "hardware", hw_backend, "build", hw_libname)
+ m = tvm.module.load(hw_lib, "vta-tsim")
+ load_sw()
+ f = tvm.get_global_func("tvm.vta.tsim.init")
+ f(m)
- def load_dll(dll):
- try:
- return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
- except OSError:
- return []
-
- def run(a, b, c):
- load_dll(_sw_lib)
- f = tvm.get_global_func("tvm.vta.driver")
- m = tvm.module.load(_hw_lib, "vta-tsim")
- return f(m, a, b, c)
- return run
+def load_module():
+ load_sw()
+ return tvm.get_global_func("tvm.vta.driver")
diff --git a/apps/tsim_example/src/driver.cc b/apps/tsim_example/src/driver.cc
index c1dc61f..0d75d37 100644
--- a/apps/tsim_example/src/driver.cc
+++ b/apps/tsim_example/src/driver.cc
@@ -35,25 +35,56 @@
using vta::dpi::DPIModuleNode;
using tvm::runtime::Module;
+class DPILoader {
+ public:
+ ~DPILoader() {
+ dpi_->SimResume();
+ dpi_->SimFinish();
+ }
+
+ void Init(Module module) {
+ mod_ = module;
+ dpi_ = this->Get();
+ dpi_->SimLaunch();
+ dpi_->SimWait();
+ }
+
+ DPIModuleNode* Get() {
+ return static_cast<DPIModuleNode*>(mod_.operator->());
+ }
+
+ static DPILoader* Global() {
+ static DPILoader inst;
+ return &inst;
+ }
+
+ // TVM module
+ Module mod_;
+ // DPI Module
+ DPIModuleNode* dpi_{nullptr};
+};
+
class Device {
public:
- Device(Module module)
- : module_(module) {
- dpi_ = static_cast<DPIModuleNode*>(
- module.operator->());
+ Device() {
+ loader_ = DPILoader::Global();
}
uint32_t Run(uint32_t c, uint32_t length, void* inp, void* out) {
uint32_t cycles;
+ this->Init();
this->Launch(c, length, inp, out);
cycles = this->WaitForCompletion();
- dpi_->Finish();
return cycles;
}
private:
+ void Init() {
+ dpi_ = loader_->Get();
+ dpi_->SimResume();
+ }
+
void Launch(uint32_t c, uint32_t length, void* inp, void* out) {
- dpi_->Launch(wait_cycles_);
dpi_->WriteReg(0x08, c);
dpi_->WriteReg(0x0c, length);
dpi_->WriteReg(0x10, get_half_addr(inp, false));
@@ -70,24 +101,33 @@
if (val == 2) break; // finish
}
val = dpi_->ReadReg(0x04);
+ dpi_->SimWait();
return val;
}
+ // wait cycles
uint32_t wait_cycles_{100000000};
- DPIModuleNode* dpi_;
- Module module_;
+ // DPI loader
+ DPILoader* loader_{nullptr};
+ // DPI Module
+ DPIModuleNode* dpi_{nullptr};
};
using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;
+TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ Module m = args[0];
+ DPILoader::Global()->Init(m);
+ });
+
TVM_REGISTER_GLOBAL("tvm.vta.driver")
.set_body([](TVMArgs args, TVMRetValue* rv) {
- Module dev_mod = args[0];
- DLTensor* A = args[1];
- DLTensor* B = args[2];
- Device dev_(dev_mod);
- uint32_t cycles = dev_.Run(static_cast<int>(args[3]), A->shape[0], A->data, B->data);
+ DLTensor* A = args[0];
+ DLTensor* B = args[1];
+ Device dev_;
+ uint32_t cycles = dev_.Run(static_cast<int>(args[2]), A->shape[0], A->data, B->data);
*rv = static_cast<int>(cycles);
});
diff --git a/apps/tsim_example/tests/python/chisel_accel.py b/apps/tsim_example/tests/python/chisel_accel.py
index 26565c3..de67ba0 100644
--- a/apps/tsim_example/tests/python/chisel_accel.py
+++ b/apps/tsim_example/tests/python/chisel_accel.py
@@ -26,12 +26,13 @@
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx)
b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx)
- f = tsim.driver("chisel")
+ f = tsim.load_module()
cycles = f(a, b, c)
msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg)
print("[PASS] " + msg)
if __name__ == "__main__":
+ tsim.init("chisel")
for i in range(10):
test_accel()
diff --git a/apps/tsim_example/tests/python/verilog_accel.py b/apps/tsim_example/tests/python/verilog_accel.py
index d88964b..027e682 100644
--- a/apps/tsim_example/tests/python/verilog_accel.py
+++ b/apps/tsim_example/tests/python/verilog_accel.py
@@ -26,12 +26,13 @@
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx)
b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx)
- f = tsim.driver("verilog")
+ f = tsim.load_module()
cycles = f(a, b, c)
msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg)
print("[PASS] " + msg)
if __name__ == "__main__":
+ tsim.init("verilog")
for i in range(10):
test_accel()
diff --git a/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v b/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v
index b466c79..3441e3e 100644
--- a/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v
+++ b/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v
@@ -35,7 +35,6 @@
import "DPI-C" function void VTAHostDPI
(
- output byte unsigned exit,
output byte unsigned req_valid,
output byte unsigned req_opcode,
output byte unsigned req_addr,
@@ -50,7 +49,6 @@
typedef logic [31:0] dpi32_t;
dpi1_t __reset;
- dpi8_t __exit;
dpi8_t __req_valid;
dpi8_t __req_opcode;
dpi8_t __req_addr;
@@ -80,7 +78,6 @@
// evaluate DPI function
always_ff @(posedge clock) begin
if (reset | __reset) begin
- __exit = 0;
__req_valid = 0;
__req_opcode = 0;
__req_addr = 0;
@@ -88,7 +85,6 @@
end
else begin
VTAHostDPI(
- __exit,
__req_valid,
__req_opcode,
__req_addr,
@@ -99,21 +95,4 @@
end
end
- logic [63:0] cycles;
-
- always_ff @(posedge clock) begin
- if (reset | __reset) begin
- cycles <= 'd0;
- end
- else begin
- cycles <= cycles + 1'b1;
- end
- end
-
- always_ff @(posedge clock) begin
- if (__exit == 'd1) begin
- $finish;
- end
- end
-
endmodule
diff --git a/hardware/chisel/src/main/resources/verilog/VTASimDPI.v b/hardware/chisel/src/main/resources/verilog/VTASimDPI.v
new file mode 100644
index 0000000..fc0d4c8
--- /dev/null
+++ b/hardware/chisel/src/main/resources/verilog/VTASimDPI.v
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+module VTASimDPI
+(
+ input clock,
+ input reset,
+ output logic dpi_wait
+);
+
+ import "DPI-C" function void VTASimDPI
+ (
+ output byte unsigned sim_wait,
+ output byte unsigned sim_exit
+ );
+
+ typedef logic dpi1_t;
+ typedef logic [7:0] dpi8_t;
+
+ dpi1_t __reset;
+ dpi8_t __wait;
+ dpi8_t __exit;
+
+ // reset
+ always_ff @(posedge clock) begin
+ __reset <= reset;
+ end
+
+ // evaluate DPI function
+ always_ff @(posedge clock) begin
+ if (reset | __reset) begin
+ __wait = 0;
+ __exit = 0;
+ end
+ else begin
+ VTASimDPI(
+ __wait,
+ __exit);
+ end
+ end
+
+ logic wait_reg;
+
+ always_ff @(posedge clock) begin
+ if (reset | __reset) begin
+ wait_reg <= 1'b0;
+ end else if (__wait == 1) begin
+ wait_reg <= 1'b1;
+ end else begin
+ wait_reg <= 1'b0;
+ end
+ end
+
+ assign dpi_wait = wait_reg;
+
+ always_ff @(posedge clock) begin
+ if (__exit == 1) begin
+ $finish;
+ end
+ end
+
+endmodule
diff --git a/hardware/chisel/src/main/scala/dpi/VTASimDPI.scala b/hardware/chisel/src/main/scala/dpi/VTASimDPI.scala
new file mode 100644
index 0000000..33b1101
--- /dev/null
+++ b/hardware/chisel/src/main/scala/dpi/VTASimDPI.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.dpi
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.interface.axi._
+import vta.shell._
+
+/** Sim DPI module.
+ *
+ * Wrapper for Sim Verilog DPI module.
+ */
+class VTASimDPI extends BlackBox with HasBlackBoxResource {
+ val io = IO(new Bundle {
+ val clock = Input(Clock())
+ val reset = Input(Bool())
+ val dpi_wait = Output(Bool())
+ })
+ setResource("/verilog/VTASimDPI.v")
+}
diff --git a/hardware/chisel/src/main/scala/shell/SimShell.scala b/hardware/chisel/src/main/scala/shell/SimShell.scala
index 3ad4b65..f3d74ef 100644
--- a/hardware/chisel/src/main/scala/shell/SimShell.scala
+++ b/hardware/chisel/src/main/scala/shell/SimShell.scala
@@ -20,6 +20,7 @@
package vta.shell
import chisel3._
+import chisel3.experimental.MultiIOModule
import vta.util.config._
import vta.interface.axi._
import vta.shell._
@@ -61,18 +62,37 @@
mem_axi.io.axi <> io.axi
}
+/** VTASim.
+ *
+ * This module is used to handle hardware simulation thread, such as halting
+ * or terminating the simulation thread. The sim_wait port is used to halt
+ * the simulation thread when it is asserted and resume it when it is
+ * de-asserted.
+ */
+class VTASim(implicit p: Parameters) extends MultiIOModule {
+ val sim_wait = IO(Output(Bool()))
+ val sim = Module(new VTASimDPI)
+ sim.io.reset := reset
+ sim.io.clock := clock
+ sim_wait := sim.io.dpi_wait
+}
/** SimShell.
*
- * The simulation shell instantiate a host and memory simulation modules and it is
- * intended to be connected to the VTAShell.
+ * The simulation shell instantiate the sim, host and memory DPI modules that
+ * are connected to the VTAShell. An extra clock, sim_clock, is used to eval
+ * the VTASim DPI function when the main simulation clock is on halt state.
*/
-class SimShell(implicit p: Parameters) extends Module {
- val io = IO(new Bundle {
- val mem = new AXIClient(p(ShellKey).memParams)
- val host = new AXILiteMaster(p(ShellKey).hostParams)
- })
- val host = Module(new VTAHost)
- val mem = Module(new VTAMem)
- io.mem <> mem.io.axi
- io.host <> host.io.axi
+class SimShell(implicit p: Parameters) extends MultiIOModule {
+ val mem = IO(new AXIClient(p(ShellKey).memParams))
+ val host = IO(new AXILiteMaster(p(ShellKey).hostParams))
+ val sim_clock = IO(Input(Clock()))
+ val sim_wait = IO(Output(Bool()))
+ val mod_sim = Module(new VTASim)
+ val mod_host = Module(new VTAHost)
+ val mod_mem = Module(new VTAMem)
+ mem <> mod_mem.io.axi
+ host <> mod_host.io.axi
+ mod_sim.reset := reset
+ mod_sim.clock := sim_clock
+ sim_wait := mod_sim.sim_wait
}
diff --git a/hardware/chisel/src/main/scala/test/Test.scala b/hardware/chisel/src/main/scala/test/Test.scala
index db06073..7749d95 100644
--- a/hardware/chisel/src/main/scala/test/Test.scala
+++ b/hardware/chisel/src/main/scala/test/Test.scala
@@ -20,14 +20,18 @@
package vta.test
import chisel3._
+import chisel3.experimental.MultiIOModule
import vta.util.config._
import vta.shell._
/** Test. This generates a testbench file for simulation */
-class Test(implicit p: Parameters) extends Module {
- val io = IO(new Bundle {})
+class Test(implicit p: Parameters) extends MultiIOModule {
+ val sim_clock = IO(Input(Clock()))
+ val sim_wait = IO(Output(Bool()))
val sim_shell = Module(new SimShell)
val vta_shell = Module(new VTAShell)
- vta_shell.io.host <> sim_shell.io.host
- sim_shell.io.mem <> vta_shell.io.mem
+ sim_shell.sim_clock := sim_clock
+ sim_wait := sim_shell.sim_wait
+ sim_shell.mem <> vta_shell.io.mem
+ vta_shell.io.host <> sim_shell.host
}
diff --git a/hardware/dpi/tsim_device.cc b/hardware/dpi/tsim_device.cc
index aa05c8c..1dae273 100644
--- a/hardware/dpi/tsim_device.cc
+++ b/hardware/dpi/tsim_device.cc
@@ -17,6 +17,8 @@
* under the License.
*/
+#include <chrono>
+#include <thread>
#include <vta/dpi/tsim.h>
#if VM_TRACE
@@ -29,11 +31,17 @@
#endif
static VTAContextHandle _ctx = nullptr;
-static VTAMemDPIFunc _mem_dpi = nullptr;
+static VTASimDPIFunc _sim_dpi = nullptr;
static VTAHostDPIFunc _host_dpi = nullptr;
+static VTAMemDPIFunc _mem_dpi = nullptr;
-void VTAHostDPI(dpi8_t* exit,
- dpi8_t* req_valid,
+void VTASimDPI(dpi8_t* wait,
+ dpi8_t* exit) {
+ assert(_sim_dpi != nullptr);
+ (*_sim_dpi)(_ctx, wait, exit);
+}
+
+void VTAHostDPI(dpi8_t* req_valid,
dpi8_t* req_opcode,
dpi8_t* req_addr,
dpi32_t* req_value,
@@ -41,7 +49,7 @@
dpi8_t resp_valid,
dpi32_t resp_value) {
assert(_host_dpi != nullptr);
- (*_host_dpi)(_ctx, exit, req_valid, req_opcode,
+ (*_host_dpi)(_ctx, req_valid, req_opcode,
req_addr, req_value, req_deq,
resp_valid, resp_value);
}
@@ -63,9 +71,11 @@
}
void VTADPIInit(VTAContextHandle handle,
+ VTASimDPIFunc sim_dpi,
VTAHostDPIFunc host_dpi,
VTAMemDPIFunc mem_dpi) {
_ctx = handle;
+ _sim_dpi = sim_dpi;
_host_dpi = host_dpi;
_mem_dpi = mem_dpi;
}
@@ -77,7 +87,7 @@
Verilated::gotFinish(true);
}
-int VTADPISim(uint64_t max_cycles) {
+int VTADPISim() {
uint64_t trace_count = 0;
Verilated::flushCall();
Verilated::gotFinish(false);
@@ -115,13 +125,15 @@
top->reset = 0;
// start simulation
- while (!Verilated::gotFinish() && trace_count < max_cycles) {
+ while (!Verilated::gotFinish()) {
+ top->sim_clock = 0;
top->clock = 0;
top->eval();
#if VM_TRACE
if (trace_count >= start)
tfp->dump(static_cast<vluint64_t>(trace_count * 2));
#endif
+ top->sim_clock = 1;
top->clock = 1;
top->eval();
#if VM_TRACE
@@ -129,6 +141,14 @@
tfp->dump(static_cast<vluint64_t>(trace_count * 2 + 1));
#endif
trace_count++;
+ while (top->sim_wait) {
+ top->clock = 0;
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ top->sim_clock = 0;
+ top->eval();
+ top->sim_clock = 1;
+ top->eval();
+ }
}
#if VM_TRACE
diff --git a/include/vta/dpi/module.h b/include/vta/dpi/module.h
index c83dad1..c1fc6bf 100644
--- a/include/vta/dpi/module.h
+++ b/include/vta/dpi/module.h
@@ -34,11 +34,17 @@
*/
class DPIModuleNode : public tvm::runtime::ModuleNode {
public:
-/*!
- * \brief Launch hardware simulation until accelerator finishes or reach max_cycles
- * \param max_cycles The maximum of cycles to wait
- */
- virtual void Launch(uint64_t max_cycles) = 0;
+/*! \brief Launch hardware simulation */
+ virtual void SimLaunch() = 0;
+
+/*! \brief Halt hardware simulation */
+ virtual void SimWait() = 0;
+
+/*! \brief Resume hardware simulation */
+ virtual void SimResume() = 0;
+
+/*! \brief Finish hardware simulation */
+ virtual void SimFinish() = 0;
/*!
* \brief Write an accelerator register
@@ -53,13 +59,9 @@
*/
virtual uint32_t ReadReg(int addr) = 0;
-/*! \brief Finish hardware simulation */
- virtual void Finish() = 0;
-
static tvm::runtime::Module Load(std::string dll_name);
};
} // namespace dpi
} // namespace vta
#endif // VTA_DPI_MODULE_H_
-
diff --git a/include/vta/dpi/tsim.h b/include/vta/dpi/tsim.h
index 6170cde..8e13def 100644
--- a/include/vta/dpi/tsim.h
+++ b/include/vta/dpi/tsim.h
@@ -36,9 +36,13 @@
/*! \brief the context handle */
typedef void* VTAContextHandle;
+typedef void (*VTASimDPIFunc)(
+ VTAContextHandle self,
+ dpi8_t* wait,
+ dpi8_t* exit);
+
/*!
* \brief Host DPI callback function that is invoked in VTAHostDPI.v every clock cycle
- * \param exit Host kill simulation
* \param req_valid Host has a valid request for read or write a register in Accel
* \param req_opcode Host request type, opcode=0 for read and opcode=1 for write
* \param req_addr Host request register address
@@ -50,7 +54,6 @@
*/
typedef void (*VTAHostDPIFunc)(
VTAContextHandle self,
- dpi8_t* exit,
dpi8_t* req_valid,
dpi8_t* req_opcode,
dpi8_t* req_addr,
@@ -84,28 +87,28 @@
/*! \brief The type of VTADPIInit function pointer */
typedef void (*VTADPIInitFunc)(VTAContextHandle handle,
+ VTASimDPIFunc sim_dpi,
VTAHostDPIFunc host_dpi,
VTAMemDPIFunc mem_dpi);
/*! \brief The type of VTADPISim function pointer */
-typedef int (*VTADPISimFunc)(uint64_t max_cycles);
+typedef int (*VTADPISimFunc)();
/*!
* \brief Set Host and Memory DPI functions
* \param handle DPI Context handle
+ * \param sim_dpi Sim DPI function
* \param host_dpi Host DPI function
* \param mem_dpi Memory DPI function
*/
TVM_DLL void VTADPIInit(VTAContextHandle handle,
+ VTASimDPIFunc sim_dpi,
VTAHostDPIFunc host_dpi,
VTAMemDPIFunc mem_dpi);
-/*!
- * \brief Instantiate VTA design and generate clock/reset
- * \param max_cycles The maximum number of simulation cycles
- */
-TVM_DLL int VTADPISim(uint64_t max_cycles);
+/*! \brief VTA hardware simulation thread */
+TVM_DLL int VTADPISim();
#ifdef __cplusplus
}
diff --git a/python/vta/exec/rpc_server.py b/python/vta/exec/rpc_server.py
index 0ac97a2..be9d91a 100644
--- a/python/vta/exec/rpc_server.py
+++ b/python/vta/exec/rpc_server.py
@@ -42,7 +42,7 @@
curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../../../"))
- dll_path = find_libvta()[0]
+ dll_path = find_libvta("libvta")[0]
cfg_path = os.path.abspath(os.path.join(proj_root, "build/vta_config.json"))
runtime_dll = []
_load_module = tvm.get_global_func("tvm.rpc.server.load_module")
diff --git a/python/vta/libinfo.py b/python/vta/libinfo.py
index 19f9d30..00a43c0 100644
--- a/python/vta/libinfo.py
+++ b/python/vta/libinfo.py
@@ -19,21 +19,48 @@
import sys
import os
-def _get_lib_name():
+def _get_lib_name(lib_name):
+ """Get lib name with extension
+
+ Returns
+ -------
+ lib_name_ext : str
+ Name of VTA shared library with extension
+
+ Parameters
+ ------------
+ lib_name : str
+ Name of VTA shared library
+ """
if sys.platform.startswith('win32'):
- return "vta.dll"
+ return lib_name + ".dll"
if sys.platform.startswith('darwin'):
- return "libvta.dylib"
- return "libvta.so"
+ return lib_name + ".dylib"
+ return lib_name + ".so"
-def find_libvta(optional=False):
- """Find VTA library"""
+def find_libvta(lib_vta, optional=False):
+ """Find VTA library
+
+ Returns
+ -------
+ lib_found : str
+ Library path
+
+ Parameters
+ ------------
+ lib_vta : str
+ Name of VTA shared library
+
+ optional : bool
+ Enable error check
+ """
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
lib_search = [curr_path]
+ lib_search += [os.path.join(curr_path, "..", "..", "build",)]
lib_search += [os.path.join(curr_path, "..", "..", "..", "build",)]
lib_search += [os.path.join(curr_path, "..", "..", "..", "build", "Release")]
- lib_name = _get_lib_name()
+ lib_name = _get_lib_name(lib_vta)
lib_path = [os.path.join(x, lib_name) for x in lib_search]
lib_found = [x for x in lib_path if os.path.exists(x)]
if not lib_found and not optional:
diff --git a/python/vta/testing/simulator.py b/python/vta/testing/simulator.py
index 2d6cfe3..4024993 100644
--- a/python/vta/testing/simulator.py
+++ b/python/vta/testing/simulator.py
@@ -17,22 +17,34 @@
"""Utilities to start simulator."""
import ctypes
import json
-import sys
-import os
import tvm
+from ..environment import get_env
from ..libinfo import find_libvta
-def _load_lib():
- """Load local library, assuming they are simulator."""
- lib_path = find_libvta(optional=True)
- if not lib_path:
+
+def _load_sw():
+ """Load software library, assuming they are simulator."""
+ lib_sw = find_libvta("libvta", optional=True)
+ if not lib_sw:
return []
try:
- return [ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)]
+ return [ctypes.CDLL(lib_sw[0], ctypes.RTLD_GLOBAL)]
except OSError:
return []
+def _load_all():
+ """Load hardware library for tsim."""
+ lib = _load_sw()
+ env = get_env()
+ if env.TARGET == "tsim":
+ lib = find_libvta("libvta_hw", optional=True)
+ f = tvm.get_global_func("vta.tsim.init")
+ m = tvm.module.load(lib[0], "vta-tsim")
+ f(m)
+ return lib
+
+
def enabled():
"""Check if simulator is enabled."""
f = tvm.get_global_func("vta.simulator.profiler_clear", True)
@@ -40,49 +52,31 @@
def clear_stats():
- """Clear profiler statistics"""
- f = tvm.get_global_func("vta.simulator.profiler_clear", True)
+ """Clear profiler statistics."""
+ env = get_env()
+ if env.TARGET == "sim":
+ f = tvm.get_global_func("vta.simulator.profiler_clear", True)
+ else:
+ f = tvm.get_global_func("vta.tsim.profiler_clear", True)
if f:
f()
def stats():
- """Clear profiler statistics
+ """Get profiler statistics
Returns
-------
stats : dict
Current profiler statistics
"""
- x = tvm.get_global_func("vta.simulator.profiler_status")()
+ env = get_env()
+ if env.TARGET == "sim":
+ x = tvm.get_global_func("vta.simulator.profiler_status")()
+ else:
+ x = tvm.get_global_func("vta.tsim.profiler_status")()
return json.loads(x)
-def tsim_init(hw_lib):
- """Init hardware shared library for TSIM
-
- Parameters
- ------------
- hw_lib : str
- Name of hardware shared library
- """
- cur_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
- vta_build_path = os.path.join(cur_path, "..", "..", "..", "build")
- if not hw_lib.endswith(("dylib", "so")):
- hw_lib += ".dylib" if sys.platform == "darwin" else ".so"
- lib = os.path.join(vta_build_path, hw_lib)
- f = tvm.get_global_func("tvm.vta.tsim.init")
- m = tvm.module.load(lib, "vta-tsim")
- f(m)
-
-def tsim_cycles():
- """Get tsim clock cycles
-
- Returns
- -------
- stats : int
- tsim clock cycles
- """
- return tvm.get_global_func("tvm.vta.tsim.cycles")()
# debug flag to skip execution.
DEBUG_SKIP_EXEC = 1
@@ -97,4 +91,4 @@
tvm.get_global_func("vta.simulator.profiler_debug_mode")(flag)
-LIBS = _load_lib()
+LIBS = _load_all()
diff --git a/python/vta/testing/util.py b/python/vta/testing/util.py
index 7da98ef..b748cdf 100644
--- a/python/vta/testing/util.py
+++ b/python/vta/testing/util.py
@@ -22,6 +22,7 @@
from ..environment import get_env
from . import simulator
+
def run(run_func):
"""Run test function on all available env.
diff --git a/src/dpi/module.cc b/src/dpi/module.cc
index 4839b02..c1dcbb7 100644
--- a/src/dpi/module.cc
+++ b/src/dpi/module.cc
@@ -86,17 +86,28 @@
std::condition_variable cond_;
};
+class SimDevice {
+ public:
+ void Wait();
+ void Resume();
+ void Exit();
+ bool GetWaitStatus();
+ bool GetExitStatus();
+
+ private:
+ bool wait_{false};
+ bool exit_{false};
+ mutable std::mutex mutex_;
+};
+
class HostDevice {
public:
void PushRequest(uint8_t opcode, uint8_t addr, uint32_t value);
bool TryPopRequest(HostRequest* r, bool pop);
void PushResponse(uint32_t value);
void WaitPopResponse(HostResponse* r);
- void Exit();
- uint8_t GetExitStatus();
private:
- uint8_t exit_{0};
mutable std::mutex mutex_;
ThreadSafeQueue<HostRequest> req_;
ThreadSafeQueue<HostResponse> resp_;
@@ -116,6 +127,31 @@
std::mutex mutex_;
};
+void SimDevice::Wait() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ wait_ = true;
+}
+
+void SimDevice::Resume() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ wait_ = false;
+}
+
+void SimDevice::Exit() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ exit_ = true;
+}
+
+bool SimDevice::GetWaitStatus() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ return wait_;
+}
+
+bool SimDevice::GetExitStatus() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ return exit_;
+}
+
void HostDevice::PushRequest(uint8_t opcode, uint8_t addr, uint32_t value) {
HostRequest r;
r.opcode = opcode;
@@ -141,16 +177,6 @@
resp_.WaitPop(r);
}
-void HostDevice::Exit() {
- std::unique_lock<std::mutex> lock(mutex_);
- exit_ = 1;
-}
-
-uint8_t HostDevice::GetExitStatus() {
- std::unique_lock<std::mutex> lock(mutex_);
- return exit_;
-}
-
void MemDevice::SetRequest(uint8_t opcode, uint64_t addr, uint32_t len) {
std::lock_guard<std::mutex> lock(mutex_);
if (opcode == 1) {
@@ -212,16 +238,29 @@
VTADPIInitFunc finit = reinterpret_cast<VTADPIInitFunc>(
GetSymbol("VTADPIInit"));
CHECK(finit != nullptr);
- finit(this, VTAHostDPI, VTAMemDPI);
- fvsim_ = reinterpret_cast<VTADPISimFunc>(GetSymbol("VTADPISim"));
- CHECK(fvsim_ != nullptr);
+ finit(this, VTASimDPI, VTAHostDPI, VTAMemDPI);
+ ftsim_ = reinterpret_cast<VTADPISimFunc>(GetSymbol("VTADPISim"));
+ CHECK(ftsim_ != nullptr);
}
- void Launch(uint64_t max_cycles) {
- auto frun = [this, max_cycles]() {
- (*fvsim_)(max_cycles);
+ void SimLaunch() {
+ auto frun = [this]() {
+ (*ftsim_)();
};
- vsim_thread_ = std::thread(frun);
+ tsim_thread_ = std::thread(frun);
+ }
+
+ void SimWait() {
+ sim_device_.Wait();
+ }
+
+ void SimResume() {
+ sim_device_.Resume();
+ }
+
+ void SimFinish() {
+ sim_device_.Exit();
+ tsim_thread_.join();
}
void WriteReg(int addr, uint32_t value) {
@@ -238,19 +277,20 @@
return value;
}
- void Finish() {
- host_device_.Exit();
- vsim_thread_.join();
- }
-
protected:
- VTADPISimFunc fvsim_;
+ VTADPISimFunc ftsim_;
+ SimDevice sim_device_;
HostDevice host_device_;
MemDevice mem_device_;
- std::thread vsim_thread_;
+ std::thread tsim_thread_;
- void HostDPI(dpi8_t* exit,
- dpi8_t* req_valid,
+ void SimDPI(dpi8_t* wait,
+ dpi8_t* exit) {
+ *wait = sim_device_.GetWaitStatus();
+ *exit = sim_device_.GetExitStatus();
+ }
+
+ void HostDPI(dpi8_t* req_valid,
dpi8_t* req_opcode,
dpi8_t* req_addr,
dpi32_t* req_value,
@@ -258,7 +298,6 @@
dpi8_t resp_valid,
dpi32_t resp_value) {
HostRequest* r = new HostRequest;
- *exit = host_device_.GetExitStatus();
*req_valid = host_device_.TryPopRequest(r, req_deq);
*req_opcode = r->opcode;
*req_addr = r->addr;
@@ -290,9 +329,16 @@
}
}
+ static void VTASimDPI(
+ VTAContextHandle self,
+ dpi8_t* wait,
+ dpi8_t* exit) {
+ static_cast<DPIModule*>(self)->SimDPI(
+ wait, exit);
+ }
+
static void VTAHostDPI(
VTAContextHandle self,
- dpi8_t* exit,
dpi8_t* req_valid,
dpi8_t* req_opcode,
dpi8_t* req_addr,
@@ -301,7 +347,7 @@
dpi8_t resp_valid,
dpi32_t resp_value) {
static_cast<DPIModule*>(self)->HostDPI(
- exit, req_valid, req_opcode, req_addr,
+ req_valid, req_opcode, req_addr,
req_value, req_deq, resp_valid, resp_value);
}
diff --git a/src/pynq/pynq_driver.cc b/src/pynq/pynq_driver.cc
index be9d0fe..5f96b65 100644
--- a/src/pynq/pynq_driver.cc
+++ b/src/pynq/pynq_driver.cc
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/src/sim/sim_driver.cc b/src/sim/sim_driver.cc
index 0691195..cf7d6dc 100644
--- a/src/sim/sim_driver.cc
+++ b/src/sim/sim_driver.cc
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/src/tsim/tsim_driver.cc b/src/tsim/tsim_driver.cc
index 6dd273c..67716ea 100644
--- a/src/tsim/tsim_driver.cc
+++ b/src/tsim/tsim_driver.cc
@@ -17,32 +17,78 @@
* under the License.
*/
-#include <vta/driver.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
+#include <vta/driver.h>
#include <vta/dpi/module.h>
namespace vta {
namespace tsim {
-using vta::dpi::DPIModuleNode;
using tvm::runtime::Module;
+using vta::dpi::DPIModuleNode;
class Profiler {
public:
- /*! \brief cycle counter */
- uint64_t cycle_count{0};
+ Profiler() {
+ counters_ = new int[num_counters_];
+ this->ClearAll();
+ }
+
+ ~Profiler() {
+ delete [] counters_;
+ }
+
+ /*! \brief update one event counter */
+ void Update(uint32_t idx, uint32_t value) {
+ counters_[idx] += value;
+ }
+
+ /*! \brief clear one event counter*/
+ void Clear(uint32_t idx) {
+ counters_[idx] = 0;
+ }
+
+ /*! \brief clear all event counters */
+ void ClearAll() {
+ for (uint32_t i = 0; i < num_counters_; i++) {
+ counters_[i] = 0;
+ }
+ }
+
+ /*! \brief return counters as json */
+ std::string AsJSON() {
+ std::ostringstream os;
+ os << "{\n"
+ << " \"cycle_count\":" << counters_[0] << "\n"
+ <<"}\n";
+ return os.str();
+ }
static Profiler* Global() {
static Profiler inst;
return &inst;
}
+
+ private:
+ /*! \brief total number of event counters */
+ uint32_t num_counters_{1};
+ /*! \brief event counters */
+ int* counters_{nullptr};
};
class DPILoader {
public:
+ ~DPILoader() {
+ dpi_->SimResume();
+ dpi_->SimFinish();
+ }
+
void Init(Module module) {
mod_ = module;
+ dpi_ = this->Get();
+ dpi_->SimLaunch();
+ dpi_->SimWait();
}
DPIModuleNode* Get() {
@@ -54,13 +100,16 @@
return &inst;
}
+ // TVM module
Module mod_;
+ // DPI Module
+ DPIModuleNode* dpi_{nullptr};
};
class Device {
public:
Device() {
- dpi_ = DPILoader::Global();
+ loader_ = DPILoader::Global();
prof_ = Profiler::Global();
}
@@ -82,13 +131,13 @@
insn_count,
wait_cycles);
this->WaitForCompletion(wait_cycles);
- dev_->Finish();
return 0;
}
private:
void Init() {
- dev_ = dpi_->Get();
+ dpi_ = loader_->Get();
+ dpi_->SimResume();
}
void Launch(vta_phy_addr_t insn_phy_addr,
@@ -99,57 +148,60 @@
vta_phy_addr_t out_phy_addr,
uint32_t insn_count,
uint32_t wait_cycles) {
- // launch simulation thread
- dev_->Launch(wait_cycles);
- // set counter to zero
- dev_->WriteReg(0x04, 0);
- dev_->WriteReg(0x08, insn_count);
- dev_->WriteReg(0x0c, insn_phy_addr);
- dev_->WriteReg(0x10, insn_phy_addr >> 32);
- dev_->WriteReg(0x14, 0);
- dev_->WriteReg(0x18, uop_phy_addr >> 32);
- dev_->WriteReg(0x1c, 0);
- dev_->WriteReg(0x20, inp_phy_addr >> 32);
- dev_->WriteReg(0x24, 0);
- dev_->WriteReg(0x28, wgt_phy_addr >> 32);
- dev_->WriteReg(0x2c, 0);
- dev_->WriteReg(0x30, acc_phy_addr >> 32);
- dev_->WriteReg(0x34, 0);
- dev_->WriteReg(0x38, out_phy_addr >> 32);
+ dpi_->WriteReg(0x04, 0);
+ dpi_->WriteReg(0x08, insn_count);
+ dpi_->WriteReg(0x0c, insn_phy_addr);
+ dpi_->WriteReg(0x10, insn_phy_addr >> 32);
+ dpi_->WriteReg(0x14, 0);
+ dpi_->WriteReg(0x18, uop_phy_addr >> 32);
+ dpi_->WriteReg(0x1c, 0);
+ dpi_->WriteReg(0x20, inp_phy_addr >> 32);
+ dpi_->WriteReg(0x24, 0);
+ dpi_->WriteReg(0x28, wgt_phy_addr >> 32);
+ dpi_->WriteReg(0x2c, 0);
+ dpi_->WriteReg(0x30, acc_phy_addr >> 32);
+ dpi_->WriteReg(0x34, 0);
+ dpi_->WriteReg(0x38, out_phy_addr >> 32);
// start
- dev_->WriteReg(0x00, 0x1);
+ dpi_->WriteReg(0x00, 0x1);
}
void WaitForCompletion(uint32_t wait_cycles) {
uint32_t i, val;
for (i = 0; i < wait_cycles; i++) {
- val = dev_->ReadReg(0x00);
+ val = dpi_->ReadReg(0x00);
val &= 0x2;
if (val == 0x2) break; // finish
}
- prof_->cycle_count = dev_->ReadReg(0x04);
+ prof_->Update(0, dpi_->ReadReg(0x04));
+ dpi_->SimWait();
}
// Profiler
Profiler* prof_;
// DPI loader
- DPILoader* dpi_;
+ DPILoader* loader_;
// DPI Module
- DPIModuleNode* dev_;
+ DPIModuleNode* dpi_;
};
using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;
-TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
+TVM_REGISTER_GLOBAL("vta.tsim.init")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
DPILoader::Global()->Init(m);
});
-TVM_REGISTER_GLOBAL("tvm.vta.tsim.cycles")
+TVM_REGISTER_GLOBAL("vta.tsim.profiler_clear")
.set_body([](TVMArgs args, TVMRetValue* rv) {
- *rv = static_cast<int>(Profiler::Global()->cycle_count);
+ Profiler::Global()->ClearAll();
+ });
+
+TVM_REGISTER_GLOBAL("vta.tsim.profiler_status")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = Profiler::Global()->AsJSON();
});
} // namespace tsim
diff --git a/tests/python/integration/test_benchmark_gemm.py b/tests/python/integration/test_benchmark_gemm.py
index 7a36352..d4eed91 100644
--- a/tests/python/integration/test_benchmark_gemm.py
+++ b/tests/python/integration/test_benchmark_gemm.py
@@ -18,6 +18,7 @@
import numpy as np
from tvm.contrib import util
import vta.testing
+from vta.testing import simulator
def test_gemm():
@@ -104,7 +105,14 @@
res_ref = np.right_shift(res_ref, 8)
res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
time_f = f.time_evaluator("gemm", ctx, number=20)
+ if env.TARGET in ["sim", "tsim"]:
+ simulator.clear_stats()
cost = time_f(data_arr, weight_arr, res_arr)
+ if env.TARGET in ["sim", "tsim"]:
+ stats = simulator.stats()
+ print("Execution statistics:")
+ for k, v in stats.items():
+ print("\t{:<16}: {:>16}".format(k, v))
res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH,
channel // env.BLOCK_OUT,
env.BATCH,
diff --git a/tests/python/integration/test_benchmark_topi_conv2d.py b/tests/python/integration/test_benchmark_topi_conv2d.py
index 2aec471..daca936 100644
--- a/tests/python/integration/test_benchmark_topi_conv2d.py
+++ b/tests/python/integration/test_benchmark_topi_conv2d.py
@@ -173,14 +173,20 @@
# In vta sim mode, collect simulator runtime statistics
stats = {}
cost = None
- if env.TARGET == "sim":
+ if env.TARGET in ["sim", "tsim"]:
# Check if we're in local RPC mode (allows us to rebuild the
# runtime on the fly when varying the VTA designs)
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
- remote.get_function("vta.simulator.profiler_clear")()
+ if env.TARGET == "sim":
+ remote.get_function("vta.simulator.profiler_clear")()
+ else:
+ remote.get_function("vta.tsim.profiler_clear")()
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
- stats = json.loads(remote.get_function("vta.simulator.profiler_status")())
+ if env.TARGET == "sim":
+ stats = json.loads(remote.get_function("vta.simulator.profiler_status")())
+ else:
+ stats = json.loads(remote.get_function("vta.tsim.profiler_status")())
else:
simulator.clear_stats()
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
@@ -215,7 +221,7 @@
def _run(env, remote):
if device == "vta":
target = env.target
- if env.TARGET != "sim":
+ if env.TARGET not in ["sim", "tsim"]:
assert tvm.module.enabled("rpc")
program_fpga(remote, bitstream=None)
reconfig_runtime(remote)
diff --git a/tests/python/integration/test_benchmark_topi_dense.py b/tests/python/integration/test_benchmark_topi_dense.py
index 12fbc45..174e966 100644
--- a/tests/python/integration/test_benchmark_topi_dense.py
+++ b/tests/python/integration/test_benchmark_topi_dense.py
@@ -131,14 +131,20 @@
# In vta sim mode, collect simulator runtime statistics
stats = {}
cost = None
- if env.TARGET == "sim":
+ if env.TARGET in ["sim", "tsim"]:
# Check if we're in local RPC mode (allows us to rebuild the
# runtime on the fly when varying the VTA designs)
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
- remote.get_function("vta.simulator.profiler_clear")()
+ if env.TARGET == "sim":
+ remote.get_function("vta.simulator.profiler_clear")()
+ else:
+ remote.get_function("vta.tsim.profiler_clear")()
cost = time_f(data_arr, kernel_arr, res_arr)
- stats = json.loads(remote.get_function("vta.simulator.profiler_status")())
+ if env.TARGET == "sim":
+ stats = json.loads(remote.get_function("vta.simulator.profiler_status")())
+ else:
+ stats = json.loads(remote.get_function("vta.tsim.profiler_status")())
else:
simulator.clear_stats()
cost = time_f(data_arr, kernel_arr, res_arr)
@@ -171,7 +177,7 @@
def _run(env, remote):
if device == "vta":
target = env.target
- if env.TARGET != "sim":
+ if env.TARGET not in ["sim", "tsim"]:
assert tvm.module.enabled("rpc")
program_fpga(remote, bitstream=None)
reconfig_runtime(remote)
diff --git a/tests/python/unittest/test_vta_insn.py b/tests/python/unittest/test_vta_insn.py
index 815f55b..25d7d8c 100644
--- a/tests/python/unittest/test_vta_insn.py
+++ b/tests/python/unittest/test_vta_insn.py
@@ -69,15 +69,18 @@
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
- if env.TARGET == "tsim":
- simulator.tsim_init("libvta_hw")
+ if env.TARGET in ["sim", "tsim"]:
+ simulator.clear_stats()
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
- if env.TARGET == "tsim":
- print("Load/store test took {} clock cycles".format(simulator.tsim_cycles()))
+ if env.TARGET in ["sim", "tsim"]:
+ sim_stats = simulator.stats()
+ print("Save load execution statistics:")
+ for k, v in sim_stats.items():
+ print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run)
@@ -135,15 +138,18 @@
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
- if env.TARGET == "tsim":
- simulator.tsim_init("libvta_hw")
+ if env.TARGET in ["sim", "tsim"]:
+ simulator.clear_stats()
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
- if env.TARGET == "tsim":
- print("Padded load test took {} clock cycles".format(simulator.tsim_cycles()))
+ if env.TARGET in ["sim", "tsim"]:
+ sim_stats = simulator.stats()
+ print("Padded load execution statistics:")
+ for k, v in sim_stats.items():
+ print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run)
@@ -213,20 +219,18 @@
y_np = np.right_shift(y_np, 8)
y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
- if env.TARGET == "tsim":
- simulator.tsim_init("libvta_hw")
-
- if env.TARGET == "sim":
+ if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
- f(x_nd, w_nd, y_nd)
- print(simulator.stats())
- else:
- f(x_nd, w_nd, y_nd)
+
+ f(x_nd, w_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
- if env.TARGET == "tsim":
- print("GEMM schedule:{} test took {} clock cycles".format(name, simulator.tsim_cycles()))
+ if env.TARGET in ["sim", "tsim"]:
+ sim_stats = simulator.stats()
+ print("GEMM schedule:{} execution statistics:".format(name))
+ for k, v in sim_stats.items():
+ print("\t{:<16}: {:>16}".format(k, v))
def test_schedule1():
# default schedule with no smt
@@ -374,8 +378,8 @@
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
- if env.TARGET == "tsim":
- simulator.tsim_init("libvta_hw")
+ if env.TARGET in ["sim", "tsim"]:
+ simulator.clear_stats()
if use_imm:
f(a_nd, res_nd)
@@ -385,8 +389,11 @@
np.testing.assert_equal(res_np, res_nd.asnumpy())
- if env.TARGET == "tsim":
- print("ALU {} imm:{} test took {} clock cycles".format(test_name, use_imm, simulator.tsim_cycles()))
+ if env.TARGET in ["sim", "tsim"]:
+ sim_stats = simulator.stats()
+ print("ALU {} execution statistics:".format(test_name))
+ for k, v in sim_stats.items():
+ print("\t{:<16}: {:>16}".format(k, v))
check_alu(lambda x, y: x << y, np.left_shift, use_imm=True, test_name="SHL")
check_alu(tvm.max, np.maximum, use_imm=True, test_name="MAX")
@@ -451,15 +458,18 @@
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
- if env.TARGET == "tsim":
- simulator.tsim_init("libvta_hw")
+ if env.TARGET in ["sim", "tsim"]:
+ simulator.clear_stats()
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
- if env.TARGET == "tsim":
- print("Relu test took {} clock cycles".format(simulator.tsim_cycles()))
+ if env.TARGET in ["sim", "tsim"]:
+ sim_stats = simulator.stats()
+ print("Relu execution statistics:")
+ for k, v in sim_stats.items():
+ print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run)
@@ -518,15 +528,18 @@
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
- if env.TARGET == "tsim":
- simulator.tsim_init("libvta_hw")
+ if env.TARGET in ["sim", "tsim"]:
+ simulator.clear_stats()
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
- if env.TARGET == "tsim":
- print("Shift/scale test took {} clock cycles".format(simulator.tsim_cycles()))
+ if env.TARGET in ["sim", "tsim"]:
+ sim_stats = simulator.stats()
+ print("Shift and scale execution statistics:")
+ for k, v in sim_stats.items():
+ print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run)
diff --git a/tutorials/frontend/deploy_resnet_on_vta.py b/tutorials/frontend/deploy_resnet_on_vta.py
index c4e7aaf..550dac7 100644
--- a/tutorials/frontend/deploy_resnet_on_vta.py
+++ b/tutorials/frontend/deploy_resnet_on_vta.py
@@ -89,7 +89,7 @@
# When target is 'pynq', reconfigure FPGA and runtime.
# Otherwise, if target is 'sim', execute locally.
-if env.TARGET != "sim":
+if env.TARGET not in ["sim", "tsim"]:
# Get remote from tracker node if environment variable is set.
# To set up the tracker, you'll need to follow the "Auto-tuning
@@ -235,7 +235,7 @@
rep = 3 # number of measurements (we derive std dev from this)
timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep)
-if env.TARGET == "sim":
+if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
timer()
sim_stats = simulator.stats()
diff --git a/tutorials/matrix_multiply.py b/tutorials/matrix_multiply.py
index bf3960b..70a899b 100644
--- a/tutorials/matrix_multiply.py
+++ b/tutorials/matrix_multiply.py
@@ -66,7 +66,7 @@
vta.program_fpga(remote, bitstream=None)
# In simulation mode, host the RPC server locally.
-elif env.TARGET == "sim":
+elif env.TARGET in ["sim", "tsim"]:
remote = rpc.LocalSession()
######################################################################
@@ -437,6 +437,10 @@
B_nd = tvm.nd.array(B_packed, ctx)
C_nd = tvm.nd.array(np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(C.dtype), ctx)
+# Clear stats
+if env.TARGET in ["sim", "tsim"]:
+ simulator.clear_stats()
+
# Invoke the module to perform the computation
f(A_nd, B_nd, C_nd)
@@ -452,8 +456,15 @@
C_ref = C_ref.reshape(
o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3))
np.testing.assert_equal(C_ref, C_nd.asnumpy())
-print("Successful matrix multiply test!")
+# Print stats
+if env.TARGET in ["sim", "tsim"]:
+ sim_stats = simulator.stats()
+ print("Execution statistics:")
+ for k, v in sim_stats.items():
+ print("\t{:<16}: {:>16}".format(k, v))
+
+print("Successful matrix multiply test!")
######################################################################
# Summary
diff --git a/tutorials/optimize/convolution_opt.py b/tutorials/optimize/convolution_opt.py
index 67b5895..f1e0ba3 100644
--- a/tutorials/optimize/convolution_opt.py
+++ b/tutorials/optimize/convolution_opt.py
@@ -70,7 +70,7 @@
vta.program_fpga(remote, bitstream=None)
# In simulation mode, host the RPC server locally.
-elif env.TARGET == "sim":
+elif env.TARGET in ["sim", "tsim"]:
remote = rpc.LocalSession()
######################################################################
@@ -412,6 +412,10 @@
kernel_nd = tvm.nd.array(kernel_packed, ctx)
res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)
+# Clear stats
+if env.TARGET in ["sim", "tsim"]:
+ simulator.clear_stats()
+
# Invoke the module to perform the computation
f(data_nd, kernel_nd, res_nd)
@@ -430,6 +434,14 @@
fout_height,
fout_width)).transpose((0, 2, 4, 5, 1, 3))
tvm.testing.assert_allclose(res_ref, res_nd.asnumpy())
+
+# Print stats
+if env.TARGET in ["sim", "tsim"]:
+ sim_stats = simulator.stats()
+ print("Execution statistics:")
+ for k, v in sim_stats.items():
+ print("\t{:<16}: {:>16}".format(k, v))
+
print("Successful 2D convolution test!")
######################################################################
diff --git a/tutorials/optimize/matrix_multiply_opt.py b/tutorials/optimize/matrix_multiply_opt.py
index 10ba770..b20094a 100644
--- a/tutorials/optimize/matrix_multiply_opt.py
+++ b/tutorials/optimize/matrix_multiply_opt.py
@@ -69,7 +69,7 @@
vta.program_fpga(remote, bitstream=None)
# In simulation mode, host the RPC server locally.
-elif env.TARGET == "sim":
+elif env.TARGET in ["sim", "tsim"]:
remote = rpc.LocalSession()
######################################################################
@@ -352,6 +352,10 @@
weight_nd = tvm.nd.array(weight_packed, ctx)
res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)
+# Clear stats
+if env.TARGET in ["sim", "tsim"]:
+ simulator.clear_stats()
+
# Invoke the module to perform the computation
f(data_nd, weight_nd, res_nd)
@@ -366,6 +370,14 @@
out_channels // env.BLOCK_OUT,
env.BLOCK_OUT).transpose((0, 2, 1, 3))
np.testing.assert_equal(res_ref, res_nd.asnumpy())
+
+# Print stats
+if env.TARGET in ["sim", "tsim"]:
+ sim_stats = simulator.stats()
+ print("Execution statistics:")
+ for k, v in sim_stats.items():
+ print("\t{:<16}: {:>16}".format(k, v))
+
print("Successful blocked matrix multiply test!")
######################################################################