blob: 27e48999a7d10fc80e85f7aca2b08b2684a7df95 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/te/operation.h
* \brief Operation node can generate one or multiple Tensors
*/
#ifndef TVM_TE_OPERATION_H_
#define TVM_TE_OPERATION_H_
#include <tvm/arith/analyzer.h>
#include <tvm/te/schedule.h>
#include <tvm/te/tensor.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
/*! \brief Tensor expression language DSL. */
namespace te {
/*!
* \brief Temporary data structure to store union
* of bounds of each axis of Tensor.
*/
struct TensorDom {
// constructor
explicit TensorDom(int ndim) : data(ndim) {}
/*! \brief The domain data */
std::vector<std::vector<IntSet> > data;
};
/*!
* \brief Base class of all operation nodes
*/
class TVM_DLL OperationNode : public Object {
public:
/*! \brief optional name of the operation */
std::string name;
/*! \brief optional tag of the operation */
std::string tag;
/*! \brief additional attributes of the operation*/
Map<String, ObjectRef> attrs;
// virtual destructor.
virtual ~OperationNode() {}
/*! \return number of outputs */
virtual int num_outputs() const = 0;
/*!
* \return The list of iteration variable at root
* \note root_iter_vars decides the shape of the outputs.
*/
virtual Array<IterVar> root_iter_vars() const = 0;
/*!
* \brief Get data type. i-th output tensor.
* \param i The output index.
* \return type of i-th output.
*/
virtual DataType output_dtype(size_t i) const = 0;
/*!
* \brief Get shape of i-th output tensor.
* \param i The output index.
* \return shape of i-th output.
*/
virtual Array<PrimExpr> output_shape(size_t i) const = 0;
/*!
* \brief List all the input Tensors.
* \return List of input tensors.
*/
virtual Array<Tensor> InputTensors() const = 0;
/*!
* \brief Replace the input of the operation by pattern specified by rmap.
*
* \param self The reference to self.
* \param rmap The replacement map.
* \return self if nothing is replaced, otherwise return replaced op.
*/
virtual Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
/*!
* \brief Propagate the bounds to inputs
* \param self The reference to self.
* \param analyzer The analyzer to be used in the function.
* \param dom_map the domain map of Variables(corresponds to root_iter_vars)
* \param out_dom_map The output domain.
* The function is only asked to fill the bounds for Tensors that
* is already in the out_dom_map
*/
virtual void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
/*!
* \brief Gather the bound from output tensor.
* Set the range of each root_iter_vars in the op to out_dom_map
*
* \param self The reference to self.
* \param tensor_dom Domain map of Tensor->access set of each dimension.
* \param out_dom_map The output domain map of each IterVar to be setted.
*/
virtual void GatherBound(const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
/*!
* \brief Build the Realize statement that realizes
* the op's output tensors.
* \param stage the op's stage.
* \param realize_map The realization domain map of the operators.
* \param body The body that is going to get
* \return A realization statement that wraps body.
*/
virtual Stmt BuildRealize(const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const = 0;
/*!
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
* \param dom_map The domain map of all iteration domains.
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return A statement that add production and wraps consumer.
*/
virtual Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const = 0;
static constexpr const char* _type_key = "Operation";
TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object);
};
/*!
* \brief A placeholder op represents an input placeholder.
*/
class PlaceholderOpNode : public OperationNode {
public:
/*! \brief The shape of the input */
Array<PrimExpr> shape;
/*! \brief The data type of the input. */
DataType dtype;
// override behavior.
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
};
/*!
* \brief Managed reference to PlaceholderOpNode
* \sa PlaceholderOpNode
*/
class PlaceholderOp : public Operation {
public:
TVM_DLL PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype);
TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode);
};
/*!
* \brief A Compute op that compute a tensor on certain domain.
* This is the base class for ComputeOp (operating on a scalar at a time) and
* TensorComputeOp (operating on a TensorSlice at a time)
*/
class TVM_DLL BaseComputeOpNode : public OperationNode {
public:
/*! \brief IterVar on each axis */
Array<IterVar> axis;
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis;
// override functions
Array<IterVar> root_iter_vars() const final;
Array<PrimExpr> output_shape(size_t idx) const final;
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
virtual size_t num_schedulable_dims() const = 0;
static constexpr const char* _type_key = "BaseComputeOp";
TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode);
};
/*!
* \brief A Compute op that compute a tensor on certain domain.
*/
class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
public:
/*! \brief the compute expression */
Array<PrimExpr> body;
/*! \brief constructor */
ComputeOpNode() {}
// override functions
int num_outputs() const final;
DataType output_dtype(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
size_t num_schedulable_dims() const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
};
/*!
* \brief Managed reference to ComputeOpNode
* \sa ComputeOpNode
*/
class ComputeOp : public Operation {
public:
TVM_DLL ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> axis, Array<PrimExpr> body);
TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
};
/*!
* \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
*/
class TensorComputeOpNode : public BaseComputeOpNode {
public:
/*! \brief number of axes that can be scheduled */
int schedulable_ndim;
/*! \brief TensorIntrin used to compute */
TensorIntrin intrin;
/*! \brief input tensors of intrin */
Array<Tensor> inputs;
/*! \brief region of input tensors */
Array<Region> input_regions;
/*! \brief scalar expression inputs */
Array<PrimExpr> scalar_inputs;
/*! \brief constructor */
TensorComputeOpNode() {}
// override functions
int num_outputs() const final;
DataType output_dtype(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
size_t num_schedulable_dims() const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("schedulable_ndim", &schedulable_ndim);
v->Visit("intrin", &intrin);
v->Visit("inputs", &inputs);
v->Visit("input_regions", &input_regions);
v->Visit("scalar_inputs", &scalar_inputs);
}
static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
};
/*!
* \brief Managed reference to TensorComputeOpNode
* \sa TensorComputeOpNode
*/
class TensorComputeOp : public Operation {
public:
TVM_DLL TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis,
Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
Array<Tensor> tensors, Array<Region> regions,
Array<PrimExpr> scalar_inputs);
TVM_DEFINE_OBJECT_REF_METHODS(TensorComputeOp, Operation, TensorComputeOpNode);
};
/*!
* \brief Symbolic scan.
*/
class ScanOpNode : public OperationNode {
public:
/*! \brief IterVar to scan over */
IterVar scan_axis;
/*! \brief the initialization tensors */
Array<Tensor> init;
/*! \brief the update function represented by tensor */
Array<Tensor> update;
/*! \brief The placeholder to refer as states in update. */
Array<Tensor> state_placeholder;
/*!
* \brief the inputs to the scan, these are optionally provided
* But they can be helpful to provide hints to speedup get of scan body.
*/
Array<Tensor> inputs;
/*!
* \brief Spatial axis to indicate spatial dimension of each output.
* They corresponds to flattened spatial axis of the outputs.
*
* [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
* These are auxiliary data structure for storing result of bound inference.
* They do not corresponds to splittable iterations, thus the name comes
* with underscore.
*/
Array<IterVar> spatial_axis_;
/*! \brief constructor */
ScanOpNode() {}
// override behavior.
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("scan_axis", &scan_axis);
v->Visit("init", &init);
v->Visit("update", &update);
v->Visit("state_placeholder", &state_placeholder);
v->Visit("inputs", &inputs);
v->Visit("spatial_axis_", &spatial_axis_);
}
static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
};
/*!
* \brief Managed reference to ScanOpNode
* \sa ScanOpNode
*/
class ScanOp : public Operation {
public:
TVM_DLL ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, IterVar axis,
Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
Array<Tensor> input);
TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode);
};
/*!
* \brief External computation that cannot be splitted.
*/
class ExternOpNode : public OperationNode {
public:
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of inputs */
Array<Buffer> input_placeholders;
/*! \brief Symbolic placeholder representation of outputs */
Array<Buffer> output_placeholders;
/*! \brief the statement that generates the computation. */
Stmt body;
/*! \brief constructor */
ExternOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("input_placeholders", &input_placeholders);
v->Visit("output_placeholders", &output_placeholders);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "ExternOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
};
/*!
* \brief Managed reference to ExternOpNode
* \sa ExternOpNode
*/
class ExternOp : public Operation {
public:
TVM_DLL ExternOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Stmt body);
TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode);
};
/*!
* \brief A computation operator that generated by hybrid script.
*/
class HybridOpNode : public OperationNode {
public:
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of outputs */
Array<Tensor> outputs;
/*! \brief The axis of iterations */
Array<IterVar> axis;
/*! \brief the statement that generates the computation. This is
* slightly different from the body in ExternOpNode. All the output
* tensors keep its own name specified by users in the script.
* However, when compilation, these tensors will be placed by those
* actual output tensors. */
Stmt body;
/*! \brief constructor */
HybridOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("axis", &axis);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
};
/*!
* \brief Managed reference to HybridOpNode
* \sa HybridOpNode
*/
class HybridOp : public Operation {
public:
TVM_DLL HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> outputs, Stmt body);
TVM_DEFINE_OBJECT_REF_METHODS(HybridOp, Operation, HybridOpNode);
};
/*!
* \brief Construct a new Var expression
* \param name_hint The name hint for the expression
* \param t The type of the expression
*/
TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32));
/*!
* \brief Create a new IterVar that represents an axis in thread.
*
* \param dom Optional, domain of the thread axis.
* \param tag The thread tag of the axis.
*/
TVM_DLL IterVar thread_axis(Range dom, std::string tag);
/*!
* \brief Create a new IterVar for reduction operations.
*
* \param dom The domain of the reduction axis.
* \param name The name of the reduction axis.
*/
TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<PrimExpr(const Array<Var>& i)>;
/*! \brief The compute function to specify the inputs source of Tensors */
using FBatchCompute = std::function<Array<PrimExpr>(const Array<Var>& i)>;
/*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
TVM_DLL Tensor placeholder(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
std::string name = "placeholder");
/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
TVM_DLL Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name = "tensor",
std::string tag = "", Map<String, ObjectRef> attrs = {});
/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensors.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute,
std::string name = "tensor", std::string tag = "",
Map<String, ObjectRef> attrs = {});
/*!
* \brief Construct new tensors by scan.
*
* \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states.
* \param inputs The inputs to the scan body, this is optional,
* but recommended to provide concrete information about scan body.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
TVM_DLL Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> inputs = Array<Tensor>(),
std::string name = "scan", std::string tag = "",
Map<String, ObjectRef> attrs = {});
// same as compute, specialized for different fcompute function
inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var)> f,
std::string name = "tensor", std::string tag = "",
Map<String, ObjectRef> attrs = {}) {
FCompute fc = [f](const Array<Var>& i) { return f(i[0]); };
return compute(shape, fc, name, tag, attrs);
}
inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var)> f,
std::string name = "tensor", std::string tag = "",
Map<String, ObjectRef> attrs = {}) {
FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1]); };
return compute(shape, fc, name, tag, attrs);
}
inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var)> f,
std::string name = "tensor", std::string tag = "",
Map<String, ObjectRef> attrs = {}) {
FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return compute(shape, fc, name, tag, attrs);
}
inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var, Var)> f,
std::string name = "tensor", std::string tag = "",
Map<String, ObjectRef> attrs = {}) {
FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return compute(shape, fc, name, tag, attrs);
}
// inline function.
inline const OperationNode* Operation::operator->() const {
return static_cast<const OperationNode*>(get());
}
} // namespace te
} // namespace tvm
#endif // TVM_TE_OPERATION_H_