blob: b567859000e87d041cf87a8c8553b53b531fe811 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "SparkFunctionNextDay.h"
#include <unordered_map>
#include <Columns/ColumnNullable.h>
#include <DataTypes/DataTypeDate32.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeDateTime64.h>
#include <DataTypes/DataTypeNullable.h>
#include <Functions/DateTimeTransforms.h>
#include <Functions/FunctionFactory.h>
#include <Functions/TransformDateTime64.h>
#include <Poco/Logger.h>
#include <Common/DateLUT.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>
#include <boost/algorithm/string/case_conv.hpp>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NOT_IMPLEMENTED;
}
}
namespace local_engine
{
DB::DataTypePtr SparkFunctionNextDay::getReturnTypeImpl(const DB::ColumnsWithTypeAndName & arguments) const
{
if (arguments.size() != 2)
{
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Function {} requires exactly 2 arguments, {} provided",
getName(), arguments.size());
}
auto arg0_type = DB::removeNullable(arguments[0].type);
auto arg0_which_type = DB::WhichDataType(arg0_type);
if (!arg0_which_type.isDate() && !arg0_which_type.isDate32() && !arg0_which_type.isDateTime() && !arg0_which_type.isDateTime64())
{
throw DB::Exception(
DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument[0] of function {}",
arguments[0].type->getName(),
getName());
}
auto arg1_type = DB::removeNullable(arguments[1].type);
if (!DB::WhichDataType(arg1_type).isString())
{
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument[1] of function {}",
arguments[1].type->getName(), getName());
}
return DB::makeNullableSafe(arg0_type);
}
DB::ColumnPtr SparkFunctionNextDay::executeImpl(
const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t /*input_rows_count*/) const
{
if (DB::isColumnConst(*arguments[1].column))
{
return executeConst(arguments, result_type);
}
else
{
return executeGeneral(arguments, result_type);
}
}
struct NextDayConstTransformer
{
static constexpr auto name = "NextDay";
static const UInt8 week_mode = 0;
static inline Int64 execute(Int64 day, UInt8 next_weekday, const DateLUTImpl & time_zone)
{
auto from_week_day = DB::ToDayOfWeekImpl::execute(day, week_mode, time_zone);
return day + static_cast<Int64>(calDayDelta(from_week_day, next_weekday));
}
static inline UInt32 execute(UInt32 day, UInt8 next_weekday, const DateLUTImpl & time_zone)
{
auto from_week_day = DB::ToDayOfWeekImpl::execute(day, week_mode, time_zone);
return day + static_cast<UInt32>(calDayDelta(from_week_day, next_weekday));
}
static inline Int32 execute(Int32 day, UInt8 next_weekday, const DateLUTImpl & time_zone)
{
auto from_week_day = DB::ToDayOfWeekImpl::execute(day, week_mode, time_zone);
return day + static_cast<Int32>(calDayDelta(from_week_day, next_weekday));
}
static inline UInt16 execute(UInt16 day, UInt8 next_weekday, const DateLUTImpl & time_zone)
{
auto from_week_day = DB::ToDayOfWeekImpl::execute(day, week_mode, time_zone);
return day + static_cast<UInt16>(calDayDelta(from_week_day, next_weekday));
}
static UInt16 calDayDelta(UInt16 from_week_day, UInt16 next_weekday)
{
return next_weekday > from_week_day ? next_weekday - from_week_day : 7 + next_weekday - from_week_day;
}
};
template <typename DateType>
class NextDayConstImpl
{
public:
template<typename Transformer>
static void execute(const DB::IColumn & src_col, UInt16 next_weekday, DB::IColumn * dst_col, Transformer transformer = {})
{
const auto & time_zone = DateLUT::instance();
size_t rows = src_col.size();
for(size_t i = 0; i < rows; ++i)
{
DB::Field from_day_value;
src_col.get(i, from_day_value);
if (from_day_value.isNull())
{
dst_col->insertDefault();
}
else
{
dst_col->insert(transformer.execute(
static_cast<typename DateType::FieldType>(from_day_value.safeGet<typename DateType::FieldType>()),
next_weekday,
time_zone));
}
}
}
template<typename Transformer>
static void execute(const DB::IColumn & src_col, const DB::IColumn & weekday_col, DB::IColumn * dst_col, Transformer transformer = {})
{
size_t rows = src_col.size();
const auto & time_zone = DateLUT::instance();
for (size_t i = 0; i < rows; ++i)
{
DB::Field from_day_value;
DB::Field weekday_value;
src_col.get(i, from_day_value);
weekday_col.get(i, weekday_value);
if (from_day_value.isNull() || weekday_value.isNull())
{
dst_col->insertDefault();
}
else
{
auto next_weekday = SparkFunctionNextDay::getDayOfWeek(weekday_value.safeGet<String>());
if (next_weekday < 0) [[unlikely]]
{
dst_col->insertDefault();
}
else
{
dst_col->insert(transformer.execute(
static_cast<typename DateType::FieldType>(from_day_value.safeGet<typename DateType::FieldType>()),
next_weekday,
time_zone));
}
}
}
}
};
DB::ColumnPtr SparkFunctionNextDay::executeConst(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type) const
{
const auto * from_date_col = arguments[0].column.get();
auto nested_result_type = DB::removeNullable(result_type);
auto next_week_day = getDayOfWeek((*arguments[1].column)[0].safeGet<String>());
if (next_week_day == -1) [[unlikely]]
return result_type->createColumnConstWithDefaultValue(from_date_col->size());
auto to_date_col = result_type->createColumn();
to_date_col->reserve(from_date_col->size());
DB::WhichDataType ty_which(nested_result_type);
if (ty_which.isDate())
{
NextDayConstImpl<DB::DataTypeDate>::execute(
*from_date_col, next_week_day, to_date_col.get(), NextDayConstTransformer{});
}
else if (ty_which.isDate32())
{
NextDayConstImpl<DB::DataTypeDate32>::execute(
*from_date_col, next_week_day, to_date_col.get(), NextDayConstTransformer{});
}
else if (ty_which.isDateTime())
{
NextDayConstImpl<DB::DataTypeDateTime>::execute(
*from_date_col, next_week_day, to_date_col.get(), NextDayConstTransformer{});
}
else if (ty_which.isDateTime64())
{
NextDayConstImpl<DB::DataTypeDateTime64>::execute(
*from_date_col,
next_week_day,
to_date_col.get(),
DB::TransformDateTime64<NextDayConstTransformer>{
assert_cast<const DB::DataTypeDateTime64 *>(nested_result_type.get())->getScale()});
}
return std::move(to_date_col);
}
DB::ColumnPtr SparkFunctionNextDay::executeGeneral(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type) const
{
const auto * from_date_col = arguments[0].column.get();
const auto * weekday_col = arguments[1].column.get();
auto nested_result_type = DB::removeNullable(result_type);
auto to_date_col = result_type->createColumn();
to_date_col->reserve(from_date_col->size());
DB::WhichDataType ty_which(nested_result_type);
if (ty_which.isDate())
{
NextDayConstImpl<DB::DataTypeDate>::execute(
*from_date_col, *weekday_col, to_date_col.get(), NextDayConstTransformer{});
}
else if (ty_which.isDate32())
{
NextDayConstImpl<DB::DataTypeDate32>::execute(
*from_date_col, *weekday_col, to_date_col.get(), NextDayConstTransformer{});
}
else if (ty_which.isDateTime())
{
NextDayConstImpl<DB::DataTypeDateTime>::execute(
*from_date_col, *weekday_col, to_date_col.get(), NextDayConstTransformer{});
}
else if (ty_which.isDateTime64())
{
NextDayConstImpl<DB::DataTypeDateTime64>::execute(
*from_date_col,
*weekday_col,
to_date_col.get(),
DB::TransformDateTime64<NextDayConstTransformer>{
assert_cast<const DB::DataTypeDateTime64 *>(nested_result_type.get())->getScale()});
}
return std::move(to_date_col);
}
Int8 SparkFunctionNextDay::getDayOfWeek(const String & abbr)
{
String lower_abbr = abbr;
boost::to_lower(lower_abbr);
// week mode must be 0
static const std::unordered_map<String, Int8> abbr2day = {
{"mo", 1},
{"mon", 1},
{"monday", 1},
{"tu", 2},
{"tue", 2},
{"tuesday", 2},
{"we", 3},
{"wed", 3},
{"wednesday", 3},
{"th", 4},
{"thu", 4},
{"thursday", 4},
{"fr", 5},
{"fri", 5},
{"friday", 5},
{"sa", 6},
{"sat", 6},
{"saturday", 6},
{"su", 7},
{"sun", 7},
{"sunday", 7}
};
auto it = abbr2day.find(lower_abbr);
if (it != abbr2day.end()) [[likely]]
return it->second;
return -1;
}
REGISTER_FUNCTION(SparkFunctionNextDay)
{
factory.registerFunction<SparkFunctionNextDay>();
}
}