blob: 1c7b82e3cc291d814435b834d734569401d36160 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
import tvm
from .. import defop, AllTypes, RealTypes
from .. import assign_by_req, reduce_axes
def compute_add(dtype, ndim):
A = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='A', dtype=dtype)
B = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='B', dtype=dtype)
C = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
lambda *index: A[index] + B[index], name='C')
s = tvm.te.create_schedule(C.op)
return s, A, B, C
@defop(name="vadd", target="cpu", auto_broadcast=True,
dtype=AllTypes, ndim=[5])
def vadd(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
axes = [axis for axis in C.op.axis]
fused = s[C].fuse(*axes)
s[C].parallel(fused)
return s, [A, B, C]
@defop(name="cuda_vadd", target="cuda", auto_broadcast=True,
dtype=["float32", "float64"], ndim=[5])
def vadd_gpu(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
s = tvm.te.create_schedule(C.op)
axes = [axis for axis in C.op.axis]
fused = s[C].fuse(*axes)
bx, tx = s[C].split(fused, factor=64)
s[C].bind(bx, tvm.te.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.te.thread_axis("threadIdx.x"))
return s, [A, B, C]
def compute_backward_vadd(dtype, ndim, reduce1st, req):
# The backward of broadcast op is basically a reduction on broadcast axes.
# We label the reduce axes as 1 and other axes as 0, and they form a bit string.
# Each bit string correponds to a kernel, so the number of kernels is as many as `2^n`
# To reduce it, the bit string is compressed by combining consecutive 0s or 1s.
# In this way, the number of bit string (the number of kernels) is reduced to `2 * n`
# They compressed bit string is stored in `axes`. And `reduce1st` represents the first bit
# of the compressed bit string. Credit to @junrushao1994 and @yzhliu.
axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim]
X = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='X', dtype=dtype)
reducer = tvm.te.comm_reducer(lambda x, y: x + y,
lambda t: tvm.tir.const(0, dtype=t), name="sum")
ret = reduce_axes(X, axes, reducer)
in_grad_a, in_grad = assign_by_req(ret, req)
s = tvm.te.create_schedule(in_grad.op)
return s, X, in_grad_a, in_grad, [ret, in_grad]
@defop(name="backward_vadd", target="cpu", dtype=AllTypes,
ndim=[5], reduce1st=[0, 1],
req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"])
def backward_vadd(dtype, ndim, reduce1st, req):
s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [X, in_grad_a, in_grad]
@defop(name="cuda_backward_vadd", target="gpu", dtype=["float32", "float64"],
ndim=[5], reduce1st=[0, 1],
req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"])
def backward_vadd_gpu(dtype, ndim, reduce1st, req):
s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req)
num_thread = 64
for t in c_list:
block_x = tvm.te.thread_axis("blockIdx.x")
thread_x = tvm.te.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [X, in_grad_a, in_grad]
def compute_degandrad(dtype, ndim, n):
A = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='A', dtype=dtype)
import math
if n == 0:
B = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
lambda *index: A[index] * tvm.tir.const(math.pi, dtype) / tvm.tir.const(180, dtype), name='B')
else:
B = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
lambda *index: A[index] / tvm.tir.const(math.pi, dtype) * tvm.tir.const(180, dtype), name='B')
s = tvm.te.create_schedule(B.op)
return s, A, B
@defop(name="deg2rad", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def deg2rad(dtype, ndim):
s, A, B = compute_degandrad(dtype, ndim, 0)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
s[B].parallel(fused)
return s, [A, B]
@defop(name="rad2deg", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def rad2deg(dtype, ndim):
s, A, B = compute_degandrad(dtype, ndim, 1)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
s[B].parallel(fused)
return s, [A, B]
@defop(name="cuda_deg2rad", target="cuda", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def deg2rad_gpu(dtype, ndim):
s, A, B = compute_degandrad(dtype, ndim, 0)
s = tvm.te.create_schedule(B.op)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
bx, tx = s[B].split(fused, factor=64)
s[B].bind(bx, tvm.te.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.te.thread_axis("threadIdx.x"))
return s, [A, B]
@defop(name="cuda_rad2deg", target="cuda", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)))
def rad2deg_gpu(dtype, ndim):
s, A, B = compute_degandrad(dtype, ndim, 1)
s = tvm.te.create_schedule(B.op)
axes = [axis for axis in B.op.axis]
fused = s[B].fuse(*axes)
bx, tx = s[B].split(fused, factor=64)
s[B].bind(bx, tvm.te.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.te.thread_axis("threadIdx.x"))
return s, [A, B]
def compute_backward_degandrad(dtype, ndim, req, n):
ishape = [tvm.te.size_var() for _ in range(ndim)]
in_grad_tmp = tvm.te.placeholder(ishape, name='in_grad_tmp', dtype=dtype)
in_grad = tvm.te.placeholder(ishape, name='in_grad', dtype=dtype)
out_grad = tvm.te.placeholder(ishape, name='out_grad', dtype=dtype)
import math
if n == 0:
ret = tvm.te.compute(ishape, lambda *index: out_grad[index] * tvm.tir.const(math.pi, dtype) / tvm.tir.const(180, dtype))
else:
ret = tvm.te.compute(ishape, lambda *index: out_grad[index] / tvm.tir.const(math.pi, dtype) * tvm.tir.const(180, dtype))
if (req == "kAddTo"):
in_grad = tvm.te.compute(ishape, lambda *index: in_grad_tmp[index] + ret[index])
else:
in_grad = tvm.te.compute(ishape, lambda *index: ret[index])
s = tvm.te.create_schedule(in_grad.op)
return s, out_grad, in_grad_tmp, in_grad, [ret, in_grad]
@defop(name="backward_deg2rad", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def backward_deg2rad(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 0)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [out_grad, in_grad, in_grad_tmp]
@defop(name="backward_rad2deg", target="cpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def backward_rad2deg(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 1)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [out_grad, in_grad, in_grad_tmp]
@defop(name="cuda_backward_deg2rad", target="gpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def cuda_backward_deg2rad(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 0)
num_thread = 64
for t in c_list:
block_x = tvm.te.thread_axis("blockIdx.x")
thread_x = tvm.te.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [out_grad, in_grad, in_grad_tmp]
@defop(name="cuda_backward_rad2deg", target="gpu", auto_broadcast=False,
dtype=["float32", "float64"], ndim=list(range(0, 6)), req=["kWriteTo", "kAddTo"],
attrs=["req"])
def cuda_backward_rad2deg(dtype, ndim, req):
s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 1)
num_thread = 64
for t in c_list:
block_x = tvm.te.thread_axis("blockIdx.x")
thread_x = tvm.te.thread_axis("threadIdx.x")
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
bx, tx = s[t].split(fused, factor=num_thread)
s[t].bind(bx, block_x)
s[t].bind(tx, thread_x)
return s, [out_grad, in_grad, in_grad_tmp]