| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| #include <Columns/ColumnArray.h> |
| #include <Columns/ColumnConst.h> |
| #include <Columns/ColumnNullable.h> |
| #include <Columns/ColumnString.h> |
| #include <DataTypes/DataTypeNullable.h> |
| #include <DataTypes/DataTypeString.h> |
| #include <Functions/FunctionFactory.h> |
| #include <Functions/IFunction.h> |
| #include <Interpreters/Context_fwd.h> |
| #include <base/StringRef.h> |
| |
| using namespace DB; |
| |
| namespace DB |
| { |
| namespace ErrorCodes |
| { |
| extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
| extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
| } |
| } |
| |
| namespace local_engine |
| { |
| class SparkFunctionArrayJoin : public IFunction |
| { |
| public: |
| static constexpr auto name = "sparkArrayJoin"; |
| static FunctionPtr create(ContextPtr) { return std::make_shared<SparkFunctionArrayJoin>(); } |
| SparkFunctionArrayJoin() = default; |
| ~SparkFunctionArrayJoin() override = default; |
| bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } |
| size_t getNumberOfArguments() const override { return 0; } |
| String getName() const override { return name; } |
| bool isVariadic() const override { return true; } |
| bool useDefaultImplementationForConstants() const override { return true; } |
| |
| DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override |
| { |
| auto data_type = std::make_shared<DataTypeString>(); |
| return makeNullable(data_type); |
| } |
| |
| ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override |
| { |
| if (arguments.size() != 2 && arguments.size() != 3) |
| throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 or 3 arguments", getName()); |
| auto res_col = ColumnString::create(); |
| auto null_col = ColumnUInt8::create(input_rows_count, 0); |
| PaddedPODArray<UInt8> & null_result = null_col->getData(); |
| if (input_rows_count == 0) |
| return ColumnNullable::create(std::move(res_col), std::move(null_col)); |
| |
| const ColumnArray * array_col = array_col = checkAndGetColumn<ColumnArray>(arguments[0].column.get());; |
| if (!array_col) |
| throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st argument must be array type", getName()); |
| |
| const ColumnNullable * array_nested_col = checkAndGetColumn<ColumnNullable>(&array_col->getData()); |
| const ColumnString * string_col; |
| if (array_nested_col) |
| string_col = checkAndGetColumn<ColumnString>(array_nested_col->getNestedColumnPtr().get()); |
| else |
| string_col = checkAndGetColumn<ColumnString>(&array_col->getData()); |
| const ColumnArray::Offsets & array_offsets = array_col->getOffsets(); |
| const ColumnString::Offsets & string_offsets = string_col->getOffsets(); |
| const ColumnString::Chars & string_data = string_col->getChars(); |
| |
| auto extractColumnString = [&](const ColumnPtr & col) -> const ColumnString * |
| { |
| const ColumnString * res = nullptr; |
| if (col->isConst()) |
| { |
| const ColumnConst * const_col = checkAndGetColumn<ColumnConst>(col.get()); |
| if (const_col) |
| res = checkAndGetColumn<ColumnString>(const_col->getDataColumnPtr().get()); |
| } |
| else |
| res = checkAndGetColumn<ColumnString>(col.get()); |
| return res; |
| }; |
| bool const_delim_col = arguments[1].column->isConst(); |
| bool const_null_replacement_col = false; |
| const ColumnString * delim_col = extractColumnString(arguments[1].column); |
| const ColumnString * null_replacement_col = nullptr; |
| if (arguments.size() == 3) |
| { |
| const_null_replacement_col = arguments[2].column->isConst(); |
| null_replacement_col = extractColumnString(arguments[2].column); |
| } |
| size_t current_offset = 0, array_pos = 0; |
| for (size_t i = 0; i < array_col->size(); ++i) |
| { |
| String res; |
| const StringRef delim = const_delim_col ? delim_col->getDataAt(0) : delim_col->getDataAt(i); |
| StringRef null_replacement = StringRef(nullptr, 0); |
| if (null_replacement_col) |
| { |
| null_replacement = const_null_replacement_col ? null_replacement_col->getDataAt(0) : null_replacement_col->getDataAt(i); |
| } |
| size_t array_size = array_offsets[i] - current_offset; |
| size_t data_pos = array_pos == 0 ? 0 : string_offsets[array_pos - 1]; |
| size_t last_not_null_pos = 0; |
| for (size_t j = 0; j < array_size; ++j) |
| { |
| if (array_nested_col && array_nested_col->isNullAt(j + array_pos)) |
| { |
| if (null_replacement.data) |
| { |
| res += null_replacement.toString(); |
| if (j != array_size - 1) |
| res += delim.toString(); |
| } |
| else if (j == array_size - 1) |
| res = res.substr(0, last_not_null_pos); |
| } |
| else |
| { |
| const StringRef s(&string_data[data_pos], string_offsets[j + array_pos] - data_pos - 1); |
| res += s.toString(); |
| last_not_null_pos = res.size(); |
| if (j != array_size - 1) |
| res += delim.toString(); |
| } |
| data_pos = string_offsets[j + array_pos]; |
| } |
| array_pos += array_size; |
| res_col->insertData(res.data(), res.length()); |
| current_offset = array_offsets[i]; |
| } |
| return ColumnNullable::create(std::move(res_col), std::move(null_col)); |
| } |
| }; |
| |
| REGISTER_FUNCTION(SparkArrayJoin) |
| { |
| factory.registerFunction<SparkFunctionArrayJoin>(); |
| } |
| } |