blob: ae0521a0e2f8ae47986e411ca5f44302cdfb0014 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
import tvm.testing
from tvm import relax
from tvm.script import relax as R, tir as T
from tvm.script import ir as I
from tvm.relax.transform import LazyTransformParams
def test_lazy_transform_params():
@I.ir_module
class Before:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
):
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Before
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 3, 3, 3), dtype="float32"),
) = (lv, lv2)
return gv
@I.ir_module
class Expected:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
@R.function(pure=False)
def main_transform_params() -> R.Tuple:
cls = Expected
lv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,))
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
lv, R.Tensor((16, 16, 3, 3), dtype="float32")
)
lv_m: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
_: R.Object = R.call_packed("set_item", R.prim_value(0), lv_m, sinfo_args=(R.Object,))
_1: R.Tuple = R.vm.kill_object(lv_m)
lv1: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
lv1, R.Tensor((3, 16, 3, 3), dtype="float32")
)
lv1_m: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1_m,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
_2: R.Tuple = R.vm.kill_object(lv1_m)
_3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, sinfo_args=(R.Object,))
gv: R.Tuple = R.tuple()
return gv
after = LazyTransformParams()(Before)
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
def test_get_item_only():
@I.ir_module
class Before:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
):
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Before
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
lv3 = R.add(lv2, R.const(1, "float32"))
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 3, 3, 3), dtype="float32"),
) = (lv, lv3)
return gv
@I.ir_module
class Expected:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
@R.function(pure=False)
def main_transform_params() -> (
R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
)
):
cls = Expected
gv: R.Object = R.call_packed("get_item_0", R.prim_value(1), sinfo_args=(R.Object,))
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
gv, R.Tensor((16, 16, 3, 3), dtype="float32")
)
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
gv2: R.Object = R.call_packed("get_item_0", R.prim_value(0), sinfo_args=(R.Object,))
gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
gv2, R.Tensor((3, 16, 3, 3), dtype="float32")
)
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2, R.const(1, "float32"))
gv_1: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
) = (lv, lv3)
return gv_1
after = LazyTransformParams(fget_item="get_item_0", fset_item=None)(Before)
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
def test_extra_get_item_params():
@I.ir_module
class Before:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
):
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Before
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
lv3 = R.add(lv2, R.const(1, "float32"))
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 3, 3, 3), dtype="float32"),
) = (lv, lv3)
return gv
@I.ir_module
class Expected:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
@R.function(pure=False)
def main_transform_params(loader: R.Object) -> R.Tuple:
cls = Expected
gv: R.Object = R.call_packed(
"get_item", loader, R.prim_value(1), sinfo_args=(R.Object,)
)
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
gv, R.Tensor((16, 16, 3, 3), dtype="float32")
)
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
_: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,))
_1: R.Tuple = R.vm.kill_object(lv)
gv2: R.Object = R.call_packed(
"get_item", loader, R.prim_value(0), sinfo_args=(R.Object,)
)
gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
gv2, R.Tensor((3, 16, 3, 3), dtype="float32")
)
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
_2: R.Tuple = R.vm.kill_object(lv1)
lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2, R.const(1, "float32"))
_3: R.Object = R.call_packed("set_item", R.prim_value(1), lv3, sinfo_args=(R.Object,))
gv_1: R.Tuple = R.tuple()
return gv_1
after = LazyTransformParams(
extra_get_item_params=[relax.Var("loader", relax.ObjectStructInfo())]
)(Before)
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
def test_extra_set_item_params():
@I.ir_module
class Before:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32")
),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32")
):
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Before
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0]
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
lv3 = R.add(lv2, R.const(1, "float32"))
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((16, 3, 3, 3), dtype="float32"),
) = (lv, lv3)
return gv
@I.ir_module
class Expected:
@T.prim_func
def transform_layout_IOHW_to_OIHW(
w1: T.Buffer((3, 16, 3, 3), "float32"), out: T.Buffer((16, 3, 3, 3), "float32")
):
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
@R.function(pure=False)
def main_transform_params(setter: R.Object) -> R.Tuple:
cls = Expected
gv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,))
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
gv, R.Tensor((16, 16, 3, 3), dtype="float32")
)
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
_: R.Object = R.call_packed(
"set_item", setter, R.prim_value(0), lv, sinfo_args=(R.Object,)
)
_1: R.Tuple = R.vm.kill_object(lv)
gv2: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
gv2, R.Tensor((3, 16, 3, 3), dtype="float32")
)
lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
)
_2: R.Tuple = R.vm.kill_object(lv1)
lv3: R.Tensor((16, 3, 3, 3), dtype="float32") = R.add(lv2, R.const(1, "float32"))
_3: R.Object = R.call_packed(
"set_item", setter, R.prim_value(1), lv3, sinfo_args=(R.Object,)
)
gv_1: R.Tuple = R.tuple()
return gv_1
after = LazyTransformParams(
extra_set_item_params=[relax.Var("setter", relax.ObjectStructInfo())]
)(Before)
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
def test_extra_set_item_params_with_const_output():
@I.ir_module
class Before:
@R.function
def main_transform_params(
params: R.Tuple(),
) -> R.Tuple(R.Tensor([2], dtype="float32"), R.Tensor([3], dtype="float32")):
R.func_attr({"relax.force_pure": True})
gv = (
R.const(np.array([1, 2]).astype("float32")),
R.const(np.array([3, 4]).astype("float32")),
)
return gv
@I.ir_module
class Expected:
@R.function(pure=False)
def main_transform_params(setter: R.Object) -> R.Tuple:
output = R.tuple()
_ = R.call_packed(
"set_item",
setter,
R.prim_value(0),
R.const(np.array([1, 2]).astype("float32")),
sinfo_args=(R.Object,),
)
_ = R.call_packed(
"set_item",
setter,
R.prim_value(1),
R.const(np.array([3, 4]).astype("float32")),
sinfo_args=(R.Object,),
)
return output
after = LazyTransformParams(
extra_set_item_params=[relax.Var("setter", relax.ObjectStructInfo())]
)(Before)
tvm.ir.assert_structural_equal(after, Expected)
def test_lazy_transform_params_with_symbolic_vars():
@I.ir_module
class Before:
@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((16, 16), dtype="float32"),
R.Shape(
["slice_index"],
),
),
):
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Before
slice_index = T.int64()
param = params[0]
transformed = R.call_tir(
cls.slice_buffer,
(param,),
tir_vars=[slice_index],
out_sinfo=R.Tensor((16,), dtype="float32"),
)
output = (transformed,)
return output
@T.prim_func(private=True)
def slice_buffer(
Input: T.Buffer((16, 16), "float32"),
Output: T.Buffer(16, "float32"),
slice_index: T.int64,
):
for i in T.grid(16):
with T.block("slice_buffer"):
vi = T.axis.remap("S", [i])
Output[vi] = Input[slice_index, vi]
@I.ir_module
class Expected:
@R.function(pure=False)
def main_transform_params(slice_shape_expr: R.Shape(["slice_index"])):
cls = Expected
slice_index = T.int64()
param = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
gv: R.Tensor((16, 16), dtype="float32") = R.match_cast(
param, R.Tensor((16, 16), dtype="float32")
)
param_m: R.Tensor((16, 16), dtype="float32") = gv
transformed = R.call_tir(
cls.slice_buffer,
(param_m,),
tir_vars=[slice_index],
out_sinfo=R.Tensor((16,), dtype="float32"),
)
unused_1_ = R.vm.kill_object(param_m)
unused_2_ = R.call_packed(
"set_item", R.prim_value(0), transformed, sinfo_args=(R.Object,)
)
output = R.tuple()
return output
@T.prim_func(private=True)
def slice_buffer(
Input: T.Buffer((16, 16), "float32"),
Output: T.Buffer(16, "float32"),
slice_index: T.int64,
):
for i in T.grid(16):
with T.block("slice_buffer"):
vi = T.axis.remap("S", [i])
Output[vi] = Input[slice_index, vi]
after = LazyTransformParams()(Before)
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
def test_param_shape_symbolic():
@I.ir_module
class Before:
@T.prim_func
def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle):
ic = T.int32()
w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32")
out = T.match_buffer(var_out, (16, ic, 3, 3), "float32")
for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
@R.function
def main_transform_params(
params: R.Tuple(
R.Tensor((3, "ic", 3, 3), dtype="float32"),
R.Tensor((16, 16, 3, 3), dtype="float32"),
),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32")
):
ic = T.int64()
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Before
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
lv1: R.Tensor((3, ic, 3, 3), dtype="float32") = params[0]
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"),
)
gv: R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"),
R.Tensor((ic, 3, 3, 3), dtype="float32"),
) = (lv, lv2)
return gv
@I.ir_module
class Expected:
@T.prim_func
def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle):
ic = T.int32()
w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32")
out = T.match_buffer(var_out, (16, ic, 3, 3), "float32")
for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3):
with T.block("layout_transform"):
o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w1[i, o, h, w])
T.writes(out[o, i, h, w])
out[o, i, h, w] = w1[i, o, h, w]
@R.function(pure=False)
def main_transform_params() -> R.Tuple:
ic = T.int64()
cls = Expected
gv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,))
gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
gv, R.Tensor((16, 16, 3, 3), dtype="float32")
)
lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
_: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,))
_1: R.Tuple = R.vm.kill_object(lv)
gv2: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
gv3: R.Tensor((3, ic, 3, 3), dtype="float32") = R.match_cast(
gv2, R.Tensor((3, ic, 3, 3), dtype="float32")
)
lv1: R.Tensor((3, ic, 3, 3), dtype="float32") = gv3
lv2 = R.call_tir(
cls.transform_layout_IOHW_to_OIHW,
(lv1,),
out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"),
)
_2: R.Tuple = R.vm.kill_object(lv1)
_3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, sinfo_args=(R.Object,))
gv4: R.Tuple = R.tuple()
return gv4
after = LazyTransformParams()(Before)
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
def test_output_with_use_site():
@I.ir_module
class Module:
@T.prim_func
def copy(x: T.Buffer((), "float32"), y: T.Buffer((), "float32")):
with T.block("block"):
T.reads(x[()])
T.writes(y[()])
y[()] = x[()]
@R.function
def main_transform_params(
params: R.Tuple(R.Tensor((), dtype="float32"))
) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")):
# we expect ToNonDataflow and RemovePurityTracking to be invoked first
R.func_attr({"relax.force_pure": True})
cls = Module
x: R.Tensor((), dtype="float32") = params[0]
y = R.call_tir(cls.copy, (x,), out_sinfo=R.Tensor((), dtype="float32"))
z = R.call_tir(cls.copy, (y,), out_sinfo=R.Tensor((), dtype="float32"))
gv: R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")) = (y, z)
return gv
@I.ir_module
class Expected:
@T.prim_func
def copy(x: T.Buffer((), "float32"), y: T.Buffer((), "float32")):
with T.block("block"):
T.reads(x[()])
T.writes(y[()])
y[()] = x[()]
@R.function(pure=False)
def main_transform_params() -> R.Tuple:
cls = Expected
x: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
gv: R.Tensor((), dtype="float32") = R.match_cast(x, R.Tensor((), dtype="float32"))
x_m: R.Tensor((), dtype="float32") = gv
y = R.call_tir(cls.copy, (x_m,), out_sinfo=R.Tensor((), dtype="float32"))
_: R.Tuple = R.vm.kill_object(x_m)
z = R.call_tir(cls.copy, (y,), out_sinfo=R.Tensor((), dtype="float32"))
_1: R.Object = R.call_packed("set_item", R.prim_value(0), y, sinfo_args=(R.Object,))
_2: R.Object = R.call_packed("set_item", R.prim_value(1), z, sinfo_args=(R.Object,))
gv: R.Tuple = R.tuple()
return gv
after = LazyTransformParams()(Module)
tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
def test_output():
target = "llvm"
dev = tvm.device(target)
@I.ir_module
class TransformModule:
@R.function
def transform_params(
params: R.Tuple(
R.Tensor((3, "ic", 3, 3), dtype="float32"),
R.Tensor((16, 16, 3, 3), dtype="float32"),
),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32")
):
R.func_attr({"relax.force_pure": True})
param0 = params[0]
param1 = params[1]
transformed0 = R.permute_dims(param0, [1, 0, 2, 3])
transformed = (transformed0, param1)
return transformed
mod = TransformModule
mod = relax.transform.LazyTransformParams()(mod)
mod = relax.transform.LegalizeOps()(mod)
built = tvm.compile(mod, target=target)
params = [
np.random.random(size=(3, 64, 3, 3)).astype("float32"),
np.random.random(size=(16, 16, 3, 3)).astype("float32"),
]
transformed = {}
expected = [params[0].transpose(1, 0, 2, 3), params[1]]
@tvm.register_global_func("get_item", override=True)
def get_item(i):
return tvm.runtime.tensor(params[i], dev)
@tvm.register_global_func("set_item", override=True)
def set_item(i, value):
assert i not in transformed, f"Set item called multiple times for index {i}"
transformed[i] = value.numpy()
vm = relax.VirtualMachine(built, dev)
vm["transform_params"]()
assert sorted(transformed) == list(range(len(transformed)))
transformed = [value for i, value in sorted(transformed.items())]
assert len(transformed) == len(expected)
for expected_i, transformed_i in zip(expected, transformed):
tvm.testing.assert_allclose(expected_i, transformed_i)
def test_duplicate_outputs():
"""A tensor may be repeated in the output
This is something that should be avoided upstream, but is a legal
parameter transformation, and should produce correct output.
"""
@I.ir_module
class Before:
@R.function
def main_transform_params(
params: R.Tuple(R.Tensor([16], dtype="int32"), R.Tensor([16], dtype="int32")),
):
R.func_attr({"relax.force_pure": True})
param0 = params[0]
param1 = params[1]
transformed0 = R.add(param0, R.const(1, "int32"))
transformed1 = R.add(param1, R.const(2, "int32"))
output = (transformed0, transformed1, transformed0)
return output
@I.ir_module
class Expected:
@R.function(pure=False)
def main_transform_params() -> R.Tuple:
gv: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,))
gv1: R.Tensor((16,), dtype="int32") = R.match_cast(gv, R.Tensor((16,), dtype="int32"))
param0: R.Tensor((16,), dtype="int32") = gv1
gv2: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,))
gv3: R.Tensor((16,), dtype="int32") = R.match_cast(gv2, R.Tensor((16,), dtype="int32"))
param1: R.Tensor((16,), dtype="int32") = gv3
transformed0: R.Tensor((16,), dtype="int32") = R.add(param0, R.const(1, "int32"))
_: R.Tuple = R.vm.kill_object(param0)
_: R.Object = R.call_packed(
"set_item", R.prim_value(0), transformed0, sinfo_args=(R.Object,)
)
_: R.Object = R.call_packed(
"set_item", R.prim_value(2), transformed0, sinfo_args=(R.Object,)
)
transformed1: R.Tensor((16,), dtype="int32") = R.add(param1, R.const(2, "int32"))
_ = R.vm.kill_object(param1)
_ = R.call_packed("set_item", R.prim_value(1), transformed1, sinfo_args=(R.Object,))
output = R.tuple()
return output
after = LazyTransformParams()(Before)
tvm.ir.assert_structural_equal(after, Expected)
def test_params_without_tuple():
@I.ir_module
class Before:
@R.function
def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32")):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
return (D, B)
@I.ir_module
class Expected:
@R.function(pure=False)
def transform_params():
A = R.call_packed("get_item", R.prim_value(0), sinfo_args=[R.Object])
A = R.match_cast(A, R.Tensor([16, 16], "float32"))
C = R.multiply(A, R.const(2, "float32"))
B = R.call_packed("get_item", R.prim_value(1), sinfo_args=[R.Object])
B = R.match_cast(B, R.Tensor([16, 16], "float32"))
D = R.add(C, B)
return (D, B)
After = LazyTransformParams(fset_item=None)(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_retain_before_num_input():
"""Only lazily load parameters after num_input"""
@I.ir_module
class Before:
@R.function
def transform_params(
relax_rank: R.Prim(value="rank"),
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
):
R.func_attr({"num_input": 1})
rank = T.int64()
A_sharded = R.strided_slice(
A, axes=[0], begin=[rank * 8], end=[(rank + 1) * 8], assume_inbound=True
)
B_sharded = R.strided_slice(
B, axes=[1], begin=[rank * 8], end=[(rank + 1) * 8], assume_inbound=True
)
return (A_sharded, B_sharded)
@I.ir_module
class Expected:
@R.function(pure=False)
def transform_params(relax_rank: R.Prim(value="rank")):
R.func_attr({"num_input": 1})
rank = T.int64()
A = R.call_packed("get_item", R.prim_value(0), sinfo_args=[R.Object])
A = R.match_cast(A, R.Tensor([16, 16], "float32"))
A_sharded = R.strided_slice(
A, axes=[0], begin=[rank * 8], end=[(rank + 1) * 8], assume_inbound=True
)
B = R.call_packed("get_item", R.prim_value(1), sinfo_args=[R.Object])
B = R.match_cast(B, R.Tensor([16, 16], "float32"))
B_sharded = R.strided_slice(
B, axes=[1], begin=[rank * 8], end=[(rank + 1) * 8], assume_inbound=True
)
return (A_sharded, B_sharded)
After = LazyTransformParams(fset_item=None)(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_params_without_tuple_with_symbolic_var():
@I.ir_module
class Before:
@R.function
def transform_params(A: R.Object):
return (A,)
@I.ir_module
class Expected:
@R.function(pure=False)
def transform_params():
A = R.call_packed("get_item", R.prim_value(0), sinfo_args=[R.Object])
A = R.match_cast(A, R.Object)
return (A,)
After = LazyTransformParams(fset_item=None)(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_get_item_callback():
@I.ir_module
class Before:
@R.function
def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32")):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
return (D, B)
@I.ir_module
class Expected:
@R.function
def transform_params(fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object)):
R.func_attr({"num_input": 1})
A = fget_param(R.prim_value(0), R.str("A"))
A = R.match_cast(A, R.Tensor([16, 16], "float32"))
C = R.multiply(A, R.const(2, "float32"))
B = fget_param(R.prim_value(1), R.str("B"))
B = R.match_cast(B, R.Tensor([16, 16], "float32"))
D = R.add(C, B)
return (D, B)
After = relax.transform.LazyGetInput()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_get_item_callback_num_attrs():
@I.ir_module
class Before:
@R.function(pure=False)
def transform_params(
rank_arg: R.Prim(value="rank"),
world_size_arg: R.Prim(value="world_size"),
weight_A: R.Tensor([16, 64], "float32"),
weight_B: R.Tensor([1024, 2048], "float32"),
):
R.func_attr({"num_input": 2})
rank = T.int64()
world_size = T.int64()
_ = R.assert_op(
R.prim_value(16 % world_size == 0),
[R.prim_value(16), R.prim_value(world_size)],
format=(
"World size must evenly divide A.shape[0] ({}), "
"but received world size of {}."
),
)
weight_A = R.strided_slice(
weight_A,
axes=[0],
begin=[rank * 16 // world_size],
end=[(rank + 1) * 16 // world_size],
)
_ = R.assert_op(
R.prim_value(2048 % world_size == 0),
[R.prim_value(2048), R.prim_value(world_size)],
format=(
"World size must evenly divide B.shape[1] ({}), "
"but received world size of {}."
),
)
weight_B = R.strided_slice(
weight_B,
axes=[1],
begin=[rank * 2048 // world_size],
end=[(rank + 1) * 2048 // world_size],
)
return (weight_A, weight_B)
@I.ir_module
class Expected:
@R.function(pure=False)
def transform_params(
rank_arg: R.Prim(value="rank"),
world_size_arg: R.Prim(value="world_size"),
fget_item: R.Callable([R.Prim("int64"), R.Object], R.Object),
):
R.func_attr({"num_input": 3})
rank = T.int64()
world_size = T.int64()
_ = R.assert_op(
R.prim_value(16 % world_size == 0),
[R.prim_value(16), R.prim_value(world_size)],
format=(
"World size must evenly divide A.shape[0] ({}), "
"but received world size of {}."
),
)
weight_A = fget_item(R.prim_value(0), R.str("weight_A"))
weight_A = R.match_cast(weight_A, R.Tensor([16, 64], "float32"))
weight_A = R.strided_slice(
weight_A,
axes=[0],
begin=[rank * 16 // world_size],
end=[(rank + 1) * 16 // world_size],
)
_ = R.assert_op(
R.prim_value(2048 % world_size == 0),
[R.prim_value(2048), R.prim_value(world_size)],
format=(
"World size must evenly divide B.shape[1] ({}), "
"but received world size of {}."
),
)
weight_B = fget_item(R.prim_value(1), R.str("weight_B"))
weight_B = R.match_cast(weight_B, R.Tensor([1024, 2048], "float32"))
weight_B = R.strided_slice(
weight_B,
axes=[1],
begin=[rank * 2048 // world_size],
end=[(rank + 1) * 2048 // world_size],
)
return (weight_A, weight_B)
After = relax.transform.LazyGetInput()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_get_item_callback_dynamic_shape():
@I.ir_module
class Before:
@R.function
def transform_params(
A: R.Tensor(["m", "n"], "float32"), B: R.Tensor(["m", "n"], "float32")
) -> R.Tuple(R.Tensor(["m", "n"], "float32"), R.Tensor(["m", "n"], "float32")):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
return (D, B)
@I.ir_module
class Expected:
@R.function
def transform_params(
fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object),
) -> R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor(ndim=2, dtype="float32")):
R.func_attr({"num_input": 1})
m = T.int64()
n = T.int64()
A = fget_param(R.prim_value(0), R.str("A"))
A = R.match_cast(A, R.Tensor([m, n], "float32"))
C = R.multiply(A, R.const(2, "float32"))
B = fget_param(R.prim_value(1), R.str("B"))
B = R.match_cast(B, R.Tensor([m, n], "float32"))
D = R.add(C, B)
return (D, B)
After = relax.transform.LazyGetInput()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_set_output_callback():
"""fset_output is called for each element of the output tuple
The call is placed immediately after the corresponding
`VarBinding`.
"""
@I.ir_module
class Before:
@R.function
def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32")):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
return (D, C)
@I.ir_module
class Expected:
@R.function(pure=False)
def transform_params(
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
):
C = R.multiply(A, R.const(2, "float32"))
fset_output(R.prim_value(1), C)
D = R.add(C, B)
fset_output(R.prim_value(0), D)
return R.tuple()
After = relax.transform.LazySetOutput()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_set_output_callback_of_param():
"""fset_output may need to be called for parameters
A function parameter does not have a `VarBinding`. If a parameter
is returned in the output tuple, the `fset_output` call is
generated at the beginning of the function.
"""
@I.ir_module
class Before:
@R.function
def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32")):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
return (D, B)
@I.ir_module
class Expected:
@R.function(pure=False)
def transform_params(
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
):
fset_output(R.prim_value(1), B)
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
fset_output(R.prim_value(0), D)
return R.tuple()
After = relax.transform.LazySetOutput()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_set_output_callback_num_input():
"""The parameter transformation may have other runtime parameters
The new `fset_output` parameter is placed after the other runtime
parameters, before any model weights.
"""
@I.ir_module
class Before:
@R.function
def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32")):
R.func_attr({"num_input": 1})
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
return (D, B)
@I.ir_module
class Expected:
@R.function(pure=False)
def transform_params(
A: R.Tensor([16, 16], "float32"),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
B: R.Tensor([16, 16], "float32"),
):
R.func_attr({"num_input": 2})
fset_output(R.prim_value(1), B)
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
fset_output(R.prim_value(0), D)
return R.tuple()
After = relax.transform.LazySetOutput()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_set_output_callback_with_duplicate_output():
"""fset_output may be called more than once for a variable
A variable may occur multiple times in the output tuple. The
`fset_output` callback should be called once for each tuple
element, even if they reuse the same variable.
"""
@I.ir_module
class Before:
@R.function
def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32")):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
return (D, D)
@I.ir_module
class Expected:
@R.function(pure=False)
def transform_params(
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
fset_output(R.prim_value(0), D)
fset_output(R.prim_value(1), D)
return R.tuple()
After = relax.transform.LazySetOutput()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_set_output_callback_with_inline_const():
"""fset_output may be called for inline objects
The return tuple may contain inline leaf nodes, such as
`relax.PrimValue` or `relax.Constant`. A call to `fset_output`
must be generated, even though they do not have an associated
`relax.VarBinding`.
"""
@I.ir_module
class Before:
@R.function
def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32")):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
return (C, D, R.prim_value(42), R.const(17.5, "float16"))
@I.ir_module
class Expected:
@R.function(pure=False)
def transform_params(
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
):
C = R.multiply(A, R.const(2, "float32"))
fset_output(R.prim_value(0), C)
D = R.add(C, B)
fset_output(R.prim_value(1), D)
fset_output(R.prim_value(2), R.prim_value(42))
fset_output(R.prim_value(3), R.const(17.5, "float16"))
return R.tuple()
After = relax.transform.LazySetOutput()(Before)
tvm.ir.assert_structural_equal(After, Expected)
def test_set_output_callback_with_non_tuple_output():
"""Non-tuple outputs produce a single call to fset_output"""
@I.ir_module
class Before:
@R.function
def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32")):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
return D
@I.ir_module
class Expected:
@R.function(pure=False)
def transform_params(
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
fset_output(R.prim_value(0), D)
return R.tuple()
After = relax.transform.LazySetOutput()(Before)
tvm.ir.assert_structural_equal(After, Expected)
if __name__ == "__main__":
tvm.testing.main()