blob: d3199314150025a6dc1ee1d40fa55cd894914819 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm.testing
from tvm.script import ir as I, relax as R, tir as T
def _get_before_module():
@I.ir_module
class Module:
@R.function
def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
R.func_attr({"relax.force_pure": True})
B = Module.relax_subroutine(A)
C = R.call_tir(Module.tir_main, B, out_sinfo=R.Tensor([16], "float32"))
D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0)
Module.tir_main(C, D)
return D
@R.function(private=True)
def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
B = R.add(A, R.prim_value(T.float32(1.0)))
return B
@T.prim_func
def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
Module.tir_subroutine(A.data, B.data)
@T.prim_func(private=True)
def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")):
A = T.decl_buffer(16, "float32", data=A_data)
B = T.decl_buffer(16, "float32", data=B_data)
for i in range(16):
B[i] = A[i] + 1.0
return Module
def test_no_op_if_no_replacements():
"""If no replacements are performed, the IRModule is unmodified"""
before = _get_before_module()
expected = before
after = before.replace_global_vars({})
tvm.ir.assert_structural_equal(expected, after)
assert before.same_as(after)
def test_replace_relax_main():
"""An externally-exposed Relax function may be replaced
In this example, the "relax_main" function is renamed. This
requires changing both the GlobalVar used to refer to the
function, and the "global_symbol" attribute of the
externally-exposed function.
"""
before = _get_before_module()
after = before.replace_global_vars({"relax_main": "relax_main_with_new_name"})
@I.ir_module
class Expected:
@R.function
def relax_main_with_new_name(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
R.func_attr({"relax.force_pure": True})
B = Expected.relax_subroutine(A)
C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], "float32"))
D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0)
Expected.tir_main(C, D)
return D
@R.function(private=True)
def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
B = R.add(A, R.prim_value(T.float32(1.0)))
return B
@T.prim_func
def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
Expected.tir_subroutine(A.data, B.data)
@T.prim_func(private=True)
def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")):
A = T.decl_buffer(16, "float32", data=A_data)
B = T.decl_buffer(16, "float32", data=B_data)
for i in range(16):
B[i] = A[i] + 1.0
tvm.ir.assert_structural_equal(Expected, after)
def test_replace_relax_subroutine():
"""An internal Relax function may be replaced
In this example, the "relax_subroutine" function is renamed. This
requires changing both the GlobalVar used to refer to the
function, and the GlobalVar used to call the subroutine within
"relax_main". The "global_symbol" attribute does not need to be
updated, because internal functions do not have this attribute.
"""
before = _get_before_module()
after = before.replace_global_vars({"relax_subroutine": "relax_subroutine_with_new_name"})
@I.ir_module
class Expected:
@R.function
def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
R.func_attr({"relax.force_pure": True})
B = Expected.relax_subroutine_with_new_name(A)
C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], "float32"))
D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0)
Expected.tir_main(C, D)
return D
@R.function(private=True)
def relax_subroutine_with_new_name(
A: R.Tensor([16], "float32"),
) -> R.Tensor([16], "float32"):
B = R.add(A, R.prim_value(T.float32(1.0)))
return B
@T.prim_func
def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
Expected.tir_subroutine(A.data, B.data)
@T.prim_func(private=True)
def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")):
A = T.decl_buffer(16, "float32", data=A_data)
B = T.decl_buffer(16, "float32", data=B_data)
for i in range(16):
B[i] = A[i] + 1.0
tvm.ir.assert_structural_equal(Expected, after)
def test_replace_tir_main():
"""An externally-exposed TIR function may be replaced
In this example, the "tir_main" function is renamed. This
requires changing both the GlobalVar used to refer to the
function, the "global_symbol" attribute of the externally-exposed
function. In addition, calls to the TIR function should be
updated to use the new GlobalVar.
"""
before = _get_before_module()
after = before.replace_global_vars({"tir_main": "tir_main_with_new_name"})
@I.ir_module
class Expected:
@R.function
def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
R.func_attr({"relax.force_pure": True})
B = Expected.relax_subroutine(A)
C = R.call_tir(Expected.tir_main_with_new_name, B, out_sinfo=R.Tensor([16], "float32"))
D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0)
Expected.tir_main_with_new_name(C, D)
return D
@R.function(private=True)
def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
B = R.add(A, R.prim_value(T.float32(1.0)))
return B
@T.prim_func
def tir_main_with_new_name(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
Expected.tir_subroutine(A.data, B.data)
@T.prim_func(private=True)
def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")):
A = T.decl_buffer(16, "float32", data=A_data)
B = T.decl_buffer(16, "float32", data=B_data)
for i in range(16):
B[i] = A[i] + 1.0
tvm.ir.assert_structural_equal(Expected, after)
def test_replace_tir_subroutine():
"""An internally-exposed TIR function may be replaced
In this example, the "tir_subroutine" function is renamed. This
requires changing both the GlobalVar used to refer to the
function, and the GlobalVar used to refer to it. Internal
functions do not have the "global_symbol" attribute, so it does
not need to be updated.
"""
before = _get_before_module()
after = before.replace_global_vars({"tir_subroutine": "tir_subroutine_with_new_name"})
@I.ir_module
class Expected:
@R.function
def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
R.func_attr({"relax.force_pure": True})
B = Expected.relax_subroutine(A)
C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], "float32"))
D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0)
Expected.tir_main(C, D)
return D
@R.function(private=True)
def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
B = R.add(A, R.prim_value(T.float32(1.0)))
return B
@T.prim_func
def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
Expected.tir_subroutine_with_new_name(A.data, B.data)
@T.prim_func(private=True)
def tir_subroutine_with_new_name(A_data: T.ptr("float32"), B_data: T.ptr("float32")):
A = T.decl_buffer(16, "float32", data=A_data)
B = T.decl_buffer(16, "float32", data=B_data)
for i in range(16):
B[i] = A[i] + 1.0
tvm.ir.assert_structural_equal(Expected, after)
def test_simultaneous_replacements():
"""Multiple replacements may be performed simultaneously"""
before = _get_before_module()
after = before.replace_global_vars(
{
"relax_main": "relax_main_with_new_name",
"relax_subroutine": "relax_subroutine_with_new_name",
"tir_main": "tir_main_with_new_name",
"tir_subroutine": "tir_subroutine_with_new_name",
}
)
@I.ir_module
class Expected:
@R.function
def relax_main_with_new_name(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"):
R.func_attr({"relax.force_pure": True})
B = Expected.relax_subroutine_with_new_name(A)
C = R.call_tir(Expected.tir_main_with_new_name, B, out_sinfo=R.Tensor([16], "float32"))
D = R.builtin.alloc_tensor(R.shape([16]), "float32", runtime_device_index=0)
Expected.tir_main_with_new_name(C, D)
return D
@R.function(private=True)
def relax_subroutine_with_new_name(
A: R.Tensor([16], "float32"),
) -> R.Tensor([16], "float32"):
B = R.add(A, R.prim_value(T.float32(1.0)))
return B
@T.prim_func
def tir_main_with_new_name(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
Expected.tir_subroutine_with_new_name(A.data, B.data)
@T.prim_func(private=True)
def tir_subroutine_with_new_name(A_data: T.ptr("float32"), B_data: T.ptr("float32")):
A = T.decl_buffer(16, "float32", data=A_data)
B = T.decl_buffer(16, "float32", data=B_data)
for i in range(16):
B[i] = A[i] + 1.0
tvm.ir.assert_structural_equal(Expected, after)
if __name__ == "__main__":
tvm.testing.main()