/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tvm/te/operation.h
 * \brief Operation node can generate one or multiple Tensors
 */
#ifndef TVM_TE_OPERATION_H_
#define TVM_TE_OPERATION_H_

#include <tvm/arith/analyzer.h>
#include <tvm/te/schedule.h>
#include <tvm/te/tensor.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

#include <string>
#include <unordered_map>
#include <vector>

namespace tvm {
/*! \brief Tensor expression language DSL. */
namespace te {

/*!
 * \brief Temporary data structure to store union
 *  of bounds of each axis of Tensor.
 */
struct TensorDom {
  // constructor
  explicit TensorDom(int ndim) : data(ndim) {}
  /*! \brief The domain data */
  std::vector<std::vector<IntSet> > data;
};

/*!
 * \brief Base class of all operation nodes
 */
class TVM_DLL OperationNode : public Object {
 public:
  /*! \brief optional name of the operation */
  std::string name;
  /*! \brief optional tag of the operation */
  std::string tag;
  /*! \brief additional attributes of the operation*/
  Map<String, ObjectRef> attrs;
  // virtual destructor.
  virtual ~OperationNode() {}
  /*! \return number of outputs */
  virtual int num_outputs() const = 0;
  /*!
   * \return The list of iteration variable at root
   * \note root_iter_vars decides the shape of the outputs.
   */
  virtual Array<IterVar> root_iter_vars() const = 0;
  /*!
   * \brief Get data type. i-th output tensor.
   * \param i The output index.
   * \return type of i-th output.
   */
  virtual DataType output_dtype(size_t i) const = 0;
  /*!
   * \brief Get shape of i-th output tensor.
   * \param i The output index.
   * \return shape of i-th output.
   */
  virtual Array<PrimExpr> output_shape(size_t i) const = 0;
  /*!
   * \brief List all the input Tensors.
   * \return List of input tensors.
   */
  virtual Array<Tensor> InputTensors() const = 0;
  /*!
   * \brief Replace the input of the operation by pattern specified by rmap.
   *
   * \param self The reference to self.
   * \param rmap The replacement map.
   * \return self if nothing is replaced, otherwise return replaced op.
   */
  virtual Operation ReplaceInputs(const Operation& self,
                                  const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
  /*!
   * \brief Propagate the bounds to inputs
   * \param self The reference to self.
   * \param analyzer The analyzer to be used in the function.
   * \param dom_map the domain map of Variables(corresponds to root_iter_vars)
   * \param out_dom_map The output domain.
   *  The function is only asked to fill the bounds for Tensors that
   *  is already in the out_dom_map
   */
  virtual void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
                                 const std::unordered_map<const VarNode*, IntSet>& dom_map,
                                 std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
  /*!
   * \brief Gather the bound from output tensor.
   *  Set the range of each root_iter_vars in the op to out_dom_map
   *
   * \param self The reference to self.
   * \param tensor_dom Domain map of Tensor->access set of each dimension.
   * \param out_dom_map The output domain map of each IterVar to be setted.
   */
  virtual void GatherBound(const Operation& self,
                           const std::unordered_map<Tensor, TensorDom>& tensor_dom,
                           std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
  /*!
   * \brief Build the Realize statement that realizes
   *   the op's output tensors.
   * \param stage the op's stage.
   * \param realize_map The realization domain map of the operators.
   * \param body The body that is going to get
   * \param storage_scope The storage scope associated with this realization
   * \return A realization statement that wraps body.
   */
  virtual Stmt BuildRealize(const Stage& stage,
                            const std::unordered_map<IterVar, Range>& realize_map, const Stmt& body,
                            String storage_scope = "") const = 0;
  /*!
   * \brief Build the statement that provide the output tensors.
   * \param stage The schedule stage of the op.
   * \param dom_map The domain map of all iteration domains.
   * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
   * \return A statement that add production and wraps consumer.
   */
  virtual Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
                            bool debug_keep_trivial_loop) const = 0;

  static constexpr const char* _type_key = "Operation";

  TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object);
};

/*!
 * \brief A placeholder op represents an input placeholder.
 */
class PlaceholderOpNode : public OperationNode {
 public:
  /*! \brief The shape of the input */
  Array<PrimExpr> shape;
  /*! \brief The data type of the input. */
  DataType dtype;
  // override behavior.
  int num_outputs() const final;
  Array<IterVar> root_iter_vars() const final;
  DataType output_dtype(size_t i) const final;
  Array<PrimExpr> output_shape(size_t i) const final;
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(const Operation& self,
                          const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
                         const std::unordered_map<const VarNode*, IntSet>& dom_map,
                         std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
                   std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
                    const Stmt& body, String storage_scope = "") const final;
  Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
                    bool debug_keep_trivial_loop) const final;

  void VisitAttrs(AttrVisitor* v) {
    v->Visit("name", &name);
    v->Visit("tag", &tag);
    v->Visit("attrs", &attrs);
    v->Visit("shape", &shape);
    v->Visit("dtype", &dtype);
  }

  static constexpr const char* _type_key = "PlaceholderOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
};

/*!
 * \brief Managed reference to PlaceholderOpNode
 * \sa PlaceholderOpNode
 */
class PlaceholderOp : public Operation {
 public:
  TVM_DLL PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype);

  TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode);
};

/*!
 * \brief A Compute op that compute a tensor on certain domain.
 * This is the base class for ComputeOp (operating on a scalar at a time) and
 * TensorComputeOp (operating on a TensorSlice at a time)
 */
class TVM_DLL BaseComputeOpNode : public OperationNode {
 public:
  /*! \brief IterVar on each axis */
  Array<IterVar> axis;
  /*! \brief IterVar on each reduction axis, if the body is a Reduce */
  Array<IterVar> reduce_axis;
  // override functions
  Array<IterVar> root_iter_vars() const final;
  Array<PrimExpr> output_shape(size_t idx) const final;
  void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
                   std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
                    const Stmt& body, String storage_scope = "") const final;
  virtual size_t num_schedulable_dims() const = 0;

  static constexpr const char* _type_key = "BaseComputeOp";
  TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode);
};

/*!
 * \brief A Compute op that compute a tensor on certain domain.
 */
class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
 public:
  /*! \brief the compute expression */
  Array<PrimExpr> body;
  /*! \brief constructor */
  ComputeOpNode() {}
  // override functions
  int num_outputs() const final;
  DataType output_dtype(size_t i) const final;
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(const Operation& self,
                          const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
                         const std::unordered_map<const VarNode*, IntSet>& dom_map,
                         std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
                    bool debug_keep_trivial_loop) const final;
  size_t num_schedulable_dims() const final;

  void VisitAttrs(AttrVisitor* v) {
    v->Visit("name", &name);
    v->Visit("tag", &tag);
    v->Visit("attrs", &attrs);
    v->Visit("axis", &axis);
    v->Visit("reduce_axis", &reduce_axis);
    v->Visit("body", &body);
  }

  static constexpr const char* _type_key = "ComputeOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
};

/*!
 * \brief Managed reference to ComputeOpNode
 * \sa ComputeOpNode
 */
class ComputeOp : public Operation {
 public:
  TVM_DLL ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
                    Array<IterVar> axis, Array<PrimExpr> body);

  TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
  TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode);
};

/*!
 * \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
 */
class TensorComputeOpNode : public BaseComputeOpNode {
 public:
  /*! \brief number of axes that can be scheduled */
  int schedulable_ndim;
  /*! \brief TensorIntrin used to compute */
  TensorIntrin intrin;
  /*! \brief input tensors of intrin */
  Array<Tensor> inputs;
  /*! \brief region of input tensors */
  Array<Region> input_regions;
  /*! \brief scalar expression inputs */
  Array<PrimExpr> scalar_inputs;
  /*! \brief constructor */
  TensorComputeOpNode() {}
  // override functions
  int num_outputs() const final;
  DataType output_dtype(size_t i) const final;
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(const Operation& self,
                          const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
                         const std::unordered_map<const VarNode*, IntSet>& dom_map,
                         std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
                    bool debug_keep_trivial_loop) const final;
  size_t num_schedulable_dims() const final;

  void VisitAttrs(AttrVisitor* v) {
    v->Visit("name", &name);
    v->Visit("tag", &tag);
    v->Visit("axis", &axis);
    v->Visit("reduce_axis", &reduce_axis);
    v->Visit("schedulable_ndim", &schedulable_ndim);
    v->Visit("intrin", &intrin);
    v->Visit("inputs", &inputs);
    v->Visit("input_regions", &input_regions);
    v->Visit("scalar_inputs", &scalar_inputs);
  }

  static constexpr const char* _type_key = "TensorComputeOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
};

/*!
 * \brief Managed reference to TensorComputeOpNode
 * \sa TensorComputeOpNode
 */
class TensorComputeOp : public Operation {
 public:
  TVM_DLL TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis,
                          Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
                          Array<Tensor> tensors, Array<Region> regions,
                          Array<PrimExpr> scalar_inputs);

  TVM_DEFINE_OBJECT_REF_METHODS(TensorComputeOp, Operation, TensorComputeOpNode);
};

/*!
 * \brief Symbolic scan.
 */
class ScanOpNode : public OperationNode {
 public:
  /*! \brief IterVar to scan over */
  IterVar scan_axis;
  /*! \brief the initialization tensors */
  Array<Tensor> init;
  /*! \brief the update function represented by tensor */
  Array<Tensor> update;
  /*! \brief The placeholder to refer as states in update. */
  Array<Tensor> state_placeholder;
  /*!
   * \brief the inputs to the scan, these are optionally provided
   *  But they can be helpful to provide hints to speedup get of scan body.
   */
  Array<Tensor> inputs;
  /*!
   * \brief Spatial axis to indicate spatial dimension of each output.
   *  They corresponds to flattened spatial axis of the outputs.
   *
   *  [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
   *  These are auxiliary data structure for storing result of bound inference.
   *  They do not corresponds to splittable iterations, thus the name comes
   *  with underscore.
   */
  Array<IterVar> spatial_axis_;
  /*! \brief constructor */
  ScanOpNode() {}
  // override behavior.
  int num_outputs() const final;
  Array<IterVar> root_iter_vars() const final;
  DataType output_dtype(size_t i) const final;
  Array<PrimExpr> output_shape(size_t i) const final;
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(const Operation& self,
                          const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
                         const std::unordered_map<const VarNode*, IntSet>& dom_map,
                         std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
                   std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
                    const Stmt& body, String storage_scope = "") const final;
  Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
                    bool debug_keep_trivial_loop) const final;

  void VisitAttrs(AttrVisitor* v) {
    v->Visit("name", &name);
    v->Visit("tag", &tag);
    v->Visit("attrs", &attrs);
    v->Visit("scan_axis", &scan_axis);
    v->Visit("init", &init);
    v->Visit("update", &update);
    v->Visit("state_placeholder", &state_placeholder);
    v->Visit("inputs", &inputs);
    v->Visit("spatial_axis_", &spatial_axis_);
  }

  static constexpr const char* _type_key = "ScanOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
};

/*!
 * \brief Managed reference to ScanOpNode
 * \sa ScanOpNode
 */
class ScanOp : public Operation {
 public:
  TVM_DLL ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, IterVar axis,
                 Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
                 Array<Tensor> input);

  TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode);
};

/*!
 * \brief External computation that cannot be splitted.
 */
class ExternOpNode : public OperationNode {
 public:
  /*! \brief The input tensors */
  Array<Tensor> inputs;
  /*! \brief Symbolic placeholder representation of inputs */
  Array<Buffer> input_placeholders;
  /*! \brief Symbolic placeholder representation of outputs */
  Array<Buffer> output_placeholders;
  /*! \brief the statement that generates the computation. */
  Stmt body;

  /*! \brief constructor */
  ExternOpNode() {}
  // override functions
  int num_outputs() const final;
  Array<IterVar> root_iter_vars() const final;
  DataType output_dtype(size_t i) const final;
  Array<PrimExpr> output_shape(size_t i) const final;
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(const Operation& self,
                          const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
                         const std::unordered_map<const VarNode*, IntSet>& dom_map,
                         std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
                   std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
                    const Stmt& body, String storage_scope = "") const final;
  Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
                    bool debug_keep_trivial_loop) const final;

  void VisitAttrs(AttrVisitor* v) {
    v->Visit("name", &name);
    v->Visit("tag", &tag);
    v->Visit("attrs", &attrs);
    v->Visit("inputs", &inputs);
    v->Visit("input_placeholders", &input_placeholders);
    v->Visit("output_placeholders", &output_placeholders);
    v->Visit("body", &body);
  }

  static constexpr const char* _type_key = "ExternOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
};

/*!
 * \brief Managed reference to ExternOpNode
 * \sa ExternOpNode
 */
class ExternOp : public Operation {
 public:
  TVM_DLL ExternOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
                   Array<Tensor> inputs, Array<Buffer> input_placeholders,
                   Array<Buffer> output_placeholders, Stmt body);

  TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode);
};

/*!
 * \brief A computation operator that generated by hybrid script.
 */
class HybridOpNode : public OperationNode {
 public:
  /*! \brief The input tensors */
  Array<Tensor> inputs;
  /*! \brief Symbolic placeholder representation of outputs */
  Array<Tensor> outputs;
  /*! \brief The axis of iterations */
  Array<IterVar> axis;
  /*! \brief the statement that generates the computation. This is
   * slightly different from the body in ExternOpNode. All the output
   * tensors keep its own name specified by users in the script.
   * However, when compilation, these tensors will be placed by those
   * actual output tensors. */
  Stmt body;

  /*! \brief constructor */
  HybridOpNode() {}
  // override functions
  int num_outputs() const final;
  Array<IterVar> root_iter_vars() const final;
  DataType output_dtype(size_t i) const final;
  Array<PrimExpr> output_shape(size_t i) const final;
  Array<Tensor> InputTensors() const final;
  Operation ReplaceInputs(const Operation& self,
                          const std::unordered_map<Tensor, Tensor>& rmap) const final;
  void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
                         const std::unordered_map<const VarNode*, IntSet>& dom_map,
                         std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
  void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
                   std::unordered_map<IterVar, Range>* out_dom_map) const final;
  Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
                    const Stmt& body, String storage_scope = "") const final;
  Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
                    bool debug_keep_trivial_loop) const final;

  void VisitAttrs(AttrVisitor* v) {
    v->Visit("name", &name);
    v->Visit("tag", &tag);
    v->Visit("attrs", &attrs);
    v->Visit("inputs", &inputs);
    v->Visit("outputs", &outputs);
    v->Visit("axis", &axis);
    v->Visit("body", &body);
  }

  static constexpr const char* _type_key = "HybridOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
};

/*!
 * \brief Managed reference to HybridOpNode
 * \sa HybridOpNode
 */
class HybridOp : public Operation {
 public:
  TVM_DLL HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
                   Array<Tensor> inputs, Array<Tensor> outputs, Stmt body);

  TVM_DEFINE_OBJECT_REF_METHODS(HybridOp, Operation, HybridOpNode);
};

/*!
 * \brief Construct a new Var expression
 * \param name_hint The name hint for the expression
 * \param t The type of the expression
 */
TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32));

/*!
 * \brief Create a new IterVar that represents an axis in thread.
 *
 * \param dom Optional, domain of the thread axis.
 * \param tag The thread tag of the axis.
 */
TVM_DLL IterVar thread_axis(Range dom, std::string tag);

/*!
 * \brief Create a new IterVar for reduction operations.
 *
 * \param dom The domain of the reduction axis.
 * \param name The name of the reduction axis.
 */
TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");

/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<PrimExpr(const Array<Var>& i)>;

/*! \brief The compute function to specify the inputs source of Tensors */
using FBatchCompute = std::function<Array<PrimExpr>(const Array<Var>& i)>;

/*!
 * \brief create a place holder tensor.
 * \param shape The shape of the tensor.
 * \param dtype the data type of the tensor.
 * \param name The name of the Tensor.
 */
TVM_DLL Tensor placeholder(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
                           std::string name = "placeholder");

/*!
 * \brief Construct a new tensor by computing over shape,
 *  using the computation rule: result_tensor[axis] = fcompute(axis)
 * \param shape Shape of the tensor.
 * \param fcompute The compute function to create the tensor.
 * \param name The optional name of the tensor.
 * \param tag The optional tag of the tensor.
 * \param attrs Optional additional attributes of the compute.
 */
TVM_DLL Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name = "tensor",
                       std::string tag = "", Map<String, ObjectRef> attrs = {});

/*!
 * \brief Construct a new tensor by computing over shape,
 *  using the computation rule: result_tensor[axis] = fcompute(axis)
 * \param shape Shape of the tensor.
 * \param fcompute The compute function to create the tensors.
 * \param name The optional name of the tensor.
 * \param tag The optional tag of the tensor.
 * \param attrs Optional additional attributes of the compute.
 */
TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute,
                              std::string name = "tensor", std::string tag = "",
                              Map<String, ObjectRef> attrs = {});

/*!
 * \brief Construct new tensors by scan.
 *
 * \param init The intialize tensor of first K steps.
 * \param update The update tensor indicated the updated result after each timestamp.
 * \param state_placeholder The placeholder for the states.
 * \param inputs The inputs to the scan body, this is optional,
 *    but recommended to provide concrete information about scan body.
 * \param name The optional name of the tensor.
 * \param tag The optional tag of the tensor.
 * \param attrs Optional additional attributes of the compute.
 */
TVM_DLL Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update,
                           Array<Tensor> state_placeholder, Array<Tensor> inputs = Array<Tensor>(),
                           std::string name = "scan", std::string tag = "",
                           Map<String, ObjectRef> attrs = {});

// same as compute, specialized for different fcompute function
inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var)> f,
                      std::string name = "tensor", std::string tag = "",
                      Map<String, ObjectRef> attrs = {}) {
  FCompute fc = [f](const Array<Var>& i) { return f(i[0]); };
  return compute(shape, fc, name, tag, attrs);
}
inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var)> f,
                      std::string name = "tensor", std::string tag = "",
                      Map<String, ObjectRef> attrs = {}) {
  FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1]); };
  return compute(shape, fc, name, tag, attrs);
}
inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var)> f,
                      std::string name = "tensor", std::string tag = "",
                      Map<String, ObjectRef> attrs = {}) {
  FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); };
  return compute(shape, fc, name, tag, attrs);
}
inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var, Var)> f,
                      std::string name = "tensor", std::string tag = "",
                      Map<String, ObjectRef> attrs = {}) {
  FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
  return compute(shape, fc, name, tag, attrs);
}

// inline function.
inline const OperationNode* Operation::operator->() const {
  return static_cast<const OperationNode*>(get());
}
}  // namespace te
}  // namespace tvm
#endif  // TVM_TE_OPERATION_H_
