blob: 4d6aa88ede01e66909a9440ce8dc5cf69da54cb7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file narrow_datatype.cc
* \brief narrow the datatype of indexing vars
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../arith/ir_visitor_with_analyzer.h"
namespace tvm {
namespace tir {
// This pass narrows indexing expressions (like StoreNode::Index)
// that trivially fit into i32/i16 (denoted by `target_bits_`) to
// i32/i16. Considering that i32/i16 indices may be more
// efficient on some backends (while i64 may be more efficient
// on others, like llvm), we may want this pass when i32/i16
// indices are more efficient.
//
// For Var v, we determine its dtype by examining all the PrimExpr
// that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}.
// If all expressions in E fit into i32/i16, then we think v can be narrowed
// to i32/i16.
//
// To make an indexing expression i32/i16, we must make sure that every
// component of that expression is of dtype i32/i16. So besides Var, we
// rewrite the following inside an indexing expression
// - Var
// - IntImm
// - Cast
//
// Algorithm:
// - Use DataTypeVisitor to determine whether a Var can be narrowed or not.
// - Use DataTypeRewritter to rewrite the components of an indexing expression.
using arith::Analyzer;
using arith::ConstIntBound;
using arith::IRMutatorWithAnalyzer;
// Determine the result dtype for Var, IntImm and Cast,
// which will be stored in `vmap` eventually.
//
// Algorithm:
// We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`.
// To be more specific, if for each Expr `e` which contains `var`
// (`var` is a child node of `e` in AST), `e` fits into `target_bits_`,
// then we narrow `var` into `target_bits_`. That is,
// `vmap[var] = min(target_bits_, var.dtype.bits())`
// Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()`
class DataTypeVisitor final : public StmtExprVisitor {
public:
explicit DataTypeVisitor(int target_bits) : bits_(target_bits), target_bits_(target_bits) {}
void VisitExpr(const PrimExpr& e) {
if (e.dtype().is_int()) {
int bits = max_bits_;
if (bound_.find(e) == bound_.end()) {
analyzer_.const_int_bound(e, &bound_);
}
ConstIntBound bound = bound_[e];
int64_t ubound = Downcast<IntImm>(max_value(DataType::Int(target_bits_)))->value;
int64_t lbound = Downcast<IntImm>(min_value(DataType::Int(target_bits_)))->value;
if (e.dtype().bits() <= target_bits_ ||
(bound->max_value <= ubound && bound->min_value >= lbound)) {
bits = target_bits_;
}
int tmp = bits > bits_ ? bits : bits_;
std::swap(bits_, tmp);
StmtExprVisitor::VisitExpr(e);
std::swap(bits_, tmp);
} else {
StmtExprVisitor::VisitExpr(e);
}
}
void VisitStmt_(const ForNode* op) {
analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
vextent_[op->loop_var.as<VarNode>()] = op->extent.dtype();
return StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value));
vextent_[iv->var.as<VarNode>()] = op->value.dtype();
StmtExprVisitor::VisitStmt_(op);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
void VisitExpr_(const ReduceNode* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_.Bind(iv->var, iv->dom);
vextent_[iv->var.as<VarNode>()] = iv->dom->extent.dtype();
}
// Recursively call simplification when necessary.
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const VarNode* op) {
if (vextent_.find(op) != vextent_.end()) {
// We only narrow and never promote, so the result dtype
// is upperbounded by its original dtype before rewrite.
int bits = std::min(vextent_[op].bits(), bits_);
if (vmap.find(op) == vmap.end()) {
vmap[op] = op->dtype.with_bits(bits);
} else {
// We take maximum bits for all the possible Expr where a var occurs
vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
}
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const IntImmNode* op) {
if (op->dtype.is_int()) {
// We only narrow and never promote, so the result dtype
// is upperbounded by its original dtype before rewrite.
int bits = std::min(op->dtype.bits(), bits_);
if (vmap.find(op) == vmap.end()) {
vmap[op] = op->dtype.with_bits(bits);
} else {
vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
}
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const CastNode* op) {
if (op->dtype.is_int()) {
// We only narrow and never promote, so the result dtype
// is upperbounded by its original dtype before rewrite.
int bits = std::min(op->dtype.bits(), bits_);
if (vmap.find(op) == vmap.end()) {
vmap[op] = op->dtype.with_bits(bits);
} else {
vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
}
}
StmtExprVisitor::VisitExpr_(op);
}
// the narrowed datatype of Var and IntImm
std::unordered_map<const PrimExprNode*, DataType> vmap;
protected:
// internal analyzer
arith::Analyzer analyzer_;
private:
// the maximum possible bits, which serves as an init value
static constexpr const int max_bits_ = 64;
// the maximum possible bit of the current expression's return dtype
int bits_;
// the target bits
int target_bits_;
// the extent of vars to be rewritten
std::unordered_map<const VarNode*, DataType> vextent_;
// the memorized bound generated by ConstIntBoundAnalyzer
arith::ConstIntBoundAnalyzer::BoundMapType bound_;
};
class DataTypeRewriter : public StmtExprMutator {
public:
explicit DataTypeRewriter(int target_bits) : visitor_(target_bits) {}
Stmt operator()(Stmt s) {
visitor_(s);
for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) {
PrimExpr e = GetRef<PrimExpr>(i->first);
if (e.dtype() == i->second) {
i = visitor_.vmap.erase(i);
} else {
++i;
}
}
return VisitStmt(s);
}
Stmt VisitStmt_(const StoreNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
is_index_ = true;
PrimExpr index = this->VisitExpr(op->index);
is_index_ = false;
Stmt s = Store(op->buffer_var, op->value, index, op->predicate);
return StmtExprMutator::VisitStmt_(s.as<StoreNode>());
}
Stmt VisitStmt_(const ForNode* op) final {
Stmt s = StmtExprMutator::VisitStmt_(op);
op = s.as<ForNode>();
CHECK(op != nullptr) << "Expected type to be ForNode"
<< ", but get " << s->GetTypeKey();
PrimExpr e = VisitExpr(op->loop_var);
Var var = Downcast<Var>(e);
return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->for_type,
op->device_api, op->body);
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
Stmt s = StmtExprMutator::VisitStmt_(op);
op = s.as<AttrStmtNode>();
CHECK(op != nullptr) << "Expected type to be AttrStmtNode"
<< ", but get " << s->GetTypeKey();
const IterVarNode* iv = op->node.as<IterVarNode>();
CHECK(iv != nullptr) << "Expected type to be IterVarNode"
<< ", but get " << op->node->GetTypeKey();
PrimExpr e = VisitExpr(iv->var);
Var var = Downcast<Var>(e);
if (ivmap_.find(iv) == ivmap_.end()) {
ivmap_[iv] = IterVar(iv->dom, var, iv->iter_type, iv->thread_tag);
}
return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body);
}
return StmtExprMutator::VisitStmt_(op);
}
PrimExpr VisitExpr_(const VarNode* op) final {
if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
if (vmap_.find(op) == vmap_.end()) {
vmap_[op] = Var(op->name_hint, visitor_.vmap[op]);
}
return vmap_[op];
}
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr VisitExpr_(const SizeVarNode* op) final {
if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
if (vmap_.find(op) == vmap_.end()) {
vmap_[op] = SizeVar(op->name_hint, visitor_.vmap[op]);
}
return vmap_[op];
}
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr VisitExpr_(const LoadNode* op) final {
is_index_ = true;
PrimExpr index = this->VisitExpr(op->index);
is_index_ = false;
PrimExpr e = Load(op->dtype, op->buffer_var, index, op->predicate);
return StmtExprMutator::VisitExpr_(e.as<LoadNode>());
}
PrimExpr VisitExpr_(const IntImmNode* op) final {
if (is_index_) {
if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
return IntImm(visitor_.vmap[op], op->value);
}
}
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr VisitExpr_(const CastNode* op) final {
if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) {
PrimExpr e = StmtExprMutator::VisitExpr_(op);
const CastNode* new_op = e.as<CastNode>();
CHECK(new_op != nullptr) << "Expected type to be CastNode"
<< ", but get " << e->GetTypeKey();
return Cast(visitor_.vmap[op], new_op->value);
}
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr VisitExpr_(const AddNode* op) final;
PrimExpr VisitExpr_(const SubNode* op) final;
PrimExpr VisitExpr_(const MulNode* op) final;
PrimExpr VisitExpr_(const DivNode* op) final;
PrimExpr VisitExpr_(const ModNode* op) final;
PrimExpr VisitExpr_(const FloorDivNode* op) final;
PrimExpr VisitExpr_(const FloorModNode* op) final;
PrimExpr VisitExpr_(const MinNode* op) final;
PrimExpr VisitExpr_(const MaxNode* op) final;
PrimExpr VisitExpr_(const EQNode* op) final;
PrimExpr VisitExpr_(const NENode* op) final;
PrimExpr VisitExpr_(const LTNode* op) final;
PrimExpr VisitExpr_(const LENode* op) final;
PrimExpr VisitExpr_(const GTNode* op) final;
PrimExpr VisitExpr_(const GENode* op) final;
PrimExpr VisitExpr_(const CallNode* op) final;
private:
// the internal visitor to deduce the narrowed dtype
DataTypeVisitor visitor_;
// a map from Var before rewrite to that after rewrite,
// ensures one old Var maps to exactly one new Var
std::unordered_map<const VarNode*, Var> vmap_;
// a map from IterVar before rewrite to that after rewrite,
// ensures one old IterVar maps to exactly one new IterVar
std::unordered_map<const IterVarNode*, IterVar> ivmap_;
// indicator of LoadNode::index and StoreNode::index
bool is_index_{false};
// cached ops
const Op& builtin_pow_ = Op::Get("tir.pow");
};
#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \
PrimExpr a = this->VisitExpr(op->a); \
PrimExpr b = this->VisitExpr(op->b); \
if (a.same_as(op->a) && b.same_as(op->b)) { \
return GetRef<PrimExpr>(op); \
} else { \
return FUNC(a, b); \
} \
}
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+);
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-);
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*);
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div);
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod);
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv);
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod);
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min);
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max);
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==);
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=);
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=);
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=);
PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) {
PrimExpr e = StmtExprMutator::VisitExpr_(op);
op = e.as<CallNode>();
CHECK(op != nullptr) << "Expected type to be CallNode"
<< ", but get " << e->GetTypeKey();
if (op->op.same_as(builtin::if_then_else())) {
return if_then_else(op->args[0], op->args[1], op->args[2]);
} else if (op->op.same_as(builtin::shift_right())) {
return op->args[0] >> op->args[1];
} else if (op->op.same_as(builtin::shift_left())) {
return op->args[0] << op->args[1];
} else if (op->op.same_as(builtin::bitwise_and())) {
return op->args[0] & op->args[1];
} else if (op->op.same_as(builtin::bitwise_or())) {
return op->args[0] | op->args[1];
} else if (op->op.same_as(builtin::bitwise_xor())) {
return op->args[0] ^ op->args[1];
} else if (op->op.same_as(builtin_pow_)) {
return pow(op->args[0], op->args[1]);
}
return e;
}
Stmt NarrowDataType(Stmt stmt, int target_bits) { return DataTypeRewriter(target_bits)(stmt); }
namespace transform {
Pass NarrowDataType(int target_bits) {
auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = DataTypeRewriter(target_bits)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {});
}
TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType").set_body_typed(NarrowDataType);
} // namespace transform
} // namespace tir
} // namespace tvm