blob: 5eada417535452ba90bb233999a81a302cff862f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks
# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda
# pylint: disable=missing-function-docstring, redefined-builtin, use-implicit-booleaness-not-comparison
"""PT: PyTorch frontend."""
import functools
import itertools
import math
import re
import sys
import numpy as np
import tvm
from tvm.ir import IRModule
from tvm.topi.utils import get_const_tuple
from .. import analysis as _analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import qnn, transform
from ..expr_functor import ExprMutator
from ..loops import while_loop
from ..prelude import Prelude, StaticTensorArrayOps
from ..ty import Any, TensorType, TupleType
from . import qnn_torch
from .common import AttrCvt, get_relay_op, gru_cell, logger, rnn_cell
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
from .common import lstm_cell, try_infer_value, unbind, fold_constant
from .common import set_span
from .pytorch_utils import is_version_greater_than, getattr_attr_name
__all__ = ["from_pytorch"]
# This returns a "subgraph" which puts variables whenever
# the type is known. It also records things to map the input
# nodes to the extracted graph's nodes.
# As Python objects are not round-trippable through C++, and
# our type annotations only live in Python, we need to map
# the we need to map the nodes we get in visiting to the nodes
# we used to construct the graph (they are the same in C++,
# match each other in dictionary lookups, but are not the same
# in Python) by using the hint dictionary filled as
# {node: node for node in nodes} to get the type annotations.
# https://discuss.tvm.apache.org/t/round-tripping-objects-through-the-ffi/8440
class _TypeFinder(ExprMutator):
def __init__(self, types):
super().__init__()
self.counter = 0
self.vars = {}
self.types = types
self.leave = set() # some variables are not inputs
def visit_let(self, let):
self.leave.add(let.var)
return super().visit_let(let)
def visit_function(self, fn):
self.leave.update(fn.params)
return super().visit_function(fn)
def visit(self, expr):
if expr in self.leave:
return super().visit(expr)
if expr in self.vars:
return self.vars[expr]
if isinstance(expr, tvm.relay.Var):
self.vars[expr] = expr
return expr
if expr in self.types:
ty = self.types[expr]
v = tvm.relay.var(f"_{self.counter}", type_annotation=ty)
self.counter += 1
self.vars[expr] = v
return v
v = super().visit(expr)
return v
def _should_construct_dynamic_list(list_construct_node):
# if this list is element-accessed or modified at runtime, generate List ADT
def inplace_add_to_add(op_name):
if op_name == "aten::add_":
return "aten::add"
else:
return op_name
uses = _get_uses(list_construct_node)
for loop_use in filter(lambda use: use.user.kind() == "prim::Loop", uses):
block_input_index = loop_use.offset - 1
block = list(loop_use.user.blocks())[0]
list_loop_var = list(block.inputs())[block_input_index]
uses += _get_uses(list_loop_var.node())
op_names = map(inplace_add_to_add, set(use.user.kind() for use in uses))
list_ops = set(["aten::add", "aten::__getitem__"])
intersect = list_ops.intersection(op_names)
if len(intersect) > 0 and intersect != set(["aten::add"]):
return True
# if add op outputs list, it is dynamic so we need to construct List ADT
for use in filter(lambda use: use.user.kind() in ["aten::add", "aten::add_"], uses):
output_type = _get_node_type(use.user)
if output_type == "ListType":
return True
return False
def _is_int_seq(seq):
# TODO (t-vi): handle non-int constants? (like numpy.intXX)
return len(seq) > 0 and all([isinstance(i, int) for i in seq])
# operator implementation
class PyTorchOpConverter:
"""A helper class for holding PyTorch op converters."""
def __init__(self, prelude, default_dtype, use_parser_friendly_name=False):
self.prelude = prelude
self.default_dtype = default_dtype
self.create_convert_map()
self.types = {} # map from nodes to (Relay) type annotations
self.source_map = {} # map from graph node to its source name
self.op_type_dict = {} # map from op type to its presenting order
self.current_op = [] # stack for recording current processing op
self.use_parser_friendly_name = use_parser_friendly_name
# this incrementally infers the type, see the comments on the type visitor
# above.
def infer_type(self, node, mod=None):
"""An incremental method to infer the type of a node in the relay graph."""
if node in self.types:
return self.types[node]
if isinstance(node, tvm.relay.Var):
return node.type_annotation
tf = _TypeFinder(types=self.types)
new_node = tf.visit(node)
fn = _function.Function(list(tf.vars.values()), new_node)
new_mod = IRModule({"main": fn})
if mod is not None:
new_mod.update(mod)
new_mod = transform.RemoveUnusedFunctions()(new_mod)
new_mod = transform.InferType()(new_mod)
entry = new_mod["main"]
ty = entry.body.checked_type
self.types[node] = ty
return self.types[node]
def infer_type_with_prelude(self, val):
body = self.infer_type(val, self.prelude.mod)
return body
# list ADT utilities
def convert_to_list_adt(self, py_lst):
elem_tys = [self.infer_type_with_prelude(elem) for elem in py_lst]
msg = "List elements should have identical types"
assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg
# get_type returns type_name, ctor1, ..., ctorN
# 1 is nil
_, cons, nil = self.prelude.mod.get_type("List")
adt_lst = nil()
for elem in reversed(py_lst):
adt_lst = cons(elem, adt_lst)
return adt_lst
def map_tensor_array_constructor(self, adt_lst, shape):
static_tensor_array_ops = StaticTensorArrayOps(self.prelude, "float32", shape)
static_tensor_array_ops.register()
tensor_create = self.prelude.get_tensor_ctor_static("tensor_constructor", "float32", shape)
return self.prelude.map(tensor_create, adt_lst)
def convert_to_tensor_array(self, adt_lst):
_, cons, nil = self.prelude.mod.get_type("List")
if self.prelude.length(adt_lst) == 0:
return nil()
checked_type = self.infer_type_with_prelude(self.prelude.hd(adt_lst))
shape = checked_type.shape
tensor_array = self.map_tensor_array_constructor(adt_lst, shape)
return tensor_array, tuple(shape)
def infer_shape(self, inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
typ = self.infer_type(inputs, mod=mod)
if hasattr(typ, "shape"):
# Regular operator that outputs tensors
return get_const_tuple(typ.shape)
# The return type is not a tensor, for example List
return typ
def infer_shape_with_prelude(self, inputs):
return self.infer_shape(inputs, mod=self.prelude.mod)
def is_empty_shape(self, shape):
rank = len(shape)
if rank:
is_empty = False
for i in range(rank):
if shape[i] == 0:
is_empty = True
break
return is_empty
else:
return True
def record_output_type(self, output):
if isinstance(output, tuple):
cleaned_output = [o for o in output if o is not None]
types = self.infer_type_with_prelude(_expr.Tuple(cleaned_output))
for o, t in zip(cleaned_output, types.fields):
self.types[o] = t
elif isinstance(output, _expr.Expr):
self.infer_type_with_prelude(output)
# it can also happen that the type is int or so
def pytorch_promote_types(self, inputs, dtypes):
"""This promotes TVM inputs with TVM dtypes passed like PyTorch would"""
actual_dtypes = []
for i, inp in enumerate(inputs):
if isinstance(inp, _expr.Expr):
idt = self.infer_type(inp).dtype
actual_dtypes.append(idt)
else:
actual_dtypes.append(dtypes[i])
dtypes = actual_dtypes
tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not np.isscalar(inp)]
non_tensor_inputs = [inp for inp in inputs if np.isscalar(inp)]
result_type = _pytorch_result_type(tensor_dtypes, non_tensor_inputs)
results = []
for inp, dt in zip(inputs, dtypes):
if np.isscalar(inp):
results.append(_expr.const(inp, dtype=result_type))
elif dt == result_type:
results.append(inp)
else:
results.append(_op.cast(inp, result_type))
return results
def is_quantized_tensor(self, data):
# If a quantized Torch module is saved and loaded back, dtype will be dropped
# Since dtypes from Torch tensors are not reliable in such cases, we use
# Relay's type inference result to decide if an input tensor is quantized
ty = self.infer_type_with_prelude(data)
return ty.dtype == "uint8"
# Operator implementations
def make_elemwise(self, name):
def elemwise(inputs, input_types):
if name == "divide":
# https://pytorch.org/docs/stable/generated/torch.div.html#torch.div
# None - default behavior. Performs no rounding and, if both input and
# other are integer types, promotes the inputs to the default scalar type.
if all(["int" in input_type for input_type in input_types[:2]]):
input_types[:2] = ["float32"] * 2
cast_inputs = []
for inp in inputs[:2]:
if np.isscalar(inp):
cast_inputs.append(_expr.const(inp, dtype="float32"))
else:
cast_inputs.append(_op.cast(inp, "float32"))
inputs[:2] = cast_inputs
data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
return get_relay_op(name)(data0, data1)
return elemwise
def min_max_common(self, name_elemwise, name_reduce, inputs, input_types):
if len(inputs) == 1:
data = self.pytorch_promote_types(inputs[:1], input_types[:1])
return get_relay_op(name_reduce)(data[0])
elif len(inputs) >= 2 and isinstance(inputs[1], (list, int)):
data = self.pytorch_promote_types(inputs[:1], input_types[:1])
dim = inputs[1]
keepdims = inputs[2] if len(inputs) > 2 else False
# also return dummy indices
return get_relay_op(name_reduce)(data[0], axis=dim, keepdims=keepdims), None
else:
data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
return get_relay_op(name_elemwise)(data0, data1)
def max(self, inputs, input_types):
return self.min_max_common("maximum", "max", inputs, input_types)
def min(self, inputs, input_types):
return self.min_max_common("minimum", "min", inputs, input_types)
def maximum(self, inputs, input_types):
data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
return _op.maximum(data0, data1)
def minimum(self, inputs, input_types):
data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
return _op.minimum(data0, data1)
def make_unary(self, name):
def unary(inputs, input_types):
# this is just to ensure tensor input
(data,) = self.pytorch_promote_types(inputs[:1], input_types[:1])
return get_relay_op(name)(data)
return unary
def log1p(self, inputs, input_types):
# 1_plus_log x = log(x + 1)
(dtype,) = input_types
one = _expr.const(1, dtype=dtype)
return _op.log(inputs[0] + one)
def square(self, inputs, input_types):
(dtype,) = input_types
return _op.power(inputs[0], _expr.const(2, dtype))
def lerp(self, inputs, input_types):
if len(inputs) != 3:
msg = f"Wrong number of arguments ({len(inputs)}) to parse."
raise AssertionError(msg)
start = inputs[0]
end = inputs[1]
weight = inputs[2]
return start + weight * (end - start)
def arange(self, inputs, input_types):
def _get_value(val, dtype):
# dtype is a tvm dtype
if isinstance(val, _expr.Expr):
# since "arange" op will fill expr into its attribute
# invoke set_span here to prevent expr-rewritten occurrs in span-filling stage
source_name = self.source_map[self.current_op[-1]]
inp = set_span(_op.cast(val, dtype), source_name)
ret, _ = try_infer_value(inp, lambda ret: _expr.const(ret, dtype))
else:
ret = _create_typed_const(val, dtype)
return ret
def _get_type(val, inp_type):
if isinstance(val, _expr.Expr):
dtype = str(self.infer_type(val))
return dtype
return inp_type
# PyTorch arange uses the following type semantics:
# - if a dtype is given, start, stop, step are converted to that dtype
# - if no dtype is given and all args are integral, dtype is int64
# - if no dtype is given and there is a float arg, dtype is float32
if len(inputs) in {5, 6, 7}:
# inputs look like [_,_,_,dtype,layout,device,requires_grad]
# therefore dtype_idx is always the length of inputs minus 4
dtype_idx = len(inputs) - 4
types = [_get_type(inputs[i], input_types[i]) for i in range(dtype_idx)]
if inputs[dtype_idx] is not None:
dtype = _convert_dtype_value(inputs[dtype_idx])
elif any([t.startswith("float") for t in types]):
dtype = "float32"
else:
dtype = "int64"
# - if len(inputs) == 5, inputs = [stop, dtype, ...]
# - if len(inputs) == 6, inputs = [start, stop, dtype, ...]
# - if len(inputs) == 7, inputs = [start, stop, step, dtype, ...]
start = _get_value(inputs[0], dtype) if len(inputs) > 5 else _expr.const(0, dtype)
stop = _get_value(inputs[1 if len(inputs) > 5 else 0], dtype)
step = _get_value(inputs[2], dtype) if len(inputs) > 6 else _expr.const(1, dtype)
else:
msg = f"Unknown number of arguments ({len(inputs)}) to parse."
raise AssertionError(msg)
return _op.transform.arange(start=start, stop=stop, step=step, dtype=dtype)
def squeeze(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 1:
axis = None
else:
# TODO (t-vi): why is the cast to int needed? similarly elsewhere
inputs = [inputs[1]] if not isinstance(inputs[1], list) else inputs[1]
axis = [int(v) for v in inputs]
return _op.transform.squeeze(data, axis)
def unsqueeze(self, inputs, input_types):
data = inputs[0]
axis = inputs[1]
return _op.transform.expand_dims(data, int(axis), 1)
def concatenate(self, inputs, input_types):
def tensor_array_concat(lst, axis):
assert axis == 0, "Tensor array concat supported only for axis 0"
tensor_array, shape = self.convert_to_tensor_array(lst)
concat_shape = (Any(),) + shape[1:]
concat = self.prelude.get_global_var_static("tensor_array_concat", "float32", shape)
concatenated = concat(tensor_array)
static_tensor_array_ops = StaticTensorArrayOps(self.prelude, "float32", concat_shape)
static_tensor_array_ops.register()
get_tensor = self.prelude.get_global_var_static(
"tensor_get_data", "float32", concat_shape
)
return get_tensor(concatenated)
data = inputs[0]
axis = inputs[1]
if not isinstance(data, list):
return tensor_array_concat(data, axis)
if isinstance(data, _expr.Expr):
data = [data]
return _op.tensor.concatenate(data, int(axis))
def slice(self, inputs, input_types):
axis_dtype = "int64"
index_size_limit = sys.maxsize
data = inputs[0]
dshape = self.infer_shape(data)
ndim = len(dshape)
dim = int(inputs[1])
stride = inputs[4]
target_begin, is_begin_const = try_infer_value(
inputs[2], lambda ret: ret.astype(np.int).item(0)
)
target_end, is_end_const = try_infer_value(
inputs[3], lambda ret: ret.astype(np.int).item(0)
)
# A fast path when slicing is nop.
if (
isinstance(target_begin, int)
and isinstance(target_end, int)
and target_begin == 0
and target_end >= index_size_limit
and stride == 1
):
return data
if target_begin is None and target_end is None:
return data
# Process begin
begin = [0] * ndim
if target_begin is not None:
begin[dim] = target_begin
if target_begin is not None and not isinstance(begin[dim], int):
tmp = []
for b in begin:
if isinstance(b, int):
tmp.append(_op.expand_dims(_expr.const(b, axis_dtype), axis=0))
else:
tmp.append(_op.cast(_op.expand_dims(b, axis=0), axis_dtype))
begin = _op.concatenate(tmp, axis=0)
btype = self.infer_type(begin).dtype
if str(btype) != axis_dtype:
begin = _op.cast(begin, axis_dtype)
# Process end
if isinstance(target_end, int) and target_end >= index_size_limit:
target_end = dshape[dim]
if any([isinstance(d, tvm.tir.Any) for d in dshape]):
end = _op.shape_of(data)
else:
end = dshape
if isinstance(target_end, int):
if isinstance(end, list):
end[dim] = target_end
else:
all_static = True
for i, shape_dim in enumerate(dshape):
if i != dim and isinstance(shape_dim, tvm.tir.Any):
all_static = False
if all_static:
end = list(get_const_tuple(dshape))
end[dim] = target_end
else:
target_end = _expr.const(target_end)
end = _op.scatter_elements(
end,
_op.expand_dims(_expr.const(dim), axis=0),
_op.expand_dims(target_end, axis=0),
axis=0,
)
else:
end = _op.cast(_op.shape_of(data), axis_dtype)
if target_end is not None and not isinstance(target_end, tvm.tir.Any):
ttype = self.infer_type(target_end).dtype
if str(ttype) != axis_dtype:
target_end = _op.cast(target_end, axis_dtype)
end = _op.scatter_elements(
end,
_op.expand_dims(_expr.const(dim), axis=0),
_op.expand_dims(target_end, axis=0),
axis=0,
)
if not isinstance(end, list):
etype = self.infer_type(end).dtype
if str(etype) != axis_dtype:
end = _op.cast(end, axis_dtype)
strides = [1] * ndim
strides[dim] = stride
return _op.transform.strided_slice(
data, begin=begin, end=end, strides=strides, slice_mode="end"
)
def narrow(self, inputs, input_types):
# Inputs are:
# 0 - the tensor to narrow
# 1 - the dimension along which to narrow
# 2 - the starting dimension
# 3 - the distance to the ending dimension
# Lets find the ending dimension
end = self.add(inputs[2:4], input_types[2:4])
stride = 1
slice_input = inputs[:3] + [end, stride]
slice_types = input_types + ["int32"]
return self.slice(slice_input, slice_types)
def split(self, inputs, input_types):
data = inputs[0]
split_size = int(inputs[1])
dim = int(inputs[2])
split_index = split_size
indices = []
while split_index < self.infer_shape(data)[dim]:
indices.append(split_index)
split_index += split_size
return _op.split(data, indices, dim)
def split_with_sizes(self, inputs, input_types):
data = inputs[0]
sections = inputs[1]
dim = int(inputs[2])
if len(sections) == 1:
# a special case used in torchvision detection models
return _expr.TupleWrapper(_expr.Tuple([data]), 1)
split_index = 0
indices = []
for i in range(len(sections) - 1):
index, _ = try_infer_value(sections[i], lambda ret: int(ret))
split_index += index
indices.append(split_index)
return _op.split(data, indices, dim)
def tensor_split(self, inputs, input_types):
# Reference: https://pytorch.org/docs/stable/generated/torch.tensor_split.html
import torch
if not isinstance(inputs[1], (int, list, tuple, torch.Tensor)):
msg = (
f"indices_or_sections type {type(inputs[1])} could not be parsed in "
f"tensor_split op"
)
raise AssertionError(msg)
if isinstance(inputs[1], torch.Tensor) and not (
list(inputs[1].shape) == [] or list(inputs[1].shape) == 1
):
msg = "indices_or_sections must be a zero-dimensional or one-dimensional long tensor"
raise AssertionError(msg)
if isinstance(inputs[1], int) or (
isinstance(inputs[1], torch.Tensor) and list(inputs[1].shape) == []
):
data = inputs[0]
n = int(inputs[1])
dim = int(inputs[2])
split_size = int(self.infer_shape(data)[dim] / n)
split_rest = int(self.infer_shape(data)[dim] % n)
indices = []
split_index = split_size
if split_rest == 0:
for i in range(n - 1):
indices.append(split_index)
split_index += split_size
else:
for i in range(split_rest):
indices.append(split_index + 1)
split_index = (i + 1) * (split_index + 1)
for i in range(n - split_rest - 1):
split_index += split_size
indices.append(split_index)
return _op.split(data, indices, dim)
else:
data = inputs[0]
sections = inputs[1]
dim = int(inputs[2])
if isinstance(sections, tuple):
sections = list(sections)
elif isinstance(sections, torch.Tensor):
sections = sections.cpu().numpy().tolist()
return _op.split(data, sections, dim)
def select(self, inputs, input_types):
data = inputs[0]
dim = int(inputs[1])
index = _wrap_const(inputs[2])
return _op.transform.take(data, index, axis=dim, mode="wrap")
def take(self, inputs, input_types):
data = inputs[0]
indices = _op.cast(inputs[1], "int32")
return _op.transform.take(data, indices=indices, mode="wrap")
def topk(self, inputs, input_types):
data = inputs[0]
axis = int(inputs[2])
is_ascend = not bool(inputs[3])
sort = bool(inputs[4])
if isinstance(inputs[1], _expr.Expr):
k, _ = try_infer_value(inputs[1], lambda ret: ret.tolist())
else:
k = inputs[1]
if not sort:
msg = "Currently supports only sorted output for topk operator."
raise AssertionError(msg)
outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, ret_type="both", dtype="int64")
return outs[0], outs[1]
def reciprocal(self, inputs, input_types):
data = inputs[0]
return _expr.const(1.0, dtype=input_types[0]) / data
def repeat(self, inputs, input_types):
data = inputs[0]
reps = []
for r in inputs[1]:
if isinstance(r, int):
reps.append(r)
else:
reps.append(int(_infer_value(r, {}).numpy()))
return _op.transform.tile(data, reps=reps)
def repeat_interleave(self, inputs, input_types):
data = inputs[0]
if isinstance(inputs[1], int):
repeats = inputs[1]
axis = inputs[2]
elif isinstance(inputs[1], _expr.Expr):
if isinstance(inputs[1], _expr.Constant):
repeats = int(inputs[1].data.numpy())
else:
repeats, _ = try_infer_value(inputs[1], lambda ret: ret.tolist())
axis = inputs[2]
else:
msg = "Only repeat with one value as repeat is currently supported."
raise AssertionError(msg)
if axis is None: # Flatten the data if no axis is given from torch
data = _op.transform.reshape(data, [-1])
axis = 0
return _op.transform.repeat(data, repeats=repeats, axis=axis)
def addcdiv(self, inputs, input_types):
data, t1, t2, c = self.pytorch_promote_types(inputs[:4], input_types[:4])
return data + (c * (t1 / t2))
def addcmul(self, inputs, input_types):
data, t1, t2, c = self.pytorch_promote_types(inputs[:4], input_types[:4])
return data + (c * (t1 * t2))
def where(self, inputs, input_types):
if len(inputs) == 1:
return self.nonzero([inputs[0], True], input_types)
cond = inputs[0]
x, y = self.pytorch_promote_types(inputs[1:3], input_types[1:3])
return _op.where(cond, x, y)
def full_impl(self, data, fill_value, dtype):
size = []
need_reshape = False
new_shape = []
for dim in data:
if isinstance(dim, _expr.Expr):
if isinstance(dim, _expr.Constant):
dim = int(dim.data.numpy())
if isinstance(size, list):
size.append(dim)
new_shape.append(dim)
else:
dim, success = try_infer_value(dim, lambda ret: int(ret), lambda: 0)
new_shape.append(dim)
if success:
if isinstance(size, list):
size.append(dim)
else:
size = None
need_reshape = True
else:
if isinstance(size, list):
size.append(dim)
new_shape.append(dim)
if size is None:
tmp = []
for dim in data:
tmp.append(_op.cast(_op.expand_dims(dim, axis=0), "int64"))
size = _op.concatenate(tmp, axis=0)
if not isinstance(fill_value, _expr.Constant):
if isinstance(fill_value, _expr.Expr):
fill_value = _infer_value(fill_value, {})
fill_value = _expr.const(fill_value, dtype=dtype)
out = _op.full(fill_value, size, dtype=dtype)
if need_reshape:
out = _op.reshape(out, new_shape)
return out
def ones(self, inputs, input_types):
data = inputs[0]
import torch
if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)):
msg = f"Data type {type(data)} could not be parsed in ones op"
raise AssertionError(msg)
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
else:
dtype = self.default_dtype
return self.full_impl(data, 1, dtype)
def ones_like(self, inputs, input_types):
data = inputs[0]
out = _op.ones_like(data)
# If the input and the output datatype is different, do a cast
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
else:
dtype = self.default_dtype
if input_types[0] != dtype:
out = _op.cast(out, dtype)
return out
def new_ones(self, inputs, input_types):
size = inputs[1]
import torch
if not isinstance(size, (_expr.Expr, list, tuple, torch.Size, np.ndarray)):
msg = f"Data type {type(size)} could not be parsed in ones op"
raise AssertionError(msg)
if inputs[2] is not None:
dtype = _convert_dtype_value(inputs[2])
else:
dtype = input_types[0]
return self.full_impl(size, 1, dtype)
def zeros(self, inputs, input_types):
data = inputs[0]
import torch
if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)):
msg = f"Data type {type(data)} could not be parsed in zeros op"
raise AssertionError(msg)
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
else:
dtype = self.default_dtype
return self.full_impl(data, 0, dtype)
def zero_(self, inputs, input_types):
data = inputs[0]
return self.full_impl(self.infer_shape(data), 0, input_types[0])
def zeros_like(self, inputs, input_types):
data = inputs[0]
out = _op.zeros_like(data)
# If the input and the output datatype is different, do a cast
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
else:
dtype = self.default_dtype
if input_types[0] not in dtype:
out = _op.cast(out, dtype)
return out
def new_zeros(self, inputs, input_types):
data = inputs[1]
import torch
if not isinstance(data, (_expr.Expr, list, tuple, torch.Size)):
msg = f"Data type {type(data)} could not be parsed in new_zeros op"
raise AssertionError(msg)
if inputs[2] is not None:
dtype = _convert_dtype_value(inputs[2])
else:
# if dtype is None, use the dtype of the input tensor
dtype = self.infer_type(inputs[0])
return self.full_impl(data, 0, dtype)
def full(self, inputs, input_types):
data = inputs[0]
fill_value = inputs[1]
import torch
if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)):
msg = f"Data type {type(data)} could not be parsed in full op"
raise AssertionError(msg)
if inputs[2] is not None: # dtype given
dtype = _convert_dtype_value(inputs[2])
else:
# if dtype is None, torch uses a global default set by torch.set_default_tensor_type()
dtype = self.default_dtype
return self.full_impl(data, fill_value, dtype)
def full_like(self, inputs, input_types):
data = inputs[0]
fill_value = inputs[1]
out = _op.full_like(data, _expr.const(fill_value))
# If the input and the output datatype is different, do a cast
if inputs[2] is not None: # dtype given
dtype = _convert_dtype_value(inputs[2])
else:
# if dtype is None, torch uses a global default set by torch.set_default_tensor_type()
dtype = self.default_dtype
if input_types[0] not in dtype:
out = _op.cast(out, dtype)
return out
def new_full(self, inputs, input_types):
data = inputs[1]
fill_value = inputs[2]
import torch
if not isinstance(data, (_expr.Expr, list, tuple, torch.Size)):
msg = f"Data type {type(data)} could not be parsed in full op"
raise AssertionError(msg)
if inputs[3] is not None: # dtype given
dtype = _convert_dtype_value(inputs[3])
else:
# if dtype is None, use the dtype of the input tensor
dtype = self.infer_type(inputs[0])
return self.full_impl(data, fill_value, dtype)
def fill_(self, inputs, input_types):
data = inputs[0]
fill_value = inputs[1]
if not isinstance(fill_value, (bool, int, float, complex)):
fill_value = fold_constant(fill_value)
return self.full_impl(self.infer_shape(data), fill_value, input_types[0])
def linspace(self, inputs, input_types):
start = inputs[0]
stop = inputs[1]
step = inputs[2]
# Find the spacing between values as step
if step != 1:
step = (stop - start) / (step - 1)
stop = stop + step
else:
stop = start + step
if inputs[3] is None:
import torch
dtype = _convert_data_type(str(torch.get_default_dtype()))
else:
dtype = _convert_dtype_value(inputs[3])
start = _create_typed_const(start, dtype)
stop = _create_typed_const(stop, dtype)
step = _create_typed_const(step, dtype)
return _op.transform.arange(start=start, stop=stop, step=step, dtype=dtype)
def relu(self, inputs, input_types):
data = inputs[0]
if self.is_quantized_tensor(data):
assert len(inputs) == 3, "Input quant param not found in op inputs"
input_zero_point = _expr.const(inputs[2], dtype="int32")
return qnn_torch.quantized_relu(data, input_zero_point)
return _op.nn.relu(data)
def relu6(self, inputs, input_types):
data = inputs[0]
return _op.tensor.clip(data, 0.0, 6.0)
def prelu(self, inputs, input_types):
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html#torch.nn.PReLU
data = inputs[0]
dim = self.get_dims(data)
ndims = len(dim)
axis = 0 if ndims == 1 else 1
alpha = _op.broadcast_to(inputs[1], (dim[axis]))
return _op.nn.prelu(data, alpha, axis)
def leaky_relu(self, inputs, input_types):
data = inputs[0]
alpha = float(inputs[1])
return _op.nn.leaky_relu(data, alpha)
def elu(self, inputs, input_types):
data = inputs[0]
dtype = input_types[0]
alpha = _expr.const(-float(inputs[1]), dtype=dtype)
return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data)
def celu(self, inputs, input_types):
data = inputs[0]
dtype = input_types[0]
alpha = _expr.const(float(inputs[1]), dtype=dtype)
zero = _op.const(0, dtype)
return alpha * _op.minimum(
zero, _op.exp(data / alpha) - _expr.const(1, dtype=dtype)
) + _op.nn.relu(data)
def gelu(self, inputs, input_types):
data = inputs[0]
dtype = input_types[0]
# gelu is data * normcdf(data)
# normcdf expressed as erf because we don't currently have that intrinsic
# note that there is also a fastgelu variant approximating normcdf
# with tanh and third order polynomials, but this is "true" gelu
return data * (
_expr.const(0.5, dtype=dtype)
+ _op.erf(data * _expr.const(0.5**0.5, dtype=dtype)) * _expr.const(0.5, dtype=dtype)
)
def selu(self, inputs, input_types):
data = inputs[0]
# https://pytorch.org/docs/stable/nn.html#selu
dtype = input_types[0]
alpha = _expr.const(-1.6732632423543772848170429916717, dtype=dtype)
gamma = _expr.const(1.0507009873554804934193349852946, dtype=dtype)
return gamma * (
alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data)
)
def silu(self, inputs, input_types):
data = inputs[0]
return data * _op.tensor.sigmoid(data)
def glu(self, inputs, input_types):
"""
Applies the gated linear unit function GLU(a,b)= a * sigmoid(b)
where a is the first half of the input matrices and b is the second half.
Link: https://pytorch.org/docs/stable/generated/torch.nn.GLU.html
"""
data = inputs[0]
dim = inputs[1]
relay_tup = _op.transform.split(data, 2, dim)
return relay_tup[0] * _op.tensor.sigmoid(relay_tup[1])
def log_sigmoid(self, inputs, input_types):
data = inputs[0]
mn = _op.minimum(_op.const(0, dtype=input_types[0]), data)
z = _op.exp(-_op.abs(data))
return mn - self.log1p([z], input_types)
def cross_entropy_loss_with_logits(self, inputs, input_types):
input = inputs[0]
target = inputs[1]
weights = inputs[2]
reduction = inputs[3]
ignore_index = inputs[4]
label_smoothing = inputs[5]
input_shape = self.infer_shape(input)
target_shape = self.infer_shape(target)
if input_shape != target_shape:
if reduction == 0:
reduction = "none"
elif reduction == 1:
reduction = "mean"
else:
reduction = "sum"
num_class = self.infer_shape(input)[1]
if weights is None:
weights = _op.full(_expr.const(1), (num_class,), dtype=input_types[0])
return _op.nn.nll_loss(
_op.nn.log_softmax(input), target, weights, reduction, ignore_index
)
assert reduction == 1, "reduction not supported in cross_entropy_loss"
assert ignore_index == -100, "ignore_index not supported in cross_entropy_loss"
assert label_smoothing == 0.0, "label_smoothing not supported in cross_entropy_loss"
assert weights is None, "weight not supported in cross_entropy_loss"
return _op.nn.cross_entropy_with_logits(_op.nn.log_softmax(input), target)
def l1_loss(self, inputs, input_types):
assert len(inputs) == 3
[predictions, targets, reduction] = inputs
delta = _op.abs(_op.subtract(predictions, targets))
if reduction == 0:
# reduction = "none"
return delta
elif reduction == 1:
# reduction = "mean"
return _op.mean(delta)
else:
# reduction = "sum"
return _op.sum(delta)
def mse_loss(self, inputs, input_types):
assert len(inputs) == 3
[predictions, targets, reduction] = inputs
delta = _op.subtract(predictions, targets)
delta = _op.power(delta, _expr.const(2, input_types[0]))
if reduction == 0:
# reduction = "none"
return delta
elif reduction == 1:
# reduction = "mean"
return _op.mean(delta)
else:
# reduction = "sum"
return _op.sum(delta)
def hard_sigmoid(self, inputs, input_types):
def _relu6(x):
return _op.tensor.clip(x, 0.0, 6.0)
def func(x):
return _relu6(x + _expr.const(3.0)) / _expr.const(6.0)
if self.is_quantized_tensor(inputs[0]):
input_scale = _expr.const(inputs[1])
input_zero_point = _expr.const(inputs[2])
# PyTorch seems to use the following output qparams, but accuracy
# is broken if we use this.
# TODO(masahi): Revisit this parameter choice
#
# Taken from src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
# output_scale = _expr.const(0.00390625) # 1.0 / 2^8
# output_zero_point = _expr.const(-128)
output_scale = input_scale
output_zero_point = input_zero_point
data = qnn.op.dequantize(inputs[0], input_scale, input_zero_point, axis=1)
out = func(data)
return qnn.op.quantize(out, output_scale, output_zero_point, out_dtype="uint8")
return func(inputs[0])
def hard_swish(self, inputs, input_types):
data = inputs[0]
return data * self.hard_sigmoid(inputs, input_types)
def adaptive_avg_pool(self, op, inputs, input_types):
data = inputs[0]
output_size = inputs[1]
for i, item in enumerate(output_size):
if isinstance(item, tvm.relay.expr.Constant):
# convert Constant to int
output_size[i] = item.data.numpy()[()]
def func(x):
return op(x, output_size=output_size)
if self.is_quantized_tensor(data):
return qnn_torch.apply_with_upcast(data, func)
return func(data)
def adaptive_max_pool(self, op, inputs, input_types):
data = inputs[0]
output_size = inputs[1]
# returns dummy indices too
return op(data, output_size=output_size), None
@staticmethod
def convert_const_list(data):
if isinstance(data, list):
for i, _ in enumerate(data):
if isinstance(data[i], _expr.Expr):
data[i] = int(_infer_value_simulated(data[i], {}).numpy())
return data
def maxpool_2d(self, inputs, input_types):
data = inputs[0]
pool_size = self.convert_const_list(inputs[1])
strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size)
padding = inputs[3]
dilation = inputs[4]
ceil_mode = int(inputs[5])
return _op.nn.max_pool2d(
data,
pool_size=pool_size,
strides=strides,
dilation=dilation,
padding=padding,
layout="NCHW",
ceil_mode=ceil_mode,
)
def maxpool_2d_with_indices(self, inputs, input_types):
# returns dummy indices too
return self.maxpool_2d(inputs, input_types), None
def maxpool_1d(self, inputs, input_types):
data = inputs[0]
pool_size = inputs[1]
strides = inputs[2] if inputs[2] else pool_size
padding = inputs[3]
dilation = inputs[4]
ceil_mode = int(inputs[5])
return _op.nn.max_pool1d(
data,
pool_size=pool_size,
strides=strides,
dilation=dilation,
padding=padding,
layout="NCW",
ceil_mode=ceil_mode,
)
def maxpool_3d(self, inputs, input_types):
data = inputs[0]
need_squeeze = False
if len(self.get_dims(data)) == 4:
need_squeeze = True
data = _op.expand_dims(data, 0)
pool_size = inputs[1]
strides = inputs[2] if inputs[2] else pool_size
padding = inputs[3]
dilation = inputs[4]
ceil_mode = int(inputs[5])
res = _op.nn.max_pool3d(
data,
pool_size=pool_size,
strides=strides,
dilation=dilation,
padding=padding,
ceil_mode=ceil_mode,
)
return res if not need_squeeze else _op.squeeze(res, [0])
def hardtanh(self, inputs, input_types):
a = inputs[0]
tanh_min = float(inputs[1])
tanh_max = float(inputs[2])
return _op.tensor.clip(a, tanh_min, tanh_max)
def convolution(self, inputs, input_types):
# Use transpose or normal
use_transpose = True if inputs[6] == 1 else False
data = inputs[0]
weight = inputs[1]
bias = inputs[2]
strides = tuple(inputs[3])
padding = tuple(inputs[4])
dilation = tuple(inputs[5])
if isinstance(weight, _expr.Expr):
inferred_shape = self.infer_shape(weight)
weight_shape = []
for infer in inferred_shape:
weight_shape.append(infer)
else:
msg = f"Data type {type(weight)} could not be parsed in conv op"
raise AssertionError(msg)
groups = int(inputs[8])
if use_transpose:
channels = weight_shape[1] * groups
in_channels = weight_shape[0]
else:
channels = weight_shape[0]
in_channels = weight_shape[1]
# Check if this is depth wise convolution
# We need to reshape weight so that Relay could recognize this is depth wise
# weight_shape[1] is always in_channels // groups
# For depthwise, in_channels == groups, so weight_shape[1] == 1
# If groups > 1 but weight_shape[1] != 1, this is group convolution
if groups > 1 and in_channels == 1:
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier) + tuple(weight_shape[2:])
weight = _op.transform.reshape(weight, new_weight_shape)
kernel_size = weight_shape[2:]
use_bias = isinstance(bias, _expr.Expr)
# We are trying to invoke various relay operations through a single conv_op variable.
# However the function signatures for some operations have additional attributes so we
# pass these in along with the standard ones.
additional_arguments = dict()
if use_transpose:
if len(kernel_size) == 3:
conv_op = _op.nn.conv3d_transpose
elif len(kernel_size) == 2:
conv_op = _op.nn.conv2d_transpose
else:
conv_op = _op.nn.conv1d_transpose
output_padding = tuple(inputs[7])
additional_arguments["output_padding"] = output_padding
else:
if len(kernel_size) == 3:
conv_op = _op.nn.conv3d
elif len(kernel_size) == 2:
conv_op = _op.nn.conv2d
else:
conv_op = _op.nn.conv1d
if len(kernel_size) == 3:
data_layout = "NCDHW"
kernel_layout = "OIDHW"
if use_transpose:
# Transposed convolutions have IODHW layout.
kernel_layout = "IODHW"
elif len(kernel_size) == 2:
data_layout = "NCHW"
kernel_layout = "OIHW"
if use_transpose:
# Transposed convolutions have IOHW layout.
kernel_layout = "IOHW"
else:
data_layout = "NCW"
kernel_layout = "OIW"
if use_transpose:
# Transposed convolutions have IOW layout.
kernel_layout = "IOW"
# Conv1d does not currently support grouped convolution so we convert it to conv2d
is_grouped_conv1d = False
if groups > 1 and len(kernel_size) == 1 and not use_transpose:
is_grouped_conv1d = True
conv_op = _op.nn.conv2d
kernel_size = [1] + kernel_size
strides = (1,) + strides
padding = (0,) + padding
dilation = (1,) + dilation
data = _op.expand_dims(data, axis=2)
weight = _op.expand_dims(weight, axis=2)
data_layout = "NCHW"
kernel_layout = "OIHW"
conv_out = conv_op(
data,
weight,
strides=strides,
padding=padding,
dilation=dilation,
groups=groups,
channels=channels,
kernel_size=kernel_size,
data_layout=data_layout,
kernel_layout=kernel_layout,
out_layout="",
out_dtype="",
**additional_arguments,
)
if use_bias:
res = _op.nn.bias_add(conv_out, bias)
else:
res = conv_out
if is_grouped_conv1d:
# Because we conducted grouped conv1d convolution through conv2d we must
# squeeze the output to get the correct result.
res = _op.squeeze(res, axis=[2])
return res
def softmax(self, inputs, input_types):
data = inputs[0]
axis = inputs[1]
if isinstance(axis, str):
axis = int(axis)
return _op.nn.softmax(data, axis=axis)
def threshold(self, inputs, input_types):
data = inputs[0]
threshold_f = float(inputs[1])
threshold_ = _op.full_like(inputs[0], fill_value=_expr.const(threshold_f))
value_f = float(inputs[2])
value = _op.full_like(inputs[0], fill_value=_expr.const(value_f))
return _op.where(_op.greater(data, threshold_), data, value)
def contiguous(self, inputs, input_types):
return inputs[0]
def batch_norm(self, inputs, input_types):
data = inputs[0]
data_type = input_types[0]
channels = self.infer_shape(data)
scale = isinstance(inputs[1], _expr.Expr)
if scale:
gamma = inputs[1]
else:
gamma = _create_typed_const(np.ones([int(channels[1])]), data_type)
center = isinstance(inputs[2], _expr.Expr)
if center:
beta = inputs[2]
else:
beta = _create_typed_const(np.zeros([int(channels[1])]), data_type)
moving_mean = inputs[3]
moving_var = inputs[4]
epsilon = float(inputs[7])
return _op.nn.batch_norm(
data,
gamma,
beta,
moving_mean,
moving_var,
axis=1,
epsilon=epsilon,
center=center,
scale=scale,
)[0]
def instance_norm(self, inputs, input_types):
data = inputs[0]
data_type = input_types[0]
channels = self.infer_shape(data)
running_mean = inputs[3]
running_var = inputs[4]
use_input_stats = inputs[5]
if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2], _expr.Expr):
scale = center = True
weight = inputs[1]
beta = inputs[2]
gamma = weight
else:
scale = center = False
if not scale:
gamma = _create_typed_const(np.ones([int(channels[1])]), data_type)
if not center:
beta = _create_typed_const(np.zeros([int(channels[1])]), data_type)
epsilon = float(inputs[7])
if not use_input_stats:
return _op.nn.batch_norm(
data,
gamma,
beta,
running_mean,
running_var,
axis=1,
epsilon=epsilon,
center=center,
scale=scale,
)[0]
return _op.nn.instance_norm(
data, gamma, beta, axis=1, epsilon=epsilon, center=center, scale=scale
)
def get_dims(self, data):
import torch
if isinstance(data, _expr.Expr):
dims = self.infer_shape(data)
elif isinstance(data, list):
dims = data
elif isinstance(data, (torch.Tensor, np.ndarray)):
dims = data.shape
else:
msg = f"Data type {type(data)} could not be parsed"
raise AssertionError(msg)
return dims
def layer_norm(self, inputs, input_types):
data = inputs[0]
ndims = len(self.get_dims(inputs[1]))
assert ndims == 1, "Support only normalization over last one dimension."
return _op.nn.layer_norm(
data,
gamma=inputs[2],
beta=inputs[3],
axis=-1,
epsilon=float(inputs[4]),
center=True,
scale=True,
)
def group_norm(self, inputs, input_types):
data = inputs[0]
gamma = inputs[2]
beta = inputs[3]
num_groups = inputs[1]
epsilon = float(inputs[4])
return _op.nn.group_norm(
data,
gamma=gamma,
beta=beta,
num_groups=num_groups,
axis=1,
epsilon=epsilon,
center=True,
scale=True,
)
def transpose(self, inputs, input_types):
data = inputs[0]
import torch
if isinstance(data, _expr.Expr):
ndims = len(self.infer_shape_with_prelude(data))
elif isinstance(data, list):
ndims = data
elif isinstance(data, (torch.Tensor, np.ndarray)):
ndims = data.shape
else:
msg = f"Data type {type(data)} could not be parsed in transpose op"
raise AssertionError(msg)
if isinstance(data, tvm.runtime.NDArray):
ndims = len(data.shape)
axes = list(range(ndims))
num_inputs = len(inputs)
if num_inputs == 1:
if ndims >= 2:
axes[-1] = ndims - 2
axes[-2] = ndims - 1
if not isinstance(data, _expr.Expr):
data = _expr.const(data)
elif num_inputs == 3:
parse = lambda i: ndims * (i < 0) + i
src, dst = [parse(int(inputs[i])) for i in [1, 2]]
axes[src] = dst
axes[dst] = src
else:
axes = inputs[1]
return _op.transform.transpose(data, axes)
def numpy_T(self, inputs, input_types):
data = inputs[0]
shape = self.infer_shape(data)
if len(shape) != 2:
logger.warning(
"The use of Tensor.T on tensors of dimensions != 2 is deprecated"
"and will be removed in a future release of PyTorch."
)
return _op.transform.transpose(data)
def flatten(self, inputs, input_types):
data = inputs[0]
start = int(inputs[1])
end = int(inputs[2])
dshape = get_const_tuple(self.infer_shape_with_prelude(data))
ndim = len(dshape)
if start < 0:
start += ndim
if end < 0:
end += ndim
assert start <= end, "start dim cannot come after end dim"
new_shape = [0] * start
new_shape.append(-1)
squeeze_axes = []
for i in range(start + 1, end + 1):
new_shape.append(1)
squeeze_axes.append(i)
for _ in range(end + 1, ndim):
new_shape.append(0)
out = _op.reshape(data, new_shape)
if squeeze_axes:
out = _op.squeeze(out, axis=squeeze_axes)
return out
def addmm(self, inputs, input_types):
input_mat = inputs[0]
mat1 = inputs[1]
data_type = input_types[1]
mat2 = inputs[2]
beta = inputs[3]
alpha = inputs[4]
if not isinstance(alpha, _expr.Expr) and alpha != 1:
alpha = _create_typed_const(alpha, data_type)
mat1 *= alpha
if not isinstance(beta, _expr.Expr) and beta != 1:
beta = _create_typed_const(beta, data_type)
mat2 *= beta
transposed_mat2 = _op.transform.transpose(mat2, axes=[1, 0])
units = self.infer_shape(transposed_mat2)[0]
dense_out = _op.nn.dense(mat1, transposed_mat2, units=units)
return dense_out + input_mat
def size(self, inputs, input_types):
shape = self.infer_shape_with_prelude(inputs[0])
axis = None
if len(inputs) > 1:
axis = int(inputs[1])
if any(map(lambda s: isinstance(s, tvm.tir.expr.Any), shape)):
if axis is None or isinstance(shape[axis], tvm.tir.expr.Any):
shape_dynamic = _op.shape_of(inputs[0], dtype="int32")
if axis is not None:
return _op.take(shape_dynamic, _expr.const(axis), 0)
return shape_dynamic
if axis is not None:
return _expr.const(shape[axis])
return _expr.const(shape)
def numtotensor(self, inputs, input_types):
val = inputs[0]
dtype = input_types[0]
if isinstance(val, _expr.Expr):
return val
if isinstance(val, tvm.tir.IntImm):
val = val.__int__()
dtype = int
arr = val * np.ones([]).astype(dtype)
return arr
def tensortonum(self, inputs, input_types):
return inputs[0]
def view(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 3:
shape_inp = [inputs[1], self.infer_shape(inputs[2])[0]]
else:
if isinstance(inputs[1], list):
shape_inp = inputs[1]
else:
shape_inp = self.infer_shape(inputs[1])
new_shape = shape_inp
for i, shape in enumerate(shape_inp):
if isinstance(shape, _expr.Expr):
val = _infer_value_simulated(shape, {})
new_shape[i] = val.numpy().item(0)
return _op.transform.reshape(data, new_shape)
def reshape(self, inputs, input_types):
data = inputs[0]
new_shape = inputs[1]
tmp_shape = []
is_dyn = False
for s in new_shape:
if isinstance(s, _expr.Constant):
tmp_shape.append(int(s.data.numpy()))
elif isinstance(s, _expr.Expr):
dim, success = try_infer_value(s, lambda ret: int(ret))
tmp_shape.append(dim)
if not success:
is_dyn = True
else:
tmp_shape.append(s)
if is_dyn:
new_shape = []
for i, s in enumerate(tmp_shape):
if not isinstance(s, _expr.Expr):
s = _expr.const(s, "int64")
else:
s = _op.cast(s, "int64")
new_shape.append(_op.expand_dims(s, axis=0))
new_shape = _op.concatenate(new_shape, axis=0)
else:
new_shape = tmp_shape
return _op.transform.reshape(data, new_shape)
def reshape_as(self, inputs, input_types):
data = inputs[0]
new_shape = self.infer_shape(inputs[1])
return _op.transform.reshape(data, new_shape)
def pixel_shuffle(self, inputs, input_types):
data = inputs[0]
upscale_factor = inputs[1]
upscale_squared = upscale_factor * upscale_factor
b, c, h, w = self.infer_shape(data)
assert (
c % upscale_squared == 0
), "input channel should be divisible by square of upscale_factor"
ndims = len(self.infer_shape_with_prelude(data))
axes = list(range(ndims))
num_inputs = len(inputs)
oc = c // upscale_squared
oh = h * upscale_factor
ow = w * upscale_factor
new_shape = [b, oc, upscale_factor, upscale_factor, h, w]
out_shape = [b, oc, oh, ow]
data = _op.transform.reshape(data, new_shape)
# The data will be transposed to
# [b, oc, h, upscale_factor, w, upscale_factor]
# for further reshape
axes = [0, 1, 4, 2, 5, 3]
data = _op.transform.transpose(data, axes)
return _op.transform.reshape(data, out_shape)
def clone(self, inputs, input_types):
data = inputs[0]
return _op.tensor.copy(data)
def log_softmax(self, inputs, input_types):
data = inputs[0]
axis = int(inputs[1])
return _op.nn.log_softmax(data, axis)
def sigmoid(self, inputs, input_types):
data = inputs[0]
def func(x):
return _op.tensor.sigmoid(x)
if self.is_quantized_tensor(data):
assert len(inputs) == 5, "Input/Ouput quant param not found in op inputs"
return qnn_torch.quantized_sigmoid(inputs)
return func(data)
def softplus(self, inputs, input_types):
dtype = input_types[0]
beta = _expr.const(float(inputs[1]), dtype=dtype)
threshold = int(inputs[2]) if inputs[2] else 20
threshold_ = _op.full_like(inputs[0], fill_value=_expr.const(threshold))
softplus_value = _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta
return _op.where(_op.greater(inputs[0] * beta, threshold_), inputs[0], softplus_value)
def make_avg_pool(self, dim):
def avg_pool(inputs, input_types):
data = inputs[0]
pool_size = self.convert_const_list(inputs[1])
strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size)
padding = inputs[3]
ceil_mode = int(inputs[4])
count_include_pad = int(inputs[5])
def func(x):
if dim == 1:
return _op.nn.avg_pool1d(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
dilation=(1,),
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)
elif dim == 2:
return _op.nn.avg_pool2d(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
dilation=(1, 1),
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)
elif dim == 3:
return _op.nn.avg_pool3d(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
dilation=(1, 1, 1),
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
)
else:
msg = "Average Pooling dimension should be between 1 and 3"
raise RuntimeError(msg)
if self.is_quantized_tensor(data):
return qnn_torch.apply_with_upcast(data, func)
return func(data)
return avg_pool
def linear(self, inputs, input_types):
# https://pytorch.org/docs/stable/nn.functional.html#linear
# 0 - input
# 1 - weight
bias = inputs[2]
a_shape = self.infer_shape_with_prelude(inputs[0])
b_shape = self.infer_shape_with_prelude(inputs[1])
if len(a_shape) == 2 and len(b_shape) == 2:
mm_out = _op.nn.dense(inputs[0], inputs[1])
elif len(b_shape) == 1:
mm_out = self.matmul([inputs[0], inputs[1]], input_types[:2])
else:
mm_out = self.matmul(
[inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2]
)
if isinstance(bias, _expr.Expr):
bias_ndims = len(self.infer_shape_with_prelude(bias))
if bias_ndims == 1:
return _op.nn.bias_add(mm_out, bias, axis=-1)
mm_dtype = self.infer_type_with_prelude(mm_out).dtype
return self.add([mm_out, bias], [mm_dtype, input_types[2]])
return mm_out
def dropout(self, inputs, input_types):
data = inputs[0]
rate = float(inputs[1])
return _op.nn.dropout(data, rate)
def make_reduce(self, name):
def reduce(inputs, input_types):
data = inputs[0]
axis = None
keepdims = False
if len(inputs) > 2: # default, torch have only data, axis=None, keepdims=False
if isinstance(inputs[1], int):
axis = int(inputs[1])
elif _is_int_seq(inputs[1]):
axis = inputs[1]
else:
axis = list(self.infer_shape(inputs[1]))
keepdims = bool(inputs[2])
return get_relay_op(name)(data, axis=axis, keepdims=keepdims)
return reduce
def norm(self, inputs, input_types):
data = inputs[0]
dtype = input_types[0]
axis = None
keepdims = False
if len(inputs) > 3:
axis = inputs[2]
keepdims = bool(inputs[3])
order = inputs[1]
if order == np.inf:
return _op.reduce.max(_op.abs(data), axis=axis, keepdims=keepdims)
elif order == np.NINF:
return _op.reduce.min(_op.abs(data), axis=axis, keepdims=keepdims)
else:
reci_order = _expr.const(1.0 / order, dtype=dtype)
order = _expr.const(order)
return _op.power(
_op.reduce.sum(_op.power(_op.abs(data), order), axis=axis, keepdims=keepdims),
reci_order,
)
def frobenius_norm(self, inputs, input_types):
data = inputs[0]
axis = None
keepdims = False
if len(inputs) > 2:
axis = inputs[1] if len(inputs[1]) > 0 else None
keepdims = bool(inputs[2])
return _op.sqrt(_op.reduce.sum((data * data), axis=axis, keepdims=keepdims))
def std(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 2:
axis = None
keepdims = False
unbiased = bool(inputs[1])
else:
axis = inputs[1]
keepdims = bool(inputs[3])
unbiased = bool(inputs[2])
return _op.reduce.std(data, axis=axis, keepdims=keepdims, unbiased=unbiased)
def variance(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 2:
axis = None
keepdims = False
unbiased = bool(inputs[1])
else:
axis = inputs[1]
keepdims = bool(inputs[3])
unbiased = bool(inputs[2])
return _op.reduce.variance(data, axis=axis, keepdims=keepdims, unbiased=unbiased)
def mean(self, inputs, input_types):
data = inputs[0]
if inputs[1]:
axis = inputs[1]
else:
axis = None
if len(inputs) > 2 and inputs[2]:
keepdims = int(inputs[2])
else:
keepdims = False
if len(inputs) > 3 and inputs[3]:
exclude = int(inputs[3])
else:
exclude = False
def func(x):
return _op.mean(x, axis, keepdims, exclude)
if self.is_quantized_tensor(data):
assert len(inputs) == 6, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
# refer to aten/src/ATen/native/quantized/cpu/qreduction.cpp
return qnn_torch.apply_with_fp32_fallback(data, input_scale, input_zero_point, func)
return func(data)
def var_mean(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 2:
axis = None
keepdims = False
unbiased = bool(inputs[1])
else:
axis = inputs[1]
keepdims = bool(inputs[3])
unbiased = bool(inputs[2])
m, v = _op.reduce.mean_variance(data, axis, keepdims, False, unbiased)
return v, m
def chunk(self, inputs, input_types):
data = inputs[0]
num_chunks = int(inputs[1])
axis = int(inputs[2])
if isinstance(data, _expr.Expr):
inferred_shape = self.infer_shape_with_prelude(data)
shape = []
for infer in inferred_shape:
shape.append(infer)
dim = int(shape[axis])
if dim % num_chunks:
unif_size = int(dim / (num_chunks - 1))
else:
unif_size = int(dim / num_chunks)
indeces = []
for i in range(unif_size, dim, unif_size):
indeces.append(i)
return _op.split(data, indeces, axis)
def baddbmm(self, inputs, _):
input = inputs[0]
batch1, batch2 = inputs[1:3]
beta = _expr.const(float(inputs[3]))
alpha = _expr.const(float(inputs[4]))
return beta * input + alpha * _op.nn.batch_matmul(batch1, batch2, transpose_b=False)
def matmul(self, inputs, input_types):
assert len(inputs) == 2, "Two tensors to be multiplied are expected."
a = inputs[0]
b = inputs[1]
# Need to check input shape as batch matmul must be supported.
a_shape = self.infer_shape_with_prelude(a)
b_shape = self.infer_shape_with_prelude(b)
a_ndims = len(a_shape)
b_ndims = len(b_shape)
# Check if both tensors are at least 1D.
if a_ndims == 0 or b_ndims == 0:
msg = "Both arguments to matmul must be at least 1D."
raise AssertionError(msg)
# Check if tensors can be multiplied.
b_mulaxis = b_shape[-2] if b_ndims > 1 else b_shape[0]
if a_shape[-1] != b_mulaxis:
msg = "Tensors being multiplied do not have compatible shapes."
raise AssertionError(msg)
# If 1D, remember axis that should be deleted at the end
squeeze_dims = []
if a_ndims == 1:
a = _op.expand_dims(a, axis=0)
squeeze_dims += [-2]
a_ndims = 2
a_shape = (1,) + a_shape
if b_ndims == 1:
b = _op.expand_dims(b, axis=1)
squeeze_dims += [-1]
b_ndims = 2
b_shape = b_shape + (1,)
# Compute result
if a_ndims == 2 and b_ndims == 2:
# Result is obtained using matmul
out = _op.nn.dense(a, _op.transpose(b))
else:
# Result is obtained using batch_matmul
batch_shape = [1] * (max(a_ndims, b_ndims) - 2)
for i, j in enumerate(reversed(a_shape[:-2])):
batch_shape[i] = j
for i, j in enumerate(reversed(b_shape[:-2])):
# Need to check if axis can be broadcasted
if batch_shape[i] == 1 or j == 1 or batch_shape[i] == j:
batch_shape[i] = max(batch_shape[i], j)
else:
msg = "Batch dimensions are not broadcastable."
raise AssertionError(msg)
batch_shape = batch_shape[::-1]
a = _op.broadcast_to(a, batch_shape + list(a_shape[-2:]))
b = _op.broadcast_to(b, batch_shape + list(b_shape[-2:]))
out = _op.nn.batch_matmul(
_op.reshape(a, [-1, *a_shape[-2:]]),
_op.reshape(b, [-1, *b_shape[-2:]]),
transpose_b=False,
)
out_shape = batch_shape + [a_shape[-2]] + [b_shape[-1]]
out = _op.reshape(out, out_shape)
return _op.squeeze(out, axis=squeeze_dims)
def expand(self, inputs, input_types):
data_in = inputs[0]
shape = list(self.infer_shape(data_in))
ndims = len(shape)
sizes = inputs[1]
out = data_in
out_dims = len(sizes)
if ndims < out_dims:
num_newaxis = out_dims - ndims
out = _op.expand_dims(out, axis=0, num_newaxis=num_newaxis)
shape = [1] * num_newaxis + shape
for i in range(out_dims):
if sizes[i] != -1 and shape[i] == 1:
if not isinstance(sizes[i], int):
sizes[i] = int(_infer_value(sizes[i], {}).numpy())
out = _op.repeat(out, sizes[i], axis=i)
return out
def int(self, inputs, input_types):
if isinstance(inputs[0], _expr.Expr):
return inputs[0]
return int(inputs[0])
def identity(self, inputs, input_types):
return inputs[0]
def none(self, inputs, input_types):
return None
def pad_common(self, mode, pad_value, inputs, input_types):
data = inputs[0]
if isinstance(inputs[1], list):
pad_list = inputs[1]
else:
pad_list = list(self.infer_shape(inputs[1]))
# initialize paddings based on input len
pad_len = len(self.infer_shape(data)) * 2
paddings = [0] * pad_len
if len(pad_list) >= 2:
paddings[-1] = pad_list[1]
paddings[-2] = pad_list[0]
if len(pad_list) >= 4:
paddings[-3] = pad_list[3]
paddings[-4] = pad_list[2]
if len(pad_list) >= 6:
paddings[-5] = pad_list[5]
paddings[-6] = pad_list[4]
# group into tuple of 2 ints
paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)]
const_paddings = []
non_zero_found = False
for pad in paddings:
const_paddings.append([])
for p in pad:
if isinstance(p, _expr.Expr):
p = int(_infer_value(p, {}).numpy())
elif not isinstance(p, int):
raise NotImplementedError("pad width should be int/expr")
const_paddings[-1].append(p)
if p != 0:
non_zero_found = True
if not non_zero_found:
return data
elif mode == "constant":
return _op.nn.pad(data, const_paddings, pad_value=pad_value, pad_mode=mode)
else:
return _op.nn.pad(data, const_paddings, pad_mode=mode)
def pad(self, inputs, input_types):
# mode: Optional default "constant"
if len(inputs) > 2 and inputs[2] is not None:
mode = inputs[2]
else:
mode = "constant"
# pad_value: Optional default 0
if len(inputs) == 4 and inputs[3] is not None:
pad_value = inputs[3]
else:
pad_value = 0
# replicate is edge in TVM's padding mode
if mode == "replicate":
mode = "edge"
elif mode == "circular":
raise ValueError("circular mode for torch.nn.functional.pad are not supported in TVM")
return self.pad_common(mode, pad_value, inputs, input_types)
def constant_pad_nd(self, inputs, input_types):
return self.pad_common("constant", _expr.const(inputs[2]), inputs, input_types)
def reflection_pad1d(self, inputs, input_types):
return self.pad_common("reflect", 0, inputs, input_types)
def reflection_pad2d(self, inputs, input_types):
return self.pad_common("reflect", 0, inputs, input_types)
def replication_pad1d(self, inputs, input_types):
return self.pad_common("edge", 0, inputs, input_types)
def replication_pad2d(self, inputs, input_types):
return self.pad_common("edge", 0, inputs, input_types)
def replication_pad3d(self, inputs, input_types):
return self.pad_common("edge", 0, inputs, input_types)
def clamp_common(self, data, min=None, max=None):
def get_v(v, default_v):
if isinstance(v, _expr.Constant):
return float(v.data.numpy())
if isinstance(v, _expr.Expr):
infer_v, success = try_infer_value(v, lambda ret: float(ret))
if success:
return infer_v
if v is not None:
return v
return default_v
dtype = self.infer_type(data).dtype
type_info = np.finfo(dtype) if "float" in dtype else np.iinfo(dtype)
# TODO(masahi): Properly handle inf in a one-way clamp case.
if min is not None and max is not None:
amin = get_v(min, type_info.min)
amax = get_v(max, type_info.max)
elif min is not None:
amin = get_v(min, type_info.min)
amax = type_info.max
else:
amin = type_info.min
amax = get_v(max, type_info.max)
return _op.clip(data, amin, amax)
def clamp(self, inputs, _):
return self.clamp_common(inputs[0], min=inputs[1], max=inputs[2])
def clamp_min(self, inputs, input_types):
return self.clamp_common(inputs[0], min=inputs[1])
def clamp_max(self, inputs, input_types):
return self.clamp_common(inputs[0], max=inputs[1])
def to(self, inputs, input_types):
data = inputs[0]
dtype = inputs[1] if inputs[1] is not None and not isinstance(inputs[1], str) else inputs[2]
# special handling for aten::to(data, 6, _, _, _) case
# 6 means dtype = float
# this happens when converting upsampling with scale factor
cast_map = {5: "float16", 6: "float32", 7: "float64", 3: "int32", 4: "int64"}
cast_func = {5: float, 6: float, 7: float, 3: int, 4: int}
ret = data
if isinstance(data, _expr.Expr):
actual_dtype = str(self.infer_type(data).dtype)
if dtype in cast_map and cast_map[dtype] != actual_dtype:
ret = _op.cast(data, cast_map[dtype])
elif dtype in cast_map:
ret = cast_func[dtype](data)
return ret
def get_upsample_out_size(self, inputs, method):
# This assumes a static shape
out_size = []
if inputs[1] is not None:
for size in inputs[1]:
if not isinstance(size, int):
out_size.append(int(_infer_value(size, {}).numpy()))
else:
out_size.append(size)
else:
scale_index = 3 if method != "nearest_neighbor" else 2
scales = inputs[scale_index]
assert scales is not None, "neither out size nor scale provided"
assert isinstance(scales, list)
ishape = self.infer_shape(inputs[0])
for i, scale in enumerate(scales):
out_size.append(int(math.floor(float(ishape[2 + i]) * scale)))
return out_size
def make_upsample(self, method):
def upsample(inputs, input_types):
data = inputs[0]
out_size = self.get_upsample_out_size(inputs, method)
if len(inputs) > 2 and method != "nearest_neighbor":
align_corners = inputs[2]
else:
align_corners = False
if method == "nearest_neighbor":
coord_trans = "asymmetric"
elif align_corners:
coord_trans = "align_corners"
else:
coord_trans = "half_pixel"
def func(x):
return _op.image.resize2d(
x, out_size, None, "NCHW", method, coord_trans, cubic_alpha=-0.75
)
if self.is_quantized_tensor(data):
# input qparams are manually appended by us
assert isinstance(inputs[-2], float)
assert isinstance(inputs[-1], int)
input_scale = _expr.const(inputs[-2])
input_zero_point = _expr.const(inputs[-1])
# currently piggy backs to fp32, it gets identical output as torch
return qnn_torch.apply_with_fp32_fallback(data, input_scale, input_zero_point, func)
return func(data)
return upsample
def make_upsample3d(self, method):
def upsample3d(inputs, input_types):
data = inputs[0]
out_size = self.get_upsample_out_size(inputs, method)
if len(inputs) > 2 and method == "linear":
align_corners = inputs[2]
else:
align_corners = False
if method == "nearest_neighbor":
coord_trans = "asymmetric"
elif align_corners:
coord_trans = "align_corners"
else:
coord_trans = "half_pixel"
return _op.image.resize3d(data, out_size, None, "NCDHW", method, coord_trans)
return upsample3d
def expand_as(self, inputs, input_types):
target = inputs[1]
t0 = self.infer_type(inputs[0]).dtype
t1 = self.infer_type(inputs[1]).dtype
if str(t0) != str(t1):
target = _op.cast(target, t0)
return _op.broadcast_to_like(inputs[0], target)
def broadcast_tensors(self, inputs, input_types):
tensor_list = inputs[0]
import torch
infer_shape_value = [self.infer_shape(t) for t in tensor_list]
# "torch.broadcast_shapes" is available after PyTorch 1.8.0
if hasattr(torch, "broadcast_shapes"):
res_shape = list(torch.broadcast_shapes(*infer_shape_value))
else:
res_shape = list(torch.broadcast_tensors(*map(torch.empty, infer_shape_value))[0].shape)
return [_op.broadcast_to(tensor, res_shape) for tensor in tensor_list]
def Bool(self, inputs, input_types):
assert len(inputs) == 1
return inputs[0]
def Float(self, inputs, input_types):
assert len(inputs) == 1
return _op.cast(inputs[0], "float32")
def bitwise_not(self, inputs, input_types):
data = inputs[0]
# The input tensor must be of integral or Boolean types.
# For bool tensors, it computes the logical NOT
if input_types[0] == "bool":
out = _op.logical_not(_op.cast(data, "bool"))
else:
out = _op.bitwise_not(_op.cast(data, "int"))
return out
def bitwise_xor(self, inputs, input_types):
lhs = inputs[0]
rhs = inputs[1]
lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int")
rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int")
return _op.bitwise_xor(lhs, rhs)
def logical_not(self, inputs, input_types):
data = _wrap_const(inputs[0])
return _op.logical_not(_op.cast(data, "bool"))
def logical_xor(self, inputs, input_types):
lhs = _op.cast(inputs[0], "bool")
rhs = _op.cast(inputs[1], "bool")
return _op.logical_xor(lhs, rhs)
def list_getitem(self, inputs, input_types):
return self.prelude.nth(inputs[0], _wrap_const(inputs[1]))
def list_len(self, inputs, input_types):
return self.prelude.length(inputs[0])
def type_as(self, inputs, input_types):
assert len(inputs) == 2
assert len(input_types) == 2
return _op.cast(inputs[0], input_types[1])
def gather(self, inputs, input_types):
data = inputs[0]
axis = inputs[1]
indices = inputs[2]
return _op.gather(data, axis, indices)
def add(self, inputs, input_types):
# add_ is overloaded for tensor add and list concat
if input_types[0] == "ListType":
return self.prelude.concat(inputs[0], inputs[1])
return self.make_elemwise("add")(inputs, input_types)
def tensor_array_stack(self, inputs, input_types):
dim = inputs[1]
assert dim == 0, "stacking on a dynamic tensor list only supported on a first axis"
tensor_array, shape = self.convert_to_tensor_array(inputs[0])
stacked_shape = (Any(),) + shape
stack = self.prelude.get_global_var_static("tensor_array_stack", "float32", shape)
stacked = stack(tensor_array)
static_tensor_array_ops = StaticTensorArrayOps(self.prelude, "float32", stacked_shape)
static_tensor_array_ops.register()
get_tensor = self.prelude.get_global_var_static("tensor_get_data", "float32", stacked_shape)
return get_tensor(stacked)
def stack(self, inputs, input_types):
if isinstance(inputs[0], list):
# a static python list of tensors
dim = inputs[1]
return _op.stack(inputs[0], dim)
else:
# List ADT case
assert isinstance(inputs[0], _expr.Expr)
ty = self.infer_type_with_prelude(inputs[0])
list_ty = self.prelude.mod.get_global_type_var("List")
msg = "The input list is expected to be List ADT"
assert isinstance(ty, tvm.ir.TypeCall) and ty.func == list_ty, msg
return self.tensor_array_stack(inputs, input_types)
def sub(self, inputs, input_types):
if len(inputs) == 3:
data0, data1, alpha = self.pytorch_promote_types(inputs, input_types)
return get_relay_op("subtract")(data0, alpha * data1)
else:
data0, data1 = self.pytorch_promote_types(inputs, input_types)
return get_relay_op("subtract")(data0, data1)
def rsub(self, inputs, input_types):
data0, data1, alpha = self.pytorch_promote_types(inputs, input_types)
# note: rsub means data0 and data1 swap places
return get_relay_op("subtract")(data1, alpha * data0)
def embedding(self, inputs, input_types):
weight = inputs[0]
indices = inputs[1]
return _op.take(weight, indices.astype("int32"), axis=0)
def one_hot(self, inputs, input_types):
indices = inputs[0].astype("int32")
num_classes = inputs[1]
if num_classes == -1:
msg = "Inferring the number of classes is not yet supported."
raise NotImplementedError(msg)
dtype = "int32"
on_value = tvm.relay.const(1.0, dtype)
off_value = tvm.relay.const(0.0, dtype)
return _op.one_hot(indices, on_value, off_value, num_classes, -1, dtype)
def index(self, inputs, input_types):
data = inputs[0]
data_shape = self.infer_type(data).shape
axes_adv_idx = [i for i, v in enumerate(inputs[1]) if v is not None]
axes_rest = [i for i in range(len(data_shape)) if i not in axes_adv_idx]
# check if the adv_index axes are consecutive
# if consecutive, result must be transposed again at the end
consecutive = True
for curr, nxt in zip(axes_adv_idx[:-1], axes_adv_idx[1:]):
if nxt - curr != 1:
consecutive = False
break
indices_list = []
axes_order = axes_adv_idx + axes_rest
for i in axes_adv_idx:
inp = inputs[1][i]
if self.infer_type(inp).dtype == "bool":
# adv_index does not support a mask as the index tensor (it will treat 0/1 as
# an index rather than a flag).
# So we use argwhere to turn the mask into indices, which will also take care
# of the dynamism in the indexing by mask.
indices_list.append(_op.squeeze(_op.transform.argwhere(inp), axis=[1]))
else:
indices_list.append(inp)
data_after_adv_index = _op.adv_index([_op.transpose(data, axes=axes_order)] + indices_list)
if consecutive:
num_dims = len(self.infer_type(data_after_adv_index).shape)
num_new_dims = num_dims - len(axes_rest)
axes_final_order = list(range(num_dims))
axes_final_order = (
axes_final_order[num_new_dims : num_new_dims + axes_adv_idx[0]]
+ axes_final_order[:num_new_dims]
+ axes_final_order[num_new_dims + axes_adv_idx[0] :]
)
return _op.transpose(data_after_adv_index, axes=axes_final_order)
else:
return data_after_adv_index
def meshgrid(self, inputs, input_types):
data = inputs[0]
return _op.meshgrid(data, indexing="ij")
def nms(self, inputs, input_types):
boxes = inputs[0]
scores = inputs[1]
iou_threshold = inputs[2]
# TVM NMS assumes score > 0
# - since there exists multi-comsumers for "scores", "num_boxes"
# - invoke set_span here to prevent expr-rewritten occurrs in span-filling stage
source_name = self.source_map[self.current_op[-1]]
scores = set_span(scores - _op.min(scores) + _op.const(1.0), source_name)
num_boxes = set_span(_op.shape_of(scores), source_name)
# PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count
# - since "arange" op will fill expr into its attribute
# - invoke set_span here to prevent expr-rewritten occurrs in span-filling stage
indices = _op.transform.arange(set_span(_op.squeeze(num_boxes), source_name), dtype="int32")
indices = _op.expand_dims(indices, 0, 1)
# Generate data with shape (1, num_anchors, 5)
scores = AttrCvt(op_name="expand_dims", extras={"axis": -1, "num_newaxis": 1})([scores], {})
data = _op.concatenate([scores, boxes], -1)
data = _op.expand_dims(data, 0, 1)
# Perform Non-Maximum Suppression,
# PyTorch NMS doesn't have parameter top_k and max_output_size
score_index = 0
top_k = max_out_size = -1
nms_ret = get_relay_op("non_max_suppression")(
data=data,
valid_count=num_boxes,
indices=indices,
max_output_size=max_out_size,
iou_threshold=iou_threshold,
force_suppress=True,
top_k=top_k,
coord_start=1,
score_index=score_index,
id_index=-1,
return_indices=True,
invalid_to_bottom=False,
)
# squeeze the two outputs of nms for strided_slice
size = get_relay_op("squeeze")(nms_ret[1], axis=[1])
data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])
# strided slice to get the dynamic result
ret = get_relay_op("strided_slice")(
data_slice, begin=_expr.const([0]), end=size, slice_mode="size"
)
# in torchvision, indices from nms are int64
return _op.cast(ret, "int64")
def logsumexp(self, inputs, input_types):
data = self.pytorch_promote_types(inputs[:1], input_types[:1])
dim_list = inputs[1]
keepdim = inputs[2] if len(inputs) > 2 else False
# dim is output of prim::ListConstruct, even if it is int in python code
assert isinstance(dim_list, list), "dim is expected to be a list"
return _op.logsumexp(data[0], axis=dim_list, keepdims=keepdim)
def roi_align(self, inputs, input_types):
data = inputs[0]
boxes = inputs[1]
output_size = (inputs[3], inputs[4])
spatial_scale = inputs[2]
sample_ratio = inputs[5]
aligned = False if len(inputs) < 7 else inputs[6]
if aligned:
boxes -= _expr.const(0.5 / spatial_scale)
return _op.vision.roi_align(data, boxes, output_size, spatial_scale, sample_ratio)
def deform_conv2d(self, inputs, input_types):
data = inputs[0]
weight = inputs[1]
offset = inputs[2]
if len(inputs) > 12:
strides_offset = 5
bias = inputs[4]
logger.warning("mask argument in deformable conv2d is not supported and ignored")
else:
strides_offset = 4
bias = inputs[3]
strides = (inputs[strides_offset], inputs[strides_offset + 1])
padding = (inputs[strides_offset + 2], inputs[strides_offset + 3])
dilation = (inputs[strides_offset + 4], inputs[strides_offset + 5])
groups = inputs[strides_offset + 6]
deformable_groups = inputs[strides_offset + 7]
weight_shape = self.infer_shape(weight)
output_channels = weight_shape[0]
kernel_size = (weight_shape[2], weight_shape[3])
conv_out = _op.nn.deformable_conv2d(
data,
offset,
weight,
strides,
padding,
dilation,
deformable_groups,
groups,
output_channels,
kernel_size,
)
return _op.nn.bias_add(conv_out, bias)
def stft(self, inputs, input_types):
data = inputs[0]
n_fft = inputs[1]
hop_length = inputs[2]
win_length = inputs[3]
window = inputs[4]
normalized = inputs[5]
onesided = inputs[6]
return _op.stft(data, n_fft, hop_length, win_length, window, normalized, onesided)
def unbind(self, inputs, input_types):
data = inputs[0]
axis = int(inputs[1])
return unbind(data, axis)
def shape_as_tensor(self, inputs, input_types):
is_symbolic_shape = False
input_shape = self.infer_shape(inputs[0], self.prelude.mod)
for axis in input_shape:
if not isinstance(axis, (int, tvm.tir.IntImm)):
is_symbolic_shape = True
break
if is_symbolic_shape:
ret = _op.shape_of(inputs[0], dtype="int64")
else:
ret = _expr.const(np.array(input_shape), dtype="int64")
return ret
def logical_and(self, inputs, input_types):
lhs = _op.cast(inputs[0], "bool")
rhs = _op.cast(inputs[1], "bool")
return _op.logical_and(lhs, rhs)
def nonzero(self, inputs, input_types, is_numpy_style=False):
data = inputs[0]
ret = _op.transform.argwhere(data)
if is_numpy_style or (len(inputs) > 1 and inputs[1]):
return unbind(ret, 1)
return ret
def nonzero_numpy(self, inputs, input_types):
return self.nonzero(inputs, input_types, is_numpy_style=False)
def scatter(self, inputs, input_types):
assert len(inputs) == 4 or len(inputs) == 5, (
f"scatter takes 4 or 5 inputs: data, dim, index, src, reduce (optional), "
f"but {len(inputs)} given"
)
data = inputs[0]
axis = int(inputs[1])
index = inputs[2]
src = inputs[3]
if len(inputs) == 5:
reduce = inputs[4]
else:
reduce = "update"
data_shape = self.infer_shape(data)
data_rank = len(data_shape)
index_shape = self.infer_shape(index)
index_rank = len(index_shape)
# When index is empty, the operation returns data unchanged
if self.is_empty_shape(index_shape):
return data
if np.isscalar(src):
assert self.infer_type(src).dtype == "float", "Scalar source can be float only"
src = _op.broadcast_to_like(src, data_shape)
src_shape = data_shape
else:
src_shape = self.infer_shape(src)
src_rank = len(src_shape)
assert data_rank == index_rank, "Index rank is not the same as data rank"
assert data_rank == src_rank, "Src rank is not the same as data rank"
assert 0 <= axis < data_rank, "Dim is out of bounds"
for i in range(data_rank):
index_dim = index_shape[i]
src_dim = src_shape[i]
data_dim = data_shape[i]
# Skip check for dynamic dimensions
if not any([isinstance(index_dim, tvm.tir.Any), isinstance(src_dim, tvm.tir.Any)]):
assert index_dim <= src_dim, "Index dim size should be less than src one"
if i != axis and not any(
[isinstance(index_dim, tvm.tir.Any), isinstance(data_dim, tvm.tir.Any)]
):
assert index_dim <= data_dim, "Index dim size should be less than data one"
if reduce is None:
reduce = "update"
elif reduce == "multiply":
reduce = "mul"
assert reduce in [
"update",
"add",
"mul",
], 'reduce arg is expected from "add", "multiply" or None'
return _op.scatter_elements(data, index, src, axis, reduce)
def index_put(self, inputs, input_types):
in_tensor = inputs[0]
indices = inputs[1]
values = inputs[2]
accumulate = inputs[3]
if not accumulate:
mode = "update"
else:
mode = "add"
# Combine array of index tensors into one index tensor with shape (N,_)
index_tensor = _op.stack(indices, axis=0)
return _op.scatter_nd(in_tensor, index_tensor, values, mode)
def scalar_tensor(self, inputs, input_types):
data = inputs[0]
cast_map = {6: "float32", 7: "float64", 3: "int32", 4: "int64"}
type_key = inputs[1]
if isinstance(data, _expr.Constant):
data = data.data.numpy().tolist()
return _expr.const(data, cast_map[type_key])
def interpolate(self, inputs, input_types):
if isinstance(inputs[1], _expr.Expr):
out_size = inputs[1]
elif isinstance(inputs[1], list):
out_size = []
for i in [0, 1]:
size, _ = try_infer_value(
inputs[1][i],
lambda ret: ret.astype(np.int),
lambda: _op.expand_dims(inputs[1][i], axis=0),
)
out_size.append(size)
out_size = _op.concatenate(out_size, axis=0)
data = inputs[0]
align_corners = inputs[4]
method = inputs[3]
if method.startswith("nearest"):
method = "nearest_neighbor"
elif method[0:2] == "bi":
method = method[2:]
if method == "nearest_neighbor":
coord_trans = "asymmetric"
elif align_corners:
coord_trans = "align_corners"
else:
coord_trans = "half_pixel"
return _op.image.resize2d(
data, out_size, None, "NCHW", method, coord_trans, cubic_alpha=-0.75
)
def numel(self, inputs, input_types):
return _op.ndarray_size(inputs[0])
def empty(self, inputs, input_types):
shape = []
for s in inputs[0]:
if isinstance(s, _expr.Constant):
shape.append(s.data.numpy().item())
else:
assert isinstance(s, int)
shape.append(s)
return _op.zeros(shape, _convert_dtype_value(inputs[1]))
def empty_like(self, inputs, input_types):
shape = self.infer_shape(inputs[0])
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
else:
dtype = input_types[0]
return _op.zeros(shape, dtype)
def new_empty(self, inputs, input_types):
size = inputs[1]
import torch
if not isinstance(size, (_expr.Expr, list, tuple, torch.Size, np.ndarray)):
msg = f"Data type {type(size)} could not be parsed in empty op"
raise AssertionError(msg)
if inputs[2] is not None:
dtype = _convert_dtype_value(inputs[2])
else:
dtype = input_types[0]
return _op.zeros(size, dtype)
def randn(self, inputs, input_types):
import time # use current time as seed
shape = inputs[0]
output = _op.random.normal(_op.random.threefry_key(int(time.time())), shape)
_, values = _expr.TupleWrapper(output, 2)
return values
def bincount(self, inputs, input_types):
data = inputs[0]
weights = inputs[1]
input_type = self.infer_type(data).dtype
if input_type == "int64":
logger.warning(
"Casting an int64 input to int32, since we do not have int64 atomic add"
"needed for bincount yet."
)
data = _op.cast(data, "int32")
maximum = _op.max(data)
dim = maximum + _expr.const(1, dtype="int32")
if weights:
weight_type = self.infer_type(weights)
out_dtype = weight_type.dtype
updates = weights
else:
out_dtype = "int32"
updates = _op.ones_like(data)
counts = _op.zeros(_op.reshape(dim, [1]), out_dtype)
out = _op.scatter_elements(counts, data, updates, axis=0, reduction="add")
if input_type == "int32":
# Torch always outputs int64 results for bincount
return _op.cast(out, "int64")
return out
def scatter_add(self, inputs, input_types):
assert (
len(inputs) == 4
), f"scatter_add takes 4 inputs (data, dim, index, src), but {len(inputs)} given"
data = inputs[0]
axis = inputs[1]
index = inputs[2]
src = inputs[3]
data_shape = self.infer_shape(inputs[0])
data_rank = len(data_shape)
index_shape = self.infer_shape(inputs[2])
index_rank = len(index_shape)
# When index is empty, the operation returns data unchanged
if self.is_empty_shape(index_shape):
return data
src_shape = self.infer_shape(inputs[3])
src_rank = len(src_shape)
assert data_rank == index_rank, "Index rank is not the same as data rank"
assert data_rank == src_rank, "Src rank is not the same as data rank"
assert 0 <= axis < data_rank, "Dim is out of bounds"
for i in range(data_rank):
assert index_shape[i] <= src_shape[i], "Index dim size should be less than src one"
if i != axis:
assert (
index_shape[i] <= data_shape[i]
), "Index dim size should be less than data one"
return _op.scatter_elements(data, index, src, axis=axis, reduction="add")
def scatter_reduce(self, inputs, input_types):
assert len(inputs) == 5 or len(inputs) == 6, (
f"scatter_reduce takes 5 or 6 inputs (data, dim, index, src, reduce, include_self), "
f"but {len(inputs)} given"
)
data = inputs[0]
dim = inputs[1]
index = inputs[2]
src = inputs[3]
reduce = inputs[4]
if len(inputs) == 6:
include_self = inputs[5]
# TODO(vvchernov): support include_self == False
assert include_self, "include_self=False has not been suppoted for scatter_reduce yet"
data_shape = self.infer_shape(inputs[0])
data_rank = len(data_shape)
index_shape = self.infer_shape(inputs[2])
index_rank = len(index_shape)
src_shape = self.infer_shape(inputs[3])
src_rank = len(src_shape)
assert data_rank == index_rank, "Index rank is not the same as data rank"
assert data_rank == src_rank, "Src rank is not the same as data rank"
assert 0 <= dim < data_rank, "Dim is out of bounds"
for i in range(data_rank):
assert index_shape[i] <= src_shape[i], "Index dim size should be less than src one"
if i != dim:
assert (
index_shape[i] <= data_shape[i]
), "Index dim size should be less than data one"
red_valids = ["sum", "prod", "mean", "amax", "amin"]
assert (
reduce in red_valids
), f"Only {red_valids} modes are supported, but {reduce} is gotten"
if reduce == "sum":
reduce = "add"
elif reduce == "prod":
reduce = "mul"
elif reduce == "amin":
reduce = "min"
elif reduce == "amax":
reduce = "max"
return _op.scatter_elements(data, index, src, axis=dim, reduction=reduce)
def cumsum(self, inputs, input_types):
data = inputs[0]
dim = inputs[1]
dtype = inputs[2]
if inputs[2] is not None:
dtype = _convert_dtype_value(inputs[2])
return _op.cumsum(data, axis=dim, dtype=dtype)
def masked_fill(self, inputs, input_types):
mask = inputs[1]
value = _op.cast(_wrap_const(inputs[2]), input_types[0])
return _op.where(mask, value, inputs[0])
def masked_select(self, inputs, input_types):
mask = inputs[1]
indices = self.nonzero([mask], input_types, is_numpy_style=True)
return _op.adv_index([inputs[0]] + [indices[i] for i in range(indices.size)])
def sort(self, inputs, input_types):
data = inputs[0]
dim = inputs[1]
is_descending = inputs[2]
# pytorch sort returns both sorted indices and values
indices = _op.argsort(data, dim, not is_descending)
return _op.gather(data, dim, indices), indices
def argsort(self, inputs, input_types):
data = inputs[0]
dim = inputs[1]
is_descending = inputs[2]
return _op.argsort(data, dim, not is_descending)
def is_floating_point(self, inputs, input_types):
assert len(inputs) == 1
if isinstance(inputs[0], _expr.Expr):
input_type = self.infer_type(inputs[0]).dtype
else:
input_type = input_types[0]
is_float = input_type in ["float32", "float64", "float16", "bfloat16"]
return _expr.const(is_float)
def unique(self, inputs, input_types):
assert len(inputs) == 4
[data, is_sorted, return_inverse, return_counts] = inputs
if not is_sorted:
logger.warning("TVM always assumes sorted=True for torch.unique")
is_sorted = True
if return_counts:
[unique, indices, inverse_indices, num_uniq, counts] = _op.unique(
data, is_sorted=is_sorted, return_counts=True
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size")
return (unique_sliced, inverse_indices, counts_sliced)
else:
[unique, indices, inverse_indices, num_uniq] = _op.unique(
data, is_sorted=is_sorted, return_counts=False
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
return (unique_sliced, inverse_indices)
def nll_loss(self, inputs, input_types):
assert len(inputs) == 5
[predictions, targets, weights, reduction, ignore_index] = inputs
num_class = self.infer_shape(predictions)[1]
if reduction == 0:
reduction = "none"
elif reduction == 1:
reduction = "mean"
else:
reduction = "sum"
if weights is None:
weights = _op.full(_expr.const(1), (num_class,), dtype=input_types[0])
return _op.nn.nll_loss(predictions, targets, weights, reduction, ignore_index)
def flip(self, inputs, input_types):
data = inputs[0]
axis = inputs[1]
return _op.transform.reverse(data, axis=axis[0])
def bidir_rnn_cell(self, input_seqs, weights_dicts, act=_op.tanh):
"""
Bidirectional RNN cell
"""
seq_len = len(input_seqs)
forward_outputs, fw_H_t = rnn_cell(input_seqs, **weights_dicts[0], backwards=False, act=act)
reverse_outputs, rev_H_t = rnn_cell(input_seqs, **weights_dicts[1], backwards=True, act=act)
final_outputs = []
for i in range(seq_len):
final_outputs.append(
_op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1)
)
return final_outputs, _op.stack([fw_H_t, rev_H_t], axis=0)
def rnn_layers(self, input_data, layer_weights_dicts, bidirectional, act, dropout_p=0.0):
"""
Methods iterates layers for Stacked RNN
"""
layers_num = len(layer_weights_dicts)
# split input sequence to samples set
input_seqs = unbind(input_data, 0) # [seq_num, (batch, feature_size)]
output_hiddens = []
for i in range(layers_num):
weights_dicts = layer_weights_dicts[i]
# input_seqs shape = [seq_num, (batch, feature_size)] or
# [seq_num, (batch, 2*feature_size)] for bidirectional
if bidirectional:
input_seqs, H_t = self.bidir_rnn_cell(input_seqs, weights_dicts, act=act)
else:
input_seqs, H_t = rnn_cell(input_seqs, **weights_dicts[0], act=act)
output_hiddens.append(H_t)
# TODO (yuanfz98): in pytorch implementation train is also checked
# see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339
# /aten/src/ATen/native/RNN.cpp#L1054
if dropout_p != 0 and i < layers_num - 1:
# for input in input_seqs:
# input = _op.dropout(input, dropout_p)
raise NotImplementedError("Dropout for GRU has not been supported yet!")
output_hiddens = (
_op.concatenate(output_hiddens, 0) if bidirectional else _op.stack(output_hiddens, 0)
)
return _op.stack(input_seqs, 0), output_hiddens
def rnn(self, inputs, input_types, nonlinearity):
"""
Description of RNN in pytorch:
https://pytorch.org/docs/stable/generated/torch.nn.RNN.html#torch.nn.RNN
Description of inputs:
https://github.com/pytorch/pytorch/blob/736fb7d22cc948b739db2c35aeb5ad4d19aea4f4/torch/overrides.py#L937
"""
# TODO (yuanfz98): support dropout
assert len(inputs) == 9, "Input of size 9 is expected"
# Unpack inputs, note that if optional and not provided then value will be None.
_X = inputs[0]
# _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size)
hidden_state = inputs[1]
# Hidden state shape (hidden_layers_num, batch, hidden_size)
_weights = inputs[2]
# Wi layer[0] shape (hidden_size, feature_size)
# Wh layer[0] shape (hidden_size, hidden_size)
# Bi layer[0] shape (hidden_size)
# Bh layer[0] shape (hidden_size)
# Wi layer[>0] shape (hidden_size, hidden_size * num_directions)
# Wh layer[>0] shape (hidden_size, hidden_size)
# Bi layer[>0] shape (hidden_size)
# Bh layer[>0] shape (hidden_size)
# Scalar inputs
has_biases = inputs[3]
num_layers = inputs[4]
dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout
# train = inputs[6]
bidirectional = inputs[7]
batch_first = inputs[8]
num_directions = 1
if bidirectional:
num_directions = 2
rsd = len(_weights) % num_layers
assert rsd == 0, "The number of weights must be a multiple of the number of layers!"
rsd = (len(_weights) / num_layers) % num_directions
assert (
rsd == 0
), "The number of weights in layer must be a multiple of the number of directions!"
weights_num = int(len(_weights) / num_layers / num_directions)
if has_biases:
assert weights_num == 4, "The weights number in layer is expected equal to 4"
else:
assert weights_num == 2, "The weights number in layer is expected equal to 2"
if nonlinearity == "tanh":
act = _op.tanh
elif nonlinearity == "relu":
act = _op.nn.relu
assert act, "The nonlinearity is unknown"
X = (
_op.transpose(_X, (1, 0, 2)) if batch_first else _X
) # always (seq_num, batch, feature_size)
# TODO (yuanfz98): Which data type should be used? from input or weights?
# Instead of it _infer_type(X).checked_type.dtype can be used
X_dtype = input_types[0]
X_shape = _infer_shape(X) # (seq_num, batch, feature_size)
hidden_size = int(_infer_shape(_weights[0])[0])
batch_size = X_shape[1]
# Initialize hidden states if not provided.
layers_h = []
hidden_layers_num = num_directions * num_layers
if hidden_state is None:
h_0 = _op.zeros((batch_size, hidden_size), X_dtype)
for i in range(hidden_layers_num):
layers_h.append(h_0)
else:
layers_h = unbind(hidden_state, 0)
layer_weights_dicts = []
k = 0 # layer counter
if has_biases:
names = ["hidden_state", "w_inp", "w_hid", "b_inp", "b_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of RNN weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
j = i + weights_num
rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 4]]
rev_weights_dict = dict(zip(names, rev_tensors))
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
layer_weights_dicts.append([fw_weights_dict])
k += 1
else:
names = ["hidden_state", "w_inp", "w_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of RNN weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
j = i + weights_num
rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 2]]
rev_weights_dict = dict(zip(names, rev_tensors))
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of RNN weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
layer_weights_dicts.append([fw_weights_dict])
k += 1
assert (
len(layer_weights_dicts) == num_layers and k == num_layers
), "For stacked RNN number of weights sets should be the same as number of layers!"
output, out_hidden_state = self.rnn_layers(
X, layer_weights_dicts, bidirectional, act, dropout_p=dropout_p
)
# output shape = (seq_num, batch, hidden_size) or
# (seq_num, batch, 2*feature_size) for bidirectional
if batch_first:
output = _op.transpose(output, (1, 0, 2))
return (output, out_hidden_state)
def bidir_gru_cell(self, input_seqs, weights_dicts):
"""
Bidirectional GRU cell
"""
seq_len = len(input_seqs)
forward_outputs, fw_H_t = gru_cell(input_seqs, **weights_dicts[0])
reverse_outputs, rev_H_t = gru_cell(input_seqs, **weights_dicts[1], backwards=True)
final_outputs = []
for i in range(seq_len):
final_outputs.append(
_op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1)
)
return final_outputs, _op.stack([fw_H_t, rev_H_t], axis=0)
def gru_layers(self, input_data, layer_weights_dicts, bidirectional, dropout_p=0.0):
"""
Methods iterates layers for Stacked GRU
"""
layers_num = len(layer_weights_dicts)
# split input sequence to samples set
input_seqs = unbind(input_data, 0) # [seq_num, (batch, feature_size)]
output_hiddens = []
for i in range(layers_num):
weights_dicts = layer_weights_dicts[i]
# input_seqs shape = [seq_num, (batch, feature_size)] or
# [seq_num, (batch, 2*feature_size)] for bidirectional
if bidirectional:
input_seqs, H_t = self.bidir_gru_cell(input_seqs, weights_dicts)
else:
input_seqs, H_t = gru_cell(input_seqs, **weights_dicts[0])
output_hiddens.append(H_t)
# TODO (vvchernov): in pytorch implementation train is also checked
# see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339
# /aten/src/ATen/native/RNN.cpp#L1054
if dropout_p != 0 and i < layers_num - 1:
# for input in input_seqs:
# input = _op.dropout(input, dropout_p)
raise NotImplementedError("Dropout for GRU has not been supported yet!")
return _op.stack(input_seqs, 0), _op.stack(output_hiddens, 0)
def gru(self, inputs, input_types):
"""
Description of GRU in pytorch:
https://pytorch.org/docs/stable/generated/torch.nn.GRU.html?highlight=gru#torch.nn.GRU
"""
# TODO (vvchernov): support dropout
assert len(inputs) == 9, "Input of size 9 is expected"
# Unpack inputs, note that if optional and not provided then value will be None.
_X = inputs[0]
# _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size)
hidden_state = inputs[1]
# Hidden state shape (hidden_layers_num, batch, hidden_size)
_weights = inputs[2]
# Wi layer[0] shape (3 * hidden_size, feature_size)
# Wh layer[0] shape (3 * hidden_size, hidden_size)
# Bi layer[0] shape (3 * hidden_size)
# Bh layer[0] shape (3 * hidden_size)
# Wi layer[>0] shape (3 * hidden_size, hidden_size * num_directions)
# Wh layer[>0] shape (3 * hidden_size, hidden_size)
# Bi layer[>0] shape (3 * hidden_size)
# Bh layer[>0] shape (3 * hidden_size)
# Scalar inputs
has_biases = inputs[3]
num_layers = inputs[4]
dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout
# train = inputs[6]
bidirectional = inputs[7]
batch_first = inputs[8]
num_directions = 1
if bidirectional:
num_directions = 2
rsd = len(_weights) % num_layers
assert rsd == 0, "The number of weights must be a multiple of the number of layers!"
rsd = (len(_weights) / num_layers) % num_directions
assert (
rsd == 0
), "The number of weights in layer must be a multiple of the number of directions!"
weights_num = int(len(_weights) / num_layers / num_directions)
if has_biases:
assert weights_num == 4, "The weights number in layer is expected equal to 4"
else:
assert weights_num == 2, "The weights number in layer is expected equal to 2"
X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X
# TODO (vvchernov): Which data type should be used? from input or weights?
# Instead of it _infer_type(X).checked_type.dtype can be used
X_dtype = input_types[0]
X_shape = _infer_shape(X) # (seq_num, batch, feature_size)
hidden_size = int(_infer_shape(_weights[0])[0] / 3)
batch_size = X_shape[1]
# Initialize hidden states if not provided.
layers_h = []
hidden_layers_num = num_directions * num_layers
if hidden_state is None:
h_0 = _op.zeros((batch_size, hidden_size), X_dtype)
for i in range(hidden_layers_num):
layers_h.append(h_0)
else:
layers_h = unbind(hidden_state, 0)
layer_weights_dicts = []
k = 0 # layer counter
if has_biases:
names = ["hidden_state", "w_inp", "w_hid", "b_inp", "b_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of GRU weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
j = i + weights_num
rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 4]]
rev_weights_dict = dict(zip(names, rev_tensors))
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
layer_weights_dicts.append([fw_weights_dict])
k += 1
else:
names = ["hidden_state", "w_inp", "w_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of GRU weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
j = i + weights_num
rev_tensors = [layers_h[2 * k + 1], *_weights[j : j + 2]]
rev_weights_dict = dict(zip(names, rev_tensors))
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of GRU weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
layer_weights_dicts.append([fw_weights_dict])
k += 1
assert (
len(layer_weights_dicts) == num_layers and k == num_layers
), "For stacked GRU number of weights sets should be the same as number of layers!"
output, out_hidden_state = self.gru_layers(
X, layer_weights_dicts, bidirectional, dropout_p=dropout_p
)
# output shape = (seq_num, batch, hidden_size) or
# (seq_num, batch, 2*feature_size) for bidirectional
if batch_first:
output = _op.transpose(output, (1, 0, 2))
return (output, out_hidden_state)
def bidir_lstm_cell(self, input_seqs, weights_dicts):
"""
Bidirectional LSTM cell
"""
seq_len = len(input_seqs)
forward_outputs, fw_H_t, fw_C_t = lstm_cell(input_seqs, **weights_dicts[0])
reverse_outputs, rev_H_t, rev_C_t = lstm_cell(
input_seqs, **weights_dicts[1], backwards=True
)
final_outputs = []
for i in range(seq_len):
final_outputs.append(
_op.concatenate([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=-1)
)
return final_outputs, (fw_H_t, fw_C_t), (rev_H_t, rev_C_t)
def lstm_layers(self, input_data, layer_weights_dicts, bidirectional, dtype, dropout_p=0.0):
"""
Methods iterates layers for Stacked LSTM
"""
layers_num = len(layer_weights_dicts)
# split input sequence to samples set
input_seqs = unbind(input_data, 0) # [seq_num, (batch, feature_size)]
output_hiddens = []
for i in range(layers_num):
weights_dicts = layer_weights_dicts[i]
# input_seqs shape = [seq_num, (batch, feature_size)] or
# [seq_num, (batch, 2*feature_size)] for bidirectional
if bidirectional:
input_seqs, H_t, C_t = self.bidir_lstm_cell(input_seqs, weights_dicts)
else:
input_seqs, H_t, C_t = lstm_cell(input_seqs, **weights_dicts[0])
output_hiddens.append((H_t, C_t))
# TODO (vvchernov): in pytorch implementation train is also checked
# see https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339
# /aten/src/ATen/native/RNN.cpp#L1054
if dropout_p != 0 and i < layers_num - 1:
# for input in input_seqs:
# input = _op.dropout(input, dropout_p)
raise NotImplementedError("Dropout for LSTM has not been supported yet!")
final_hiddens = []
if bidirectional:
for output_hidden in output_hiddens:
final_hiddens.append(output_hidden[0])
final_hiddens.append(output_hidden[1])
else:
final_hiddens = output_hiddens
return _op.stack(input_seqs, 0), final_hiddens
def lstm(self, inputs, input_types):
"""
Description of LSTM in pytorch:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
Native implementation for torch version less than 1.8.0 (projection is unsupported):
https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/ \
src/ATen/native/RNN.cpp#L1396
Native implementation for torch version from 1.8.0 and higher (projection is supported):
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L1483
"""
# TODO (vvchernov): support dropout
assert len(inputs) == 9, "Input of size 9 is expected"
# Unpack inputs, note that if optional and not provided then value will be None.
_X = inputs[0]
# _X shape (seq_num, batch, feature_size) or (batch, seq_num, feature_size)
hidden_states = inputs[1]
assert len(hidden_states) == 2, "lstm expects two hidden states"
h_0 = hidden_states[0]
c_0 = hidden_states[1]
# H0 shape (hidden_layers_num, batch, proj_size) if projection
# else (hidden_layers_num, batch, hidden_size)
# C0 shape (hidden_layers_num, batch, hidden_size)
_weights = inputs[2]
# If no projection
# Wi layer[0] shape (4 * hidden_size, feature_size)
# Wh layer[0] shape (4 * hidden_size, hidden_size)
# Bi layer[0] shape (4 * hidden_size)
# Bh layer[0] shape (4 * hidden_size)
# Wi layer[>0] shape (4 * hidden_size, hidden_size * num_directions)
# Wh layer[>0] shape (4 * hidden_size, hidden_size)
# Bi layer[>0] shape (4 * hidden_size)
# Bh layer[>0] shape (4 * hidden_size)
# If projection
# Wi layer[0] shape (4 * hidden_size, feature_size)
# Wh layer[0] shape (4 * hidden_size, proj_size)
# Bi layer[0] shape (4 * hidden_size)
# Bh layer[0] shape (4 * hidden_size)
# P layer[0] shape (proj_size, hidden_size)
# Wi layer[>0] shape (4 * hidden_size, proj_size * num_directions)
# Wh layer[>0] shape (4 * hidden_size, proj_size)
# Bi layer[>0] shape (4 * hidden_size)
# Bh layer[>0] shape (4 * hidden_size)
# P layer[>0] shape (proj_size, hidden_size)
# Scalar inputs
has_biases = inputs[3]
num_layers = inputs[4]
dropout_p = inputs[5] # dropout probability, if 0.0 it means there is no dropout
# train = inputs[6]
bidirectional = inputs[7]
batch_first = inputs[8]
num_directions = 1
if bidirectional:
num_directions = 2
rsd = len(_weights) % num_layers
assert rsd == 0, "The number of weights must be a multiple of the number of layers!"
rsd = (len(_weights) / num_layers) % num_directions
assert (
rsd == 0
), "The number of weights in layer must be a multiple of the number of directions!"
has_proj = False
proj_size = 0
weights_num = int(len(_weights) / num_layers / num_directions)
if has_biases:
if weights_num == 5:
has_proj = True
proj_size = _infer_shape(_weights[4])[0]
else:
assert weights_num == 4, "The weights number in layer is expected equal to 4"
else:
if weights_num == 3:
has_proj = True
proj_size = _infer_shape(_weights[2])[0]
else:
assert weights_num == 2, "The weights number in layer is expected equal to 2"
X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X
# TODO (vvchernov): Which data type should be used? from input or weights?
# Instead of it _infer_type(X).checked_type.dtype can be used
X_dtype = input_types[0]
X_shape = _infer_shape(X) # (seq_num, batch, feature_size)
hidden_size = _infer_shape(_weights[0])[0] / 4
batch_size = X_shape[1]
# Initialize hidden states if not provided.
layers_h = []
layers_c = []
hidden_layers_num = num_directions * num_layers
if h_0 is None:
if has_proj:
h_0 = _op.zeros((batch_size, proj_size), X_dtype)
else:
h_0 = _op.zeros((batch_size, hidden_size), X_dtype)
for i in range(hidden_layers_num):
layers_h.append(h_0)
else:
layers_h = unbind(h_0, 0)
if c_0 is None:
c_0 = _op.zeros((batch_size, hidden_size), X_dtype)
for i in range(hidden_layers_num):
layers_c.append(c_0)
else:
layers_c = unbind(c_0, 0)
layer_weights_dicts = []
k = 0 # layer counter
if has_biases:
names = ["hidden_state", "cell_state", "w_inp", "w_hid", "b_inp", "b_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of LSTM weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], layers_c[2 * k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
if has_proj:
fw_weights_dict["proj"] = _weights[i + 4]
j = i + weights_num
rev_tensors = [layers_h[2 * k + 1], layers_c[2 * k + 1], *_weights[j : j + 4]]
rev_weights_dict = dict(zip(names, rev_tensors))
if has_proj:
rev_weights_dict["proj"] = _weights[j + 4]
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], layers_c[k], *_weights[i : i + 4]]
fw_weights_dict = dict(zip(names, fw_tensors))
if has_proj:
fw_weights_dict["proj"] = _weights[i + 4]
layer_weights_dicts.append([fw_weights_dict])
k += 1
else:
names = ["hidden_state", "cell_state", "w_inp", "w_hid"]
if bidirectional:
rsd = len(_weights) % (2 * weights_num)
assert rsd == 0, "got an incorrect number of LSTM weights"
for i in range(0, len(_weights), 2 * weights_num):
fw_tensors = [layers_h[2 * k], layers_c[2 * k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
if has_proj:
fw_weights_dict["proj"] = _weights[i + 2]
j = i + weights_num
rev_tensors = [layers_h[2 * k + 1], layers_c[2 * k + 1], *_weights[j : j + 2]]
rev_weights_dict = dict(zip(names, rev_tensors))
if has_proj:
rev_weights_dict["proj"] = _weights[j + 2]
layer_weights_dicts.append([fw_weights_dict, rev_weights_dict])
k += 1
else:
assert len(_weights) % weights_num == 0, "got an incorrect number of LSTM weights"
for i in range(0, len(_weights), weights_num):
fw_tensors = [layers_h[k], layers_c[k], *_weights[i : i + 2]]
fw_weights_dict = dict(zip(names, fw_tensors))
if has_proj:
fw_weights_dict["proj"] = _weights[i + 2]
layer_weights_dicts.append([fw_weights_dict])
k += 1
assert (
len(layer_weights_dicts) == num_layers and k == num_layers
), "For stacked LSTM number of weights sets should be the same as number of layers!"
outputs = self.lstm_layers(
X, layer_weights_dicts, bidirectional, dtype=X_dtype, dropout_p=dropout_p
)
# output shape = (seq_num, batch, hidden_size) or
# (seq_num, batch, 2*feature_size) for bidirectional
output = outputs[0]
hy = []
cy = []
for hidden in outputs[1]:
hy.append(hidden[0])
cy.append(hidden[1])
if batch_first:
output = _op.transpose(output, (1, 0, 2))
return (output, _op.stack(hy, 0), _op.stack(cy, 0))
def all_any_common(self, op, inputs, input_types):
if len(inputs) >= 2:
dim = inputs[1]
else:
dim = None
if len(inputs) >= 3:
keepdim = inputs[2]
else:
keepdim = False
if self.infer_type(inputs[0]).dtype != "bool":
# The input dtype can be uint8.
inp = _op.cast(inputs[0], "bool")
else:
inp = inputs[0]
return op(inp, axis=dim, keepdims=keepdim)
def searchsorted_common(
self, sorted_sequence, values, out_int32, right, side=None, out=None, sorter=None
):
assert side is None and out is None and sorter is None, "unsupported parameters"
dtype = "int32" if out_int32 else "int64"
values_shape = _infer_shape(values)
if len(values_shape) == 0:
values = _op.expand_dims(values, 0)
out = _op.searchsorted(sorted_sequence, values, right=right, dtype=dtype)
if len(values_shape) == 0:
return _op.squeeze(out)
return out
def searchsorted(self, inputs, input_types):
return self.searchsorted_common(*inputs)
def bucketize(self, inputs, input_types):
return self.searchsorted_common(inputs[1], inputs[0], inputs[2], inputs[3])
def roll(self, inputs, input_types):
def slide_axes(inp, shape, ax):
axes = list(range(len(shape)))
axes = axes[:ax] + [-1] + axes[ax:-1]
return _op.transpose(inp, axes)
x = inputs[0]
shifts = inputs[1]
dims = inputs[2]
shape = self.infer_shape(x)
start = _expr.const(0, "int64")
step = _expr.const(1, "int64")
out = x
for i, dim in enumerate(dims):
roll_dim = _expr.const(shape[dim], "int64")
indices_1d = _op.mod(
_op.transform.arange(start, roll_dim, step, "int64")
- _expr.const(shifts[i], "int64")
+ roll_dim,
roll_dim,
)
# First fill in the last axis with roll indices, and then do transpose to
# bring the roll indices into the desired axis.
indices = slide_axes(
_op.tile(indices_1d, shape[:dim] + shape[dim + 1 :] + (1,)), shape, dim
)
out = _op.gather(out, dim, indices)
return out
def einsum(self, inputs, input_types):
equation = inputs[0]
data = inputs[1]
return _op.einsum(data, equation)
def dot(self, inputs, _):
lhs, rhs = inputs
return _op.sum(_op.multiply(lhs, rhs))
def mv(self, inputs, _):
lhs, rhs = inputs
# Convert the 1D matrix (vector) into a 2D matrix with the extra
# dimension=1
rhs_matrix = _op.transform.expand_dims(rhs, 0)
# Run multiplication
dense_result = _op.nn.dense(lhs, rhs_matrix, units=None)
# Chop off the extra result dimension
return _op.transform.squeeze(dense_result)
def grid_sampler(self, inputs, input_types):
interpolate_mode = inputs[2]
padding_mode = inputs[3]
align_corners = inputs[4]
data_shape = self.infer_shape_with_prelude(inputs[0])
if len(data_shape) == 4:
layout = "NCHW"
axes = [0, 3, 1, 2]
grid = _op.transform.transpose(inputs[1], axes)
elif len(data_shape) == 5:
layout = "NCDHW"
axes = [0, 4, 1, 2, 3]
grid = _op.transform.transpose(inputs[1], axes)
else:
msg = "only 4D and 5D are supported."
raise ValueError(msg)
if interpolate_mode == 0:
interpolate_str = "bilinear"
elif interpolate_mode == 1:
interpolate_str = "nearest"
elif interpolate_mode == 2:
interpolate_str = "bicubic"
else:
msg = f"interpolation method {interpolate_mode} is not supported"
raise ValueError(msg)
if padding_mode == 0:
padding_mode_str = "zeros"
elif padding_mode == 1:
padding_mode_str = "border"
elif padding_mode == 2:
padding_mode_str = "reflection"
else:
msg = f"padding_mode {padding_mode} is not supported"
raise ValueError(msg)
return _op.image.grid_sample(
inputs[0], grid, interpolate_str, layout, padding_mode_str, align_corners
)
def trilu(self, inputs, input_types, mode):
data = inputs[0]
k = inputs[1] if inputs[1] else 0
upper = True if mode == "triu" else False
return _op.trilu(data, k, upper)
def multinomial(self, inputs, input_types):
probs = inputs[0]
num_samples = inputs[1]
replacement = inputs[2] if inputs[2] else True
assert not (
replacement is False and num_samples > 1
), "Multinomial without replacement is not yet supported."
# Ideally this seed would be generated by a previous threefry operation.
# Eventually we might want to add a global store for random keys.
seed = np.random.randint(1e6)
key = _op.random.threefry_key(seed)
output = _op.random.multinomial(key, probs, num_samples)
_, indices = _expr.TupleWrapper(output, 2)
return indices
def weight_norm(self, inputs, input_types):
weight_v, weight_g = inputs[0], inputs[1]
dim = inputs[2]
dtype = input_types[0]
order = 2.0
reci_order = _expr.const(1.0 / order, dtype=dtype)
order = _expr.const(order)
norm_v = _op.power(
_op.reduce.sum(_op.power(_op.abs(weight_v), order), axis=dim, exclude=2, keepdims=True),
reci_order,
)
return weight_g * (weight_v / norm_v)
# Operator mappings
def create_convert_map(self):
self.convert_map = {
"aten::is_floating_point": self.is_floating_point,
"aten::pixel_shuffle": self.pixel_shuffle,
"aten::device": self.none,
"prim::device": self.none,
"aten::sub": self.sub,
"aten::max": self.max,
"aten::min": self.min,
"aten::maximum": self.maximum,
"aten::minimum": self.minimum,
"aten::amax": self.max,
"aten::amin": self.min,
"aten::stft": self.stft,
"aten::mul": self.make_elemwise("multiply"),
"aten::pow": self.make_elemwise("power"),
"aten::lerp": self.lerp,
"aten::arange": self.arange,
"aten::meshgrid": self.meshgrid,
"aten::div": self.make_elemwise("divide"),
"aten::floor_divide": self.make_elemwise("floor_divide"),
"aten::true_divide": self.make_elemwise("divide"),
"aten::fmod": self.make_elemwise("trunc_mod"),
"aten::remainder": self.make_elemwise("floor_mod"),
"aten::addcdiv": self.addcdiv,
"aten::addcmul": self.addcmul,
"aten::ones": self.ones,
"aten::ones_like": self.ones_like,
"aten::zeros": self.zeros,
"aten::zero_": self.zero_,
"aten::zeros_like": self.zeros_like,
"aten::new_zeros": self.new_zeros,
"aten::new_ones": self.new_ones,
"aten::full": self.full,
"aten::full_like": self.full_like,
"aten::new_full": self.new_full,
"aten::fill_": self.fill_,
"aten::linspace": self.linspace,
"aten::reciprocal": self.reciprocal,
"aten::repeat": self.repeat,
"aten::repeat_interleave": self.repeat_interleave,
"aten::to": self.to,
"aten::squeeze": self.squeeze,
"aten::unsqueeze": self.unsqueeze,
"aten::cat": self.concatenate,
"aten::slice": self.slice,
"aten::narrow": self.narrow,
"aten::split": self.split,
"aten::tensor_split": self.tensor_split,
"aten::split_with_sizes": self.split_with_sizes,
"aten::select": self.select,
"aten::take": self.take,
"aten::where": self.where,
"aten::topk": self.topk,
"aten::relu": self.relu,
"aten::relu6": self.relu6,
"aten::prelu": self.prelu,
"aten::leaky_relu": self.leaky_relu,
"aten::elu": self.elu,
"aten::celu": self.celu,
"aten::gelu": self.gelu,
"aten::selu": self.selu,
"aten::silu": self.silu,
"aten::glu": self.glu,
"aten::log_sigmoid": self.log_sigmoid,
"aten::adaptive_avg_pool1d": functools.partial(
self.adaptive_avg_pool, _op.nn.adaptive_avg_pool1d
),
"aten::adaptive_avg_pool2d": functools.partial(
self.adaptive_avg_pool, _op.nn.adaptive_avg_pool2d
),
"aten::adaptive_avg_pool3d": functools.partial(
self.adaptive_avg_pool, _op.nn.adaptive_avg_pool3d
),
"aten::adaptive_max_pool1d": functools.partial(
self.adaptive_max_pool, _op.nn.adaptive_max_pool1d
),
"aten::adaptive_max_pool2d": functools.partial(
self.adaptive_max_pool, _op.nn.adaptive_max_pool2d
),
"aten::adaptive_max_pool3d": functools.partial(
self.adaptive_max_pool, _op.nn.adaptive_max_pool3d
),
"aten::max_pool2d": self.maxpool_2d,
"aten::max_pool2d_with_indices": self.maxpool_2d_with_indices,
"aten::max_pool1d": self.maxpool_1d,
"aten::max_pool3d": self.maxpool_3d,
"aten::hardtanh": self.hardtanh,
"aten::_convolution": self.convolution,
"aten::softmax": self.softmax,
"aten::threshold": self.threshold,
"aten::contiguous": self.contiguous,
"aten::batch_norm": self.batch_norm,
"aten::instance_norm": self.instance_norm,
"aten::layer_norm": self.layer_norm,
"aten::group_norm": self.group_norm,
"aten::transpose": self.transpose,
"aten::t": self.transpose,
"aten::numpy_T": self.numpy_T,
"aten::flatten": self.flatten,
"aten::addmm": self.addmm,
"aten::size": self.size,
"aten::view": self.view,
"aten::reshape": self.reshape,
"aten::reshape_as": self.reshape_as,
"aten::clone": self.clone,
"aten::log_softmax": self.log_softmax,
"aten::sigmoid": self.sigmoid,
"aten::softplus": self.softplus,
"aten::avg_pool1d": self.make_avg_pool(1),
"aten::avg_pool2d": self.make_avg_pool(2),
"aten::avg_pool3d": self.make_avg_pool(3),
"aten::linear": self.linear,
"aten::dropout": self.dropout,
"aten::feature_dropout": self.dropout,
"aten::alpha_dropout": self.dropout,
"aten::mean": self.mean,
"aten::chunk": self.chunk,
"aten::unsafe_chunk": self.chunk,
"aten::matmul": self.matmul,
"aten::bmm": self.matmul,
"aten::baddbmm": self.baddbmm,
"aten::expand": self.expand,
"aten::Int": self.int,
"prim::NumToTensor": self.numtotensor,
"prim::ImplicitTensorToNum": self.tensortonum,
"aten::ScalarImplicit": self.tensortonum,
"aten::pad": self.pad,
"aten::constant_pad_nd": self.constant_pad_nd,
"aten::reflection_pad1d": self.reflection_pad1d,
"aten::reflection_pad2d": self.reflection_pad2d,
"aten::replication_pad1d": self.replication_pad1d,
"aten::replication_pad2d": self.replication_pad2d,
"aten::replication_pad3d": self.replication_pad3d,
"aten::permute": self.transpose,
"aten::sum": self.make_reduce("sum"),
"aten::prod": self.make_reduce("prod"),
"aten::argmin": self.make_reduce("argmin"),
"aten::argmax": self.make_reduce("argmax"),
"aten::norm": self.norm,
"aten::frobenius_norm": self.frobenius_norm,
"aten::std": self.std,
"aten::var": self.variance,
"aten::var_mean": self.var_mean,
"aten::abs": self.make_unary("abs"),
"aten::neg": self.make_unary("negative"),
"aten::cos": self.make_unary("cos"),
"aten::cosh": self.make_unary("cosh"),
"aten::sin": self.make_unary("sin"),
"aten::sinh": self.make_unary("sinh"),
"aten::tan": self.make_unary("tan"),
"aten::tanh": self.make_unary("tanh"),
"aten::acos": self.make_unary("acos"),
"aten::asin": self.make_unary("asin"),
"aten::atan": self.make_unary("atan"),
"aten::log": self.make_unary("log"),
"aten::log2": self.make_unary("log2"),
"aten::log10": self.make_unary("log10"),
"aten::log1p": self.log1p,
"aten::exp": self.make_unary("exp"),
"aten::erf": self.make_unary("erf"),
"aten::trunc": self.make_unary("trunc"),
"aten::sign": self.make_unary("sign"),
"aten::sqrt": self.make_unary("sqrt"),
"aten::rsqrt": self.make_unary("rsqrt"),
"aten::square": self.square,
"aten::tril": functools.partial(self.trilu, mode="tril"),
"aten::triu": functools.partial(self.trilu, mode="triu"),
"aten::ceil": self.make_unary("ceil"),
"aten::floor": self.make_unary("floor"),
"aten::round": self.make_unary("round"),
"aten::isfinite": self.make_unary("isfinite"),
"aten::isinf": self.make_unary("isinf"),
"aten::isnan": self.make_unary("isnan"),
"aten::clamp": self.clamp,
"aten::clamp_min": self.clamp_min,
"aten::clamp_max": self.clamp_max,
"aten::detach": self.identity,
"aten::upsample_bilinear2d": self.make_upsample("linear"),
"aten::upsample_bicubic2d": self.make_upsample("cubic"),
"aten::upsample_nearest2d": self.make_upsample("nearest_neighbor"),
"aten::upsample_trilinear3d": self.make_upsample3d("linear"),
"aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"),
"aten::expand_as": self.expand_as,
"aten::broadcast_tensors": self.broadcast_tensors,
"aten::lt": self.make_elemwise("less"),
"aten::gt": self.make_elemwise("greater"),
"aten::le": self.make_elemwise("less_equal"),
"aten::ge": self.make_elemwise("greater_equal"),
"aten::ne": self.make_elemwise("not_equal"),
"aten::eq": self.make_elemwise("equal"),
"aten::logical_not": self.logical_not,
"aten::logical_xor": self.logical_xor,
"aten::bitwise_not": self.bitwise_not,
"aten::bitwise_xor": self.bitwise_xor,
"aten::Bool": self.Bool,
"aten::Float": self.Float,
"aten::rsub": self.rsub,
"aten::embedding": self.embedding,
"aten::one_hot": self.one_hot,
"aten::mm": self.matmul,
"aten::add": self.add,
"aten::stack": self.stack,
"aten::__getitem__": self.list_getitem,
"aten::len": self.list_len,
"aten::type_as": self.type_as,
"aten::gather": self.gather,
"aten::index_select": self.select,
"aten::index": self.index,
"torchvision::nms": self.nms,
"aten::logsumexp": self.logsumexp,
"torchvision::roi_align": self.roi_align,
"torchvision::deform_conv2d": self.deform_conv2d,
"aten::unbind": self.unbind,
"aten::__and__": self.logical_and,
"aten::logical_and": self.logical_and,
"aten::_shape_as_tensor": self.shape_as_tensor,
"aten::nonzero": self.nonzero,
"aten::nonzero_numpy": self.nonzero_numpy,
"aten::scatter": self.scatter,
"aten::scatter_add": self.scatter_add,
"aten::scatter_reduce": self.scatter_reduce,
"aten::index_put": self.index_put,
"aten::scalar_tensor": self.scalar_tensor,
"aten::__interpolate": self.interpolate,
"aten::IntImplicit": self.identity,
"aten::tensor": self.identity, # used for example in tensor(1.0)
"aten::numel": self.numel,
"aten::empty": self.empty,
"aten::empty_like": self.empty_like,
"aten::new_empty": self.new_empty,
"aten::randn": self.randn,
"aten::bincount": self.bincount,
"aten::__not__": self.logical_not,
"aten::hardswish": self.hard_swish,
"aten::hardsigmoid": self.hard_sigmoid,
"aten::cumsum": self.cumsum,
"aten::masked_fill": self.masked_fill,
"aten::masked_select": self.masked_select,
"aten::argsort": self.argsort,
"aten::sort": self.sort,
"aten::_unique2": self.unique,
"aten::nll_loss": self.nll_loss,
"aten::nll_loss2d": self.nll_loss,
"aten::nll_loss_nd": self.nll_loss,
"aten::cross_entropy_loss": self.cross_entropy_loss_with_logits,
"aten::l1_loss": self.l1_loss,
"aten::mse_loss": self.mse_loss,
"aten::flip": self.flip,
"aten::rnn_tanh": functools.partial(self.rnn, nonlinearity="tanh"),
"aten::rnn_relu": functools.partial(self.rnn, nonlinearity="relu"),
"aten::gru": self.gru,
"aten::lstm": self.lstm,
"aten::all": functools.partial(self.all_any_common, _op.all),
"aten::any": functools.partial(self.all_any_common, _op.any),
"aten::searchsorted": self.searchsorted,
"aten::bucketize": self.bucketize,
"aten::roll": self.roll,
"aten::einsum": self.einsum,
"aten::dot": self.dot,
"aten::mv": self.mv,
"aten::grid_sampler": self.grid_sampler,
"aten::__ior__": self.make_elemwise("bitwise_or"),
"aten::__iand__": self.make_elemwise("bitwise_and"),
"aten::__ixor__": self.make_elemwise("bitwise_xor"),
"aten::__lshift__": self.make_elemwise("left_shift"),
"aten::__rshift__": self.make_elemwise("right_shift"),
"aten::multinomial": self.multinomial,
"aten::_weight_norm": self.weight_norm,
}
def update_convert_map(self, custom_map):
self.convert_map.update(custom_map)
def report_missing_conversion(self, op_names):
"""Check if all ops in an input graph are supported by TVM"""
known_ops = [
"prim::Constant",
"prim::GetAttr",
"prim::ListConstruct",
"prim::ListUnpack",
"prim::TupleConstruct",
"prim::TupleUnpack",
"prim::RaiseException",
"prim::If",
"prim::Loop",
]
known_ops += list(self.convert_map.keys())
known_ops += list(qnn_torch.convert_map.keys())
missing = []
for op_name in op_names:
# Also take care of in-place variant ops like aten::relu_
if op_name not in known_ops and not (
op_name.endswith("_") and op_name[:-1] in known_ops
):
missing.append(op_name)
if missing:
msg = f"The following operators are not implemented: {missing}"
raise NotImplementedError(msg)
def convert_block(self, block, outputs):
"""Translate Torch "Block", used for prim::If and prim::Loop"""
ops = _get_operator_nodes(
block.nodes(), self.source_map, self.op_type_dict, self.use_parser_friendly_name
)
ret_names = _get_input_names(block.returnNode())
return self.convert_operators(ops, outputs, ret_names)
def convert_if(self, if_node, outputs):
"""Translate Torch prim::If to Relay If"""
cond = outputs[if_node.inputsAt(0).debugName()]
blocks = list(if_node.blocks())
true_branch = self.convert_block(blocks[0], outputs)
false_branch = self.convert_block(blocks[1], outputs)
assert len(true_branch) == 1 and len(false_branch) == 1
return _expr.If(cond, true_branch[0], false_branch[0])
def convert_loop(self, loop_node, outputs):
"""Translate Torch prim::Loop to Relay while_loop"""
def get_input(index):
ivalue = loop_node.inputsAt(index)
inode = ivalue.node()
if inode.kind() == "prim::Constant":
return _expr.const(_get_constant(inode))
var_name = ivalue.debugName()
assert var_name in outputs
return _wrap_const(outputs[var_name])
# Refer to the spec for prim::Loop below
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
# The first input: %max_trip_count
# The second input: %initial_condition
# The rest of input: loop variables
max_loop_count = get_input(0)
init_cond = get_input(1)
num_loop_var = len(list(loop_node.inputs())) - 2
init_vals = [get_input(i + 2) for i in range(num_loop_var)]
# while loop has always max_loop_count being int64 max
# max_loop_count.data (tvm.runtime.NDArray) is -1, so _get_constant again
is_while_loop = (
isinstance(max_loop_count, _expr.Constant)
and _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize
)
if is_while_loop:
loop_iter_dtype = "bool"
# while loop with non input dependent condition such as while i < 10:
# init_cond is int, need to cast to bool to type check
if isinstance(init_cond, _expr.Constant):
init_cond = _op.cast(init_cond, "bool")
init_loop_iter_val = init_cond
else:
loop_iter_dtype = "int32"
# always count from 0
init_loop_iter_val = _expr.const(0, dtype="int32")
body_block = list(loop_node.blocks())[0]
block_input_names = _get_input_names(body_block)
num_block_inputs = len(block_input_names)
name_val_pairs = list(zip(block_input_names, [init_loop_iter_val] + init_vals))
outputs.update(name_val_pairs)
def get_var(name, val):
if val:
checked_type = self.infer_type_with_prelude(val)
if hasattr(checked_type, "shape"):
shape = get_const_tuple(checked_type.shape)
actual_shape = []
for dim in shape:
if isinstance(dim, int) and dim == 0:
actual_shape.append(Any())
else:
actual_shape.append(dim)
expr = _expr.var(name, shape=actual_shape, dtype=checked_type.dtype)
else:
expr = _expr.var(name, type_annotation=checked_type)
return set_span(expr, val.span) if val.span else expr
return _expr.var(name)
source_name = self.source_map[loop_node]
loop_iter_var = set_span(
_expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype), span=source_name
)
loop_vars = set_span(
[get_var(name, val) for name, val in name_val_pairs[1:]], span=source_name
)
# Add non constant free variables to loop variables to prevent code blow up
# Without this, if there are two for loops in a row, which often happens
# if the outer loop is unrolled, the computation corresponding to the first for loop
# is inlined inside loop body, turning O(N) + O(N) computation into O(N^2).
# This issue was found when converting from Stacked LSTM test. Torch does not add the
# outputof the eariler loop into loop variables of the next loop.
# So the variable corresponding to the first loop output appears free in the second
# loop body.
free_vars = [
var
for var in _get_free_vars_from_block(body_block)
if var in outputs
and not isinstance(outputs[var], (_expr.Constant, int, float, str))
and outputs[var]
]
prev_outputs = {}
for name in free_vars:
prev_output = outputs[name]
new_loop_var = get_var(name, prev_output)
prev_outputs[name] = prev_output
outputs[name] = set_span(new_loop_var, source_name)
loop_vars.append(new_loop_var)
init_vals.append(prev_output)
def cond(*current_vals):
i = current_vals[0]
if is_while_loop:
return _op.equal(i, _expr.const(True, "bool"))
return _op.less(i, max_loop_count)
def body(*current_vals):
# Update loop variables using the prev iteration outputs
assert len(current_vals) == num_block_inputs + len(free_vars)
for (i, val) in enumerate(current_vals):
if i < num_block_inputs:
outputs[block_input_names[i]] = val
else:
outputs[free_vars[i - num_block_inputs]] = val
block_outputs = self.convert_block(body_block, outputs)
block_outputs += [outputs[name] for name in free_vars]
if not is_while_loop:
# iter var increment implicit in torch, so do it manually
# for while loop, block_outputs[0] is already a boolean,
# the result of termination check
incr = _expr.const(1, dtype="int32")
block_outputs[0] = current_vals[0] + incr
return block_outputs
loop = while_loop(cond, [loop_iter_var] + loop_vars, body)
loop_val = loop(init_loop_iter_val, *init_vals)
# restore original output values for free vars
outputs.update(prev_outputs)
# The first element is a loop counter or boolean condition, ignore it
return [_expr.TupleGetItem(loop_val, i + 1) for i in range(num_loop_var)]
def convert_operators(self, operators, outputs, ret_names):
"""Convert each Torch IR operators to Relay equivalent"""
for node_name, op_node in operators:
operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs)
# we need to record what current operator is to provide correct source name
# for operators needed to be taken care with (e.g. nms / arange ...)
self.current_op.append(op_node)
if operator == "prim::Constant":
outputs[node_name] = _get_constant(op_node)
elif operator == "prim::ListConstruct" and _should_construct_dynamic_list(op_node):
outputs[node_name] = set_span(
self.convert_to_list_adt(inputs), self.source_map[op_node]
)
elif operator == "prim::ListConstruct":
# This assumes that no more elements will be appended to this list
# In this case, we keep the Python list
outputs[node_name] = inputs
elif operator == "prim::TupleConstruct":
def _handel_nested_input(inputs):
inputs_list = []
for i, _ in enumerate(inputs):
if isinstance(inputs[i], list):
inputs_list.append(_handel_nested_input(inputs[i]))
else:
assert isinstance(inputs[i], _expr.Expr)
inputs_list.append(inputs[i])
return _expr.Tuple(inputs_list)
outputs[node_name] = set_span(
_handel_nested_input(inputs), self.source_map[op_node]
)
elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]:
assert len(inputs) == 1
if isinstance(inputs[0], (list, _expr.TupleWrapper)):
unpacked = inputs[0]
else:
unpacked = _unpack_tuple(inputs[0])
outputs.update(
zip(_get_output_names(op_node), set_span(unpacked, self.source_map[op_node]))
)
elif operator == "prim::prim::RaiseException":
logger.warning("raising exceptions is ignored")
outputs[node_name] = None
elif operator == "prim::If":
if_out = self.convert_if(op_node, outputs)
outputs[node_name] = set_span(if_out, self.source_map[op_node])
elif operator == "prim::Loop":
loop_out = self.convert_loop(op_node, outputs)
unpacked_names = _get_output_names(op_node)
assert len(loop_out) == len(unpacked_names)
outputs.update(zip(unpacked_names, set_span(loop_out, self.source_map[op_node])))
else:
if operator not in self.convert_map:
# At this point, the only possible ops that are not in convert_map are
# in-place variant of ops like aten::relu_
assert operator.endswith("_")
logger.warning(
"An in-place op %s found, the result will not be correct "
"if the model depends on side-effects by this op.",
operator,
)
relay_op = self.convert_map[operator[:-1]]
else:
relay_op = self.convert_map[operator]
self._set_parameter_source_name(op_node, outputs)
relay_out = relay_op(
# since the elements in "outputs" may change due to span-filling process
# we have to call "_get_op_inputs" again rather than use "inputs" directly
_get_op_inputs(op_node, outputs),
_get_input_types(op_node, outputs, default_dtype=self.default_dtype),
)
relay_out = set_span(relay_out, self.source_map[op_node])
self.record_output_type(relay_out)
if isinstance(relay_out, tuple):
# This is for torch operators that return multiple outputs
# See _adaptive_max_2d above for example
out_names = _get_output_names(op_node)
outputs.update(zip(out_names, relay_out))
else:
assert op_node.outputsSize() == 1
outputs[node_name] = relay_out
self.current_op.pop()
return [_wrap_const(outputs[ret_name]) for ret_name in ret_names]
def _set_parameter_source_name(self, op_node, outputs):
"""A helper function to rewrite source_name of parameter."""
for name in _get_input_names(op_node):
expr = outputs[name]
if isinstance(expr, (_expr.Var, _expr.Constant)):
name_sep = "_" if self.use_parser_friendly_name else "."
source_name = [self.source_map[op_node]]
if isinstance(expr, _expr.Var):
# variable name should have contained node source name
# for op with attributes in convert_params stage
# e.g. "aten::batch_norm_5.running_mean"
if expr.name_hint.startswith(source_name[0]):
source_name[0] = expr.name_hint
else:
source_name.append(expr.name_hint)
new_expr = set_span(expr, name_sep.join(source_name))
outputs[name] = new_expr
def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
import torch
dtype_map = {
"float64": torch.float64,
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"int64": torch.int64,
"int32": torch.int32,
"int16": torch.int16,
"int8": torch.int8,
"uint8": torch.uint8,
"bool": torch.bool,
}
if len(dtypes) > 0:
result_type = dtypes[0]
for dt in dtypes[1:]:
if dt != result_type: # we don't want to work with same types as we
# don't do quantized here (which cannot be promoted?)
result_type = _convert_data_type(
str(
torch.result_type(
torch.zeros((), dtype=dtype_map[result_type]),
torch.zeros((), dtype=dtype_map[dt]),
)
)
)
else:
result_type = "bool" # this is the smallest type...
for inp in non_tensor_inputs:
result_type = _convert_data_type(
str(torch.result_type(torch.zeros((), dtype=dtype_map[result_type]), inp))
)
return result_type
# Helper functions for operator implementation
def _convert_dtype_value(val):
"""converts a PyTorch the PyTorch numeric type id to a torch scalar type."""
convert_torch_dtype_map = {
11: "torch.bool",
7: "torch.float64",
6: "torch.float32",
5: "torch.float16",
4: "torch.int64",
3: "torch.int32",
2: "torch.int16",
1: "torch.int8",
0: "torch.uint8",
None: "torch.int64",
} # Default is torch.int64
if val in convert_torch_dtype_map:
return _convert_data_type(convert_torch_dtype_map[val])
else:
msg = f"Torch data type value {val} is not handled yet."
raise NotImplementedError(msg)
def _convert_data_type(input_type, default_dtype=None):
"""converts the PyTorch scalar type input_type to a TVM dtype.
optionally, default_dtype can be a TVM dtype that is used
if input_type is None (but not when it is unknown)"""
if input_type is None and default_dtype is not None:
return default_dtype
input_type = input_type.lower()
if input_type in ["double", "float64", "torch.float64"]:
return "float64"
elif input_type in ["float", "float32", "torch.float32"]:
return "float32"
elif input_type in ["half", "float16", "torch.float16"]:
return "float16"
elif input_type in ["long", "int64", "torch.int64"]:
return "int64"
elif input_type in ["int", "int32", "torch.int32"]:
return "int32"
elif input_type in ["short", "int16", "torch.int16"]:
return "int16"
elif input_type in ["char", "int8", "torch.int8"]:
return "int8"
elif input_type in ["byte", "uint8", "torch.uint8"]:
return "uint8"
elif input_type in ["quint8", "torch.quint8"]:
return "quint8"
elif input_type in ["qint8", "torch.qint8"]:
return "qint8"
elif input_type in ["qint32", "torch.qint32"]:
return "qint32"
elif input_type in ["bool", "torch.bool"]:
return "bool"
elif input_type in ["str"]:
return "str"
else:
raise NotImplementedError(f"input_type {input_type} is not handled yet")
return "float32" # Never reached
def _create_typed_const(data, dtype):
"""create a (scalar) constant of given value and dtype.
dtype should be a TVM dtype"""
if dtype == "float64":
typed_data = _expr.const(np.float64(data), dtype=dtype)
elif dtype == "float32":
typed_data = _expr.const(np.float32(data), dtype=dtype)
elif dtype == "float16":
typed_data = _expr.const(np.float16(data), dtype=dtype)
elif dtype == "int64":
typed_data = _expr.const(np.int64(data), dtype=dtype)
elif dtype == "int32":
typed_data = _expr.const(np.int32(data), dtype=dtype)
elif dtype == "int16":
typed_data = _expr.const(np.int16(data), dtype=dtype)
elif dtype == "int8":
typed_data = _expr.const(np.int8(data), dtype=dtype)
elif dtype == "uint8":
typed_data = _expr.const(np.uint8(data), dtype=dtype)
else:
raise NotImplementedError(f"input_type {dtype} is not handled yet")
return typed_data
def _wrap_const(c):
if not isinstance(c, (_expr.Expr, list, tvm.tir.expr.Any)):
return _expr.const(c)
return c
def _run_jit_passes(graph, enable_lower_all_tuples=True):
"""The inline pass is necessary to unwrap prim::CallMethod"""
# pylint: disable=c-extension-no-member
import torch
if is_version_greater_than("1.5.1"):
# This is required for torchvision detection models from 1.6 above
# It is the same as _jit_pass_inline, except that it has some special
# case behaviors for some ops such as aten::__interpolate()
torch._C._jit_pass_onnx_function_substitution(graph)
else:
torch._C._jit_pass_inline(graph)
if enable_lower_all_tuples:
torch._C._jit_pass_lower_all_tuples(graph)
def _get_tensor_and_var(torch_tensor, name):
tensor = tvm.nd.array(torch_tensor.cpu().numpy())
var = _expr.var(name, shape=tensor.shape, dtype=tensor.dtype)
return tensor, var
def _get_output_name(node):
assert node.outputsSize() == 1
return node.output().debugName()
def _get_output_names(node):
return [output.debugName() for output in node.outputs()]
def _get_input_names(node_or_graph):
return [inp.debugName() for inp in node_or_graph.inputs()]
def _get_op_inputs(op_node, outputs):
return [outputs[name] for name in _get_input_names(op_node)]
def _get_node_type(node):
assert node.outputsSize() == 1
return node.output().type().kind()
def _get_uses(node):
uses = []
for output in node.outputs():
uses += output.uses()
return uses
def _get_users(node):
return [use.user for use in _get_uses(node)]
def _getattr_full_name(getattrs, sep="."):
return sep.join([getattr_attr_name(node) for node in getattrs])
def _get_pytorch_value_type(typ, default_dtype="float32"):
kind = typ.kind()
if kind == "TensorType":
if typ.scalarType() is None:
# Tensor's type can be unknown if we use torch.jit.script(...)
# Defaults can be passed in, if not it is float32
logger.warning("Untyped Tensor found, assume it is %s", default_dtype)
return default_dtype
else:
return _convert_data_type(typ.scalarType())
elif kind == "ListType":
return "ListType"
elif kind in ["IntType", "FloatType", "BoolType", "StringType", "OptionalType"]:
pt_dtype = str(typ).lower()
dtype = pt_dtype if kind == "OptionalType" else _convert_data_type(pt_dtype)
return dtype
else:
return "UnsupportedType"
def _get_input_types(op_node, outputs, default_dtype="float32"):
"""Returns a TVM dtype for each input nodes derived from the torch type"""
in_types = []
for inp in op_node.inputs():
if inp.node().kind() == "prim::GetAttr":
# GetAttr nodes always return None when we call scalarType() on it
name = inp.debugName()
assert name in outputs
if isinstance(outputs[name], _expr.Var):
in_types.append(outputs[name].type_annotation.dtype)
else:
# For quantized modules with parameters, here we would get
# "prim::GetAttr[name="_packed_params"]". Since the dtype corresponding to
# _packed_params is not needed by quantized ops, we return an arbitrary type.
in_types.append(default_dtype)
else:
in_types.append(_get_pytorch_value_type(inp.type(), default_dtype=default_dtype))
return in_types
def _get_constant(node):
"""Retrieve a constant associated with this prim::Constant node"""
attribute_names = node.attributeNames()
num_attributes = len(attribute_names)
if num_attributes == 1:
attr_name = attribute_names[0]
ty = node.output().type().kind()
if ty == "IntType":
return node.i(attr_name)
elif ty == "BoolType":
return bool(node.i(attr_name))
elif ty in ["FloatType", "LongType"]:
return node.f(attr_name)
elif ty in ["TensorType", "CompleteTensorType"]:
tensor = node.t(attr_name)
if tensor.is_cuda:
tensor = tensor.cpu()
if len(tensor.shape) == 0: # tensor(0.1)
# TODO(t-vi): When is this needed?
return tensor.item()
return _wrap_const(tensor.numpy())
elif ty in ["DeviceObjType", "StringType"]:
return node.s(attr_name)
elif ty == "FunctionType":
return None
else:
raise NotImplementedError(f"Unsupported type: {ty}")
else:
assert num_attributes == 0
return None
def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name):
"""Rewrite debug name of node outputs with its operator type"""
def _get_source_name(op_type):
op_idx = 0
if op_type in op_type_dict:
op_idx = op_type_dict[op_type] + 1
op_type_dict[op_type] = op_idx
return "_".join([op_type, str(op_idx)])
# get source name of operator and rename all of its outputs
# e.g. node.kind(): aten::adaptive_max_pool2d
# node_src_name -> aten::adaptive_max_pool2d_x
# output_1 -> aten::adaptive_max_pool2d_x_0
# output_2 -> aten::adaptive_max_pool2d_x_1
if node.kind() != "prim::GetAttr":
node_src_name = _get_source_name(node.kind())
for index, output in enumerate(node.outputs()):
output.setDebugName("_".join([node_src_name, str(index)]))
# update source map
# if use_parser_friendly_name is True: e.g. prim::Constant_0 -> prim__Constant_0
if use_parser_friendly_name:
node_src_name = re.sub(r":|\.", "_", node_src_name)
source_map[node] = node_src_name
def _debug_rename(graph, use_parser_friendly_name):
"""Returns map between node and source name"""
source_map, op_type_dict = {}, {}
prim_with_blocks = ["prim::If", "prim::Loop"]
def _traverse_graph(nodes):
for node in nodes:
if node.outputsSize() == 0:
continue
if node.kind() in prim_with_blocks:
for block in node.blocks():
_traverse_graph(block.nodes())
_rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name)
_traverse_graph(graph.nodes())
return source_map
def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, use_parser_friendly_name=False):
"""Returns torch IR nodes that need conversion to Relay"""
ops, should_rename_graph = [], all([source_map, op_type_dict]) is not None
# Traverse nodes and add to graph
for node in nodes:
if node.outputsSize() == 0:
continue
if should_rename_graph:
_rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name)
if node.outputsSize() > 1:
node_name = "_".join(_get_output_names(node))
else:
node_name = _get_output_name(node)
if node.kind() != "prim::GetAttr":
ops.append((node_name, node))
return ops
def _get_relay_input_vars(graph, input_infos, prelude, is_module=True, default_dtype="float32"):
"""
Return Relay vars from input shapes and create entries based on
expected graph inputs - to allow translation
"""
graph_inputs = list(graph.inputs())
if is_module:
# a module has "self" as first input, which we do not need/want
graph_inputs = graph_inputs[1:]
if not isinstance(input_infos, list):
msg = "Graph inputs input_infos should be a list"
raise RuntimeError(msg)
if len(graph_inputs) != len(input_infos):
msg = f"PyTorch has {len(graph_inputs)} inputs and input_infos lists {len(input_infos)}."
raise RuntimeError(msg)
def get_relay_ty(ishape, itype, pt_type):
if pt_type.kind() == "TensorType":
if not (_is_int_seq(ishape) or len(ishape) == 0):
msg = "Shape for Tensors must be lists of ints"
raise RuntimeError(msg)
if (pt_type.dim() is not None and pt_type.dim() != len(ishape)) or (
pt_type.sizes() is not None
and any([s1 != s2 for s1, s2 in zip(pt_type.sizes(), ishape)])
):
msg = "Shapes of input list and information in the graph do not match"
raise RuntimeError(msg)
if len(ishape) > 1 and any(dim <= 0 for dim in ishape[1:]):
msg = (
"Expected input's non-batch dimensions to have positive length, "
f"but input has a shape of {pt_type.sizes()}"
)
raise RuntimeError(msg)
pt_dtype = pt_type.scalarType()
if not pt_dtype and itype:
pt_dtype = itype
dtype = _convert_data_type(pt_dtype, default_dtype=default_dtype)
return TensorType(ishape, dtype)
elif pt_type.kind() == "TupleType":
if not isinstance(ishape, tuple):
msg = "Shapes for tuples must be tuples"
raise RuntimeError(msg)
return TupleType(
[get_relay_ty(elem, itype, pt_t) for elem, pt_t in zip(ishape, pt_type.elements())]
)
elif pt_type.kind() == "ListType":
if not isinstance(ishape, list):
msg = "Shapes for lists must be lists"
raise RuntimeError(msg)
pt_elemtype = pt_type.getElementType()
elem_tys = [get_relay_ty(s, itype, pt_elemtype) for s in ishape]
if len(elem_tys) > 0 and not all(map(lambda ty: ty == elem_tys[0], elem_tys)):
msg = "List elements need have identical types"
raise RuntimeError(msg)
rlist, _, _ = prelude.mod.get_type("List")
return rlist(elem_tys[0])
elif pt_type.kind() == "OptionalType":
# we do not support None yet, so we fill in the type
return get_relay_ty(ishape, itype, pt_type.getElementType())
# TODO: scalar inputs
raise NotImplementedError("unsupported input type")
input_vars = {}
new_input_infos = []
for num, inp in enumerate(input_infos):
if not isinstance(inp, tuple):
msg = f"Graph input {num} is not a tuple"
raise RuntimeError(msg)
if len(inp) != 2 or not isinstance(inp[0], str):
msg = (
f"Graph input {inp} is not valid,"
f" expected ('name', shape) or ('name', (shape, dtype))"
)
raise RuntimeError(msg)
if not isinstance(inp[1], tuple) or len(inp[1]) == 0 or not isinstance(inp[1][-1], str):
new_input_infos.append((inp[0], (inp[1], default_dtype)))
else:
new_input_infos.append(inp)
input_types = [
(name, get_relay_ty(info[0], info[1], gi.type()))
for (name, info), gi in zip(new_input_infos, graph_inputs)
]
ir_inputs = [i.debugName() for i in graph_inputs]
for ir_input, (name, itype) in zip(ir_inputs, input_types):
inp = _expr.var(name, type_annotation=itype)
# Translate from graph input to user input name
input_vars[ir_input] = inp
return input_vars
def _unpack_tuple(tup):
def unpack(tup, num_fields):
return [_expr.TupleGetItem(tup, i) for i in range(num_fields)]
if isinstance(tup, _expr.Tuple):
return unpack(tup, len(tup.fields))
elif isinstance(tup.type_annotation, TupleType):
return unpack(tup, len(tup.type_annotation.fields))
# shouldn't happen
assert False
def _get_free_vars_from_block(block):
block_inp_names = _get_input_names(block)
bound_names = block_inp_names
free_vars = set()
for node in block.nodes():
inp_names = _get_input_names(node)
list_diff = [name for name in inp_names if name not in bound_names]
free_vars.update(list_diff)
bound_names += _get_output_names(node)
return free_vars
def get_use_chains(root_node, terminate=lambda _: False):
"""
Track a chain of users of this node forward, returning a list of chains
See get_attr_chains below for its usage
"""
def concat_lists(lists):
return itertools.chain.from_iterable(lists)
def inner(current, accum):
users = _get_users(current)
if not users or terminate(users):
return [accum]
return concat_lists([inner(nxt, accum + [nxt]) for nxt in users])
return inner(root_node, [root_node])
def get_attr_chains(root_getattr_node):
"""Returns chains of attribute access starting from root_getattr_node
For example, given attribute "block", as in "self.block" when "self" points
to the top level torch.nn.Module, it returns lists of attribute "chains",
e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params']
These sets of attributes form full attribute accessors. For example,
"self.block.1", "self.block.2" will return the second and third submodule,
and "self.block.0._packed_params" will return the parameters of the first
submodule.
"""
def terminate(users):
next_attrs = [user for user in users if user.kind() == "prim::GetAttr"]
return len(next_attrs) == 0
return get_use_chains(root_getattr_node, terminate)
def convert_params(graph, state_dict, source_map, use_parser_friendly_name=False):
"""
Return Relay vars and TVM NDArrays for input parameters
A chain of prim::GetAttr nodes is processed one at a time
"""
getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
params = {}
param_tensors = {}
packed_param_map = {}
param_debug_name_map = {}
vars_by_name = {}
seen = set()
attr_name_sep = "_" if use_parser_friendly_name else "."
for node in getattr_nodes:
if _get_output_name(node) in seen:
continue
for getattrs in get_attr_chains(node):
seen.update(map(_get_output_name, getattrs))
full_attr = _getattr_full_name(getattrs, attr_name_sep)
full_attr_node_name = _get_output_name(getattrs[-1])
# set variable name by concatenating first consumer's name with full attribute
# e.g. "aten::batch_norm_5.running_mean"
var_name = attr_name_sep.join(
[source_map[_get_users(getattrs[-1])[0]], full_attr.split(attr_name_sep)[-1]]
)
if full_attr.endswith("_packed_params"): # for quantized models
packed_param_map[full_attr_node_name] = full_attr
elif full_attr in state_dict:
if var_name in vars_by_name:
var = vars_by_name[var_name]
else:
torch_tensor = state_dict[full_attr]
tensor, var = _get_tensor_and_var(torch_tensor, var_name)
param_tensors[var_name] = tensor
# for quantized parameters to be correctly located
param_debug_name_map[full_attr_node_name] = var_name
vars_by_name[var_name] = var
params[full_attr_node_name] = var
return params, param_tensors, packed_param_map, param_debug_name_map
def get_all_op_names(graph):
"""Return all operator names in the input graph"""
nodes = list(graph.nodes())
prim_with_blocks = ["prim::If", "prim::Loop"]
for prim in prim_with_blocks:
prim_nodes = graph.findAllNodes(prim, recurse=True)
for prim_node in prim_nodes:
for block in prim_node.blocks():
nodes += block.nodes()
return set(node.kind() for node in nodes)
def export_c_graph(location, graph):
"""Convert the graph to an onnx model and export it to the location."""
import datetime
import os
if not os.path.exists(location):
os.makedirs(location)
time_stamp = datetime.datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
fname = os.path.join(location, f"tvm_exported_c_graph_{time_stamp}.txt")
with open(f"{fname}", "w") as f:
f.write(str(graph))
def from_pytorch(
script_module,
input_infos,
custom_convert_map=None,
default_dtype="float32",
use_parser_friendly_name=False,
keep_quantized_weight=False,
export_renamed_c_graph_path=None,
):
"""Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
Parameters
----------
script_module : TopLevelTracedModule object
TorchScripted PyTorch graph
Note: We currently only support traces (ie: torch.jit.trace(model, input))
input_infos : List of tuples
Can be (input name, input shape) or (input name, (input shape, input types))
Graph level input shape and type list
The same input names need to be used for deployment, so choose easy to
remember names (such as: input0, input1)
e.g.
[('input0', (1, 2)), ('input1', (3, 4))]
or
[('input0', ((1, 2), 'int')), ('input1', ((3, 4), 'float'))]
custom_convert_map : Dictionary of str to Relay op
A custom op conversion map in the same format as _convert_map above
default_type : str
The default dtype to use when type information is not provided by PyTorch.
use_parser_friendly_name : bool
When True, replace '.' with `_' in a original parameter name.
The Relay text parser treats a variable name followed by a period as a tuple element access,
so a variable name like "dense.weight" cannot be parsed correctly.
Use this option when you want to run the AnnotateSpans pass on the imported module.
keep_quantized_weight : bool
Return quantized weights and bias, rather than float ones. PyTorch stores quantized weights
in a custom format, so we cannot directly access 8 bit weights as Numpy arrays. We use
a PyTorch function to unpack quantized weights into float32 arrays and quantization
parameters. By default, we return float32 weights and rely on the QNN lowering and the
Relay constant folding pass to quantize weights at compile time. In BYOC use cases, however,
we cannot apply the constant folding pass on a QNN graph. If keep_quantized_weight is True,
we quantize weights in the frontend using a function that is equivalent to
qnn.op.quantize(...) operating on Numpy arrays.
export_renamed_c_graph_path : str, optional
Export the renamed torch._C.Graph to the path.
During the conversion, variable names in torch._C.Graph will be assigned based on their op
types. The exported text file can be the reference to spans.
Returns
-------
mod : tvm.IRModule
The module that optimizations will be performed on.
params : dict of str to tvm.runtime.NDArray
Dict of converted parameters stored in tvm.runtime.ndarray format
"""
import torch
mod = tvm.IRModule()
prelude = Prelude(mod)
enable_lower_all_tuples = True
converter = PyTorchOpConverter(prelude, default_dtype, use_parser_friendly_name)
graph = script_module.graph.copy()
# Check if lower_all_tuples pass can be enabled
graph_inputs = list(graph.inputs())
for inp in graph_inputs:
if inp.type().kind() == "TupleType" or inp.type().kind() == "ListType":
enable_lower_all_tuples = False
break
_run_jit_passes(graph, enable_lower_all_tuples)
if custom_convert_map:
converter.update_convert_map(custom_convert_map)
op_names = get_all_op_names(graph)
converter.report_missing_conversion(op_names)
is_module = isinstance(script_module, torch.jit.ScriptModule)
params = script_module.state_dict() if is_module else {}
outputs = _get_relay_input_vars(
graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module
)
if use_parser_friendly_name:
new_names = [key.replace(".", "_") for key in params.keys()]
params = dict(zip(new_names, params.values()))
# rename _C.Graph here for constructing meaningful source name of graph nodes
# by doing so, we could Use source_map as the reference to rename model parameters
source_map = _debug_rename(graph, use_parser_friendly_name)
param_vars, tensors, packed_param_map, param_debug_name_map = convert_params(
graph, params, source_map, use_parser_friendly_name
)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
outputs.update(param_vars)
# For quantized models
quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"])
if len(quantized_ops.intersection(set(op_names))) > 0:
weight_quant_params = qnn_torch.get_weight_quant_params(
script_module, packed_param_map.values()
)
qnn_torch.inline_input_quant_params_for_fx(graph, tensors, param_debug_name_map)
input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(
outputs,
packed_param_map,
weight_quant_params,
input_scales_for_bias,
keep_quantized_weight,
)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
converter.update_convert_map(qnn_torch.convert_map)
operator_nodes = _get_operator_nodes(
graph.nodes(), converter.source_map, converter.op_type_dict, use_parser_friendly_name
)
ret_name = _get_input_names(graph.return_node())
outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
# ListConstruct kept original python list. Convert to tuple.
outputs = [_expr.Tuple(output) if isinstance(output, list) else output for output in outputs]
if len(outputs) > 1:
ret = _expr.Tuple(outputs)
else:
ret = outputs[0]
# Separate data inputs and parameters to make sure data inputs come first.
func_args = []
data_inputs = []
for arg in _analysis.free_vars(ret):
if arg.name_hint not in tvm_params.keys():
data_inputs.append(arg)
else:
func_args.append(arg)
# Ensures the order of data_input is the same as the order of inputs specified in input_info.
order_input_infos = {
input_info[0]: len(input_infos) - idx for idx, input_info in enumerate(input_infos)
}
data_inputs = sorted(
data_inputs,
key=lambda data_input: order_input_infos[data_input.name_hint]
if data_input.name_hint in order_input_infos
else -1,
reverse=True,
)
func_args = data_inputs + func_args
mod["main"] = tvm.relay.Function(func_args, ret)
if export_renamed_c_graph_path:
export_c_graph(export_renamed_c_graph_path, graph)
return transform.RemoveUnusedFunctions()(mod), tvm_params