| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file buffer.cc |
| */ |
| #include <tvm/arith/analyzer.h> |
| #include <tvm/runtime/device_api.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/tir/analysis.h> |
| #include <tvm/tir/buffer.h> |
| #include <tvm/tir/builtin.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/op.h> |
| |
| #include <iterator> |
| #include <stack> |
| |
| namespace tvm { |
| namespace tir { |
| |
| using IndexMod = tir::FloorModNode; |
| using IndexDiv = tir::FloorDivNode; |
| |
| Array<PrimExpr> SimplifyArray(arith::Analyzer* ana, Array<PrimExpr> array) { |
| for (size_t i = 0; i < array.size(); ++i) { |
| array.Set(i, ana->Simplify(array[i])); |
| } |
| return array; |
| } |
| |
| Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name) { |
| return Buffer(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array<PrimExpr>(), |
| PrimExpr(), name, "", 0, 0, kDefault); |
| } |
| |
| // Split the given expression w.r.t the add operator |
| inline std::vector<const PrimExpr*> ExprSplitAddition(const PrimExpr& expr) { |
| using namespace tir; |
| std::vector<const PrimExpr*> ret; |
| std::stack<const PrimExpr*> split_buffer; |
| split_buffer.push(&expr); |
| while (!split_buffer.empty()) { |
| const PrimExpr* top_ele = split_buffer.top(); |
| split_buffer.pop(); |
| auto expr_add_match = top_ele->as<AddNode>(); |
| if (expr_add_match) { |
| split_buffer.push(&expr_add_match->b); |
| split_buffer.push(&expr_add_match->a); |
| } else { |
| ret.emplace_back(top_ele); |
| } |
| } |
| return ret; |
| } |
| |
| // Searches for the following types of expr: |
| // mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki |
| // mod_l_expr = c |
| // mod_r_expr = k1 * k2 * ... * ki |
| // If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c) |
| // Currently the we will not search the add/mult combinations exhaustively |
| // as it will take too much computation. |
| inline std::pair<bool, PrimExpr> MergeMulModInner(const PrimExpr& mult_expr, |
| const PrimExpr& mod_l_expr, |
| const PrimExpr& mod_r_expr) { |
| using namespace tir; |
| const MulNode* mult_ptr = mult_expr.as<MulNode>(); |
| if (!mult_ptr) return std::make_pair(false, PrimExpr()); |
| PrimExpr mult_outer = mult_ptr->b; |
| const PrimExpr* inner = &(mult_ptr->a); |
| // 1. Calculate the outer multiplier |
| while (true) { |
| mult_ptr = inner->as<MulNode>(); |
| if (mult_ptr) { |
| inner = &(mult_ptr->a); |
| mult_outer = mult_ptr->b * mult_outer; |
| } else { |
| break; |
| } |
| } |
| // 2. Search for the pattern c / (...) * (...) + c % (...) |
| // We match the search element with Add, Mul and Div. |
| // If Add is found, we need to continue our search for the rhs |
| // If Mult is found, we will expand the inner multiplication factor |
| // If Div is found, we will go on testing whether lhs matches the lhs of mod expr |
| // and returns the optimization result. |
| const PrimExpr* search_ptr = inner; |
| PrimExpr mult_inner; // The inner multiplication factor |
| PrimExpr no_opt_sum; // Sum of the exprs that cannot be optimized |
| tir::ExprDeepEqual expr_equal; |
| |
| while (true) { |
| auto inner_div_ptr = search_ptr->as<IndexDiv>(); |
| auto inner_mult_ptr = search_ptr->as<MulNode>(); |
| auto inner_add_ptr = search_ptr->as<AddNode>(); |
| if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) { |
| return std::make_pair(false, PrimExpr()); |
| } else if (inner_div_ptr) { |
| PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer; |
| if (expr_equal(overall_mult, inner_div_ptr->b) && expr_equal(overall_mult, mod_r_expr) && |
| expr_equal(inner_div_ptr->a, mod_l_expr)) { |
| // Found! |
| PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr; |
| return std::make_pair(true, ret); |
| } else { |
| return std::make_pair(false, PrimExpr()); |
| } |
| } else if (inner_mult_ptr) { |
| mult_inner = mult_inner.get() ? inner_mult_ptr->b * mult_inner : inner_mult_ptr->b; |
| search_ptr = &(inner_mult_ptr->a); |
| } else if (inner_add_ptr) { |
| if (mult_inner.get()) { |
| return std::make_pair(false, PrimExpr()); |
| } |
| no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a; |
| search_ptr = &(inner_add_ptr->b); |
| } else { |
| LOG(FATAL) << "Unexpected search result!"; |
| break; |
| } |
| } |
| return std::make_pair(false, PrimExpr()); |
| } |
| |
| // Insert the elements into the corresponding mult_exprs and mod_exprs. |
| // If the element is found to match Mul, it will be pushed to the mult_exprs. |
| // If the element it found to match Mod, it will be pused to the mod_exprs. |
| // Otherwise, the elements will be added to the no_opt_sum variable |
| inline void MergeMulModInsertElements(const std::vector<const PrimExpr*>& eles, |
| std::list<PrimExpr>* mult_exprs, |
| std::list<std::pair<PrimExpr, PrimExpr> >* mod_exprs, |
| PrimExpr* no_opt_sum, bool* has_mult, bool* has_mod) { |
| using namespace tir; |
| *has_mult = false; |
| *has_mod = false; |
| for (const PrimExpr* ele : eles) { |
| auto mod_ptr = ele->as<IndexMod>(); |
| auto mult_ptr = ele->as<MulNode>(); |
| if (mod_ptr) { |
| *has_mod = true; |
| mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b))); |
| } else if (mult_ptr) { |
| *has_mult = true; |
| mult_exprs->emplace_back(*ele); |
| } else { |
| *no_opt_sum = no_opt_sum->get() ? *no_opt_sum + *ele : *ele; |
| } |
| } |
| } |
| |
| // Searches for this types of expr: |
| // (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki |
| // + c % (k1 * k2 * ... * ki) |
| // and simplifies to (a1 + a2 + ... + aj) * kt * ... * ki + c |
| // The search will be performed repeatively until no pattern is found. |
| // Return: a pair with (false, Expr()) if cannot be optimized. |
| // a pair with (true, optimized_expr) if can be optimized |
| inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { |
| using namespace tir; |
| // 1. Prepare the lists. |
| // We store two lists, a list that contain all the elements that match Mul and |
| // a list that contain all the elements that match Mod. |
| // The elements in the Mod will be used to match against the elements in Mul. |
| // The result will then be split and pushed back to these two lists. |
| PrimExpr simplified_base = analyzer->Simplify(base); |
| std::vector<const PrimExpr*> eles = ExprSplitAddition(simplified_base); |
| std::list<PrimExpr> mult_exprs; |
| std::list<std::pair<PrimExpr, PrimExpr> > mod_exprs; |
| PrimExpr no_opt_sum; |
| bool has_mult; |
| bool has_mod; |
| MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, &has_mod); |
| bool find_opt = false; |
| std::list<std::pair<PrimExpr, PrimExpr> >::iterator search_mod_it = mod_exprs.begin(); |
| // 2. Exhaustive Search |
| while (search_mod_it != mod_exprs.end()) { |
| std::list<PrimExpr>::iterator mult_it = mult_exprs.begin(); |
| bool inner_find_opt = false; |
| while (mult_it != mult_exprs.end()) { |
| std::pair<bool, PrimExpr> ret = |
| MergeMulModInner(*mult_it, search_mod_it->first, search_mod_it->second); |
| if (ret.first) { |
| inner_find_opt = true; |
| auto temp_mod_it = search_mod_it; |
| ++search_mod_it; |
| mod_exprs.erase(temp_mod_it); |
| mult_exprs.erase(mult_it); |
| std::vector<const PrimExpr*> ret_eles = ExprSplitAddition(ret.second); |
| MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, |
| &has_mod); |
| if (has_mult) { |
| search_mod_it = mod_exprs.begin(); |
| } else if (has_mod && search_mod_it == mod_exprs.end()) { |
| search_mod_it--; |
| } |
| break; |
| } else { |
| ++mult_it; |
| } |
| } |
| find_opt = find_opt || inner_find_opt; |
| if (!inner_find_opt) { |
| ++search_mod_it; |
| } |
| } |
| if (!find_opt) { |
| return simplified_base; |
| } |
| for (std::list<PrimExpr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) { |
| no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it; |
| } |
| for (std::list<std::pair<PrimExpr, PrimExpr> >::iterator it = mod_exprs.begin(); |
| it != mod_exprs.end(); ++it) { |
| no_opt_sum = no_opt_sum.get() ? no_opt_sum + indexmod(it->first, it->second) |
| : indexmod(it->first, it->second); |
| } |
| return no_opt_sum; |
| } |
| |
| // The buffer offset in convention of number of elements of |
| // original data ignoring number of lanes. |
| // We also perform optimization to simplify the indexing expression. |
| inline PrimExpr ElemOffset(const BufferNode* n, Array<PrimExpr> index) { |
| PrimExpr base = n->elem_offset; |
| arith::Analyzer ana; |
| if (n->strides.size() == 0) { |
| // Scalar case |
| if (n->shape.size() == 0 && index.size() == 1) { |
| auto is_int = index[0].as<IntImmNode>(); |
| CHECK(is_int && is_int->value == 0); |
| base = base + index[0]; |
| } else { |
| CHECK_EQ(n->shape.size(), index.size()); |
| if (index.size() > 0) { |
| PrimExpr offset = index[0]; |
| for (size_t i = 1; i < index.size(); ++i) { |
| offset = MergeMulMod(&ana, offset * n->shape[i] + index[i]); |
| } |
| base = base + offset; |
| } |
| } |
| } else { |
| CHECK_EQ(n->strides.size(), index.size()); |
| if (is_zero(base)) { |
| base = MergeMulMod(&ana, index[0] * n->strides[0]); |
| } else { |
| base = MergeMulMod(&ana, base + index[0] * n->strides[0]); |
| } |
| for (size_t i = 1; i < index.size(); ++i) { |
| base = MergeMulMod(&ana, base + index[i] * n->strides[i]); |
| } |
| } |
| return base; |
| } |
| |
| inline PrimExpr BufferOffset(const BufferNode* n, Array<PrimExpr> index, DataType dtype) { |
| PrimExpr offset = ElemOffset(n, index); |
| if (n->dtype.lanes() != 1) { |
| offset = offset * make_const(offset.dtype(), dtype.lanes()); |
| } |
| if (dtype.lanes() != 1) { |
| return tir::Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); |
| } else { |
| return offset; |
| } |
| } |
| |
| PrimExpr Buffer::vload(Array<PrimExpr> begin, DataType dtype) const { |
| // specially handle bool, stored as DataType::Int(8) |
| const BufferNode* n = operator->(); |
| CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) |
| << "Cannot load " << dtype << " from buffer of " << n->dtype; |
| if (dtype == DataType::Bool()) { |
| return tir::Cast(DataType::Bool(), |
| tir::Load(DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), |
| const_true())); |
| } else { |
| return tir::Load(dtype, n->data, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); |
| } |
| } |
| |
| Stmt Buffer::vstore(Array<PrimExpr> begin, PrimExpr value) const { |
| // specially handle bool, stored as DataType::Int(8) |
| const BufferNode* n = operator->(); |
| DataType dtype = value.dtype(); |
| CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) |
| << "Cannot store " << dtype << " to buffer of " << n->dtype; |
| if (value.dtype() == DataType::Bool()) { |
| return tir::Store(n->data, tir::Cast(DataType::Int(8), value), |
| BufferOffset(n, begin, DataType::Int(8)), const_true()); |
| } else { |
| return tir::Store(n->data, value, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); |
| } |
| } |
| |
| Buffer Buffer::MakeStrideView() const { |
| if ((*this)->strides.size() != 0) return *this; |
| if ((*this)->shape.size() == 0) return *this; |
| std::vector<PrimExpr> temp; |
| auto n = make_object<BufferNode>(*operator->()); |
| PrimExpr acc = make_const(n->DefaultIndexType(), 1); |
| for (size_t i = n->shape.size(); i != 0; --i) { |
| temp.push_back(acc); |
| acc = acc * n->shape[i - 1]; |
| } |
| for (size_t i = temp.size(); i != 0; --i) { |
| n->strides.push_back(temp[i - 1]); |
| } |
| return Buffer(n); |
| } |
| |
| Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const { |
| const BufferNode* n = operator->(); |
| arith::Analyzer ana; |
| begins = SimplifyArray(&ana, begins); |
| PrimExpr elem_offset = ana.Simplify(ElemOffset(n, begins)); |
| Array<PrimExpr> strides = n->strides; |
| if (strides.size() == 0) { |
| bool can_relax = true; |
| bool need_stride = false; |
| // check if stride is needed. |
| for (size_t i = 0; i < extents.size(); ++i) { |
| if (!can_relax) { |
| if (!is_zero(begins[i]) || !is_zero(ana.Simplify(extents[i] - n->shape[i]))) { |
| need_stride = true; |
| } |
| } |
| if (!is_one(extents[i])) can_relax = false; |
| } |
| // make stride. |
| if (need_stride) { |
| return MakeStrideView().MakeSlice(begins, extents); |
| } |
| } |
| return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->scope, |
| n->data_alignment, 0, n->buffer_type); |
| } |
| |
| PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, |
| PrimExpr offset) const { |
| const BufferNode* self = operator->(); |
| PrimExpr e_dtype; |
| PrimExpr extent; |
| if (self->shape.size() == 0) { |
| extent = make_const(self->DefaultIndexType(), 1); |
| } else if (self->strides.size() == self->shape.size()) { |
| int highest_dim = 0; |
| extent = self->strides[highest_dim] * self->shape[highest_dim] - offset; |
| } else { |
| auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; }; |
| extent = foldl(fmul, make_const(DataType::Int(32), 1), self->shape) - offset; |
| } |
| PrimExpr elem_offset = self->elem_offset + offset; |
| if (content_lanes > 1) { |
| e_dtype = tir::TypeAnnotation(self->dtype.with_lanes(content_lanes)); |
| extent = extent / make_const(self->elem_offset.dtype(), content_lanes); |
| elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), content_lanes); |
| } else { |
| e_dtype = tir::TypeAnnotation(self->dtype); |
| } |
| Array<PrimExpr> acc_args{e_dtype, self->data, elem_offset, extent, |
| make_const(DataType::Int(32), access_mask)}; |
| return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args); |
| } |
| |
| Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides, |
| PrimExpr elem_offset, String name, String scope, int data_alignment, |
| int offset_factor, BufferType buffer_type) { |
| CHECK(IsPointerType(data->type_annotation, dtype)) |
| << "Buffer data field expect to have the right pointer type annotation" |
| << " annotation=" << data->type_annotation << ", dtype=" << dtype; |
| |
| auto n = make_object<BufferNode>(); |
| n->data = std::move(data); |
| n->dtype = dtype; |
| |
| n->shape = std::move(shape); |
| n->strides = std::move(strides); |
| n->name = std::move(name); |
| if (scope.length() == 0) { |
| scope = "global"; |
| } |
| n->scope = std::move(scope); |
| if (!elem_offset.defined()) { |
| elem_offset = make_const(n->DefaultIndexType(), 0); |
| } |
| if (data_alignment <= 0) { |
| data_alignment = runtime::kAllocAlignment; |
| } |
| if (offset_factor == 0) { |
| offset_factor = 1; |
| } |
| n->elem_offset = std::move(elem_offset); |
| n->data_alignment = data_alignment; |
| n->offset_factor = offset_factor; |
| n->buffer_type = buffer_type; |
| if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) { |
| for (size_t i = 0; i < n->shape.size(); ++i) { |
| n->strides.push_back(Var("stride", n->shape[i].dtype())); |
| } |
| } |
| data_ = std::move(n); |
| } |
| |
| TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
| .set_dispatch<BufferNode>([](const ObjectRef& node, ReprPrinter* p) { |
| auto* op = static_cast<const BufferNode*>(node.get()); |
| p->stream << "buffer(" << op->name << ", " << op << ")"; |
| }); |
| |
| TVM_REGISTER_NODE_TYPE(BufferNode); |
| |
| TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { |
| CHECK_EQ(args.size(), 10); |
| auto buffer_type = args[9].operator String(); |
| BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; |
| *ret = |
| Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], type); |
| }); |
| |
| TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); |
| |
| TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); |
| |
| TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); |
| |
| } // namespace tir |
| } // namespace tvm |