blob: 52010ec322c80520b196dc8fc62720fa57e972b3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arith/transitive_comparison_analyzer.cc
*/
#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <optional>
#include <vector>
#include "constraint_extract.h"
#include "pattern_match.h"
namespace tvm {
namespace arith {
using namespace tir;
class TransitiveComparisonAnalyzer::Impl {
public:
/* \brief Using previously specified knowns, compare the expressions provided
*
* \param lhs The left-hand side of the comparison
*
* \param rhs The right-hand side of the comparison
*
* \param propagate_inequalities If true, attempt to find a sequence
* of transitive inequalities that allow the lhs and rhs to be
* compared. If false, only use the known comparison that have been
* directly provided. Using `propagate_inequalities = false` is
* roughly equivalent to comparing against all known values with
* `ExprDeepEqual`, but also allowing for constant offsets on either
* side of the inequality.
*
* \return The most specific result that can be proven about the
* comparison. If nothing can be proven, returns kUnknown.
*/
CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
bool propagate_inequalities = true) const;
/*! \brief Bind a variable as being equal to a known expression
*
* \param var The variable of interest.
* \param expr The bound expression
* \param allow_override Whether to allow override of existing information.
*/
void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
/*! \brief Bind a variable as being within a specified range
*
* \param var The variable of interest.
* \param range The known range
* \param allow_override Whether to allow override of existing information.
*/
void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
*
* \return An exit function that must be called to cleanup. May be
* `nullptr`, if no cleanup is required.
*/
std::function<void()> EnterConstraint(const PrimExpr& expr);
private:
/* \brief Internal representation of a PrimExpr
*
* The Key enum serves two purposes.
*
* 1. Providing efficiency, as compared to a PrimExpr. Two keys are
* equal if and only if the corresponding PrimExprs would satisfy
* ExprDeepEqual. This allows two expressions to be checked for
* equivalency, without requiring a call to ExprDeepEqual for
* each comparison.
*
* 2. Providing type-safety, as compared to using `size_t` directly.
* Requiring an explicit conversion from an integer to a Key
* prevents accidental comparisons, especially if both loop
* iterators and Keys are used in the same scope.
*
* A Key should only be obtained using the methods `ExprToKey` and
* `ExprToPreviousKey`.
*/
enum class Key : size_t {};
/*! \brief Convert an expression to internal representation
*
* If the expression has previously been converted to the internal
* representation, returns the same Key as has been used previously.
* Otherwise, generate and return a new Key.
*
* \param expr The PrimExpr to be converted
*
* \returns The Key representing the expression
*
* \see ExprToPreviousKey
*/
Key ExprToKey(const PrimExpr& expr);
/*! \brief Convert an expression to internal representation
*
* If the expression has previously been converted to the internal
* representation, returns the same Key as has been used previously.
* Otherwise, return `std::nullopt`.
*
* \param expr The PrimExpr to be converted
*
* \returns The Key representing the expression, if one exists.
*
* \see ExprToKey
*/
std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
/*! \brief The mapping from expression to Key
*
* Should not be used directly. Instead, use the helper functions
* `ExprToKey` and `ExprToPreviousKey`.
*
* \see ExprToKey
* \see ExprToPreviousKey
*/
std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
/*! \brief Internal representation of a comparison operator */
struct Comparison {
/*! \brief Construct a comparison that represents `lhs OP rhs +
* offset`, where the operation is specified by the CompareResult.
*/
Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
/*! \brief Utility function to validate that all GT and LT results
* have been normalized out
*/
bool IsNormalized() const;
/*! \brief Move the specified expression to the LHS.
*
* \param new_lhs The argument that should be moved to the LHS of the
* comparison.
*
* \return If possible, returns a comparison that is equivalent to
* the current comparison, but with the specified LHS. If not
* possible, returns nullopt.
*/
std::optional<Comparison> WithLHS(Key new_lhs) const;
/*! \brief Create the negation of the current comparison */
Comparison Negated() const;
/*! \brief Check the this comparison implies
*
* Returns true if this comparison being true implies that the
* other comparison must also be true. Returns false if the other
* comparison cannot be shown to be true.
*/
bool Implies(const Comparison& other) const;
// The LHS of the comparison
Key lhs_;
// The RHS of the comparison, not including any constant offset.
Key rhs_;
// Additive offset on rhs
int64_t offset_{0};
// The comparison operator.
CompareResult result_{CompareResult::kInconsistent};
};
/*! \brief Generate a Comparison representing the given expression */
std::optional<Comparison> FromExpr(const PrimExpr& expr);
/*! \brief Utility function used by Bind and EnterConstraint
*
* \param expr The comparison expression, to be converted into
* internal Comparison objects.
*
* \param vec The vector to which the Comparison objects should be
* appended.
*/
void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
/*! Collect known comparisons between LHS and RHS, without propagation
*
* Allows the internal representation to handle any constant
* offsets, without searching for a sequence of inequalities.
*
* \param lhs_key The left-hand side of the comparison
*
* \param rhs_key The right-hand side of the comparison
*
* \returns A subset of `knowns_` and `scoped_knowns_`, filtered to
* only include comparisons between `lhs_key` and `rhs_key`,
* normalized such that `lhs_key` is on the left-hand side.
*/
std::vector<Comparison> CollectDirectComparisons(Key lhs_key, Key rhs_key) const;
/*! Collect known comparisons between LHS and RHS, with propagation
*
* \param lhs_key The left-hand side of the comparison
*
* \param rhs_key The right-hand side of the comparison
*
* \returns All comparisons between `lhs_key` and `rhs_key`,
* including the explicitly-provided comparisons in `knowns_` and
* `scoped_knowns_`, and comparisons provable through a series of
* comparisons through other values. All comparisons returned are
* between `lhs_key` and `rhs_key`, and are normalized such that
* `lhs_key` is on the left-hand side.
*/
std::vector<Comparison> CollectIndirectComparisons(Key lhs_key, Key rhs_key) const;
/*! \brief Internal function used by CollectIndirectComparisons
*
* Perform a depth-first search through the space of known
* expressions, starting at the LHS of a comparison. In this
* search, each expression is a node of a graph, and each known
* comparison is an edge of the graph.
*
* For example, suppose we have previous knowns of (A<=B), (B<=C+1)
* and (C<=D-5). The expressions [A,B,C,D] are the nodes of the
* search space. Each comparison is an edge connecting two
* expressions, such as (B<=C+1) connecting the expressions B and D.
* If we are attempting to compare expressions A and D, a search
* starting at expression A could follow each edge until reaching
* expression D, then combine the comparisons that compose the path
* into the expression A<=D-4.
*
* \param lhs_key The left-hand side of the comparison
*
* \param rhs_key The right-hand side of the comparison
*
* \returns A vector of comparisons between the two expressions.
*/
std::vector<Comparison> DFSFromLHS(Key lhs_key, Key rhs_key) const;
/*! \brief Combine a set of comparisons that share a LHS and RHS
*
* \param lhs_to_rhs The comparisons to merge. These should all
* have the same LHS and RHS. This parameter will typically be the
* result from `CollectDirectComparisons` or
* `CollectIndirectComparisons`.
*
* \param offset The constant offset in the comparison being proven.
* This is extracted from any additive/subtractive constants in the
* `PrimExpr` arguments to `TryCompare`.
*
* \returns The possible comparisons between LHS and RHS provided
* inequalities.
*/
CompareResult MergeComparisons(const std::vector<Comparison>& lhs_to_rhs, int64_t offset) const;
/*! \brief Previous Range bindings
*
* Tracked separatedly to handle the `allow_override` option used by
* all sub-analyzers when binding variables.
*/
Map<Var, Range> prev_bindings_;
/*! \brief Known comparisons based on definitionally-true statements
*
* For example, a Let binding, or the range of an iterator. These
* known statements are always true, based on the definition site of
* the variable. e.g. A loop iterator may never exceed the bounds
* of its loop.
*/
std::vector<Comparison> knowns_;
/*! \brief Known comparisons based on scoped conditions
*
* For example, the condition of an IfThenElse. These known
* statements may only be used within the scope of the statement
* from which they were derived. e.g. After exiting an IfThenElse,
* the condition may no longer be true.
*/
std::vector<Comparison> scoped_knowns_;
};
namespace {
// Internal utility, return the CompareResult resulting from swapping
// the left-hand side with the right-hand side.
CompareResult Reverse(CompareResult res) {
switch (res) {
case CompareResult::kInconsistent:
return CompareResult::kInconsistent;
case CompareResult::kEQ:
return CompareResult::kEQ;
case CompareResult::kLT:
return CompareResult::kGT;
case CompareResult::kLE:
return CompareResult::kGE;
case CompareResult::kGT:
return CompareResult::kLT;
case CompareResult::kGE:
return CompareResult::kLE;
case CompareResult::kNE:
return CompareResult::kNE;
case CompareResult::kUnknown:
return CompareResult::kUnknown;
default:
LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
}
}
// Internal utility, return the CompareResult resulting from negating
// the comparison.
CompareResult Negate(CompareResult res) {
switch (res) {
case CompareResult::kInconsistent:
return CompareResult::kInconsistent;
case CompareResult::kUnknown:
return CompareResult::kUnknown;
default:
return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
}
}
// Internal utility, extract constant offsets out of the two sides of
// a comparison. Given lhs and rhs, return a tuple of three elements
// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
// (lhs_inner OP rhs_inner + offset) are equivalent.
std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
PVar<PrimExpr> x;
PVar<IntImm> c;
if ((x + c).Match(expr)) {
return {x.Eval(), c.Eval()->value};
} else if ((x - c).Match(expr)) {
return {x.Eval(), -c.Eval()->value};
} else if (c.Match(expr)) {
return {0, c.Eval()->value};
} else {
return {expr, 0};
}
};
auto lhs_split = extract_offset(lhs);
auto rhs_split = extract_offset(rhs);
return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
}
} // namespace
std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
CompareResult res;
PVar<PrimExpr> x, y;
if ((x <= y).Match(expr)) {
res = CompareResult::kLE;
} else if ((x >= y).Match(expr)) {
res = CompareResult::kGE;
} else if ((x < y).Match(expr)) {
res = CompareResult::kLT;
} else if ((x > y).Match(expr)) {
res = CompareResult::kGT;
} else if ((x == y).Match(expr)) {
res = CompareResult::kEQ;
} else if ((x != y).Match(expr)) {
res = CompareResult::kNE;
} else {
return std::nullopt;
}
PrimExpr lhs_expr = x.Eval();
PrimExpr rhs_expr = y.Eval();
if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
return std::nullopt;
}
auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
Key lhs_key = ExprToKey(lhs);
Key rhs_key = ExprToKey(rhs);
return Comparison(lhs_key, rhs_key, offset, res);
}
TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
CompareResult result)
: lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
// Normalize the comparison to remove LT and GT expressions,
// reducing the number of operators that must be handled later. By
// eliminating LT and GT, instead of eliminating LE or GE, a
// potential off-by-one error is avoided.
//
// For floating-point numbers, (x < y + c1) and (y < z + c2) implies
// that (x < z + (c1 + c2)). For integer types, which the
// TransitiveComparisonAnalyzer is intended for use with integers,
// LT or GT can give a tighter constraint, though with a less
// convenient symmetry.
//
// i < j + c1, j < k + c2
// i <= j + c1 - 1, j <= k + c2 - 1
// i + 1 - c1 <= j, j <= k + c2 - 1
// i + 1 - c1 <= k + c2 - 1
// i <= k + c1 + c2 - 2
// i < k + (c1 + c2 - 1)
//
// By always working with LE and GE comparisons, we avoid needing to
// handle the offset of one that would be introduced by LT and GT at
// all points of use. The only point of use for LT and GT is when
// normalizing comparisons (i.e. this constructor).
if (result_ == CompareResult::kLT) {
result_ = CompareResult::kLE;
offset_ -= 1;
}
if (result_ == CompareResult::kGT) {
result_ = CompareResult::kGE;
offset_ += 1;
}
}
std::optional<TransitiveComparisonAnalyzer::Impl::Key>
TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
auto it = expr_to_key.find(expr);
if (it != expr_to_key.end()) {
return it->second;
} else {
return std::nullopt;
}
}
TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
const PrimExpr& expr) {
if (auto prev = ExprToPreviousKey(expr)) {
return prev.value();
} else {
Key new_key = Key(expr_to_key.size());
expr_to_key[expr] = new_key;
return new_key;
}
}
bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
// These < and > should be removed during normalization. See the
// `Comparison::Comparison` constructor for further details.
return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
}
std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
if (new_lhs == lhs_) {
return *this;
} else if (new_lhs == rhs_) {
return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
} else {
return std::nullopt;
}
}
TransitiveComparisonAnalyzer::Impl::Comparison
TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
return Comparison(lhs_, rhs_, offset_, Negate(result_));
}
bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
ICHECK(lhs_ == other.lhs_);
ICHECK(rhs_ == other.rhs_);
ICHECK(IsNormalized());
ICHECK(other.IsNormalized());
if (result_ == other.result_ && offset_ == other.offset_) {
// if c1 == c2, x != y + c1 => x != y + c2
// if c1 == c2, x == y + c1 => x == y + c2
return true;
}
if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
// if c1 <= c2, x <= y + c1 => x <= y + c2
// if c1 <= c2, x == y + c1 => x <= y + c2
return true;
}
}
if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
// if c1 >= c2, x == y + c1 => x >= y + c2
// if c1 >= c2, x >= y + c1 => x >= y + c2
return true;
}
}
if (other.result_ == CompareResult::kNE) {
if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
// if c1 != c2, x == y + c1 => x != y + c2
return true;
}
if (result_ == CompareResult::kLE && offset_ < other.offset_) {
// if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
return true;
}
if (result_ == CompareResult::kGE && offset_ > other.offset_) {
// if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
return true;
}
}
return false;
}
TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
bool propagate_inequalities) {
return impl_->TryCompare(lhs, rhs, propagate_inequalities);
}
void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
impl_->Bind(var, expr, allow_override);
}
void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
impl_->Bind(var, range, allow_override);
}
std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
return impl_->EnterConstraint(constraint);
}
void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
std::vector<Comparison>* vec) {
for (const auto& subexpr : ExtractConstraints(expr, false)) {
if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
if (auto cmp = FromExpr(subexpr)) {
vec->push_back(cmp.value());
}
}
}
}
void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
bool allow_override) {
auto it = prev_bindings_.find(var);
if (it != prev_bindings_.end()) {
ExprDeepEqual expr_equal;
bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
!expr_equal(range->extent, (*it).second->extent);
if (differs_from_previous) {
ICHECK(allow_override) << "Binding of variable " << var << " as " << range
<< " conflicts with previous binding as " << (*it).second;
if (auto key = ExprToPreviousKey(var)) {
knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
[&](const auto& known) { return known.lhs_ == key.value(); }),
knowns_.end());
}
}
}
prev_bindings_.Set(var, range);
if (is_const_int(range->extent, 1)) {
AddKnown(var == range->min, &knowns_);
} else {
AddKnown(var >= range->min, &knowns_);
AddKnown(var < range->min + range->extent, &knowns_);
}
}
void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
bool allow_override) {
Bind(var, Range::FromMinExtent(expr, 1), allow_override);
}
std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
size_t old_literal_size = scoped_knowns_.size();
AddKnown(expr, &scoped_knowns_);
size_t new_literal_size = scoped_knowns_.size();
auto frecover = [old_literal_size, new_literal_size, this]() {
ICHECK_EQ(scoped_knowns_.size(), new_literal_size);
scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end());
};
return frecover;
}
CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr,
const PrimExpr& rhs_expr,
bool propagate_inequalities) const {
// Currently only supports integer checks
if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
return CompareResult::kUnknown;
}
// Bail out early if possible. This int check should have been
// constant-folded earlier, so this check shouldn't occur.
auto* x_int = lhs_expr.as<IntImmNode>();
auto* y_int = rhs_expr.as<IntImmNode>();
if (x_int && y_int) {
if (x_int->value < y_int->value) {
return CompareResult::kLT;
} else if (x_int->value > y_int->value) {
return CompareResult::kGT;
} else {
return CompareResult::kEQ;
}
}
auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
auto lhs_key = ExprToPreviousKey(lhs);
auto rhs_key = ExprToPreviousKey(rhs);
if (!lhs_key.has_value() || !rhs_key.has_value()) {
return CompareResult::kUnknown;
}
auto lhs_to_rhs = [&]() {
if (propagate_inequalities) {
return CollectIndirectComparisons(lhs_key.value(), rhs_key.value());
} else {
return CollectDirectComparisons(lhs_key.value(), rhs_key.value());
}
}();
return MergeComparisons(lhs_to_rhs, offset);
}
std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
TransitiveComparisonAnalyzer::Impl::CollectDirectComparisons(Key lhs_key, Key rhs_key) const {
std::vector<Comparison> output;
auto append_known = [&](Comparison cmp) {
if (auto normalized = cmp.WithLHS(lhs_key)) {
if (normalized.value().rhs_ == rhs_key) {
output.push_back(normalized.value());
}
}
};
for (const auto& known : knowns_) {
append_known(known);
}
for (const auto& known : scoped_knowns_) {
append_known(known);
}
return output;
}
std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
TransitiveComparisonAnalyzer::Impl::CollectIndirectComparisons(Key lhs_key, Key rhs_key) const {
auto output = DFSFromLHS(lhs_key, rhs_key);
for (Comparison cmp : DFSFromLHS(rhs_key, lhs_key)) {
auto opt_normalized = cmp.WithLHS(lhs_key);
ICHECK(opt_normalized.has_value());
output.push_back(opt_normalized.value());
}
return output;
}
std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key, Key rhs_key) const {
// Everything in `to_visit` has lhs as its lhs.
std::unordered_set<Key> seen;
std::unordered_set<Key> to_visit;
std::unordered_map<Key, std::vector<Comparison>> compared_to_lhs;
// Utility function to add a new known statement
auto declare_known = [&](Comparison cmp) {
std::vector<Comparison>& knowns = compared_to_lhs[cmp.rhs_];
// The comparison adds no new information, no modification
// required.
for (auto& prev_known : knowns) {
if (prev_known.Implies(cmp)) {
return;
}
}
// New information may require visiting a new expression.
if (cmp.rhs_ != rhs_key && !seen.count(cmp.rhs_)) {
to_visit.insert(cmp.rhs_);
seen.insert(cmp.rhs_);
}
// This comparison is a stronger version of a previous constraint.
// Therefore, replace the old version entirely.
for (auto& prev_known : knowns) {
if (cmp.Implies(prev_known)) {
prev_known = cmp;
return;
}
}
// Neither a superset nor a subset of previously known
// constraints, must be tracked separately.
knowns.push_back(cmp);
};
// Initialize the search based on any known (in)equalities that use
// the LHS of the comparison.
for (const auto& known : knowns_) {
if (auto normalized = known.WithLHS(lhs_key)) {
declare_known(normalized.value());
}
}
for (const auto& known : scoped_knowns_) {
if (auto normalized = known.WithLHS(lhs_key)) {
declare_known(normalized.value());
}
}
// Walk through the space of all comparisons that can be made with
// LHS.
while (to_visit.size()) {
Key middle_key = *to_visit.begin();
to_visit.erase(to_visit.begin());
std::vector<Comparison>& prev_knowns_using_middle = compared_to_lhs.at(middle_key);
ICHECK(compared_to_lhs.count(middle_key));
std::vector<Comparison> new_knowns_using_lhs;
auto attempt_transitive = [&](Comparison cmp) {
ICHECK(cmp.IsNormalized());
Key right_key = cmp.rhs_;
if (right_key == lhs_key) {
return;
}
for (const auto& prev : prev_knowns_using_middle) {
CompareResult new_result = CompareResult::kUnknown;
int64_t new_offset = prev.offset_ + cmp.offset_;
if (prev.result_ == CompareResult::kEQ) {
// x == y + c1 && y OP z + c2, x OP z + (c1 + c2)
new_result = cmp.result_;
} else if (cmp.result_ == CompareResult::kEQ) {
// x OP y + c1 && y == z + c2, x OP z + (c1 + c2)
new_result = prev.result_;
} else if (prev.result_ == cmp.result_ &&
(prev.result_ == CompareResult::kLE || prev.result_ == CompareResult::kGE)) {
// x <= y + c1 && y <= z + c2, x <= z + (c1 + c2)
// x >= y + c1 && y >= z + c2, x >= z + (c1 + c2)
//
// This condition is much simpler to write than the
// equivalent handling of < or of >, which is why the
// inequalities are normalized to <= and to >=. See
// `TransitiveComparisonAnalyzer::Impl::Comparison::Comparison`
// for further details.
new_result = prev.result_;
}
if (new_result != CompareResult::kUnknown) {
Comparison new_known(lhs_key, right_key, new_offset, new_result);
new_knowns_using_lhs.push_back(new_known);
}
}
};
// Attempt to prove a new comparison using one of the original
// known comparisons. We want to find a known such that
// `(LHS OP1 middle) && (middle OP2 right)` can be simplified
// into `(LHS OP3 right)`.
//
// Note: The right side is this step is not necessarily the RHS of
// the comparison we're trying to prove, as we may need to find
// intermediate comparisons first. For example, if we know that
// `a<=b`, `b<=c`, and `c<=d`, and we wish to prove that `a<=d`,
// we must first combine `a<=b` and `b<=c` into `a<=c`. During
// this first step, `b` is the "middle" and `c` is the "right".
// The next step can then combind `a<=c` and `c<=d` into `a<=d`.
for (const auto& known : knowns_) {
if (auto cmp = known.WithLHS(middle_key)) {
attempt_transitive(cmp.value());
}
}
for (const auto& known : scoped_knowns_) {
if (auto cmp = known.WithLHS(middle_key)) {
attempt_transitive(cmp.value());
}
}
// Collect together all new knowns, marking new nodes for visiting
// as needed.
for (const auto& new_known : new_knowns_using_lhs) {
declare_known(new_known);
}
}
if (auto it = compared_to_lhs.find(rhs_key); it != compared_to_lhs.end()) {
return it->second;
} else {
// There are known comparisons involving the LHS and the RHS, but
// no path that connects the two expressions.
return {};
}
}
CompareResult TransitiveComparisonAnalyzer::Impl::MergeComparisons(
const std::vector<Comparison>& lhs_to_rhs, int64_t offset) const {
// Just because we found a comparison involving LHS and RHS doesn't
// mean that it's useful. e.g. Knowing that `x < y` doesn't let us
// prove whether `x + 5 < y`.
CompareResult result = CompareResult::kUnknown;
for (const auto& cmp : lhs_to_rhs) {
switch (cmp.result_) {
case CompareResult::kInconsistent:
result = CompareResult::kInconsistent;
break;
case CompareResult::kEQ:
if (offset == cmp.offset_) {
result = result & CompareResult::kEQ;
} else {
result = result & CompareResult::kNE;
}
break;
case CompareResult::kLE:
if (cmp.offset_ < offset) {
result = result & CompareResult::kLT;
} else if (cmp.offset_ <= offset) {
result = result & CompareResult::kLE;
}
break;
case CompareResult::kGE:
if (cmp.offset_ > offset) {
result = result & CompareResult::kGT;
} else if (cmp.offset_ >= offset) {
result = result & CompareResult::kGE;
}
break;
case CompareResult::kNE:
if (offset == cmp.offset_) {
result = result & CompareResult::kNE;
}
break;
case CompareResult::kUnknown:
break;
case CompareResult::kGT:
case CompareResult::kLT:
LOG(FATAL) << "Internal error, normalized comparisons should only include <= and >=";
default:
LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(cmp.result_);
}
}
return result;
}
} // namespace arith
} // namespace tvm