blob: b4e747a0b4d8b94ed565be7a4f7d80c8129e41c9 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
import tvm.script
import tvm.testing
from tvm import te
from tvm.script import tir as T
@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_metal_inf_nan():
target = "metal"
def check_inf_nan(dev, n, value, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
inf_value = tvm.tir.const(value, dtype=dtype)
C = te.compute((n,), lambda i: inf_value, name="C")
prim_func = te.create_prim_func([A, C])
sch = tvm.tir.Schedule(prim_func)
(x,) = sch.get_loops(sch.get_block("C"))
sch.bind(x, "threadIdx.x")
fun = tvm.build(sch.mod, target=target)
a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
# Only need to test compiling here
fun(a, c)
dev = tvm.device(target, 0)
check_inf_nan(dev, 1, -float("inf"), "float32")
check_inf_nan(dev, 1, -float("inf"), "float16")
check_inf_nan(dev, 1, float("inf"), "float32")
check_inf_nan(dev, 1, float("inf"), "float16")
check_inf_nan(dev, 1, float("nan"), "float32")
check_inf_nan(dev, 1, float("nan"), "float16")
@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_unaligned_vectorize():
@tvm.script.ir_module
class IRModule:
@T.prim_func
def main(A: T.Buffer((2, 3), "float32"), B: T.Buffer((6,), "float32")):
T.func_attr({"global_symbol": "main"})
for i0_1 in T.thread_binding(3, thread="threadIdx.x"):
for i0_0 in T.vectorized(2):
with T.block("block"):
vi0 = T.axis.spatial(6, i0_0 * 3 + i0_1)
B[vi0] = A[vi0 // 3, vi0 % 3]
target = "metal"
dev = tvm.metal()
a = (np.arange(6).reshape(2, 3)).astype("float32")
a_nd = tvm.nd.array(a, dev)
b_nd = tvm.nd.empty((6,), "float32", dev)
f = tvm.build(IRModule, target=target)
f(a_nd, b_nd)
np.testing.assert_allclose(b_nd.numpy(), a.reshape(6), atol=1e-5, rtol=1e-5)
@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_metal_erf():
target = "metal"
def check_erf(dev, n, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C")
func = te.create_prim_func([A, C])
sch = tvm.tir.Schedule(func)
(x,) = sch.get_loops(sch.get_block("C"))
sch.bind(x, "threadIdx.x")
fun = tvm.build(sch.mod, target=target)
a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
# Only need to test compiling here
fun(a, c)
dev = tvm.device(target, 0)
check_erf(dev, 1, "float32")
check_erf(dev, 1, "float16")
@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_ramp():
target = "metal"
@tvm.script.ir_module
class IRModule:
@T.prim_func
def main(A: T.Buffer((1, 2), "int32")):
T.func_attr({"global_symbol": "main"})
for i in T.thread_binding(1, thread="threadIdx.x"):
with T.block("block"):
tx = T.axis.spatial(1, i)
r = T.ramp(tx, 3, 2)
A[0, T.ramp(0, 1, 2)] = r
f = tvm.build(IRModule, target=target)
dev = tvm.metal()
a_nd = tvm.nd.empty((1, 2), "int32", dev)
f(a_nd)
assert tuple(a_nd.numpy()[0, :]) == (0, 3)
@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_select_vectorize():
@tvm.script.ir_module
class IRModule:
@T.prim_func
def main(A: T.Buffer((6), "float32"), B: T.Buffer((6,), "float32")):
T.func_attr({"global_symbol": "main"})
for i0_1 in T.thread_binding(3, thread="threadIdx.x"):
for i0_0 in T.vectorized(2):
with T.block("block"):
vi0 = T.axis.spatial(6, i0_0 * 3 + i0_1)
B[vi0] = T.Select((vi0 % 2) == 0, A[vi0], T.float32(0))
target = "metal"
dev = tvm.metal()
a = np.arange(6).astype("float32")
a_nd = tvm.nd.array(a, dev)
b_nd = tvm.nd.empty((6,), "float32", dev)
f = tvm.build(IRModule, target=target)
f(a_nd, b_nd)
a.reshape(3, 2)[:, 1] = 0
np.testing.assert_allclose(b_nd.numpy(), a, atol=1e-5, rtol=1e-5)
@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_vectorized_uint8():
@T.prim_func
def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")):
for i in T.thread_binding(4, thread="threadIdx.x"):
for j in T.vectorized(4):
with T.block("block"):
vi = T.axis.spatial(16, i * 4 + j)
B[vi] = T.Cast("float32", A[vi])
dev = tvm.metal()
a = np.arange(16).astype("uint8")
a_nd = tvm.nd.array(a, dev)
b_nd = tvm.nd.empty((16,), "float32", dev)
f = tvm.build(func, target="metal")
f(a_nd, b_nd)
np.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5)
@tvm.testing.requires_metal(support_required="compile-only")
def test_func_with_trailing_pod_params():
from tvm.contrib import xcode # pylint: disable=import-outside-toplevel
@T.prim_func
def func(A: T.Buffer((16), "float32"), B: T.Buffer((16), "float32"), x: T.float32):
for i in T.thread_binding(16, thread="threadIdx.x"):
with T.block("block"):
vi = T.axis.spatial(16, i)
B[vi] = A[vi] + x
@tvm.register_func("tvm_callback_metal_compile")
def compile_metal(src, target):
return xcode.compile_metal(src)
mod = tvm.IRModule({"main": func})
f = tvm.build(mod, target="metal")
src: str = f.imported_modules[0].get_source()
occurrences = src.count("struct func_kernel_args_t")
assert occurrences == 1, occurrences
if __name__ == "__main__":
tvm.testing.main()