blob: 46d46e239d3a180e3c4a3714a1dde9e8af38196d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relax, tir
from tvm.ir import Op
from tvm.ir.base import assert_structural_equal
from tvm.relax import PyExprMutator, PyExprVisitor
from tvm.relax.expr import (
BindingBlock,
Call,
Constant,
DataflowBlock,
DataflowVar,
Expr,
ExternFunc,
Function,
GlobalVar,
If,
MatchCast,
SeqExpr,
ShapeExpr,
Tuple,
TupleGetItem,
PrimValue,
StringImm,
DataTypeImm,
Var,
VarBinding,
)
from tvm.script import relax as R
import pytest
import tvm.testing
m, n = tir.Var("m", "int64"), tir.Var("n", "int64")
x = relax.Var("x", R.Tensor([n], "float32"))
y = relax.Var("y", R.Tensor([m, n], "float32"))
bb = relax.BlockBuilder()
@relax.expr_functor.visitor
class BasicVisitor(PyExprVisitor):
"""Default ExprVisitor"""
class ASTLog:
"""Helper class to log AST"""
def __init__(self) -> None:
self.log = []
self.indent = "\t"
self.level = 0
def push_scope(self):
self.level += 1
def pop_scope(self):
self.level -= 1
def add(self, s: str):
self.log.append(self.indent * self.level + s)
def __str__(self) -> str:
return "\n".join(self.log)
@relax.expr_functor.visitor
class ASTPrinter(PyExprVisitor):
"""Print relax AST in structured format. The shape of Node is ignored."""
def __init__(self) -> None:
super().__init__()
self.log = ASTLog()
def visit_constant_(self, op: Constant) -> None:
self.log.add("Constant")
def visit_global_var_(self, op: GlobalVar) -> None:
self.log.add("GlobalVar")
def visit_tuple_(self, op: Tuple) -> None:
self.log.add("Tuple")
self.log.push_scope()
for field in op.fields:
self.visit_expr(field)
self.log.pop_scope()
def visit_var_(self, op: Var) -> None:
self.log.add("Var")
def visit_dataflow_var_(self, op: DataflowVar) -> None:
self.log.add("DataflowVar")
def visit_function_(self, op: Function) -> None:
self.log.add("Function")
self.log.push_scope()
for param in op.params:
self.visit_var_def(param)
self.visit_expr(op.body)
self.log.pop_scope()
def visit_call_(self, op: Call) -> None:
self.log.add("Call")
self.log.push_scope()
self.visit_expr(op.op)
for arg in op.args:
self.visit_expr(arg)
self.log.pop_scope()
def visit_if_(self, op: If) -> None:
self.log.add("If")
self.log.push_scope()
self.visit_expr(op.cond)
self.visit_expr(op.true_branch)
self.visit_expr(op.false_branch)
self.log.pop_scope()
def visit_op_(self, op: Op) -> None:
self.log.add("Op")
def visit_tuple_getitem_(self, op: TupleGetItem) -> None:
self.log.add("TupleGetItem")
self.log.push_scope()
self.visit_expr(op.tuple_value)
self.log.pop_scope()
def visit_prim_value_(self, op: PrimValue) -> None:
self.log.add("PrimValue")
def visit_string_imm_(self, op: StringImm) -> None:
self.log.add("StringImm")
def visit_data_type_imm_(self, op: DataTypeImm) -> None:
self.log.add("DataTypeImm")
def visit_shape_expr_(self, op: ShapeExpr) -> None:
self.log.add("ShapeExpr")
def visit_extern_func_(self, op: ExternFunc) -> None:
self.log.add("ExternFunc")
def visit_seq_expr_(self, op: SeqExpr) -> None:
self.log.add("SeqExpr")
self.log.push_scope()
for block in op.blocks:
self.visit_binding_block(block)
self.visit_expr(op.body)
self.log.pop_scope()
def visit_var_binding_(self, binding: VarBinding) -> None:
self.log.add("VarBinding")
self.log.push_scope()
self.visit_expr(binding.value)
self.visit_var_def(binding.var)
self.log.pop_scope()
def visit_match_cast_(self, binding: MatchCast) -> None:
self.log.add("MatchCast")
self.log.push_scope()
self.visit_var_def(binding.var)
self.visit_expr(binding.value)
self.log.pop_scope()
def visit_binding_block_(self, block: BindingBlock) -> None:
self.log.add("BindingBlock")
self.log.push_scope()
for binding in block.bindings:
self.visit_binding(binding)
self.log.pop_scope()
def visit_dataflow_block_(self, block: DataflowBlock) -> None:
self.log.add("DataflowBlock")
self.log.push_scope()
for binding in block.bindings:
self.visit_binding(binding)
self.log.pop_scope()
def visit_var_def_(self, var: Var) -> None:
self.log.add("VarDef")
def visit_dataflow_var_def_(self, var: DataflowVar) -> None:
self.log.add("DataflowVarDef")
@relax.expr_functor.mutator
class BasicMutator(PyExprMutator):
"""Default ExprMutator"""
@relax.expr_functor.mutator
class ASTPostPrinterMutator(PyExprMutator):
"""Print relax AST in the post order format."""
def __init__(self) -> None:
super().__init__()
self.log = ASTLog()
def visit_constant_(self, op: Constant) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("Constant")
return op
def visit_global_var_(self, op: GlobalVar) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("GlobalVar")
return op
def visit_tuple_(self, op: Tuple) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("Tuple")
return op
def visit_var_(self, op: Var) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("Var")
return op
def visit_dataflow_var_(self, op: DataflowVar) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("DataflowVar")
return op
def visit_function_(self, op: Function) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("Function")
return op
def visit_call_(self, op: Call) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("Call")
return op
def visit_if_(self, op: If) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("If")
return op
def visit_op_(self, op: Op) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("Op")
return op
def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("TupleGetItem")
return op
def visit_prim_value_(self, op: PrimValue) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("PrimValue")
return op
def visit_string_imm_(self, op: StringImm) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("StringImm")
return op
def visit_data_type_imm_(self, op: DataTypeImm) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("DataTypeImm")
return op
def visit_shape_expr_(self, op: ShapeExpr) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("ShapeExpr")
return op
def visit_extern_func_(self, op: ExternFunc) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("ExternFunc")
return op
def visit_seq_expr_(self, op: SeqExpr) -> Expr:
op = self.visit_expr_post_order(op)
self.log.add("SeqExpr")
return op
def visit_var_binding_(self, binding: VarBinding) -> None:
"""Identical with ExprMutator::VisitBinding_(const VarBindingNode* binding) on the C++ side."""
new_value = self.visit_expr(binding.value)
new_var = self.visit_var_def(binding.var)
self.log.add("VarBinding")
if binding.var.same_as(new_var) and binding.value.same_as(new_value):
self.builder_.emit_normalized(binding)
return
temp = self.with_struct_info(new_var, new_value.struct_info)
if not temp.same_as(new_var):
new_var = temp
self.set_var_remap(binding.var.vid, new_var)
self.builder_.emit_normalized(VarBinding(new_var, new_value))
def visit_match_cast_(self, binding: MatchCast) -> None:
"""Identical with ExprMutator::VisitBinding_(const MatchCastNode* binding) on the C++ side."""
new_var = self.visit_var_def(binding.var)
new_value = self.visit_expr(binding.value)
temp = self.with_struct_info(new_var, binding.struct_info)
if not temp.same_as(new_var):
new_var = temp
self.set_var_remap(binding.var.vid, new_var)
self.log.add("MatchCast")
self.builder_.emit_normalized(MatchCast(new_var, new_value, binding.struct_info))
def visit_binding_block_(self, block: BindingBlock) -> BindingBlock:
"""Identical with ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) on the C++ side."""
self.builder_._begin_binding_block()
for binding in block.bindings:
self.visit_binding(binding)
self.log.add("BindingBlock")
return self.builder_._end_block()
def visit_dataflow_block_(self, block: DataflowBlock) -> None:
"""Identical with ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) on the C++ side."""
self.builder_._begin_dataflow_block()
for binding in block.bindings:
self.visit_binding(binding)
self.log.add("DataflowBlock")
return self.builder_._end_block()
def visit_var_def_(self, var: Var) -> None:
"""Identical with ExprMutator::VisitVarDef_(const VarNode* var) on the C++ side."""
self.log.add("VarDef")
return var
def visit_dataflow_var_def_(self, var: DataflowVar) -> None:
"""Identical with ExprMutator::VisitVarDef_(const DataflowVarNode* var) on the C++ side."""
self.log.add("DataflowVarDef")
return var
def basic_check(expr, visitor_str, mutator_str):
def visit(f, expr):
if isinstance(expr, relax.Expr):
return f.visit_expr(expr)
elif isinstance(expr, relax.BindingBlock):
return f.visit_binding_block(expr)
# check no overloading case
basic_visitor = BasicVisitor()
visit(basic_visitor, expr)
# check the output log
log_visitor = ASTPrinter()
visit(log_visitor, expr)
assert str(log_visitor.log) == visitor_str
# check no overloading case
basic_mutator = BasicMutator()
# skip normalize GlobalVar since it requires context IRModule to get the struct_info_
if isinstance(expr, relax.Expr) and not isinstance(expr, relax.GlobalVar):
expr = bb.normalize(expr)
assert_structural_equal(visit(basic_mutator, expr), expr)
# check the output log and return value
post_log_mutator = ASTPostPrinterMutator()
if isinstance(expr, relax.Expr) and not isinstance(expr, relax.GlobalVar):
expr = bb.normalize(expr)
assert_structural_equal(visit(post_log_mutator, expr), expr)
assert str(post_log_mutator.log) == mutator_str
def test_constant():
basic_check(relax.const(1.0), "Constant", "Constant")
def test_var():
basic_check(x, "Var", "Var")
def test_dataflow_var():
lv = relax.DataflowVar("lv", R.Tensor([n], "float32"))
basic_check(lv, "DataflowVar", "DataflowVar")
def test_tuple():
t = relax.Tuple([x, y])
basic_check(t, "\n".join(["Tuple", "\tVar", "\tVar"]), "\n".join(["Var", "Var", "Tuple"]))
def test_global_var():
gv = relax.GlobalVar("gv")
basic_check(gv, "GlobalVar", "GlobalVar")
def test_seq_expr():
bindings = [relax.VarBinding(x, relax.const(1))]
blocks = [relax.BindingBlock(bindings)]
seq_expr = relax.SeqExpr(blocks, x)
basic_check(
seq_expr,
"\n".join(
[
"SeqExpr",
"\tBindingBlock",
"\t\tVarBinding",
"\t\t\tConstant",
"\t\t\tVarDef",
"\tVar",
]
),
"\n".join(["Constant", "VarDef", "VarBinding", "BindingBlock", "Var", "SeqExpr"]),
)
def test_shape_expr():
x = relax.ShapeExpr([m, n])
basic_check(x, "ShapeExpr", "ShapeExpr")
def test_call():
call_node = relax.op.add(x, y)
basic_check(
call_node,
"\n".join(["Call", "\tOp", "\tVar", "\tVar"]),
"\n".join(["Op", "Var", "Var", "ShapeExpr", "Call"]),
)
def test_if():
if_node = relax.If(x, x, x)
basic_check(
if_node,
"\n".join(["If", "\tVar", "\tSeqExpr", "\t\tVar", "\tSeqExpr", "\t\tVar"]),
"\n".join(["Var", "Var", "SeqExpr", "Var", "SeqExpr", "If"]),
)
def test_tuple_getitem():
tuple_getitem_node = relax.TupleGetItem(relax.Tuple([x, y]), 0)
basic_check(
tuple_getitem_node,
"\n".join(["TupleGetItem", "\tTuple", "\t\tVar", "\t\tVar"]),
"\n".join(["Var", "Var", "Tuple", "TupleGetItem"]),
)
def test_binding_block():
bb._begin_binding_block()
gv0 = bb.emit(relax.op.add(x, y))
gv1 = bb.match_cast(y, R.Tensor([m, n], "float32"))
b0 = bb._end_block()
basic_check(
b0,
"\n".join(
[
"BindingBlock",
"\tVarBinding",
"\t\tCall",
"\t\t\tOp",
"\t\t\tVar",
"\t\t\tVar",
"\t\tVarDef",
"\tMatchCast",
"\t\tVarDef",
"\t\tVar",
]
),
"\n".join(
[
"Op",
"Var",
"Var",
"Call",
"ShapeExpr",
"VarDef",
"VarBinding",
"Var",
"ShapeExpr",
"ShapeExpr",
"VarDef",
"MatchCast",
"BindingBlock",
]
),
)
def test_dataflow_block():
bb._begin_dataflow_block()
lv0 = bb.emit(relax.op.add(x, y))
gv1 = bb.match_cast(y, R.Tensor([m, n], "float32"))
b0 = bb._end_block()
basic_check(
b0,
"\n".join(
[
"DataflowBlock",
"\tVarBinding",
"\t\tCall",
"\t\t\tOp",
"\t\t\tVar",
"\t\t\tVar",
"\t\tDataflowVarDef",
"\tMatchCast",
"\t\tDataflowVarDef",
"\t\tVar",
]
),
"\n".join(
[
"Op",
"Var",
"Var",
"Call",
"ShapeExpr",
"DataflowVarDef",
"VarBinding",
"Var",
"ShapeExpr",
"ShapeExpr",
"DataflowVarDef",
"MatchCast",
"DataflowBlock",
]
),
)
def test_function():
bindings = [relax.VarBinding(x, relax.const(1))]
blocks = [relax.BindingBlock(bindings)]
seq_expr = relax.SeqExpr(blocks, x)
func = relax.Function([x], seq_expr, R.Tensor([n], "float32"))
basic_check(
func,
"\n".join(
[
"Function",
"\tVarDef",
"\tSeqExpr",
"\t\tBindingBlock",
"\t\t\tVarBinding",
"\t\t\t\tConstant",
"\t\t\t\tVarDef",
"\t\tVar",
]
),
"\n".join(
[
"VarDef",
"Constant",
"VarDef",
"VarBinding",
"BindingBlock",
"Var",
"SeqExpr",
"Function",
]
),
)
def test_extern_func():
func = relax.ExternFunc("f")
basic_check(func, "ExternFunc", "ExternFunc")
def test_inherit():
# The internal class is not instantiated.
class InternalVisitor(PyExprVisitor):
def __init__(self) -> None:
super().__init__()
self.log = ASTLog()
def visit_call_(self, op: Call) -> None:
self.log.add("InternalCall")
self.log.push_scope()
self.visit_expr(op.op)
for arg in op.args:
self.visit_expr(arg)
self.log.pop_scope()
def visit_var_(self, op: Var) -> None:
self.log.add("Var")
def visit_op_(self, op: Op) -> None:
self.log.add("Op")
@relax.expr_functor.visitor
class LeafVisitor(InternalVisitor):
def visit_call_(self, op: Call) -> None:
self.log.add("LeafCall")
self.log.push_scope()
self.visit_expr(op.op)
for arg in op.args:
self.visit_expr(arg)
self.log.pop_scope()
call_node = relax.op.add(x, y)
lv = LeafVisitor()
lv.visit_expr(call_node)
assert str(lv.log) == "\n".join(["LeafCall", "\tOp", "\tVar", "\tVar"])
def test_inherit_with_cls():
# The decorator converts `InternalVisitor` to a wrapper class.
@relax.expr_functor.visitor
class InternalVisitor(PyExprVisitor):
def __init__(self) -> None:
super().__init__()
self.log = ASTLog()
def visit_call_(self, op: Call) -> None:
self.log.add("InternalCall")
self.log.push_scope()
self.visit_expr(op.op)
for arg in op.args:
self.visit_expr(arg)
self.log.pop_scope()
def visit_var_(self, op: Var) -> None:
self.log.add("Var")
def visit_op_(self, op: Op) -> None:
self.log.add("Op")
# `InternalVisitor._cls` refers to the original `InternalVisitor` users defined.
@relax.expr_functor.visitor
class LeafVisitor(InternalVisitor._cls):
def visit_call_(self, op: Call) -> None:
self.log.add("LeafCall")
self.log.push_scope()
self.visit_expr(op.op)
for arg in op.args:
self.visit_expr(arg)
self.log.pop_scope()
call_node = relax.op.add(x, y)
iv = InternalVisitor()
iv.visit_expr(call_node)
assert str(iv.log) == "\n".join(["InternalCall", "\tOp", "\tVar", "\tVar"])
lv = LeafVisitor()
lv.visit_expr(call_node)
assert str(lv.log) == "\n".join(["LeafCall", "\tOp", "\tVar", "\tVar"])
def test_wrong_inherit():
@relax.expr_functor.visitor
class InternalVisitor(PyExprVisitor):
def visit_call_(self, op: Call) -> None:
pass
with pytest.raises(
TypeError,
match="Inheritance from a decorated object `LeafVisitor` is not allowed. Please inherit from `LeafVisitor._cls`.",
):
@relax.expr_functor.visitor
class LeafVisitor(InternalVisitor):
def visit_call_(self, op: Call) -> None:
pass
@R.function
def dummy(x: R.Tensor((10, 10))):
lv = R.add(x, R.const(1))
with R.dataflow():
gv = lv
R.output(gv)
return gv
def test_call_visitor_super():
@relax.expr_functor.visitor
class InternalVisitor(PyExprVisitor):
def __init__(self) -> None:
super().__init__()
self.log = ASTLog()
def visit_binding_block_(self, block: relax.BindingBlock) -> None:
self.log.add("BindingBlock")
super().visit_binding_block_(block)
def visit_dataflow_block_(self, block: DataflowBlock) -> None:
self.log.add("DataflowBlock")
super().visit_dataflow_block_(block)
def visit_var_binding_(self, binding: relax.VarBinding) -> None:
self.log.add("VarBinding")
super().visit_var_binding_(binding)
def visit_call_(self, op: Call) -> None:
self.log.add("InternalCall")
super().visit_call_(op) # call PyExprVisitor.visit_call_
def visit_var_def_(self, var: Var) -> None:
self.log.add("VarDef")
super().visit_var_def_(var)
def visit_dataflow_var_def_(self, var: Var) -> None:
self.log.add("DataflowVarDef")
super().visit_dataflow_var_def_(var)
def visit_var_(self, op: Var) -> None:
self.log.add("Var")
def visit_op_(self, op: Op) -> None:
self.log.add("Op")
@relax.expr_functor.visitor
class LeafVisitor(InternalVisitor._cls):
def visit_call_(self, op: Call) -> None:
self.log.add("LeafCall")
super().visit_call_(op) # call InternalVisit.visit_call_
call_node = relax.op.add(x, y)
iv = InternalVisitor()
iv.visit_expr(call_node)
assert str(iv.log) == "\n".join(["InternalCall", "Op", "Var", "Var"])
lv = LeafVisitor()
lv.visit_expr(call_node)
assert str(lv.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"])
lv = LeafVisitor()
lv.visit_expr(dummy)
assert str(lv.log) == "\n".join(
[
"VarDef",
"BindingBlock",
"VarBinding",
"LeafCall",
"InternalCall",
"Op",
"Var",
"VarDef",
"DataflowBlock",
"VarBinding",
"Var",
"VarDef",
"Var",
]
)
def test_call_mutator_super():
@relax.expr_functor.mutator
class InternalMutator(PyExprMutator):
def __init__(self) -> None:
super().__init__()
self.log = ASTLog()
def visit_binding_block_(self, block: relax.BindingBlock) -> None:
self.log.add("BindingBlock")
return super().visit_binding_block_(block)
def visit_dataflow_block_(self, block: DataflowBlock) -> None:
self.log.add("DataflowBlock")
return super().visit_dataflow_block_(block)
def visit_var_binding_(self, binding: relax.VarBinding) -> None:
self.log.add("VarBinding")
return super().visit_var_binding_(binding)
def visit_call_(self, op: Call) -> None:
self.log.add("InternalCall")
return super().visit_call_(op) # call PyExprMutator.visit_call_
def visit_var_def_(self, var: Var) -> None:
self.log.add("VarDef")
return super().visit_var_def_(var)
def visit_dataflow_var_def_(self, var: Var) -> None:
self.log.add("DataflowVarDef")
return super().visit_dataflow_var_def_(var)
def visit_var_(self, op: Var) -> None:
self.log.add("Var")
return super().visit_var_(op) # call PyExprMutator.visit_var_
def visit_op_(self, op: Op) -> None:
self.log.add("Op")
return super().visit_op_(op) # call PyExprMutator.visit_op_
@relax.expr_functor.mutator
class LeafMutator(InternalMutator._cls):
def visit_call_(self, op: Call) -> None:
self.log.add("LeafCall")
return super().visit_call_(op) # call InternalMutator.visit_call_
call_node = relax.op.add(x, y)
im = InternalMutator()
im.visit_expr(call_node)
assert str(im.log) == "\n".join(["InternalCall", "Op", "Var", "Var"])
lm = LeafMutator()
lm.visit_expr(call_node)
assert str(lm.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"])
lm = LeafMutator()
lm.visit_expr(dummy)
assert str(lm.log) == "\n".join(
[
"VarDef",
"BindingBlock",
"VarBinding",
"LeafCall",
"InternalCall",
"Op",
"Var",
"VarDef",
"DataflowBlock",
"VarBinding",
"Var",
"VarDef",
"Var",
]
)
def test_function_parameter_mutation():
@relax.expr_functor.mutator
class ParamMutator(PyExprMutator):
def __init__(self, shape_replacements):
super().__init__()
self.shape_replacements = shape_replacements
def visit_var_def_(self, var):
if var.name_hint in self.shape_replacements:
new_shape = self.shape_replacements[var.name_hint]
new_sinfo = relax.TensorStructInfo(new_shape, dtype=var.struct_info.dtype)
return relax.Var(f"{var.name_hint}_with_new_shape", new_sinfo)
else:
return var
@R.function(private=True)
def before(
A: R.Tensor((16, 32), "float32"), B: R.Tensor((32, 64), "float32")
) -> R.Tensor((16, 64), "float32"):
return R.matmul(A, B)
@R.function(private=True)
def expected(
A: R.Tensor((1, 32), "float32"), B: R.Tensor((32, 64), "float32")
) -> R.Tensor((1, 64), "float32"):
return R.matmul(A, B)
after = ParamMutator({"A": (1, 32)}).visit_expr(before)
tvm.ir.assert_structural_equal(expected, after)
if __name__ == "__main__":
tvm.testing.main()