blob: 04a3c5af1cd96ff63919ad5c46c1783745892626 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.testing import run_infer_type, create_workload
target_name = tvm.testing.parameter("opencl", "metal", "cuda")
shape_type = tvm.testing.parameter("dynamic", "static")
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = relay.transform.InferType()(mod)
mod = opt_pass(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_split_concat(target_name, shape_type):
if shape_type == "dynamic":
shape = (tvm.tir.Any(), 1, 1, 3)
number_of_any_dims = 1
else:
shape = (1, 1, 1, 3)
number_of_any_dims = 0
ndims = len(shape)
dtype = "float32"
axis = 1
tensors_num = 300
inputs = []
for i in range(tensors_num):
inputs.append(relay.var("p{}".format(i), shape=shape, dtype=dtype))
def before():
inp = relay.Tuple(inputs)
return relay.op.concatenate(inp, axis)
def expected(limit):
if limit == 0:
return before()
limit = limit - 1 # one buffer with output
if number_of_any_dims > 0:
limit -= ndims
new_args = []
added_args = 0
num_inputs = 0
for inp in inputs:
curr_args = 1 + number_of_any_dims
if number_of_any_dims > 0:
curr_args += ndims
num_inputs += curr_args
if added_args + curr_args > limit:
t = relay.Tuple(new_args)
stop = relay.annotation.stop_fusion(t)
concat = relay.op.concatenate(stop, axis)
new_args = [concat]
added_args = curr_args
added_args += curr_args
new_args.append(inp)
t = relay.Tuple(new_args)
stop = relay.annotation.stop_fusion(t)
concat = relay.op.concatenate(stop, axis)
if num_inputs < limit:
return before()
return concat
# the fold constant should work on any context.
limit = tvm.target.Target(target_name).max_function_args
res = run_opt_pass(before(), transform.SplitArgs(limit))
exp = run_opt_pass(expected(limit), transform.InferType())
tvm.ir.assert_structural_equal(res, exp)
if __name__ == "__main__":
tvm.testing.main()