blob: 96faa9d1dc1d0aebbe41ef3c68fb46e8027a5de2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Columns/ColumnArray.h>
#include <Columns/ColumnNullable.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeNullable.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
}
/// arrayFlatten([[1, 2, 3], [4, 5]]) = [1, 2, 3, 4, 5] - flatten array.
class SparkArrayFlatten : public IFunction
{
public:
static constexpr auto name = "sparkArrayFlatten";
static FunctionPtr create(ContextPtr) { return std::make_shared<SparkArrayFlatten>(); }
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForConstants() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isArray(arguments[0]))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}, expected Array",
arguments[0]->getName(), getName());
DataTypePtr nested_type = arguments[0];
nested_type = checkAndGetDataType<DataTypeArray>(removeNullable(nested_type).get())->getNestedType();
return nested_type;
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
/** We create an array column with array elements as the most deep elements of nested arrays,
* and construct offsets by selecting elements of most deep offsets by values of ancestor offsets.
*
Example 1:
Source column: Array(Array(UInt8)):
Row 1: [[1, 2, 3], [4, 5]], Row 2: [[6], [7, 8]]
data: [1, 2, 3], [4, 5], [6], [7, 8]
offsets: 2, 4
data.data: 1 2 3 4 5 6 7 8
data.offsets: 3 5 6 8
Result column: Array(UInt8):
Row 1: [1, 2, 3, 4, 5], Row 2: [6, 7, 8]
data: 1 2 3 4 5 6 7 8
offsets: 5 8
Result offsets are selected from the most deep (data.offsets) by previous deep (offsets) (and values are decremented by one):
3 5 6 8
^ ^
Example 2:
Source column: Array(Array(Array(UInt8))):
Row 1: [[], [[1], [], [2, 3]]], Row 2: [[[4]]]
most deep data: 1 2 3 4
offsets1: 2 3
offsets2: 0 3 4
- ^ ^ - select by prev offsets
offsets3: 1 1 3 4
- ^ ^ - select by prev offsets
result offsets: 3, 4
result: Row 1: [1, 2, 3], Row2: [4]
*/
const ColumnArray * src_col = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
if (!src_col)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} in argument of function 'arrayFlatten'",
arguments[0].column->getName());
const IColumn::Offsets & src_offsets = src_col->getOffsets();
ColumnArray::ColumnOffsets::MutablePtr result_offsets_column;
const IColumn::Offsets * prev_offsets = &src_offsets;
const IColumn * prev_data = &src_col->getData();
bool nullable = prev_data->isNullable();
// when array has null element, return null
if (nullable)
{
const ColumnNullable * nullable_column = checkAndGetColumn<ColumnNullable>(prev_data);
prev_data = nullable_column->getNestedColumnPtr().get();
for (size_t i = 0; i < nullable_column->size(); i++)
{
if (nullable_column->isNullAt(i))
{
auto res= nullable_column->cloneEmpty();
res->insertManyDefaults(input_rows_count);
return res;
}
}
}
if (isNothing(prev_data->getDataType()))
return prev_data->cloneResized(input_rows_count);
// only flatten one dimension
if (const ColumnArray * next_col = checkAndGetColumn<ColumnArray>(prev_data))
{
result_offsets_column = ColumnArray::ColumnOffsets::create(input_rows_count);
IColumn::Offsets & result_offsets = result_offsets_column->getData();
const IColumn::Offsets * next_offsets = &next_col->getOffsets();
for (size_t i = 0; i < input_rows_count; ++i)
result_offsets[i] = (*next_offsets)[(*prev_offsets)[i] - 1]; /// -1 array subscript is Ok, see PaddedPODArray
prev_data = &next_col->getData();
}
auto res = ColumnArray::create(
prev_data->getPtr(),
result_offsets_column ? std::move(result_offsets_column) : src_col->getOffsetsPtr());
if (nullable)
return makeNullable(res);
return res;
}
private:
String getName() const override
{
return name;
}
};
REGISTER_FUNCTION(SparkArrayFlatten)
{
factory.registerFunction<SparkArrayFlatten>();
}
}