blob: 5ec11677da7044e148510c686c05e400b4ae41ee [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""A Relay implementation of graph packing."""
import tvm
from tvm import relay
from tvm.relay import op, transform
from tvm.relay import ExprMutator
def run_opt_pass(expr, opt_pass):
"""Exectue a relay pass."""
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def _to_shape(shape):
"""convert shape into tuple."""
return tuple(int(sh) for sh in shape)
def _pack_batch_channel(data, dshape, bfactor, cfactor):
"""Pack the data channel dimension."""
assert int(dshape[0]) % bfactor == 0
assert int(dshape[1]) % cfactor == 0
data = op.reshape(
data,
newshape=(
int(dshape[0]) // bfactor,
bfactor,
int(dshape[1]) // cfactor,
cfactor,
int(dshape[2]),
int(dshape[3]),
),
)
data = op.transpose(data, axes=(0, 2, 4, 5, 1, 3))
return data
def _unpack_batch_channel(data, old_shape):
"""Unpack the data channel dimension."""
data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
data = op.reshape(data, newshape=old_shape)
return data
def _const_shape_match(data, dshape, cfactor_out):
"""Pad the constant if the shape[0] not divisible by cfactor_out."""
assert len(dshape) == 3
pad_width = int(dshape[0]) % cfactor_out
if pad_width != 0:
pad_width = cfactor_out - pad_width
data = op.nn.pad(data, [[0, pad_width], [0, 0], [0, 0]])
dshape = tuple([dshape[0] + pad_width, dshape[1], dshape[2]])
return data, dshape
def _weight_shape_match(data, dshape, channels, cfactor_out, transpose=False):
"""Pad the weight if the shape[0] not divisible by cfactor_out."""
assert len(dshape) == 4
pad_width = int(dshape[0]) % cfactor_out
channels_pad = int(channels) % cfactor_out
if pad_width != 0:
pad_width = cfactor_out - pad_width
data = op.nn.pad(data, [[0, pad_width], [0, 0], [0, 0], [0, 0]])
dshape = tuple([dshape[0] + pad_width, dshape[1], dshape[2], dshape[3]])
if channels_pad != 0:
channels = channels + (cfactor_out - channels_pad)
return data, dshape, channels
def _weight_shape_match_transpose(data, dshape, channels, cfactor_out):
"""Pad the weight if the shape[1] not divisible by cfactor_out."""
assert len(dshape) == 4
pad_width = int(dshape[1]) % cfactor_out
channels_pad = int(channels) % cfactor_out
if pad_width != 0:
pad_width = cfactor_out - pad_width
data = op.nn.pad(data, [[0, 0], [0, pad_width], [0, 0], [0, 0]])
dshape = tuple(dshape[0], [dshape[1] + pad_width, dshape[2], dshape[3]])
if channels_pad != 0:
channels = channels + (cfactor_out - channels_pad)
return data, dshape, channels
def _pack_weight(data, dshape, cfactor):
"""Pack the weight into packed format."""
assert len(dshape) == 4
assert int(dshape[0]) % cfactor == 0
assert int(dshape[1]) % cfactor == 0
data = op.reshape(
data,
newshape=(
int(dshape[0]) // cfactor,
cfactor,
int(dshape[1]) // cfactor,
cfactor,
int(dshape[2]),
int(dshape[3]),
),
)
data = op.transpose(data, axes=(0, 2, 4, 5, 1, 3))
return data
def _pack_weight_conv2d_transpose(data, dshape, cfactor):
"""Pack the weight into packed format."""
dshape = _to_shape(dshape)
assert len(dshape) == 4
assert dshape[0] % cfactor == 0
assert dshape[1] % cfactor == 0
data = op.reshape(
data,
newshape=(
dshape[0] // cfactor,
cfactor,
dshape[1] // cfactor,
cfactor,
dshape[2],
dshape[3],
),
)
data = op.transpose(data, axes=(2, 0, 4, 5, 3, 1))
return data
def _pack_const(data, dshape, dtype, bfactor, cfactor):
"""Pack a constant parameter."""
dshape = _to_shape(dshape)
assert len(dshape) == 3
assert dshape[0] % cfactor == 0
data = op.reshape(data, newshape=(dshape[0] // cfactor, cfactor, dshape[1], dshape[2], 1))
data = op.transpose(data, axes=(0, 2, 3, 4, 1))
# broadcast batch dimension to bfactor
data = op.broadcast_to(
data, shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor)
)
return data
def _get_tensor_shape(node):
"""Get node shape."""
if isinstance(node.checked_type, relay.ty.TensorType):
return _to_shape(node.checked_type.shape)
return []
def _get_tensor_type(node):
"""Get node type."""
if isinstance(node.checked_type, relay.ty.TensorType):
return node.checked_type.dtype
return "float32"
def _operator_idx_inc(expr, count_meta, operator_current_idx):
"""Increase operator index"""
if isinstance(expr, relay.expr.Constant):
operator_current_idx = operator_current_idx + 1 if count_meta else operator_current_idx
else:
operator_current_idx = operator_current_idx + 1
return operator_current_idx
class ExprDeviceAnnot(ExprMutator):
"""Visitor to perform graph annotation on an AST.
Parameters
----------
start: int
the start location to mark run on vta (inclusive)
end: int
the end location to mark run on vta (exclusive)
Returns
---------
None
"""
def __init__(self, start=-1, end=-1):
self.ext_dev = tvm.device("ext_dev")
self.cpu_dev = tvm.device("cpu")
self.cast = op.op.get("cast")
self.counter = -1
self.start = start
self.end = end
super().__init__()
def visit_call(self, call):
""" Visit the children. """
# First visit the children.
args = [self.visit(arg) for arg in call.args]
self.counter += 1
if self.counter == self.start:
ret = relay.Call(call.op, args, call.attrs)
ret = relay.annotation.on_device(ret, self.ext_dev)
return ret
if self.counter == self.end:
ret = relay.Call(call.op, args, call.attrs)
ret = relay.annotation.on_device(ret, self.cpu_dev)
return ret
if self.counter > self.start and self.counter < self.end:
ret = relay.Call(call.op, args, call.attrs)
# skip the float op, i.e., float->int cast
if self.is_float_op(call):
return ret
return relay.annotation.on_device(ret, self.ext_dev)
return relay.Call(self.visit(call.op), args, call.attrs)
def is_float_op(self, call):
"""check if this op belongs to a float op
in general, float op's odtype is float;
a special case is float->int cast, which follow this op sequence:
multiply(float) -> round(float) -> clip(float) -> cast(int);
"""
args = call.args
odtype = _get_tensor_type(call)
if odtype == "float32":
return True
if call.op == self.cast:
idtype = _get_tensor_type(args[0])
if idtype == "float32":
return True
return False
class ExprLocator(ExprMutator):
"""Visitor to locate op on an AST."""
def __init__(self):
self.counter = -1
self.op2nodes = {}
super().__init__()
def visit_call(self, call):
""" Visit the children. """
# First visit the children.
args = [self.visit(arg) for arg in call.args]
odtype = _get_tensor_type(call)
self.counter += 1
if (call.op, odtype) in self.op2nodes:
self.op2nodes[(call.op, odtype)].append(self.counter)
else:
self.op2nodes[(call.op, odtype)] = [self.counter]
return relay.Call(self.visit(call.op), args, call.attrs)
class ExprPack(ExprMutator):
"""Visitor to perform graph packing on an AST."""
def __init__(self, bfactor, cfactor, weight_bits):
self.bfactor = bfactor
self.cfactor = cfactor
self.weight_bits = weight_bits
self.start_pack = False
# Cache Operator the algorithm matches against.
self.bitpack_start = op.op.get("annotation.bitpack_start")
self.bitpack_end = op.op.get("annotation.bitpack_end")
self.conv2d = op.op.get("nn.conv2d")
self.conv2d_transpose = op.op.get("nn.conv2d_transpose")
self.add = op.op.get("add")
self.multiply = op.op.get("multiply")
self.bias_add = op.op.get("nn.bias_add")
self.pad = op.op.get("nn.pad")
self.upsampling = op.op.get("nn.upsampling")
self.reshape = op.op.get("reshape")
self.number_of_conv2d = 0
super().__init__()
def visit_call(self, call):
""" Visit the children. """
# First visit the children.
oshape = _get_tensor_shape(call)
odtype = _get_tensor_type(call)
input_types = [arg.checked_type for arg in call.args]
args = [self.visit(arg) for arg in call.args]
# Start and stop cases.
if call.op == self.bitpack_start:
assert not self.start_pack
self.start_pack = True
return _pack_batch_channel(args[0], oshape, self.bfactor, self.cfactor)
if call.op == self.bitpack_end:
if self.start_pack:
self.start_pack = False
data = args[0]
data_shape = _get_tensor_shape(call.args[0])
return _unpack_batch_channel(data, data_shape)
if self.start_pack:
# Operator cases
if call.op == self.conv2d and odtype == "int32":
self.number_of_conv2d += 1
assert 8 % self.weight_bits == 0
w_lanes = 8 // self.weight_bits
data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
kernel_layout = "OIHW%do%di" % (self.cfactor, self.cfactor)
data, weight = args
data_shape = _to_shape(input_types[0].shape)
kernel_shape = _to_shape(input_types[1].shape)
channels = call.attrs.channels
weight, kernel_shape, channels = _weight_shape_match(
weight, kernel_shape, channels, self.cfactor
)
kernel = _pack_weight(weight, kernel_shape, self.cfactor)
# insert bit packing when necessary
if w_lanes != 1:
assert 8 % w_lanes == 0
kernel = op.bitpack(kernel, lanes=w_lanes)
conv2d = op.nn.conv2d(
data,
kernel,
strides=call.attrs.strides,
padding=call.attrs.padding,
dilation=call.attrs.dilation,
groups=call.attrs.groups,
channels=channels,
kernel_size=call.attrs.kernel_size,
data_layout=data_layout,
kernel_layout=kernel_layout,
out_dtype=call.attrs.out_dtype,
)
return conv2d
if call.op == self.conv2d_transpose and odtype == "int32":
self.number_of_conv2d += 1
assert 8 % self.weight_bits == 0
w_lanes = 8 // self.weight_bits
if self.start_pack:
data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
kernel_layout = "IOHW%di%do" % (self.cfactor, self.cfactor)
data, weight = args
data_shape = _to_shape(input_types[0].shape)
kernel_shape = _to_shape(input_types[1].shape)
channels = call.attrs.channels
weight, kernel_shape, channels = _weight_shape_match_transpose(
weight, kernel_shape, channels, self.cfactor
)
kernel = _pack_weight_conv2d_transpose(weight, kernel_shape, self.cfactor)
conv2d = op.nn.conv2d_transpose(
data,
kernel,
strides=call.attrs.strides,
padding=call.attrs.padding,
dilation=call.attrs.dilation,
groups=call.attrs.groups,
channels=call.attrs.channels,
kernel_size=call.attrs.kernel_size,
data_layout=data_layout,
kernel_layout=kernel_layout,
output_padding=call.attrs.output_padding,
out_dtype=call.attrs.out_dtype,
)
return conv2d
if call.op == self.add and tuple(input_types[0].shape) == tuple(input_types[1].shape):
pass
elif call.op == self.add and len(input_types[1].shape) == 3:
data, const = args
const, input_shape = _const_shape_match(const, input_types[1].shape, self.cfactor)
const = _pack_const(
const, _to_shape(input_shape), input_types[1].dtype, self.bfactor, self.cfactor
)
return relay.Call(self.add, [data, const])
elif call.op == self.multiply and tuple(input_types[0].shape) == tuple(
input_types[1].shape
):
pass
elif call.op == self.multiply and len(input_types[1].shape) == 3:
data, const = args
const = _pack_const(
const,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor,
)
return relay.Call(self.multiply, [data, const])
elif self.start_pack and call.op == self.bias_add:
data, bias = args
bias = _pack_const(
bias,
_to_shape(input_types[1].shape),
input_types[1].dtype,
self.bfactor,
self.cfactor,
)
return relay.Call(self.add, [data, bias])
elif (
self.start_pack and call.op == op.op.get("cast") and input_types[0].dtype == "int32"
):
cast = relay.Call(op.op.get("cast"), [args[0]], call.attrs)
return cast
elif call.op == self.pad:
pad_width = call.attrs.pad_width
if len(pad_width) == 6:
pass
elif len(pad_width) == 4:
(data,) = args
new_pad_width = []
new_pad_width.extend(pad_width)
for _ in range(2):
new_pad_width.append([0, 0])
return op.nn.pad(data, pad_value=call.attrs.pad_value, pad_width=new_pad_width)
elif call.op == self.upsampling:
(data,) = args
scale_h = call.attrs.scale_h
scale_w = call.attrs.scale_w
data_layout = "NCHW%dn%dc" % (self.bfactor, self.cfactor)
method = call.attrs.method
align_corners = call.attrs.align_corners
return op.nn.upsampling(data, scale_h, scale_w, data_layout, method, align_corners)
elif call.op == self.reshape and len(input_types[0].shape) == 4:
(data,) = args
data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
return op.reshape(data, [int(x) for x in input_types[0].shape])
return relay.Call(self.visit(call.op), args, call.attrs)
class BT(Exception):
pass
def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta):
"""We assume stop_name only appears once for simplicity.
This constraint will be lifted in the future.
bitpack_start and bitpack_end are both inclusive.
"""
bitpack_start = op.op.get("annotation.bitpack_start")
bitpack_end = op.op.get("annotation.bitpack_end")
anf = run_opt_pass(expr, transform.ToANormalForm())
operator_current_idx = 0
def _recursion(anf, start_found, stop_found, operator_current_idx):
"""Helper to obtain the subgraph."""
if isinstance(anf, relay.Function):
return relay.Function(
anf.params,
_recursion(anf.body, start_found, stop_found, operator_current_idx),
anf.ret_type,
anf.type_params,
anf.attrs,
)
if isinstance(anf, relay.expr.Let):
value = anf.value
if isinstance(value, relay.expr.Call):
if isinstance(value.op, tvm.ir.Op):
if value.op.name == start_name and not start_found:
if operator_current_idx == start_name_idx or start_name_idx is None:
value = relay.expr.Call(bitpack_start, [value])
start_found = True
elif value.op.name == stop_name:
if operator_current_idx == stop_name_idx or stop_name_idx is None:
raise BT()
operator_current_idx = _operator_idx_inc(value, count_meta, operator_current_idx)
try:
return relay.expr.Let(
anf.var,
value,
_recursion(anf.body, start_found, stop_found, operator_current_idx),
)
except BT:
assert start_found
assert not stop_found
stop_found = True
value = relay.expr.Call(bitpack_end, [value])
# todo: check anf.body has no more stop_name beside that one
return relay.expr.Let(anf.var, value, anf.body)
else:
assert start_found
assert stop_found
return anf
annotated = _recursion(anf, False, False, operator_current_idx)
return run_opt_pass(annotated, transform.ToGraphNormalForm())
def graph_pack(
expr,
bfactor,
cfactor,
weight_bits,
start_name="nn.max_pool2d",
stop_name="nn.global_avg_pool2d",
start_name_idx=None,
stop_name_idx=None,
count_meta=False,
device_annot=False,
annot_start_name="nn.conv2d",
annot_end_name="annotation.stop_fusion",
):
"""Pack the graph into batch&channel packed format.
Parameters
----------
expr : relay.Expr
The input program.
bfactor : int
The packing factor in batch
cfactor : int
The packing factor in channel
weight_bits: int
The bit-width of the weights.
start_name: str, optional
Start packing from certain known node when start_name_idx is None.
stop_name: str, optional
Stop packing from certain known node when stop_name_idx is None.
start_name_idx: int, optional
When start_name_idx not None, start packing only when node name equal start_name
and node idx equals start_name_idx.
stop_name_idx: int, optional
When stop_name_idx not None, stop packing only when node name equal stop_name
and node index equals stop_name_idx.
count_meta:boolean, optional
When count_meta is False, the operator increase logic would not count the meta that have
the type 'relay.expr.Constant', start_name_idx and stop_name_idx follow the index from
'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase
logic would count the meta.
device_annot: boolean, optional
if we want to annoate the device_type
annot_start_name: str, optional
device annotation start node, from which we mark the nodes as `ext_dev`
annot_end_name: str, optional
device annotation end node, after which we mark the nodes as 'cpu'
Returns
-------
expr : Expr
The transformed expression.
"""
assert isinstance(expr, relay.Function)
assert (
(start_name != stop_name)
or (start_name_idx is None != stop_name_idx is None)
or (not (start_name_idx is None and stop_name_idx is None))
or (start_name_idx < stop_name_idx)
)
expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
expr = run_opt_pass(expr, transform.InferType())
packer = ExprPack(bfactor, cfactor, weight_bits)
expr = packer.visit(expr)
assert not packer.start_pack
expr = run_opt_pass(expr, transform.InferType())
if device_annot:
expr_locator = ExprLocator()
expr_locator.visit(expr)
annot_start = op.op.get(annot_start_name)
start = expr_locator.op2nodes[(annot_start, "int32")][0]
annot_end = op.op.get(annot_end_name)
# we mark the next op to the last stop_fusion on cpu device
end = expr_locator.op2nodes[(annot_end, "int8")][-1] + 1
device_annot = ExprDeviceAnnot(start=start, end=end)
expr = device_annot.visit(expr)
return run_opt_pass(expr, transform.InferType())
return expr