blob: 87b9a2628cc71d3234bc72ef986e4cdc9ee39248 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file buffer.cc
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <iterator>
#include <stack>
#include "../../arith/pattern_match.h"
namespace tvm {
namespace tir {
TVM_FFI_STATIC_INIT_BLOCK() { BufferNode::RegisterReflection(); }
using IndexMod = tir::FloorModNode;
using IndexDiv = tir::FloorDivNode;
ffi::Array<PrimExpr> SimplifyArray(arith::Analyzer* ana, ffi::Array<PrimExpr> array) {
for (size_t i = 0; i < array.size(); ++i) {
array.Set(i, ana->Simplify(array[i]));
}
return array;
}
Buffer decl_buffer(ffi::Array<PrimExpr> shape, DataType dtype, ffi::String name,
ffi::String storage_scope, ffi::Optional<ffi::Array<IntImm>> axis_separators,
Span span) {
DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape,
ffi::Array<PrimExpr>(), PrimExpr(), name, 0, 0, kDefault,
axis_separators.value_or(ffi::Array<IntImm>()), span);
}
// Split the given expression w.r.t the add operator
inline std::vector<const PrimExpr*> ExprSplitAddition(const PrimExpr& expr) {
using namespace tir;
std::vector<const PrimExpr*> ret;
std::stack<const PrimExpr*> split_buffer;
split_buffer.push(&expr);
while (!split_buffer.empty()) {
const PrimExpr* top_ele = split_buffer.top();
split_buffer.pop();
auto expr_add_match = top_ele->as<AddNode>();
if (expr_add_match) {
split_buffer.push(&expr_add_match->b);
split_buffer.push(&expr_add_match->a);
} else {
ret.emplace_back(top_ele);
}
}
return ret;
}
// Searches for the following types of expr:
// mult_expr = (a1 + a2 + ... + aj + c1 / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
// mod_l_expr = c2
// mod_r_expr = k1 * k2 * ... * ki
// where c1 ~= c2 mod k1 * k2 * ... * ki
// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c1)
// Currently the we will not search the add/mult combinations exhaustively
// as it will take too much computation.
inline std::pair<bool, PrimExpr> MergeMulModInner(arith::Analyzer* analyzer,
const PrimExpr& mult_expr,
const PrimExpr& mod_l_expr,
const PrimExpr& mod_r_expr) {
using namespace tir;
const MulNode* mult_ptr = mult_expr.as<MulNode>();
if (!mult_ptr) return std::make_pair(false, PrimExpr());
PrimExpr mult_outer = mult_ptr->b;
const PrimExpr* inner = &(mult_ptr->a);
// 1. Calculate the outer multiplier
while (true) {
mult_ptr = inner->as<MulNode>();
if (mult_ptr) {
inner = &(mult_ptr->a);
mult_outer = mult_ptr->b * mult_outer;
} else {
break;
}
}
// 2. Search for the pattern c / (...) * (...) + c % (...)
// We match the search element with Add, Mul and Div.
// If Add is found, we need to continue our search for the rhs
// If Mult is found, we will expand the inner multiplication factor
// If Div is found, we will go on testing whether lhs matches the lhs of mod expr
// and returns the optimization result.
const PrimExpr* search_ptr = inner;
PrimExpr mult_inner; // The inner multiplication factor
PrimExpr no_opt_sum; // Sum of the exprs that cannot be optimized
tir::ExprDeepEqual expr_equal;
while (true) {
auto inner_div_ptr = search_ptr->as<IndexDiv>();
auto inner_mult_ptr = search_ptr->as<MulNode>();
auto inner_add_ptr = search_ptr->as<AddNode>();
if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
return std::make_pair(false, PrimExpr());
} else if (inner_div_ptr) {
PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
if (expr_equal(overall_mult, inner_div_ptr->b) && expr_equal(overall_mult, mod_r_expr) &&
analyzer->CanProveEqual(floormod(inner_div_ptr->a - mod_l_expr, mod_r_expr), 0)) {
// Found!
PrimExpr ret =
no_opt_sum.get() ? no_opt_sum * mult_outer + inner_div_ptr->a : inner_div_ptr->a;
return std::make_pair(true, ret);
} else {
return std::make_pair(false, PrimExpr());
}
} else if (inner_mult_ptr) {
mult_inner = mult_inner.get() ? inner_mult_ptr->b * mult_inner : inner_mult_ptr->b;
search_ptr = &(inner_mult_ptr->a);
} else if (inner_add_ptr) {
if (mult_inner.get()) {
return std::make_pair(false, PrimExpr());
}
no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a;
search_ptr = &(inner_add_ptr->b);
} else {
LOG(FATAL) << "Unexpected search result!";
break;
}
}
return std::make_pair(false, PrimExpr());
}
// Insert the elements into the corresponding mult_exprs and mod_exprs.
// If the element is found to match Mul, it will be pushed to the mult_exprs.
// If the element it found to match Mod, it will be pused to the mod_exprs.
// Otherwise, the elements will be added to the no_opt_sum variable
inline void MergeMulModInsertElements(const std::vector<const PrimExpr*>& eles,
std::list<PrimExpr>* mult_exprs,
std::list<std::pair<PrimExpr, PrimExpr>>* mod_exprs,
PrimExpr* no_opt_sum, bool* has_mult, bool* has_mod) {
using namespace tir;
*has_mult = false;
*has_mod = false;
for (const PrimExpr* ele : eles) {
auto mod_ptr = ele->as<IndexMod>();
auto mult_ptr = ele->as<MulNode>();
if (mod_ptr) {
*has_mod = true;
mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b)));
} else if (mult_ptr) {
*has_mult = true;
mult_exprs->emplace_back(*ele);
} else {
*no_opt_sum = no_opt_sum->get() ? *no_opt_sum + *ele : *ele;
}
}
}
// Searches for this types of expr:
// (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
// + c % (k1 * k2 * ... * ki)
// and simplifies to (a1 + a2 + ... + aj) * kt * ... * ki + c
// The search will be performed repeatively until no pattern is found.
// Return: a pair with (false, Expr()) if cannot be optimized.
// a pair with (true, optimized_expr) if can be optimized
inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) {
using namespace tir;
// 1. Prepare the lists.
// We store two lists, a list that contain all the elements that match Mul and
// a list that contain all the elements that match Mod.
// The elements in the Mod will be used to match against the elements in Mul.
// The result will then be split and pushed back to these two lists.
PrimExpr simplified_base = base;
arith::PVar<PrimExpr> x, y;
if ((floordiv(x, y) * y + floormod(x, y)).Match(simplified_base)) {
simplified_base = x.Eval();
}
simplified_base = analyzer->Simplify(simplified_base);
std::vector<const PrimExpr*> eles = ExprSplitAddition(simplified_base);
std::list<PrimExpr> mult_exprs;
std::list<std::pair<PrimExpr, PrimExpr>> mod_exprs;
PrimExpr no_opt_sum;
bool has_mult;
bool has_mod;
MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, &has_mod);
bool find_opt = false;
std::list<std::pair<PrimExpr, PrimExpr>>::iterator search_mod_it = mod_exprs.begin();
// 2. Exhaustive Search
while (search_mod_it != mod_exprs.end()) {
std::list<PrimExpr>::iterator mult_it = mult_exprs.begin();
bool inner_find_opt = false;
while (mult_it != mult_exprs.end()) {
std::pair<bool, PrimExpr> ret =
MergeMulModInner(analyzer, *mult_it, search_mod_it->first, search_mod_it->second);
if (ret.first) {
inner_find_opt = true;
auto temp_mod_it = search_mod_it;
++search_mod_it;
mod_exprs.erase(temp_mod_it);
mult_exprs.erase(mult_it);
std::vector<const PrimExpr*> ret_eles = ExprSplitAddition(ret.second);
MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult,
&has_mod);
if (has_mult) {
search_mod_it = mod_exprs.begin();
} else if (has_mod && search_mod_it == mod_exprs.end()) {
search_mod_it--;
}
break;
} else {
++mult_it;
}
}
find_opt = find_opt || inner_find_opt;
if (!inner_find_opt) {
++search_mod_it;
}
}
if (!find_opt) {
return simplified_base;
}
for (std::list<PrimExpr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) {
no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it;
}
for (std::list<std::pair<PrimExpr, PrimExpr>>::iterator it = mod_exprs.begin();
it != mod_exprs.end(); ++it) {
no_opt_sum = no_opt_sum.get() ? no_opt_sum + indexmod(it->first, it->second)
: indexmod(it->first, it->second);
}
return no_opt_sum;
}
ffi::Array<PrimExpr> Buffer::OffsetOf(ffi::Array<PrimExpr> input_indices) const {
return (*this)->ElemOffset(std::move(input_indices));
}
// The buffer offset in convention of number of elements of
// original data ignoring number of lanes.
// We also perform optimization to simplify the indexing expression.
ffi::Array<PrimExpr> BufferNode::ElemOffset(ffi::Array<PrimExpr> input_indices) const {
ICHECK_EQ(shape.size(), input_indices.size())
<< "Buffer " << this->name << " is " << shape.size()
<< "-dimensional, cannot be indexed with the " << input_indices.size()
<< "-dimensional indices provided.";
if (strides.size()) {
ICHECK_EQ(this->strides.size(), input_indices.size())
<< "If strides are defined, "
<< "the index's dimensionality must match the dimensionality of the index given.";
}
// TODO(Lunderberg): Better handling for cases where there is more
// than one output index. Currently, this only allows elem_offset
// to be non-zero for flat memory allocations.
ffi::Array<PrimExpr> elem_offsets = {};
if (elem_offset.defined() && !is_zero(elem_offset)) {
elem_offsets = {elem_offset};
}
if (elem_offsets.size()) {
ICHECK_EQ(elem_offsets.size(), axis_separators.size() + 1)
<< "If element offsets are defined, "
<< "there must be one element offset for each output index.";
}
ffi::Array<PrimExpr> output_indices(axis_separators.size() + 1, 0);
size_t current_output_axis = 0;
arith::Analyzer ana;
for (size_t i = 0; i < input_indices.size(); i++) {
if ((current_output_axis < axis_separators.size()) &&
(i == size_t(axis_separators[current_output_axis]->value))) {
current_output_axis++;
}
PrimExpr output_index = output_indices[current_output_axis];
if (strides.size()) {
output_index = output_index + input_indices[i] * strides[i];
} else {
output_index = output_index * this->shape[i] + input_indices[i];
}
if (i > 0) {
output_index = MergeMulMod(&ana, output_index);
}
output_indices.Set(current_output_axis, output_index);
}
if (elem_offsets.size()) {
for (size_t i = 0; i < output_indices.size(); i++) {
output_indices.Set(i, output_indices[i] + elem_offsets[i]);
}
}
return SimplifyArray(&ana, output_indices);
}
inline ffi::Array<PrimExpr> BufferOffset(const BufferNode* n, ffi::Array<PrimExpr> index,
DataType dtype) {
ffi::Array<PrimExpr> offsets = n->ElemOffset(index);
// If the Buffer has element type with more than one lane, scale to
// get the offset in number of scalars.
if (n->dtype.lanes() != 1) {
PrimExpr last_offset = offsets[offsets.size() - 1];
offsets.Set(offsets.size() - 1, last_offset * make_const(last_offset.dtype(), dtype.lanes()));
}
// If the requested type has more than one lane, make a RampNode at
// that offset.
if (dtype.lanes() != 1) {
PrimExpr last_offset = offsets[offsets.size() - 1];
PrimExpr stride = make_const(last_offset.dtype(), 1);
offsets.Set(offsets.size() - 1, tir::Ramp(last_offset, stride, dtype.lanes()));
}
return offsets;
}
static void ValidateAxisSeparators(const ffi::Array<IntImm>& axis_separators, size_t buffer_dim) {
// These checks ensure that all output axes contain at least one
// input axis.
for (size_t i = 0; (i + 1) < axis_separators.size(); i++) {
auto sep = axis_separators[i]->value;
auto next_sep = axis_separators[i + 1]->value;
CHECK_LE(sep, next_sep) << "ValueError: "
<< "Axis separators must be in increasing order, "
<< "but axis_separators[" << i << "] = " << sep
<< " is greater than or equal to axis_separators[" << (i + 1)
<< "] = " << next_sep << ".";
}
if (axis_separators.size()) {
auto first_sep = axis_separators[0]->value;
CHECK_GE(first_sep, 0) << "ValueError: "
<< "All axis separators must be non-negative. "
<< "However, the axis_separators[0] = " << first_sep;
auto last_sep = axis_separators[axis_separators.size() - 1]->value;
CHECK_LE(last_sep, buffer_dim)
<< "ValueError: "
<< "All axis separators must be within the range "
<< "0 <= sep <= buffer_dim. "
<< "However, the last axis_separators[" << (axis_separators.size() - 1)
<< "] = " << last_sep << " is greater than the buffer's dimensionality of " << buffer_dim;
}
}
Buffer Buffer::GetFlattenedBuffer() const {
auto self = operator->();
ValidateAxisSeparators(self->axis_separators, self->shape.size());
ffi::Array<PrimExpr> output_shape;
if (self->strides.size()) {
// If strides are defined, then the extent of each flattened
// buffer is the stride*size for the first input axis used for
// each output axis.
ICHECK_EQ(self->shape.size(), self->strides.size());
output_shape.push_back(self->strides[0] * self->shape[0]);
for (const auto& sep : self->axis_separators) {
output_shape.push_back(self->strides[sep->value] * self->shape[sep->value]);
}
} else {
// Otherwise, the extent of each flattened buffer is the product
// of the extents of each input axis used to generate that output
// axis. This also "flattens" rank-0 tensors to a rank-1 buffer
// of shape [1].
output_shape = ffi::Array<PrimExpr>(self->axis_separators.size() + 1, 1);
size_t current_output_index = 0;
for (size_t i = 0; i < self->shape.size(); i++) {
if ((current_output_index < self->axis_separators.size()) &&
(i == size_t(self->axis_separators[current_output_index]->value))) {
current_output_index += 1;
}
output_shape.Set(current_output_index, output_shape[current_output_index] * self->shape[i]);
}
}
// The axis_separators for the output buffer.
ffi::Array<IntImm> output_axis_separators;
for (size_t i = 0; i < self->axis_separators.size(); i++) {
auto dtype = self->axis_separators[i]->dtype;
output_axis_separators.push_back(IntImm(dtype, i + 1));
}
if (output_shape.size() == self->shape.size() && self->strides.empty()) {
return *this;
} else {
Buffer output = *this;
auto writer = output.CopyOnWrite();
writer->shape = output_shape;
writer->axis_separators = output_axis_separators;
writer->strides = {};
return output;
}
}
PrimExpr Buffer::vload(ffi::Array<PrimExpr> begin, DataType value_dtype,
ffi::Optional<PrimExpr> predicate) const {
// specially handle bool, stored as DataType::Int(8)
const BufferNode* n = operator->();
ICHECK(n != nullptr);
ICHECK(value_dtype.element_of() == n->dtype.element_of() &&
value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0)
<< "Cannot load " << value_dtype << " from buffer of " << n->dtype;
ffi::Array<PrimExpr> indices = begin;
PrimExpr base = indices[indices.size() - 1];
if (value_dtype.is_fixed_length_vector()) {
int factor = value_dtype.lanes() / n->dtype.lanes();
if (factor > 1 && base.dtype().is_scalar()) {
indices.Set(indices.size() - 1, Ramp(base, 1, factor));
}
}
return BufferLoad(*this, indices, predicate);
}
Stmt Buffer::vstore(ffi::Array<PrimExpr> begin, PrimExpr value,
ffi::Optional<PrimExpr> predicate) const {
// specially handle bool, stored as DataType::Int(8)
const BufferNode* n = operator->();
ICHECK(n != nullptr);
DataType value_dtype = value.dtype();
ICHECK(value_dtype.element_of() == n->dtype.element_of() &&
value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0)
<< "Cannot store " << value_dtype << " to buffer of " << n->dtype;
ffi::Array<PrimExpr> indices = begin;
PrimExpr base = indices[indices.size() - 1];
if (value_dtype.is_fixed_length_vector()) {
int factor = value_dtype.lanes() / n->dtype.lanes();
if (factor > 1 && base.dtype().is_scalar()) {
indices.Set(indices.size() - 1, Ramp(base, 1, factor));
}
}
return BufferStore(*this, value, indices, predicate);
}
ffi::String Buffer::scope() const {
const auto* ptr_type = (*this)->data->type_annotation.as<PointerTypeNode>();
ICHECK(ptr_type) << "Buffer variable is not of pointer type";
if (ptr_type->storage_scope.empty()) {
return "global";
}
return ptr_type->storage_scope;
}
Buffer Buffer::MakeStrideView() const {
if ((*this)->strides.size() != 0) return *this;
if ((*this)->shape.size() == 0) return *this;
std::vector<PrimExpr> temp;
const BufferNode* self = operator->();
ICHECK(self != nullptr);
auto n = ffi::make_object<BufferNode>(*self);
PrimExpr acc = make_const(n->DefaultIndexType(), 1);
for (size_t i = n->shape.size(); i != 0; --i) {
temp.push_back(acc);
acc = acc * n->shape[i - 1];
}
for (size_t i = temp.size(); i != 0; --i) {
n->strides.push_back(temp[i - 1]);
}
return Buffer(n);
}
Buffer Buffer::MakeSlice(ffi::Array<PrimExpr> begins, ffi::Array<PrimExpr> extents) const {
const BufferNode* n = operator->();
ICHECK(n != nullptr);
arith::Analyzer ana;
begins = SimplifyArray(&ana, begins);
ffi::Array<PrimExpr> elem_offset =
n->ElemOffset(begins).Map([&](const PrimExpr& expr) { return ana.Simplify(expr); });
ffi::Array<PrimExpr> strides = n->strides;
if (strides.size() == 0) {
bool can_relax = true;
bool need_stride = false;
// check if stride is needed.
for (size_t i = 0; i < extents.size(); ++i) {
if (!can_relax) {
if (!is_zero(begins[i]) || !is_zero(ana.Simplify(extents[i] - n->shape[i]))) {
need_stride = true;
}
}
if (!is_one(extents[i])) can_relax = false;
}
// make stride.
if (need_stride) {
return MakeStrideView().MakeSlice(begins, extents);
}
}
Buffer slice(n->data, n->dtype, extents, strides, elem_offset[0], n->name + "_slice",
n->data_alignment, 0, n->buffer_type);
// Buffer must be constructed with a singular element offset which means there is no
// support for n-dimensional buffers where n > 1. Insert sentinel value for
// ArgBinder::BindBuffer to state that any usage of element offset is invalid
// in this case. This allows for construction of a Buffer with multiple element offsets
// but disallows any usage of those element offsets. See PR #10816 for discussion on
// supporting multiple element offsets in TIR Buffer.
// TODO(Lunderberg): Remove if/when TIR supports multiple element offsets in TIR Buffer
if (elem_offset.size() != 1) {
slice.CopyOnWrite()->elem_offset = PrimExpr();
}
return slice;
}
PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset,
ffi::Optional<PrimExpr> input_extent) const {
const BufferNode* self = operator->();
ICHECK(self != nullptr);
PrimExpr e_dtype;
PrimExpr extent;
if (self->shape.size() == 0) {
extent = make_const(self->DefaultIndexType(), 1);
} else if (self->strides.size() == self->shape.size()) {
int highest_dim = 0;
extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
} else {
extent = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), self->shape) -
offset;
}
PrimExpr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
e_dtype = tir::TypeAnnotation(self->dtype.with_lanes(content_lanes));
extent = extent / make_const(self->elem_offset.dtype(), content_lanes);
elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), content_lanes);
} else {
e_dtype = tir::TypeAnnotation(self->dtype);
}
if (input_extent.defined()) {
extent = input_extent.value();
}
ffi::Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent,
make_const(DataType::Int(32), access_mask)};
return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args);
}
Buffer::Buffer(Var data, DataType dtype, ffi::Array<PrimExpr> shape, ffi::Array<PrimExpr> strides,
PrimExpr elem_offset, ffi::String name, int data_alignment, int offset_factor,
BufferType buffer_type, ffi::Array<IntImm> axis_separators, Span span) {
DataType storage_dtype = dtype;
// specially handle bool
if (storage_dtype == DataType::Bool()) {
storage_dtype = DataType::Int(8);
}
// The buffer dtype may differ from the dtype of the underlying
// allocation, such as a single allocation that backs multiple
// tensors without a common datatype. Therefore, we check that the
// data pointer is a pointer, but not the exact type of the
// pointed-to values.
// TODO(Lunderberg): Use an explicit pointer cast for the data
// pointer. Should be done alongside extensions to StmtExprMutator
// to more easily handle buffer/buffer_var updates.
ICHECK(data->type_annotation.defined())
<< "Variable " << data->name_hint << " is missing a type annotation.";
ICHECK(data->type_annotation.as<PointerTypeNode>())
<< "Variable " << data->name_hint << " is not a pointer.";
ICHECK(data->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>())
<< "Variable " << data->name_hint << " does not point to a primitive.";
ValidateAxisSeparators(axis_separators, shape.size());
auto n = ffi::make_object<BufferNode>();
n->data = std::move(data);
n->dtype = dtype;
n->shape = std::move(shape);
n->strides = std::move(strides);
n->axis_separators = std::move(axis_separators);
n->name = std::move(name);
if (!elem_offset.defined()) {
elem_offset = make_const(n->DefaultIndexType(), 0);
}
if (data_alignment <= 0) {
data_alignment = runtime::kAllocAlignment;
}
if (offset_factor == 0) {
offset_factor = 1;
}
n->elem_offset = std::move(elem_offset);
n->data_alignment = data_alignment;
n->offset_factor = offset_factor;
n->buffer_type = buffer_type;
if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) {
for (size_t i = 0; i < n->shape.size(); ++i) {
n->strides.push_back(Var("stride", n->shape[i].dtype()));
}
}
n->span = std::move(span);
data_ = std::move(n);
}
tir::Buffer BufferWithOffsetAlignment(ffi::Array<PrimExpr> shape, DataType dtype, std::string name,
int data_alignment, int offset_factor, bool compact,
std::string memory_scope) {
DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
auto data = tir::Var(name, PointerType(PrimType(storage_dtype), memory_scope));
bool has_any = false;
if (!compact) {
for (const auto& it : shape) {
if (it.as<tir::VarNode>()) {
has_any = true;
break;
}
}
}
tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault;
PrimExpr elem_offset;
if (offset_factor != 0) {
elem_offset = tir::Var(name + "_elem_offset", shape[0].dtype());
} else {
elem_offset = PrimExpr();
}
return tir::Buffer(data, dtype, shape, ffi::Array<PrimExpr>(), elem_offset, name, data_alignment,
offset_factor, buffer_type);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def_packed("tir.Buffer",
[](ffi::PackedArgs args, ffi::Any* ret) {
ICHECK_EQ(args.size(), 11);
auto buffer_type = args[8].cast<ffi::String>();
BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
auto data = args[0].cast<Var>();
auto dtype = args[1].cast<DataType>();
auto shape = args[2].cast<ffi::Array<PrimExpr>>();
auto strides = args[3].cast<ffi::Array<PrimExpr>>();
auto elem_offset = args[4].cast<PrimExpr>();
auto name = args[5].cast<ffi::String>();
auto data_alignment = args[6].cast<int>();
auto offset_factor = args[7].cast<int>();
auto axis_separators = args[9].cast<ffi::Array<IntImm>>();
auto span = args[10].cast<Span>();
*ret = Buffer(data, dtype, shape, strides, elem_offset, name, data_alignment,
offset_factor, type, axis_separators, span);
})
.def_method("tir.BufferAccessPtr", &Buffer::access_ptr)
.def_method("tir.BufferGetFlattenedBuffer", &Buffer::GetFlattenedBuffer)
.def_method("tir.BufferOffsetOf", &Buffer::OffsetOf)
.def_method("tir.BufferVLoad", &Buffer::vload)
.def_method("tir.BufferVStore", &Buffer::vstore)
.def_method("tir.BufferStorageScope", &Buffer::scope);
}
} // namespace tir
} // namespace tvm