blob: b02b7980f37abf87747e2e6ef1d1be3b38d3bc7a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import numpy as np
import tvm.testing
@tvm.testing.requires_gpu
def test_reduce_prims():
def test_prim(reducer, np_reducer):
# graph
n = tvm.te.size_var("n")
m = tvm.te.size_var("m")
A = te.placeholder((n, m), name="A")
R = te.compute((n,), lambda i: tvm.tir.Select((i > 1), 1, 0), name="R")
k = te.reduce_axis((0, m))
B = te.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i] == 1)), name="B")
# schedule
s = te.create_schedule(B.op)
# create iter var and assign them tags.
num_thread = 1
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, te.thread_axis("blockIdx.x"))
s[B].bind(xi, te.thread_axis("threadIdx.x"))
s[R].compute_inline()
# one line to build the function.
def check_device(device, host="llvm"):
ctx = tvm.context(device, 0)
if not tvm.testing.device_enabled(device):
print("skip because %s is not enabled.." % device)
return
freduce = tvm.build(s, args=[A, B], target=device, target_host=host, name="myreduce")
# launch the kernel.
n = 1028
m = 129
x = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
y = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
freduce(x, y)
npy = y.asnumpy()
npy[:2] = 0
res = np_reducer(x.asnumpy(), axis=1)
res[:2] = 0
tvm.testing.assert_allclose(npy, res, rtol=1e-4)
check_device("metal")
check_device("vulkan")
check_device("cuda")
check_device("opencl")
check_device("rocm")
test_prim(te.sum, np.sum)
test_prim(tvm.te.min, np.amin)
test_prim(tvm.te.max, np.amax)
def test_init_imm():
n = tvm.runtime.convert(1027)
A = te.placeholder((n,), name="A")
k = te.reduce_axis((0, n))
B = te.compute((1,), lambda i: te.sum(A[k], axis=k, init=10.0), name="B")
# schedule
s = te.create_schedule(B.op)
# one line to build the function.
def check_target(target="llvm"):
if not tvm.runtime.enabled(target):
return
ctx = tvm.cpu(0)
fapi = tvm.lower(s, args=[A, B])
fsum = tvm.build(fapi, target=target, name="mysum")
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
fsum(a, b)
res = 10.0 + np.sum(a.asnumpy(), axis=0)
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
check_target()
def test_init():
n = tvm.runtime.convert(1027)
A = te.placeholder((n, n), name="A")
C = te.placeholder((n, n), name="C")
I = te.placeholder((n, n), name="I")
k = te.reduce_axis((0, n))
B = te.compute((n, n), lambda i, j: te.sum(A[i, k] * C[k, j], axis=k, init=I[i, j]), name="B")
# schedule
s = te.create_schedule(B.op)
# one line to build the function.
def check_target(target="llvm"):
if not tvm.runtime.enabled(target):
return
ctx = tvm.cpu(0)
fapi = tvm.lower(s, args=[A, C, I, B])
print(fapi)
mmult = tvm.build(fapi, target=target, name="mmult")
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n, n)).astype(A.dtype), ctx)
c = tvm.nd.array(np.random.uniform(size=(n, n)).astype(C.dtype), ctx)
ii = tvm.nd.array(np.random.uniform(size=(n, n)).astype(B.dtype), ctx)
b = tvm.nd.array(np.zeros((n, n), dtype=B.dtype), ctx)
mmult(a, c, ii, b)
res = ii.asnumpy() + np.matmul(a.asnumpy(), c.asnumpy())
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
check_target()
def test_rfactor():
n = tvm.runtime.convert(1027)
A = te.placeholder((n,), name="A")
k = te.reduce_axis((0, n))
B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name="B")
# schedule
s = te.create_schedule(B.op)
kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf)
s[BF].parallel(BF.op.axis[0])
# one line to build the function.
def check_target(target="llvm"):
if not tvm.testing.device_enabled(target):
return
ctx = tvm.cpu(0)
fapi = tvm.lower(s, args=[A, B])
fsum = tvm.build(fapi, target=target, name="mysum")
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=0)
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
check_target()
def test_rfactor_init():
n = tvm.runtime.convert(1027)
A = te.placeholder((n, n), name="A")
C = te.placeholder((n, n), name="C")
I = te.placeholder((n, n), name="I")
k = te.reduce_axis((0, n))
B = te.compute((n, n), lambda i, j: te.sum(A[i, k] * C[k, j], axis=k, init=I[i, j]), name="B")
# schedule
s = te.create_schedule(B.op)
kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf, 1)
s[BF].parallel(BF.op.axis[0])
# one line to build the function.
def check_target(target="llvm"):
if not tvm.runtime.enabled(target):
return
ctx = tvm.cpu(0)
fapi = tvm.lower(s, args=[A, C, I, B])
print(fapi)
mmult = tvm.build(fapi, target=target, name="mmult")
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n, n)).astype(A.dtype), ctx)
c = tvm.nd.array(np.random.uniform(size=(n, n)).astype(C.dtype), ctx)
ii = tvm.nd.array(np.random.uniform(size=(n, n)).astype(B.dtype), ctx)
b = tvm.nd.array(np.zeros((n, n), dtype=B.dtype), ctx)
mmult(a, c, ii, b)
res = ii.asnumpy() + np.matmul(a.asnumpy(), c.asnumpy())
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
check_target()
def test_rfactor_factor_axis():
n = tvm.runtime.convert(1027)
A = te.placeholder((n,), name="A")
k = te.reduce_axis((0, n))
B = te.compute((1,), lambda i: te.sum(A[k], axis=k), name="B")
# schedule
s = te.create_schedule(B.op)
kf, ki = s[B].split(k, nparts=4)
BF = s.rfactor(B, kf, 1)
s[BF].parallel(BF.op.axis[0])
# one line to build the function.
def check_target(target="llvm"):
if not tvm.testing.device_enabled(target):
return
ctx = tvm.cpu(0)
fapi = tvm.lower(s, args=[A, B])
fsum = tvm.build(fapi, target=target, name="mysum")
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=0)
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
check_target()
@tvm.testing.requires_gpu
def test_rfactor_threads():
nn = 1027
mm = 10
n = tvm.runtime.convert(nn)
m = tvm.runtime.convert(mm)
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n))
nthread = 16
B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k, where=(i > 1)), name="B")
# schedule
s = te.create_schedule(B.op)
ko, kf = s[B].split(k, factor=nthread)
BF = s.rfactor(B, kf)
bx, ty = s[B].split(s[B].op.axis[0], factor=nthread)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(ty, te.thread_axis("threadIdx.y"))
tx = s[B].op.reduce_axis[0]
thread_x = te.thread_axis("threadIdx.x")
s[B].bind(tx, thread_x)
s[BF].compute_at(s[B], tx)
s[B].set_store_predicate(thread_x.var.equal(0))
# one line to build the function.
def check_target(device, host="stackvm"):
ctx = tvm.context(device, 0)
if not tvm.testing.device_enabled(device):
print("skip because %s is not enabled.." % device)
return
fapi = tvm.lower(s, args=[A, B])
fsum = tvm.build(fapi, target=device, name="mysum")
# launch the kernel.
n = nn
m = mm
a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=1)
res[:2] = 0
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
check_target("vulkan")
check_target("cuda")
check_target("metal")
check_target("opencl")
check_target("rocm")
@tvm.testing.requires_gpu
def test_rfactor_elemwise_threads():
n = 1025
m = 10
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n))
nthread = 16
B = te.compute((m,), lambda i: te.sum(A[i, k], axis=k), name="B")
BB = te.compute((m,), lambda i: B[i] + 1, name="BB")
C = te.compute((m,), lambda i: BB[i] + 1, name="C")
# schedule
s = te.create_schedule(C.op)
s[BB].compute_inline()
bx, ty = s[C].split(s[C].op.axis[0], factor=nthread)
ko, kf = s[B].split(k, factor=nthread)
BF = s.rfactor(B, kf)
s[B].compute_at(s[C], ty)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(ty, te.thread_axis("threadIdx.y"))
tx = s[B].op.reduce_axis[0]
thread_x = te.thread_axis("threadIdx.x")
s[B].bind(tx, thread_x)
s[BF].compute_at(s[B], tx)
# Since thread_x is shared across reductions
# only one of them need to do write back
s[B].set_store_predicate(thread_x.var.equal(0))
s[C].set_store_predicate(thread_x.var.equal(0))
# one line to build the function.
def check_target(device, host="stackvm"):
ctx = tvm.context(device, 0)
if not tvm.testing.device_enabled(device):
print("skip because %s is not enabled.." % device)
return
fapi = tvm.lower(s, args=[A, C])
fsum = tvm.build(fapi, target=device, name="mysum")
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=(m, n)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
fsum(a, b)
res = np.sum(a.asnumpy(), axis=1) + 2
tvm.testing.assert_allclose(b.asnumpy(), res, rtol=1e-4)
check_target("vulkan")
check_target("cuda")
check_target("metal")
check_target("opencl")
check_target("rocm")
def test_argmax():
def fcombine(x, y):
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
def fidentity(t0, t1):
return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
argmax = te.comm_reducer(fcombine, fidentity, name="argmax")
m = te.size_var("m")
n = te.size_var("n")
idx = te.placeholder((m, n), name="idx", dtype="int32")
val = te.placeholder((m, n), name="val", dtype="float32")
k = te.reduce_axis((0, n), "k")
T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="T")
s = te.create_schedule(T0.op)
def check_target():
device = "cpu"
if not tvm.testing.device_enabled(device):
print("skip because %s is not enabled.." % device)
return
ctx = tvm.context(device, 0)
fapi = tvm.lower(s, args=[idx, val, T0, T1])
fargmax = tvm.build(fapi, target="llvm", name="argmax")
mm = 12
nn = 16
np_idx = np.repeat(np.arange(nn, dtype="int32").reshape(1, nn), mm, axis=0)
np_val = np.random.uniform(size=(mm, nn)).astype("float32")
np_res = np.argmax(np_val, axis=1)
nd_idx = tvm.nd.array(np_idx, ctx)
nd_val = tvm.nd.array(np_val, ctx)
nd_res0 = tvm.nd.array(np.zeros(mm, dtype="int32"), ctx)
nd_res1 = tvm.nd.array(np.zeros(mm, dtype="float32"), ctx)
fargmax(nd_idx, nd_val, nd_res0, nd_res1)
tvm.testing.assert_allclose(np_res, nd_res0.asnumpy())
check_target()
@tvm.testing.requires_gpu
def test_rfactor_argmax():
def fcombine(x, y):
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
def fidentity(t0, t1):
return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
argmax = te.comm_reducer(fcombine, fidentity, name="argmax")
nn = 1027
mm = 10
n = tvm.runtime.convert(nn)
m = tvm.runtime.convert(mm)
A0 = te.placeholder((m, n), name="A0", dtype="int32")
A1 = te.placeholder((m, n), name="A1", dtype="float32")
k = te.reduce_axis((0, n))
B0, B1 = te.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name="B")
# schedule
s = te.create_schedule(B0.op)
nthread = 16
ko, kf = s[B0].split(k, factor=nthread)
BF0, BF1 = s.rfactor(B0, kf)
bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread)
s[B0].bind(bx, te.thread_axis("blockIdx.x"))
s[B0].bind(ty, te.thread_axis("threadIdx.y"))
tx = s[B0].op.reduce_axis[0]
thread_x = te.thread_axis("threadIdx.x")
s[B0].bind(tx, thread_x)
s[BF0.op].compute_at(s[B0], tx)
s[B0].set_store_predicate(thread_x.var.equal(0))
def check_target(device):
ctx = tvm.context(device, 0)
if not tvm.testing.device_enabled(device):
print("skip because %s is not enabled.." % device)
return
fapi = tvm.lower(s, args=[A0, A1, B0, B1])
fargmax = tvm.build(fapi, target=device, name="argmax")
np_idx = np.repeat(np.arange(nn, dtype="int32").reshape(1, nn), mm, axis=0)
np_val = np.random.uniform(size=(mm, nn)).astype("float32")
np_res = np.argmax(np_val, axis=1)
nd_idx = tvm.nd.array(np_idx, ctx)
nd_val = tvm.nd.array(np_val, ctx)
nd_res0 = tvm.nd.array(np.zeros(mm, dtype="int32"), ctx)
nd_res1 = tvm.nd.array(np.zeros(mm, dtype="float32"), ctx)
fargmax(nd_idx, nd_val, nd_res0, nd_res1)
tvm.testing.assert_allclose(np_res, nd_res0.asnumpy())
check_target("cuda")
check_target("vulkan")
check_target("rocm")
@tvm.testing.requires_gpu
def test_warp_reduction1():
nthx = 32
nthy = 4
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis((0, nthx), "threadIdx.x")
thread_y = te.thread_axis((0, nthy), "threadIdx.y")
def check_target(device, m, n):
ctx = tvm.context(device, 0)
if not tvm.testing.device_enabled(device):
print("skip because %s is not enabled.." % device)
return
# compute
A = te.placeholder((m, n), name="A")
k = te.reduce_axis((0, n))
B = te.compute((m,), lambda i: te.max(A[i][k], axis=k), name="B")
s = te.create_schedule(B.op)
# schedule
k = s[B].op.reduce_axis[0]
ko, _ = s[B].split(k, nparts=nthx)
s[B].bind(ko, thread_x)
xo, xi = s[B].split(s[B].op.axis[0], factor=nthy)
s[B].bind(xi, thread_y)
s[B].bind(xo, block_x)
tvm.lower(s, [A, B], simple_mode=True)
# validation
func = tvm.build(s, [A, B], device, name="warp_reduction")
a_np = np.random.uniform(size=(m, n)).astype(A.dtype)
b_np = np.zeros((m,), dtype=A.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
b_np = np.max(a_np, axis=1)
func(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3)
check_target("cuda", m=32, n=256)
check_target("cuda", m=10, n=20)
check_target("rocm", m=32, n=256)
check_target("rocm", m=10, n=20)
# This is a bug in normal reduction.
# check_target("cuda", m=10, n=37)
@tvm.testing.requires_gpu
def test_warp_reduction2():
def fcombine(x, y):
return x[0] + y[0], x[1] * y[1]
def fidentity(t0, t1):
return tvm.tir.const(0, t0), tvm.tir.const(1, t1)
add_mul_reducer = te.comm_reducer(fcombine, fidentity, name="add_mul_reducer")
# compute
m = 16
n = 256
A0 = te.placeholder((m, n), name="A0", dtype="float32")
A1 = te.placeholder((m, n), name="Al", dtype="float32")
k = te.reduce_axis((0, n), "k")
T0, T1 = te.compute((m,), lambda i: add_mul_reducer((A0[i, k], A1[i, k]), axis=k), name="T")
nthdx, nthdy = 32, 2
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis((0, nthdx), "threadIdx.x")
thread_y = te.thread_axis((0, nthdy), "threadIdx.y")
def check_target(device):
ctx = tvm.context(device, 0)
if not tvm.testing.device_enabled(device):
print("skip because %s is not enabled.." % device)
return
# schedule
s = te.create_schedule(T0.op)
ko, _ = s[T0].split(k, nparts=nthdx)
xo, xi = s[T0].split(s[T0].op.axis[0], factor=nthdy)
s[T0].bind(ko, thread_x)
s[T0].bind(xi, thread_y)
s[T0].bind(xo, block_x)
# validation
ctx = tvm.context(device, 0)
a0_np = np.random.uniform(size=(m, n)).astype(A0.dtype)
a1_np = np.random.uniform(size=(m, n)).astype(A1.dtype)
t0_np = np.zeros((m,), dtype=A0.dtype)
t1_np = np.zeros((m,), dtype=A1.dtype)
a0 = tvm.nd.array(a0_np, ctx)
a1 = tvm.nd.array(a1_np, ctx)
t0 = tvm.nd.array(t0_np, ctx)
t1 = tvm.nd.array(t1_np, ctx)
func = tvm.build(s, [A0, A1, T0, T1], device, name="reduction")
func(a0, a1, t0, t1)
t0_np = np.sum(a0_np, axis=1)
t1_np = np.product(a1_np, axis=1)
tvm.testing.assert_allclose(t0.asnumpy(), t0_np, rtol=1e-3, atol=1e-3)
tvm.testing.assert_allclose(t1.asnumpy(), t1_np, rtol=1e-3, atol=1e-3)
check_target("cuda")
check_target("rocm")
if __name__ == "__main__":
test_rfactor_elemwise_threads()
test_rfactor_threads()
test_rfactor_factor_axis()
test_rfactor()
test_reduce_prims()
test_argmax()
test_rfactor_argmax()
test_warp_reduction1()
test_warp_reduction2()
test_init()
test_init_imm()
test_rfactor_init()