ARROW-11070: [C++][Compute] Implement power kernel
This is to resolve [ARROW-11070](https://issues.apache.org/jira/projects/ARROW/issues/ARROW-11070).
Closes #9841 from rok/ARROW-11070
Lead-authored-by: Rok <rok@mihevc.org>
Co-authored-by: Yibo Cai <yibo.cai@arm.com>
Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc
index f4696fb..d169fd2 100644
--- a/cpp/src/arrow/compute/api_scalar.cc
+++ b/cpp/src/arrow/compute/api_scalar.cc
@@ -52,6 +52,7 @@
SCALAR_ARITHMETIC_BINARY(Subtract, "subtract", "subtract_checked")
SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked")
SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked")
+SCALAR_ARITHMETIC_BINARY(Power, "power", "power_checked")
// ----------------------------------------------------------------------
// Set-related operations
diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h
index f59426d..6032f65 100644
--- a/cpp/src/arrow/compute/api_scalar.h
+++ b/cpp/src/arrow/compute/api_scalar.h
@@ -204,6 +204,20 @@
ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);
+/// \brief Raise the values of base array to the power of the exponent array values.
+/// Array values must be the same length. If either base or exponent is null the result
+/// will be null.
+///
+/// \param[in] left the base
+/// \param[in] right the exponent
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise base value raised to the power of exponent
+ARROW_EXPORT
+Result<Datum> Power(const Datum& left, const Datum& right,
+ ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
/// \brief Compare a numeric array with a scalar.
///
/// \param[in] left datum to compare, must be an Array
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
index 7abaa1c..260721b 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.
+#include <cmath>
+
#include "arrow/compute/kernels/common.h"
#include "arrow/util/int_util_internal.h"
#include "arrow/util/macros.h"
@@ -233,6 +235,70 @@
}
};
+struct Power {
+ ARROW_NOINLINE
+ static uint64_t IntegerPower(uint64_t base, uint64_t exp) {
+ // right to left O(logn) power
+ uint64_t pow = 1;
+ while (exp) {
+ pow *= (exp & 1) ? base : 1;
+ base *= base;
+ exp >>= 1;
+ }
+ return pow;
+ }
+
+ template <typename T>
+ static enable_if_integer<T> Call(KernelContext* ctx, T base, T exp) {
+ if (exp < 0) {
+ ctx->SetStatus(
+ Status::Invalid("integers to negative integer powers are not allowed"));
+ return 0;
+ }
+ return static_cast<T>(IntegerPower(base, exp));
+ }
+
+ template <typename T>
+ static enable_if_floating_point<T> Call(KernelContext* ctx, T base, T exp) {
+ return std::pow(base, exp);
+ }
+};
+
+struct PowerChecked {
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_integer<T> Call(KernelContext* ctx, Arg0 base, Arg1 exp) {
+ if (exp < 0) {
+ ctx->SetStatus(
+ Status::Invalid("integers to negative integer powers are not allowed"));
+ return 0;
+ } else if (exp == 0) {
+ return 1;
+ }
+ // left to right O(logn) power with overflow checks
+ bool overflow = false;
+ uint64_t bitmask =
+ 1ULL << (63 - BitUtil::CountLeadingZeros(static_cast<uint64_t>(exp)));
+ T pow = 1;
+ while (bitmask) {
+ overflow |= MultiplyWithOverflow(pow, pow, &pow);
+ if (exp & bitmask) {
+ overflow |= MultiplyWithOverflow(pow, base, &pow);
+ }
+ bitmask >>= 1;
+ }
+ if (overflow) {
+ ctx->SetStatus(Status::Invalid("overflow"));
+ }
+ return pow;
+ }
+
+ template <typename T, typename Arg0, typename Arg1>
+ static enable_if_floating_point<T> Call(KernelContext* ctx, Arg0 base, Arg1 exp) {
+ static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
+ return std::pow(base, exp);
+ }
+};
+
// Generate a kernel given an arithmetic functor
template <template <typename... Args> class KernelGenerator, typename Op>
ArrayKernelExec NumericEqualTypesBinary(detail::GetTypeId get_id) {
@@ -359,6 +425,18 @@
"integer overflow is encountered."),
{"dividend", "divisor"}};
+const FunctionDoc pow_doc{
+ "Raise arguments to power element-wise",
+ ("Integer to negative integer power returns an error. However, integer overflow\n"
+ "wraps around. If either base or exponent is null the result will be null."),
+ {"base", "exponent"}};
+
+const FunctionDoc pow_checked_doc{
+ "Raise arguments to power element-wise",
+ ("An error is returned when integer to negative integer power is encountered,\n"
+ "or integer overflow is encountered."),
+ {"base", "exponent"}};
+
} // namespace
void RegisterScalarArithmetic(FunctionRegistry* registry) {
@@ -407,6 +485,15 @@
auto divide_checked =
MakeArithmeticFunctionNotNull<DivideChecked>("divide_checked", &div_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(divide_checked)));
+
+ // ----------------------------------------------------------------------
+ auto power = MakeArithmeticFunction<Power>("power", &pow_doc);
+ DCHECK_OK(registry->AddFunction(std::move(power)));
+
+ // ----------------------------------------------------------------------
+ auto power_checked =
+ MakeArithmeticFunctionNotNull<PowerChecked>("power_checked", &pow_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(power_checked)));
}
} // namespace internal
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
index 4d4f14e..cd5f298 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
@@ -590,6 +590,114 @@
this->AssertBinop(Divide, MakeArray(min), MakeArray(-1), "[0]");
}
+TYPED_TEST(TestBinaryArithmeticFloating, Power) {
+ using CType = typename TestFixture::CType;
+ auto max = std::numeric_limits<CType>::max();
+ this->SetNansEqual(true);
+
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+
+ // Empty arrays
+ this->AssertBinop(Power, "[]", "[]", "[]");
+ // Ordinary arrays
+ this->AssertBinop(Power, "[3.4, 16, 0.64, 1.2, 0]", "[1, 0.5, 2, 4, 0]",
+ "[3.4, 4, 0.4096, 2.0736, 1]");
+ // Array with nulls
+ this->AssertBinop(Power, "[null, 1, 3.3, null, 2]", "[1, 4, 2, 5, 0.1]",
+ "[null, 1, 10.89, null, 1.07177346]");
+ // Scalar exponentiated by array
+ this->AssertBinop(Power, 10.0F, "[null, 1, 2.5, null, 2, 5]",
+ "[null, 10, 316.227766017, null, 100, 100000]");
+ // Array exponentiated by scalar
+ this->AssertBinop(Power, "[null, 1, 2.5, null, 2, 5]", 10.0F,
+ "[null, 1, 9536.74316406, null, 1024, 9765625]");
+ // Array with infinity
+ this->AssertBinop(Power, "[3.4, Inf, -Inf, 1.1, 100000]", "[1, 2, 3, Inf, 100000]",
+ "[3.4, Inf, -Inf, Inf, Inf]");
+ // Array with NaN
+ this->AssertBinop(Power, "[3.4, NaN, 2.0]", "[1, 2, 2.0]", "[3.4, NaN, 4.0]");
+ // Scalar exponentiated by scalar
+ this->AssertBinop(Power, 21.0F, 3.0F, 9261.0F);
+ // Divide by zero
+ this->AssertBinop(Power, "[0.0, 0.0]", "[-1.0, -3.0]", "[Inf, Inf]");
+ // Check overflow behaviour
+ this->AssertBinop(Power, max, 10, INFINITY);
+ }
+
+ // Edge cases - removing NaNs
+ this->AssertBinop(Power, "[1, NaN, 0, null, 1.2, -Inf, Inf, 1.1, 1, 0, 1, 0]",
+ "[NaN, 0, NaN, 1, null, 1, 2, -Inf, Inf, 0, 0, 42]",
+ "[1, 1, NaN, null, null, -Inf, Inf, 0, 1, 1, 1, 0]");
+}
+
+TYPED_TEST(TestBinaryArithmeticIntegral, Power) {
+ using CType = typename TestFixture::CType;
+ auto max = std::numeric_limits<CType>::max();
+
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+
+ // Empty arrays
+ this->AssertBinop(Power, "[]", "[]", "[]");
+ // Ordinary arrays
+ this->AssertBinop(Power, "[3, 2, 6, 2]", "[1, 1, 2, 0]", "[3, 2, 36, 1]");
+ // Array with nulls
+ this->AssertBinop(Power, "[null, 2, 3, null, 20]", "[1, 6, 2, 5, 1]",
+ "[null, 64, 9, null, 20]");
+ // Scalar exponentiated by array
+ this->AssertBinop(Power, 3, "[null, 3, 4, null, 2]", "[null, 27, 81, null, 9]");
+ // Array exponentiated by scalar
+ this->AssertBinop(Power, "[null, 10, 3, null, 2]", 2, "[null, 100, 9, null, 4]");
+ // Scalar exponentiated by scalar
+ this->AssertBinop(Power, 4, 3, 64);
+ // Edge cases
+ this->AssertBinop(Power, "[0, 1, 0]", "[0, 0, 42]", "[1, 1, 0]");
+ }
+
+ // Overflow raises
+ this->SetOverflowCheck(true);
+ this->AssertBinopRaises(Power, MakeArray(max), MakeArray(10), "overflow");
+ // Disable overflow check
+ this->SetOverflowCheck(false);
+ this->AssertBinop(Power, max, 10, 1);
+}
+
+TYPED_TEST(TestBinaryArithmeticSigned, Power) {
+ using CType = typename TestFixture::CType;
+ auto max = std::numeric_limits<CType>::max();
+
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+
+ // Empty arrays
+ this->AssertBinop(Power, "[]", "[]", "[]");
+ // Ordinary arrays
+ this->AssertBinop(Power, "[-3, 2, -6, 2]", "[3, 1, 2, 0]", "[-27, 2, 36, 1]");
+ // Array with nulls
+ this->AssertBinop(Power, "[null, 10, 127, null, -20]", "[1, 2, 1, 5, 1]",
+ "[null, 100, 127, null, -20]");
+ // Scalar exponentiated by array
+ this->AssertBinop(Power, 11, "[null, 1, null, 2]", "[null, 11, null, 121]");
+ // Array exponentiated by scalar
+ this->AssertBinop(Power, "[null, 1, 3, null, 2]", 3, "[null, 1, 27, null, 8]");
+ // Scalar exponentiated by scalar
+ this->AssertBinop(Power, 16, 1, 16);
+ // Edge cases
+ this->AssertBinop(Power, "[1, 0, -1, 2]", "[0, 42, 0, 1]", "[1, 0, 1, 2]");
+ // Divide by zero raises
+ this->AssertBinopRaises(Power, MakeArray(0), MakeArray(-1),
+ "integers to negative integer powers are not allowed");
+ }
+
+ // Overflow raises
+ this->SetOverflowCheck(true);
+ this->AssertBinopRaises(Power, MakeArray(max), MakeArray(10), "overflow");
+ // Disable overflow check
+ this->SetOverflowCheck(false);
+ this->AssertBinop(Power, max, 10, 1);
+}
+
TYPED_TEST(TestBinaryArithmeticFloating, Sub) {
this->AssertBinop(Subtract, "[]", "[]", "[]");
@@ -638,7 +746,7 @@
}
TEST(TestBinaryArithmetic, DispatchBest) {
- for (std::string name : {"add", "subtract", "multiply", "divide"}) {
+ for (std::string name : {"add", "subtract", "multiply", "divide", "power"}) {
for (std::string suffix : {"", "_checked"}) {
name += suffix;
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 715d503..b2ecb3b 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -272,6 +272,10 @@
+--------------------------+------------+--------------------+---------------------+
| divide_checked | Binary | Numeric | Numeric |
+--------------------------+------------+--------------------+---------------------+
+| power | Binary | Numeric | Numeric |
++--------------------------+------------+--------------------+---------------------+
+| power_checked | Binary | Numeric | Numeric |
++--------------------------+------------+--------------------+---------------------+
| multiply | Binary | Numeric | Numeric |
+--------------------------+------------+--------------------+---------------------+
| multiply_checked | Binary | Numeric | Numeric |
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index d6efc6a..da16ccd 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -53,6 +53,8 @@
multiply_checked
subtract
subtract_checked
+ power
+ power_checked
Comparisons
-----------