[REFACTOR][PY] Establish tvm.tir
- Move related files into the corresponding location as in C++
- Keep the top-level TVM API backward compatible to make minimum changes in topi
diff --git a/python/vta/build_module.py b/python/vta/build_module.py
index df67faa..f368362 100644
--- a/python/vta/build_module.py
+++ b/python/vta/build_module.py
@@ -68,7 +68,7 @@
env.dev.command_handle,
debug_flag)
- return tvm.make.stmt_seq(debug, stmt)
+ return tvm.tir.stmt_seq(debug, stmt)
pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy),
diff --git a/python/vta/environment.py b/python/vta/environment.py
index 83db612..8d58958 100644
--- a/python/vta/environment.py
+++ b/python/vta/environment.py
@@ -62,11 +62,11 @@
def __init__(self, env):
self.vta_axis = tvm.thread_axis("vta")
- self.vta_push_uop = tvm.make.StringImm("VTAPushGEMMOp")
+ self.vta_push_uop = tvm.tir.StringImm("VTAPushGEMMOp")
ctx = tvm.call_extern("handle", "VTATLSCommandHandle")
- self.command_handle = tvm.make.Call(
+ self.command_handle = tvm.tir.Call(
"handle", "tvm_thread_context", [ctx],
- tvm.expr.Call.Intrinsic, None, 0)
+ tvm.tir.Call.Intrinsic, None, 0)
self.DEBUG_NO_SYNC = False
env._dev_ctx = self
self.gemm = intrin.gemm(env, env.mock_mode)
@@ -256,29 +256,29 @@
@tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope)
def mem_info_inp_buffer():
spec = get_env()
- return tvm.make.node("MemoryInfo",
- unit_bits=spec.INP_ELEM_BITS,
- max_simd_bits=spec.INP_ELEM_BITS,
- max_num_bits=spec.INP_BUFF_SIZE * 8,
- head_address=None)
+ return tvm.ir.make_node("MemoryInfo",
+ unit_bits=spec.INP_ELEM_BITS,
+ max_simd_bits=spec.INP_ELEM_BITS,
+ max_num_bits=spec.INP_BUFF_SIZE * 8,
+ head_address=None)
@tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope)
def mem_info_wgt_buffer():
spec = get_env()
- return tvm.make.node("MemoryInfo",
- unit_bits=spec.WGT_ELEM_BITS,
- max_simd_bits=spec.WGT_ELEM_BITS,
- max_num_bits=spec.WGT_BUFF_SIZE * 8,
- head_address=None)
+ return tvm.ir.make_node("MemoryInfo",
+ unit_bits=spec.WGT_ELEM_BITS,
+ max_simd_bits=spec.WGT_ELEM_BITS,
+ max_num_bits=spec.WGT_BUFF_SIZE * 8,
+ head_address=None)
@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
def mem_info_acc_buffer():
spec = get_env()
- return tvm.make.node("MemoryInfo",
- unit_bits=spec.ACC_ELEM_BITS,
- max_simd_bits=spec.ACC_ELEM_BITS,
- max_num_bits=spec.ACC_BUFF_SIZE * 8,
- head_address=None)
+ return tvm.ir.make_node("MemoryInfo",
+ unit_bits=spec.ACC_ELEM_BITS,
+ max_simd_bits=spec.ACC_ELEM_BITS,
+ max_num_bits=spec.ACC_BUFF_SIZE * 8,
+ head_address=None)
# TVM related registration
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync")
diff --git a/python/vta/intrin.py b/python/vta/intrin.py
index 77c7ff2..a43fc75 100644
--- a/python/vta/intrin.py
+++ b/python/vta/intrin.py
@@ -98,7 +98,7 @@
0, 0, 0))
return irb.get()
# return a triple of normal-set, reset, update
- nop = tvm.make.Evaluate(0)
+ nop = tvm.tir.Evaluate(0)
if mock:
return (nop, nop, nop)
return (instr(0), instr(1), instr(2))
diff --git a/python/vta/ir_pass.py b/python/vta/ir_pass.py
index e42e3a0..8b8a2f0 100644
--- a/python/vta/ir_pass.py
+++ b/python/vta/ir_pass.py
@@ -59,8 +59,8 @@
def _fold_outermost_loop(body):
stmt = body
- while not isinstance(stmt, tvm.stmt.For):
- if isinstance(stmt, (tvm.stmt.ProducerConsumer,)):
+ while not isinstance(stmt, tvm.tir.For):
+ if isinstance(stmt, (tvm.tir.ProducerConsumer,)):
stmt = stmt.body
else:
return None, body, None
@@ -70,7 +70,7 @@
fail = [False]
def _post_order(op):
- assert isinstance(op, tvm.expr.Call)
+ assert isinstance(op, tvm.tir.Call)
base_args = 2
if op.name == "VTAUopPush":
args = []
@@ -112,7 +112,7 @@
def _do_fold(stmt):
if (stmt.attr_key == "coproc_uop_scope" and
- isinstance(stmt.value, tvm.expr.StringImm) and
+ isinstance(stmt.value, tvm.tir.StringImm) and
stmt.value.value == env.dev.vta_push_uop.value):
body = stmt.body
begins = []
@@ -133,8 +133,8 @@
if body == stmt.body:
return stmt
ends = list(reversed(ends))
- body = tvm.stmt.stmt_seq(*(begins + [body] + ends))
- return tvm.make.AttrStmt(
+ body = tvm.tir.stmt_seq(*(begins + [body] + ends))
+ return tvm.tir.AttrStmt(
stmt.node, stmt.attr_key, stmt.value, body)
return None
out = tvm.ir_pass.IRTransform(
@@ -163,40 +163,40 @@
env = get_env()
rw_info = {}
def _post_order(op):
- if isinstance(op, tvm.stmt.Allocate):
+ if isinstance(op, tvm.tir.Allocate):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
return None
new_var = rw_info[buffer_var]
- let_stmt = tvm.make.LetStmt(
+ let_stmt = tvm.tir.LetStmt(
new_var, tvm.call_extern(
"handle", "VTABufferCPUPtr",
env.dev.command_handle,
buffer_var), op.body)
- alloc = tvm.make.Allocate(
+ alloc = tvm.tir.Allocate(
buffer_var, op.dtype, op.extents,
op.condition, let_stmt)
del rw_info[buffer_var]
return alloc
- if isinstance(op, tvm.expr.Load):
+ if isinstance(op, tvm.tir.Load):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = tvm.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
- return tvm.make.Load(op.dtype, new_var, op.index)
- if isinstance(op, tvm.stmt.Store):
+ return tvm.tir.Load(op.dtype, new_var, op.index)
+ if isinstance(op, tvm.tir.Store):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = tvm.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
- return tvm.make.Store(new_var, op.value, op.index)
+ return tvm.tir.Store(new_var, op.value, op.index)
raise RuntimeError("not reached")
stmt = tvm.ir_pass.IRTransform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
for buffer_var, new_var in rw_info.items():
- stmt = tvm.make.LetStmt(
+ stmt = tvm.tir.LetStmt(
new_var, tvm.call_extern(
"handle", "VTABufferCPUPtr",
env.dev.command_handle,
@@ -222,15 +222,15 @@
for op in slist:
if op.body == body:
body = op
- elif isinstance(op, tvm.stmt.Allocate):
- body = tvm.make.Allocate(
+ elif isinstance(op, tvm.tir.Allocate):
+ body = tvm.tir.Allocate(
op.buffer_var, op.dtype,
op.extents, op.condition, body)
- elif isinstance(op, tvm.stmt.AttrStmt):
- body = tvm.make.AttrStmt(
+ elif isinstance(op, tvm.tir.AttrStmt):
+ body = tvm.tir.AttrStmt(
op.node, op.attr_key, op.value, body)
- elif isinstance(op, tvm.stmt.For):
- body = tvm.make.For(
+ elif isinstance(op, tvm.tir.For):
+ body = tvm.tir.For(
op.loop_var, op.min, op.extent, op.for_type,
op.device_api, body)
else:
@@ -239,24 +239,24 @@
return body
def _pre_order(op):
- if isinstance(op, tvm.stmt.For):
+ if isinstance(op, tvm.tir.For):
lift_stmt.append([])
- elif isinstance(op, tvm.stmt.AttrStmt):
+ elif isinstance(op, tvm.tir.AttrStmt):
if op.attr_key == "virtual_thread":
lift_stmt.append([])
def _post_order(op):
- if isinstance(op, tvm.stmt.Allocate):
+ if isinstance(op, tvm.tir.Allocate):
lift_stmt[-1].append(op)
return op.body
- if isinstance(op, tvm.stmt.AttrStmt):
+ if isinstance(op, tvm.tir.AttrStmt):
if op.attr_key == "storage_scope":
lift_stmt[-1].append(op)
return op.body
if op.attr_key == "virtual_thread":
return _merge_block(lift_stmt.pop() + [op], op.body)
return op
- if isinstance(op, tvm.stmt.For):
+ if isinstance(op, tvm.tir.For):
return _merge_block(lift_stmt.pop() + [op], op.body)
raise RuntimeError("not reached")
stmt = tvm.ir_pass.IRTransform(
@@ -280,7 +280,7 @@
"""
def _do_fold(stmt):
if _match_pragma(stmt, "skip_dma_copy"):
- return tvm.make.Evaluate(0)
+ return tvm.tir.Evaluate(0)
return None
return tvm.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
@@ -303,13 +303,13 @@
def _do_fold(stmt):
if _match_pragma(stmt, "coproc_sync"):
success[0] = True
- sync = tvm.make.Call(
- "int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0)
- return tvm.stmt.SeqStmt([stmt.body, tvm.make.Evaluate(sync)])
+ sync = tvm.tir.Call(
+ "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0)
+ return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
if _match_pragma(stmt, "trim_loop"):
op = stmt.body
- assert isinstance(op, tvm.stmt.For)
- return tvm.make.For(
+ assert isinstance(op, tvm.tir.For)
+ return tvm.tir.For(
op.loop_var, op.min, 2, op.for_type,
op.device_api, op.body)
return None
@@ -640,9 +640,9 @@
selects = []
def _find_basics(op):
- if isinstance(op, tvm.expr.Call):
+ if isinstance(op, tvm.tir.Call):
calls.append(op)
- elif isinstance(op, tvm.expr.Select):
+ elif isinstance(op, tvm.tir.Select):
selects.append(op)
def _do_fold(op):
@@ -665,7 +665,7 @@
args = op.body.body.args
res_tensor = op.body.body.func.output(0)
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
- inner = tvm.make.AttrStmt(
+ inner = tvm.tir.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope',
tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
@@ -697,19 +697,19 @@
args = conv_call.args
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_OUT)
- inner = tvm.make.AttrStmt(
+ inner = tvm.tir.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope',
tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = kernel_call.args
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
- inner = tvm.make.AttrStmt(
+ inner = tvm.tir.AttrStmt(
[dwgt, kernel_tensor], 'buffer_bind_scope',
tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = data_call.args
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
1, 0, 1, 0, env.BLOCK_IN)
- inner = tvm.make.AttrStmt(
+ inner = tvm.tir.AttrStmt(
[dinp, pad_data_tensor], 'buffer_bind_scope',
tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
@@ -739,11 +739,11 @@
irb.scope_attr(env.dev.vta_axis, "coproc_scope",
env.dev.get_task_qid(env.dev.QID_COMPUTE))
irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
- tvm.make.StringImm("VTAPushALUOp"))
+ tvm.tir.StringImm("VTAPushALUOp"))
irb.emit(stmt)
return irb.get()
if _match_pragma(stmt, "skip_alu"):
- return tvm.make.Evaluate(0)
+ return tvm.tir.Evaluate(0)
return stmt
stmt_out = tvm.ir_pass.IRTransform(
@@ -810,7 +810,7 @@
# Get to the innermost loop body
loop_body = stmt.body
nest_size = 0
- while isinstance(loop_body, tvm.stmt.For):
+ while isinstance(loop_body, tvm.tir.For):
loop_body = loop_body.body
nest_size += 1
# Get the src/dst arguments
@@ -825,27 +825,27 @@
extents.append(tmp_body.extent)
tmp_body = tmp_body.body
# Derive opcode
- if isinstance(loop_body.value, tvm.expr.Add):
+ if isinstance(loop_body.value, tvm.tir.Add):
alu_opcode = env.dev.ALU_OPCODE_ADD
lhs = loop_body.value.a
rhs = loop_body.value.b
- elif isinstance(loop_body.value, tvm.expr.Sub):
+ elif isinstance(loop_body.value, tvm.tir.Sub):
alu_opcode = env.dev.ALU_OPCODE_SUB
lhs = loop_body.value.a
rhs = loop_body.value.b
- elif isinstance(loop_body.value, tvm.expr.Mul):
+ elif isinstance(loop_body.value, tvm.tir.Mul):
alu_opcode = env.dev.ALU_OPCODE_MUL
lhs = loop_body.value.a
rhs = loop_body.value.b
- elif isinstance(loop_body.value, tvm.expr.Min):
+ elif isinstance(loop_body.value, tvm.tir.Min):
alu_opcode = env.dev.ALU_OPCODE_MIN
lhs = loop_body.value.a
rhs = loop_body.value.b
- elif isinstance(loop_body.value, tvm.expr.Max):
+ elif isinstance(loop_body.value, tvm.tir.Max):
alu_opcode = env.dev.ALU_OPCODE_MAX
lhs = loop_body.value.a
rhs = loop_body.value.b
- elif isinstance(loop_body.value, tvm.expr.Call):
+ elif isinstance(loop_body.value, tvm.tir.Call):
if loop_body.value.name == 'shift_left':
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
@@ -857,7 +857,7 @@
else:
raise RuntimeError(
"Function call not recognized %s" % (loop_body.value.name))
- elif isinstance(loop_body.value, tvm.expr.Load):
+ elif isinstance(loop_body.value, tvm.tir.Load):
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value
rhs = tvm.const(0, "int32")
@@ -871,12 +871,12 @@
# Check if lhs/rhs is immediate
use_imm = False
imm_val = None
- if isinstance(rhs, tvm.expr.IntImm):
+ if isinstance(rhs, tvm.tir.IntImm):
assert lhs.buffer_var.same_as(dst_var)
src_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices)
use_imm = True
imm_val = rhs
- if isinstance(lhs, tvm.expr.IntImm):
+ if isinstance(lhs, tvm.tir.IntImm):
assert rhs.buffer_var.same_as(dst_var)
src_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices)
use_imm = True