blob: bc777db55dbe655810b02440d74b91bcdb294316 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/lang/data_layout.cc
* \brief Data Layout expression.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/data_layout.h>
#include <tvm/tir/stmt_functor.h>
#include <cctype>
namespace tvm {
namespace tir {
using tir::IterVar;
using tir::IterVarNode;
using tir::Var;
TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(BijectiveLayoutNode);
const LayoutAxis LayoutAxis::UPPER_CASE[] = {
LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'),
LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'),
LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'),
LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'),
LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'),
LayoutAxis('Z')};
const LayoutAxis LayoutAxis::LOWER_CASE[] = {
LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'),
LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'),
LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'),
LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'),
LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'),
LayoutAxis('z')};
const LayoutAxis& LayoutAxis::Get(const char name) {
CHECK((name >= 'A' && name <= 'Z') || (name >= 'a' && name <= 'z'))
<< "Invalid layout axis name: " << name << ". Has to be A-Z or a-z.";
return (name >= 'A' && name <= 'Z') ? LayoutAxis::UPPER_CASE[name - 'A']
: LayoutAxis::LOWER_CASE[name - 'a'];
}
const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) {
const std::string axis = itvar->var.get()->name_hint;
CHECK_EQ(axis.size(), 1) << "Invalid layout axis " << axis;
return LayoutAxis::Get(axis[0]);
}
const LayoutAxis& LayoutAxis::Get(const std::string& name) {
CHECK_EQ(name.length(), 1) << "Invalid axis " << name;
return LayoutAxis::Get(name[0]);
}
Layout::Layout(const Array<IterVar>& axes) {
auto node = make_object<LayoutNode>();
node->axes = axes;
std::ostringstream repr;
for (const IterVar& axis : axes) {
if (const auto* factor = axis->dom->extent.as<IntImmNode>()) {
CHECK_GT(factor->value, 0);
repr << factor->value;
}
CHECK_EQ(axis->var.get()->name_hint.size(), 1)
<< "Invalid layout axis " << axis->var.get()->name_hint;
char c = axis->var.get()->name_hint.operator std::string()[0];
CHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c;
repr << axis->var.get()->name_hint;
}
node->name = repr.str();
data_ = std::move(node);
}
Layout::Layout(const std::string& name) { // NOLINT(*)
if (name == "__undef__") return;
auto node = make_object<LayoutNode>();
node->name = name;
if (name.empty()) return; // scalar
// parse layout string
int32_t factor = 0;
for (char c : name) {
if (c >= 'A' && c <= 'Z') {
CHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor
<< " before dimension " << c;
std::string shape_name("_shape");
shape_name.insert(0, 1, c);
IterVar axis =
IterVar(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), tir::kDataPar);
node->axes.push_back(axis);
} else if (c >= 'a' && c <= 'z') {
CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor
<< " for dimension " << c;
IterVar axis =
IterVar(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)), tir::kDataPar);
node->axes.push_back(axis);
factor = 0;
} else if (c >= '0' && c <= '9') {
CHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number.";
factor = factor * 10 + c - '0';
} else {
LOG(FATAL) << "Invalid layout " << name;
}
}
// validate layout
std::vector<bool> exist_axis(256, false);
for (const IterVar& v : node->axes) {
auto axis_str = v->var.get()->name_hint.operator std::string();
CHECK_EQ(axis_str.size(), 1);
char axis = axis_str[0];
CHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z'));
CHECK(!exist_axis[axis]) << "Invalid layout " << name << ": duplicate axis " << axis;
exist_axis[axis] = true;
}
for (const IterVar& v : node->axes) {
char axis = v->var.get()->name_hint.operator std::string()[0];
if (axis >= 'a' && axis <= 'z') {
CHECK(exist_axis[axis - 'a' + 'A'])
<< "Invalid layout " << name << ": missing axis " << std::toupper(axis);
}
}
data_ = std::move(node);
}
Layout Layout::SubLayout(size_t pos, size_t len) const {
if (!defined() || pos > ndim()) return Layout::Undef();
if (len == 0) return Layout(Array<IterVar>());
if (pos + len > ndim()) len = ndim() - pos;
Array<IterVar> new_layout;
const auto axes = operator->()->axes;
for (size_t i = pos; i < pos + len; ++i) {
new_layout.push_back(axes[i]);
}
return Layout(new_layout);
}
Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const {
if (!defined()) return Layout::Undef();
const std::string& name = operator->()->name;
const auto axes = operator->()->axes;
CHECK(target_pos <= this->ndim())
<< "Invalid split position " << target_pos << " for layout " << name;
CHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis;
CHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name;
CHECK(!this->Contains(axis.ToSubordinate()))
<< "Axis " << axis << " has already been split in " << name;
CHECK(factor > 0) << "Invalid split size " << factor;
Array<IterVar> new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
if (i == target_pos) {
new_layout.push_back(IterVar(Range(PrimExpr(0), PrimExpr(factor)),
Var(axis.ToSubordinate().name()), tir::kDataPar));
}
if (i == this->ndim()) break;
new_layout.push_back(axes[i]);
}
return Layout(new_layout);
}
int32_t Layout::FactorOf(const LayoutAxis& axis) const {
if (!defined()) return -1;
const LayoutAxis& sub = axis.ToSubordinate();
if (!this->defined()) return -1;
for (const IterVar& itvar : operator->()->axes) {
if (sub == LayoutAxis::Get(itvar)) {
const auto* factor = itvar->dom->extent.as<IntImmNode>();
CHECK(factor);
return factor->value;
}
}
return -1;
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LayoutNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* l = static_cast<const LayoutNode*>(node.get());
p->stream << "Layout(" << l->name << ")";
});
inline bool GetStoreRule(Array<PrimExpr>* rule, const Layout& src_layout,
const Layout& dst_layout) {
if (!src_layout.defined() || src_layout.name().empty() || !dst_layout.defined() ||
dst_layout.name().empty()) {
return false;
}
for (size_t i = 0; i < dst_layout.ndim(); ++i) {
const auto& store_axis = dst_layout[i];
const IterVar& store_axis_impl = dst_layout->axes[i];
PrimExpr store(0);
for (size_t j = 0; j < src_layout.ndim(); ++j) {
const auto& orig_axis = src_layout[j];
const IterVar& orig_axis_impl = src_layout->axes[j];
if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
if (orig_axis.IsPrimal()) {
PrimExpr orig_var = orig_axis_impl->var;
const int32_t factor = src_layout.FactorOf(orig_axis);
if (factor > 0) {
orig_var = orig_var * PrimExpr(factor);
}
store = store + orig_var;
} else {
store = store + orig_axis_impl->var;
}
}
}
if (tir::is_zero(store)) {
// Not convertible
return false;
}
if (store_axis.IsPrimal()) {
const int32_t factor = dst_layout.FactorOf(store_axis);
if (factor > 0) {
store = indexdiv(store, PrimExpr(factor));
}
} else {
store = indexmod(store, store_axis_impl->dom->extent);
}
rule->push_back(store);
}
return true;
}
inline Array<PrimExpr> TransformIndex(const Array<PrimExpr>& src_index,
const Array<IterVar>& src_axis,
const Array<PrimExpr>& transform_rule) {
arith::Analyzer ana;
Array<PrimExpr> result;
std::unordered_map<const tir::VarNode*, PrimExpr> bind_map;
for (size_t i = 0; i < src_index.size(); ++i) {
bind_map[src_axis[i]->var.get()] = src_index[i];
}
for (PrimExpr rule : transform_rule) {
result.push_back(ana.Simplify(tir::Substitute(rule, bind_map)));
}
return result;
}
Array<PrimExpr> BijectiveLayout::ForwardIndex(const Array<PrimExpr>& src_index) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
CHECK_EQ(src_index.size(), self->src_layout->axes.size())
<< "Input mismatch with layout " << self->src_layout;
return TransformIndex(src_index, self->src_layout->axes, self->forward_rule);
}
Array<PrimExpr> BijectiveLayout::BackwardIndex(const Array<PrimExpr>& dst_index) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
CHECK_EQ(dst_index.size(), self->dst_layout->axes.size())
<< "Output mismatch with layout " << self->dst_layout;
return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule);
}
inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
const Array<IterVar>& src_axis,
const Array<IterVar>& target_axis,
const Array<PrimExpr>& transform_rule) {
arith::Analyzer ana;
CHECK_EQ(src_shape.size(), src_axis.size());
// bind variables for original axes
// for major-axis, bind the corresponding size
// for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
// e.g., (C * 16 + c) / 32
std::unordered_map<const tir::VarNode*, PrimExpr> bind_map;
std::unordered_set<size_t> symbolic_var_set;
for (size_t i = 0; i < src_shape.size(); ++i) {
PrimExpr orig_shape = src_shape[i];
IterVar orig_axis = src_axis[i];
if (orig_shape.as<tir::AnyNode>()) {
symbolic_var_set.insert(i);
}
if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
if (orig_shape.defined()) {
const auto* orig_shape_const = orig_shape.as<IntImmNode>();
const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImmNode>();
if (orig_shape_const) {
CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
<< "Input shape mismatch at index " << i << ". Expected " << orig_axis->dom->extent
<< ", get " << orig_shape;
}
}
bind_map[orig_axis->var.get()] = PrimExpr(0);
} else {
bind_map[orig_axis->var.get()] = orig_shape;
}
}
// infer the target shape,
// for major-axis, use the forward/backward_rule directly,
// for minor-axis, simply use the extent.
Array<PrimExpr> result;
CHECK_EQ(transform_rule.size(), target_axis.size());
for (size_t i = 0; i < transform_rule.size(); ++i) {
PrimExpr rule = transform_rule[i];
IterVar axis = target_axis[i];
if (!LayoutAxis::Get(axis).IsPrimal()) {
result.push_back(axis->dom->extent);
} else {
if (symbolic_var_set.count(i)) {
result.push_back(tir::Any());
} else {
result.push_back(ana.Simplify(tir::Substitute(rule, bind_map)));
}
}
}
return result;
}
Array<PrimExpr> BijectiveLayout::ForwardShape(const Array<PrimExpr>& shape) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->forward_rule);
}
Array<PrimExpr> BijectiveLayout::BackwardShape(const Array<PrimExpr>& shape) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, self->backward_rule);
}
BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) {
auto n = make_object<BijectiveLayoutNode>();
n->src_layout = std::move(src_layout);
n->dst_layout = std::move(dst_layout);
// To be consistent with previous behavior, a nullptr layout is created
// when argument is invalid.
if (GetStoreRule(&n->forward_rule, n->src_layout, n->dst_layout)) {
CHECK(GetStoreRule(&n->backward_rule, n->dst_layout, n->src_layout));
data_ = std::move(n);
}
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* b = static_cast<const BijectiveLayoutNode*>(node.get());
p->stream << "BijectiveLayout(" << b->src_layout.name() << "->" << b->dst_layout.name()
<< ")";
});
TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name) { return Layout(name); });
TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::Get(axis));
});
TVM_REGISTER_GLOBAL("tir.LayoutFactorOf")
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.FactorOf(LayoutAxis::Get(axis));
});
TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int {
return layout.ndim();
});
TVM_REGISTER_GLOBAL("tir.LayoutGetItem").set_body_typed([](Layout layout, int idx) -> std::string {
const LayoutAxis& axis = layout[idx];
return axis.name();
});
TVM_REGISTER_GLOBAL("tir.BijectiveLayout")
.set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout {
return BijectiveLayout(src_layout, dst_layout);
});
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex")
.set_body_method(&BijectiveLayout::ForwardIndex);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex")
.set_body_method(&BijectiveLayout::BackwardIndex);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape")
.set_body_method(&BijectiveLayout::ForwardShape);
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape")
.set_body_method(&BijectiveLayout::BackwardShape);
} // namespace tir
} // namespace tvm