"""Conv2D operator declaration and schedule registration for VTA."""
import numpy as np
import tvm
from tvm import autotvm
import topi
from .util import is_packed_layout
from ..environment import get_env
def conv2d_packed(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
""" Packed conv2d function."""
if not is_packed_layout(layout):
raise topi.InvalidShapeError()
assert dilation == (1, 1)
if padding[0]:
pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0, 0], name="pad_data")
pad_data = data
assert len(data.shape) == 6
assert len(kernel.shape) == 6
oheight = topi.util.get_const_int((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1)
owidth = topi.util.get_const_int((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1)
oshape = (data.shape[0], kernel.shape[0], oheight, owidth, data.shape[4], kernel.shape[4])
ishape = topi.util.get_const_tuple(data.shape)
kshape = topi.util.get_const_tuple(kernel.shape)
d_i = tvm.reduce_axis((0, kshape[2]), name='d_i')
d_j = tvm.reduce_axis((0, kshape[3]), name='d_j')
k_o = tvm.reduce_axis((0, ishape[1]), name='k_o')
k_i = tvm.reduce_axis((0, ishape[-1]), name='k_i')
hstride, wstride = strides
res = tvm.compute(
lambda b_o, c_o, i, j, b_i, c_i: tvm.sum(
pad_data[b_o, k_o, i*hstride+d_i, j*wstride+d_j, b_i, k_i].astype(out_dtype) *
kernel[c_o, k_o, d_i, d_j, c_i, k_i].astype(out_dtype),
axis=[k_o, d_i, d_j, k_i]),
name="res", tag="conv2d_dense")
cfg.add_flop(2 * *
kshape[2] * kshape[3] * ishape[1] * ishape[-1])
return res
def schedule_conv2d_packed(cfg, outs):
"""Schedule packed conv2d"""
assert len(outs) == 1
output = outs[0]
const_ops = []
ewise_inputs = []
ewise_ops = []
conv2d_res = []
assert "int" in output.op.input_tensors[0].dtype
def _traverse(op):
if topi.tag.is_broadcast(op.tag):
if not op.same_as(output.op):
if not op.axis:
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.PlaceholderOp):
ewise_inputs.append((op, tensor))
assert op.tag == "conv2d_dense"
assert len(conv2d_res) == 1
conv2d_stage = conv2d_res[0].output(0)
s = tvm.create_schedule(output.op)
##### space definition begin #####
b, c_o, x_i, x_j, _, _ = s[conv2d_stage].op.axis
c_i, _, _, _ = s[conv2d_stage].op.reduce_axis
cfg.define_split('tile_b', b, num_outputs=2)
cfg.define_split('tile_h', x_i, num_outputs=2)
cfg.define_split('tile_w', x_j, num_outputs=2)
cfg.define_split('tile_ci', c_i, num_outputs=2)
cfg.define_split('tile_co', c_o, num_outputs=2)
cfg.define_knob('oc_nthread', [1, 2])
cfg.define_knob('h_nthread', [1, 2])
###### space definition end ######
data, kernel = conv2d_stage.op.input_tensors
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
temp = data.op.input_tensors[0]
pad_data = data
data = temp
pad_data = None
env = get_env()
# setup pad
if pad_data is not None:
cdata = pad_data
cdata = s.cache_read(data, env.inp_scope, [conv2d_stage])
ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage])
# cache read input
cache_read_ewise = []
for consumer, tensor in ewise_inputs:
s.cache_read(tensor, env.acc_scope, [consumer]))
# set ewise scope
for op in ewise_ops:
s[op].pragma(s[op].op.axis[0], env.alu)
for op in const_ops:
# tile
x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis
x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co)
x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i)
x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j)
s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci)
store_pt = x_j0
# set all compute scopes
s[conv2d_stage].compute_at(s[output], store_pt)
for op in ewise_ops:
s[op].compute_at(s[output], store_pt)
for tensor in cache_read_ewise:
s[tensor].compute_at(s[output], store_pt)
s[tensor].pragma(s[tensor].op.axis[0], env.dma_copy)
# virtual threading along output channel axes
if cfg['oc_nthread'].val > 1:
_, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val)
s[output].reorder(v_t, x_bo)
s[output].bind(v_t, tvm.thread_axis("cthread"))
# virtual threading along spatial rows
if cfg['h_nthread'].val > 1:
_, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val)
s[output].reorder(v_t, x_bo)
s[output].bind(v_t, tvm.thread_axis("cthread"))
x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis
k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis
s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i)
k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o)
s[cdata].compute_at(s[conv2d_stage], k_o)
s[ckernel].compute_at(s[conv2d_stage], k_o)
# Use VTA instructions
s[cdata].pragma(s[cdata].op.axis[0], env.dma_copy)
s[ckernel].pragma(s[ckernel].op.axis[0], env.dma_copy)
s[conv2d_stage].tensorize(x_bi, env.gemm)
s[output].pragma(x_co1, env.dma_copy)
return s