blob: fa846a51cf0ad1fed62892d104410552a582e3c9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "config.h"
#include <base/arithmeticOverflow.h>
#if USE_EMBEDDED_COMPILER
#include <DataTypes/Native.h>
#include <llvm/IR/IRBuilder.h>
#endif
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wbit-int-extension"
using BitInt128 = signed _BitInt(128);
using BitUInt128 = unsigned _BitInt(128);
#if defined(__x86_64__)
using BitInt256 = signed _BitInt(256);
using BitUInt256 = unsigned _BitInt(256);
#else
// up to version 18, clang supports large _Bitint sizes on x86 and x86-64;
// but on arm and aarch64, they are currently only supported up to 128 bits.
// https://stackoverflow.com/questions/78614816/why-am-i-getting-a-256-bit-arithmetic-error-unsigined-bitint-of-bit-sizes-gre
using BitInt256 = Int256;
using BitUInt256 = UInt256;
#endif
namespace local_engine
{
static bool canCastLower(const Int256 & a, const Int256 & b)
{
return a.items[2] == 0 && a.items[3] == 0 && b.items[2] == 0 && b.items[3] == 0;
}
static bool canCastLower(const Int128 & a, const Int128 & b)
{
return a.items[1] == 0 && b.items[1] == 0;
}
static bool canCastLower(const UInt128 & a, const UInt128 & b)
{
return a.items[1] == 0 && b.items[1] == 0;
}
static const Int256 & toInt256(const BitInt256 & value)
{
return *reinterpret_cast<const Int256 *>(&value);
}
static const BitInt256 & toBitInt256(const Int256 & value)
{
return *reinterpret_cast<const BitInt256 *>(&value);
}
/// TODO(taiyang-li): remove all overflow checking in below codes because we have already checked overflow in SparkDecimalBinaryOperation
struct DecimalPlusImpl
{
template <typename T>
static bool apply(T a, T b, T & r)
{
r = a + b;
return true;
}
template <>
static bool apply(Int128 a, Int128 b, Int128 & r)
{
if (canCastLower(a, b))
{
UInt64 low_result;
if (!common::addOverflow(static_cast<UInt64>(a), static_cast<UInt64>(b), low_result))
{
r = static_cast<Int128>(low_result);
chassert(r == a + b);
return true;
}
}
r = a + b;
return true;
}
template <>
static bool apply(Int256 a, Int256 b, Int256 & r)
{
if (canCastLower(a, b))
{
UInt128 low_result;
if (!common::addOverflow(static_cast<UInt128>(a), static_cast<UInt128>(b), low_result))
{
r = static_cast<Int256>(low_result);
chassert(r == a + b);
return true;
}
}
r = toInt256(toBitInt256(a) + toBitInt256(b));
chassert(r == a + b);
return true;
}
#if USE_EMBEDDED_COMPILER
static constexpr bool compilable = true;
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, llvm::Value * right, bool)
{
return left->getType()->isIntegerTy() ? b.CreateAdd(left, right) : b.CreateFAdd(left, right);
}
#endif
};
struct DecimalMinusImpl
{
/// Apply operation and check overflow. It's used for Deciamal operations. @returns true if overflowed, false otherwise.
template <typename T>
static bool apply(T a, T b, T & r)
{
r = a - b;
return true;
}
template <>
static bool apply(Int128 a, Int128 b, Int128 & r)
{
if (canCastLower(a, b))
{
Int64 low_result;
if (!common::subOverflow(static_cast<Int64>(a), static_cast<Int64>(b), low_result))
{
r = static_cast<Int128>(low_result);
chassert(r == a - b);
return true;
}
}
r = a - b;
return true;
}
template <>
static bool apply(Int256 a, Int256 b, Int256 & r)
{
if (canCastLower(a, b))
{
Int128 low_result;
if (!common::subOverflow(static_cast<Int128>(a), static_cast<Int128>(b), low_result))
{
r = static_cast<Int256>(low_result);
chassert(r == a - b);
return true;
}
}
r = toInt256(toBitInt256(a) - toBitInt256(b));
chassert(r == a - b);
return true;
}
#if USE_EMBEDDED_COMPILER
static constexpr bool compilable = true;
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, llvm::Value * right, bool)
{
return left->getType()->isIntegerTy() ? b.CreateSub(left, right) : b.CreateFSub(left, right);
}
#endif
};
struct DecimalMultiplyImpl
{
/// Apply operation and check overflow. It's used for Decimal operations. @returns true if overflowed, false otherwise.
template <typename T>
static bool apply(T a, T b, T & c)
{
c = a * b;
return true;
}
template <>
static bool apply(Int128 a, Int128 b, Int128 & r)
{
if (canCastLower(a, b))
{
UInt64 low_result = 0;
if (!common::mulOverflow(static_cast<UInt64>(a), static_cast<UInt64>(b), low_result))
{
r = static_cast<Int128>(low_result);
chassert(r == a * b);
return true;
}
}
r = a * b;
return true;
}
template <>
static bool apply(Int256 a, Int256 b, Int256 & r)
{
/// Notice that we can't use common::mulOverflow here because it doesn't support checking overflow on Int128 multiplication.
r = toInt256(toBitInt256(a) * toBitInt256(b));
chassert(r == a * b);
return true;
}
#if USE_EMBEDDED_COMPILER
static constexpr bool compilable = true;
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, llvm::Value * right, bool)
{
return left->getType()->isIntegerTy() ? b.CreateMul(left, right) : b.CreateFMul(left, right);
}
#endif
};
struct DecimalDivideImpl
{
template <typename T>
static bool apply(T a, T b, T & r)
{
if (b == 0)
return false;
r = a / b;
return true;
}
template <>
static bool apply(Int128 a, Int128 b, Int128 & r)
{
if (b == 0)
return false;
if (canCastLower(a, b))
{
/// We must cast to UInt64 to avoid overflow in the division.
r = static_cast<Int128>(static_cast<UInt64>(a) / static_cast<UInt64>(b));
chassert(r == a / b);
return true;
}
r = a / b;
return true;
}
template <>
static bool apply(UInt128 a, UInt128 b, UInt128 & r)
{
if (b == 0)
return false;
if (canCastLower(a, b))
{
/// We must cast to UInt64 to avoid overflow in the division.
r = static_cast<UInt128>(static_cast<UInt64>(a) / static_cast<UInt64>(b));
chassert(r == a / b);
return true;
}
r = a / b;
return true;
}
template <>
static bool apply(Int256 a, Int256 b, Int256 & r)
{
if (b == 0)
return false;
if (canCastLower(a, b))
{
/// We must cast to UInt128 to avoid overflow in the division.
UInt128 low_result;
apply(static_cast<UInt128>(a), static_cast<UInt128>(b), low_result);
r = static_cast<Int256>(low_result);
chassert(r == a / b);
return true;
}
r = toInt256(toBitInt256(a) / toBitInt256(b));
chassert(r == a / b);
return true;
}
#if USE_EMBEDDED_COMPILER
static constexpr bool compilable = true;
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, llvm::Value * right, bool)
{
return left->getType()->isIntegerTy() ? b.CreateSub(left, right) : b.CreateFSub(left, right);
}
#endif
};
// ModuloImpl
struct DecimalModuloImpl
{
template <typename T>
static bool apply(T a, T b, T & r)
{
if (b == 0)
return false;
r = a % b;
return true;
}
template <>
static bool apply(Int128 a, Int128 b, Int128 & r)
{
if (b == 0)
return false;
if (canCastLower(a, b))
{
/// We must cast to UInt64 to avoid overflow in the division.
r = static_cast<Int128>(static_cast<UInt64>(a) % static_cast<UInt64>(b));
chassert(r == a % b);
return true;
}
r = a % b;
return true;
}
template <>
static bool apply(Int256 a, Int256 b, Int256 & r)
{
if (b == 0)
return false;
if (canCastLower(a, b))
{
/// We must cast to UInt128 to avoid overflow in the division.
r = static_cast<Int256>(static_cast<UInt128>(a) % static_cast<UInt128>(b));
chassert(r == a % b);
return true;
}
r = toInt256(toBitInt256(a) % toBitInt256(b));
chassert(r == a % b);
return true;
}
#if USE_EMBEDDED_COMPILER
static constexpr bool compilable = true;
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, llvm::Value * right, bool)
{
return left->getType()->isIntegerTy() ? b.CreateSub(left, right) : b.CreateFSub(left, right);
}
#endif
};
}