blob: 5e87d907fb561958270803c1bccf55d444b03682 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/tir/buffer.h
* \brief Symbolic n-dimensional array, to represent a memory buffer.
*/
#ifndef TVM_TIR_BUFFER_H_
#define TVM_TIR_BUFFER_H_
#include <tvm/ir/expr.h>
#include <tvm/node/script_printer.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/string.h>
#include <tvm/tir/var.h>
#include <string>
namespace tvm {
namespace tir {
// forward declare Stmt
class Stmt;
/*! \brief buffer type */
enum BufferType : int {
kDefault = 1,
// Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
kAutoBroadcast = 2,
};
/*! \brief Node to represent a buffer */
class BufferNode : public Object {
public:
// Data fields.
/*!
* \brief The pointer to the head of the data
* \sa data_alignment The alignment of data in bytes.
*/
Var data;
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief The type of the buffer prior to flattening
*
* This contains the shape as it is accessed by
* BufferLoad/BufferStore nodes, and used by the low-level code
* generators.
*/
Array<PrimExpr> shape;
/*!
* \brief Separators between input axes when generating flattened output axes
*
* For buffers representing flat 1-d memory (e.g. any buffer in
* RAM), this should be an empty array. For buffers representing
* non-flat memory, each entry in axis_separators should be the
* first input axis that is part of a new flattened axis.
*/
Array<IntImm> axis_separators;
/*!
* \brief The strides of each dimension
* This can be an empty array, indicating array is contiguous
*/
Array<PrimExpr> strides;
/*! \brief The offset in terms of number of dtype elements (including lanes) */
PrimExpr elem_offset;
// Meta data
/*! \brief optional name of the buffer */
String name;
/*! \brief Alignment requirement of data pointer in bytes. */
int data_alignment;
/*!
* \brief Factor of elem_offset field,
* elem_offset is guaranteed to be multiple of offset_factor.
*/
int offset_factor;
/*! \brief buffer type */
BufferType buffer_type;
/*!
* \brief Span that points to the original source code.
* Reserved debug information.
*/
mutable Span span;
/*! \brief constructor */
BufferNode() {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("dtype", &dtype);
v->Visit("shape", &shape);
v->Visit("strides", &strides);
v->Visit("axis_separators", &axis_separators);
v->Visit("elem_offset", &elem_offset);
v->Visit("name", &name);
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
v->Visit("buffer_type", &buffer_type);
v->Visit("span", &span);
}
bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
// Use DefEqual as buffer can define variables in its semantics,
// skip name as name is not important.
return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) &&
equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) &&
equal.DefEqual(axis_separators, other->axis_separators) &&
equal.DefEqual(elem_offset, other->elem_offset) &&
equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(data);
hash_reduce(dtype);
hash_reduce.DefHash(shape);
hash_reduce.DefHash(strides);
hash_reduce.DefHash(elem_offset);
hash_reduce.DefHash(axis_separators);
hash_reduce(data_alignment);
hash_reduce(buffer_type);
}
/*! \return preferred index type for this buffer node */
DataType DefaultIndexType() const {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
}
/*! \brief Determine the offset in the buffer of the given index.
*
* Returns the buffer offset, in number of elements of type dtype,
* without adjusting for number of lanes. (e.g. The number of
* float16x4 elements in a buffer of type float16x4.)
*/
Array<PrimExpr> ElemOffset(Array<PrimExpr> index) const;
static constexpr const char* _type_key = "tir.Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
};
/*!
* \brief Buffer is a symbolic n-darray structure.
* It is a composition of primitive symbolic types,
* used to specify the memory layout of the Tensor used in program input.
*/
class Buffer : public ObjectRef {
public:
// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
TVM_DLL Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
PrimExpr elem_offset, String name, int data_alignment, int offset_factor,
BufferType buffer_type, Array<IntImm> axis_separators = {}, Span span = Span());
/*!
* \brief Return a new buffer that is equivalent with current one
* but always add stride field.
* \return The strided version of the buffer.
*/
TVM_DLL Buffer MakeStrideView() const;
/*!
* \brief Make a new symbolic buffer representing a slice of the buffer.
* \param begins The beginning position of each dimension.
* \param extents The extent of each dimension.
* \note This function will make target buffer as compact as possible.
* If stride is not needed in the slice, it won't be presented
* \return the result buffer.
*/
TVM_DLL Buffer MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const;
/*!
* \brief Get access ptr to the entire buffer.
* \param access_mask The access mask
* \param ptr_type The type of the pointer.
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
* \param input_extent The extent of ptr.
*/
TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0),
Optional<PrimExpr> input_extent = NullOpt) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param dtype The data type to be loaded.
*/
TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype) const;
/*!
* \brief Create a Stmt that does a vector store at begin index.
* \param begin The beginning index
* \param value The value to be stored.
*/
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
/*!
* \brief Get a flattened version of the buffer
*/
Buffer GetFlattenedBuffer() const;
/*! \brief Determine the offset in the buffer of the given index.
*
* Returns the buffer offset, in number of elements of type dtype,
* without adjusting for number of lanes. (e.g. The number of
* float16x4 elements in a buffer of type float16x4.)
*/
Array<PrimExpr> OffsetOf(Array<PrimExpr> index) const;
/*!
* \brief Return the storage scope associated with this buffer.
*/
TVM_DLL String scope() const;
TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode);
};
/*!
* \brief Construct a new buffer given shape, and dtype.
* \param shape The shape of the buffer,
* \param dtype The content data type.
* \param name The name of the buffer
* \param storage_scope The storage scope associated with this buffer
* \param axis_separators Divisions defining the groups of axes that will be flattened together.
* \param span The location of this object in the source code.
* \return The created buffer.
* \sa Buffer for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
String name = "buffer", String storage_scope = "",
Array<IntImm> axis_separators = {}, Span span = Span());
/*!
* \brief Base node for data producers.
*
* A DataProducer stores necessary information(e.g. a tensor expression) to produce
* a multi-dimensional array. The stored information is opaque to the TIR.
* DataProducer can appear in high-level DSLs that are built on top of the TIR.
*
* A valid TIR PrimFunc should not contain any DataProducer, high level DSLs should lower
* all DataProducers to Buffers before TIR transformations.
*
* \sa tvm::te::Tensor
*/
class DataProducerNode : public Object {
public:
/*! \brief destructor. */
virtual ~DataProducerNode() {}
/*!
* \brief Get the shape of the result.
* \return The shape.
*/
virtual Array<PrimExpr> GetShape() const = 0;
/*!
* \brief Get the data type of the result.
* \return The data type.
*/
virtual DataType GetDataType() const = 0;
/*!
* \brief Get the name hint of the data producer.
* \return The data type.
*/
virtual String GetNameHint() const = 0;
bool SEqualReduce(const DataProducerNode* other, SEqualReducer equal) const {
// because buffer producer is opaque, we just do pointer equality.
return this == other;
}
void SHashReduce(SHashReducer hash_reduce) const {}
static constexpr const char* _type_key = "tir.DataProducer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object);
};
/*!
* \brief Managed reference to DataProducerNode.
* \sa DataProducerNode
*/
class DataProducer : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, ObjectRef, DataProducerNode);
};
/*!
* \brief Creates TIR Buffer for provided parameters
* \param shape shape of the buffer
* \param dtype data type
* \param name buffer name
* \param data_alignment alignment requirement of data pointer in bytes
* \param offset_factor Factor of elem_offset field, elem_offset is guaranteed to be
* multiple of offset_factor
User can specify data_alignment and offset_factor to be 0
* A default value will be picked.
* \param compact If the statement has already bound to a compact buffer.
* \param memory_scope memory scope of the buffer
*/
TVM_DLL tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype,
std::string name, int data_alignment,
int offset_factor, bool compact,
std::string memory_scope = "");
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_BUFFER_H_