blob: a38b2172671bba292818a6a242760a04f5684148 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace for supporting packed_conv2d + ewise variant of nnvm."""
from __future__ import absolute_import as _abs
import logging
import tvm
import topi
from nnvm.top import registry as reg, OpPattern
from nnvm.top import nn as _nn
from .vta_conv2d import is_packed_layout
from ..environment import get_env
@tvm.register_func("nnvm.compiler.build_target", override=True)
def _build(funcs, target, target_host):
tvm_t = tvm.target.create(target)
if tvm_t.device_name == "vta":
return tvm.build(funcs, target="ext_dev", target_host=target_host)
if tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu":
return tvm.build(funcs, target=target_host)
return tvm.build(funcs, target=target)
@tvm.register_func("nnvm.compiler.lower", override=True)
def _lower(sch, inputs, func_name, graph):
import traceback
# pylint: disable=broad-except
try:
f = tvm.lower(sch, inputs, name=func_name)
if "quantized_conv2d" in func_name:
logging.info(graph.ir(join_entry_attrs=["shape"]))
except Exception:
msg = traceback.format_exc()
msg += "Error during compile graph\n"
msg += "--------------------------\n"
msg += graph.ir(join_entry_attrs=["shape"])
raise RuntimeError(msg)
return f if isinstance(
f, (tvm.container.Array, tuple, list)) else [f]
# override to force partition at copy
reg.register_pattern("copy", OpPattern.INJECTIVE, level=15)
@reg.register_compute("clip", level=15)
def compute_clip(attrs, inputs, _):
""" Clip operator. """
x = inputs[0]
a_min = attrs.get_float("a_min")
a_max = attrs.get_float("a_max")
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
with tvm.tag_scope(topi.tag.ELEMWISE):
x = tvm.compute(
x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(
x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
@reg.register_compute("conv2d", level=15)
def compute_conv2d(attrs, inputs, out):
""" Compute definition of conv2d """
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs["layout"]
out_dtype = attrs['out_dtype']
assert dilation == (1, 1), "not support dilate now"
if is_packed_layout(layout):
if groups == 1:
assert groups == 1
env = get_env()
assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now"
inputs = list(inputs)
assert inputs[1].dtype == "int8"
return topi.nn.conv2d(inputs[0], inputs[1], strides,
padding, dilation, layout, out_dtype)
return topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides,
padding, dilation, groups, out_dtype)
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.compute_conv2d(attrs, inputs, out)
@reg.register_schedule("conv2d", level=15)
def schedule_conv2d(attrs, outs, target):
""" Schedule definition of conv2d """
layout = attrs["layout"]
groups = attrs.get_int('groups')
if is_packed_layout(layout):
target = tvm.target.create(target)
if target.device_name == "vta":
if groups == 1:
return topi.generic.schedule_conv2d_nchw(outs)
return topi.generic.schedule_group_conv2d_nchw(outs)
elif str(target).startswith("llvm"):
return tvm.create_schedule([x.op for x in outs])
else:
raise RuntimeError("not support target %s" % target)
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.schedule_conv2d(attrs, outs, tvm.target.current_target())
@reg.register_alter_op_layout("conv2d", level=15)
def alter_conv2d_layout(attrs, inputs, out):
layout = attrs['layout']
if is_packed_layout(layout):
return None
with tvm.target.arm_cpu(tvm.target.current_target().model):
return _nn.alter_conv2d_layout(attrs, inputs, out)