blob: 20f2c4ac128c9046a9b573a8f2d14e45749ba20b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin, no-else-return
"""Conv3D operators"""
from collections import namedtuple
import tvm
from tvm import autotvm, te
from tvm.autotvm.task.space import OtherOptionEntity, SplitEntity
from tvm.target.x86 import get_simd_32bit_lanes
from ..nn.pad import pad
from ..nn.utils import get_pad_tuple3d, infer_pad3d
from ..utils import get_const_int, get_const_tuple, simplify, traverse_inline
Workload3D = namedtuple(
"Workload",
[
"in_dtype",
"out_dtype",
"depth",
"height",
"width",
"in_filter",
"groups",
"out_filter",
"dkernel",
"hkernel",
"wkernel",
"dpad",
"hpad",
"wpad",
"dstride",
"hstride",
"wstride",
],
)
@autotvm.register_topi_compute("conv3d_ndhwc.x86")
def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, groups, out_dtype):
"""3D convolution forward operator.
Parameters
----------
input : tvm.te.Tensor
5-D input data with shapes:
[batch, in_depth, in_height, in_width, in_channel] for NDHWC layout
filter : tvm.te.Tensor
5-D filter with shape [kernel_depth, kernel_height, kernel_width, in_channels, out_channels]
strides : int or a list/tuple of three ints
stride size, or [stride_depth, stride_height, stride_width]
padding : int or a list/tuple of three ints
padding size, or [pad_depth, pad_height, pad_width]
dilation: int or a list/tuple of three ints
dilation size, or [dilation_depth, dilation_height, dilation_width]
groups: int
Number of groups
Returns
-------
output : tvm.te.Tensor
5-D with shape [batch, out_depth, out_height, out_width, out_channel] for NDHWC layout
"""
layout = "NDHWC"
out_dtype = data.dtype if out_dtype is None else out_dtype
strides = strides if isinstance(strides, (tuple, list)) else (strides, strides, strides)
dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation, dilation)
_create_tuning_space(cfg, data, kernel, strides, padding, dilation, groups, layout)
if cfg.is_fallback:
_get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, layout)
return _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, groups, out_dtype)
@autotvm.register_topi_compute("conv3d_ncdhw.x86")
def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, groups, out_dtype):
"""3D convolution forward operator.
Parameters
----------
input : tvm.te.Tensor
5-D input data with shapes:
[batch, in_channel, in_depth, in_height, in_width] for NCDHW layout
filter : tvm.te.Tensor
5-D filter with shape [out_channels, in_channels, kernel_depth, kernel_height, kernel_width]
strides : int or a list/tuple of three ints
stride size, or [stride_depth, stride_height, stride_width]
padding : int or a list/tuple of three ints
padding size, or [pad_depth, pad_height, pad_width]
dilation: int or a list/tuple of three ints
dilation size, or [dilation_depth, dilation_height, dilation_width]
groups: int
Number of groups
Returns
-------
output : tvm.te.Tensor
5-D with shape [batch, out_channel, out_depth, out_height, out_width] for NCDHW layout
"""
# assert groups == 1, "conv3d_ncdhw.x86 does not support groups"
layout = "NCDHW"
out_dtype = data.dtype if out_dtype is None else out_dtype
strides = strides if isinstance(strides, (tuple, list)) else (strides, strides, strides)
dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation, dilation)
_create_tuning_space(cfg, data, kernel, strides, padding, dilation, groups, layout)
if cfg.is_fallback:
_get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, layout)
return _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, groups, out_dtype)
@autotvm.register_topi_schedule("conv3d_ndhwc.x86")
def schedule_conv3d_ndhwc(cfg, outs):
"""TOPI schedule callback for conv3d
Parameters
----------
outs: Array of Tensor
The computation graph description of conv3d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv3d.
"""
s = te.create_schedule([x.op for x in outs])
def _traverse(op):
if "conv3d_ndhwc" in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.te.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
kd, kh, kw, i, o = get_const_tuple(kernel.shape)
args = [s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]]
_schedule_conv3d_ndhwc(*args)
traverse_inline(s, outs[0].op, _traverse)
return s
@autotvm.register_topi_schedule("conv3d_ncdhw.x86")
def schedule_conv3d_ncdhw(cfg, outs):
"""TOPI schedule callback for conv3d
Parameters
----------
outs: Array of Tensor
The computation graph description of conv3d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv3d.
"""
s = te.create_schedule([x.op for x in outs])
def _traverse(op):
if "conv3d_ncdhw" in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.te.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
kd, kh, kw, i, o = get_const_tuple(kernel.shape)
args = [s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]]
_schedule_conv3d_ncdhw(*args)
traverse_inline(s, outs[0].op, _traverse)
return s
def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, groups, out_dtype):
out_dtype = data.dtype if out_dtype is None else out_dtype
assert isinstance(dilation, int) or len(dilation) == 3
if isinstance(dilation, int):
dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation)
else:
dilation_d, dilation_h, dilation_w = dilation
DSTR, HSTR, WSTR = strides
batch_size, in_depth, in_height, in_width, in_channel = get_const_tuple(data.shape)
kernel_depth, kernel_height, kernel_width, _, num_filter = get_const_tuple(kernel.shape)
assert in_channel % groups == 0, "input channels must be a multiple of group size"
assert num_filter % groups == 0, "number of filters must be a multiple of group size"
dilated_kernel_d = (kernel_depth - 1) * dilation_d + 1
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w)
)
pad_d = pad_front + pad_back
pad_h = pad_top + pad_down
pad_w = pad_left + pad_right
pad_depth = in_depth + pad_d
pad_height = in_height + pad_h
pad_width = in_width + pad_w
out_depth = simplify((in_depth + pad_d - dilated_kernel_d) // DSTR + 1)
out_height = simplify((in_height + pad_h - dilated_kernel_h) // HSTR + 1)
out_width = simplify((in_width + pad_w - dilated_kernel_w) // WSTR + 1)
# pack data
DOPAD = pad_d != 0 or pad_h != 0 or pad_w != 0
if DOPAD:
data_pad = pad(
data,
(0, pad_front, pad_top, pad_left, 0),
(0, pad_back, pad_down, pad_right, 0),
name="data_pad",
)
else:
data_pad = data
# fetch schedule
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
assert groups == 1 or ic_bn <= groups
assert groups == 1 or oc_bn <= groups
shape = (batch_size, in_channel // ic_bn, pad_depth, pad_height, ic_bn, pad_width)
data_vec = te.compute(
shape, lambda n, C, d, h, c, w: data_pad[n, d, h, w, C * ic_bn + c], name="data_vec"
)
ci_tile = in_channel // groups // ic_bn
if ci_tile == 0 or ci_tile * ic_bn * groups < in_channel:
ci_tile += 1
# pack kernel
shape = (num_filter // oc_bn, ci_tile, kernel_depth, kernel_height, kernel_width, ic_bn, oc_bn)
kernel_vec = te.compute(
shape,
lambda CO, CI, d, h, w, ci, co: kernel[d, h, w, CI * ic_bn + ci, CO * oc_bn + co],
name="kernel_vec",
)
# convolution
oshape = (batch_size, num_filter // oc_bn, out_depth, out_height, out_width, oc_bn)
unpack_shape = (batch_size, out_depth, out_height, out_width, num_filter)
ic = te.reduce_axis((0, in_channel // groups), name="ic")
kh = te.reduce_axis((0, kernel_height), name="kh")
kw = te.reduce_axis((0, kernel_width), name="kw")
kd = te.reduce_axis((0, kernel_depth), name="kd")
idxmod = tvm.tir.indexmod
idxdiv = tvm.tir.indexdiv
conv = te.compute(
oshape,
lambda n, oc_chunk, od, oh, ow, oc_block: te.sum(
data_vec[
n,
idxdiv(
(oc_chunk * oc_bn + oc_block) // (num_filter // groups) * (in_channel // groups)
+ ic,
ic_bn,
),
od * DSTR + kd * dilation_d,
oh * HSTR + kh * dilation_h,
idxmod(
(oc_chunk * oc_bn + oc_block) // (num_filter // groups) * (in_channel // groups)
+ ic,
ic_bn,
),
ow * WSTR + kw * dilation_w,
].astype(out_dtype)
* kernel_vec[
oc_chunk, idxdiv(ic, ic_bn), kd, kh, kw, idxmod(ic, ic_bn), oc_block
].astype(out_dtype),
axis=[kd, kh, kw, ic],
),
name="conv",
)
conv_unpacked = te.compute(
unpack_shape,
lambda n, d, h, w, c: conv[n, idxdiv(c, oc_bn), d, h, w, idxmod(c, oc_bn)].astype(
out_dtype
),
name="output_unpack",
tag="conv3d_ndhwc",
)
return conv_unpacked
def _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, groups, out_dtype):
out_dtype = data.dtype if out_dtype is None else out_dtype
assert isinstance(dilation, int) or len(dilation) == 3
if isinstance(dilation, int):
dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation)
else:
dilation_d, dilation_h, dilation_w = dilation
DSTR, HSTR, WSTR = strides
batch_size, in_channel, in_depth, in_height, in_width = get_const_tuple(data.shape)
num_filter, _, kernel_depth, kernel_height, kernel_width = get_const_tuple(kernel.shape)
dilated_kernel_d = (kernel_depth - 1) * dilation_d + 1
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w)
)
pad_d = pad_front + pad_back
pad_h = pad_top + pad_down
pad_w = pad_left + pad_right
pad_depth = in_depth + pad_d
pad_height = in_height + pad_h
pad_width = in_width + pad_w
out_depth = simplify((in_depth + pad_d - dilated_kernel_d) // DSTR + 1)
out_height = simplify((in_height + pad_h - dilated_kernel_h) // HSTR + 1)
out_width = simplify((in_width + pad_w - dilated_kernel_w) // WSTR + 1)
# pack data
DOPAD = pad_d != 0 or pad_h != 0 or pad_w != 0
if DOPAD:
data_pad = pad(
data,
(0, 0, pad_front, pad_top, pad_left),
(0, 0, pad_back, pad_down, pad_right),
name="data_pad",
)
else:
data_pad = data
# fetch schedule
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
shape = (batch_size, in_channel // ic_bn, pad_depth, pad_height, ic_bn, pad_width)
data_vec = te.compute(
shape, lambda n, C, d, h, c, w: data_pad[n, C * ic_bn + c, d, h, w], name="data_vec"
)
ci_tile = in_channel // groups // ic_bn
if ci_tile == 0 or ci_tile * ic_bn * groups < in_channel:
ci_tile += 1
# pack kernel
shape = (num_filter // oc_bn, ci_tile, kernel_depth, kernel_height, kernel_width, ic_bn, oc_bn)
kernel_vec = te.compute(
shape,
lambda CO, CI, d, h, w, ci, co: kernel[CO * oc_bn + co, CI * ic_bn + ci, d, h, w],
name="kernel_vec",
)
# convolution
oshape = (batch_size, num_filter // oc_bn, out_depth, out_height, out_width, oc_bn)
unpack_shape = (batch_size, num_filter, out_depth, out_height, out_width)
ic = te.reduce_axis((0, in_channel // groups), name="ic")
kh = te.reduce_axis((0, kernel_height), name="kh")
kw = te.reduce_axis((0, kernel_width), name="kw")
kd = te.reduce_axis((0, kernel_depth), name="kd")
idxmod = tvm.tir.indexmod
idxdiv = tvm.tir.indexdiv
conv = te.compute(
oshape,
lambda n, oc_chunk, od, oh, ow, oc_block: te.sum(
data_vec[
n,
idxdiv(
(oc_chunk * oc_bn + oc_block) // (num_filter // groups) * (in_channel // groups)
+ ic,
ic_bn,
),
od * DSTR + kd * dilation_d,
oh * HSTR + kh * dilation_h,
idxmod(
(oc_chunk * oc_bn + oc_block) // (num_filter // groups) * (in_channel // groups)
+ ic,
ic_bn,
),
ow * WSTR + kw * dilation_w,
].astype(out_dtype)
* kernel_vec[
oc_chunk, idxdiv(ic, ic_bn), kd, kh, kw, idxmod(ic, ic_bn), oc_block
].astype(out_dtype),
axis=[ic, kd, kh, kw],
),
name="conv",
)
conv_unpacked = te.compute(
unpack_shape,
lambda n, c, d, h, w: conv[n, idxdiv(c, oc_bn), d, h, w, idxmod(c, oc_bn)].astype(
out_dtype
),
name="output_unpack",
tag="conv3d_ncdhw",
)
return conv_unpacked
def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, groups, layout):
"""Create schedule configuration from input arguments"""
dshape = get_const_tuple(data.shape)
kshape = get_const_tuple(kernel.shape)
if layout == "NDHWC":
n, d, h, w, ic = dshape
kd, kh, kw, _, oc = kshape
elif layout == "NCDHW":
n, ic, d, h, w = dshape
oc, _, kd, kh, kw = kshape
else:
raise ValueError(f"Not support this layout {layout} with schedule template.")
# pad_front, pad_top, pad_left, pad_back, pad_down(bottom), pad_right
pf, pt, pl, pb, pd, pr = get_pad_tuple3d(padding, (kd, kh, kw))
sd, sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides, strides)
od = (d - kd + pf + pb) // sd + 1
oh = (h - kh + pt + pd) // sh + 1
ow = (w - kw + pl + pr) // sw + 1
# Create schedule config
cfg.define_split("tile_ic", ic, num_outputs=2)
cfg.define_split("tile_oc", oc, num_outputs=2)
cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 8)
cfg.define_knob("unroll_kw", [True, False])
def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype, layout):
"""
Get default schedule config for the workload
"""
if layout not in ["NDHWC", "NCDHW"]:
raise ValueError(f"Layout {layout} is not supported")
static_data_shape = []
for dim in get_const_tuple(data.shape):
if isinstance(dim, tvm.tir.Var):
static_data_shape.append(1)
else:
static_data_shape.append(dim)
data = te.placeholder(static_data_shape, dtype=data.dtype)
wkl = _get_conv3d_workload(data, kernel, strides, padding, groups, out_dtype, layout)
_fallback_schedule(cfg, wkl)
def _get_conv3d_workload(data, kernel, stride, padding, groups, out_dtype, data_layout="NCHW"):
"""Get the workload structure."""
if data_layout == "NCDHW":
_, CI, ID, IH, IW = get_const_tuple(data.shape)
CO, CIG, KD, KH, KW = get_const_tuple(kernel.shape)
elif data_layout == "NDHWC":
_, ID, IH, IW, CI = get_const_tuple(data.shape)
KD, KH, KW, CIG, CO = get_const_tuple(kernel.shape)
else:
raise ValueError(f"not support this layout {data_layout} yet")
pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
padding, (get_const_int(KD), get_const_int(KH), get_const_int(KW))
)
DPAD = pad_front + pad_back
HPAD = pad_top + pad_down
WPAD = pad_left + pad_right
if isinstance(stride, (tuple, list)):
DSTR, HSTR, WSTR = stride
else:
DSTR, HSTR, WSTR = stride, stride, stride
assert (data.dtype == kernel.dtype) or (
data.dtype == "uint8" and kernel.dtype == "int8"
), f"Do not support inputs with different data types now. {data.dtype} vs. {kernel.dtype}"
return Workload3D(
data.dtype,
out_dtype,
ID,
IH,
IW,
CI,
groups,
CO,
KD,
KH,
KW,
DPAD,
HPAD,
WPAD,
DSTR,
HSTR,
WSTR,
)
def _fallback_schedule(cfg, wkl):
simd_width = get_simd_32bit_lanes()
DPAD, HPAD, WPAD = wkl.dpad, wkl.hpad, wkl.wpad
DSTR, HSTR, WSTR = wkl.dstride, wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
oc_bn = 1
for bn in range(simd_width, 0, -1):
if wkl.out_filter % bn == 0:
oc_bn = bn
break
ic_bn = 1
for bn in range(oc_bn, 0, -1):
if wkl.in_filter % bn == 0:
ic_bn = bn
break
reg_n = 1
for n in range(7, 0, -1):
if out_width % n == 0:
reg_n = n
break
cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
cfg["unroll_kw"] = OtherOptionEntity(False)
def _schedule_conv3d_ndhwc(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
# fetch schedule
ic_bn, oc_bn, reg_n, unroll_kw = (
cfg["tile_ic"].size[-1],
cfg["tile_oc"].size[-1],
cfg["tile_ow"].size[-1],
cfg["unroll_kw"].val,
)
# get padding size
padding = infer_pad3d(data, data_pad, "NDHWC")
DPAD, HPAD, WPAD = padding
DOPAD = DPAD != 0 or HPAD != 0 or WPAD != 0
A, W = data, kernel_vec
A0, A1 = data_pad, data_vec
# schedule data
if DOPAD:
s[A0].compute_inline()
batch, ic_chunk, idd, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(batch, ic_chunk, idd, ih)
s[A1].parallel(parallel_axis)
# schedule kernel pack
oc_chunk, ic_chunk, od, oh, ow, ic_block, oc_block = s[W].op.axis
s[W].reorder(oc_chunk, od, oh, ic_chunk, ow, ic_block, oc_block)
if oc_bn > 1:
s[W].vectorize(oc_block)
parallel_axis = s[W].fuse(oc_chunk, od, oh)
s[W].parallel(parallel_axis)
# schedule conv
C, O0, O = conv_out, output, last
CC = s.cache_write(C, "global")
_, oc_chunk, od, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
s[C].reorder(oc_chunk, od, oh, ow_chunk, ow_block, oc_block)
s[C].fuse(oc_chunk, od, oh)
s[C].vectorize(oc_block)
s[CC].compute_at(s[C], ow_chunk)
_, oc_chunk, od, oh, ow, oc_block = s[CC].op.axis
kd, kh, kw, ic = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn)
if unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kd, kh, ic_block, kw, ow_block, oc_block)
s[CC].unroll(kw)
else:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kd, kh, kw, ic_block, ow_block, oc_block)
s[CC].fuse(oc_chunk, od, oh)
s[CC].vectorize(oc_block)
s[CC].unroll(ow_block)
if O0 != O:
s[O0].compute_inline()
# unpacking
batch, od, oh, ow, oc = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
s[O].reorder(oc_chunk, od, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, od, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s
def _schedule_conv3d_ncdhw(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
# fetch schedule
ic_bn, oc_bn, reg_n, unroll_kw = (
cfg["tile_ic"].size[-1],
cfg["tile_oc"].size[-1],
cfg["tile_ow"].size[-1],
cfg["unroll_kw"].val,
)
# get padding size
padding = infer_pad3d(data, data_pad, "NCDHW")
DPAD, HPAD, WPAD = padding
DOPAD = DPAD != 0 or HPAD != 0 or WPAD != 0
A, W = data, kernel_vec
A0, A1 = data_pad, data_vec
# schedule data
if DOPAD:
s[A0].compute_inline()
batch, ic_chunk, idd, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(batch, ic_chunk, idd, ih)
s[A1].parallel(parallel_axis)
# schedule kernel pack
oc_chunk, ic_chunk, od, oh, ow, ic_block, oc_block = s[W].op.axis
s[W].reorder(oc_chunk, od, oh, ic_chunk, ow, ic_block, oc_block)
if oc_bn > 1:
s[W].vectorize(oc_block)
parallel_axis = s[W].fuse(oc_chunk, od, oh)
s[W].parallel(parallel_axis)
# schedule conv
C, O0, O = conv_out, output, last
CC = s.cache_write(C, "global")
_, oc_chunk, od, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
s[C].reorder(oc_chunk, od, oh, ow_chunk, ow_block, oc_block)
s[C].fuse(oc_chunk, od, oh)
s[C].vectorize(oc_block)
s[CC].compute_at(s[C], ow_chunk)
_, oc_chunk, od, oh, ow, oc_block = s[CC].op.axis
ic, kd, kh, kw = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn)
if unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kd, kh, ic_block, kw, ow_block, oc_block)
s[CC].unroll(kw)
else:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kd, kh, kw, ic_block, ow_block, oc_block)
s[CC].fuse(oc_chunk, od, oh)
s[CC].vectorize(oc_block)
s[CC].unroll(ow_block)
if O0 != O:
s[O0].compute_inline()
# unpacking
batch, oc, od, oh, ow = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
s[O].reorder(oc_chunk, od, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, od, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
return s