| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file tvm/te/tensor.h |
| * \brief Dataflow tensor object |
| */ |
| #ifndef TVM_TE_TENSOR_H_ |
| #define TVM_TE_TENSOR_H_ |
| |
| #include <tvm/arith/bound.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/op.h> |
| |
| #include <string> |
| #include <type_traits> |
| #include <utility> |
| #include <vector> |
| |
| namespace tvm { |
| namespace te { |
| |
| using arith::IntSet; |
| using namespace tvm::tir; |
| |
| // internal node container for Operation |
| class OperationNode; |
| class Tensor; |
| |
| /*! \brief Operation that produces tensors */ |
| class Operation : public ObjectRef { |
| public: |
| /*! \brief default constructor */ |
| Operation() {} |
| explicit Operation(ObjectPtr<Object> n) : ObjectRef(n) {} |
| /*! |
| * \brief access the internal node container |
| * \return the pointer to the internal node container |
| */ |
| inline const OperationNode* operator->() const; |
| /*! |
| * \brief get the i-th output of the operation. |
| * \param i the output index. |
| * \return The i-th output. |
| */ |
| TVM_DLL Tensor output(size_t i) const; |
| /*! \brief specify container node */ |
| using ContainerType = OperationNode; |
| }; |
| |
| /*! \brief Node to represent a tensor */ |
| class TensorNode : public DataProducerNode { |
| public: |
| /*! \brief The shape of the tensor */ |
| Array<PrimExpr> shape; |
| /*! \brief data type in the content of the tensor */ |
| DataType dtype; |
| /*! \brief the source operation, can be None */ |
| Operation op; |
| /*! \brief the output index from source operation */ |
| int value_index{0}; |
| /*! \brief constructor */ |
| TensorNode() {} |
| |
| void VisitAttrs(AttrVisitor* v) { |
| v->Visit("shape", &shape); |
| v->Visit("dtype", &dtype); |
| v->Visit("op", &op); |
| v->Visit("value_index", &value_index); |
| } |
| |
| Array<PrimExpr> GetShape() const final { return shape; } |
| |
| DataType GetDataType() const final { return dtype; } |
| |
| TVM_DLL String GetNameHint() const final; |
| |
| static constexpr const char* _type_key = "Tensor"; |
| TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode); |
| }; |
| |
| /*! |
| * \brief Tensor structure representing a possible input, |
| * or intermediate computation result. |
| */ |
| class Tensor : public DataProducer { |
| public: |
| TVM_DLL Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index); |
| /*! |
| * \brief check if two tensors equals each other. |
| * \param other tensor to be checked. |
| * \return whether the two tensors equals each other. |
| */ |
| inline bool operator==(const Tensor& other) const; |
| /*! |
| * \brief check if two tensors are different. |
| * \param other tensor to be checked. |
| * \return whether the two tensors are different. |
| */ |
| inline bool operator!=(const Tensor& other) const; |
| /*! \return The dimension of the tensor */ |
| inline size_t ndim() const; |
| /*! |
| * \brief Take elements from the tensor |
| * \param args The indices |
| * \return the result expression representing tensor read. |
| */ |
| template <typename... Args> |
| inline PrimExpr operator()(Args&&... args) const { |
| Array<PrimExpr> indices{std::forward<Args>(args)...}; |
| return operator()(indices); |
| } |
| /*! |
| * \brief Take elements from the tensor |
| * \param indices the indices. |
| * \return the result expression representing tensor read. |
| */ |
| TVM_DLL PrimExpr operator()(Array<PrimExpr> indices) const; |
| /*! |
| * \brief Take elements from the tensor |
| * \param indices the indices. |
| * \return the result expression representing tensor read. |
| */ |
| TVM_DLL PrimExpr operator()(Array<Var> indices) const; |
| /*! |
| * \brief data structure to represent a slice that fixes first k coordinates. |
| * This is used to enable syntax sugar of Tensor[x][y][z] to get the element. |
| */ |
| class Slice { |
| public: |
| // construct via tensor and indices |
| Slice(const Tensor& tensor, std::vector<PrimExpr> indices) |
| : tensor_(tensor), indices_(indices) {} |
| /*! |
| * \brief get i-th slice from the current slice. |
| * \param i the index of the coordinate |
| * \return the subsequent slice. |
| */ |
| inline Slice operator[](PrimExpr i) { |
| std::vector<PrimExpr> other = indices_; |
| other.emplace_back(i); |
| return Slice(tensor_, other); |
| } |
| /*! |
| * \brief Convert slice to expression. |
| * This is only valid when all the coordinates are fully specified. |
| * \return the corresponding expression of this slice. |
| */ |
| inline operator PrimExpr() const { return tensor_(indices_); } |
| |
| private: |
| const Tensor& tensor_; |
| std::vector<PrimExpr> indices_; |
| }; |
| /*! |
| * \brief get i-th slice from the current Tensor. |
| * \param i the index of the coordinate |
| * \return the subsequent slice. |
| */ |
| inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); } |
| |
| TVM_DEFINE_OBJECT_REF_METHODS(Tensor, DataProducer, TensorNode); |
| }; |
| |
| // Implementations of inline functions |
| inline size_t Tensor::ndim() const { return (*this)->shape.size(); } |
| |
| inline bool Tensor::operator==(const Tensor& other) const { |
| if (get() == other.get()) return true; |
| if (get() == nullptr || other.get() == nullptr) return false; |
| if ((*this)->op.defined() || other->op.defined()) { |
| return (*this)->op == other->op && (*this)->value_index == other->value_index; |
| } else { |
| return false; |
| } |
| } |
| |
| inline bool Tensor::operator!=(const Tensor& other) const { return !(*this == other); } |
| |
| // macro to turn every operation of slice to expression |
| #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ |
| inline PrimExpr operator Op(const Tensor::Slice& a) { return Op a.operator PrimExpr(); } |
| |
| #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ |
| template <typename T> \ |
| inline PrimExpr operator Op(const Tensor::Slice& a, const T& b) { \ |
| return a.operator PrimExpr() Op b; \ |
| } \ |
| template <typename T> \ |
| inline PrimExpr operator Op(const T& a, const Tensor::Slice& b) { \ |
| return a Op b.operator PrimExpr(); \ |
| } \ |
| inline PrimExpr operator Op(const Tensor::Slice& a, const Tensor::Slice& b) { \ |
| return a.operator PrimExpr() Op b.operator PrimExpr(); \ |
| } |
| |
| DEFINE_OVERLOAD_SLICE_UNARY_OP(!); |
| DEFINE_OVERLOAD_SLICE_UNARY_OP(-); |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(+); |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(-); |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(*); |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(==); |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(<=); |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(>=); |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(!=); |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(&&); |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(||); |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(>>); |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(<<); |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(>); // NOLINT(*) |
| DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) |
| |
| } // namespace te |
| } // namespace tvm |
| |
| namespace std { |
| template <> |
| struct hash<::tvm::te::Operation> : public ::tvm::ObjectPtrHash {}; |
| |
| template <> |
| struct hash<::tvm::te::Tensor> { |
| std::size_t operator()(const ::tvm::te::Tensor& k) const { |
| ::tvm::ObjectPtrHash hasher; |
| if (k.defined() && k->op.defined()) { |
| return hasher(k->op); |
| } else { |
| return hasher(k); |
| } |
| } |
| }; |
| } // namespace std |
| #endif // TVM_TE_TENSOR_H_ |