blob: 066282ae15b1819f1de7ab8e4b7a888bdadab6d1 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
import tvm.testing
from tvm import relax
from tvm.script import relax as R, tir as T
from tvm.script import ir as I
import numpy as np
import tvm.topi.testing
@pytest.mark.parametrize("consume_params", [True, False])
def test_basic(consume_params):
@tvm.script.ir_module
class Before:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
) -> None:
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
out[o, i, h, w] = w1[i, o, h, w]
@R.function
def main(
x: R.Tensor((1, 3, 224, 224), "float32"),
w1: R.Tensor((3, 16, 3, 3), "float32"),
w2: R.Tensor((16, 16, 3, 3), "float32"),
) -> R.Tensor((1, 16, 224, 224), "float32"):
R.func_attr({"num_input": 1})
cls = Before
with R.dataflow():
w1_transformed = R.call_tir(
cls.transform_layout_IOHW_to_OIHW, w1, R.Tensor((16, 3, 3, 3), "float32")
)
conv1 = R.nn.conv2d(
x, w1_transformed, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW"
)
conv2 = R.nn.conv2d(
conv1, w2, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW"
)
R.output(conv2)
return conv2
@tvm.script.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((1, 3, 224, 224), dtype="float32"),
w2: R.Tensor((16, 16, 3, 3), dtype="float32"),
w1_transformed: R.Tensor((16, 3, 3, 3), dtype="float32"),
) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
conv1: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d(
x,
w1_transformed,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
groups=1,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="NCHW",
out_dtype="void",
)
conv2: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d(
conv1,
w2,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
groups=1,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="NCHW",
out_dtype="void",
)
R.output(conv2)
return conv2
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
):
R.func_attr({"num_input": 0})
cls = Expected
with R.dataflow():
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 3, 3, 3), dtype="float32"),
) = (lv, lv2)
R.output(gv)
return gv
@tvm.script.ir_module
class ExpectedConsumeParams:
@R.function
def main(
x: R.Tensor((1, 3, 224, 224), dtype="float32"),
w2: R.Tensor((16, 16, 3, 3), dtype="float32"),
w1_transformed: R.Tensor((16, 3, 3, 3), dtype="float32"),
) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
conv1: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d(
x,
w1_transformed,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
groups=1,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="NCHW",
out_dtype="void",
)
conv2: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d(
conv1,
w2,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
groups=1,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="NCHW",
out_dtype="void",
)
R.output(conv2)
return conv2
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
):
R.func_attr({"num_input": 0})
cls = ExpectedConsumeParams
with R.dataflow():
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
_1: R.Tuple = R.call_pure_packed(
"vm.builtin.tuple_reset_item",
params,
R.prim_value(T.int32(0)),
sinfo_args=(R.Tuple,),
)
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
_2: R.Tuple = R.call_pure_packed(
"vm.builtin.tuple_reset_item",
params,
R.prim_value(T.int32(1)),
sinfo_args=(R.Tuple,),
)
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 3, 3, 3), dtype="float32"),
) = (lv, lv2)
R.output(gv)
return gv
mod = Before
expected = Expected if not consume_params else ExpectedConsumeParams
with tvm.transform.PassContext(
config={"relax.lift_transform_params.consume_params": consume_params}
):
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, expected)
names_after = [param.name_hint for param in after["main"].params]
names_expected = [param.name_hint for param in expected["main"].params]
assert names_after == names_expected
def test_tuple():
@tvm.script.ir_module
class Before:
@R.function
def main(
x: R.Tensor((1, 16, 224, 224), "float32"), w1: R.Tensor((16, 16, 3, 3), "float32")
) -> R.Tensor((1, 16, 224, 224), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
l0 = (w1,)
l1 = (l0,)
l2 = l1[0]
l3 = l2[0]
conv1 = R.nn.conv2d(x, l3, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW")
conv2 = R.nn.conv2d(
conv1, w1, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW"
)
R.output(conv2)
return conv2
@tvm.script.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((1, 16, 224, 224), dtype="float32"),
param0: R.Tensor((16, 16, 3, 3), dtype="float32"),
param1: R.Tensor((16, 16, 3, 3), dtype="float32"),
) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
conv1: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d(
x,
param1,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
groups=1,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="NCHW",
out_dtype="void",
)
conv2: R.Tensor((1, 16, 224, 224), dtype="float32") = R.nn.conv2d(
conv1,
param0,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
groups=1,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="NCHW",
out_dtype="void",
)
R.output(conv2)
return conv2
@R.function
def main_transform_params(
params: R.Tuple(R.Tensor((16, 16, 3, 3), dtype="float32")),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
):
R.func_attr({"num_input": 0})
with R.dataflow():
l3 = params[0]
w1 = params[0]
gv = (w1, l3)
R.output(gv)
return gv
mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)
def test_condition():
"""Test case that the conditional statement can't be lifted"""
@tvm.script.ir_module
class Before:
@R.function
def main(
x: R.Tensor((1, 16, 224, 224), "float32"),
w1: R.Tensor((16, 16, 3, 3), "float32"),
w2: R.Tensor((16, 16, 3, 3), "float32"),
cond: R.Tensor((), "bool"),
) -> R.Tensor((1, 16, 224, 224), "float32"):
R.func_attr({"num_input": 1})
if cond:
w = w1
else:
w = w2
with R.dataflow():
conv1 = R.nn.conv2d(x, w, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW")
R.output(conv1)
return conv1
@tvm.script.ir_module
class Expected:
@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((), dtype="bool"),
),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((), dtype="bool"),
):
R.func_attr({"num_input": 0})
return params
@R.function
def main(
x: R.Tensor((1, 16, 224, 224), "float32"),
param0: R.Tensor((16, 16, 3, 3), dtype="float32"),
param1: R.Tensor((16, 16, 3, 3), dtype="float32"),
param2: R.Tensor((), dtype="bool"),
) -> R.Tensor((1, 16, 224, 224), "float32"):
R.func_attr({"num_input": 1})
if param2:
w: R.Tensor((16, 16, 3, 3), dtype="float32") = param0
else:
w: R.Tensor((16, 16, 3, 3), dtype="float32") = param1
with R.dataflow():
conv1 = R.nn.conv2d(x, w, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW")
R.output(conv1)
return conv1
mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)
def test_multiple_functions():
@tvm.script.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1, [1, 0])
y = R.matmul(x, w1_t)
R.output(y)
return y
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((128, 256), "float32"),
) -> R.Tensor((256, 128), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1, [1, 0])
y = R.matmul(x, w1_t)
R.output(y)
return y
@R.function
def func3(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
with R.dataflow():
w1_t = R.permute_dims(w1, [1, 0])
y = R.matmul(x, w1_t)
R.output(y)
return y
@tvm.script.ir_module
class Expected:
@R.function
def func1(
x: R.Tensor((256, 256), dtype="float32"),
param0: R.Tensor((256, 256), dtype="float32"),
) -> R.Tensor((256, 256), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, param0, out_dtype="void")
R.output(y)
return y
@R.function
def func1_transform_params(
params: R.Tuple(R.Tensor((256, 256), dtype="float32")),
) -> R.Tuple(R.Tensor((256, 256), dtype="float32")):
R.func_attr({"num_input": 0})
with R.dataflow():
lv: R.Tensor((256, 256), dtype="float32") = params[0]
lv1: R.Tensor((256, 256), dtype="float32") = R.permute_dims(lv, axes=[1, 0])
gv: R.Tuple(R.Tensor((256, 256), dtype="float32")) = (lv1,)
R.output(gv)
return gv
@R.function
def func2(
x: R.Tensor((256, 256), dtype="float32"),
param0: R.Tensor((256, 128), dtype="float32"),
) -> R.Tensor((256, 128), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y: R.Tensor((256, 128), dtype="float32") = R.matmul(x, param0, out_dtype="void")
R.output(y)
return y
@R.function
def func2_transform_params(
params: R.Tuple(R.Tensor((128, 256), dtype="float32")),
) -> R.Tuple(R.Tensor((256, 128), dtype="float32")):
R.func_attr({"num_input": 0})
with R.dataflow():
lv: R.Tensor((128, 256), dtype="float32") = params[0]
lv1: R.Tensor((256, 128), dtype="float32") = R.permute_dims(lv, axes=[1, 0])
gv: R.Tuple(R.Tensor((256, 128), dtype="float32")) = (lv1,)
R.output(gv)
return gv
@R.function
def func3(
x: R.Tensor((256, 256), dtype="float32"), w1: R.Tensor((256, 256), dtype="float32")
) -> R.Tensor((256, 256), dtype="float32"):
with R.dataflow():
w1_t: R.Tensor((256, 256), dtype="float32") = R.permute_dims(w1, axes=[1, 0])
y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, w1_t, out_dtype="void")
R.output(y)
return y
mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)
def test_share_identical_transform_across_multiple_functions():
"""Like test_multiple_functions, but producing a single transform_params
`func1` and `func2` contain the same values `w1_t` and `w2_t`.
When `shared_transform=True`, all eligible publicly-exposed
functions must be usable with the same shared transform.
"""
@I.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.multiply(y1, y2)
R.output(output)
return output
@I.ir_module
class Expected:
@R.function
def transform_params(
params: R.Tuple(
R.Tensor((256, 256), dtype="float32"),
R.Tensor((256, 256), dtype="float32"),
),
):
R.func_attr({"num_input": 0})
with R.dataflow():
w1 = params[0]
w1_t = R.permute_dims(w1)
w2 = params[1]
w2_t = R.permute_dims(w2)
output = (w1_t, w2_t)
R.output(output)
return output
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
w2_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
w2_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
y2 = R.matmul(x, w2_t)
output = R.multiply(y1, y2)
R.output(output)
return output
after = relax.transform.LiftTransformParams(shared_transform=True)(Before)
tvm.ir.assert_structural_equal(after, Expected)
def test_incompatible_weights_in_shared_transform_raises_error():
"""Model weights must have matched shape for shared_transform
Here, `func1` accepts one model weight, but `func2` accepts two.
"""
@I.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
output = y1
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.multiply(y1, y2)
R.output(output)
return output
with pytest.raises(tvm.TVMError):
relax.transform.LiftTransformParams(shared_transform=True)(Before)
def test_incompatible_shape_in_shared_transform_raises_error():
"""Model weights must have matched shape for shared_transform
Here, `func1` accepts `w1` and `w2` with shape `[256,256]`, but `func2`
requires shape `[128, 256]`.
"""
@I.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((128, 256), "float32"),
w2: R.Tensor((128, 256), "float32"),
) -> R.Tensor((256, 128), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.multiply(y1, y2)
R.output(output)
return output
with pytest.raises(tvm.TVMError):
relax.transform.LiftTransformParams(shared_transform=True)(Before)
def test_incompatible_dtype_in_shared_transform_raises_error():
"""Model weights must have matched dtype for shared_transform
Here, `func1` accepts `w1` and `w2` with "float32" dtype, but
`func2` requires "float16".
"""
@I.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float16"),
w1: R.Tensor((128, 256), "float16"),
w2: R.Tensor((128, 256), "float16"),
) -> R.Tensor((256, 128), "float16"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.multiply(y1, y2)
R.output(output)
return output
with pytest.raises(tvm.TVMError):
relax.transform.LiftTransformParams(shared_transform=True)(Before)
def test_share_transform_across_multiple_functions_has_intersection_of_transforms():
"""Like test_multiple_functions, but producing a single transform_params
In `func1`, both `w1_t` and `w2_t` could be lifted out. In
`func2`, only `w1_t` could be lifted out of the function.
Therefore, the shared `transform_params` can pre-compute `w1_t`,
but must preserve `w2`.
When `shared_transform=True`, all eligible publicly-exposed
functions must be usable with the same shared transform.
"""
@I.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
y2 = Before.fused_permute_dims_matmul(x, w2)
output = R.add(y1, y2)
R.output(output)
return output
@R.function(private=True)
def fused_permute_dims_matmul(
x: R.Tensor((256, 256), "float32"),
weight: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
with R.dataflow():
weight_t = R.permute_dims(weight)
y = R.matmul(x, weight_t)
R.output(y)
return y
@I.ir_module
class Expected:
@R.function
def transform_params(
params: R.Tuple(
R.Tensor((256, 256), dtype="float32"),
R.Tensor((256, 256), dtype="float32"),
),
):
R.func_attr({"num_input": 0})
with R.dataflow():
w1 = params[0]
w1_t = R.permute_dims(w1)
w2 = params[1]
output = (w2, w1_t)
R.output(output)
return output
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
y2 = Expected.fused_permute_dims_matmul(x, w2)
output = R.add(y1, y2)
R.output(output)
return output
@R.function(private=True)
def fused_permute_dims_matmul(
x: R.Tensor((256, 256), "float32"),
weight: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
with R.dataflow():
weight_t = R.permute_dims(weight)
y = R.matmul(x, weight_t)
R.output(y)
return y
after = relax.transform.LiftTransformParams(shared_transform=True)(Before)
tvm.ir.assert_structural_equal(after, Expected)
def test_share_transforms_with_different_binding_order():
"""Like test_share_transform_across_multiple_functions, but the
lifted bindings are in different order for each function.
Both `func1` and `func2` compute the same value for `w1_t` and
`w2_t`. However, the bindings occur in different orders. The
shared `transform_params` can pre-compute both `w1_t` and `w2_t`,
even though they occur in different orders.
For consistency in testing and pre-computing weights, the order of
`transform_params` should be deterministic. When lifting from a
single function, the bindings in `transform_params` may be
determined from the order in that function. When lifting from
multiple functions, the order should be deterministic. Since
`IRModule::functions` has unspecified order, the order in this
test assumes that public functions are visited in alphabetical
order by name.
"""
@I.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w2_t = R.permute_dims(w2)
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.multiply(y1, y2)
R.output(output)
return output
@I.ir_module
class Expected:
@R.function
def transform_params(
params: R.Tuple(
R.Tensor((256, 256), dtype="float32"),
R.Tensor((256, 256), dtype="float32"),
),
):
R.func_attr({"num_input": 0})
with R.dataflow():
w2 = params[1]
w2_t = R.permute_dims(w2)
w1 = params[0]
w1_t = R.permute_dims(w1)
output = (w2_t, w1_t)
R.output(output)
return output
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w2_t: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w2_t: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
y2 = R.matmul(x, w2_t)
output = R.multiply(y1, y2)
R.output(output)
return output
after = relax.transform.LiftTransformParams(shared_transform=True)(Before)
tvm.ir.assert_structural_equal(after, Expected)
def test_share_transforms_resulting_in_identical_functions():
"""Functions in the public interface must be preserved
When lifting functions, the resulting functions may be identical.
Even though the `relax.BlockBuilder` de-duplicates identical
functions, functions that are part of the IRModule's public
interface must be preserved.
"""
@I.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w2_t = R.permute_dims(w2)
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@I.ir_module
class Expected:
@R.function
def transform_params(
params: R.Tuple(
R.Tensor((256, 256), dtype="float32"),
R.Tensor((256, 256), dtype="float32"),
),
):
R.func_attr({"num_input": 0})
with R.dataflow():
w2 = params[1]
w2_t = R.permute_dims(w2)
w1 = params[0]
w1_t = R.permute_dims(w1)
output = (w2_t, w1_t)
R.output(output)
return output
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w2_t: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w2_t: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
after = relax.transform.LiftTransformParams(shared_transform=True)(Before)
tvm.ir.assert_structural_equal(after, Expected)
def test_share_transform_across_specified_functions():
"""Like test_multiple_functions, but producing a single transform_params
In `func1`, both `w1_t` and `w2_t` could be lifted out. In
`func2`, only `w1_t` could be lifted out of the function.
Therefore, the shared `transform_params` can pre-compute `w1_t`,
but must preserve `w2`.
If `func3` were included in the `transform_params`, the same logic
would prevent `w1_t` from being computed in the shared
`transform_params`. However, the
`shared_transform=['func1','func2']` argument means that `func3`
does not have any parameter transformations lifted out.
"""
@I.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
y2 = Before.fused_permute_dims_matmul(x, w2)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func3(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = Before.fused_permute_dims_matmul(x, w1)
y2 = Before.fused_permute_dims_matmul(x, w2)
output = R.add(y1, y2)
R.output(output)
return output
@R.function(private=True)
def fused_permute_dims_matmul(
x: R.Tensor((256, 256), "float32"),
weight: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
with R.dataflow():
weight_t = R.permute_dims(weight)
y = R.matmul(x, weight_t)
R.output(y)
return y
@I.ir_module
class Expected:
@R.function
def transform_params(
params: R.Tuple(
R.Tensor((256, 256), dtype="float32"),
R.Tensor((256, 256), dtype="float32"),
),
):
R.func_attr({"num_input": 0})
with R.dataflow():
w1 = params[0]
w1_t = R.permute_dims(w1)
w2 = params[1]
output = (w2, w1_t)
R.output(output)
return output
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
y2 = Expected.fused_permute_dims_matmul(x, w2)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func3(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = Expected.fused_permute_dims_matmul(x, w1)
y2 = Expected.fused_permute_dims_matmul(x, w2)
output = R.add(y1, y2)
R.output(output)
return output
@R.function(private=True)
def fused_permute_dims_matmul(
x: R.Tensor((256, 256), "float32"),
weight: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
with R.dataflow():
weight_t = R.permute_dims(weight)
y = R.matmul(x, weight_t)
R.output(y)
return y
after = relax.transform.LiftTransformParams(shared_transform=["func1", "func2"])(Before)
tvm.ir.assert_structural_equal(after, Expected)
def test_share_transform_with_unused_parameter():
"""Like test_share_transform_across_specified_functions, but not
all functions use every model weight.
In `func1`, both `w1_t` and `w2_t` could be lifted out. In
`func2`, only `w1_t` could be lifted out of the function.
Normally, the `w2` parameter would need to be preserved, as `w2_t`
is only generated in one of the functions. However, `func2`
doesn't use `w2` at all, and so `w2_t` can still be pre-computed.
For example, a `embed_vocab` function would only use the embedding
weights. It could accept the full set of model weights for
consistency, but any transformations performed on unused weights
in other functions can still be lifted out.
"""
@I.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
R.output(y1)
return y1
@I.ir_module
class Expected:
@R.function
def transform_params(
params: R.Tuple(
R.Tensor((256, 256), dtype="float32"),
R.Tensor((256, 256), dtype="float32"),
),
):
R.func_attr({"num_input": 0})
with R.dataflow():
w1 = params[0]
w1_t = R.permute_dims(w1)
w2 = params[1]
output = (w2, w1_t)
R.output(output)
return output
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
output = R.add(y1, y2)
R.output(output)
return output
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
R.output(y1)
return y1
after = relax.transform.LiftTransformParams(shared_transform=True)(Before)
tvm.ir.assert_structural_equal(after, Expected)
@pytest.mark.xfail
def test_share_transform_with_no_shared_preprocessing():
"""Like test_share_transform_with_unused_parameter, but each
function uses a single model weight.
In `func1`, `w2_t` can be lifted out and `w1` is unused. In
`func2`, `w1_t` can be lifted out, and `w2` is unused. In their
shared `transform_params`, both `w1_t` and `w2_t` can be computed.
For consistency in testing and pre-computing weights, the order of
`transform_params` should be deterministic. When lifting from a
single function, the bindings in `transform_params` may be
determined from the order in that function. When lifting from
multiple functions, the order should be deterministic. Since
`IRModule::functions` has unspecified order, the order in this
test assumes that public functions are visited in alphabetical
order by name.
"""
@I.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w2_t = R.permute_dims(w2)
y2 = R.matmul(x, w2_t)
R.output(y2)
return y2
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
w2: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1)
y1 = R.matmul(x, w1_t)
R.output(y1)
return y1
@I.ir_module
class Expected:
@R.function
def transform_params(
params: R.Tuple(
R.Tensor((256, 256), dtype="float32"),
R.Tensor((256, 256), dtype="float32"),
),
):
R.func_attr({"num_input": 0})
with R.dataflow():
w1 = params[0]
w1_t = R.permute_dims(w1)
w2 = params[1]
w2_t = R.permute_dims(w2)
output = (w2_t, w1_t)
R.output(output)
return output
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w2_t: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y2 = R.matmul(x, w2_t)
R.output(y2)
return y2
@R.function
def func2(
x: R.Tensor((256, 256), "float32"),
w2_t: R.Tensor((256, 256), "float32"),
w1_t: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
y1 = R.matmul(x, w1_t)
R.output(y1)
return y1
after = relax.transform.LiftTransformParams(shared_transform=True)(Before)
tvm.ir.assert_structural_equal(after, Expected)
def test_stop_lifting():
@tvm.script.ir_module
class Before:
@R.function
def func1(
x: R.Tensor((256, 256), "float32"),
w1: R.Tensor((256, 256), "float32"),
) -> R.Tensor((256, 256), "float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_t = R.permute_dims(w1, [1, 0])
w1_t1 = R.builtin.stop_lift_params(w1_t)
w1_add = R.add(w1_t1, R.const(1, "float32"))
y = R.matmul(x, w1_add)
R.output(y)
return y
@I.ir_module
class Expected:
@R.function
def func1(
x: R.Tensor((256, 256), dtype="float32"),
param0: R.Tensor((256, 256), dtype="float32"),
) -> R.Tensor((256, 256), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
w1_add: R.Tensor((256, 256), dtype="float32") = R.add(param0, R.const(1, "float32"))
y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, w1_add, out_dtype="void")
R.output(y)
return y
@R.function
def func1_transform_params(
params: R.Tuple(R.Tensor((256, 256), dtype="float32")),
) -> R.Tuple(R.Tensor((256, 256), dtype="float32")):
R.func_attr({"num_input": 0})
with R.dataflow():
lv: R.Tensor((256, 256), dtype="float32") = params[0]
lv1: R.Tensor((256, 256), dtype="float32") = R.permute_dims(lv, axes=[1, 0])
gv: R.Tuple(R.Tensor((256, 256), dtype="float32")) = (lv1,)
R.output(gv)
return gv
mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)
def test_symbolic_var_1():
@tvm.script.ir_module
class Before:
@R.function
def main(shape: R.Shape(["n"])):
R.func_attr({"num_input": 1})
n = T.int64()
with R.dataflow():
zeros = R.zeros((n, n), "float32")
return shape
@I.ir_module
class Expected:
@R.function
def main_transform_params(params: R.Tuple) -> R.Tuple:
R.func_attr({"num_input": 0})
# All instance of the empty tuple are normalized to be
# in-line.
return R.tuple()
@R.function
def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
R.func_attr({"num_input": 1})
n = T.int64()
with R.dataflow():
zeros: R.Tensor((n, n), dtype="float32") = R.zeros(R.shape([n, n]), dtype="float32")
R.output()
return shape
mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)
def test_symbolic_var_2():
@I.ir_module
class Before:
@T.prim_func
def zeros(var_T_full: T.handle):
T.func_attr({"tir.noalias": True})
n = T.int64()
T_full = T.match_buffer(var_T_full, (n, n))
for ax0, ax1 in T.grid(n, n):
with T.block("T_full"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads()
T.writes(T_full[v_ax0, v_ax1])
T_full[v_ax0, v_ax1] = T.float32(0)
@R.function
def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
R.func_attr({"num_input": 1})
n = T.int64()
cls = Before
with R.dataflow():
zeros = R.call_tir(
cls.zeros, R.tuple(), out_sinfo=R.Tensor((n, n), dtype="float32")
)
R.output()
return shape
@I.ir_module
class Expected:
@T.prim_func
def zeros(var_T_full: T.handle):
T.func_attr({"tir.noalias": True})
n = T.int64()
T_full = T.match_buffer(var_T_full, (n, n))
# with T.block("root"):
for ax0, ax1 in T.grid(n, n):
with T.block("T_full"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads()
T.writes(T_full[v_ax0, v_ax1])
T_full[v_ax0, v_ax1] = T.float32(0)
@R.function
def main_transform_params(params: R.Tuple) -> R.Tuple:
R.func_attr({"num_input": 0})
return R.tuple()
@R.function
def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
R.func_attr({"num_input": 1})
n = T.int64()
cls = Expected
with R.dataflow():
zeros = R.call_tir(
cls.zeros, R.tuple(), out_sinfo=R.Tensor((n, n), dtype="float32")
)
R.output()
return shape
mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)
def test_symbolic_var_from_shape():
@I.ir_module
class Before:
@R.function
def main(
A: R.Tensor([16, 16], "int32"),
B: R.Tensor([16, 16], "int32"),
shape: R.Shape(["slice_index"]),
) -> R.Tensor([16], "int32"):
R.func_attr({"num_input": 1})
slice_index = T.int64()
cls = Before
with R.dataflow():
B_slice = R.call_tir(
cls.slice,
[B],
tir_vars=R.ShapeExpr([slice_index]),
out_sinfo=R.Tensor([16], dtype="int32"),
)
A_slice = R.call_tir(
cls.slice,
[A],
tir_vars=R.ShapeExpr([slice_index]),
out_sinfo=R.Tensor([16], dtype="int32"),
)
A_scale = R.multiply(A_slice, B_slice)
R.output(A_scale)
return A_scale
@T.prim_func(private=True)
def slice(
Input_2d: T.Buffer(shape=[16, 16], dtype="int32"),
Output_Slice: T.Buffer(shape=[16], dtype="int32"),
slice_index: T.int64,
):
T.func_attr({"tir.noalias": True})
for j in range(16):
with T.block("T_full"):
vj = T.axis.remap("S", [j])
Output_Slice[vj] = Input_2d[slice_index, vj]
@I.ir_module
class Expected:
@R.function
def main(
A: R.Tensor([16, 16], "int32"),
shape: R.Shape(["slice_index"]),
B_slice: R.Tensor([16], "int32"),
) -> R.Tensor([16], "int32"):
R.func_attr({"num_input": 1})
slice_index = T.int64()
cls = Expected
with R.dataflow():
A_slice = R.call_tir(
cls.slice,
[A],
tir_vars=R.ShapeExpr([slice_index]),
out_sinfo=R.Tensor([16], dtype="int32"),
)
A_scale = R.multiply(A_slice, B_slice)
R.output(A_scale)
return A_scale
@R.function
def main_transform_params(
params: R.Tuple(R.Tensor([16, 16], "int32"), R.Shape(["slice_index"])),
):
R.func_attr({"num_input": 0})
slice_index = T.int64()
cls = Expected
with R.dataflow():
B = params[0]
# extra_symbolic_vars = params[1]
B_slice = R.call_tir(
cls.slice,
[B],
tir_vars=R.ShapeExpr([slice_index]),
out_sinfo=R.Tensor([16], dtype="int32"),
)
output = (R.ShapeExpr([slice_index]), B_slice)
R.output(output)
return output
@T.prim_func(private=True)
def slice(
Input_2d: T.Buffer(shape=[16, 16], dtype="int32"),
Output_Slice: T.Buffer(shape=[16], dtype="int32"),
slice_index: T.int64,
):
T.func_attr({"tir.noalias": True})
for j in range(16):
with T.block("T_full"):
vj = T.axis.remap("S", [j])
Output_Slice[vj] = Input_2d[slice_index, vj]
mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(Expected, after)
def test_symbolic_var_in_param_shape():
@tvm.script.ir_module
class Before:
@R.function
def main(
x: R.Tensor((1, 16, 224, "n"), "float32"),
w1: R.Tensor((16, "m", 3, 3), "float32"),
w2: R.Tensor((16, "m", 3, 3), "float32"),
) -> R.Tensor((1, 16, 224, "n"), "float32"):
m = T.int64()
n = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
zeros = R.zeros((n, n), "float32")
w1 = R.add(w1, R.const(1, "float32"))
conv1 = R.nn.conv2d(x, w1, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW")
conv2 = R.nn.conv2d(
conv1, w2, padding=(1, 1), data_layout="NCHW", kernel_layout="OIHW"
)
R.output(conv2)
return conv2
@I.ir_module
class Expected:
@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((16, "m", 3, 3), dtype="float32"),
R.Tensor((16, "m", 3, 3), dtype="float32"),
),
) -> R.Tuple(
R.Tensor((16, "m", 3, 3), dtype="float32"), R.Tensor((16, "m", 3, 3), dtype="float32")
):
R.func_attr({"num_input": 0})
m = T.int64()
with R.dataflow():
lv1: R.Tensor((16, m, 3, 3), dtype="float32") = params[0]
lv2: R.Tensor((16, m, 3, 3), dtype="float32") = R.add(lv1, R.const(1, "float32"))
lv: R.Tensor((16, m, 3, 3), dtype="float32") = params[1]
gv: R.Tuple(
R.Tensor((16, m, 3, 3), dtype="float32"),
R.Tensor((16, m, 3, 3), dtype="float32"),
) = (lv, lv2)
R.output(gv)
return gv
@R.function
def main(
x: R.Tensor((1, 16, 224, "n"), dtype="float32"),
transformed_param_0: R.Tensor((16, "m", 3, 3), dtype="float32"),
transformed_param_1: R.Tensor((16, "m", 3, 3), dtype="float32"),
) -> R.Tensor((1, 16, 224, "n"), dtype="float32"):
n = T.int64()
m = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
zeros: R.Tensor((n, n), dtype="float32") = R.zeros(R.shape([n, n]), dtype="float32")
conv1: R.Tensor((1, 16, 224, n), dtype="float32") = R.nn.conv2d(
x,
transformed_param_1,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
groups=1,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="NCHW",
out_dtype="void",
)
conv2: R.Tensor((1, 16, 224, n), dtype="float32") = R.nn.conv2d(
conv1,
transformed_param_0,
strides=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
groups=1,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="NCHW",
out_dtype="void",
)
R.output(conv2)
return conv2
mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)
# not supported yet
@pytest.mark.xfail
def test_symbolic_var_defined_in_params_but_used_in_weights():
"""A symbolic variable's occurrence in the weights may not define it
In order to be a source of definition, a symbolic variable in the
parameters must occur as a distinct parameter, as a tensor shape
`R.Tensor(["var"])`, an explicit `R.Shape(["var"])`, or as a
`R.Prim(value="var")`. A variable that is part of a larger
expression, such as `R.Tensor(["m * n"])`, are variable usages,
not variable definitions.
"""
@tvm.script.ir_module
class Before:
@R.function
def main(
x: R.Tensor(["m", "n"], "float32"),
weight: R.Tensor(["m * n"], "float32"),
) -> R.Tensor(["m", "n"], "float32"):
m = T.int64()
n = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
weight = R.add(weight, R.const(1, "float32"))
weight = R.reshape(weight, [m, n])
output = R.multiply(x, weight)
R.output(output)
return output
@tvm.script.ir_module
class Expected:
@R.function
def main_transform_params(
params: R.Tuple(R.Tensor(("k",), dtype="float32"))
) -> R.Tuple(R.Tensor(dtype="float32", ndim=1)):
R.func_attr({"num_input": 0})
k = T.int64()
with R.dataflow():
lv: R.Tensor((k,), dtype="float32") = params[0]
gv: R.Tuple(R.Tensor((k,), dtype="float32")) = (lv,)
R.output(gv)
return gv
@R.function
def main(
x: R.Tensor(("m", "n"), dtype="float32"),
transformed_param_0: R.Tensor(dtype="float32", ndim=1),
) -> R.Tensor(("m", "n"), dtype="float32"):
m = T.int64()
n = T.int64()
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor(dtype="float32", ndim=1) = transformed_param_0
weight: R.Tensor(dtype="float32", ndim=1) = R.add(lv, R.const(1, "float32"))
weight_1: R.Tensor((m, n), dtype="float32") = R.reshape(weight, R.shape([m, n]))
output: R.Tensor((m, n), dtype="float32") = R.multiply(x, weight_1)
R.output(output)
return output
After = relax.transform.LiftTransformParams()(Before)
tvm.ir.assert_structural_equal(Expected, After)
def test_only_lift_when_variable_uses_constants():
"""A variable that has no inputs should not be lifted
For example, `R.zeros`, or the result of allocation function
calls.
"""
@tvm.script.ir_module
class Before:
@R.function
def main(
A: R.Tensor([16], "int32"),
B: R.Tensor([16], "int32"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
offset = R.ones([16], "int32")
A_offset = R.add(A, offset)
B_offset = R.add(B, offset)
output = R.multiply(A_offset, B_offset)
R.output(output)
return output
@tvm.script.ir_module
class Expected:
@R.function
def main(
A: R.Tensor([16], "int32"),
B_offset: R.Tensor([16], "int32"),
):
R.func_attr({"num_input": 1})
with R.dataflow():
offset = R.ones([16], "int32")
A_offset = R.add(A, offset)
output = R.multiply(A_offset, B_offset)
R.output(output)
return output
@R.function
def main_transform_params(params: R.Tuple([R.Tensor([16], "int32")])):
R.func_attr({"num_input": 0})
with R.dataflow():
offset = R.ones([16], "int32")
B = params[0]
B_offset = R.add(B, offset)
output = (B_offset,)
R.output(output)
return output
mod = Before
after = relax.transform.LiftTransformParams()(mod)
tvm.ir.assert_structural_equal(after, Expected)
@pytest.mark.parametrize("shared_transform", [True, False])
def test_lift_transform_is_idempotent(shared_transform):
"""Multiple applicates of LiftTransformParams are allowed"""
@I.ir_module
class Module:
@R.function
def main(
state: R.Tensor(["batch_size", 4096], "float16"),
base_weights: R.Tensor([4096, 4096], "float16"),
lora_A: R.Tensor([4096, "lora_rank"], "float16"),
lora_B: R.Tensor(["lora_rank", 4096], "float16"),
):
R.func_attr({"num_input": 1})
folded_weights = base_weights + R.matmul(lora_A, lora_B)
output = R.matmul(state, folded_weights)
return output
transform = relax.transform.LiftTransformParams(shared_transform=shared_transform)
AfterOneRound = transform(Module)
assert len(AfterOneRound.functions) == 2
AfterTwoRounds = transform(AfterOneRound)
assert len(AfterTwoRounds.functions) == 2
tvm.ir.assert_structural_equal(AfterOneRound, AfterTwoRounds)
def test_lift_transform_when_one_already_exists():
"""If the module already contains `transform_params`, the
functions are composed together"""
@I.ir_module
class Module:
@R.function
def main(
state: R.Tensor(["batch_size", 4096], "float16"),
base_weights: R.Tensor([4096, 4096], "float16"),
lora_A: R.Tensor([4096, "lora_rank"], "float16"),
lora_B: R.Tensor(["lora_rank", 4096], "float16"),
):
R.func_attr({"num_input": 1})
folded_weights = base_weights + R.matmul(lora_A, lora_B)
output = R.matmul(state, folded_weights)
return output
@R.function
def main_transform_params(
model_params: R.Tuple(
R.Tensor([4096, 4096], "float16"),
R.Tensor([4096, "lora_rank"], "float16"),
R.Tensor(["lora_rank", 4096], "float16"),
),
):
R.func_attr({"num_input": 0})
return model_params
transform = relax.transform.LiftTransformParams(shared_transform=False)
after_lift_with_previous_identity_function = transform(Module)
del Module["main_transform_params"]
after_lift_without_previous_identity_function = transform(Module)
tvm.ir.assert_structural_equal(
after_lift_without_previous_identity_function,
after_lift_with_previous_identity_function,
)
if __name__ == "__main__":
tvm.testing.main()