| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name |
| """A prelude containing useful global functions and ADT definitions.""" |
| from tvm.ir import IRModule, TypeCall |
| from tvm.tir import Any |
| from tvm.relay.transform import ToANormalFormExpr |
| |
| from .ty import GlobalTypeVar, TensorType, scalar_type |
| from .expr import Var, GlobalVar, If, const |
| from .function import Function |
| from .op.tensor import add, subtract, equal |
| from .adt import Constructor, TypeData, Clause, Match |
| from .adt import PatternConstructor, PatternVar, PatternWildcard |
| from . import op, transform |
| from .analysis import free_vars |
| |
| |
| def get_tensor_array_shape(expr, dtype, prelude): |
| """Get the static shape of a tensor array if it has fixed rank shape. |
| |
| By design, static ADT tensor in TVM has type name in the format |
| of static_tensor_dim0_dim1_..._dimN_t. |
| |
| Parameters |
| ---------- |
| expr : Relay Expr |
| Input expression. |
| |
| dtype : str |
| Data type. |
| |
| prelude : Prelude |
| Tensor array prelude |
| |
| Returns |
| ------- |
| shape : tuple of (int, Any) or None |
| The output shape. None if input tensor array |
| has dynamic shape. |
| """ |
| mod = prelude.mod |
| mod["main"] = Function(free_vars(expr), expr) |
| mod = transform.InferType()(mod) |
| checked_type = mod["main"].body.checked_type |
| assert isinstance(checked_type, TypeCall), "Input must be a tensor array." |
| ta_type_str = checked_type.args[0].func.name_hint |
| static_ta_ty_start = "static_tensor_{}".format(dtype) |
| if ta_type_str.startswith(static_ta_ty_start): |
| shape_str = ta_type_str.replace("{}_".format(static_ta_ty_start), "").replace("_t", "") |
| shape = [] |
| if "scalar" not in shape_str: |
| for dim_str in shape_str.split("_"): |
| if dim_str in ["?", "any"]: |
| shape.append(Any()) |
| else: |
| shape.append(int(dim_str)) |
| return tuple(shape) |
| return None |
| |
| |
| def _get_name_static(canonical, dtype, shape, batch_dim=None, extra_shapes=None): |
| """Get name for static shape tensor array op |
| |
| By design, static ADT tensor in TVM has type name in the format |
| of static_tensor_dim0_dim1_..._dimN_t |
| or static_tensor_batch1_dim0_dim1_..._dimN_t if tensorlist stack only have one item. |
| |
| Parameters |
| ---------- |
| canonical : String |
| Tensor array op name |
| |
| dtype : str |
| Data type. |
| |
| shape : tuple of (int, Any) or None |
| Tensor array shape |
| |
| batch_dim: None or int |
| 1 if tensorlist stack only have one item. |
| None by default |
| |
| Returns |
| ------- |
| name : String |
| The tensor array op name |
| """ |
| shape_str = _to_str(shape) |
| |
| if extra_shapes is not None: |
| for n, s in extra_shapes.items(): |
| extra_shape_str = "_{}_{}".format(n, _to_str(s)) |
| shape_str += extra_shape_str |
| |
| if len(shape_str) == 0: |
| shape_str = "scalar" |
| if canonical == "tensor_t": |
| return "static_tensor_{}_{}_t".format(dtype, shape_str) |
| if batch_dim is None or canonical in ["tensor_constructor", "tensor_nil"]: |
| return "{}_{}_{}".format(canonical, dtype, shape_str) |
| if batch_dim != 1: |
| return "{}_{}_{}".format(canonical, dtype, shape_str) |
| return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim), shape_str) |
| |
| |
| def _to_str(shape): |
| dim_names = [] |
| for dim in shape: |
| if isinstance(dim, Any): |
| dim_names.append("any") |
| else: |
| dim_names.append(str(dim)) |
| return "_".join(dim_names) |
| |
| |
| class StaticTensorArrayOps(object): |
| """Contains tensor array related ops for fixed rank tensor array""" |
| |
| def __init__(self, prelude, dtype, shape, batch_dim=None): |
| """Create tensor array ops registry""" |
| self.prelude = prelude |
| self.dtype = dtype |
| self.shape = shape |
| self.batch_dim = batch_dim |
| self.list, self.cons, self.nil = self.prelude.mod.get_type("List") |
| |
| def get_name(self, canonical, extra_shapes=None): |
| """Get name corresponding to the canonical name""" |
| return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim, extra_shapes) |
| |
| def get_global_var(self, canonical): |
| """Get global corresponding to the canonical name""" |
| return self.prelude.get_global_var_static(canonical, self.dtype, self.shape, self.batch_dim) |
| |
| def get_type(self, canonical): |
| """Get type corresponding to the canonical name""" |
| return self.prelude.get_type_static(canonical, self.dtype, self.shape) |
| |
| def get_ctor(self, canonical): |
| """Get ctor corresponding to the canonical name""" |
| return self.prelude.get_ctor_static("tensor_t", canonical, self.dtype, self.shape) |
| |
| def define_tensor_adt(self): |
| """Defines the static tensor ADT, which is the container for tensors |
| with fixed shapes.""" |
| tensor_type_name = self.get_name("tensor_t") |
| |
| # This is effectively functioning as a monomorphizer. |
| # TODO(@jroesch): we should add full shape polymoprhism |
| # and do monomorphization. |
| # |
| # Skip register if tensor type is already registered. |
| global_type_names = set() |
| for g_ty_var in self.prelude.mod.get_global_type_vars(): |
| global_type_names.add(g_ty_var.name_hint) |
| |
| if tensor_type_name in global_type_names: |
| self.tensor_type_var = self.get_type("tensor_t") |
| return |
| |
| self.tensor_type_var = GlobalTypeVar(tensor_type_name) |
| |
| tensor_type = TensorType(self.shape, self.dtype) |
| tensor_constructor_name = self.get_name("tensor_constructor") |
| |
| tensor_nil_name = self.get_name("tensor_nil") |
| tensor_nil_case = Constructor(tensor_nil_name, [], self.tensor_type_var) |
| tensor_case = Constructor(tensor_constructor_name, [tensor_type], self.tensor_type_var) |
| |
| self.prelude.mod[self.tensor_type_var] = TypeData( |
| self.tensor_type_var, [], [tensor_nil_case, tensor_case] |
| ) |
| |
| def define_tensor_array(self): |
| """Defines a function to create a tensor array with size n. |
| tensor_array(n) : Tensor[(), int32] -> list[tensor_t] |
| """ |
| tensor_array_constructor_name = self.get_name("tensor_array") |
| tensor_array_constructor_var = self._create_global_var(tensor_array_constructor_name) |
| |
| tensor_nil_var = self.get_ctor("tensor_nil") |
| tensor_type_var = self.get_ctor("tensor_t") |
| n = Var("x", scalar_type("int32")) |
| body = If( |
| equal(n, const(0)), |
| self.nil(), |
| self.cons(tensor_nil_var(), tensor_array_constructor_var(subtract(n, const(1)))), |
| ) |
| self.prelude.mod[tensor_array_constructor_var] = Function( |
| [n], body, self.list(tensor_type_var()), [] |
| ) |
| |
| def define_tensor_take(self): |
| """Defines a function to return a range of tensor_t on axis 0. |
| tensor_take(t, lower, upper) : |
| tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t |
| """ |
| # We don't register take for scalar tensor. |
| ndim = len(self.shape) |
| if ndim == 0: |
| return |
| |
| take_name = self.get_name("tensor_take") |
| |
| if self.is_cached(take_name): |
| return |
| |
| take_var = GlobalVar(take_name) |
| |
| origin_tensor_constructor = self.get_ctor("tensor_constructor") |
| |
| output_shape = [ |
| Any(), |
| ] + list(self.shape[1:]) |
| tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape(output_shape) |
| |
| t = Var("tensor", self.tensor_type_var()) |
| lower = Var("lower", scalar_type("int32")) |
| upper = Var("upper", scalar_type("int32")) |
| tvar = Var("t") |
| case = Clause( |
| PatternConstructor(origin_tensor_constructor, [PatternVar(tvar)]), |
| tensor_constructor(op.take(tvar, op.arange(lower, upper, dtype="int32"), axis=0)), |
| ) |
| self.prelude.mod[take_var] = Function( |
| [t, lower, upper], Match(t, [case], False), tensor_type_var(), [] |
| ) |
| |
| def define_tensor_concatenate(self): |
| """Defines a function to concatenate two tensor_t on axis 0. |
| tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t |
| """ |
| # We don't register concatenate for scalar tensor. |
| ndim = len(self.shape) |
| if ndim == 0: |
| return |
| |
| concat_name = self.get_name("tensor_concatenate") |
| concat_var = GlobalVar(concat_name) |
| if self.is_cached(concat_name): |
| return |
| |
| output_shape = [ |
| Any(), |
| ] + list(self.shape[1:]) |
| tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape(output_shape) |
| |
| origin_tensor_constructor = self.get_ctor("tensor_constructor") |
| origin_tensor_type_var = self.tensor_type_var |
| x = Var("x", origin_tensor_type_var()) |
| y = Var("y", origin_tensor_type_var()) |
| t1 = Var("t1") |
| t2 = Var("t2") |
| |
| case = Clause( |
| PatternConstructor(origin_tensor_constructor, [PatternVar(t1)]), |
| Match( |
| y, |
| [ |
| Clause( |
| PatternConstructor(origin_tensor_constructor, [PatternVar(t2)]), |
| tensor_constructor(op.concatenate([t1, t2], axis=0)), |
| ) |
| ], |
| False, |
| ), |
| ) |
| |
| self.prelude.mod[concat_var] = Function( |
| [x, y], Match(x, [case], False), tensor_type_var(), [] |
| ) |
| |
| def define_tensor_expand_dims(self): |
| """Defines a function to grow a tensor_t's rank by adding one dimension in front |
| of the original tensor_t. |
| tensor_expand_dims(t) : tensor_t -> tensor_t |
| """ |
| expand_dims_name = self.get_name("tensor_expand_dims") |
| expand_dims_var = self._create_global_var(expand_dims_name) |
| setattr(self.prelude, expand_dims_name, expand_dims_var) |
| origin_tensor_type_var = self.tensor_type_var |
| origin_tensor_constructor = self.get_ctor("tensor_constructor") |
| x = Var("x", origin_tensor_type_var()) |
| |
| # Note: we set the added axis to be Any() instead of 1 due to |
| # in stack op, we need to recursively concatenate. |
| new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim |
| tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape( |
| [ |
| new_axis, |
| ] |
| + list(self.shape) |
| ) |
| t = Var("t") |
| case = Clause( |
| PatternConstructor(origin_tensor_constructor, [PatternVar(t)]), |
| tensor_constructor(op.expand_dims(t, 0, 1)), |
| ) |
| |
| self.prelude.mod[expand_dims_var] = Function( |
| [x], Match(x, [case], False), tensor_type_var(), [] |
| ) |
| |
| def define_tensor_array_read(self): |
| """Defines a function to get the nth element of a list. Assume the list has at least one |
| element. |
| tensor_array_read(ta, n) : list[static_tensor_t] -> Tensor[(), int32] -> |
| Tensor[self.shape, self.dtype] |
| """ |
| read_name = self.get_name("tensor_array_read") |
| |
| if self.is_cached(read_name): |
| return |
| |
| read_var = GlobalVar(read_name) |
| |
| tensor_array = Var("tensor_array", self.list(self.tensor_type_var())) |
| n = Var("x", scalar_type("int32")) |
| self.prelude.mod[read_var] = Function( |
| [tensor_array, n], self.prelude.nth(tensor_array, n), self.tensor_type_var(), [] |
| ) |
| |
| def is_cached(self, name): |
| try: |
| self.prelude.mod.get_global_var(name) |
| return True |
| except ValueError: |
| return False |
| |
| def define_tensor_array_write(self): |
| """Defines a function to update a tensor array at index n with value v. |
| tensor_array_write(ta, n, v) : |
| list[static_tensor_t] -> Tensor[(), int32] -> Tensor[self.shape, self.dtype] -> |
| list[static_tensor_t] |
| """ |
| write_name = self.get_name("tensor_array_write") |
| if self.is_cached(write_name): |
| return |
| |
| write_var = GlobalVar(write_name) |
| tensor_array = Var("tensor_array", self.list(self.tensor_type_var())) |
| n = Var("x", scalar_type("int32")) |
| v = Var("v", self.tensor_type_var()) |
| self.prelude.mod[write_var] = Function( |
| [tensor_array, n, v], |
| self.prelude.update(tensor_array, n, v), |
| self.list(self.tensor_type_var()), |
| [], |
| ) |
| |
| def define_tensor_array_unstack(self): |
| """Defines a function to unstack the values of a tensor_t in a tensor array. |
| tensor_array_unstack_tensor(t) : tensor_t -> list[tensor_t] |
| """ |
| ndim = len(self.shape) |
| # We don't register unstack for scalar tensor array |
| if ndim == 0: |
| return |
| |
| helper_name = self.get_name("tensor_array_unstack_helper") |
| helper_var = self._create_global_var(helper_name) |
| setattr(self.prelude, helper_name, helper_var) |
| tensor = Var("t", TensorType(self.shape, self.dtype)) |
| up = Var("up", scalar_type("int32")) |
| i = Var("i", scalar_type("int32")) |
| tensor_var = Var("tensor", TensorType(self.shape, self.dtype)) |
| |
| reduced_tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape(self.shape[1:]) |
| helper_body = If( |
| equal(i, up), |
| self.nil(), |
| self.cons( |
| tensor_constructor(op.take(tensor, i, axis=0)), |
| helper_var(add(i, const(1)), up, tensor), |
| ), |
| ) |
| self.prelude.mod[helper_var] = Function( |
| [i, up, tensor], helper_body, self.list(reduced_tensor_type_var()), [] |
| ) |
| |
| unstack_name = self.get_name("tensor_array_unstack") |
| unstack_var = self._create_global_var(unstack_name) |
| setattr(self.prelude, unstack_name, unstack_var) |
| shape = op.shape_of(tensor_var) |
| unstack_length = op.take(shape, const(0)) |
| self.prelude.mod[unstack_var] = Function( |
| [tensor_var], |
| helper_var(const(0), unstack_length, tensor_var), |
| self.list(reduced_tensor_type_var()), |
| [], |
| ) |
| |
| def define_tensor_array_scatter(self, indices_shape=None, force_update=False): |
| """Defines a function to scatter the values of a tensor_t in indices of a tensor array. |
| tensor_array_scatter(ta, indices, value) : |
| list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] |
| |
| Set static indices shape by specifying indices_shape. |
| Set force_update to get static indices shape operator. |
| """ |
| # When this operator has already been registered, only update |
| # when force_update is set. This should be used only when we need to |
| # redefine this op for static indices shape. |
| |
| extra_shapes = {"indices": indices_shape} if indices_shape is not None else None |
| tensor_array_scatter_name = self.get_name("tensor_array_scatter", extra_shapes) |
| if hasattr(self.prelude, tensor_array_scatter_name) and not force_update: |
| return |
| |
| tensor_array_scatter_helper_name = self.get_name( |
| "tensor_array_scatter_helper", extra_shapes |
| ) |
| |
| tensor_array_scatter_helper_var = self._create_global_var(tensor_array_scatter_helper_name) |
| ta = Var("ta", self.list(self.tensor_type_var())) |
| current = Var("current", scalar_type("int32")) |
| limit = Var("limit", scalar_type("int32")) |
| indices_ = Var("indices_", TensorType(indices_shape or [Any()], "int32")) |
| values_ = Var("values_", self.list(self.tensor_type_var())) |
| write_var = self.get_global_var("tensor_array_write") |
| read_var = self.get_global_var("tensor_array_read") |
| helper_body = If( |
| equal(current, limit), |
| ta, |
| tensor_array_scatter_helper_var( |
| write_var(ta, op.take(indices_, current), read_var(values_, current)), |
| add(current, const(1)), |
| limit, |
| indices_, |
| values_, |
| ), |
| ) |
| self.prelude.mod[tensor_array_scatter_helper_var] = Function( |
| [ta, current, limit, indices_, values_], |
| helper_body, |
| self.list(self.tensor_type_var()), |
| [], |
| ) |
| |
| tensor_array_scatter_var = self._create_global_var(tensor_array_scatter_name) |
| setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var) |
| tensor_array = Var("tensor_array", self.list(self.tensor_type_var())) |
| |
| indices = Var("indices", TensorType(indices_shape or [Any()], "int32")) |
| values = Var("values", self.list(self.tensor_type_var())) |
| if indices_shape is None: |
| indices_shape = op.shape_of(indices) |
| limit = op.take(indices_shape, const(0)) |
| else: |
| limit = const(indices_shape[0]) |
| |
| body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values) |
| self.prelude.mod[tensor_array_scatter_var] = Function( |
| [tensor_array, indices, values], body, self.list(self.tensor_type_var()), [] |
| ) |
| |
| def define_tensor_array_split(self, value_shape=None, lengths_shape=None, force_update=False): |
| """Defines a function to split the values of a tensor_t into a tensor array. |
| tensor_array_split(ta, value, lengths) : |
| list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t] |
| |
| Set static value and lengths shapes by specifying value_shape and lengths_shape. |
| Set force_update to get static value and lengths shape operator. |
| """ |
| # Skip scalar case |
| ndim = len(self.shape) |
| if ndim == 0: |
| return |
| |
| # When this operator has already been registered, only update |
| # when force_update is set. This should be used only when we need to |
| # redefine this op for static value/indices shape. |
| split_name = self.get_name("tensor_array_split") |
| |
| if self.is_cached(split_name): |
| if not force_update: |
| return |
| tensor_array_split_helper_var = self.get_global_var("ta_split_helper") |
| split_var = self.get_global_var("tensor_array_split") |
| else: |
| tensor_array_split_helper_name = self.get_name("ta_split_helper") |
| tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name) |
| split_var = GlobalVar(split_name) |
| |
| output_shape = [ |
| Any(), |
| ] + list(self.shape[1:]) |
| output_tensor_type_var, _, output_ops = self._get_adt_by_shape(output_shape) |
| output_ops.define_tensor_array_write() |
| write_var = output_ops.get_global_var("tensor_array_write") |
| |
| if value_shape is None: |
| value_type_var = self.tensor_type_var |
| take_var = self.get_global_var("tensor_take") |
| else: |
| value_type_var, _, value_adts = self._get_adt_by_shape(value_shape) |
| value_adts.define_tensor_take() |
| take_var = value_adts.get_global_var("tensor_take") |
| |
| ta1 = Var("tensor_array", self.list(output_tensor_type_var())) |
| value1 = Var("value1", value_type_var()) |
| offset1 = Var("offset1", scalar_type("int32")) |
| current1 = Var("current1", scalar_type("int32")) |
| limit1 = Var("limit1", scalar_type("int32")) |
| lengths1 = Var("lengths", TensorType(lengths_shape or [Any()], "int32")) |
| |
| helper1_body = If( |
| equal(current1, limit1), |
| ta1, |
| write_var( |
| tensor_array_split_helper_var( |
| ta1, |
| value1, |
| add(offset1, op.take(lengths1, current1)), |
| add(current1, const(1)), |
| limit1, |
| lengths1, |
| ), |
| current1, |
| take_var(value1, offset1, add(op.take(lengths1, current1), offset1)), |
| ), |
| ) |
| |
| self.prelude.mod[tensor_array_split_helper_var] = Function( |
| [ta1, value1, offset1, current1, limit1, lengths1], |
| helper1_body, |
| self.list(output_tensor_type_var()), |
| [], |
| ) |
| tensor_array = Var("tensor_array", self.list(output_tensor_type_var())) |
| |
| value = Var("value", value_type_var()) |
| lengths = Var("lengths", TensorType(lengths_shape or [Any()], "int32")) |
| if lengths_shape is None: |
| lengths_shape = op.shape_of(lengths) |
| lengths_limit = op.take(lengths_shape, const(0)) |
| else: |
| lengths_limit = const(lengths_shape[0]) |
| body = tensor_array_split_helper_var( |
| tensor_array, value, const(0), const(0), lengths_limit, lengths |
| ) |
| |
| self.prelude.mod[split_var] = Function( |
| [tensor_array, value, lengths], body, self.list(output_tensor_type_var()), [] |
| ) |
| |
| def define_tensor_array_concat(self): |
| """Defines a function to return the values in the tensor array as concatenated tensor_t. |
| tensor_array_concat(ta) : list[tensor_t] -> tensor_t |
| """ |
| # We don't register concat for scalar tensor array. |
| ndim = len(self.shape) |
| if ndim == 0: |
| return |
| |
| concat_name = self.get_name("tensor_array_concat") |
| |
| if self.is_cached(concat_name): |
| return |
| |
| concat_var = GlobalVar(concat_name) |
| |
| output_shape = [ |
| Any(), |
| ] + list(self.shape[1:]) |
| |
| tensor_type_var, _, output_ops = self._get_adt_by_shape(output_shape) |
| |
| # Register tensor concatenate and get tensor_nil var for output shape |
| output_ops.define_tensor_concatenate() |
| tensor_concat_var = output_ops.get_global_var("tensor_concatenate") |
| tensor_nil_var = output_ops.get_ctor("tensor_nil") |
| |
| tensor_array = Var("tensor_array", self.list(tensor_type_var())) |
| hd = Var("hd") |
| tl = Var("tl") |
| nil_case = Clause(PatternConstructor(self.nil), tensor_nil_var()) |
| cons_case = Clause( |
| PatternConstructor(self.cons, [PatternVar(hd), PatternVar(tl)]), |
| Match( |
| tl, |
| [ |
| Clause(PatternConstructor(self.nil), hd), |
| Clause(PatternWildcard(), tensor_concat_var(hd, concat_var(tl))), |
| ], |
| False, |
| ), |
| ) |
| self.prelude.mod[concat_var] = Function( |
| [tensor_array], Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), [] |
| ) |
| |
| def define_tensor_array_stack(self): |
| """Defines a function to get the values in the tensor array as a stack tensor_t. |
| tensor_array_stack(l) : list[tensor_t] -> tensor_t |
| """ |
| stack_name = self.get_name("tensor_array_stack") |
| stack_var = self._create_global_var(stack_name) |
| setattr(self.prelude, stack_name, stack_var) |
| tensor_array = Var("tensor_array", self.list(self.tensor_type_var())) |
| expand_dims_var = self.get_global_var("tensor_expand_dims") |
| |
| # Register tensor_concatenate for output_shape |
| new_axis = Any() if not self.batch_dim or self.batch_dim != 1 else self.batch_dim |
| output_shape = [ |
| new_axis, |
| ] + list(self.shape) |
| _, _, output_ops = self._get_adt_by_shape(output_shape) |
| output_ops.define_tensor_concatenate() |
| concat_var = output_ops.get_global_var("tensor_concatenate") |
| |
| tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) |
| if self.batch_dim is not None and self.batch_dim == 1: |
| # only one element |
| tensors = self.prelude.id( |
| self.prelude.hd(tensor_array_expand_dims), |
| ) |
| else: |
| tensors = self.prelude.foldl( |
| concat_var, |
| self.prelude.hd(tensor_array_expand_dims), |
| self.prelude.tl(tensor_array_expand_dims), |
| ) |
| |
| output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape) |
| self.prelude.mod[stack_var] = Function( |
| [tensor_array], tensors, output_tensor_type_var(), [] |
| ) |
| |
| def define_tensor_array_gather(self): |
| """Defines a function to return the selected values in a tensor array as tensor_t. |
| tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t |
| """ |
| helper_name = self.get_name("tensor_array_gather_helper") |
| helper_var = self._create_global_var(helper_name) |
| |
| new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim |
| output_shape = [ |
| new_axis, |
| ] + list(self.shape) |
| output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape) |
| stack_var = self.get_global_var("tensor_array_stack") |
| read_var = self.get_global_var("tensor_array_read") |
| ta = Var("ta", self.list(self.tensor_type_var())) |
| accu = Var("accu", self.list(self.tensor_type_var())) |
| current = Var("current", scalar_type("int32")) |
| limit = Var("limit", scalar_type("int32")) |
| indices_ = Var("indices_", TensorType([Any()], "int32")) |
| helper_body = If( |
| equal(current, const(0)), |
| stack_var(accu), |
| helper_var( |
| ta, |
| self.cons(read_var(ta, op.take(indices_, subtract(current, const(1)))), accu), |
| subtract(current, const(1)), |
| limit, |
| indices_, |
| ), |
| ) |
| self.prelude.mod[helper_var] = Function( |
| [ta, accu, current, limit, indices_], helper_body, output_tensor_type_var(), [] |
| ) |
| gather_name = self.get_name("tensor_array_gather") |
| gather_var = self._create_global_var(gather_name) |
| |
| tensor_array = Var("tensor_array", self.list(self.tensor_type_var())) |
| indices = Var("indices", TensorType([Any()], "int32")) |
| indices_shape = op.shape_of(indices) |
| limit = op.take(indices_shape, const(0)) |
| body = helper_var(tensor_array, self.nil(), limit, limit, indices) |
| self.prelude.mod[gather_var] = Function( |
| [tensor_array, indices], body, output_tensor_type_var(), [] |
| ) |
| |
| def define_tensor_get_data(self): |
| """Defines a function to get a Tensor from tensor_t with given shape.""" |
| tensor_get_data_name = self.get_name("tensor_get_data") |
| tensor_get_data_var = self._create_global_var(tensor_get_data_name) |
| |
| tensor_constructor = self.get_ctor("tensor_constructor") |
| t = Var("tensor", self.tensor_type_var()) |
| tvar = Var("t") |
| case = Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar) |
| self.prelude.mod[tensor_get_data_var] = Function( |
| [t], Match(t, [case], False), TensorType(self.shape, self.dtype), [] |
| ) |
| |
| def register(self): |
| """Register all tensor array ops in Prelude""" |
| self.define_tensor_adt() |
| self.define_tensor_take() |
| self.define_tensor_concatenate() |
| self.define_tensor_expand_dims() |
| self.define_tensor_array() |
| self.define_tensor_array_read() |
| self.define_tensor_array_write() |
| self.define_tensor_array_unstack() |
| self.define_tensor_array_scatter() |
| self.define_tensor_array_split() |
| self.define_tensor_array_concat() |
| self.define_tensor_array_stack() |
| self.define_tensor_array_gather() |
| self.define_tensor_get_data() |
| |
| def _get_adt_by_shape(self, shape): |
| """Get ADT type and constructor with given shape.""" |
| adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape, self.batch_dim) |
| adt_ops.define_tensor_adt() |
| tensor_type_var = adt_ops.get_type("tensor_t") |
| tensor_constructor = adt_ops.get_ctor("tensor_constructor") |
| return tensor_type_var, tensor_constructor, adt_ops |
| |
| def _create_global_var(self, name): |
| """Create a GlobalVar if doesn't exist in prelude.""" |
| global_var_name_set = set() |
| for g_var_name in self.prelude.mod.get_global_vars(): |
| global_var_name_set.add(g_var_name.name_hint) |
| if name not in global_var_name_set: |
| gvar = GlobalVar(name) |
| else: |
| gvar = self.prelude.mod.get_global_var(name) |
| |
| return gvar |
| |
| |
| class TensorArrayOps(object): |
| """Contains tensor array related ops""" |
| |
| def __init__(self, prelude, dtype): |
| """Create tensor array ops registry""" |
| self.prelude = prelude |
| self.dtype = dtype |
| self.list, self.cons, self.nil = self.prelude.mod.get_type("List") |
| |
| def get_name(self, canonical): |
| """Get name corresponding to the canonical name""" |
| return self.prelude.get_name(canonical, self.dtype) |
| |
| def get_global_var(self, canonical): |
| """Get global corresponding to the canonical name""" |
| return self.prelude.get_global_var(canonical, self.dtype) |
| |
| def get_type(self, canonical): |
| """Get type corresponding to the canonical name""" |
| return self.prelude.get_type(canonical, self.dtype) |
| |
| def get_ctor(self, canonical): |
| """Get ctor corresponding to the canonical name""" |
| return self.prelude.get_ctor(self.tensor_type_var.name_hint, canonical, self.dtype) |
| |
| def define_tensor_adt(self): |
| """Defines the dynamic tensor ADT, which is the container for tensors |
| with variable shapes.""" |
| tensor_type_name = self.get_name("tensor_t") |
| self.tensor_type_var = tensor_type_var = GlobalTypeVar(tensor_type_name) |
| |
| tensor0_type = TensorType([], self.dtype) |
| tensor1_type = TensorType([Any()], self.dtype) |
| tensor2_type = TensorType([Any(), Any()], self.dtype) |
| tensor3_type = TensorType([Any(), Any(), Any()], self.dtype) |
| tensor4_type = TensorType([Any(), Any(), Any(), Any()], self.dtype) |
| tensor5_type = TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype) |
| tensor6_type = TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype) |
| tensor_nil_name = self.get_name("tensor_nil") |
| tensor0_name = self.get_name("tensor0") |
| tensor1_name = self.get_name("tensor1") |
| tensor2_name = self.get_name("tensor2") |
| tensor3_name = self.get_name("tensor3") |
| tensor4_name = self.get_name("tensor4") |
| tensor5_name = self.get_name("tensor5") |
| tensor6_name = self.get_name("tensor6") |
| tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var) |
| tensor0_case = Constructor(tensor0_name, [tensor0_type], tensor_type_var) |
| tensor1_case = Constructor(tensor1_name, [tensor1_type], tensor_type_var) |
| tensor2_case = Constructor(tensor2_name, [tensor2_type], tensor_type_var) |
| tensor3_case = Constructor(tensor3_name, [tensor3_type], tensor_type_var) |
| tensor4_case = Constructor(tensor4_name, [tensor4_type], tensor_type_var) |
| tensor5_case = Constructor(tensor5_name, [tensor5_type], tensor_type_var) |
| tensor6_case = Constructor(tensor6_name, [tensor6_type], tensor_type_var) |
| |
| self.prelude.mod[tensor_type_var] = TypeData( |
| tensor_type_var, |
| [], |
| [ |
| tensor_nil_case, |
| tensor0_case, |
| tensor1_case, |
| tensor2_case, |
| tensor3_case, |
| tensor4_case, |
| tensor5_case, |
| tensor6_case, |
| ], |
| ) |
| |
| def define_tensor_take(self): |
| """Defines a function to return a range of tensor_t on axis 0. |
| tensor_take(t, lower, upper) : |
| tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t |
| """ |
| take_name = self.get_name("tensor_take") |
| take_var = GlobalVar(take_name) |
| |
| tensor_t = self.tensor_type_var |
| tensor1_var = self.get_ctor("tensor1") |
| tensor2_var = self.get_ctor("tensor2") |
| tensor3_var = self.get_ctor("tensor3") |
| tensor4_var = self.get_ctor("tensor4") |
| tensor5_var = self.get_ctor("tensor5") |
| tensor6_var = self.get_ctor("tensor6") |
| |
| t = Var("tensor", tensor_t()) |
| lower = Var("lower", scalar_type("int32")) |
| upper = Var("upper", scalar_type("int32")) |
| t1 = Var("t1") |
| t2 = Var("t2") |
| t3 = Var("t3") |
| t4 = Var("t4") |
| t5 = Var("t5") |
| t6 = Var("t6") |
| tensor1_case = Clause( |
| PatternConstructor(tensor1_var, [PatternVar(t1)]), |
| tensor1_var(op.take(t1, op.arange(lower, upper, dtype="int32"))), |
| ) |
| tensor2_case = Clause( |
| PatternConstructor(tensor2_var, [PatternVar(t2)]), |
| tensor2_var(op.take(t2, op.arange(lower, upper, dtype="int32"), axis=0)), |
| ) |
| tensor3_case = Clause( |
| PatternConstructor(tensor3_var, [PatternVar(t3)]), |
| tensor3_var(op.take(t3, op.arange(lower, upper, dtype="int32"), axis=0)), |
| ) |
| tensor4_case = Clause( |
| PatternConstructor(tensor4_var, [PatternVar(t4)]), |
| tensor4_var(op.take(t4, op.arange(lower, upper, dtype="int32"), axis=0)), |
| ) |
| tensor5_case = Clause( |
| PatternConstructor(tensor5_var, [PatternVar(t5)]), |
| tensor5_var(op.take(t5, op.arange(lower, upper, dtype="int32"), axis=0)), |
| ) |
| tensor6_case = Clause( |
| PatternConstructor(tensor6_var, [PatternVar(t6)]), |
| tensor6_var(op.take(t6, op.arange(lower, upper, dtype="int32"), axis=0)), |
| ) |
| self.prelude.mod[take_var] = Function( |
| [t, lower, upper], |
| Match( |
| t, |
| [ |
| tensor1_case, |
| tensor2_case, |
| tensor3_case, |
| tensor4_case, |
| tensor5_case, |
| tensor6_case, |
| ], |
| False, |
| ), |
| tensor_t(), |
| [], |
| ) |
| |
| def define_tensor_expand_dims(self): |
| """Defines a function to grow a tensor_t's rank by adding one dimension in front |
| of the original tensor_t. |
| tensor_expand_dims(t) : tensor_t -> tensor_t |
| """ |
| expand_dims_name = self.get_name("tensor_expand_dims") |
| expand_dims_var = GlobalVar(expand_dims_name) |
| tensor_type_var = self.tensor_type_var |
| |
| x = Var("x", tensor_type_var()) |
| t0 = Var("t0") |
| t1 = Var("t1") |
| t2 = Var("t2") |
| t3 = Var("t3") |
| t4 = Var("t4") |
| t5 = Var("t5") |
| tensor0_var = self.get_ctor("tensor0") |
| tensor1_var = self.get_ctor("tensor1") |
| tensor2_var = self.get_ctor("tensor2") |
| tensor3_var = self.get_ctor("tensor3") |
| tensor4_var = self.get_ctor("tensor4") |
| tensor5_var = self.get_ctor("tensor5") |
| tensor6_var = self.get_ctor("tensor6") |
| tensor0_case = Clause( |
| PatternConstructor(tensor0_var, [PatternVar(t0)]), tensor1_var(op.expand_dims(t0, 0, 1)) |
| ) |
| tensor1_case = Clause( |
| PatternConstructor(tensor1_var, [PatternVar(t1)]), tensor2_var(op.expand_dims(t1, 0, 1)) |
| ) |
| tensor2_case = Clause( |
| PatternConstructor(tensor2_var, [PatternVar(t2)]), tensor3_var(op.expand_dims(t2, 0, 1)) |
| ) |
| tensor3_case = Clause( |
| PatternConstructor(tensor3_var, [PatternVar(t3)]), tensor4_var(op.expand_dims(t3, 0, 1)) |
| ) |
| tensor4_case = Clause( |
| PatternConstructor(tensor4_var, [PatternVar(t4)]), tensor5_var(op.expand_dims(t4, 0, 1)) |
| ) |
| tensor5_case = Clause( |
| PatternConstructor(tensor5_var, [PatternVar(t5)]), tensor6_var(op.expand_dims(t5, 0, 1)) |
| ) |
| self.prelude.mod[expand_dims_var] = Function( |
| [x], |
| Match( |
| x, |
| [ |
| tensor0_case, |
| tensor1_case, |
| tensor2_case, |
| tensor3_case, |
| tensor4_case, |
| tensor5_case, |
| ], |
| False, |
| ), |
| tensor_type_var(), |
| ) |
| |
| def define_tensor_concat(self): |
| """Defines a function to concatenate two tensor_t on the first axis |
| |
| tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t |
| """ |
| concat_name = self.get_name("tensor_concatenate") |
| concat_var = GlobalVar(concat_name) |
| |
| tensor_type_var = self.tensor_type_var |
| x = Var("x", tensor_type_var()) |
| y = Var("y", tensor_type_var()) |
| |
| tensor1_var = self.get_ctor("tensor1") |
| tensor2_var = self.get_ctor("tensor2") |
| tensor3_var = self.get_ctor("tensor3") |
| tensor4_var = self.get_ctor("tensor4") |
| t11 = Var("t11") |
| t12 = Var("t12") |
| t21 = Var("t21") |
| t22 = Var("t22") |
| t31 = Var("t31") |
| t32 = Var("t32") |
| t41 = Var("t41") |
| t42 = Var("t42") |
| tensor1_case = Clause( |
| PatternConstructor(tensor1_var, [PatternVar(t11)]), |
| Match( |
| y, |
| [ |
| Clause( |
| PatternConstructor(tensor1_var, [PatternVar(t12)]), |
| tensor1_var(op.concatenate([t11, t12], axis=0)), |
| ) |
| ], |
| False, |
| ), |
| ) |
| tensor2_case = Clause( |
| PatternConstructor(tensor2_var, [PatternVar(t21)]), |
| Match( |
| y, |
| [ |
| Clause( |
| PatternConstructor(tensor2_var, [PatternVar(t22)]), |
| tensor2_var(op.concatenate([t21, t22], axis=0)), |
| ) |
| ], |
| False, |
| ), |
| ) |
| tensor3_case = Clause( |
| PatternConstructor(tensor3_var, [PatternVar(t31)]), |
| Match( |
| y, |
| [ |
| Clause( |
| PatternConstructor(tensor3_var, [PatternVar(t32)]), |
| tensor3_var(op.concatenate([t31, t32], axis=0)), |
| ) |
| ], |
| False, |
| ), |
| ) |
| tensor4_case = Clause( |
| PatternConstructor(tensor4_var, [PatternVar(t41)]), |
| Match( |
| y, |
| [ |
| Clause( |
| PatternConstructor(tensor4_var, [PatternVar(t42)]), |
| tensor4_var(op.concatenate([t41, t42], axis=0)), |
| ) |
| ], |
| False, |
| ), |
| ) |
| # op.concatenate does not support tensor with rank higher than 4 |
| self.prelude.mod[concat_var] = Function( |
| [x, y], |
| Match(x, [tensor1_case, tensor2_case, tensor3_case, tensor4_case], False), |
| tensor_type_var(), |
| ) |
| |
| def define_tensor_array(self): |
| """Defines a function to create a tensor array with size n. |
| tensor_array(n) : Tensor[(), int32] -> list[tensor_t] |
| """ |
| tensor_array_constructor_name = self.get_name("tensor_array") |
| tensor_array_constructor_var = GlobalVar(tensor_array_constructor_name) |
| setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var) |
| tensor_nil_var = self.get_ctor("tensor_nil") |
| tensor_type_var = self.get_ctor("tensor_t") |
| n = Var("x", scalar_type("int32")) |
| body = If( |
| equal(n, const(0)), |
| self.nil(), |
| self.cons(tensor_nil_var(), tensor_array_constructor_var(subtract(n, const(1)))), |
| ) |
| self.prelude.mod[tensor_array_constructor_var] = Function( |
| [n], body, self.list(tensor_type_var()), [] |
| ) |
| |
| def define_tensor_array_read(self): |
| """Defines a function to get the head of a list. Assume the list has at least one |
| element. |
| |
| tensor_array_read(ta, n) : list[tensor_t] -> Tensor[(), int32] -> tensor_t |
| """ |
| read_name = self.get_name("tensor_array_read") |
| read_var = GlobalVar(read_name) |
| setattr(self.prelude, read_name, read_var) |
| tensor_type_var = self.tensor_type_var |
| |
| tensor_array = Var("tensor_array", self.list(tensor_type_var())) |
| n = Var("x", scalar_type("int32")) |
| self.prelude.mod[read_var] = Function( |
| [tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [] |
| ) |
| |
| def define_tensor_array_write(self): |
| """Defines a function to update a tensor array at index n with value v. |
| tensor_array_write(ta, n, v) : |
| list[tensor_t] -> Tensor[(), int32] -> tensor_t -> list[tensor_t] |
| """ |
| write_name = self.get_name("tensor_array_write") |
| write_var = GlobalVar(write_name) |
| |
| tensor_type_var = self.tensor_type_var |
| tensor_array = Var("tensor_array", self.list(tensor_type_var())) |
| n = Var("x", scalar_type("int32")) |
| v = Var("v", tensor_type_var()) |
| self.prelude.mod[write_var] = Function( |
| [tensor_array, n, v], |
| self.prelude.update(tensor_array, n, v), |
| self.list(tensor_type_var()), |
| [], |
| ) |
| |
| def define_tensor_array_unstack_tensor1(self): |
| """Defines a function to unstack the values of a tensor_t with rank 1 in a tensor array. |
| tensor_array_unstack_tensor1(t) : tensor_t -> list[tensor_t] |
| """ |
| helper_name = self.get_name("tensor_array_unstack_tensor1_helper") |
| helper_var = GlobalVar(helper_name) |
| tensor = Var("t", TensorType([Any()], self.dtype)) |
| up = Var("up", scalar_type("int32")) |
| i = Var("i", scalar_type("int32")) |
| tensor_type_var = self.tensor_type_var |
| tensor0_var = self.get_ctor("tensor0") |
| helper_body = If( |
| equal(i, up), |
| self.nil(), |
| self.cons(tensor0_var(op.take(tensor, i)), helper_var(add(i, const(1)), up, tensor)), |
| ) |
| self.prelude.mod[helper_var] = Function( |
| [i, up, tensor], helper_body, self.list(tensor_type_var()), [] |
| ) |
| unstack_name = self.get_name("tensor_array_unstack_tensor1") |
| unstack_var = GlobalVar(unstack_name) |
| tensor1 = Var("tensor", TensorType([Any()], self.dtype)) |
| shape = op.shape_of(tensor1) |
| ndim = op.take(shape, const(0)) |
| self.prelude.mod[unstack_var] = Function( |
| [tensor1], helper_var(const(0), ndim, tensor1), self.list(tensor_type_var()), [] |
| ) |
| |
| def define_tensor_array_unstack_tensor2(self): |
| """Defines a function to unstack the values of a tensor_t with rank 2 in a tensor array. |
| |
| tensor_array_unstack_tensor2(t) : tensor_t -> list[tensor_t] |
| """ |
| helper_name = self.get_name("tensor_array_unstack_tensor2_helper") |
| helper_var = GlobalVar(helper_name) |
| setattr(self.prelude, helper_name, helper_var) |
| tensor = Var("t", TensorType([Any(), Any()], self.dtype)) |
| up = Var("up", scalar_type("int32")) |
| i = Var("i", scalar_type("int32")) |
| |
| helper_body = If( |
| equal(i, up), |
| self.nil(), |
| self.cons( |
| self.get_ctor("tensor1")(op.take(tensor, i, axis=0)), |
| helper_var(add(i, const(1)), up, tensor), |
| ), |
| ) |
| self.prelude.mod[helper_var] = Function( |
| [i, up, tensor], helper_body, self.list(self.tensor_type_var()), [] |
| ) |
| |
| tensor_array_unstack_tensor2_name = self.get_name("tensor_array_unstack_tensor2") |
| tensor_array_unstack_tensor2_var = GlobalVar(tensor_array_unstack_tensor2_name) |
| setattr(self.prelude, tensor_array_unstack_tensor2_name, tensor_array_unstack_tensor2_var) |
| tensor2 = Var("tensor", TensorType([Any(), Any()], self.dtype)) |
| shape = op.shape_of(tensor2) |
| ndim = op.take(shape, const(0)) |
| self.prelude.mod[tensor_array_unstack_tensor2_var] = Function( |
| [tensor2], |
| helper_var(const(0), ndim, tensor2), |
| self.list(self.tensor_type_var()), |
| [], |
| ) |
| |
| def define_tensor_array_unstack_tensor3(self): |
| """Defines a function to unstack the values of a tensor_t with rank 3 in a tensor array. |
| |
| tensor_array_unstack_tensor3(t) : tensor_t -> list[tensor_t] |
| """ |
| helper_name = self.get_name("tensor_array_unstack_tensor3_helper") |
| helper_var = GlobalVar(helper_name) |
| setattr(self.prelude, helper_name, helper_var) |
| tensor = Var("t", TensorType([Any(), Any(), Any()], self.dtype)) |
| up = Var("up", scalar_type("int32")) |
| i = Var("i", scalar_type("int32")) |
| |
| helper_body = If( |
| equal(i, up), |
| self.nil(), |
| self.cons( |
| self.get_ctor("tensor2")(op.take(tensor, i, axis=0)), |
| helper_var(add(i, const(1)), up, tensor), |
| ), |
| ) |
| self.prelude.mod[helper_var] = Function( |
| [i, up, tensor], helper_body, self.list(self.tensor_type_var()), [] |
| ) |
| |
| tensor_array_unstack_tensor3_name = self.get_name("tensor_array_unstack_tensor3") |
| tensor_array_unstack_tensor3_var = GlobalVar(tensor_array_unstack_tensor3_name) |
| setattr(self.prelude, tensor_array_unstack_tensor3_name, tensor_array_unstack_tensor3_var) |
| tensor3 = Var("tensor", TensorType([Any(), Any(), Any()], self.dtype)) |
| shape = op.shape_of(tensor3) |
| ndim = op.take(shape, const(0)) |
| self.prelude.mod[tensor_array_unstack_tensor3_var] = Function( |
| [tensor3], |
| helper_var(const(0), ndim, tensor3), |
| self.list(self.tensor_type_var()), |
| [], |
| ) |
| |
| def define_tensor_array_unstack_tensor4(self): |
| """Defines a function to unstack the values of a tensor_t with rank 4 in a tensor array. |
| |
| tensor_array_unstack_tensor4(t) : tensor_t -> list[tensor_t] |
| """ |
| helper_name = self.get_name("tensor_array_unstack_tensor4_helper") |
| helper_var = GlobalVar(helper_name) |
| setattr(self.prelude, helper_name, helper_var) |
| tensor = Var("t", TensorType([Any(), Any(), Any(), Any()], self.dtype)) |
| up = Var("up", scalar_type("int32")) |
| i = Var("i", scalar_type("int32")) |
| |
| helper_body = If( |
| equal(i, up), |
| self.nil(), |
| self.cons( |
| self.get_ctor("tensor3")(op.take(tensor, i, axis=0)), |
| helper_var(add(i, const(1)), up, tensor), |
| ), |
| ) |
| self.prelude.mod[helper_var] = Function( |
| [i, up, tensor], helper_body, self.list(self.tensor_type_var()), [] |
| ) |
| |
| tensor_array_unstack_tensor4_name = self.get_name("tensor_array_unstack_tensor4") |
| tensor_array_unstack_tensor4_var = GlobalVar(tensor_array_unstack_tensor4_name) |
| setattr(self.prelude, tensor_array_unstack_tensor4_name, tensor_array_unstack_tensor4_var) |
| tensor4 = Var("tensor", TensorType([Any(), Any(), Any(), Any()], self.dtype)) |
| shape = op.shape_of(tensor4) |
| ndim = op.take(shape, const(0)) |
| self.prelude.mod[tensor_array_unstack_tensor4_var] = Function( |
| [tensor4], |
| helper_var(const(0), ndim, tensor4), |
| self.list(self.tensor_type_var()), |
| [], |
| ) |
| |
| def define_tensor_array_unstack_tensor5(self): |
| """Defines a function to unstack the values of a tensor_t with rank 5 in a tensor array. |
| |
| tensor_array_unstack_tensor5(t) : tensor_t -> list[tensor_t] |
| """ |
| helper_name = self.get_name("tensor_array_unstack_tensor5_helper") |
| helper_var = GlobalVar(helper_name) |
| setattr(self.prelude, helper_name, helper_var) |
| tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype)) |
| up = Var("up", scalar_type("int32")) |
| i = Var("i", scalar_type("int32")) |
| |
| helper_body = If( |
| equal(i, up), |
| self.nil(), |
| self.cons( |
| self.get_ctor("tensor4")(op.take(tensor, i, axis=0)), |
| helper_var(add(i, const(1)), up, tensor), |
| ), |
| ) |
| self.prelude.mod[helper_var] = Function( |
| [i, up, tensor], helper_body, self.list(self.tensor_type_var()), [] |
| ) |
| |
| tensor_array_unstack_tensor5_name = self.get_name("tensor_array_unstack_tensor5") |
| tensor_array_unstack_tensor5_var = GlobalVar(tensor_array_unstack_tensor5_name) |
| setattr(self.prelude, tensor_array_unstack_tensor5_name, tensor_array_unstack_tensor5_var) |
| tensor5 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype)) |
| shape = op.shape_of(tensor5) |
| ndim = op.take(shape, const(0)) |
| self.prelude.mod[tensor_array_unstack_tensor5_var] = Function( |
| [tensor5], |
| helper_var(const(0), ndim, tensor5), |
| self.list(self.tensor_type_var()), |
| [], |
| ) |
| |
| def define_tensor_array_unstack_tensor6(self): |
| """Defines a function to unstack the values of a tensor_t with rank 6 in a tensor array. |
| |
| tensor_array_unstack_tensor6(t) : tensor_t -> list[tensor_t] |
| """ |
| helper_name = self.get_name("tensor_array_unstack_tensor6_helper") |
| helper_var = GlobalVar(helper_name) |
| setattr(self.prelude, helper_name, helper_var) |
| tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype)) |
| up = Var("up", scalar_type("int32")) |
| i = Var("i", scalar_type("int32")) |
| |
| helper_body = If( |
| equal(i, up), |
| self.nil(), |
| self.cons( |
| self.get_ctor("tensor5")(op.take(tensor, i, axis=0)), |
| helper_var(add(i, const(1)), up, tensor), |
| ), |
| ) |
| self.prelude.mod[helper_var] = Function( |
| [i, up, tensor], helper_body, self.list(self.tensor_type_var()), [] |
| ) |
| |
| tensor_array_unstack_tensor6_name = self.get_name("tensor_array_unstack_tensor6") |
| tensor_array_unstack_tensor6_var = GlobalVar(tensor_array_unstack_tensor6_name) |
| setattr(self.prelude, tensor_array_unstack_tensor6_name, tensor_array_unstack_tensor6_var) |
| tensor6 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype)) |
| shape = op.shape_of(tensor6) |
| ndim = op.take(shape, const(0)) |
| self.prelude.mod[tensor_array_unstack_tensor6_var] = Function( |
| [tensor6], |
| helper_var(const(0), ndim, tensor6), |
| self.list(self.tensor_type_var()), |
| [], |
| ) |
| |
| def define_tensor_array_scatter(self): |
| """Defines a function to scatter the values of a tensor_t in indices of a tensor array. |
| tensor_array_scatter(ta, indices, value) : |
| list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] |
| """ |
| tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper") |
| tensor_array_scatter_helper_var = GlobalVar(tensor_array_scatter_helper_name) |
| tensor_t = self.tensor_type_var |
| ta = Var("ta", self.list(tensor_t())) |
| current = Var("current", scalar_type("int32")) |
| limit = Var("limit", scalar_type("int32")) |
| indices_ = Var("indices_", TensorType([Any()], "int32")) |
| values_ = Var("values_", self.list(tensor_t())) |
| write_var = self.get_global_var("tensor_array_write") |
| read_var = self.get_global_var("tensor_array_read") |
| helper_body = If( |
| equal(current, limit), |
| ta, |
| tensor_array_scatter_helper_var( |
| write_var(ta, op.take(indices_, current), read_var(values_, current)), |
| add(current, const(1)), |
| limit, |
| indices_, |
| values_, |
| ), |
| ) |
| self.prelude.mod[tensor_array_scatter_helper_var] = Function( |
| [ta, current, limit, indices_, values_], helper_body, self.list(tensor_t()), [] |
| ) |
| tensor_array_scatter_name = self.get_name("tensor_array_scatter") |
| tensor_array_scatter_var = GlobalVar(tensor_array_scatter_name) |
| setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var) |
| tensor_array = Var("tensor_array", self.list(tensor_t())) |
| indices = Var("indices", TensorType([Any()], "int32")) |
| values = Var("values", self.list(tensor_t())) |
| indices_shape = op.shape_of(indices) |
| limit = op.take(indices_shape, const(0)) |
| body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values) |
| self.prelude.mod[tensor_array_scatter_var] = Function( |
| [tensor_array, indices, values], body, self.list(tensor_t()), [] |
| ) |
| |
| def define_tensor_array_split(self): |
| """Defines a function to split the values of a tensor_t into a tensor array. |
| tensor_array_split(ta, value, lengths) : |
| list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t] |
| """ |
| tensor_t = self.tensor_type_var |
| tensor_array_split_helper_name = self.get_name("ta_split_helper") |
| tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name) |
| setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var) |
| ta1 = Var("tensor_array", self.list(tensor_t())) |
| value1 = Var("value1", tensor_t()) |
| offset1 = Var("offset1", scalar_type("int32")) |
| current1 = Var("current1", scalar_type("int32")) |
| limit1 = Var("limit1", scalar_type("int32")) |
| lengths1 = Var("lengths", TensorType([Any()], "int32")) |
| write_var = self.get_global_var("tensor_array_write") |
| take_var = self.get_global_var("tensor_take") |
| helper1_body = If( |
| equal(current1, limit1), |
| ta1, |
| write_var( |
| tensor_array_split_helper_var( |
| ta1, |
| value1, |
| add(offset1, op.take(lengths1, current1)), |
| add(current1, const(1)), |
| limit1, |
| lengths1, |
| ), |
| current1, |
| take_var(value1, offset1, add(op.take(lengths1, current1), offset1)), |
| ), |
| ) |
| self.prelude.mod[tensor_array_split_helper_var] = Function( |
| [ta1, value1, offset1, current1, limit1, lengths1], |
| helper1_body, |
| self.list(tensor_t()), |
| [], |
| ) |
| split_name = self.get_name("tensor_array_split") |
| split_var = GlobalVar(split_name) |
| setattr(self.prelude, split_name, split_var) |
| tensor_array = Var("tensor_array", self.list(tensor_t())) |
| value = Var("value", tensor_t()) |
| lengths = Var("lengths", TensorType([Any()], "int32")) |
| lengths_shape = op.shape_of(lengths) |
| lengths_limit = op.take(lengths_shape, const(0)) |
| body = tensor_array_split_helper_var( |
| tensor_array, value, const(0), const(0), lengths_limit, lengths |
| ) |
| self.prelude.mod[split_var] = Function( |
| [tensor_array, value, lengths], body, self.list(tensor_t()), [] |
| ) |
| |
| def define_tensor_array_concat(self): |
| """Defines a function to return the values in the tensor array as concatenated tensor_t. |
| tensor_array_concat(ta) : list[tensor_t] -> tensor_t |
| """ |
| concat_name = self.get_name("tensor_array_concat") |
| concat_var = GlobalVar(concat_name) |
| setattr(self.prelude, concat_name, concat_var) |
| tensor_concat_var = self.get_global_var("tensor_concatenate") |
| tensor_t = self.tensor_type_var |
| tensor_nil_var = self.get_ctor("tensor_nil") |
| tensor_array = Var("tensor_array", self.list(tensor_t())) |
| hd = Var("hd") |
| tl = Var("tl") |
| nil_case = Clause(PatternConstructor(self.nil), tensor_nil_var()) |
| cons_case = Clause( |
| PatternConstructor(self.cons, [PatternVar(hd), PatternVar(tl)]), |
| Match( |
| tl, |
| [ |
| Clause(PatternConstructor(self.nil), hd), |
| Clause(PatternWildcard(), tensor_concat_var(hd, concat_var(tl))), |
| ], |
| False, |
| ), |
| ) |
| self.prelude.mod[concat_var] = Function( |
| [tensor_array], Match(tensor_array, [nil_case, cons_case], False), tensor_t(), [] |
| ) |
| |
| def define_tensor_array_gather(self): |
| """Defines a function to return the selected values in a tensor array as tensor_t. |
| tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t |
| """ |
| helper_name = self.get_name("tensor_array_gather_helper") |
| helper_var = GlobalVar(helper_name) |
| setattr(self.prelude, helper_name, helper_var) |
| tensor_type_var = self.tensor_type_var |
| stack_var = self.get_var("tensor_array_stack") |
| read_var = self.get_var("tensor_array_read") |
| ta = Var("ta", self.list(tensor_type_var())) |
| accu = Var("accu", self.list(tensor_type_var())) |
| current = Var("current", scalar_type("int32")) |
| limit = Var("limit", scalar_type("int32")) |
| indices_ = Var("indices_", TensorType([Any()], "int32")) |
| helper_body = If( |
| equal(current, const(0)), |
| stack_var(accu), |
| helper_var( |
| ta, |
| self.cons(read_var(ta, op.take(indices_, subtract(current, const(1)))), accu), |
| subtract(current, const(1)), |
| limit, |
| indices_, |
| ), |
| ) |
| self.prelude.mod[helper_var] = Function( |
| [ta, accu, current, limit, indices_], helper_body, tensor_type_var(), [] |
| ) |
| gather_name = self.get_name("tensor_array_gather") |
| gather_var = GlobalVar(gather_name) |
| setattr(self.prelude, gather_name, gather_var) |
| tensor_array = Var("tensor_array", self.list(tensor_type_var())) |
| indices = Var("indices", TensorType([Any()], "int32")) |
| indices_shape = op.shape_of(indices) |
| limit = op.take(indices_shape, const(0)) |
| body = helper_var(tensor_array, self.nil(), limit, limit, indices) |
| self.prelude.mod[gather_var] = Function( |
| [tensor_array, indices], body, tensor_type_var(), [] |
| ) |
| |
| def define_tensor_array_stack(self): |
| """Defines a function to get the values in the tensor array as a stack tensor_t. |
| tensor_array_stack(l) : list[tensor_t] -> tensor_t |
| """ |
| stack_name = self.get_name("tensor_array_stack") |
| stack_var = GlobalVar(stack_name) |
| setattr(self.prelude, stack_name, stack_var) |
| tensor_type_var = self.tensor_type_var |
| tensor_array = Var("tensor_array", self.list(tensor_type_var())) |
| expand_dims_var = self.get_global_var("tensor_expand_dims") |
| concat_var = self.get_global_var("tensor_concatenate") |
| |
| tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) |
| tensors = self.prelude.foldl( |
| concat_var, |
| self.prelude.hd(tensor_array_expand_dims), |
| self.prelude.tl(tensor_array_expand_dims), |
| ) |
| self.prelude.mod[stack_var] = Function( |
| [tensor_array], ToANormalFormExpr(tensors), tensor_type_var(), [] |
| ) |
| |
| def register(self): |
| """Register all tensor array ops in Prelude""" |
| self.define_tensor_adt() |
| self.define_tensor_take() |
| self.define_tensor_expand_dims() |
| self.define_tensor_concat() |
| self.define_tensor_array() |
| self.define_tensor_array_read() |
| self.define_tensor_array_write() |
| self.define_tensor_array_unstack_tensor1() |
| self.define_tensor_array_unstack_tensor2() |
| self.define_tensor_array_unstack_tensor3() |
| self.define_tensor_array_unstack_tensor4() |
| self.define_tensor_array_unstack_tensor5() |
| self.define_tensor_array_unstack_tensor6() |
| self.define_tensor_array_scatter() |
| self.define_tensor_array_split() |
| self.define_tensor_array_concat() |
| self.define_tensor_array_stack() |
| # TODO(wweic): Gather fails in PartialEvaluate |
| # self.define_tensor_array_gather() |
| |
| |
| class Prelude: |
| """Contains standard definitions.""" |
| |
| def __init__(self, mod=None): |
| if mod is None: |
| mod = IRModule() |
| self.mod = mod |
| self.load_prelude() |
| |
| def get_name(self, canonical, dtype): |
| """Get name corresponding to the canonical name""" |
| if canonical == "tensor_t": |
| return "tensor_{}_t".format(dtype) |
| return "{}_{}".format(canonical, dtype) |
| |
| def get_global_var(self, canonical, dtype): |
| """Get global var corresponding to the canonical name""" |
| name = self.get_name(canonical, dtype) |
| return self.mod.get_global_var(name) |
| |
| def get_type(self, canonical, dtype): |
| """Get type corresponding to the canonical name""" |
| name = self.get_name(canonical, dtype) |
| return self.mod.get_global_type_var(name) |
| |
| def get_ctor(self, ty_name, canonical, dtype): |
| """Get constructor corresponding to the canonical name""" |
| name = self.get_name(canonical, dtype) |
| ctors = self.mod.get_type(ty_name) |
| for ctor in ctors: |
| if ctor.name_hint == name: |
| return ctor |
| raise Exception(f"could not find {name}") |
| |
| def get_tensor_ctor(self, canonical, dtype): |
| ty = self.get_type("tensor_t", dtype) |
| return self.get_ctor(ty.name_hint, canonical, dtype) |
| |
| def get_name_static(self, canonical, dtype, shape, batch_dim=None): |
| """Get name corresponding to the canonical name""" |
| return _get_name_static(canonical, dtype, shape, batch_dim) |
| |
| def get_global_var_static(self, canonical, dtype, shape, batch_dim=None): |
| """Get var corresponding to the canonical name""" |
| name = self.get_name_static(canonical, dtype, shape, batch_dim) |
| return self.mod.get_global_var(name) |
| |
| def get_type_static(self, canonical, dtype, shape): |
| """Get type corresponding to the canonical name""" |
| name = self.get_name_static(canonical, dtype, shape) |
| return self.mod.get_global_type_var(name) |
| |
| def get_ctor_static(self, ty_name, name, dtype, shape): |
| """Get constructor corresponding to the canonical name""" |
| ty_name = self.get_name_static(ty_name, dtype, shape) |
| name = self.get_name_static(name, dtype, shape) |
| ctors = self.mod.get_type(ty_name) |
| for ctor in ctors: |
| if ctor.name_hint == name: |
| return ctor |
| raise Exception(f"could not find {name}") |
| |
| def get_tensor_ctor_static(self, name, dtype, shape): |
| """Get constructor corresponding to the canonical name""" |
| return self.get_ctor_static("tensor_t", name, dtype, shape) |
| |
| def load_prelude(self): |
| """Parses the Prelude from Relay's text format into a module.""" |
| # TODO(@jroesch): we should remove this helper when we port over prelude |
| self.mod.import_from_std("prelude.rly") |
| |
| GLOBAL_DEFS = [ |
| "id", |
| "compose", |
| "flip", |
| "hd", |
| "tl", |
| "nth", |
| "update", |
| "map", |
| "foldl", |
| "foldr", |
| "foldr1", |
| "concat", |
| "filter", |
| "zip", |
| "rev", |
| "map_accuml", |
| "map_accumr", |
| "unfoldl", |
| "unfoldr", |
| "sum", |
| "length", |
| "tmap", |
| "size", |
| "iterate", |
| ] |
| |
| for global_def in GLOBAL_DEFS: |
| setattr(self, global_def, self.mod.get_global_var(global_def)) |
| |
| for dtype in [ |
| "float32", |
| "float16", |
| "float64", |
| "int32", |
| "uint8", |
| "int8", |
| "int16", |
| "uint16", |
| "int64", |
| ]: |
| tensor_array_ops = TensorArrayOps(self, dtype) |
| tensor_array_ops.register() |
| |
| # Renamer doesn't properly deal with constructors, etc |
| # self.mod = AnnotateSpans()(self.mod) |