blob: bba6d983c7a14f6782de1aa790056c23fb491067 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for the thin ``tvm.testing.env`` capability probes."""
import pytest
import tvm
import tvm.testing
from tvm.testing import env
# Probes that take no arguments and must return a plain bool without raising.
_BOOL_PROBES = [
# runtime device
env.has_cuda,
env.has_rocm,
env.has_vulkan,
env.has_metal,
env.has_opencl,
env.has_nvptx,
env.has_llvm,
env.has_gpu,
# build support
env.has_cudnn,
env.has_cublas,
env.has_nccl,
env.has_hipblas,
env.has_cutlass,
env.has_rpc,
env.has_nnapi,
env.has_openclml,
env.has_mrvl,
env.has_nvshmem,
# version / capability
env.has_tensorcore,
env.has_matrixcore,
env.has_cudagraph,
# toolchain / environment
env.has_hexagon,
env.has_hexagon_toolchain,
env.has_adreno_opencl,
env.has_aprofile_aem_fvp,
# cpu features
env.has_arm_dot,
env.has_arm_fp16,
env.has_aarch64_sve,
env.has_aarch64_sme,
env.has_x86_vnni,
env.has_x86_avx512,
env.has_x86_amx,
# host architecture
env.is_x86,
env.is_aarch64,
]
@pytest.mark.parametrize("probe", _BOOL_PROBES, ids=lambda p: p.__name__)
def test_probe_returns_bool(probe):
"""Every probe returns a real bool and never raises during collection/run."""
assert isinstance(probe(), bool)
def test_has_cuda_implies_device():
"""has_cuda() requires a device (it also requires the kind to be enabled)."""
if env.has_cuda():
assert tvm.cuda().exist
def test_has_gpu_is_raw_any_device():
"""has_gpu() is the disjunction of the raw device checks (no target gating)."""
any_device = (
env._device_exists("cuda") # pylint: disable=protected-access
or env._device_exists("rocm") # pylint: disable=protected-access
or env._device_exists("opencl") # pylint: disable=protected-access
or env._device_exists("metal") # pylint: disable=protected-access
or env._device_exists("vulkan") # pylint: disable=protected-access
)
assert env.has_gpu() == any_device
def test_target_enabled_respects_tvm_test_targets(monkeypatch):
"""A device kind excluded from TVM_TEST_TARGETS is reported as not enabled."""
env._target_enabled.cache_clear() # pylint: disable=protected-access
monkeypatch.setenv("TVM_TEST_TARGETS", "cuda;llvm")
try:
assert env._target_enabled("cuda") # pylint: disable=protected-access
assert env._target_enabled("llvm") # pylint: disable=protected-access
assert not env._target_enabled("opencl") # pylint: disable=protected-access
assert not env._target_enabled("metal") # pylint: disable=protected-access
finally:
env._target_enabled.cache_clear() # pylint: disable=protected-access
def test_tensorcore_implies_cuda():
"""Tensor Core support cannot be reported without a CUDA device."""
if env.has_tensorcore():
assert env.has_cuda()
def test_cudagraph_implies_cuda():
"""CUDA Graph support cannot be reported without a CUDA device."""
if env.has_cudagraph():
assert env.has_cuda()
def test_cuda_compute_is_monotonic():
"""has_cuda_compute is monotone in the requested version."""
if not env.has_cuda():
# Without a CUDA device every query is False, including the (0, 0) floor.
assert not env.has_cuda_compute(1, 0)
assert not env.has_cuda_compute(0, 0)
return
# A device that satisfies (major, minor) also satisfies anything lower.
assert env.has_cuda_compute(1, 0)
assert env.has_cuda_compute(0, 0)
def test_has_multi_gpu_is_bool():
assert isinstance(env.has_multi_gpu(), bool)
assert isinstance(env.has_multi_gpu(1), bool)
# Requiring a single device is at least as permissive as requiring two.
assert env.has_multi_gpu(1) or not env.has_multi_gpu(2)
@pytest.mark.parametrize(
"probe,flag",
[
(env.has_cutlass, "USE_CUTLASS"),
(env.has_rpc, "USE_RPC"),
(env.has_nnapi, "USE_NNAPI_CODEGEN"),
(env.has_openclml, "USE_CLML"),
(env.has_mrvl, "USE_MRVL"),
],
ids=lambda v: getattr(v, "__name__", v),
)
def test_build_flag_probe_matches_libinfo(probe, flag):
"""Pure build-flag probes agree with the build-info flag they wrap."""
assert probe() == env._build_flag_enabled(flag) # pylint: disable=protected-access
@pytest.mark.parametrize(
"probe,parent",
[
(env.has_cudnn, env.has_cuda),
(env.has_cublas, env.has_cuda),
(env.has_nccl, env.has_cuda),
(env.has_hipblas, env.has_rocm),
],
ids=lambda v: v.__name__,
)
def test_library_probe_implies_parent_device(probe, parent):
"""A CUDA/ROCm library cannot be reported without its parent device."""
if probe():
assert parent()
def test_llvm_min_version_is_monotone():
if not env.has_llvm():
assert not env.has_llvm_min_version(1)
return
# An LLVM that satisfies a higher floor also satisfies a lower one.
assert env.has_llvm_min_version(1)
def test_hexagon_run_implies_toolchain():
"""Full Hexagon support implies the compile-time toolchain is present."""
if env.has_hexagon():
assert env.has_hexagon_toolchain()
def test_probes_are_memoized():
"""Probes are cached so the driver/subprocess is hit once per process."""
env.has_cuda()
info = env._device_exists.cache_info() # pylint: disable=protected-access
assert info.hits + info.misses >= 1
# --- demonstration of the target idiom -------------------------------------
#
# The standard gating idiom: a plain registered pytest marker (for ``-m``
# selection) plus a skipif backed by a thin env probe (for runtime gating).
@pytest.mark.gpu
@pytest.mark.skipif(not env.has_cuda(), reason="need cuda")
def test_thin_cuda_idiom():
dev = tvm.cuda()
assert dev.exist
if __name__ == "__main__":
tvm.testing.main()