blob: f924d559a508c383e22fe37c4ad8665399a10809 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <Columns/IColumn.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeNumberBase.h>
#include <Functions/FunctionBinaryArithmetic.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/castTypeToEither.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}
namespace local_engine
{
template <typename A, typename B>
struct SparkDivideFloatingImpl
{
using ResultType = typename DB::NumberTraits::ResultOfFloatingPointDivision<A, B>::Type;
static const constexpr bool allow_fixed_string = false;
static const constexpr bool allow_string_integer = false;
template <typename Result = ResultType>
static inline NO_SANITIZE_UNDEFINED Result apply(A a, B b)
{
return static_cast<Result>(a) / b;
}
};
class SparkFunctionDivide : public DB::IFunction
{
public:
static constexpr auto name = "sparkDivide";
static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<SparkFunctionDivide>(); }
SparkFunctionDivide() = default;
~SparkFunctionDivide() override = default;
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForConstants() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; }
DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & arguments) const override
{
if (arguments.size() != 2)
throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s arguments number must be 2", name);
/// Spark guarantees that the input arguments of sparkDivide are float64
/// by transforming "ia/ib" to "cast(ia as float64)/cast(ib as float64)"
DB::WhichDataType left(arguments[0]);
DB::WhichDataType right(arguments[1]);
if (!left.isFloat64() || !right.isFloat64())
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {}'s arguments type must be float64", name);
return DB::makeNullable(std::make_shared<const DB::DataTypeFloat64>());
}
DB::ColumnPtr
executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const override
{
using L = Float64;
using R = Float64;
using T = Float64;
const DB::ColumnVector<L> * col_left = nullptr;
const DB::ColumnVector<R> * col_right = nullptr;
const DB::ColumnVector<L> * const_col_left = checkAndGetColumnConstData<DB::ColumnVector<L>>(arguments[0].column.get());
const DB::ColumnVector<R> * const_col_right = checkAndGetColumnConstData<DB::ColumnVector<R>>(arguments[1].column.get());
L left_const_val = 0;
if (const_col_left)
left_const_val = const_col_left->getElement(0);
else
col_left = assert_cast<const DB::ColumnVector<L> *>(arguments[0].column.get());
R right_const_val = 0;
if (const_col_right)
{
right_const_val = const_col_right->getElement(0);
if (right_const_val == 0)
{
auto data_col = DB::ColumnVector<T>::create(1, 0);
auto null_map_col = DB::ColumnVector<UInt8>::create(1, 1);
return DB::ColumnConst::create(DB::ColumnNullable::create(std::move(data_col), std::move(null_map_col)), input_rows_count);
}
}
else
col_right = assert_cast<const DB::ColumnVector<R> *>(arguments[1].column.get());
auto res_col = DB::ColumnVector<T>::create(input_rows_count, 0);
auto res_null_map = DB::ColumnVector<UInt8>::create(input_rows_count, 0);
DB::PaddedPODArray<T> & res_data = res_col->getData();
DB::PaddedPODArray<UInt8> & res_null_map_data = res_null_map->getData();
vector(col_left, col_right, left_const_val, right_const_val, res_data, res_null_map_data, input_rows_count);
return DB::ColumnNullable::create(std::move(res_col), std::move(res_null_map));
}
MULTITARGET_FUNCTION_AVX2_SSE42(
MULTITARGET_FUNCTION_HEADER(static void NO_SANITIZE_UNDEFINED NO_INLINE),
vectorImpl,
MULTITARGET_FUNCTION_BODY(
(const DB::ColumnVector<Float64> * col_left,
const DB::ColumnVector<Float64> * col_right,
Float64 left_const_val,
Float64 right_const_val,
DB::PaddedPODArray<Float64> & res_data,
DB::PaddedPODArray<UInt8> & res_null_map_data,
size_t input_rows_count) /// NOLINT
{
if (col_left && col_right)
{
const auto & ldata = col_left->getData();
const auto & rdata = col_right->getData();
for (size_t i = 0; i < input_rows_count; ++i)
{
auto l = ldata[i];
auto r = rdata[i];
res_data[i] = SparkDivideFloatingImpl<Float64, Float64>::apply(l, r ? r : 1);
res_null_map_data[i] = !rdata[i];
}
}
else if (col_left)
{
Float64 r = right_const_val;
for (size_t i = 0; i < input_rows_count; ++i)
{
Float64 l = col_left->getData()[i];
/// r must not be zero because r = 0 is already processed in fast path
/// No need to assign null_map_data[i] = 0, because it is already 0
// res_null_map_data[i] = 0;
res_data[i] = SparkDivideFloatingImpl<Float64, Float64>::apply(l, r);
}
}
else if (col_right)
{
Float64 l = left_const_val;
for (size_t i = 0; i < input_rows_count; ++i)
{
Float64 r = col_right->getData()[i];
res_null_map_data[i] = !r;
res_data[i] = SparkDivideFloatingImpl<Float64, Float64>::apply(l, r ? r : 1);
}
}
}))
static void NO_INLINE vector(
const DB::ColumnVector<Float64> * col_left,
const DB::ColumnVector<Float64> * col_right,
Float64 left_const_val,
Float64 right_const_val,
DB::PaddedPODArray<Float64> & res_data,
DB::PaddedPODArray<UInt8> & res_null_map_data,
size_t input_rows_count)
{
#if USE_MULTITARGET_CODE
if (isArchSupported(DB::TargetArch::AVX2))
{
vectorImplAVX2(col_left, col_right, left_const_val, right_const_val, res_data, res_null_map_data, input_rows_count);
return;
}
if (isArchSupported(DB::TargetArch::SSE42))
{
vectorImplSSE42(col_left, col_right, left_const_val, right_const_val, res_data, res_null_map_data, input_rows_count);
return;
}
#endif
vectorImpl(col_left, col_right, left_const_val, right_const_val, res_data, res_null_map_data, input_rows_count);
}
#if USE_EMBEDDED_COMPILER
bool isCompilableImpl(const DB::DataTypes & arguments, const DB::DataTypePtr & result_type) const override
{
if (2 != arguments.size())
return false;
if (!canBeNativeType(*arguments[0]) || !canBeNativeType(*arguments[1]) || !canBeNativeType(*result_type))
return false;
return true;
}
llvm::Value *
compileImpl(llvm::IRBuilderBase & builder, const DB::ValuesWithType & arguments, const DB::DataTypePtr & result_type) const override
{
assert(2 == arguments.size());
auto * left = arguments[0].value;
auto * right = arguments[1].value;
auto * zero = llvm::ConstantFP::get(right->getType(), 0.0);
auto * neg_zero = llvm::ConstantFP::get(right->getType(), -0.0);
auto * is_null = builder.CreateOr(builder.CreateFCmpOEQ(right, zero), builder.CreateFCmpOEQ(right, neg_zero));
auto * result_value = builder.CreateFDiv(left, right);
auto * nullable_structure_type = toNativeType(builder, makeNullable(result_type));
auto * nullable_structure_value = llvm::Constant::getNullValue(nullable_structure_type);
auto * nullable_structure_with_result_value = builder.CreateInsertValue(nullable_structure_value, result_value, 0);
return builder.CreateInsertValue(nullable_structure_with_result_value, is_null, 1);
}
#endif // USE_EMBEDDED_COMPILER
};
}