blob: 9e212693f969e4be9e1ab18d80ba2cd537be7d8d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
import tvm.script
import tvm.testing
from tvm import relax
from tvm.script import relax as R
from tvm.script import tir as T
use_np_array = tvm.testing.parameter(False, True)
def test_bind_params(use_np_array):
@tvm.script.ir_module
class InputModule:
@T.prim_func
def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
T.func_attr({"global_symbol": "tir_matmul"})
A = T.match_buffer(x, (16, 16))
B = T.match_buffer(y, (16, 16))
C = T.match_buffer(z, (16, 16))
for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
with T.block("matmul"):
vi = T.axis.S(16, i0 * 4 + i1)
vj = T.axis.S(16, j)
vk = T.axis.R(16, k0 * 4 + k1)
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
@R.function
def main(
x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
) -> R.Tensor((16, 16), "float32"):
gv0 = R.call_tir(InputModule.tir_matmul, (x, w), R.Tensor((16, 16), dtype="float32"))
return gv0
x_np = np.random.rand(16, 16).astype(np.float32)
w_np = np.random.rand(16, 16).astype(np.float32)
x_tvm = tvm.nd.array(x_np)
w_tvm = tvm.nd.array(w_np)
params_dict = {"w": w_np if use_np_array else w_tvm}
mod = relax.transform.BindParams("main", params_dict)(InputModule)
assert len(mod["main"].params) == 1
target = tvm.target.Target("llvm")
ex_after = relax.build(mod, target)
vm_after = relax.VirtualMachine(ex_after, tvm.cpu())
res_after = vm_after["main"](x_tvm)
ex_before = relax.build(InputModule, target)
vm_before = relax.VirtualMachine(ex_before, tvm.cpu())
res_before = vm_before["main"](x_tvm, w_tvm)
tvm.testing.assert_allclose(res_before.numpy(), res_after.numpy())
def test_bind_params_symbolic_vars():
@tvm.script.ir_module
class Before:
@R.function
def main(
x: R.Tensor(("batch", "m"), dtype="float32"),
w0: R.Tensor(("n", "m"), dtype="float32"),
b0: R.Tensor(("n",), dtype="float32"),
w1: R.Tensor(("k", "n"), dtype="float32"),
b1: R.Tensor(("k",), dtype="float32"),
) -> R.Tensor(("batch", "k"), dtype="float32"):
batch = T.Var("batch", "int64")
k = T.Var("k", "int64")
m = T.Var("m", "int64")
n = T.Var("n", "int64")
with R.dataflow():
lv0 = R.call_dps_packed(
"linear0", (x, w0, b0), out_sinfo=R.Tensor((batch, n), dtype="float32")
)
out = R.call_dps_packed(
"linear1", (lv0, w1, b1), out_sinfo=R.Tensor((batch, k), dtype="float32")
)
R.output(out)
return out
m, n, k = 4, 6, 8
w0_tvm = tvm.nd.array(np.random.rand(n, m).astype(np.float32))
b0_tvm = tvm.nd.array(np.random.rand(n).astype(np.float32))
w1_tvm = tvm.nd.array(np.random.rand(k, n).astype(np.float32))
b1_tvm = tvm.nd.array(np.random.rand(k).astype(np.float32))
params_dict = {"w0": w0_tvm, "b0": b0_tvm, "w1": w1_tvm, "b1": b1_tvm}
mod = relax.transform.BindParams("main", params_dict)(Before)
# Since it contains ConstantNode, it's hard to check with structural equality.
func = mod["main"]
assert len(func.params) == 1
batch = func.params[0].struct_info.shape[0]
tvm.ir.assert_structural_equal(
func.params[0].struct_info, relax.TensorStructInfo((batch, 4), "float32")
)
tvm.ir.assert_structural_equal(
func.ret_struct_info, relax.TensorStructInfo((batch, 8), "float32")
)
bindings = func.body.blocks[0].bindings
tvm.ir.assert_structural_equal(
bindings[0].var.struct_info, relax.TensorStructInfo((batch, 6), "float32")
)
tvm.ir.assert_structural_equal(
bindings[1].var.struct_info, relax.TensorStructInfo((batch, 8), "float32")
)
param_specification = tvm.testing.parameter("by_string", "by_var")
def test_bind_params_by_var_obj(param_specification):
@tvm.script.ir_module
class Before:
@R.function
def main(A: R.Tensor([16], "float32")):
return A
np_data = np.arange(16).astype("float32")
inlined_relax_const = relax.const(np_data)
@tvm.script.ir_module
class Expected:
@R.function
def main():
return inlined_relax_const
if param_specification == "by_string":
var = "A"
elif param_specification == "by_var":
var = Before["main"].params[0]
else:
raise ValueError("Unknown param_specification: {param_specification}")
After = relax.transform.BindParams("main", {var: np_data})(Before)
tvm.ir.assert_structural_equal(Expected, After)
def test_bind_params_by_var_name():
@tvm.script.ir_module
class Before:
@R.function
def main(A: R.Tensor([16], "float32")):
return A
np_data = np.arange(16).astype("float32")
inlined_relax_const = relax.const(np_data)
@tvm.script.ir_module
class Expected:
@R.function
def main():
return inlined_relax_const
After = relax.transform.BindParams("main", {"A": np_data})(Before)
tvm.ir.assert_structural_equal(Expected, After)
if __name__ == "__main__":
tvm.testing.main()