blob: eff52308f389e03f26da1d3cc80645e3fb3ad247 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file int_operator.h
* \brief Additional useful operators for integer.
*/
#ifndef TVM_ARITH_INT_OPERATOR_H_
#define TVM_ARITH_INT_OPERATOR_H_
#include <limits>
#include <utility>
namespace tvm {
namespace arith {
/*!
* \brief Check if an integer op with operand x, y will overflow.
* \param x The left operand.
* \param y The left operand.
* \param min_value The minimum value of the domain.
* \param max_value The maximum value of the domain.
* \return Whether overflow can happen.
* \tparam Op The integer operator.
*/
template <typename Op>
inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
return false;
}
template <>
inline bool WillOverflow<tir::AddNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
if ((y > 0) && (x > max_value - y)) return true;
if ((y < 0) && (x < min_value - y)) return true;
return false;
}
template <>
inline bool WillOverflow<tir::SubNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
if ((y > 0) && (x < min_value + y)) return true;
if ((y < 0) && (x > max_value + y)) return true;
return false;
}
template <>
inline bool WillOverflow<tir::MulNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
if (y == 0) return false;
if (y > 0) {
if (x < min_value / y) return true;
if (x > max_value / y) return true;
} else {
if (y == -1 && x == std::numeric_limits<int64_t>::min()) return true;
if (x > min_value / y) return true;
if (x < max_value / y) return true;
}
return false;
}
template <>
inline bool WillOverflow<tir::ModNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
return y == 0;
}
/*!
* \brief Peform trunc division of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t truncdiv(int64_t x, int64_t y) { return x / y; }
/*!
* \brief Compute the truncdiv remainder of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t truncmod(int64_t x, int64_t y) { return x % y; }
/*!
* \brief Peform floor division of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t floordiv(int64_t x, int64_t y) {
int64_t rdiv = x / y;
int64_t rmod = x % y;
bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0);
return is_floor_div ? rdiv : (rdiv - 1);
}
/*!
* \brief Compute the floordiv remainder of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t floormod(int64_t x, int64_t y) {
int64_t rmod = x % y;
bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0);
return is_floor_div ? rmod : rmod + y;
}
/*!
* \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
* \param a The first coefficient.
* \param b The second coefficient.
* \param x The solution of x.
* \param y The solution of y.
* \return The GCD of a and b.
*/
inline int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) {
// Extended Euclidean algorithm
// if a < 0, the problem can be convert into
// |a|* (-x) + b * y = gcd(|a|, b)
//
// initial condition:
// a * 0 + b * 1 = b
// a * 1 + b * 0 = a
int64_t s = 0, old_s = 1;
int64_t r = b, old_r = a >= 0 ? a : -a;
// Iteration (r2 < r1):
// a * x1 + b * y1 = r1
// a * x2 + b * y2 = r2
// The above two eqs can derive the following eq (q = r1 / r2)
// a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3
// Because r3 < r2, the iteration can eventually terminate
while (r != 0) {
int64_t q = old_r / r;
int64_t tmp = old_r;
old_r = r;
r = tmp - q * r;
tmp = old_s;
old_s = s;
s = tmp - q * s;
}
*x = a >= 0 ? old_s : -old_s;
if (b != 0) {
*y = (old_r - (*x) * a) / b;
} else {
*y = 1;
}
return old_r;
}
/*!
* \brief Take GCD of a and b.
* \param a The first operand.
* \param b The second operand.
* \return The result.
*/
inline int64_t ZeroAwareGCD(int64_t a, int64_t b) {
if (a < 0) a = -a;
if (b < 0) b = -b;
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}
/*!
* \brief Calculate the least common multiple for two values.
* \param a an integer number
* \param b an integer number
* \return the least common multiple.
*/
inline int64_t LeastCommonMultiple(int64_t a, int64_t b) {
int64_t x, y;
return (a * b) / ExtendedEuclidean(a, b, &x, &y);
}
} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_OPERATOR_H_