/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file int_constraints.cc
 * \brief The integer constraints data structures.
 */
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include <algorithm>
#include <unordered_map>
#include <utility>

#include "../tir/transforms/ir_util.h"

namespace tvm {
namespace arith {

Array<PrimExpr> AsConditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds,
                             const Array<PrimExpr>& relations) {
  Array<PrimExpr> res;
  // use variables to keep the order of iteration
  // so as to get rid of any non-determinism.
  CHECK_EQ(variables.size(), bounds.size());
  for (const auto v : variables) {
    CHECK(bounds.count(v));
    const auto& bnds = bounds[v];
    PrimExpr lhs = bnds->coef * v;
    for (const PrimExpr& rhs : bnds->equal) {
      res.push_back(tir::EQ(lhs, rhs));
    }
    for (const PrimExpr& rhs : bnds->lower) {
      res.push_back(tir::GE(lhs, rhs));
    }
    for (const PrimExpr& rhs : bnds->upper) {
      res.push_back(tir::LE(lhs, rhs));
    }
  }
  for (const PrimExpr& e : relations) {
    res.push_back(e);
  }
  return res;
}

IntGroupBounds::IntGroupBounds(PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
                               Array<PrimExpr> upper) {
  CHECK(coef.dtype().is_int() || coef.dtype().is_uint())
      << "Coefficient in IntGroupBounds must be integers";
  ObjectPtr<IntGroupBoundsNode> node = make_object<IntGroupBoundsNode>();
  node->coef = std::move(coef);
  node->lower = std::move(lower);
  node->equal = std::move(equal);
  node->upper = std::move(upper);
  data_ = std::move(node);
}

IntGroupBounds IntGroupBounds::FromRange(const Range& r) {
  Analyzer analyzer;
  PrimExpr coef = tir::make_const(r->min.dtype(), 1);
  Array<PrimExpr> equal;
  Array<PrimExpr> lower;
  Array<PrimExpr> upper;
  if (tir::is_one(r->extent)) {
    equal.push_back(r->min);
  } else {
    lower.push_back(r->min);
    upper.push_back(analyzer.Simplify(r->min + r->extent - 1));
  }
  return IntGroupBounds(coef, lower, equal, upper);
}

IntGroupBounds IntGroupBounds::operator+(const Range& r) {
  Analyzer analyzer;
  Array<PrimExpr> equal;
  Array<PrimExpr> lower;
  Array<PrimExpr> upper;
  const PrimExpr& coef = operator->()->coef;
  if (tir::is_one(r->extent)) {
    equal.push_back(analyzer.Simplify(r->min * coef));
  } else {
    lower.push_back(analyzer.Simplify(r->min * coef));
    upper.push_back(analyzer.Simplify((r->min + r->extent - 1) * coef));
  }
  for (const auto& eq : operator->()->equal) equal.push_back(eq);
  for (const auto& lb : operator->()->lower) lower.push_back(lb);
  for (const auto& ub : operator->()->upper) upper.push_back(ub);
  return IntGroupBounds(coef, lower, equal, upper);
}

IntGroupBounds IntGroupBounds::Substitute(const Map<Var, PrimExpr>& subst) const {
  auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); };
  return IntGroupBounds(tir::Substitute(operator->()->coef, subst),
                        tir::UpdateArray(operator->()->lower, apply_fun),
                        tir::UpdateArray(operator->()->equal, apply_fun),
                        tir::UpdateArray(operator->()->upper, apply_fun));
}

Range IntGroupBounds::FindBestRange(const Map<Var, Range>& vranges_addl) const {
  Analyzer analyzer;
  analyzer.Bind(vranges_addl);

  std::unordered_map<const VarNode*, IntSet> var_intsets;
  for (auto kv : vranges_addl) {
    var_intsets[kv.first.get()] = IntSet::FromRange(kv.second);
  }

  const Array<PrimExpr>& equal = operator->()->equal;
  const PrimExpr& coef = operator->()->coef;

  std::vector<PrimExpr> lowers(equal.begin(), equal.end());
  std::vector<PrimExpr> uppers(equal.begin(), equal.end());
  for (const auto& expr : operator->()->lower) {
    lowers.push_back(expr);
  }
  for (const auto& expr : operator->()->upper) {
    uppers.push_back(expr);
  }

  if (lowers.size() == 1 && uppers.size() == 1 && tir::is_one(coef)) {
    return Range(analyzer.Simplify(lowers[0]), analyzer.Simplify(uppers[0] + 1));
  }

  // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the
  // pair with the minimal difference between the upper and the lower.
  // Note that the bounds are for v, not for v*coef

  // The lower bound of the best pair so far
  PrimExpr best_lower;
  // The difference between the upper and the lower of the best pair, maybe overapproximation
  PrimExpr best_diff_over;

  for (const PrimExpr& low : lowers) {
    for (const PrimExpr& upp : uppers) {
      PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3);
      // Since diff may depend on some other variables, we compute its overapproximation
      PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max(), 3);

      // low is the lower bound for v*coef, but we need the lower bound for v.
      // We use rounding-up division to compute it. Since we want to use a single formula
      PrimExpr low_divided = analyzer.Simplify(floordiv(low + coef - 1, coef), 3);

      // Compute another difference which may be more precise (or not).
      PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, coef) - low_divided, 3);
      PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max(), 3);

      PrimExpr diff_over =
          analyzer.CanProve(diff_over_2 - diff_over_1 < 0) ? diff_over_2 : diff_over_1;

      // If it is provable that the new one is strictly better than the current best one,
      // then replace it. Note that we are biased towards earlier pairs which should be simpler.
      if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) {
        best_lower = low_divided;
        best_diff_over = diff_over;
      }
    }
  }

  if (!best_lower.defined()) {
    CHECK(!best_diff_over.defined());
    return Range();
  }
  return Range::FromMinExtent(best_lower, analyzer.Simplify(best_diff_over + 1));
}

TVM_REGISTER_NODE_TYPE(IntGroupBoundsNode);

TVM_REGISTER_GLOBAL("arith.IntGroupBounds")
    .set_body_typed([](PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
                       Array<PrimExpr> upper) {
      return IntGroupBounds(coef, lower, equal, upper);
    });

TVM_REGISTER_GLOBAL("arith.IntGroupBounds_from_range").set_body_typed(IntGroupBounds::FromRange);

TVM_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange")
    .set_body([](TVMArgs args, TVMRetValue* ret) {
      CHECK(args.size() == 1 || args.size() == 2);
      IntGroupBounds bounds = args[0];
      if (args.size() == 1) {
        *ret = bounds.FindBestRange();
      } else if (args.size() == 2) {
        *ret = bounds.FindBestRange(args[1]);
      }
    });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
    .set_dispatch<IntGroupBoundsNode>([](const ObjectRef& node, ReprPrinter* p) {
      auto* op = static_cast<const IntGroupBoundsNode*>(node.get());
      p->stream << "IntGroupBounds(coef=" << op->coef << ", lower=" << op->lower
                << ", equal=" << op->equal << ", upper=" << op->upper << ")";
    });

IntConstraints::IntConstraints(Array<Var> variables, Map<Var, Range> ranges,
                               Array<PrimExpr> relations) {
  ObjectPtr<IntConstraintsNode> node = make_object<IntConstraintsNode>();
  if (!variables.defined()) {
    variables = Array<Var>();
  }
  if (!ranges.defined()) {
    ranges = Map<Var, Range>();
  }
  CHECK(relations.defined());
  for (const auto& var : variables) {
    CHECK(var.dtype().is_int() || var.dtype().is_uint())
        << "Variables in IntConstraints must be integers";
  }
  node->variables = std::move(variables);
  node->ranges = std::move(ranges);
  node->relations = std::move(relations);
  data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(IntConstraintsNode);

TVM_REGISTER_GLOBAL("arith.IntConstraints")
    .set_body_typed([](Array<Var> variables, Map<Var, Range> ranges, Array<PrimExpr> relations) {
      return IntConstraints(variables, ranges, relations);
    });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
    .set_dispatch<IntConstraintsNode>([](const ObjectRef& node, ReprPrinter* p) {
      auto* op = static_cast<const IntConstraintsNode*>(node.get());
      p->stream << "IntConstraints(" << op->variables << ", " << op->ranges << ", " << op->relations
                << ")";
    });

IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstraints dst,
                                                 Map<Var, PrimExpr> src_to_dst,
                                                 Map<Var, PrimExpr> dst_to_src) {
  ObjectPtr<IntConstraintsTransformNode> node = make_object<IntConstraintsTransformNode>();
  node->src = std::move(src);
  node->dst = std::move(dst);
  node->src_to_dst = std::move(src_to_dst);
  node->dst_to_src = std::move(dst_to_src);
  data_ = std::move(node);
}

IntConstraintsTransform IntConstraintsTransform::operator+(
    const IntConstraintsTransform& other) const {
  CHECK(other->src.same_as(operator->()->dst));
  Map<Var, PrimExpr> dst_to_src;
  Map<Var, PrimExpr> src_to_dst;

  Analyzer ana_first;
  ana_first.Bind(operator->()->src->ranges);
  for (auto p : other->dst_to_src) {
    dst_to_src.Set(p.first, ana_first.Simplify(Substitute(p.second, operator->()->dst_to_src)));
  }

  Analyzer ana_second;
  ana_second.Bind(other->dst->ranges);
  for (auto p : operator->()->src_to_dst) {
    src_to_dst.Set(p.first, ana_second.Simplify(Substitute(p.second, other->src_to_dst)));
  }
  return IntConstraintsTransform(operator->()->src, other->dst, src_to_dst, dst_to_src);
}

TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode);

TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform")
    .set_body_typed([](IntConstraints src, IntConstraints dst, Map<Var, PrimExpr> src_to_dst,
                       Map<Var, PrimExpr> dst_to_src) {
      return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src);
    });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
    .set_dispatch<IntConstraintsTransformNode>([](const ObjectRef& node, ReprPrinter* p) {
      auto* op = static_cast<const IntConstraintsTransformNode*>(node.get());
      p->stream << "IntConstraintsTransform("
                << "\n\t" << op->src << "\n\t" << op->dst << "\n\t" << op->src_to_dst << "\n\t"
                << op->dst_to_src << "\n)";
    });

}  // namespace arith
}  // namespace tvm
