blob: e8463027dbe5cfee7bc0071903a3e39555dce2a1 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import tir
from tvm.tir import (
EQ,
LT,
Add,
Cast,
Evaluate,
FloatImm,
For,
IfThenElse,
IntImm,
Max,
Min,
Mul,
PyStmtExprMutator,
PyStmtExprVisitor,
StringImm,
Sub,
Var,
)
class ASTLog:
"""Helper class to log AST"""
def __init__(self) -> None:
self.log = []
self.indent = "\t"
self.level = 0
def push_scope(self):
self.level += 1
def pop_scope(self):
self.level -= 1
def add(self, s: str):
self.log.append(self.indent * self.level + s)
def __str__(self) -> str:
return "\n".join(self.log)
@tir.functor.visitor
class ASTPrinter(PyStmtExprVisitor):
"""Print tir AST in structured format. The shape of Node is ignored."""
def __init__(self) -> None:
super().__init__()
self.log = ASTLog()
def visit_var_(self, op: Var) -> None:
self.log.add("Stmt: Var")
super().visit_var_(op)
def visit_add_(self, op: Add) -> None:
self.log.add("Stmt: Add")
super().visit_add_(op)
@tir.functor.visitor
class SimpleExprCounter(PyStmtExprVisitor):
"""Count expressions without recursion"""
def __init__(self):
super().__init__()
self.var_count = 0
self.add_count = 0
self.mul_count = 0
def visit_var_(self, op: Var):
self.var_count += 1
# Don't recursively visit children to avoid infinite recursion
def visit_add_(self, op: Add):
self.add_count += 1
# Visit children manually
super().visit_add_(op)
def visit_mul_(self, op: Mul):
self.mul_count += 1
# Visit children manually
super().visit_mul_(op)
@tir.functor.mutator
class VariableReplacer(PyStmtExprMutator):
"""Replace variables with constants"""
def __init__(self, replacements):
super().__init__()
self.replacements = replacements
def visit_var_(self, op: Var):
if op.name in self.replacements:
return IntImm("int32", self.replacements[op.name])
return op
@tir.functor.mutator
class AddToSubMutator(PyStmtExprMutator):
"""Convert Add operations to Sub operations"""
def visit_add_(self, op: Add):
# First mutate the operands
a = self.visit_expr(op.a)
b = self.visit_expr(op.b)
# Convert Add to Sub
return Sub(a, b)
@tir.functor.visitor
class SimpleStmtCounter(PyStmtExprVisitor):
"""Count statements without recursion"""
def __init__(self):
super().__init__()
self.for_count = 0
self.if_count = 0
self.evaluate_count = 0
def visit_for_(self, op: For):
self.for_count += 1
super().visit_for_(op)
def visit_if_then_else_(self, op: IfThenElse):
self.if_count += 1
super().visit_if_then_else_(op)
def visit_evaluate_(self, op: Evaluate):
self.evaluate_count += 1
super().visit_evaluate_(op)
@tir.functor.mutator
class ForLoopUnroller(PyStmtExprMutator):
"""Simple loop unroller for demonstration"""
def __init__(self, unroll_factor=2):
super().__init__()
self.unroll_factor = unroll_factor
def visit_for_(self, op: For):
# For demonstration, just return the original for now
# In a real implementation, we would unroll small loops
return super().visit_for_(op)
@tir.functor.visitor
class SimpleStmtExprVisitor(PyStmtExprVisitor):
"""Visitor that handles both statements and expressions"""
def __init__(self):
super().__init__()
self.expr_count = 0
self.stmt_count = 0
self.var_names = set()
def visit_var_(self, op: Var):
self.var_names.add(op.name)
self.expr_count += 1
def visit_evaluate_(self, op: Evaluate):
self.stmt_count += 1
# Visit the expression
self.visit_expr(op.value)
@tir.functor.mutator
class ComplexMutator(PyStmtExprMutator):
"""Mutator that handles both statements and expressions"""
def __init__(self):
super().__init__()
self.modifications = 0
def visit_add_(self, op: Add):
self.modifications += 1
# Convert a + b to a * 2 + b for demonstration
a = self.visit_expr(op.a)
b = self.visit_expr(op.b)
return Add(Mul(a, IntImm("int32", 2)), b)
def test_basic_visitor():
"""Test the basic AST printer visitor"""
expr = Add(Var("x", dtype="int32"), Var("y", dtype="int32"))
printer = ASTPrinter()
printer.visit_expr(expr)
assert str(printer.log) == "\n".join(["Stmt: Add", "Stmt: Var", "Stmt: Var"])
def test_simple_expr_counter():
"""Test simple expression counting visitor"""
x = Var("x", dtype="int32")
y = Var("y", dtype="int32")
# Create simple expression: x + y
expr = Add(x, y)
counter = SimpleExprCounter()
counter.visit_expr(expr)
assert counter.var_count == 2 # x and y
assert counter.add_count == 1 # one add
def test_variable_replacer():
"""Test expression mutator that replaces variables"""
x = Var("x", dtype="int32")
y = Var("y", dtype="int32")
expr = Add(x, Mul(y, IntImm("int32", 3)))
replacer = VariableReplacer({"x": 10, "y": 5})
result = replacer.visit_expr(expr)
# Should be Add(IntImm(10), Mul(IntImm(5), IntImm(3)))
assert isinstance(result, Add)
assert isinstance(result.a, IntImm)
assert result.a.value == 10
assert isinstance(result.b, Mul)
assert isinstance(result.b.a, IntImm)
assert result.b.a.value == 5
def test_add_to_sub_mutator():
"""Test mutator that converts Add to Sub"""
x = Var("x", dtype="int32")
y = Var("y", dtype="int32")
expr = Add(x, y)
mutator = AddToSubMutator()
result = mutator.visit_expr(expr)
assert isinstance(result, Sub)
assert isinstance(result.a, Var)
assert isinstance(result.b, Var)
assert result.a.name == "x"
assert result.b.name == "y"
def test_simple_stmt_counter():
"""Test statement visitor that counts statements"""
i = Var("i", dtype="int32")
# Create a simple for loop
loop_body = Evaluate(IntImm("int32", 0))
for_stmt = For(i, IntImm("int32", 0), IntImm("int32", 10), tir.ForKind.SERIAL, loop_body)
counter = SimpleStmtCounter()
counter.visit_stmt(for_stmt)
assert counter.for_count == 1 # One for loop
assert counter.evaluate_count == 1 # One evaluate in the body
def test_if_then_else_visitor():
"""Test visitor with if-then-else statements"""
x = Var("x", dtype="int32")
condition = EQ(x, IntImm("int32", 0))
then_stmt = Evaluate(IntImm("int32", 1))
else_stmt = Evaluate(IntImm("int32", 2))
if_stmt = IfThenElse(condition, then_stmt, else_stmt)
counter = SimpleStmtCounter()
counter.visit_stmt(if_stmt)
assert counter.if_count == 1
assert counter.for_count == 0
def test_simple_stmt_expr_visitor():
"""Test stmt_expr_visitor with mixed statements and expressions"""
x = Var("x", dtype="int32")
y = Var("y", dtype="int32")
# Create an evaluate statement with an expression
expr = Add(x, y)
stmt = Evaluate(expr)
visitor = SimpleStmtExprVisitor()
visitor.visit_stmt(stmt)
assert visitor.stmt_count == 1 # One Evaluate statement
assert visitor.expr_count == 2 # Two variables
assert "x" in visitor.var_names
assert "y" in visitor.var_names
def test_complex_mutator():
"""Test stmt_expr_mutator"""
x = Var("x", dtype="int32")
y = Var("y", dtype="int32")
# Expression with Add operations
expr = Add(x, y)
stmt = Evaluate(expr)
mutator = ComplexMutator()
result = mutator.visit_stmt(stmt)
print(type(mutator))
assert mutator.modifications == 1 # One Add operation modified
assert isinstance(result, Evaluate)
# Check that the expression was modified
modified_expr = result.value
assert isinstance(modified_expr, Add)
assert isinstance(modified_expr.a, Mul) # First operand should be multiplied by 2
def test_different_expr_types():
"""Test visitor with various expression types"""
x = Var("x", dtype="int32")
# Test different expression types individually
exprs = [
IntImm("int32", 42),
FloatImm("float32", 3.14),
StringImm("hello"),
Cast("float32", x),
Min(x, IntImm("int32", 10)),
Max(x, IntImm("int32", 0)),
LT(x, IntImm("int32", 5)),
]
# Just test that we can create and visit each type
counter = SimpleExprCounter()
for expr in exprs:
try:
counter.visit_expr(expr)
except Exception as e:
# Some expressions might not be supported, that's ok
pass
def test_decorator_functionality():
"""Test that decorators work correctly"""
# Test that decorated classes are properly wrapped
visitor = SimpleExprCounter()
assert hasattr(visitor, "_outer") # Should have the wrapper functionality
mutator = VariableReplacer({})
assert hasattr(mutator, "_outer")
def test_empty_expressions():
"""Test handling of simple expressions"""
counter = SimpleExprCounter()
# Test with just a variable
x = Var("x", dtype="int32")
counter.visit_expr(x)
assert counter.var_count == 1
# Test with just a constant
counter = SimpleExprCounter()
const = IntImm("int32", 5)
counter.visit_expr(const)
# Constants don't increase var_count
assert counter.var_count == 0
def test_stmt_mutator():
"""Test basic statement mutator functionality"""
x = Var("x", dtype="int32")
stmt = Evaluate(Add(x, IntImm("int32", 1)))
unroller = ForLoopUnroller()
result = unroller.visit_stmt(stmt)
# Should return the same statement (no actual unrolling implemented)
assert isinstance(result, Evaluate)
def test_nested_expressions():
"""Test with nested expressions"""
x = Var("x", dtype="int32")
y = Var("y", dtype="int32")
z = Var("z", dtype="int32")
# Create nested expression: (x + y) * z
inner_add = Add(x, y)
expr = Mul(inner_add, z)
counter = SimpleExprCounter()
counter.visit_expr(expr)
assert counter.var_count == 3 # x, y, z
assert counter.add_count == 1 # one add
assert counter.mul_count == 1 # one mul
def test_simple_mutations():
"""Test simple expression mutations"""
x = Var("x", dtype="int32")
y = Var("y", dtype="int32")
# Test multiple replacements
expr = Add(x, y)
replacer = VariableReplacer({"x": 1, "y": 2})
result = replacer.visit_expr(expr)
assert isinstance(result, Add)
assert isinstance(result.a, IntImm)
assert isinstance(result.b, IntImm)
assert result.a.value == 1
assert result.b.value == 2
if __name__ == "__main__":
test_basic_visitor()
tvm.testing.main()