blob: 40c7ab3c54ac5951b68babfac6a4c766e73ad2d2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file narrow_predicate_expression.cc
* \brief Utility to deduce bound of expression
*/
#include <tvm/arith/int_solver.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
namespace tvm {
namespace arith {
using namespace tir;
/* \brief Given a true expression that includes free parameter,
* generate a true expression without the free parameters.
*
* This function provides two guarantees:
*
* 1. If the resulting expression evaluates to True, then the original
* expression also evaluates to True.
*
* 2. The resulting expression does not contain any of the free
* parameters.
*
*/
// Utility for generating a known true expression from an expression
// with free parameters, and the range of those parameters.
class ExpressionNarrower : public tir::ExprMutator {
public:
static PrimExpr Apply(PrimExpr expr, Map<Var, Range> free_parameters) {
ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr;
ExpressionNarrower mutator(free_parameters);
return mutator(expr);
}
private:
explicit ExpressionNarrower(Map<Var, Range> free_parameters)
: free_parameters_(free_parameters) {}
using Parent = tir::ExprMutator;
using Parent::VisitExpr_;
enum class Context {
Maximize,
Minimize,
};
template <typename T>
PrimExpr VisitInequality(T t, Context a_ctx, Context b_ctx) {
PrimExpr a = [&]() {
WithContext context(this, a_ctx);
return this->VisitExpr(t->a);
}();
PrimExpr b = [&]() {
WithContext context(this, b_ctx);
return this->VisitExpr(t->b);
}();
if (contains_unknown_expr_ && t.dtype().is_bool()) {
contains_unknown_expr_ = false;
return Bool(CurrentContext() == Context::Minimize);
} else if (a.same_as(t->a) && b.same_as(t->b)) {
return std::move(t);
} else {
return T(a, b);
}
}
PrimExpr VisitExpr_(const FloorModNode* op) override {
// FloorMod is non-monotonic, so inserting min/max won't remove
// the free parameters.
contains_unknown_expr_ = true;
return Parent::VisitExpr_(op);
}
PrimExpr VisitExpr_(const FloorDivNode* op) override {
auto res_a = this->VisitExpr(op->a);
auto res_b = this->VisitExpr(op->b);
if (is_zero(res_b)) {
contains_unknown_expr_ = true;
return IntImm(op->dtype, 0);
} else {
return floordiv(res_a, res_b);
}
}
PrimExpr VisitExpr_(const GTNode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<GT>(op), OppositeContext(current), current);
}
PrimExpr VisitExpr_(const GENode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<GE>(op), OppositeContext(current), current);
}
PrimExpr VisitExpr_(const LTNode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<LT>(op), current, OppositeContext(current));
}
PrimExpr VisitExpr_(const LENode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<LE>(op), current, OppositeContext(current));
}
PrimExpr VisitExpr_(const EQNode* op) override {
auto res_a = this->VisitExpr(op->a <= op->b);
auto res_b = this->VisitExpr(op->b <= op->a);
return res_a && res_b;
}
PrimExpr VisitExpr_(const NENode* op) override {
auto res_a = this->VisitExpr(op->a < op->b);
auto res_b = this->VisitExpr(op->b < op->a);
return res_a || res_b;
}
PrimExpr VisitExpr_(const SubNode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<Sub>(op), current, OppositeContext(current));
}
PrimExpr VisitExpr_(const NotNode* op) override {
auto current = CurrentContext();
WithContext context(this, OppositeContext(current));
return !VisitExpr(op->a);
}
PrimExpr VisitExpr_(const BufferLoadNode* op) override {
contains_unknown_expr_ = true;
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const VarNode* op) override {
auto it = free_parameters_.find(GetRef<Var>(op));
if (it == free_parameters_.end()) {
return Parent::VisitExpr_(op);
}
Range range = (*it).second;
switch (CurrentContext()) {
case Context::Minimize:
return range->min;
case Context::Maximize:
return range->min + range->extent - 1;
}
return Parent::VisitExpr_(op);
}
Context CurrentContext() const {
if (context_stack_.size()) {
return context_stack_.back();
} else {
return Context::Maximize;
}
}
Context OppositeContext(Context context) const {
switch (context) {
case Context::Minimize:
return Context::Maximize;
case Context::Maximize:
return Context::Minimize;
default:
LOG(FATAL) << "Unhandled Context, all legal values should be handled";
}
}
struct WithContext {
WithContext(ExpressionNarrower* self, Context context) : self(self) {
self->context_stack_.push_back(context);
}
~WithContext() { self->context_stack_.pop_back(); }
ExpressionNarrower* self;
};
std::vector<Context> context_stack_;
Map<Var, Range> free_parameters_;
bool contains_unknown_expr_{false};
};
PrimExpr NarrowPredicateExpression(PrimExpr expr, Map<Var, Range> free_parameters) {
return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters));
}
TVM_REGISTER_GLOBAL("arith.NarrowPredicateExpression").set_body_typed(NarrowPredicateExpression);
} // namespace arith
} // namespace tvm