blob: a7501d6a282044623a3febb71c4f69e4605ce5b3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "AggregateSerializationUtils.h"
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnString.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeFixedString.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeString.h>
#include <Functions/FunctionHelpers.h>
#include <IO/WriteBufferFromVector.h>
#include <IO/WriteHelpers.h>
#include <Common/Arena.h>
using namespace DB;
namespace local_engine
{
bool isFixedSizeStateAggregateFunction(const String& name)
{
static const std::set<String> function_set = {"min", "max", "sum", "count", "avg"};
return function_set.contains(name);
}
bool isFixedSizeArguments(const DataTypes& data_types)
{
return removeNullable(data_types.front())->isValueRepresentedByNumber();
}
bool isFixedSizeAggregateFunction(const DB::AggregateFunctionPtr& function)
{
return isFixedSizeStateAggregateFunction(function->getName()) && isFixedSizeArguments(function->getArgumentTypes());
}
DB::ColumnWithTypeAndName convertAggregateStateToFixedString(const DB::ColumnWithTypeAndName& col)
{
const auto *aggregate_col = checkAndGetColumn<ColumnAggregateFunction>(&*col.column);
if (!aggregate_col)
{
return col;
}
// only support known fixed size aggregate function
if (!isFixedSizeAggregateFunction(aggregate_col->getAggregateFunction()))
{
return col;
}
size_t state_size = aggregate_col->getAggregateFunction()->sizeOfData();
auto res_type = std::make_shared<DataTypeFixedString>(state_size);
auto res_col = res_type->createColumn();
PaddedPODArray<UInt8> & column_chars_t = assert_cast<ColumnFixedString &>(*res_col).getChars();
column_chars_t.reserve_exact(aggregate_col->size() * state_size);
for (const auto & item : aggregate_col->getData())
{
column_chars_t.insert_assume_reserved(item, item + state_size);
}
return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name);
}
DB::ColumnWithTypeAndName convertAggregateStateToString(const DB::ColumnWithTypeAndName& col)
{
const auto *aggregate_col = checkAndGetColumn<ColumnAggregateFunction>(&*col.column);
if (!aggregate_col)
{
return col;
}
auto res_type = std::make_shared<DataTypeString>();
auto res_col = res_type->createColumn();
PaddedPODArray<UInt8> & column_chars = assert_cast<ColumnString &>(*res_col).getChars();
IColumn::Offsets & column_offsets = assert_cast<ColumnString &>(*res_col).getOffsets();
auto value_writer = WriteBufferFromVector<PaddedPODArray<UInt8>>(column_chars);
column_offsets.reserve_exact(aggregate_col->size());
for (const auto & item : aggregate_col->getData())
{
aggregate_col->getAggregateFunction()->serialize(item, value_writer);
column_offsets.emplace_back(value_writer.count());
}
return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name);
}
DB::ColumnWithTypeAndName convertFixedStringToAggregateState(const DB::ColumnWithTypeAndName & col, const DB::DataTypePtr & type)
{
chassert(WhichDataType(type).isAggregateFunction());
auto res_col = type->createColumn();
const auto * agg_type = checkAndGetDataType<DataTypeAggregateFunction>(type.get());
ColumnAggregateFunction & real_column = typeid_cast<ColumnAggregateFunction &>(*res_col);
auto & arena = real_column.createOrGetArena();
ColumnAggregateFunction::Container & vec = real_column.getData();
vec.reserve_exact(col.column->size());
auto agg_function = agg_type->getFunction();
size_t size_of_state = agg_function->sizeOfData();
size_t align_of_state = agg_function->alignOfData();
for (size_t i = 0; i < col.column->size(); ++i)
{
AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state);
agg_function->create(place);
auto value = col.column->getDataAt(i);
memcpy(place, value.data, value.size);
vec.push_back(place);
}
return DB::ColumnWithTypeAndName(std::move(res_col), type, col.name);
}
DB::Block convertAggregateStateInBlock(DB::Block& block)
{
ColumnsWithTypeAndName columns;
columns.reserve(block.columns());
for (const auto & item : block.getColumnsWithTypeAndName())
{
if (WhichDataType(item.type).isAggregateFunction())
{
const auto *aggregate_col = checkAndGetColumn<ColumnAggregateFunction>(&*item.column);
if (isFixedSizeAggregateFunction(aggregate_col->getAggregateFunction()))
columns.emplace_back(convertAggregateStateToFixedString(item));
else
columns.emplace_back(convertAggregateStateToString(item));
}
else
{
columns.emplace_back(item);
}
}
return columns;
}
}