blob: bd058e89797967813429a942f8bfbbf304ee33be [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Callable, List, Optional, Union
import pytest
import tvm
import tvm.testing
from tvm import relax
from tvm.relax import Expr
from tvm.relax.utils import args_converter
def _test_base(f_checker: Callable, arg: Any, *args: Any, **kwargs: Any) -> None:
# Test converting to `Expr`
assert f_checker(arg)
# Test converting `*args`
assert isinstance(args, tuple)
assert all([f_checker(arg) for arg in args])
# Test converting `**kwargs`
assert isinstance(kwargs, dict)
assert all([f_checker(arg) for arg in kwargs.values()])
def _test_expr(arg: Expr, *args: Expr, **kwargs: Expr) -> None:
f_checker = lambda x: isinstance(x, Expr)
_test_base(f_checker, arg, *args, **kwargs)
def _test_optional_expr(
arg: Optional[Expr], *args: Optional[Expr], **kwargs: Optional[Expr]
) -> None:
f_checker = lambda x: x is None or isinstance(x, Expr)
_test_base(f_checker, arg, *args, **kwargs)
def _test_list_expr(arg: List[Expr], *args: List[Expr], **kwargs: List[Expr]) -> None:
f_checker = lambda x: isinstance(x, list) and all([isinstance(arg, Expr) for arg in x])
_test_base(f_checker, arg, *args, **kwargs)
def _test_optional_list_expr(
arg: Optional[List[Expr]], *args: Optional[List[Expr]], **kwargs: Optional[List[Expr]]
) -> None:
f_checker = lambda x: x is None or (
isinstance(x, list) and all([isinstance(arg, Expr) for arg in x])
)
_test_base(f_checker, arg, *args, **kwargs)
prim_value = 1
str_value = "value_to_be_convert"
shape_value = (1, 1)
tuple_value = (relax.const(1), (1, 1))
placeholder = relax.const(0)
test_cases = [prim_value, str_value, shape_value, tuple_value, placeholder]
def test_args_to_expr():
for _f in [_test_expr, _test_optional_expr]:
f = args_converter.to_expr("arg", "args", "kwargs")(_f)
for x in test_cases:
f(
x,
x, # the first argument in *args
x, # the second argument in *args
test_kwargs=x,
)
if _f == _test_optional_expr:
f(None, None, x, test_kwargs=None)
def test_args_to_list_expr():
for _f in [_test_list_expr, _test_optional_list_expr]:
f = args_converter.to_list_expr("arg", "args", "kwargs")(_f)
for x in test_cases:
f(
[x],
[x], # the first argument in *args
[x, x], # the second argument in *args
test_kwargs=[x, (x,)],
)
if _f == _test_optional_list_expr:
f(None, None, [x], test_kwargs=None)
def test_error():
f = args_converter.to_list_expr("arg", "args", "kwargs")(_test_list_expr)
with pytest.raises(TypeError):
f(prim_value) # fail to convert prim_value to `List[Expr]`
def test_auto_convert():
for _f in [_test_expr, _test_optional_expr]:
f = args_converter.auto(_f)
for x in test_cases:
f(x, (x,), test_kwargs=x)
if _f == _test_optional_expr:
f(None, x, test_kwargs=None)
for _f in [_test_list_expr, _test_optional_list_expr]:
f = args_converter.auto(_f)
for x in test_cases:
f([x], [x, x], test_kwargs=[x, (x,)])
if _f == _test_optional_list_expr:
f(None, None, [x], test_kwargs=None)
def test_auto_convert_skip():
def _test_expr_skip(arg: int, *args: Union[str, Expr], **kwargs: List[Optional[Expr]]) -> None:
f_checker = lambda x: not isinstance(x, Expr)
_test_base(f_checker, arg, *args, **kwargs)
f = args_converter.auto(_test_expr_skip)
f(1, "str", test_kwargs=[None])
def test_empty_tuple():
def _test(arg: Expr):
assert isinstance(arg, relax.Tuple)
f = args_converter.auto(_test)
f(())
if __name__ == "__main__":
tvm.testing.main()