blob: fed4e4c04d3294b49146a5bfef4bc0e434189ec1 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
import numpy as np
from tvm.script import tir as T
import pytest
@T.prim_func
def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> None:
A = T.match_buffer(a, [1, d1, d2, d3])
B = T.match_buffer(b, [1, d1, d2])
for i, j, k, l in T.grid(1, d1, d2, d3):
with T.block("reduce"):
vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
with T.init():
B[vi, vj, vk] = 0.0
B[vi, vj, vk] = B[vi, vj, vk] + A[vi, vj, vk, vl]
@T.prim_func
def reduce_max(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> None:
A = T.match_buffer(a, [1, d1, d2, d3])
B = T.match_buffer(b, [1, d1, d2])
for i, j, k, l in T.grid(1, d1, d2, d3):
with T.block("reduce"):
vi, vj, vk, vl = T.axis.remap("SSSR", [i, j, k, l])
with T.init():
B[vi, vj, vk] = T.float32(-3.4028234663852886e38)
B[vi, vj, vk] = T.max(B[vi, vj, vk], A[vi, vj, vk, vl])
def generate_param_sets():
for d1 in range(1, 5):
for d2 in range(1, 5):
for d3 in [2, 4, 8, 12, 16, 32, 48, 64, 100, 128, 201, 256, 512, 1024]:
if d1 * d2 * d3 < 1024:
yield (d1, d2, d3)
dims = tvm.testing.parameter(*generate_param_sets())
@tvm.testing.parametrize_targets("cuda", "metal")
def test_allreduce_sum(dims, target, dev):
d1, d2, d3 = dims
_, _, _d1, _d2, _d3 = reduce.params
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("reduce")
i, j, k, l = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.z")
sch.bind(k, "threadIdx.y")
sch.bind(l, "threadIdx.x")
f = tvm.build(sch.mod["main"], target=target)
# prepare input and output array
a_np = np.random.rand(1, d1, d2, d3).astype("float32")
b_np = a_np.sum(axis=-1).astype("float32")
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros_like(b_np), dev)
# launch kernel
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
define_metal_compile_callback = tvm.testing.parameter(True, False)
@pytest.fixture
def optional_metal_compile_callback(define_metal_compile_callback):
name = "tvm_callback_metal_compile"
cached = tvm.get_global_func(name, allow_missing=True)
if define_metal_compile_callback:
@tvm.register_func(name, override=True)
def compile_metal(src, target):
return tvm.contrib.xcode.compile_metal(src, sdk="macosx")
yield
if define_metal_compile_callback:
if cached is None:
tvm._ffi.registry.remove_global_func(name)
else:
tvm.register_func(name, cached, override=True)
@tvm.testing.requires_metal(support_required="compile-only")
def test_allreduce_sum_compile(optional_metal_compile_callback):
# Disable the parametrization over dims, at least for now
dims = (1, 1, 2)
target = "metal"
d1, d2, d3 = dims
_, _, _d1, _d2, _d3 = reduce.params
mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("reduce")
i, j, k, l = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.z")
sch.bind(k, "threadIdx.y")
sch.bind(l, "threadIdx.x")
tvm.build(sch.mod["main"], target=target)
@tvm.testing.parametrize_targets("cuda", "metal")
def test_allreduce_max(dims, target, dev):
d1, d2, d3 = dims
_, _, _d1, _d2, _d3 = reduce_max.params
mod = reduce_max.specialize({_d1: d1, _d2: d2, _d3: d3})
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("reduce")
i, j, k, l = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.z")
sch.bind(k, "threadIdx.y")
sch.bind(l, "threadIdx.x")
f = tvm.build(sch.mod["main"], target=target)
# prepare input and output array
a_np = -np.random.rand(1, d1, d2, d3).astype("float32")
b_np = a_np.max(axis=-1).astype("float32")
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros_like(b_np), dev)
# launch kernel
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6)
if __name__ == "__main__":
tvm.testing.main()