blob: fab3e66126002616fab4cda63ccde154eb09da40 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "SparkFunctionConv.h"
#include <string>
#include <DataTypes/DataTypeNullable.h>
#include <Functions/FunctionFactory.h>
#include <Poco/Logger.h>
#include <Poco/Types.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>
#include <DataTypes/IDataType.h>
#include <base/types.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}
namespace local_engine
{
DB::DataTypePtr SparkFunctionConv::getReturnTypeImpl(const DB::DataTypes & arguments) const
{
if (arguments.size() != 3)
throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}, should be 3.",
getName(), arguments.size());
if (!isInteger(arguments[1]))
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Second argument for function {} must be Int", getName());
if (!isInteger(arguments[2]))
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Third argument for function {} must be Int", getName());
auto arg0_type = DB::removeNullable(arguments[0]);
return std::make_shared<DB::DataTypeNullable>(arg0_type);
}
/// Taken from mysql-server sql/item_strfunc.cc
static unsigned long long my_strntoull_8bit(const char *nptr,
size_t l, int base, const char **endptr,
int *err)
{
int negative;
unsigned long long cutoff = 0;
unsigned cutlim = 0;
unsigned long long i = 0;
const char *save = nullptr;
int overflow = 0;
*err = 0; /* Initialize error indicator */
const char *s = nptr;
const char *e = nptr + l;
for (; s < e && std::isspace(*s); s++)
;
if (s == e)
{
err[0] = EDOM;
goto noconv;
}
if (*s == '-')
{
negative = 1;
++s;
} else if (*s == '+') {
negative = 0;
++s;
} else
negative = 0;
save = s;
cutoff = (~static_cast<unsigned long long>(0)) / static_cast<unsigned long int>(base);
cutlim = static_cast<unsigned>(((~static_cast<unsigned long long>(0)) % static_cast<unsigned long int>(base)));
overflow = 0;
i = 0;
for (; s != e; s++) {
uint8_t c = *s;
if (c >= '0' && c <= '9')
c -= '0';
else if (c >= 'A' && c <= 'Z')
c = c - 'A' + 10;
else if (c >= 'a' && c <= 'z')
c = c - 'a' + 10;
else
break;
if (c >= base) break;
if (i > cutoff || (i == cutoff && c > cutlim))
overflow = 1;
else
{
i *= static_cast<unsigned long long>(base);
i += c;
}
}
if (s == save) goto noconv;
if (endptr != nullptr) *endptr = s;
if (overflow)
{
err[0] = ERANGE;
return (~static_cast<unsigned long long>(0));
}
return negative ? -i : i;
noconv:
if (endptr != nullptr) *endptr = nptr;
return 0L;
}
static char * ll2str(int64_t val, char * dst, int radix, bool upcase)
{
constexpr std::array<const char, 37> dig_vec_upper{"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
constexpr std::array<const char, 37> dig_vec_lower{"0123456789abcdefghijklmnopqrstuvwxyz"};
char buffer[65];
const char *const dig_vec = upcase ? dig_vec_upper.data() : dig_vec_lower.data();
auto uval = static_cast<uint64_t>(val);
if (radix < 0)
{
if (radix < -36 || radix > -2) return nullptr;
if (val < 0)
{
*dst++ = '-';
/* Avoid integer overflow in (-val) for LLONG_MIN (BUG#31799). */
uval = 0ULL - uval;
}
radix = -radix;
}
else if (radix > 36 || radix < 2)
{
return nullptr;
}
char *p = std::end(buffer);
do
{
*--p = dig_vec[uval % radix];
uval /= radix;
} while (uval != 0);
const size_t length = std::end(buffer) - p;
memcpy(dst, p, length);
dst[length] = '\0';
return dst + length;
}
DB::ColumnPtr SparkFunctionConv::executeImpl(
const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const
{
using longlong = Int64;
auto from_base = static_cast<int>(arguments[1].column->getInt(0));
auto to_base = static_cast<int>(arguments[2].column->getInt(0));
auto result = result_type->createColumn();
result->reserve(input_rows_count);
// Note that abs(INT_MIN) is undefined.
if (from_base == INT_MIN || to_base == INT_MIN || abs(to_base) > 36 || abs(to_base) < 2 || abs(from_base) > 36 || abs(from_base) < 2)
{
for (size_t i = 0; i < input_rows_count; ++i)
result->insertData(nullptr, 1);
return result;
}
longlong dec = 0;
const char * endptr = nullptr;
int err = 0;
for (size_t i = 0; i < input_rows_count; ++i)
{
auto value_str = arguments[0].column->getDataAt(i).toString();
if (from_base < 0)
dec = my_strntoull_8bit(value_str.data(), value_str.length(), -from_base, &endptr, &err);
else
dec = static_cast<longlong>(my_strntoull_8bit(value_str.data(), value_str.length(), from_base, &endptr, &err));
if (err == EDOM)
{
result->insertData(nullptr, 1);
continue;
}
char ans[CONV_MAX_LENGTH + 1U];
auto ret_ptr = ll2str(dec, ans, to_base, true);
if (ret_ptr == nullptr)
result->insertData(nullptr, 1);
else
result->insertData(ans, ret_ptr - ans);
}
return result;
}
REGISTER_FUNCTION(SparkFunctionConv)
{
factory.registerFunction<SparkFunctionConv>();
}
}