blob: 3e9fd2cd6b67b07e071cd1d6720a636e6afd2398 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm_ffi
import tvm.testing
import numpy as np
from tvm.script import tir as T, ir as I
import pytest
def _reduce_sum_module(d1, d2, d3):
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1, d1, d2, d3), "float32"), B: T.Buffer((1, d1, d2), "float32")):
for i in T.thread_binding(1, thread="blockIdx.x"):
for j in T.thread_binding(d1, thread="threadIdx.z"):
for k in T.thread_binding(d2, thread="threadIdx.y"):
for l in T.thread_binding(d3, thread="threadIdx.x"):
with T.sblock("reduce"):
vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
T.reads(A[vi, vj, vk, vl])
T.writes(B[vi, vj, vk])
with T.init():
B[vi, vj, vk] = T.float32(0.0)
B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]
return Module
def _reduce_max_module(d1, d2, d3):
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1, d1, d2, d3), "float32"), B: T.Buffer((1, d1, d2), "float32")):
for i in T.thread_binding(1, thread="blockIdx.x"):
for j in T.thread_binding(d1, thread="threadIdx.z"):
for k in T.thread_binding(d2, thread="threadIdx.y"):
for l in T.thread_binding(d3, thread="threadIdx.x"):
with T.sblock("reduce"):
vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
T.reads(A[vi, vj, vk, vl])
T.writes(B[vi, vj, vk])
with T.init():
B[vi, vj, vk] = T.float32(-3.4028234663852886e38)
B[vi, vj, vk] = T.max(B[vi, vj, vk], A[vi, vj, vk, vl])
return Module
def generate_param_sets():
for d1 in range(1, 5):
for d2 in range(1, 5):
for d3 in [2, 4, 8, 12, 16, 32, 48, 64, 100, 128, 201, 256, 512, 1024]:
if d1 * d2 * d3 < 1024:
yield (d1, d2, d3)
dims = tvm.testing.parameter(*generate_param_sets())
@tvm.testing.parametrize_targets("cuda", "metal")
def test_allreduce_sum(dims, target, dev):
d1, d2, d3 = dims
mod = _reduce_sum_module(d1, d2, d3)
f = tvm.compile(mod, target=target)
# prepare input and output array
a_np = np.random.rand(1, d1, d2, d3).astype("float32")
b_np = a_np.sum(axis=-1).astype("float32")
a = tvm.runtime.tensor(a_np, dev)
b = tvm.runtime.tensor(np.zeros_like(b_np), dev)
# launch kernel
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
define_metal_compile_callback = tvm.testing.parameter(True, False)
@pytest.fixture
def optional_metal_compile_callback(define_metal_compile_callback):
name = "tvm_callback_metal_compile"
cached = tvm.get_global_func(name, allow_missing=True)
if define_metal_compile_callback:
@tvm.register_global_func(name, override=True)
def compile_metal(src, target):
from tvm.contrib.xcode import compile_metal # pylint: disable=import-outside-toplevel
return compile_metal(src, sdk="macosx")
yield
if define_metal_compile_callback:
if cached is None:
tvm_ffi.registry.remove_global_func(name)
else:
tvm.register_global_func(name, cached, override=True)
@tvm.testing.requires_metal(support_required="compile-only")
def test_allreduce_sum_compile(optional_metal_compile_callback):
# Disable the parametrization over dims, at least for now
dims = (1, 1, 2)
target = "metal"
d1, d2, d3 = dims
mod = _reduce_sum_module(d1, d2, d3)
tvm.compile(mod, target=target)
@tvm.testing.parametrize_targets("cuda", "metal")
def test_allreduce_max(dims, target, dev):
d1, d2, d3 = dims
mod = _reduce_max_module(d1, d2, d3)
f = tvm.compile(mod, target=target)
# prepare input and output array
a_np = -np.random.rand(1, d1, d2, d3).astype("float32")
b_np = a_np.max(axis=-1).astype("float32")
a = tvm.runtime.tensor(a_np, dev)
b = tvm.runtime.tensor(np.zeros_like(b_np), dev)
# launch kernel
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
if __name__ == "__main__":
tvm.testing.main()