blob: 1c5f31a913a13d8ca691de2d6014d4c87341441d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arith/conjunctive_normal_form.cc
*/
#include "conjunctive_normal_form.h"
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <optional>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "pattern_match.h"
#include "rewrite_simplify.h"
namespace tvm {
namespace arith {
namespace {
/* \brief A utility for simplifying expressions using conjunctive/disjuctive normal forms */
class AndOfOrs {
public:
/*! \brief Construct the simplifier
*
* Convert a PrimExpr to the internal representation.
*
* \param expr The PrimExpr to be simplified.
*/
explicit AndOfOrs(const PrimExpr& expr);
/*! \brief Convert internal representation to PrimExpr */
PrimExpr AsPrimExpr() const;
/*! \brief Simplify the internal representation */
void Simplify(Analyzer* analyzer);
private:
/*! \brief Internal utility, simplify within each group of expressions
*
* For each pair of values within a chunk, attempt to simplify them into
* a single expression.
*
* For example,
* before = (a == 5) && ((b < 10) || (b > 10))
* after = (a == 5) && ((b != 10) || false)
*/
void SimplifyWithinChunks(Analyzer* analyzer);
/*! \brief Internal utility, simplify across groups of expressions
*
* For each pair of chunks, if the two chunks differ by only a single
* term, attempt to simplify those differing terms.
*
* For example,
* before = ((a == 5) || (b <= 10)) && ((a == 5) || (b >= 10))
* after = ((a == 5) || (b == 10)) && ((a == 5) || true)
*/
void SimplifyAcrossChunks(Analyzer* analyzer);
/*! \brief Remove instances of true/false from internal representation
*
* To avoid invalidating iterators, `SimplifyWithinChunks` and
* `SimplifyAcrossChunks` may replace keys, but may not remove keys
* from the internal representation. For example, `(a < 5) && (a <
* 10)` would be simplified to `(a < 5) && true`. The
* `RemoveTrueFalse` function removes these leftover instances of
* true/false.
*/
void RemoveTrueFalse();
/*! \brief Internal utility function used to convert to internal form */
static void VisitAndExpressions(const PrimExpr& expr,
std::function<void(const PrimExpr&)> callback);
/*! \brief Internal utility function used to convert to internal form */
static void VisitOrExpressions(const PrimExpr& expr,
std::function<void(const PrimExpr&)> callback);
/* \brief Type-safe wrapper class that represents an PrimExpr
*
* Because integer indices are used frequently through this class,
* maintaining a separation between integer indices used to access
* specific elements of the internal representation, and unique
* identifiers used to represent expressions PrimExpr, is useful.
*/
enum class Key : size_t {};
/*! \brief Convert a PrimExpr to a Key */
Key GetKey(const PrimExpr& expr);
/*! \brief Convert a Key to a PrimExpr */
PrimExpr GetExpr(Key key) const;
/*! \brief Attempt to simplify (a && b)
*
* If successful, will overwrite the parameters `a` and `b` with the
* simplified form.
*/
void TrySimplifyOr(Key* a, Key* b, Analyzer* analyzer);
/*! \brief Attempt to simplify (a || b)
*
* If successful, will overwrite the parameters `a` and `b` with the
* simplified form.
*/
void TrySimplifyAnd(Key* a, Key* b, Analyzer* analyzer);
/*! \brief The internal representation
*
* `chunks[i][j]` is the j-th expression in the i-th OR-group.
*/
std::vector<std::vector<Key>> chunks_;
/*! \brief Mapping from internal Key to PrimExpr */
std::unordered_map<Key, PrimExpr, StructuralHash, StructuralEqual> key_to_expr_;
/*! \brief Mapping from PrimExpr to internal Key */
std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key_;
/*! \brief Cached key representing tir::Bool(true) */
Key key_true_;
/*! \brief Cached key representing tir::Bool(false) */
Key key_false_;
};
AndOfOrs::AndOfOrs(const PrimExpr& expr)
: key_true_(GetKey(Bool(true))), key_false_(GetKey(Bool(false))) {
VisitAndExpressions(expr, [&](const PrimExpr& outer_expr) {
std::vector<Key> or_components;
VisitOrExpressions(outer_expr, [&](const PrimExpr& inner_expr) {
Key key = GetKey(inner_expr);
bool is_duplicate = std::any_of(or_components.begin(), or_components.end(),
[&](Key prev) { return prev == key; });
if (!is_duplicate) {
or_components.push_back(key);
}
});
bool is_permutation =
std::any_of(chunks_.begin(), chunks_.end(), [&](const std::vector<Key>& prev_components) {
return or_components.size() == prev_components.size() &&
std::is_permutation(prev_components.begin(), prev_components.end(),
or_components.begin());
});
if (!is_permutation) {
chunks_.push_back(std::move(or_components));
}
});
}
void AndOfOrs::VisitAndExpressions(const PrimExpr& expr,
std::function<void(const PrimExpr&)> callback) {
PVar<PrimExpr> x, y, z;
if ((x && y).Match(expr)) {
// These are separate AND conditions, recurse into them in case
// they contain AND internally.
VisitAndExpressions(x.Eval(), callback);
VisitAndExpressions(y.Eval(), callback);
} else if ((x || y).Match(expr)) {
// This may be the bottom-most breakdown, but either x or y may
// themselves contain AND. (e.g. (A && B) || (C && D) should be
// split into (A || C), (A || D), (B || C), and (B || D).)
// Recurse into each, then reconstruct an OR condition.
VisitAndExpressions(x.Eval(), [&](const PrimExpr& x_part) {
VisitAndExpressions(y.Eval(), [&](const PrimExpr& y_part) { callback(x_part || y_part); });
});
} else {
// This is bottom-most breakdown.
callback(expr);
}
}
void AndOfOrs::VisitOrExpressions(const PrimExpr& expr,
std::function<void(const PrimExpr&)> callback) {
PVar<PrimExpr> x, y, z;
if ((x || y).Match(expr)) {
// These are separate OR conditions, recurse into them in case
// they contain OR internally.
VisitOrExpressions(x.Eval(), callback);
VisitOrExpressions(y.Eval(), callback);
} else if ((x && y).Match(expr)) {
// This may be the bottom-most breakdown, but either x or y may
// themselves contain OR. (e.g. (A || B) && (C || D) should be
// split into (A && C), (A && D), (B && C), and (B && D).)
// Recurse into each, then reconstruct an AND condition.
VisitOrExpressions(x.Eval(), [&](const PrimExpr& x_part) {
VisitOrExpressions(y.Eval(), [&](const PrimExpr& y_part) { callback(x_part && y_part); });
});
} else {
// This is bottom-most breakdown.
callback(expr);
}
}
AndOfOrs::Key AndOfOrs::GetKey(const PrimExpr& expr) {
auto it = expr_to_key_.find(expr);
if (it != expr_to_key_.end()) {
return it->second;
}
Key key{expr_to_key_.size()};
expr_to_key_[expr] = key;
key_to_expr_[key] = expr;
return key;
}
PrimExpr AndOfOrs::GetExpr(AndOfOrs::Key key) const {
auto it = key_to_expr_.find(key);
ICHECK(it != key_to_expr_.end());
return it->second;
}
PrimExpr AndOfOrs::AsPrimExpr() const {
PrimExpr expr = Bool(true);
for (const auto& chunk : chunks_) {
PrimExpr chunk_expr = Bool(false);
for (Key j : chunk) {
chunk_expr = chunk_expr || GetExpr(j);
}
expr = expr && chunk_expr;
}
return expr;
}
void AndOfOrs::TrySimplifyOr(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) {
Key& a = *a_ptr;
Key& b = *b_ptr;
PrimExpr joint = GetExpr(a) || GetExpr(b);
PrimExpr simplified = analyzer->rewrite_simplify(joint);
if (!ExprDeepEqual()(simplified, joint)) {
if (auto* simplified_or = simplified.as<OrNode>()) {
a = GetKey(simplified_or->a);
b = GetKey(simplified_or->b);
} else {
a = key_false_;
b = GetKey(simplified);
}
}
}
void AndOfOrs::TrySimplifyAnd(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) {
Key& a = *a_ptr;
Key& b = *b_ptr;
PrimExpr joint = GetExpr(a) && GetExpr(b);
PrimExpr simplified = analyzer->rewrite_simplify(joint);
if (!ExprDeepEqual()(simplified, joint)) {
if (auto* simplified_and = simplified.as<AndNode>()) {
a = GetKey(simplified_and->a);
b = GetKey(simplified_and->b);
} else {
a = key_true_;
b = GetKey(simplified);
}
}
}
void AndOfOrs::Simplify(Analyzer* analyzer) {
SimplifyWithinChunks(analyzer);
RemoveTrueFalse();
SimplifyAcrossChunks(analyzer);
RemoveTrueFalse();
}
void AndOfOrs::SimplifyWithinChunks(Analyzer* analyzer) {
for (auto& chunk : chunks_) {
for (size_t expr_i = 0; expr_i < chunk.size(); expr_i++) {
for (size_t expr_j = expr_i + 1; expr_j < chunk.size(); expr_j++) {
Key& key_i = chunk[expr_i];
Key& key_j = chunk[expr_j];
TrySimplifyOr(&key_i, &key_j, analyzer);
}
}
}
}
void AndOfOrs::SimplifyAcrossChunks(Analyzer* analyzer) {
for (size_t i_and = 0; i_and < chunks_.size(); i_and++) {
for (size_t j_and = i_and + 1; j_and < chunks_.size(); j_and++) {
auto& i_chunk = chunks_[i_and];
auto& j_chunk = chunks_[j_and];
if (i_chunk.size() == 1 && j_chunk.size() == 1) {
auto& key_i = i_chunk[0];
auto& key_j = j_chunk[0];
TrySimplifyAnd(&key_i, &key_j, analyzer);
continue;
}
std::unordered_set<Key> j_set(j_chunk.begin(), j_chunk.end());
std::optional<size_t> i_distinct_index;
for (size_t i = 0; i < i_chunk.size(); i++) {
if (!j_set.count(i_chunk[i])) {
i_distinct_index = i;
break;
}
}
if (!i_distinct_index.has_value()) {
// I = (i_0 || i_1 || ... || i_N)
// J = (i_0 || i_1 || ... || i_N || j_0 || ... || j_N)
// I && J == I == I && true
j_chunk = {key_true_};
continue;
}
std::unordered_set<Key> i_set(i_chunk.begin(), i_chunk.end());
std::optional<size_t> j_distinct_index;
for (size_t j = 0; j < j_chunk.size(); j++) {
if (!i_set.count(j_chunk[j])) {
j_distinct_index = j;
break;
}
}
if (!j_distinct_index.has_value()) {
// I = (i_0 || ... || i_N || j_0 || ... || j_N)
// J = (j_0 || ... || j_N)
// I && J == J == true && J
i_chunk = {key_true_};
continue;
}
if (i_chunk.size() == j_chunk.size()) {
size_t num_shared_exprs = 0;
for (const auto& j_key : j_chunk) {
if (i_set.count(j_key)) {
++num_shared_exprs;
}
}
if (num_shared_exprs + 1 == i_chunk.size()) {
// All but one of the expressions are shared. If the AND
// of the distinct expressions can be simplified, we can
// replace.
//
// (A or B) and (A or C) => A or (B and C)
auto& key_i = i_chunk[i_distinct_index.value()];
auto& key_j = j_chunk[j_distinct_index.value()];
// When attempting to simplify (B and C), the analyzer may
// assume that A is false.
PrimExpr known = [&]() {
PrimExpr known = Bool(true);
for (const auto& key : i_chunk) {
if (&key != &key_i) {
known = known && analyzer->Simplify(!GetExpr(key));
}
}
return known;
}();
With<ConstraintContext> context(analyzer, known);
TrySimplifyAnd(&key_i, &key_j, analyzer);
}
}
}
}
}
void AndOfOrs::RemoveTrueFalse() {
for (auto& chunk : chunks_) {
// Any occurrence of True inside an OR makes the entire expression True.
if (std::any_of(chunk.begin(), chunk.end(), [&](Key key) { return key == key_true_; })) {
chunk = {key_true_};
} else {
// Any occurrence of False inside an OR can be removed
chunk.erase(
std::remove_if(chunk.begin(), chunk.end(), [&](Key key) { return key == key_false_; }),
chunk.end());
}
}
// Any occurence of False inside an AND makes the entire expression False.
if (std::any_of(chunks_.begin(), chunks_.end(),
[&](const std::vector<Key>& chunk) { return chunk.size() == 0; })) {
chunks_ = {{}};
} else {
// Any occurrence of True inside an AND can be removed.
chunks_.erase(std::remove_if(chunks_.begin(), chunks_.end(),
[&](const std::vector<Key>& chunk) {
return chunk.size() == 1 && chunk[0] == key_true_;
}),
chunks_.end());
}
}
// Helper utility for temporarily disabling the
// kConvertBooleanToAndOfOrs flag on an analyzer, to prevent infinite
// recursion.
class DisableAndOfOrRecursion {
public:
explicit DisableAndOfOrRecursion(Analyzer* analyzer)
: analyzer_(analyzer), cached_flags_(analyzer->rewrite_simplify.GetEnabledExtensions()) {
auto new_flags = static_cast<RewriteSimplifier::Extension>(
cached_flags_ & (~RewriteSimplifier::kConvertBooleanToAndOfOrs));
analyzer->rewrite_simplify.SetEnabledExtensions(new_flags);
}
~DisableAndOfOrRecursion() { analyzer_->rewrite_simplify.SetEnabledExtensions(cached_flags_); }
DisableAndOfOrRecursion(const DisableAndOfOrRecursion&) = delete;
DisableAndOfOrRecursion& operator=(const DisableAndOfOrRecursion&) = delete;
private:
Analyzer* analyzer_;
RewriteSimplifier::Extension cached_flags_;
};
} // namespace
PrimExpr SimplifyAsAndOfOrs(const PrimExpr& expr, Analyzer* analyzer) {
DisableAndOfOrRecursion context(analyzer);
AndOfOrs repr(analyzer->Simplify(expr));
repr.Simplify(analyzer);
return repr.AsPrimExpr();
}
} // namespace arith
} // namespace tvm