blob: 8774fc37125fb34e7325077a467900f136f38587 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Lower intrinsic calls and ops to device specific ir when possible.
* \file lower_intrin.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../arith/pattern_match.h"
namespace tvm {
namespace tir {
class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
public:
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt_;
IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "")
: IRMutatorWithAnalyzer(analyzer) {
patterns_.push_back("tvm.intrin.rule." + target + ".");
bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
if (is_llvm_aarch64) {
patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64.");
}
patterns_.push_back("tvm.intrin.rule.default.");
fma_ = runtime::Registry::Get(patterns_[0] + "fma");
if (target == "stackvm") {
support_bitwise_op_ = false;
}
}
PrimExpr VisitExpr_(const CallNode* op) final {
if (auto* ptr_op = op->op.as<OpNode>()) {
// Still use legacy string based rewriting
// TODO(tvm-team): migrate the pattern application from global function look up
// to an OpAttrMap<PackedFunc>
std::string name = ptr_op->name;
PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
if (r.defined()) return r;
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
PrimExpr VisitExpr_(const AddNode* op) final {
if (const MulNode* mb = op->b.as<MulNode>()) {
return MakeFMA(mb->a, mb->b, op->a, op);
} else if (const MulNode* ma = op->a.as<MulNode>()) {
return MakeFMA(ma->a, ma->b, op->b, op);
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
// We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions
PrimExpr VisitExpr_(const FloorDivNode* op) final {
auto e = GetRef<PrimExpr>(op);
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDivNode>();
if (op == nullptr) return ret;
int shift;
const DataType& dtype = op->dtype;
CHECK(dtype.is_int() || dtype.is_uint());
if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
// lower to right shift if possible.
return op->a >> make_const(dtype, shift);
}
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
// Common path, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) {
return truncdiv(op->a, op->b);
} else {
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1));
}
}
} else {
if (dtype.is_float()) {
// floor(a / b)
return VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>());
} else {
// uncommon case
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
auto rmod = tir::Var("rmod", dtype);
auto rdiv = tir::Var("rdiv", dtype);
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
PrimExpr let_rdiv =
tir::Let(rdiv, truncdiv(op->a, op->b),
tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
rdiv - make_const(dtype, 1)));
return Let(rmod, truncmod(op->a, op->b), let_rdiv);
}
}
}
PrimExpr VisitExpr_(const FloorModNode* op) final {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorModNode>();
if (op == nullptr) return ret;
// Lower floordiv to native truncdiv.
int shift;
const DataType& dtype = op->dtype;
CHECK(dtype.is_int() || dtype.is_uint());
if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
// lower to masking if possible.
int64_t mask = (static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
return op->a & make_const(dtype, mask);
}
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
// Common pass, positive divisor
if (analyzer_->CanProveGreaterEqual(op->a, 0)) {
return truncmod(op->a, op->b);
} else {
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
// NOTE:condition on b >= 0.
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
PrimExpr rmod = truncmod(op->a, op->b);
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
} else {
return tir::Select(rmod >= 0, rmod, rmod + op->b);
}
}
} else {
if (dtype.is_float()) {
// a - floor(a / b) * b
return op->a - (VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>()) * op->b);
} else {
// uncommon case
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
auto rmod = tir::Var("rmod", dtype);
// b > 0 && rmod >= 0 -> rmod
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
// b < 0 && rmod > 0 -> rmod + b
return Let(
rmod, truncmod(op->a, op->b),
Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b));
}
}
}
PrimExpr VisitExpr_(const MaxNode* op) final {
using namespace arith;
PVar<PrimExpr> x, y;
PVar<IntImm> c;
auto e = GetRef<PrimExpr>(op);
if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 &&
analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval());
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
PrimExpr VisitExpr_(const EQNode* op) final {
using namespace arith;
PVar<PrimExpr> x, y;
auto e = GetRef<PrimExpr>(op);
if ((floormod(x, y) == 0).Match(e)) {
return VisitExpr((truncmod(x, y) == 0).Eval());
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
PrimExpr VisitExpr_(const NENode* op) final {
using namespace arith;
PVar<PrimExpr> x, y;
auto e = GetRef<PrimExpr>(op);
if ((floormod(x, y) != 0).Match(e)) {
return VisitExpr((truncmod(x, y) != 0).Eval());
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
private:
PrimExpr SwapBroadcastCast(const PrimExpr& e) {
// Try to change broadcast(cast(x)) to cast(broadcast(x))
// For some targets, LLVM will generate more efficient FMA
// instruction with the latter. For example, vmla vs. vmlal
// on ARM.
if (const BroadcastNode* bcast = e.as<BroadcastNode>()) {
if (const CastNode* cast = bcast->value.as<CastNode>()) {
auto should_swap = [&]() {
// Maintain behaviour (int8 -> int16, fp16 -> fp32).
if (cast->dtype.bits() == cast->value.dtype().bits() * 2) {
return true;
}
// Check both operands are integer-like.
if (!cast->dtype.is_uint() && !cast->dtype.is_int()) {
return false;
}
if (!cast->value.dtype().is_uint() && !cast->value.dtype().is_int()) {
return false;
}
// If both are integer-like, swap if we have a widening cast.
return cast->dtype.bits() > cast->value.dtype().bits();
};
if (should_swap()) {
PrimExpr new_bcast = Broadcast(cast->value, bcast->lanes);
return Cast(bcast->dtype, new_bcast);
}
}
}
return e;
}
PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, const AddNode* op) {
// emit fma instruction: a * b + c
PrimExpr lhs = SwapBroadcastCast(a);
PrimExpr rhs = SwapBroadcastCast(b);
if (fma_ != nullptr && op->dtype.is_float()) {
PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
if (r.defined()) return this->VisitExpr(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
PrimExpr mul = this->VisitExpr(Mul(lhs, rhs));
return Add(mul, this->VisitExpr(c));
}
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
PrimExpr ApplyPattern(std::string name, const PrimExpr& e) {
if (name.compare(0, 4, "tir.") == 0) {
name = name.substr(4);
}
for (size_t i = 0; i < patterns_.size(); ++i) {
std::string& p = patterns_[i];
size_t psize = p.length();
p.resize(psize + name.length());
name.copy(&p[0] + psize, name.length());
const runtime::PackedFunc* f = runtime::Registry::Get(p);
p.resize(psize);
// if pattern exists.
if (f != nullptr) {
PrimExpr r = (*f)(e);
CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
if (!r.same_as(e)) {
return this->VisitExpr(r);
}
}
}
return PrimExpr();
}
// patterns
std::vector<std::string> patterns_;
const PackedFunc* fma_{nullptr};
bool support_bitwise_op_{true};
};
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
arith::Analyzer analyzer;
return IntrinInjecter(&analyzer, target)(std::move(stmt));
}
namespace transform {
Pass LowerIntrin() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) << "LowerIntrin: Require the target attribute";
arith::Analyzer analyzer;
auto mtriple = target.value()->GetAttr<runtime::String>("mtriple", "");
n->body =
IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin").set_body_typed(LowerIntrin);
} // namespace transform
} // namespace tir
} // namespace tvm