blob: 2c619f949885ca30d9e0ca57fae27bb19fe3ef1a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "exprs/agg-fn.h"
#include "codegen/llvm-codegen.h"
#include "exprs/anyval-util.h"
#include "exprs/scalar-expr.h"
#include "runtime/descriptors.h"
#include "runtime/lib-cache.h"
#include "common/names.h"
using namespace impala_udf;
namespace impala {
AggFn::AggFn(const TExprNode& tnode, const SlotDescriptor& intermediate_slot_desc,
const SlotDescriptor& output_slot_desc)
: Expr(tnode),
is_merge_(tnode.agg_expr.is_merge_agg),
intermediate_slot_desc_(intermediate_slot_desc),
output_slot_desc_(output_slot_desc),
arg_type_descs_(AnyValUtil::ColumnTypesToTypeDescs(
ColumnType::FromThrift(tnode.agg_expr.arg_types))) {
DCHECK(tnode.__isset.fn);
DCHECK(tnode.fn.__isset.aggregate_fn);
DCHECK_EQ(tnode.node_type, TExprNodeType::AGGREGATE_EXPR);
DCHECK_EQ(ColumnType::FromThrift(tnode.type).type,
ColumnType::FromThrift(fn_.ret_type).type);
const string& fn_name = fn_.name.function_name;
if (fn_name == "count") {
agg_op_ = COUNT;
} else if (fn_name == "min") {
agg_op_ = MIN;
} else if (fn_name == "max") {
agg_op_ = MAX;
} else if (fn_name == "sum" || fn_name == "sum_init_zero") {
agg_op_ = SUM;
} else if (fn_name == "avg") {
agg_op_ = AVG;
} else if (fn_name == "ndv" || fn_name == "ndv_no_finalize") {
agg_op_ = NDV;
} else {
agg_op_ = OTHER;
}
}
Status AggFn::Init(const RowDescriptor& row_desc, RuntimeState* state) {
// Initialize all children (i.e. input exprs to this aggregate expr).
for (ScalarExpr* input_expr : children()) {
RETURN_IF_ERROR(input_expr->Init(row_desc, /*is_entry_point*/ false, state));
}
// Initialize the aggregate expressions' internals.
const TAggregateFunction& aggregate_fn = fn_.aggregate_fn;
DCHECK_EQ(intermediate_slot_desc_.type().type,
ColumnType::FromThrift(aggregate_fn.intermediate_type).type);
DCHECK_EQ(output_slot_desc_.type().type, ColumnType::FromThrift(fn_.ret_type).type);
time_t mtime = fn_.last_modified_time;
// Load the function pointers. Must have init() and update().
if (aggregate_fn.init_fn_symbol.empty() ||
aggregate_fn.update_fn_symbol.empty() ||
(aggregate_fn.merge_fn_symbol.empty() && !aggregate_fn.is_analytic_only_fn)) {
// This path is only for partially implemented builtins.
DCHECK_EQ(fn_.binary_type, TFunctionBinaryType::BUILTIN);
stringstream ss;
ss << "Function " << fn_.name.function_name << " is not implemented.";
return Status(ss.str());
}
RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(
fn_.hdfs_location, aggregate_fn.init_fn_symbol, mtime, &init_fn_, &cache_entry_));
RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
aggregate_fn.update_fn_symbol, mtime, &update_fn_, &cache_entry_));
// Merge() is not defined for purely analytic function.
if (!aggregate_fn.is_analytic_only_fn) {
RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
aggregate_fn.merge_fn_symbol, mtime, &merge_fn_, &cache_entry_));
}
// Serialize(), GetValue(), Remove() and Finalize() are optional
if (!aggregate_fn.serialize_fn_symbol.empty()) {
RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
aggregate_fn.serialize_fn_symbol, mtime, &serialize_fn_, &cache_entry_));
}
if (!aggregate_fn.get_value_fn_symbol.empty()) {
RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
aggregate_fn.get_value_fn_symbol, mtime, &get_value_fn_, &cache_entry_));
}
if (!aggregate_fn.remove_fn_symbol.empty()) {
RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
aggregate_fn.remove_fn_symbol, mtime, &remove_fn_, &cache_entry_));
}
if (!aggregate_fn.finalize_fn_symbol.empty()) {
RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
fn_.aggregate_fn.finalize_fn_symbol, mtime, &finalize_fn_, &cache_entry_));
}
return Status::OK();
}
Status AggFn::Create(const TExpr& texpr, const RowDescriptor& row_desc,
const SlotDescriptor& intermediate_slot_desc, const SlotDescriptor& output_slot_desc,
RuntimeState* state, AggFn** agg_fn) {
*agg_fn = nullptr;
ObjectPool* pool = state->obj_pool();
const TExprNode& texpr_node = texpr.nodes[0];
DCHECK_EQ(texpr_node.node_type, TExprNodeType::AGGREGATE_EXPR);
if (!texpr_node.__isset.fn) {
return Status("Function not set in thrift AGGREGATE_EXPR node");
}
AggFn* new_agg_fn =
pool->Add(new AggFn(texpr_node, intermediate_slot_desc, output_slot_desc));
RETURN_IF_ERROR(Expr::CreateTree(texpr, pool, new_agg_fn));
Status status = new_agg_fn->Init(row_desc, state);
if (UNLIKELY(!status.ok())) {
new_agg_fn->Close();
return status;
}
for (ScalarExpr* input_expr : new_agg_fn->children()) {
int fn_ctx_idx = 0;
input_expr->AssignFnCtxIdx(&fn_ctx_idx);
}
*agg_fn = new_agg_fn;
return Status::OK();
}
FunctionContext::TypeDesc AggFn::GetIntermediateTypeDesc() const {
return AnyValUtil::ColumnTypeToTypeDesc(intermediate_slot_desc_.type());
}
FunctionContext::TypeDesc AggFn::GetOutputTypeDesc() const {
return AnyValUtil::ColumnTypeToTypeDesc(output_slot_desc_.type());
}
Status AggFn::CodegenUpdateOrMergeFunction(
LlvmCodeGen* codegen, llvm::Function** uda_fn) {
const string& symbol =
is_merge_ ? fn_.aggregate_fn.merge_fn_symbol : fn_.aggregate_fn.update_fn_symbol;
vector<ColumnType> fn_arg_types;
for (ScalarExpr* input_expr : children()) {
fn_arg_types.push_back(input_expr->type());
}
// The intermediate value is passed as the last argument.
fn_arg_types.push_back(intermediate_type());
RETURN_IF_ERROR(codegen->LoadFunction(fn_, symbol, nullptr, fn_arg_types,
fn_arg_types.size(), false, uda_fn, &cache_entry_));
// Inline constants into the function body (if there is an IR body).
if (!(*uda_fn)->isDeclaration()) {
// TODO: IMPALA-4785: we should also replace references to GetIntermediateType()
// with constants.
codegen->InlineConstFnAttrs(GetOutputTypeDesc(), arg_type_descs_, *uda_fn);
*uda_fn = codegen->FinalizeFunction(*uda_fn);
if (*uda_fn == nullptr) {
return Status(TErrorCode::UDF_VERIFY_FAILED, symbol, fn_.hdfs_location);
}
}
return Status::OK();
}
void AggFn::Close() {
// This also closes all the input expressions.
Expr::Close();
}
void AggFn::Close(const vector<AggFn*>& exprs) {
for (AggFn* expr : exprs) expr->Close();
}
string AggFn::DebugString() const {
stringstream out;
out << "AggFn(op=" << agg_op_;
for (ScalarExpr* input_expr : children()) {
out << " " << input_expr->DebugString() << ")";
}
out << ")";
return out.str();
}
string AggFn::DebugString(const vector<AggFn*>& agg_fns) {
stringstream out;
out << "[";
for (int i = 0; i < agg_fns.size(); ++i) {
out << (i == 0 ? "" : " ") << agg_fns[i]->DebugString();
}
out << "]";
return out.str();
}
}