blob: 427001ffa5ed45624c1c21e0898c4e6179aa49d5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""An NNVM implementation of graph packing."""
import nnvm
from nnvm.compiler import graph_attr, graph_util
def _pack_batch_channel(data, dshape, bfactor, cfactor):
"""Pack the data channel dimension.
"""
assert dshape[0] % bfactor == 0
assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // bfactor, bfactor,
dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(0, 2, 4, 5, 1, 3))
return data
def _unpack_batch_channel(data, old_shape):
"""Unpack the data channel dimension.
"""
data = nnvm.sym.transpose(data, axes=(0, 4, 1, 5, 2, 3))
data = nnvm.sym.reshape(data, shape=old_shape)
return data
def _pack_weight(data, dshape, cfactor):
"""Pack the weight into packed format.
"""
assert len(dshape) == 4
assert dshape[0] % cfactor == 0
assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // cfactor, cfactor,
dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(0, 2, 4, 5, 1, 3))
return data
def _pack_weight_conv2d_transpose(data, dshape, cfactor):
"""Pack the weight into packed format.
"""
assert len(dshape) == 4
assert dshape[0] % cfactor == 0
assert dshape[1] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // cfactor, cfactor,
dshape[1] // cfactor, cfactor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(2, 0, 4, 5, 3, 1))
return data
def _pack_bias(data, dshape, bfactor, cfactor):
"""Pack the bias parameter.
"""
assert len(dshape) == 3
assert dshape[0] % cfactor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // cfactor,
cfactor, dshape[1],
dshape[2], 1))
data = nnvm.sym.transpose(
data, axes=(0, 2, 3, 4, 1))
# broadcast batch dimension to bfactor
data = nnvm.sym.broadcast_to(
data,
shape=(dshape[0] // cfactor, dshape[1], dshape[2], bfactor, cfactor))
return data
def _get_shape(sym, shape_dict):
"""Get the shape of a node.
"""
return graph_util.infer_shape(
nnvm.graph.create(sym), **shape_dict)[1][0]
def nnvm_graph_pack(graph,
shape_dict,
bfactor,
cfactor,
weight_bits,
start_name="max_pool2d0",
stop_name="global_avg_pool2d0"):
"""Pack the graph into batch&channel packed format.
Parameters
----------
graph : Graph
The input graph.
shape_dict : dict of str to shape
The input shape.
bfactor : int
The packing factor in batch
cfactor : int
The packing factor in channel
start_name: str, optional
Start packing from certain known node.
start_name: str, optional
Stop packing from certain known node.
Returns
-------
graph : Graph
The transformed graph.
"""
graph = graph_attr.set_shape_inputs(graph, shape_dict)
graph = graph.apply("InferShape")
shape = graph.json_attr("shape")
gidx = graph.index
node_map = {}
dset = set()
start_pack = False
for nid, node in enumerate(gidx.nodes):
children = [node_map[e[0]] for e in node["inputs"]]
ishape = [shape[gidx.entry_id(e)] for e in node["inputs"]]
oshape = shape[gidx.entry_id(nid, 0)]
attrs = node.get("attrs", {})
node_name = node["name"]
op_name = node["op"]
get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
*c, name=n_n, **a)
if op_name == "null":
new_node = nnvm.symbol.Variable(node_name)
if start_name and node_name == start_name:
start_pack = True
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
if start_pack and "_begin_state_" in node_name: # RNN -> CNN, pack
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
elif node_name == start_name:
assert not start_pack
start_pack = True
new_node = get_clone(children, op_name, node_name, attrs)
new_node = _pack_batch_channel(new_node, oshape, bfactor, cfactor)
elif node_name == stop_name:
if start_pack:
start_pack = False
children[0] = _unpack_batch_channel(children[0], ishape[0])
new_node = getattr(nnvm.symbol, op_name)(
*children, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name == "conv2d" and attrs.get("out_dtype", None) == "int32":
assert 8 % weight_bits == 0
w_lanes = 8 // weight_bits
if start_pack:
attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
attrs["kernel_layout"] = "OIHW%do%di%dp" % (cfactor, cfactor, w_lanes)
data, weight = children
weight = _pack_weight(weight, ishape[1], cfactor)
# insert bit packing when necessary
if w_lanes != 1:
assert 8 % w_lanes == 0
weight = nnvm.sym.bitpack(weight, lanes=w_lanes)
new_node = nnvm.sym.conv2d(
data, weight, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name == "conv2d_transpose" and attrs.get("out_dtype", None) == "int32":
assert 8 % weight_bits == 0
w_lanes = 8 // weight_bits
if start_pack:
attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
attrs["kernel_layout"] = "IOHW%di%do%dp" % (cfactor, cfactor, w_lanes)
data, weight = children
weight = _pack_weight_conv2d_transpose(weight, ishape[1], cfactor)
new_node = nnvm.sym.conv2d_transpose(
data, weight, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("broadcast_") and tuple(ishape[0]) == tuple(ishape[1]):
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("broadcast") and len(ishape[1]) == 3:
if start_pack:
children[1] = _pack_bias(children[1], ishape[1], bfactor, cfactor)
new_node = getattr(nnvm.symbol, op_name)(
*children, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("elementwise_add"):
new_node = get_clone(children, op_name, node_name, attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
dset.add(op_name)
node_map[nid] = new_node
assert len(graph.index.output_entries) == 1
ret = node_map[graph.index.output_entries[0][0]]
if start_pack:
oshape = shape[graph.index.output_entries[0][0]]
ret = _unpack_batch_channel(ret, oshape)
graph = nnvm.graph.create(ret)
graph = graph_attr.set_shape_inputs(graph, shape_dict)
graph = graph.apply("InferShape")
return graph