blob: ab78ec0b3bc74c12819f09aefb75fd297e6e3802 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
class ExtractCompare(tvm.testing.CompareBeforeAfter):
transform = relax.transform.ConvertToDataflow()
# functions that will not change
class TestTrivial(ExtractCompare):
@I.ir_module
class Before:
# already a DF block
@R.function
def main(A: R.Tensor, B: R.Tensor):
with R.dataflow():
x = R.add(A, B)
y = R.multiply(x, A)
z = R.add(x, y)
q = R.multiply(y, z)
p = R.add(z, q)
R.output(p)
return p
# too small
@R.function
def func(A: R.Tensor, B: R.Tensor) -> R.Tensor:
x = R.add(A, B)
return x
# too few pure ops between non-dataflow ops
@R.function(pure=False)
def func2(A: R.Tensor, B: R.Tensor) -> R.Tensor:
_ = R.print(format="Hi there!")
y = R.add(A, B)
_ = R.print(y, format="Sum: {}")
x = R.multiply(y, y)
if R.const(False):
_ = R.print(format="True branch")
q = R.add(x, y)
_ = R.print(q, format="Value of q: {}")
w = q
else:
_ = R.print(format="False branch")
q = R.subtract(x, y)
_ = R.print(q, format="Value of q: {}")
w = q
p = R.multiply(w, w)
return p
Expected = Before
class TestBasic(ExtractCompare):
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
z = R.add(x, y)
w = R.multiply(z, y)
v = R.add(w, x)
return v
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(z, y)
v = R.add(w, x)
R.output(v)
return v
class TestMultipleBlocks(ExtractCompare):
@I.ir_module
class Before:
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
z = R.add(x, y)
w = R.multiply(z, y)
v = R.add(w, x)
_ = R.print(format="Hi mom!")
a = R.multiply(v, v)
b = R.add(a, a)
c = R.subtract(b, a)
d = R.add(c, c)
return d
@I.ir_module
class Expected:
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(z, y)
v = R.add(w, x)
R.output(v)
_ = R.print(format="Hi mom!")
with R.dataflow():
a = R.multiply(v, v)
b = R.add(a, a)
c = R.subtract(b, a)
d = R.add(c, c)
R.output(d)
return d
class TestExtractInsideBranches(ExtractCompare):
@I.ir_module
class Before:
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
z = R.add(x, y)
w = R.multiply(z, y)
v = R.add(w, x)
if R.const(True):
q = R.multiply(v, v)
a = R.add(q, q)
b = R.multiply(a, a)
else:
q = R.add(v, v)
a = R.multiply(q, q)
b = R.add(a, a)
c = R.multiply(b, b)
d = R.add(c, c)
e = R.multiply(d, d)
return e
@I.ir_module
class Expected:
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(z, y)
v = R.add(w, x)
R.output(v)
if R.const(True):
with R.dataflow():
q = R.multiply(v, v)
a = R.add(q, q)
b = R.multiply(a, a)
R.output(b)
# weird but the parser requires this construct
c = b
else:
with R.dataflow():
q = R.add(v, v)
a = R.multiply(q, q)
b = R.add(a, a)
R.output(b)
c = b
with R.dataflow():
d = R.multiply(c, c)
e = R.add(d, d)
f = R.multiply(e, e)
R.output(f)
return f
class TestTreatNonCallAsPure(ExtractCompare):
@I.ir_module
class Before:
@R.function
def tuples_and_const(x: R.Tensor, y: R.Tensor) -> R.Tensor:
t1 = (x, y, x)
t2 = (y, y, x)
c = R.const([1, 2, 3], dtype="int32")
return c
@R.function
def shapes() -> R.Shape:
s1 = R.shape((1, 2, 3))
s2 = R.shape((4, 5, 6))
s3 = R.shape((7, 8, 9))
return s3
@R.function
def prim_values():
x = R.prim_value(1)
y = R.prim_value(2)
z = R.prim_value(3)
return z
@R.function
def main(t: R.Tuple(R.Tensor, R.Tensor)) -> R.Tensor:
x = t[0]
y = t[1]
z = R.add(x, y)
w = R.multiply(z, z)
return w
@I.ir_module
class Expected:
@R.function
def tuples_and_const(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
t1 = (x, y, x)
t2 = (y, y, x)
c = R.const([1, 2, 3], dtype="int32")
R.output(c)
return R.const([1, 2, 3], dtype="int32")
@R.function
def shapes() -> R.Shape:
with R.dataflow():
s1 = R.shape((1, 2, 3))
s2 = R.shape((4, 5, 6))
s3 = R.shape((7, 8, 9))
R.output(s3)
return s3
@R.function
def prim_values():
with R.dataflow():
x = R.prim_value(1)
y = R.prim_value(2)
z = R.prim_value(3)
R.output(z)
return z
@R.function
def main(t: R.Tuple(R.Tensor, R.Tensor)) -> R.Tensor:
with R.dataflow():
x = t[0]
y = t[1]
z = R.add(x, y)
w = R.multiply(z, z)
R.output(w)
return w
class TestImpureInnerFunction(ExtractCompare):
@I.ir_module
class Before:
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
@R.function(pure=False)
def inner_func(x: R.Tensor, y: R.Tensor) -> R.Tensor:
z = R.add(x, y)
w = R.multiply(x, z)
v = R.add(y, w)
_ = R.print(format="oops")
a = R.multiply(v, v)
b = R.add(a, a)
c = R.multiply(a, b)
return c
z = R.add(x, y)
w = R.multiply(z, z)
v = R.divide(w, z)
q = inner_func(w, v)
a = R.multiply(q, q)
b = R.add(a, a)
c = R.multiply(b, a)
return c
@I.ir_module
class Expected:
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
@R.function(pure=False)
def inner_func(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(x, z)
v = R.add(y, w)
R.output(v)
_ = R.print(format="oops")
with R.dataflow():
a = R.multiply(v, v)
b = R.add(a, a)
c = R.multiply(a, b)
R.output(c)
return c
z = R.add(x, y)
w = R.multiply(z, z)
v = R.divide(w, z)
R.output(inner_func, v, w)
q = inner_func(w, v)
with R.dataflow():
a = R.multiply(q, q)
b = R.add(a, a)
c = R.multiply(b, a)
R.output(c)
return c
class TestPureInnerFunction(ExtractCompare):
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
@R.function
def inner_func(x: R.Tensor, y: R.Tensor) -> R.Tensor:
z = R.add(x, y)
w = R.multiply(x, z)
v = R.add(y, w)
return v
z = R.add(x, y)
w = R.multiply(z, z)
v = R.divide(w, z)
q = inner_func(w, v)
return q
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
@R.function
def inner_func(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(x, z)
v = R.add(y, w)
R.output(v)
return v
z = R.add(x, y)
w = R.multiply(z, z)
v = R.divide(w, z)
q = inner_func(w, v)
R.output(q)
return q
class TestImpureExternalFunction(ExtractCompare):
@I.ir_module
class Before:
@R.function(pure=False)
def extra(x: R.Tensor, y: R.Tensor) -> R.Tensor:
z = R.add(x, y)
q = R.matmul(z, x)
w = R.nn.relu(q)
_ = R.print(format="Whoa")
return w
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
z = R.add(x, y)
w = R.multiply(z, z)
q = Before.extra(z, w)
return q
@I.ir_module
class Expected:
@R.function(pure=False)
def extra(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
q = R.matmul(z, x)
w = R.nn.relu(q)
R.output(w)
_ = R.print(format="Whoa")
return w
@R.function(pure=False)
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(z, z)
R.output(z, w)
q = Expected.extra(z, w)
return q
class TestPureExternalFunction(ExtractCompare):
@I.ir_module
class Before:
@R.function
def extra(x: R.Tensor, y: R.Tensor) -> R.Tensor:
z = R.add(x, y)
q = R.matmul(z, x)
w = R.nn.relu(q)
return w
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
z = R.add(x, y)
w = R.multiply(z, z)
q = Before.extra(z, w)
return q
@I.ir_module
class Expected:
@R.function
def extra(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
q = R.matmul(z, x)
w = R.nn.relu(q)
R.output(w)
return w
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(z, z)
q = Expected.extra(z, w)
R.output(q)
return q
class TestMergeWithPrecedingDataflowBlock(ExtractCompare):
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(z, y)
R.output(w)
# The single binding of `v = R.add` would normally not be
# enough to make a dataflow block, as `1 < min_size == 2`.
v = R.add(w, x)
return v
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(z, y)
# However, it occurs just after an existing dataflow
# block, and can be merged into it.
v = R.add(w, x)
R.output(v)
return v
class TestMergeWithNextDataflowBlock(ExtractCompare):
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
# The single binding of `z = R.add` would normally not be
# enough to make a dataflow block, as `1 < min_size == 2`.
z = R.add(x, y)
# However, it occurs just before an existing dataflow
# block, and can be merged into it.
with R.dataflow():
w = R.multiply(z, y)
v = R.add(w, x)
R.output(v)
return v
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor, y: R.Tensor) -> R.Tensor:
with R.dataflow():
z = R.add(x, y)
w = R.multiply(z, y)
v = R.add(w, x)
R.output(v)
return v
class TestPreserveExistingDataflowBlocksAtBeginning(ExtractCompare):
"""Preserve existing DataflowBlocks
This is a regression test. In previous implementations, a
DataflowBlock in the input, without enough bindings to become a
new dataflow block, could be accidentally ommitted.
This test is identical to
`TestPreserveExistingDataflowBlocksAtEnd`, except that the
existing dataflow block is at the beginning of the function.
"""
@I.ir_module
class Before:
@R.function(pure=False)
def main(A0: R.Tensor, B0: R.Tensor):
# This DataflowBlock is below the minimum size for a new
# block, but already exists in the input IRModule.
with R.dataflow():
A1 = R.add(A0, A0)
R.output(A1)
R.print(format="impure_function")
# This sequence is large enough that it may be converted
# to a DataflowBlock.
B1 = R.add(B0, B0)
B2 = R.add(B1, B1)
B3 = R.add(B2, B2)
return (A1, B3)
@I.ir_module
class Expected:
@R.function(pure=False)
def main(A0: R.Tensor, B0: R.Tensor):
# This dataflow block should be preserved in the output.
with R.dataflow():
A1 = R.add(A0, A0)
R.output(A1)
R.print(format="impure_function")
with R.dataflow():
B1 = R.add(B0, B0)
B2 = R.add(B1, B1)
B3 = R.add(B2, B2)
R.output(B3)
return (A1, B3)
class TestPreserveExistingDataflowBlocksAtEnd(ExtractCompare):
"""Preserve existing DataflowBlocks
This is a regression test. In previous implementations, a
DataflowBlock in the input, without enough bindings to become a
new dataflow block, could be accidentally ommitted.
This test is identical to
`TestPreserveExistingDataflowBlocksAtBeginning`, except that the
existing dataflow block is at the end of the function.
"""
@I.ir_module
class Before:
@R.function(pure=False)
def main(A0: R.Tensor, B0: R.Tensor):
# This sequence is large enough that it may be converted
# to a DataflowBlock.
B1 = R.add(B0, B0)
B2 = R.add(B1, B1)
B3 = R.add(B2, B2)
R.print(format="impure_function")
# This DataflowBlock is below the minimum size for a new
# block, but already exists in the input IRModule.
with R.dataflow():
A1 = R.add(A0, A0)
R.output(A1)
return (A1, B3)
@I.ir_module
class Expected:
@R.function(pure=False)
def main(A0: R.Tensor, B0: R.Tensor):
with R.dataflow():
B1 = R.add(B0, B0)
B2 = R.add(B1, B1)
B3 = R.add(B2, B2)
R.output(B3)
R.print(format="impure_function")
# This dataflow block should be preserved in the output.
with R.dataflow():
A1 = R.add(A0, A0)
R.output(A1)
return (A1, B3)
if __name__ == "__main__":
tvm.testing.main()