blob: 9618113ae3a9743e9dd4f24acc11d9e5c3d8c25b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import sys
import pytest
import tvm.testing
# This file tests features in tvm.testing, such as verifying that
# cached fixtures are run an appropriate number of times. As a
# result, the order of the tests is important. Use of --last-failed
# or --failed-first while debugging this file is not advised. If
# these tests are distributed/parallelized using pytest-xdist or
# similar, all tests in this file should run sequentially on the same
# node. (See https://stackoverflow.com/a/59504228)
class TestTargetAutoParametrization:
targets_used = []
devices_used = []
enabled_targets = [target for target, dev in tvm.testing.enabled_targets()]
enabled_devices = [dev for target, dev in tvm.testing.enabled_targets()]
def test_target_parametrization(self, target):
assert target in self.enabled_targets
self.targets_used.append(target)
def test_device_parametrization(self, dev):
assert dev in self.enabled_devices
self.devices_used.append(dev)
def test_all_targets_used(self):
assert sorted(self.targets_used) == sorted(self.enabled_targets)
def test_all_devices_used(self):
sort_key = lambda dev: (dev.dlpack_device_type(), dev.index)
assert sorted(self.devices_used, key=sort_key) == sorted(self.enabled_devices, key=sort_key)
targets_with_explicit_list = []
@tvm.testing.parametrize_targets("llvm")
def test_explicit_list(self, target):
assert target == "llvm"
self.targets_with_explicit_list.append(target)
def test_no_repeats_in_explicit_list(self):
if tvm.testing.device_enabled("llvm"):
assert self.targets_with_explicit_list == ["llvm"]
else:
assert self.targets_with_explicit_list == []
targets_with_exclusion = []
@tvm.testing.exclude_targets("llvm")
def test_exclude_target(self, target):
assert "llvm" not in target
self.targets_with_exclusion.append(target)
def test_all_nonexcluded_targets_ran(self):
assert sorted(self.targets_with_exclusion) == sorted(
[target for target in self.enabled_targets if not target.startswith("llvm")]
)
run_targets_with_known_failure = []
@tvm.testing.known_failing_targets("llvm")
def test_known_failing_target(self, target):
# This test runs for all targets, but intentionally fails for
# llvm. The behavior is working correctly if this test shows
# up as an expected failure, xfail.
self.run_targets_with_known_failure.append(target)
assert "llvm" not in target
def test_all_targets_ran(self):
assert sorted(self.run_targets_with_known_failure) == sorted(self.enabled_targets)
@tvm.testing.known_failing_targets("llvm")
@tvm.testing.parametrize_targets("llvm")
def test_known_failing_explicit_list(self, target):
assert target != "llvm"
class TestJointParameter:
param1_vals = [1, 2, 3]
param2_vals = ["a", "b", "c"]
independent_usages = 0
param1 = tvm.testing.parameter(*param1_vals)
param2 = tvm.testing.parameter(*param2_vals)
joint_usages = 0
joint_param_vals = list(zip(param1_vals, param2_vals))
joint_param_ids = ["apple", "pear", "banana"]
joint_param1, joint_param2 = tvm.testing.parameters(*joint_param_vals, ids=joint_param_ids)
def test_using_independent(self, param1, param2):
type(self).independent_usages += 1
def test_independent(self):
assert self.independent_usages == len(self.param1_vals) * len(self.param2_vals)
def test_using_joint(self, joint_param1, joint_param2):
type(self).joint_usages += 1
assert (joint_param1, joint_param2) in self.joint_param_vals
def test_joint(self):
assert self.joint_usages == len(self.joint_param_vals)
def test_joint_test_id(self, joint_param1, joint_param2, request):
param_string = (
request.node.name.replace(request.node.originalname, "")
.replace("[", "")
.replace("]", "")
)
assert param_string in self.joint_param_ids
class TestFixtureCaching:
param1_vals = [1, 2, 3]
param2_vals = ["a", "b", "c"]
param1 = tvm.testing.parameter(*param1_vals)
param2 = tvm.testing.parameter(*param2_vals)
uncached_calls = 0
cached_calls = 0
@tvm.testing.fixture
def uncached_fixture(self, param1):
type(self).uncached_calls += 1
return 2 * param1
def test_use_uncached(self, param1, param2, uncached_fixture):
assert 2 * param1 == uncached_fixture
def test_uncached_count(self):
assert self.uncached_calls == len(self.param1_vals) * len(self.param2_vals)
@tvm.testing.fixture(cache_return_value=True)
def cached_fixture(self, param1):
type(self).cached_calls += 1
return 3 * param1
def test_use_cached(self, param1, param2, cached_fixture):
assert 3 * param1 == cached_fixture
def test_cached_count(self):
cache_disabled = bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0")))
if cache_disabled:
assert self.cached_calls == len(self.param1_vals) * len(self.param2_vals)
else:
assert self.cached_calls == len(self.param1_vals)
class TestCachedFixtureIsCopy:
param = tvm.testing.parameter(1, 2, 3, 4)
@tvm.testing.fixture(cache_return_value=True)
def cached_mutable_fixture(self):
return {"val": 0}
def test_modifies_fixture(self, param, cached_mutable_fixture):
assert cached_mutable_fixture["val"] == 0
# The tests should receive a copy of the fixture value. If
# the test receives the original and not a copy, then this
# will cause the next parametrization to fail.
cached_mutable_fixture["val"] = param
class TestBrokenFixture:
# Tests that use a fixture that throws an exception fail, and are
# marked as setup failures. The tests themselves are never run.
# This behavior should be the same whether or not the fixture
# results are cached.
num_uses_broken_uncached_fixture = 0
num_uses_broken_cached_fixture = 0
@tvm.testing.fixture
def broken_uncached_fixture(self):
raise RuntimeError("Intentionally broken fixture")
@pytest.mark.xfail(True, reason="Broken fixtures should result in a failing setup", strict=True)
def test_uses_broken_uncached_fixture(self, broken_uncached_fixture):
type(self).num_uses_broken_fixture += 1
def test_num_uses_uncached(self):
assert self.num_uses_broken_uncached_fixture == 0
@tvm.testing.fixture(cache_return_value=True)
def broken_cached_fixture(self):
raise RuntimeError("Intentionally broken fixture")
@pytest.mark.xfail(True, reason="Broken fixtures should result in a failing setup", strict=True)
def test_uses_broken_cached_fixture(self, broken_cached_fixture):
type(self).num_uses_broken_cached_fixture += 1
def test_num_uses_cached(self):
assert self.num_uses_broken_cached_fixture == 0
class TestAutomaticMarks:
@staticmethod
def check_marks(request, target):
decorators = tvm.testing.plugin._target_to_requirement(target)
required_marks = [decorator.mark for decorator in decorators]
applied_marks = list(request.node.iter_markers())
for required_mark in required_marks:
assert required_mark in applied_marks
def test_automatic_fixture(self, request, target):
self.check_marks(request, target)
@tvm.testing.parametrize_targets
def test_bare_parametrize(self, request, target):
self.check_marks(request, target)
@tvm.testing.parametrize_targets("llvm", "cuda", "vulkan")
def test_explicit_parametrize(self, request, target):
self.check_marks(request, target)
@pytest.mark.parametrize("target", ["llvm", "cuda", "vulkan"])
def test_pytest_mark(self, request, target):
self.check_marks(request, target)
@pytest.mark.parametrize("target,other_param", [("llvm", 0), ("cuda", 1), ("vulkan", 2)])
def test_pytest_mark_covariant(self, request, target, other_param):
self.check_marks(request, target)
@pytest.mark.skipif(
bool(int(os.environ.get("TVM_TEST_DISABLE_CACHE", "0"))),
reason="Cannot test cache behavior while caching is disabled",
)
class TestCacheableTypes:
class EmptyClass:
pass
@tvm.testing.fixture(cache_return_value=True)
def uncacheable_fixture(self):
return self.EmptyClass()
def test_uses_uncacheable(self, request):
# Normally the num_tests_use_this_fixture would be set before
# anything runs. For this test case only, because we are
# delaying the use of the fixture, we need to manually
# increment it.
self.uncacheable_fixture.num_tests_use_this_fixture[0] += 1
with pytest.raises(TypeError):
request.getfixturevalue("uncacheable_fixture")
class ImplementsReduce:
def __reduce__(self):
return super().__reduce__()
@tvm.testing.fixture(cache_return_value=True)
def fixture_with_reduce(self):
return self.ImplementsReduce()
def test_uses_reduce(self, fixture_with_reduce):
pass
class ImplementsDeepcopy:
def __deepcopy__(self, memo):
return type(self)()
@tvm.testing.fixture(cache_return_value=True)
def fixture_with_deepcopy(self):
return self.ImplementsDeepcopy()
def test_uses_deepcopy(self, fixture_with_deepcopy):
pass
class TestPytestCache:
param = tvm.testing.parameter(1, 2, 3)
@pytest.fixture(scope="class")
def cached_fixture(self, param):
return param * param
def test_uses_cached_fixture(self, param, cached_fixture):
assert cached_fixture == param * param
if __name__ == "__main__":
tvm.testing.main()