blob: 4b2ee11fe2e1c15bd93c8881bbac992634135e7c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-locals, too-many-arguments
"""Schedule for bitserial dense operator."""
from __future__ import absolute_import as _abs
import tvm
from tvm import te
from tvm import autotvm
from tvm.topi.utils import get_const_int, get_const_tuple
from .. import tag
from ..nn.bitserial_util import bitpack, binary_op_multiplier
@autotvm.register_topi_compute("bitserial_dense.x86")
def bitserial_dense(
cfg, data, weight, data_bits, weight_bits, pack_dtype="uint32", out_dtype="int16", unipolar=True
):
"""Bitserial dense implementation. TODO: Why are these separate
Parameters
----------
data : tvm.te.Tensor
2-D with shape [batch, in_dim]
weight : tvm.te.Tensor
2-D with shape [out_dim, in_dim] or
3-D with shape [out_dim, weight_bits, in_dim]
Returns
-------
output : tvm.te.Tensor
2-D with shape [batch, out_dim]
"""
data_packed = bitpack(data, data_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
if len(weight.shape) == 2:
weight_packed = bitpack(weight, weight_bits, pack_axis=1, bit_axis=1, pack_type=pack_dtype)
else:
weight_packed = weight
Y, DB, K = get_const_tuple(data_packed.shape)
X, WB, _ = get_const_tuple(weight_packed.shape)
######## Search space
x, y = cfg.axis(X), cfg.axis(Y)
db, wb, k = cfg.reduce_axis(DB), cfg.reduce_axis(WB), cfg.reduce_axis(K)
ko, ki = cfg.define_split("tile_k", k, num_outputs=2)
yo, yi = cfg.define_split("tile_y", y, num_outputs=2)
xo, xi = cfg.define_split("tile_x", x, num_outputs=2)
cfg.define_reorder(
"reorder_0",
[yo, xo, ko, yi, wb, db, ki, xi],
policy="candidate",
candidate=[[yo, xo, ko, yi, wb, db, ki, xi], [yo, xo, yi, ko, wb, db, ki, xi]],
)
cfg.define_annotate("ann_reduce", [db, wb], policy="try_unroll")
cfg.define_annotate("ann_spatial", [yi, xi], policy="try_unroll_vec")
###### Compute rule
VX = cfg["tile_x"].size[-1]
wvshape = (X // VX, WB, VX, K)
oshape = (Y, X)
k = te.reduce_axis((0, K), name="k")
db = te.reduce_axis((0, DB), name="db")
wb = te.reduce_axis((0, WB), name="wb")
# Tile data and weights
weight_vec = te.compute(
wvshape, lambda xo, wb, vx, k: weight_packed[xo * VX + vx][wb][k], name="weight_vec"
)
idxdiv = tvm.tir.indexdiv
idxmod = tvm.tir.indexmod
matmul_unipolar = te.compute(
oshape,
lambda i, j: te.sum(
(
tvm.tir.popcount(
weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
)
- tvm.tir.popcount(
~weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
)
).astype(out_dtype)
<< (db + wb).astype(out_dtype),
axis=[wb, db, k],
),
tag="bitserial_dense_unipolar",
)
matmul = te.compute(
oshape,
lambda i, j: te.sum(
tvm.tir.popcount(
weight_vec[idxdiv(j, VX), wb, idxmod(j, VX), k] & data_packed[i, db, k]
).astype(out_dtype)
<< (db + wb).astype(out_dtype),
axis=[wb, db, k],
),
tag="bitserial_dense",
)
# binary ops
cfg.add_flop(2 * Y * X * K * binary_op_multiplier(pack_dtype))
if unipolar:
return matmul_unipolar
return matmul
@autotvm.register_topi_schedule("bitserial_dense.x86")
def schedule_bitserial_dense(cfg, outs):
"""Schedule for bitserial_dense.
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial dense operator.
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for bitserial_dense.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
def _schedule(cfg, s, data_vec, weight_vec, output):
s[data_vec].parallel(s[data_vec].op.axis[0])
s[weight_vec].parallel(s[weight_vec].op.axis[0])
y, x = s[output].op.axis
wb, db, k = s[output].op.reduce_axis
yo, yi = cfg["tile_y"].apply(s, output, y)
xo, xi = cfg["tile_x"].apply(s, output, x)
ko, ki = cfg["tile_k"].apply(s, output, k)
cfg["reorder_0"].apply(s, output, [yo, xo, ko, yi, wb, db, ki, xi])
cfg["ann_reduce"].apply(
s,
output,
[db, wb],
axis_lens=[get_const_int(db.dom.extent), get_const_int(wb.dom.extent)],
max_unroll=8,
cfg=cfg,
)
cfg["ann_spatial"].apply(
s,
output,
[yi, xi],
axis_lens=[cfg["tile_y"].size[-1], cfg["tile_x"].size[-1]],
max_unroll=8,
cfg=cfg,
)
s[output].vectorize(xi)
s[output].parallel(yo)
return s
def traverse(op):
"""Internal traverse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag) or "elemwise" in op.tag:
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.te.ComputeOp):
traverse(tensor.op)
elif op.tag == "bitserial_dense" or "bitserial_dense_unipolar":
output = op.output(0)
weight_vec = op.input_tensors[0]
data_vec = op.input_tensors[1]
data = data_vec.op.input_tensors[0]
if "QuantizeInput" in data.op.name:
data = data.op.input_tensors[0]
_schedule(cfg, s, data_vec, weight_vec, output)
else:
raise RuntimeError("Unsupported operator: %s" % op.tag)
traverse(outs[0].op)
return s