blob: 9610c8dd3cdc422a6a2f0847fa36796c3919748f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract information from the identity operator in TIR."""
from typing import Tuple
import tvm
from .spec import (
SerialBlockConfig,
SerialKernel,
SerialActivation,
SerialPooling,
SerialPadding,
SerialFeatureMap,
)
from .utils import get_op_attrs, get_base_address, get_strides, get_loads
from .producers_consumers import ProducersConsumers
def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatureMap, tvm.tir.Var]:
"""Get the feature map parameters from a loop nest of any shape (as long there are at
most 4 nested loops).
Parameters
----------
stmt : tvm.tir.AttrStmt
The outermost attribute statement of a loop nest.
fm_type: str
Either "ifm" or "ofm", depending on whether it is an input or output feature map
Returns
-------
SerialFeatureMap
The serializable feature map.
output_pointer : tvm.tir.Var
The pointer produced by the operation.
"""
assert fm_type in ("ifm", "ofm")
attrs, body = get_op_attrs(stmt)
loops = []
inner = body
# extact the loops and the innermost statement
while hasattr(inner, "body"):
loops.append(inner)
inner = inner.body
# If the batch size loop is present, we need to remove it
if len(loops) > 3:
assert loops[0].extent == 1
loops = loops[1:]
fm_inner = inner.value if fm_type == "ifm" else inner
# Needed for stride calculation, can replace with
# inner.value.buffer.strides in future.
assert len(fm_inner.indices) == 1, "Ethos-U passes expect flattened buffers"
stride_vars = [l.loop_var for l in loops]
strides = get_strides(fm_inner.indices[0], stride_vars)
base_address = [get_base_address(index) for index in fm_inner.indices]
data_type = inner.buffer.data.type_annotation.element_type.dtype
serial_feature_map = SerialFeatureMap(
data_type=data_type,
height=loops[0].extent,
width=loops[1].extent if len(loops) > 1 else 1,
channels=loops[2].extent if len(loops) > 2 else 1,
tile_height_0=loops[0].extent,
tile_height_1=0,
tile_width_0=loops[1].extent if len(loops) > 1 else 1,
tile_address_0=tvm.tir.BufferLoad(fm_inner.buffer, base_address),
tile_address_1=0,
tile_address_2=0,
tile_address_3=0,
scale=attrs["scale"],
zero_point=attrs["zero_point"],
layout="NHWC",
stride_h=strides[0] if len(strides) > 0 else 1,
stride_w=strides[1] if len(strides) > 1 else 1,
stride_c=strides[2] if len(strides) > 2 else 1,
)
output_pointer = inner.buffer.data
return serial_feature_map, output_pointer
def get_identity_params(
stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers
) -> Tuple[SerialPooling, tvm.tir.Var, tvm.tir.Var]:
"""Get the parameters necessary to construct a call_extern for an identity pooling.
Parameters
----------
stmt : tvm.tir.AttrStmt
The outermost attribute statement of an identity pooling loop nest.
producers_consumers: ProducersConsumers
It associates pointers with the loop nest that produces
their values and with the loop nest that consumes their values.
Returns
-------
SerialPooling
The parameters needed to construct a 2D pooling.
output_pointer : tvm.tir.Var
The output pointer of the pooling operation.
replace_pointer : tvm.tir.Var
The output pointer of the DMA write operation, which is to replace
the pooling output pointer.
is_allocator : bool
Whether this operator allocates its output.
"""
attrs, _ = get_op_attrs(stmt)
# Find the inner loop
store = stmt
while hasattr(store, "body"):
store = store.body
# loads = [input, LUT, LUT]
loads = get_loads(store)
input_pointer = loads[0].buffer.data
output_pointer = store.buffer.data
read = producers_consumers.get_producer(input_pointer, stmt)
write = producers_consumers.get_consumer(output_pointer, stmt)
serial_ifm, _ = _get_feature_map(read, "ifm")
serial_ofm, write_output_pointer = _get_feature_map(write, "ofm")
replace_pointer = write_output_pointer
is_allocator = True
producer = producers_consumers.get_producer(write_output_pointer, write)
if producer is None or producer != write:
is_allocator = False
# TODO: We might want to support stand alone ReLU in the future by adding clip_min and
# clip max attributes to the identity operator
serial_activation = SerialActivation(op=attrs["activation"], clip_min=0, clip_max=0)
# Create a serialized identity pooling to be run on the NPU
return (
SerialPooling(
ifm=serial_ifm,
ofm=serial_ofm,
pooling_type="AVG",
pool_shape=SerialKernel(1, 1, 1, 1, 1, 1),
padding=SerialPadding(0, 0, 0, 0),
activation=serial_activation,
upscale="NONE",
rounding_mode=attrs["rounding_mode"],
block_config=SerialBlockConfig(0, 0, 0),
),
output_pointer,
replace_pointer,
is_allocator,
)